跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

Stable Diffusion 的工程实现:训练、推理与 LoRA 微调

Stable Diffusion基于潜空间扩散,将图像通过VAE压缩至低维空间进行DDPM训练与推理,大幅降低计算成本。训练时在潜空间加噪,UNet结合CLIP文本嵌入预测噪声,用MSE损失优化。推理使用DDIM采样和CFG增强文本控制,50步去噪后VAE解码回图像。LoRA方法通过低秩适配器在注意力层微调,以百万参数实现风格迁移,显著节省显存。关键细节包括数据增强限于像素空间、时间步采样范围、CFG系数调节等。

日志猎手发布于 2026/6/300 浏览
Stable Diffusion 的工程实现:训练、推理与 LoRA 微调

SD 的思路来自 LDM 论文,把扩散模型从像素空间搬到了 VAE 的潜空间,计算量降得明显,同时用交叉注意力接住文本条件。下面按训练和推理两条线,把关键步骤和代码理一遍。

文章配图

论文:https://arxiv.org/pdf/2112.10752
代码:https://github.com/CompVis/latent-diffusion
补充复现(简化版):https://github.com/wenwenqqq/sd-demo

环境与数据准备

依赖

主力是 PyTorch,加上 diffusers、peft。基本够用。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from diffusers import AutoencoderKL, CLIPTextModel, CLIPTokenizer, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import logging

from peft import LoraConfig, get_peft_model, PeftModel

logging.set_verbosity_info()

数据集

图像-文本配对就行。图像统一缩放到 512×512,归一化到 [-1, 1],和 VAE 的输入对齐。文本后面再用 CLIP 编码。

数据预处理

基础封装

把图像路径和文本读进来,做 resize、toTensor、归一化。

class ImageTextDataset(Dataset):
    def __init__(self, image_dir, caption_csv, transform=None):
        self.image_dir = image_dir
        self.captions = pd.read_csv(caption_csv)
        self.transform = transform

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image_name = self.captions.iloc[idx]['image_name']
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        text = self.captions.iloc[idx]['text'].strip()
        if self.transform is not None:
            image = self.transform(image)
        return {"image": image, "text": text}

image_transform = transforms.Compose([
    transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

base_dataset = ImageTextDataset(
    image_dir="dataset/images",
    caption_csv="dataset/captions.csv",
    transform=image_transform
)

增强与潜空间编码

数据增强放在像素空间做,比如随机翻转、锐度调整,之后再用 VAE 编码到潜空间。潜空间里面就别做增强了,会破坏 VAE 的分布。VAE 冻结,只用来编码。

class AugmentedLatentDataset(Dataset):
    def __init__(self, base_dataset, vae, augment_transform=None):
        self.base_dataset = base_dataset
        self.vae = vae
        self.augment_transform = augment_transform
        self.vae.eval()

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        data = self.base_dataset[idx]
        image = data["image"]
        text = data["text"]

        if self.augment_transform is not None:
            image = self.augment_transform(image)

        with torch.no_grad():
            latent = self.vae.encode(image.unsqueeze(0)).latent_dist.sample()
            latent = latent * 0.18215  # SD 固定缩放

        return {"latent": latent.squeeze(0), "text": text}

vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
vae.requires_grad_(False)

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.3),
])

latent_dataset = AugmentedLatentDataset(
    base_dataset=base_dataset,
    vae=vae,
    augment_transform=augment_transform
)

DataLoader

封装一下,训练时 shuffle 开着,drop_last 可以避免最后 batch 尺寸不一致。pin_memory 和多 worker 能快一些。

def create_dataloader(latent_dataset, batch_size=4, shuffle=True, drop_last=True):
    return DataLoader(
        dataset=latent_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        pin_memory=True,
        num_workers=4
    )

train_dataloader = create_dataloader(latent_dataset, batch_size=4, shuffle=True, drop_last=True)

训练:在潜空间里预测噪声

核心思路:每次都往干净的 latent 上加噪,让 UNet 根据时间步和文本去猜那个噪声,损失就用 MSE 算预测噪声和真实噪声的差距。整个过程全在潜空间,不通像素。

