mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
cosmetics
This commit is contained in:
parent
ea05f3d327
commit
ca9e89b22a
|
@ -70,11 +70,10 @@ class DoubleTextEncoder(fl.Chain):
|
||||||
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
|
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
||||||
fl.Lambda(func=self.concatenate_embeddings),
|
fl.Lambda(self.concatenate_embeddings),
|
||||||
)
|
|
||||||
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(
|
|
||||||
parent=self.layer("Parallel", fl.Parallel)
|
|
||||||
)
|
)
|
||||||
|
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
|
||||||
|
tep.inject(self.layer("Parallel", fl.Parallel))
|
||||||
|
|
||||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
||||||
return super().__call__(text)
|
return super().__call__(text)
|
||||||
|
@ -83,5 +82,5 @@ class DoubleTextEncoder(fl.Chain):
|
||||||
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
|
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
||||||
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
|
text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
|
||||||
return text_embedding, pooled_text_embedding
|
return text_embedding, pooled_text_embedding
|
||||||
|
|
|
@ -1624,7 +1624,7 @@ def test_sdxl_random_init(
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_sdxl_random_init_sag(
|
def test_diffusion_sdxl_random_init_sag(
|
||||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
||||||
) -> None:
|
) -> None:
|
||||||
sdxl = sdxl_ddim
|
sdxl = sdxl_ddim
|
||||||
|
|
Loading…
Reference in a new issue