Things to understand:

- subscripted generic basic types (e.g. `list[int]`) are types.GenericAlias;
- subscripted generic classes are `typing._GenericAlias`;
- neither can be used with `isinstance()`;
- get_origin is the cleanest way to check for this.
This commit is contained in:
Pierre Chapuis 2024-02-06 13:03:50 +01:00
parent f9305aa416
commit 98fce82853
3 changed files with 18 additions and 7 deletions

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast
from typing import Any, Generic, Iterator, TypeVar, cast
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
@ -385,20 +385,25 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
with self.setup_adapter(target):
super().__init__(target, *loras)
@property
def lora_layers(self) -> Iterator[Lora[Any]]:
"""The LoRA layers."""
return cast(Iterator[Lora[Any]], self.layers(Lora))
@property
def names(self) -> list[str]:
"""The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora[Any])]
return [lora.name for lora in self.lora_layers]
@property
def loras(self) -> dict[str, Lora[Any]]:
"""The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora[Any])}
return {lora.name: lora for lora in self.lora_layers}
@property
def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
return {lora.name: lora.scale for lora in self.lora_layers}
@scales.setter
def scale(self, values: dict[str, float]) -> None:

View file

@ -3,7 +3,7 @@ import re
import sys
import traceback
from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType
@ -349,6 +349,10 @@ class Chain(ContextModule):
Yields:
Each module that matches the predicate.
"""
if get_origin(predicate) is not None:
raise ValueError(f"subscripted generics cannot be used as predicates")
if isinstance(predicate, type):
# if the predicate is a Module type
# build a predicate function that matches the type

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Iterator, cast
from warnings import warn
from torch import Tensor
@ -193,7 +193,9 @@ class SDLoraManager:
@property
def loras(self) -> list[Lora[Any]]:
"""List of all the LoRA layers managed by the SDLoraManager."""
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
unet_layers = cast(Iterator[Lora[Any]], self.unet.layers(Lora))
text_encoder_layers = cast(Iterator[Lora[Any]], self.clip_text_encoder.layers(Lora))
return [*unet_layers, *text_encoder_layers]
@property
def names(self) -> list[str]: