Stable Diffusion(SD)完整训练+推理流程详解(含伪代码,新手友好)
Stable Diffusion(SD)的核心理论基石源自论文《High-Resolution Image Synthesis with Latent Diffusion Models》(LDM),其革命性创新在于将扩散模型从高维像素空间迁移至 VAE 预训练的低维潜空间,在大幅降低训练与推理的计算成本(相比像素级扩散模型节省大量 GPU 资源)的同时,通过跨注意力机制实现文本、布局等多模态条件控制,兼顾了生成质量与灵活性。本文将基于这一核心思想,从数据预处理、模型训练、推理生成到 LoRA 轻量化训练,一步步拆解 SD 的完整技术流程,每个关键环节均搭配伪代码,结合实操场景,理解 SD 的工程实现。

论文地址:https://arxiv.org/pdf/2112.10752
论文代码:https://github.com/CompVis/latent-diffusion
复现代码(基于非官方的复现,简化版):https://github.com/wenwenqqq/sd-demo
核心前提:SD的核心设计是「潜空间扩散」——用VAE将图片映射到低维潜空间,在潜空间内完成DDPM的训练与推理,大幅降低计算量和显存消耗,这也是SD能高效训练大尺寸图片的关键。
一、前期准备与核心依赖
在开始流程前,需准备好核心依赖库和数据集,这里列出博客实操所需的基础依赖(基于PyTorch框架),以及数据集的基础要求。(以下伪代码仅供参考)
1.1 核心依赖库
SD的训练/推理依赖VAE、CLIP、UNet三大核心模型,以及数据处理、扩散模型相关的工具库,伪代码如下:
# 基础依赖 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms # SD核心依赖(可直接用diffusers库简化实现) from diffusers import AutoencoderKL, CLIPTextModel, CLIPTokenizer, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.utils import logging # 轻量化训练依赖(LoRA相关) from peft import LoraConfig, get_peft_model, PeftModel # 日志配置(方便调试) logging.set_verbosity_info() 1.2 数据集要求
本文以「图像-文本配对数据集」为例
二、数据预处理(核心:从原始数据到潜空间张量)
数据预处理是SD训练的基础,核心目标是:将原始2K图像缩放归一化、文本编码,最终转换为模型可直接输入的潜空间张量和文本嵌入,分为3个关键步骤。
2.1 基础数据集封装(图像+文本配对)
首先读取原始图像和文本,对图像进行缩放、归一化等基础预处理,将两者封装为{image, text}的配对格式,适配后续数据增强和VAE编码。
关键注意点:图像需缩放到SD标准训练尺寸(512×512),归一化到[-1, 1](匹配VAE输入要求);文本暂不编码,仅做基础清洗。
class ImageTextDataset(Dataset): def __init__(self, image_dir, caption_csv, transform=None): """ Args: image_dir: 图像文件夹路径 caption_csv: 文本描述csv文件路径 transform: 图像预处理transform """ self.image_dir = image_dir self.captions = pd.read_csv(caption_csv) # 读取文本描述 self.transform = transform def __len__(self): return len(self.captions) # 数据集总样本数 def __getitem__(self, idx): # 1. 读取图像 image_name = self.captions.iloc[idx]['image_name'] image_path = os.path.join(self.image_dir, image_name) image = Image.open(image_path).convert("RGB") # 转为RGB三通道 # 2. 读取文本(基础清洗) text = self.captions.iloc[idx]['text'].strip() # 3. 图像预处理(缩放、归一化) if self.transform is not None: image = self.transform(image) # 返回配对数据(image: [3,512,512], text: 字符串) return {"image": image, "text": text} # ------------------- 伪代码调用 ------------------- # 定义图像预处理transform(核心:缩放+归一化) image_transform = transforms.Compose([ transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR), # 缩放到512×512 transforms.ToTensor(), # 转为张量 [3,512,512],像素值[0,1] transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化到[-1,1] ]) # 初始化基础数据集 base_dataset = ImageTextDataset( image_dir="dataset/images", caption_csv="dataset/captions.csv", transform=image_transform ) # 查看数据集输出维度(BS=4时,后续dataloader输出参考) sample = base_dataset[0] print("预处理后图像维度:", sample["image"].shape) # torch.Size([3, 512, 512]) print("文本示例:", sample["text"]) # "a red cat sitting on a chair, high resolution" 2.2 增强型潜空间数据集(AugmentedLatentDataset)
核心作用:在像素空间做数据增强(提升模型泛化性),再将增强后的图像通过VAE编码为潜空间张量(64×64×4)——数据增强仅在像素空间进行,潜空间不做增强(避免破坏VAE的压缩特征)。
常见数据增强:随机水平翻转、随机裁切、亮度/对比度调整等,增强后需保持512×512尺寸,再送入VAE编码。
class AugmentedLatentDataset(Dataset): def __init__(self, base_dataset, vae, augment_transform=None): """ Args: base_dataset: 基础ImageTextDataset vae: VAE编码器(用于将像素空间转为潜空间) augment_transform: 像素空间的数据增强transform """ self.base_dataset = base_dataset self.vae = vae self.augment_transform = augment_transform # VAE设置为评估模式(不训练VAE,仅用于编码) self.vae.eval() def __len__(self): return len(self.base_dataset) def __getitem__(self, idx): # 1. 获取基础数据(预处理后的图像+文本) data = self.base_dataset[idx] image = data["image"] # [3,512,512] text = data["text"] # 2. 像素空间数据增强(可选,提升泛化性) if self.augment_transform is not None: image = self.augment_transform(image) # 3. VAE编码:将像素空间图像转为潜空间张量(64×64×4) # 注意:VAE输入需加batch维度,编码后去除batch维度,缩放潜空间(SD标准操作) with torch.no_grad(): # 编码时不计算梯度,节省显存 latent = self.vae.encode(image.unsqueeze(0)).latent_dist.sample() # [1,4,64,64] latent = latent * 0.18215 # SD固定缩放系数,匹配VAE训练时的归一化 # 返回潜空间张量+文本(latent: [4,64,64], text: 字符串) return {"latent": latent.squeeze(0), "text": text} # ------------------- 伪代码调用 ------------------- # 初始化VAE(使用SD预训练VAE,冻结参数) vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") vae.requires_grad_(False) # 冻结VAE,不参与训练 # 定义像素空间数据增强(仅在训练时使用) augment_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转(概率50%) transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.3), # 随机调整锐度 ]) # 初始化增强型潜空间数据集 latent_dataset = AugmentedLatentDataset( base_dataset=base_dataset, vae=vae, augment_transform=augment_transform ) # 查看潜空间数据维度 sample = latent_dataset[0] print("VAE编码后潜空间维度:", sample["latent"].shape) # torch.Size([4, 64, 64]) 2.3 DataLoader封装(批量处理)
将潜空间数据集封装为DataLoader,完成批量读取、打乱、丢弃最后不足一个batch的样本等操作,适配模型训练的批量输入需求,核心参数:batch_size=4(本文示例)、shuffle=True(训练时打乱数据)、drop_last=True(避免最后一个不完整batch影响训练)。
def create_dataloader(latent_dataset, batch_size=4, shuffle=True, drop_last=True): """创建DataLoader,批量输出潜空间张量和文本""" dataloader = DataLoader( dataset=latent_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, pin_memory=True, # 加速数据读取,适配GPU训练 num_workers=4 # 多线程读取,根据CPU核心数调整 ) return dataloader # ------------------- 伪代码调用 ------------------- # 训练集DataLoader(shuffle=True) train_dataloader = create_dataloader( latent_dataset=latent_dataset, batch_size=4, shuffle=True, drop_last=True ) # 验证集DataLoader(shuffle=False,仅用于评估) # val_dataloader = create_dataloader(latent_dataset=val_latent_dataset, batch_size=4, shuffle=False, drop_last=True) # 查看DataLoader输出维度(BS=4) for batch in train_dataloader: print("Batch潜空间维度:", batch["latent"].shape) # torch.Size([4, 4, 64, 64]) print("Batch文本数量:", len(batch["text"])) # 4(每个样本对应1条文本) break三、SD模型训练流程(核心:潜空间DDPM训练)
SD的训练核心是「在潜空间内训练DDPM」,模型输入为:加噪后的潜空间张量(noisy_latents)、时间步(timesteps)、文本嵌入(text_embeddings),目标是让UNet精准预测加进去的噪声,全程不涉及像素空间,仅在潜空间操作。
训练流程分为:时间步采样与加噪、文本编码、UNet前向传播、损失计算、反向传播与参数更新,共5个关键环节。
3.1 初始化核心模型与优化器
SD训练需初始化3个核心模型:CLIP Text Encoder(文本编码)、UNet(扩散模型核心,预测噪声)、VAE(已在数据预处理时初始化,冻结),以及优化器、学习率调度器。
关键注意点:训练时仅更新UNet参数,CLIP和VAE预训练后冻结,大幅降低计算量和显存消耗。
def init_training_components(): # 1. 初始化CLIP Text Encoder和Tokenizer(文本编码) tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") text_encoder.requires_grad_(False) # 冻结CLIP,不参与训练 # 2. 初始化UNet(扩散模型核心,预测噪声) unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") unet.train() # UNet设为训练模式 # 3. 初始化优化器(AdamW是SD训练的标准优化器) optimizer = optim.AdamW( unet.parameters(), lr=1e-4, # 基础学习率,可根据batchsize调整 betas=(0.9, 0.999), weight_decay=0.01 ) # 4. 初始化学习率调度器(线性衰减,适配SD训练) num_epochs = 10 # 训练总轮次 num_training_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=num_training_steps * 0.1, # 预热步数(10%) num_training_steps=num_training_steps ) return tokenizer, text_encoder, unet, optimizer, lr_scheduler # ------------------- 伪代码调用 ------------------- tokenizer, text_encoder, unet, optimizer, lr_scheduler = init_training_components() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 优先使用GPU unet.to(device) text_encoder.to(device) vae.to(device) print("核心模型初始化完成,设备:", device) 3.2 时间步采样与潜空间加噪(训练的核心前提)
DDPM的训练核心是「加噪-去噪」的迭代学习,这里的加噪操作仅在潜空间进行(VAE编码后的张量),步骤如下:
- 为每个batch的样本,随机采样时间步t(范围1~1000,t=0为无噪声,不采样);
- 生成与潜空间张量形状一致的标准正态噪声ε(训练必需的噪声信号);
- 根据DDPM前向公式,计算加噪后的潜空间张量noisy_latents。
DDPM前向加噪公式:

其中:

,

,是预定义的固定噪声调度序列(1e-4~0.02线性分布)。
def add_noise_to_latents(latents, timesteps, noise_scheduler): """ 对潜空间张量加噪,生成noisy_latents Args: latents: 原始潜空间张量 [BS,4,64,64] timesteps: 随机采样的时间步 [BS] noise_scheduler: DDPM噪声调度器(预定义β序列) Returns: noisy_latents: 加噪后的潜空间张量 [BS,4,64,64] noise: 真实加噪的噪声 [BS,4,64,64] """ # 1. 生成标准正态噪声(与潜空间张量形状一致) noise = torch.randn_like(latents, device=latents.device) # 2. 用噪声调度器计算加噪后的latents(DDPM前向公式) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) return noisy_latents, noise # ------------------- 伪代码调用 ------------------- # 初始化DDPM噪声调度器(SD标准配置:T=1000,β从1e-4到0.02线性分布) from diffusers import DDPMScheduler noise_scheduler = DDPMScheduler( num_train_timesteps=1000, beta_start=1e-4, beta_end=0.02, beta_schedule="linear" ) # 从dataloader取一个batch,进行加噪操作(BS=4) for batch in train_dataloader: latents = batch["latent"].to(device) # [4,4,64,64] texts = batch["text"] # 1. 随机采样时间步t(1~1000,每个样本的t不同) timesteps = torch.randint(1, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device) # 2. 潜空间加噪 noisy_latents, real_noise = add_noise_to_latents(latents, timesteps, noise_scheduler) print("原始潜空间维度:", latents.shape) # [4,4,64,64] print("加噪后潜空间维度:", noisy_latents.shape) # [4,4,64,64] print("真实噪声维度:", real_noise.shape) # [4,4,64,64] print("随机时间步:", timesteps) # 示例:tensor([345, 890, 120, 780], device='cuda:0') break 3.3 文本编码(text→text_embeddings)
将batch中的文本字符串,通过CLIP Tokenizer转为token张量,再通过CLIP Text Encoder编码为文本嵌入(text_embeddings),用于后续UNet的Cross-Attention融合。
关键注意点:CLIP Tokenizer默认将文本转为77维token(不足77维补0,超过77维截断),编码后得到[BS, 77, 768]的文本嵌入,需与UNet的注意力维度适配。
def encode_text(texts, tokenizer, text_encoder): """ 将文本转为text_embeddings Args: texts: batch文本列表(长度=BS) tokenizer: CLIP Tokenizer text_encoder: CLIP Text Encoder Returns: text_embeddings: 文本嵌入 [BS, 77, 768] """ # 1. Tokenizer编码:文本→token张量 [BS, 77] inputs = tokenizer( texts,, # 补全到77维 max_length=tokenizer.model_max_length, # 77 truncation=True, # 截断超过77维的文本 return_tensors="pt" # 返回PyTorch张量 ).to(text_encoder.device) # 2. Text Encoder编码:token→文本嵌入 [BS, 77, 768] with torch.no_grad(): # CLIP冻结,不计算梯度 text_embeddings = text_encoder(**inputs).last_hidden_state return text_embeddings # ------------------- 伪代码调用 ------------------- # 对当前batch的文本进行编码(BS=4) text_embeddings = encode_text(texts, tokenizer, text_encoder) print("文本嵌入维度:", text_embeddings.shape) # torch.Size([4, 77, 768]) 3.4 UNet前向传播(预测噪声)
UNet是SD的核心,输入为3个部分:noisy_latents(加噪潜空间张量)、timesteps(时间步)、text_embeddings(文本嵌入),输出为与真实噪声形状一致的预测噪声(ε_θ)。
关键细节:
- timesteps:需先做位置编码→MLP投影→广播,与noisy_latents的特征图相加,实现时间步信息的融入;
- text_embeddings:仅在UNet的Cross-Attention层融合,投影后作为K/V,与图像特征(Q)做注意力计算,实现文本-图像的关联。
def unet_forward(noisy_latents, timesteps, text_embeddings, unet): """ UNet前向传播,预测噪声 Args: noisy_latents: 加噪潜空间张量 [BS,4,64,64] timesteps: 时间步 [BS] text_embeddings: 文本嵌入 [BS,77,768] unet: UNet模型 Returns: noise_pred: 预测噪声 [BS,4,64,64] """ # UNet直接接收三个输入,内部自动完成timesteps和text_embeddings的维度适配 # 1. timesteps:内部做位置编码→投影→广播,与noisy_latents特征相加 # 2. text_embeddings:内部投影后,在Cross-Attention层作为K/V融合 noise_pred = unet( sample=noisy_latents, timestep=timesteps, encoder_hidden_states=text_embeddings ).sample # sample是UNet输出的预测噪声 return noise_pred # ------------------- 伪代码调用 ------------------- # UNet前向传播,预测噪声 noise_pred = unet_forward(noisy_latents, timesteps, text_embeddings, unet) print("预测噪声维度:", noise_pred.shape) # torch.Size([4,4,64,64])(与真实噪声维度一致) 3.5 损失计算与反向传播
SD训练的核心损失是「预测噪声与真实噪声的MSE Loss」——无需反推潜空间张量,直接对比UNet输出的noise_pred和加噪时的real_noise,计算均方误差,再反向传播更新UNet参数。
关键注意点:Loss仅计算噪声的差异,这是DDPM的核心简化设计,让模型专注于“猜中加进去的噪声”,后续推理时通过采样器反向去噪即可生成图像。
def train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler): """训练一个batch,完成前向、损失计算、反向传播、参数更新""" # 1. 前向传播,预测噪声 noise_pred = unet_forward(noisy_latents, timesteps, text_embeddings, unet) # 2. 计算MSE Loss(预测噪声 vs 真实噪声) loss_fn = nn.MSELoss() loss = loss_fn(noise_pred, real_noise) # 3. 反向传播(仅更新UNet参数) optimizer.zero_grad() # 清空梯度 loss.backward() # 计算梯度 optimizer.step() # 更新参数 lr_scheduler.step() # 学习率调度 return loss.item() # ------------------- 伪代码调用 ------------------- # 训练一个batch,查看Loss loss = train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler) print("当前batch的Loss:", loss) # 示例:0.035(训练初期Loss较高,后期逐步下降)3.6 完整训练循环(多Epoch迭代)
将上述环节整合,实现多Epoch的完整训练,定期保存模型权重(checkpoint),用于后续推理生成。
def full_training_loop(num_epochs, train_dataloader, noise_scheduler, tokenizer, text_encoder, unet, optimizer, lr_scheduler): """完整训练循环""" unet.train() for epoch in range(num_epochs): epoch_loss = 0.0 for step, batch in enumerate(train_dataloader): # 1. 读取batch数据 latents = batch["latent"].to(device) texts = batch["text"] # 2. 时间步采样与潜空间加噪 timesteps = torch.randint(1, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device) noisy_latents, real_noise = add_noise_to_latents(latents, timesteps, noise_scheduler) # 3. 文本编码 text_embeddings = encode_text(texts, tokenizer, text_encoder) # 4. 训练一个batch,计算Loss batch_loss = train_one_batch(noisy_latents, timesteps, text_embeddings, real_noise, unet, optimizer, lr_scheduler) epoch_loss += batch_loss # 打印日志(每100步打印一次) if (step + 1) % 100 == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step+1}/{len(train_dataloader)}], Batch Loss: {batch_loss:.4f}") # 计算当前Epoch的平均Loss avg_epoch_loss = epoch_loss / len(train_dataloader) print(f"Epoch [{epoch+1}/{num_epochs}] Finished, Average Loss: {avg_epoch_loss:.4f}") # 定期保存模型权重(每1个Epoch保存一次) torch.save(unet.state_dict(), f"unet_epoch_{epoch+1}.pth") print(f"Model saved to unet_epoch_{epoch+1}.pth") # ------------------- 伪代码调用 ------------------- # 启动完整训练(10个Epoch) num_epochs = 10 full_training_loop( num_epochs=num_epochs, train_dataloader=train_dataloader, noise_scheduler=noise_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, optimizer=optimizer, lr_scheduler=lr_scheduler ) 四、SD推理流程(核心:潜空间逐步去噪)
推理阶段的核心是「反向去噪」:从纯高斯噪声(t=1000)开始,按设定的步数逐步去噪,最终得到清晰的潜空间张量,再通过VAE解码为像素空间图像。
关键环节:逐步去噪(每步一次UNet前向)、CFG增强(强化文本控制)、随机噪声(增加生成多样性)。
4.1 推理前准备(加载模型与参数)
def init_inference_components(unet_ckpt_path): """初始化推理所需组件,加载训练好的UNet权重""" # 1. 初始化VAE(用于最终解码潜空间→像素空间) vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") vae.eval() vae.requires_grad_(False) # 2. 初始化CLIP(文本编码) tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") text_encoder.eval() text_encoder.requires_grad_(False) # 3. 初始化UNet,加载训练好的权重 unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") unet.load_state_dict(torch.load(unet_ckpt_path)) # 加载训练权重 unet.eval() unet.requires_grad_(False) # 4. 初始化推理用噪声调度器(与训练时一致) noise_scheduler = DDPMScheduler( num_train_timesteps=1000, beta_start=1e-4, beta_end=0.02, beta_schedule="linear" ) # 5. 初始化采样器(这里用DDIM采样器,加速推理,步数20~50步) from diffusers import DDIMScheduler sampler = DDIMScheduler.from_config(noise_scheduler.config) sampler.set_timesteps(num_inference_steps=50) # 推理步数(50步,比训练时1000步快20倍) return vae, tokenizer, text_encoder, unet, sampler # ------------------- 伪代码调用 ------------------- # 加载训练好的UNet权重(示例:第10个Epoch的权重) unet_ckpt_path = "unet_epoch_10.pth" vae, tokenizer, text_encoder, unet, sampler = init_inference_components(unet_ckpt_path) # 移动到GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae.to(device) text_encoder.to(device) unet.to(device) 4.2 文本编码(推理时与训练一致)
推理时的文本编码流程与训练完全一致,将输入的文本描述转为text_embeddings,同时生成“空文本嵌入”(用于CFG增强)。
def encode_text_inference(prompt, tokenizer, text_encoder): """推理时的文本编码,同时生成有文本和无文本(空文本)的嵌入""" # 1. 有文本的嵌入(prompt为输入文本) prompt_inputs = tokenizer( prompt,, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" ).to(text_encoder.device) with torch.no_grad(): text_embeddings = text_encoder(**prompt_inputs).last_hidden_state # 2. 无文本的嵌入(空文本,用于CFG增强) null_inputs = tokenizer( null_prompt,, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" ).to(text_encoder.device) with torch.no_grad(): null_text_embeddings = text_encoder(**null_inputs).last_hidden_state return text_embeddings, null_text_embeddings # ------------------- 伪代码调用 ------------------- # 输入推理文本(示例:"a red cat sitting on a chair, high resolution") prompt = "a red cat sitting on a chair, high resolution" text_embeddings, null_text_embeddings = encode_text_inference(prompt, tokenizer, text_encoder) print("推理文本嵌入维度:", text_embeddings.shape) # [1,77,768](推理时BS=1,单张生成) 4.3 逐步去噪与CFG增强(推理核心)
推理时的去噪流程:从t=1000的纯高斯噪声开始,按采样器设定的步数(50步)逐步从t=1000→1去噪,每步执行一次UNet前向传播,通过CFG增强文本控制,加入随机噪声增加多样性。
CFG核心公式(强化文本引导):
$$\epsilon_{cfg} = \epsilon_{null} + cfg\_scale \times (\epsilon_{text} - \epsilon_{null})$$
其中:cfg_scale默认7.5,值越大,文本控制越强(过高会导致图像失真)。
def inference(prompt, vae, tokenizer, text_encoder, unet, sampler, cfg_scale=7.5): """ SD推理生成图像 Args: prompt: 文本描述 vae: VAE解码器 tokenizer: CLIP Tokenizer text_encoder: CLIP Text Encoder unet: 训练好的UNet sampler: 采样器(DDIM) cfg_scale: CFG系数,控制文本引导强度 Returns: generated_image: 生成的像素空间图像 [3,512,512] """ # 1. 文本编码,得到有文本/无文本嵌入 text_embeddings, null_text_embeddings = encode_text_inference(prompt, tokenizer, text_encoder) # 拼接有文本和无文本嵌入(适配CFG计算) text_embeddings = torch.cat([null_text_embeddings, text_embeddings]) # [2,77,768] # 2. 初始化潜空间噪声(t=1000,纯高斯噪声) batch_size = 1 latent_dim = 4 latent_size = 64 noise = torch.randn( (batch_size, latent_dim, latent_size, latent_size), device=unet.device ) latents = noise # 初始潜空间噪声(t=1000) # 3. 逐步去噪(按采样器的时间步迭代) with torch.no_grad(): # 推理时不计算梯度 for t in sampler.timesteps: # 3.1 扩展latents和timesteps,适配CFG的双输入(有文本/无文本) latent_model_input = torch.cat([latents] * 2) # [2,4,64,64] timestep = torch.tensor([t] * batch_size * 2, device=unet.device) # 3.2 UNet前向传播,预测噪声(一次预测有文本/无文本两种情况) noise_pred = unet( sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings ).sample # [2,4,64,64] # 3.3 CFG增强:分离无文本/有文本的噪声预测,计算最终噪声 noise_pred_null, noise_pred_text = noise_pred.chunk(2) # 各[1,4,64,64] noise_pred = noise_pred_null + cfg_scale * (noise_pred_text - noise_pred_null) # 3.4 采样器去噪,得到t-1的潜空间张量 latents = sampler.step(noise_pred, t, latents).prev_sample # 4. VAE解码:潜空间→像素空间(512×512) latents = latents / 0.18215 # 反缩放(与训练时的缩放对应) with torch.no_grad(): generated_image = vae.decode(latents).sample # [1,3,512,512] # 5. 图像后处理:从[-1,1]转回[0,255],转为PIL图像 generated_image = (generated_image / 2 + 0.5).clamp(0, 1) # 归一化到[0,1] generated_image = generated_image.cpu().permute(0, 2, 3, 1).numpy()[0] # [512,512,3] generated_image = (generated_image * 255).astype(np.uint8) generated_image = Image.fromarray(generated_image) return generated_image # ------------------- 伪代码调用 ------------------- # 执行推理,生成图像 generated_image = inference( prompt="a red cat sitting on a chair, high resolution", vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, sampler=sampler, cfg_scale=7.5 ) # 保存生成的图像 generated_image.save("generated_image.jpg") print("图像生成完成,已保存为generated_image.jpg") 五、LoRA轻量化训练(可选,核心:冻结主模型,训练适配器)
SD的UNet参数量数十亿,全量训练显存消耗大(需40GB以上),LoRA(Low-Rank Adaptation)通过在UNet的注意力层挂载轻量化适配器,仅训练适配器参数(参数量仅百万级),大幅降低显存消耗,同时实现特定风格/内容的微调。
LoRA训练流程分为两种方式:手动实现适配器、用PEFT库简化实现(推荐新手)。
5.1 PEFT库简化实现LoRA训练(推荐)
PEFT库已封装好LoRA逻辑,只需定义LoRA配置,挂载到UNet,即可实现轻量化训练,无需手动编写适配器。
def init_lora_training(): """初始化LoRA训练组件,冻结主模型,挂载LoRA适配器""" # 1. 初始化基础模型(与之前一致) tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") # 2. 冻结主模型(仅训练LoRA适配器) text_encoder.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) # 3. 定义LoRA配置(核心参数) lora_config = LoraConfig( r=8, # LoRA秩,越小参数量越少,一般取4~16 lora_alpha=16, # 缩放系数,通常是r的2倍 target_modules=["q_proj", "v_proj"], # 挂载到UNet的注意力层(Q/V投影层) lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) # 4. 挂载LoRA适配器到UNet unet = get_peft_model(unet, lora_config) unet.print_trainable_parameters() # 查看可训练参数(通常仅百万级) # 5. 初始化优化器和调度器(仅优化LoRA参数) optimizer = optim.AdamW( unet.parameters(), lr=5e-5, # LoRA学习率可略低 betas=(0.9, 0.999), weight_decay=0.01 ) num_epochs = 5 num_training_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps ) return tokenizer, text_encoder, vae, unet, optimizer, lr_scheduler # ------------------- 伪代码调用 ------------------- # 初始化LoRA训练组件 tokenizer_lora, text_encoder_lora, vae_lora, unet_lora, optimizer_lora, lr_scheduler_lora = init_lora_training() # 启动LoRA训练(训练流程与全量训练一致,仅训练LoRA参数) full_training_loop( num_epochs=5, train_dataloader=train_dataloader, noise_scheduler=noise_scheduler, tokenizer=tokenizer_lora, text_encoder=text_encoder_lora, unet=unet_lora, optimizer=optimizer_lora, lr_scheduler=lr_scheduler_lora ) # 保存LoRA权重(仅保存适配器参数,文件体积小,约几MB) unet_lora.save_pretrained("lora_weights") print("LoRA权重保存完成,路径:lora_weights") 5.2 LoRA推理(挂载适配器)
LoRA推理时,需加载预训练的UNet主模型,再挂载LoRA适配器,即可实现微调后的生成效果,无需加载完整的微调UNet权重。
def lora_inference(prompt, lora_path, cfg_scale=7.5): """LoRA推理,挂载适配器""" # 1. 初始化基础模型(与推理时一致) vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") # 2. 挂载LoRA适配器 unet = PeftModel.from_pretrained(unet, lora_path) unet.eval() unet.requires_grad_(False) # 3. 初始化采样器 sampler = DDIMScheduler.from_config(DDPMScheduler(num_train_timesteps=1000).config) sampler.set_timesteps(num_inference_steps=50) # 4. 执行推理(与普通推理流程一致) generated_image = inference( prompt=prompt, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, sampler=sampler, cfg_scale=cfg_scale ) return generated_image # ------------------- 伪代码调用 ------------------- # LoRA推理(示例:加载训练好的LoRA权重) lora_path = "lora_weights" lora_generated_image = lora_inference( prompt="a red cat sitting on a chair, high resolution, lora style", lora_path=lora_path, cfg_scale=7.5 ) # 保存LoRA生成的图像 lora_generated_image.save("lora_generated_image.jpg") print("LoRA图像生成完成,已保存") 六、常见问题与注意事项
- 显存不足:可降低、使用LoRA、开启梯度检查点(gradient checkpointing);
- 训练Loss不下降:检查时间步采样范围(需1~1000)、噪声调度器配置、学习率是否过高;
- 生成图像不贴合文本:调大CFG_scale(7~10)、增加推理步数(50步)、检查文本编码是否正确;
- LoRA训练无效:确认target_modules是否为UNet的注意力层(q_proj、v_proj)、LoRA秩r是否合理。
七、总结
Stable Diffusion的核心是「潜空间扩散」,全程围绕VAE(潜空间映射)、CLIP(文本编码)、UNet(噪声预测)三大模型展开,训练时在潜空间加噪、让UNet预测噪声,推理时逐步去噪、用CFG强化文本控制,LoRA则实现轻量化微调。