tensorflow模型优化技巧

    xiaoxiao2021-03-25  188

    当把模型跑起来后,开始考虑如何优化model,提升性能,从网上找了找资料,并结合实际,整理一下分享给大家。

    预处理数据

    说道预处理数据,我觉得我自己做的还是不少,学习tensorflow时候,把mnist_soft.py跑起来以后,就开始思考mnist数据是什么数据?带着这个疑问我开始尝试制作自己的数据集,期间使用了很多的方法,如二进制文件,直接读取图片进内存等等。其实仔细想想可以知道,如果预处理数据没有制作好,会直接影响后续tensorflow读取数据的速度。但如果你觉得提高机器性能可以的话,那我只能说,就像是兰博基尼跑在泥泞的道路上。所有一定得使预处理数据干净, tensorflow官方提供的数据格式TFRecord 是一个很不错的选择,可以试着制作一下 最近发布了预处理组件:tf.Transfrom() 有兴趣点击如下地址了解 http://www.leiphone.com/news/201702/Yi4oU1mSwKLc8Rad.html

    使用队列

    队列的优势就不说了,把预处理数据放进队列,怎么出自己控制。有一种发现昂贵的预处理管道的方法是查看 Tensorboard 的队列图。如果你使用框架 QueueRunners并将摘要存储在文件中,这些图都是自动生成的。这些图会显示你的计算机是否能够保持队列处在排满的状态。如果你发现图当中出现了负峰值,则系统无法在计算机要处理一个批次的时间内生成新的数据。其中的一个原因上面已经说过了。根据我的经验,最常见的原因是 min_after_dequeue 值很大。如果队列试图在内存中保留大量记录,你的容量很容易就饱和了,这会导致交换(swapping),并且显著降低队列的速度。其他的原因还包括硬盘问题(例如磁盘速度慢),以及单纯的是数据大,大过了你系统可以处理的程度。无论原因为何,修复这个问题都会加快你的训练过程。

    注意内存

    确定整个模型的内存消耗没有超出机器内存,如果超出了,必然使用swapping,而 swapping 肯定会让输入流程放慢,会让你的 GPU 开始坐等新数据。如何侦探这个行为呢?一个简单地 top,就像下文讲到的 TensorBoard 队列图就应当足够侦测到这样的行为。

    tensorboard

    说道tensorboard,不得不说就是它对于tensorflow的可视化分析太有用了,不仅可以对当前运行的graph 进行流式图分析,还能进行性能监控。

    # Collect tracing information during the fifth step. if global_step == 5: # Create an object to hold the tracing data run_metadata = tf.RunMetadata() # Run one step and collect the tracing data _, loss = sess.run([train_op, loss_op], options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) # Add summary to the summary writer summary_writer.add_run_metadata(run_metadata, 'step%d', global_step)

    之后,一个 timeline.json 文件会被保存到当前文件夹,跟踪数据可以在 Tensorboard 找到。现在,你可以很容易地看到一个操作花了多长时间来计算,以及这个操作消耗了多少内存。打开Tensorboard的图视图,选择左侧的最新运行,你就能在右边看到性能的详细信息。一方面,这方便你调整模型,尽可能多地使用机器;另一方面,这方便你在训练管道中发现瓶颈。如果你更喜欢时间轴视图,在 Google Chromes 跟踪事件分析工具(Trace Event Profiling Tool)中加载timeline.json 文件就行了。   另一个不错的工具是 tfprof,tfprof 使用相同的功能做内存和执行时间分析,不过提供了更多的便利功能(feature)。额外的统计信息需要更改代码。

    Debug

    作为开发人员,这个我就不说了。 提示: TensorFlow 1.0 推出了新的 TFDebugger,应该很有用的,这是一篇关于它的介绍

    设置运算超时时间

    当我们点击运行的时候,session 也启动了,但没有事情都没有什么发生?这通常是由空队列引起的。但是,如果你不知道是哪一个队列导致的,那么有一个简单的修复方法:只需在创建会话时启用一个操作执行超时,这样当操作超过限制时,脚本就会崩溃:

    config = tf.ConfigProto() config.operation_timeout_in_ms=5000 sess = tf.Session(config=config)

    使用堆栈跟踪,你就可以找出是哪个操作产生了问题,修复错误,继续训练吧。 参考文章: http://www.deeplearningweekly.com/blog/tensorflow-quick-tips http://it.sohu.com/20170221/n481292049.shtml

    转载请注明原文地址: https://ju.6miu.com/read-416.html

    最新回复(0)