在PyTorch框架中model.train() 和 model.eval()的作用是什么?

在PyTorch框架中model.train() 和 model.eval()的作用是什么?

在PyTorch框架中model.train() 和 model.eval()的作用是什么?

2024年11月14日 00:00 广东

本人是某双一流大学硕士生,也最近刚好准备参加 2024年秋招,在找大模型算法岗实习中,遇到了很多有意思的面试,所以将这些面试题记录下来,并分享给那些和我一样在为一份满意的offer努力着的小伙伴们!!!

在PyTorch框架中model.train() 和 model.eval()的作用是什么?

面试题

大家都知道,PyTorch框架中 model.train() 和 model.eval() 两种方式

那在PyTorch框架中model.train() 和 model.eval()的作用是什么?

简单答案

分别用于:

model.train():模型训练

model.eval():模型评估(推断)

所以一般情况下,在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

问题延申

虽然你这样回答,基本能够回答上面试官的问题,但是面试官可能并不希望你就这样草草了事,可能更希望你对该问题做一层延申,下面,我将教你如何给面试官一个王者级回答!!!

1、从PyTorch框架源代码角度,分析 model.train()与model.eval()?

要回答这个问题,需要从代码进行分析,通过查看 PyTorch框架源代码,可以发现train()和eval()的源代码位于torch/nn/modules/module.py,属于torch.nn.Module的成员函数,而torch.nn.Module类是我们在自定义网络层和自定义模型时所继承的类,其中相关的函数定义如下:

model.train()

www.zeeklog.com  - 在PyTorch框架中model.train() 和 model.eval()的作用是什么?

model.train()方法,可以看到这里有个操作是遍历模型的子层并对其调用train()方法,这样子层比如Dropout的training也会被设置为相应的值。

model.eval()

www.zeeklog.com  - 在PyTorch框架中model.train() 和 model.eval()的作用是什么?

回答

从 torch.nn.Module(以下简称Module)的代码实现来看, model.eval() 无参,直接调用了 model.train(),而后者是有布尔参的。model.train()与model.eval()改变了model的training属性值(True/False),从而让model处于训练或验证模式。

注:model.train()与model.eval() 这两个接口都有动到了self.training这个标志位,而该标志位在Module初始化时是置为True的,因此,模型网络默认是工作在training模式下的。

2、为什么 model.train()与model.eval() 要这样写?

在 model.train() 详细注释中有这样一段描述:

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:Dropout, :class:BatchNorm, etc.

从上述描述中可以看出,这两个接口是一个标志位开关,用于控制一些 子类(eg:Dropout 和 BatchNorm 等Ops)在训练推理时的工作状态。

3、为什么 model.train()与model.eval() 会影响到 Dropout?

Dropout 网络结构分析:

简单理解就是,Dropout在训练和验证模式下表现不同。所以training的值的改变用来通知Dropout层在训练和验证模式间切换。

如下图,训练模式下,Dropout会以一定的概率让一些神经元值为0。

换句话说,对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

www.zeeklog.com  - 在PyTorch框架中model.train() 和 model.eval()的作用是什么?

Dropout 代码实现分析:

def dropout(input,
            p: float,
            train: bool):
    use_cuda = input.is_cuda
    # lowering is specialized for cuda because cuda fuser can efficiently fuse those operations
    # for cpu backend, where fusions are disabled, a different lowering that is more efficient
    # in the absence of fusion is used
    p1m = 1. - p
    if train:
        if use_cuda:
            mask = torch.rand_like(input, memory_format=1) < p1m
            res = mask.type_as(input) * input * (1./p1m)
        else:
            mask = torch.empty_like(input, memory_format=1)
            mask.bernoulli_(p1m)
            res = mask * input / p1m
    else:
        p1m = 1.
        res = input
        mask = torch.empty_like(input, memory_format=1)

Dropout在训练时会按照所给概率随机丢弃神经元,可避免过拟合,但是在推理时则需要所有神经元参与,否则推理时结果不固定。

4、为什么 model.train()与model.eval() 会影响到 BatchNorm?

BatchNorm 网络结构分析:

model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

BatchNorm 代码实现分析:

def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
    if training:
        norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
    else:
        norm_mean = torch._unwrap_optional(running_mean)
        norm_var = torch._unwrap_optional(running_var)
    norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
    norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
    norm_invstd = 1 / (torch.sqrt(norm_var + eps))
    return ((input - norm_mean) * norm_invstd)

BatchNorm同样在训练时,需要根据训练数据更新维护 norm_mean 和 norm_var 这两个变量,而在推理时则需要使用全局的mean和var,否则效果会打折扣。

由于 model.train() 和 model.eval() 都有对 model.training 产生影响,因此,我们在定义模型网络时,亦可以利用该标志位进行一些分支设计。

3、model.train()与model.eval() 会影响梯度反向传播么?

model.train()与model.eval()  这两个接口并不会影响到梯度反向传播部分,因此,当我们在推理时,最好能加上如下的类似代码,主要是“with torch.no_grad():”这个上下文管理器结构,这样可以减小内存耗用和加快推理速度。

# evaluate model
model.eval()

with torch.no_grad():
    ...
    output = model(input)
    ...

(完)

Read more

深入理解 Proxy 和 Object.defineProperty

在JavaScript中,对象是一种核心的数据结构,而对对象的操作也是开发中经常遇到的任务。在这个过程中,我们经常会使用到两个重要的特性:Proxy和Object.defineProperty。这两者都允许我们在对象上进行拦截和自定义操作,但它们在实现方式、应用场景和灵活性等方面存在一些显著的区别。本文将深入比较Proxy和Object.defineProperty,包括它们的基本概念、使用示例以及适用场景,以帮助读者更好地理解和运用这两个特性。 1. Object.defineProperty 1.1 基本概念 Object.defineProperty 是 ECMAScript 5 引入的一个方法,用于直接在对象上定义新属性或修改已有属性。它的基本语法如下: javascript 代码解读复制代码Object.defineProperty(obj, prop, descriptor); 其中,obj是目标对象,prop是要定义或修改的属性名,descriptor是一个描述符对象,用于定义属性的特性。 1.2 使用示例 javascript 代码解读复制代码//

By Ne0inhk

Proxy 和 Object.defineProperty 的区别

Proxy 和 Object.defineProperty 是 JavaScript 中两个不同的特性,它们的作用也不完全相同。 Object.defineProperty 允许你在一个对象上定义一个新属性或者修改一个已有属性。通过这个方法你可以精确地定义属性的特征,比如它是否可写、可枚举、可配置等。该方法的使用场景通常是需要在一个对象上创建一个属性,然后控制这个属性的行为。 Proxy 也可以用来代理一个对象,但是相比于 Object.defineProperty,它提供了更加强大的功能。使用 Proxy 可以截获并重定义对象的基本操作,比如访问属性、赋值、函数调用等等。在这些操作被执行之前,可以通过拦截器函数对这些操作进行拦截和修改。因此,通过 Proxy,你可以完全重写一个对象的默认行为。该方法的使用场景通常是需要对一个对象的行为进行定制化,或者需要在对象上添加额外的功能。 对比 以下是 Proxy 和 Object.defineProperty 的一些区别对比: 方面ProxyObject.defineProperty语法使用 new Proxy(target,

By Ne0inhk