Stable Diffusion 3.5 开发指南(三):Stable Diffusion 3.5 LoRA 微调

概述

在之前的章节中,我们学习了如何获取和调用 Stable Diffusion 3.5 模型,以及深入理解了其核心的 Flow Matching 机制。本章将聚焦于LoRA(Low-Rank Adaptation)微调技术,这是一种高效的模型定制方法,能够在保持原有模型性能的同时,仅通过少量参数更新即可实现特定任务的定制化。

1. 数据集准备

1.1 数据集格式

微调 Stable Diffusion 3.5 模型需要图像-文本对数据集,每个数据项应包含以下两个核心字段:

  • img_path:图像文件的路径(支持绝对路径或相对路径)
  • caption:与图像内容精准匹配的文本描述

示例 JSON 数据集格式

[{"img_path":"/path/to/image1.jpg","caption":"A beautiful sunset over the mountains"},{"img_path":"/path/to/image2.jpg","caption":"A group of people at a conference"}]

1.2 数据处理

为了方便加载和预处理数据,我们实现了一个自定义的 PyTorch 数据集类 StableDiffusionDataset。该类封装了以下核心功能:

  • 从 JSON 文件加载数据集元信息
  • 图像自动预处理(缩放、转换为张量、归一化)
  • 数据加载错误处理

数据集类实现

import json import os from PIL import Image import torch from torch.utils.data import Dataset from torchvision import transforms classStableDiffusionDataset(Dataset):def__init__(self, json_path):""" 初始化 Stable Diffusion 微调数据集 Args: json_path: JSON 文件路径,包含 img_path 和 caption 字段 """super().__init__()# 读取 JSON 文件withopen(json_path,'r', encoding='utf-8')as f: self.data = json.load(f)# 定义图像预处理 pipeline# 将图像调整为 512x512(SD 3.5 模型的默认输入尺寸),转换为张量并归一化到 [-1, 1] 范围 self.transform = transforms.Compose([ transforms.Resize((512,512)),# 调整图像大小为 512x512 transforms.ToTensor(),# 转换为张量 [0, 1] transforms.Normalize([0.5],[0.5])# 归一化到 [-1, 1]])def__len__(self):"""返回数据集样本数量"""returnlen(self.data)def__getitem__(self, idx):""" 获取单个数据样本 Args: idx: 样本索引 Returns: tuple: (image_tensor, caption) - image_tensor: 处理后的图像张量,形状为 [3, 512, 512] - caption: 文本描述字符串 """ item = self.data[idx]# 读取图像 img_path = item['img_path']# 检查文件是否存在ifnot os.path.exists(img_path):raise FileNotFoundError(f"图像文件不存在: {img_path}")# 打开并转换图像try:# 确保图像为 RGB 格式(丢弃 alpha 通道) image = Image.open(img_path).convert('RGB')except Exception as e:raise ValueError(f"无法读取图像 {img_path}: {str(e)}")# 应用预处理转换 image_tensor = self.transform(image)# 获取文本描述 caption = item['caption']return image_tensor, caption 

使用示例

以下是如何使用 StableDiffusionDataset 类的完整示例,包括数据集创建、样本查看和 DataLoader 构建:

# 创建数据集实例 dataset = StableDiffusionDataset("data.json")# 查看数据集大小print(f"数据集包含 {len(dataset)} 个样本")# 获取单个样本 image, caption = dataset[0]print(f"图像维度: {image.shape}")print(f"文本描述: {caption}")# 创建 DataLoader 用于批量训练from torch.utils.data import DataLoader dataloader = DataLoader( dataset, batch_size=4,# 批次大小,可根据 GPU 内存调整 shuffle=True,# 训练时打乱数据,增加随机性 num_workers=2,# 并行加载进程数,加速数据加载 pin_memory=True# 启用内存锁定,加速数据传输到 GPU)

2. LoRA 微调原理

LoRA 是一种参数高效微调技术,其核心思想是:

  1. 冻结原有模型:保持预训练模型的权重不变,避免灾难性遗忘
  2. 添加低秩适配器:在关键层(如注意力层)插入低秩矩阵对(A 和 B)
  3. 仅训练低秩矩阵:通过少量参数更新即可实现模型定制

