跳到主要内容 Stable Diffusion 3.5 LoRA 微调指南 | 极客日志
Python AI 算法
Stable Diffusion 3.5 LoRA 微调指南 Stable Diffusion 3.5 LoRA 微调技术详解。涵盖数据集准备、LoRA 原理、模型加载配置、训练循环实现及权重保存加载。重点解析 Flow Matching 机制下的损失计算与时间步采样策略,提供最佳实践与常见问题解决方案,助力高效定制模型风格。
概述
在之前的章节中,我们学习了如何获取和调用 Stable Diffusion 3.5 模型,以及深入理解了其核心的 Flow Matching 机制。本章将聚焦于LoRA(Low-Rank Adaptation)微调技术 ,这是一种高效的模型定制方法,能够在保持原有模型性能的同时,仅通过少量参数更新即可实现特定任务的定制化。
1. 数据集准备
1.1 数据集格式
微调 Stable Diffusion 3.5 模型需要图像 - 文本对 数据集,每个数据项应包含以下两个核心字段:
img_path:图像文件的路径(支持绝对路径或相对路径)
caption:与图像内容精准匹配的文本描述
示例 JSON 数据集格式 [
{"img_path" : "/path/to/image1.jpg" , "caption" : "A beautiful sunset over the mountains" },
{"img_path" : "/path/to/image2.jpg" , "caption" : "A group of people at a conference" }
]
1.2 数据处理 为了方便加载和预处理数据,我们实现了一个自定义的 PyTorch 数据集类 StableDiffusionDataset。该类封装了以下核心功能:
从 JSON 文件加载数据集元信息
图像自动预处理(缩放、转换为张量、归一化)
数据加载错误处理
数据集类实现 import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class StableDiffusionDataset (Dataset ):
def __init__ (self, json_path ):
"""
初始化 Stable Diffusion 微调数据集
Args:
json_path: JSON 文件路径,包含 img_path 和 caption 字段
"""
super ().__init__()
with open (json_path, 'r' , encoding='utf-8' ) as f:
self .data = json.load(f)
self .transform = transforms.Compose([
transforms.Resize((512 , 512 )),
transforms.ToTensor(),
transforms.Normalize([0.5 ], [0.5 ])
])
def __len__ (self ):
"""返回数据集样本数量"""
return len (self .data)
def __getitem__ (self, idx ):
"""
获取单个数据样本
Args:
idx: 样本索引
Returns:
tuple: (image_tensor, caption) - image_tensor: 处理后的图像张量,形状为 [3, 512, 512] - caption: 文本描述字符串
"""
item = self .data[idx]
img_path = item['img_path' ]
if not os.path.exists(img_path):
raise FileNotFoundError(f"图像文件不存在:{img_path} " )
try :
image = Image.open (img_path).convert('RGB' )
except Exception as e:
raise ValueError(f"无法读取图像 {img_path} : {str (e)} " )
image_tensor = self .transform(image)
caption = item['caption' ]
return image_tensor, caption
使用示例 以下是如何使用 StableDiffusionDataset 类的完整示例,包括数据集创建、样本查看和 DataLoader 构建:
dataset = StableDiffusionDataset("data.json" )
print (f"数据集包含 {len (dataset)} 个样本" )
image, caption = dataset[0 ]
print (f"图像维度:{image.shape} " )
print (f"文本描述:{caption} " )
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=4 ,
shuffle=True ,
num_workers=2 ,
pin_memory=True
)
2. LoRA 微调原理
冻结原有模型 :保持预训练模型的权重不变,避免灾难性遗忘
添加低秩适配器 :在关键层(如注意力层)插入低秩矩阵对(A 和 B)
仅训练低秩矩阵 :通过少量参数更新即可实现模型定制
训练参数仅为原有模型的 1%-5%,大幅降低内存占用
训练速度显著提升,减少计算资源消耗
微调结果易于保存和分享,单个 LoRA 权重文件通常仅几 MB
支持多 LoRA 权重组合使用,实现灵活的风格控制
3. 模型加载与 LoRA 配置
3.1 加载预训练模型 首先需要加载 Stable Diffusion 3.5 预训练模型:
from diffusers import StableDiffusion3Pipeline
import torch
model_id = "stabilityai/stable-diffusion-3.5-large"
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
).to("cuda" )
3.2 配置 LoRA 参数
r:低秩矩阵的秩,控制适配器的容量(常用值:4, 8, 16, 32)
alpha:缩放因子,控制 LoRA 对模型的影响程度
target_modules:需要添加 LoRA 适配器的目标层
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16 ,
lora_alpha=32 ,
target_modules=['to_k' ,
'to_q' ,
'to_v' ],
lora_dropout=0.05 ,
bias="none" ,
task_type="TEXT_TO_IMAGE"
)
pipeline.transformer = get_peft_model(pipeline.transformer, lora_config)
pipeline.vae.requires_grad_(False )
pipeline.text_encoder.requires_grad_(False )
pipeline.text_encoder_2.requires_grad_(False )
pipeline.text_encoder_3.requires_grad_(False )
print ("可训练参数数量:" , sum (p.numel() for p in pipeline.transformer.parameters() if p.requires_grad))
4. 训练循环实现
4.1 定义训练参数 import torch.nn.functional as F
from transformers import get_scheduler
epochs = 5
batch_size = 4
learning_rate = 1e-4
weight_decay = 1e-2
optimizer = torch.optim.AdamW(
params=filter (lambda p: p.requires_grad, pipeline.transformer.parameters()),
lr=learning_rate,
weight_decay=weight_decay
)
num_training_steps = len (dataloader) * epochs
scheduler = get_scheduler(
name="cosine" ,
optimizer=optimizer,
num_warmup_steps=0 ,
num_training_steps=num_training_steps
)
4.2 执行训练循环
pipeline.transformer.train()
device = 'cuda:0'
def compute_density_for_timestep_sampling (batch_size, device ):
"""
基于正态分布的时间步采样,增加训练稳定性
Args:
batch_size: 批次大小
device: 设备类型
Returns:
torch.Tensor: 采样的时间步权重,形状为 [batch_size]
"""
u = torch.normal(0 , 1 , (batch_size,), device=device)
u = torch.sigmoid(u)
return u
def get_sigmas (timesteps, n_dim, device ):
"""
获取对应时间步的噪声方差(sigmas)
Args:
timesteps: 时间步张量
n_dim: 目标维度,用于广播 sigma
device: 设备类型
Returns:
torch.Tensor: 噪声方差,形状为 [batch_size, 1, 1, 1]
"""
scheduler_timesteps = pipeline.scheduler.timesteps.to(device)
sigmas = pipeline.scheduler.sigmas.to(device)
timesteps = timesteps.to(device)
step_indices = [(scheduler_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len (sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1 )
return sigma
for epoch in range (epochs):
print (f"Epoch {epoch+1 } /{epochs} " )
total_loss = 0
for step, (images, captions) in enumerate (dataloader):
with torch.no_grad():
prompt_embeds, _, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=captions,
prompt_2=captions,
prompt_3=captions,
device=device,
negative_prompt='' ,
negative_prompt_2='' ,
negative_prompt_3='' ,
do_classifier_free_guidance=True ,
)
images = images.to(device, dtype=torch.float16)
with torch.no_grad():
vae_output = pipeline.vae.encode(images)
latents = vae_output.latent_dist.sample()
latents = (latents - pipeline.vae.config.shift_factor) * pipeline.vae.config.scaling_factor
u = compute_density_for_timestep_sampling(
batch_size=batch_size,
device=device
)
indices = (u * pipeline.scheduler.config.num_train_timesteps).long()
timesteps = pipeline.scheduler.timesteps.to(device)[indices]
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, device=device)
noise = torch.randn_like(latents, device=device)
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
with torch.autocast("cuda" ):
model_pred = pipeline.transformer(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False
)[0 ]
pred = model_pred * (-sigmas) + noisy_latents
loss = F.mse_loss(pred, latents)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
if step % 100 == 0 :
avg_loss = total_loss / (step + 1 )
print (f"Step {step} , Loss: {loss.item():.4 f} , Avg Loss: {avg_loss:.4 f} " )
avg_epoch_loss = total_loss / len (dataloader)
print (f"Epoch {epoch+1 } 完成,平均损失:{avg_epoch_loss:.4 f} " )
4.3 关于损失计算的说明 SD 3.5 的 Flow Matching 训练与标准 Flow Matching 有两点关键不同:
插值方向相反 :
标准 Flow Matching:0 时刻是噪声,1 时刻是图像
SD 3.5:0 时刻是图像的压缩态(latents),1 时刻是噪声
中间状态:(1.0 - sigmas) * latents + sigmas * noise
模型预测目标不同 :
模型预测的是平均速度 (方向:从压缩态到噪声)
距离:model_pred * sigmas,平均速度乘以时间就是距离
预测公式:pred = model_pred * (-sigmas) + noisy_latents
损失计算:MSE(pred, latents),即中间点往回走一段距离的位置和起点的差距
这种设计能够更好地适应扩散模型的训练特性,提高生成质量和训练稳定性。
5. LoRA 权重保存与加载
5.1 保存 LoRA 权重
pipeline.transformer.save_pretrained("lora-sd35-finetuned" )
print ("LoRA 权重保存完成" )
5.2 加载 LoRA 权重 from peft import PeftModel
pipeline = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large" ,
torch_dtype=torch.float16
).to("cuda" )
pipeline.transformer = PeftModel.from_pretrained(
pipeline.transformer,
"lora-sd35-finetuned"
)
pipeline.transformer.eval ()
print ("LoRA 权重加载完成" )
6. 推理 推理方式与之前的 Stable Diffusion 3.5 开发指南完全相同,加载 LoRA 权重后可直接使用 pipeline 进行图像生成:
prompt = "A beautiful sunset over the mountains in my style"
generated_image = pipeline(
prompt=prompt,
negative_prompt="blur, low quality, distortion" ,
num_inference_steps=30 ,
guidance_scale=7.5
).images[0 ]
generated_image.save("generated_image.png" )
7. 最佳实践
7.1 数据准备
数据质量 :确保图像清晰(建议分辨率 ≥ 512x512),文本描述准确且详细
数据多样性 :包含多种场景、角度和风格的图像,避免过拟合
数据量 :建议至少准备 100-500 个样本,具体取决于任务复杂度
文本描述优化 :使用详细的关键词,例如 'a photo of a cat, detailed fur, blue eyes, sunny day' 而非 'a cat"
7.2 训练参数调整 参数 建议范围 说明 批次大小 2-8 根据 GPU 内存调整,A100 可使用 8-16 学习率 5e-5 - 2e-4 建议使用余弦退火调度,逐渐降低 LoRA 秩 r 4-32 小数据集使用小 r(4-8),大数据集使用大 r(16-32) 训练轮次 5-20 监控损失曲线,避免过拟合 权重衰减 1e-2 - 1e-3 防止过拟合,正则化模型
7.3 常见问题与解决方案 问题 可能原因 解决方案 生成图像模糊 训练轮次不足或学习率过低 增加训练轮次或提高学习率 过拟合(生成图像与训练集高度相似) 数据量不足或 LoRA 秩过大 增加数据量、减小 LoRA 秩或增加 dropout 训练速度慢 批次大小过大或使用全精度 减小批次大小或使用半精度(float16) 内存不足 模型过大或批次大小过大 使用更小的模型版本、减小批次大小或启用梯度检查点 生成图像与文本描述不符 文本描述质量差或 LoRA 影响过大 优化文本描述、调整 LoRA alpha 或减小 r 值
7.4 高级技巧
多 LoRA 组合 :同时加载多个 LoRA 权重,实现风格混合
LoRA 缩放 :加载时调整 LoRA 权重的缩放因子,控制风格强度
梯度检查点 :启用 gradient_checkpointing 减少内存占用
文本编码器微调 :在数据量充足时,可解冻部分文本编码器层进行微调
评估指标 :使用 FID、CLIP 分数等指标评估生成质量
总结 本章详细介绍了使用 LoRA 技术微调 Stable Diffusion 3.5 模型的完整流程,包括:
数据集准备与处理:创建图像 - 文本对数据集,实现自定义数据加载器
LoRA 微调原理:理解低秩适配器的工作机制和优势
模型加载与配置:加载预训练模型,配置 LoRA 参数
训练循环实现:实现 Flow Matching 训练逻辑,理解 SD 3.5 的特殊损失计算
权重保存与加载:保存和加载 LoRA 权重,实现模型复用
通过 LoRA 微调,您可以高效地定制 Stable Diffusion 3.5 模型,使其适应特定领域或风格的图像生成需求。在实际应用中,建议根据具体任务调整参数和流程,以获得最佳效果。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online