Press "Enter" to skip to content

TensorFlow中Tensor的shape概念与tf ops:tf.reshape

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

田海立@CSDN 2020-10-17

 

图解NCHW与NHWC数据格式 》中从逻辑表达和物理存储角度用图的方式讲述了NHWC与NCHW两种数据格式,数据shape是可以改变的,本文介绍TensorFlow里Tensor的Shape概念,并用图示和程序阐述了reshape运算。

 

一、TensorFlow中Tensor的Shape

 

TensorFlow中的数据都是由Tensor来表示,Shape相关有下列一些概念:

Rank:维数
Dimension:表达每一维长度
Size:所有的Dimension数值相乘,也就是Tensor里数据元素的尺寸了

rank为0/1/2的典型Tensor如下图所示:

 

 

Tensor rank为3时,数据表达为:

 

 

二、Tensor的逻辑表达与物理存储

 

如《 图解NCHW与NHWC数据格式 》中所述,数据可以从逻辑上和物理排布上去理解。而本文第一节中你可以仍从逻辑上去理解,还未牵涉到物理存储数据排布。

 

三维以下的比较容易理解,各个ML框架之间也没大的区别,对于三维(及以上)Tensor的排布就很不同了,这里着重介绍3-D。

 

我们已经知道TensorFlow的Tensor缺省是NHWC的,对于上面的shape(3, 2, 5)的Tensor【n为1】,在TensorFlow中应该是这样的:

 

 

如果数据值按顺序排布如下,

 

[[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],
       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],
       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]]

 

那幺对应上面三维立方体的摆放应该如下:

 

 

三、tf.reshape()

 

reshape原型如下:

 

tf.reshape(
    tensor, shape, name=None
)

 

3.1 tf.reshape()的不改变性

 

 

    1. tf.reshape()运算不改变数据的物理摆布 ,也就是说一个Tensor reshape到别的shape只是逻辑上shape的改变,存储的数据不会改变;

 

    1. tf.reshape()也就不会改变Tensor的size。 指定的新的shape的size如果与原Tensor的size不一致就会报错。比如上面Shape(3, 2, 5)的Tensor就没法reshape成7x?。

 

 

有了上面两个原则,tf.reshape()运算就很容易理解了,物理存储不变,就看rank以及各个dimension怎幺取了。

 

3.2 tf.reshape() 图示

 

比如,上面Tensor有30个数:从0~29顺序存储。可以存储为(3, 2, 5)【上面介绍过的3-D】,也可以存储为2D的(3, 10)或(6, 5),等。

 

 

3.3 程序实现如下:

 

TF2.0以后的版本上,直接可以执行,而不用还要在session下执行。当然前提是已经

 

import tensorflow as tf

 

1. 30个数的数据

 

>>> t = tf.range(30)
>>> t
<tf.Tensor: shape=(30,), dtype=int32, numpy=
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32)>
>>>

 

2. shape(3, 2, 5)

 

>>> t = tf.reshape(t, [3,2,5])
>>> t
<tf.Tensor: shape=(3, 2, 5), dtype=int32, numpy=
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],
       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],
       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]], dtype=int32)>
>>>

 

3. shape(3, 10)

 

>>> t = tf.reshape(t, [3, 10])
>>> t
<tf.Tensor: shape=(3, 10), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]], dtype=int32)>
>>>

 

4. shape(6, 5)

 

>>> t = tf.reshape(t, [6, 5])
>>> t
<tf.Tensor: shape=(6, 5), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]], dtype=int32)>
>>>

 

四、小结

 

本文介绍了TensorFlow里Tensor的Shape概念,并用图示和实际程序解释了reshape的变化。

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注