mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add ensure_find and ensure_find_parent helpers
This commit is contained in:
parent
b69dbc4e5c
commit
7a32699cc6
|
@ -63,8 +63,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
target_state_dict = target.state_dict()
|
target_state_dict = target.state_dict()
|
||||||
|
|
||||||
# Remove the class embedding from state dict since it was not mapped by the model converter
|
# Remove the class embedding from state dict since it was not mapped by the model converter
|
||||||
class_embedding = target.find(fl.Parameter)
|
class_embedding = target.ensure_find(fl.Parameter)
|
||||||
assert class_embedding is not None
|
|
||||||
class_embedding_key = next(
|
class_embedding_key = next(
|
||||||
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
|
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
|
||||||
)
|
)
|
||||||
|
|
|
@ -46,8 +46,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError(f"Unsupported architecture: {architecture}")
|
raise RuntimeError(f"Unsupported architecture: {architecture}")
|
||||||
text = "What a nice cat you have there!"
|
text = "What a nice cat you have there!"
|
||||||
tokenizer = target.find(layer_type=CLIPTokenizer)
|
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None, "Could not find tokenizer"
|
|
||||||
tokens = tokenizer(text)
|
tokens = tokenizer(text)
|
||||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||||
if not converter.run(source_args=(tokens,), target_args=(text,)):
|
if not converter.run(source_args=(tokens,), target_args=(text,)):
|
||||||
|
|
|
@ -97,18 +97,15 @@ class TextualInversionConfig(BaseModel):
|
||||||
|
|
||||||
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
||||||
adapter = ConceptExtender(target=text_encoder)
|
adapter = ConceptExtender(target=text_encoder)
|
||||||
tokenizer = text_encoder.find(layer_type=CLIPTokenizer)
|
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None, "Tokenizer not found in text encoder."
|
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
||||||
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
|
||||||
assert token_encoder is not None, "Token encoder not found in text encoder."
|
|
||||||
if self.initializer_token is not None:
|
if self.initializer_token is not None:
|
||||||
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
||||||
assert " " not in bpe, "This initializer_token is not a single token."
|
assert " " not in bpe, "This initializer_token is not a single token."
|
||||||
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
||||||
init_embedding = token_encoder(token).squeeze(0)
|
init_embedding = token_encoder(token).squeeze(0)
|
||||||
else:
|
else:
|
||||||
token_encoder = text_encoder.find(layer_type=TokenEncoder)
|
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
||||||
assert token_encoder is not None, "Token encoder not found in text encoder."
|
|
||||||
init_embedding = randn(token_encoder.embedding_dim)
|
init_embedding = randn(token_encoder.embedding_dim)
|
||||||
adapter.add_concept(self.placeholder_token, init_embedding)
|
adapter.add_concept(self.placeholder_token, init_embedding)
|
||||||
adapter.inject()
|
adapter.inject()
|
||||||
|
@ -146,8 +143,7 @@ class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||||
|
|
||||||
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||||
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||||
embedding_extender = trainer.text_encoder.find(layer_type=EmbeddingExtender)
|
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
|
||||||
assert embedding_extender is not None, "Embedding extender not found in text encoder."
|
|
||||||
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
||||||
|
|
||||||
save_to_safetensors(
|
save_to_safetensors(
|
||||||
|
|
|
@ -53,10 +53,7 @@ class Adapter(Generic[T]):
|
||||||
|
|
||||||
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
|
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
|
||||||
# i.e. initializing two adapters before injecting them.
|
# i.e. initializing two adapters before injecting them.
|
||||||
true_parent = parent.find_parent(self.target)
|
true_parent = parent.ensure_find_parent(self.target)
|
||||||
if true_parent is None:
|
|
||||||
raise ValueError(f"{self.target} is not in {parent}")
|
|
||||||
|
|
||||||
true_parent.replace(
|
true_parent.replace(
|
||||||
old_module=self.target,
|
old_module=self.target,
|
||||||
new_module=self,
|
new_module=self,
|
||||||
|
|
|
@ -292,6 +292,11 @@ class Chain(ContextModule):
|
||||||
def find(self, layer_type: type[T]) -> T | None:
|
def find(self, layer_type: type[T]) -> T | None:
|
||||||
return next(self.layers(layer_type=layer_type), None)
|
return next(self.layers(layer_type=layer_type), None)
|
||||||
|
|
||||||
|
def ensure_find(self, layer_type: type[T]) -> T:
|
||||||
|
r = self.find(layer_type)
|
||||||
|
assert r is not None, f"could not find {layer_type} in {self}"
|
||||||
|
return r
|
||||||
|
|
||||||
def find_parent(self, module: Module) -> "Chain | None":
|
def find_parent(self, module: Module) -> "Chain | None":
|
||||||
if module in self: # avoid DFS-crawling the whole tree
|
if module in self: # avoid DFS-crawling the whole tree
|
||||||
return self
|
return self
|
||||||
|
@ -299,6 +304,11 @@ class Chain(ContextModule):
|
||||||
return parent
|
return parent
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def ensure_find_parent(self, module: Module) -> "Chain":
|
||||||
|
r = self.find_parent(module)
|
||||||
|
assert r is not None, f"could not find {module} in {self}"
|
||||||
|
return r
|
||||||
|
|
||||||
def insert(self, index: int, module: Module) -> None:
|
def insert(self, index: int, module: Module) -> None:
|
||||||
if index < 0:
|
if index < 0:
|
||||||
index = max(0, len(self._modules) + index + 1)
|
index = max(0, len(self._modules) + index + 1)
|
||||||
|
|
|
@ -114,8 +114,7 @@ class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||||
|
|
||||||
def add_token(self, token: str, token_id: int) -> None:
|
def add_token(self, token: str, token_id: int) -> None:
|
||||||
token = token.lower()
|
token = token.lower()
|
||||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
tokenizer = self.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None, "Tokenizer not found."
|
|
||||||
assert token_id not in tokenizer.token_to_id_mapping.values()
|
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||||
tokenizer.token_to_id_mapping[token] = token_id
|
tokenizer.token_to_id_mapping[token] = token_id
|
||||||
current_pattern = tokenizer.token_pattern.pattern
|
current_pattern = tokenizer.token_pattern.pattern
|
||||||
|
|
|
@ -58,8 +58,8 @@ class SelfAttentionInjectionPassthrough(Passthrough):
|
||||||
def __init__(self, target: SD1UNet) -> None:
|
def __init__(self, target: SD1UNet) -> None:
|
||||||
guide_unet = target.structural_copy()
|
guide_unet = target.structural_copy()
|
||||||
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
|
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
|
||||||
sa = attention_block.find(SelfAttention)
|
sa = attention_block.ensure_find(SelfAttention)
|
||||||
assert sa is not None and sa.parent is not None
|
assert sa.parent is not None
|
||||||
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
|
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -109,8 +109,8 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
|
||||||
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
|
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
|
||||||
self.set_context(f"self_attention_context_{i}", {"norm": None})
|
self.set_context(f"self_attention_context_{i}", {"norm": None})
|
||||||
|
|
||||||
sa = attention_block.find(SelfAttention)
|
sa = attention_block.ensure_find(SelfAttention)
|
||||||
assert sa is not None and sa.parent is not None
|
assert sa.parent is not None
|
||||||
|
|
||||||
self.sub_adapters.append(
|
self.sub_adapters.append(
|
||||||
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||||
|
|
|
@ -16,8 +16,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
projection: fl.Linear | None = None,
|
projection: fl.Linear | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
with self.setup_adapter(target=target):
|
with self.setup_adapter(target=target):
|
||||||
tokenizer = target.find(layer_type=CLIPTokenizer)
|
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None, "Tokenizer not found."
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
fl.SetContext(
|
fl.SetContext(
|
||||||
|
@ -45,9 +44,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tokenizer(self) -> CLIPTokenizer:
|
def tokenizer(self) -> CLIPTokenizer:
|
||||||
tokenizer = self.find(layer_type=CLIPTokenizer)
|
return self.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None, "Tokenizer not found."
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||||
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
|
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
|
||||||
|
|
|
@ -92,8 +92,7 @@ def test_tokenizer_with_special_character():
|
||||||
token_extender.add_token("*", new_token_id)
|
token_extender.add_token("*", new_token_id)
|
||||||
token_extender.inject(clip_tokenizer)
|
token_extender.inject(clip_tokenizer)
|
||||||
|
|
||||||
adapted_clip_tokenizer = clip_tokenizer.find(layer_type=CLIPTokenizer)
|
adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer)
|
||||||
assert adapted_clip_tokenizer is not None
|
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
adapted_clip_tokenizer.encode("*"),
|
adapted_clip_tokenizer.encode("*"),
|
||||||
|
@ -122,8 +121,7 @@ def test_encoder(
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).input_ids
|
).input_ids
|
||||||
assert isinstance(ref_tokens, torch.Tensor)
|
assert isinstance(ref_tokens, torch.Tensor)
|
||||||
tokenizer = our_encoder_with_new_concepts.find(layer_type=CLIPTokenizer)
|
tokenizer = our_encoder_with_new_concepts.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None
|
|
||||||
our_tokens = tokenizer(prompt)
|
our_tokens = tokenizer(prompt)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
|
|
|
@ -87,8 +87,7 @@ def test_encoder(
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).input_ids
|
).input_ids
|
||||||
assert isinstance(ref_tokens, torch.Tensor)
|
assert isinstance(ref_tokens, torch.Tensor)
|
||||||
tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
|
tokenizer = our_encoder.ensure_find(CLIPTokenizer)
|
||||||
assert tokenizer is not None
|
|
||||||
our_tokens = tokenizer(prompt)
|
our_tokens = tokenizer(prompt)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue