From 4a6146bb6c3f8864927a8ed8e1e845b6e8140af3 Mon Sep 17 00:00:00 2001 From: Colle Date: Thu, 1 Feb 2024 15:05:43 +0100 Subject: [PATCH] clip text, lda encode batch inputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * text_encoder([str1, str2]) * lda decode_latents/encode_image image_to_latent/latent_to_image * images_to_tensor, tensor_to_images --------- Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com> --- src/refiners/fluxion/utils.py | 11 +++ src/refiners/foundationals/clip/tokenizer.py | 14 +++- .../latent_diffusion/auto_encoder.py | 21 ++++-- .../foundationals/latent_diffusion/model.py | 2 +- .../latent_diffusion/multi_diffusion.py | 9 ++- .../training_utils/latent_diffusion.py | 4 +- tests/e2e/test_diffusion.py | 68 +++++++++---------- tests/foundationals/clip/test_text_encoder.py | 11 +++ .../latent_diffusion/test_auto_encoder.py | 15 +++- 9 files changed, 107 insertions(+), 48 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 32a302b..68b97c2 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -10,6 +10,7 @@ from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore from torch import ( Tensor, + cat, device as Device, dtype as DType, manual_seed as _manual_seed, # type: ignore @@ -113,6 +114,12 @@ def gaussian_blur( return tensor +def images_to_tensor( + images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None +) -> Tensor: + return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images]) + + def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: """ Convert a PIL Image to a Tensor. @@ -135,6 +142,10 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp return image_tensor.unsqueeze(0) +def tensor_to_images(tensor: Tensor) -> list[Image.Image]: + return [tensor_to_image(t) for t in tensor.split(1)] # type: ignore + + def tensor_to_image(tensor: Tensor) -> Image.Image: """ Convert a Tensor to a PIL Image. diff --git a/src/refiners/foundationals/clip/tokenizer.py b/src/refiners/foundationals/clip/tokenizer.py index 585474d..9df0fcf 100644 --- a/src/refiners/foundationals/clip/tokenizer.py +++ b/src/refiners/foundationals/clip/tokenizer.py @@ -4,7 +4,7 @@ from functools import lru_cache from itertools import islice from pathlib import Path -from torch import Tensor, tensor +from torch import Tensor, cat, tensor import refiners.fluxion.layers as fl from refiners.fluxion import pad @@ -51,11 +51,19 @@ class CLIPTokenizer(fl.Module): self.end_of_text_token_id: int = end_of_text_token_id self.pad_token_id: int = pad_token_id - def forward(self, text: str) -> Tensor: + def forward(self, text: str | list[str]) -> Tensor: + if isinstance(text, str): + return self.tokenize_str(text) + else: + assert isinstance(text, list), f"Expected type `str` or `list[str]`, got {type(text)}" + return cat([self.tokenize_str(txt) for txt in text]) + + def tokenize_str(self, text: str) -> Tensor: tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0) + assert ( tokens.shape[1] <= self.sequence_length - ), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}" + ), f"Text is too long ({len(text)}): tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}" return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id) @lru_cache() diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 9fee47d..70294f5 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -15,7 +15,7 @@ from refiners.fluxion.layers import ( Sum, Upsample, ) -from refiners.fluxion.utils import image_to_tensor, tensor_to_image +from refiners.fluxion.utils import images_to_tensor, tensor_to_images class Resnet(Sum): @@ -210,12 +210,25 @@ class LatentDiffusionAutoencoder(Chain): x = decoder(x / self.encoder_scale) return x - def encode_image(self, image: Image.Image) -> Tensor: - x = image_to_tensor(image, device=self.device, dtype=self.dtype) + def image_to_latents(self, image: Image.Image) -> Tensor: + return self.images_to_latents([image]) + + def images_to_latents(self, images: list[Image.Image]) -> Tensor: + x = images_to_tensor(images, device=self.device, dtype=self.dtype) x = 2 * x - 1 return self.encode(x) + # backward-compatibility alias def decode_latents(self, x: Tensor) -> Image.Image: + return self.latents_to_image(x) + + def latents_to_image(self, x: Tensor) -> Image.Image: + if x.shape[0] != 1: + raise ValueError(f"Expected batch size of 1, got {x.shape[0]}") + + return self.latents_to_images(x)[0] + + def latents_to_images(self, x: Tensor) -> list[Image.Image]: x = self.decode(x) x = (x + 1) / 2 - return tensor_to_image(x) + return tensor_to_images(x) diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index d151b1e..b4a2f82 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -50,7 +50,7 @@ class LatentDiffusionModel(fl.Module, ABC): ], f"noise shape is not compatible: {noise.shape}, with size: {size}" if init_image is None: return noise - encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) + encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height))) return self.solver.add_noise( x=encoded_image, noise=noise, diff --git a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py index 10f0f1f..1f9c482 100644 --- a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py @@ -83,8 +83,15 @@ class MultiDiffusion(Generic[T, D], ABC): def dtype(self) -> DType: return self.ldm.dtype + # backward-compatibility alias def decode_latents(self, x: Tensor) -> Image.Image: - return self.ldm.lda.decode_latents(x=x) + return self.latents_to_image(x=x) + + def latents_to_image(self, x: Tensor) -> Image.Image: + return self.ldm.lda.latents_to_image(x=x) + + def latents_to_images(self, x: Tensor) -> list[Image.Image]: + return self.ldm.lda.latents_to_images(x=x) @staticmethod def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]: diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index b72357c..28aef44 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -113,7 +113,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): max_size=self.config.dataset.resize_image_max_size, ) processed_image = self.process_image(resized_image) - latents = self.lda.encode_image(image=processed_image).to(device=self.device) + latents = self.lda.image_to_latents(image=processed_image).to(device=self.device) processed_caption = self.process_caption(caption=caption) clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device) return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents) @@ -202,7 +202,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): step=step, clip_text_embedding=clip_text_embedding, ) - canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i)) + canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i)) images[prompt] = canvas_image self.log(data=images) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 5ff987d..d137ac6 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -666,7 +666,7 @@ def test_diffusion_std_random_init( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init) @@ -696,7 +696,7 @@ def test_diffusion_std_random_init_euler( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init_euler) @@ -721,7 +721,7 @@ def test_diffusion_karras_random_init( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) @@ -749,7 +749,7 @@ def test_diffusion_std_random_init_float16( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) @@ -777,7 +777,7 @@ def test_diffusion_std_random_init_sag( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_random_init_sag) @@ -806,7 +806,7 @@ def test_diffusion_std_init_image( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_std_init_image) @@ -823,7 +823,7 @@ def test_rectangular_init_latents( rect_init_image = cutecat_init.crop((0, 0, width, height)) x = sd15.init_latents((height, width), rect_init_image) - assert sd15.lda.decode_latents(x).size == (width, height) + assert sd15.lda.latents_to_image(x).size == (width, height) @no_grad() @@ -853,7 +853,7 @@ def test_diffusion_inpainting( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) # PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves. ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) @@ -887,7 +887,7 @@ def test_diffusion_inpainting_float16( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) # PSNR and SSIM values are large because float16 is even worse than float32. ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) @@ -930,7 +930,7 @@ def test_diffusion_controlnet( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -973,7 +973,7 @@ def test_diffusion_controlnet_structural_copy( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -1015,7 +1015,7 @@ def test_diffusion_controlnet_float16( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -1069,7 +1069,7 @@ def test_diffusion_controlnet_stack( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) @@ -1101,7 +1101,7 @@ def test_diffusion_lora( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -1144,7 +1144,7 @@ def test_diffusion_sdxl_lora( condition_scale=guidance_scale, ) - predicted_image = sdxl.lda.decode_latents(x) + predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -1192,7 +1192,7 @@ def test_diffusion_sdxl_multiple_loras( condition_scale=guidance_scale, ) - predicted_image = sdxl.lda.decode_latents(x) + predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -1211,7 +1211,7 @@ def test_diffusion_refonly( refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() - guide = sd15.lda.encode_image(condition_image_refonly) + guide = sd15.lda.image_to_latents(condition_image_refonly) guide = torch.cat((guide, guide)) manual_seed(2) @@ -1228,7 +1228,7 @@ def test_diffusion_refonly( condition_scale=7.5, ) torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) # min_psnr lowered to 33 because this reference image was generated without noise removal (see #192) ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99) @@ -1253,7 +1253,7 @@ def test_diffusion_inpainting_refonly( sd15.set_inference_steps(30) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) - guide = sd15.lda.encode_image(scene_image_inpainting_refonly) + guide = sd15.lda.image_to_latents(scene_image_inpainting_refonly) guide = torch.cat((guide, guide)) manual_seed(2) @@ -1273,7 +1273,7 @@ def test_diffusion_inpainting_refonly( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) @@ -1306,7 +1306,7 @@ def test_diffusion_textual_inversion_random_init( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) @@ -1351,7 +1351,7 @@ def test_diffusion_ip_adapter( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) @@ -1440,7 +1440,7 @@ def test_diffusion_sdxl_ip_adapter( # See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the # internal activation values are too big" sdxl.lda.to(dtype=torch.float32) - predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32)) + predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32)) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) @@ -1496,7 +1496,7 @@ def test_diffusion_ip_adapter_controlnet( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) @@ -1537,7 +1537,7 @@ def test_diffusion_ip_adapter_plus( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) @@ -1584,7 +1584,7 @@ def test_diffusion_sdxl_ip_adapter_plus( condition_scale=5, ) sdxl.lda.to(dtype=torch.float32) - predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32)) + predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32)) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) @@ -1618,7 +1618,7 @@ def test_sdxl_random_init( time_ids=time_ids, condition_scale=5, ) - predicted_image = sdxl.lda.decode_latents(x=x) + predicted_image = sdxl.lda.latents_to_image(x=x) ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) @@ -1653,7 +1653,7 @@ def test_sdxl_random_init_sag( time_ids=time_ids, condition_scale=5, ) - predicted_image = sdxl.lda.decode_latents(x=x) + predicted_image = sdxl.lda.latents_to_image(x=x) ensure_similar_images(img_1=predicted_image, img_2=expected_image) @@ -1766,7 +1766,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: step=step, targets=[target_1, target_2], ) - result = sd.lda.decode_latents(x=x) + result = sd.lda.latents_to_image(x=x) ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) @@ -1805,7 +1805,7 @@ def test_t2i_adapter_depth( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) @@ -1853,7 +1853,7 @@ def test_t2i_adapter_xl_canny( time_ids=time_ids, condition_scale=7.5, ) - predicted_image = sdxl.lda.decode_latents(x) + predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) @@ -1892,7 +1892,7 @@ def test_restart( condition_scale=8, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) @@ -1924,7 +1924,7 @@ def test_freeu( clip_text_embedding=clip_text_embedding, condition_scale=7.5, ) - predicted_image = sd15.lda.decode_latents(x) + predicted_image = sd15.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_freeu) @@ -1980,6 +1980,6 @@ def test_hello_world( pooled_text_embedding=pooled_text_embedding, time_ids=time_ids, ) - predicted_image = sdxl.lda.decode_latents(x) + predicted_image = sdxl.lda.latents_to_image(x) ensure_similar_images(predicted_image, expected_image) diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index f1b6f07..4ea510a 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -101,3 +101,14 @@ def test_encoder( # numerical differences depending on the backend. # Also we use FP16 weights. assert (our_embeddings - ref_embeddings).abs().max() < 0.01 + + +def test_list_string_tokenizer( + prompt: str, + our_encoder: CLIPTextEncoderL, +): + tokenizer = our_encoder.ensure_find(CLIPTokenizer) + + # batched inputs + double_tokens = tokenizer([prompt, prompt[0:3]]) + assert double_tokens.shape[0] == 2 diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index 462c407..871135b 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image: @no_grad() -def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): - encoded = encoder.encode_image(sample_image) - decoded = encoder.decode_latents(encoded) +def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + encoded = encoder.image_to_latents(sample_image) + decoded = encoder.latents_to_image(encoded) assert decoded.mode == "RGB" @@ -49,3 +49,12 @@ def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image. assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9) + + +@no_grad() +def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): + encoded = encoder.images_to_latents([sample_image, sample_image]) + images = encoder.latents_to_images(encoded) + assert isinstance(images, list) + assert len(images) == 2 + ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)