Workspace of LionHeart

Focus on Spark/TensorFlow and other BigData and ML platforms

Tensorflow模型保存与加载方法

导航[-]

本文档介绍如何保存和读取Tensorflow变量和模型。

官方文档可参考:

保存/读取变量

本节介绍如何存取Tensorflow变量。注意Estimators会自动在model_dir中存取变量。

tf.train.Saver 提供了存取模型的接口,其构造函数会在图中增加存取模型的op。

示例化的Saver对象提供方法来运行这些op并,设置checkpoint文件用于保存和恢复模型信息。

Saver关于将模型中定义的所有变量保存,如果你不了解加载模型的graph结构,可以参考后面的保存恢复模型一节。

TensorFlow将变量以二进制形式存储在文件中,保存的信息主要是变量名以及对应的Tensor的值。

保存变量

调用tf.train.Saver()来管理模型变量,示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

恢复变量

同样使用tf.train.Saver来从checkpoint中恢复变量,示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

存取指定的变量

有时可能只需要存取模型graph中的部分变量,可以list或者dict两种形式传给tf.train.Saver()来指定需要存取的变量。

  • list形式,模型的变量列表。
  • dict形式,name为新的保存后的变量名,value为模型中的变量名。

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})

# Use the saver object normally after that.
with tf.Session() as sess:
  # Initialize v1 since the saver will not.
  v1.initializer.run()
  saver.restore(sess, "/tmp/model.ckpt")

  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

注意:

  • 可以创建多个模型Saver对象来保存模型变量,同样的变量可被不同的Saver保存多次,restore操作用于从checkpoint中恢复变量的值。
  • 如果checkpoint中只保存了部分变量,那么恢复后,graph中其他的变量仍需要被初始化。
  • 查看checkpoint中的变量可以使用inspect_checkpoint库,特别是print_tensors_in_checkpoint_file函数。
  • Saver默认使用tf.Variable.name属性作为每个变量的变量名,然而你可以在Saver对象中未变量指定存储在checkpoint中的新名字。

保存/读取模型概述

当你想要保存整个模型(变量、模型graph以及graph的meta信息)时,我们推荐使用SavedModel

SavedModel是一种面向多语言的,可恢复的高度序列化封装的格式。

SavedModel运行上层系统或工具来生产、消费或者转换Tensorflow模型。

Tensorflow提供了多种机制来同SavedModel进行交互,包括tf.saved_model API, Estimator API 以及 CLI方式。

操作SavedModel API

本节聚焦在使用底层Tensorflow API时需要用到的保存或加载SavedModel的API。

构建SavedModel

我们提供了SavedModel builder的python实现。SavedModelBuilder提供保存MetaGraphDef结构的功能。

  • MetaGraphDef是MetaGraph的proto buffer表达形式。
  • MetaGraph是一个数据流图,以及相关的变量、资源和signatures。
  • signature是一个graph的输入与输出的集合。

每个加入到 SavedModel中的MetaGraphDef需要以用户指定的tag标注。tag提供了区分特定MetaGraphDef的方法。通常这些tag会标注MetaGraphDef的功能以及一些可选的硬件相关的信息。

SavedModelBuilder的使用示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
export_dir = ...
...
builder = tf.saved_model_builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph_and_variables(sess,
                                       [tag_constants.TRAINING],
                                       signature_def_map=foo_signatures,
                                       assets_collection=foo_assets)
...
# Add a second MetaGraphDef for inference.
with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph([tag_constants.SERVING])
...
builder.save()

加载SaveModel

调用python版本的SaveModel loader需要提供一下信息:

  • 保存graph定义和变量的session。
  • 用来标识MetaGraphDef的tag。
  • SavedModel对应的目录位置。

示例代码如下:

1
2
3
4
5
export_dir = ...
...
with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
  ...

C++版本的SavedModel loader也提供从指定目录恢复模型的API,并支持指定SessionOptions和RunOptions参数。

同样需要指定taq参数,被加载的SavedModel被SavedModelBundle引用,包含了MetaGraphDef以及加载它的session信息。

示例代码:

1
2
3
4
5
const string export_dir = ...
SavedModelBundle bundle;
...
LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},
               &bundle);

另外Tensorflow提供了一组MetaGraphDef和SignatureDef相关的常量供用户使用。

MetaGraphDef常量:

SignatureDef常量

在Estimators中使用SaveModel

当训练好Estimator模型后,你可能需要部署预测服务。你可以选择在本机启动一个本地服务,或在云端进行扩展。

