mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
remove wandb from base config
This commit is contained in:
parent
11da76f7df
commit
2ef4982e04
|
@ -1,162 +0,0 @@
|
|||
import random
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from torch import Tensor, randn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
||||
from refiners.training_utils.latent_diffusion import (
|
||||
FinetuneLatentDiffusionConfig,
|
||||
LatentDiffusionConfig,
|
||||
LatentDiffusionTrainer,
|
||||
TextEmbeddingLatentsBatch,
|
||||
TextEmbeddingLatentsDataset,
|
||||
)
|
||||
|
||||
IMAGENET_TEMPLATES_SMALL = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
IMAGENET_STYLE_TEMPLATES_SMALL = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(TextEmbeddingLatentsDataset):
|
||||
templates: list[str] = []
|
||||
placeholder_token: str = ""
|
||||
|
||||
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
|
||||
super().__init__(trainer)
|
||||
self.templates = (
|
||||
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
|
||||
)
|
||||
self.placeholder_token = self.config.textual_inversion.placeholder_token
|
||||
|
||||
def get_caption(self, index: int) -> str:
|
||||
# Ignore the dataset caption, if any: use a template instead
|
||||
return random.choice(self.templates).format(self.placeholder_token)
|
||||
|
||||
|
||||
class TextualInversionConfig(BaseModel):
|
||||
# The new token to be learned
|
||||
placeholder_token: str = "*"
|
||||
# The token to be used as initializer; if None, a random vector is used
|
||||
initializer_token: str | None = None
|
||||
style_mode: bool = False
|
||||
|
||||
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
|
||||
adapter = ConceptExtender(target=text_encoder)
|
||||
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
|
||||
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
||||
if self.initializer_token is not None:
|
||||
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
|
||||
assert " " not in bpe, "This initializer_token is not a single token."
|
||||
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
|
||||
init_embedding = token_encoder(token).squeeze(0)
|
||||
else:
|
||||
token_encoder = text_encoder.ensure_find(TokenEncoder)
|
||||
init_embedding = randn(token_encoder.embedding_dim)
|
||||
adapter.add_concept(self.placeholder_token, init_embedding)
|
||||
adapter.inject()
|
||||
|
||||
|
||||
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||
dataset: HuggingfaceDatasetConfig
|
||||
latent_diffusion: LatentDiffusionConfig
|
||||
textual_inversion: TextualInversionConfig
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
# Pydantic v2 does post init differently, so we need to override this method too.
|
||||
logger.info("Freezing models to train only the new embedding.")
|
||||
self.models["unet"].train = False
|
||||
self.models["text_encoder"].train = False
|
||||
self.models["lda"].train = False
|
||||
|
||||
|
||||
class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
|
||||
def __init__(
|
||||
self,
|
||||
config: TextualInversionLatentDiffusionConfig,
|
||||
callbacks: "list[Callback[Any]] | None" = None,
|
||||
) -> None:
|
||||
super().__init__(config=config, callbacks=callbacks)
|
||||
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))
|
||||
|
||||
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
||||
return TextualInversionDataset(trainer=self)
|
||||
|
||||
|
||||
class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)
|
||||
|
||||
|
||||
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
|
||||
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
|
||||
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
|
||||
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
|
||||
|
||||
save_to_safetensors(
|
||||
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
config_path = sys.argv[1]
|
||||
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
|
||||
trainer = TextualInversionLatentDiffusionTrainer(config=config)
|
||||
trainer.train()
|
|
@ -13,8 +13,6 @@ __all__ = [
|
|||
"GradientNormClipping",
|
||||
"GradientValueClipping",
|
||||
"ClockCallback",
|
||||
"GradientNormLogging",
|
||||
"MonitorLoss",
|
||||
]
|
||||
|
||||
|
||||
|
@ -153,31 +151,6 @@ class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
|||
logger.info("Evaluation ended.")
|
||||
|
||||
|
||||
class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.epoch_losses: list[float] = []
|
||||
self.iteration_losses: list[float] = []
|
||||
|
||||
def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
loss_value = trainer.loss.detach().cpu().item()
|
||||
self.epoch_losses.append(loss_value)
|
||||
self.iteration_losses.append(loss_value)
|
||||
trainer.log(data={"step_loss": loss_value})
|
||||
|
||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||
trainer.log(data={"average_iteration_loss": avg_iteration_loss})
|
||||
self.iteration_losses = []
|
||||
|
||||
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
||||
trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
|
||||
self.epoch_losses = []
|
||||
|
||||
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
||||
|
||||
|
||||
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
clip_norm = trainer.config.training.clip_grad_norm
|
||||
|
@ -192,8 +165,3 @@ class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
|
|||
clip_value = trainer.config.training.clip_grad_value
|
||||
if clip_value is not None:
|
||||
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
|
||||
|
||||
|
||||
class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
trainer.log(data={"total_grad_norm": trainer.total_gradient_norm})
|
||||
|
|
|
@ -196,37 +196,15 @@ class DropoutConfig(BaseModel):
|
|||
apply_dropout(module=model, probability=self.dropout_probability)
|
||||
|
||||
|
||||
class WandbConfig(BaseModel):
|
||||
mode: Literal["online", "offline", "disabled"] = "online"
|
||||
project: str
|
||||
entity: str = "finegrain"
|
||||
name: str | None = None
|
||||
tags: list[str] = []
|
||||
group: str | None = None
|
||||
job_type: str | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class CheckpointingConfig(BaseModel):
|
||||
save_folder: Path | None = None
|
||||
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
|
||||
|
||||
@validator("save_interval", pre=True)
|
||||
def parse_field(cls, value: Any) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseConfig")
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
models: dict[str, ModelConfig]
|
||||
wandb: WandbConfig
|
||||
training: TrainingConfig
|
||||
optimizer: OptimizerConfig
|
||||
scheduler: SchedulerConfig
|
||||
dropout: DropoutConfig
|
||||
checkpointing: CheckpointingConfig
|
||||
|
||||
@classmethod
|
||||
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
||||
|
|
|
@ -204,7 +204,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
)
|
||||
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
|
||||
images[prompt] = canvas_image
|
||||
self.log(data=images)
|
||||
# TODO: wandb_log images
|
||||
|
||||
|
||||
def sample_noise(
|
||||
|
@ -257,4 +257,4 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
|
|||
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
|
||||
self.timestep_bins[bin_index] = []
|
||||
|
||||
trainer.log(data=log_data)
|
||||
# TODO: wandb_log log_data
|
||||
|
|
|
@ -2,7 +2,6 @@ import random
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
|
@ -33,13 +32,10 @@ from refiners.training_utils.callback import (
|
|||
Callback,
|
||||
ClockCallback,
|
||||
GradientNormClipping,
|
||||
GradientNormLogging,
|
||||
GradientValueClipping,
|
||||
MonitorLoss,
|
||||
)
|
||||
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
|
||||
from refiners.training_utils.dropout import DropoutCallback
|
||||
from refiners.training_utils.wandb import WandbLoggable, WandbLogger
|
||||
|
||||
__all__ = ["seed_everything", "scoped_seed", "Trainer"]
|
||||
|
||||
|
@ -130,7 +126,6 @@ class TrainingClock:
|
|||
gradient_accumulation: TimeValue,
|
||||
evaluation_interval: TimeValue,
|
||||
lr_scheduler_interval: TimeValue,
|
||||
checkpointing_save_interval: TimeValue,
|
||||
) -> None:
|
||||
self.dataset_length = dataset_length
|
||||
self.batch_size = batch_size
|
||||
|
@ -138,7 +133,6 @@ class TrainingClock:
|
|||
self.gradient_accumulation = gradient_accumulation
|
||||
self.evaluation_interval = evaluation_interval
|
||||
self.lr_scheduler_interval = lr_scheduler_interval
|
||||
self.checkpointing_save_interval = checkpointing_save_interval
|
||||
self.num_batches_per_epoch = dataset_length // batch_size
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
|
@ -225,12 +219,6 @@ class TrainingClock:
|
|||
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def checkpointing_save_interval_steps(self) -> int:
|
||||
return self.convert_time_unit_to_steps(
|
||||
number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"]
|
||||
)
|
||||
|
||||
@property
|
||||
def is_optimizer_step(self) -> bool:
|
||||
return self.num_minibatches_processed == self.num_step_per_iteration
|
||||
|
@ -247,10 +235,6 @@ class TrainingClock:
|
|||
def is_evaluation_step(self) -> bool:
|
||||
return self.step % self.evaluation_interval_steps == 0
|
||||
|
||||
@property
|
||||
def is_checkpointing_step(self) -> bool:
|
||||
return self.step % self.checkpointing_save_interval_steps == 0
|
||||
|
||||
|
||||
def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
|
||||
"""
|
||||
|
@ -276,27 +260,35 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
evaluation_interval=config.training.evaluation_interval,
|
||||
gradient_accumulation=config.training.gradient_accumulation,
|
||||
lr_scheduler_interval=config.scheduler.update_interval,
|
||||
checkpointing_save_interval=config.checkpointing.save_interval,
|
||||
)
|
||||
self.callbacks = callbacks or []
|
||||
self.callbacks += self.default_callbacks()
|
||||
self._call_callbacks(event_name="on_init_begin")
|
||||
self.load_wandb()
|
||||
self.load_models()
|
||||
self.prepare_models()
|
||||
self.prepare_checkpointing()
|
||||
self._call_callbacks(event_name="on_init_end")
|
||||
|
||||
def default_callbacks(self) -> list[Callback[Any]]:
|
||||
return [
|
||||
callbacks: list[Callback[Any]] = [
|
||||
ClockCallback(),
|
||||
MonitorLoss(),
|
||||
GradientNormLogging(),
|
||||
GradientValueClipping(),
|
||||
GradientNormClipping(),
|
||||
DropoutCallback(),
|
||||
]
|
||||
|
||||
# look for any Callback that might be a property of the Trainer
|
||||
for attr_name in dir(self):
|
||||
if "__" in attr_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
attr = getattr(self, attr_name)
|
||||
except AssertionError:
|
||||
continue
|
||||
if isinstance(attr, Callback):
|
||||
callbacks.append(cast(Callback[Any], attr))
|
||||
return callbacks
|
||||
|
||||
@cached_property
|
||||
def device(self) -> Device:
|
||||
selected_device = Device(self.config.training.device)
|
||||
|
@ -417,13 +409,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
for model in self.models.values():
|
||||
model.eval()
|
||||
|
||||
def log(self, data: dict[str, WandbLoggable]) -> None:
|
||||
self.wandb.log(data=data, step=self.clock.step)
|
||||
|
||||
def load_wandb(self) -> None:
|
||||
init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()}
|
||||
self.wandb = WandbLogger(init_config=init_config)
|
||||
|
||||
def prepare_model(self, model_name: str) -> None:
|
||||
model = self.models[model_name]
|
||||
if (checkpoint := self.config.models[model_name].checkpoint) is not None:
|
||||
|
@ -439,18 +424,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
for model_name in self.models:
|
||||
self.prepare_model(model_name=model_name)
|
||||
|
||||
def prepare_checkpointing(self) -> None:
|
||||
if self.config.checkpointing.save_folder is not None:
|
||||
assert self.config.checkpointing.save_folder.is_dir()
|
||||
self.checkpoints_save_folder = (
|
||||
self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name
|
||||
)
|
||||
self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False)
|
||||
logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}")
|
||||
else:
|
||||
self.checkpoints_save_folder = None
|
||||
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
|
||||
|
||||
@abstractmethod
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
...
|
||||
|
@ -475,15 +448,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
@property
|
||||
def checkpointing_enabled(self) -> bool:
|
||||
return self.checkpoints_save_folder is not None
|
||||
|
||||
@property
|
||||
def ensure_checkpoints_save_folder(self) -> Path:
|
||||
assert self.checkpoints_save_folder is not None
|
||||
return self.checkpoints_save_folder
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Batch) -> Tensor:
|
||||
...
|
||||
|
@ -508,8 +472,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
self._call_callbacks(event_name="on_lr_scheduler_step_end")
|
||||
if self.clock.is_evaluation_step:
|
||||
self.evaluate()
|
||||
if self.checkpointing_enabled and self.clock.is_checkpointing_step:
|
||||
self._call_callbacks(event_name="on_checkpoint_save")
|
||||
|
||||
def step(self, batch: Batch) -> None:
|
||||
"""Perform a single training step."""
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from typing import Any
|
||||
from abc import ABC
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import wandb
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = [
|
||||
"WandbLogger",
|
||||
"WandbLoggable",
|
||||
]
|
||||
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
|
||||
number = float | int
|
||||
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||
|
@ -60,3 +62,87 @@ class WandbLogger:
|
|||
@property
|
||||
def run_name(self) -> str:
|
||||
return self.wandb_run.name or "" # type: ignore
|
||||
|
||||
|
||||
class WandbConfig(BaseModel):
|
||||
"""
|
||||
Wandb configuration.
|
||||
|
||||
See https://docs.wandb.ai/ref/python/init for more details.
|
||||
"""
|
||||
|
||||
mode: Literal["online", "offline", "disabled"] = "disabled"
|
||||
project: str
|
||||
entity: str | None = None
|
||||
save_code: bool | None = None
|
||||
name: str | None = None
|
||||
tags: list[str] = []
|
||||
group: str | None = None
|
||||
job_type: str | None = None
|
||||
notes: str | None = None
|
||||
dir: Path | None = None
|
||||
resume: bool | Literal["allow", "must", "never", "auto"] | None = None
|
||||
reinit: bool | None = None
|
||||
magic: bool | None = None
|
||||
anonymous: Literal["never", "allow", "must"] | None = None
|
||||
id: str | None = None
|
||||
|
||||
|
||||
AnyTrainer = Trainer[BaseConfig, Any]
|
||||
|
||||
|
||||
class WandbCallback(Callback["TrainerWithWandb"]):
|
||||
epoch_losses: list[float]
|
||||
iteration_losses: list[float]
|
||||
|
||||
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.load_wandb()
|
||||
|
||||
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
|
||||
self.epoch_losses = []
|
||||
self.iteration_losses = []
|
||||
|
||||
def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
loss_value = trainer.loss.detach().cpu().item()
|
||||
self.epoch_losses.append(loss_value)
|
||||
self.iteration_losses.append(loss_value)
|
||||
trainer.wandb_log(data={"step_loss": loss_value})
|
||||
|
||||
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
|
||||
self.iteration_losses = []
|
||||
|
||||
def on_epoch_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
||||
trainer.wandb_log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
|
||||
self.epoch_losses = []
|
||||
|
||||
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
|
||||
|
||||
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
|
||||
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
|
||||
|
||||
|
||||
class WandbMixin(ABC):
|
||||
config: Any
|
||||
wandb_logger: WandbLogger
|
||||
|
||||
def load_wandb(self) -> None:
|
||||
wandb_config = getattr(self.config, "wandb", None)
|
||||
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
|
||||
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
|
||||
self.wandb_logger = WandbLogger(init_config=init_config)
|
||||
|
||||
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
|
||||
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
|
||||
self.wandb_logger.log(data=data, step=self.clock.step)
|
||||
|
||||
@cached_property
|
||||
def wandb_callback(self) -> WandbCallback:
|
||||
return WandbCallback()
|
||||
|
||||
|
||||
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
|
||||
pass
|
||||
|
|
|
@ -24,10 +24,3 @@ warmup = "20:step"
|
|||
|
||||
[dropout]
|
||||
dropout = 0.0
|
||||
|
||||
[checkpointing]
|
||||
save_interval = "10:epoch"
|
||||
|
||||
[wandb]
|
||||
mode = "disabled"
|
||||
project = "mock_project"
|
||||
|
|
|
@ -119,7 +119,6 @@ def training_clock() -> TrainingClock:
|
|||
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
)
|
||||
|
||||
|
||||
|
@ -153,10 +152,8 @@ def test_timer_functionality(training_clock: TrainingClock) -> None:
|
|||
def test_state_based_properties(training_clock: TrainingClock) -> None:
|
||||
training_clock.step = 5 # Halfway through the first epoch
|
||||
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
|
||||
assert not training_clock.is_checkpointing_step
|
||||
training_clock.step = 10 # End of the first epoch
|
||||
assert training_clock.is_evaluation_step
|
||||
assert training_clock.is_checkpointing_step
|
||||
|
||||
|
||||
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
|
||||
|
|
Loading…
Reference in a new issue