mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
Add scale_decay parameter for SD1 ControlNet
This commit is contained in:
parent
567cc9c6d9
commit
15ccdb38f3
|
@ -70,8 +70,15 @@ class ConditionEncoder(Chain):
|
|||
|
||||
|
||||
class Controlnet(Passthrough):
|
||||
scale_decays: list[float]
|
||||
|
||||
def __init__(
|
||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||
self,
|
||||
name: str,
|
||||
scale: float = 1.0,
|
||||
scale_decay: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
|
||||
|
||||
|
@ -79,9 +86,16 @@ class Controlnet(Passthrough):
|
|||
stored in the context.
|
||||
|
||||
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||
|
||||
Scale decay of 0.825 corresponds to the "Prompt is more important" Control Mode of sd-webui-controlnet plugin
|
||||
https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
|
||||
See also the so-called "Guess Mode" in the official ControlNet demos which uses such scales:
|
||||
https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode
|
||||
"""
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self._scale_decay = scale_decay
|
||||
self.compute_scale_decays()
|
||||
super().__init__(
|
||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||
Slicing(dim=1, end=4), # support inpainting
|
||||
|
@ -134,19 +148,38 @@ class Controlnet(Passthrough):
|
|||
def _store_nth_residual(self, n: int):
|
||||
def _store_residual(x: Tensor):
|
||||
residuals = self.use_context("unet")["residuals"]
|
||||
residuals[n] = residuals[n] + x * self.scale
|
||||
residuals[n] = residuals[n] + x * self.scale * self.scale_decays[n]
|
||||
return x
|
||||
|
||||
return _store_residual
|
||||
|
||||
@property
|
||||
def scale_decay(self) -> float:
|
||||
return self._scale_decay
|
||||
|
||||
@scale_decay.setter
|
||||
def scale_decay(self, value: float) -> None:
|
||||
self._scale_decay = value
|
||||
self.compute_scale_decays()
|
||||
|
||||
def compute_scale_decays(self) -> None:
|
||||
self.scale_decays = [self.scale_decay ** float(12 - i) for i in range(13)]
|
||||
|
||||
|
||||
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||
def __init__(
|
||||
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
|
||||
self,
|
||||
target: SD1UNet,
|
||||
name: str,
|
||||
scale: float = 1.0,
|
||||
scale_decay: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
|
||||
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
|
||||
controlnet = Controlnet(
|
||||
name=name, scale=scale, scale_decay=scale_decay, device=target.device, dtype=target.dtype
|
||||
)
|
||||
if weights is not None:
|
||||
controlnet.load_state_dict(weights)
|
||||
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
|
||||
|
@ -176,11 +209,19 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
|||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._controlnet[0].scale
|
||||
return self.controlnet.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self._controlnet[0].scale = value
|
||||
self.controlnet.scale = value
|
||||
|
||||
@property
|
||||
def scale_decay(self) -> float:
|
||||
return self.controlnet.scale_decay
|
||||
|
||||
@scale_decay.setter
|
||||
def scale_decay(self, value: float) -> None:
|
||||
self.controlnet.scale_decay = value
|
||||
|
||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
||||
|
|
|
@ -179,6 +179,21 @@ def controlnet_data(
|
|||
yield (cn_name, condition_image, expected_image, weights_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["canny"])
|
||||
def controlnet_data_scale_decay(
|
||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||
cn_name: str = request.param
|
||||
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
|
||||
weights_fn = {
|
||||
"canny": "lllyasviel_control_v11p_sd15_canny",
|
||||
}
|
||||
|
||||
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
|
||||
yield (cn_name, condition_image, expected_image, weights_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||
cn_name = "canny"
|
||||
|
@ -1076,6 +1091,50 @@ def test_diffusion_controlnet(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_controlnet_scale_decay(
|
||||
sd15_std: StableDiffusion_1,
|
||||
controlnet_data_scale_decay: tuple[str, Image.Image, Image.Image, Path],
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
|
||||
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_scale_decay
|
||||
|
||||
if not cn_weights_path.is_file():
|
||||
warn(f"could not find weights at {cn_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
# Using default value of 0.825 chosen by lvmin
|
||||
# https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, scale_decay=0.825, weights=load_from_safetensors(cn_weights_path)
|
||||
).inject()
|
||||
|
||||
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
for step in sd15.steps:
|
||||
controlnet.set_controlnet_condition(cn_condition)
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_controlnet_structural_copy(
|
||||
sd15_std: StableDiffusion_1,
|
||||
|
|
|
@ -57,6 +57,8 @@ Special cases:
|
|||
- `expected_controllora_PyraCanny+CPDS.png`
|
||||
- `expected_controllora_disabled.png`
|
||||
- `expected_style_aligned.png`
|
||||
- `expected_controlnet_<name>.png` (canny|depth|lineart|normals|sam|stack)
|
||||
- `expected_controlnet_<name>_scale_decay.png` (canny)
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 432 KiB |
Loading…
Reference in a new issue