cosmetics

This commit is contained in:
Pierre Chapuis 2024-02-06 18:43:46 +01:00
parent ea05f3d327
commit ca9e89b22a
2 changed files with 5 additions and 6 deletions

View file

@ -70,11 +70,10 @@ class DoubleTextEncoder(fl.Chain):
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype) text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
super().__init__( super().__init__(
fl.Parallel(text_encoder_l[:-2], text_encoder_g), fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings), fl.Lambda(self.concatenate_embeddings),
)
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(
parent=self.layer("Parallel", fl.Parallel)
) )
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
tep.inject(self.layer("Parallel", fl.Parallel))
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text) return super().__call__(text)
@ -83,5 +82,5 @@ class DoubleTextEncoder(fl.Chain):
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor] self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1) text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
return text_embedding, pooled_text_embedding return text_embedding, pooled_text_embedding

View file

@ -1624,7 +1624,7 @@ def test_sdxl_random_init(
@no_grad() @no_grad()
def test_sdxl_random_init_sag( def test_diffusion_sdxl_random_init_sag(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
) -> None: ) -> None:
sdxl = sdxl_ddim sdxl = sdxl_ddim