add ensure_find and ensure_find_parent helpers

This commit is contained in:
Pierre Chapuis 2023-09-12 11:50:56 +02:00
parent b69dbc4e5c
commit 7a32699cc6
10 changed files with 27 additions and 33 deletions

View file

@ -63,8 +63,7 @@ def setup_converter(args: Args) -> ModelConverter:
target_state_dict = target.state_dict()
# Remove the class embedding from state dict since it was not mapped by the model converter
class_embedding = target.find(fl.Parameter)
assert class_embedding is not None
class_embedding = target.ensure_find(fl.Parameter)
class_embedding_key = next(
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
)

View file

@ -46,8 +46,7 @@ def setup_converter(args: Args) -> ModelConverter:
case _:
raise RuntimeError(f"Unsupported architecture: {architecture}")
text = "What a nice cat you have there!"
tokenizer = target.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokenizer = target.ensure_find(CLIPTokenizer)
tokens = tokenizer(text)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(tokens,), target_args=(text,)):

View file

@ -97,18 +97,15 @@ class TextualInversionConfig(BaseModel):
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
adapter = ConceptExtender(target=text_encoder)
tokenizer = text_encoder.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found in text encoder."
token_encoder = text_encoder.find(layer_type=TokenEncoder)
assert token_encoder is not None, "Token encoder not found in text encoder."
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
token_encoder = text_encoder.ensure_find(TokenEncoder)
if self.initializer_token is not None:
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
assert " " not in bpe, "This initializer_token is not a single token."
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
init_embedding = token_encoder(token).squeeze(0)
else:
token_encoder = text_encoder.find(layer_type=TokenEncoder)
assert token_encoder is not None, "Token encoder not found in text encoder."
token_encoder = text_encoder.ensure_find(TokenEncoder)
init_embedding = randn(token_encoder.embedding_dim)
adapter.add_concept(self.placeholder_token, init_embedding)
adapter.inject()
@ -146,8 +143,7 @@ class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender)
assert embedding_extender is not None, "Embedding extender not found in text encoder."
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
save_to_safetensors(

View file

@ -53,10 +53,7 @@ class Adapter(Generic[T]):
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
# i.e. initializing two adapters before injecting them.
true_parent = parent.find_parent(self.target)
if true_parent is None:
raise ValueError(f"{self.target} is not in {parent}")
true_parent = parent.ensure_find_parent(self.target)
true_parent.replace(
old_module=self.target,
new_module=self,

View file

@ -292,6 +292,11 @@ class Chain(ContextModule):
def find(self, layer_type: type[T]) -> T | None:
return next(self.layers(layer_type=layer_type), None)
def ensure_find(self, layer_type: type[T]) -> T:
r = self.find(layer_type)
assert r is not None, f"could not find {layer_type} in {self}"
return r
def find_parent(self, module: Module) -> "Chain | None":
if module in self: # avoid DFS-crawling the whole tree
return self
@ -299,6 +304,11 @@ class Chain(ContextModule):
return parent
return None
def ensure_find_parent(self, module: Module) -> "Chain":
r = self.find_parent(module)
assert r is not None, f"could not find {module} in {self}"
return r
def insert(self, index: int, module: Module) -> None:
if index < 0:
index = max(0, len(self._modules) + index + 1)

View file

@ -114,8 +114,7 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
def add_token(self, token: str, token_id: int) -> None:
token = token.lower()
tokenizer = self.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
tokenizer = self.ensure_find(CLIPTokenizer)
assert token_id not in tokenizer.token_to_id_mapping.values()
tokenizer.token_to_id_mapping[token] = token_id
current_pattern = tokenizer.token_pattern.pattern

View file

@ -58,8 +58,8 @@ class SelfAttentionInjectionPassthrough(Passthrough):
def __init__(self, target: SD1UNet) -> None:
guide_unet = target.structural_copy()
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
sa = attention_block.ensure_find(SelfAttention)
assert sa.parent is not None
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
super().__init__(
@ -109,8 +109,8 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
self.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
sa = attention_block.ensure_find(SelfAttention)
assert sa.parent is not None
self.sub_adapters.append(
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)

View file

@ -16,8 +16,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
projection: fl.Linear | None = None,
) -> None:
with self.setup_adapter(target=target):
tokenizer = target.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
tokenizer = target.ensure_find(CLIPTokenizer)
super().__init__(
tokenizer,
fl.SetContext(
@ -45,9 +44,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
@property
def tokenizer(self) -> CLIPTokenizer:
tokenizer = self.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Tokenizer not found."
return tokenizer
return self.ensure_find(CLIPTokenizer)
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()

View file

@ -92,8 +92,7 @@ def test_tokenizer_with_special_character():
token_extender.add_token("*", new_token_id)
token_extender.inject(clip_tokenizer)
adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer)
assert adapted_clip_tokenizer is not None
adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer)
assert torch.allclose(
adapted_clip_tokenizer.encode("*"),
@ -122,8 +121,7 @@ def test_encoder(
return_tensors="pt",
).input_ids
assert isinstance(ref_tokens, torch.Tensor)
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer)
assert tokenizer is not None
tokenizer = our_encoder_with_new_concepts.ensure_find(CLIPTokenizer)
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)

View file

@ -87,8 +87,7 @@ def test_encoder(
return_tensors="pt",
).input_ids
assert isinstance(ref_tokens, torch.Tensor)
tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
assert tokenizer is not None
tokenizer = our_encoder.ensure_find(CLIPTokenizer)
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)