diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 0d50ae3..8243d9d 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -4,6 +4,7 @@ Download and convert weights for testing To see what weights will be downloaded and converted, run: DRY_RUN=1 python scripts/prepare_test_weights.py """ + import hashlib import os import subprocess diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 802ee61..2868159 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -129,8 +129,7 @@ class Lora(Generic[T], fl.Chain, ABC): return loras @abstractmethod - def is_compatible(self, layer: fl.WeightedModule, /) -> bool: - ... + def is_compatible(self, layer: fl.WeightedModule, /) -> bool: ... def auto_attach( self, diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 767c6e1..0c5c2f9 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -256,16 +256,13 @@ class Chain(ContextModule): self._modules = generate_unique_names(tuple(modules)) # type: ignore @overload - def __getitem__(self, key: int) -> Module: - ... + def __getitem__(self, key: int) -> Module: ... @overload - def __getitem__(self, key: str) -> Module: - ... + def __getitem__(self, key: str) -> Module: ... @overload - def __getitem__(self, key: slice) -> "Chain": - ... + def __getitem__(self, key: slice) -> "Chain": ... def __getitem__(self, key: int | str | slice) -> Module: if isinstance(key, slice): @@ -324,16 +321,14 @@ class Chain(ContextModule): self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False, - ) -> Iterator[tuple[Module, "Chain"]]: - ... + ) -> Iterator[tuple[Module, "Chain"]]: ... @overload def walk( self, predicate: type[T], recurse: bool = False, - ) -> Iterator[tuple[T, "Chain"]]: - ... + ) -> Iterator[tuple[T, "Chain"]]: ... def walk( self, diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 75fc5a0..415b676 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -440,18 +440,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) @overload - def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: - ... + def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ... @overload - def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: - ... + def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ... @overload def compute_clip_image_embedding( self, image_prompt: list[Image.Image], weights: list[float] | None = None - ) -> Tensor: - ... + ) -> Tensor: ... def compute_clip_image_embedding( self, diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 79e3272..497b1bd 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -66,22 +66,18 @@ class LatentDiffusionModel(fl.Module, ABC): return self.solver.inference_steps @abstractmethod - def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: - ... + def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ... @abstractmethod - def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: - ... + def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ... @abstractmethod - def has_self_attention_guidance(self) -> bool: - ... + def has_self_attention_guidance(self) -> bool: ... @abstractmethod def compute_self_attention_guidance( self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor - ) -> Tensor: - ... + ) -> Tensor: ... def forward( self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor diff --git a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py index 1f9c482..131357f 100644 --- a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py @@ -68,8 +68,7 @@ class MultiDiffusion(Generic[T, D], ABC): return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x) @abstractmethod - def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: - ... + def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ... @property def steps(self) -> list[int]: diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index cbc1472..644fd83 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -20,56 +20,38 @@ class CallbackConfig(BaseModel): class Callback(Generic[T]): - def on_init_begin(self, trainer: T) -> None: - ... + def on_init_begin(self, trainer: T) -> None: ... - def on_init_end(self, trainer: T) -> None: - ... + def on_init_end(self, trainer: T) -> None: ... - def on_train_begin(self, trainer: T) -> None: - ... + def on_train_begin(self, trainer: T) -> None: ... - def on_train_end(self, trainer: T) -> None: - ... + def on_train_end(self, trainer: T) -> None: ... - def on_epoch_begin(self, trainer: T) -> None: - ... + def on_epoch_begin(self, trainer: T) -> None: ... - def on_epoch_end(self, trainer: T) -> None: - ... + def on_epoch_end(self, trainer: T) -> None: ... - def on_batch_begin(self, trainer: T) -> None: - ... + def on_batch_begin(self, trainer: T) -> None: ... - def on_batch_end(self, trainer: T) -> None: - ... + def on_batch_end(self, trainer: T) -> None: ... - def on_backward_begin(self, trainer: T) -> None: - ... + def on_backward_begin(self, trainer: T) -> None: ... - def on_backward_end(self, trainer: T) -> None: - ... + def on_backward_end(self, trainer: T) -> None: ... - def on_optimizer_step_begin(self, trainer: T) -> None: - ... + def on_optimizer_step_begin(self, trainer: T) -> None: ... - def on_optimizer_step_end(self, trainer: T) -> None: - ... + def on_optimizer_step_end(self, trainer: T) -> None: ... - def on_compute_loss_begin(self, trainer: T) -> None: - ... + def on_compute_loss_begin(self, trainer: T) -> None: ... - def on_compute_loss_end(self, trainer: T) -> None: - ... + def on_compute_loss_end(self, trainer: T) -> None: ... - def on_evaluate_begin(self, trainer: T) -> None: - ... + def on_evaluate_begin(self, trainer: T) -> None: ... - def on_evaluate_end(self, trainer: T) -> None: - ... + def on_evaluate_end(self, trainer: T) -> None: ... - def on_lr_scheduler_step_begin(self, trainer: T) -> None: - ... + def on_lr_scheduler_step_begin(self, trainer: T) -> None: ... - def on_lr_scheduler_step_end(self, trainer: T) -> None: - ... + def on_lr_scheduler_step_end(self, trainer: T) -> None: ... diff --git a/src/refiners/training_utils/huggingface_datasets.py b/src/refiners/training_utils/huggingface_datasets.py index 63d3826..715b2da 100644 --- a/src/refiners/training_utils/huggingface_datasets.py +++ b/src/refiners/training_utils/huggingface_datasets.py @@ -10,11 +10,9 @@ T = TypeVar("T", covariant=True) class HuggingfaceDataset(Generic[T], Protocol): - def __getitem__(self, index: int) -> T: - ... + def __getitem__(self, index: int) -> T: ... - def __len__(self) -> int: - ... + def __len__(self) -> int: ... def load_hf_dataset( diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index c68a08e..8653286 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -335,8 +335,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): ) @abstractmethod - def compute_loss(self, batch: Batch) -> Tensor: - ... + def compute_loss(self, batch: Batch) -> Tensor: ... def compute_evaluation(self) -> None: pass diff --git a/tests/adapters/test_ip_adapter.py b/tests/adapters/test_ip_adapter.py index 8e18ece..570a46d 100644 --- a/tests/adapters/test_ip_adapter.py +++ b/tests/adapters/test_ip_adapter.py @@ -10,13 +10,11 @@ from refiners.foundationals.latent_diffusion.image_prompt import ImageCrossAtten @overload -def new_adapter(target: SD1UNet) -> SD1IPAdapter: - ... +def new_adapter(target: SD1UNet) -> SD1IPAdapter: ... @overload -def new_adapter(target: SDXLUNet) -> SDXLIPAdapter: - ... +def new_adapter(target: SDXLUNet) -> SDXLIPAdapter: ... def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter: diff --git a/tests/adapters/test_t2i_adapter.py b/tests/adapters/test_t2i_adapter.py index 3cb836e..227e57b 100644 --- a/tests/adapters/test_t2i_adapter.py +++ b/tests/adapters/test_t2i_adapter.py @@ -9,13 +9,11 @@ from refiners.foundationals.latent_diffusion.t2i_adapter import T2IFeatures @overload -def new_adapter(target: SD1UNet, name: str) -> SD1T2IAdapter: - ... +def new_adapter(target: SD1UNet, name: str) -> SD1T2IAdapter: ... @overload -def new_adapter(target: SDXLUNet, name: str) -> SDXLT2IAdapter: - ... +def new_adapter(target: SDXLUNet, name: str) -> SDXLT2IAdapter: ... def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2IAdapter: diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 4563c15..d738ff2 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -19,8 +19,7 @@ class DiffusersSDXL(Protocol): tokenizer_2: fl.Module vae: fl.Module - def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: - ... + def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ... def encode_prompt( self, @@ -28,8 +27,7 @@ class DiffusersSDXL(Protocol): prompt_2: str | None = None, negative_prompt: str | None = None, negative_prompt_2: str | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - ... + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... @pytest.fixture(scope="module") diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index fa18e88..ca3f239 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -32,19 +32,16 @@ class FacebookSAM(nn.Module): prompt_encoder: nn.Module mask_decoder: nn.Module - def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: - ... + def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ... @property - def device(self) -> Any: - ... + def device(self) -> Any: ... class FacebookSAMPredictor: model: FacebookSAM - def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: - ... + def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ... def predict( self, @@ -54,8 +51,7 @@ class FacebookSAMPredictor: mask_input: NDArray | None = None, multimask_output: bool = True, return_logits: bool = False, - ) -> tuple[NDArray, NDArray, NDArray]: - ... + ) -> tuple[NDArray, NDArray, NDArray]: ... @dataclass