Press "Enter" to skip to content

Tensorflow 使用笔记:TFRecords

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

Tensorflow 的数据输入现在主要有两种形式:直接使用 Python 和 TFRecords . 在图像的项目中看到比较多的是直接自己实现dataprovider ,在 NLP 项目中见到比较多先做生成TFRecords 然后利用 tf.data.TFRecordDataset 来读取。我习惯 TRFRecords 的方式来实现。主要因为可以把数据清洗和模型处理的过程分开,二者不是混杂在一起。TFRecords 作为中间格式存在,生成什幺样的 TFRecord 完全决定于你对要做的问题的理解,因为这里定义了你将要用到的特征。

 

我们通常有图像或文本这样的原始数据,拿图像分类或文本分类任务来说。我们的输入特征可能是图像的像素矩阵或者文本中词对应的 ID 而分类标签可能是对应标签的Id 或者甚至直接是字符串等等。Tensorflow 把这样的数据抽象成 Example 。 Example 有很多 Feature 这些 Feature 的数据类型主要有三种。TFrecord 中存储的就是 Example 对象对应的二进制数据,确切的说是使用 protobuf 序列化的二进制数据。在读取的使用 Tensorflow 提供的 DataSet API 在对序列化的数据解码的时候可以把想用的特征解码成对应的 Tensor 。简单的抽象和实现流程如下。

 

Example

 

在创建 TFRecords 的过程中需要对Example 的定义比较好的理解。 数据类型抽象成三种:bytes, float, int64 , Feature 的基本组成单元是这三种数据的 list 定义如下:

 

message BytesList {
  repeated bytes value = 1;
}
message FloatList {
  repeated float value = 1 [packed = true];
}
message Int64List {
  repeated int64 value = 1 [packed = true];
}

 

Feature 就是 BytesList, FloatList,Int64List 的封装

 

message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

 

Feature 可以组成Map 状态的 Features :

 

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

 

还可以组成 FeatureList

 

message FeatureList {
  repeated Feature feature = 1;
};

 

二者结合还可以产生下面类型

 

message FeatureLists {
  // Map from feature name to feature list.
  map<string, FeatureList> feature_list = 1;
};

 

如果对protobuf 的语法有了解的话,这些定义就很明了了。

 

Exmaple 的是 map 型的 Feature 的组合

 

message Example {
  Features features = 1;
};

 

序列状态的 Example

 

message SequenceExample {
  Features context = 1;
  FeatureLists feature_lists = 2;
};

 

了解这些定义之后,我们要做的就是把各种原始数据转成 bytes ,float ,int 类型然后构造成 Feature 然后组成 Example 序列到 文件中就好了 下面是一个完整的例子把 mnist 的数据序列化到 TFRecords

 

#!/usr/bin/env python
#-*- coding:utf-8 -*-
#author: wu.zheng midday.me
import mnist
import cv2
import os
import sys
import  as np
import tensorflow as tf
def _bytes_feature(value):
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def write_tfrecord(data, labels, out_data_path):
    writer = tf.python_io.TFRecordWriter(out_data_path)
    counter = 0
    total_count = len(data)
    for image, label in zip(data, labels):
        counter += 1
        image = np.array(image)
        image = image.reshape((28, 28))
        is_success, image_buffer = cv2.imencode(".jpg", image)
        if not is_success:
            continue
        label_value = [0] * 10
        label_value[label] = 1
        image_feature = _bytes_feature(image_buffer.tostring());
        label_feature = _int64_feature(label_value)
        features = tf.train.Features(feature={"image":image_feature, "label":label_feature})
        example = tf.train.Example(features=features)
        writer.write(example.SerializeToString())
        sys.stdout.write("\r>>Writing to {:s}  {:d}/{:d}".format(out_data_path, counter, total_count))
        sys.stdout.flush()
    writer.close()
    sys.stdout.write("\n")
    sys.stdout.write(">>{:s} write finish. ".format(out_data_path))

def create_mnist_tfrecord(in_data_floder, out_data_floder ):
    meta_data = mnist.MNIST(in_data_floder)
    train_data, train_labels = meta_data.load_training()
    test_data, test_labels = meta_data.load_testing()
    train_tf_record_path = os.path.join(out_data_floder, 'train_mnist.tfrecord')
    test_tf_record_path = os.path.join(out_data_floder, 'test_mnist.tfrecord')
    write_tfrecord(train_data, train_labels, train_tf_record_path)
    write_tfrecord(test_data, test_labels, test_tf_record_path)

if __name__ == "__main__":
    # datasets/mnist 下存放的是解压后的 mnist 数据, 
    in_data_floder = "./datasets/mnist"
    out_data_floder = "./datasets/mnist_tfrecord"
    create_mnist_tfrecord(in_data_floder, out_data_floder)

 

上门例子还有些需要改进的地方,通常这个构建过程相对比较慢,几百万的数据可能会花费一两天的时间,所以需要多线程处理,生成一个 TFRecords 文件可能会很大,也不方便分布式,通常会把生成的文件划分成很多份。

 

Input_fn

 

有了 TFRecords 我们可以实现一个 input_fn 就好了,如果后面我们有新的数据要添加进来继续训练我们的模型,也只需要按照上门的步骤处理成 TFRecords, input_fn 不用做改变。在 input_fn 里面我们可以做数据增强等一些处理

 

在这里有个比较麻烦的是 Example 中定义的 Feature 会有与之对应的 tf.data.Feature. 有 VarLenFeatureSparseFeature , FixedLenFeature , FixedLenSequenceFeature 使用的是后选择合适的 Feature 就好了,他们本质上是对应这不通形态的 Tensor 比如 VarLenFeature 会产生一个 SparseTensor

 

下面是 mnist 数据的 一个 input_fn 的 实现:

 

#!/usr/bin/env python
#-*- coding:utf-8 -*-
#author: wu.zheng midday.me
import tensorflow as tf
def _decode_record(record_proto):
    feature_map = {
            "image": tf.FixedLenFeature((), tf.string),
            'label': tf.VarLenFeature(tf.int64),
            }
    features = tf.parse_single_example(record_proto, features=feature_map)
    image = features['image']
    image = tf.image.decode_jpeg(image, channels=1)
    image =  tf.cast(image, tf.float32)
    paddings = tf.constant([[2, 2], [2, 2], [0,0]])
    image = tf.pad(image, paddings, mode='CONSTANT', constant_values=0 )
    image = image / 255.0
    label = features['label']
    example = {"image": image, "label": label}
    return example

def input_fn(tfrecord_path, batch_size, is_training):
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    if is_training:
        dataset = dataset.repeat().shuffle(buffer_size=10000)
    else:
        dataset = tf.repeat(1)
    dataset = dataset.map(lambda x: _decode_record(x))
    dataset = dataset.batch(batch_size=batch_size)
    return dataset.make_one_shot_iterator()

if __name__ == "__main__":
    tf_record_path = "./datasets/mnist_tfrecord/train_mnist.tfrecord"
    with tf.Session() as sess:
        iterator = input_fn(tf_record_path, 1, True)
        next_batch = iterator.get_next()
        sess.run(tf.global_variables_initializer())
        while True:
            batch = sess.run(next_batch)
            image = batch['image']
            print(image.shape)
            exit(0)

Be First to Comment

发表评论

邮箱地址不会被公开。 必填项已用*标注