remove wandb from base config

This commit is contained in:
limiteinductive 2024-02-07 09:37:10 +00:00 committed by Benjamin Trom
parent 11da76f7df
commit 2ef4982e04
8 changed files with 108 additions and 286 deletions

View file

@ -1,162 +0,0 @@
import random
from typing import Any
from loguru import logger
from pydantic import BaseModel
from torch import Tensor, randn
from torch.utils.data import Dataset
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
LatentDiffusionTrainer,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
)
IMAGENET_TEMPLATES_SMALL = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
IMAGENET_STYLE_TEMPLATES_SMALL = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
class TextualInversionDataset(TextEmbeddingLatentsDataset):
templates: list[str] = []
placeholder_token: str = ""
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
super().__init__(trainer)
self.templates = (
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
)
self.placeholder_token = self.config.textual_inversion.placeholder_token
def get_caption(self, index: int) -> str:
# Ignore the dataset caption, if any: use a template instead
return random.choice(self.templates).format(self.placeholder_token)
class TextualInversionConfig(BaseModel):
# The new token to be learned
placeholder_token: str = "*"
# The token to be used as initializer; if None, a random vector is used
initializer_token: str | None = None
style_mode: bool = False
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
adapter = ConceptExtender(target=text_encoder)
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
token_encoder = text_encoder.ensure_find(TokenEncoder)
if self.initializer_token is not None:
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
assert " " not in bpe, "This initializer_token is not a single token."
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
init_embedding = token_encoder(token).squeeze(0)
else:
token_encoder = text_encoder.ensure_find(TokenEncoder)
init_embedding = randn(token_encoder.embedding_dim)
adapter.add_concept(self.placeholder_token, init_embedding)
adapter.inject()
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
textual_inversion: TextualInversionConfig
def model_post_init(self, __context: Any) -> None:
# Pydantic v2 does post init differently, so we need to override this method too.
logger.info("Freezing models to train only the new embedding.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
def __init__(
self,
config: TextualInversionLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TextualInversionDataset(trainer=self)
class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
trainer = TextualInversionLatentDiffusionTrainer(config=config)
trainer.train()

View file

@ -13,8 +13,6 @@ __all__ = [
"GradientNormClipping",
"GradientValueClipping",
"ClockCallback",
"GradientNormLogging",
"MonitorLoss",
]
@ -153,31 +151,6 @@ class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
logger.info("Evaluation ended.")
class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]):
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.epoch_losses: list[float] = []
self.iteration_losses: list[float] = []
def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
trainer.log(data={"step_loss": loss_value})
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = []
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_norm = trainer.config.training.clip_grad_norm
@ -192,8 +165,3 @@ class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
clip_value = trainer.config.training.clip_grad_value
if clip_value is not None:
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"total_grad_norm": trainer.total_gradient_norm})

View file

