PyTorch 提供了一种非常方便的节省显存的方式,就是 Checkpoint 机制。这篇文章的目的在于更透彻地了解其内在的机制。
Checkpoint 机制
该技术的核心是一种使用时间换空间的策略。在现有的许多方法中被大量使用,例如 DenseNet、Swin Transformer 源码中都可以看到它的身影。
为了了解它的工作原理,我们先得弄明白的一个问题是,PyTorch 模型在训练过程中显存占用主要是用来存储什么?
关于这一点,相关技术文档介绍的非常详细:
开门见山的说,PyTorch 在进行深度学习训练的时候,有 4 大部分的显存开销,分别是模型参数 (parameters) ,模型参数的梯度 (gradients) ,优化器状态 (optimizer states) 以及 中间激活值 (intermediate activations) 或者叫中间结果 (intermediate results)。
而通过 Checkpoint 技术,我们可以通过一种取巧的方式,使用 PyTorch 提供的 torch.no_grad() 模式来避免将这部分运算被 autograd 记录到反向图'backward graph'中,从而避免了对于中间激活值的存储需求。
个人理解(欢迎指出错误):
前向传播时 autograd 记录各个操作反向传播需要的一些信息和中间变量。反向传播之后,用于计算梯度的中间结果会被释放。也就是说,模型参数、优化器状态和参数梯度是始终在占用着存储空间的,中间激活值在反向传播之后就自动被清空了。具体显存占用变化可见相关测试,这里我简单修改了示例。
这里实际上会引申出另一个问题,为什么自定义 Function 一般情况下会减少显存占用? (在 Vision Longformer 中各种实现的对比里可以明显看到这一现象)
我觉得主要是因为自定义 Function 的时候,我们可以从一整个模块的角度来更有针对性的在 ctx 中存储中间变量,而自动求导引擎可能关注的太细了,导致存储许多不必要的中间变量。关于这一点暂时不知道如何验证。
这可以避免存储模型特定层中间运算结果,从而有效降低了前向传播中显存的占用。 这些中间结果会在反向传播的时候被即时重新计算一次。要注意,被 checkpoint 包裹的层反向传播时仍然会在第一次反向传播的时候开辟存储梯度的空间。
因为 checkpoint 是在 torch.no_grad() 模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。
通过这样的方式,虽然延长了反向传播的时间,但是却也在一定程度上缓解了存储大量中间变量带来的显存占用。
源码解析
以下代码来自 PyTorch v1.10.1 版本。最新的版本中补充了一些新的内容,待其最终发布后再说吧,下面的内容本身已经将 checkpoint 的核心介绍了。
辅助函数
这部分代码中首先构造了数个辅助函数,主要是用来做一些针对输入的检查和处理,同时也要处理好随机种子的问题。
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def check_backward_validity(inputs: Iterable[Any]) -> None:
"""检查输入参数是否至少有一个需要记录梯度的 Tensor,这样才能确保输出也有梯度。"""
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
由于需要重复计算,所以随机状态的一致性是需要重视的。由于前向传播的部分在反向过程中仍会计算一次,所以如果不使用原始的随机状态的话,会导致重新计算和原本正常计算过程中的随机状态不同,而影响模型的行为。
另外在这段代码的注释中提到了一点有趣的地方:
由于无法获悉被 checkpoint 处理的操作是否在运算中间会将一些参数移动到不同的设备上,这可能需要手动保存这些设备对应的随机状态。当前的实现直接保存了所有可见设备上的随机状态,但是这样有时可能是不必要的,但是目前尚没有较好的解决策略。
所以按照文档的意思,就是在说如果没有这样的移动,那就可以不用保存随机状态咯?这一点其实有些令人疑惑。
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
"""获取不同输入对应的 GPU 设备的随机数生成器的状态"""
fwd_gpu_devices = list(set(arg.get_device() for arg in args
if isinstance(arg, torch.Tensor) and arg.is_cuda))
fwd_gpu_states = []
for device in fwd_gpu_devices:
with torch.cuda.device(device):
fwd_gpu_states.append(torch.cuda.get_rng_state())
return fwd_gpu_devices, fwd_gpu_states
def set_device_states(devices, states) -> None:
"""针对不同的设备设置随机数生成器的状态"""
for device, state in zip(devices, states):
with torch.cuda.device(device):
torch.cuda.set_rng_state(state)
核心 Function
可以看到,这里的 Checkpoint 本身就是基于 PyTorch 的 torch.autograd.Function 实现的一个扩展算子,所以该部分代码也涉及到了 Function 的诸多功能。
阅读它既可以帮助我们同时复习一下相关的知识,又能进一步了解更复杂的处理逻辑该如何搭建。
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
torch.no_grad():
outputs = run_function(*args)
outputs
():
torch.autograd._is_checkpoint_valid():
RuntimeError(
)
inputs = (ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
i, idx (tensor_indices):
inputs[idx] = tensors[i]
rng_devices = []
ctx.preserve_rng_state ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable((inputs))
torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
outputs = ctx.run_function(*detached_inputs)
(outputs, torch.Tensor):
outputs = (outputs,)
outputs_with_grad = []
args_with_grad = []
i ((outputs)):
torch.is_tensor(outputs[i]) outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
(outputs_with_grad) == :
RuntimeError(
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = (inp.grad (inp, torch.Tensor)
inp detached_inputs)
(, ) + grads
这里实际上就是在原始的操作和整体的计算图之间添加了一个中间层,用于信息的交互:
- 原始模型的数据传输到被包裹的目标层的时候,数据进入 checkpoint 的
forward() 中,被 checkpoint 进行检查和记录后,再送入目标层中;
- 目标层在非梯度模式下执行前向传播。该模式下,新创建的 tensor 都是不会记录梯度信息的;
- 目标层的结果通过 checkpoint 的前向传播输出,送入模型后续的其他结构中;
- 执行反向传播,损失求导,链式回传,计算梯度;
- 回传回来的对应于 checkpoint 输出的梯度被送入其对应的反向传播函数,即 checkpoint 的
backward()。
- 梯度送入 checkpoint 中后,需要进一步将梯度回传到目标层的输入上。由于在 checkpoint 的
forward 中目标层本身前向传播是处于非梯度状态下,所以回传路径上缺少目标层中操作的梯度子图。于是为了获取这部分信息,需要先梯度状态下对目标层进行一次前向传播,通过将回传回来的梯度和目标层的输出一起执行 torch.autograd.backward(outputs_with_grad, args_with_grad),从而获得对应输入的梯度信息。
- 将对应目标操作输入的梯度信息按照 checkpoint 本身 Function 的
backward 的需求,使用 None 对其他辅助参数的梯度占位后进行返回。
- 返回的对应于其他模块的输出量的梯度,被沿着反向传播的路径送入对应操作的
backward 中,一层一层回传累加到各个叶子节点上。
定义好操作后,进行一个简单的包装,同时处理一下默认参数,补充了更细致的文档:
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
intermediate activations of the entire computation graph for computing
backward, the checkpointed part does **not** save intermediate activations,
and instead recomputes them in backward pass. It can be applied on any part
of a model.
Specifically, in the forward pass, :attr:`function` will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. Instead, the forward pass saves the inputs tuple and the
:attr:`function` parameter. In the backwards pass, the saved inputs and
:attr:`function` is retrieved, and the forward pass is computed on
:attr:`function` again, now tracking the intermediate activations, and then
the gradients are calculated using these activation values.
这一段详细介绍了 checkpoint 的核心技术,也就是在非梯度模式下执行目标操作的前向传播,只保留输入和结构参数,省去了中间激活的保存。反向传播时在梯度模式下重新计算这些激活,重建这部分反向图,进而实现了梯度的正常回传。
The output of :attr:`function` can contain non-Tensor values and gradient
recording is only performed for the Tensor values. Note that if the output
consists of nested structures (ex: custom objects, lists, dicts etc.)
consisting of Tensors, these Tensors nested in custom structures will not
be considered as part of autograd.
因为 checkpoint 的 backward 实现的逻辑中,直接遍历目标操作的输出(会被自定转换成元组类型)并确定那些需要回流梯度的输出。如果输出中包含其他的非 tensor 结构,就会导致在遍历过程中这些输出被忽略掉。不过也确实,这样直接简化处理虽然使得灵活性下降,但是却也避免了代码过于复杂。
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning::
If :attr:`function` invocation during backward does anything different
than the one during forward, e.g., due to some global variable, the
checkpointed version won't be equivalent, and unfortunately it can't be detected.
尽量保证目标操作在反向计算期间和前向期间的操作的一致性。
因为在 checkpoint 会在反向中重新计算一次前向,这可能会带来一些由于无法检测到的不确定因素而造成的与常规版本的差异。
.. warning::
If checkpointed segment contains tensors detached from the computational
graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
error. This is because `checkpoint` makes all the outputs require
gradients which causes issues when a tensor is defined to have no
gradient in the model. To circumvent this, detach the tensors outside of the `checkpoint` function.
不要在目标操作中包含 detach 或者非梯度模式的处理。
**在我的实际测试中似乎并没有这个问题?**或许这里应该看一下 pytorch 提供的测试案例。
.. warning::
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients. At least one of the outputs needs to have
:code:`requires_grad=True` as well.
要保证至少有一个输入是 requires_grad 的,这样才可以保证这部分操作可以被记录梯度。
也要保证输出至少有一个需要计算梯度。
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
args: tuple containing inputs to the :attr:`function`
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
preserve = kwargs.pop(, )
kwargs:
ValueError( + .join(arg arg kwargs))
CheckpointFunction.apply(function, preserve, *args)
应用案例
Checkpoint for Sequential
PyTorch 源码中给了一个很直接的应用案例,就是将 checkpoint 应用于 Sequential 搭建起来的模型。按照分段数 segments 指定的,将模型划分为多段。
def checkpoint_sequential(functions, segments, input, **kwargs):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a model in various segments
and checkpoint each segment. All segments except the last will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. The inputs of each checkpointed segment will be saved for
re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
.. warning:
Since PyTorch 1.4, it allows only one Tensor as the input and
intermediate outputs, just like :class:`torch.nn.Sequential`.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or
functions (comprising the model) to run sequentially.
segments: Number of chunks to create in the model
input: A Tensor that is input to :attr:`functions`
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
"""
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
def run_function(start, end, functions):
def forward(input):
for j (start, end + ):
= functions[j]()
forward
(functions, torch.nn.Sequential):
functions = (functions.children())
segment_size = (functions) // segments
end = -
start (, segment_size * (segments - ), segment_size):
end = start + segment_size -
= checkpoint(run_function(start, end, functions), ,
preserve_rng_state=preserve)
run_function(end + , (functions) - , functions)()
总结
PyTorch 的 Checkpoint 机制是深度学习训练中显存优化的重要工具。它通过牺牲一定的计算时间来换取显存空间的节省,特别适用于深层网络或 Batch Size 受限的场景。在使用时,开发者需注意以下几点:
- 梯度要求:确保至少有一个输入和一个输出设置了
requires_grad=True,否则 Checkpoint 将无效。
- 随机性一致性:开启
preserve_rng_state 以保证反向重计算时的随机状态与前向一致,这对包含 Dropout 等随机操作的模型至关重要。
- 兼容性限制:不支持
.grad() 接口,需使用 .backward();且被 Checkpoint 包裹的段内不应包含额外的 detach() 或 no_grad() 操作。
- 适用场景:适合内存敏感但计算资源相对充裕的情况,如训练 Transformer 类大模型或高分辨率图像分割网络。
合理运用 Checkpoint 机制,可以在有限的硬件资源下训练更大的模型,是构建高效深度学习系统的关键技术之一。