mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Add scale_decay parameter for SD1 ControlNet
This commit is contained in:
parent
567cc9c6d9
commit
15ccdb38f3
|
@ -70,8 +70,15 @@ class ConditionEncoder(Chain):
|
||||||
|
|
||||||
|
|
||||||
class Controlnet(Passthrough):
|
class Controlnet(Passthrough):
|
||||||
|
scale_decays: list[float]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
self,
|
||||||
|
name: str,
|
||||||
|
scale: float = 1.0,
|
||||||
|
scale_decay: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
|
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
|
||||||
|
|
||||||
|
@ -79,9 +86,16 @@ class Controlnet(Passthrough):
|
||||||
stored in the context.
|
stored in the context.
|
||||||
|
|
||||||
It has to use the same context as the UNet: `unet` and `sampling`.
|
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||||
|
|
||||||
|
Scale decay of 0.825 corresponds to the "Prompt is more important" Control Mode of sd-webui-controlnet plugin
|
||||||
|
https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
|
||||||
|
See also the so-called "Guess Mode" in the official ControlNet demos which uses such scales:
|
||||||
|
https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self._scale_decay = scale_decay
|
||||||
|
self.compute_scale_decays()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||||
Slicing(dim=1, end=4), # support inpainting
|
Slicing(dim=1, end=4), # support inpainting
|
||||||
|
@ -134,19 +148,38 @@ class Controlnet(Passthrough):
|
||||||
def _store_nth_residual(self, n: int):
|
def _store_nth_residual(self, n: int):
|
||||||
def _store_residual(x: Tensor):
|
def _store_residual(x: Tensor):
|
||||||
residuals = self.use_context("unet")["residuals"]
|
residuals = self.use_context("unet")["residuals"]
|
||||||
residuals[n] = residuals[n] + x * self.scale
|
residuals[n] = residuals[n] + x * self.scale * self.scale_decays[n]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return _store_residual
|
return _store_residual
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale_decay(self) -> float:
|
||||||
|
return self._scale_decay
|
||||||
|
|
||||||
|
@scale_decay.setter
|
||||||
|
def scale_decay(self, value: float) -> None:
|
||||||
|
self._scale_decay = value
|
||||||
|
self.compute_scale_decays()
|
||||||
|
|
||||||
|
def compute_scale_decays(self) -> None:
|
||||||
|
self.scale_decays = [self.scale_decay ** float(12 - i) for i in range(13)]
|
||||||
|
|
||||||
|
|
||||||
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
|
self,
|
||||||
|
target: SD1UNet,
|
||||||
|
name: str,
|
||||||
|
scale: float = 1.0,
|
||||||
|
scale_decay: float = 1.0,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
|
controlnet = Controlnet(
|
||||||
|
name=name, scale=scale, scale_decay=scale_decay, device=target.device, dtype=target.dtype
|
||||||
|
)
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
controlnet.load_state_dict(weights)
|
controlnet.load_state_dict(weights)
|
||||||
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
|
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
|
||||||
|
@ -176,11 +209,19 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
return self._controlnet[0].scale
|
return self.controlnet.scale
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
def scale(self, value: float) -> None:
|
def scale(self, value: float) -> None:
|
||||||
self._controlnet[0].scale = value
|
self.controlnet.scale = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale_decay(self) -> float:
|
||||||
|
return self.controlnet.scale_decay
|
||||||
|
|
||||||
|
@scale_decay.setter
|
||||||
|
def scale_decay(self, value: float) -> None:
|
||||||
|
self.controlnet.scale_decay = value
|
||||||
|
|
||||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||||
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
||||||
|
|
|
@ -179,6 +179,21 @@ def controlnet_data(
|
||||||
yield (cn_name, condition_image, expected_image, weights_path)
|
yield (cn_name, condition_image, expected_image, weights_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=["canny"])
|
||||||
|
def controlnet_data_scale_decay(
|
||||||
|
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||||
|
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
|
||||||
|
cn_name: str = request.param
|
||||||
|
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||||
|
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
|
||||||
|
weights_fn = {
|
||||||
|
"canny": "lllyasviel_control_v11p_sd15_canny",
|
||||||
|
}
|
||||||
|
|
||||||
|
weights_path = test_weights_path / "controlnet" / f"{weights_fn[cn_name]}.safetensors"
|
||||||
|
yield (cn_name, condition_image, expected_image, weights_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||||
cn_name = "canny"
|
cn_name = "canny"
|
||||||
|
@ -1076,6 +1091,50 @@ def test_diffusion_controlnet(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_controlnet_scale_decay(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
controlnet_data_scale_decay: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
|
||||||
|
cn_name, condition_image, expected_image, cn_weights_path = controlnet_data_scale_decay
|
||||||
|
|
||||||
|
if not cn_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {cn_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_inference_steps(30)
|
||||||
|
|
||||||
|
# Using default value of 0.825 chosen by lvmin
|
||||||
|
# https://github.com/Mikubill/sd-webui-controlnet/blob/8e143d3545140b8f0398dfbe1d95a0a766019283/scripts/hook.py#L472
|
||||||
|
controlnet = SD1ControlnetAdapter(
|
||||||
|
sd15.unet, name=cn_name, scale=0.5, scale_decay=0.825, weights=load_from_safetensors(cn_weights_path)
|
||||||
|
).inject()
|
||||||
|
|
||||||
|
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
for step in sd15.steps:
|
||||||
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.latents_to_image(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_controlnet_structural_copy(
|
def test_diffusion_controlnet_structural_copy(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
|
|
|
@ -57,6 +57,8 @@ Special cases:
|
||||||
- `expected_controllora_PyraCanny+CPDS.png`
|
- `expected_controllora_PyraCanny+CPDS.png`
|
||||||
- `expected_controllora_disabled.png`
|
- `expected_controllora_disabled.png`
|
||||||
- `expected_style_aligned.png`
|
- `expected_style_aligned.png`
|
||||||
|
- `expected_controlnet_<name>.png` (canny|depth|lineart|normals|sam|stack)
|
||||||
|
- `expected_controlnet_<name>_scale_decay.png` (canny)
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 432 KiB |
Loading…
Reference in a new issue