clip text, lda encode batch inputs

* text_encoder([str1, str2])
* lda decode_latents/encode_image image_to_latent/latent_to_image
* images_to_tensor, tensor_to_images
---------
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Colle 2024-02-01 15:05:43 +01:00 committed by Cédric Deltheil
parent 12aa0b23f6
commit 4a6146bb6c
9 changed files with 107 additions and 48 deletions

View file

@ -10,6 +10,7 @@ from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore from safetensors.torch import save_file as _save_file # type: ignore
from torch import ( from torch import (
Tensor, Tensor,
cat,
device as Device, device as Device,
dtype as DType, dtype as DType,
manual_seed as _manual_seed, # type: ignore manual_seed as _manual_seed, # type: ignore
@ -113,6 +114,12 @@ def gaussian_blur(
return tensor return tensor
def images_to_tensor(
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
) -> Tensor:
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
""" """
Convert a PIL Image to a Tensor. Convert a PIL Image to a Tensor.
@ -135,6 +142,10 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
return image_tensor.unsqueeze(0) return image_tensor.unsqueeze(0)
def tensor_to_images(tensor: Tensor) -> list[Image.Image]:
return [tensor_to_image(t) for t in tensor.split(1)] # type: ignore
def tensor_to_image(tensor: Tensor) -> Image.Image: def tensor_to_image(tensor: Tensor) -> Image.Image:
""" """
Convert a Tensor to a PIL Image. Convert a Tensor to a PIL Image.

View file

@ -4,7 +4,7 @@ from functools import lru_cache
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from torch import Tensor, tensor from torch import Tensor, cat, tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion import pad from refiners.fluxion import pad
@ -51,11 +51,19 @@ class CLIPTokenizer(fl.Module):
self.end_of_text_token_id: int = end_of_text_token_id self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id self.pad_token_id: int = pad_token_id
def forward(self, text: str) -> Tensor: def forward(self, text: str | list[str]) -> Tensor:
if isinstance(text, str):
return self.tokenize_str(text)
else:
assert isinstance(text, list), f"Expected type `str` or `list[str]`, got {type(text)}"
return cat([self.tokenize_str(txt) for txt in text])
def tokenize_str(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0) tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
assert ( assert (
tokens.shape[1] <= self.sequence_length tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}" ), f"Text is too long ({len(text)}): tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id) return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache() @lru_cache()

View file

@ -15,7 +15,7 @@ from refiners.fluxion.layers import (
Sum, Sum,
Upsample, Upsample,
) )
from refiners.fluxion.utils import image_to_tensor, tensor_to_image from refiners.fluxion.utils import images_to_tensor, tensor_to_images
class Resnet(Sum): class Resnet(Sum):
@ -210,12 +210,25 @@ class LatentDiffusionAutoencoder(Chain):
x = decoder(x / self.encoder_scale) x = decoder(x / self.encoder_scale)
return x return x
def encode_image(self, image: Image.Image) -> Tensor: def image_to_latents(self, image: Image.Image) -> Tensor:
x = image_to_tensor(image, device=self.device, dtype=self.dtype) return self.images_to_latents([image])
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
x = 2 * x - 1 x = 2 * x - 1
return self.encode(x) return self.encode(x)
# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image: def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)
def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
return self.latents_to_images(x)[0]
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
x = self.decode(x) x = self.decode(x)
x = (x + 1) / 2 x = (x + 1) / 2
return tensor_to_image(x) return tensor_to_images(x)

View file

@ -50,7 +50,7 @@ class LatentDiffusionModel(fl.Module, ABC):
], f"noise shape is not compatible: {noise.shape}, with size: {size}" ], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None: if init_image is None:
return noise return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
return self.solver.add_noise( return self.solver.add_noise(
x=encoded_image, x=encoded_image,
noise=noise, noise=noise,

View file

@ -83,8 +83,15 @@ class MultiDiffusion(Generic[T, D], ABC):
def dtype(self) -> DType: def dtype(self) -> DType:
return self.ldm.dtype return self.ldm.dtype
# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image: def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x) return self.latents_to_image(x=x)
def latents_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latents_to_image(x=x)
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)
@staticmethod @staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]: def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:

View file

@ -113,7 +113,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
max_size=self.config.dataset.resize_image_max_size, max_size=self.config.dataset.resize_image_max_size,
) )
processed_image = self.process_image(resized_image) processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device) latents = self.lda.image_to_latents(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption) processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device) clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents) return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
@ -202,7 +202,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
step=step, step=step,
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
) )
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i)) canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image images[prompt] = canvas_image
self.log(data=images) self.log(data=images)

View file

