diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 523b6db..27d61fd 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -70,8 +70,15 @@ class ConditionEncoder(Chain): class Controlnet(Passthrough): + scale_decays: list[float] + def __init__( - self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None + self, + name: str, + scale: float = 1.0, + scale_decay: float = 1.0, + device: Device | str | None = None, + dtype: DType | None = None, ) -> None: """Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet. @@ -79,9 +86,16 @@ class Controlnet(Passthrough): stored in the context. It has to use the same context as the UNet: `unet` and `sampling`. + + Scale decay of 0.825 corresponds to the "Prompt is more important" Control Mode of sd-webui-controlnet plugin + https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472 + See also the so-called "Guess Mode" in the official ControlNet demos which uses such scales: + https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode """ self.name = name self.scale = scale + self._scale_decay = scale_decay + self.compute_scale_decays() super().__init__( TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), Slicing(dim=1, end=4), # support inpainting @@ -134,19 +148,38 @@ class Controlnet(Passthrough): def _store_nth_residual(self, n: int): def _store_residual(x: Tensor): residuals = self.use_context("unet")["residuals"] - residuals[n] = residuals[n] + x * self.scale + residuals[n] = residuals[n] + x * self.scale * self.scale_decays[n] return x return _store_residual + @property + def scale_decay(self) -> float: + return self._scale_decay + + @scale_decay.setter + def scale_decay(self, value: float) -> None: + self._scale_decay = value + self.compute_scale_decays() + + def compute_scale_decays(self) -> None: + self.scale_decays = [self.scale_decay ** float(12 - i) for i in range(13)] + class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): def __init__( - self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None + self, + target: SD1UNet, + name: str, + scale: float = 1.0, + scale_decay: float = 1.0, + weights: dict[str, Tensor] | None = None, ) -> None: self.name = name - controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype) + controlnet = Controlnet( + name=name, scale=scale, scale_decay=scale_decay, device=target.device, dtype=target.dtype + ) if weights is not None: controlnet.load_state_dict(weights) self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch @@ -176,11 +209,19 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): @property def scale(self) -> float: - return self._controlnet[0].scale + return self.controlnet.scale @scale.setter def scale(self, value: float) -> None: - self._controlnet[0].scale = value + self.controlnet.scale = value + + @property + def scale_decay(self) -> float: + return self.controlnet.scale_decay + + @scale_decay.setter + def scale_decay(self, value: float) -> None: + self.controlnet.scale_decay = value def set_controlnet_condition(self, condition: Tensor) -> None: self.set_context("controlnet", {f"condition_{self.name}": condition}) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 0fcb0fd..c570bf1 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -179,6 +179,21 @@ def controlnet_data( yield (cn_name, condition_image, expected_image, weights_path) +@pytest.fixture(scope="module", params=["canny"]) +def controlnet_data_scale_decay( + ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest +) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]: + cn_name: str = request.param + condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB") + weights_fn = { + "canny": "lllyasviel_control_v11p_sd15_canny", + } + + weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors" + yield (cn_name, condition_image, expected_image, weights_path) + + @pytest.fixture(scope="module") def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: cn_name = "canny" @@ -1076,6 +1091,50 @@ def test_diffusion_controlnet( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_diffusion_controlnet_scale_decay( + sd15_std: StableDiffusion_1, + controlnet_data_scale_decay: tuple[str, Image.Image, Image.Image, Path], + test_device: torch.device, +): + sd15 = sd15_std + + cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_scale_decay + + if not cn_weights_path.is_file(): + warn(f"could not find weights at {cn_weights_path}, skipping") + pytest.skip(allow_module_level=True) + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_inference_steps(30) + + # Using default value of 0.825 chosen by lvmin + # https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472 + controlnet = SD1ControlnetAdapter( + sd15.unet, name=cn_name, scale=0.5, scale_decay=0.825, weights=load_from_safetensors(cn_weights_path) + ).inject() + + cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + + for step in sd15.steps: + controlnet.set_controlnet_condition(cn_condition) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.latents_to_image(x) + + ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) + + @no_grad() def test_diffusion_controlnet_structural_copy( sd15_std: StableDiffusion_1, diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 42e3594..a5325b3 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -57,6 +57,8 @@ Special cases: - `expected_controllora_PyraCanny+CPDS.png` - `expected_controllora_disabled.png` - `expected_style_aligned.png` + - `expected_controlnet_.png` (canny|depth|lineart|normals|sam|stack) + - `expected_controlnet__scale_decay.png` (canny) ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_controlnet_canny_scale_decay.png b/tests/e2e/test_diffusion_ref/expected_controlnet_canny_scale_decay.png new file mode 100644 index 0000000..ba1b58a Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_controlnet_canny_scale_decay.png differ