使用TensorFlow Serving快速部署模型

工业产品中TensorFlow使用方法

  1. 用TensorFlow的C++/Java/Nodejs API直接使用保存的TensorFlow模型:类似Caffe,适合做桌面软件。
  2. 直接将使用TensorFlow的Python代码放到Flask等Web程序中,提供Restful接口:实现和调试方便,但效率不太高,不大适合高负荷场景,且没有版本管理、模型热更新等功能。
  3. 将TensorFlow模型托管到TensorFlow Serving中,提供RPC或Restful服务:实现方便,高效,自带版本管理、模型热更新等,很适合大规模线上业务。

参考链接:https://cloud.tencent.com/developer/article/1375668

TensorFlow Serving简介

Tensorflow Serving是Google官方提供的模型部署方式,正确导出模型后,可一分钟完成部署(官方广告)。TF1.8后,Tensorflow Serving支持RESTfull API和grpc的请求方式,模型部署完成后可很方便的利用post请求进行测试。

TensorFlow Serving服务框架

框架分为模型训练、模型上线和服务使用三部分。模型训练与正常的训练过程一致,只是导出时需要按照TF Serving的标准定义输入、输出和签名。模型上线时指定端口号和模型路径后,通过tensorflow_model_server命令启动服务。服务使用可通过grpc和RESTfull方式请求。

模型导出

需指定模型的输入和输出,并在tags中包含”serve”,在实际使用中,TF Serving要求导出模型包含”serve”这个tag。此外,还需要指定默认签名,tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY = “serving_default”,此外tf.saved_model.signature_constants定义了三类签名,分别是:

  • 分类classify
  • 回归regress
  • 预测predict
1
2
3
CLASSIFY_METHOD_NAME = "tensorflow/serving/classify"
PREDICT_METHOD_NAME = "tensorflow/serving/predict"
REGRESS_METHOD_NAME = "tensorflow/serving/regress"

一般而言,用predict就完事了。

1
2
3
4
5
6
7
8
with sess.graph.as_default() as graph:
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
signature = tf.saved_model.signature_def_utils.predict_signature_def(inputs={'image': in_image},
outputs={'prediction': graph.get_tensor_by_name('final_result:0')},)
builder.add_meta_graph_and_variables(sess=sess,
tags=["serve"],
signature_def_map={'predict':signature, tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:signature})
builder.save()

启动服务

1
tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=模型名 --model_base_path=模型所在路径

请求服务

1
curl -d '{"inputs": [[1.1,1.2,0.8,1.3]]}' -X POST http://localhost:8501/v1/models/模型名:predict

python可以通过post请求,golang可以通过grpc服务请求。