mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
Refactor TimeValue
This commit is contained in:
parent
17246708b9
commit
446796da57
|
@ -219,15 +219,14 @@ We will now define the configuration for the autoencoder. It holds the configura
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers
|
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers, Epoch
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue
|
|
||||||
|
|
||||||
class AutoencoderConfig(BaseConfig):
|
class AutoencoderConfig(BaseConfig):
|
||||||
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
|
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
|
||||||
epoch_size: int = 2048
|
epoch_size: int = 2048
|
||||||
|
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
|
duration=Epoch(1000),
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32"
|
dtype="float32"
|
||||||
|
@ -336,11 +335,11 @@ We can also evaluate the model using the `compute_evaluation` method.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
|
duration=Epoch(1000)
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH),
|
evaluation_interval=Epoch(50),
|
||||||
)
|
)
|
||||||
|
|
||||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
|
@ -478,9 +477,8 @@ You can train this toy model using the code below:
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
register_callback,
|
register_callback,
|
||||||
register_model,
|
register_model,
|
||||||
|
Epoch,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(fl.Chain):
|
class ConvBlock(fl.Chain):
|
||||||
def __init__(self, in_channels: int, out_channels: int) -> None:
|
def __init__(self, in_channels: int, out_channels: int) -> None:
|
||||||
|
@ -628,11 +626,11 @@ You can train this toy model using the code below:
|
||||||
)
|
)
|
||||||
|
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
|
duration=Epoch(1000),
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH),
|
evaluation_interval=Epoch(50),
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = OptimizerConfig(
|
optimizer = OptimizerConfig(
|
||||||
|
@ -702,9 +700,9 @@ You can train this toy model using the code below:
|
||||||
axes[i, 1].axis("off")
|
axes[i, 1].axis("off")
|
||||||
axes[i, 1].set_title("Reconstructed")
|
axes[i, 1].set_title("Reconstructed")
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout() # type: ignore
|
||||||
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
|
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
|
||||||
plt.close()
|
plt.close() # type: ignore
|
||||||
|
|
||||||
@register_callback()
|
@register_callback()
|
||||||
def logging(self, config: CallbackConfig) -> LoggingCallback:
|
def logging(self, config: CallbackConfig) -> LoggingCallback:
|
||||||
|
|
|
@ -31,6 +31,7 @@ for dep in refiners_requires:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
|
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
|
||||||
from refiners.training_utils.config import (
|
from refiners.training_utils.config import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
LRSchedulerConfig,
|
LRSchedulerConfig,
|
||||||
|
@ -59,4 +60,9 @@ __all__ = [
|
||||||
"ClockConfig",
|
"ClockConfig",
|
||||||
"Optimizers",
|
"Optimizers",
|
||||||
"LRSchedulerType",
|
"LRSchedulerType",
|
||||||
|
"TimeValue",
|
||||||
|
"TimeUnit",
|
||||||
|
"Epoch",
|
||||||
|
"Iteration",
|
||||||
|
"Step",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,7 +3,7 @@ from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue
|
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
@ -52,47 +52,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||||
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation.unit == TimeUnit.EPOCH else 1
|
iteration_factor = self.num_batches_per_epoch if isinstance(self.gradient_accumulation, Epoch) else 1
|
||||||
return {
|
return {
|
||||||
TimeUnit.STEP: 1,
|
Step: 1,
|
||||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
Epoch: self.num_batches_per_epoch,
|
||||||
TimeUnit.ITERATION: self.gradient_accumulation.number * iteration_factor,
|
Iteration: self.gradient_accumulation.number * iteration_factor,
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
def convert_time_value_to_steps(self, time_value: TimeValue) -> int:
|
||||||
return number * self.unit_to_steps[unit]
|
return time_value.number * self.unit_to_steps[time_value.unit]
|
||||||
|
|
||||||
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
|
||||||
return steps // self.unit_to_steps[unit]
|
return steps // self.unit_to_steps[unit]
|
||||||
|
|
||||||
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
|
||||||
number, unit = time_value.number, time_value.unit
|
steps = self.convert_time_value_to_steps(time_value=time_value)
|
||||||
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
|
|
||||||
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_epochs(self) -> int:
|
def num_epochs(self) -> int:
|
||||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH)
|
return self.convert_time_value(time_value=self.training_duration, target_unit=Epoch)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_iterations(self) -> int:
|
def num_iterations(self) -> int:
|
||||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION)
|
return self.convert_time_value(time_value=self.training_duration, target_unit=Iteration)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_steps(self) -> int:
|
def num_steps(self) -> int:
|
||||||
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP)
|
return self.convert_time_value(time_value=self.training_duration, target_unit=Step)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_step_per_iteration(self) -> int:
|
def num_step_per_iteration(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_value_to_steps(self.gradient_accumulation)
|
||||||
number=self.gradient_accumulation.number, unit=self.gradient_accumulation.unit
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def num_step_per_evaluation(self) -> int:
|
def num_step_per_evaluation(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_value_to_steps(self.evaluation_interval)
|
||||||
number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.start_time = None
|
self.start_time = None
|
||||||
|
@ -116,15 +111,11 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def evaluation_interval_steps(self) -> int:
|
def evaluation_interval_steps(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_value_to_steps(self.evaluation_interval)
|
||||||
number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def lr_scheduler_interval_steps(self) -> int:
|
def lr_scheduler_interval_steps(self) -> int:
|
||||||
return self.convert_time_unit_to_steps(
|
return self.convert_time_value_to_steps(self.lr_scheduler_interval)
|
||||||
number=self.lr_scheduler_interval.number, unit=self.lr_scheduler_interval.unit
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_optimizer_step(self) -> bool:
|
def is_optimizer_step(self) -> bool:
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from typing import Any, Callable, Iterable, Protocol, runtime_checkable
|
||||||
from typing import Any, Callable, Iterable
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -83,32 +82,67 @@ class scoped_seed:
|
||||||
cuda.set_rng_state(self.cuda_torch_state)
|
cuda.set_rng_state(self.cuda_torch_state)
|
||||||
|
|
||||||
|
|
||||||
class TimeUnit(str, Enum):
|
@dataclass
|
||||||
STEP = "step"
|
@runtime_checkable
|
||||||
EPOCH = "epoch"
|
class TimeValue(Protocol):
|
||||||
ITERATION = "iteration"
|
number: int
|
||||||
DEFAULT = "step"
|
|
||||||
|
@property
|
||||||
|
def unit(self) -> "TimeUnit":
|
||||||
|
match self.__class__.__name__:
|
||||||
|
case "Step":
|
||||||
|
return Step
|
||||||
|
case "Epoch":
|
||||||
|
return Epoch
|
||||||
|
case "Iteration":
|
||||||
|
return Iteration
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported time unit: {self.__class__.__name__}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, value: str) -> "TimeValue":
|
||||||
|
match cls.extract_number_unit(value):
|
||||||
|
case number, "step":
|
||||||
|
return Step(number)
|
||||||
|
case number, "epoch":
|
||||||
|
return Epoch(number)
|
||||||
|
case number, "iteration":
|
||||||
|
return Iteration(number)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Incorrect time value format: {value}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_number_unit(value: str) -> tuple[int, str]:
|
||||||
|
number, unit = value.lower().split(":")
|
||||||
|
return int(number.strip()), unit.strip()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TimeValue:
|
class Step(TimeValue):
|
||||||
number: int
|
number: int
|
||||||
unit: TimeUnit
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Epoch(TimeValue):
|
||||||
|
number: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Iteration(TimeValue):
|
||||||
|
number: int
|
||||||
|
|
||||||
|
|
||||||
|
TimeUnit = type[Step] | type[Epoch] | type[Iteration]
|
||||||
TimeValueInput = str | int | dict[str, str | int] | TimeValue
|
TimeValueInput = str | int | dict[str, str | int] | TimeValue
|
||||||
|
|
||||||
|
|
||||||
def parse_number_unit_field(value: TimeValueInput) -> TimeValue:
|
def parse_number_unit_field(value: TimeValueInput) -> TimeValue:
|
||||||
match value:
|
match value:
|
||||||
case str(value_str):
|
case str(value_str):
|
||||||
number, unit = value_str.split(sep=":")
|
return TimeValue.from_str(value_str)
|
||||||
return TimeValue(number=int(number.strip()), unit=TimeUnit(value=unit.strip().lower()))
|
|
||||||
case int(number):
|
case int(number):
|
||||||
return TimeValue(number=number, unit=TimeUnit.DEFAULT)
|
return Step(number=number)
|
||||||
case {"number": int(number), "unit": str(unit)}:
|
case TimeValue(number):
|
||||||
return TimeValue(number=number, unit=TimeUnit(value=unit.lower()))
|
return value
|
||||||
case TimeValue(number, unit):
|
|
||||||
return TimeValue(number=number, unit=unit)
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported value format: {value}")
|
raise ValueError(f"Unsupported value format: {value}")
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
|
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field
|
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
|
||||||
|
|
||||||
# PyTorch optimizer parameters type
|
# PyTorch optimizer parameters type
|
||||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||||
|
@ -22,18 +22,18 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
dtype: str = "float32"
|
dtype: str = "float32"
|
||||||
duration: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
|
gradient_accumulation: Step | Epoch = Step(1)
|
||||||
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
evaluation_interval: Iteration | Epoch = Iteration(1)
|
||||||
gradient_clipping_max_norm: float | None = None
|
gradient_clipping_max_norm: float | None = None
|
||||||
evaluation_seed: int = 0
|
evaluation_seed: int = 0
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
|
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
|
||||||
def parse_field(cls, value: Any) -> TimeValue:
|
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||||
return parse_number_unit_field(value)
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,8 +63,8 @@ class LRSchedulerType(str, Enum):
|
||||||
|
|
||||||
class LRSchedulerConfig(BaseModel):
|
class LRSchedulerConfig(BaseModel):
|
||||||
type: LRSchedulerType = LRSchedulerType.DEFAULT
|
type: LRSchedulerType = LRSchedulerType.DEFAULT
|
||||||
update_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
|
update_interval: Iteration | Epoch = Iteration(1)
|
||||||
warmup: TimeValue = TimeValue(number=0, unit=TimeUnit.ITERATION)
|
warmup: TimeValue = Iteration(0)
|
||||||
gamma: float = 0.1
|
gamma: float = 0.1
|
||||||
lr_lambda: Callable[[int], float] | None = None
|
lr_lambda: Callable[[int], float] | None = None
|
||||||
mode: Literal["min", "max"] = "min"
|
mode: Literal["min", "max"] = "min"
|
||||||
|
|
|
@ -22,5 +22,5 @@ learning_rate = 1
|
||||||
|
|
||||||
[lr_scheduler]
|
[lr_scheduler]
|
||||||
type = "ConstantLR"
|
type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:iteration"
|
||||||
warmup = "20:step"
|
warmup = "20:iteration"
|
||||||
|
|
|
@ -23,4 +23,4 @@ learning_rate = 1
|
||||||
|
|
||||||
[lr_scheduler]
|
[lr_scheduler]
|
||||||
type = "ConstantLR"
|
type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:iteration"
|
||||||
|
|
|
@ -3,21 +3,41 @@ import random
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed
|
from refiners.training_utils.common import (
|
||||||
|
Epoch,
|
||||||
|
Iteration,
|
||||||
|
Step,
|
||||||
|
TimeValue,
|
||||||
|
TimeValueInput,
|
||||||
|
parse_number_unit_field,
|
||||||
|
scoped_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_value, expected_output",
|
"input_value, expected_output",
|
||||||
[
|
[
|
||||||
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)),
|
("3 : steP", Step(3)),
|
||||||
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)),
|
("5: epoch", Epoch(5)),
|
||||||
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)),
|
(" 7:Iteration", Iteration(7)),
|
||||||
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
|
|
||||||
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
|
|
||||||
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue):
|
def test_time_value_from_str(input_value: str, expected_output: TimeValue) -> None:
|
||||||
|
result = TimeValue.from_str(input_value)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_value, expected_output",
|
||||||
|
[
|
||||||
|
("10: step", Step(10)),
|
||||||
|
("20 :epoch", Epoch(20)),
|
||||||
|
("30: Iteration", Iteration(30)),
|
||||||
|
(50, Step(50)),
|
||||||
|
(Epoch(200), Epoch(200)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue) -> None:
|
||||||
result = parse_number_unit_field(input_value)
|
result = parse_number_unit_field(input_value)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
|
@ -26,8 +46,8 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T
|
||||||
"invalid_input",
|
"invalid_input",
|
||||||
[
|
[
|
||||||
"invalid:input",
|
"invalid:input",
|
||||||
{"number": "not_a_number", "unit": "step"},
|
"10: invalid",
|
||||||
{"invalid_key": 10},
|
"10",
|
||||||
None,
|
None,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,13 @@ from torch.optim import SGD
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import norm
|
from refiners.fluxion.utils import norm
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue, count_learnable_parameters, human_readable_number
|
from refiners.training_utils.common import (
|
||||||
|
Epoch,
|
||||||
|
Iteration,
|
||||||
|
Step,
|
||||||
|
count_learnable_parameters,
|
||||||
|
human_readable_number,
|
||||||
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||||
from refiners.training_utils.trainer import (
|
from refiners.training_utils.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
|
@ -96,7 +102,7 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
|
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
|
||||||
mock_config_short = mock_config.model_copy(deep=True)
|
mock_config_short = mock_config.model_copy(deep=True)
|
||||||
mock_config_short.training.duration = TimeValue(number=3, unit=TimeUnit.STEP)
|
mock_config_short.training.duration = Step(3)
|
||||||
return MockTrainer(config=mock_config_short)
|
return MockTrainer(config=mock_config_short)
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,10 +136,10 @@ def training_clock() -> TrainingClock:
|
||||||
return TrainingClock(
|
return TrainingClock(
|
||||||
dataset_length=100,
|
dataset_length=100,
|
||||||
batch_size=10,
|
batch_size=10,
|
||||||
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
training_duration=Epoch(5),
|
||||||
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
gradient_accumulation=Epoch(1),
|
||||||
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
evaluation_interval=Epoch(1),
|
||||||
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
lr_scheduler_interval=Epoch(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,10 +148,10 @@ def test_small_dataset_error():
|
||||||
TrainingClock(
|
TrainingClock(
|
||||||
dataset_length=3,
|
dataset_length=3,
|
||||||
batch_size=10,
|
batch_size=10,
|
||||||
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
training_duration=Epoch(5),
|
||||||
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
gradient_accumulation=Epoch(1),
|
||||||
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
evaluation_interval=Epoch(1),
|
||||||
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
lr_scheduler_interval=Epoch(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,23 +160,25 @@ def test_zero_batch_size_error():
|
||||||
TrainingClock(
|
TrainingClock(
|
||||||
dataset_length=3,
|
dataset_length=3,
|
||||||
batch_size=0,
|
batch_size=0,
|
||||||
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
training_duration=Epoch(5),
|
||||||
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
gradient_accumulation=Epoch(1),
|
||||||
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
evaluation_interval=Epoch(1),
|
||||||
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
lr_scheduler_interval=Epoch(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
||||||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10
|
assert training_clock.convert_time_value_to_steps(Epoch(1)) == 10
|
||||||
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20
|
assert training_clock.convert_time_value_to_steps(Epoch(2)) == 20
|
||||||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.STEP) == 1
|
assert training_clock.convert_time_value_to_steps(Step(1)) == 1
|
||||||
|
assert training_clock.convert_time_value_to_steps(Iteration(1)) == 10
|
||||||
|
|
||||||
|
|
||||||
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None:
|
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None:
|
||||||
assert training_clock.convert_steps_to_time_unit(10, TimeUnit.EPOCH) == 1
|
assert training_clock.convert_steps_to_time_unit(10, Epoch) == 1
|
||||||
assert training_clock.convert_steps_to_time_unit(20, TimeUnit.EPOCH) == 2
|
assert training_clock.convert_steps_to_time_unit(20, Epoch) == 2
|
||||||
assert training_clock.convert_steps_to_time_unit(1, TimeUnit.STEP) == 1
|
assert training_clock.convert_steps_to_time_unit(1, Step) == 1
|
||||||
|
assert training_clock.convert_steps_to_time_unit(10, Iteration) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_clock_properties(training_clock: TrainingClock) -> None:
|
def test_clock_properties(training_clock: TrainingClock) -> None:
|
||||||
|
|
Loading…
Reference in a new issue