这种方法的优势在于:

  • 训练参数仅为原有模型的 1%-5%,大幅降低内存占用
  • 训练速度显著提升,减少计算资源消耗
  • 微调结果易于保存和分享,单个 LoRA 权重文件通常仅几 MB
  • 支持多 LoRA 权重组合使用,实现灵活的风格控制

3. 模型加载与 LoRA 配置

3.1 加载预训练模型

首先需要加载 Stable Diffusion 3.5 预训练模型:

from diffusers import StableDiffusion3Pipeline import torch # 加载预训练模型 model_id ="stabilityai/stable-diffusion-3.5-large" pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, torch_dtype=torch.float16,# 使用半精度(float16)加速计算并减少内存占用).to("cuda")# 移至 GPU 设备

3.2 配置 LoRA 参数

LoRA 的关键参数包括:

  • r:低秩矩阵的秩,控制适配器的容量(常用值:4, 8, 16, 32)
  • alpha:缩放因子,控制 LoRA 对模型的影响程度
  • target_modules:需要添加 LoRA 适配器的目标层
from peft import LoraConfig, get_peft_model # 配置 LoRA 参数 lora_config = LoraConfig( r=16,# 低秩矩阵的秩,r 越大,适配器容量越大,但参数也越多 lora_alpha=32,# 缩放因子,通常设置为 r 的 2 倍 target_modules=['to_k',# 注意力层的键(Key)投影层'to_q',# 注意力层的查询(Query)投影层'to_v',# 注意力层的值(Value)投影层], lora_dropout=0.05,# Dropout 率,防止过拟合 bias="none",# 不对偏置项应用 LoRA task_type="TEXT_TO_IMAGE"# 任务类型)# 为模型添加 LoRA 适配器 pipeline.transformer = get_peft_model(pipeline.transformer, lora_config)# 冻结不需要训练的组件,仅训练 LoRA 适配器 pipeline.vae.requires_grad_(False)# 冻结 VAE 编码器/解码器 pipeline.text_encoder.requires_grad_(False)# 冻结文本编码器 1 pipeline.text_encoder_2.requires_grad_(False)# 冻结文本编码器 2 pipeline.text_encoder_3.requires_grad_(False)# 冻结文本编码器 3# 打印可训练参数数量print("可训练参数数量:",sum(p.numel()for p in pipeline.transformer.parameters()if p.requires_grad))

4. 训练循环实现

4.1 定义训练参数

import torch.nn.functional as F from transformers import get_scheduler # 训练参数配置 epochs =5# 训练轮次 batch_size =4# 批次大小 learning_rate =1e-4# 学习率 weight_decay =1e-2# 权重衰减,防止过拟合# 优化器配置:仅优化可训练参数(LoRA 适配器参数) optimizer = torch.optim.AdamW( params=filter(lambda p: p.requires_grad, pipeline.transformer.parameters()), lr=learning_rate, weight_decay=weight_decay )# 学习率调度器:使用余弦退火调度,逐步降低学习率 num_training_steps =len(dataloader)* epochs scheduler = get_scheduler( name="cosine",# 调度器类型 optimizer=optimizer, num_warmup_steps=0,# 预热步数 num_training_steps=num_training_steps )

4.2 执行训练循环

