mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix 37425fb609
Things to understand: - subscripted generic basic types (e.g. `list[int]`) are types.GenericAlias; - subscripted generic classes are `typing._GenericAlias`; - neither can be used with `isinstance()`; - get_origin is the cleanest way to check for this.
This commit is contained in:
parent
f9305aa416
commit
98fce82853
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
from typing import Any, Generic, Iterator, TypeVar, cast
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
@ -385,20 +385,25 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
|||
with self.setup_adapter(target):
|
||||
super().__init__(target, *loras)
|
||||
|
||||
@property
|
||||
def lora_layers(self) -> Iterator[Lora[Any]]:
|
||||
"""The LoRA layers."""
|
||||
return cast(Iterator[Lora[Any]], self.layers(Lora))
|
||||
|
||||
@property
|
||||
def names(self) -> list[str]:
|
||||
"""The names of the LoRA layers."""
|
||||
return [lora.name for lora in self.layers(Lora[Any])]
|
||||
return [lora.name for lora in self.lora_layers]
|
||||
|
||||
@property
|
||||
def loras(self) -> dict[str, Lora[Any]]:
|
||||
"""The LoRA layers indexed by name."""
|
||||
return {lora.name: lora for lora in self.layers(Lora[Any])}
|
||||
return {lora.name: lora for lora in self.lora_layers}
|
||||
|
||||
@property
|
||||
def scales(self) -> dict[str, float]:
|
||||
"""The scales of the LoRA layers indexed by names."""
|
||||
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
|
||||
return {lora.name: lora.scale for lora in self.lora_layers}
|
||||
|
||||
@scales.setter
|
||||
def scale(self, values: dict[str, float]) -> None:
|
||||
|
|
|
@ -3,7 +3,7 @@ import re
|
|||
import sys
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
|
||||
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload
|
||||
|
||||
import torch
|
||||
from torch import Tensor, cat, device as Device, dtype as DType
|
||||
|
@ -349,6 +349,10 @@ class Chain(ContextModule):
|
|||
Yields:
|
||||
Each module that matches the predicate.
|
||||
"""
|
||||
|
||||
if get_origin(predicate) is not None:
|
||||
raise ValueError(f"subscripted generics cannot be used as predicates")
|
||||
|
||||
if isinstance(predicate, type):
|
||||
# if the predicate is a Module type
|
||||
# build a predicate function that matches the type
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any
|
||||
from typing import Any, Iterator, cast
|
||||
from warnings import warn
|
||||
|
||||
from torch import Tensor
|
||||
|
@ -193,7 +193,9 @@ class SDLoraManager:
|
|||
@property
|
||||
def loras(self) -> list[Lora[Any]]:
|
||||
"""List of all the LoRA layers managed by the SDLoraManager."""
|
||||
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
|
||||
unet_layers = cast(Iterator[Lora[Any]], self.unet.layers(Lora))
|
||||
text_encoder_layers = cast(Iterator[Lora[Any]], self.clip_text_encoder.layers(Lora))
|
||||
return [*unet_layers, *text_encoder_layers]
|
||||
|
||||
@property
|
||||
def names(self) -> list[str]:
|
||||
|
|
Loading…
Reference in a new issue