@ -666,7 +666,7 @@ def test_diffusion_std_random_init(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init) ensure_similar_images(predicted_image, expected_image_std_random_init)
@ -696,7 +696,7 @@ def test_diffusion_std_random_init_euler(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_euler) ensure_similar_images(predicted_image, expected_image_std_random_init_euler)
@ -721,7 +721,7 @@ def test_diffusion_karras_random_init(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
@ -749,7 +749,7 @@ def test_diffusion_std_random_init_float16(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@ -777,7 +777,7 @@ def test_diffusion_std_random_init_sag(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_sag) ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@ -806,7 +806,7 @@ def test_diffusion_std_init_image(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_init_image) ensure_similar_images(predicted_image, expected_image_std_init_image)
@ -823,7 +823,7 @@ def test_rectangular_init_latents(
rect_init_image = cutecat_init.crop((0, 0, width, height)) rect_init_image = cutecat_init.crop((0, 0, width, height))
x = sd15.init_latents((height, width), rect_init_image) x = sd15.init_latents((height, width), rect_init_image)
assert sd15.lda.decode_latents(x).size == (width, height) assert sd15.lda.latents_to_image(x).size == (width, height)
@no_grad() @no_grad()
@ -853,7 +853,7 @@ def test_diffusion_inpainting(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves. # PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
@ -887,7 +887,7 @@ def test_diffusion_inpainting_float16(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because float16 is even worse than float32. # PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
@ -930,7 +930,7 @@ def test_diffusion_controlnet(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -973,7 +973,7 @@ def test_diffusion_controlnet_structural_copy(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1015,7 +1015,7 @@ def test_diffusion_controlnet_float16(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1069,7 +1069,7 @@ def test_diffusion_controlnet_stack(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@ -1101,7 +1101,7 @@ def test_diffusion_lora(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1144,7 +1144,7 @@ def test_diffusion_sdxl_lora(
condition_scale=guidance_scale, condition_scale=guidance_scale,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1192,7 +1192,7 @@ def test_diffusion_sdxl_multiple_loras(
condition_scale=guidance_scale, condition_scale=guidance_scale,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -1211,7 +1211,7 @@ def test_diffusion_refonly(
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
guide = sd15.lda.encode_image(condition_image_refonly) guide = sd15.lda.image_to_latents(condition_image_refonly)
guide = torch.cat((guide, guide)) guide = torch.cat((guide, guide))
manual_seed(2) manual_seed(2)
@ -1228,7 +1228,7 @@ def test_diffusion_refonly(
condition_scale=7.5, condition_scale=7.5,
) )
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
# min_psnr lowered to 33 because this reference image was generated without noise removal (see #192) # min_psnr lowered to 33 because this reference image was generated without noise removal (see #192)
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99) ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99)
@ -1253,7 +1253,7 @@ def test_diffusion_inpainting_refonly(
sd15.set_inference_steps(30) sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
guide = sd15.lda.encode_image(scene_image_inpainting_refonly) guide = sd15.lda.image_to_latents(scene_image_inpainting_refonly)
guide = torch.cat((guide, guide)) guide = torch.cat((guide, guide))
manual_seed(2) manual_seed(2)
@ -1273,7 +1273,7 @@ def test_diffusion_inpainting_refonly(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
@ -1306,7 +1306,7 @@ def test_diffusion_textual_inversion_random_init(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@ -1351,7 +1351,7 @@ def test_diffusion_ip_adapter(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@ -1440,7 +1440,7 @@ def test_diffusion_sdxl_ip_adapter(
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the # See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big" # internal activation values are too big"
sdxl.lda.to(dtype=torch.float32) sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32)) predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@ -1496,7 +1496,7 @@ def test_diffusion_ip_adapter_controlnet(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@ -1537,7 +1537,7 @@ def test_diffusion_ip_adapter_plus(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@ -1584,7 +1584,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
condition_scale=5, condition_scale=5,
) )
sdxl.lda.to(dtype=torch.float32) sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32)) predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@ -1618,7 +1618,7 @@ def test_sdxl_random_init(
time_ids=time_ids, time_ids=time_ids,
condition_scale=5, condition_scale=5,
) )
predicted_image = sdxl.lda.decode_latents(x=x) predicted_image = sdxl.lda.latents_to_image(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@ -1653,7 +1653,7 @@ def test_sdxl_random_init_sag(
time_ids=time_ids, time_ids=time_ids,
condition_scale=5, condition_scale=5,
) )
predicted_image = sdxl.lda.decode_latents(x=x) predicted_image = sdxl.lda.latents_to_image(x=x)
ensure_similar_images(img_1=predicted_image, img_2=expected_image) ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@ -1766,7 +1766,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
step=step, step=step,
targets=[target_1, target_2], targets=[target_1, target_2],
) )
result = sd.lda.decode_latents(x=x) result = sd.lda.latents_to_image(x=x)
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@ -1805,7 +1805,7 @@ def test_t2i_adapter_depth(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)
@ -1853,7 +1853,7 @@ def test_t2i_adapter_xl_canny(
time_ids=time_ids, time_ids=time_ids,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)
@ -1892,7 +1892,7 @@ def test_restart(
condition_scale=8, condition_scale=8,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
@ -1924,7 +1924,7 @@ def test_freeu(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_freeu) ensure_similar_images(predicted_image, expected_freeu)
@ -1980,6 +1980,6 @@ def test_hello_world(
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)

View file

@ -101,3 +101,14 @@ def test_encoder(
# numerical differences depending on the backend. # numerical differences depending on the backend.
# Also we use FP16 weights. # Also we use FP16 weights.
assert (our_embeddings - ref_embeddings).abs().max() < 0.01 assert (our_embeddings - ref_embeddings).abs().max() < 0.01
def test_list_string_tokenizer(
prompt: str,
our_encoder: CLIPTextEncoderL,
):
tokenizer = our_encoder.ensure_find(CLIPTokenizer)
# batched inputs
double_tokens = tokenizer([prompt, prompt[0:3]])
assert double_tokens.shape[0] == 2

View file

@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image:
@no_grad() @no_grad()
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.encode_image(sample_image) encoded = encoder.image_to_latents(sample_image)
decoded = encoder.decode_latents(encoded) decoded = encoder.latents_to_image(encoded)
assert decoded.mode == "RGB" assert decoded.mode == "RGB"
@ -49,3 +49,12 @@ def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9) ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
@no_grad()
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.images_to_latents([sample_image, sample_image])
images = encoder.latents_to_images(encoded)
assert isinstance(images, list)
assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)