在PyTorch框架中model.train() 和 model.eval()的作用是什么?
在PyTorch框架中model.train() 和 model.eval()的作用是什么?
2024年11月14日 00:00 广东
本人是某双一流大学硕士生,也最近刚好准备参加 2024年秋招,在找大模型算法岗实习中,遇到了很多有意思的面试,所以将这些面试题记录下来,并分享给那些和我一样在为一份满意的offer努力着的小伙伴们!!!
在PyTorch框架中model.train() 和 model.eval()的作用是什么?
面试题
大家都知道,PyTorch框架中 model.train() 和 model.eval() 两种方式
那在PyTorch框架中model.train() 和 model.eval()的作用是什么?
简单答案
分别用于:
model.train():模型训练
model.eval():模型评估(推断)
所以一般情况下,在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。
问题延申
虽然你这样回答,基本能够回答上面试官的问题,但是面试官可能并不希望你就这样草草了事,可能更希望你对该问题做一层延申,下面,我将教你如何给面试官一个王者级回答!!!
1、从PyTorch框架源代码角度,分析 model.train()与model.eval()?
要回答这个问题,需要从代码进行分析,通过查看 PyTorch框架源代码,可以发现train()和eval()的源代码位于torch/nn/modules/module.py,属于torch.nn.Module的成员函数,而torch.nn.Module类是我们在自定义网络层和自定义模型时所继承的类,其中相关的函数定义如下:
model.train()
model.train()方法,可以看到这里有个操作是遍历模型的子层并对其调用train()方法,这样子层比如Dropout的training也会被设置为相应的值。
model.eval()
回答
从 torch.nn.Module(以下简称Module)的代码实现来看, model.eval() 无参,直接调用了 model.train(),而后者是有布尔参的。model.train()与model.eval()改变了model的training属性值(True/False),从而让model处于训练或验证模式。
注:model.train()与model.eval() 这两个接口都有动到了self.training这个标志位,而该标志位在Module初始化时是置为True的,因此,模型网络默认是工作在training模式下的。
2、为什么 model.train()与model.eval() 要这样写?
在 model.train() 详细注释中有这样一段描述:
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout
, :class:BatchNorm
, etc.
从上述描述中可以看出,这两个接口是一个标志位开关,用于控制一些 子类(eg:Dropout 和 BatchNorm 等Ops)在训练推理时的工作状态。
3、为什么 model.train()与model.eval() 会影响到 Dropout?
Dropout 网络结构分析:
简单理解就是,Dropout在训练和验证模式下表现不同。所以training的值的改变用来通知Dropout层在训练和验证模式间切换。
如下图,训练模式下,Dropout会以一定的概率让一些神经元值为0。
换句话说,对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。
Dropout 代码实现分析:
def dropout(input,
p: float,
train: bool):
use_cuda = input.is_cuda
# lowering is specialized for cuda because cuda fuser can efficiently fuse those operations
# for cpu backend, where fusions are disabled, a different lowering that is more efficient
# in the absence of fusion is used
p1m = 1. - p
if train:
if use_cuda:
mask = torch.rand_like(input, memory_format=1) < p1m
res = mask.type_as(input) * input * (1./p1m)
else:
mask = torch.empty_like(input, memory_format=1)
mask.bernoulli_(p1m)
res = mask * input / p1m
else:
p1m = 1.
res = input
mask = torch.empty_like(input, memory_format=1)
Dropout在训练时会按照所给概率随机丢弃神经元,可避免过拟合,但是在推理时则需要所有神经元参与,否则推理时结果不固定。
4、为什么 model.train()与model.eval() 会影响到 BatchNorm?
BatchNorm 网络结构分析:
model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;
BatchNorm 代码实现分析:
def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
if training:
norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
else:
norm_mean = torch._unwrap_optional(running_mean)
norm_var = torch._unwrap_optional(running_var)
norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
norm_invstd = 1 / (torch.sqrt(norm_var + eps))
return ((input - norm_mean) * norm_invstd)
BatchNorm同样在训练时,需要根据训练数据更新维护 norm_mean 和 norm_var 这两个变量,而在推理时则需要使用全局的mean和var,否则效果会打折扣。
由于 model.train() 和 model.eval() 都有对 model.training 产生影响,因此,我们在定义模型网络时,亦可以利用该标志位进行一些分支设计。
3、model.train()与model.eval() 会影响梯度反向传播么?
model.train()与model.eval() 这两个接口并不会影响到梯度反向传播部分,因此,当我们在推理时,最好能加上如下的类似代码,主要是“with torch.no_grad():”这个上下文管理器结构,这样可以减小内存耗用和加快推理速度。
# evaluate model
model.eval()
with torch.no_grad():
...
output = model(input)
...
(完)