Press "Enter" to skip to content

tensorflow 三种模型:ckpt、pb、pb-savemodel

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

1、CKPT

 

目录结构

 

checkpoint:

 

model.ckpt-1000.index

 

    model.ckpt-1000.data-00000-of-00001
    model.ckpt-1000.meta

 

特点:

 

首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;

 

数据和图是分开的

 

这种在训练的时候用的比较多。

 

代码:就省略了

 

2、pb模型-只有模型

 

这种方式只保存了模型的图结构,可以保留隐私的公布到网上。

 

感觉一些水的论文会用这种方式。

 

代码:

 

thanks: https://www.jianshu.com/p/9221fbf52c55

 

·

 

  1 import os
  2 import tensorflow as tf
  3 from tensorflow.python.saved_model import builder as saved_model_builder
  4 from tensorflow.python.saved_model import (signature_constants, signature_def_utils, tag_constants, utils)
  5 
  6 class model():
  7     def __init__(self):
  8         self.a = tf.placeholder(tf.float32, [None])
  9         self.w = tf.Variable(tf.constant(2.0, shape=[1]), name="w")
 10         b = tf.Variable(tf.constant(0.5, shape=[1]), name="b")
 11         self.y = self.a * self.w + b
 12 
 13 #模型保存为ckpt
 14 def save_model():
 15     graph1 = tf.Graph()
 16     with graph1.as_default():
 17         m = model()
 18     with tf.Session(graph=graph1) as session:
 19         session.run(tf.global_variables_initializer())
 20         update = tf.assign(m.w, [10])
 21         session.run(update)
 22         predict_y = session.run(m.y,feed_dict={m.a:[3.0]})
 23         print(predict_y)
 24 
 25         saver = tf.train.Saver()
 26         saver.save(session,"model_pb/model.ckpt")
 27 
 28 
 29 #保存为pb模型
 30 def export_model(session, m):
 31 
 32 
 33    #只需要修改这一段,定义输入输出,其他保持默认即可
 34     model_signature = signature_def_utils.build_signature_def(
 35         inputs={"input": utils.build_tensor_info(m.a)},
 36         outputs={
 37             "output": utils.build_tensor_info(m.y)},
 38 
 39         method_name=signature_constants.PREDICT_METHOD_NAME)
 40 
 41     export_path = "pb_model/1"
 42     if os.path.exists(export_path):
 43         os.system("rm -rf "+ export_path)
 44     print("Export the model to {}".format(export_path))
 45 
 46     try:
 47         legacy_init_op = tf.group(
 48             tf.tables_initializer(), name='legacy_init_op')
 49         builder = saved_model_builder.SavedModelBuilder(export_path)
 50         builder.add_meta_graph_and_variables(
 51             session, [tag_constants.SERVING],
 52             clear_devices=True,
 53             signature_def_map={
 54                 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
 55                     model_signature,
 56             },
 57             legacy_init_op=legacy_init_op)
 58 
 59         builder.save()
 60     except Exception as e:
 61         print("Fail to export saved model, exception: {}".format(e))
 62 
 63 #加载pb模型
 64 def load_pb():
 65     session = tf.Session(graph=tf.Graph())
 66     model_file_path = "pb_model/1"
 67     meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path)
 68 
 69     model_graph_signature = list(meta_graph.signature_def.items())[0][1]
 70     output_tensor_names = []
 71     output_op_names = []
 72     for output_item in model_graph_signature.outputs.items():
 73         output_op_name = output_item[0]
 74         output_op_names.append(output_op_name)
 75         output_tensor_name = output_item[1].name
 76         output_tensor_names.append(output_tensor_name)
 77     print("load model finish!")
 78     sentences = {}
 79     # 测试pb模型
 80     for test_x in [[1],[2],[3],[4],[5]]:
 81         sentences["input"] = test_x
 82         feed_dict_map = {}
 83         for input_item in model_graph_signature.inputs.items():
 84             input_op_name = input_item[0]
 85             input_tensor_name = input_item[1].name
 86             feed_dict_map[input_tensor_name] = sentences[input_op_name]
 87         predict_y = session.run(output_tensor_names, feed_dict=feed_dict_map)
 88         print("predict pb y:",predict_y)
 89 
 90 if __name__ == "__main__":
 91 
 92     save_model()
 93 
 94     graph2 = tf.Graph()
 95     with graph2.as_default():
 96         m = model()
 97         saver = tf.train.Saver()
 98     with tf.Session(graph=graph2) as session:
 99         saver.restore(session, "model_pb/model.ckpt") #加载ckpt模型
100         export_model(session, m)
101 
102     load_pb()

 

3、pb模型-Saved model

 

