mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
clip text, lda encode batch inputs
* text_encoder([str1, str2]) * lda decode_latents/encode_image image_to_latent/latent_to_image * images_to_tensor, tensor_to_images --------- Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
12aa0b23f6
commit
4a6146bb6c
|
@ -10,6 +10,7 @@ from safetensors import safe_open as _safe_open # type: ignore
|
|||
from safetensors.torch import save_file as _save_file # type: ignore
|
||||
from torch import (
|
||||
Tensor,
|
||||
cat,
|
||||
device as Device,
|
||||
dtype as DType,
|
||||
manual_seed as _manual_seed, # type: ignore
|
||||
|
@ -113,6 +114,12 @@ def gaussian_blur(
|
|||
return tensor
|
||||
|
||||
|
||||
def images_to_tensor(
|
||||
images: list[Image.Image], device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> Tensor:
|
||||
return cat([image_to_tensor(image, device=device, dtype=dtype) for image in images])
|
||||
|
||||
|
||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||
"""
|
||||
Convert a PIL Image to a Tensor.
|
||||
|
@ -135,6 +142,10 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
|
|||
return image_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def tensor_to_images(tensor: Tensor) -> list[Image.Image]:
|
||||
return [tensor_to_image(t) for t in tensor.split(1)] # type: ignore
|
||||
|
||||
|
||||
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||
"""
|
||||
Convert a Tensor to a PIL Image.
|
||||
|
|
|
@ -4,7 +4,7 @@ from functools import lru_cache
|
|||
from itertools import islice
|
||||
from pathlib import Path
|
||||
|
||||
from torch import Tensor, tensor
|
||||
from torch import Tensor, cat, tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion import pad
|
||||
|
@ -51,11 +51,19 @@ class CLIPTokenizer(fl.Module):
|
|||
self.end_of_text_token_id: int = end_of_text_token_id
|
||||
self.pad_token_id: int = pad_token_id
|
||||
|
||||
def forward(self, text: str) -> Tensor:
|
||||
def forward(self, text: str | list[str]) -> Tensor:
|
||||
if isinstance(text, str):
|
||||
return self.tokenize_str(text)
|
||||
else:
|
||||
assert isinstance(text, list), f"Expected type `str` or `list[str]`, got {type(text)}"
|
||||
return cat([self.tokenize_str(txt) for txt in text])
|
||||
|
||||
def tokenize_str(self, text: str) -> Tensor:
|
||||
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
|
||||
|
||||
assert (
|
||||
tokens.shape[1] <= self.sequence_length
|
||||
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
|
||||
), f"Text is too long ({len(text)}): tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
|
||||
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
||||
|
||||
@lru_cache()
|
||||
|
|
|
@ -15,7 +15,7 @@ from refiners.fluxion.layers import (
|
|||
Sum,
|
||||
Upsample,
|
||||
)
|
||||
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||
from refiners.fluxion.utils import images_to_tensor, tensor_to_images
|
||||
|
||||
|
||||
class Resnet(Sum):
|
||||
|
@ -210,12 +210,25 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
x = decoder(x / self.encoder_scale)
|
||||
return x
|
||||
|
||||
def encode_image(self, image: Image.Image) -> Tensor:
|
||||
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||
def image_to_latents(self, image: Image.Image) -> Tensor:
|
||||
return self.images_to_latents([image])
|
||||
|
||||
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
||||
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
|
||||
x = 2 * x - 1
|
||||
return self.encode(x)
|
||||
|
||||
# backward-compatibility alias
|
||||
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||
return self.latents_to_image(x)
|
||||
|
||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
||||
if x.shape[0] != 1:
|
||||
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
||||
|
||||
return self.latents_to_images(x)[0]
|
||||
|
||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
||||
x = self.decode(x)
|
||||
x = (x + 1) / 2
|
||||
return tensor_to_image(x)
|
||||
return tensor_to_images(x)
|
||||
|
|
|
@ -50,7 +50,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||
if init_image is None:
|
||||
return noise
|
||||
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
||||
encoded_image = self.lda.image_to_latents(image=init_image.resize(size=(width, height)))
|
||||
return self.solver.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
|
|
|
@ -83,8 +83,15 @@ class MultiDiffusion(Generic[T, D], ABC):
|
|||
def dtype(self) -> DType:
|
||||
return self.ldm.dtype
|
||||
|
||||
# backward-compatibility alias
|
||||
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||
return self.ldm.lda.decode_latents(x=x)
|
||||
return self.latents_to_image(x=x)
|
||||
|
||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
||||
return self.ldm.lda.latents_to_image(x=x)
|
||||
|
||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
||||
return self.ldm.lda.latents_to_images(x=x)
|
||||
|
||||
@staticmethod
|
||||
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
|
||||
|
|
|
@ -113,7 +113,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
|||
max_size=self.config.dataset.resize_image_max_size,
|
||||
)
|
||||
processed_image = self.process_image(resized_image)
|
||||
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
||||
latents = self.lda.image_to_latents(image=processed_image).to(device=self.device)
|
||||
processed_caption = self.process_caption(caption=caption)
|
||||
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
|
||||
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
|
||||
|
@ -202,7 +202,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
)
|
||||
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
|
||||
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
|
||||
images[prompt] = canvas_image
|
||||
self.log(data=images)
|
||||
|
||||
|
|
|
@ -666,7 +666,7 @@ def test_diffusion_std_random_init(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||
|
||||
|
@ -696,7 +696,7 @@ def test_diffusion_std_random_init_euler(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init_euler)
|
||||
|
||||
|
@ -721,7 +721,7 @@ def test_diffusion_karras_random_init(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -749,7 +749,7 @@ def test_diffusion_std_random_init_float16(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -777,7 +777,7 @@ def test_diffusion_std_random_init_sag(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
|
||||
|
||||
|
@ -806,7 +806,7 @@ def test_diffusion_std_init_image(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
||||
|
||||
|
@ -823,7 +823,7 @@ def test_rectangular_init_latents(
|
|||
rect_init_image = cutecat_init.crop((0, 0, width, height))
|
||||
x = sd15.init_latents((height, width), rect_init_image)
|
||||
|
||||
assert sd15.lda.decode_latents(x).size == (width, height)
|
||||
assert sd15.lda.latents_to_image(x).size == (width, height)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -853,7 +853,7 @@ def test_diffusion_inpainting(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
|
||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
||||
|
@ -887,7 +887,7 @@ def test_diffusion_inpainting_float16(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
# PSNR and SSIM values are large because float16 is even worse than float32.
|
||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
||||
|
@ -930,7 +930,7 @@ def test_diffusion_controlnet(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -973,7 +973,7 @@ def test_diffusion_controlnet_structural_copy(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1015,7 +1015,7 @@ def test_diffusion_controlnet_float16(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1069,7 +1069,7 @@ def test_diffusion_controlnet_stack(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1101,7 +1101,7 @@ def test_diffusion_lora(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1144,7 +1144,7 @@ def test_diffusion_sdxl_lora(
|
|||
condition_scale=guidance_scale,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1192,7 +1192,7 @@ def test_diffusion_sdxl_multiple_loras(
|
|||
condition_scale=guidance_scale,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1211,7 +1211,7 @@ def test_diffusion_refonly(
|
|||
|
||||
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
|
||||
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||
guide = sd15.lda.image_to_latents(condition_image_refonly)
|
||||
guide = torch.cat((guide, guide))
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -1228,7 +1228,7 @@ def test_diffusion_refonly(
|
|||
condition_scale=7.5,
|
||||
)
|
||||
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
# min_psnr lowered to 33 because this reference image was generated without noise removal (see #192)
|
||||
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99)
|
||||
|
@ -1253,7 +1253,7 @@ def test_diffusion_inpainting_refonly(
|
|||
sd15.set_inference_steps(30)
|
||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||
|
||||
guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
|
||||
guide = sd15.lda.image_to_latents(scene_image_inpainting_refonly)
|
||||
guide = torch.cat((guide, guide))
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -1273,7 +1273,7 @@ def test_diffusion_inpainting_refonly(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
||||
|
||||
|
@ -1306,7 +1306,7 @@ def test_diffusion_textual_inversion_random_init(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1351,7 +1351,7 @@ def test_diffusion_ip_adapter(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||
|
||||
|
@ -1440,7 +1440,7 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
|
||||
# internal activation values are too big"
|
||||
sdxl.lda.to(dtype=torch.float32)
|
||||
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
|
||||
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
||||
|
||||
|
@ -1496,7 +1496,7 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
|
||||
|
||||
|
@ -1537,7 +1537,7 @@ def test_diffusion_ip_adapter_plus(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1584,7 +1584,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
condition_scale=5,
|
||||
)
|
||||
sdxl.lda.to(dtype=torch.float32)
|
||||
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
|
||||
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
|
||||
|
||||
|
@ -1618,7 +1618,7 @@ def test_sdxl_random_init(
|
|||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x=x)
|
||||
predicted_image = sdxl.lda.latents_to_image(x=x)
|
||||
|
||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1653,7 +1653,7 @@ def test_sdxl_random_init_sag(
|
|||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x=x)
|
||||
predicted_image = sdxl.lda.latents_to_image(x=x)
|
||||
|
||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
||||
|
||||
|
@ -1766,7 +1766,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
|||
step=step,
|
||||
targets=[target_1, target_2],
|
||||
)
|
||||
result = sd.lda.decode_latents(x=x)
|
||||
result = sd.lda.latents_to_image(x=x)
|
||||
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
|
@ -1805,7 +1805,7 @@ def test_t2i_adapter_depth(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
@ -1853,7 +1853,7 @@ def test_t2i_adapter_xl_canny(
|
|||
time_ids=time_ids,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
@ -1892,7 +1892,7 @@ def test_restart(
|
|||
condition_scale=8,
|
||||
)
|
||||
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
@ -1924,7 +1924,7 @@ def test_freeu(
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_freeu)
|
||||
|
||||
|
@ -1980,6 +1980,6 @@ def test_hello_world(
|
|||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
)
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
|
|
@ -101,3 +101,14 @@ def test_encoder(
|
|||
# numerical differences depending on the backend.
|
||||
# Also we use FP16 weights.
|
||||
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
|
||||
|
||||
|
||||
def test_list_string_tokenizer(
|
||||
prompt: str,
|
||||
our_encoder: CLIPTextEncoderL,
|
||||
):
|
||||
tokenizer = our_encoder.ensure_find(CLIPTokenizer)
|
||||
|
||||
# batched inputs
|
||||
double_tokens = tokenizer([prompt, prompt[0:3]])
|
||||
assert double_tokens.shape[0] == 2
|
||||
|
|
|
@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image:
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = encoder.encode_image(sample_image)
|
||||
decoded = encoder.decode_latents(encoded)
|
||||
def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = encoder.image_to_latents(sample_image)
|
||||
decoded = encoder.latents_to_image(encoded)
|
||||
|
||||
assert decoded.mode == "RGB"
|
||||
|
||||
|
@ -49,3 +49,12 @@ def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.
|
|||
assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore
|
||||
|
||||
ensure_similar_images(sample_image, decoded, min_psnr=20, min_ssim=0.9)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = encoder.images_to_latents([sample_image, sample_image])
|
||||
images = encoder.latents_to_images(encoded)
|
||||
assert isinstance(images, list)
|
||||
assert len(images) == 2
|
||||
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
|
||||
|
|
Loading…
Reference in a new issue