[llamafactory预训练源码解析trainer]src >llamafactory >train >pt>trainer.py CustomTrainer 类通过添加自定义回调、优化器和调度器
from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
import torch
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
"""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
llama factory
框架的代码逐行解释:
from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
- 导入需要的模块和函数:
MethodType
:用于动态绑定方法。TYPE_CHECKING
和Optional
:用于类型检查和可选类型。Trainer
:transformers
库中的训练器类。- 自定义模块的导入,用于日志记录、回调函数和自定义优化器、调度器创建函数。
python
复制
if TYPE_CHECKING:
import torch
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments
- 类型检查时导入的模块和类型:
torch
:用于深度学习。ProcessorMixin
:transformers
库中的处理器混合类。FinetuningArguments
:自定义的微调参数类。
python
复制
logger = get_logger(__name__)
- 初始化日志记录器。
class CustomTrainer(Trainer):
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
"""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
- 定义
CustomTrainer
类,继承自Trainer
,用于自定义优化器。 - 初始化方法
__init__
: - 调用父类的初始化方法。
- 保存微调参数。
- 如果提供了处理器,添加
SaveProcessorCallback
回调。 - 如果启用了
pissa_convert
,添加PissaConvertCallback
回调。 - 如果使用 BAdam 优化器:
- 从
badam
模块导入BAdamCallback
和旧版的梯度裁剪函数clip_grad_norm_old_version
。 - 动态绑定旧版的梯度裁剪函数到
accelerator
对象。 - 添加
BAdamCallback
回调。
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if processor is not None:
这行代码检查 processor
变量是否不为 None
。如果 processor
有值(不是 None
),则执行下面的代码块。
self.add_callback(SaveProcessorCallback(processor))
这里调用 SaveProcessorCallback
类的构造函数,并传递 processor
作为参数,然后将返回的对象传递给 self.add_callback
方法。
processing_utils.py定义ProcessorMixin(PushToHubMixin)类SaveProcessorCallback(processor)
很可能是一个回调函数或回调类,用于处理与 processor
相关的保存操作。
python
复制
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.pissa_convert:
检查 finetuning_args
对象的 pissa_convert
属性是否为真。如果为真,则执行下面的代码块。
self.add_callback(PissaConvertCallback)
直接将 PissaConvertCallback
传递给 self.add_callback
方法。PissaConvertCallback
可能是一个类或函数,用于进行 PISSA 转换操作。
python
复制
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_badam:
检查 finetuning_args
对象的 use_badam
属性是否为真。如果为真,则执行下面的代码块。
from badam import BAdamCallback, clip_grad_norm_old_version
动态导入 badam
模块中的 BAdamCallback
和 clip_grad_norm_old_version
。这通常用来实现某些特定功能,只有在需要时才导入,避免不必要的开销。
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
这里将 clip_grad_norm_old_version
方法绑定到 self.accelerator
对象的 clip_grad_norm_
属性上。这种方式允许你动态替换或添加方法到对象实例上。
self.add_callback(BAdamCallback)
将 BAdamCallback
回调添加到 self
的回调列表中。BAdamCallback
可能是一个类或函数,用于处理 BAdam 优化器相关的逻辑。
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
- 重写
create_optimizer
方法: - 如果优化器未定义,则创建自定义优化器。
- 调用父类的
create_optimizer
方法。
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- 重写
create_scheduler
方法: - 调用自定义调度器创建函数。
- 调用父类的
create_scheduler
方法。
继续解释 CustomTrainer
类的 create_scheduler
方法:
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
- 重写
create_scheduler
方法: create_scheduler
方法的目的是创建一个学习率调度器。- 参数:
num_training_steps
:训练步骤的数量。optimizer
(可选):优化器对象。- 调用
create_custom_scheduler
函数来创建一个自定义的学习率调度器。 - 调用父类的
create_scheduler
方法,将num_training_steps
和optimizer
传递给它。
总结:
- 这个
CustomTrainer
类扩展了transformers
库中的Trainer
类,增加了对自定义优化器和调度器的支持。 - 它还可以根据微调参数条件添加特定的回调函数,比如保存处理器和应用
PissaConvert
或BAdam
优化器相关的回调。 - 通过重写
create_optimizer
和create_scheduler
方法,这个类可以使用自定义的优化器和调度器。
完整代码:
python
复制
from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
import torch
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
"""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
super().__init__(**kwargs)
self.finetuning_args = finetuning_args
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
这个 CustomTrainer
类通过添加自定义回调、优化器和调度器,实现了对 Trainer
的功能扩展,使其更适合特定的微调需求。
好的,我们逐行解释这段代码,并详细讨论 processor
可能是什么。
代码逐行解释
python
复制
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if processor is not None:
这行代码检查 processor
变量是否不为 None
。如果 processor
有值(不是 None
),则执行下面的代码块。
self.add_callback(SaveProcessorCallback(processor))
这里调用 SaveProcessorCallback
类的构造函数,并传递 processor
作为参数,然后将返回的对象传递给 self.add_callback
方法。SaveProcessorCallback(processor)
很可能是一个回调函数或回调类,用于处理与 processor
相关的保存操作。
python
复制
if finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if finetuning_args.pissa_convert:
检查 finetuning_args
对象的 pissa_convert
属性是否为真。如果为真,则执行下面的代码块。
self.add_callback(PissaConvertCallback)
直接将 PissaConvertCallback
传递给 self.add_callback
方法。PissaConvertCallback
可能是一个类或函数,用于进行 PISSA 转换操作。
python
复制
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_badam:
检查 finetuning_args
对象的 use_badam
属性是否为真。如果为真,则执行下面的代码块。
from badam import BAdamCallback, clip_grad_norm_old_version
动态导入 badam
模块中的 BAdamCallback
和 clip_grad_norm_old_version
。这通常用来实现某些特定功能,只有在需要时才导入,避免不必要的开销。
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
这里将 clip_grad_norm_old_version
方法绑定到 self.accelerator
对象的 clip_grad_norm_
属性上。这种方式允许你动态替换或添加方法到对象实例上。
self.add_callback(BAdamCallback)
将 BAdamCallback
回调添加到 self
的回调列表中。BAdamCallback
可能是一个类或函数,用于处理 BAdam 优化器相关的逻辑。
猜想或解释 processor
是什么
根据代码的上下文和常见的设计模式,我们可以做出以下几个合理的猜想:
处理器对象:processor
可能是一个处理器对象,负责处理某种数据或任务。例如,在机器学习或数据处理管道中,processor
可能是一个预处理器、特征提取器或数据转换器。
模型组件:在某些深度学习或机器学习框架中,processor
可能是一个模型组件或模块,负责处理输入数据并将其传递给模型的其他部分。
回调依赖对象:processor
可能是某些回调函数或类所依赖的对象。比如,SaveProcessorCallback
可能需要 processor
来完成其特定的保存任务。
代码上下文
为了更准确地理解 processor
,了解更多上下文是非常重要的。通常在类的构造函数或初始化方法中,可以看到 processor
是如何被赋值的。以下是一个可能的类定义示例:
python
复制
class SomeClass:
def __init__(self, processor=None, finetuning_args=None):
self.processor = processor
self.finetuning_args = finetuning_args
if self.processor is not None:
self.add_callback(SaveProcessorCallback(self.processor))
if self.finetuning_args.pissa_convert:
self.add_callback(PissaConvertCallback)
if self.finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
在这个示例中,processor
和 finetuning_args
作为初始化参数传递给类的构造函数,然后在类的方法中使用。
结论
processor
可能是一个负责处理某种任务的对象或组件,它被传递给 SaveProcessorCallback
,以进行特定的保存操作。通过检查 processor
是否为 None
,可以决定是否需要添加这个回调。更详细的信息需要查看类的完整定义以及 processor
是如何被初始化和使用的。