llamafactory/hparams/finetuning_args.py 【源码解析】
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16: bool = field(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
finetuning_type: Literal["lora", "freeze", "full"] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
use_llama_pro: bool = field(
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
我们逐行解释这段代码。
数据类定义和继承
python
复制
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
- 使用
@dataclass
装饰器定义一个名为FinetuningArguments
的数据类。 - 该类继承自多个基类:
FreezeArguments
、LoraArguments
、RLHFArguments
、GaloreArguments
和BAdamArgument
。 - 文档字符串说明:该类包含与微调技术相关的参数。
pure_bf16
字段
python
复制
pure_bf16: bool = field(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
- 定义一个名为
pure_bf16
的字段,类型为bool
。 - 默认值为
False
。 metadata
字典包含一个帮助信息,说明该字段表示是否使用纯 bf16 精度(不使用 AMP)进行模型训练。
stage
字段
python
复制
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)
- 定义一个名为
stage
的字段,类型为Literal
,可取值为"pt"
、"sft"
、"rm"
、"ppo"
、"dpo"
或"kto"
。 - 默认值为
"sft"
。 metadata
字典包含一个帮助信息,说明该字段表示训练过程中执行的阶段。
finetuning_type
字段
python
复制
finetuning_type: Literal["lora", "freeze", "full"] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
- 定义一个名为
finetuning_type
的字段,类型为Literal
,可取值为"lora"
、"freeze"
或"full"
。 - 默认值为
"lora"
。 metadata
字典包含一个帮助信息,说明该字段表示使用的微调方法。
use_llama_pro
字段
python
复制
use_llama_pro: bool = field(
default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
)
- 定义一个名为
use_llama_pro
的字段,类型为bool
。 - 默认值为
False
。 metadata
字典包含一个帮助信息,说明该字段表示是否仅使扩展块中的参数可训练。
freeze_vision_tower
字段
python
复制
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
- 定义一个名为
freeze_vision_tower
的字段,类型为bool
。 - 默认值为
True
。 metadata
字典包含一个帮助信息,说明该字段表示在 MLLM 训练中是否冻结视觉塔(vision tower)。
train_mm_proj_only
字段
python
复制
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
)
- 定义一个名为
train_mm_proj_only
的字段,类型为bool
。 - 默认值为
False
。 metadata
字典包含一个帮助信息,说明该字段表示是否仅训练用于 MLLM 的多模态投影器。
plot_loss
字段
python
复制
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
)
- 定义一个名为
plot_loss
的字段,类型为bool
。 - 默认值为
False
。 metadata
字典包含一个帮助信息,说明该字段表示是否保存训练损失曲线。
通过以上逐行解释,我们详细了解了 FinetuningArguments
数据类的定义和每个字段的具体含义。