remove wandb from base config

This commit is contained in:
limiteinductive 2024-02-07 09:37:10 +00:00 committed by Benjamin Trom
parent 11da76f7df
commit 2ef4982e04
8 changed files with 108 additions and 286 deletions

View file

@ -1,162 +0,0 @@
import random
from typing import Any
from loguru import logger
from pydantic import BaseModel
from torch import Tensor, randn
from torch.utils.data import Dataset
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
LatentDiffusionTrainer,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
)
IMAGENET_TEMPLATES_SMALL = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
IMAGENET_STYLE_TEMPLATES_SMALL = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
class TextualInversionDataset(TextEmbeddingLatentsDataset):
templates: list[str] = []
placeholder_token: str = ""
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
super().__init__(trainer)
self.templates = (
IMAGENET_STYLE_TEMPLATES_SMALL if self.config.textual_inversion.style_mode else IMAGENET_TEMPLATES_SMALL
)
self.placeholder_token = self.config.textual_inversion.placeholder_token
def get_caption(self, index: int) -> str:
# Ignore the dataset caption, if any: use a template instead
return random.choice(self.templates).format(self.placeholder_token)
class TextualInversionConfig(BaseModel):
# The new token to be learned
placeholder_token: str = "*"
# The token to be used as initializer; if None, a random vector is used
initializer_token: str | None = None
style_mode: bool = False
def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> None:
adapter = ConceptExtender(target=text_encoder)
tokenizer = text_encoder.ensure_find(CLIPTokenizer)
token_encoder = text_encoder.ensure_find(TokenEncoder)
if self.initializer_token is not None:
bpe = tokenizer.byte_pair_encoding(token=self.initializer_token)
assert " " not in bpe, "This initializer_token is not a single token."
token = Tensor([tokenizer.token_to_id_mapping[bpe]]).int().to(text_encoder.device)
init_embedding = token_encoder(token).squeeze(0)
else:
token_encoder = text_encoder.ensure_find(TokenEncoder)
init_embedding = randn(token_encoder.embedding_dim)
adapter.add_concept(self.placeholder_token, init_embedding)
adapter.inject()
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
textual_inversion: TextualInversionConfig
def model_post_init(self, __context: Any) -> None:
# Pydantic v2 does post init differently, so we need to override this method too.
logger.info("Freezing models to train only the new embedding.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class TextualInversionLatentDiffusionTrainer(LatentDiffusionTrainer[TextualInversionLatentDiffusionConfig]):
def __init__(
self,
config: TextualInversionLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadTextualInversion(), SaveTextualInversion()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TextualInversionDataset(trainer=self)
class LoadTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_train_begin(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
trainer.config.textual_inversion.apply_textual_inversion_to_target(text_encoder=trainer.text_encoder)
class SaveTextualInversion(Callback[TextualInversionLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: TextualInversionLatentDiffusionTrainer) -> None:
embedding_extender = trainer.text_encoder.ensure_find(EmbeddingExtender)
tensors = {trainer.config.textual_inversion.placeholder_token: embedding_extender.new_weight.squeeze(0)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", tensors=tensors
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = TextualInversionLatentDiffusionConfig.load_from_toml(toml_path=config_path)
trainer = TextualInversionLatentDiffusionTrainer(config=config)
trainer.train()

View file

@ -13,8 +13,6 @@ __all__ = [
"GradientNormClipping", "GradientNormClipping",
"GradientValueClipping", "GradientValueClipping",
"ClockCallback", "ClockCallback",
"GradientNormLogging",
"MonitorLoss",
] ]
@ -153,31 +151,6 @@ class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
logger.info("Evaluation ended.") logger.info("Evaluation ended.")
class MonitorLoss(Callback["Trainer[BaseConfig, Any]"]):
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.epoch_losses: list[float] = []
self.iteration_losses: list[float] = []
def on_compute_loss_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
trainer.log(data={"step_loss": loss_value})
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = []
def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
trainer.log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
clip_norm = trainer.config.training.clip_grad_norm clip_norm = trainer.config.training.clip_grad_norm
@ -192,8 +165,3 @@ class GradientValueClipping(Callback["Trainer[BaseConfig, Any]"]):
clip_value = trainer.config.training.clip_grad_value clip_value = trainer.config.training.clip_grad_value
if clip_value is not None: if clip_value is not None:
clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value) clip_gradient_value(parameters=trainer.learnable_parameters, clip_value=clip_value)
class GradientNormLogging(Callback["Trainer[BaseConfig, Any]"]):
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
trainer.log(data={"total_grad_norm": trainer.total_gradient_norm})

View file

@ -196,37 +196,15 @@ class DropoutConfig(BaseModel):
apply_dropout(module=model, probability=self.dropout_probability) apply_dropout(module=model, probability=self.dropout_probability)
class WandbConfig(BaseModel):
mode: Literal["online", "offline", "disabled"] = "online"
project: str
entity: str = "finegrain"
name: str | None = None
tags: list[str] = []
group: str | None = None
job_type: str | None = None
notes: str | None = None
class CheckpointingConfig(BaseModel):
save_folder: Path | None = None
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
@validator("save_interval", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
T = TypeVar("T", bound="BaseConfig") T = TypeVar("T", bound="BaseConfig")
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
models: dict[str, ModelConfig] models: dict[str, ModelConfig]
wandb: WandbConfig
training: TrainingConfig training: TrainingConfig
optimizer: OptimizerConfig optimizer: OptimizerConfig
scheduler: SchedulerConfig scheduler: SchedulerConfig
dropout: DropoutConfig dropout: DropoutConfig
checkpointing: CheckpointingConfig
@classmethod @classmethod
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T: def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:

View file

@ -204,7 +204,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
) )
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i)) canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image images[prompt] = canvas_image
self.log(data=images) # TODO: wandb_log images
def sample_noise( def sample_noise(
@ -257,4 +257,4 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
self.timestep_bins[bin_index] = [] self.timestep_bins[bin_index] = []
trainer.log(data=log_data) # TODO: wandb_log log_data

View file

@ -2,7 +2,6 @@ import random
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property, wraps from functools import cached_property, wraps
from pathlib import Path
from typing import Any, Callable, Generic, Iterable, TypeVar, cast from typing import Any, Callable, Generic, Iterable, TypeVar, cast
import numpy as np import numpy as np
@ -33,13 +32,10 @@ from refiners.training_utils.callback import (
Callback, Callback,
ClockCallback, ClockCallback,
GradientNormClipping, GradientNormClipping,
GradientNormLogging,
GradientValueClipping, GradientValueClipping,
MonitorLoss,
) )
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
from refiners.training_utils.dropout import DropoutCallback from refiners.training_utils.dropout import DropoutCallback
from refiners.training_utils.wandb import WandbLoggable, WandbLogger
__all__ = ["seed_everything", "scoped_seed", "Trainer"] __all__ = ["seed_everything", "scoped_seed", "Trainer"]
@ -130,7 +126,6 @@ class TrainingClock:
gradient_accumulation: TimeValue, gradient_accumulation: TimeValue,
evaluation_interval: TimeValue, evaluation_interval: TimeValue,
lr_scheduler_interval: TimeValue, lr_scheduler_interval: TimeValue,
checkpointing_save_interval: TimeValue,
) -> None: ) -> None:
self.dataset_length = dataset_length self.dataset_length = dataset_length
self.batch_size = batch_size self.batch_size = batch_size
@ -138,7 +133,6 @@ class TrainingClock:
self.gradient_accumulation = gradient_accumulation self.gradient_accumulation = gradient_accumulation
self.evaluation_interval = evaluation_interval self.evaluation_interval = evaluation_interval
self.lr_scheduler_interval = lr_scheduler_interval self.lr_scheduler_interval = lr_scheduler_interval
self.checkpointing_save_interval = checkpointing_save_interval
self.num_batches_per_epoch = dataset_length // batch_size self.num_batches_per_epoch = dataset_length // batch_size
self.start_time = None self.start_time = None
self.end_time = None self.end_time = None
@ -225,12 +219,6 @@ class TrainingClock:
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"] number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
) )
@cached_property
def checkpointing_save_interval_steps(self) -> int:
return self.convert_time_unit_to_steps(
number=self.checkpointing_save_interval["number"], unit=self.checkpointing_save_interval["unit"]
)
@property @property
def is_optimizer_step(self) -> bool: def is_optimizer_step(self) -> bool:
return self.num_minibatches_processed == self.num_step_per_iteration return self.num_minibatches_processed == self.num_step_per_iteration
@ -247,10 +235,6 @@ class TrainingClock:
def is_evaluation_step(self) -> bool: def is_evaluation_step(self) -> bool:
return self.step % self.evaluation_interval_steps == 0 return self.step % self.evaluation_interval_steps == 0
@property
def is_checkpointing_step(self) -> bool:
return self.step % self.checkpointing_save_interval_steps == 0
def compute_grad_norm(parameters: Iterable[Parameter]) -> float: def compute_grad_norm(parameters: Iterable[Parameter]) -> float:
""" """
@ -276,27 +260,35 @@ class Trainer(Generic[ConfigType, Batch], ABC):
evaluation_interval=config.training.evaluation_interval, evaluation_interval=config.training.evaluation_interval,
gradient_accumulation=config.training.gradient_accumulation, gradient_accumulation=config.training.gradient_accumulation,
lr_scheduler_interval=config.scheduler.update_interval, lr_scheduler_interval=config.scheduler.update_interval,
checkpointing_save_interval=config.checkpointing.save_interval,
) )
self.callbacks = callbacks or [] self.callbacks = callbacks or []
self.callbacks += self.default_callbacks() self.callbacks += self.default_callbacks()
self._call_callbacks(event_name="on_init_begin") self._call_callbacks(event_name="on_init_begin")
self.load_wandb()
self.load_models() self.load_models()
self.prepare_models() self.prepare_models()
self.prepare_checkpointing()
self._call_callbacks(event_name="on_init_end") self._call_callbacks(event_name="on_init_end")
def default_callbacks(self) -> list[Callback[Any]]: def default_callbacks(self) -> list[Callback[Any]]:
return [ callbacks: list[Callback[Any]] = [
ClockCallback(), ClockCallback(),
MonitorLoss(),
GradientNormLogging(),
GradientValueClipping(), GradientValueClipping(),
GradientNormClipping(), GradientNormClipping(),
DropoutCallback(), DropoutCallback(),
] ]
# look for any Callback that might be a property of the Trainer
for attr_name in dir(self):
if "__" in attr_name:
continue
try:
attr = getattr(self, attr_name)
except AssertionError:
continue
if isinstance(attr, Callback):
callbacks.append(cast(Callback[Any], attr))
return callbacks
@cached_property @cached_property
def device(self) -> Device: def device(self) -> Device:
selected_device = Device(self.config.training.device) selected_device = Device(self.config.training.device)
@ -417,13 +409,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
for model in self.models.values(): for model in self.models.values():
model.eval() model.eval()
def log(self, data: dict[str, WandbLoggable]) -> None:
self.wandb.log(data=data, step=self.clock.step)
def load_wandb(self) -> None:
init_config = {**self.config.wandb.model_dump(), "config": self.config.model_dump()}
self.wandb = WandbLogger(init_config=init_config)
def prepare_model(self, model_name: str) -> None: def prepare_model(self, model_name: str) -> None:
model = self.models[model_name] model = self.models[model_name]
if (checkpoint := self.config.models[model_name].checkpoint) is not None: if (checkpoint := self.config.models[model_name].checkpoint) is not None:
@ -439,18 +424,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
for model_name in self.models: for model_name in self.models:
self.prepare_model(model_name=model_name) self.prepare_model(model_name=model_name)
def prepare_checkpointing(self) -> None:
if self.config.checkpointing.save_folder is not None:
assert self.config.checkpointing.save_folder.is_dir()
self.checkpoints_save_folder = (
self.config.checkpointing.save_folder / self.wandb.project_name / self.wandb.run_name
)
self.checkpoints_save_folder.mkdir(parents=True, exist_ok=False)
logger.info(f"Checkpointing enabled: {self.checkpoints_save_folder}")
else:
self.checkpoints_save_folder = None
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
@abstractmethod @abstractmethod
def load_models(self) -> dict[str, fl.Module]: def load_models(self) -> dict[str, fl.Module]:
... ...
@ -475,15 +448,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
) )
@property
def checkpointing_enabled(self) -> bool:
return self.checkpoints_save_folder is not None
@property
def ensure_checkpoints_save_folder(self) -> Path:
assert self.checkpoints_save_folder is not None
return self.checkpoints_save_folder
@abstractmethod @abstractmethod
def compute_loss(self, batch: Batch) -> Tensor: def compute_loss(self, batch: Batch) -> Tensor:
... ...
@ -508,8 +472,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
self._call_callbacks(event_name="on_lr_scheduler_step_end") self._call_callbacks(event_name="on_lr_scheduler_step_end")
if self.clock.is_evaluation_step: if self.clock.is_evaluation_step:
self.evaluate() self.evaluate()
if self.checkpointing_enabled and self.clock.is_checkpointing_step:
self._call_callbacks(event_name="on_checkpoint_save")
def step(self, batch: Batch) -> None: def step(self, batch: Batch) -> None:
"""Perform a single training step.""" """Perform a single training step."""

