From 98fce82853bbe61c0759eddcdcfd68114a13ebce Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 6 Feb 2024 13:03:50 +0100 Subject: [PATCH] fix 37425fb609bfc1d954ad2396d3542b961d3cff78 Things to understand: - subscripted generic basic types (e.g. `list[int]`) are types.GenericAlias; - subscripted generic classes are `typing._GenericAlias`; - neither can be used with `isinstance()`; - get_origin is the cleanest way to check for this. --- src/refiners/fluxion/adapters/lora.py | 13 +++++++++---- src/refiners/fluxion/layers/chain.py | 6 +++++- src/refiners/foundationals/latent_diffusion/lora.py | 6 ++++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index cb3e348..e773359 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Generic, TypeVar, cast +from typing import Any, Generic, Iterator, TypeVar, cast from torch import Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter @@ -385,20 +385,25 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): with self.setup_adapter(target): super().__init__(target, *loras) + @property + def lora_layers(self) -> Iterator[Lora[Any]]: + """The LoRA layers.""" + return cast(Iterator[Lora[Any]], self.layers(Lora)) + @property def names(self) -> list[str]: """The names of the LoRA layers.""" - return [lora.name for lora in self.layers(Lora[Any])] + return [lora.name for lora in self.lora_layers] @property def loras(self) -> dict[str, Lora[Any]]: """The LoRA layers indexed by name.""" - return {lora.name: lora for lora in self.layers(Lora[Any])} + return {lora.name: lora for lora in self.lora_layers} @property def scales(self) -> dict[str, float]: """The scales of the LoRA layers indexed by names.""" - return {lora.name: lora.scale for lora in self.layers(Lora[Any])} + return {lora.name: lora.scale for lora in self.lora_layers} @scales.setter def scale(self, values: dict[str, float]) -> None: diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index acd2ddc..ab65e5c 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -3,7 +3,7 @@ import re import sys import traceback from collections import defaultdict -from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload +from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload import torch from torch import Tensor, cat, device as Device, dtype as DType @@ -349,6 +349,10 @@ class Chain(ContextModule): Yields: Each module that matches the predicate. """ + + if get_origin(predicate) is not None: + raise ValueError(f"subscripted generics cannot be used as predicates") + if isinstance(predicate, type): # if the predicate is a Module type # build a predicate function that matches the type diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 0c120b3..ca23adb 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Iterator, cast from warnings import warn from torch import Tensor @@ -193,7 +193,9 @@ class SDLoraManager: @property def loras(self) -> list[Lora[Any]]: """List of all the LoRA layers managed by the SDLoraManager.""" - return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any])) + unet_layers = cast(Iterator[Lora[Any]], self.unet.layers(Lora)) + text_encoder_layers = cast(Iterator[Lora[Any]], self.clip_text_encoder.layers(Lora)) + return [*unet_layers, *text_encoder_layers] @property def names(self) -> list[str]: