除去从检查点恢复前向传播结果张量带来的主要开销,DTR 的额外开销在于寻找应该被释放的最优张量,即计算上图张量 t 的 f(t)值。为了降低这一部分的计算量,MegEngine 还采用了两种运行时优化:
- 不考虑小的张量,它们不加入候选集
- 每次在需要释放张量的时候,随机采样并遍历少部分张量,以节省计算开销
最难的是工程实现
虽然 DTR 看上去原理也不复杂,但真正的难题在于提高易用性,即将所有细节都隐藏到框架的底层,只为开发者提供最简单的接口。
在此就用一个最简单的计算例子,跟着框架演算一遍,看看 MegEngine 是如何利用动态图的计算历史恢复与释放张量的。
现在假设输入有 a 和 b 两个张量,并希望计算 a*b 与 a b,但是显存最大只能保存三个张量。在黄框计算 c=a b 时,显存还能保留张量 c,然而在下一步绿框计算 d=a*b 时只能先释放 c 才能保存 d。
不巧的是,下一步灰框需要获取黄框的计算结果,然而为了节省显存,c 已经被释放了。所以,MegEngine 现在需要做的是重新运行灰框的计算图,计算 c=a b,并加载到显存中。显然,这样做必然需要释放 d 的显存。
这样一来,鉴于显存的限制,MegEngine 就会自动选择合适的张量释放,并在需要时重新计算。如果需要重新计算某个张量的结果,例如上图的 d,就需要具体的历史计算信息(在这里就是 a b 这样的计算路径),与此同时还需要知道 a 和 b 这两个输入张量。
所有这样的历史计算信息都由 MegEngine 自动获取与保存,MegEngine 的工程师已经在底层用 C 处理完毕,用户完全不需要考虑。
struct ComputePath {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
double compute_time = 0;
} *producer;
SmallVector<ComputePath*> users;
size_t ref_cnt = 0;
以上为 MegEngine 底层用于追踪计算路径信息的结构体。其中 op 表示产生该张量的算子;inputs 和 outputs 分别表示这个算子需要的输入与输出张量;compute_time 表示该算子实际的运行时间。
实际上,在使用 MegEngine 的过程中,全都是用 Python 接口创建张量,只不过框架会对应追踪每个张量的具体信息。每当需要访问张量,不用考虑张量是否在显存中时,没有也能立刻恢复出来。所有这些复杂的工程化的操作与运算逻辑都隐藏在了 MegEngine C 底层。
Python 代码会翻译成 C 底层实现,C 代码会通过指针管理显卡内存中真正的张量(右图绿色部分)。
幸好这样的复杂操作不需要算法工程师完成,都交给 MegEngine 好了。
MegEngine 能做的事情远不止于此,只不过大多是像动态图显存优化这种技术一样,润物细无声地把用户的实际问题解决于无形。2020 年 3 月开源的 MegEngine 在以肉眼可见的速度快速成长,从静态计算图到动态计算图,再到持续提升的训练能力、移动端推理性能优化、动态显存优化…… 这也许就是开源的魅力。只有不断优化和创新,才能吸引和满足「挑剔」的开发者。MegEngine 下一个推出的功能会是什么?让我们拭目以待。