diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 007ecce..cb3e348 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar, cast from torch import Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter @@ -7,8 +8,10 @@ from torch.nn.init import normal_, zeros_ import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter +T = TypeVar("T", bound=fl.WeightedModule) -class Lora(fl.Chain, ABC): + +class Lora(Generic[T], fl.Chain, ABC): """Low-Rank Adaptation (LoRA) layer. This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]: @@ -55,9 +58,7 @@ class Lora(fl.Chain, ABC): zeros_(tensor=self.up.weight) @abstractmethod - def lora_layers( - self, device: Device | str | None = None, dtype: DType | None = None - ) -> tuple[fl.WeightedModule, fl.WeightedModule]: + def lora_layers(self, device: Device | str | None = None, dtype: DType | None = None) -> tuple[T, T]: """Create the down and up layers of the LoRA. Args: @@ -67,18 +68,18 @@ class Lora(fl.Chain, ABC): ... @property - def down(self) -> fl.WeightedModule: + def down(self) -> T: """The down layer.""" down_layer = self[0] assert isinstance(down_layer, fl.WeightedModule) - return down_layer + return cast(T, down_layer) @property - def up(self) -> fl.WeightedModule: + def up(self) -> T: """The up layer.""" up_layer = self[1] assert isinstance(up_layer, fl.WeightedModule) - return up_layer + return cast(T, up_layer) @property def rank(self) -> int: @@ -102,7 +103,7 @@ class Lora(fl.Chain, ABC): /, down: Tensor, up: Tensor, - ) -> "Lora": + ) -> "Lora[Any]": match (up.ndim, down.ndim): case (2, 2): return LinearLora.from_weights(name, up=up, down=down) @@ -112,14 +113,14 @@ class Lora(fl.Chain, ABC): raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}") @classmethod - def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]: + def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora[Any]"]: """ Create a dictionary of LoRA layers from a state dict. Expects the state dict to be a succession of down and up weights. """ state_dict = {k: v for k, v in state_dict.items() if ".weight" in k} - loras: dict[str, Lora] = {} + loras: dict[str, Lora[Any]] = {} for down_key, down_tensor, up_tensor in zip( list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2] ): @@ -168,7 +169,7 @@ class Lora(fl.Chain, ABC): self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) -class LinearLora(Lora): +class LinearLora(Lora[fl.Linear]): """Low-Rank Adaptation (LoRA) layer for linear layers. This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers. @@ -254,7 +255,7 @@ class LinearLora(Lora): return False -class Conv2dLora(Lora): +class Conv2dLora(Lora[fl.Conv2d]): """Low-Rank Adaptation (LoRA) layer for 2D convolutional layers. This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers. @@ -374,7 +375,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): This adapter simply sums the target layer with the given LoRA layers. """ - def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None: + def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None: """Initialize the adapter. Args: @@ -387,24 +388,24 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): @property def names(self) -> list[str]: """The names of the LoRA layers.""" - return [lora.name for lora in self.layers(Lora)] + return [lora.name for lora in self.layers(Lora[Any])] @property - def loras(self) -> dict[str, Lora]: + def loras(self) -> dict[str, Lora[Any]]: """The LoRA layers indexed by name.""" - return {lora.name: lora for lora in self.layers(Lora)} + return {lora.name: lora for lora in self.layers(Lora[Any])} @property def scales(self) -> dict[str, float]: """The scales of the LoRA layers indexed by names.""" - return {lora.name: lora.scale for lora in self.layers(Lora)} + return {lora.name: lora.scale for lora in self.layers(Lora[Any])} @scales.setter def scale(self, values: dict[str, float]) -> None: for name, value in values.items(): self.loras[name].scale = value - def add_lora(self, lora: Lora, /) -> None: + def add_lora(self, lora: Lora[Any], /) -> None: """Add a LoRA layer to the adapter. Raises: @@ -416,7 +417,7 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists" self.append(lora) - def remove_lora(self, name: str, /) -> Lora | None: + def remove_lora(self, name: str, /) -> Lora[Any] | None: """Remove a LoRA layer from the adapter. Note: diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index e2debf2..0c120b3 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,3 +1,4 @@ +from typing import Any from warnings import warn from torch import Tensor @@ -106,7 +107,7 @@ class SDLoraManager: for name, lora_tensors in tensors.items(): self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0) - def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None: + def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None: """Add multiple LoRAs to the text encoder. Args: @@ -116,7 +117,7 @@ class SDLoraManager: text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) - def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None: + def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: """Add multiple LoRAs to the U-Net. Args: @@ -147,7 +148,7 @@ class SDLoraManager: for lora_adapter in self.lora_adapters: lora_adapter.eject() - def get_loras_by_name(self, name: str, /) -> list[Lora]: + def get_loras_by_name(self, name: str, /) -> list[Lora[Any]]: """Get the LoRA layers with the given name. Args: @@ -190,9 +191,9 @@ class SDLoraManager: lora.scale = scale @property - def loras(self) -> list[Lora]: + def loras(self) -> list[Lora[Any]]: """List of all the LoRA layers managed by the SDLoraManager.""" - return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora)) + return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any])) @property def names(self) -> list[str]: @@ -239,12 +240,12 @@ class SDLoraManager: @staticmethod def auto_attach( - loras: dict[str, Lora], + loras: dict[str, Lora[Any]], target: fl.Chain, /, exclude: list[str] | None = None, ) -> None: - failed_loras: dict[str, Lora] = {} + failed_loras: dict[str, Lora[Any]] = {} for key, lora in loras.items(): if attach := lora.auto_attach(target, exclude=exclude): adapter, parent = attach diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py index 8666982..0aa74ba 100644 --- a/tests/adapters/test_lora.py +++ b/tests/adapters/test_lora.py @@ -11,11 +11,11 @@ def lora() -> LinearLora: @pytest.fixture -def conv_lora() -> Lora: +def conv_lora() -> Conv2dLora: return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4) -def test_properties(lora: LinearLora, conv_lora: Lora) -> None: +def test_properties(lora: LinearLora, conv_lora: Conv2dLora) -> None: assert lora.name == "test" assert lora.rank == lora.down.out_features == lora.up.in_features == 16 assert lora.scale == 1.0 @@ -27,7 +27,6 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None: assert conv_lora.scale == 1.0 assert conv_lora.in_channels == conv_lora.down.in_channels == 16 assert conv_lora.out_channels == conv_lora.up.out_channels == 8 - assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d) assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1) # padding is set so the spatial dimensions are preserved assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1) @@ -40,12 +39,10 @@ def test_scale_setter(lora: LinearLora) -> None: def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None: - assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear) new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight) x = torch.randn(1, 320) assert torch.allclose(lora(x), new_lora(x)) - assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d) new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight) x = torch.randn(1, 16, 64, 64) assert torch.allclose(conv_lora(x), new_conv_lora(x))