diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 8c33d44..ef29151 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -68,6 +68,14 @@ class StableDiffusion_1(LatentDiffusionModel): dtype=dtype, ) + def __call__(self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5) -> Tensor: + return super().__call__( + x, + step, + clip_text_embedding=clip_text_embedding, + condition_scale=condition_scale, + ) + def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor: """Compute the CLIP text embedding associated with the given prompt and negative prompt. diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py index 6b18302..0355d85 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py @@ -16,7 +16,7 @@ class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]): x=x, step=step, clip_text_embedding=target.clip_text_embedding, - scale=target.condition_scale, + condition_scale=target.condition_scale, ) @@ -37,5 +37,5 @@ class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, I x=x, step=step, clip_text_embedding=target.clip_text_embedding, - scale=target.condition_scale, + condition_scale=target.condition_scale, ) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 4c4f8d7..74bb8f1 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -65,6 +65,25 @@ class StableDiffusion_XL(LatentDiffusionModel): dtype=dtype, ) + def __call__( + self, + x: Tensor, + step: int, + *, + clip_text_embedding: Tensor, + pooled_text_embedding: Tensor, + time_ids: Tensor, + condition_scale: float = 5.0, + ) -> Tensor: + return super().__call__( + x=x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=condition_scale, + ) + def compute_clip_text_embedding( self, text: str | list[str], negative_text: str | list[str] = "" ) -> tuple[Tensor, Tensor]: diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4431c95..0fcb0fd 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -2038,6 +2038,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: size=(64, 64), offset=(0, 16), clip_text_embedding=clip_text_embedding, + condition_scale=3, start_step=0, ) noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype) diff --git a/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png b/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png index da09414..4fc4b99 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png and b/tests/e2e/test_diffusion_ref/expected_multi_diffusion.png differ