如果不使用动态图显存优化技术,PyTorch 上的模型一次训练迭代最多只能处理 64 个样本,MegEngine 能处理 100 个样本。只要加上 DTR,PyTorch 模型一次迭代就能处理 140 个样本,MegEngine 能尝试处理 300 个样本。
如果换算成模型大小,加上动态图显存优化技术的 MegEngine,在相同的 GPU 及批大小情况下,能高效训练增大近乎 5 倍的模型。
MegEngine 动态图显存优化技术
深度学习模型的显存占用一般分为权重矩阵、前向传播的中间张量、反向传播的梯度矩阵(Adam 优化器)三部分。
权重矩阵和梯度矩阵占的内存很难优化,各个模型基本上都有一个定值。前向传播的中间计算结果则不然:随着 Batch Size 的增加以及模型层和数量的增加,显存必然跟着增加。如果模型比较大,中间计算结果将占据最主要的显存。
如上图所示,在前向传播中(第一行从左到右),蓝色圆圈表示模型的中间计算结果开始占用显存。一直到前向传播完成,第一行完全变为蓝色圆圈,前面计算所占用的显存都不能释放。
等到反向传播开始(第二行从右到左),随着梯度的计算与完成应用,前向传播保留在显存中的张量才可以释放。
很明显,如果要降低显存占用,就要拿前向传播保存的中间计算结果开刀,这也正是 MegEngine 动态图显存优化的主要方向。
用计算换显存
对于动态计算图,最直接的方法就是用计算或内存换显存。因此,MegEngine 首先要决定到底使用哪种技术。
MegEngine 团队通过实验发现,用计算耗时远比交换耗时少。例如从显存中节省 612.5MB 空间,用带宽换显存要比用计算换显存慢了几十上百倍。
因此很明确,动态计算图中也应该使用梯度检查点技术,用计算换显存。
如下为梯度检查点技术原理示意,前向传播中第三个点为检查点,它会一直保存在显存中。第四个点在完成计算后即可释放显存,在反向传播中如果需要第四个点的值,可以从第三个点重新计算出第四个点的值。
虽然大致原理不难理解,但具体怎么做还是比较复杂的,MegEngine 团队借鉴了论文《Dynamic Tensor Rematerialization》,将其优化并实现到 MegEngine 中。
DTR,最前沿的显存优化技术
DTR 是一种完全动态的启发式策略,核心思想是当显存超过某个阈值时,动态地释放一些合适的张量,直到显存低于阈值。一般而言,释放张量的标准有三个:重新计算出该张量的开销越小越好;占用的显存越大越好;在显存中停留的时间越长越好。