模型与优化器

CLIP 和 VAE 只读不训,UNet 是唯一要更新的。优化器用 AdamW + 线性 warmup 和衰减,这在 SD 里很常见。

def init_training_components():
    tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
    text_encoder.requires_grad_(False)

    unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
    unet.train()

    optimizer = optim.AdamW(
        unet.parameters(),
        lr=1e-4,
        betas=(0.9, 0.999),
        weight_decay=0.01
    )

    num_epochs = 10
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=int(num_training_steps * 0.1),
        num_training_steps=num_training_steps
    )
    return tokenizer, text_encoder, unet, optimizer, lr_scheduler

tokenizer, text_encoder, unet, optimizer, lr_scheduler = init_training_components()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet.to(device)
text_encoder.to(device)
vae.to(device)

加噪

对每个 batch,随机抽一组时间步 t(1~1000),按 DDPM 的调度参数把噪声加进去。调度器用 diffusers.DDPMScheduler,配置 β 从 1e-4 到 0.02 的线性噪声表。

文章配图

其中:

文章配图

,

文章配图

是预先算好的噪声调度序列。代码里直接调调度器的 add_noise 就完成了。

from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    beta_schedule="linear"
)

def add_noise_to_latents(latents, timesteps, noise_scheduler):
    noise = torch.randn_like(latents)
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    return noisy_latents, noise

# ---- 取一个 batch 示例 ----
batch = next(iter(train_dataloader))
latents = batch["latent"].to(device)
timesteps = torch.randint(1, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device)
noisy_latents, real_noise = add_noise_to_latents(latents, timesteps, noise_scheduler)

文本编码

用 CLIP 把文本映射成 [BS, 77, 768] 的嵌入。tokenizer 会补到 77 个 token,不够的补 0,超出的截断。CLIP 冻结,不参与梯度。

def encode_text(texts, tokenizer, text_encoder):
    inputs = tokenizer(
        texts,
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    ).to(text_encoder.device)
    with torch.no_grad():
        text_embeddings = text_encoder(**inputs).last_hidden_state
    return text_embeddings

text_embeddings = encode_text(batch["text"], tokenizer, text_encoder)

UNet 前向与损失

UNet 直接接受加噪 latent、时间步和文本嵌入,返回预测的噪声。时间步会经过位置编码和 MLP 广播到特征图,文本嵌入在交叉注意力层作为 K、V。损失只算 noise_pred 和 real_noise 的 MSE,简单直接。

def train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler):
    noise_pred = unet(
        sample=noisy_latents,
        timestep=timesteps,
        encoder_hidden_states=text_embeddings
    ).sample
    loss = nn.MSELoss()(noise_pred, real_noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    return loss.item()

loss = train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler)

完整训练循环

把上面的步骤串起来,每个 epoch 结束顺手保存 checkpoint。单个 GPU 跑 10 个 epoch 大概要一阵子,显存不够的话后面会聊到 LoRA。

def full_training_loop(num_epochs, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler):
    unet.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            latents = batch["latent"].to(device)
            texts = batch["text"]

            timesteps = torch.randint(1, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device)
            noisy_latents, real_noise = add_noise_to_latents(latents, timesteps, noise_scheduler)
            text_embeddings = encode_text(texts, tokenizer, text_encoder)

            batch_loss = train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler)
            epoch_loss += batch_loss

            if (step + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{len(train_dataloader)}], Loss: {batch_loss:.4f}")

        avg_loss = epoch_loss / len(train_dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] avg loss: {avg_loss:.4f}")
        torch.save(unet.state_dict(), f"unet_epoch_{epoch+1}.pth")
        print(f"Checkpoint saved to unet_epoch_{epoch+1}.pth")

num_epochs = 10
full_training_loop(num_epochs, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler)

推理:从纯噪声走回图像

推理时从 t=1000 的随机噪声开始,用 DDIM 采样跳过步骤,50 步足够。每步跑两次 UNet(有文本和无文本),按 CFG 公式插值,提升文本贴合度。最后 VAE 解码回像素空间。

