clip text, lda encode batch inputs

* text_encoder([str1, str2])
* lda decode_latents/encode_image image_to_latent/latent_to_image
* images_to_tensor, tensor_to_images
---------
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Colle 2024-02-01 15:05:43 +01:00 committed by Cédric Deltheil
parent 12aa0b23f6
commit 4a6146bb6c
9 changed files with 107 additions and 48 deletions

View file

@ -10,6 +10,7 @@ from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import (
Tensor,
cat,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
@ -113,6 +114,12 @@ def gaussian_blur(
return tensor
def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
"""
Convert a PIL Image to a Tensor.
@ -135,6 +142,10 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
return image_tensor.unsqueeze(0)
def tensor_to_images(tensor: Tensor) -> list[Image.Image]:
return [tensor_to_image(t) for t in tensor.split(1)] # type: ignore
def tensor_to_image(tensor: Tensor) -> Image.Image:
"""
Convert a Tensor to a PIL Image.

View file

@ -4,7 +4,7 @@ from functools import lru_cache
from itertools import islice
from pathlib import Path
from torch import Tensor, tensor
from torch import Tensor, cat, tensor
import refiners.fluxion.layers as fl
from refiners.fluxion import pad
@ -51,11 +51,19 @@ class CLIPTokenizer(fl.Module):
self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id
def forward(self, text: str) -> Tensor:
def forward(self, text: str | list[str]) -> Tensor:
if isinstance(text, str):
return self.tokenize_str(text)
else:
assert isinstance(text, list), f"Expected type `str` or `list[str]`, got {type(text)}"
return cat([self.tokenize_str(txt) for txt in text])
def tokenize_str(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
assert (
tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
), f"Text is too long ({len(text)}): tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache()

View file

@ -15,7 +15,7 @@ from refiners.fluxion.layers import (
Sum,
Upsample,
)
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from refiners.fluxion.utils import images_to_tensor, tensor_to_images
class Resnet(Sum):
@ -210,12 +210,25 @@ class LatentDiffusionAutoencoder(Chain):
x = decoder(x / self.encoder_scale)
return x
def encode_image(self, image: Image.Image) -> Tensor:
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
def image_to_latents(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image])
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
x = 2 * x - 1
return self.encode(x)
# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)
def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
return self.latents_to_images(x)[0]
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
x = self.decode(x)
x = (x + 1) / 2
return tensor_to_image(x)
return tensor_to_images(x)

View file

@ -50,7 +50,7 @@ class LatentDiffusionModel(fl.Module, ABC):
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
return self.solver.add_noise(
x=encoded_image,
noise=noise,

View file

@ -83,8 +83,15 @@ class MultiDiffusion(Generic[T, D], ABC):
def dtype(self) -> DType:
return self.ldm.dtype
# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x)
return self.latents_to_image(x=x)
def latents_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latents_to_image(x=x)
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)
@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:

View file

@ -113,7 +113,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
latents = self.lda.image_to_latents(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
@ -202,7 +202,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
step=step,
clip_text_embedding=clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
self.log(data=images)

View file

@ -666,7 +666,7 @@ def test_diffusion_std_random_init(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init)
@ -696,7 +696,7 @@ def test_diffusion_std_random_init_euler(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_euler)
@ -721,7 +721,7 @@ def test_diffusion_karras_random_init(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
@ -749,7 +749,7 @@ def test_diffusion_std_random_init_float16(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@ -777,7 +777,7 @@ def test_diffusion_std_random_init_sag(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@ -806,7 +806,7 @@ def test_diffusion_std_init_image(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_init_image)
@ -823,7 +823,7 @@ def test_rectangular_init_latents(
rect_init_image = cutecat_init.crop((0, 0, width, height))
x = sd15.init_latents((height, width), rect_init_image)
assert sd15.lda.decode_latents(x).size == (width, height)
assert sd15.lda.latents_to_image(x).size == (width, height)
@no_grad()
@ -853,7 +853,7 @@ def test_diffusion_inpainting(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
@ -887,7 +887,7 @@ def test_diffusion_inpainting_float16(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
@ -930,7 +930,7 @@ def test_diffusion_controlnet(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -973,7 +973,7 @@ def test_diffusion_controlnet_structural_copy(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1015,7 +1015,7 @@ def test_diffusion_controlnet_float16(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1069,7 +1069,7 @@ def test_diffusion_controlnet_stack(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@ -1101,7 +1101,7 @@ def test_diffusion_lora(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1144,7 +1144,7 @@ def test_diffusion_sdxl_lora(
condition_scale=guidance_scale,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1192,7 +1192,7 @@ def test_diffusion_sdxl_multiple_loras(
condition_scale=guidance_scale,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1211,7 +1211,7 @@ def test_diffusion_refonly(
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
guide = sd15.lda.encode_image(condition_image_refonly)
guide = sd15.lda.image_to_latents(condition_image_refonly)
guide = torch.cat((guide, guide))
manual_seed(2)
@ -1228,7 +1228,7 @@ def test_diffusion_refonly(
condition_scale=7.5,
)
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
# min_psnr lowered to 33 because this reference image was generated without noise removal (see #192)
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99)
@ -1253,7 +1253,7 @@ def test_diffusion_inpainting_refonly(
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
guide = sd15.lda.image_to_latents(scene_image_inpainting_refonly)
guide = torch.cat((guide, guide))
manual_seed(2)
@ -1273,7 +1273,7 @@ def test_diffusion_inpainting_refonly(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
@ -1306,7 +1306,7 @@ def test_diffusion_textual_inversion_random_init(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@ -1351,7 +1351,7 @@ def test_diffusion_ip_adapter(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@ -1440,7 +1440,7 @@ def test_diffusion_sdxl_ip_adapter(
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big"
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@ -1496,7 +1496,7 @@ def test_diffusion_ip_adapter_controlnet(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@ -1537,7 +1537,7 @@ def test_diffusion_ip_adapter_plus(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@ -1584,7 +1584,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
condition_scale=5,
)
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@ -1618,7 +1618,7 @@ def test_sdxl_random_init(
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
predicted_image = sdxl.lda.latents_to_image(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@ -1653,7 +1653,7 @@ def test_sdxl_random_init_sag(
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
predicted_image = sdxl.lda.latents_to_image(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@ -1766,7 +1766,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
step=step,
targets=[target_1, target_2],
)
result = sd.lda.decode_latents(x=x)
result = sd.lda.latents_to_image(x=x)
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@ -1805,7 +1805,7 @@ def test_t2i_adapter_depth(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
@ -1853,7 +1853,7 @@ def test_t2i_adapter_xl_canny(
time_ids=time_ids,
condition_scale=7.5,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
@ -1892,7 +1892,7 @@ def test_restart(
condition_scale=8,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
@ -1924,7 +1924,7 @@ def test_freeu(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_freeu)
@ -1980,6 +1980,6 @@ def test_hello_world(
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)

View file

@ -101,3 +101,14 @@ def test_encoder(
# numerical differences depending on the backend.
# Also we use FP16 weights.
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
def test_list_string_tokenizer(
prompt: str,
our_encoder: CLIPTextEncoderL,
):
tokenizer = our_encoder.ensure_find(CLIPTokenizer)
# batched inputs
double_tokens = tokenizer([prompt, prompt[0:3]])
assert double_tokens.shape[0] == 2

View file

@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image:
@no_grad()
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.encode_image(sample_image)
decoded = encoder.decode_latents(encoded)
def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.image_to_latents(sample_image)
decoded = encoder.latents_to_image(encoded)
assert decoded.mode == "RGB"
@ -49,3 +49,12 @@ def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.images_to_latents([sample_image, sample_image])
images = encoder.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)