diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index d57acac..5983a49 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -678,7 +678,7 @@ def convert_lcm_base(): "tests/weights/latent-consistency/lcm-sdxl", "tests/weights/sdxl-lcm-unet.safetensors", half=True, - expected_hash="242cf440", + expected_hash="e161b20c", ) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py index abbd1d3..4869d8f 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm.py @@ -29,7 +29,7 @@ def compute_sinusoidal_embedding( return embedding -class ResidualBlock(fl.Residual): +class ConditionScaleBlock(fl.Residual): def __init__( self, in_channels: int, @@ -85,7 +85,7 @@ class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]): def inject(self: "SDXLLcmAdapter", parent: fl.Chain | None = None) -> "SDXLLcmAdapter": ra = self.target.ensure_find(RangeEncoder) - block = ResidualBlock( + block = ConditionScaleBlock( in_channels=self.condition_scale_embedding_dim, out_channels=ra.sinusoidal_embedding_dim, device=self.target.device, @@ -96,5 +96,5 @@ class SDXLLcmAdapter(fl.Chain, Adapter[SDXLUNet]): def eject(self) -> None: ra = self.target.ensure_find(RangeEncoder) - ra.remove(ra.ensure_find(ResidualBlock)) + ra.remove(ra.ensure_find(ConditionScaleBlock)) super().eject()