From 1eb71077aafdf16d622e750fc3e6175062a7f3dd Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 8 Mar 2024 10:56:27 +0100 Subject: [PATCH] use same scale setter / getter interface for all controls --- .../foundationals/latent_diffusion/image_prompt.py | 4 ---- .../latent_diffusion/stable_diffusion_1/controlnet.py | 9 +++++++-- .../foundationals/latent_diffusion/t2i_adapter.py | 9 +++++++-- tests/e2e/test_diffusion.py | 6 +++--- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 415b676..768a110 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -424,10 +424,6 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): for cross_attn in self.sub_adapters: cross_attn.scale = value - def set_scale(self, scale: float) -> None: - for cross_attn in self.sub_adapters: - cross_attn.scale = scale - def set_clip_image_embedding(self, image_embedding: Tensor) -> None: """Set the CLIP image embedding context. diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 8729bd4..523b6db 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -174,8 +174,13 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): def init_context(self) -> Contexts: return {"controlnet": {f"condition_{self.name}": None}} - def set_scale(self, scale: float) -> None: - self._controlnet[0].scale = scale + @property + def scale(self) -> float: + return self._controlnet[0].scale + + @scale.setter + def scale(self, value: float) -> None: + self._controlnet[0].scale = value def set_controlnet_condition(self, condition: Tensor) -> None: self.set_context("controlnet", {f"condition_{self.name}": condition}) diff --git a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py index ba49f5e..95a1530 100644 --- a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py @@ -204,9 +204,14 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]): def set_condition_features(self, features: tuple[Tensor, ...]) -> None: self.set_context("t2iadapter", {f"condition_features_{self.name}": features}) - def set_scale(self, scale: float) -> None: + @property + def scale(self) -> float: + return self._features[0].scale + + @scale.setter + def scale(self, value: float) -> None: for f in self._features: - f.scale = scale + f.scale = value def init_context(self) -> Contexts: return {"t2iadapter": {f"condition_features_{self.name}": None}} diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 8fdab23..0c4f71e 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -2109,7 +2109,7 @@ def test_t2i_adapter_xl_canny( sdxl.set_inference_steps(30) t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject() - t2i_adapter.set_scale(0.8) + t2i_adapter.scale = 0.8 condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) @@ -2238,8 +2238,8 @@ def test_hello_world( condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition)) - ip_adapter.set_scale(0.85) - t2i_adapter.set_scale(0.8) + ip_adapter.scale = 0.85 + t2i_adapter.scale = 0.8 sdxl.set_inference_steps(50, first_step=1) sdxl.set_self_attention_guidance(enable=True, scale=0.75)