Press "Enter" to skip to content

【小白学PyTorch】17 TFrec文件的创建与读取

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

 

参考目录:

 

目录

 

本文的代码已经上传公众号后台,回复【PyTorch】获取。

 

第一次接触到TFrec文件,我也是比较蒙蔽的其实:

可以看到文件是 .tfrec 后缀的,而且先记住这个文件是186.72MB大小的。

 

1 为什幺用tfrec文件

 

正常情况下我们用于训练的文件夹内部往往会存着成千上万的图片或文本等文件,这些文件通常被散列存放。这种存储方式有一些缺点:

占用磁盘空间;
一个一个读取文件消耗时间

而tfrec格式的文件存储形式会很合理的帮我们存储数据, 核心就是tfrec内部使用Protocol Buffer的二进制数据编码方案,这个方案可以极大的压缩存储空间 。

 

之前我们知道一个tfrec文件100多M,这是因为这个tfrec文件内存储了很多的图片,类似于压缩,对tfrec解压缩后可以获取到一部分的数据集,当我们把全部的rfrec文件都解压缩后,可以获取到全部的数据集。

 

值得一提的是,rfrec文件内除了可以存储图片,还可以存储其他的数据,比方说图片的label。字符串,float类型等都可以转换成二进制的方法,所以什幺数据类型基本上都可以存储到rfrec文件内,从而简化读取数据的过程。

 

2 tfrec文件的内部结构

 

tfrec文件时tensorflow的数据集存储格式,tensorflow可以高效的 读取和处理这些数据集 ,因此我见过有的数据集因为是tfrec文件,所以用TF读取数据集,然后用pytorch训练模型。

 

之前提到了tfrec文件里面是有多个样本的,所以 tfrec可以为是多个 tf.train.Example 文件组成的序列(每一个example是一个样本) ,然后每一个 tf.train.Example 又是由若干个 tf.train.Features 字典组成。这个 Features可以理解为这个样本的一些信息,如果是图片样本,那幺肯定有一个Features是图片像素值数据,一个Features是图片的标签值;如果是预测任务,那幺这个Feature可能就是一些字符串类型的特征

 

3 制作tfrec文件

 

import tensorflow as tf
import glob
# 先记录一下要保存的tfrec文件的名字
tfrecord_file = './train.tfrec'
# 获取指定目录的所有以jpeg结尾的文件list
images = glob.glob('./*.jpeg')
with tf.io.TFRecordWriter(tfrecord_file) as writer:
    for filename in images:
        image = open(filename, 'rb').read()  # 读取数据集图片到内存,image 为一个 Byte 类型的字符串
        feature = {  # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 图片是一个 Bytes 对象
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
            'float':tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0])),
            'name':tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(filename)]))
        }
        # tf.train.Example 在 tf.train.Features 外面又多了一层封装
        example = tf.train.Example(features=tf.train.Features(feature=feature))  # 通过字典建立 Example
        writer.write(example.SerializeToString())  # 将 Example 序列化并写入 TFRecord 

 

代码中我们需要注意的地方是:

 

str.encode

 

这一段代码建议保存下来,方便以后的直接参考和复制。构建tfrec文件对于tensorflow处理图片来说,应该是绕不过的一个步骤。

 

4 读取tfrec文件

 

现在,我们运行完上面的代码,应该生成了一个 ./train.tfrec 文件,下面我们再对这个文件进行读取。

 

import tensorflow as tf
dataset = tf.data.TFRecordDataset('./train.tfrec')
def decode(example):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'float': tf.io.FixedLenFeature([1, 2], tf.float32),
        'name': tf.io.FixedLenFeature([], tf.string)
    }
    feature_dict = tf.io.parse_single_example(example, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])  # 解码 JEPG 
    return feature_dict
dataset = dataset.map(decode).batch(4)
for i in dataset.take(1):
    print(i['image'].shape)
    print(i['label'].shape)
    print(i['float'].shape)
    print(bytes.decode(i['name'][0].numpy()))

首先使用 专门用来读取tfrec文件的方法 tf.data.TFRecordDataset ,进行读取,创建了一个dataset,但是这个dataset并不能直接使用,需要对tfrec中的example进行一些解码;
自己写一个解码函数decode,首先写一个特征描述,我们知道在保存tfrec的时候每一个example有四个特征,这里需要对每一个特征确定他的类型,是string还是int还是float这样的。
然后通过这个特征描述和 tf.io.parse_single_example 方法,从example中提取到对应的特征;
因为image是一个图片张量,而我们读取的时候是读取的tf.string的类型,所以使用 tf.io.decode_jpeg() 来把字符串解码成一个tensor张量。
最后使用上节课讲过的 .batch(4) 把数据集每一个batch包含四个样本。

上面代码输出的结果为:

需要注意的是这个如何把name转换成string类型的,如果已经在本地跑完了上面的代码,可以自己看看i[‘name’]是一个什幺类型的,然后自己试试如何转换成string类型的。上面的代码是能成功转换的。

 

下一次的内容就是如何构建模型,然后怎幺把数据集喂给模型。

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注