mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
controlnet: replace Lambda w/ Slicing basic layer
This commit is contained in:
parent
7a32699cc6
commit
12e37f5d85
|
@ -1,5 +1,5 @@
|
|||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
||||
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity, Slicing
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||
SD1UNet,
|
||||
DownBlocks,
|
||||
|
@ -88,7 +88,7 @@ class Controlnet(Passthrough):
|
|||
self.scale = scale
|
||||
super().__init__(
|
||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
|
||||
Slicing(dim=1, start=0, length=4), # support inpainting
|
||||
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
||||
MiddleBlock(device=device, dtype=dtype),
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue