add missing device and dtype to SD1UNet's UpBlocks

This commit is contained in:
Laurent 2024-09-08 13:47:35 +00:00
parent 444882a734
commit 04fd98ce20
No known key found for this signature in database

View file

@ -188,7 +188,7 @@ class SD1UNet(fl.Chain):
fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]), fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
MiddleBlock(device=device, dtype=dtype), MiddleBlock(device=device, dtype=dtype),
), ),
UpBlocks(), UpBlocks(device=device, dtype=dtype),
fl.Chain( fl.Chain(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype), fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(), fl.SiLU(),