mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
use same scale setter / getter interface for all controls
This commit is contained in:
parent
5e7986ef08
commit
1eb71077aa
|
@ -424,10 +424,6 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
for cross_attn in self.sub_adapters:
|
||||
cross_attn.scale = value
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
for cross_attn in self.sub_adapters:
|
||||
cross_attn.scale = scale
|
||||
|
||||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||
"""Set the CLIP image embedding context.
|
||||
|
||||
|
|
|
@ -174,8 +174,13 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
|||
def init_context(self) -> Contexts:
|
||||
return {"controlnet": {f"condition_{self.name}": None}}
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self._controlnet[0].scale = scale
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._controlnet[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self._controlnet[0].scale = value
|
||||
|
||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
||||
|
|
|
@ -204,9 +204,14 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
|
||||
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._features[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for f in self._features:
|
||||
f.scale = scale
|
||||
f.scale = value
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
||||
|
|
|
@ -2109,7 +2109,7 @@ def test_t2i_adapter_xl_canny(
|
|||
sdxl.set_inference_steps(30)
|
||||
|
||||
t2i_adapter = SDXLT2IAdapter(target=sdxl.unet, name=name, weights=load_from_safetensors(weights_path)).inject()
|
||||
t2i_adapter.set_scale(0.8)
|
||||
t2i_adapter.scale = 0.8
|
||||
|
||||
condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
@ -2238,8 +2238,8 @@ def test_hello_world(
|
|||
condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
|
||||
t2i_adapter.set_condition_features(features=t2i_adapter.compute_condition_features(condition))
|
||||
|
||||
ip_adapter.set_scale(0.85)
|
||||
t2i_adapter.set_scale(0.8)
|
||||
ip_adapter.scale = 0.85
|
||||
t2i_adapter.scale = 0.8
|
||||
sdxl.set_inference_steps(50, first_step=1)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
|
|
Loading…
Reference in a new issue