跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

Stable Diffusion 3.5 LoRA 微调指南

Stable Diffusion 3.5 LoRA 微调技术详解。涵盖数据集准备、LoRA 原理、模型加载配置、训练循环实现及权重保存加载。重点解析 Flow Matching 机制下的损失计算与时间步采样策略,提供最佳实践与常见问题解决方案,助力高效定制模型风格。

Eee_123发布于 2026/4/10更新于 2026/6/218 浏览

概述

在之前的章节中,我们学习了如何获取和调用 Stable Diffusion 3.5 模型,以及深入理解了其核心的 Flow Matching 机制。本章将聚焦于LoRA(Low-Rank Adaptation)微调技术,这是一种高效的模型定制方法,能够在保持原有模型性能的同时,仅通过少量参数更新即可实现特定任务的定制化。

1. 数据集准备

1.1 数据集格式

微调 Stable Diffusion 3.5 模型需要图像 - 文本对数据集,每个数据项应包含以下两个核心字段:

  • img_path:图像文件的路径(支持绝对路径或相对路径)
  • caption:与图像内容精准匹配的文本描述
示例 JSON 数据集格式
[
    {"img_path": "/path/to/image1.jpg", "caption": "A beautiful sunset over the mountains"},
    {"img_path": "/path/to/image2.jpg", "caption": "A group of people at a conference"}
]

1.2 数据处理

为了方便加载和预处理数据,我们实现了一个自定义的 PyTorch 数据集类 StableDiffusionDataset。该类封装了以下核心功能:

  • 从 JSON 文件加载数据集元信息
  • 图像自动预处理(缩放、转换为张量、归一化)
  • 数据加载错误处理
数据集类实现
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class StableDiffusionDataset(Dataset):
    def __init__(self, json_path):
        """
        初始化 Stable Diffusion 微调数据集
        Args:
            json_path: JSON 文件路径,包含 img_path 和 caption 字段
        """
        super().__init__()
        # 读取 JSON 文件
        with (json_path, , encoding=)  f:
            .data = json.load(f)
        
        
        .transform = transforms.Compose([
            transforms.Resize((, )),  
            transforms.ToTensor(),  
            transforms.Normalize([], [])  
        ])

     ():
        
         (.data)

     ():
        
        item = .data[idx]
        
        img_path = item[]
        
          os.path.exists(img_path):
             FileNotFoundError()
        
        :
            
            image = Image.(img_path).convert()
         Exception  e:
             ValueError()
        
        image_tensor = .transform(image)
        
        caption = item[]
         image_tensor, caption
open
'r'
'utf-8'
as
self
# 定义图像预处理 pipeline
# 将图像调整为 512x512(SD 3.5 模型的默认输入尺寸),转换为张量并归一化到 [-1, 1] 范围
self
512
512
# 调整图像大小为 512x512
# 转换为张量 [0, 1]
0.5
0.5
# 归一化到 [-1, 1]
def
__len__
self
"""返回数据集样本数量"""
return
len
self
def
__getitem__
self, idx
""" 获取单个数据样本 Args: idx: 样本索引 Returns: tuple: (image_tensor, caption) - image_tensor: 处理后的图像张量,形状为 [3, 512, 512] - caption: 文本描述字符串 """
self
# 读取图像
'img_path'
# 检查文件是否存在
if
not
raise
f"图像文件不存在:{img_path}"
# 打开并转换图像
try
# 确保图像为 RGB 格式(丢弃 alpha 通道)
open
'RGB'
except
as
raise
f"无法读取图像 {img_path}: {str(e)}"
# 应用预处理转换
self
# 获取文本描述
'caption'
return
使用示例

以下是如何使用 StableDiffusionDataset 类的完整示例,包括数据集创建、样本查看和 DataLoader 构建:

# 创建数据集实例
dataset = StableDiffusionDataset("data.json")
# 查看数据集大小
print(f"数据集包含 {len(dataset)} 个样本")
# 获取单个样本
image, caption = dataset[0]
print(f"图像维度:{image.shape}")
print(f"文本描述:{caption}")
# 创建 DataLoader 用于批量训练
from torch.utils.data import DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=4,  # 批次大小,可根据 GPU 内存调整
    shuffle=True,  # 训练时打乱数据,增加随机性
    num_workers=2,  # 并行加载进程数,加速数据加载
    pin_memory=True  # 启用内存锁定,加速数据传输到 GPU
)

