Tensorflow是一个用来解决数学计算的计算框架,其实不管是CNN、RNN或者是各种变种神经网络模型,本质上都是各种数学公式堆叠在一起的网络结构,Tensorflow的亮点就在于采用数据流图的方式将数学公式粘合在一起,数据流图中的node一般都是各种数学操作,而edge则表示操作结果的数据流向。
Graph
先来看一下下面这段很简单的程序:
x = tf.placeholder(tf.float32, [1, 2], 'x')
y = tf.placeholder(tf.float32, [1], 'y')
w = tf.Variable([[1.], [1.]], dtype=tf.float32, name='w')
bias = tf.constant(1., dtype=tf.float32)
with tf.variable_scope("output"):
out = tf.matmul(x, w) + bias
with tf.variable_scope("loss"):
gap = tf.subtract(out, y)
with tf.Session() as sess:
feed_input = np.array([[1., 1.]], dtype='float32')
expect = np.array([1], dtype='float32')
sess.run(tf.global_variables_initializer())
loss = sess.run([gap], feed_dict={
x:feed_input,
y:expect
})
这段程序的核心就是:f(x)=x*w+bias
,然后再计算f(x)与给定y(期望)的差值。
那么对于上面这段tensorflow程序而言,它的Graph是什么样子呢,请看下图:

其实这段程序能够简单说明tensorflow中一般神经网络模型的训练原理,只是这段简单程序中没有涉及激活函数和梯度。
OP
数据流图中node其实在tensorflow中叫作op,也就是一个操作(其实也很好理解),op可以是:
- 占位符
- 变量
- 运算操作(加减乘除等等)
无论是占位符、变量、运算操作,本质上都产生了Tensor在图中流动。
对于上图中的椭圆框框,其实就是一个个op,旁边的英文就是该op的name,注意下图中add和sub,其实这两个op的名称准确来说分别为:
- output/add
- loss/sub
为什么为这样,因为op不光有name,还可以通过score进行限定,output和loss分别就是我在程序中为其指定的scope。
什么是模型文件
对于Tensorflow,模型文件就是数据流图(Graph),实际上就是数学操作、之间的依赖关系,以及涉及的参数值,那么Tensorflow中有哪些模型格式呢?
CKPT
ckpt(checkpoint)
就是tensorflow模型保存其中的一种格式,也是训练时最常用的,它的好处在于保存断点,可以基于断点再训练。
怎么保存一个ckpt:
saver = tf.train.Saver()
....
saver.save(sess, save_path, global_step)
通常,我们没有特殊指定,ckpt保存下来的的模型主要有三种文件:
.meta
保存图结构.index
和.data
保存图中变量的值
其实一旦开始训练,图的结构就固定了,唯一不确定的就是要学习的参数值,也就是Tensorflow中的Variable
,我们不断迭代训练根据loss函数然后求梯度去更新Variable
的值,所以采用ckpt
保存的话,就会把当前图中各个变量的值保存至.index和.data中。
PB(冻结图)
冻结图是什么呢,顾名思义,就是将Graph中涉及到的Variable参数进行固化,即转变成constant,也就是说一旦一个ckpt转成PB之后,就没法进行再进行参数的更新。所以PB模型可以看做是ckpt的release。
另外PB文件的大小相较于ckpt会小很多,至于为什么会小,你会在下一章节找到答案。
生成PB
下面是一段生成PB模型的代码:
'''CKPT转PB'''
def ckpt2pb(self, checkpoint_dir, output_names=None):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
saver = tf.train.import_meta_graph(checkpoint.model_checkpoint_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, checkpoint.model_checkpoint_path)
if(self.collectGraph):
self.saveNodes(checkpoint_dir, sess.graph)
input_graph_def = sess.graph.as_graph_def()
frozen_graph = convert_variables_to_constants(sess, input_graph_def, output_names)
return frozen_graph
'''保存节点'''
def saveNodes(self, checkpoint_dir, graph):
with open(os.path.join(checkpoint_dir, 'nodes.txt'), 'w') as f:
for node in graph.get_operations():
f.write(str(node.name) + "\n")
'''保存PB'''
def save2pb(self, frozen_graph, saveDir, pbName):
graph_io.write_graph(frozen_graph, saveDir, pbName, as_text=False)
其实上面这段程序的核心是:
convert_variables_to_constants(sess, input_graph_def, output_names)
我们来看一下这个函数对应的几个参数,这里转述一下源码注释(还是怀念Java,算了Python就不吐槽了):
- output_node_names graph中result节点的名称,也就是你需要进行fetch的节点
- variable_names_whitelist 转换成constant的variable节点集合,默认所有variable节点都转换
- variable_names_blacklist 不转换为constant的variable节点集合
这里主要关注一下output_node_names,这个入参取决你在实际场景下,你需要用到什么功能,同时选择哪些输出节点也决定了模型的冗余程度。
在release之后,我们使用模型会关心loss值、会关心accuracy吗?显然不会,我们只会关心infer/predict的结果,更有甚者,比如我们想从一个分类模型中抽取文本向量,那么我们只需要输出层前一隐藏层各个神经元的激活值。
因为Tensorflow中的Graph实际是有向图,所以一旦指定output节点,那么只会保存图中和指定的output节点关联的前向节点,output的后续节点都会在冻结图中丢弃。
Java加载PB
Maven依赖:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.4.0</version>
</dependency>
这边还是通过使用本文最开始那段程序对应模型为例,然后给出部分相关代码:
/**
* 加载模型
*
* @param model
*/
public Session loadModel(File model) throws IOException {
try {
Graph graph = new Graph();
graph.importGraphDef(Files.readAllBytes(model.toPath()));
return new Session(graph);
} catch (Exception e) {
throw new IOException("模型转Graph失败!", e);
}
}
public float[] infer(Session sess, float[] value, long[] shape) {
if (sess == null) {
throw new RuntimeException("Graph还未加载!");
}
Tensor<Float> input = Tensor.create(shape, FloatBuffer.wrap(value));
Tensor<?> out =
sess.runner().feed("x",input).fetch("output/add").run().get(0);
FloatBuffer buffer = FloatBuffer.allocate((int) out.shape()[0]);
out.writeTo(buffer);
return buffer.array();
}
That's all!