Add scale_decay parameter for SD1 ControlNet

This commit is contained in:
limiteinductive 2024-06-24 09:32:27 +00:00 committed by Benjamin Trom
parent 567cc9c6d9
commit 15ccdb38f3
4 changed files with 108 additions and 6 deletions

View file

@ -70,8 +70,15 @@ class ConditionEncoder(Chain):
class Controlnet(Passthrough):
scale_decays: list[float]
def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
self,
name: str,
scale: float = 1.0,
scale_decay: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
@ -79,9 +86,16 @@ class Controlnet(Passthrough):
stored in the context.
It has to use the same context as the UNet: `unet` and `sampling`.
Scale decay of 0.825 corresponds to the "Prompt is more important" Control Mode of sd-webui-controlnet plugin
https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
See also the so-called "Guess Mode" in the official ControlNet demos which uses such scales:
https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode
"""
self.name = name
self.scale = scale
self._scale_decay = scale_decay
self.compute_scale_decays()
super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Slicing(dim=1, end=4), # support inpainting
@ -134,19 +148,38 @@ class Controlnet(Passthrough):
def _store_nth_residual(self, n: int):
def _store_residual(x: Tensor):
residuals = self.use_context("unet")["residuals"]
residuals[n] = residuals[n] + x * self.scale
residuals[n] = residuals[n] + x * self.scale * self.scale_decays[n]
return x
return _store_residual
@property
def scale_decay(self) -> float:
return self._scale_decay
@scale_decay.setter
def scale_decay(self, value: float) -> None:
self._scale_decay = value
self.compute_scale_decays()
def compute_scale_decays(self) -> None:
self.scale_decays = [self.scale_decay ** float(12 - i) for i in range(13)]
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
def __init__(
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
self,
target: SD1UNet,
name: str,
scale: float = 1.0,
scale_decay: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
self.name = name
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
controlnet = Controlnet(
name=name, scale=scale, scale_decay=scale_decay, device=target.device, dtype=target.dtype
)
if weights is not None:
controlnet.load_state_dict(weights)
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
@ -176,11 +209,19 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
@property
def scale(self) -> float:
return self._controlnet[0].scale
return self.controlnet.scale
@scale.setter
def scale(self, value: float) -> None:
self._controlnet[0].scale = value
self.controlnet.scale = value
@property
def scale_decay(self) -> float:
return self.controlnet.scale_decay
@scale_decay.setter
def scale_decay(self, value: float) -> None:
self.controlnet.scale_decay = value
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("controlnet", {f"condition_{self.name}": condition})

View file

@ -179,6 +179,21 @@ def controlnet_data(
yield (cn_name, condition_image, expected_image, weights_path)
@pytest.fixture(scope="module", params=["canny"])
def controlnet_data_scale_decay(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
weights_fn = {
"canny": "lllyasviel_control_v11p_sd15_canny",
}
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
yield (cn_name, condition_image, expected_image, weights_path)
@pytest.fixture(scope="module")
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny"
@ -1076,6 +1091,50 @@ def test_diffusion_controlnet(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_controlnet_scale_decay(
sd15_std: StableDiffusion_1,
controlnet_data_scale_decay: tuple[str, Image.Image, Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_scale_decay
if not cn_weights_path.is_file():
warn(f"could not find weights at {cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(30)
# Using default value of 0.825 chosen by lvmin
# https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, scale_decay=0.825, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_diffusion_controlnet_structural_copy(
sd15_std: StableDiffusion_1,

View file

@ -57,6 +57,8 @@ Special cases:
- `expected_controllora_PyraCanny+CPDS.png`
- `expected_controllora_disabled.png`
- `expected_style_aligned.png`
- `expected_controlnet_<name>.png` (canny|depth|lineart|normals|sam|stack)
- `expected_controlnet_<name>_scale_decay.png` (canny)
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 432 KiB