diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 25e7502..54365ac 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -188,7 +188,7 @@ class SD1UNet(fl.Chain): fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]), MiddleBlock(device=device, dtype=dtype), ), - UpBlocks(), + UpBlocks(device=device, dtype=dtype), fl.Chain( fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype), fl.SiLU(),