# 设置模型为训练模式 pipeline.transformer.train() device ='cuda:0'# 指定 GPU 设备defcompute_density_for_timestep_sampling(batch_size, device):""" 基于正态分布的时间步采样,增加训练稳定性 Args: batch_size: 批次大小 device: 设备类型 Returns: torch.Tensor: 采样的时间步权重,形状为 [batch_size] """ u = torch.normal(0,1,(batch_size,), device=device) u = torch.sigmoid(u)# 将正态分布转换到 [0, 1] 区间return u defget_sigmas(timesteps, n_dim, device):""" 获取对应时间步的噪声方差(sigmas) Args: timesteps: 时间步张量 n_dim: 目标维度,用于广播 sigma device: 设备类型 Returns: torch.Tensor: 噪声方差,形状为 [batch_size, 1, 1, 1] """# 将调度器的时间步和 sigmas 移动到当前设备 scheduler_timesteps = pipeline.scheduler.timesteps.to(device) sigmas = pipeline.scheduler.sigmas.to(device)# 确保输入 timesteps 也在同一设备 timesteps = timesteps.to(device)# 查找每个时间步对应的索引 step_indices =[(scheduler_timesteps == t).nonzero().item()for t in timesteps] sigma = sigmas[step_indices].flatten()# 广播 sigma 到目标维度(适配 latent 形状)whilelen(sigma.shape)< n_dim: sigma = sigma.unsqueeze(-1)return sigma # 开始训练循环for epoch inrange(epochs):print(f"Epoch {epoch+1}/{epochs}")# 重置累积损失 total_loss =0for step,(images, captions)inenumerate(dataloader):# ----------------------------# 1. 编码文本(CLIP + T5)# ----------------------------with torch.no_grad(): prompt_embeds, _, pooled_prompt_embeds, _ = pipeline.encode_prompt( prompt=captions, prompt_2=captions, prompt_3=captions, device=device, negative_prompt='', negative_prompt_2='', negative_prompt_3='', do_classifier_free_guidance=True,)# ----------------------------# 2. 将图像编码为潜在表示(latent)# ----------------------------# 移动图像到 GPU 并转换为半精度 images = images.to(device, dtype=torch.float16)with torch.no_grad():# 使用 VAE 编码器将图像转换为潜在表示 vae_output = pipeline.vae.encode(images) latents = vae_output.latent_dist.sample()# 应用 VAE 配置的缩放和偏移因子 latents =(latents - pipeline.vae.config.shift_factor)* pipeline.vae.config.scaling_factor # ----------------------------# 3. 采样时间步(带权重方案)# ---------------------------- u = compute_density_for_timestep_sampling( batch_size=batch_size, device=device )# 将采样权重转换为时间步索引 indices =(u * pipeline.scheduler.config.num_train_timesteps).long() timesteps = pipeline.scheduler.timesteps.to(device)[indices]# ----------------------------# 4. Flow Matching:生成带噪声的潜在表示# ----------------------------# 获取对应时间步的噪声方差 sigmas = get_sigmas(timesteps, n_dim=latents.ndim, device=device)# 生成随机噪声 noise = torch.randn_like(latents, device=device)# 生成中间状态:(1-sigma)*latent + sigma*noise# 注意:SD 3.5 的 Flow Matching 插值方向与标准相反# 0 时刻是图像的压缩态(latents),1 时刻是噪声 noisy_latents =(1.0- sigmas)* latents + sigmas * noise # ----------------------------# 5. 预测流场(model_pred)# ----------------------------# 使用混合精度训练,加速计算并减少内存占用with torch.autocast("cuda"):# 模型预测平均速度(方向:从压缩态到噪声) model_pred = pipeline.transformer( hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False)[0]# 计算预测的 latent:当前位置 + 速度*时间(反向)# 模型预测的是平均速度,乘以 (-sigma) 表示反向移动 pred = model_pred *(-sigmas)+ noisy_latents # 计算 MSE 损失:预测 latent 与真实 latent 的差距 loss = F.mse_loss(pred, latents)# 反向传播与优化 optimizer.zero_grad()# 清空梯度 loss.backward()# 计算梯度 optimizer.step()# 更新参数 scheduler.step()# 更新学习率# 累积损失 total_loss += loss.item()# 打印训练日志if step %100==0: avg_loss = total_loss /(step +1)print(f"Step {step}, Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}")# 打印 epoch 日志 avg_epoch_loss = total_loss /len(dataloader)print(f"Epoch {epoch+1} 完成, 平均损失: {avg_epoch_loss:.4f}")

4.3 关于损失计算的说明

SD 3.5 的 Flow Matching 训练与标准 Flow Matching 有两点关键不同:

  1. 插值方向相反
    • 标准 Flow Matching:0 时刻是噪声,1 时刻是图像
    • SD 3.5:0 时刻是图像的压缩态(latents),1 时刻是噪声
    • 中间状态:(1.0 - sigmas) * latents + sigmas * noise
  2. 模型预测目标不同
    • 模型预测的是平均速度(方向:从压缩态到噪声)
    • 距离:model_pred * sigmas,平均速度乘以时间就是距离
    • 预测公式:pred = model_pred * (-sigmas) + noisy_latents
    • 损失计算:MSE(pred, latents),即中间点往回走一段距离的位置和起点的差距

