Press "Enter" to skip to content

Tensorflow加载多个模型方法实践——Graph与Session

 

Tensorflow(1.x版本)是一种符号式编程框架,首先要构造一个图(graph),然后在这个图上做运算,也就是用计算图来构建网络,用会话(Session)来具体执行网络。

 

计算框架,是通过定义placeholder、Variable和OP等构成一张完成计算图Graph;接下来通过新建Session实例启动模型运行,Session实例会分布式执行Graph,输入数据,根据优化算法更新Variable,然后返回执行结果即Tensor实例。

 

计算图graph定义了计算过程与公式,是一些加减乘除等数学运算的组合。它本身不会进行任何计算,也不保存任何中间计算结果。

 

session用来运行一个graph,或者运行graph的一部分。它类似于一个执行者,给graph灌入输入数据,得到输出,并保存中间的计算结果。同时它也给graph分配计算资源(如内存、显卡等)。

 

一个graph可以供多个session使用,而一个session不一定需要使用graph的全部,可以只使用其中的一部分。

 

通常,使用上下文管理器的代码结果如下,其中使用默认graph。

 

with tf.Session() as sess:
            saver = tf.train.import_meta_graph(DB_info.BPNetModel_graph)
            saver.restore(sess,tf.train.latest_checkpoint(DB_info.BPNetModel))
            graph = tf.get_default_graph()
            x = graph.get_tensor_by_name("x:0")
            # 输出预测结果
            y_conv = graph.get_tensor_by_name('y_conv:0')
            keep_prob = graph.get_tensor_by_name("keep_prob:0")    
            ret = sess.run(y_conv, feed_dict={
 x:dtest,keep_prob:1.0})
            y = sess.run(tf.argmax(ret,1))  # 用于分类问题,取最大概率

 

Tensorflow加载多个模型方法是在Tensorflow中创建多个Session,每个Session运行一个graph,实践案例代码如下。

 

def ChurnModelWorking(self):
        graph = tf.Graph()                                                      # 定义图1
        with tf.Session(graph = graph) as sess:                                 # Session加载所定义的图
            saver = tf.train.import_meta_graph(DB_info.BPNetModel_graph)        # 加载模型图
            saver.restore(sess,tf.train.latest_checkpoint(DB_info.BPNetModel))  # 恢复模型参数
            
            x = graph.get_tensor_by_name("x:0")                                 # 从图中获取输入定义
            # 输出预测结果
            y_conv = graph.get_tensor_by_name('y_conv:0')                       # 从图中获取输出定义
            keep_prob = graph.get_tensor_by_name("keep_prob:0")    
            ret = sess.run(y_conv, feed_dict={
 x:dtest,keep_prob:1.0})
            y = sess.run(tf.argmax(ret,1))  # 用于分类问题,取最大概率
    def ChurnOtherModelWorking(self):
        graph_other = tf.Graph()                                                # 定义另一个图2
        with tf.Session(graph=graph_other) as sess_other:                       # Session加载所定义的图2
            saver_other = tf.train.import_meta_graph(DB_info.BPNetModelOther_graph)
            saver_other.restore(sess_other,tf.train.latest_checkpoint(DB_info.BPNetModelOther))            
            x = graph_other.get_tensor_by_name("x:0")
            # 输出预测结果
            y_conv = graph_other.get_tensor_by_name('y_conv:0')
            keep_prob = graph_other.get_tensor_by_name("keep_prob:0")    
            ret = sess_other.run(y_conv, feed_dict={
 x:dtest,keep_prob:1.0})
            y = sess_other.run(tf.argmax(ret,1))  # 用于分类问题,取最大概率

 

实践总结

 

通过在Tensorflow中创建多个Session,每个Session运行一个graph,来实现加载多个Tensorflow模型进行组合应用。

 

关于Tensorflow的图

 

graph视角的关系图

 

TensorFlow是一个通过计算图的形式来表述计算的编程系统。其中的Tnesor,代表它的数据结构,而Flow代表它的计算模型。TensorFlow中的每一个计算都是计算图上的一个节点,而节点之间的线描述了计算之间的依赖关系。

 

在TensorFlow程序中,系统会自动维护一个默认的计算图,通过tf.get_default_gragh函数可以获取当前默认的计算图。除了默认的计算图,TensorFlow也支持通过tf.Graph函数来生成新的计算图。不同的计算图上的张量和运算不会共享。

 

计算图举例如下:

 

参考:

 

[1]. Echo. ​Tensorflow Session使用和浅析
. 知乎. 2020.03

 

[2]. Arkenstone.Tensorflow同时加载使用多个模型
. 博客园. 2017.06

 

[3]. 马尔代夫Maldives. tensorflow中的Graph(图)和Session(会话)的关系
. 简书. 2019.07

 

[4]. 老夫叨叨叨.tensorflow: graph
. 简书. 2019.01

 

[5]. HOU_JUN.TensorFlow计算模型—计算图
. 博客园. 2018.04

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注