PyTorch 提供了一种非常方便的节省显存的方式,就是 Checkpoint 机制。这篇文章的目的在于更透彻地了解其内在的机制。
Checkpoint 机制
该技术的核心是一种使用时间换空间的策略。在现有的许多方法中被大量使用,例如 DenseNet、Swin Transformer 源码中都可以看到它的身影。
为了了解它的工作原理,我们先得弄明白的一个问题是,PyTorch 模型在训练过程中显存占用主要是用来存储什么?
关于这一点,相关技术文档介绍的非常详细:
开门见山的说,PyTorch 在进行深度学习训练的时候,有 4 大部分的显存开销,分别是模型参数 (parameters) ,模型参数的梯度 (gradients) ,优化器状态 (optimizer states) 以及 中间激活值 (intermediate activations) 或者叫中间结果 (intermediate results)。
而通过 Checkpoint 技术,我们可以通过一种取巧的方式,使用 PyTorch 提供的 torch.no_grad() 模式来避免将这部分运算被 autograd 记录到反向图'backward graph'中,从而避免了对于中间激活值的存储需求。
个人理解(欢迎指出错误):
前向传播时 autograd 记录各个操作反向传播需要的一些信息和中间变量。反向传播之后,用于计算梯度的中间结果会被释放。也就是说,模型参数、优化器状态和参数梯度是始终在占用着存储空间的,中间激活值在反向传播之后就自动被清空了。具体显存占用变化可见相关测试,这里我简单修改了示例。
这里实际上会引申出另一个问题,为什么自定义 Function 一般情况下会减少显存占用? (在 Vision Longformer 中各种实现的对比里可以明显看到这一现象)
我觉得主要是因为自定义 Function 的时候,我们可以从一整个模块的角度来更有针对性的在
ctx中存储中间变量,而自动求导引擎可能关注的太细了,导致存储许多不必要的中间变量。关于这一点暂时不知道如何验证。
这可以避免存储模型特定层中间运算结果,从而有效降低了前向传播中显存的占用。 这些中间结果会在反向传播的时候被即时重新计算一次。要注意,被 checkpoint 包裹的层反向传播时仍然会在第一次反向传播的时候开辟存储梯度的空间。
因为 checkpoint 是在
torch.no_grad()模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。
通过这样的方式,虽然延长了反向传播的时间,但是却也在一定程度上缓解了存储大量中间变量带来的显存占用。
源码解析
以下代码来自 PyTorch v1.10.1 版本。最新的版本中补充了一些新的内容,待其最终发布后再说吧,下面的内容本身已经将 checkpoint 的核心介绍了。
辅助函数
这部分代码中首先构造了数个辅助函数,主要是用来做一些针对输入的检查和处理,同时也要处理好随机种子的问题。
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
(out)
:
RuntimeError(
, (inputs).__name__)
() -> :
(inp.requires_grad inp inputs (inp, torch.Tensor)):
warnings.warn()


