背景与动机
LLaMA Factory 凭借简洁的 API 和丰富的训练范式(如增量预训练、指令微调、DPO/ORPO 等),已成为国内 LLM 微调的首选工具之一。它基于 Hugging Face Transformers 生态,在单机多卡场景下表现优异。
然而,当面对百亿参数以上的大模型或大规模多节点分布式训练需求时,传统的数据并行方案(如 ZeRO 或 FSDP)往往面临通信瓶颈与显存效率不足的问题。此时,若能将 LLaMA Factory 强大的数据处理能力与 NVIDIA Megatron-LM 专为超大规模模型设计的高性能分布式框架相结合,将显著提升训练吞吐与可扩展性。
本文介绍一种高效的技术路径:借助 MCoreAdapter 桥接层,利用 LLaMA Factory 的生态优势驱动 Megatron-LM,实现 SFT、DPO 和 ORPO 等主流微调任务的分布式加速训练。
技术原理
Megatron-LM 是 NVIDIA 开源的面向超大规模 Transformer 模型的分布式训练框架。它在 PyTorch 基础上深度优化,针对 NVIDIA GPU 架构实现了高度工程化的并行策略,被广泛应用于千亿参数级别模型的预训练与微调。
其核心优势在于对多维模型并行的极致优化:
- 灵活的并行组合:原生支持张量并行(TP)、流水线并行(PP)、序列并行(SP)、上下文并行(CP)及专家并行(EP),可按需组合(如 DP + TP + PP),有效应对不同规模与硬件配置下的训练挑战。
- 极低的通信开销:通过计算与通信重叠、定制化 NCCL 调优等技术,在千卡集群上仍能保持高训练效率。
- 超大模型友好:相比 ZeRO-3 或 FSDP 等基于状态切分的数据并行方案,Megatron-LM 通过对模型结构本身进行物理切分,使大部分计算本地化,大幅减少跨设备同步,尤其适合 >100B 参数模型的高效训练。
MCoreAdapter 是阿里巴巴开源的轻量级桥接工具包,最初作为强化学习框架 Roll 的 Megatron 集成组件而开发。它巧妙融合了 Megatron-LM 的分布式训练能力与 Hugging Face Transformers 风格的简洁接口,可无缝接入 LLaMA Factory 的数据管道与训练配置体系,使得用户既能享受 LLaMA Factory 的便捷性,又能释放 Megatron-LM 在大规模分布式环境下的性能潜力。
快速开始
环境配置
需要提前准备以下基础环境:
- CUDA 版本 >= 12.4
- cuDNN 版本 >= 9.1.0
- PyTorch >= 2.5.1
- SGlang >= 0.4.3
- vLLM >= 0.7.3
可以利用 conda 安装,或者使用 roll 社区官方提供的镜像作为基础环境。例如:
torch2.6.0 + SGlang0.4.6: roll-registry.cn-hangzhou.cr.aliyuncs.com/roll/pytorch:nvcr-24.05-py3-torch260-sglang046
torch2.6.0 + vLLM0.8.4: roll-registry.cn-hangzhou.cr.aliyuncs.com/roll/pytorch:nvcr-24.05-py3-torch260-vllm084
在基础环境上安装 LLaMA-Factory 和 MCoreAdapter:
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git && \
cd LLaMA-Factory && \
pip install -e "[torch,metrics]" && \
pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
编写训练启动脚本
我们需要创建一个 run_train.py 脚本来桥接训练逻辑。该脚本参考了 MCoreAdapter 中的示例,但修复了部分参数冲突问题。
import functools
import hashlib
import os
copy deepcopy
dataclasses dataclass, field
typing , , ,
torch
filelock FileLock
huggingface_hub snapshot_download
llamafactory.data get_dataset, get_template_and_fix_tokenizer
llamafactory.data.collator PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
llamafactory.hparams DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
llamafactory.model load_tokenizer
llamafactory.train.callbacks SaveProcessorCallback
llamafactory.train.dpo run_dpo
llamafactory.train.pt run_pt
llamafactory.train.sft run_sft
transformers DataCollatorForSeq2Seq, HfArgumentParser
transformers.trainer_callback TrainerCallback
mcore_adapter.models AutoConfig, AutoModel
mcore_adapter.trainer DPOTrainer, McaTrainer
mcore_adapter.trainer.dpo_config DPOConfig
mcore_adapter.training_args Seq2SeqTrainingArguments
() -> :
os.path.isdir(model_name_or_path):
model_name_or_path
use_model_scope = os.getenv(, ) ==
temp_lock_path = os.path.join(
,
,
)
FileLock(temp_lock_path):
use_model_scope:
modelscope.hub.snapshot_download snapshot_download ms_snapshot_download
ms_snapshot_download(model_name_or_path, local_dir=local_dir)
snapshot_download(model_name_or_path, local_dir=local_dir)
():
():
.prof = prof
():
.prof.step()
:
enable_mca: = field(
default=,
metadata={: }
)
() -> [
Seq2SeqTrainingArguments, ModelArguments, DataArguments,
FinetuningArguments, GeneratingArguments, UseMcaArguments,
]:
parser = HfArgumentParser((
Seq2SeqTrainingArguments, ModelArguments, DataArguments,
FinetuningArguments, GeneratingArguments, UseMcaArguments,
))
training_args, model_args, data_args, finetuning_args, generating_args, use_mca_args = parser.parse_args_into_dataclasses()
use_mca_args.enable_mca:
transformers Seq2SeqTrainingArguments HFSeq2SeqTrainingArguments
training_args = HFSeq2SeqTrainingArguments(**(training_args))
model_args.model_name_or_path = download_model(model_args.model_name_or_path)
training_args, model_args, data_args, finetuning_args, generating_args, use_mca_args
():
():
labels_key = [k k features[].keys() k.endswith()]
input_ids_key = [k k features[].keys() k.endswith()]
feature features:
(labels_key) == :
feature[] = deepcopy(feature[])[:]
k labels_key:
feature[k] = feature[k][:]
k input_ids_key:
feature[k] = feature[k][:-]
k [, ]:
k feature:
feature[k] = feature[k][:-]
data_collator(features)
wrapper
():
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module[]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
data_args.cutoff_len +=
dataset_module = get_dataset(template, model_args, data_args, training_args, stage=, **tokenizer_module)
data_args.cutoff_len -=
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, pad_to_multiple_of=, label_pad_token_id=-,
)
data_collator = data_collator_wrapper(data_collator)
trainer = McaTrainer(
model=model, args=training_args, tokenizer=tokenizer,
data_collator=data_collator, **dataset_module,
)
tokenizer_module tokenizer_module[] :
trainer.add_callback(SaveProcessorCallback(tokenizer_module[]))
trainer.train(training_args.resume_from_checkpoint)
():
data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing training_args.sequence_packing
data_args.packing = data_args.neat_packing data_args.packing
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module[]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
data_args.cutoff_len +=
dataset_module = get_dataset(template, model_args, data_args, training_args, stage=, **tokenizer_module)
data_args.cutoff_len -=
pad_to_max = training_args.expert_model_parallel_size training_args.expert_model_parallel_size > training_args.variable_seq_lengths
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template, pad_to_multiple_of=, label_pad_token_id=-,
max_length=data_args.cutoff_len pad_to_max ,
**tokenizer_module,
)
data_collator = data_collator_wrapper(data_collator)
trainer = McaTrainer(
model=model, args=training_args, tokenizer=tokenizer,
data_collator=data_collator, **dataset_module,
)
tokenizer_module tokenizer_module[] :
trainer.add_callback(SaveProcessorCallback(tokenizer_module[]))
trainer.train(training_args.resume_from_checkpoint)
():
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module[]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config)
ref_model.load_state_dict(model.state_dict())
:
ref_model =
data_args.cutoff_len +=
dataset_module = get_dataset(template, model_args, data_args, training_args, stage=, **tokenizer_module)
data_args.cutoff_len -=
pad_to_max = training_args.expert_model_parallel_size training_args.expert_model_parallel_size >
dpo_config = DPOConfig(
beta=finetuning_args.pref_beta, pref_loss=finetuning_args.pref_loss,
label_smoothing=finetuning_args.dpo_label_smoothing,
)
data_collator = PairwiseDataCollatorWithPadding(
template=template, pad_to_multiple_of=, label_pad_token_id=-,
max_length=data_args.cutoff_len pad_to_max ,
**tokenizer_module,
)
data_collator = data_collator_wrapper(data_collator)
trainer = DPOTrainer(
model=model, ref_model=ref_model, args=training_args, train_config=dpo_config,
tokenizer=tokenizer, data_collator=data_collator, **dataset_module,
)
tokenizer_module tokenizer_module[] :
trainer.add_callback(SaveProcessorCallback(tokenizer_module[]))
trainer.train(training_args.resume_from_checkpoint)
():
finetuning_args.stage == :
pt_mca_train(training_args, model_args, data_args, finetuning_args)
finetuning_args.stage == :
sft_mca_train(training_args, model_args, data_args, finetuning_args)
finetuning_args.stage == :
dpo_mca_train(training_args, model_args, data_args, finetuning_args)
:
ValueError()
():
data_args.cutoff_len +=
callbacks =
finetuning_args.stage == :
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
finetuning_args.stage == :
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
finetuning_args.stage == :
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
:
ValueError()
():
training_args, model_args, data_args, finetuning_args, generating_args, use_mca_args = get_args()
model_args.model_max_length = data_args.cutoff_len
model_args.block_diag_attn = data_args.neat_packing
data_args.packing = data_args.packing data_args.packing finetuning_args.stage ==
use_mca_args.enable_mca:
mca_train(training_args, model_args, data_args, finetuning_args)
:
llama_factory_train(training_args, model_args, data_args, finetuning_args, generating_args)
__name__ == :
main()

