add ensure_find and ensure_find_parent helpers

This commit is contained in:
Pierre Chapuis 2023-09-12 11:50:56 +02:00
parent b69dbc4e5c
commit 7a32699cc6
10 changed files with 27 additions and 33 deletions

View file

@ -63,8 +63,7 @@ def setup_converter(args: Args) -> ModelConverter:
target_state_dict = target.state_dict() target_state_dict = target.state_dict()
# Remove the class embedding from state dict since it was not mapped by the model converter # Remove the class embedding from state dict since it was not mapped by the model converter
class_embedding = target.find(fl.Parameter) class_embedding = target.ensure_find(fl.Parameter)
assert class_embedding is not None
class_embedding_key = next( class_embedding_key = next(
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None (n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
) )

View file

@ -46,8 +46,7 @@ def setup_converter(args: Args) -> ModelConverter:
case _: case _:
raise RuntimeError(f"Unsupported architecture: {architecture}") raise RuntimeError(f"Unsupported architecture: {architecture}")
text = "What a nice cat you have there!" text = "What a nice cat you have there!"
tokenizer = target.find(layer_type=CLIPTokenizer) tokenizer = target.ensure_find(CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer(text) tokens = tokenizer(text)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(tokens,), target_args=(text,)): if not converter.run(source_args=(tokens,), target_args=(text,)):

View file

@ -97,18 +97,15 @@ class TextualInversionConfig(BaseModel):
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None: def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
adapter = ConceptExtender(target=text_encoder) adapter = ConceptExtender(target=text_encoder)
tokenizer = text_encoder.find(layer_type=CLIPTokenizer) tokenizer = text_encoder.ensure_find(CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found in text encoder." token_encoder = text_encoder.ensure_find(TokenEncoder)
token_encoder = text_encoder.find(layer_type=TokenEncoder)
assert token_encoder is not None, "Token encoder not found in text encoder."
if self.initializer_token is not None: if self.initializer_token is not None:
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token) bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
assert " " not in bpe, "This initializer_token is not a single token." assert " " not in bpe, "This initializer_token is not a single token."
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device) token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
init_embedding = token_encoder(token).squeeze(0) init_embedding = token_encoder(token).squeeze(0)
else: else:
token_encoder = text_encoder.find(layer_type=TokenEncoder) token_encoder = text_encoder.ensure_find(TokenEncoder)
assert token_encoder is not None, "Token encoder not found in text encoder."
init_embedding = randn(token_encoder.embedding_dim) init_embedding = randn(token_encoder.embedding_dim)
adapter.add_concept(self.placeholder_token, init_embedding) adapter.add_concept(self.placeholder_token, init_embedding)
adapter.inject() adapter.inject()
@ -146,8 +143,7 @@ class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None: def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender) embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
assert embedding_extender is not None, "Embedding extender not found in text encoder."
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)} tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
save_to_safetensors( save_to_safetensors(

View file

@ -53,10 +53,7 @@ class Adapter(Generic[T]):
# In general, `true_parent` is `parent`. We do this to support multiple adaptation, # In general, `true_parent` is `parent`. We do this to support multiple adaptation,
# i.e. initializing two adapters before injecting them. # i.e. initializing two adapters before injecting them.
true_parent = parent.find_parent(self.target) true_parent = parent.ensure_find_parent(self.target)
if true_parent is None:
raise ValueError(f"{self.target} is not in {parent}")
true_parent.replace( true_parent.replace(
old_module=self.target, old_module=self.target,
new_module=self, new_module=self,

View file

@ -292,6 +292,11 @@ class Chain(ContextModule):
def find(self, layer_type: type[T]) -> T | None: def find(self, layer_type: type[T]) -> T | None:
return next(self.layers(layer_type=layer_type), None) return next(self.layers(layer_type=layer_type), None)
def ensure_find(self, layer_type: type[T]) -> T:
r = self.find(layer_type)
assert r is not None, f"could not find {layer_type} in {self}"
return r
def find_parent(self, module: Module) -> "Chain | None": def find_parent(self, module: Module) -> "Chain | None":
if module in self: # avoid DFS-crawling the whole tree if module in self: # avoid DFS-crawling the whole tree
return self return self
@ -299,6 +304,11 @@ class Chain(ContextModule):
return parent return parent
return None return None
def ensure_find_parent(self, module: Module) -> "Chain":
r = self.find_parent(module)
assert r is not None, f"could not find {module} in {self}"
return r
def insert(self, index: int, module: Module) -> None: def insert(self, index: int, module: Module) -> None:
if index < 0: if index < 0:
index = max(0, len(self._modules) + index + 1) index = max(0, len(self._modules) + index + 1)

View file

@ -114,8 +114,7 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
def add_token(self, token: str, token_id: int) -> None: def add_token(self, token: str, token_id: int) -> None:
token = token.lower() token = token.lower()
tokenizer = self.find(layer_type=CLIPTokenizer) tokenizer = self.ensure_find(CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
assert token_id not in tokenizer.token_to_id_mapping.values() assert token_id not in tokenizer.token_to_id_mapping.values()
tokenizer.token_to_id_mapping[token] = token_id tokenizer.token_to_id_mapping[token] = token_id
current_pattern = tokenizer.token_pattern.pattern current_pattern = tokenizer.token_pattern.pattern

View file

@ -58,8 +58,8 @@ class SelfAttentionInjectionPassthrough(Passthrough):
def __init__(self, target: SD1UNet) -> None: def __init__(self, target: SD1UNet) -> None:
guide_unet = target.structural_copy() guide_unet = target.structural_copy()
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)): for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
sa = attention_block.find(SelfAttention) sa = attention_block.ensure_find(SelfAttention)
assert sa is not None and sa.parent is not None assert sa.parent is not None
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject() SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
super().__init__( super().__init__(
@ -109,8 +109,8 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)): for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
self.set_context(f"self_attention_context_{i}", {"norm": None}) self.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention) sa = attention_block.ensure_find(SelfAttention)
assert sa is not None and sa.parent is not None assert sa.parent is not None
self.sub_adapters.append( self.sub_adapters.append(
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg) SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)

View file

@ -16,8 +16,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
projection: fl.Linear | None = None, projection: fl.Linear | None = None,
) -> None: ) -> None:
with self.setup_adapter(target=target): with self.setup_adapter(target=target):
tokenizer = target.find(layer_type=CLIPTokenizer) tokenizer = target.ensure_find(CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
super().__init__( super().__init__(
tokenizer, tokenizer,
fl.SetContext( fl.SetContext(
@ -45,9 +44,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
@property @property
def tokenizer(self) -> CLIPTokenizer: def tokenizer(self) -> CLIPTokenizer:
tokenizer = self.find(layer_type=CLIPTokenizer) return self.ensure_find(CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
return tokenizer
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None: def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()

View file

@ -92,8 +92,7 @@ def test_tokenizer_with_special_character():
token_extender.add_token("*", new_token_id) token_extender.add_token("*", new_token_id)
token_extender.inject(clip_tokenizer) token_extender.inject(clip_tokenizer)
adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer) adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer)
assert adapted_clip_tokenizer is not None
assert torch.allclose( assert torch.allclose(
adapted_clip_tokenizer.encode("*"), adapted_clip_tokenizer.encode("*"),
@ -122,8 +121,7 @@ def test_encoder(
return_tensors="pt", return_tensors="pt",
).input_ids ).input_ids
assert isinstance(ref_tokens, torch.Tensor) assert isinstance(ref_tokens, torch.Tensor)
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer) tokenizer = our_encoder_with_new_concepts.ensure_find(CLIPTokenizer)
assert tokenizer is not None
our_tokens = tokenizer(prompt) our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens) assert torch.equal(our_tokens, ref_tokens)

View file

@ -87,8 +87,7 @@ def test_encoder(
return_tensors="pt", return_tensors="pt",
).input_ids ).input_ids
assert isinstance(ref_tokens, torch.Tensor) assert isinstance(ref_tokens, torch.Tensor)
tokenizer = our_encoder.find(layer_type=CLIPTokenizer) tokenizer = our_encoder.ensure_find(CLIPTokenizer)
assert tokenizer is not None
our_tokens = tokenizer(prompt) our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens) assert torch.equal(our_tokens, ref_tokens)