这是一种简单格式pb模型保存方式

 

目录结构

 

└── 1

 

···├── saved_model.pb

 

···└── variables

 

·········├── variables.data-00000-of-00001

 

·········└── variables.index

 

特点:

 

对于训练好的模型,我们都是用来进行使用的,也就是进行inference。

 

这个时候就不模型变化了。这种方式就将变量的权重变成了一个常亮。

 

这样方式模型会变小

 

在一些嵌入式吗,用C或者C++的系统中,我们也是常用.pb格式的。

 

代码:

 

thanks to  https://www.jianshu.com/p/9221fbf52c55

 

  1 import os
  2 import tensorflow as tf
  3 from tensorflow.python.saved_model import builder as saved_model_builder
  4 from tensorflow.python.saved_model import (signature_constants, signature_def_utils, tag_constants, utils)
  5 
  6 class model():
  7     def __init__(self):
  8         self.a = tf.placeholder(tf.float32, [None])
  9         self.w = tf.Variable(tf.constant(2.0, shape=[1]), name="w")
 10         b = tf.Variable(tf.constant(0.5, shape=[1]), name="b")
 11         self.y = self.a * self.w + b
 12 
 13 #模型保存为ckpt
 14 def save_model():
 15     graph1 = tf.Graph()
 16     with graph1.as_default():
 17         m = model()
 18     with tf.Session(graph=graph1) as session:
 19         session.run(tf.global_variables_initializer())
 20         update = tf.assign(m.w, [10])
 21         session.run(update)
 22         predict_y = session.run(m.y,feed_dict={m.a:[3.0]})
 23         print(predict_y)
 24 
 25         saver = tf.train.Saver()
 26         saver.save(session,"model_pb/model.ckpt")
 27 
 28 
 29 #保存为pb模型
 30 def export_model(session, m):
 31 
 32 
 33    #只需要修改这一段,定义输入输出,其他保持默认即可
 34     model_signature = signature_def_utils.build_signature_def(
 35         inputs={"input": utils.build_tensor_info(m.a)},
 36         outputs={
 37             "output": utils.build_tensor_info(m.y)},
 38 
 39         method_name=signature_constants.PREDICT_METHOD_NAME)
 40 
 41     export_path = "pb_model/1"
 42     if os.path.exists(export_path):
 43         os.system("rm -rf "+ export_path)
 44     print("Export the model to {}".format(export_path))
 45 
 46     try:
 47         legacy_init_op = tf.group(
 48             tf.tables_initializer(), name='legacy_init_op')
 49         builder = saved_model_builder.SavedModelBuilder(export_path)
 50         builder.add_meta_graph_and_variables(
 51             session, [tag_constants.SERVING],
 52             clear_devices=True,
 53             signature_def_map={
 54                 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
 55                     model_signature,
 56             },
 57             legacy_init_op=legacy_init_op)
 58 
 59         builder.save()
 60     except Exception as e:
 61         print("Fail to export saved model, exception: {}".format(e))
 62 
 63 #加载pb模型
 64 def load_pb():
 65     session = tf.Session(graph=tf.Graph())
 66     model_file_path = "pb_model/1"
 67     meta_graph = tf.saved_model.loader.load(session, [tf.saved_model.tag_constants.SERVING], model_file_path)
 68 
 69     model_graph_signature = list(meta_graph.signature_def.items())[0][1]
 70     output_tensor_names = []
 71     output_op_names = []
 72     for output_item in model_graph_signature.outputs.items():
 73         output_op_name = output_item[0]
 74         output_op_names.append(output_op_name)
 75         output_tensor_name = output_item[1].name
 76         output_tensor_names.append(output_tensor_name)
 77     print("load model finish!")
 78     sentences = {}
 79     # 测试pb模型
 80     for test_x in [[1],[2],[3],[4],[5]]:
 81         sentences["input"] = test_x
 82         feed_dict_map = {}
 83         for input_item in model_graph_signature.inputs.items():
 84             input_op_name = input_item[0]
 85             input_tensor_name = input_item[1].name
 86             feed_dict_map[input_tensor_name] = sentences[input_op_name]
 87         predict_y = session.run(output_tensor_names, feed_dict=feed_dict_map)
 88         print("predict pb y:",predict_y)
 89 
 90 if __name__ == "__main__":
 91 
 92     save_model()
 93 
 94     graph2 = tf.Graph()
 95     with graph2.as_default():
 96         m = model()
 97         saver = tf.train.Saver()
 98     with tf.Session(graph=graph2) as session:
 99         saver.restore(session, "model_pb/model.ckpt") #加载ckpt模型
100         export_model(session, m)
101 
102     load_pb()

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注