mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
batch sdxl + sd1 + compute_clip_text_embedding
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
8139b2dd91
commit
d199cd4f24
|
@ -68,7 +68,7 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
|
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
|
||||||
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -76,12 +76,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
negative_text: The negative prompt to compute the CLIP text embedding of.
|
negative_text: The negative prompt to compute the CLIP text embedding of.
|
||||||
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
||||||
"""
|
"""
|
||||||
conditional_embedding = self.clip_text_encoder(text)
|
text = [text] if isinstance(text, str) else text
|
||||||
if text == negative_text:
|
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
|
||||||
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
|
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
|
||||||
|
|
||||||
negative_embedding = self.clip_text_encoder(negative_text or "")
|
conditional_embedding = self.clip_text_encoder(text)
|
||||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
negative_embedding = self.clip_text_encoder(negative_text)
|
||||||
|
|
||||||
|
return torch.cat((negative_embedding, conditional_embedding))
|
||||||
|
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
"""Set the various context parameters required by the U-Net model.
|
"""Set the various context parameters required by the U-Net model.
|
||||||
|
|
|
@ -65,7 +65,9 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
|
def compute_clip_text_embedding(
|
||||||
|
self, text: str | list[str], negative_text: str | list[str] = ""
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -73,14 +75,13 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
negative_text: The negative prompt to compute the CLIP text embedding of.
|
negative_text: The negative prompt to compute the CLIP text embedding of.
|
||||||
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
||||||
"""
|
"""
|
||||||
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
|
|
||||||
if text == negative_text:
|
|
||||||
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
|
|
||||||
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: when negative_text is None, use zero tensor?
|
text = [text] if isinstance(text, str) else text
|
||||||
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
|
negative_text = [negative_text] if isinstance(negative_text, str) else negative_text
|
||||||
|
assert len(text) == len(negative_text), "The length of the text list and negative_text should be the same"
|
||||||
|
|
||||||
|
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
|
||||||
|
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text)
|
||||||
|
|
||||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
|
||||||
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
|
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType
|
from torch import Tensor, cat, device as Device, dtype as DType, split
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -40,7 +40,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"text_encoder_pooling": {"end_of_text_index": []}}
|
return {"text_encoder_pooling": {"end_of_text_index": []}}
|
||||||
|
|
||||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
|
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 1280"], Float[Tensor, "batch 1280"]]:
|
||||||
return super().__call__(text)
|
return super().__call__(text)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -48,13 +48,14 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
return self.ensure_find(CLIPTokenizer)
|
return self.ensure_find(CLIPTokenizer)
|
||||||
|
|
||||||
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||||
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
|
for str_tokens in split(tokens, 1):
|
||||||
|
position = (str_tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item() # type: ignore
|
||||||
end_of_text_index.append(cast(int, position))
|
end_of_text_index.append(cast(int, position))
|
||||||
|
|
||||||
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
|
def pool(self, x: Float[Tensor, "batch 77 1280"]) -> Float[Tensor, "batch 1280"]:
|
||||||
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
|
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
|
||||||
assert len(end_of_text_index) == 1, "End of text index not found."
|
assert len(end_of_text_index) == x.shape[0], "End of text index not found."
|
||||||
return x[:, end_of_text_index[0], :]
|
return cat([x[i : i + 1, end_of_text_index[i], :] for i in range(x.shape[0])], dim=0)
|
||||||
|
|
||||||
|
|
||||||
class DoubleTextEncoder(fl.Chain):
|
class DoubleTextEncoder(fl.Chain):
|
||||||
|
@ -75,7 +76,7 @@ class DoubleTextEncoder(fl.Chain):
|
||||||
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
|
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
|
||||||
tep.inject(self.layer("Parallel", fl.Parallel))
|
tep.inject(self.layer("Parallel", fl.Parallel))
|
||||||
|
|
||||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
def __call__(self, text: str | list[str]) -> tuple[Float[Tensor, "batch 77 2048"], Float[Tensor, "batch 1280"]]:
|
||||||
return super().__call__(text)
|
return super().__call__(text)
|
||||||
|
|
||||||
def concatenate_embeddings(
|
def concatenate_embeddings(
|
||||||
|
|
|
@ -757,6 +757,60 @@ def test_diffusion_std_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
||||||
|
sd15 = sd15_std
|
||||||
|
|
||||||
|
prompt1 = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
prompt2 = "a cute dog"
|
||||||
|
negative_prompt2 = "lowres, bad anatomy, bad hands"
|
||||||
|
|
||||||
|
clip_text_embedding_b2 = sd15.compute_clip_text_embedding(
|
||||||
|
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
|
||||||
|
)
|
||||||
|
|
||||||
|
step = sd15.steps[0]
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
rand_b2 = torch.randn(2, 4, 64, 64, device=sd15.device)
|
||||||
|
|
||||||
|
x_b2 = sd15(
|
||||||
|
rand_b2,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding_b2,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert x_b2.shape == (2, 4, 64, 64)
|
||||||
|
|
||||||
|
rand_1 = rand_b2[0:1]
|
||||||
|
clip_text_embedding_1 = sd15.compute_clip_text_embedding(text=[prompt1], negative_text=[negative_prompt1])
|
||||||
|
x_1 = sd15(
|
||||||
|
rand_1,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding_1,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
rand_2 = rand_b2[1:2]
|
||||||
|
clip_text_embedding_2 = sd15.compute_clip_text_embedding(text=[prompt2], negative_text=[negative_prompt2])
|
||||||
|
x_2 = sd15(
|
||||||
|
rand_2,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding_2,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
|
||||||
|
assert torch.allclose(
|
||||||
|
x_b2[0], x_1[0], atol=5e-3, rtol=0
|
||||||
|
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
|
||||||
|
assert torch.allclose(
|
||||||
|
x_b2[1], x_2[0], atol=5e-3, rtol=0
|
||||||
|
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init_euler(
|
def test_diffusion_std_random_init_euler(
|
||||||
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
|
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
|
||||||
|
@ -836,7 +890,6 @@ def test_diffusion_std_random_init_float16(
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.latents_to_image(x)
|
predicted_image = sd15.lda.latents_to_image(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1265,6 +1318,68 @@ def test_diffusion_lora(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_sdxl_batch2(sdxl_ddim: StableDiffusion_XL) -> None:
|
||||||
|
sdxl = sdxl_ddim
|
||||||
|
|
||||||
|
prompt1 = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
prompt2 = "a cute dog"
|
||||||
|
negative_prompt2 = "lowres, bad anatomy, bad hands"
|
||||||
|
|
||||||
|
clip_text_embedding_b2, pooled_text_embedding_b2 = sdxl.compute_clip_text_embedding(
|
||||||
|
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
|
||||||
|
)
|
||||||
|
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
time_ids_b2 = sdxl.default_time_ids.repeat(2, 1)
|
||||||
|
|
||||||
|
manual_seed(seed=2)
|
||||||
|
x_b2 = torch.randn(2, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||||
|
x_1 = x_b2[0:1]
|
||||||
|
x_2 = x_b2[1:2]
|
||||||
|
|
||||||
|
x_b2 = sdxl(
|
||||||
|
x_b2,
|
||||||
|
step=sdxl.steps[0],
|
||||||
|
clip_text_embedding=clip_text_embedding_b2,
|
||||||
|
pooled_text_embedding=pooled_text_embedding_b2,
|
||||||
|
time_ids=time_ids_b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_text_embedding_1, pooled_text_embedding_1 = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt1, negative_text=negative_prompt1
|
||||||
|
)
|
||||||
|
|
||||||
|
x_1 = sdxl(
|
||||||
|
x_1,
|
||||||
|
step=sdxl.steps[0],
|
||||||
|
clip_text_embedding=clip_text_embedding_1,
|
||||||
|
pooled_text_embedding=pooled_text_embedding_1,
|
||||||
|
time_ids=time_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_text_embedding_2, pooled_text_embedding_2 = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt2, negative_text=negative_prompt2
|
||||||
|
)
|
||||||
|
|
||||||
|
x_2 = sdxl(
|
||||||
|
x_2,
|
||||||
|
step=sdxl.steps[0],
|
||||||
|
clip_text_embedding=clip_text_embedding_2,
|
||||||
|
pooled_text_embedding=pooled_text_embedding_2,
|
||||||
|
time_ids=time_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The 5e-3 tolerance is detailed in https://github.com/finegrain-ai/refiners/pull/263#issuecomment-1956404911
|
||||||
|
assert torch.allclose(
|
||||||
|
x_b2[0], x_1[0], atol=5e-3, rtol=0
|
||||||
|
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[0] - x_1[0]).abs()).item()}"
|
||||||
|
assert torch.allclose(
|
||||||
|
x_b2[1], x_2[0], atol=5e-3, rtol=0
|
||||||
|
), f"Batch 2 and batch1 output should be the same and are distant of {torch.max((x_b2[1] - x_2[0]).abs()).item()}"
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_sdxl_lora(
|
def test_diffusion_sdxl_lora(
|
||||||
sdxl_ddim: StableDiffusion_XL,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
|
|
|
@ -100,3 +100,24 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:
|
||||||
|
|
||||||
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
|
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
|
||||||
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_double_text_encoder_batch2(double_text_encoder: DoubleTextEncoder) -> None:
|
||||||
|
manual_seed(seed=0)
|
||||||
|
prompt1 = "A photo of a pizza."
|
||||||
|
prompt2 = "A giant duck."
|
||||||
|
|
||||||
|
double_embedding_b2, pooled_embedding_b2 = double_text_encoder([prompt1, prompt2])
|
||||||
|
|
||||||
|
assert double_embedding_b2.shape == torch.Size([2, 77, 2048])
|
||||||
|
assert pooled_embedding_b2.shape == torch.Size([2, 1280])
|
||||||
|
|
||||||
|
double_embedding_1, pooled_embedding_1 = double_text_encoder(prompt1)
|
||||||
|
double_embedding_2, pooled_embedding_2 = double_text_encoder(prompt2)
|
||||||
|
|
||||||
|
assert torch.allclose(input=double_embedding_1, other=double_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
|
||||||
|
assert torch.allclose(input=pooled_embedding_1, other=pooled_embedding_b2[0:1], rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
assert torch.allclose(input=double_embedding_2, other=double_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
|
||||||
|
assert torch.allclose(input=pooled_embedding_2, other=pooled_embedding_b2[1:2], rtol=1e-3, atol=1e-3)
|
||||||
|
|
Loading…
Reference in a new issue