Press "Enter" to skip to content

TensorFlow模型的签名推荐与快速上线

TensorFlow模型的签名推荐与快速上线

简介

往期文章 我们给你推荐一种TensorFlow模型格式 介绍过, TensorFlow官方推荐SavedModel格式作为在线服务的模型文件格式。近期TensorFlow SavedModel模块又推出了simple_save接口,简化了模型签名的构建和模型导出的成本,这期就结合 simple_tensorflow_serving 来做模型签名推荐以及快速上线相关的介绍。

新旧接口

回顾一下过去导出SavedModel的函数接口,由于一个模型可以有多signature,每个签名可以有对应的method name,因此我们需要引入saved_model_builder、signature_constants、signature_def_utils、tag_constants等变量,复制粘贴代码较多。

而新增的simple_save()接口则简化很多,默认的签名就是DEFAULT_SERVING _SIGNATURE_DEF_KEY,默认的method就是PREDICT_METHOD_NAME,除了提供默认值用户也不需要调用utils.build_tensor_info()来封装op了,简化代码如下。

可以看出,使用新的接口彻底解决了每次导出模型都要Ctrl-c/Ctrl-v的烦恼,接下来就介绍几个快速上线的模型。

Simplest模型上线

Simplest模型是我们用于TensorFlow Session性能测试的模型,最简化的模型签名就是一个input op和一个output op,而且不混杂add/minus/multiple/divide/convolution等kernel的实现,理论上就可以测出TensorFlow本身与模型无关的性能,代码地址 tobegit3hub/tensorflow_examples 。

import tensorflow as tf

input_keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
output_keys = tf.identity(input_keys_placeholder)

session = tf.Session()
tf.saved_model.simple_save(session, "./model/1", inputs={"keys": input_keys_placeholder}, outputs={"keys": output_keys})

可以看出,简化后的代码只有5行,只需要import tensorflow即可,两个简单的op,以及创建Session和导出模型。这个模型可以用simple_tensorflow_serving上线。

simple_tensorflow_serving --model_base_path="./model"

模型上线后也可以预估,例如构造一个请求的JSON,在浏览器、命令行或者任意编程语言实现的HTTP客户端请求即可。

Linear模型上线

对于其他机器学习模型,模型导出方法也是类似的,首先是定义训练的Graph,然后指定模型签名的op,最后是调用simple_save接口来导出模型文件。这里以Linear模型为例,主要是训练数据可以在内存构造不需要依赖外网下载或者本地数据文件,完整代码在 tobegit3hub/tensorflow_examples 。

import numpy as np
import tensorflow as tf

# Prepare train data
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.33 + 10

# Define the model
X = tf.placeholder(tf.float32, shape=[2])
Y = tf.placeholder(tf.float32, shape=[2])
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="bias")
predict = X * w + b

loss = tf.square(Y - predict)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

# Create session to run
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())

  epoch = 1
  for i in range(10):
    for (x, y) in zip(train_X, train_Y):
      _, w_value, b_value = sess.run([train_op, w, b], feed_dict={X: [x], Y: [y]})
    print("Epoch: {}, w: {}, b: {}".format(epoch, w_value, b_value))
    epoch += 1

  export_dir = "./model/1/"
  print("Try to export the model in {}".format(export_dir))
  tf.saved_model.simple_save(sess, export_dir, inputs={"x": X}, outputs={"y": predict})

可以看出代码也不复杂,通过定义一个X * W + b的op来进行模型的预估,以及后续模型的预估,输入为x,正好训练的Graph以及预估的Graph是重合的无序额外定义op。这个模型上线方式是一样的,这里顺便介绍simple_tensorflow_serving的code gen的功能,我们可以在浏览器选择想要生成的编程语言就可以自动生成客户端代码。

simple_tensorflow_serving --model_base_path="./model"

除此之外,在前端还有一个JSON Inference的功能,使用code gen生成的JSON请求示例,我们在前端就可以做深度学习模型的在线预估了。当然,这个请求数据是根据TensorFlow SavedModel的Signature来生成的,几乎所有的模型都可以这种方式自动生成请求数据以及客户端代码,Inference结果或错误信息都会在浏览器中展示。

图像模型上线

