如何高效使用TensorFlow2.0
Install: pip install tensorflow==2.0.0-alpha0
or Build From source https://www.weaf.top/posts/c05ce623/
——————————————————-分割线—————————————————————————————–
首先在2.0中,为了使得用户效率提高有许多改变。2.0中移除了许多冗余的APIs,使得更多的API一致,而且很好的和Python集成了Eager execution。(着实让人很舒服)Eager execution
许多的RFCs对于2.0做的改变都有详细的解释.RFCs 假定你已经对TensorFlow1.x已经很熟悉的情况下,我们来介绍一下TensorFlow在开发中应该是什么样子的。
主要修改的简单总结
API 更加干净
许多的APIs在TF2.0中已经被移出掉。一些主要的改变包括tf.app
,tf.flags
和tf.logging
,赞成现在开源的项目absl-py, 重新安置了tf.contrib
项目。清理了一些在tf.*
模块到比如tf.math
模块下。还有一些模块在TF2.0中具有相同的功能,比如tf.summary
, tf.keras.metrics
, 和 tf.keras.optimizers
.最简单将这些名字命名为2.0的格式是通过v2 upgrade script
Eager execution
TensorFlow1.x要求通过tf.* API用户手动去缝合一棵抽象的语法树(就是那个graph)。要求用户手动编译抽象的语法树,通过session.run()
输入张量来得到输出张量。在TensorFlow2.0中,用户可以像Python一样,立马得到结果,graph和sessions这些要求是去实现一些细节时候所使用到的功能。
没有更多的Global
TensorFlow1.x中都使用了隐式的全局空间。当你调用tf.Variable()
.这样就会放入到默认的图中。他就会留在图中,即使你忘了将它赋值给一个Python变量保存。当然你可以通过tf.Variable()
来恢复它,哈哈哈,但是你只有在知道它构建时候的名字才能找到它,所以之前我们在TensorFlow1.x中建议大家,尽可能给创建的变量一个名字,而且每一层次使用tf.variable_scope()
去管理是多么的重要。这样导致的结果呢,就会有很多机制去帮助用户去找到他们的创建过的变量,并在框架中找到他们的变量。比如Variable scope, global collections,帮助的方法比如tf.get_global_step()
, tf.global_variables_initializer()
而且优化器隐式的计算了所有变量的梯度等等。TensorFlow2.0将这些机制都去除了,所以,哥们们,实在是太难学了。Variables2.0RFC
还是要遵循默认的机制,保留对变量的跟踪,如果你失去了你的变量tf.Variable()
,那么你将永远试去它,请珍惜它。它会被当作垃圾回收。
这样的话,我们就需要提出新的让用户跟踪自己变量的机制。但是只要是和Keras的对象打交道(随后会讲到),那么你会觉得很轻松。
Functions, not sessions
一个session.run()
的调用,类似于一个函数的调用。你给一个函数输入,这个函数将给你一个输出。在TensorFlow2.0中。你可以将你的Python函数使用tf.function
装饰器装饰。将其标记为JIT编译。因此TensorFlow2.0会将它作为一个单个图运行Functions 2.0 RFC,这种机制会让TensorFlow2.0获得所有图机制的好处(这个改进相当于将每一步运算可以加入到一个子图中,这样可以将所有的图联合起来就是一个模型,这个确实很牛叉)。
- 性能:这个函数可以被优化(节点剪枝,核融合等等)
- 可移植性:这个函数可以被导出核重导入(SavedModel 2.0 RFC)运行用户可以重用核共享模块函数。(我个人感觉这种方式重用性更高)
运行方式:
1 | # TensorFlow 1.X |
这种方式可以让你在TF和Python中随意切换。我们希望用户可以充分的使用Python的表现力,但是方便的同时,在一些移动、c++、和JS中没有Python编译器时候,添加@tf.function
这样的代码就需要用户重构许多代码。所以AutoGraph将转换一个部分Python的子集到与TensorFlow等价的操作。
for
/while
->tf.while_loop
(break
andcontinue
are supported)if
->tf.cond
for _ in dataset
->dataset.reduce
AutoGraph支持任意的嵌套控制流,它可以尽可能的简明扼要实现许多ML程序比如序列模型,强化学习模型,定制的训练循环等等。
TensorFlow2.0中的惯用方式
将你的code重构到一些更加小的函数(这是tf.function装饰器带来的趋势)
在TensorFlow1.x中你习惯了“厨房水槽”方法。将所有的计算联合起来。然后选中Tensor,最后通过session.run()
得到结果。在TensorFlow2.0中。用户应该将你们的代码重构到一些很小的函数中。只是在需要的时候调用。一般,没有必要将所有的函数都加上tf.function
装饰器。你只需要给最高级的计算加上,直接调用即可。举例:一部的训练,或者你模型的一次前馈操作。
使用Keras层和模型管理你的变量
Keras模型和层都提供了方便的variables
和trainable_variables
属性,以递归的方式收集所有依赖的变量。这使得变量管理变得十分简单。
对比:
1 | def dense(x, W, b): |
with the Keras version:
1 | # Each layer can be called, with a signature equivalent to linear(x) |
Keras layers/models 都是继承自tf.train.Checkpointable
也被集成了@tf.function
, 这样可以使得直接从Keras 对象获取到checkpoint或者导出SavedModels。你都没有必要使用Keras的fit()
API去获取这些集成特性。
下面是一个迁移学习的例子,描述了怎么使用Keras方便的收集相关变量的子集。假设你在共享分支上训练一个多输入模型。
1 | trunk = tf.keras.Sequential([...]) # shared trunk |
(上面的例子是在一个共享分支上训练两个模型,先使用一个主要的数据集,来训练trunk分支,然后再通过一个小的数据集来微调trunk分支,典型的迁移思想)。
将tf.data.Datasets
和@tf.function
结合
当你迭代将训练数据填充到内存中,你可以随意使用Python规则的迭代。不然,你也可以使用tf.data.Datasets
,这是一种最好的方式将训练数据从硬盘填充到内存中。Datasets是可迭代的不是迭代器iterables (not iterators) 在Eager模型下,和Python迭代一样工作。你可以完全利用Dataset异步预装载/流特征通过tf.funtion
来装饰你的code。它使用AutoGraph替换了Python迭代的等效图操作。
1 |
|
如果你使用Keras 的fit API,则你不用担心dataset的迭代:
1 | model.compile(optimizer=optimizer, loss=loss_fn) |
利用AutoGraph来代替Python的控制流
AutoGraph提供了一些方法,去转换数据依赖的依赖流进入图中,就像tf.cond
和tf.while_loop
一个常用的数据依赖流出现的地方就是序列模型。tf.keras.layers.RNN
封装了一个RNN单元,允许你静态或者动态的循环展开。为了示范,你可以重新实现如下动态展开:
1 | class DynamicRNN(tf.keras.Model): |
有关更多详细的AutoGraph的功能,请看the guide
使用tf.metrics
去集成data,随后使用tf.summary
去log它
为了log summary,使用tf.summary.(scalar | histogram | ...)
需要重定位它到上下文管理器(如果你忘记了上下文管理器,将什么都不会被记录下来)不想TF1.x, 所有的summaries都会被直接写入到writer中,没有单独的merge操作,和add_summary()调用。这意味着step值必须在调用点提供。
1 | summary_writer = tf.summary.create_file_writer('/tmp/summaries') |
在log他们为summary前为了聚合数据,使用tf.metrics
。Metrics是有状态的,它可以累计数据然后当你调用.result()
时返回一个累计的结果。使用.reset_states()
去清除累计值。
1 | def train(model, optimizer, dataset, log_freq=10): |
在log目录通过Tensorboard可视化生成的结果: tensorboard --logdir /tmp/summaries
.
QQ:329804334
Mail:mizeshuang@gmail.com