"""医疗问答 LoRA 微调完整示例"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import json
from tqdm import tqdm
def prepare_medical_data():
dataset = load_dataset("medalp/medquad-zh", split="train[:5000]")
formatted_data = []
for item in tqdm(dataset, desc="格式化数据"):
formatted = {
"instruction": "你是一位经验丰富的临床医生,请根据患者描述提供专业建议",
"input": item['question'],
"output": item['answer']
}
formatted_data.append(formatted)
with open("medical_qa_formatted.json", "w", encoding="utf-8") as f:
json.dump(formatted_data, f, ensure_ascii=False, indent=2)
return formatted_data
def setup_model_and_lora():
model_name = "Qwen/Qwen-1.8B-Chat"
print("加载预训练模型...")
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True, padding_side="right"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none",
modules_to_save=["lm_head", "embed_tokens"]
)
print("应用 LoRA 适配器...")
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
return peft_model, tokenizer
def train_medical_model():
data = prepare_medical_data()
model, tokenizer = setup_model_and_lora()
def preprocess_function(examples):
texts = []
for inst, inp, out in zip(examples["instruction"], examples["input"], examples["output"]):
text = f"{inst}\n\n患者描述:{inp}\n\n医生建议:{out}"
texts.append(text)
tokenized = tokenizer(texts, truncation=True, max_length=512, return_tensors="pt")
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
from datasets import Dataset
dataset = Dataset.from_dict({
"instruction": [d["instruction"] for d in data],
"input": [d["input"] for d in data],
"output": [d["output"] for d in data]
})
tokenized_dataset = dataset.map(preprocess_function, batched=True)
training_args = TrainingArguments(
output_dir="./medical-chatbot-lora",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-4,
fp16=True,
logging_steps=10,
save_steps=500,
eval_steps=500,
evaluation_strategy="steps",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="loss",
greater_is_better=False,
warmup_ratio=0.1,
weight_decay=0.01,
report_to="tensorboard"
)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
eval_dataset=tokenized_dataset.select(range(100)),
data_collator=data_collator,
tokenizer=tokenizer
)
print("开始训练医学问答模型...")
trainer.train()
trainer.save_model("./medical-chatbot-final")
tokenizer.save_pretrained("./medical-chatbot-final")
print("训练完成!")
return trainer
def test_medical_model():
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-1.8B-Chat", torch_dtype=torch.float16, device_map="auto"
)
model = PeftModel.from_pretrained(base_model, "./medical-chatbot-final")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained("./medical-chatbot-final")
test_cases = [
"头痛、恶心、视力模糊应该怎么办?",
"高血压患者日常需要注意什么?",
"糖尿病早期有哪些症状?"
]
for query in test_cases:
prompt = f"你是一位经验丰富的临床医生,请根据患者描述提供专业建议\n\n患者描述:{query}\n\n医生建议:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True, top_p=0.9)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"问题:{query}")
print(f"回答:{response[len(prompt):]}")
print("-" * 50)
if __name__ == "__main__":
trainer = train_medical_model()
test_medical_model()