TensorFlow有多火?

2015年11月,Google发布深度学习框架TensorFlow并宣布开源。一年多来,TensorFlow一骑绝尘,已然成为深度学习的代名词,将其他深度学习框架Theano、Caffe、Torch甩出了好几个身段。我们看一下StackOverflow上各个框架的提问数来直观地感受一下TensorFlow的火爆:一大波数据挖掘工程师正走在成为掏粪男孩(TF Boy)的路上!

本文的目的是帮助懂一些机器学习或统计学的朋友理解深度学习与TensorFlow的关键概念和用法。

Hello Linear Regression

TensorFlow官网的Get Started教程,演示了如何使用TensorFlow的Python API训练线性回归模型。这个例子特别好:线性回归模型大家都很熟,学这段代码不用纠结于高大上的深度学习模型理论;但是TensorFlow的精髓又都包含在这里面。

#-*- coding=utf-8 -*-

import tensorflow as tf
import numpy as np

# 阶段1:定义计算图
# 创建线性回归输入数据
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# 定义待学习参数:W为斜率,b为截距
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))

# 定义线性回归模型
y = W * x_data + b

# 定义损失(目标函数)
loss = tf.reduce_mean(tf.square(y - y_data))

# 定义参数求解器,代表一次参数更新操作
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# 定义变量初始化操作
init = tf.global_variables_initializer()

# 阶段2:执行计算图
sess = tf.Session()
sess.run(init)

for step in range(201):
    # 执行参数更新操作    
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))

基本概念

TensorFlow使用数据流图(Data Flow Graph)来描述一个深度学习任务。图由节点(Node)和边(Edge)组成;每个节点代表一个数学运算(简称op),指向一个节点的边代表该节点的输入,从节点引出来的边代表该节点的输出;输入和输出都是多维数组(Multidimensional Array),数学上又称Tensor。这是TensorFlow名字的由来。

双阶段编程模式

从代码里,我们发现,TensorFlow的代码往往由两个阶段组成:

A construction phase, that assembles a graph, and an execution phase that uses a session to execute ops in the graph.

TensorFlow通过tf.Session类,将两个阶段连接起来,作用是把在Python里定义好的数据流图部署到设备(Devices,如CPU、GPU),并提供具体执行这些op的方法。

为什么要这么设计呢?考虑到Python运行性能较低,我们在执行数值计算的时候,都会尽量使用非Python语言编写的代码,比如使用NumPy这种预编译好的C代码来做矩阵运算。在Python内部计算环境和外部计算环境(如NumPy)切换需要花费的时间称为overhead cost。对于一个简单运算,比如矩阵运算,从Python环境切换到Numpy,Numpy运算得到结果,再从Numpy切回Python,这个成本,比纯粹在Python内部做同类运算的成本要低很多。但是,一个复杂数值运算由多个基本运算组合而成,如果每个基本运算来一次这种环境切换,overhead cost就不可忽视了。为了减少来回的环境切换,TensorFlow的做法是,先在Python内定义好整个Graph(即一套复杂的数值运算集合),然后在Python外运行整个完整的Graph。因此TensorFlow的代码结构也就对应为两个阶段了。

下面我们来逐行解读一下这段代码。

阶段1:定义Graph

W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))

tf.Variable是TensorFlow的一个类,是取值可变的Tensor,构造函数的第一个参数是初始值initial_value。

initial_value: A Tensor, or Python object convertible to a Tensor, which is the initial value for the Variable.

tf.zeros(shape, dtype=tf.float32, name=None)是一个op,用于生成取值全是0的Constant Value Tensor。

tf.random_uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None, name=None)是一个op,用于生成服从uniform distribution的Random Tensor。

y = W * x_data + b

y是线性回归运算产生的Tensor。运算符*和+,等价为tf.multiple()tf.add()这两个TensorFlow提供的数学类ops。tf.multiple()的输入是W和x_data;W是Variable,属于Tensor,可以直接作为op的输入;x_data是numpy的多维数组ndarray,TensorFlow的ops接收到ndarray的输入时,会将其转化为tensor。tf.multiple()的输出是一个tensor,和b一起交给tf.add(),得到输出结果y。

至此,线性回归的模型已经建立好,但这只是Graph的一部分,还需要定义损失。

loss = tf.reduce_mean(tf.square(y - y_data))

loss是最小二乘法需要的目标函数,是一个Tensor,具体的op不再赘述。

optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

这一步指定求解器,并设定求解器的最小化目标为损失。train代表了求解器执行一次的输出Tensor。这里我们使用了梯度下降求解器,每一步会对输入loss求一次梯度,然后将loss里Variable类型的Tensor按照梯度更新取值。

init = tf.global_variables_initializer()

定义Graph阶段的代码,只是在Python内定义了Graph的结构,并不会真正执行。在执行Graph阶段,所有的变量要先进行初始化。每个变量可以单独初始化,但这样做有些繁琐,所以TensorFlow提供了一个方便的函数tf.global_variables_initializer()可以在Graph中添加一个初始化所有变量的op。

阶段2:执行Graph

sess = tf.Session()
sess.run(init)

Session是连接两个阶段的纽带。在进行任何计算以前,先给Variable赋初始值。

for step in range(201):
    sess.run(train)

train操作对应梯度下降法的一步迭代。当step为0时,train里的variable取值为初始值,根据初始值可以计算出梯度,然后将初始值根据梯度更新为更好的取值;当step为1时,train里的variable为上一步更新的值,根据这一步的值可以计算出一个新的梯度,然后将variable的取值更新为更好的取值;以此类推,直到达到最大迭代次数。

print(step, sess.run(W), sess.run(b))

如果我们将sess.run()赋值给Python环境的变量,或者传给Python环境的print,可以fetch执行op的输出Tensor取值,这些取值会转化为numpy的ndarray结构。因此,这就需要一次环境的切换,会增加overhead cost。所以我们一般会每隔一定步骤才fetch一下计算结果,以减少时间开销。

小结

TensorFlow使用数据流图表示深度学习模型,节点是操作,边是Tensor。TensorFlow的代码分为两个阶段:第一阶段定义图,包括定义变量、模型、损失、训练op、初始化op等;第二阶段执行图,创建Session,包括执行初始化op,执行训练op,获取被更新的变量取值等。TensorFlow训练线性回归的代码,麻雀虽小,五脏俱全,只需要根据使用的模型替换掉线性回归模型那一行代码即可。入了门的TF Boy们,怎么进一步提高自己在这方面的技能呢?一是去钻研高大上的深度学习理论,二是去学习TensorFlow工程部署上的一些API(字符串处理、Tensor变换、保存变量、读取文件、分布式等),三是在项目中积累实战经验。祝各位TF Boy们掏粪快乐!