diff --git a/scripts/conversion/convert_transformers_clip_image_model.py b/scripts/conversion/convert_transformers_clip_image_model.py index fd6a7b5..53e896d 100644 --- a/scripts/conversion/convert_transformers_clip_image_model.py +++ b/scripts/conversion/convert_transformers_clip_image_model.py @@ -63,8 +63,7 @@ def setup_converter(args: Args) -> ModelConverter: target_state_dict = target.state_dict() # Remove the class embedding from state dict since it was not mapped by the model converter - class_embedding = target.find(fl.Parameter) - assert class_embedding is not None + class_embedding = target.ensure_find(fl.Parameter) class_embedding_key = next( (n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None ) diff --git a/scripts/conversion/convert_transformers_clip_text_model.py b/scripts/conversion/convert_transformers_clip_text_model.py index b8819cd..bbe8679 100644 --- a/scripts/conversion/convert_transformers_clip_text_model.py +++ b/scripts/conversion/convert_transformers_clip_text_model.py @@ -46,8 +46,7 @@ def setup_converter(args: Args) -> ModelConverter: case _: raise RuntimeError(f"Unsupported architecture: {architecture}") text = "What a nice cat you have there!" - tokenizer = target.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Could not find tokenizer" + tokenizer = target.ensure_find(CLIPTokenizer) tokens = tokenizer(text) converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) if not converter.run(source_args=(tokens,), target_args=(text,)): diff --git a/scripts/training/finetune-ldm-textual-inversion.py b/scripts/training/finetune-ldm-textual-inversion.py index cf7d3f5..4d704b3 100644 --- a/scripts/training/finetune-ldm-textual-inversion.py +++ b/scripts/training/finetune-ldm-textual-inversion.py @@ -97,18 +97,15 @@ class TextualInversionConfig(BaseModel): def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None: adapter = ConceptExtender(target=text_encoder) - tokenizer = text_encoder.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Tokenizer not found in text encoder." - token_encoder = text_encoder.find(layer_type=TokenEncoder) - assert token_encoder is not None, "Token encoder not found in text encoder." + tokenizer = text_encoder.ensure_find(CLIPTokenizer) + token_encoder = text_encoder.ensure_find(TokenEncoder) if self.initializer_token is not None: bpe = tokenizer.byte_pair_encoding(token=self.initializer_token) assert " " not in bpe, "This initializer_token is not a single token." token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device) init_embedding = token_encoder(token).squeeze(0) else: - token_encoder = text_encoder.find(layer_type=TokenEncoder) - assert token_encoder is not None, "Token encoder not found in text encoder." + token_encoder = text_encoder.ensure_find(TokenEncoder) init_embedding = randn(token_encoder.embedding_dim) adapter.add_concept(self.placeholder_token, init_embedding) adapter.inject() @@ -146,8 +143,7 @@ class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None: - embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender) - assert embedding_extender is not None, "Embedding extender not found in text encoder." + embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender) tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)} save_to_safetensors( diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index 03a5146..9f27153 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -53,10 +53,7 @@ class Adapter(Generic[T]): # In general, `true_parent` is `parent`. We do this to support multiple adaptation, # i.e. initializing two adapters before injecting them. - true_parent = parent.find_parent(self.target) - if true_parent is None: - raise ValueError(f"{self.target} is not in {parent}") - + true_parent = parent.ensure_find_parent(self.target) true_parent.replace( old_module=self.target, new_module=self, diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 1a36747..b9853a3 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -292,6 +292,11 @@ class Chain(ContextModule): def find(self, layer_type: type[T]) -> T | None: return next(self.layers(layer_type=layer_type), None) + def ensure_find(self, layer_type: type[T]) -> T: + r = self.find(layer_type) + assert r is not None, f"could not find {layer_type} in {self}" + return r + def find_parent(self, module: Module) -> "Chain | None": if module in self: # avoid DFS-crawling the whole tree return self @@ -299,6 +304,11 @@ class Chain(ContextModule): return parent return None + def ensure_find_parent(self, module: Module) -> "Chain": + r = self.find_parent(module) + assert r is not None, f"could not find {module} in {self}" + return r + def insert(self, index: int, module: Module) -> None: if index < 0: index = max(0, len(self._modules) + index + 1) diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index a247a52..902112d 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -114,8 +114,7 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]): def add_token(self, token: str, token_id: int) -> None: token = token.lower() - tokenizer = self.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Tokenizer not found." + tokenizer = self.ensure_find(CLIPTokenizer) assert token_id not in tokenizer.token_to_id_mapping.values() tokenizer.token_to_id_mapping[token] = token_id current_pattern = tokenizer.token_pattern.pattern diff --git a/src/refiners/foundationals/latent_diffusion/reference_only_control.py b/src/refiners/foundationals/latent_diffusion/reference_only_control.py index f56707e..9e45f5c 100644 --- a/src/refiners/foundationals/latent_diffusion/reference_only_control.py +++ b/src/refiners/foundationals/latent_diffusion/reference_only_control.py @@ -58,8 +58,8 @@ class SelfAttentionInjectionPassthrough(Passthrough): def __init__(self, target: SD1UNet) -> None: guide_unet = target.structural_copy() for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)): - sa = attention_block.find(SelfAttention) - assert sa is not None and sa.parent is not None + sa = attention_block.ensure_find(SelfAttention) + assert sa.parent is not None SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject() super().__init__( @@ -109,8 +109,8 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]): for i, attention_block in enumerate(target.layers(CrossAttentionBlock)): self.set_context(f"self_attention_context_{i}", {"norm": None}) - sa = attention_block.find(SelfAttention) - assert sa is not None and sa.parent is not None + sa = attention_block.ensure_find(SelfAttention) + assert sa.parent is not None self.sub_adapters.append( SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index e334c3d..12b32db 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -16,8 +16,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): projection: fl.Linear | None = None, ) -> None: with self.setup_adapter(target=target): - tokenizer = target.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Tokenizer not found." + tokenizer = target.ensure_find(CLIPTokenizer) super().__init__( tokenizer, fl.SetContext( @@ -45,9 +44,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): @property def tokenizer(self) -> CLIPTokenizer: - tokenizer = self.find(layer_type=CLIPTokenizer) - assert tokenizer is not None, "Tokenizer not found." - return tokenizer + return self.ensure_find(CLIPTokenizer) def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None: position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index d25c04e..bf42fe8 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -92,8 +92,7 @@ def test_tokenizer_with_special_character(): token_extender.add_token("*", new_token_id) token_extender.inject(clip_tokenizer) - adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer) - assert adapted_clip_tokenizer is not None + adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer) assert torch.allclose( adapted_clip_tokenizer.encode("*"), @@ -122,8 +121,7 @@ def test_encoder( return_tensors="pt", ).input_ids assert isinstance(ref_tokens, torch.Tensor) - tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer) - assert tokenizer is not None + tokenizer = our_encoder_with_new_concepts.ensure_find(CLIPTokenizer) our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index b5d4693..eeb501a 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -87,8 +87,7 @@ def test_encoder( return_tensors="pt", ).input_ids assert isinstance(ref_tokens, torch.Tensor) - tokenizer = our_encoder.find(layer_type=CLIPTokenizer) - assert tokenizer is not None + tokenizer = our_encoder.ensure_find(CLIPTokenizer) our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens)