基于 Stable Diffusion 的数据增强实践方案
'数据不够?那就让 AI 自己画!'——某位凌晨三点还在调 prompt 的算法工程师
当训练数据成了'稀有动物'
做 CV 的兄弟姐妹都懂,数据永远比 idea 贵。
老板一句'我要模型精度再涨 5 个点',背后往往是标注团队通宵达旦地画框、打点、写标签。更惨的是,有些场景连原始图片都凑不齐:
- 医疗影像里,某种罕见病灶一年才出现几十例;
- 工业产线上,缺陷样本比 996 的程序员还稀缺;
- 新零售商品库,长尾 SKU 的货架图只能靠采购小哥手机随手拍——光照、角度、背景全靠缘分。
传统 augmentation 三板斧(旋转、裁剪、颜色抖动)在这些场景下就像用指甲刀砍大树,语义信息没变,但也没增加多少新东西。
直到某天,我盯着 Stable Diffusion 生成的'赛博朋克猫'出神,脑子里突然蹦出一个念头:
既然它能画猫,能不能画'缺陷'?
于是,这篇'血泪踩坑史'就有了开头。
为什么偏偏是 Stable Diffusion?
先别急着抄家伙,生成式模型那么多,凭啥选它?
| 模型 | 可控性 | 开源程度 | 消费级显卡友好度 | 备注 |
|---|
| StyleGAN3 | 中 | 高 | 凑合 | 画风偏'艺术',语义控制需额外网络 |
| DALL·E 2 | 高 | 闭源 | ❌ | API 限速 + 钱包警告 |
| Midjourney | 高 | 闭源 | ❌ | 付费 + 不能本地批量 |
| Stable Diffusion | 高 | 完全开源 | RTX 3060 就能 512² 跑 batch | 社区轮子多到用不完 |
一句话:免费、本地、可批量、可微调、社区还卷。
对我们这些**'公司只给预算 0 元'**的开发者来说,它就是天降正义。
把'魔法'拆开:Stable Diffusion 到底干了啥?
'别急着念咒,先搞清楚魔杖是什么木头。'
1. 潜在空间里的'降噪游戏'
Stable Diffusion 把图像压缩到 64×64 的潜在向量(latent space),然后在这块'小画布'上做扩散——前向加噪、反向去噪。
好处?
- 比直接操作像素省显存,512² 图在 8G 显存里能跑 batch=8;
- latent 空间天生带'语义坐标',文本 embedding 像遥控器,往哪儿走它都听得懂。
2. 提示词 = 遥控器的'按钮组合'
正向 prompt:a photo of cracked phone screen, close-up, industrial inspection, 4K, sharp
负向 prompt:cartoon, painting, lowres, blurry, extra fingers
负向 prompt 是隐藏宝藏:把'不想要的'写进去,比单纯堆正向词更能减少废图。
3. ControlNet:给'画家'一把尺子
纯文本容易'抽卡',ControlNet 把 Canny 边缘、深度图、语义分割 mask 变成'草图',让生成结果结构不变、纹理随便换。
做数据增强时,原图边缘图 + 随机 prompt = 同一结构不同外观,完美。
4. LoRA:不煮大锅饭,只开小灶
全量微调 4 GB 模型?老板不给显卡。
LoRA 把权重更新拆成两个小矩阵,训练量降到 1/10,10 张图 10 分钟就能学会'某种裂纹风格',迁移学习神器。
搭一条'可控'的增强流水线
'没有流程的生成,都是玄学。'
下面给出一条工业界能落地的 Python 流水线,每一步都能打断点 debug,拒绝黑箱。
0. 环境一键复现
conda create -n sdaug python=3.10
conda activate sdaug
pip install diffusers==0.21.0 transformers accelerate xformers opencv-python safetensors
1. 原图→边缘图:保留结构
import cv2
import os
def extract_canny(img_path, low=100, high=200, output_size=512):
img = cv2.imread(img_path)
img = cv2.resize(img, (output_size, output_size))
canny = cv2.Canny(img, low, high)
canny = cv2.cvtColor(canny, cv2.COLOR_GRAY2RGB)
return canny
os.makedirs("canny_dir", exist_ok=True)
for f in os.listdir("raw_images"):
canny = extract_canny(f"raw_images/{f}")
cv2.imwrite(f"canny_dir/{f}", canny)
小贴士:Canny 阈值别手抖,低阈值太高会把细节弄丢,经验值 50/150 起步,每张图都用同样阈值,保证后续对齐。
2. prompt 模板:把'随机'装进笼子里
templates = {
"crack": [
"a photo of {defect} on {object}, industrial scene, {lighting}, 4K, sharp, no text",
"close-up shot of {defect} defect, metal surface, {lighting}, realistic, high contrast"
],
"lighting": ["under factory LED light", "natural daylight", "dim warehouse light", "fluorescent tube light"],
"object": ["aluminum panel", "steel plate", "phone screen", "car bumper"]
}
def sample_prompt(defect="crack"):
import random
t = random.choice(templates[defect])
lighting = random.choice(templates["lighting"])
obj = random.choice(templates["object"])
return t.format(defect=defect, lighting=lighting, object=obj)
模板化 = 可复现 + 可单元测试,别小看这一步,后期排查语义漂移全靠它。
3. 图像→图像:把边缘图喂给 Stable Diffusion
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
import os
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload()
os.makedirs("aug_images", exist_ok=True)
for idx, canny_file in enumerate(os.listdir("canny_dir")):
canny_image = load_image(f"canny_dir/{canny_file}")
prompt = sample_prompt(defect="crack")
negative = "cartoon, painting, lowres, blurry, extra fingers, text, watermark"
out = pipe(
prompt=prompt,
negative_prompt=negative,
image=canny_image,
num_inference_steps=30,
guidance_scale=7.5,
generator=torch.Generator().manual_seed(42+idx),
strength=0.9
).images[0]
out.save(f"aug_images/{idx:04d}.jpg")
strength 参数是灵魂:0.7 以下:基本只是'重新打光';0.9 左右:结构保留但纹理大换血;1.0:放飞自我,可能把裂纹画成涂鸦。
4. 自动过滤:别让'垃圾'进数据集
from transformers import CLIPProcessor, CLIPModel
import torch, os, json
from PIL import Image
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def clip_score(image, text):
inputs = proc(text=[text], images=image, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = clip(**inputs)
logits = outputs.logits_per_image
return logits.item()
threshold = 28
manifest = []
for imgf in os.listdir("aug_images"):
img = Image.open(f"aug_images/{imgf}")
score = clip_score(img, "a photo of cracked phone screen")
if score >= threshold:
manifest.append({"file": imgf, "clip": score})
json.dump(manifest, open("valid_images.json", "w"), ensure_ascii=False, indent=2)
print(f"过滤后剩余 {len(manifest)} 张,淘汰率 {1-len(manifest)/len(os.listdir('aug_images')):.2%}")
CLIP 分数不是圣旨,只能当'初筛'。后续还要让业务分类器回测,看加入这些图后 validation 涨不涨,再决定要不要下调阈值。
三个真实到'掉头发'的落地案例
1. 医疗影像:给罕见病灶'加戏'
背景:某三甲放射科,早期肺结节 CT 只有 87 张阳性,阴性 3000+,模型快把阳性当成'外星人'。
方案:
- 用
nnUNet 把原图结节 mask 抠出来→生成深度图→喂 ControlNet;
- prompt 模板里加入'-1mm 薄层、低剂量、40 岁患者'等医学关键词,让生成图自带 CT 质感;
- 生成 800 张后,请主任肉眼筛掉 120 张'不像病灶'的废图(医生眼光毒辣,一眼看穿伪影)。
结果:
- 召回率从 0.61 → 0.78,假阳性降 35%;
- 论文投 MICCAI,审稿人唯一质疑:'伦理批准呢?'——生成数据也要补伦理批件,别踩坑。
2. 工业质检:把'缺陷'搬到不同产线
背景:手机盖板玻璃,裂纹样本 214 张,客户要求识别 5 条产线、3 种光照、4 种角度共 60 种工况。
方案:
- 用 Blender 批量渲染虚拟边缘图(裂纹形状固定,角度/光照随意调),一天造 2W 张边缘图;
- 边缘图 + ControlNet 生成真实纹理,**prompt 里随机