2. LoRA 微调原理

LoRA 是一种参数高效微调技术,其核心思想是:

  1. 冻结原有模型:保持预训练模型的权重不变,避免灾难性遗忘
  2. 添加低秩适配器:在关键层(如注意力层)插入低秩矩阵对(A 和 B)
  3. 仅训练低秩矩阵:通过少量参数更新即可实现模型定制

这种方法的优势在于:

  • 训练参数仅为原有模型的 1%-5%,大幅降低内存占用
  • 训练速度显著提升,减少计算资源消耗
  • 微调结果易于保存和分享,单个 LoRA 权重文件通常仅几 MB
  • 支持多 LoRA 权重组合使用,实现灵活的风格控制

3. 模型加载与 LoRA 配置

3.1 加载预训练模型

首先需要加载 Stable Diffusion 3.5 预训练模型:

from diffusers import StableDiffusion3Pipeline
import torch

# 加载预训练模型
model_id = "stabilityai/stable-diffusion-3.5-large"
pipeline = StableDiffusion3Pipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,  # 使用半精度(float16)加速计算并减少内存占用
).to("cuda")  # 移至 GPU 设备

3.2 配置 LoRA 参数

LoRA 的关键参数包括:

  • r:低秩矩阵的秩,控制适配器的容量(常用值:4, 8, 16, 32)
  • alpha:缩放因子,控制 LoRA 对模型的影响程度
  • target_modules:需要添加 LoRA 适配器的目标层
from peft import LoraConfig, get_peft_model

# 配置 LoRA 参数
lora_config = LoraConfig(
    r=16,  # 低秩矩阵的秩,r 越大,适配器容量越大,但参数也越多
    lora_alpha=32,  # 缩放因子,通常设置为 r 的 2 倍
    target_modules=['to_k',  # 注意力层的键(Key)投影层
                    'to_q',  # 注意力层的查询(Query)投影层
                    'to_v'],  # 注意力层的值(Value)投影层
    lora_dropout=0.05,  # Dropout 率,防止过拟合
    bias="none",  # 不对偏置项应用 LoRA
    task_type="TEXT_TO_IMAGE"  # 任务类型
)

# 为模型添加 LoRA 适配器
pipeline.transformer = get_peft_model(pipeline.transformer, lora_config)

# 冻结不需要训练的组件,仅训练 LoRA 适配器
pipeline.vae.requires_grad_(False)  # 冻结 VAE 编码器/解码器
pipeline.text_encoder.requires_grad_(False)  # 冻结文本编码器 1
pipeline.text_encoder_2.requires_grad_(False)  # 冻结文本编码器 2
pipeline.text_encoder_3.requires_grad_(False)  # 冻结文本编码器 3

# 打印可训练参数数量
print("可训练参数数量:", sum(p.numel() for p in pipeline.transformer.parameters() if p.requires_grad))

4. 训练循环实现

4.1 定义训练参数

import torch.nn.functional as F
from transformers import get_scheduler

# 训练参数配置
epochs = 5  # 训练轮次
batch_size = 4  # 批次大小
learning_rate = 1e-4  # 学习率
weight_decay = 1e-2  # 权重衰减,防止过拟合

# 优化器配置:仅优化可训练参数(LoRA 适配器参数)
optimizer = torch.optim.AdamW(
    params=filter(lambda p: p.requires_grad, pipeline.transformer.parameters()),
    lr=learning_rate,
    weight_decay=weight_decay
)

# 学习率调度器:使用余弦退火调度,逐步降低学习率
num_training_steps = len(dataloader) * epochs
scheduler = get_scheduler(
    name="cosine",  # 调度器类型
    optimizer=optimizer,
    num_warmup_steps=0,  # 预热步数
    num_training_steps=num_training_steps
)

4.2 执行训练循环

# 设置模型为训练模式
pipeline.transformer.train()
device = 'cuda:0'  # 指定 GPU 设备

