import pytest import torch from refiners.fluxion.utils import no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer @no_grad() @pytest.mark.parametrize("k_encoder", [CLIPTextEncoderL]) def test_inject_eject(k_encoder: type[CLIPTextEncoder], test_device: torch.device): encoder = k_encoder(device=test_device) extender = ConceptExtender(encoder) cat_embedding = torch.randn((encoder.embedding_dim,), device=test_device) extender.add_concept(token="", embedding=cat_embedding) extender_2 = ConceptExtender(encoder) extender.inject() with pytest.raises(AssertionError) as no_nesting: extender_2.inject() assert str(no_nesting.value) == "ConceptExtender cannot be nested, add concepts to the injected instance instead." with pytest.raises(AssertionError) as no_nesting: ConceptExtender(encoder) assert str(no_nesting.value) == "ConceptExtender cannot be nested, add concepts to the injected instance instead." dog_embedding = torch.randn((encoder.embedding_dim,), device=test_device) extender.add_concept(token="", embedding=dog_embedding) extender.eject() extender_2.inject().eject() ConceptExtender(encoder) # no exception tokenizer = encoder.ensure_find(CLIPTokenizer) assert len(tokenizer.encode("")) > 3 assert len(tokenizer.encode("")) > 3 extender.inject() tokenizer = encoder.ensure_find(CLIPTokenizer) assert tokenizer.encode("").equal( torch.tensor( [ tokenizer.start_of_text_token_id, tokenizer.end_of_text_token_id + 1, tokenizer.end_of_text_token_id, ] ) ) assert tokenizer.encode("").equal( torch.tensor( [ tokenizer.start_of_text_token_id, tokenizer.end_of_text_token_id + 2, tokenizer.end_of_text_token_id, ] ) ) assert len(tokenizer.encode("")) > 3