From 44e184d4d5ce5c0ea35c20b655bf72d1d965dcdc Mon Sep 17 00:00:00 2001 From: Doryan Kaced Date: Fri, 1 Sep 2023 18:26:24 +0200 Subject: [PATCH] Init dtype and device correctly for OutputBlock --- .../foundationals/latent_diffusion/stable_diffusion_xl/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 753472a..a35fd9b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -242,7 +242,7 @@ class OutputBlock(fl.Chain): def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__( - fl.GroupNorm(channels=320, num_groups=32), + fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype), fl.SiLU(), fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), )