mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
add inject / eject test for concept extender (+ better errors)
This commit is contained in:
parent
93270ec2d7
commit
0e77ef1720
|
@ -97,6 +97,8 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target)
|
super().__init__(target)
|
||||||
|
|
||||||
|
self._ensure_no_nesting()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token_encoder, token_encoder_parent = next(target.walk(TokenEncoder))
|
token_encoder, token_encoder_parent = next(target.walk(TokenEncoder))
|
||||||
self._token_encoder_parent = [token_encoder_parent]
|
self._token_encoder_parent = [token_encoder_parent]
|
||||||
|
@ -113,6 +115,11 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||||
self._embedding_extender = [EmbeddingExtender(token_encoder)]
|
self._embedding_extender = [EmbeddingExtender(token_encoder)]
|
||||||
self._token_extender = [TokenExtender(clip_tokenizer)]
|
self._token_extender = [TokenExtender(clip_tokenizer)]
|
||||||
|
|
||||||
|
def _ensure_no_nesting(self) -> None:
|
||||||
|
assert not isinstance(
|
||||||
|
self.target.parent, ConceptExtender
|
||||||
|
), "ConceptExtender cannot be nested, add concepts to the injected instance instead."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embedding_extender(self) -> EmbeddingExtender:
|
def embedding_extender(self) -> EmbeddingExtender:
|
||||||
assert len(self._embedding_extender) == 1, "EmbeddingExtender not found."
|
assert len(self._embedding_extender) == 1, "EmbeddingExtender not found."
|
||||||
|
@ -138,6 +145,7 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||||
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
||||||
|
|
||||||
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
|
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
|
||||||
|
self._ensure_no_nesting()
|
||||||
self.embedding_extender.inject(self.token_encoder_parent)
|
self.embedding_extender.inject(self.token_encoder_parent)
|
||||||
self.token_extender.inject(self.clip_tokenizer_parent)
|
self.token_extender.inject(self.clip_tokenizer_parent)
|
||||||
return super().inject(parent)
|
return super().inject(parent)
|
||||||
|
|
63
tests/adapters/test_concept_extender.py
Normal file
63
tests/adapters/test_concept_extender.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
@pytest.mark.parametrize("k_encoder", [CLIPTextEncoderL])
|
||||||
|
def test_inject_eject(k_encoder: type[CLIPTextEncoder], test_device: torch.device):
|
||||||
|
encoder = k_encoder(device=test_device)
|
||||||
|
extender = ConceptExtender(encoder)
|
||||||
|
|
||||||
|
cat_embedding = torch.randn((encoder.embedding_dim,), device=test_device)
|
||||||
|
extender.add_concept(token="<token1>", embedding=cat_embedding)
|
||||||
|
|
||||||
|
extender_2 = ConceptExtender(encoder)
|
||||||
|
|
||||||
|
extender.inject()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as no_nesting:
|
||||||
|
extender_2.inject()
|
||||||
|
assert str(no_nesting.value) == "ConceptExtender cannot be nested, add concepts to the injected instance instead."
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as no_nesting:
|
||||||
|
ConceptExtender(encoder)
|
||||||
|
assert str(no_nesting.value) == "ConceptExtender cannot be nested, add concepts to the injected instance instead."
|
||||||
|
|
||||||
|
dog_embedding = torch.randn((encoder.embedding_dim,), device=test_device)
|
||||||
|
extender.add_concept(token="<token2>", embedding=dog_embedding)
|
||||||
|
extender.eject()
|
||||||
|
|
||||||
|
extender_2.inject().eject()
|
||||||
|
ConceptExtender(encoder) # no exception
|
||||||
|
|
||||||
|
tokenizer = encoder.ensure_find(CLIPTokenizer)
|
||||||
|
assert len(tokenizer.encode("<token1>")) > 3
|
||||||
|
assert len(tokenizer.encode("<token2>")) > 3
|
||||||
|
|
||||||
|
extender.inject()
|
||||||
|
|
||||||
|
tokenizer = encoder.ensure_find(CLIPTokenizer)
|
||||||
|
assert tokenizer.encode("<token1>").equal(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
tokenizer.start_of_text_token_id,
|
||||||
|
tokenizer.end_of_text_token_id + 1,
|
||||||
|
tokenizer.end_of_text_token_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert tokenizer.encode("<token2>").equal(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
tokenizer.start_of_text_token_id,
|
||||||
|
tokenizer.end_of_text_token_id + 2,
|
||||||
|
tokenizer.end_of_text_token_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(tokenizer.encode("<token3>")) > 3
|
Loading…
Reference in a new issue