def compute_density_for_timestep_sampling(batch_size, device):
    """
    基于正态分布的时间步采样,增加训练稳定性
    Args:
        batch_size: 批次大小
        device: 设备类型
    Returns:
        torch.Tensor: 采样的时间步权重,形状为 [batch_size]
    """
    u = torch.normal(0, 1, (batch_size,), device=device)
    u = torch.sigmoid(u)  # 将正态分布转换到 [0, 1] 区间
    return u

def get_sigmas(timesteps, n_dim, device):
    """
    获取对应时间步的噪声方差(sigmas)
    Args:
        timesteps: 时间步张量
        n_dim: 目标维度,用于广播 sigma
        device: 设备类型
    Returns:
        torch.Tensor: 噪声方差,形状为 [batch_size, 1, 1, 1]
    """
    # 将调度器的时间步和 sigmas 移动到当前设备
    scheduler_timesteps = pipeline.scheduler.timesteps.to(device)
    sigmas = pipeline.scheduler.sigmas.to(device)
    # 确保输入 timesteps 也在同一设备
    timesteps = timesteps.to(device)
    # 查找每个时间步对应的索引
    step_indices = [(scheduler_timesteps == t).nonzero().item() for t in timesteps]
    sigma = sigmas[step_indices].flatten()
    # 广播 sigma 到目标维度(适配 latent 形状)
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

# 开始训练循环
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    # 重置累积损失
    total_loss = 0
    for step, (images, captions) in enumerate(dataloader):
        # ----------------------------
        # 1. 编码文本(CLIP + T5)
        # ----------------------------
        with torch.no_grad():
            prompt_embeds, _, pooled_prompt_embeds, _ = pipeline.encode_prompt(
                prompt=captions,
                prompt_2=captions,
                prompt_3=captions,
                device=device,
                negative_prompt='',
                negative_prompt_2='',
                negative_prompt_3='',
                do_classifier_free_guidance=True,
            )
        # ----------------------------
        # 2. 将图像编码为潜在表示(latent)
        # ----------------------------
        # 移动图像到 GPU 并转换为半精度
        images = images.to(device, dtype=torch.float16)
        with torch.no_grad():
            # 使用 VAE 编码器将图像转换为潜在表示
            vae_output = pipeline.vae.encode(images)
            latents = vae_output.latent_dist.sample()
            # 应用 VAE 配置的缩放和偏移因子
            latents = (latents - pipeline.vae.config.shift_factor) * pipeline.vae.config.scaling_factor
        # ----------------------------
        # 3. 采样时间步(带权重方案)
        # ----------------------------
        u = compute_density_for_timestep_sampling(
            batch_size=batch_size,
            device=device
        )
        # 将采样权重转换为时间步索引
        indices = (u * pipeline.scheduler.config.num_train_timesteps).long()
        timesteps = pipeline.scheduler.timesteps.to(device)[indices]
        # ----------------------------
        # 4. Flow Matching:生成带噪声的潜在表示
        # ----------------------------
        # 获取对应时间步的噪声方差
        sigmas = get_sigmas(timesteps, n_dim=latents.ndim, device=device)
        # 生成随机噪声
        noise = torch.randn_like(latents, device=device)
        # 生成中间状态:(1-sigma)*latent + sigma*noise
        # 注意:SD 3.5 的 Flow Matching 插值方向与标准相反
        # 0 时刻是图像的压缩态(latents),1 时刻是噪声
        noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
        # ----------------------------
        # 5. 预测流场(model_pred)
        # ----------------------------
        # 使用混合精度训练,加速计算并减少内存占用
        with torch.autocast("cuda"):
            # 模型预测平均速度(方向:从压缩态到噪声)
            model_pred = pipeline.transformer(
                hidden_states=noisy_latents,
                timestep=timesteps,
                encoder_hidden_states=prompt_embeds,
                pooled_projections=pooled_prompt_embeds,
                return_dict=False
            )[0]
            # 计算预测的 latent:当前位置 + 速度*时间(反向)
            # 模型预测的是平均速度,乘以 (-sigma) 表示反向移动
            pred = model_pred * (-sigmas) + noisy_latents
            # 计算 MSE 损失:预测 latent 与真实 latent 的差距
            loss = F.mse_loss(pred, latents)
        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        # 累积损失
        total_loss += loss.item()
        # 打印训练日志
        if step % 100 == 0:
            avg_loss = total_loss / (step + 1)
            print(f"Step {step}, Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}")
    # 打印 epoch 日志
    avg_epoch_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} 完成,平均损失:{avg_epoch_loss:.4f}")