View file

@ -1,13 +1,15 @@
from typing import Any from abc import ABC
from functools import cached_property
from pathlib import Path
from typing import Any, Literal
import wandb import wandb
from PIL import Image from PIL import Image
from pydantic import BaseModel
__all__ = [ from refiners.training_utils.callback import Callback
"WandbLogger", from refiners.training_utils.config import BaseConfig
"WandbLoggable", from refiners.training_utils.trainer import Trainer
]
number = float | int number = float | int
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]] WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
@ -60,3 +62,87 @@ class WandbLogger:
@property @property
def run_name(self) -> str: def run_name(self) -> str:
return self.wandb_run.name or "" # type: ignore return self.wandb_run.name or "" # type: ignore
class WandbConfig(BaseModel):
"""
Wandb configuration.
See https://docs.wandb.ai/ref/python/init for more details.
"""
mode: Literal["online", "offline", "disabled"] = "disabled"
project: str
entity: str | None = None
save_code: bool | None = None
name: str | None = None
tags: list[str] = []
group: str | None = None
job_type: str | None = None
notes: str | None = None
dir: Path | None = None
resume: bool | Literal["allow", "must", "never", "auto"] | None = None
reinit: bool | None = None
magic: bool | None = None
anonymous: Literal["never", "allow", "must"] | None = None
id: str | None = None
AnyTrainer = Trainer[BaseConfig, Any]
class WandbCallback(Callback["TrainerWithWandb"]):
epoch_losses: list[float]
iteration_losses: list[float]
def on_init_begin(self, trainer: "TrainerWithWandb") -> None:
trainer.load_wandb()
def on_train_begin(self, trainer: "TrainerWithWandb") -> None:
self.epoch_losses = []
self.iteration_losses = []
def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
trainer.wandb_log(data={"step_loss": loss_value})
def on_optimizer_step_end(self, trainer: "TrainerWithWandb") -> None:
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
trainer.wandb_log(data={"average_iteration_loss": avg_iteration_loss})
self.iteration_losses = []
def on_epoch_end(self, trainer: "TrainerWithWandb") -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
trainer.wandb_log(data={"average_epoch_loss": avg_epoch_loss, "epoch": trainer.clock.epoch})
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"learning_rate": trainer.optimizer.param_groups[0]["lr"]})
def on_backward_end(self, trainer: "TrainerWithWandb") -> None:
trainer.wandb_log(data={"total_grad_norm": trainer.total_gradient_norm})
class WandbMixin(ABC):
config: Any
wandb_logger: WandbLogger
def load_wandb(self) -> None:
wandb_config = getattr(self.config, "wandb", None)
assert wandb_config is not None and isinstance(wandb_config, WandbConfig), "Wandb config is not set"
init_config = {**wandb_config.model_dump(), "config": self.config.model_dump()}
self.wandb_logger = WandbLogger(init_config=init_config)
def wandb_log(self, data: dict[str, WandbLoggable]) -> None:
assert isinstance(self, Trainer), "WandbMixin must be mixed with a Trainer"
self.wandb_logger.log(data=data, step=self.clock.step)
@cached_property
def wandb_callback(self) -> WandbCallback:
return WandbCallback()
class TrainerWithWandb(AnyTrainer, WandbMixin, ABC):
pass

View file

@ -24,10 +24,3 @@ warmup = "20:step"
[dropout] [dropout]
dropout = 0.0 dropout = 0.0
[checkpointing]
save_interval = "10:epoch"
[wandb]
mode = "disabled"
project = "mock_project"

View file

@ -119,7 +119,6 @@ def training_clock() -> TrainingClock:
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH}, gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH}, evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH}, lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH},
) )
@ -153,10 +152,8 @@ def test_timer_functionality(training_clock: TrainingClock) -> None:
def test_state_based_properties(training_clock: TrainingClock) -> None: def test_state_based_properties(training_clock: TrainingClock) -> None:
training_clock.step = 5 # Halfway through the first epoch training_clock.step = 5 # Halfway through the first epoch
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
assert not training_clock.is_checkpointing_step
training_clock.step = 10 # End of the first epoch training_clock.step = 10 # End of the first epoch
assert training_clock.is_evaluation_step assert training_clock.is_evaluation_step
assert training_clock.is_checkpointing_step
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: