diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py index b9eeccd..8138d63 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/control_lora.py @@ -117,7 +117,7 @@ class ZeroConvolution(Passthrough): device: The PyTorch device to use. dtype: The PyTorch data type to use. """ - self.scale = scale + self._scale = scale super().__init__( Conv2d( @@ -131,6 +131,15 @@ class ZeroConvolution(Passthrough): ResidualAccumulator(n=residual_index), ) + @property + def scale(self) -> float: + return self._scale + + @scale.setter + def scale(self, value: float) -> None: + self._scale = value + self.ensure_find(Multiply).scale = value + class ControlLora(Passthrough): """ControlLora is a Half-UNet clone of the target UNet, diff --git a/tests/adapters/test_control_lora.py b/tests/adapters/test_control_lora.py new file mode 100644 index 0000000..5c7f26c --- /dev/null +++ b/tests/adapters/test_control_lora.py @@ -0,0 +1,34 @@ +import torch + +import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion import ControlLoraAdapter, SDXLUNet +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ZeroConvolution + + +def test_inject_eject(test_device: torch.device): + unet = SDXLUNet(in_channels=4, device=test_device, dtype=torch.float16) + initial_repr = repr(unet) + adapter = ControlLoraAdapter(name="foo", target=unet) + assert repr(unet) == initial_repr + adapter.inject() + assert repr(unet) != initial_repr + adapter.eject() + assert repr(unet) == initial_repr + + +def test_scale(test_device: torch.device): + unet = SDXLUNet(in_channels=4, device=test_device, dtype=torch.float16) + adapter = ControlLoraAdapter(name="foo", target=unet, scale=0.75).inject() + + def predicate(m: fl.Module, p: fl.Chain) -> bool: + return isinstance(p, ZeroConvolution) and isinstance(m, fl.Multiply) + + for m, _ in unet.walk(predicate): + assert isinstance(m, fl.Multiply) + assert m.scale == 0.75 + + adapter.scale = 0.42 + assert adapter.scale == 0.42 + for m, _ in unet.walk(predicate): + assert isinstance(m, fl.Multiply) + assert m.scale == 0.42