diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 6a59f73..9a94b27 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -1,5 +1,5 @@ from refiners.fluxion.context import Contexts -from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity +from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity, Slicing from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( SD1UNet, DownBlocks, @@ -88,7 +88,7 @@ class Controlnet(Passthrough): self.scale = scale super().__init__( TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), - Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting + Slicing(dim=1, start=0, length=4), # support inpainting DownBlocks(in_channels=4, device=device, dtype=dtype), MiddleBlock(device=device, dtype=dtype), )