LLaMA-Factory/src/llamafactory/hparams/parser.py源码解析
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
逐行解释这段代码:
_parse_train_args
函数
python
复制
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
- 定义了一个名为
_parse_train_args
的函数,用于解析训练参数。 - 创建一个
HfArgumentParser
对象,传入_TRAIN_ARGS
,它可能是一个包含训练相关参数定义的类或类列表。 - 调用私有的
_parse_args
函数,传入解析器和可选的参数字典args
,返回解析后的训练参数对象。
_parse_infer_args
函数
python
复制
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
- 定义了一个名为
_parse_infer_args
的函数,用于解析推理参数。 - 创建一个
HfArgumentParser
对象,传入_INFER_ARGS
,它可能是一个包含推理相关参数定义的类或类列表。 - 调用私有的
_parse_args
函数,传入解析器和可选的参数字典args
,返回解析后的推理参数对象。
_parse_eval_args
函数
python
复制
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
- 定义了一个名为
_parse_eval_args
的函数,用于解析评估参数。 - 创建一个
HfArgumentParser
对象,传入_EVAL_ARGS
,它可能是一个包含评估相关参数定义的类或类列表。 - 调用私有的
_parse_args
函数,传入解析器和可选的参数字典args
,返回解析后的评估参数对象。
get_train_args
函数
python
复制
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
- 定义了一个名为
get_train_args
的函数,用于获取训练参数。 - 调用
_parse_train_args
函数,解析传入的参数字典args
,并将其拆解为model_args
、data_args
、training_args
、finetuning_args
和generating_args
。
˙˙˙˙python
复制
# Setup logging
if training_args.should_log:
_set_transformers_logging()
- 如果
training_args
中的should_log
为真,调用_set_transformers_logging
函数设置日志记录。
python
复制
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
- 检查参数的有效性:
- 如果微调阶段(
stage
)不是 "pt"(可能表示预训练),并且data_args
中没有指定模板(template
),则抛出ValueError
。
python
复制
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
- 如果微调阶段不是 "sft"(可能表示特定的微调阶段),并且
training_args
中的predict_with_generate
被设置为真,则抛出ValueError
。
python
复制
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
- 如果微调阶段是 "sft",并且
training_args
中的do_predict
为真,但predict_with_generate
为假,则抛出ValueError
,要求启用predict_with_generate
以保存模型预测。
总结
这段代码定义了一些函数,用于解析不同任务(训练、推理、评估)的参数,并进行必要的参数检查。它们使用了 HfArgumentParser
来解析参数定义,并通过 _parse_args
函数来实际解析传入的参数字典。get_train_args
函数还包含了对