diff --git a/scripts/training/finetune-ldm-textual-inversion.py b/scripts/training/finetune-ldm-textual-inversion.py deleted file mode 100644 index 18142b3..0000000 --- a/scripts/training/finetune-ldm-textual-inversion.py +++ /dev/null @@ -1,162 +0,0 @@ -import random -from typing import Any - -from loguru import logger -from pydantic import BaseModel -from torch import Tensor, randn -from torch.utils.data import Dataset - -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender -from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder -from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.training_utils.callback import Callback -from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig -from refiners.training_utils.latent_diffusion import ( - FinetuneLatentDiffusionConfig, - LatentDiffusionConfig, - LatentDiffusionTrainer, - TextEmbeddingLatentsBatch, - TextEmbeddingLatentsDataset, -) - -IMAGENET_TEMPLATES_SMALL = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", -] - -IMAGENET_STYLE_TEMPLATES_SMALL = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", -] - - -class TextualInversionDataset(TextEmbeddingLatentsDataset): - templates: list[str] = [] - placeholder_token: str = "" - - def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None: - super().__init__(trainer) - self.templates = ( - IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL - ) - self.placeholder_token = self.config.textual_inversion.placeholder_token - - def get_caption(self, index: int) -> str: - # Ignore the dataset caption, if any: use a template instead - return random.choice(self.templates).format(self.placeholder_token) - - -class TextualInversionConfig(BaseModel): - # The new token to be learned - placeholder_token: str = "*" - # The token to be used as initializer; if None, a random vector is used - initializer_token: str | None = None - style_mode: bool = False - - def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None: - adapter = ConceptExtender(target=text_encoder) - tokenizer = text_encoder.ensure_find(CLIPTokenizer) - token_encoder = text_encoder.ensure_find(TokenEncoder) - if self.initializer_token is not None: - bpe = tokenizer.byte_pair_encoding(token=self.initializer_token) - assert " " not in bpe, "This initializer_token is not a single token." - token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device) - init_embedding = token_encoder(token).squeeze(0) - else: - token_encoder = text_encoder.ensure_find(TokenEncoder) - init_embedding = randn(token_encoder.embedding_dim) - adapter.add_concept(self.placeholder_token, init_embedding) - adapter.inject() - - -class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig): - dataset: HuggingfaceDatasetConfig - latent_diffusion: LatentDiffusionConfig - textual_inversion: TextualInversionConfig - - def model_post_init(self, __context: Any) -> None: - # Pydantic v2 does post init differently, so we need to override this method too. - logger.info("Freezing models to train only the new embedding.") - self.models["unet"].train = False - self.models["text_encoder"].train = False - self.models["lda"].train = False - - -class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]): - def __init__( - self, - config: TextualInversionLatentDiffusionConfig, - callbacks: "list[Callback[Any]] | None" = None, - ) -> None: - super().__init__(config=config, callbacks=callbacks) - self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion())) - - def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: - return TextualInversionDataset(trainer=self) - - -class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): - def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None: - trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder) - - -class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]): - def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None: - embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender) - tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)} - - save_to_safetensors( - path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors - ) - - -if __name__ == "__main__": - import sys - - config_path = sys.argv[1] - config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path) - trainer = TextualInversionLatentDiffusionTrainer(config=config) - trainer.train() diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index f1471ec..2d7c0b2 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -13,8 +13,6 @@ __all__ = [ "GradientNormClipping", "GradientValueClipping", "ClockCallback", - "GradientNormLogging", - "MonitorLoss", ] @@ -153,31 +151,6 @@ class ClockCallback(Callback["Trainer[BaseConfig, Any]"]): logger.info("Evaluation ended.") -class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]): - def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.epoch_losses: list[float] = [] - self.iteration_losses: list[float] = [] - - def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - loss_value = trainer.loss.detach().cpu().item() - self.epoch_losses.append(loss_value) - self.iteration_losses.append(loss_value) - trainer.log(data={"step_loss": loss_value}) - - def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses) - trainer.log(data={"average_iteration_loss": avg_iteration_loss}) - self.iteration_losses = [] - - def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses) - trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch}) - self.epoch_losses = [] - - def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]}) - - class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: clip_norm = trainer.config.training.clip_grad_norm @@ -192,8 +165,3 @@ class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]): clip_value = trainer.config.training.clip_grad_value if clip_value is not None: clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value) - - -class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]): - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.log(data={"total_grad_norm": trainer.total_gradient_norm}) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index c484b6a..0901f5d 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -196,37 +196,15 @@ class DropoutConfig(BaseModel): apply_dropout(module=model, probability=self.dropout_probability) -class WandbConfig(BaseModel): - mode: Literal["online", "offline", "disabled"] = "online" - project: str - entity: str = "finegrain" - name: str | None = None - tags: list[str] = [] - group: str | None = None - job_type: str | None = None - notes: str | None = None - - -class CheckpointingConfig(BaseModel): - save_folder: Path | None = None - save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH} - - @validator("save_interval", pre=True) - def parse_field(cls, value: Any) -> TimeValue: - return parse_number_unit_field(value) - - T = TypeVar("T", bound="BaseConfig") class BaseConfig(BaseModel): models: dict[str, ModelConfig] - wandb: WandbConfig training: TrainingConfig optimizer: OptimizerConfig scheduler: SchedulerConfig dropout: DropoutConfig - checkpointing: CheckpointingConfig @classmethod def load_from_toml(cls: Type[T], toml_path: Path | str) -> T: diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 28aef44..9600ce6 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -204,7 +204,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): ) canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i)) images[prompt] = canvas_image - self.log(data=images) + # TODO: wandb_log images def sample_noise( @@ -257,4 +257,4 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]): log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss self.timestep_bins[bin_index] = [] - trainer.log(data=log_data) + # TODO: wandb_log log_data diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 8344a99..0d0e72d 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -2,7 +2,6 @@ import random import time from abc import ABC, abstractmethod from functools import cached_property, wraps -from pathlib import Path from typing import Any, Callable, Generic, Iterable, TypeVar, cast import numpy as np @@ -33,13 +32,10 @@ from refiners.training_utils.callback import ( Callback, ClockCallback, GradientNormClipping, - GradientNormLogging, GradientValueClipping, - MonitorLoss, ) from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.dropout import DropoutCallback -from refiners.training_utils.wandb import WandbLoggable, WandbLogger __all__ = ["seed_everything", "scoped_seed", "Trainer"] @@ -130,7 +126,6 @@ class TrainingClock: gradient_accumulation: TimeValue, evaluation_interval: TimeValue, lr_scheduler_interval: TimeValue, - checkpointing_save_interval: TimeValue, ) -> None: self.dataset_length = dataset_length self.batch_size = batch_size @@ -138,7 +133,6 @@ class TrainingClock: self.gradient_accumulation = gradient_accumulation self.evaluation_interval = evaluation_interval self.lr_scheduler_interval = lr_scheduler_interval - self.checkpointing_save_interval = checkpointing_save_interval self.num_batches_per_epoch = dataset_length // batch_size self.start_time = None self.end_time = None @@ -225,12 +219,6 @@ class TrainingClock: number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"] ) - @cached_property - def checkpointing_save_interval_steps(self) -> int: - return self.convert_time_unit_to_steps( - number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"] - ) - @property def is_optimizer_step(self) -> bool: return self.num_minibatches_processed == self.num_step_per_iteration @@ -247,10 +235,6 @@ class TrainingClock: def is_evaluation_step(self) -> bool: return self.step % self.evaluation_interval_steps == 0 - @property - def is_checkpointing_step(self) -> bool: - return self.step % self.checkpointing_save_interval_steps == 0 - def compute_grad_norm(parameters: Iterable[Parameter]) -> float: """ @@ -276,27 +260,35 @@ class Trainer(Generic[ConfigType, Batch], ABC): evaluation_interval=config.training.evaluation_interval, gradient_accumulation=config.training.gradient_accumulation, lr_scheduler_interval=config.scheduler.update_interval, - checkpointing_save_interval=config.checkpointing.save_interval, ) self.callbacks = callbacks or [] self.callbacks += self.default_callbacks() self._call_callbacks(event_name="on_init_begin") - self.load_wandb() self.load_models() self.prepare_models() - self.prepare_checkpointing() self._call_callbacks(event_name="on_init_end") def default_callbacks(self) -> list[Callback[Any]]: - return [ + callbacks: list[Callback[Any]] = [ ClockCallback(), - MonitorLoss(), - GradientNormLogging(), GradientValueClipping(), GradientNormClipping(), DropoutCallback(), ] + # look for any Callback that might be a property of the Trainer + for attr_name in dir(self): + if "__" in attr_name: + continue + + try: + attr = getattr(self, attr_name) + except AssertionError: + continue + if isinstance(attr, Callback): + callbacks.append(cast(Callback[Any], attr)) + return callbacks + @cached_property def device(self) -> Device: selected_device = Device(self.config.training.device) @@ -417,13 +409,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): for model in self.models.values(): model.eval() - def log(self, data: dict[str, WandbLoggable]) -> None: - self.wandb.log(data=data, step=self.clock.step) - - def load_wandb(self) -> None: - init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()} - self.wandb = WandbLogger(init_config=init_config) - def prepare_model(self, model_name: str) -> None: model = self.models[model_name] if (checkpoint := self.config.models[model_name].checkpoint) is not None: @@ -439,18 +424,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): for model_name in self.models: self.prepare_model(model_name=model_name) - def prepare_checkpointing(self) -> None: - if self.config.checkpointing.save_folder is not None: - assert self.config.checkpointing.save_folder.is_dir() - self.checkpoints_save_folder = ( - self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name - ) - self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False) - logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}") - else: - self.checkpoints_save_folder = None - logger.info("Checkpointing disabled: configure `save_folder` to turn it on.") - @abstractmethod def load_models(self) -> dict[str, fl.Module]: ... @@ -475,15 +448,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn ) - @property - def checkpointing_enabled(self) -> bool: - return self.checkpoints_save_folder is not None - - @property - def ensure_checkpoints_save_folder(self) -> Path: - assert self.checkpoints_save_folder is not None - return self.checkpoints_save_folder - @abstractmethod def compute_loss(self, batch: Batch) -> Tensor: ... @@ -508,8 +472,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): self._call_callbacks(event_name="on_lr_scheduler_step_end") if self.clock.is_evaluation_step: self.evaluate() - if self.checkpointing_enabled and self.clock.is_checkpointing_step: - self._call_callbacks(event_name="on_checkpoint_save") def step(self, batch: Batch) -> None: """Perform a single training step.""" diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index 2683f89..efe2b4e 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -1,13 +1,15 @@ -from typing import Any +from abc import ABC +from functools import cached_property +from pathlib import Path +from typing import Any, Literal import wandb from PIL import Image +from pydantic import BaseModel -__all__ = [ - "WandbLogger", - "WandbLoggable", -] - +from refiners.training_utils.callback import Callback +from refiners.training_utils.config import BaseConfig +from refiners.training_utils.trainer import Trainer number = float | int WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]] @@ -60,3 +62,87 @@ class WandbLogger: @property def run_name(self) -> str: return self.wandb_run.name or "" # type: ignore + + +class WandbConfig(BaseModel): + """ + Wandb configuration. + + See https://docs.wandb.ai/ref/python/init for more details. + """ + + mode: Literal["online", "offline", "disabled"] = "disabled" + project: str + entity: str | None = None + save_code: bool | None = None + name: str | None = None + tags: list[str] = [] + group: str | None = None + job_type: str | None = None + notes: str | None = None + dir: Path | None = None + resume: bool | Literal["allow", "must", "never", "auto"] | None = None + reinit: bool | None = None + magic: bool | None = None + anonymous: Literal["never", "allow", "must"] | None = None + id: str | None = None + + +AnyTrainer = Trainer[BaseConfig, Any] + + +class WandbCallback(Callback["TrainerWithWandb"]): + epoch_losses: list[float] + iteration_losses: list[float] + + def on_init_begin(self, trainer: "TrainerWithWandb") -> None: + trainer.load_wandb() + + def on_train_begin(self, trainer: "TrainerWithWandb") -> None: + self.epoch_losses = [] + self.iteration_losses = [] + + def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None: + loss_value = trainer.loss.detach().cpu().item() + self.epoch_losses.append(loss_value) + self.iteration_losses.append(loss_value) + trainer.wandb_log(data={"step_loss": loss_value}) + + def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None: + avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses) + trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss}) + self.iteration_losses = [] + + def on_epoch_end(self, trainer: "TrainerWithWandb") -> None: + avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses) + trainer.wandb_log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch}) + self.epoch_losses = [] + + def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None: + trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]}) + + def on_backward_end(self, trainer: "TrainerWithWandb") -> None: + trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm}) + + +class WandbMixin(ABC): + config: Any + wandb_logger: WandbLogger + + def load_wandb(self) -> None: + wandb_config = getattr(self.config, "wandb", None) + assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set" + init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()} + self.wandb_logger = WandbLogger(init_config=init_config) + + def wandb_log(self, data: dict[str, WandbLoggable]) -> None: + assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer" + self.wandb_logger.log(data=data, step=self.clock.step) + + @cached_property + def wandb_callback(self) -> WandbCallback: + return WandbCallback() + + +class TrainerWithWandb(AnyTrainer, WandbMixin, ABC): + pass diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index d96a820..f260876 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -24,10 +24,3 @@ warmup = "20:step" [dropout] dropout = 0.0 - -[checkpointing] -save_interval = "10:epoch" - -[wandb] -mode = "disabled" -project = "mock_project" diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index ae88eef..505c0df 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -119,7 +119,6 @@ def training_clock() -> TrainingClock: gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH}, evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH}, lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH}, - checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH}, ) @@ -153,10 +152,8 @@ def test_timer_functionality(training_clock: TrainingClock) -> None: def test_state_based_properties(training_clock: TrainingClock) -> None: training_clock.step = 5 # Halfway through the first epoch assert not training_clock.is_evaluation_step # Assuming evaluation every epoch - assert not training_clock.is_checkpointing_step training_clock.step = 10 # End of the first epoch assert training_clock.is_evaluation_step - assert training_clock.is_checkpointing_step def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: