from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/Llama-Guard-3-8B",
max_seq_length=2048,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=32,
lora_dropout=0,
)
def format_prompt(sample):
return f"<|begin_of_text|>[INST] {sample['instruction']} [/INST]\n{sample['label']}\n{sample['category']}"
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
logging_steps=1,
output_dir="outputs",
),
)
trainer.train()