4.3 关于损失计算的说明

SD 3.5 的 Flow Matching 训练与标准 Flow Matching 有两点关键不同:

  1. 插值方向相反:
    • 标准 Flow Matching:0 时刻是噪声,1 时刻是图像
    • SD 3.5:0 时刻是图像的压缩态(latents),1 时刻是噪声
    • 中间状态:(1.0 - sigmas) * latents + sigmas * noise
  2. 模型预测目标不同:
    • 模型预测的是平均速度(方向:从压缩态到噪声)
    • 距离:model_pred * sigmas,平均速度乘以时间就是距离
    • 预测公式:pred = model_pred * (-sigmas) + noisy_latents
    • 损失计算:MSE(pred, latents),即中间点往回走一段距离的位置和起点的差距

这种设计能够更好地适应扩散模型的训练特性,提高生成质量和训练稳定性。

5. LoRA 权重保存与加载

5.1 保存 LoRA 权重

# 保存 LoRA 权重,仅保存可训练的 LoRA 参数
pipeline.transformer.save_pretrained("lora-sd35-finetuned")
print("LoRA 权重保存完成")

5.2 加载 LoRA 权重

from peft import PeftModel

# 加载预训练模型
pipeline = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large",
    torch_dtype=torch.float16
).to("cuda")

# 加载 LoRA 权重到 transformer 组件
pipeline.transformer = PeftModel.from_pretrained(
    pipeline.transformer,
    "lora-sd35-finetuned"
)

# 设置模型为推理模式
pipeline.transformer.eval()
print("LoRA 权重加载完成")

6. 推理

推理方式与之前的 Stable Diffusion 3.5 开发指南完全相同,加载 LoRA 权重后可直接使用 pipeline 进行图像生成:

# 示例:使用微调后的 LoRA 生成图像
prompt = "A beautiful sunset over the mountains in my style"
generated_image = pipeline(
    prompt=prompt,
    negative_prompt="blur, low quality, distortion",
    num_inference_steps=30,
    guidance_scale=7.5
).images[0]
# 保存生成的图像
generated_image.save("generated_image.png")

7. 最佳实践

7.1 数据准备

  • 数据质量:确保图像清晰(建议分辨率 ≥ 512x512),文本描述准确且详细
  • 数据多样性:包含多种场景、角度和风格的图像,避免过拟合
  • 数据量:建议至少准备 100-500 个样本,具体取决于任务复杂度
  • 文本描述优化:使用详细的关键词,例如 'a photo of a cat, detailed fur, blue eyes, sunny day' 而非 'a cat"

7.2 训练参数调整

参数建议范围说明
批次大小2-8根据 GPU 内存调整,A100 可使用 8-16
学习率5e-5 - 2e-4建议使用余弦退火调度,逐渐降低
LoRA 秩 r4-32小数据集使用小 r(4-8),大数据集使用大 r(16-32)
训练轮次5-20监控损失曲线,避免过拟合
权重衰减1e-2 - 1e-3防止过拟合,正则化模型

7.3 常见问题与解决方案

问题可能原因解决方案
生成图像模糊训练轮次不足或学习率过低增加训练轮次或提高学习率
过拟合(生成图像与训练集高度相似)数据量不足或 LoRA 秩过大增加数据量、减小 LoRA 秩或增加 dropout
训练速度慢批次大小过大或使用全精度减小批次大小或使用半精度(float16)
内存不足模型过大或批次大小过大使用更小的模型版本、减小批次大小或启用梯度检查点
生成图像与文本描述不符文本描述质量差或 LoRA 影响过大优化文本描述、调整 LoRA alpha 或减小 r 值

7.4 高级技巧

  1. 多 LoRA 组合:同时加载多个 LoRA 权重,实现风格混合
  2. LoRA 缩放:加载时调整 LoRA 权重的缩放因子,控制风格强度
  3. 梯度检查点:启用 gradient_checkpointing 减少内存占用
  4. 文本编码器微调:在数据量充足时,可解冻部分文本编码器层进行微调
  5. 评估指标:使用 FID、CLIP 分数等指标评估生成质量

