torch.utils.checkpoint模块

当训练模型,出现了显存不足的问题时,可以利用torch.utils.checkpoint模块下的checkpoint函数来减少训练时的显存占用,但是代价是训练过程会变慢,本质上是用速度换取内存。官方文档地址:torch.utils.checkpoint — PyTorch 2.7 documentation

不使用checkpoint时,前向传播的过程中会保存中间的activation,用于后续反向传播时计算梯度。但使用checkpoint,会只保存输入tuple和模型参数,不保存中间的activation,在反向传播时,使用保存的输入重新计算中间的activation,相当于重新跑了一遍正常的前向传播。反向传播完成后,把activation释放。

checkpoint的正确使用方式应该是将网络分为若干的模块,对每个模块应用checkpoint,这样反向传播时,依从后往前的顺序对每个模块计算梯度,但只有在计算到该模块时,该模块才有activation的内存占用,计算完成或没计算到时就没有。checkpoint不应该对网络整体使用,否则就和正常的训练过程一样占用同样的显存。

注意dropout和batch normalization层不能用checkpoint

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注