control_lora: fix adapter set scale

The adapter set scale did not propagate the scale to the underlying
zero convolutions. The value set at CTOR time was used instead.

Follow up of #285
This commit is contained in:
Cédric Deltheil 2024-02-22 08:39:55 +00:00 committed by Cédric Deltheil
parent 83960bdbb8
commit 176807740b
2 changed files with 44 additions and 1 deletions

View file

@ -117,7 +117,7 @@ class ZeroConvolution(Passthrough):
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
self.scale = scale
self._scale = scale
super().__init__(
Conv2d(
@ -131,6 +131,15 @@ class ZeroConvolution(Passthrough):
ResidualAccumulator(n=residual_index),
)
@property
def scale(self) -> float:
return self._scale
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(Multiply).scale = value
class ControlLora(Passthrough):
"""ControlLora is a Half-UNet clone of the target UNet,

View file

@ -0,0 +1,34 @@
import torch
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion import ControlLoraAdapter, SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ZeroConvolution
def test_inject_eject(test_device: torch.device):
unet = SDXLUNet(in_channels=4, device=test_device, dtype=torch.float16)
initial_repr = repr(unet)
adapter = ControlLoraAdapter(name="foo", target=unet)
assert repr(unet) == initial_repr
adapter.inject()
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr
def test_scale(test_device: torch.device):
unet = SDXLUNet(in_channels=4, device=test_device, dtype=torch.float16)
adapter = ControlLoraAdapter(name="foo", target=unet, scale=0.75).inject()
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(p, ZeroConvolution) and isinstance(m, fl.Multiply)
for m, _ in unet.walk(predicate):
assert isinstance(m, fl.Multiply)
assert m.scale == 0.75
adapter.scale = 0.42
assert adapter.scale == 0.42
for m, _ in unet.walk(predicate):
assert isinstance(m, fl.Multiply)
assert m.scale == 0.42