mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
controlnet: replace Lambda w/ Slicing basic layer
This commit is contained in:
parent
7a32699cc6
commit
12e37f5d85
|
@ -1,5 +1,5 @@
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity, Slicing
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
DownBlocks,
|
DownBlocks,
|
||||||
|
@ -88,7 +88,7 @@ class Controlnet(Passthrough):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||||
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
|
Slicing(dim=1, start=0, length=4), # support inpainting
|
||||||
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
||||||
MiddleBlock(device=device, dtype=dtype),
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue