Easy R1 训练环境搭建与配置实战
Easy R1 框架在 H800 双卡环境下基于 GRPO 算法的训练环境搭建全流程。涵盖虚拟环境依赖版本清单、PyTorch 与 FlashAttention 安装命令、vLLM 0.15.1 API 兼容补丁、渗透测试奖励函数实现、训练配置文件优化及启动脚本编写。解决了依赖冲突与接口变更问题,提供了可直接复用的配置方案,适用于大模型强化学习场景。
Easy R1 框架在 H800 双卡环境下基于 GRPO 算法的训练环境搭建全流程。涵盖虚拟环境依赖版本清单、PyTorch 与 FlashAttention 安装命令、vLLM 0.15.1 API 兼容补丁、渗透测试奖励函数实现、训练配置文件优化及启动脚本编写。解决了依赖冲突与接口变更问题,提供了可直接复用的配置方案,适用于大模型强化学习场景。
表中的依赖和版本经过验证可以正常进行训练,vllm 0.15.1 由于 API 变动导致训练的问题已在本文后续部分给出兼容代码。如果安装过程中出现部分依赖兼容问题,可以查阅此表。
| Package | Version | 说明 |
|---|---|---|
| accelerate | 1.12.0 | Hugging Face 的模型训练和推理加速库 |
| aiohappyeyeballs | 2.6.1 | 异步 DNS 解析库 |
| aiohttp | 3.13.3 | 异步 HTTP 客户端/服务器框架 |
| ... | ... | ... |
| torch | 2.9.1+cu128 | PyTorch 深度学习框架 (CUDA 12.8) |
| transformers | 4.56.2 | Hugging Face 预训练模型库 |
| vllm | 0.15.1 | 大语言模型推理和服务引擎 |
| verl | 0.3.3.dev0 | RL 框架 |
环境特征总结:
# 创建虚拟环境
conda create -n easy-r1 python=3.12 -y
# 安装 cuda 12.8+torch 2.9.1
conda activate easy-r1
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
# 验证安装
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name(0)}')"
apt update
# 下载 EasyR1 项目
apt install git
git clone https://github.com/hiyouga/EasyR1.git
cd EasyR1
# Flash Attention 编译依赖 packaging 和 ninja
pip install packaging ninja psutil
# 安装 Flash Attention2
pip install flash-attn==2.8.3 --no-build-isolation
# 验证 flash_attn
python -c "import flash_attn; print(flash_attn.__version__)"
# 安装项目依赖
pip install -e .
pip install swanlab
swanlab login
创建 test_flash_attn.py:
import torch
import flash_attn
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"Flash Attention: {flash_attn.__version__}")
# 测试是否能正常导入 flash_attn 函数
from flash_attn import flash_attn_func
print("Flash Attention 导入成功!")
运行验证:
python test_flash_attn.py
/xxx/EasyR1/verl/utils/vllm_utils.py 使用的 vllm api 过时,这里提供一个兼容 vllm 0.15.1 版本的 vllm_utils.py 源码内容。
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import version
from typing import List
from msgspec import field
from packaging import version as vs
from vllm.lora.lora_model import LoRAModel
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
class TensorLoRARequest(LoRARequest):
peft_config: dict = field(default=None)
lora_tensors: dict = field(default=None)
class VLLMHijack:
@staticmethod
def hijack():
def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
# ... (省略中间详细实现逻辑以保持简洁,实际使用时请保留完整代码)
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_tensors = None
from vllm.lora.peft_helper import PEFTHelper
if isinstance(lora_request, TensorLoRARequest):
peft_config = lora_request.peft_config
lora_tensors = lora_request.lora_tensors
peft_helper = PEFTHelper.from_dict(peft_config)
else:
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)
# Validates the LoRA configuration against requirements before loading weights
peft_helper.validate_legal(self.lora_config)
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None:
hf_to_vllm_mapper = model.hf_to_vllm_mapper
# vLLM 0.15.1+ compatibility: use getattr for embedding_modules and embedding_padding_modules
embedding_modules = getattr(self, 'embedding_modules', None)
embedding_padding_modules = getattr(self, 'embedding_padding_modules', None)
# Build kwargs dynamically based on vLLM version
kwargs = {'lora_model_id': lora_request.lora_int_id, 'device': "cpu", 'dtype': self.lora_config.lora_dtype, }
import inspect
sig = inspect.signature(self._lora_model_cls.from_lora_tensors)
sig_params = list(sig.parameters.keys())
if 'tensors' in sig_params:
kwargs['tensors'] = lora_tensors
if 'peft_helper' in sig_params:
kwargs['peft_helper'] = peft_helper
if 'target_embedding_padding' in sig_params:
kwargs['target_embedding_padding'] = self.vocab_size + getattr(self.lora_config, 'lora_extra_vocab_size', 0)
if 'embedding_modules' in sig_params:
kwargs['embedding_modules'] = embedding_modules
if 'embedding_padding_modules' in sig_params:
kwargs['embedding_padding_modules'] = embedding_padding_modules
if 'weights_mapper' in sig_params:
kwargs['weights_mapper'] = hf_to_vllm_mapper
if isinstance(lora_request, TensorLoRARequest):
lora = self._lora_model_cls.from_lora_tensors(**kwargs)
else:
local_kwargs = {'lora_path': lora_path, 'expected_lora_modules': expected_lora_modules, 'peft_helper': peft_helper, 'lora_model_id': lora_request.lora_int_id, 'device': "cpu", 'dtype': self.lora_config.lora_dtype, }
local_sig = inspect.signature(self._lora_model_cls.from_local_checkpoint)
local_sig_params = list(local_sig.parameters.keys())
if 'target_embedding_padding' in local_sig_params:
local_kwargs['target_embedding_padding'] = self.vocab_size + getattr(self.lora_config, 'lora_extra_vocab_size', 0)
if 'embedding_modules' in local_sig_params:
local_kwargs['embedding_modules'] = embedding_modules
if 'embedding_padding_modules' in local_sig_params:
local_kwargs['embedding_padding_modules'] = embedding_padding_modules
if 'weights_mapper' in local_sig_params:
local_kwargs['weights_mapper'] = hf_to_vllm_mapper
lora = self._lora_model_cls.from_local_checkpoint(**local_kwargs)
lora_extra_vocab_size = getattr(lora, 'extra_vocab_size', 0)
config_extra_vocab_size = getattr(self.lora_config, 'lora_extra_vocab_size', 0)
if lora_extra_vocab_size > config_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora_extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{config_extra_vocab_size}."
)
return lora
setattr(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter)
if vs.parse(version("vllm")).base_version == "0.11.0":
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
def hijack__get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(language_model="language_model", connector="visual.merger.", tower_model="visual.", )
setattr(Qwen3VLForConditionalGeneration, "get_mm_mapping", hijack__get_mm_mapping)
这里只提供 EasyR1 要求的数据集格式,具体数据集的收集和格式转化请自行处理。
{"problem":"xxx","answer":"xxx"}
创建 EasyR1/easyr1/reward/pentest_reward.py:
import re
from typing import List, Dict, Any
THINK_PATTERN = re.compile(r"</think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
STEP_PATTERN = re.compile(r"===\s*Step\s*\d+\s*===\s*Thought:.*?Command:", re.DOTALL | re.IGNORECASE)
POST_THINK_PATTERN = re.compile(r"</think>(.*)", re.DOTALL)
STEP_NUM_PATTERN = re.compile(r"===\s*Step\s*(\d+)\s*===" , re.IGNORECASE)
COMMAND_PATTERN = re.compile(r"Command:\s*(.*?)(?:\n|$)", re.DOTALL | re.IGNORECASE)
class PentestRewardFunction:
""" 渗透测试 GRPO 奖励函数 包含两个奖励组件: 1. format_reward: 评估输出格式是否符合要求 2. accuracy_reward: 评估内容与参考答案的匹配程度 """
def __init__(self, format_weight: float = 0.3, accuracy_weight: float = 0.7, **kwargs):
self.format_weight = format_weight
self.accuracy_weight = accuracy_weight
def __call__(self, prompts: List[str], completions: List[str], **kwargs) -> Dict[str, List[float]]:
answers = kwargs.get("answer", [""] * len(completions))
rewards = []
format_rewards = []
accuracy_rewards = []
for prompt, completion, answer in zip(prompts, completions, answers):
format_score = self._compute_format_reward(completion)
format_rewards.append(format_score)
accuracy_score = self._compute_accuracy_reward(completion, answer)
accuracy_rewards.append(accuracy_score)
total_score = (self.format_weight * format_score + self.accuracy_weight * accuracy_score)
rewards.append(total_score)
return {"rewards": rewards, "format_reward": format_rewards, "accuracy_reward": accuracy_rewards}
def _compute_format_reward(self, completion: str) -> float:
score = 0.0
think_match = THINK_PATTERN.search(completion)
if think_match and think_match.group(1).strip():
score += 0.3
if STEP_PATTERN.search(completion):
score += 0.3
return score
def _compute_accuracy_reward(self, completion: str, answer: str) -> float:
score = 0.0
match_format = POST_THINK_PATTERN.search(completion)
if not match_format:
return score
analysis_content = match_format.group(1).strip()
gen_step_match = STEP_NUM_PATTERN.search(analysis_content)
true_step_match = STEP_NUM_PATTERN.search(answer)
if (gen_step_match and true_step_match and gen_step_match.group(1) == true_step_match.group(1)):
score += 0.2
gen_command_match = COMMAND_PATTERN.search(analysis_content)
true_command_match = COMMAND_PATTERN.search(answer)
gen_command = gen_command_match.group(1).strip() if gen_command_match else ""
true_command = true_command_match.group(1).strip() if true_command_match else ""
if not true_command:
return score
if gen_command == true_command:
score += 1.0
elif gen_command:
gen_words = set(gen_command.split())
true_words = set(true_command.split())
intersection = gen_words & true_words
union_size = len(gen_words) + len(true_words) - len(intersection)
if union_size > 0:
similarity = len(intersection) / union_size
if similarity > 0.5:
score += similarity * 0.7
return min(score, 1.5)
_reward_fn_instance = PentestRewardFunction()
def create_reward_function(reward_inputs):
prompts = []
completions = []
answers = []
for item in reward_inputs:
prompts.append("")
completions.append(item.get("response", ""))
answers.append(item.get("ground_truth", ""))
result = _reward_fn_instance(prompts, completions, answer=answers)
scores = []
for i in range(len(reward_inputs)):
scores.append({"overall": result["rewards"][i], "format": result["format_reward"][i], "accuracy": result["accuracy_reward"][i],})
return scores
REWARD_NAME = "pentest_reward"
REWARD_TYPE = "batch"
注册奖励函数,编辑 EasyR1/easyr1/reward/__init__.py,添加:
from .pentest_reward import PentestRewardFunction, create_reward_function
__all__ = [
"PentestRewardFunction",
"create_reward_function",
]
同时在上一级目录创建 vim EasyR1/easyr1/__init__.py。(后续训练配置文件使用绝对路径,可能这里不需要设置)
创建 EasyR1/examples/pentest_grpo_h800_optimized.yaml,使用 verl 的配置格式。
data:
train_files: /xxx/EasyR1/data/pentest/pentest_grpo_train.jsonl
val_files: /xxx/EasyR1/data/pentest/pentest_grpo_eval.jsonl
prompt_key: problem
answer_key: answer
image_key: images
video_key: videos
image_dir: null
video_fps: 2.0
max_prompt_length: 5120
max_response_length: 3072
rollout_batch_size: 16
mini_rollout_batch_size: 8
val_batch_size: 8
shuffle: true
seed: 1
algorithm:
adv_estimator: grpo
disable_kl: false
use_kl_loss: true
kl_penalty: low_var_kl
kl_coef: 1.0e-2
worker:
actor:
global_batch_size: 16
micro_batch_size_per_device_for_update: 4
micro_batch_size_per_device_for_experience: 4
max_grad_norm: 1.0
padding_free: true
dynamic_batching: true
model:
model_path: unsloth/DeepSeek-R1-0528-Qwen3-8B
enable_gradient_checkpointing: true
trust_remote_code: true
freeze_vision_tower: false
lora:
rank: 32
alpha: 64
target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
optim:
lr: 5.0e-6
weight_decay: 0.01
strategy: adamw
lr_warmup_ratio: 0.0
fsdp:
enable_full_shard: true
enable_cpu_offload: false
enable_rank0_init: true
torch_dtype: bf16
offload:
offload_params: false
offload_optimizer: false
rollout:
n: 4
temperature: 0.6
top_p: 0.9
gpu_memory_utilization: 0.75
enforce_eager: false
enable_chunked_prefill: true
tensor_parallel_size: 2
max_num_batched_tokens: 8192
val_override_config:
temperature: 0.6
top_p: 0.9
n: 1
ref:
fsdp:
enable_full_shard: true
enable_cpu_offload: true
enable_rank0_init: true
torch_dtype: bf16
offload:
offload_params: false
reward:
reward_function: /xxx/EasyR1/easyr1/reward/pentest_reward.py:create_reward_function
trainer:
total_epochs: 2
project_name: pentest
experiment_name: grpo-h800x2-optimized
logger: ["console", "swanlab"]
nnodes: 1
n_gpus_per_node: 2
val_freq: 50
val_before_train: true
val_only: false
save_freq: 50
save_limit: 3
save_model_only: false
save_checkpoint_path: /xxx/EasyR1/checkpoints/pentest-optimized
load_checkpoint_path: null
find_last_checkpoint: true
在 EasyR1 项目根目录下创建启动脚本文件 start_training.sh。
#!/bin/bash
# Pentest-R1 GRPO 训练启动脚本
export CUDA_VISIBLE_DEVICES=0,1
export PYTHONPATH=/root/autodl-fs/EasyR1:$PYTHONPATH
mkdir -p /root/autodl-fs/EasyR1/logs
mkdir -p /root/autodl-fs/EasyR1/checkpoints/pentest-r1-optimized
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
LOG_FILE="/root/autodl-fs/EasyR1/logs/training_${TIMESTAMP}.log"
echo "========================================"
echo "启动 Pentest-R1 GRPO 训练"
echo "配置文件:examples/pentest_grpo_h800_optimized.yaml"
echo "日志文件:${LOG_FILE}"
echo "开始时间:$(date)"
echo "========================================"
nohup python3 -m verl.trainer.main \
config=examples/pentest_grpo_h800_optimized.yaml \
trainer.n_gpus_per_node=2 \
>"${LOG_FILE}" 2>&1 &
PID=$!
echo "训练进程 PID: ${PID}"
echo "${PID}" > /root/autodl-fs/EasyR1/logs/training.pid
echo ""
echo "训练已在后台启动!"
echo "常用命令:"
echo " 查看实时日志:tail -f ${LOG_FILE}"
echo " 查看进程状态:ps aux | grep verl.trainer"
echo " 查看 GPU 状态:nvidia-smi"
echo " 停止训练:kill ${PID}"
echo ""
echo "查看日志:tail -f ${LOG_FILE}"
训练中断再启动:可以从保存点继续训练。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online