turn CLIPTokenizer into a fl.Module

This commit is contained in:
limiteinductive 2023-08-17 11:00:47 +02:00 committed by Benjamin Trom
parent 1ad4e1a35a
commit e7c1db50e0
8 changed files with 36 additions and 25 deletions

View file

@ -6,13 +6,16 @@ from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@torch.no_grad() @torch.no_grad()
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderL() dst_model = CLIPTextEncoderL()
x = dst_model.tokenizer("Nice cat", sequence_length=77) tokenizer = dst_model.find(layer_type=CLIPTokenizer)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
assert mapping is not None, "Model conversion failed" assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict( state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping

View file

@ -4,6 +4,7 @@ from refiners.fluxion.utils import (
save_to_safetensors, save_to_safetensors,
) )
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.unet import UNet from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget from refiners.foundationals.latent_diffusion.lora import LoraTarget
from refiners.fluxion.layers.module import Module from refiners.fluxion.layers.module import Module
@ -33,9 +34,10 @@ def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dic
@torch.no_grad() @torch.no_grad()
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None: def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
x = dst_model.tokenizer("Nice cat", sequence_length=77) tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore tokens = tokenizer("Nice cat")
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
def main() -> None: def main() -> None:

View file

@ -6,6 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import DiffusionPipeline # type: ignore from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
@ -15,8 +16,10 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderG() dst_model = CLIPTextEncoderG()
# Extra projection layer (see CLIPTextModelWithProjection in transformers) # Extra projection layer (see CLIPTextModelWithProjection in transformers)
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False)) dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
x = dst_model.tokenizer("Nice cat", sequence_length=77) tokenizer = dst_model.find(layer_type=CLIPTokenizer)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
if mapping is None: if mapping is None:
raise RuntimeError("Could not create state dict mapping") raise RuntimeError("Could not create state dict mapping")
state_dict = convert_state_dict( state_dict = convert_state_dict(

View file

@ -131,7 +131,6 @@ class CLIPTextEncoder(fl.Chain):
"feedforward_dim", "feedforward_dim",
"layer_norm_eps", "layer_norm_eps",
"use_quick_gelu", "use_quick_gelu",
"tokenizer",
] ]
def __init__( def __init__(
@ -156,8 +155,9 @@ class CLIPTextEncoder(fl.Chain):
self.feedforward_dim = feedforward_dim self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu self.use_quick_gelu = use_quick_gelu
self.tokenizer = tokenizer or CLIPTokenizer()
super().__init__( super().__init__(
tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),
fl.Converter(set_dtype=False),
fl.Sum( fl.Sum(
TokenEncoder( TokenEncoder(
vocabulary_size=vocabulary_size, vocabulary_size=vocabulary_size,
@ -189,13 +189,9 @@ class CLIPTextEncoder(fl.Chain):
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)): for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU()) parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
return self(tokens)
@property @property
def unconditional_text_embedding(self) -> Tensor: def unconditional_text_embedding(self) -> Tensor:
return self.encode(text="") return self("")
class CLIPTextEncoderL(CLIPTextEncoder): class CLIPTextEncoderL(CLIPTextEncoder):

View file

@ -5,17 +5,21 @@ from itertools import islice
import re import re
from torch import Tensor, tensor from torch import Tensor, tensor
from refiners.fluxion import pad from refiners.fluxion import pad
import refiners.fluxion.layers as fl
class CLIPTokenizer: class CLIPTokenizer(fl.Module):
def __init__( def __init__(
self, self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz", vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
sequence_length: int = 77,
start_of_text_token_id: int = 49406, start_of_text_token_id: int = 49406,
end_of_text_token_id: int = 49407, end_of_text_token_id: int = 49407,
pad_token_id: int = 49407, pad_token_id: int = 49407,
) -> None: ) -> None:
super().__init__()
self.vocabulary_path = vocabulary_path self.vocabulary_path = vocabulary_path
self.sequence_length = sequence_length
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping() self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()} self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
merge_tuples = [ merge_tuples = [
@ -45,12 +49,12 @@ class CLIPTokenizer:
self.end_of_text_token_id: int = end_of_text_token_id self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id self.pad_token_id: int = pad_token_id
def __call__(self, text: str, sequence_length: int) -> Tensor: def forward(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0) tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
assert ( assert (
tokens.shape[1] <= sequence_length tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}" ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id) return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache() @lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]: def get_bytes_to_unicode_mapping(self) -> dict[int, str]:

View file

@ -88,7 +88,7 @@ class LatentDiffusionModel(Module):
return self.clip_text_encoder.unconditional_text_embedding return self.clip_text_encoder.unconditional_text_embedding
def compute_text_embedding(self, text: str) -> Tensor: def compute_text_embedding(self, text: str) -> Tensor:
return self.clip_text_encoder.encode(text) return self.clip_text_encoder(text)
def forward( def forward(
self, self,

View file

@ -89,7 +89,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
processed_image: Image.Image = self.process_image(resized_image) processed_image: Image.Image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device) latents = self.lda.encode_image(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption) processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device) clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents) return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch: def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:

View file

@ -7,7 +7,8 @@ from pathlib import Path
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.utils import load_from_safetensors
import transformers # type: ignore import transformers # type: ignore
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
long_prompt = """ long_prompt = """
@ -86,12 +87,14 @@ def test_encoder(
return_tensors="pt", return_tensors="pt",
).input_ids ).input_ids
assert isinstance(ref_tokens, torch.Tensor) assert isinstance(ref_tokens, torch.Tensor)
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length) tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
assert tokenizer is not None
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens) assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad(): with torch.no_grad():
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder(our_tokens.to(test_device)) our_embeddings = our_encoder(prompt)
assert ref_embeddings.shape == (1, 77, 768) assert ref_embeddings.shape == (1, 77, 768)
assert our_embeddings.shape == (1, 77, 768) assert our_embeddings.shape == (1, 77, 768)