mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
prevent setattr pytorch module to register on the Chain class
This commit is contained in:
parent
d02be0d10e
commit
a663375dc7
|
@ -143,6 +143,14 @@ class Chain(ContextModule):
|
||||||
if isinstance(module, ContextModule) and module.parent != self:
|
if isinstance(module, ContextModule) and module.parent != self:
|
||||||
module._set_parent(self)
|
module._set_parent(self)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if isinstance(value, torch.nn.Module):
|
||||||
|
raise ValueError(
|
||||||
|
"Chain does not support setting modules by attribute. Instead, use a mutation method like `append` or"
|
||||||
|
" wrap it within a single element list to prevent pytorch from registering it as a submodule."
|
||||||
|
)
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> ContextProvider:
|
def provider(self) -> ContextProvider:
|
||||||
return self._provider
|
return self._provider
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch.nn.modules.module import Module as TorchModule
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
from refiners.fluxion.context import Context, ContextProvider
|
from refiners.fluxion.context import Context, ContextProvider
|
||||||
|
|
||||||
from typing import Callable, TYPE_CHECKING, Sequence
|
from typing import TYPE_CHECKING, Sequence
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.fluxion.layers.chain import Chain
|
from refiners.fluxion.layers.chain import Chain
|
||||||
|
@ -26,11 +26,14 @@ class Module(TorchModule):
|
||||||
_buffers: dict[str, Any]
|
_buffers: dict[str, Any]
|
||||||
_tag: str = ""
|
_tag: str = ""
|
||||||
|
|
||||||
__getattr__: Callable[["Module", str], Any] # type: ignore
|
|
||||||
__setattr__: Callable[["Module", str, Any], None] # type: ignore
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, *kwargs) # type: ignore
|
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return super().__getattr__(name=name)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
return super().__setattr__(name=name, value=value)
|
||||||
|
|
||||||
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
||||||
state_dict = load_from_safetensors(tensors_path)
|
state_dict = load_from_safetensors(tensors_path)
|
||||||
|
|
|
@ -10,64 +10,6 @@ from torch.nn import Parameter
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
|
||||||
"""
|
|
||||||
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
import torch
|
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
|
||||||
|
|
||||||
encoder = CLIPTextEncoderL(device="cuda")
|
|
||||||
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
|
|
||||||
encoder.load_state_dict(tensors)
|
|
||||||
|
|
||||||
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
|
|
||||||
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
|
|
||||||
|
|
||||||
extender = ConceptExtender(encoder)
|
|
||||||
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
|
|
||||||
extender.inject()
|
|
||||||
# New concepts can be added at any time
|
|
||||||
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
|
|
||||||
|
|
||||||
# Now the encoder can be used with the new concepts
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, target: CLIPTextEncoder) -> None:
|
|
||||||
with self.setup_adapter(target):
|
|
||||||
super().__init__(target)
|
|
||||||
|
|
||||||
try:
|
|
||||||
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError("TokenEncoder not found.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
clip_tokenizer, self.clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError("Tokenizer not found.")
|
|
||||||
|
|
||||||
self.embedding_extender = EmbeddingExtender(token_encoder)
|
|
||||||
self.token_extender = TokenExtender(clip_tokenizer)
|
|
||||||
|
|
||||||
def add_concept(self, token: str, embedding: Tensor) -> None:
|
|
||||||
self.embedding_extender.add_embedding(embedding)
|
|
||||||
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
|
||||||
|
|
||||||
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
|
|
||||||
self.embedding_extender.inject(self.token_encoder_parent)
|
|
||||||
self.token_extender.inject(self.clip_tokenizer_parent)
|
|
||||||
return super().inject(parent)
|
|
||||||
|
|
||||||
def eject(self) -> None:
|
|
||||||
self.embedding_extender.eject()
|
|
||||||
self.token_extender.eject()
|
|
||||||
super().eject()
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
old_weight: Parameter
|
old_weight: Parameter
|
||||||
new_weight: Parameter
|
new_weight: Parameter
|
||||||
|
@ -122,3 +64,84 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||||
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
||||||
# Define the keyword as its own smallest subtoken
|
# Define the keyword as its own smallest subtoken
|
||||||
tokenizer.byte_pair_encoding_cache[token] = token
|
tokenizer.byte_pair_encoding_cache[token] = token
|
||||||
|
|
||||||
|
|
||||||
|
class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||||
|
"""
|
||||||
|
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
import torch
|
||||||
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
|
encoder = CLIPTextEncoderL(device="cuda")
|
||||||
|
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
|
||||||
|
encoder.load_state_dict(tensors)
|
||||||
|
|
||||||
|
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
|
||||||
|
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
|
||||||
|
|
||||||
|
extender = ConceptExtender(encoder)
|
||||||
|
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
|
||||||
|
extender.inject()
|
||||||
|
# New concepts can be added at any time
|
||||||
|
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
|
||||||
|
|
||||||
|
# Now the encoder can be used with the new concepts
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target: CLIPTextEncoder) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_encoder, token_encoder_parent = next(target.walk(TokenEncoder))
|
||||||
|
self._token_encoder_parent = [token_encoder_parent]
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("TokenEncoder not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip_tokenizer, clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
|
||||||
|
self._clip_tokenizer_parent = [clip_tokenizer_parent]
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("Tokenizer not found.")
|
||||||
|
|
||||||
|
self._embedding_extender = [EmbeddingExtender(token_encoder)]
|
||||||
|
self._token_extender = [TokenExtender(clip_tokenizer)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_extender(self) -> EmbeddingExtender:
|
||||||
|
assert len(self._embedding_extender) == 1, "EmbeddingExtender not found."
|
||||||
|
return self._embedding_extender[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_extender(self) -> TokenExtender:
|
||||||
|
assert len(self._token_extender) == 1, "TokenExtender not found."
|
||||||
|
return self._token_extender[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_encoder_parent(self) -> fl.Chain:
|
||||||
|
assert len(self._token_encoder_parent) == 1, "TokenEncoder parent not found."
|
||||||
|
return self._token_encoder_parent[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clip_tokenizer_parent(self) -> fl.Chain:
|
||||||
|
assert len(self._clip_tokenizer_parent) == 1, "Tokenizer parent not found."
|
||||||
|
return self._clip_tokenizer_parent[0]
|
||||||
|
|
||||||
|
def add_concept(self, token: str, embedding: Tensor) -> None:
|
||||||
|
self.embedding_extender.add_embedding(embedding)
|
||||||
|
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
||||||
|
|
||||||
|
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
|
||||||
|
self.embedding_extender.inject(self.token_encoder_parent)
|
||||||
|
self.token_extender.inject(self.clip_tokenizer_parent)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.embedding_extender.eject()
|
||||||
|
self.token_extender.eject()
|
||||||
|
super().eject()
|
||||||
|
|
|
@ -217,3 +217,12 @@ def test_chain_structural_copy() -> None:
|
||||||
y2 = m2(x)
|
y2 = m2(x)
|
||||||
assert y2.shape == (7, 12)
|
assert y2.shape == (7, 12)
|
||||||
torch.equal(y2, y)
|
torch.equal(y2, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_setattr_dont_register() -> None:
|
||||||
|
chain = fl.Chain(fl.Linear(in_features=1, out_features=1), fl.Linear(in_features=1, out_features=1))
|
||||||
|
|
||||||
|
with pytest.raises(expected_exception=ValueError):
|
||||||
|
chain.foo = fl.Linear(in_features=1, out_features=1)
|
||||||
|
|
||||||
|
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
|
||||||
|
|
Loading…
Reference in a new issue