这种设计能够更好地适应扩散模型的训练特性,提高生成质量和训练稳定性。

5. LoRA 权重保存与加载

5.1 保存 LoRA 权重

# 保存 LoRA 权重,仅保存可训练的 LoRA 参数 pipeline.transformer.save_pretrained("lora-sd35-finetuned")print("LoRA 权重保存完成")

5.2 加载 LoRA 权重

from peft import PeftModel # 加载预训练模型 pipeline = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.float16 ).to("cuda")# 加载 LoRA 权重到 transformer 组件 pipeline.transformer = PeftModel.from_pretrained( pipeline.transformer,"lora-sd35-finetuned")# 设置模型为推理模式 pipeline.transformer.eval()print("LoRA 权重加载完成")

6. 推理

推理方式与之前的 Stable Diffusion 3.5 开发指南(一)完全相同,加载 LoRA 权重后可直接使用 pipeline 进行图像生成:

# 示例:使用微调后的 LoRA 生成图像 prompt ="A beautiful sunset over the mountains in my style" generated_image = pipeline( prompt=prompt, negative_prompt="blur, low quality, distortion", num_inference_steps=30, guidance_scale=7.5).images[0]# 保存生成的图像 generated_image.save("generated_image.png")

7. 最佳实践

7.1 数据准备

  • 数据质量:确保图像清晰(建议分辨率 ≥ 512x512),文本描述准确且详细
  • 数据多样性:包含多种场景、角度和风格的图像,避免过拟合
  • 数据量:建议至少准备 100-500 个样本,具体取决于任务复杂度
  • 文本描述优化:使用详细的关键词,例如 “a photo of a cat, detailed fur, blue eyes, sunny day” 而非 “a cat”

7.2 训练参数调整

参数建议范围说明
批次大小2-8根据 GPU 内存调整,A100 可使用 8-16
学习率5e-5 - 2e-4建议使用余弦退火调度,逐渐降低
LoRA 秩 r4-32小数据集使用小 r(4-8),大数据集使用大 r(16-32)
训练轮次5-20监控损失曲线,避免过拟合
权重衰减1e-2 - 1e-3防止过拟合,正则化模型

7.3 常见问题与解决方案

问题可能原因解决方案
生成图像模糊训练轮次不足或学习率过低增加训练轮次或提高学习率
过拟合(生成图像与训练集高度相似)数据量不足或 LoRA 秩过大增加数据量、减小 LoRA 秩或增加 dropout
训练速度慢批次大小过大或使用全精度减小批次大小或使用半精度(float16)
内存不足模型过大或批次大小过大使用更小的模型版本、减小批次大小或启用梯度检查点
生成图像与文本描述不符文本描述质量差或 LoRA 影响过大优化文本描述、调整 LoRA alpha 或减小 r 值

7.4 高级技巧

  1. 多 LoRA 组合:同时加载多个 LoRA 权重,实现风格混合
  2. LoRA 缩放:加载时调整 LoRA 权重的缩放因子,控制风格强度
  3. 梯度检查点:启用 gradient_checkpointing 减少内存占用
  4. 文本编码器微调:在数据量充足时,可解冻部分文本编码器层进行微调
  5. 评估指标:使用 FID、CLIP 分数等指标评估生成质量

总结

本章详细介绍了使用 LoRA 技术微调 Stable Diffusion 3.5 模型的完整流程,包括:

  1. 数据集准备与处理:创建图像-文本对数据集,实现自定义数据加载器
  2. LoRA 微调原理:理解低秩适配器的工作机制和优势
  3. 模型加载与配置:加载预训练模型,配置 LoRA 参数
  4. 训练循环实现:实现 Flow Matching 训练逻辑,理解 SD 3.5 的特殊损失计算
  5. 权重保存与加载:保存和加载 LoRA 权重,实现模型复用

通过 LoRA 微调,您可以高效地定制 Stable Diffusion 3.5 模型,使其适应特定领域或风格的图像生成需求。在实际应用中,建议根据具体任务调整参数和流程,以获得最佳效果。

Read more

AI大模型核心概念解析:Token 究竟是什么?