@ -196,37 +196,15 @@ class DropoutConfig(BaseModel):
apply_dropout(module=model, probability=self.dropout_probability)
class WandbConfig(BaseModel):
mode: Literal["online", "offline", "disabled"] = "online"
project: str
entity: str = "finegrain"
name: str | None = None
tags: list[str] = []
group: str | None = None
job_type: str | None = None
notes: str | None = None
class CheckpointingConfig(BaseModel):
save_folder: Path | None = None
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
@validator("save_interval", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
T = TypeVar("T", bound="BaseConfig")
class BaseConfig(BaseModel):
models: dict[str, ModelConfig]
wandb: WandbConfig
training: TrainingConfig
optimizer: OptimizerConfig
scheduler: SchedulerConfig
dropout: DropoutConfig
checkpointing: CheckpointingConfig
@classmethod
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:

View file

@ -204,7 +204,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
)
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
self.log(data=images)
# TODO: wandb_log images
def sample_noise(
@ -257,4 +257,4 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
self.timestep_bins[bin_index] = []
trainer.log(data=log_data)
# TODO: wandb_log log_data

View file

@ -2,7 +2,6 @@ import random
import time
from abc import ABC, abstractmethod
from functools import cached_property, wraps
from pathlib import Path
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
import numpy as np
@ -33,13 +32,10 @@ from refiners.training_utils.callback import (
Callback,
ClockCallback,
GradientNormClipping,
GradientNormLogging,
GradientValueClipping,
MonitorLoss,
)
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
from refiners.training_utils.dropout import DropoutCallback
from refiners.training_utils.wandb import WandbLoggable, WandbLogger
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
@ -130,7 +126,6 @@ class TrainingClock:
gradient_accumulation: TimeValue,
evaluation_interval: TimeValue,
lr_scheduler_interval: TimeValue,
checkpointing_save_interval: TimeValue,
) -> None:
self.dataset_length = dataset_length
self.batch_size = batch_size
@ -138,7 +133,6 @@ class TrainingClock:
self.gradient_accumulation = gradient_accumulation
self.evaluation_interval = evaluation_interval
self.lr_scheduler_interval = lr_scheduler_interval
self.checkpointing_save_interval = checkpointing_save_interval
self.num_batches_per_epoch = dataset_length // batch_size
self.start_time = None
self.end_time = None
@ -225,12 +219,6 @@ class TrainingClock:
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
)
@cached_property
def checkpointing_save_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"]
)
@property
def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.num_step_per_iteration
@ -247,10 +235,6 @@ class TrainingClock:
def is_evaluation_step(self) -> bool:
return self.step % self.evaluation_interval_steps == 0
@property
def is_checkpointing_step(self) -> bool:
return self.step % self.checkpointing_save_interval_steps == 0
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
"""
@ -276,27 +260,35 @@ class Trainer(Generic[ConfigType, Batch], ABC):
evaluation_interval=config.training.evaluation_interval,
gradient_accumulation=config.training.gradient_accumulation,
lr_scheduler_interval=config.scheduler.update_interval,
checkpointing_save_interval=config.checkpointing.save_interval,
)
self.callbacks = callbacks or []
self.callbacks += self.default_callbacks()
self._call_callbacks(event_name="on_init_begin")
self.load_wandb()
self.load_models()
self.prepare_models()
self.prepare_checkpointing()
self._call_callbacks(event_name="on_init_end")
def default_callbacks(self) -> list[Callback[Any]]:
return [
callbacks: list[Callback[Any]] = [
ClockCallback(),
MonitorLoss(),
GradientNormLogging(),
GradientValueClipping(),
GradientNormClipping(),
DropoutCallback(),
]
# look for any Callback that might be a property of the Trainer
for attr_name in dir(self):
if "__" in attr_name:
continue
try:
attr = getattr(self, attr_name)
except AssertionError:
continue
if isinstance(attr, Callback):
callbacks.append(cast(Callback[Any], attr))
return callbacks
@cached_property
def device(self) -> Device:
selected_device = Device(self.config.training.device)
@ -417,13 +409,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
for model in self.models.values():
model.eval()
def log(self, data: dict[str, WandbLoggable]) -> None:
self.wandb.log(data=data, step=self.clock.step)
def load_wandb(self) -> None:
init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()}
self.wandb = WandbLogger(init_config=init_config)
def prepare_model(self, model_name: str) -> None:
model = self.models[model_name]
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
@ -439,18 +424,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
for model_name in self.models:
self.prepare_model(model_name=model_name)
def prepare_checkpointing(self) -> None:
if self.config.checkpointing.save_folder is not None:
assert self.config.checkpointing.save_folder.is_dir()
self.checkpoints_save_folder = (
self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name
)
self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False)
logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}")
else:
self.checkpoints_save_folder = None
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
@abstractmethod
def load_models(self) -> dict[str, fl.Module]:
...
@ -475,15 +448,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
)
@property
def checkpointing_enabled(self) -> bool:
return self.checkpoints_save_folder is not None
@property
def ensure_checkpoints_save_folder(self) -> Path:
assert self.checkpoints_save_folder is not None
return self.checkpoints_save_folder
@abstractmethod
def compute_loss(self, batch: Batch) -> Tensor:
...
@ -508,8 +472,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
self._call_callbacks(event_name="on_lr_scheduler_step_end")
if self.clock.is_evaluation_step:
self.evaluate()
if self.checkpointing_enabled and self.clock.is_checkpointing_step:
self._call_callbacks(event_name="on_checkpoint_save")
def step(self, batch: Batch) -> None:
"""Perform a single training step."""

View file

@ -1,13 +1,15 @@
from typing import Any
from abc import ABC
from functools import cached_property
from pathlib import Path
from typing import Any, Literal
import wandb
from PIL import Image
from pydantic import BaseModel
__all__ = [
"WandbLogger",
"WandbLoggable",
]
from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
number = float | int
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
@ -60,3 +62,87 @@ class WandbLogger:
@property
def run_name(self) -> str:
return self.wandb_run.name or "" # type: ignore
class WandbConfig(BaseModel):
"""
Wandb configuration.
See https://docs.wandb.ai/ref/python/init for more details.
"""
mode: Literal["online", "offline", "disabled"] = "disabled"
project: str
entity: str | None = None
save_code: bool | None = None
name: str | None = None
tags: list[str] = []
group: str | None = None
job_type: str | None = None
notes: str | None = None
dir: Path | None = None
resume: bool | Literal["allow", "must", "never", "auto"] | None = None
reinit: bool | None = None
magic: bool | None = None
anonymous: Literal["never", "allow", "must"] | None = None
id: str | None = None
AnyTrainer = Trainer[BaseConfig, Any]
class WandbCallback(Callback["TrainerWithWandb"]):
epoch_losses: list[float]
iteration_losses: list[float]
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
trainer.load_wandb()
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
self.epoch_losses = []
self.iteration_losses = []
def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
trainer.wandb_log(data={"step_loss": loss_value})
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = []
def on_epoch_end(self, trainer: "TrainerWithWandb") -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
trainer.wandb_log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
class WandbMixin(ABC):
config: Any
wandb_logger: WandbLogger
def load_wandb(self) -> None:
wandb_config = getattr(self.config, "wandb", None)
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
self.wandb_logger = WandbLogger(init_config=init_config)
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
self.wandb_logger.log(data=data, step=self.clock.step)
@cached_property
def wandb_callback(self) -> WandbCallback:
return WandbCallback()
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
pass

View file

@ -24,10 +24,3 @@ warmup = "20:step"
[dropout]
dropout = 0.0
[checkpointing]
save_interval = "10:epoch"
[wandb]
mode = "disabled"
project = "mock_project"

View file

@ -119,7 +119,6 @@ def training_clock() -> TrainingClock:
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH},
)
@ -153,10 +152,8 @@ def test_timer_functionality(training_clock: TrainingClock) -> None:
def test_state_based_properties(training_clock: TrainingClock) -> None:
training_clock.step = 5 # Halfway through the first epoch
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
assert not training_clock.is_checkpointing_step
training_clock.step = 10 # End of the first epoch
assert training_clock.is_evaluation_step
assert training_clock.is_checkpointing_step
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: