SD 的思路来自 LDM 论文,把扩散模型从像素空间搬到了 VAE 的潜空间,计算量降得明显,同时用交叉注意力接住文本条件。下面按训练和推理两条线,把关键步骤和代码理一遍。

论文:https://arxiv.org/pdf/2112.10752
代码:https://github.com/CompVis/latent-diffusion
补充复现(简化版):https://github.com/wenwenqqq/sd-demo
环境与数据准备
依赖
主力是 PyTorch,加上 diffusers、peft。基本够用。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from diffusers import AutoencoderKL, CLIPTextModel, CLIPTokenizer, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from peft import LoraConfig, get_peft_model, PeftModel
logging.set_verbosity_info()
数据集
图像-文本配对就行。图像统一缩放到 512×512,归一化到 [-1, 1],和 VAE 的输入对齐。文本后面再用 CLIP 编码。
数据预处理
基础封装
把图像路径和文本读进来,做 resize、toTensor、归一化。
class ImageTextDataset(Dataset):
def __init__(self, image_dir, caption_csv, transform=None):
self.image_dir = image_dir
self.captions = pd.read_csv(caption_csv)
self.transform = transform
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
image_name = self.captions.iloc[idx]['image_name']
image_path = os.path.join(self.image_dir, image_name)
image = Image.open(image_path).convert("RGB")
text = self.captions.iloc[idx]['text'].strip()
if self.transform is not None:
image = self.transform(image)
return {"image": image, "text": text}
image_transform = transforms.Compose([
transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
base_dataset = ImageTextDataset(
image_dir="dataset/images",
caption_csv="dataset/captions.csv",
transform=image_transform
)
增强与潜空间编码
数据增强放在像素空间做,比如随机翻转、锐度调整,之后再用 VAE 编码到潜空间。潜空间里面就别做增强了,会破坏 VAE 的分布。VAE 冻结,只用来编码。
class AugmentedLatentDataset(Dataset):
def __init__(self, base_dataset, vae, augment_transform=None):
self.base_dataset = base_dataset
self.vae = vae
self.augment_transform = augment_transform
self.vae.eval()
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, idx):
data = self.base_dataset[idx]
image = data["image"]
text = data["text"]
if self.augment_transform is not None:
image = self.augment_transform(image)
with torch.no_grad():
latent = self.vae.encode(image.unsqueeze(0)).latent_dist.sample()
latent = latent * 0.18215 # SD 固定缩放
return {"latent": latent.squeeze(0), "text": text}
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
vae.requires_grad_(False)
augment_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.3),
])
latent_dataset = AugmentedLatentDataset(
base_dataset=base_dataset,
vae=vae,
augment_transform=augment_transform
)
DataLoader
封装一下,训练时 shuffle 开着,drop_last 可以避免最后 batch 尺寸不一致。pin_memory 和多 worker 能快一些。
def create_dataloader(latent_dataset, batch_size=4, shuffle=True, drop_last=True):
return DataLoader(
dataset=latent_dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=True,
num_workers=4
)
train_dataloader = create_dataloader(latent_dataset, batch_size=4, shuffle=True, drop_last=True)
训练:在潜空间里预测噪声
核心思路:每次都往干净的 latent 上加噪,让 UNet 根据时间步和文本去猜那个噪声,损失就用 MSE 算预测噪声和真实噪声的差距。整个过程全在潜空间,不通像素。
模型与优化器
CLIP 和 VAE 只读不训,UNet 是唯一要更新的。优化器用 AdamW + 线性 warmup 和衰减,这在 SD 里很常见。
def init_training_components():
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
text_encoder.requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet.train()
optimizer = optim.AdamW(
unet.parameters(),
lr=1e-4,
betas=(0.9, 0.999),
weight_decay=0.01
)
num_epochs = 10
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=int(num_training_steps * 0.1),
num_training_steps=num_training_steps
)
return tokenizer, text_encoder, unet, optimizer, lr_scheduler
tokenizer, text_encoder, unet, optimizer, lr_scheduler = init_training_components()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet.to(device)
text_encoder.to(device)
vae.to(device)
加噪
对每个 batch,随机抽一组时间步 t(1~1000),按 DDPM 的调度参数把噪声加进去。调度器用 diffusers.DDPMScheduler,配置 β 从 1e-4 到 0.02 的线性噪声表。

其中:

,

是预先算好的噪声调度序列。代码里直接调调度器的 add_noise 就完成了。
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_start=1e-4,
beta_end=0.02,
beta_schedule="linear"
)
def add_noise_to_latents(latents, timesteps, noise_scheduler):
noise = torch.randn_like(latents)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noisy_latents, noise
# ---- 取一个 batch 示例 ----
batch = next(iter(train_dataloader))
latents = batch["latent"].to(device)
timesteps = torch.randint(1, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device)
noisy_latents, real_noise = add_noise_to_latents(latents, timesteps, noise_scheduler)
文本编码
用 CLIP 把文本映射成 [BS, 77, 768] 的嵌入。tokenizer 会补到 77 个 token,不够的补 0,超出的截断。CLIP 冻结,不参与梯度。
def encode_text(texts, tokenizer, text_encoder):
inputs = tokenizer(
texts,
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).to(text_encoder.device)
with torch.no_grad():
text_embeddings = text_encoder(**inputs).last_hidden_state
return text_embeddings
text_embeddings = encode_text(batch["text"], tokenizer, text_encoder)
UNet 前向与损失
UNet 直接接受加噪 latent、时间步和文本嵌入,返回预测的噪声。时间步会经过位置编码和 MLP 广播到特征图,文本嵌入在交叉注意力层作为 K、V。损失只算 noise_pred 和 real_noise 的 MSE,简单直接。
def train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler):
noise_pred = unet(
sample=noisy_latents,
timestep=timesteps,
encoder_hidden_states=text_embeddings
).sample
loss = nn.MSELoss()(noise_pred, real_noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
return loss.item()
loss = train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler)
完整训练循环
把上面的步骤串起来,每个 epoch 结束顺手保存 checkpoint。单个 GPU 跑 10 个 epoch 大概要一阵子,显存不够的话后面会聊到 LoRA。
def full_training_loop(num_epochs, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler):
unet.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
for step, batch in enumerate(train_dataloader):
latents = batch["latent"].to(device)
texts = batch["text"]
timesteps = torch.randint(1, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device)
noisy_latents, real_noise = add_noise_to_latents(latents, timesteps, noise_scheduler)
text_embeddings = encode_text(texts, tokenizer, text_encoder)
batch_loss = train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler)
epoch_loss += batch_loss
if (step + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{len(train_dataloader)}], Loss: {batch_loss:.4f}")
avg_loss = epoch_loss / len(train_dataloader)
print(f"Epoch [{epoch+1}/{num_epochs}] avg loss: {avg_loss:.4f}")
torch.save(unet.state_dict(), f"unet_epoch_{epoch+1}.pth")
print(f"Checkpoint saved to unet_epoch_{epoch+1}.pth")
num_epochs = 10
full_training_loop(num_epochs, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler)
推理:从纯噪声走回图像
推理时从 t=1000 的随机噪声开始,用 DDIM 采样跳过步骤,50 步足够。每步跑两次 UNet(有文本和无文本),按 CFG 公式插值,提升文本贴合度。最后 VAE 解码回像素空间。
准备组件
加载训练好的 UNet 权重,VAE 和 CLIP 直接用预训练版本。采样器用 DDIMScheduler,设置推理步数。
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").eval().requires_grad_(False)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").eval().requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet.load_state_dict(torch.load("unet_epoch_10.pth"))
unet.eval().requires_grad_(False)
from diffusers import DDIMScheduler
sampler = DDIMScheduler.from_config(DDPMScheduler(num_train_timesteps=1000).config)
sampler.set_timesteps(num_inference_steps=50)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae.to(device); text_encoder.to(device); unet.to(device)
文本编码与 CFG
生成一对嵌入:真正 prompt 和空字符串(null prompt)。推理时 batch 扩展成 2,一次同时算,再拆开做 CFG。系数默认 7.5,调高能让图像更听话,但太高容易过饱和。
def encode_text_inference(prompt, tokenizer, text_encoder):
def encode(p):
inputs = tokenizer(p, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
return text_encoder(**inputs).last_hidden_state
return encode(prompt), encode("")
prompt = "a red cat sitting on a chair, high resolution"
text_emb, null_emb = encode_text_inference(prompt, tokenizer, text_encoder)
text_embeddings = torch.cat([null_emb, text_emb], dim=0)
逐步去噪
从纯噪声开始,按采样器的时间步循环。每一步用 torch.cat([latents]*2) 扩展输入,得到两个噪声预测,再用 CFG 公式合成去噪方向,交给采样器更新到下一个时间步。结束后用 VAE 解码。
def inference_step(latents, t, text_embeddings, unet, sampler, cfg_scale=7.5):
latent_input = torch.cat([latents] * 2)
noise_pred = unet(
sample=latent_input,
timestep=torch.tensor([t]*2, device=device),
encoder_hidden_states=text_embeddings
).sample
noise_pred_null, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_null + cfg_scale * (noise_pred_text - noise_pred_null)
latents = sampler.step(noise_pred, t, latents).prev_sample
return latents
latents = torch.randn((1,4,64,64), device=device)
for t in sampler.timesteps:
latents = inference_step(latents, t, text_embeddings, unet, sampler, cfg_scale=7.5)
latents = latents / 0.18215
with torch.no_grad():
image = vae.decode(latents).sample # [1,3,512,512]
image = (image / 2 + 0.5).clamp(0, 1).cpu().permute(0,2,3,1).numpy()[0]
image = Image.fromarray((image*255).astype(np.uint8))
image.save("generated_image.jpg")
print("Saved generated_image.jpg")
LoRA 微调
全量训练 UNet 显存消耗太大,单卡 40G 起步。LoRA 只在注意力层的 Q、V 投影上挂低秩矩阵,可训参数降到百万级,显存省很多,适合风格微调。
使用 PEFT 库
定义 LoraConfig,挂到 UNet 上,冻结原有权重。训练循环和前面一样,只是 optimizer 只动 LoRA 参数。
def init_lora_training():
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").requires_grad_(False)
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").requires_grad_(False)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none"
)
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()
optimizer = optim.AdamW(unet.parameters(), lr=5e-5, betas=(0.9,0.999), weight_decay=0.01)
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer, num_warmup_steps=int(num_training_steps*0.1), num_training_steps=num_training_steps)
return tokenizer, text_encoder, vae, unet, optimizer, lr_scheduler
tokenizer, text_encoder, vae, unet, optimizer, lr_scheduler = init_lora_training()
full_training_loop(5, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler)
unet.save_pretrained("lora_weights")
LoRA 推理
加载基础 UNet 后,用 PeftModel.from_pretrained 挂上适配器,其余流程不变。
def lora_inference(prompt, lora_path, cfg_scale=7.5):
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").eval().requires_grad_(False)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").eval().requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet = PeftModel.from_pretrained(unet, lora_path).eval().requires_grad_(False)
sampler = DDIMScheduler.from_config(DDPMScheduler(num_train_timesteps=1000).config)
sampler.set_timesteps(50)
# 复用前面的推理循环
text_emb, null_emb = encode_text_inference(prompt, tokenizer, text_encoder)
text_embeddings = torch.cat([null_emb, text_emb], dim=0)
latents = torch.randn((1,4,64,64), device=device)
for t in sampler.timesteps:
latents = inference_step(latents, t, text_embeddings, unet, sampler, cfg_scale)
# 解码保存
latents = latents / 0.18215
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0,1).cpu().permute(0,2,3,1).numpy()[0]
return Image.fromarray((image*255).astype(np.uint8))
img = lora_inference("a red cat, lora style", "lora_weights")
img.save("lora_output.jpg")
踩坑备忘
- 显存不足:降分辨率、开 gradient checkpointing,或者直接用 LoRA。
- Loss 降不下来:检查时间步采样范围、噪声调度器参数和学习率。
- 生成图不跟 prompt:调大 CFG_scale(7~10),加点推理步数,确认文本编码没问题。
- LoRA 不生效:确认
target_modules命中 UNet 的注意力投影层,秩 r 不要太小。
SD 里绝大部分计算都在潜空间,VAE 负责压缩,UNet 负责猜噪声,几步走通以后,剩下的就是调参和扩展了。


