diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index e019ae0..db4d3cc 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -29,6 +29,41 @@ class StableDiffusion_1(LatentDiffusionModel): unet: The U-Net model. clip_text_encoder: The text encoder. lda: The image autoencoder. + + Example: + ```py + import torch + + from refiners.fluxion.utils import manual_seed, no_grad + from refiners.foundationals.latent_diffusion.stable_diffusion_1 import StableDiffusion_1 + + # Load SD + sd15 = StableDiffusion_1(device="cuda", dtype=torch.float16) + + sd15.clip_text_encoder.load_from_safetensors("sd1_5.text_encoder.safetensors") + sd15.unet.load_from_safetensors("sd1_5.unet.safetensors") + sd15.lda.load_from_safetensors("sd1_5.autoencoder.safetensors") + + # Hyperparameters + prompt = "a cute cat, best quality, high quality" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + seed = 42 + + sd15.set_inference_steps(50) + + with no_grad(): # Disable gradient calculation for memory-efficient inference + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + manual_seed(seed) + + x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype) + + # Diffusion process + for step in sd15.steps: + x = sd15(x, step=step, clip_text_embedding=clip_text_embedding) + + predicted_image = sd15.lda.decode_latents(x) + predicted_image.save("output.png") + ``` """ unet: SD1UNet