diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index 3b08765..3a7c99a 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -97,6 +97,8 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): with self.setup_adapter(target): super().__init__(target) + self._ensure_no_nesting() + try: token_encoder, token_encoder_parent = next(target.walk(TokenEncoder)) self._token_encoder_parent = [token_encoder_parent] @@ -113,6 +115,11 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): self._embedding_extender = [EmbeddingExtender(token_encoder)] self._token_extender = [TokenExtender(clip_tokenizer)] + def _ensure_no_nesting(self) -> None: + assert not isinstance( + self.target.parent, ConceptExtender + ), "ConceptExtender cannot be nested, add concepts to the injected instance instead." + @property def embedding_extender(self) -> EmbeddingExtender: assert len(self._embedding_extender) == 1, "EmbeddingExtender not found." @@ -138,6 +145,7 @@ class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1) def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender": + self._ensure_no_nesting() self.embedding_extender.inject(self.token_encoder_parent) self.token_extender.inject(self.clip_tokenizer_parent) return super().inject(parent) diff --git a/tests/adapters/test_concept_extender.py b/tests/adapters/test_concept_extender.py new file mode 100644 index 0000000..55c3599 --- /dev/null +++ b/tests/adapters/test_concept_extender.py @@ -0,0 +1,63 @@ +import pytest +import torch + +from refiners.fluxion.utils import no_grad +from refiners.foundationals.clip.concepts import ConceptExtender +from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL +from refiners.foundationals.clip.tokenizer import CLIPTokenizer + + +@no_grad() +@pytest.mark.parametrize("k_encoder", [CLIPTextEncoderL]) +def test_inject_eject(k_encoder: type[CLIPTextEncoder], test_device: torch.device): + encoder = k_encoder(device=test_device) + extender = ConceptExtender(encoder) + + cat_embedding = torch.randn((encoder.embedding_dim,), device=test_device) + extender.add_concept(token="", embedding=cat_embedding) + + extender_2 = ConceptExtender(encoder) + + extender.inject() + + with pytest.raises(AssertionError) as no_nesting: + extender_2.inject() + assert str(no_nesting.value) == "ConceptExtender cannot be nested, add concepts to the injected instance instead." + + with pytest.raises(AssertionError) as no_nesting: + ConceptExtender(encoder) + assert str(no_nesting.value) == "ConceptExtender cannot be nested, add concepts to the injected instance instead." + + dog_embedding = torch.randn((encoder.embedding_dim,), device=test_device) + extender.add_concept(token="", embedding=dog_embedding) + extender.eject() + + extender_2.inject().eject() + ConceptExtender(encoder) # no exception + + tokenizer = encoder.ensure_find(CLIPTokenizer) + assert len(tokenizer.encode("")) > 3 + assert len(tokenizer.encode("")) > 3 + + extender.inject() + + tokenizer = encoder.ensure_find(CLIPTokenizer) + assert tokenizer.encode("").equal( + torch.tensor( + [ + tokenizer.start_of_text_token_id, + tokenizer.end_of_text_token_id + 1, + tokenizer.end_of_text_token_id, + ] + ) + ) + assert tokenizer.encode("").equal( + torch.tensor( + [ + tokenizer.start_of_text_token_id, + tokenizer.end_of_text_token_id + 2, + tokenizer.end_of_text_token_id, + ] + ) + ) + assert len(tokenizer.encode("")) > 3