生成的甚么玩意 Õ_Õ ?基于 LoRA+Stable Diffusion 的100种动物图像生成
生成的甚么玩意 Õ_Õ ?基于 LoRA+Stable Diffusion 的100种动物图像生成
代码详见:https://github.com/xiaozhou-alt/Animals_Generationn
文章目录
- 生成的甚么玩意 Õ_Õ ?基于 LoRA+Stable Diffusion 的100种动物图像生成
- 一、项目介绍
- 二、文件夹结构
- 三、数据集介绍
- 四、Stable Diffusion 与 LoRA 模型介绍
- 五、项目实现
- 六、结果展示
一、项目介绍
这是一个基于Stable Diffusion和LoRA技术的动物图像生成系统,能够通过文本描述生成高质量的动物图像,包含完整的训练流程和用户友好的图形界面,支持自定义参数调整和实时图像生成。
主要特性
- 高效训练: 使用 LoRA (Low-Rank Adaptation) 技术对 Stable Diffusion 模型进行轻量级微调
- 用户友好界面: 基于 PyQt5 的图形界面,支持实时参数调整和图像预览
- 高质量生成: 经过优化的生成流程,支持色彩校正和后处理
- 跨平台支持: 支持 CPU 和 GPU 运行环境
生成的部分动物图像:
二、文件夹结构
Animals_Creation/ ├── README.md ├── demo.gif # 演示动画 ├── demo.mp4 # 演示视频 ├── demo.py # 主演示脚本 ├── icons/ # 图标资源目录 ├── train.py ├── log/ # 日志目录 ├── model/ └── LCM-runwayml-stable-diffusion-v1-5/ # Stable Diffusion模型 ├── feature_extractor/ # 特征提取器 ├── model_index.json # 模型索引文件 ├── safety_checker/ # 安全检查器 ├── scheduler/ # 调度器 ├── text_encoder/ # 文本编码器 ├── tokenizer/ # 分词器 ├── unet/ # UNet模型 └── vae/ # 变分自编码器 ├── output/ ├── evaluation_results.xlsx # 评估结果Excel文件 ├── lora_models/ # LoRA模型权重 └── clip-31.475.safetensors ├── training_history.xlsx # 训练历史记录 └── pic/ └── requirements.txt 三、数据集介绍
本项目使用的动物数据集包含 100 个不同类别的动物图片,因为使用网页图片提取下载,清洗由我一个人完全进行,数据集数据量较大,所以部分动物文件夹存在1%-1.5%的噪声图片,数据集组织结构如下:
- 总类别数:100 种动物
- 数据划分:采用 80% 作为训练集,20% 作为验证集
- 图片格式:支持常见图片格式(.jpg, .jpeg, .png 等)
在模型训练过程中,通过数据增强技术扩充了训练样本,包括旋转、平移、缩放、亮度调整等操作,以提高模型的泛化能力。
动物的类别信息请查看class.txt:
antelope
badger
bat
…
数据集下载:100种动物识别数据集(ScienceDB)
引用
如果您使用了本项目的数据集,请使用如下方式进行引用:
Haojing ZHOU.100种动物识别数据集[DS/OL]. V1. Science Data Bank,2025[2025-08-30]. https://cstr.cn/31253.11.sciencedb.29221. CSTR:31253.11.sciencedb.29221.或
@misc{动物识别, author ={Haojing ZHOU}, title ={100种动物识别数据集}, year ={2025}, doi ={10.57760/sciencedb.29221}, url ={https://doi.org/10.57760/sciencedb.29221}, note ={CSTR:31253.11.sciencedb.29221}, publisher ={ScienceDB}}四、Stable Diffusion 与 LoRA 模型介绍
1. Stable Diffusion 模型架构解析
Stable Diffusion 采用潜在扩散模型(Latent Diffusion Model)架构,通过将高维图像压缩到低维潜在空间进行扩散过程,显著提升了计算效率。该模型主要由四个核心组件构成:变分自编码器(VAE) + CLIP 文本编码器 + U-Net 模型 + 噪声调度器(DDPMScheduler)
1.1 变分自编码器(VAE)
VAE 在 Stable Diffusion 中承担着图像与潜在空间的双向转换任务。其编码器将输入图像 x x x压缩为潜在表示 z z z,解码器则将潜在表示重建为图像 x ^ \hat{x} x^。在代码实现中,我们使用预训练的 AutoencoderKL 模型:
vae = AutoencoderKL.from_pretrained( config.pretrained_model_name_or_path, subfolder="vae")VAE 的核心工作原理是通过变分推断学习数据的潜在分布。对于输入图像 x x x,编码器输出潜在分布的均值 μ \mu μ 和方差 σ 2 \sigma^2 σ2,通过重参数化技巧采样得到潜在表示:
z = μ + ϵ ⋅ σ , ϵ ∼ N ( 0 , I ) z = \mu + \epsilon \cdot \sigma, \quad \epsilon \sim \mathcal{N}(0, I) z=μ+ϵ⋅σ,ϵ∼N(0,I)
在项目中,我们将编码得到的潜在表示进行缩放:
latents = vae.encode(pixel_values).latent_dist.sample() latents = latents *0.18215# 缩放因子这里的缩放因子 0.18215 0.18215 0.18215 是 Stable Diffusion 模型预训练时确定的常数,用于将 VAE 输出的潜在空间分布标准化到更适合扩散过程的范围。
🤓🤓🤓小周有话说:
VAE 就像一位技艺精湛的 图像压缩专家。当处理输入图像时,它能将原始像素数据压缩成紧凑的"潜在代码"(类似图像的 zip 压缩文件),这个过程保留了图像的 关键特征 但大大减小了数据量。在生成图像时,它又能将 “潜在代码” 完美解压还原成高质量图像。代码中 0.18215 的缩放因子则像是统一的压缩标准,确保不同图像的 “压缩文件” 具有一致的数据范围。
1.2 CLIP 文本编码器
文本引导 是 Stable Diffusion 的核心特性,这一功能由 CLIP(Contrastive Language-Image Pretraining)文本编码器实现。它将文本描述转换为固定维度的向量表示,建立文本与图像之间的语义关联:
text_encoder = CLIPTextModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder") tokenizer = CLIPTokenizer.from_pretrained( config.pretrained_model_name_or_path, subfolder="tokenizer")CLIP 文本编码器通过对比学习训练,其输出的文本嵌入 t t t与图像嵌入在同一语义空间中。对于输入文本 w w w(如 "a photo of a cat"),经过分词和编码后得到文本特征:
t = text_encoder ( tokenizer ( w ) ) t = \text{text\_encoder}(\text{tokenizer}(w)) t=text_encoder(tokenizer(w))
在项目中,我们使用多样化的提示词模板增强文本嵌入的鲁棒性:
self.prompt_templates =["a photo of a {}","a high quality image of a {}",# 更多模板...]🤓🤓🤓小周有话说:
CLIP 文本编码器就像一位精通多语言的 翻译官,它能将人类的文字描述准确翻译成模型能理解的"向量语言"。当我们输入"一只可爱的小狗"这样的描述时,它会生成一个独特的数字向量,这个向量捕捉了文字中的语义信息。项目中使用多种提示词模板,相当于用不同表达方式描述同一种动物,确保模型能理解各种表述方式,提高生成的准确性。
1.3 U-Net 条件扩散模型
U-Net 是 Stable Diffusion 的 核心扩散模块,负责在潜在空间中 预测噪声。它以带噪声的潜在表示 z t z_t zt、时间步 t t t 和文本嵌入 t t t 作为输入,输出噪声预测 ϵ θ ( z t , t , c ) \epsilon_\theta(z_t, t, c) ϵθ(zt,t,c):
unet = UNet2DConditionModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="unet")# 噪声预测 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample U-Net 采用编码器 - 解码器结构,通过跳跃连接保留细节信息,同时引入时间步嵌入和文本条件嵌入,实现条件生成。损失函数采用预测噪声与真实噪声的均方误差:
L = E z 0 , ϵ , t [ ∥ ϵ − ϵ θ ( z t , t , c ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{z_0, \epsilon, t} \left[ \|\epsilon - \epsilon_\theta(z_t, t, c)\|^2 \right] L=Ez0,ϵ,t[∥ϵ−ϵθ(zt,t,c)∥2]
🤓🤓🤓小周有话说:
U-Net 就像一位技艺高超的 图像修复专家,擅长从模糊图像中还原细节。在扩散过程的每一步,它观察当前带有噪声的图像(潜在表示),并根据时间信息和文本描述,精确 预测出需要移除的噪声。随着扩散步骤推进,它逐步 “清理” 图像中的噪声,直到生成清晰、符合文本描述的图像。可以想象成一系列渐进式的修复过程,每一步都让图像更接近最终目标。
1.4 噪声调度器(DDPMScheduler)
噪声调度器控制着扩散过程中的 噪声添加和去除 策略。在训练阶段,它按照特定 schedule 向干净样本添加噪声;在推理阶段,则逐步从纯噪声中生成图像:
noise_scheduler = DDPMScheduler.from_pretrained( config.pretrained_model_name_or_path, subfolder="scheduler")扩散过程遵循 马尔可夫链,前向过程中噪声逐步增加:
z t = α t z t − 1 + 1 − α t ϵ , ϵ ∼ N ( 0 , I ) z_t = \sqrt{\alpha_t} z_{t-1} + \sqrt{1-\alpha_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) zt=αtzt−1+1−αtϵ,ϵ∼N(0,I)
其中 α t \alpha_t αt 是调度器预定义的噪声系数。在项目中,我们通过调度器添加噪声:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)🤓🤓🤓小周有话说:
噪声调度器就像一位导演,精心控制着扩散过程的节奏。在 训练 时,它决定如何逐步向清晰图像添加噪声,制造出不同程度的 模糊效果;在 生成图像 时,它又指导模型如何一步步从完全的噪声中还原出清晰图像。项目中配置的 推理步数(num_inference_steps)决定了生成过程的 “精细度”,步数越多,生成的图像通常越精细,但需要的计算时间也越长。
2. LoRA 参数高效微调技术
在大规模预训练模型的微调任务中,全参数微调需要巨大的计算资源。LoRA(Low-Rank Adaptation)技术通过 冻结预训练模型权重,仅训练低秩矩阵参数,实现高效微调:
defprepare_unet_for_lora(unet, rank=2, alpha=16): lora_config = LoraConfig( r=rank, lora_alpha=alpha, target_modules=["to_q","to_k","to_v","to_out.0"], lora_dropout=0.0, bias="none",) unet = get_peft_model(unet, lora_config)return unet 2.1 LoRA 工作原理
LoRA 的核心思想是将权重更新表示为低秩矩阵分解的形式。对于预训练权重 W ∈ R d × k W \in \mathbb{R}^{d \times k} W∈Rd×k,LoRA 通过学习两个低秩矩阵 W A ∈ R d × r W_A \in \mathbb{R}^{d \times r} WA∈Rd×r 和 W B ∈ R r × k W_B \in \mathbb{R}^{r \times k} WB∈Rr×k( r ≪ min ( d , k ) r \ll \min(d,k) r≪min(d,k))来近似权重更新:
W ′ = W + W B W A W' = W + W_B W_A W′=W+WBWA
在项目中,我们将 LoRA 应用于 U-Net 的注意力模块,具体是查询(to_q)、键(to_k)、值(to_v)投影层和输出投影层(to_out.0):
Attention ( Q + Δ Q , K + Δ K , V + Δ V ) \text{Attention}(Q + \Delta Q, K + \Delta K, V + \Delta V) Attention(Q+ΔQ,K+ΔK,V+ΔV)
其中 Δ Q = W B Q W A Q \Delta Q = W_B^Q W_A^Q ΔQ=WBQWAQ, Δ K \Delta K ΔK 和 Δ V \Delta V ΔV 类似。这种设计使模型能够在保持预训练知识的同时,高效学习特定任务的知识。
🤓🤓🤓小周有话说:
LoRA 技术就像给预训练模型加装 专用插件,而不是重新设计整个模型。想象 Stable Diffusion 是一台功能强大的 通用相机,能够拍摄各种场景,但我们希望它特别擅长拍摄动物。全参数微调相当于重新设计相机的所有部件,而 LoRA 则是给相机添加一个专用的动物摄影镜头。项目中配置的rank 参数(r=2)决定了这个 “专用镜头” 的 复杂度,alpha 参数(lora_alpha=16)则控制其 影响力。这种方式不仅节省资源,还能保留模型原有的通用能力。
2.2 LoRA 参数配置与优势
在项目配置中,我们选择了较小的 秩(rank=2)和 alpha 值(lora_alpha=16):
# LoRA参数 rank =2 lora_alpha =16这种配置大大减少了可训练参数数量。通过print_trainable_parameters()可以发现,仅约 0.1% 的参数参与训练,显著降低了内存需求和计算成本。同时,LoRA 权重文件体积小(通常只有几 MB),便于存储和分享。
五、项目实现
温馨提示:本文项目实现部分分割较细,篇幅较长,读者不想深究代码原理可见 GitHub项目 中README.md 进行项目实现;若是愿意深究代码逻辑,我在下文中对于 训练代码 和 UI 界面使用代码 进行了详细说明
1. 训练代码实现
① 参数配置
Config 类集中管理了所有关键参数,体现了资源受限情况下的优化策略:
- 计算资源优化:通过降低分辨率(
256x256)、减少 LoRA 秩(rank=2)等措施,显著降低显存占用,使训练在普通 GPU 上成为可能。 - 训练稳定性控制:采用较小的学习率(1e-5)配合 余弦退火调度器,结合梯度裁剪(
max_grad_norm=0.5),有效防止训练过程中的 梯度爆炸 问题。 - 数据均衡策略:通过
max_samples_per_class限制每类样本数量,解决动物数据集类别不平衡问题,避免模型对样本多的类别过拟合。 - 评估体系设计:集成 CLIP 模型 作为客观 评估指标,相比传统的 MSE 损失能更好地反映生成图像的语义一致性。
- 工程化配置:合理规划输出目录结构,为模型权重、训练日志、生成样本等分别创建存储路径,便于实验管理。
# 参数配置 - 关键优化点classConfig:# 数据参数 - 减少数据量 data_root ="/kaggle/input/animals/Animal/Animal"# 动物数据集根路径 output_dir ="/kaggle/working/output"# 所有输出文件的目录 lora_model_dir = os.path.join(output_dir,"lora_models")# 保存LoRA模型的目录 history_file = os.path.join(output_dir,"training_history.xlsx")# 训练历史记录文件 sample_output_dir = os.path.join(output_dir,"validation_samples")# 验证样本输出目录 evaluation_file = os.path.join(output_dir,"evaluation_results.xlsx")# 评估结果文件 comparison_dir = os.path.join(output_dir,"comparison_samples")# 对比样本目录# 模型参数 - 降低分辨率 pretrained_model_name_or_path ="runwayml/stable-diffusion-v1-5"# 使用SD 1.5作为基础模型 resolution =256# 降低分辨率以减少计算量 (原为512) center_crop =True# 中心裁剪 random_flip =True# 随机水平翻转 (数据增强)# LoRA参数 - 简化LoRA rank =2# 降低LoRA的秩 (原为4) lora_alpha =16# 降低LoRA的alpha值 (原为32)# 训练参数 - 关键优化 train_batch_size =1# 批大小 gradient_accumulation_steps =4# 梯度累积步数 num_train_epochs =10# 训练轮数 learning_rate =1e-5# 学习率 lr_scheduler_type ="cosine_with_warmup"# 学习率调度器类型 lr_warmup_steps =200# 预热步数 max_grad_norm =0.5# 梯度裁剪阈值 use_ema =True# 启用EMA以提高稳定性 gradient_checkpointing =True# 梯度检查点 (节省显存) mixed_precision ="fp16"# 混合精度训练# 早停参数 - 使用CLIP分数作为指标 early_stopping_patience =5# 早停耐心值 early_stopping_delta =0.02# CLIP分数的最小改善值 validation_split =0.1# 验证集比例# 验证参数 num_validation_samples =5# 验证生成的动物种类数量 num_inference_steps =20# 验证时推理步数 num_final_inference_steps =100# 最终评估推理步数 guidance_scale =7.5# 指导尺度 (CFG)# 每类最大样本数 max_samples_per_class =100# 限制每类动物使用的最大样本数# 评估参数 num_evaluation_samples =10# 评估样本数量 clip_model_name ="openai/clip-vit-base-patch32"# CLIP模型名称② 数据处理
AnimalDataset 类实现了动物图像数据集的 加载和预处理 功能,核心特点包括:
- 层级数据加载:数据集采用
"根目录 / 动物类别 / 图像文件"的层级结构,通过扫描子文件夹自动识别类别名称。 - 样本均衡处理:对样本数量超过
max_samples_per_class的类别进行随机采样,确保各类别样本量相对均衡。 - 提示词工程:定义多种提示词模板,避免模型对单一表述产生过拟合,增强生成多样性。模板涵盖了不同角度(
特写、自然栖息地)和属性(可爱、野生)的描述,丰富了模型的条件学习信号。
# 1. 数据处理与准备 - 添加样本限制classAnimalDataset(Dataset):def__init__(self, data_root, tokenizer, size=384, center_crop=True, random_flip=True, max_samples_per_class=100): self.data_root = data_root self.tokenizer = tokenizer self.size = size # 使用新的分辨率 self.center_crop = center_crop self.random_flip = random_flip self.max_samples_per_class = max_samples_per_class # 获取所有图像路径和对应的类别(动物名称) self.image_paths =[] self.class_names =[]# 假设子文件夹以动物英文名称命名 subfolders =[f.name for f in os.scandir(data_root)if f.is_dir()]for class_name in subfolders: class_dir = os.path.join(data_root, class_name) image_files = glob.glob(os.path.join(class_dir,"*.jpg"))+ \ glob.glob(os.path.join(class_dir,"*.png"))+ \ glob.glob(os.path.join(class_dir,"*.jpeg"))# 限制每类样本数量iflen(image_files)> max_samples_per_class: image_files = random.sample(image_files, max_samples_per_class)for img_path in image_files: self.image_paths.append(img_path) self.class_names.append(class_name)# 为每个类别创建提示词模板 self.prompt_templates =["a photo of a {}","a high quality image of a {}","a clear picture of a {}","a realistic image of a {}","a cute {}","a wild {} in its natural habitat","a close-up of a {}"]- 图像预处理:
- 采用 中心裁剪 策略保留图像主体内容,避免边缘干扰
- 使用
LANCZOS重采样方法进行缩放,保证图像质量 - 随机水平翻转增强数据多样性,降低过拟合风险
- 像素值归一化到 [-1, 1] 范围,符合 Stable Diffusion 输入要求
- 文本处理:
- 随机选择提示词模板,为每个图像生成多样化描述
- 使用
CLIP tokenizer对文本进行编码,生成模型可理解的输入_ids - 确保文本长度一致(通过
padding和truncation),便于批量处理
- 返回数据结构:精心设计的字典结构包含图像张量、文本编码、原始提示和类别名称,满足训练和日志记录需求
def__len__(self):returnlen(self.image_paths)def__getitem__(self, idx): image_path = self.image_paths[idx] class_name = self.class_names[idx]# 加载和预处理图像 image = Image.open(image_path).convert("RGB")# 调整大小和中心裁剪if self.center_crop:# 保持宽高比的调整大小和中心裁剪 image = self._center_crop(image)else: image = image.resize((self.size, self.size), Image.Resampling.LANCZOS)# 随机水平翻转 (数据增强)if self.random_flip and random.random()<0.5: image = image.transpose(Image.FLIP_LEFT_RIGHT)# 将图像转换为模型输入的张量 (-1 to 1) image_tensor =(torch.tensor(np.array(image).astype(np.float32)/127.5)-1.0).permute(2,0,1)# 为图像生成随机的提示词 prompt_template = random.choice(self.prompt_templates) prompt = prompt_template.format(class_name)# 对提示词进行标记化 tokenized_input = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt") input_ids = tokenized_input.input_ids.squeeze(0)return{"pixel_values": image_tensor,"input_ids": input_ids,"prompt": prompt,"class_name": class_name }def_center_crop(self, image): width, height = image.size new_size =min(width, height) left =(width - new_size)/2 top =(height - new_size)/2 right =(width + new_size)/2 bottom =(height + new_size)/2 image = image.crop((left, top, right, bottom)) image = image.resize((self.size, self.size), Image.Resampling.LANCZOS)return image ③ 早停机制
传统的基于损失的早停可能无法准确反映生成模型的质量,本项目采用 基于 CLIP 分数 的早停策略,具有以下特点:
- 质量导向停止:当 CLIP 分数(
图像 - 文本匹配度)不再显著提升时停止训练,更符合生成任务的质量目标。 - 参数配置:
patience:允许分数不提升的轮数,设置为 5 5 5 给予模型足够的优化空间delta:最小改善阈值, 0.02 0.02 0.02 确保只有显著提升才被认可
- 工作原理:维护最佳分数记录,连续多轮未达到改善阈值则触发早停,有效防止过拟合和节省计算资源
# 早停机制 (PyTorch实现) - 使用CLIP分数作为指标classEarlyStopping:def__init__(self, patience=3, delta=0.05, verbose=False): self.patience = patience self.delta = delta self.verbose = verbose self.counter =0 self.best_score =None self.early_stop =Falsedef__call__(self, clip_score):if self.best_score isNone: self.best_score = clip_score elif clip_score < self.best_score + self.delta: self.counter +=1if self.verbose:print(f"EarlyStopping counter: {self.counter} out of {self.patience}")if self.counter >= self.patience: self.early_stop =Trueelse: self.best_score = clip_score self.counter =0④ LoRA 模型配置
函数实现了对 Stable Diffusion 核心组件UNet 的 LoRA 适配,是参数高效微调的关键:
- LoRA 核心参数:
r( r a n k rank rank):低秩矩阵的秩,设置为 2 2 2 大幅减少可训练参数lora_alpha:缩放因子,与秩配合控制更新幅度target_modules:指定需要注入 LoRA 的模块,选择注意力层的查询、键、值投影和输出层
- 参数效率:通过 PEFT 库的
get_peft_model函数将 LoRA 适配器注入UNet,仅训练少量适配器参数而非整个模型。print_trainable_parameters()会输出可训练参数比例,通常仅为原始模型的 0.1 % 0.1\% 0.1% 左右。 - 优势:相比全参数微调,LoRA 方法显著降低内存需求,加快训练速度,同时减少过拟合风险,特别适合小数据集场景
# 为UNet准备LoRA的函数 - 使用peft库defprepare_unet_for_lora(unet, rank=2, alpha=16):# 配置LoRA lora_config = LoraConfig( r=rank, lora_alpha=alpha, target_modules=["to_q","to_k","to_v","to_out.0"], lora_dropout=0.0, bias="none",)# 应用LoRA到UNet unet = get_peft_model(unet, lora_config) unet.print_trainable_parameters()return unet ⑤ CLIP 分数计算
CLIP(Contrastive Language-Image Pretraining)分数用于 量化评估生成图像与文本描述的匹配程度,是生成质量的重要指标:
- 评估流程:
- 随机选择代表性动物类别进行评估
- 使用微调后的模型生成对应动物图像
- 通过 CLIP 模型计算生成图像与文本提示的 相似度
- 技术细节:
- CLIP 模型将图像和文本编码到共享语义空间
logits_per_image表示图像与文本的匹配分数,值越高表示匹配度越好- 使用
torch.autocast和torch.no_grad优化计算效率
- 优势:相比人工评估更高效,相比 MSE 损失更能反映语义层面的质量,为早停机制和模型选择提供客观依据
# 计算验证CLIP分数的函数defcompute_validation_clip_score(config, unet, text_encoder, vae, tokenizer, device):# 获取所有动物类别 animal_classes =[f.name for f in os.scandir(config.data_root)if f.is_dir()]# 随机选择验证用的动物 selected_animals = random.sample(animal_classes,min(config.num_validation_samples,len(animal_classes)))print(f"Selected animals for validation CLIP score: {selected_animals}")# 创建生成管道 pipe = StableDiffusionPipeline.from_pretrained( config.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, safety_checker=None, torch_dtype=torch.float16 if config.mixed_precision =="fp16"else torch.float32 ).to(device)# 加载CLIP模型和处理器 clip_model = CLIPModel.from_pretrained(config.clip_model_name).to(device) clip_processor = CLIPProcessor.from_pretrained(config.clip_model_name) clip_scores =[]for animal in selected_animals: prompt =f"a high quality photo of a {animal}"# 生成图像with torch.autocast(device.type): image = pipe( prompt, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, height=config.resolution, width=config.resolution ).images[0]# 计算CLIP Scorewith torch.no_grad():# 处理图像和文本 inputs = clip_processor( text=[prompt], images=image, return_tensors="pt", padding=True).to(device)# 获取特征 outputs = clip_model(**inputs)# 计算相似度 (CLIP Score) logits_per_image = outputs.logits_per_image # 图像-文本相似度 clip_score = logits_per_image.item()print(f"Animal: {animal}, CLIP Score: {clip_score:.4f}") clip_scores.append(clip_score) avg_clip_score = np.mean(clip_scores)print(f"Average Validation CLIP Score: {avg_clip_score:.4f}")return avg_clip_score ⑥ 开始训练!
模型初始化:
训练函数的初始部分负责加载和配置 Stable Diffusion 的核心组件:
- 组件构成:
- CLIP 组件:tokenizer(文本分词)和 text_encoder(文本编码)
- VAE:变分自编码器,负责图像与潜空间的转换
- UNet:扩散模型核心,通过 LoRA 适配后成为主要训练对象
- 噪声调度器:控制扩散过程的噪声添加和去除
- 显存优化:
- 启用梯度检查点(
gradient checkpointing)牺牲少量计算时间换取显存节省 - 仅 UNet 添加可训练参数,其他组件保持冻结状态
- 启用梯度检查点(
# 2. 训练函数 (包含早停和历史记录)deftrain_lora_with_earlystopping(config):# 初始化模型组件 tokenizer = CLIPTokenizer.from_pretrained( config.pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained( config.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained( config.pretrained_model_name_or_path, subfolder="unet")# 添加LoRA适配器到UNet unet = prepare_unet_for_lora(unet, config.rank, config.lora_alpha)# 设置噪声调度器 noise_scheduler = DDPMScheduler.from_pretrained( config.pretrained_model_name_or_path, subfolder="scheduler")# 启用梯度检查点以节省显存if config.gradient_checkpointing: unet.enable_gradient_checkpointing()# 将模型移动到GPU device = torch.device("cuda"if torch.cuda.is_available()else"cpu") text_encoder.to(device) vae.to(device) unet.to(device)优化器与数据加载配置 :
- 优化器设置:
- 仅对 LoRA 参数进行优化,大幅减少优化器状态内存占用
- 采用
AdamW优化器,配合合理的权重衰减( 0.01 0.01 0.01)防止过拟合 - 学习率设置为 1 e − 5 1e-5 1e−5,远低于全参数微调,适合 LoRA 稳定训练
- 数据加载:
- 使用自定义
AnimalDataset加载数据 - 按 9 : 1 9:1 9:1 比例分割训练集和验证集
- 配置
DataLoader实现批量加载和多进程预处理
- 使用自定义
# 设置优化器 (只优化LoRA参数) lora_params =[]for name, param in unet.named_parameters():if param.requires_grad:# 只优化需要梯度的参数 lora_params.append(param)# 使用更稳定的优化器配置 optimizer = torch.optim.AdamW( lora_params, lr=config.learning_rate, betas=(0.9,0.999), eps=1e-8, weight_decay=0.01)# 准备数据集和数据加载器 full_dataset = AnimalDataset( config.data_root, tokenizer, size=config.resolution, center_crop=config.center_crop, random_flip=config.random_flip, max_samples_per_class=config.max_samples_per_class )# 分割训练集和验证集 val_size =int(len(full_dataset)* config.validation_split) train_size =len(full_dataset)- val_size train_dataset, val_dataset = random_split(full_dataset,[train_size, val_size]) train_dataloader = DataLoader( train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=2) val_dataloader = DataLoader( val_dataset, batch_size=config.train_batch_size, shuffle=False, num_workers=2)训练调度与记录配置:
- 学习率调度:
- 采用 余弦退火 带预热的调度策略
- 预热步数 200 200 200,使模型在初始阶段稳定收敛
- 总步数根据 e p o c h epoch epoch 数和梯度累积动态计算
- 训练记录:
- 使用 openpyxl 创建 Excel 日志
- 记录关键指标:损失、CLIP 分数、学习率、梯度范数等
- 为后续分析和模型优化提供完整数据支持
# 计算总训练步数 num_update_steps_per_epoch =len(train_dataloader)// config.gradient_accumulation_steps max_train_steps = config.num_train_epochs * num_update_steps_per_epoch # 学习率调度器 lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps, num_training_steps=max_train_steps )# 初始化早停 (使用CLIP分数作为指标) early_stopping = EarlyStopping( patience=config.early_stopping_patience, delta=config.early_stopping_delta, verbose=True)# 创建Excel工作簿用于记录历史 history_wb = Workbook() history_ws = history_wb.active history_ws.title ="Training History" history_ws.append(["Epoch","Step","Train Loss","Validation Loss","CLIP Score","Learning Rate","Best CLIP Score","Gradient Norm"])核心训练循环:
训练循环实现了 Stable Diffusion 的噪声预测训练过程,关键步骤包括:
- 潜空间编码:
- 使用 VAE 将图像编码到潜空间,
with torch.no_grad()冻结 VAE 参数 - 应用
0.18215缩放因子,这是 Stable Diffusion 的标准处理流程
- 使用 VAE 将图像编码到潜空间,
- 扩散过程模拟:
- 随机采样时间步和噪声
- 向潜变量添加噪声,模拟扩散过程
- UNet 根据带噪声的潜变量、时间步和文本嵌入预测原始噪声
- 梯度管理:
- 采用梯度累积(4 步)模拟更大批次训练
- 梯度裁剪 控制梯度范数,防止训练不稳定
- 仅在累积步数完成后更新参数,节省显存
# 训练循环 global_step =0 best_clip_score =0.0# 训练循环部分for epoch inrange(config.num_train_epochs): unet.train() total_loss =0 optimizer.zero_grad() current_grad_norm =0.0# 初始化梯度范数for step, batch inenumerate(train_dataloader):# 将批次数据移动到设备 pixel_values = batch["pixel_values"].to(device) input_ids = batch["input_ids"].to(device)# 将图像编码到潜在空间with torch.no_grad(): latents = vae.encode(pixel_values).latent_dist.sample() latents = latents *0.18215# 缩放因子# 采样噪声 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,(bsz,), device=device).long()# 向潜在表示添加噪声 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)# 获取文本嵌入with torch.no_grad(): encoder_hidden_states = text_encoder(input_ids)[0]# 预测噪声残差 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # 计算损失 loss = F.mse_loss(noise_pred, noise, reduction="mean")/ config.gradient_accumulation_steps # 反向传播 loss.backward()# 梯度累积if(step +1)% config.gradient_accumulation_steps ==0:# 梯度裁剪 torch.nn.utils.clip_grad_norm_(lora_params, config.max_grad_norm)# 计算梯度范数用于监控 current_grad_norm =0for p in lora_params:if p.grad isnotNone: param_norm = p.grad.data.norm(2) current_grad_norm += param_norm.item()**2 current_grad_norm = current_grad_norm **0.5# 更新参数 optimizer.step() lr_scheduler.step() optimizer.zero_grad() global_step +=1 total_loss += loss.item()* config.gradient_accumulation_steps # 打印训练信息if global_step %50==0:# 减少打印频率 avg_loss = total_loss /(step +1) current_lr = lr_scheduler.get_last_lr()[0]print(f"Epoch {epoch}, Step {global_step}, Loss: {avg_loss:.4f}, LR: {current_lr:.6f}, Grad Norm: {current_grad_norm:.6f}")epoch 后处理流程:
每个训练 epoch 结束后执行的关键操作:
- 验证评估:
- 计算验证集损失,评估模型泛化能力
- 生成验证图像并计算 CLIP 分数,评估生成质量
- 模型保存:
- 仅保存 CLIP 分数提升的模型,确保最佳性能
- 单独保存 LoRA 权重,文件小且便于部署
- 早停检查:
- 基于 CLIP 分数判断是否继续训练
- 达到早停条件则终止训练,避免无效计算
# 每个epoch结束后计算验证损失和CLIP分数 val_loss = compute_validation_loss(unet, vae, text_encoder, val_dataloader, noise_scheduler, device) avg_train_loss = total_loss /len(train_dataloader)print(f"Epoch {epoch} completed. Train Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}")# 计算CLIP分数 clip_score = compute_validation_clip_score(config, unet, text_encoder, vae, tokenizer, device)# 记录到历史 current_lr = lr_scheduler.get_last_lr()[0] history_ws.append([epoch, global_step, avg_train_loss, val_loss, clip_score, current_lr, best_clip_score, current_grad_norm])# 早停检查 (基于CLIP分数) early_stopping(clip_score)# 保存最佳模型if clip_score > best_clip_score: best_clip_score = clip_score # 保存LoRA权重 save_path = os.path.join(config.lora_model_dir,f"lora_weights_epoch_{epoch}.safetensors") save_lora_weights(unet, save_path)print(f"Saved best model with CLIP score: {best_clip_score:.4f}")# 保存训练历史 history_wb.save(config.history_file)# 检查早停if early_stopping.early_stop:print("Early stopping triggered")breakprint("Training completed!")return unet, text_encoder, vae, tokenizer 训练过程示例输出如下所示:
Starting LoRA training…
trainable params: 398,592 || all params: 859,919,556 || trainable%: 0.0464
Epoch 0, Step 0, Loss: 0.0044, LR: 0.000000, Grad Norm: 0.000000
Epoch 0, Step 0, Loss: 0.0077, LR: 0.000000, Grad Norm: 0.000000
Epoch 0, Step 0, Loss: 0.0216, LR: 0.000000, Grad Norm: 0.000000
Epoch 0, Step 50, Loss: 0.2194, LR: 0.000003, Grad Norm: 0.227759
Epoch 0, Step 50, Loss: 0.2198, LR: 0.000003, Grad Norm: 0.227759
…
Epoch 0, Step 2250, Loss: 0.1937, LR: 0.000010, Grad Norm: 0.048456
Epoch 0 completed. Train Loss: 0.1937, Validation Loss: 0.1842
Selected animals for validation CLIP score: [‘dolphin’, ‘zebra’, ‘sandpiper’, ‘swan’, ‘pig’]
Animal: dolphin, CLIP Score: 29.2128
Animal: zebra, CLIP Score: 32.8276
Animal: sandpiper, CLIP Score: 30.2431
Animal: swan, CLIP Score: 28.6533
Animal: pig, CLIP Score: 29.8067
Average Validation CLIP Score: 30.1487
Saved best model with CLIP score: 30.1487
…
Evaluation results saved to /kaggle/working/output/evaluation_results.xlsx
All done! Average CLIP Score: 31.1877
⑦ 验证与评估
验证样本生成:
该函数在训练完成后生成代表性样本用于可视化评估:
- 随机选择动物类别,避免评估偏差
- 使用更多推理步数(
100 步)生成高质量样本 - 保存单张图像并创建网格对比图,便于直观检查
# 3. 验证和生成样本defgenerate_validation_samples(config, unet, text_encoder, vae, tokenizer): device = torch.device("cuda"if torch.cuda.is_available()else"cpu")# 获取所有动物类别 animal_classes =[f.name for f in os.scandir(config.data_root)if f.is_dir()]# 随机选择5种动物 selected_animals = random.sample(animal_classes, config.num_validation_samples)print(f"Selected animals for validation: {selected_animals}")# 创建生成管道 pipe = StableDiffusionPipeline.from_pretrained( config.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, safety_checker=None,# 禁用安全检查器以加快生成速度 torch_dtype=torch.float16 if config.mixed_precision =="fp16"else torch.float32 ).to(device)# 生成每种动物的图像 all_images =[] all_titles =[]for animal in selected_animals: prompt =f"a high quality photo of a {animal}"# 生成图像 (使用更多推理步数)with torch.autocast(device.type): image = pipe( prompt, num_inference_steps=config.num_final_inference_steps, guidance_scale=config.guidance_scale, height=config.resolution, width=config.resolution ).images[0]# 保存图像 save_path = os.path.join(config.sample_output_dir,f"{animal}.png") image.save(save_path)print(f"Generated image for {animal} saved at {save_path}")# 添加到列表用于创建对比图像 all_images.append(image) all_titles.append(animal)# 创建对比图像 comparison_path = os.path.join(config.comparison_dir,"animal_comparison.png") create_comparison_image(all_images, all_titles, comparison_path)return selected_animals 量化评估实现:
该函数提供训练后的全面量化评估:
- 对更多动物类别(10 种)进行评估,提高结果可靠性
- 结合生成图像和 CLIP 分数,形成完整评估记录
- 将结果保存到 Excel,包含单类分数和平均分数,便于分析模型优势和不足
# 4. 评估函数 - 使用CLIP Score评估生成质量defevaluate_with_clip_score(config, unet, text_encoder, vae, tokenizer): device = torch.device("cuda"if torch.cuda.is_available()else"cpu")# 加载CLIP模型和处理器 clip_model = CLIPModel.from_pretrained(config.clip_model_name).to(device) clip_processor = CLIPProcessor.from_pretrained(config.clip_model_name)# 获取所有动物类别 animal_classes =[f.name for f in os.scandir(config.data_root)if f.is_dir()]# 随机选择评估用的动物 selected_animals = random.sample(animal_classes, config.num_evaluation_samples)print(f"Selected animals for evaluation: {selected_animals}")# 创建生成管道 pipe = StableDiffusionPipeline.from_pretrained( config.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, safety_checker=None, torch_dtype=torch.float16 if config.mixed_precision =="fp16"else torch.float32 ).to(device)# 存储评估结果 evaluation_results =[]for animal in selected_animals: prompt =f"a high quality photo of a {animal}"# 生成图像 (使用更多推理步数)with torch.autocast(device.type): image = pipe( prompt, num_inference_steps=config.num_final_inference_steps, guidance_scale=config.guidance_scale, height=config.resolution, width=config.resolution ).images[0]# 保存图像 save_path = os.path.join(config.sample_output_dir,f"eval_{animal}.png") image.save(save_path)# 计算CLIP Scorewith torch.no_grad():# 处理图像和文本 inputs = clip_processor( text=[prompt], images=image, return_tensors="pt", padding=True).to(device)# 获取特征 outputs = clip_model(** inputs)# 计算相似度 (CLIP Score) logits_per_image = outputs.logits_per_image # 图像-文本相似度 clip_score = logits_per_image.item()print(f"Animal: {animal}, CLIP Score: {clip_score:.4f}") evaluation_results.append({"animal": animal,"prompt": prompt,"clip_score": clip_score,"image_path": save_path })# 计算平均CLIP Score avg_clip_score = np.mean([result["clip_score"]for result in evaluation_results])print(f"Average CLIP Score: {avg_clip_score:.4f}")# 保存评估结果到Excel evaluation_wb = Workbook() evaluation_ws = evaluation_wb.active evaluation_ws.title ="Evaluation Results" evaluation_ws.append(["Animal","Prompt","CLIP Score","Image Path"])for result in evaluation_results: evaluation_ws.append([result["animal"], result["prompt"], result["clip_score"], result["image_path"]]) evaluation_ws.append([]) evaluation_ws.append(["Average CLIP Score", avg_clip_score]) evaluation_wb.save(config.evaluation_file)print(f"Evaluation results saved to {config.evaluation_file}")return evaluation_results, avg_clip_score 2. UI 界面代码实现
① 全局参数配置
- 模型路径:
pretrained_model_name_or_path指定本地 Stable Diffusion 基础模型目录,需包含text_encoder、unet、vae等核心组件; - 引导尺度(guidance_scale):核心生成参数 —— 值过高会导致图像 “过度贴合提示词”(可能色彩失真),值过低则生成结果偏离提示词,此处设为 5.0 5.0 5.0 是平衡贴合度与色彩自然度的经验值;
- 色彩因子:新增的
contrast/saturation/brightness_factor为后期图像校正参数,通过 PIL 库实现实时调整,解决扩散模型生成图像可能偏灰、色彩暗淡的问题。
# 配置类 - 增加色彩相关参数classConfig: pretrained_model_name_or_path ="model/LCM-runwayml-stable-diffusion-v1-5"# 本地基础模型路径 resolution =512# 生成图像分辨率(默认512x512) rank =2# LoRA微调秩(控制微调强度) lora_alpha =16# LoRA缩放因子 device ="cpu"# 运行设备(cpu/cuda,cuda需安装GPU版本PyTorch) num_final_inference_steps =100# 默认推理步数(步数越多生成越精细,但耗时更长) guidance_scale =5.0# 引导尺度(控制提示词对生成的影响,值越低色彩越自然) contrast_factor =1.0# 对比度调整因子(1.0为默认,<1降低对比度,>1增强) saturation_factor =1.0# 饱和度调整因子(同上,影响色彩鲜艳度) brightness_factor =1.0# 亮度调整因子(同上,影响图像明暗)② 核心技术函数
包含 3 个核心工具函数,分别解决 LoRA 权重加载、Tokenizer 加载异常、图像色彩校正 三大关键问题,是模型正常运行与生成效果优化的基础
- LoRA 权重加载函数(load_lora_weights)
LoRA(Low-Rank Adaptation)是轻量级微调技术,通过加载预训练的 LoRA 权重,可让基础模型快速适配 “动物图像生成” 场景(无需重新训练整个模型)
# 加载LoRA权重的函数defload_lora_weights(unet, load_path):# 从本地文件加载LoRA权重,指定设备(与模型一致) lora_state_dict = torch.load(load_path, map_location=torch.device(Config.device))# 非严格模式加载(LoRA权重仅覆盖unet部分层,无需匹配所有参数) unet.load_state_dict(lora_state_dict, strict=False)return unet - Tokenizer 修复函数(load_tokenizer_with_fix)
Tokenizer(文本分词器)是将 “提示词” 转换为模型可识别向量的组件,该函数解决了 “本地模型 Tokenizer 路径异常” 的常见问题,提供降级加载方案
# 修复tokenizer加载问题的函数defload_tokenizer_with_fix(model_path):try:# 尝试正常加载(默认路径:模型目录下的tokenizer文件夹) tokenizer = CLIPTokenizer.from_pretrained( os.path.join(model_path,"tokenizer"))return tokenizer except Exception as e:print(f"加载tokenizer时出错: {e}")print("尝试修复tokenizer配置...")# 降级方案:手动指定vocab.json和merges.txt文件(Tokenizer核心文件)from transformers import CLIPTokenizerFast vocab_file = os.path.join(model_path,"tokenizer","vocab.json") merges_file = os.path.join(model_path,"tokenizer","merges.txt")if os.path.exists(vocab_file)and os.path.exists(merges_file): tokenizer = CLIPTokenizerFast( vocab_file=vocab_file, merges_file=merges_file, max_length=77,# CLIP模型固定输入长度(超过截断,不足补全) pad_token="!",# 填充token(统一输入长度) additional_special_tokens=["<startoftext|>","<endoftext|>"]# 特殊分隔符)return tokenizer else:raise Exception(f"找不到tokenizer文件: {vocab_file} 或 {merges_file}")- 图像色彩校正函数(adjust_image_colors)
扩散模型生成的图像常存在 “对比度不足、色彩暗淡” 问题,该函数通过 PIL 的ImageEnhance模块,对生成图像进行后处理,提升视觉效果
# 图像色彩校正函数defadjust_image_colors(image):"""调整图像的色彩、对比度和饱和度,使其更自然"""# 1. 调整对比度(增强细节层次) enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(Config.contrast_factor)# 2. 调整饱和度(提升色彩鲜艳度,避免偏灰) enhancer = ImageEnhance.Color(image) image = enhancer.enhance(Config.saturation_factor)# 3. 调整亮度(平衡整体明暗,避免过暗/过曝) enhancer = ImageEnhance.Brightness(image) image = enhancer.enhance(Config.brightness_factor)return image ③ 模型加载器:整合 Stable Diffusion 核心组件(ModelLoader)
Stable Diffusion 由Tokenizer + Text Encoder + UNet + VAE + Scheduler五大组件构成,ModelLoader类负责将这些组件从本地加载、组装为可直接调用的StableDiffusionPipeline(生成流水线),并集成 LoRA 权重
- 组件分工:
Text Encoder负责文本→向量,UNet负责向量→latent(隐空间向量),VAE负责latent→图像,Scheduler控制去噪步数节奏; - LoRA 集成:仅对
UNet加载 LoRA 权重 —— 因为 UNet 是扩散模型的核心生成层,微调 UNet 即可快速改变生成风格(动物图像),无需调整其他组件; - 设备适配:所有组件需统一移动到同一设备(cpu/cuda),否则会出现 “设备不匹配” 错误。
# 模型加载类classModelLoader:def__init__(self, config, lora_model_path): self.config = config # 全局配置 self.lora_model_path = lora_model_path # LoRA模型路径# 初始化各组件(后续加载) self.tokenizer =None# 文本分词器 self.text_encoder =None# 文本编码器(将分词结果转为向量) self.vae =None# 变分自编码器(负责图像解码:latent→像素) self.unet =None# 核心生成网络(扩散过程核心,更新latent) self.pipe =None# 最终生成流水线defload_models(self):# 1. 加载Tokenizer(调用修复函数,避免路径异常) self.tokenizer = load_tokenizer_with_fix(self.config.pretrained_model_name_or_path)# 2. 加载Text Encoder(CLIP模型,将文本转为语义向量) text_encoder_path = os.path.join(self.config.pretrained_model_name_or_path,"text_encoder") self.text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)# 3. 加载VAE(将扩散过程的 latent 向量解码为图像像素) vae_path = os.path.join(self.config.pretrained_model_name_or_path,"vae") self.vae = AutoencoderKL.from_pretrained(vae_path)# 4. 加载UNet(扩散核心,通过迭代去噪生成latent) unet_path = os.path.join(self.config.pretrained_model_name_or_path,"unet") self.unet = UNet2DConditionModel.from_pretrained(unet_path)# 5. 加载LoRA权重到UNet(让模型适配动物图像生成) self.unet = load_lora_weights(self.unet, self.lora_model_path)# 6. 将所有组件移动到指定设备(cpu/cuda) self.text_encoder.to(self.config.device) self.vae.to(self.config.device) self.unet.to(self.config.device)# 7. 加载Scheduler(扩散调度器,控制去噪步骤节奏) scheduler_path = os.path.join(self.config.pretrained_model_name_or_path,"scheduler") scheduler = DDPMScheduler.from_pretrained(scheduler_path)# 8. 组装生成流水线(整合所有组件,提供统一生成接口) self.pipe = StableDiffusionPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=scheduler, safety_checker=None,# 关闭安全检查(避免误判动物图像) feature_extractor=None, requires_safety_checker=False)return self.pipe ④ 生成线程:避免 UI 卡顿(GenerateThread)
图像生成是耗时操作(尤其 CPU 运行时),若直接在主线程执行会导致 UI 卡死。GenerateThread继承QThread,将生成逻辑放入子线程,通过信号机制与主线程(UI)交互,实时反馈进度
- 信号机制:通过
pyqtSignal定义 3 3 3 类信号,实现子线程与 UI 的“无阻塞通信”—— 进度更新实时反馈,完成 / 错误信号触发 UI 后续操作; - 双嵌入设计:同时计算 “条件嵌入(有提示词)” 和 “无条件嵌入(空提示词)”,通过引导尺度控制生成结果与提示词的贴合度;
- Latent 尺寸:分辨率需除以 8(VAE 的下采样比例),例如512x512的图像对应 Latent 尺寸为64x64;
- 色彩校正集成:在图像生成完成后、返回 UI 前执行色彩调整,确保最终显示的图像色彩自然。
# 生成线程类 - 增加色彩校正步骤classGenerateThread(QThread):# 定义信号:生成完成(返回PIL图像)、错误(返回错误信息)、进度更新(进度百分比+剩余时间) finished = pyqtSignal(Image.Image) error = pyqtSignal(str) progress_updated = pyqtSignal(int,float)def__init__(self, pipe, animal_name, num_inference_steps, guidance_scale, contrast_factor, saturation_factor, brightness_factor):super().__init__() self.pipe = pipe # 生成流水线 self.animal_name = animal_name # 目标动物名称(用户输入) self.num_inference_steps = num_inference_steps # 推理步数 self.guidance_scale = guidance_scale # 引导尺度# 色彩调整参数(从UI获取,覆盖全局配置) self.contrast_factor = contrast_factor self.saturation_factor = saturation_factor self.brightness_factor = brightness_factor self.start_time =0# 生成开始时间(计算总耗时) self.step_times =[]# 每步耗时(估算剩余时间)defrun(self):try:# 1. 优化提示词(增加环境/光照描述,提升生成质量) prompt =(f"a high quality photo of a {self.animal_name}, natural lighting, "f"realistic colors, in natural habitat, detailed texture")# 2. 文本编码(生成“条件嵌入”和“无条件嵌入”,用于引导生成)with torch.no_grad():# 禁用梯度计算,减少内存占用# 条件嵌入:基于提示词的向量(引导模型生成符合提示的内容) text_inputs = self.pipe.tokenizer( prompt, padding="max_length",# 补全到77长度 max_length=self.pipe.tokenizer.model_max_length, truncation=True,# 超过77长度截断 return_tensors="pt",# 返回PyTorch张量) text_input_ids = text_inputs.input_ids text_embeddings = self.pipe.text_encoder(text_input_ids.to(self.pipe.device))[0]# 无条件嵌入:基于空提示词的向量(作为对比,让模型更“关注”条件提示词) max_length = text_input_ids.shape[-1] uncond_input = self.pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt",) uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.pipe.device))[0]# 合并两种嵌入(Stable Diffusion要求输入格式:[无条件, 条件]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])# 3. 初始化Latent(隐空间向量,扩散模型的“初始噪声”) latents = torch.randn((1, self.pipe.unet.config.in_channels, Config.resolution //8, Config.resolution //8),# 尺寸=分辨率/8(VAE下采样比例) generator=torch.Generator(device=Config.device),# 随机生成器(保证可复现) device=Config.device,)# 4. 配置调度器(设置推理步数) self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=Config.device)# 5. 迭代扩散去噪(核心步骤:逐步将噪声转为符合提示词的latent) self.start_time = time.time() self.step_times =[]for i, t inenumerate(self.pipe.scheduler.timesteps): step_start_time = time.time()# 复制latent(对应两种嵌入:无条件+条件) latent_model_input = torch.cat([latents]*2)# 调度器缩放输入(匹配当前去噪步骤的噪声水平) latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)# UNet预测噪声(输入:当前latent+时间步t+文本嵌入,输出:预测的噪声) noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # 分离噪声(无条件预测 vs 条件预测) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)# 引导式去噪(用引导尺度控制提示词影响:噪声=无条件噪声 + 引导尺度*(条件噪声-无条件噪声)) noise_pred = noise_pred_uncond + self.guidance_scale *(noise_pred_text - noise_pred_uncond)# 调度器更新latent(根据预测噪声去除当前步骤的噪声) latents = self.pipe.scheduler.step(noise_pred, t, latents).prev_sample # 6. 计算进度与剩余时间(反馈给UI) step_time = time.time()- step_start_time self.step_times.append(step_time) progress =int((i +1)/ self.num_inference_steps *100)# 进度百分比 steps_remaining = self.num_inference_steps -(i +1)# 剩余步数# 估算剩余时间(用最近5步的平均耗时,避免初始步骤波动影响)iflen(self.step_times)>=5: avg_step_time =sum(self.step_times[-5:])/5else: avg_step_time =sum(self.step_times)/len(self.step_times)if self.step_times else0 remaining_time = avg_step_time * steps_remaining # 发送进度信号(UI接收后更新进度条) self.progress_updated.emit(progress, remaining_time)# 7. 解码Latent为图像(VAE将隐空间向量转为像素) latents =1/0.18215* latents # VAE解码缩放因子(固定值,模型训练时确定)with torch.no_grad(): image = self.pipe.vae.decode(latents).sample # 8. 图像后处理(标准化→转PIL→色彩校正) image =(image /2+0.5).clamp(0,1)# 标准化:将[-1,1]转为[0,1] image = image.cpu().permute(0,2,3,1).float().numpy()# 调整维度:(1,C,H,W)→(H,W,C) image =(image[0]*255).round().astype("uint8")# 转为8位像素(0-255) image = Image.fromarray(image)# 转PIL图像# 应用色彩校正(调用之前定义的函数) image = adjust_image_colors(image)# 确保图像分辨率一致if image.size !=(Config.resolution, Config.resolution): image = image.resize((Config.resolution, Config.resolution), Image.LANCZOS)# 高质量缩放# 9. 发送生成完成信号(UI接收后显示图像) self.finished.emit(image)except Exception as e:# 发送错误信号(UI接收后弹窗提示) self.error.emit(str(e))⑤ 主窗口 UI:可视化交互入口(AnimalGeneratorApp)
AnimalGeneratorApp继承QMainWindow,是整个工具的 “用户交互中心”,负责构建 UI 布局、绑定按钮事件、处理线程信号(显示进度 / 图像 / 错误)。核心分为左侧控制面板和右侧图像显示区两部分,以下重点讲解核心功能逻辑:
- UI 布局核心逻辑
classAnimalGeneratorApp(QMainWindow):def__init__(self):super().__init__() self.pipe =None# 生成流水线(加载模型后赋值) self.current_image =None# 当前生成的图像 self.initUI()# 初始化UIdefinitUI(self):# 1. 基础设置(字体、窗口标题、尺寸) font = QFont("SimHei")# 支持中文显示(避免乱码) font.setPointSize(10) self.setFont(font) self.setWindowTitle('动物图像生成器') self.setGeometry(100,100,1100,800)# 窗口位置与尺寸# 2. 中心部件与主布局(左右分栏:控制面板+图像显示区) central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QHBoxLayout(central_widget) main_layout.setContentsMargins(15,15,15,15) main_layout.setSpacing(20)# 3. 左侧控制面板(模型设置、生成参数、进度) control_panel = self.create_control_panel() main_layout.addWidget(control_panel,3)# 占3份宽度# 4. 右侧图像显示区(默认图、生成图、水印) image_panel = self.create_image_panel() main_layout.addWidget(image_panel,5)# 占5份宽度(图像区更宽,提升体验)- 核心功能绑定(以 “生成图像” 为例)
defgenerate_image(self):# 前置校验(避免无效操作)ifnot self.pipe:# 未加载模型 QMessageBox.warning(self,"错误","请先加载模型")return animal_name = self.animal_edit.text().strip()ifnot animal_name:# 未输入动物名称 QMessageBox.warning(self,"错误","请输入动物名称")return# 1. 获取UI参数(覆盖全局配置) Config.resolution = self.resolution_spin.value() num_inference_steps = self.steps_spin.value() guidance_scale = self.guidance_spin.value() contrast_factor = self.contrast_spin.value() saturation_factor = self.saturation_spin.value() brightness_factor = self.brightness_spin.value()# 2. UI状态更新(禁用生成/保存按钮,显示进度条) self.generate_btn.setEnabled(False) self.save_btn.setEnabled(False) self.progress_bar.setVisible(True) self.progress_bar.setRange(0,100) self.progress_bar.setValue(0) self.progress_label.setText("准备生成(第一次加载请耐心等待哦)...") self.statusBar().showMessage("正在生成图像,请稍候...")# 3. 启动生成线程(传入参数,绑定信号) self.gen_thread = GenerateThread( self.pipe, animal_name, num_inference_steps, guidance_scale, contrast_factor, saturation_factor, brightness_factor )# 线程信号绑定:完成→显示图像,错误→弹窗,进度→更新进度条 self.gen_thread.finished.connect(self.on_generation_finished) self.gen_thread.error.connect(self.on_generation_error) self.gen_thread.progress_updated.connect(self.on_progress_updated) self.gen_thread.start()# 启动线程(执行run方法)# 生成完成回调(接收线程信号,显示图像)defon_generation_finished(self, image): self.current_image = image # 保存当前图像(用于后续保存) pixmap = self.pil2pixmap(image)# PIL图像转PyQt5的QPixmap(用于显示)# 显示图像(缩放至图像区尺寸,保持比例) self.image_label.setPixmap(pixmap.scaled( self.image_label.width(), self.image_label.height(), Qt.KeepAspectRatio, Qt.SmoothTransformation ))# 恢复UI状态(启用按钮,更新提示) self.generate_btn.setEnabled(True) self.save_btn.setEnabled(True) self.progress_bar.setValue(100) self.progress_label.setText("生成完成!") self.statusBar().showMessage("图像生成成功!")- 图像显示层级设计(提升用户体验)
右侧图像显示区采用三层叠加设计,兼顾“默认提示”+“生成结果”+“版权水印”:- 底层:默认图(70% 透明度,提示用户 “生成图像的样例图”);
- 中层:生成图(完全不透明,覆盖默认图);
- 顶层:水印(右下角半透明,显示 “制作者:热心市民小周”(嗨嗨嗨,就是我 (◦˙▽˙◦)))
defcreate_image_panel(self): panel = QWidget() layout = QVBoxLayout(panel)# 图像显示容器(带阴影,提升美观度) image_container = QWidget() image_container.setStyleSheet(""" background-color: white; border-radius: 8px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); padding: 15px; """) image_layout = QVBoxLayout(image_container)# 1. 底层:默认图(70%透明度) self.default_image_label = QLabel() self.default_image_label.setAlignment(Qt.AlignCenter) self.default_image_label.setMinimumSize(512,512) self.load_default_image()# 加载默认提示图(如“请生成动物图像”)# 2. 中层:生成图(初始为空,生成后显示) self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignCenter) self.image_label.setMinimumSize(512,512) self.image_label.setStyleSheet("background-color: transparent;")# 透明背景,避免遮挡底层# 3. 顶层:水印(右下角对齐) self.watermark_label = QLabel("制作者:热心市民小周") self.watermark_label.setStyleSheet(""" color: rgba(100, 100, 100, 150); /* 半透明灰色 */ font-size: 12px; padding: 5px; background-color: rgba(255, 255, 255, 100); border-radius: 2px; """) self.watermark_label.setAlignment(Qt.AlignRight | Qt.AlignBottom)# 网格布局实现层级叠加(同一单元格内,后添加的控件在顶层) grid_layout = QGridLayout() grid_layout.addWidget(self.default_image_label,0,0)# 底层 grid_layout.addWidget(self.image_label,0,0)# 中层 grid_layout.addWidget(self.watermark_label,0,0)# 顶层 image_layout.addLayout(grid_layout) layout.addWidget(image_container,1)return panel 最终的界面图如下所示:
六、结果展示
训练指标
- 训练损失: 随epoch下降的MSE损失
- 验证损失: 评估模型泛化能力
- CLIP分数: 衡量生成图像与文本的匹配程度
- 学习率LR:随训练而递减
梯度范数GN:衡量参数更新规模
推理时间
| 推理步数 | 推理平均时间(CPU) | 平均总时间 |
|---|---|---|
| 20 | 68.64s | 456.65s |
| 100 | 349.35s | 823.45s |
| 200 | 683.96s | 1209.65s |
| 400 | 1356.86s | 1863.25s |
生成示例
系统能够生成多种动物的高质量图像,包括:
- 狮子、老虎、大象等大型动物
- 猫、狗等常见宠物
鸟类、鱼类等各种动物类别
更多验证输出样例可见output/pic使用不同推理步数得到的tiger图像如下所示:
UI界面的使用实例如下:
100类动物图像生成
如果你喜欢我的文章,不妨给小周一个免费的点赞和关注吧!