mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix 37425fb609
Things to understand: - subscripted generic basic types (e.g. `list[int]`) are types.GenericAlias; - subscripted generic classes are `typing._GenericAlias`; - neither can be used with `isinstance()`; - get_origin is the cleanest way to check for this.
This commit is contained in:
parent
f9305aa416
commit
98fce82853
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Generic, TypeVar, cast
|
from typing import Any, Generic, Iterator, TypeVar, cast
|
||||||
|
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
@ -385,20 +385,25 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target, *loras)
|
super().__init__(target, *loras)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_layers(self) -> Iterator[Lora[Any]]:
|
||||||
|
"""The LoRA layers."""
|
||||||
|
return cast(Iterator[Lora[Any]], self.layers(Lora))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
"""The names of the LoRA layers."""
|
"""The names of the LoRA layers."""
|
||||||
return [lora.name for lora in self.layers(Lora[Any])]
|
return [lora.name for lora in self.lora_layers]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> dict[str, Lora[Any]]:
|
def loras(self) -> dict[str, Lora[Any]]:
|
||||||
"""The LoRA layers indexed by name."""
|
"""The LoRA layers indexed by name."""
|
||||||
return {lora.name: lora for lora in self.layers(Lora[Any])}
|
return {lora.name: lora for lora in self.lora_layers}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scales(self) -> dict[str, float]:
|
def scales(self) -> dict[str, float]:
|
||||||
"""The scales of the LoRA layers indexed by names."""
|
"""The scales of the LoRA layers indexed by names."""
|
||||||
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
|
return {lora.name: lora.scale for lora in self.lora_layers}
|
||||||
|
|
||||||
@scales.setter
|
@scales.setter
|
||||||
def scale(self, values: dict[str, float]) -> None:
|
def scale(self, values: dict[str, float]) -> None:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
|
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
@ -349,6 +349,10 @@ class Chain(ContextModule):
|
||||||
Yields:
|
Yields:
|
||||||
Each module that matches the predicate.
|
Each module that matches the predicate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if get_origin(predicate) is not None:
|
||||||
|
raise ValueError(f"subscripted generics cannot be used as predicates")
|
||||||
|
|
||||||
if isinstance(predicate, type):
|
if isinstance(predicate, type):
|
||||||
# if the predicate is a Module type
|
# if the predicate is a Module type
|
||||||
# build a predicate function that matches the type
|
# build a predicate function that matches the type
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any
|
from typing import Any, Iterator, cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -193,7 +193,9 @@ class SDLoraManager:
|
||||||
@property
|
@property
|
||||||
def loras(self) -> list[Lora[Any]]:
|
def loras(self) -> list[Lora[Any]]:
|
||||||
"""List of all the LoRA layers managed by the SDLoraManager."""
|
"""List of all the LoRA layers managed by the SDLoraManager."""
|
||||||
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
|
unet_layers = cast(Iterator[Lora[Any]], self.unet.layers(Lora))
|
||||||
|
text_encoder_layers = cast(Iterator[Lora[Any]], self.clip_text_encoder.layers(Lora))
|
||||||
|
return [*unet_layers, *text_encoder_layers]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
|
|
Loading…
Reference in a new issue