controlnet: replace Lambda w/ Slicing basic layer

This commit is contained in:
Cédric Deltheil 2023-09-12 15:33:19 +02:00 committed by Cédric Deltheil
parent 7a32699cc6
commit 12e37f5d85

View file

@ -1,5 +1,5 @@
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity, Slicing
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
SD1UNet,
DownBlocks,
@ -88,7 +88,7 @@ class Controlnet(Passthrough):
self.scale = scale
super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
Slicing(dim=1, start=0, length=4), # support inpainting
DownBlocks(in_channels=4, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
)