配置: jit_compile=False (默认关闭 XLA 编译)
提示词: "There is a pink BMW Mini at the exhibition where the lights focus"
Batch Size: 3
import time
import keras_cv
import keras
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=False)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image("There is a pink BMW Mini at the exhibition where the lights focus", batch_size=3)
print(f"Standard model: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()
配置: mixed_float16 策略,jit_compile=False提示词: "There is a black BMW Mini at the exhibition where the lights focus"
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=False)
print("Compute dtype:", model.diffusion_model.compute_dtype)
print("Variable dtype:", model.diffusion_model.variable_dtype)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image("There is a black BMW Mini at the exhibition where the lights focus", batch_size=3)
print(f"Mixed precision model: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()
配置: float32 策略,jit_compile=True提示词: "There is a black ford mustang at the exhibition where the lights focus"
keras.mixed_precision.set_global_policy("float32")
model = keras_cv.models.StableDiffusion(jit_compile=True)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image("There is a black ford mustang at the exhibition where the lights focus", batch_size=3)
print(f"With XLA: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()
配置: mixed_float16 策略,jit_compile=True提示词: "There is a purple ford mustang at the exhibition where the lights focus"
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image("There is a purple ford mustang at the exhibition where the lights focus", batch_size=3)
print(f"XLA + mixed precision: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()