turn CLIPTokenizer into a fl.Module

This commit is contained in:
limiteinductive 2023-08-17 11:00:47 +02:00 committed by Benjamin Trom
parent 1ad4e1a35a
commit e7c1db50e0
8 changed files with 36 additions and 25 deletions

View file

@ -6,13 +6,16 @@ from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@torch.no_grad()
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderL()
x = dst_model.tokenizer("Nice cat", sequence_length=77)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping

View file

@ -4,6 +4,7 @@ from refiners.fluxion.utils import (
save_to_safetensors,
)
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget
from refiners.fluxion.layers.module import Module
@ -33,9 +34,10 @@ def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dic
@torch.no_grad()
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
x = dst_model.tokenizer("Nice cat", sequence_length=77)
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
def main() -> None:

View file

@ -6,6 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG
import refiners.fluxion.layers as fl
@ -15,8 +16,10 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderG()
# Extra projection layer (see CLIPTextModelWithProjection in transformers)
dst_model.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
x = dst_model.tokenizer("Nice cat", sequence_length=77)
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
tokenizer = dst_model.find(layer_type=CLIPTokenizer)
assert tokenizer is not None, "Could not find tokenizer"
tokens = tokenizer("Nice cat")
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[tokens], target_args=["Nice cat"]) # type: ignore
if mapping is None:
raise RuntimeError("Could not create state dict mapping")
state_dict = convert_state_dict(

View file

@ -131,7 +131,6 @@ class CLIPTextEncoder(fl.Chain):
"feedforward_dim",
"layer_norm_eps",
"use_quick_gelu",
"tokenizer",
]
def __init__(
@ -156,8 +155,9 @@ class CLIPTextEncoder(fl.Chain):
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu
self.tokenizer = tokenizer or CLIPTokenizer()
super().__init__(
tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),
fl.Converter(set_dtype=False),
fl.Sum(
TokenEncoder(
vocabulary_size=vocabulary_size,
@ -189,13 +189,9 @@ class CLIPTextEncoder(fl.Chain):
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
return self(tokens)
@property
def unconditional_text_embedding(self) -> Tensor:
return self.encode(text="")
return self("")
class CLIPTextEncoderL(CLIPTextEncoder):

View file

@ -5,17 +5,21 @@ from itertools import islice
import re
from torch import Tensor, tensor
from refiners.fluxion import pad
import refiners.fluxion.layers as fl
class CLIPTokenizer:
class CLIPTokenizer(fl.Module):
def __init__(
self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
sequence_length: int = 77,
start_of_text_token_id: int = 49406,
end_of_text_token_id: int = 49407,
pad_token_id: int = 49407,
) -> None:
super().__init__()
self.vocabulary_path = vocabulary_path
self.sequence_length = sequence_length
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
merge_tuples = [
@ -45,12 +49,12 @@ class CLIPTokenizer:
self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id
def __call__(self, text: str, sequence_length: int) -> Tensor:
tokens = self.encode(text=text, max_length=sequence_length).unsqueeze(dim=0)
def forward(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
assert (
tokens.shape[1] <= sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {sequence_length}"
return pad(x=tokens, pad=(0, sequence_length - tokens.shape[1]), value=self.pad_token_id)
tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:

View file

@ -88,7 +88,7 @@ class LatentDiffusionModel(Module):
return self.clip_text_encoder.unconditional_text_embedding
def compute_text_embedding(self, text: str) -> Tensor:
return self.clip_text_encoder.encode(text)
return self.clip_text_encoder(text)
def forward(
self,

View file

@ -89,7 +89,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
processed_image: Image.Image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder.encode(text=processed_caption).to(device=self.device)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:

View file

@ -7,7 +7,8 @@ from pathlib import Path
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.fluxion.utils import load_from_safetensors
import transformers # type: ignore
import transformers # type: ignore
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
long_prompt = """
@ -86,12 +87,14 @@ def test_encoder(
return_tensors="pt",
).input_ids
assert isinstance(ref_tokens, torch.Tensor)
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
tokenizer = our_encoder.find(layer_type=CLIPTokenizer)
assert tokenizer is not None
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad():
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder(our_tokens.to(test_device))
our_embeddings = our_encoder(prompt)
assert ref_embeddings.shape == (1, 77, 768)
assert our_embeddings.shape == (1, 77, 768)