diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 753472a..a35fd9b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -242,7 +242,7 @@ class OutputBlock(fl.Chain): def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__( - fl.GroupNorm(channels=320, num_groups=32), + fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype), fl.SiLU(), fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), )