tensorflow API使用记录
文章目录
0. 常用函数
- tf.train.list_variables(ckpt_dir_or_file) Returns list of all variables in the checkpoint
1. tf.nn.bias_add
tf.nn.bias_add(value, bias, name = None),把bias添加到value上。其中bias必须为一维的,若value的维度大于1则为广播相加。具体来说,value最后一维的维度必须要和bias的维度一致才可以。
import tensorflow as tf a=tf.constant([[[1,1],[2,2],[3,3]]],dtype=tf.float32) b=tf.constant([1,-1],dtype=tf.float32) c=tf.constant([1],dtype=tf.float32) with tf.Session() as sess: print('bias_add:') print(a.shape) print(b.shape) print(sess.run(tf.nn.bias_add(a, b))) #执行下面语句错误 #print(sess.run(tf.nn.bias_add(a, c))) 2. tf.tensordot
tf.tensordot(a, b, axis),axis=0的时候,数据的维度就是拼接,如下所示:
a = tf.constant([1, 2, 1, 1, 1, 1, 1, 1, 1,1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], shape=[2,3,4]) print(a.shape) b = tf.constant([1,1,1,1,5,6,7,8,9,10,11,12], shape=[4,3]) print(b.shape) c = tf.tensordot(a, b, axes=0) print(c.shape) with tf.Session(): print(a.eval()) print(b.eval()) print(c.eval()) 相关shape的结果如下所示:
(2, 3, 4) (4, 3) (2, 3, 4, 4, 3) axis = 1的时候,是矩阵的乘法,这里的乘法是一种乘法的扩展版本。也就是广播+矩阵乘法的形式。如果不加广播,这样写就是错的,如下所示:
a = tf.constant([1, 2, 1, 1, 1, 1, 1, 1, 1,1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], shape=[2,3,4]) print(a.shape) b = tf.constant([1,1,1,1,5,6,7,8,9,10,11,12], shape=[4,3]) print(b.shape) c = tf.matmul(a, b) #该行会报错,因为多维张量的乘法,前面的维度必须相同,然后是最后两维对应的矩阵相乘 print(c.shape) with tf.Session(): print(a.eval()) print(b.eval()) print(c.eval()) 而使用tensordot就不会报错,具体代码如下所示:
a = tf.constant([1, 2, 1, 1, 1, 1, 1, 1, 1,1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], shape=[2,3,4]) print(a.shape) b = tf.constant([1,1,1,1,5,6,7,8,9,10,11,12], shape=[4,3]) print(b.shape) c = tf.tensordot(a, b, axes = 1) print(c.shape) with tf.Session(): print(a.eval()) print(b.eval()) print(c.eval()) 相关shape的结果如下所示:
(2, 3, 4) (4, 3) (2, 3, 3) 假设三维数据和一维数据进行tensordot,正确的写法之一是
a = tf.constant([1, 2, 1, 1, 1, 1, 1, 1, 1,1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], shape=[3, 2, 4]) print(a.shape) b = tf.constant([1, 1, 1, 1], shape=[4]) print(b.shape) c = tf.tensordot(a, b, axes=1) print(c.shape) with tf.Session(): print(a.eval()) print(b.eval()) print(c.eval()) 这种情况下会对b进行扩展,结果就是[4,1],然后a和b进行matmul,结果为[3,2,1],再进行sequeeze最后一维,结果是[3,2]。其中这里的矩阵表示的都是shape。
可参考链接:https://www.machenxiao.com/blog/tensordot
3. 交叉熵的不同API
- tf.keras.losses.BinaryCrossentropy(y_true, y_pred, from_logits=False, label_smoothing=0)
- categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1),需要注意的是y_true是one-hot表示,而y_pred表示了每种类别对应的概率。
- sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1),sparse相比于上式而言,y_true的输入是整数值,而不是one-hot表示。
参考链接:https://zhuanlan.zhihu.com/p/112314557