rename ResidualBlock to ConditionScaleBlock in LCM

This commit is contained in:
Pierre Chapuis 2024-02-21 15:53:58 +01:00
parent 5f21922925
commit 03b79d6d34
2 changed files with 4 additions and 4 deletions

View file

@ -678,7 +678,7 @@ def convert_lcm_base():
"tests/weights/latent-consistency/lcm-sdxl", "tests/weights/latent-consistency/lcm-sdxl",
"tests/weights/sdxl-lcm-unet.safetensors", "tests/weights/sdxl-lcm-unet.safetensors",
half=True, half=True,
expected_hash="242cf440", expected_hash="e161b20c",
) )

View file

@ -29,7 +29,7 @@ def compute_sinusoidal_embedding(
return embedding return embedding
class ResidualBlock(fl.Residual): class ConditionScaleBlock(fl.Residual):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -85,7 +85,7 @@ class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]):
def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter": def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter":
ra = self.target.ensure_find(RangeEncoder) ra = self.target.ensure_find(RangeEncoder)
block = ResidualBlock( block = ConditionScaleBlock(
in_channels=self.condition_scale_embedding_dim, in_channels=self.condition_scale_embedding_dim,
out_channels=ra.sinusoidal_embedding_dim, out_channels=ra.sinusoidal_embedding_dim,
device=self.target.device, device=self.target.device,
@ -96,5 +96,5 @@ class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]):
def eject(self) -> None: def eject(self) -> None:
ra = self.target.ensure_find(RangeEncoder) ra = self.target.ensure_find(RangeEncoder)
ra.remove(ra.ensure_find(ResidualBlock)) ra.remove(ra.ensure_find(ConditionScaleBlock))
super().eject() super().eject()