mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
add ensure_find and ensure_find_parent helpers
This commit is contained in:
parent
b69dbc4e5c
commit
7a32699cc6
|
@ -63,8 +63,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
target_state_dict = target.state_dict()
|
||||
|
||||
# Remove the class embedding from state dict since it was not mapped by the model converter
|
||||
class_embedding = target.find(fl.Parameter)
|
||||
assert class_embedding is not None
|
||||
class_embedding = target.ensure_find(fl.Parameter)
|
||||
class_embedding_key = next(
|
||||
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
|
||||
)
|
||||
|
|
|
@ -46,8 +46,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
case _:
|
||||
raise RuntimeError(f"Unsupported architecture: {architecture}")
|
||||
text = "What a nice cat you have there!"
|
||||
tokenizer = target.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Could not find tokenizer"
|
||||
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||
tokens = tokenizer(text)
|
||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||
if not converter.run(source_args=(tokens,), target_args=(text,)):
|
||||
|
|
|
@ -97,18 +97,15 @@ class TextualInversionConfig(BaseModel):
|
|||
|
||||
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
||||
adapter = ConceptExtender(target=text_encoder)
|
||||
tokenizer = text_encoder.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found in text encoder."
|
||||
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
||||
assert token_encoder is not None, "Token encoder not found in text encoder."
|
||||
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
|
||||
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
||||
if self.initializer_token is not None:
|
||||
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
||||
assert " " not in bpe, "This initializer_token is not a single token."
|
||||
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
||||
init_embedding = token_encoder(token).squeeze(0)
|
||||
else:
|
||||
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
||||
assert token_encoder is not None, "Token encoder not found in text encoder."
|
||||
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
||||
init_embedding = randn(token_encoder.embedding_dim)
|
||||
adapter.add_concept(self.placeholder_token, init_embedding)
|
||||
adapter.inject()
|
||||
|
@ -146,8 +143,7 @@ class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
|||
|
||||
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender)
|
||||
assert embedding_extender is not None, "Embedding extender not found in text encoder."
|
||||
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
|
||||
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
||||
|
||||
save_to_safetensors(
|
||||
|
|
|
@ -53,10 +53,7 @@ class Adapter(Generic[T]):
|
|||
|
||||
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
|
||||
# i.e. initializing two adapters before injecting them.
|
||||
true_parent = parent.find_parent(self.target)
|
||||
if true_parent is None:
|
||||
raise ValueError(f"{self.target} is not in {parent}")
|
||||
|
||||
true_parent = parent.ensure_find_parent(self.target)
|
||||
true_parent.replace(
|
||||
old_module=self.target,
|
||||
new_module=self,
|
||||
|
|
|
@ -292,6 +292,11 @@ class Chain(ContextModule):
|
|||
def find(self, layer_type: type[T]) -> T | None:
|
||||
return next(self.layers(layer_type=layer_type), None)
|
||||
|
||||
def ensure_find(self, layer_type: type[T]) -> T:
|
||||
r = self.find(layer_type)
|
||||
assert r is not None, f"could not find {layer_type} in {self}"
|
||||
return r
|
||||
|
||||
def find_parent(self, module: Module) -> "Chain | None":
|
||||
if module in self: # avoid DFS-crawling the whole tree
|
||||
return self
|
||||
|
@ -299,6 +304,11 @@ class Chain(ContextModule):
|
|||
return parent
|
||||
return None
|
||||
|
||||
def ensure_find_parent(self, module: Module) -> "Chain":
|
||||
r = self.find_parent(module)
|
||||
assert r is not None, f"could not find {module} in {self}"
|
||||
return r
|
||||
|
||||
def insert(self, index: int, module: Module) -> None:
|
||||
if index < 0:
|
||||
index = max(0, len(self._modules) + index + 1)
|
||||
|
|
|
@ -114,8 +114,7 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
|||
|
||||
def add_token(self, token: str, token_id: int) -> None:
|
||||
token = token.lower()
|
||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
tokenizer = self.ensure_find(CLIPTokenizer)
|
||||
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||
tokenizer.token_to_id_mapping[token] = token_id
|
||||
current_pattern = tokenizer.token_pattern.pattern
|
||||
|
|
|
@ -58,8 +58,8 @@ class SelfAttentionInjectionPassthrough(Passthrough):
|
|||
def __init__(self, target: SD1UNet) -> None:
|
||||
guide_unet = target.structural_copy()
|
||||
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
|
||||
sa = attention_block.find(SelfAttention)
|
||||
assert sa is not None and sa.parent is not None
|
||||
sa = attention_block.ensure_find(SelfAttention)
|
||||
assert sa.parent is not None
|
||||
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
|
||||
|
||||
super().__init__(
|
||||
|
@ -109,8 +109,8 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
|
|||
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
|
||||
self.set_context(f"self_attention_context_{i}", {"norm": None})
|
||||
|
||||
sa = attention_block.find(SelfAttention)
|
||||
assert sa is not None and sa.parent is not None
|
||||
sa = attention_block.ensure_find(SelfAttention)
|
||||
assert sa.parent is not None
|
||||
|
||||
self.sub_adapters.append(
|
||||
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||
|
|
|
@ -16,8 +16,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
|||
projection: fl.Linear | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target=target):
|
||||
tokenizer = target.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
fl.SetContext(
|
||||
|
@ -45,9 +44,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
|||
|
||||
@property
|
||||
def tokenizer(self) -> CLIPTokenizer:
|
||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None, "Tokenizer not found."
|
||||
return tokenizer
|
||||
return self.ensure_find(CLIPTokenizer)
|
||||
|
||||
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
|
||||
|
|
|
@ -92,8 +92,7 @@ def test_tokenizer_with_special_character():
|
|||
token_extender.add_token("*", new_token_id)
|
||||
token_extender.inject(clip_tokenizer)
|
||||
|
||||
adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer)
|
||||
assert adapted_clip_tokenizer is not None
|
||||
adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer)
|
||||
|
||||
assert torch.allclose(
|
||||
adapted_clip_tokenizer.encode("*"),
|
||||
|
@ -122,8 +121,7 @@ def test_encoder(
|
|||
return_tensors="pt",
|
||||
).input_ids
|
||||
assert isinstance(ref_tokens, torch.Tensor)
|
||||
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None
|
||||
tokenizer = our_encoder_with_new_concepts.ensure_find(CLIPTokenizer)
|
||||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
|
|
|
@ -87,8 +87,7 @@ def test_encoder(
|
|||
return_tensors="pt",
|
||||
).input_ids
|
||||
assert isinstance(ref_tokens, torch.Tensor)
|
||||
tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
|
||||
assert tokenizer is not None
|
||||
tokenizer = our_encoder.ensure_find(CLIPTokenizer)
|
||||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
|
|
Loading…
Reference in a new issue