mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
cosmetics
This commit is contained in:
parent
ea05f3d327
commit
ca9e89b22a
|
@ -70,11 +70,10 @@ class DoubleTextEncoder(fl.Chain):
|
|||
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
|
||||
super().__init__(
|
||||
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
||||
fl.Lambda(func=self.concatenate_embeddings),
|
||||
)
|
||||
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(
|
||||
parent=self.layer("Parallel", fl.Parallel)
|
||||
fl.Lambda(self.concatenate_embeddings),
|
||||
)
|
||||
tep = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
|
||||
tep.inject(self.layer("Parallel", fl.Parallel))
|
||||
|
||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
||||
return super().__call__(text)
|
||||
|
@ -83,5 +82,5 @@ class DoubleTextEncoder(fl.Chain):
|
|||
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
||||
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
|
||||
text_embedding = cat((text_embedding_l, text_embedding_g), dim=-1)
|
||||
return text_embedding, pooled_text_embedding
|
||||
|
|
|
@ -1624,7 +1624,7 @@ def test_sdxl_random_init(
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_sdxl_random_init_sag(
|
||||
def test_diffusion_sdxl_random_init_sag(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
||||
) -> None:
|
||||
sdxl = sdxl_ddim
|
||||
|
|
Loading…
Reference in a new issue