Press "Enter" to skip to content

Tensorflow基本概念

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

Tensorflow基本概念

 

1.Tensor

 

Tensorflow张量 ,是Tensorflow中最基础的概念,也是最主要的数据结构

 

2.Variable

 

Tensorflow变量 ,一般用于表示图中的各计算参数,包括矩阵,向量等。它在图中有固定的位置。

 

3.placeholder

 

Tensorflow占位符 ,用于表示输入输出数据的格式,允许传入指定的类型和形状的数据。

 

4.Session

 

Tensorflow会话 ,在Tensorflow中是计算图的具体执行者,与图进行实际的交互。

 

5.Operation

 

Tensorflow操作 ,是Tensorflow图中的节点,它的输入和输出都是 Tensor 。它的操作都是完成各种操作,包括算数操作、矩阵操作、神经网络构建操作等。

 

6.Queue

 

Tensorflow队列 ,也是图中的一个节点,是一种有状态的节点。

 

7.QueueRunner

 

队列管理器 ,通常会使用 多个线程 来读取数据,然后使用 一个线程 来使用数据。使用队列管理器来管理这些读写队列的线程。

 

8.Coordinator

 

使用QueueRunner时,由于入队和出队由各自线程完成,且未进行同步通讯,导致程序无法正常结束的情况。为了实现线程之间的 同步 ,需要使用 Coordinator

 

Tensorflow程序步骤

 

(一)加载训练数据

 

1.生成或导入样本数据集。

 

2.归一化处理。

 

3.划分样本数据集为 训练样本集测试样本集

 

(二)构建训练模型

 

1.初始化超参数

 

2.初始化变量和占位符

 

3.定义模型结构

 

4.定义损失函数

 

(三)进行数据训练

 

1.初始化模型

 

(四)评估和预测

 

1.评估机器学习模型

 

2.调优超参数

 

3.预测结果

 

加载数据

 

在Tensorflow中加载数据的方式一共有三种:预加载数据、填充数据和从文件读取数据。

 

预加载数据

 

在Tensorflow中定义常量或变量来保存所有数据,例如:

 

a = tf.constant([1, 2])
b = tf.constant([3, 4])
x = tf.add(a, b)

 

因为常数会直接存储在数据流图数据结构中,在训练过程中,这个结构体可能会被复制多次,从而导致内存的大量消耗。

 

填充数据

 

将数据填充到任意一个张量中。然后通过会话 run() 函数中的 feed_dict 参数进行获取数据:

 

数据量大 时,填充数据的方式也存在消耗内存的问题。

 

从CVS文件中读取数据

 

要存文件中读取数据, 首先需要使用读取器将数据读取到队列中,然后从队列中获取数据进行处理:

 

1.创建队列

 

2.创建读取器获取数据

 

3.处理数据

 

读取TFRecords数据

 

Tensorflow针对处理 数据量巨大 的应用场景进行了优化,定义了 格式。

 

采用这种读取方式读取数据分为两个步骤:

 

1.把样本数据转换为 TFRecords二进制文件

 

2.读取 TFRecords 格式。

 

存储和加载模型

 

Tensorflow中提供了 tf.train.Saver 类实现训练模型的保存和加载。

 

存储模型

 

在模型的设计和训练的过程中,会消耗大量的时间。为了降低训练过程中意外情况发生造成的不良影响,所以会对训练过程中模型进行定期存储。 (模型复用,节省整体训练时间)

 

saver = tf.train.Saver(max_to_keep, keep_checkpoint_event_n_hours)

 

存储的模型,会生成四个文件:

 

 

加载模型

 

为了保证意外中断的模型能够继续训练以及训练完成的模型加载在其他数据上直接使用,会对模型进行加载使用。

 

加载存储好的模型,包括了两个步骤:

 

1.加载模型:

 

saver = tf.train.import_meta_graph("my_test_model-100.meta")

 

2.加载训练参数:

 

saver.restore(sess, tf.train.latest_checkpoint('./'))

Be First to Comment

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注