From 5aef1408d8e8647f26dde33c7310f48059854bde Mon Sep 17 00:00:00 2001 From: Kadir Nar Date: Fri, 6 Sep 2024 10:16:09 +0300 Subject: [PATCH] =?UTF-8?q?=E2=AD=90=20Add=20example=20code=20for=20Stable?= =?UTF-8?q?=20Diffusion(1.5)=20(#409)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../stable_diffusion_1/model.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index e019ae0..db4d3cc 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -29,6 +29,41 @@ class StableDiffusion_1(LatentDiffusionModel): unet: The U-Net model. clip_text_encoder: The text encoder. lda: The image autoencoder. + + Example: + ```py + import torch + + from refiners.fluxion.utils import manual_seed, no_grad + from refiners.foundationals.latent_diffusion.stable_diffusion_1 import StableDiffusion_1 + + # Load SD + sd15 = StableDiffusion_1(device="cuda", dtype=torch.float16) + + sd15.clip_text_encoder.load_from_safetensors("sd1_5.text_encoder.safetensors") + sd15.unet.load_from_safetensors("sd1_5.unet.safetensors") + sd15.lda.load_from_safetensors("sd1_5.autoencoder.safetensors") + + # Hyperparameters + prompt = "a cute cat, best quality, high quality" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + seed = 42 + + sd15.set_inference_steps(50) + + with no_grad(): # Disable gradient calculation for memory-efficient inference + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + manual_seed(seed) + + x = sd15.init_latents((512, 512)).to(sd15.device, sd15.dtype) + + # Diffusion process + for step in sd15.steps: + x = sd15(x, step=step, clip_text_embedding=clip_text_embedding) + + predicted_image = sd15.lda.decode_latents(x) + predicted_image.save("output.png") + ``` """ unet: SD1UNet