Stable Diffusion 3.5 开发指南(三):Stable Diffusion 3.5 LoRA 微调
概述
在之前的章节中,我们学习了如何获取和调用 Stable Diffusion 3.5 模型,以及深入理解了其核心的 Flow Matching 机制。本章将聚焦于LoRA(Low-Rank Adaptation)微调技术,这是一种高效的模型定制方法,能够在保持原有模型性能的同时,仅通过少量参数更新即可实现特定任务的定制化。
1. 数据集准备
1.1 数据集格式
微调 Stable Diffusion 3.5 模型需要图像-文本对数据集,每个数据项应包含以下两个核心字段:
img_path:图像文件的路径(支持绝对路径或相对路径)caption:与图像内容精准匹配的文本描述
示例 JSON 数据集格式
[{"img_path":"/path/to/image1.jpg","caption":"A beautiful sunset over the mountains"},{"img_path":"/path/to/image2.jpg","caption":"A group of people at a conference"}]1.2 数据处理
为了方便加载和预处理数据,我们实现了一个自定义的 PyTorch 数据集类 StableDiffusionDataset。该类封装了以下核心功能:
- 从 JSON 文件加载数据集元信息
- 图像自动预处理(缩放、转换为张量、归一化)
- 数据加载错误处理
数据集类实现
import json import os from PIL import Image import torch from torch.utils.data import Dataset from torchvision import transforms classStableDiffusionDataset(Dataset):def__init__(self, json_path):""" 初始化 Stable Diffusion 微调数据集 Args: json_path: JSON 文件路径,包含 img_path 和 caption 字段 """super().__init__()# 读取 JSON 文件withopen(json_path,'r', encoding='utf-8')as f: self.data = json.load(f)# 定义图像预处理 pipeline# 将图像调整为 512x512(SD 3.5 模型的默认输入尺寸),转换为张量并归一化到 [-1, 1] 范围 self.transform = transforms.Compose([ transforms.Resize((512,512)),# 调整图像大小为 512x512 transforms.ToTensor(),# 转换为张量 [0, 1] transforms.Normalize([0.5],[0.5])# 归一化到 [-1, 1]])def__len__(self):"""返回数据集样本数量"""returnlen(self.data)def__getitem__(self, idx):""" 获取单个数据样本 Args: idx: 样本索引 Returns: tuple: (image_tensor, caption) - image_tensor: 处理后的图像张量,形状为 [3, 512, 512] - caption: 文本描述字符串 """ item = self.data[idx]# 读取图像 img_path = item['img_path']# 检查文件是否存在ifnot os.path.exists(img_path):raise FileNotFoundError(f"图像文件不存在: {img_path}")# 打开并转换图像try:# 确保图像为 RGB 格式(丢弃 alpha 通道) image = Image.open(img_path).convert('RGB')except Exception as e:raise ValueError(f"无法读取图像 {img_path}: {str(e)}")# 应用预处理转换 image_tensor = self.transform(image)# 获取文本描述 caption = item['caption']return image_tensor, caption 使用示例
以下是如何使用 StableDiffusionDataset 类的完整示例,包括数据集创建、样本查看和 DataLoader 构建:
# 创建数据集实例 dataset = StableDiffusionDataset("data.json")# 查看数据集大小print(f"数据集包含 {len(dataset)} 个样本")# 获取单个样本 image, caption = dataset[0]print(f"图像维度: {image.shape}")print(f"文本描述: {caption}")# 创建 DataLoader 用于批量训练from torch.utils.data import DataLoader dataloader = DataLoader( dataset, batch_size=4,# 批次大小,可根据 GPU 内存调整 shuffle=True,# 训练时打乱数据,增加随机性 num_workers=2,# 并行加载进程数,加速数据加载 pin_memory=True# 启用内存锁定,加速数据传输到 GPU)2. LoRA 微调原理
LoRA 是一种参数高效微调技术,其核心思想是:
- 冻结原有模型:保持预训练模型的权重不变,避免灾难性遗忘
- 添加低秩适配器:在关键层(如注意力层)插入低秩矩阵对(A 和 B)
- 仅训练低秩矩阵:通过少量参数更新即可实现模型定制
这种方法的优势在于:
- 训练参数仅为原有模型的 1%-5%,大幅降低内存占用
- 训练速度显著提升,减少计算资源消耗
- 微调结果易于保存和分享,单个 LoRA 权重文件通常仅几 MB
- 支持多 LoRA 权重组合使用,实现灵活的风格控制
3. 模型加载与 LoRA 配置
3.1 加载预训练模型
首先需要加载 Stable Diffusion 3.5 预训练模型:
from diffusers import StableDiffusion3Pipeline import torch # 加载预训练模型 model_id ="stabilityai/stable-diffusion-3.5-large" pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, torch_dtype=torch.float16,# 使用半精度(float16)加速计算并减少内存占用).to("cuda")# 移至 GPU 设备3.2 配置 LoRA 参数
LoRA 的关键参数包括:
r:低秩矩阵的秩,控制适配器的容量(常用值:4, 8, 16, 32)alpha:缩放因子,控制 LoRA 对模型的影响程度target_modules:需要添加 LoRA 适配器的目标层
from peft import LoraConfig, get_peft_model # 配置 LoRA 参数 lora_config = LoraConfig( r=16,# 低秩矩阵的秩,r 越大,适配器容量越大,但参数也越多 lora_alpha=32,# 缩放因子,通常设置为 r 的 2 倍 target_modules=['to_k',# 注意力层的键(Key)投影层'to_q',# 注意力层的查询(Query)投影层'to_v',# 注意力层的值(Value)投影层], lora_dropout=0.05,# Dropout 率,防止过拟合 bias="none",# 不对偏置项应用 LoRA task_type="TEXT_TO_IMAGE"# 任务类型)# 为模型添加 LoRA 适配器 pipeline.transformer = get_peft_model(pipeline.transformer, lora_config)# 冻结不需要训练的组件,仅训练 LoRA 适配器 pipeline.vae.requires_grad_(False)# 冻结 VAE 编码器/解码器 pipeline.text_encoder.requires_grad_(False)# 冻结文本编码器 1 pipeline.text_encoder_2.requires_grad_(False)# 冻结文本编码器 2 pipeline.text_encoder_3.requires_grad_(False)# 冻结文本编码器 3# 打印可训练参数数量print("可训练参数数量:",sum(p.numel()for p in pipeline.transformer.parameters()if p.requires_grad))4. 训练循环实现
4.1 定义训练参数
import torch.nn.functional as F from transformers import get_scheduler # 训练参数配置 epochs =5# 训练轮次 batch_size =4# 批次大小 learning_rate =1e-4# 学习率 weight_decay =1e-2# 权重衰减,防止过拟合# 优化器配置:仅优化可训练参数(LoRA 适配器参数) optimizer = torch.optim.AdamW( params=filter(lambda p: p.requires_grad, pipeline.transformer.parameters()), lr=learning_rate, weight_decay=weight_decay )# 学习率调度器:使用余弦退火调度,逐步降低学习率 num_training_steps =len(dataloader)* epochs scheduler = get_scheduler( name="cosine",# 调度器类型 optimizer=optimizer, num_warmup_steps=0,# 预热步数 num_training_steps=num_training_steps )4.2 执行训练循环
# 设置模型为训练模式 pipeline.transformer.train() device ='cuda:0'# 指定 GPU 设备defcompute_density_for_timestep_sampling(batch_size, device):""" 基于正态分布的时间步采样,增加训练稳定性 Args: batch_size: 批次大小 device: 设备类型 Returns: torch.Tensor: 采样的时间步权重,形状为 [batch_size] """ u = torch.normal(0,1,(batch_size,), device=device) u = torch.sigmoid(u)# 将正态分布转换到 [0, 1] 区间return u defget_sigmas(timesteps, n_dim, device):""" 获取对应时间步的噪声方差(sigmas) Args: timesteps: 时间步张量 n_dim: 目标维度,用于广播 sigma device: 设备类型 Returns: torch.Tensor: 噪声方差,形状为 [batch_size, 1, 1, 1] """# 将调度器的时间步和 sigmas 移动到当前设备 scheduler_timesteps = pipeline.scheduler.timesteps.to(device) sigmas = pipeline.scheduler.sigmas.to(device)# 确保输入 timesteps 也在同一设备 timesteps = timesteps.to(device)# 查找每个时间步对应的索引 step_indices =[(scheduler_timesteps == t).nonzero().item()for t in timesteps] sigma = sigmas[step_indices].flatten()# 广播 sigma 到目标维度(适配 latent 形状)whilelen(sigma.shape)< n_dim: sigma = sigma.unsqueeze(-1)return sigma # 开始训练循环for epoch inrange(epochs):print(f"Epoch {epoch+1}/{epochs}")# 重置累积损失 total_loss =0for step,(images, captions)inenumerate(dataloader):# ----------------------------# 1. 编码文本(CLIP + T5)# ----------------------------with torch.no_grad(): prompt_embeds, _, pooled_prompt_embeds, _ = pipeline.encode_prompt( prompt=captions, prompt_2=captions, prompt_3=captions, device=device, negative_prompt='', negative_prompt_2='', negative_prompt_3='', do_classifier_free_guidance=True,)# ----------------------------# 2. 将图像编码为潜在表示(latent)# ----------------------------# 移动图像到 GPU 并转换为半精度 images = images.to(device, dtype=torch.float16)with torch.no_grad():# 使用 VAE 编码器将图像转换为潜在表示 vae_output = pipeline.vae.encode(images) latents = vae_output.latent_dist.sample()# 应用 VAE 配置的缩放和偏移因子 latents =(latents - pipeline.vae.config.shift_factor)* pipeline.vae.config.scaling_factor # ----------------------------# 3. 采样时间步(带权重方案)# ---------------------------- u = compute_density_for_timestep_sampling( batch_size=batch_size, device=device )# 将采样权重转换为时间步索引 indices =(u * pipeline.scheduler.config.num_train_timesteps).long() timesteps = pipeline.scheduler.timesteps.to(device)[indices]# ----------------------------# 4. Flow Matching:生成带噪声的潜在表示# ----------------------------# 获取对应时间步的噪声方差 sigmas = get_sigmas(timesteps, n_dim=latents.ndim, device=device)# 生成随机噪声 noise = torch.randn_like(latents, device=device)# 生成中间状态:(1-sigma)*latent + sigma*noise# 注意:SD 3.5 的 Flow Matching 插值方向与标准相反# 0 时刻是图像的压缩态(latents),1 时刻是噪声 noisy_latents =(1.0- sigmas)* latents + sigmas * noise # ----------------------------# 5. 预测流场(model_pred)# ----------------------------# 使用混合精度训练,加速计算并减少内存占用with torch.autocast("cuda"):# 模型预测平均速度(方向:从压缩态到噪声) model_pred = pipeline.transformer( hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False)[0]# 计算预测的 latent:当前位置 + 速度*时间(反向)# 模型预测的是平均速度,乘以 (-sigma) 表示反向移动 pred = model_pred *(-sigmas)+ noisy_latents # 计算 MSE 损失:预测 latent 与真实 latent 的差距 loss = F.mse_loss(pred, latents)# 反向传播与优化 optimizer.zero_grad()# 清空梯度 loss.backward()# 计算梯度 optimizer.step()# 更新参数 scheduler.step()# 更新学习率# 累积损失 total_loss += loss.item()# 打印训练日志if step %100==0: avg_loss = total_loss /(step +1)print(f"Step {step}, Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}")# 打印 epoch 日志 avg_epoch_loss = total_loss /len(dataloader)print(f"Epoch {epoch+1} 完成, 平均损失: {avg_epoch_loss:.4f}")4.3 关于损失计算的说明
SD 3.5 的 Flow Matching 训练与标准 Flow Matching 有两点关键不同:
- 插值方向相反:
- 标准 Flow Matching:0 时刻是噪声,1 时刻是图像
- SD 3.5:0 时刻是图像的压缩态(latents),1 时刻是噪声
- 中间状态:
(1.0 - sigmas) * latents + sigmas * noise
- 模型预测目标不同:
- 模型预测的是平均速度(方向:从压缩态到噪声)
- 距离:
model_pred * sigmas,平均速度乘以时间就是距离 - 预测公式:
pred = model_pred * (-sigmas) + noisy_latents - 损失计算:
MSE(pred, latents),即中间点往回走一段距离的位置和起点的差距
这种设计能够更好地适应扩散模型的训练特性,提高生成质量和训练稳定性。
5. LoRA 权重保存与加载
5.1 保存 LoRA 权重
# 保存 LoRA 权重,仅保存可训练的 LoRA 参数 pipeline.transformer.save_pretrained("lora-sd35-finetuned")print("LoRA 权重保存完成")5.2 加载 LoRA 权重
from peft import PeftModel # 加载预训练模型 pipeline = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.float16 ).to("cuda")# 加载 LoRA 权重到 transformer 组件 pipeline.transformer = PeftModel.from_pretrained( pipeline.transformer,"lora-sd35-finetuned")# 设置模型为推理模式 pipeline.transformer.eval()print("LoRA 权重加载完成")6. 推理
推理方式与之前的 Stable Diffusion 3.5 开发指南(一)完全相同,加载 LoRA 权重后可直接使用 pipeline 进行图像生成:
# 示例:使用微调后的 LoRA 生成图像 prompt ="A beautiful sunset over the mountains in my style" generated_image = pipeline( prompt=prompt, negative_prompt="blur, low quality, distortion", num_inference_steps=30, guidance_scale=7.5).images[0]# 保存生成的图像 generated_image.save("generated_image.png")7. 最佳实践
7.1 数据准备
- 数据质量:确保图像清晰(建议分辨率 ≥ 512x512),文本描述准确且详细
- 数据多样性:包含多种场景、角度和风格的图像,避免过拟合
- 数据量:建议至少准备 100-500 个样本,具体取决于任务复杂度
- 文本描述优化:使用详细的关键词,例如 “a photo of a cat, detailed fur, blue eyes, sunny day” 而非 “a cat”
7.2 训练参数调整
| 参数 | 建议范围 | 说明 |
|---|---|---|
| 批次大小 | 2-8 | 根据 GPU 内存调整,A100 可使用 8-16 |
| 学习率 | 5e-5 - 2e-4 | 建议使用余弦退火调度,逐渐降低 |
| LoRA 秩 r | 4-32 | 小数据集使用小 r(4-8),大数据集使用大 r(16-32) |
| 训练轮次 | 5-20 | 监控损失曲线,避免过拟合 |
| 权重衰减 | 1e-2 - 1e-3 | 防止过拟合,正则化模型 |
7.3 常见问题与解决方案
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
| 生成图像模糊 | 训练轮次不足或学习率过低 | 增加训练轮次或提高学习率 |
| 过拟合(生成图像与训练集高度相似) | 数据量不足或 LoRA 秩过大 | 增加数据量、减小 LoRA 秩或增加 dropout |
| 训练速度慢 | 批次大小过大或使用全精度 | 减小批次大小或使用半精度(float16) |
| 内存不足 | 模型过大或批次大小过大 | 使用更小的模型版本、减小批次大小或启用梯度检查点 |
| 生成图像与文本描述不符 | 文本描述质量差或 LoRA 影响过大 | 优化文本描述、调整 LoRA alpha 或减小 r 值 |
7.4 高级技巧
- 多 LoRA 组合:同时加载多个 LoRA 权重,实现风格混合
- LoRA 缩放:加载时调整 LoRA 权重的缩放因子,控制风格强度
- 梯度检查点:启用
gradient_checkpointing减少内存占用 - 文本编码器微调:在数据量充足时,可解冻部分文本编码器层进行微调
- 评估指标:使用 FID、CLIP 分数等指标评估生成质量
总结
本章详细介绍了使用 LoRA 技术微调 Stable Diffusion 3.5 模型的完整流程,包括:
- 数据集准备与处理:创建图像-文本对数据集,实现自定义数据加载器
- LoRA 微调原理:理解低秩适配器的工作机制和优势
- 模型加载与配置:加载预训练模型,配置 LoRA 参数
- 训练循环实现:实现 Flow Matching 训练逻辑,理解 SD 3.5 的特殊损失计算
- 权重保存与加载:保存和加载 LoRA 权重,实现模型复用
通过 LoRA 微调,您可以高效地定制 Stable Diffusion 3.5 模型,使其适应特定领域或风格的图像生成需求。在实际应用中,建议根据具体任务调整参数和流程,以获得最佳效果。