From 176807740ba0d768072dc8dd786a5d005c9a704c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Thu, 22 Feb 2024 08:39:55 +0000 Subject: [PATCH] control_lora: fix adapter set scale The adapter set scale did not propagate the scale to the underlying zero convolutions. The value set at CTOR time was used instead. Follow up of #285 --- .../stable_diffusion_xl/control_lora.py | 11 +++++- tests/adapters/test_control_lora.py | 34 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tests/adapters/test_control_lora.py diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py index b9eeccd..8138d63 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py @@ -117,7 +117,7 @@ class ZeroConvolution(Passthrough): device: The PyTorch device to use. dtype: The PyTorch data type to use. """ - self.scale = scale + self._scale = scale super().__init__( Conv2d( @@ -131,6 +131,15 @@ class ZeroConvolution(Passthrough): ResidualAccumulator(n=residual_index), ) + @property + def scale(self) -> float: + return self._scale + + @scale.setter + def scale(self, value: float) -> None: + self._scale = value + self.ensure_find(Multiply).scale = value + class ControlLora(Passthrough): """ControlLora is a Half-UNet clone of the target UNet, diff --git a/tests/adapters/test_control_lora.py b/tests/adapters/test_control_lora.py new file mode 100644 index 0000000..5c7f26c --- /dev/null +++ b/tests/adapters/test_control_lora.py @@ -0,0 +1,34 @@ +import torch + +import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion import ControlLoraAdapter, SDXLUNet +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ZeroConvolution + + +def test_inject_eject(test_device: torch.device): + unet = SDXLUNet(in_channels=4, device=test_device, dtype=torch.float16) + initial_repr = repr(unet) + adapter = ControlLoraAdapter(name="foo", target=unet) + assert repr(unet) == initial_repr + adapter.inject() + assert repr(unet) != initial_repr + adapter.eject() + assert repr(unet) == initial_repr + + +def test_scale(test_device: torch.device): + unet = SDXLUNet(in_channels=4, device=test_device, dtype=torch.float16) + adapter = ControlLoraAdapter(name="foo", target=unet, scale=0.75).inject() + + def predicate(m: fl.Module, p: fl.Chain) -> bool: + return isinstance(p, ZeroConvolution) and isinstance(m, fl.Multiply) + + for m, _ in unet.walk(predicate): + assert isinstance(m, fl.Multiply) + assert m.scale == 0.75 + + adapter.scale = 0.42 + assert adapter.scale == 0.42 + for m, _ in unet.walk(predicate): + assert isinstance(m, fl.Multiply) + assert m.scale == 0.42