准备组件

加载训练好的 UNet 权重,VAE 和 CLIP 直接用预训练版本。采样器用 DDIMScheduler,设置推理步数。

vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").eval().requires_grad_(False)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").eval().requires_grad_(False)

unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet.load_state_dict(torch.load("unet_epoch_10.pth"))
unet.eval().requires_grad_(False)

from diffusers import DDIMScheduler
sampler = DDIMScheduler.from_config(DDPMScheduler(num_train_timesteps=1000).config)
sampler.set_timesteps(num_inference_steps=50)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae.to(device); text_encoder.to(device); unet.to(device)

文本编码与 CFG

生成一对嵌入:真正 prompt 和空字符串(null prompt)。推理时 batch 扩展成 2,一次同时算,再拆开做 CFG。系数默认 7.5,调高能让图像更听话,但太高容易过饱和。

def encode_text_inference(prompt, tokenizer, text_encoder):
    def encode(p):
        inputs = tokenizer(p, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad():
            return text_encoder(**inputs).last_hidden_state
    return encode(prompt), encode("")

prompt = "a red cat sitting on a chair, high resolution"
text_emb, null_emb = encode_text_inference(prompt, tokenizer, text_encoder)
text_embeddings = torch.cat([null_emb, text_emb], dim=0)

逐步去噪

从纯噪声开始,按采样器的时间步循环。每一步用 torch.cat([latents]*2) 扩展输入,得到两个噪声预测,再用 CFG 公式合成去噪方向,交给采样器更新到下一个时间步。结束后用 VAE 解码。

def inference_step(latents, t, text_embeddings, unet, sampler, cfg_scale=7.5):
    latent_input = torch.cat([latents] * 2)
    noise_pred = unet(
        sample=latent_input,
        timestep=torch.tensor([t]*2, device=device),
        encoder_hidden_states=text_embeddings
    ).sample
    noise_pred_null, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_null + cfg_scale * (noise_pred_text - noise_pred_null)
    latents = sampler.step(noise_pred, t, latents).prev_sample
    return latents

latents = torch.randn((1,4,64,64), device=device)
for t in sampler.timesteps:
    latents = inference_step(latents, t, text_embeddings, unet, sampler, cfg_scale=7.5)

latents = latents / 0.18215
with torch.no_grad():
    image = vae.decode(latents).sample  # [1,3,512,512]
image = (image / 2 + 0.5).clamp(0, 1).cpu().permute(0,2,3,1).numpy()[0]
image = Image.fromarray((image*255).astype(np.uint8))
image.save("generated_image.jpg")
print("Saved generated_image.jpg")

LoRA 微调

全量训练 UNet 显存消耗太大,单卡 40G 起步。LoRA 只在注意力层的 Q、V 投影上挂低秩矩阵,可训参数降到百万级,显存省很多,适合风格微调。

使用 PEFT 库

定义 LoraConfig,挂到 UNet 上,冻结原有权重。训练循环和前面一样,只是 optimizer 只动 LoRA 参数。

def init_lora_training():
    tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").requires_grad_(False)
    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").requires_grad_(False)
    unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").requires_grad_(False)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.05,
        bias="none"
    )
    unet = get_peft_model(unet, lora_config)
    unet.print_trainable_parameters()

    optimizer = optim.AdamW(unet.parameters(), lr=5e-5, betas=(0.9,0.999), weight_decay=0.01)
    num_epochs = 5
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler("linear", optimizer, num_warmup_steps=int(num_training_steps*0.1), num_training_steps=num_training_steps)
    return tokenizer, text_encoder, vae, unet, optimizer, lr_scheduler

tokenizer, text_encoder, vae, unet, optimizer, lr_scheduler = init_lora_training()
full_training_loop(5, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler)
unet.save_pretrained("lora_weights")

LoRA 推理

加载基础 UNet 后,用 PeftModel.from_pretrained 挂上适配器,其余流程不变。

