From 0e0c39b4b5146bdf3aaf08f04f6580d3048b540d Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 13 Sep 2023 16:27:08 +0200 Subject: [PATCH] black --- src/refiners/fluxion/layers/chain.py | 15 ++---- .../foundationals/latent_diffusion/model.py | 3 +- src/refiners/training_utils/callback.py | 51 +++++++------------ .../training_utils/huggingface_datasets.py | 6 +-- .../test_sdxl_double_encoder.py | 6 +-- 5 files changed, 27 insertions(+), 54 deletions(-) diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index b9853a3..374cf18 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -213,16 +213,13 @@ class Chain(ContextModule): return Chain(*self, *other) @overload - def __getitem__(self, key: int) -> Module: - ... + def __getitem__(self, key: int) -> Module: ... @overload - def __getitem__(self, key: str) -> Module: - ... + def __getitem__(self, key: str) -> Module: ... @overload - def __getitem__(self, key: slice) -> "Chain": - ... + def __getitem__(self, key: slice) -> "Chain": ... def __getitem__(self, key: int | str | slice) -> Module: if isinstance(key, slice): @@ -270,12 +267,10 @@ class Chain(ContextModule): @overload def walk( self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False - ) -> Iterator[tuple[Module, "Chain"]]: - ... + ) -> Iterator[tuple[Module, "Chain"]]: ... @overload - def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: - ... + def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: ... def walk( self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index fcdd298..952fa90 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -66,8 +66,7 @@ class LatentDiffusionModel(fl.Module, ABC): return self.scheduler.steps @abstractmethod - def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: - ... + def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ... def forward( self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index d1d04f0..8c8a715 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -42,56 +42,39 @@ T = TypeVar("T") class Callback(Generic[T]): - def on_train_begin(self, trainer: T) -> None: - ... + def on_train_begin(self, trainer: T) -> None: ... - def on_train_end(self, trainer: T) -> None: - ... + def on_train_end(self, trainer: T) -> None: ... - def on_epoch_begin(self, trainer: T) -> None: - ... + def on_epoch_begin(self, trainer: T) -> None: ... - def on_epoch_end(self, trainer: T) -> None: - ... + def on_epoch_end(self, trainer: T) -> None: ... - def on_batch_begin(self, trainer: T) -> None: - ... + def on_batch_begin(self, trainer: T) -> None: ... - def on_batch_end(self, trainer: T) -> None: - ... + def on_batch_end(self, trainer: T) -> None: ... - def on_backward_begin(self, trainer: T) -> None: - ... + def on_backward_begin(self, trainer: T) -> None: ... - def on_backward_end(self, trainer: T) -> None: - ... + def on_backward_end(self, trainer: T) -> None: ... - def on_optimizer_step_begin(self, trainer: T) -> None: - ... + def on_optimizer_step_begin(self, trainer: T) -> None: ... - def on_optimizer_step_end(self, trainer: T) -> None: - ... + def on_optimizer_step_end(self, trainer: T) -> None: ... - def on_compute_loss_begin(self, trainer: T) -> None: - ... + def on_compute_loss_begin(self, trainer: T) -> None: ... - def on_compute_loss_end(self, trainer: T) -> None: - ... + def on_compute_loss_end(self, trainer: T) -> None: ... - def on_evaluate_begin(self, trainer: T) -> None: - ... + def on_evaluate_begin(self, trainer: T) -> None: ... - def on_evaluate_end(self, trainer: T) -> None: - ... + def on_evaluate_end(self, trainer: T) -> None: ... - def on_lr_scheduler_step_begin(self, trainer: T) -> None: - ... + def on_lr_scheduler_step_begin(self, trainer: T) -> None: ... - def on_lr_scheduler_step_end(self, trainer: T) -> None: - ... + def on_lr_scheduler_step_end(self, trainer: T) -> None: ... - def on_checkpoint_save(self, trainer: T) -> None: - ... + def on_checkpoint_save(self, trainer: T) -> None: ... class ClockCallback(Callback["Trainer[BaseConfig, Any]"]): diff --git a/src/refiners/training_utils/huggingface_datasets.py b/src/refiners/training_utils/huggingface_datasets.py index 956ad0f..fdf6986 100644 --- a/src/refiners/training_utils/huggingface_datasets.py +++ b/src/refiners/training_utils/huggingface_datasets.py @@ -8,11 +8,9 @@ T = TypeVar("T", covariant=True) class HuggingfaceDataset(Generic[T], Protocol): - def __getitem__(self, index: int) -> T: - ... + def __getitem__(self, index: int) -> T: ... - def __len__(self) -> int: - ... + def __len__(self) -> int: ... def load_hf_dataset( diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 172820f..29420c32 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -18,8 +18,7 @@ class DiffusersSDXL(Protocol): tokenizer_2: fl.Module vae: fl.Module - def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: - ... + def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ... def encode_prompt( self, @@ -27,8 +26,7 @@ class DiffusersSDXL(Protocol): prompt_2: str | None = None, negative_prompt: str | None = None, negative_prompt_2: str | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - ... + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... @pytest.fixture(scope="module")