controlnet: replace Lambda w/ Slicing basic layer

This commit is contained in:
Cédric Deltheil 2023-09-12 15:33:19 +02:00 committed by Cédric Deltheil
parent 7a32699cc6
commit 12e37f5d85

View file

@ -1,5 +1,5 @@
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity, Slicing
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
SD1UNet, SD1UNet,
DownBlocks, DownBlocks,
@ -88,7 +88,7 @@ class Controlnet(Passthrough):
self.scale = scale self.scale = scale
super().__init__( super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting Slicing(dim=1, start=0, length=4), # support inpainting
DownBlocks(in_channels=4, device=device, dtype=dtype), DownBlocks(in_channels=4, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype), MiddleBlock(device=device, dtype=dtype),
) )