use same scale setter / getter interface for all controls

This commit is contained in:
Pierre Chapuis 2024-03-08 10:56:27 +01:00
parent 5e7986ef08
commit 1eb71077aa
4 changed files with 17 additions and 11 deletions

View file

@ -424,10 +424,6 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
for cross_attn in self.sub_adapters:
cross_attn.scale = value
def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
"""Set the CLIP image embedding context.

View file

@ -174,8 +174,13 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
def init_context(self) -> Contexts:
return {"controlnet": {f"condition_{self.name}": None}}
def set_scale(self, scale: float) -> None:
self._controlnet[0].scale = scale
@property
def scale(self) -> float:
return self._controlnet[0].scale
@scale.setter
def scale(self, value: float) -> None:
self._controlnet[0].scale = value
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("controlnet", {f"condition_{self.name}": condition})

View file

@ -204,9 +204,14 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
def set_scale(self, scale: float) -> None:
@property
def scale(self) -> float:
return self._features[0].scale
@scale.setter
def scale(self, value: float) -> None:
for f in self._features:
f.scale = scale
f.scale = value
def init_context(self) -> Contexts:
return {"t2iadapter": {f"condition_features_{self.name}": None}}

View file

@ -2109,7 +2109,7 @@ def test_t2i_adapter_xl_canny(
sdxl.set_inference_steps(30)
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
t2i_adapter.set_scale(0.8)
t2i_adapter.scale = 0.8
condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
@ -2238,8 +2238,8 @@ def test_hello_world(
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
ip_adapter.set_scale(0.85)
t2i_adapter.set_scale(0.8)
ip_adapter.scale = 0.85
t2i_adapter.scale = 0.8
sdxl.set_inference_steps(50, first_step=1)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)