基于 LoRA 与 Stable Diffusion 的 100 种动物图像生成系统
本项目介绍了一个基于 Stable Diffusion 和 LoRA 技术的动物图像生成系统。系统支持 100 种动物类别,采用 PyTorch 框架进行训练,使用 LoRA 技术进行参数高效微调。项目包含完整的训练流程、数据增强、早停机制及基于 CLIP 分数的评估体系。同时提供 PyQt5 图形界面,支持实时参数调整与图像预览。实现了显存优化、色彩校正等功能,适用于 CPU 和 GPU 环境。

本项目介绍了一个基于 Stable Diffusion 和 LoRA 技术的动物图像生成系统。系统支持 100 种动物类别,采用 PyTorch 框架进行训练,使用 LoRA 技术进行参数高效微调。项目包含完整的训练流程、数据增强、早停机制及基于 CLIP 分数的评估体系。同时提供 PyQt5 图形界面,支持实时参数调整与图像预览。实现了显存优化、色彩校正等功能,适用于 CPU 和 GPU 环境。

代码详见:https://github.com/xiaozhou-alt/Animals_Generationn
这是一个基于 Stable Diffusion 和 LoRA 技术的动物图像生成系统,能够通过文本描述生成高质量的动物图像,包含完整的训练流程和用户友好的图形界面,支持自定义参数调整和实时图像生成。
主要特性
生成的部分动物图像:
Animals_Creation/
├── README.md
├── demo.gif # 演示动画
├── demo.mp4 # 演示视频
├── demo.py # 主演示脚本
├── icons/ # 图标资源目录
├── train.py
├── log/ # 日志目录
├── model/
│ └── LCM-runwayml-stable-diffusion-v1-5/ # Stable Diffusion 模型
│ ├── feature_extractor/ # 特征提取器
│ ├── model_index.json # 模型索引文件
│ ├── safety_checker/ # 安全检查器
│ ├── scheduler/ # 调度器
│ ├── text_encoder/ # 文本编码器
│ ├── tokenizer/ # 分词器
│ ├── unet/ # UNet 模型
│ └── vae/ # 变分自编码器
├── output/
│ ├── evaluation_results.xlsx # 评估结果 Excel 文件
│ ├── lora_models/ # LoRA 模型权重
│ │ └── clip-31.475.safetensors
│ ├── training_history.xlsx # 训练历史记录
│ └── pic/
└── requirements.txt
本项目使用的动物数据集包含 100 个不同类别的动物图片,因为使用网页图片提取下载,清洗由个人完全进行,数据集数据量较大,所以部分动物文件夹存在 1%-1.5% 的噪声图片,数据集组织结构如下:
在模型训练过程中,通过数据增强技术扩充了训练样本,包括旋转、平移、缩放、亮度调整等操作,以提高模型的泛化能力。
动物的类别信息请查看 class.txt:
antelope badger bat …
数据集下载:100 种动物识别数据集 (ScienceDB)
引用 如果您使用了本项目的数据集,请使用如下方式进行引用:
Haojing ZHOU.100 种动物识别数据集 [DS/OL]. V1. Science Data Bank,2025[2025-08-30]. https://cstr.cn/31253.11.sciencedb.29221. CSTR:31253.11.sciencedb.29221.
或
@misc{动物识别,author ={Haojing ZHOU}, title ={100 种动物识别数据集}, year ={2025}, doi ={10.57760/sciencedb.29221}, url ={https://doi.org/10.57760/sciencedb.29221}, note ={CSTR:31253.11.sciencedb.29221}, publisher ={ScienceDB}}
Stable Diffusion 采用潜在扩散模型(Latent Diffusion Model)架构,通过将高维图像压缩到低维潜在空间进行扩散过程,显著提升了计算效率。该模型主要由四个核心组件构成:变分自编码器(VAE) + CLIP 文本编码器 + U-Net 模型 + 噪声调度器(DDPMScheduler)
VAE 在 Stable Diffusion 中承担着图像与潜在空间的双向转换任务。其编码器将输入图像 x 压缩为潜在表示 z,解码器则将潜在表示重建为图像 x^。在代码实现中,我们使用预训练的 AutoencoderKL 模型:
vae = AutoencoderKL.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="vae"
)
VAE 的核心工作原理是通过变分推断学习数据的潜在分布。对于输入图像 x,编码器输出潜在分布的均值 μ 和方差 σ²,通过重参数化技巧采样得到潜在表示: z = μ + ε ⋅ σ, ε ~ N(0, I)
在项目中,我们将编码得到的潜在表示进行缩放:
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215 # 缩放因子
这里的缩放因子 0.18215 是 Stable Diffusion 模型预训练时确定的常数,用于将 VAE 输出的潜在空间分布标准化到更适合扩散过程的范围。
文本引导 是 Stable Diffusion 的核心特性,这一功能由 CLIP(Contrastive Language-Image Pretraining)文本编码器实现。它将文本描述转换为固定维度的向量表示,建立文本与图像之间的语义关联:
text_encoder = CLIPTextModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="text_encoder"
)
tokenizer = CLIPTokenizer.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="tokenizer"
)
CLIP 文本编码器通过对比学习训练,其输出的文本嵌入 t 与图像嵌入在同一语义空间中。对于输入文本 w(如 "a photo of a cat"),经过分词和编码后得到文本特征: t = text_encoder(tokenizer(w))
在项目中,我们使用多样化的提示词模板增强文本嵌入的鲁棒性:
self.prompt_templates = [
"a photo of a {}",
"a high quality image of a {}",
# 更多模板...
]
U-Net 是 Stable Diffusion 的 核心扩散模块,负责在潜在空间中 预测噪声。它以带噪声的潜在表示 z_t、时间步 t 和文本嵌入 c 作为输入,输出噪声预测 ε_θ(z_t, t, c):
unet = UNet2DConditionModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="unet"
)
# 噪声预测 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
U-Net 采用编码器 - 解码器结构,通过跳跃连接保留细节信息,同时引入时间步嵌入和文本条件嵌入,实现条件生成。损失函数采用预测噪声与真实噪声的均方误差: L = E[z_0, ε, t] [ ||ε - ε_θ(z_t, t, c)||² ]
噪声调度器控制着扩散过程中的 噪声添加和去除 策略。在训练阶段,它按照特定 schedule 向干净样本添加噪声;在推理阶段,则逐步从纯噪声中生成图像:
noise_scheduler = DDPMScheduler.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="scheduler"
)
扩散过程遵循 马尔可夫链,前向过程中噪声逐步增加: z_t = √α_t z_{t-1} + √(1-α_t) ε, ε ~ N(0, I)
其中 α_t 是调度器预定义的噪声系数。在项目中,我们通过调度器添加噪声:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
在大规模预训练模型的微调任务中,全参数微调需要巨大的计算资源。LoRA(Low-Rank Adaptation)技术通过 冻结预训练模型权重,仅训练低秩矩阵参数,实现高效微调:
def prepare_unet_for_lora(unet, rank=2, alpha=16):
lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
target_modules=["to_q","to_k","to_v","to_out.0"],
lora_dropout=0.0,
bias="none",
)
unet = get_peft_model(unet, lora_config)
return unet
LoRA 的核心思想是将权重更新表示为低秩矩阵分解的形式。对于预训练权重 W ∈ R^{d×k},LoRA 通过学习两个低秩矩阵 W_A ∈ R^{d×r} 和 W_B ∈ R^{r×k}(r ≪ min(d,k))来近似权重更新: W' = W + W_B W_A
在项目中,我们将 LoRA 应用于 U-Net 的注意力模块,具体是查询(to_q)、键(to_k)、值(to_v)投影层和输出投影层(to_out.0):
Attention(Q + ΔQ, K + ΔK, V + ΔV)
其中 ΔQ = W_B^Q W_A^Q,ΔK 和 ΔV 类似。这种设计使模型能够在保持预训练知识的同时,高效学习特定任务的知识。
在项目配置中,我们选择了较小的 秩(rank=2)和 alpha 值(lora_alpha=16):
# LoRA 参数
rank = 2
lora_alpha = 16
这种配置大大减少了可训练参数数量。通过 print_trainable_parameters() 可以发现,仅约 0.1% 的参数参与训练,显著降低了内存需求和计算成本。同时,LoRA 权重文件体积小(通常只有几 MB),便于存储和分享。
Config 类集中管理了所有关键参数,体现了资源受限情况下的优化策略:
256x256)、减少 LoRA 秩(rank=2)等措施,显著降低显存占用,使训练在普通 GPU 上成为可能。max_grad_norm=0.5),有效防止训练过程中的 梯度爆炸 问题。max_samples_per_class 限制每类样本数量,解决动物数据集类别不平衡问题,避免模型对样本多的类别过拟合。# 参数配置 - 关键优化点
class Config:
# 数据参数 - 减少数据量
data_root = "/kaggle/input/animals/Animal/Animal"
output_dir = "/kaggle/working/output"
lora_model_dir = os.path.join(output_dir, "lora_models")
history_file = os.path.join(output_dir, "training_history.xlsx")
sample_output_dir = os.path.join(output_dir, "validation_samples")
evaluation_file = os.path.join(output_dir, "evaluation_results.xlsx")
comparison_dir = os.path.join(output_dir, "comparison_samples")
# 模型参数 - 降低分辨率
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
resolution = 256
center_crop = True
random_flip = True
# LoRA 参数 - 简化 LoRA
rank = 2
lora_alpha = 16
# 训练参数 - 关键优化
train_batch_size = 1
gradient_accumulation_steps = 4
num_train_epochs = 10
learning_rate = 1e-5
lr_scheduler_type = "cosine_with_warmup"
lr_warmup_steps = 200
max_grad_norm = 0.5
use_ema = True
gradient_checkpointing = True
mixed_precision = "fp16"
# 早停参数 - 使用 CLIP 分数作为指标
early_stopping_patience = 5
early_stopping_delta = 0.02
validation_split = 0.1
# 验证参数
num_validation_samples = 5
num_inference_steps = 20
num_final_inference_steps = 100
guidance_scale = 7.5
# 每类最大样本数
max_samples_per_class =
num_evaluation_samples =
clip_model_name =
AnimalDataset 类实现了动物图像数据集的 加载和预处理 功能,核心特点包括:
"根目录 / 动物类别 / 图像文件" 的层级结构,通过扫描子文件夹自动识别类别名称。max_samples_per_class 的类别进行随机采样,确保各类别样本量相对均衡。特写、自然栖息地)和属性(可爱、野生)的描述,丰富了模型的条件学习信号。# 1. 数据处理与准备 - 添加样本限制
class AnimalDataset(Dataset):
def __init__(self, data_root, tokenizer, size=384, center_crop=True, random_flip=True, max_samples_per_class=100):
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.random_flip = random_flip
self.max_samples_per_class = max_samples_per_class
self.image_paths = []
self.class_names = []
subfolders = [f.name for f in os.scandir(data_root) if f.is_dir()]
for class_name in subfolders:
class_dir = os.path.join(data_root, class_name)
image_files = glob.glob(os.path.join(class_dir, "*.jpg")) + \
glob.glob(os.path.join(class_dir, "*.png")) + \
glob.glob(os.path.join(class_dir, "*.jpeg"))
if len(image_files) > max_samples_per_class:
image_files = random.sample(image_files, max_samples_per_class)
for img_path in image_files:
self.image_paths.append(img_path)
self.class_names.append(class_name)
self.prompt_templates = [
"a photo of a {}",
"a high quality image of a {}",
"a clear picture of a {}",
"a realistic image of a {}",
,
,
]
LANCZOS 重采样方法进行缩放,保证图像质量CLIP tokenizer 对文本进行编码,生成模型可理解的输入 idspadding 和 truncation),便于批量处理def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
class_name = self.class_names[idx]
image = Image.open(image_path).convert("RGB")
if self.center_crop:
image = self._center_crop(image)
else:
image = image.resize((self.size, self.size), Image.Resampling.LANCZOS)
if self.random_flip and random.random() < 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
image_tensor = (torch.tensor(np.array(image).astype(np.float32)/127.5)-1.0).permute(2,0,1)
prompt_template = random.choice(self.prompt_templates)
prompt = prompt_template.format(class_name)
tokenized_input = self.tokenizer(
prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = tokenized_input.input_ids.squeeze(0)
return {
"pixel_values": image_tensor,
"input_ids": input_ids,
"prompt": prompt,
"class_name": class_name
}
def _center_crop():
width, height = image.size
new_size = (width, height)
left = (width - new_size)/
top = (height - new_size)/
right = (width + new_size)/
bottom = (height + new_size)/
image = image.crop((left, top, right, bottom))
image = image.resize((.size, .size), Image.Resampling.LANCZOS)
image
传统的基于损失的早停可能无法准确反映生成模型的质量,本项目采用 基于 CLIP 分数 的早停策略,具有以下特点:
图像 - 文本匹配度)不再显著提升时停止训练,更符合生成任务的质量目标。patience: 允许分数不提升的轮数,设置为 5 给予模型足够的优化空间delta: 最小改善阈值,0.02 确保只有显著提升才被认可# 早停机制 (PyTorch 实现) - 使用 CLIP 分数作为指标
class EarlyStopping:
def __init__(self, patience=3, delta=0.05, verbose=False):
self.patience = patience
self.delta = delta
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, clip_score):
if self.best_score is None:
self.best_score = clip_score
elif clip_score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = clip_score
self.counter = 0
函数实现了对 Stable Diffusion 核心组件UNet 的 LoRA 适配,是参数高效微调的关键:
r(rank):低秩矩阵的秩,设置为 2 大幅减少可训练参数lora_alpha:缩放因子,与秩配合控制更新幅度target_modules:指定需要注入 LoRA 的模块,选择注意力层的查询、键、值投影和输出层get_peft_model 函数将 LoRA 适配器注入UNet,仅训练少量适配器参数而非整个模型。print_trainable_parameters() 会输出可训练参数比例,通常仅为原始模型的 0.1% 左右。# 为 UNet 准备 LoRA 的函数 - 使用 peft 库
def prepare_unet_for_lora(unet, rank=2, alpha=16):
# 配置 LoRA
lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
target_modules=["to_q","to_k","to_v","to_out.0"],
lora_dropout=0.0,
bias="none",
)
# 应用 LoRA 到 UNet
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()
return unet
CLIP(Contrastive Language-Image Pretraining)分数用于 量化评估生成图像与文本描述的匹配程度,是生成质量的重要指标:
logits_per_image 表示图像与文本的匹配分数,值越高表示匹配度越好torch.autocast 和 torch.no_grad 优化计算效率# 计算验证 CLIP 分数的函数
def compute_validation_clip_score(config, unet, text_encoder, vae, tokenizer, device):
animal_classes = [f.name for f in os.scandir(config.data_root) if f.is_dir()]
selected_animals = random.sample(animal_classes, min(config.num_validation_samples, len(animal_classes)))
print(f"Selected animals for validation CLIP score: {selected_animals}")
pipe = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=torch.float16 if config.mixed_precision == "fp16" else torch.float32
).to(device)
clip_model = CLIPModel.from_pretrained(config.clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(config.clip_model_name)
clip_scores = []
for animal in selected_animals:
prompt = f"a high quality photo of a {animal}"
with torch.autocast(device.type):
image = pipe(
prompt,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
height=config.resolution,
width=config.resolution
).images[0]
with torch.no_grad():
inputs = clip_processor(
text=[prompt],
images=image,
return_tensors="pt",
padding=True
).to(device)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
clip_score = logits_per_image.item()
print(f"Animal: {animal}, CLIP Score: {clip_score:f}")
clip_scores.append(clip_score)
avg_clip_score = np.mean(clip_scores)
()
avg_clip_score
模型初始化: 训练函数的初始部分负责加载和配置 Stable Diffusion 的核心组件:
gradient checkpointing)牺牲少量计算时间换取显存节省# 2. 训练函数 (包含早停和历史记录)
def train_lora_with_earlystopping(config):
# 初始化模型组件
tokenizer = CLIPTokenizer.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="text_encoder"
)
vae = AutoencoderKL.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="vae"
)
unet = UNet2DConditionModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="unet"
)
# 添加 LoRA 适配器到 UNet
unet = prepare_unet_for_lora(unet, config.rank, config.lora_alpha)
# 设置噪声调度器
noise_scheduler = DDPMScheduler.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="scheduler"
)
# 启用梯度检查点以节省显存
if config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# 将模型移动到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
text_encoder.to(device)
vae.to(device)
unet.to(device)
优化器与数据加载配置:
AdamW 优化器,配合合理的权重衰减(0.01)防止过拟合AnimalDataset 加载数据DataLoader 实现批量加载和多进程预处理# 设置优化器 (只优化 LoRA 参数)
lora_params = []
for name, param in unet.named_parameters():
if param.requires_grad:
lora_params.append(param)
optimizer = torch.optim.AdamW(
lora_params,
lr=config.learning_rate,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01
)
# 准备数据集和数据加载器
full_dataset = AnimalDataset(
config.data_root,
tokenizer,
size=config.resolution,
center_crop=config.center_crop,
random_flip=config.random_flip,
max_samples_per_class=config.max_samples_per_class
)
val_size = int(len(full_dataset) * config.validation_split)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_dataloader = DataLoader(
train_dataset,
batch_size=config.train_batch_size,
shuffle=True,
num_workers=2
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.train_batch_size,
shuffle=False,
num_workers=2
)
训练调度与记录配置:
# 计算总训练步数
num_update_steps_per_epoch = len(train_dataloader) // config.gradient_accumulation_steps
max_train_steps = config.num_train_epochs * num_update_steps_per_epoch
# 学习率调度器
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=max_train_steps
)
# 初始化早停 (使用 CLIP 分数作为指标)
early_stopping = EarlyStopping(
patience=config.early_stopping_patience,
delta=config.early_stopping_delta,
verbose=True
)
# 创建 Excel 工作簿用于记录历史
history_wb = Workbook()
history_ws = history_wb.active
history_ws.title = "Training History"
history_ws.append(["Epoch", "Step", "Train Loss", "Validation Loss", "CLIP Score", "Learning Rate", "Best CLIP Score", "Gradient Norm"])
核心训练循环:
训练循环实现了 Stable Diffusion 的噪声预测训练过程,关键步骤包括:
with torch.no_grad() 冻结 VAE 参数0.18215 缩放因子,这是 Stable Diffusion 的标准处理流程# 训练循环
global_step = 0
best_clip_score = 0.0
for epoch in range(config.num_train_epochs):
unet.train()
total_loss = 0
optimizer.zero_grad()
current_grad_norm = 0.0
for step, batch in enumerate(train_dataloader):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
# 将图像编码到潜在空间
with torch.no_grad():
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215 # 缩放因子
# 采样噪声
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device).long()
# 向潜在表示添加噪声
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# 获取文本嵌入
with torch.no_grad():
encoder_hidden_states = text_encoder(input_ids)[0]
# 预测噪声残差
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# 计算损失
loss = F.mse_loss(noise_pred, noise, reduction="mean") / config.gradient_accumulation_steps
# 反向传播
loss.backward()
# 梯度累积
if (step + 1) % config.gradient_accumulation_steps == 0:
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(lora_params, config.max_grad_norm)
# 计算梯度范数用于监控
current_grad_norm = 0
p lora_params:
p.grad :
param_norm = p.grad.data.norm()
current_grad_norm += param_norm.item()**
current_grad_norm = current_grad_norm**
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step +=
total_loss += loss.item() * config.gradient_accumulation_steps
global_step % == :
avg_loss = total_loss / (step + )
current_lr = lr_scheduler.get_last_lr()[]
()
epoch 后处理流程:
每个训练 epoch 结束后执行的关键操作:
# 每个 epoch 结束后计算验证损失和 CLIP 分数
val_loss = compute_validation_loss(unet, vae, text_encoder, val_dataloader, noise_scheduler, device)
avg_train_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch} completed. Train Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}")
# 计算 CLIP 分数
clip_score = compute_validation_clip_score(config, unet, text_encoder, vae, tokenizer, device)
# 记录到历史
current_lr = lr_scheduler.get_last_lr()[0]
history_ws.append([epoch, global_step, avg_train_loss, val_loss, clip_score, current_lr, best_clip_score, current_grad_norm])
# 早停检查 (基于 CLIP 分数)
early_stopping(clip_score)
# 保存最佳模型
if clip_score > best_clip_score:
best_clip_score = clip_score
save_path = os.path.join(config.lora_model_dir, f"lora_weights_epoch_{epoch}.safetensors")
save_lora_weights(unet, save_path)
print(f"Saved best model with CLIP score: {best_clip_score:.4f}")
# 保存训练历史
history_wb.save(config.history_file)
# 检查早停
if early_stopping.early_stop:
print("Early stopping triggered")
break
print("Training completed!")
return unet, text_encoder, vae, tokenizer
训练过程示例输出如下所示:
Starting LoRA training… trainable params: 398,592 || all params: 859,919,556 || trainable%: 0.0464 Epoch 0, Step 0, Loss: 0.0044, LR: 0.000000, Grad Norm: 0.000000 ... Average Validation CLIP Score: 30.1487 Saved best model with CLIP score: 30.1487 ... Evaluation results saved to /kaggle/working/output/evaluation_results.xlsx All done! Average CLIP Score: 31.1877
验证样本生成:
该函数在训练完成后生成代表性样本用于可视化评估:
100 步)生成高质量样本# 3. 验证和生成样本
def generate_validation_samples(config, unet, text_encoder, vae, tokenizer):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
animal_classes = [f.name for f in os.scandir(config.data_root) if f.is_dir()]
selected_animals = random.sample(animal_classes, config.num_validation_samples)
print(f"Selected animals for validation: {selected_animals}")
pipe = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
safety_checker=None, # 禁用安全检查器以加快生成速度
torch_dtype=torch.float16 if config.mixed_precision == "fp16" else torch.float32
).to(device)
all_images = []
all_titles = []
for animal in selected_animals:
prompt = f"a high quality photo of a {animal}"
with torch.autocast(device.type):
image = pipe(
prompt,
num_inference_steps=config.num_final_inference_steps,
guidance_scale=config.guidance_scale,
height=config.resolution,
width=config.resolution
).images[0]
save_path = os.path.join(config.sample_output_dir, f"{animal}.png")
image.save(save_path)
print(f"Generated image for {animal} saved at {save_path}")
all_images.append(image)
all_titles.append(animal)
comparison_path = os.path.join(config.comparison_dir, "animal_comparison.png")
create_comparison_image(all_images, all_titles, comparison_path)
selected_animals
量化评估实现:
该函数提供训练后的全面量化评估:
# 4. 评估函数 - 使用 CLIP Score 评估生成质量
def evaluate_with_clip_score(config, unet, text_encoder, vae, tokenizer):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model = CLIPModel.from_pretrained(config.clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(config.clip_model_name)
animal_classes = [f.name for f in os.scandir(config.data_root) if f.is_dir()]
selected_animals = random.sample(animal_classes, config.num_evaluation_samples)
print(f"Selected animals for evaluation: {selected_animals}")
pipe = StableDiffusionPipeline.from_pretrained(
config.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
safety_checker=None,
torch_dtype=torch.float16 if config.mixed_precision == "fp16" else torch.float32
).to(device)
evaluation_results = []
for animal in selected_animals:
prompt = f"a high quality photo of a {animal}"
with torch.autocast(device.type):
image = pipe(
prompt,
num_inference_steps=config.num_final_inference_steps,
guidance_scale=config.guidance_scale,
height=config.resolution,
width=config.resolution
).images[0]
save_path = os.path.join(config.sample_output_dir, f"eval_{animal}.png")
image.save(save_path)
with torch.no_grad():
inputs = clip_processor(
text=[prompt],
images=image,
return_tensors="pt",
padding=True
).to(device)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
clip_score = logits_per_image.item()
()
evaluation_results.append({
: animal,
: prompt,
: clip_score,
: save_path
})
avg_clip_score = np.mean([result[] result evaluation_results])
()
evaluation_wb = Workbook()
evaluation_ws = evaluation_wb.active
evaluation_ws.title =
evaluation_ws.append([, , , ])
result evaluation_results:
evaluation_ws.append([result[], result[], result[], result[]])
evaluation_ws.append([])
evaluation_ws.append([, avg_clip_score])
evaluation_wb.save(config.evaluation_file)
()
evaluation_results, avg_clip_score
pretrained_model_name_or_path 指定本地 Stable Diffusion 基础模型目录,需包含 text_encoder、unet、vae 等核心组件;contrast/saturation/brightness_factor 为后期图像校正参数,通过 PIL 库实现实时调整,解决扩散模型生成图像可能偏灰、色彩暗淡的问题。# 配置类 - 增加色彩相关参数
class Config:
pretrained_model_name_or_path = "model/LCM-runwayml-stable-diffusion-v1-5"
resolution = 512
rank = 2
lora_alpha = 16
device = "cpu"
num_final_inference_steps = 100
guidance_scale = 5.0
contrast_factor = 1.0
saturation_factor = 1.0
brightness_factor = 1.0
包含 3 个核心工具函数,分别解决 LoRA 权重加载、Tokenizer 加载异常、图像色彩校正 三大关键问题,是模型正常运行与生成效果优化的基础
LoRA(Low-Rank Adaptation)是轻量级微调技术,通过加载预训练的 LoRA 权重,可让基础模型快速适配 '动物图像生成' 场景(无需重新训练整个模型)
# 加载 LoRA 权重的函数
def load_lora_weights(unet, load_path):
lora_state_dict = torch.load(load_path, map_location=torch.device(Config.device))
unet.load_state_dict(lora_state_dict, strict=False)
return unet
Tokenizer(文本分词器)是将 '提示词' 转换为模型可识别向量的组件,该函数解决了 '本地模型 Tokenizer 路径异常' 的常见问题,提供降级加载方案
# 修复 tokenizer 加载问题的函数
def load_tokenizer_with_fix(model_path):
try:
tokenizer = CLIPTokenizer.from_pretrained(
os.path.join(model_path, "tokenizer")
)
return tokenizer
except Exception as e:
print(f"加载 tokenizer 时出错:{e}")
print("尝试修复 tokenizer 配置...")
from transformers import CLIPTokenizerFast
vocab_file = os.path.join(model_path, "tokenizer", "vocab.json")
merges_file = os.path.join(model_path, "tokenizer", "merges.txt")
if os.path.exists(vocab_file) and os.path.exists(merges_file):
tokenizer = CLIPTokenizerFast(
vocab_file=vocab_file,
merges_file=merges_file,
max_length=77,
pad_token="!",
additional_special_tokens=["<startoftext|>","<endoftext|>"]
)
return tokenizer
else:
raise Exception(f"找不到 tokenizer 文件:{vocab_file} 或 {merges_file}")
扩散模型生成的图像常存在 '对比度不足、色彩暗淡' 问题,该函数通过 PIL 的 ImageEnhance 模块,对生成图像进行后处理,提升视觉效果
# 图像色彩校正函数
def adjust_image_colors(image):
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(Config.contrast_factor)
enhancer = ImageEnhance.Color(image)
image = enhancer.enhance(Config.saturation_factor)
enhancer = ImageEnhance.Brightness(image)
image = enhancer.enhance(Config.brightness_factor)
return image
Stable Diffusion 由 Tokenizer + Text Encoder + UNet + VAE + Scheduler 五大组件构成,ModelLoader 类负责将这些组件从本地加载、组装为可直接调用的 StableDiffusionPipeline(生成流水线),并集成 LoRA 权重
Text Encoder 负责文本→向量,UNet 负责向量→latent(隐空间向量),VAE 负责latent→图像,Scheduler 控制去噪步数节奏;UNet 加载 LoRA 权重 —— 因为 UNet 是扩散模型的核心生成层,微调 UNet 即可快速改变生成风格(动物图像),无需调整其他组件;# 模型加载类
class ModelLoader:
def __init__(self, config, lora_model_path):
self.config = config
self.lora_model_path = lora_model_path
self.tokenizer = None
self.text_encoder = None
self.vae = None
self.unet = None
self.pipe = None
def load_models(self):
self.tokenizer = load_tokenizer_with_fix(self.config.pretrained_model_name_or_path)
text_encoder_path = os.path.join(self.config.pretrained_model_name_or_path, "text_encoder")
self.text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
vae_path = os.path.join(self.config.pretrained_model_name_or_path, "vae")
self.vae = AutoencoderKL.from_pretrained(vae_path)
unet_path = os.path.join(self.config.pretrained_model_name_or_path, "unet")
self.unet = UNet2DConditionModel.from_pretrained(unet_path)
self.unet = load_lora_weights(self.unet, self.lora_model_path)
self.text_encoder.to(self.config.device)
self.vae.to(self.config.device)
self.unet.to(self.config.device)
scheduler_path = os.path.join(.config.pretrained_model_name_or_path, )
scheduler = DDPMScheduler.from_pretrained(scheduler_path)
.pipe = StableDiffusionPipeline(
vae=.vae,
text_encoder=.text_encoder,
tokenizer=.tokenizer,
unet=.unet,
scheduler=scheduler,
safety_checker=,
feature_extractor=,
requires_safety_checker=
)
.pipe
图像生成是耗时操作(尤其 CPU 运行时),若直接在主线程执行会导致 UI 卡死。GenerateThread 继承 QThread,将生成逻辑放入子线程,通过信号机制与主线程(UI)交互,实时反馈进度
pyqtSignal 定义 3 类信号,实现子线程与 UI 的 '无阻塞通信'—— 进度更新实时反馈,完成 / 错误信号触发 UI 后续操作;# 生成线程类 - 增加色彩校正步骤
class GenerateThread(QThread):
finished = pyqtSignal(Image.Image)
error = pyqtSignal(str)
progress_updated = pyqtSignal(int, float)
def __init__(self, pipe, animal_name, num_inference_steps, guidance_scale, contrast_factor, saturation_factor, brightness_factor):
super().__init__()
self.pipe = pipe
self.animal_name = animal_name
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.contrast_factor = contrast_factor
self.saturation_factor = saturation_factor
self.brightness_factor = brightness_factor
self.start_time = 0
self.step_times = []
def run(self):
try:
prompt = (
f"a high quality photo of a {self.animal_name}, natural lighting, "
f"realistic colors, in natural habitat, detailed texture"
)
with torch.no_grad():
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
text_input_ids = text_inputs.input_ids
text_embeddings = self.pipe.text_encoder(text_input_ids.to(self.pipe.device))[0]
max_length = text_input_ids.shape[-]
uncond_input = .pipe.tokenizer([], padding=, max_length=max_length, return_tensors=,)
uncond_embeddings = .pipe.text_encoder(uncond_input.input_ids.to(.pipe.device))[]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(, .pipe.unet.config.in_channels, Config.resolution // , Config.resolution // ),
generator=torch.Generator(device=Config.device),
device=Config.device,
)
.pipe.scheduler.set_timesteps(.num_inference_steps, device=Config.device)
.start_time = time.time()
.step_times = []
i, t (.pipe.scheduler.timesteps):
step_start_time = time.time()
latent_model_input = torch.cat([latents]*)
latent_model_input = .pipe.scheduler.scale_model_input(latent_model_input, t)
noise_pred = .pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk()
noise_pred = noise_pred_uncond + .guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = .pipe.scheduler.step(noise_pred, t, latents).prev_sample
step_time = time.time() - step_start_time
.step_times.append(step_time)
progress = ((i + ) / .num_inference_steps * )
steps_remaining = .num_inference_steps - (i + )
(.step_times) >= :
avg_step_time = (.step_times[-:]) /
:
avg_step_time = (.step_times) / (.step_times) .step_times
remaining_time = avg_step_time * steps_remaining
.progress_updated.emit(progress, remaining_time)
latents = / * latents
torch.no_grad():
image = .pipe.vae.decode(latents).sample
image = (image / + ).clamp(, )
image = image.cpu().permute(, , , ).().numpy()
image = (image[] * ).().astype()
image = Image.fromarray(image)
image = adjust_image_colors(image)
image.size != (Config.resolution, Config.resolution):
image = image.resize((Config.resolution, Config.resolution), Image.LANCZOS)
.finished.emit(image)
Exception e:
.error.emit((e))
AnimalGeneratorApp 继承 QMainWindow,是整个工具的 '用户交互中心',负责构建 UI 布局、绑定按钮事件、处理线程信号(显示进度 / 图像 / 错误)。核心分为左侧控制面板和右侧图像显示区两部分,以下重点讲解核心功能逻辑:
class AnimalGeneratorApp(QMainWindow):
def __init__(self):
super().__init__()
self.pipe = None
self.current_image = None
self.initUI()
def initUI(self):
font = QFont("SimHei")
font.setPointSize(10)
self.setFont(font)
self.setWindowTitle('动物图像生成器')
self.setGeometry(100, 100, 1100, 800)
central_widget = QWidget()
self.setCentralWidget(central_widget)
main_layout = QHBoxLayout(central_widget)
main_layout.setContentsMargins(15, 15, 15, 15)
main_layout.setSpacing(20)
control_panel = self.create_control_panel()
main_layout.addWidget(control_panel, 3)
image_panel = self.create_image_panel()
main_layout.addWidget(image_panel, 5)
def generate_image(self):
if not self.pipe:
QMessageBox.warning(self, "错误", "请先加载模型")
return
animal_name = self.animal_edit.text().strip()
if not animal_name:
QMessageBox.warning(self, "错误", "请输入动物名称")
return
Config.resolution = self.resolution_spin.value()
num_inference_steps = self.steps_spin.value()
guidance_scale = self.guidance_spin.value()
contrast_factor = self.contrast_spin.value()
saturation_factor = self.saturation_spin.value()
brightness_factor = self.brightness_spin.value()
self.generate_btn.setEnabled(False)
self.save_btn.setEnabled(False)
self.progress_bar.setVisible(True)
self.progress_bar.setRange(0, 100)
self.progress_bar.setValue(0)
self.progress_label.setText("准备生成 (第一次加载请耐心等待哦)...")
self.statusBar().showMessage("正在生成图像,请稍候...")
self.gen_thread = GenerateThread(
self.pipe, animal_name, num_inference_steps, guidance_scale,
contrast_factor, saturation_factor, brightness_factor
)
self.gen_thread.finished.connect(self.on_generation_finished)
self.gen_thread.error.connect(.on_generation_error)
.gen_thread.progress_updated.connect(.on_progress_updated)
.gen_thread.start()
():
.current_image = image
pixmap = .pil2pixmap(image)
.image_label.setPixmap(pixmap.scaled(
.image_label.width(), .image_label.height(), Qt.KeepAspectRatio, Qt.SmoothTransformation
))
.generate_btn.setEnabled()
.save_btn.setEnabled()
.progress_bar.setValue()
.progress_label.setText()
.statusBar().showMessage()
右侧图像显示区采用三层叠加设计,兼顾 '默认提示'+'生成结果'+'版权水印':
def create_image_panel(self):
panel = QWidget()
layout = QVBoxLayout(panel)
image_container = QWidget()
image_container.setStyleSheet("""
background-color: white;
border-radius: 8px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
padding: 15px;
""")
image_layout = QVBoxLayout(image_container)
self.default_image_label = QLabel()
self.default_image_label.setAlignment(Qt.AlignCenter)
self.default_image_label.setMinimumSize(512, 512)
self.load_default_image()
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setMinimumSize(512, 512)
self.image_label.setStyleSheet("background-color: transparent;")
self.watermark_label = QLabel("制作者:热心市民小周")
self.watermark_label.setStyleSheet("""
color: rgba(100, 100, 100, 150);
font-size: 12px;
padding: 5px;
background-color: rgba(255, 255, 255, 100);
border-radius: 2px;
""")
self.watermark_label.setAlignment(Qt.AlignRight | Qt.AlignBottom)
grid_layout = QGridLayout()
grid_layout.addWidget(self.default_image_label, 0, 0)
grid_layout.addWidget(self.image_label, 0, 0)
grid_layout.addWidget(self.watermark_label, 0, 0)
image_layout.addLayout(grid_layout)
layout.addWidget(image_container, 1)
return panel
最终的界面图如下所示:
梯度范数 GN: 衡量参数更新规模
| 推理步数 | 推理平均时间(CPU) | 平均总时间 |
|---|---|---|
| 20 | 68.64s | 456.65s |
| 100 | 349.35s | 823.45s |
| 200 | 683.96s | 1209.65s |
| 400 | 1356.86s | 1863.25s |
系统能够生成多种动物的高质量图像,包括:
更多验证输出样例可见
output/pic
使用不同推理步数得到的 tiger 图像如下所示:
UI 界面的使用实例如下:
100 类动物图像生成

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online