From b91a457495f2721e5e4562aa13cbf13b5701d343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Mon, 21 Aug 2023 14:34:40 +0200 Subject: [PATCH] use Converter layer for sinuosoidal embedding --- src/refiners/adapters/range_adapter.py | 3 ++- src/refiners/foundationals/latent_diffusion/sdxl_unet.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/refiners/adapters/range_adapter.py b/src/refiners/adapters/range_adapter.py index 44217d6..b30d062 100644 --- a/src/refiners/adapters/range_adapter.py +++ b/src/refiners/adapters/range_adapter.py @@ -34,13 +34,14 @@ class RangeEncoder(fl.Chain): self.embedding_dim = embedding_dim super().__init__( fl.Lambda(self.compute_sinuosoidal_embedding), + fl.Converter(set_device=False, set_dtype=True), fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.SiLU(), fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), ) def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]: - return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim).to(self.dtype) + return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim) class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]): diff --git a/src/refiners/foundationals/latent_diffusion/sdxl_unet.py b/src/refiners/foundationals/latent_diffusion/sdxl_unet.py index 05df85a..27a7c68 100644 --- a/src/refiners/foundationals/latent_diffusion/sdxl_unet.py +++ b/src/refiners/foundationals/latent_diffusion/sdxl_unet.py @@ -25,6 +25,7 @@ class TextTimeEmbedding(fl.Chain): ), dim=1, ), + fl.Converter(set_device=False, set_dtype=True), fl.Linear( in_features=self.text_time_embedding_dim, out_features=self.timestep_embedding_dim, @@ -41,7 +42,7 @@ class TextTimeEmbedding(fl.Chain): ) def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor: - return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim).to(dtype=self.dtype) + return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim) class TimestepEncoder(fl.Passthrough):