mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove wandb from base config
This commit is contained in:
parent
11da76f7df
commit
2ef4982e04
|
@ -1,162 +0,0 @@
|
||||||
import random
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from torch import Tensor, randn
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
|
||||||
from refiners.training_utils.callback import Callback
|
|
||||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
|
||||||
from refiners.training_utils.latent_diffusion import (
|
|
||||||
FinetuneLatentDiffusionConfig,
|
|
||||||
LatentDiffusionConfig,
|
|
||||||
LatentDiffusionTrainer,
|
|
||||||
TextEmbeddingLatentsBatch,
|
|
||||||
TextEmbeddingLatentsDataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
IMAGENET_TEMPLATES_SMALL = [
|
|
||||||
"a photo of a {}",
|
|
||||||
"a rendering of a {}",
|
|
||||||
"a cropped photo of the {}",
|
|
||||||
"the photo of a {}",
|
|
||||||
"a photo of a clean {}",
|
|
||||||
"a photo of a dirty {}",
|
|
||||||
"a dark photo of the {}",
|
|
||||||
"a photo of my {}",
|
|
||||||
"a photo of the cool {}",
|
|
||||||
"a close-up photo of a {}",
|
|
||||||
"a bright photo of the {}",
|
|
||||||
"a cropped photo of a {}",
|
|
||||||
"a photo of the {}",
|
|
||||||
"a good photo of the {}",
|
|
||||||
"a photo of one {}",
|
|
||||||
"a close-up photo of the {}",
|
|
||||||
"a rendition of the {}",
|
|
||||||
"a photo of the clean {}",
|
|
||||||
"a rendition of a {}",
|
|
||||||
"a photo of a nice {}",
|
|
||||||
"a good photo of a {}",
|
|
||||||
"a photo of the nice {}",
|
|
||||||
"a photo of the small {}",
|
|
||||||
"a photo of the weird {}",
|
|
||||||
"a photo of the large {}",
|
|
||||||
"a photo of a cool {}",
|
|
||||||
"a photo of a small {}",
|
|
||||||
]
|
|
||||||
|
|
||||||
IMAGENET_STYLE_TEMPLATES_SMALL = [
|
|
||||||
"a painting in the style of {}",
|
|
||||||
"a rendering in the style of {}",
|
|
||||||
"a cropped painting in the style of {}",
|
|
||||||
"the painting in the style of {}",
|
|
||||||
"a clean painting in the style of {}",
|
|
||||||
"a dirty painting in the style of {}",
|
|
||||||
"a dark painting in the style of {}",
|
|
||||||
"a picture in the style of {}",
|
|
||||||
"a cool painting in the style of {}",
|
|
||||||
"a close-up painting in the style of {}",
|
|
||||||
"a bright painting in the style of {}",
|
|
||||||
"a cropped painting in the style of {}",
|
|
||||||
"a good painting in the style of {}",
|
|
||||||
"a close-up painting in the style of {}",
|
|
||||||
"a rendition in the style of {}",
|
|
||||||
"a nice painting in the style of {}",
|
|
||||||
"a small painting in the style of {}",
|
|
||||||
"a weird painting in the style of {}",
|
|
||||||
"a large painting in the style of {}",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionDataset(TextEmbeddingLatentsDataset):
|
|
||||||
templates: list[str] = []
|
|
||||||
placeholder_token: str = ""
|
|
||||||
|
|
||||||
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
|
|
||||||
super().__init__(trainer)
|
|
||||||
self.templates = (
|
|
||||||
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
|
|
||||||
)
|
|
||||||
self.placeholder_token = self.config.textual_inversion.placeholder_token
|
|
||||||
|
|
||||||
def get_caption(self, index: int) -> str:
|
|
||||||
# Ignore the dataset caption, if any: use a template instead
|
|
||||||
return random.choice(self.templates).format(self.placeholder_token)
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionConfig(BaseModel):
|
|
||||||
# The new token to be learned
|
|
||||||
placeholder_token: str = "*"
|
|
||||||
# The token to be used as initializer; if None, a random vector is used
|
|
||||||
initializer_token: str | None = None
|
|
||||||
style_mode: bool = False
|
|
||||||
|
|
||||||
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
|
||||||
adapter = ConceptExtender(target=text_encoder)
|
|
||||||
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
|
|
||||||
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
|
||||||
if self.initializer_token is not None:
|
|
||||||
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
|
||||||
assert " " not in bpe, "This initializer_token is not a single token."
|
|
||||||
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
|
||||||
init_embedding = token_encoder(token).squeeze(0)
|
|
||||||
else:
|
|
||||||
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
|
||||||
init_embedding = randn(token_encoder.embedding_dim)
|
|
||||||
adapter.add_concept(self.placeholder_token, init_embedding)
|
|
||||||
adapter.inject()
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
|
||||||
dataset: HuggingfaceDatasetConfig
|
|
||||||
latent_diffusion: LatentDiffusionConfig
|
|
||||||
textual_inversion: TextualInversionConfig
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
# Pydantic v2 does post init differently, so we need to override this method too.
|
|
||||||
logger.info("Freezing models to train only the new embedding.")
|
|
||||||
self.models["unet"].train = False
|
|
||||||
self.models["text_encoder"].train = False
|
|
||||||
self.models["lda"].train = False
|
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: TextualInversionLatentDiffusionConfig,
|
|
||||||
callbacks: "list[Callback[Any]] | None" = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(config=config, callbacks=callbacks)
|
|
||||||
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))
|
|
||||||
|
|
||||||
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
|
||||||
return TextualInversionDataset(trainer=self)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
|
||||||
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
|
||||||
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)
|
|
||||||
|
|
||||||
|
|
||||||
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
|
||||||
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
|
||||||
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
|
|
||||||
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
|
||||||
|
|
||||||
save_to_safetensors(
|
|
||||||
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
config_path = sys.argv[1]
|
|
||||||
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
|
|
||||||
trainer = TextualInversionLatentDiffusionTrainer(config=config)
|
|
||||||
trainer.train()
|
|
|
@ -13,8 +13,6 @@ __all__ = [
|
||||||
"GradientNormClipping",
|
"GradientNormClipping",
|
||||||
"GradientValueClipping",
|
"GradientValueClipping",
|
||||||
"ClockCallback",
|
"ClockCallback",
|
||||||
"GradientNormLogging",
|
|
||||||
"MonitorLoss",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,31 +151,6 @@ class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
logger.info("Evaluation ended.")
|
logger.info("Evaluation ended.")
|
||||||
|
|
||||||
|
|
||||||
class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]):
|
|
||||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
self.epoch_losses: list[float] = []
|
|
||||||
self.iteration_losses: list[float] = []
|
|
||||||
|
|
||||||
def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
loss_value = trainer.loss.detach().cpu().item()
|
|
||||||
self.epoch_losses.append(loss_value)
|
|
||||||
self.iteration_losses.append(loss_value)
|
|
||||||
trainer.log(data={"step_loss": loss_value})
|
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
|
||||||
trainer.log(data={"average_iteration_loss": avg_iteration_loss})
|
|
||||||
self.iteration_losses = []
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
|
||||||
trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
|
|
||||||
self.epoch_losses = []
|
|
||||||
|
|
||||||
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
|
||||||
|
|
||||||
|
|
||||||
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
|
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
clip_norm = trainer.config.training.clip_grad_norm
|
clip_norm = trainer.config.training.clip_grad_norm
|
||||||
|
@ -192,8 +165,3 @@ class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
clip_value = trainer.config.training.clip_grad_value
|
clip_value = trainer.config.training.clip_grad_value
|
||||||
if clip_value is not None:
|
if clip_value is not None:
|
||||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
||||||
|
|
||||||
|
|
||||||
class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]):
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
|
||||||
trainer.log(data={"total_grad_norm": trainer.total_gradient_norm})
|
|
||||||
|
|
|
@ -196,37 +196,15 @@ class DropoutConfig(BaseModel):
|
||||||
apply_dropout(module=model, probability=self.dropout_probability)
|
apply_dropout(module=model, probability=self.dropout_probability)
|
||||||
|
|
||||||
|
|
||||||
class WandbConfig(BaseModel):
|
|
||||||
mode: Literal["online", "offline", "disabled"] = "online"
|
|
||||||
project: str
|
|
||||||
entity: str = "finegrain"
|
|
||||||
name: str | None = None
|
|
||||||
tags: list[str] = []
|
|
||||||
group: str | None = None
|
|
||||||
job_type: str | None = None
|
|
||||||
notes: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointingConfig(BaseModel):
|
|
||||||
save_folder: Path | None = None
|
|
||||||
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
|
|
||||||
|
|
||||||
@validator("save_interval", pre=True)
|
|
||||||
def parse_field(cls, value: Any) -> TimeValue:
|
|
||||||
return parse_number_unit_field(value)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="BaseConfig")
|
T = TypeVar("T", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
models: dict[str, ModelConfig]
|
models: dict[str, ModelConfig]
|
||||||
wandb: WandbConfig
|
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
scheduler: SchedulerConfig
|
scheduler: SchedulerConfig
|
||||||
dropout: DropoutConfig
|
dropout: DropoutConfig
|
||||||
checkpointing: CheckpointingConfig
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
||||||
|
|
|
@ -204,7 +204,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
)
|
)
|
||||||
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
|
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
|
||||||
images[prompt] = canvas_image
|
images[prompt] = canvas_image
|
||||||
self.log(data=images)
|
# TODO: wandb_log images
|
||||||
|
|
||||||
|
|
||||||
def sample_noise(
|
def sample_noise(
|
||||||
|
@ -257,4 +257,4 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
|
||||||
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
|
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
|
||||||
self.timestep_bins[bin_index] = []
|
self.timestep_bins[bin_index] = []
|
||||||
|
|
||||||
trainer.log(data=log_data)
|
# TODO: wandb_log log_data
|
||||||
|
|
|
@ -2,7 +2,6 @@ import random
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property, wraps
|
from functools import cached_property, wraps
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -33,13 +32,10 @@ from refiners.training_utils.callback import (
|
||||||
Callback,
|
Callback,
|
||||||
ClockCallback,
|
ClockCallback,
|
||||||
GradientNormClipping,
|
GradientNormClipping,
|
||||||
GradientNormLogging,
|
|
||||||
GradientValueClipping,
|
GradientValueClipping,
|
||||||
MonitorLoss,
|
|
||||||
)
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
|
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
|
||||||
from refiners.training_utils.dropout import DropoutCallback
|
from refiners.training_utils.dropout import DropoutCallback
|
||||||
from refiners.training_utils.wandb import WandbLoggable, WandbLogger
|
|
||||||
|
|
||||||
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
|
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
|
||||||
|
|
||||||
|
@ -130,7 +126,6 @@ class TrainingClock:
|
||||||
gradient_accumulation: TimeValue,
|
gradient_accumulation: TimeValue,
|
||||||
evaluation_interval: TimeValue,
|
evaluation_interval: TimeValue,
|
||||||
lr_scheduler_interval: TimeValue,
|
lr_scheduler_interval: TimeValue,
|
||||||
checkpointing_save_interval: TimeValue,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dataset_length = dataset_length
|
self.dataset_length = dataset_length
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -138,7 +133,6 @@ class TrainingClock:
|
||||||
self.gradient_accumulation = gradient_accumulation
|
self.gradient_accumulation = gradient_accumulation
|
||||||
self.evaluation_interval = evaluation_interval
|
self.evaluation_interval = evaluation_interval
|
||||||
self.lr_scheduler_interval = lr_scheduler_interval
|
self.lr_scheduler_interval = lr_scheduler_interval
|
||||||
self.checkpointing_save_interval = checkpointing_save_interval
|
|
||||||
self.num_batches_per_epoch = dataset_length // batch_size
|
self.num_batches_per_epoch = dataset_length // batch_size
|
||||||
self.start_time = None
|
self.start_time = None
|
||||||
self.end_time = None
|
self.end_time = None
|
||||||
|
@ -225,12 +219,6 @@ class TrainingClock:
|
||||||
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def checkpointing_save_interval_steps(self) -> int:
|
|
||||||
return self.convert_time_unit_to_steps(
|
|
||||||
number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_optimizer_step(self) -> bool:
|
def is_optimizer_step(self) -> bool:
|
||||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||||
|
@ -247,10 +235,6 @@ class TrainingClock:
|
||||||
def is_evaluation_step(self) -> bool:
|
def is_evaluation_step(self) -> bool:
|
||||||
return self.step % self.evaluation_interval_steps == 0
|
return self.step % self.evaluation_interval_steps == 0
|
||||||
|
|
||||||
@property
|
|
||||||
def is_checkpointing_step(self) -> bool:
|
|
||||||
return self.step % self.checkpointing_save_interval_steps == 0
|
|
||||||
|
|
||||||
|
|
||||||
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
|
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
|
||||||
"""
|
"""
|
||||||
|
@ -276,27 +260,35 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
evaluation_interval=config.training.evaluation_interval,
|
evaluation_interval=config.training.evaluation_interval,
|
||||||
gradient_accumulation=config.training.gradient_accumulation,
|
gradient_accumulation=config.training.gradient_accumulation,
|
||||||
lr_scheduler_interval=config.scheduler.update_interval,
|
lr_scheduler_interval=config.scheduler.update_interval,
|
||||||
checkpointing_save_interval=config.checkpointing.save_interval,
|
|
||||||
)
|
)
|
||||||
self.callbacks = callbacks or []
|
self.callbacks = callbacks or []
|
||||||
self.callbacks += self.default_callbacks()
|
self.callbacks += self.default_callbacks()
|
||||||
self._call_callbacks(event_name="on_init_begin")
|
self._call_callbacks(event_name="on_init_begin")
|
||||||
self.load_wandb()
|
|
||||||
self.load_models()
|
self.load_models()
|
||||||
self.prepare_models()
|
self.prepare_models()
|
||||||
self.prepare_checkpointing()
|
|
||||||
self._call_callbacks(event_name="on_init_end")
|
self._call_callbacks(event_name="on_init_end")
|
||||||
|
|
||||||
def default_callbacks(self) -> list[Callback[Any]]:
|
def default_callbacks(self) -> list[Callback[Any]]:
|
||||||
return [
|
callbacks: list[Callback[Any]] = [
|
||||||
ClockCallback(),
|
ClockCallback(),
|
||||||
MonitorLoss(),
|
|
||||||
GradientNormLogging(),
|
|
||||||
GradientValueClipping(),
|
GradientValueClipping(),
|
||||||
GradientNormClipping(),
|
GradientNormClipping(),
|
||||||
DropoutCallback(),
|
DropoutCallback(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# look for any Callback that might be a property of the Trainer
|
||||||
|
for attr_name in dir(self):
|
||||||
|
if "__" in attr_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
attr = getattr(self, attr_name)
|
||||||
|
except AssertionError:
|
||||||
|
continue
|
||||||
|
if isinstance(attr, Callback):
|
||||||
|
callbacks.append(cast(Callback[Any], attr))
|
||||||
|
return callbacks
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
selected_device = Device(self.config.training.device)
|
selected_device = Device(self.config.training.device)
|
||||||
|
@ -417,13 +409,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
for model in self.models.values():
|
for model in self.models.values():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
def log(self, data: dict[str, WandbLoggable]) -> None:
|
|
||||||
self.wandb.log(data=data, step=self.clock.step)
|
|
||||||
|
|
||||||
def load_wandb(self) -> None:
|
|
||||||
init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()}
|
|
||||||
self.wandb = WandbLogger(init_config=init_config)
|
|
||||||
|
|
||||||
def prepare_model(self, model_name: str) -> None:
|
def prepare_model(self, model_name: str) -> None:
|
||||||
model = self.models[model_name]
|
model = self.models[model_name]
|
||||||
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
|
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
|
||||||
|
@ -439,18 +424,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
for model_name in self.models:
|
for model_name in self.models:
|
||||||
self.prepare_model(model_name=model_name)
|
self.prepare_model(model_name=model_name)
|
||||||
|
|
||||||
def prepare_checkpointing(self) -> None:
|
|
||||||
if self.config.checkpointing.save_folder is not None:
|
|
||||||
assert self.config.checkpointing.save_folder.is_dir()
|
|
||||||
self.checkpoints_save_folder = (
|
|
||||||
self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name
|
|
||||||
)
|
|
||||||
self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False)
|
|
||||||
logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}")
|
|
||||||
else:
|
|
||||||
self.checkpoints_save_folder = None
|
|
||||||
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_models(self) -> dict[str, fl.Module]:
|
def load_models(self) -> dict[str, fl.Module]:
|
||||||
...
|
...
|
||||||
|
@ -475,15 +448,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
|
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def checkpointing_enabled(self) -> bool:
|
|
||||||
return self.checkpoints_save_folder is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ensure_checkpoints_save_folder(self) -> Path:
|
|
||||||
assert self.checkpoints_save_folder is not None
|
|
||||||
return self.checkpoints_save_folder
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Batch) -> Tensor:
|
def compute_loss(self, batch: Batch) -> Tensor:
|
||||||
...
|
...
|
||||||
|
@ -508,8 +472,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
self._call_callbacks(event_name="on_lr_scheduler_step_end")
|
self._call_callbacks(event_name="on_lr_scheduler_step_end")
|
||||||
if self.clock.is_evaluation_step:
|
if self.clock.is_evaluation_step:
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
if self.checkpointing_enabled and self.clock.is_checkpointing_step:
|
|
||||||
self._call_callbacks(event_name="on_checkpoint_save")
|
|
||||||
|
|
||||||
def step(self, batch: Batch) -> None:
|
def step(self, batch: Batch) -> None:
|
||||||
"""Perform a single training step."""
|
"""Perform a single training step."""
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
from typing import Any
|
from abc import ABC
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
__all__ = [
|
from refiners.training_utils.callback import Callback
|
||||||
"WandbLogger",
|
from refiners.training_utils.config import BaseConfig
|
||||||
"WandbLoggable",
|
from refiners.training_utils.trainer import Trainer
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
number = float | int
|
number = float | int
|
||||||
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||||
|
@ -60,3 +62,87 @@ class WandbLogger:
|
||||||
@property
|
@property
|
||||||
def run_name(self) -> str:
|
def run_name(self) -> str:
|
||||||
return self.wandb_run.name or "" # type: ignore
|
return self.wandb_run.name or "" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class WandbConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Wandb configuration.
|
||||||
|
|
||||||
|
See https://docs.wandb.ai/ref/python/init for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mode: Literal["online", "offline", "disabled"] = "disabled"
|
||||||
|
project: str
|
||||||
|
entity: str | None = None
|
||||||
|
save_code: bool | None = None
|
||||||
|
name: str | None = None
|
||||||
|
tags: list[str] = []
|
||||||
|
group: str | None = None
|
||||||
|
job_type: str | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
dir: Path | None = None
|
||||||
|
resume: bool | Literal["allow", "must", "never", "auto"] | None = None
|
||||||
|
reinit: bool | None = None
|
||||||
|
magic: bool | None = None
|
||||||
|
anonymous: Literal["never", "allow", "must"] | None = None
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
AnyTrainer = Trainer[BaseConfig, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WandbCallback(Callback["TrainerWithWandb"]):
|
||||||
|
epoch_losses: list[float]
|
||||||
|
iteration_losses: list[float]
|
||||||
|
|
||||||
|
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
trainer.load_wandb()
|
||||||
|
|
||||||
|
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
self.epoch_losses = []
|
||||||
|
self.iteration_losses = []
|
||||||
|
|
||||||
|
def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
loss_value = trainer.loss.detach().cpu().item()
|
||||||
|
self.epoch_losses.append(loss_value)
|
||||||
|
self.iteration_losses.append(loss_value)
|
||||||
|
trainer.wandb_log(data={"step_loss": loss_value})
|
||||||
|
|
||||||
|
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||||
|
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
|
||||||
|
self.iteration_losses = []
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
||||||
|
trainer.wandb_log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
|
||||||
|
self.epoch_losses = []
|
||||||
|
|
||||||
|
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
||||||
|
|
||||||
|
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
|
||||||
|
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
|
||||||
|
|
||||||
|
|
||||||
|
class WandbMixin(ABC):
|
||||||
|
config: Any
|
||||||
|
wandb_logger: WandbLogger
|
||||||
|
|
||||||
|
def load_wandb(self) -> None:
|
||||||
|
wandb_config = getattr(self.config, "wandb", None)
|
||||||
|
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
|
||||||
|
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
|
||||||
|
self.wandb_logger = WandbLogger(init_config=init_config)
|
||||||
|
|
||||||
|
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
|
||||||
|
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
|
||||||
|
self.wandb_logger.log(data=data, step=self.clock.step)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def wandb_callback(self) -> WandbCallback:
|
||||||
|
return WandbCallback()
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
|
||||||
|
pass
|
||||||
|
|
|
@ -24,10 +24,3 @@ warmup = "20:step"
|
||||||
|
|
||||||
[dropout]
|
[dropout]
|
||||||
dropout = 0.0
|
dropout = 0.0
|
||||||
|
|
||||||
[checkpointing]
|
|
||||||
save_interval = "10:epoch"
|
|
||||||
|
|
||||||
[wandb]
|
|
||||||
mode = "disabled"
|
|
||||||
project = "mock_project"
|
|
||||||
|
|
|
@ -119,7 +119,6 @@ def training_clock() -> TrainingClock:
|
||||||
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
|
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
|
||||||
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||||
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||||
checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,10 +152,8 @@ def test_timer_functionality(training_clock: TrainingClock) -> None:
|
||||||
def test_state_based_properties(training_clock: TrainingClock) -> None:
|
def test_state_based_properties(training_clock: TrainingClock) -> None:
|
||||||
training_clock.step = 5 # Halfway through the first epoch
|
training_clock.step = 5 # Halfway through the first epoch
|
||||||
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
|
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
|
||||||
assert not training_clock.is_checkpointing_step
|
|
||||||
training_clock.step = 10 # End of the first epoch
|
training_clock.step = 10 # End of the first epoch
|
||||||
assert training_clock.is_evaluation_step
|
assert training_clock.is_evaluation_step
|
||||||
assert training_clock.is_checkpointing_step
|
|
||||||
|
|
||||||
|
|
||||||
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
|
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
|
||||||
|
|
Loading…
Reference in a new issue