在大模型(LLM)的世界里,token 是一个基础且重要的概念。接下来,让我们一文读懂大模型中的 token 究竟是什么。 一、token究竟是什么? 在大语言模型(LLM)中,Token 代表模型可以理解和生成的最小意义单位,是模型处理文本的基础单元。它就像是模型世界里的 “积木块”,模型通过对这些 “积木块” 的操作来理解和生成文本。根据所使用的特定标记化方案,Token 可以表示单词、单词的一部分,甚至只表示字符。 例如,对于英文文本,“apple” 可能是一个 Token,而对于中文文本,“苹果” 可能是一个 Token。但有时候,Token 并不完全等同于我们日常理解的单词或汉字,它还可能是单词的片段,比如 “playing” 可能被拆分为 “play” 和 “ing” 两个 Token。 为了让模型能够处理这些 Token,

AI一键生成专业技术路线图(课题研究/论文 技术路线图)

AI一键生成专业技术路线图(课题研究/论文 技术路线图)

工具地址:https://draw.anqstar.com/ 一、技术背景:计算机专业学生的“路线图痛点”,你是否也遇到过? 对于计算机专业的大学生而言,从课程设计、课程论文,到最终的毕业设计、毕业论文,“技术路线图”都是不可或缺的核心组成部分——它是梳理课题思路、明确研究步骤、展示技术逻辑的关键载体,直接影响作业/论文的完整性和专业性。 但实际操作中,绝大多数同学都会陷入这样的困境,尤其是涉及MySQL、SQL Server、SQL等数据库相关课题时,痛点更为突出: 1.1 小白入门难,无从下手 刚接触课设、毕设的同学,对“技术路线图”的规范的格式、核心要素一无所知,不清楚如何将SQL查询、MySQL数据库搭建、SQL Server数据存储等技术点,合理融入路线图的各个环节,常常对着空白画布发呆,浪费大量时间。 1.2 技术梳理乱,逻辑断层

相干伊辛机在医疗领域及医疗AI领域的应用前景分析

相干伊辛机在医疗领域及医疗AI领域的应用前景分析

引言:当量子退火遇见精准医疗 21世纪的医疗健康领域正经历着一场由数据驱动的深刻变革。从基因组学到医学影像,从电子病历到可穿戴设备,医疗数据正以指数级增长。然而,海量数据的背后是经典的“组合爆炸”难题——例如,药物分子中电子的量子态搜索、多模态医疗影像的特征匹配、个性化治疗方案的组合优化等,这些问题对经典计算机,甚至对传统的超级计算机而言,都构成了难以逾越的计算壁垒。 相干伊辛机(Coherent Ising Machine, CIM)作为一种基于量子光学和量子退火原理的新型计算范式,为解决这类组合优化问题提供了全新的物理路径。它不同于通用量子计算机(如超导门模型),CIM是专为寻找复杂伊辛模型基态而设计的专用量子处理器。本文将深入探讨CIM如何凭借其强大的并行搜索能力,在药物研发、精准诊断、个性化治疗以及医疗AI优化等领域,从计算底层赋能医疗科技的未来。 一、 相干伊辛机:从统计物理到量子计算引擎 要理解CIM在医疗领域的潜力,首先需要深入其物理内核,厘清它如何通过光的相干性来高效解决现实世界的复杂问题。 1. 伊辛模型:组合优化的“通用语言” 伊辛模型最初源于统计物理学

进阶实战:CLIProxyAPI Plus + OpenClaw 零配置结合,打造你的专属 24/7 AI 超级助手(保姆级 + 原理级教程)

进阶实战:CLIProxyAPI Plus + OpenClaw 零配置结合,打造你的专属 24/7 AI 超级助手(保姆级 + 原理级教程)

相关链接(仍在活跃):CLIProxyAPI Plus、OpenClaw 昨天我们已经用 Docker 一键部署了 CLIProxyAPI Plus(简称 CPA),生成了专属 API 密钥,并通过 http://你的IP:9999/v1 实现了 OpenAI 兼容端点。今天我们继续进阶:把这个代理完美对接 OpenClaw(开源个人 AI 助手,前身 Clawdbot),让 OpenClaw 通过你的 CPA 代理调用多账号 OpenAI 模型,实现 WhatsApp/Telegram/Slack 等消息渠道的自动任务执行、代码编写、邮件处理、日历管理等全自动化能力。 为什么必须结合? CPA 负责账号集中管理、额度自动切换、