mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
change TimeValue to a dataclass
This commit is contained in:
parent
b8fae60d38
commit
6a72943ff7
|
@ -220,13 +220,14 @@ Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers
|
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers
|
||||||
|
from refiners.training_utils.common import TimeUnit, TimeValue
|
||||||
|
|
||||||
class AutoencoderConfig(BaseConfig):
|
class AutoencoderConfig(BaseConfig):
|
||||||
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
|
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
|
||||||
epoch_size: int = 2048
|
epoch_size: int = 2048
|
||||||
|
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration="1000:epoch",
|
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32"
|
dtype="float32"
|
||||||
|
@ -335,11 +336,11 @@ We can also evaluate the model using the `compute_evaluation` method.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration="1000:epoch",
|
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
evaluation_interval="50:epoch" # We set the evaluation to be done every 10 epochs
|
evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH),
|
||||||
)
|
)
|
||||||
|
|
||||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
|
@ -459,6 +460,8 @@ You can train this toy model using the code below:
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||||
from refiners.training_utils import (
|
from refiners.training_utils import (
|
||||||
|
@ -476,7 +479,7 @@ You can train this toy model using the code below:
|
||||||
register_callback,
|
register_callback,
|
||||||
register_model,
|
register_model,
|
||||||
)
|
)
|
||||||
from torch.nn import functional as F
|
from refiners.training_utils.common import TimeUnit, TimeValue
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(fl.Chain):
|
class ConvBlock(fl.Chain):
|
||||||
|
@ -487,7 +490,7 @@ You can train this toy model using the code below:
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=min(in_channels, out_channels)
|
groups=min(in_channels, out_channels),
|
||||||
),
|
),
|
||||||
fl.LayerNorm2d(out_channels),
|
fl.LayerNorm2d(out_channels),
|
||||||
fl.SiLU(),
|
fl.SiLU(),
|
||||||
|
@ -576,9 +579,7 @@ You can train this toy model using the code below:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
rectangle = Image.new(
|
rectangle = Image.new("L", (random.randint(1, size), random.randint(1, size)), color=255)
|
||||||
"L", (random.randint(1, size), random.randint(1, size)), color=255
|
|
||||||
)
|
|
||||||
mask = Image.new("L", (size, size))
|
mask = Image.new("L", (size, size))
|
||||||
mask.paste(
|
mask.paste(
|
||||||
rectangle,
|
rectangle,
|
||||||
|
@ -627,11 +628,11 @@ You can train this toy model using the code below:
|
||||||
)
|
)
|
||||||
|
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration="1000:epoch", # type: ignore
|
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
evaluation_interval="50:epoch", # type: ignore
|
evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH),
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = OptimizerConfig(
|
optimizer = OptimizerConfig(
|
||||||
|
@ -639,9 +640,7 @@ You can train this toy model using the code below:
|
||||||
learning_rate=1e-4,
|
learning_rate=1e-4,
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = LRSchedulerConfig(
|
lr_scheduler = LRSchedulerConfig(type=LRSchedulerType.CONSTANT_LR)
|
||||||
type=LRSchedulerType.CONSTANT_LR
|
|
||||||
)
|
|
||||||
|
|
||||||
config = AutoencoderConfig(
|
config = AutoencoderConfig(
|
||||||
training=training,
|
training=training,
|
||||||
|
@ -672,9 +671,7 @@ You can train this toy model using the code below:
|
||||||
return Autoencoder()
|
return Autoencoder()
|
||||||
|
|
||||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||||
x_reconstructed = self.autoencoder.decoder(
|
x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))
|
||||||
self.autoencoder.encoder(batch.image)
|
|
||||||
)
|
|
||||||
return F.binary_cross_entropy(x_reconstructed, batch.image)
|
return F.binary_cross_entropy(x_reconstructed, batch.image)
|
||||||
|
|
||||||
def compute_evaluation(self) -> None:
|
def compute_evaluation(self) -> None:
|
||||||
|
|
|
@ -48,11 +48,11 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||||
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
|
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation.unit == TimeUnit.EPOCH else 1
|
||||||
return {
|
return {
|
||||||
TimeUnit.STEP: 1,
|
TimeUnit.STEP: 1,
|
||||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||||
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
|
TimeUnit.ITERATION: self.gradient_accumulation.number * iteration_factor,
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
||||||
|
@ -62,7 +62,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
return steps // self.unit_to_steps[unit]
|
return steps // self.unit_to_steps[unit]
|
||||||
|
|
||||||
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||||
number, unit = time_value["number"], time_value["unit"]
|
number, unit = time_value.number, time_value.unit
|
||||||
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
||||||
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||||
|
|
||||||
|
@ -81,13 +81,13 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_step_per_iteration(self) -> int:
|
def num_step_per_iteration(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_unit_to_steps(
|
||||||
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"]
|
number=self.gradient_accumulation.number, unit=self.gradient_accumulation.unit
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_step_per_evaluation(self) -> int:
|
def num_step_per_evaluation(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_unit_to_steps(
|
||||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
@ -113,13 +113,13 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
@cached_property
|
@cached_property
|
||||||
def evaluation_interval_steps(self) -> int:
|
def evaluation_interval_steps(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_unit_to_steps(
|
||||||
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"]
|
number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def lr_scheduler_interval_steps(self) -> int:
|
def lr_scheduler_interval_steps(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_unit_to_steps(
|
||||||
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"]
|
number=self.lr_scheduler_interval.number, unit=self.lr_scheduler_interval.unit
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import random
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Any, Callable, Iterable
|
||||||
|
@ -7,7 +8,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch import Tensor, cuda, nn
|
from torch import Tensor, cuda, nn
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
|
@ -79,26 +79,32 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class TimeUnit(Enum):
|
class TimeUnit(str, Enum):
|
||||||
STEP = "step"
|
STEP = "step"
|
||||||
EPOCH = "epoch"
|
EPOCH = "epoch"
|
||||||
ITERATION = "iteration"
|
ITERATION = "iteration"
|
||||||
DEFAULT = "step"
|
DEFAULT = "step"
|
||||||
|
|
||||||
|
|
||||||
class TimeValue(TypedDict):
|
@dataclass
|
||||||
|
class TimeValue:
|
||||||
number: int
|
number: int
|
||||||
unit: TimeUnit
|
unit: TimeUnit
|
||||||
|
|
||||||
|
|
||||||
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue:
|
TimeValueInput = str | int | dict[str, str | int] | TimeValue
|
||||||
|
|
||||||
|
|
||||||
|
def parse_number_unit_field(value: TimeValueInput) -> TimeValue:
|
||||||
match value:
|
match value:
|
||||||
case str(value_str):
|
case str(value_str):
|
||||||
number, unit = value_str.split(sep=":")
|
number, unit = value_str.split(sep=":")
|
||||||
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())}
|
return TimeValue(number=int(number.strip()), unit=TimeUnit(value=unit.strip().lower()))
|
||||||
case int(number):
|
case int(number):
|
||||||
return {"number": number, "unit": TimeUnit.DEFAULT}
|
return TimeValue(number=number, unit=TimeUnit.DEFAULT)
|
||||||
case {"number": int(number), "unit": str(unit)}:
|
case {"number": int(number), "unit": str(unit)}:
|
||||||
return {"number": number, "unit": TimeUnit(value=unit.lower())}
|
return TimeValue(number=number, unit=TimeUnit(value=unit.lower()))
|
||||||
|
case TimeValue(number, unit):
|
||||||
|
return TimeValue(number=number, unit=unit)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported value format: {value}")
|
raise ValueError(f"Unsupported value format: {value}")
|
||||||
|
|
|
@ -23,11 +23,11 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
dtype: str = "float32"
|
dtype: str = "float32"
|
||||||
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
duration: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
|
||||||
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||||
evaluation_seed: int = 0
|
evaluation_seed: int = 0
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
@ -63,8 +63,8 @@ class LRSchedulerType(str, Enum):
|
||||||
|
|
||||||
class LRSchedulerConfig(BaseModel):
|
class LRSchedulerConfig(BaseModel):
|
||||||
type: LRSchedulerType = LRSchedulerType.DEFAULT
|
type: LRSchedulerType = LRSchedulerType.DEFAULT
|
||||||
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
update_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||||
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION}
|
warmup: TimeValue = TimeValue(number=0, unit=TimeUnit.ITERATION)
|
||||||
gamma: float = 0.1
|
gamma: float = 0.1
|
||||||
lr_lambda: Callable[[int], float] | None = None
|
lr_lambda: Callable[[int], float] | None = None
|
||||||
mode: Literal["min", "max"] = "min"
|
mode: Literal["min", "max"] = "min"
|
||||||
|
|
|
@ -241,7 +241,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@cached_property
|
@cached_property
|
||||||
def lr_scheduler(self) -> LRScheduler:
|
def lr_scheduler(self) -> LRScheduler:
|
||||||
config = self.config.lr_scheduler
|
config = self.config.lr_scheduler
|
||||||
scheduler_step_size = config.update_interval["number"]
|
scheduler_step_size = config.update_interval.number
|
||||||
|
|
||||||
match config.type:
|
match config.type:
|
||||||
case LRSchedulerType.CONSTANT_LR:
|
case LRSchedulerType.CONSTANT_LR:
|
||||||
|
@ -286,7 +286,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown scheduler type: {config.type}")
|
raise ValueError(f"Unknown scheduler type: {config.type}")
|
||||||
|
|
||||||
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"])
|
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval.unit)
|
||||||
if warmup_scheduler_steps > 0:
|
if warmup_scheduler_steps > 0:
|
||||||
lr_scheduler = WarmupScheduler(
|
lr_scheduler = WarmupScheduler(
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
|
|
33
tests/training_utils/test_common.py
Normal file
33
tests/training_utils/test_common.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_value, expected_output",
|
||||||
|
[
|
||||||
|
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)),
|
||||||
|
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)),
|
||||||
|
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)),
|
||||||
|
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
|
||||||
|
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
|
||||||
|
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue):
|
||||||
|
result = parse_number_unit_field(input_value)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"invalid_input",
|
||||||
|
[
|
||||||
|
"invalid:input",
|
||||||
|
{"number": "not_a_number", "unit": "step"},
|
||||||
|
{"invalid_key": 10},
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_number_unit_field(invalid_input)
|
|
@ -10,7 +10,7 @@ from torch.optim import SGD
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import norm
|
from refiners.fluxion.utils import norm
|
||||||
from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number
|
from refiners.training_utils.common import TimeUnit, TimeValue, count_learnable_parameters, human_readable_number
|
||||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||||
from refiners.training_utils.trainer import (
|
from refiners.training_utils.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
|
@ -96,7 +96,7 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
|
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
|
||||||
mock_config_short = mock_config.model_copy(deep=True)
|
mock_config_short = mock_config.model_copy(deep=True)
|
||||||
mock_config_short.training.duration = {"number": 3, "unit": TimeUnit.STEP}
|
mock_config_short.training.duration = TimeValue(number=3, unit=TimeUnit.STEP)
|
||||||
return MockTrainer(config=mock_config_short)
|
return MockTrainer(config=mock_config_short)
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,10 +130,10 @@ def training_clock() -> TrainingClock:
|
||||||
return TrainingClock(
|
return TrainingClock(
|
||||||
dataset_length=100,
|
dataset_length=100,
|
||||||
batch_size=10,
|
batch_size=10,
|
||||||
training_duration={"number": 5, "unit": TimeUnit.EPOCH},
|
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
||||||
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
|
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
clock = mock_trainer.clock
|
clock = mock_trainer.clock
|
||||||
config = mock_trainer.config
|
config = mock_trainer.config
|
||||||
|
|
||||||
assert clock.num_step_per_iteration == config.training.gradient_accumulation["number"]
|
assert clock.num_step_per_iteration == config.training.gradient_accumulation.number
|
||||||
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
|
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
|
||||||
|
|
||||||
assert mock_trainer.step_counter == 0
|
assert mock_trainer.step_counter == 0
|
||||||
|
@ -191,8 +191,8 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
|
|
||||||
mock_trainer.train()
|
mock_trainer.train()
|
||||||
|
|
||||||
assert clock.epoch == config.training.duration["number"]
|
assert clock.epoch == config.training.duration.number
|
||||||
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch
|
assert clock.step == config.training.duration.number * clock.num_batches_per_epoch
|
||||||
|
|
||||||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
assert mock_trainer.step_counter == mock_trainer.clock.step
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
||||||
|
|
||||||
mock_trainer_short.train()
|
mock_trainer_short.train()
|
||||||
|
|
||||||
assert clock.step == config.training.duration["number"]
|
assert clock.step == config.training.duration.number
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
Loading…
Reference in a new issue