总结

本章详细介绍了使用 LoRA 技术微调 Stable Diffusion 3.5 模型的完整流程,包括:

  1. 数据集准备与处理:创建图像 - 文本对数据集,实现自定义数据加载器
  2. LoRA 微调原理:理解低秩适配器的工作机制和优势
  3. 模型加载与配置:加载预训练模型,配置 LoRA 参数
  4. 训练循环实现:实现 Flow Matching 训练逻辑,理解 SD 3.5 的特殊损失计算
  5. 权重保存与加载:保存和加载 LoRA 权重,实现模型复用

通过 LoRA 微调,您可以高效地定制 Stable Diffusion 3.5 模型,使其适应特定领域或风格的图像生成需求。在实际应用中,建议根据具体任务调整参数和流程,以获得最佳效果。

目录

  1. 概述
  2. 1. 数据集准备
  3. 1.1 数据集格式
  4. 示例 JSON 数据集格式
  5. 1.2 数据处理
  6. 数据集类实现
  7. 使用示例
  8. 创建数据集实例
  9. 查看数据集大小
  10. 获取单个样本
  11. 创建 DataLoader 用于批量训练
  12. 2. LoRA 微调原理
  13. 3. 模型加载与 LoRA 配置
  14. 3.1 加载预训练模型
  15. 加载预训练模型
  16. 3.2 配置 LoRA 参数
  17. 配置 LoRA 参数
  18. 为模型添加 LoRA 适配器
  19. 冻结不需要训练的组件,仅训练 LoRA 适配器
  20. 打印可训练参数数量
  21. 4. 训练循环实现
  22. 4.1 定义训练参数
  23. 训练参数配置
  24. 优化器配置:仅优化可训练参数(LoRA 适配器参数)
  25. 学习率调度器:使用余弦退火调度,逐步降低学习率
  26. 4.2 执行训练循环
  27. 设置模型为训练模式
  28. 开始训练循环
  29. 4.3 关于损失计算的说明
  30. 5. LoRA 权重保存与加载
  31. 5.1 保存 LoRA 权重
  32. 保存 LoRA 权重,仅保存可训练的 LoRA 参数
  33. 5.2 加载 LoRA 权重
  34. 加载预训练模型
  35. 加载 LoRA 权重到 transformer 组件
  36. 设置模型为推理模式
  37. 6. 推理
  38. 示例:使用微调后的 LoRA 生成图像
  39. 保存生成的图像
  40. 7. 最佳实践
  41. 7.1 数据准备
  42. 7.2 训练参数调整
  43. 7.3 常见问题与解决方案
  44. 7.4 高级技巧
  45. 总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • GitHub Copilot:Python 开发者的 AI 助手
  • tablib:Python 数据处理与格式转换库详解
  • MySQL 数据库核心技术与实践指南
  • 选择排序算法原理、实现与复杂度分析
  • Vue Router 进阶实战:导航守卫、嵌套路由与状态管理
  • 国内大模型公司面试经验总结与技术要点分析
  • 2025 亚洲 WEB3 商业生态创新峰会将于香港举行
  • AI 元人文:自感概念与 DOS 模型深度解析
  • 大模型在医疗行业中的应用与技术解析
  • GraphRAG 提升 LLM 摘要总结能力的原理与实践
  • JavaScript 事件循环进阶:requestAnimationFrame 与 Web Workers
  • MATLAB 图像处理:冈萨雷斯 DIPUM 工具箱功能详解与实战
  • 数电设计步骤与 FPGA 实现的本质区别
  • 基于大模型的 ChatBI 实现与 Text-to-SQL 技术路线演进
  • Linux 内核 list_for_each_entry 链表遍历详解
  • C++ 面向对象核心:多态详解
  • 基于大语言模型和 RAG 的知识库问答系统
  • GitHub Copilot 主流模型对比与高效编程指南
  • WebView 并发初始化竞争风险分析
  • Kylin/Linux 服务器健康一键巡检脚本

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online