From f22f969d659648f0deaf4aefce2cc6cae3651019 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 4 Dec 2023 10:49:26 +0100 Subject: [PATCH] remove Black preview mode also fix multiline logs in training --- pyproject.toml | 1 - src/refiners/fluxion/layers/chain.py | 15 ++-- .../foundationals/latent_diffusion/model.py | 12 ++- .../latent_diffusion/multi_diffusion.py | 3 +- .../latent_diffusion/schedulers/ddpm.py | 4 +- src/refiners/training_utils/callback.py | 79 ++++++++++++------- .../training_utils/huggingface_datasets.py | 6 +- src/refiners/training_utils/trainer.py | 6 +- .../test_sdxl_double_encoder.py | 15 ++-- tests/foundationals/segment_anything/utils.py | 12 ++- 10 files changed, 98 insertions(+), 55 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1af9219..aa29035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ build-backend = "poetry.core.masonry.api" [tool.black] line-length = 120 -preview = true [tool.ruff] ignore = [ diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index bf99be1..d67cd6e 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -292,13 +292,16 @@ class Chain(ContextModule): return Chain(*self, *other) @overload - def __getitem__(self, key: int) -> Module: ... + def __getitem__(self, key: int) -> Module: + ... @overload - def __getitem__(self, key: str) -> Module: ... + def __getitem__(self, key: str) -> Module: + ... @overload - def __getitem__(self, key: slice) -> "Chain": ... + def __getitem__(self, key: slice) -> "Chain": + ... def __getitem__(self, key: int | str | slice) -> Module: if isinstance(key, slice): @@ -346,10 +349,12 @@ class Chain(ContextModule): @overload def walk( self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False - ) -> Iterator[tuple[Module, "Chain"]]: ... + ) -> Iterator[tuple[Module, "Chain"]]: + ... @overload - def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: ... + def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: + ... def walk( self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 7882dba..7f5f5da 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -65,18 +65,22 @@ class LatentDiffusionModel(fl.Module, ABC): return self.scheduler.steps @abstractmethod - def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ... + def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: + ... @abstractmethod - def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ... + def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: + ... @abstractmethod - def has_self_attention_guidance(self) -> bool: ... + def has_self_attention_guidance(self) -> bool: + ... @abstractmethod def compute_self_attention_guidance( self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor - ) -> Tensor: ... + ) -> Tensor: + ... def forward( self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor diff --git a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py index cb83411..19a0a08 100644 --- a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py @@ -69,7 +69,8 @@ class MultiDiffusion(Generic[T, D], ABC): return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x) @abstractmethod - def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ... + def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: + ... @property def steps(self) -> list[int]: diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index db004fa..528d395 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -56,7 +56,9 @@ class DDPM(Scheduler): else tensor(1, device=self.device) ) current_factor = current_cumulative_factor / previous_cumulative_scale_factor - estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5 + estimated_denoised_data = ( + x - (1 - current_cumulative_factor) ** 0.5 * noise + ) / current_cumulative_factor**0.5 estimated_denoised_data = estimated_denoised_data.clamp(-1, 1) original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / ( 1 - current_cumulative_factor diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 8c8a715..b2ddd4d 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -42,59 +42,82 @@ T = TypeVar("T") class Callback(Generic[T]): - def on_train_begin(self, trainer: T) -> None: ... + def on_train_begin(self, trainer: T) -> None: + ... - def on_train_end(self, trainer: T) -> None: ... + def on_train_end(self, trainer: T) -> None: + ... - def on_epoch_begin(self, trainer: T) -> None: ... + def on_epoch_begin(self, trainer: T) -> None: + ... - def on_epoch_end(self, trainer: T) -> None: ... + def on_epoch_end(self, trainer: T) -> None: + ... - def on_batch_begin(self, trainer: T) -> None: ... + def on_batch_begin(self, trainer: T) -> None: + ... - def on_batch_end(self, trainer: T) -> None: ... + def on_batch_end(self, trainer: T) -> None: + ... - def on_backward_begin(self, trainer: T) -> None: ... + def on_backward_begin(self, trainer: T) -> None: + ... - def on_backward_end(self, trainer: T) -> None: ... + def on_backward_end(self, trainer: T) -> None: + ... - def on_optimizer_step_begin(self, trainer: T) -> None: ... + def on_optimizer_step_begin(self, trainer: T) -> None: + ... - def on_optimizer_step_end(self, trainer: T) -> None: ... + def on_optimizer_step_end(self, trainer: T) -> None: + ... - def on_compute_loss_begin(self, trainer: T) -> None: ... + def on_compute_loss_begin(self, trainer: T) -> None: + ... - def on_compute_loss_end(self, trainer: T) -> None: ... + def on_compute_loss_end(self, trainer: T) -> None: + ... - def on_evaluate_begin(self, trainer: T) -> None: ... + def on_evaluate_begin(self, trainer: T) -> None: + ... - def on_evaluate_end(self, trainer: T) -> None: ... + def on_evaluate_end(self, trainer: T) -> None: + ... - def on_lr_scheduler_step_begin(self, trainer: T) -> None: ... + def on_lr_scheduler_step_begin(self, trainer: T) -> None: + ... - def on_lr_scheduler_step_end(self, trainer: T) -> None: ... + def on_lr_scheduler_step_end(self, trainer: T) -> None: + ... - def on_checkpoint_save(self, trainer: T) -> None: ... + def on_checkpoint_save(self, trainer: T) -> None: + ... class ClockCallback(Callback["Trainer[BaseConfig, Any]"]): def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: trainer.clock.reset() - logger.info(f"""Starting training for a total of: - {trainer.clock.num_steps} steps. - {trainer.clock.num_epochs} epochs. - {trainer.clock.num_iterations} iterations. - """) + logger.info( + ( + "Starting training for a total of: " + f"{trainer.clock.num_steps} steps, " + f"{trainer.clock.num_epochs} epochs, " + f"{trainer.clock.num_iterations} iterations." + ) + ) trainer.clock.start_timer() def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: trainer.clock.stop_timer() - logger.info(f"""Training took: - {trainer.clock.time_elapsed} seconds. - {trainer.clock.iteration} iterations. - {trainer.clock.epoch} epochs. - {trainer.clock.step} steps. - """) + logger.info( + ( + "Training took: " + f"{trainer.clock.time_elapsed} seconds, " + f"{trainer.clock.iteration} iterations, " + f"{trainer.clock.epoch} epochs, " + f"{trainer.clock.step} steps." + ) + ) def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: logger.info(f"Epoch {trainer.clock.epoch} started.") diff --git a/src/refiners/training_utils/huggingface_datasets.py b/src/refiners/training_utils/huggingface_datasets.py index fdf6986..956ad0f 100644 --- a/src/refiners/training_utils/huggingface_datasets.py +++ b/src/refiners/training_utils/huggingface_datasets.py @@ -8,9 +8,11 @@ T = TypeVar("T", covariant=True) class HuggingfaceDataset(Generic[T], Protocol): - def __getitem__(self, index: int) -> T: ... + def __getitem__(self, index: int) -> T: + ... - def __len__(self) -> int: ... + def __len__(self) -> int: + ... def load_hf_dataset( diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 70990c0..ea079d0 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -147,13 +147,11 @@ class TrainingClock: @cached_property def unit_to_steps(self) -> dict[TimeUnit, int]: + iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1 return { TimeUnit.STEP: 1, TimeUnit.EPOCH: self.num_batches_per_epoch, - TimeUnit.ITERATION: self.gradient_accumulation["number"] * { - TimeUnit.STEP: 1, - TimeUnit.EPOCH: self.num_batches_per_epoch, - }.get(self.gradient_accumulation["unit"], 1), + TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor, } def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 29420c32..1533a9d 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -18,7 +18,8 @@ class DiffusersSDXL(Protocol): tokenizer_2: fl.Module vae: fl.Module - def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ... + def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: + ... def encode_prompt( self, @@ -26,7 +27,8 @@ class DiffusersSDXL(Protocol): prompt_2: str | None = None, negative_prompt: str | None = None, negative_prompt_2: str | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ... + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ... @pytest.fixture(scope="module") @@ -67,9 +69,12 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: manual_seed(seed=0) prompt = "A photo of a pizza." - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = ( - diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="") - ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="") double_embedding, pooled_embedding = double_text_encoder(prompt) diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index 37085e2..274726c 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -32,16 +32,19 @@ class FacebookSAM(nn.Module): prompt_encoder: nn.Module mask_decoder: nn.Module - def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ... + def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: + ... @property - def device(self) -> Any: ... + def device(self) -> Any: + ... class FacebookSAMPredictor: model: FacebookSAM - def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ... + def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: + ... def predict( self, @@ -51,7 +54,8 @@ class FacebookSAMPredictor: mask_input: NDArray | None = None, multimask_output: bool = True, return_logits: bool = False, - ) -> tuple[NDArray, NDArray, NDArray]: ... + ) -> tuple[NDArray, NDArray, NDArray]: + ... @dataclass