前面已经介绍过通过TensorFlow模型的上线了,对于SavedModel模型来说,所有输入都是Tensor,所谓”tensor in, tensor out”,这是通用Serving + 任意Model的基础,但在绝大部分CV的场景下,模型的数据都是图片文件。

针对图像模型签名的优化,首先参考Google官方TensorFlow Serving的用法,例如MNIST模型的输入是[None, 28*28]也是可以支持的,但要求客户端自己把JPG、PNG等图片文件转成28*28的数组。然后我们参考AWS的mxnet model server,用户不经可以通过HTTP + application/json的方式请求,还可以通过form-data来做预估,这样用户在浏览器上传的图片文件就可以直接请求到预估服务而不需要额外的转化。最终我们采用的是base64 + decode op + reshape op的方案,下面以MNIST模型为例。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

def inference(input):
  weights = tf.get_variable(
      "weights", [784, 10], initializer=tf.random_normal_initializer())
  bias = tf.get_variable(
      "bias", [10], initializer=tf.random_normal_initializer())
  logits = tf.matmul(input, weights) + bias

  return logits

def main():
  mnist = input_data.read_data_sets("./input_data")

  x = tf.placeholder(tf.float32, [None, 784])
  logits = inference(x)
  y_ = tf.placeholder(tf.int64, [None])
  cross_entropy = tf.losses.sparse_softmax_cross_entropy(
      labels=y_, logits=logits)
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  init_op = tf.global_variables_initializer()

  # Define op for model signature
  tf.get_variable_scope().reuse_variables()

  model_base64_placeholder = tf.placeholder(
      shape=[None], dtype=tf.string, name="model_input_b64_images")
  model_base64_string = tf.decode_base64(model_base64_placeholder)
  model_base64_input = tf.map_fn(lambda x: tf.image.resize_images(tf.image.decode_jpeg(x, channels=1), [28, 28]), model_base64_string, dtype=tf.float32)
  model_base64_reshape_input = tf.reshape(model_base64_input, [-1, 28 * 28])
  model_logits = inference(model_base64_reshape_input)
  model_predict_softmax = tf.nn.softmax(model_logits)
  model_predict = tf.argmax(model_predict_softmax, 1)

  with tf.Session() as sess:
    sess.run(init_op)

    for i in range(938):
      batch_xs, batch_ys = mnist.train.next_batch(64)
      sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    # Export image model
    export_dir = "./model/1"
    print("Try to export the model in {}".format(export_dir))
    tf.saved_model.simple_save(
        sess,
        export_dir,
        inputs={"images": model_base64_placeholder},
        outputs={
            "predict": model_predict,
            "probability": model_predict_softmax
        })

MNIST的代码会稍微复杂一些,首先模型运行和模型签名用的op不同,因为我们希望预估的时候允许用户上传图片文件而不是训练时用的Tensor,其次是要复用模型的权重用了get_variable()以及reuse_variables(),然后是在模型签名的输出op上做了各种tf.decode_base64、tf.reshape、tf.argmax等操作,这部分逻辑与训练无关却需要加到模型训练的脚本中方便导出模型。

通过上面的代码,我们就可以上线一个图像模型,可以接收以form-data形式传输的图片文件,图片文件的文件类型以及长、宽、通道数都没有要求,进入模型后会被统一处理成模型输入的shape。而在simple_tensorflow_serving中,我们可以用到Image Inference功能,直接在浏览器页面上传图片文件,后台不需要额外处理直接请求TensorFlow SavedModel进行模型预测,结果返回到浏览器前端,这个步骤对于任意的CV场景以及CV模型都是通用的。

总结

最后总结下,本文只是介绍SavedModel模型签名的几种用法,并没有创造新的使用接口,但通过约定图像模型的输入字段以及TensorFlow内置op的预处理就可以实现更多“高级”的Inference接口。

大家在导出TensorFlow SavedModel时,根据预估接口在定义inputs和outputs,即使输入op或输出op与训练的Graph不同也没关系,可以新增op并且通过reuse_variables()等方式来共享权重实现更灵活的接口。而且一个SavedModel可以有多个Signature,可以导出多个模型签名来保证接口的多样性,也可以使用本文介绍的simple_save()接口简单地导出模型,或者使用simple_tensorflow_serving来快速上线和验证模型。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注