def lora_inference(prompt, lora_path, cfg_scale=7.5):
    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").eval().requires_grad_(False)
    tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").eval().requires_grad_(False)
    unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
    unet = PeftModel.from_pretrained(unet, lora_path).eval().requires_grad_(False)

    sampler = DDIMScheduler.from_config(DDPMScheduler(num_train_timesteps=1000).config)
    sampler.set_timesteps(50)

    # 复用前面的推理循环
    text_emb, null_emb = encode_text_inference(prompt, tokenizer, text_encoder)
    text_embeddings = torch.cat([null_emb, text_emb], dim=0)
    latents = torch.randn((1,4,64,64), device=device)
    for t in sampler.timesteps:
        latents = inference_step(latents, t, text_embeddings, unet, sampler, cfg_scale)
    # 解码保存
    latents = latents / 0.18215
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0,1).cpu().permute(0,2,3,1).numpy()[0]
    return Image.fromarray((image*255).astype(np.uint8))

img = lora_inference("a red cat, lora style", "lora_weights")
img.save("lora_output.jpg")

踩坑备忘

  • 显存不足:降分辨率、开 gradient checkpointing,或者直接用 LoRA。
  • Loss 降不下来:检查时间步采样范围、噪声调度器参数和学习率。
  • 生成图不跟 prompt:调大 CFG_scale(7~10),加点推理步数,确认文本编码没问题。
  • LoRA 不生效:确认 target_modules 命中 UNet 的注意力投影层,秩 r 不要太小。

SD 里绝大部分计算都在潜空间,VAE 负责压缩,UNet 负责猜噪声,几步走通以后,剩下的就是调参和扩展了。

目录

  1. 环境与数据准备
  2. 依赖
  3. 数据集
  4. 数据预处理
  5. 基础封装
  6. 增强与潜空间编码
  7. DataLoader
  8. 训练:在潜空间里预测噪声
  9. 模型与优化器
  10. 加噪
  11. ---- 取一个 batch 示例 ----
  12. 文本编码
  13. UNet 前向与损失
  14. 完整训练循环
  15. 推理:从纯噪声走回图像
  16. 准备组件
  17. 文本编码与 CFG
  18. 逐步去噪
  19. LoRA 微调
  20. 使用 PEFT 库
  21. LoRA 推理
  22. 踩坑备忘
  • 免费图片AI生成工具免费生成了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 免费图片视频在线生成30秒,将你的创意变成现实开始设计
  • X/Twitter免费视频下载器免登陆无限额度免费视频解析下载了解详情
  • 100+免费在线小游戏爽一把
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • 我的 Java 开发环境搭建手记:从 JDK 到 Hello World
  • C++ 面试通关:语法基础、内存管理与类设计
  • 前端面试深度解析:核心概念与代码实践
  • Aurora 8B/10B 配置实战:从物理层到共享逻辑的避坑笔记
  • 大型模型评估的六个关键指标
  • PX4 Offboard 控制实战:从飞行模式理解到 ROS 轨迹跟踪
  • 移植 3D 封面画廊到 Android TV
  • 浏览器里用微信网页版?这个开源插件帮你绕开限制
  • 链表细节与 Java LinkedList 实战
  • Llama-3.2V-11B-cot 读胸片实测:推理过程、准确率与落地取舍
  • 用 DRF 搞定企业 API:从视图到监控的实战经验
  • 从零搭建在线投稿系统:SSM + Vue 实战笔记
  • MCP AI Copilot 运维实践:从智能告警到故障自愈的量化复盘
  • 从表单到 JSON:Spring Boot 前后端交互三案例
  • Claude Skill-Creator 内部解读:如何把 AI 技能开发做成工程循环
  • 实际项目里用了用 Copilot、Comate 和通义灵码,聊点真实感受
  • Linux 命名管道 FIFO 实战:跨进程通信与常见坑
  • 用 Nginx 部署 Vue 项目全过程
  • 我用过的7款AI写小说工具:加上这套SOP,终于不卡文了
  • 从零部署 OpenClaw:接入 QQ 的全流程踩坑记录

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online