部署Estimator训练出的模型需要先把模型导出为SaveModel格式,本节介绍:

  • 如何制定输出节点以及相关的API
  • 使用SavedModel导出模型
  • 发请求给本地模型预测服务

准备Serving输入

训练过程中 input_fn()用于提供数据输入,类似地,预测阶段输入数据由 serving_input_receiver_fn提供。

serving_input_receiver_fn有如下两个功能:

  • 将预测需要输入数据的placeholder添加到graph中。
  • 添加额外的用于将输入数据格式转换为feature Tensors格式的op。

函数返回 tf.estimator.export.ServingInputReceiver对象,封装了placeholders以及feature Tensor。

当编写 serving_input_receiver_fn时,需要提供一个 tf.parse_example的特定parser描述来说明数据解析的方式。

Parser说明是一个dict的形式,包含:

  • tf.FixedLenFeature
  • tf.VarLenFeature
  • tf.SparseFeature

示例代码:

1
2
3
4
5
6
7
8
9
10
11
feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}

def serving_input_receiver_fn():
  """An input receiver that expects a serialized tf.Example."""
  serialized_tf_example = tf.placeholder(dtype=tf.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
  receiver_tensors = {'examples': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

tf.estimator.export.build_parsing_serving_input_receiver_fn 工具函数给出了一个通用实现。

导出模型

调用 tf.estimator.Estimator.export_savedmodel ,提供导出路径以及serving_input_receiver_fn进行模型导出:

1
estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn)

该方法创建一个新的graph并调用serving_input_receiver_fn来获取输入tensor,随后调用Estimator的model_fn() 来产生模型的graph。

最终会创建一个带时间戳的目录(export_dir_base/)并将模型导出为SavedModel。

指定模型输出

通过export_outputs指定,其类型为 tf.estimator.EstimatorSpec,是一个形如 {name: output} 的dict,用于描述预测阶段的输出。

预测输出的值类型必须为 ExportOutput 的某个实现,例如:

  • tf.estimator.export.ClassificationOutput,
  • tf.estimator.export.RegressionOutput
  • tf.estimator.export.PredictOutput.

部署本地预测服务

本地部署预测服务需要使用TensorFlow Serving。

TensorFlow Serving是一个独立的开源项目,功能是加载SavedModel模型并对外提供gRPC服务。

首先安装Tensorflow Serving。

部署服务命令如下,将 $export_dir_base 替换为SavedModel导出的目录。

1
2
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_base_path=$export_dir_base

执行后在9000端口会启动一个gRPC预测服务。

向本地Server发送请求

发送预测请求需要通过PredictionService gRPC API。相关API依赖:

1
2
3
4
5
6
  deps = [
    "//tensorflow_serving/apis:classification_proto_py_pb2",
    "//tensorflow_serving/apis:regression_proto_py_pb2",
    "//tensorflow_serving/apis:predict_proto_py_pb2",
    "//tensorflow_serving/apis:prediction_service_proto_py_pb2"
  ]

python代码中可以像如下示例使用:

1
2
3
4
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import regression_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2

请求的数据会以proto buffer的形式发送,发送请求的示例代码:

1
2
3
4
5
6
7
8
9
10
from grpc.beta import implementations

channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

request = classification_pb2.ClassificationRequest()
example = request.input.example_list.examples.add()
example.features.feature['x'].float_list.value.extend(image[0].astype(float))

result = stub.Classify(request, 10.0)  # 10 secs timeout

本例中的返回值是一个 ClassificationResponse格式的PB数据。

SavedModel目录结构

Tensorflow为每个SavedModel组织目录结构形式如下:

1
2
3
4
5
6
assets/
assets.extra/
variables/
    variables.data-?????-of-?????
    variables.index
saved_model.pb|saved_model.pbtxt

说明如下:

  • assets:是一个子目录,包含了一些外部文件,例如词表等,这些资源文件会被特定的MetaGraphDef读取使用。
  • assets.extra:是一个子目录,用于上层应用或者用户添加一些自己的资源文件,但不会被模型的graph加载。该目录不由SavedModel管理。
  • variables:是一个子目录,用于存储tf.train.Saver的输出。
  • saved_model.pb/saved_model.pbtxt SavedModel的Proto Buffer描述。包含了MetaGraphDef的proto buffer形式的定义。

一个单独的SavedModel可以表达多个graph,SavedModel的多个graph共享一组checkpoint(变量和资源文件)。

组织形式如下图所示: