mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add basic unit test for training_utils
This commit is contained in:
parent
dba9065229
commit
7f722029be
32
tests/training_utils/mock_config.toml
Normal file
32
tests/training_utils/mock_config.toml
Normal file
|
@ -0,0 +1,32 @@
|
|||
[models.mock_model]
|
||||
train = true
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
seed = 0
|
||||
gpu_index = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
clip_grad_norm = 1.0
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
|
||||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
learning_rate = 1
|
||||
momentum = 0.9
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
warmup = "20:step"
|
||||
|
||||
[dropout]
|
||||
dropout = 0.0
|
||||
|
||||
[checkpointing]
|
||||
save_interval = "10:epoch"
|
||||
|
||||
[wandb]
|
||||
mode = "disabled"
|
||||
project = "mock_project"
|
185
tests/training_utils/test_trainer.py
Normal file
185
tests/training_utils/test_trainer.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import norm
|
||||
from refiners.training_utils.config import BaseConfig, TimeUnit
|
||||
from refiners.training_utils.trainer import (
|
||||
Trainer,
|
||||
TrainingClock,
|
||||
count_learnable_parameters,
|
||||
human_readable_number,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockBatch:
|
||||
inputs: torch.Tensor
|
||||
targets: torch.Tensor
|
||||
|
||||
|
||||
class MockDataset(Dataset[MockBatch]):
|
||||
def __len__(self):
|
||||
return 20
|
||||
|
||||
def __getitem__(self, _: int) -> MockBatch:
|
||||
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))
|
||||
|
||||
def collate_fn(self, batch: list[MockBatch]) -> MockBatch:
|
||||
return MockBatch(
|
||||
inputs=torch.cat([b.inputs for b in batch]),
|
||||
targets=torch.cat([b.targets for b in batch]),
|
||||
)
|
||||
|
||||
|
||||
class MockConfig(BaseConfig):
|
||||
pass
|
||||
|
||||
|
||||
class MockModel(fl.Chain):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
fl.Linear(10, 10),
|
||||
fl.Linear(10, 10),
|
||||
fl.Linear(10, 10),
|
||||
)
|
||||
|
||||
|
||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||
step_counter: int = 0
|
||||
|
||||
@cached_property
|
||||
def mock_model(self) -> MockModel:
|
||||
return MockModel()
|
||||
|
||||
def load_dataset(self) -> Dataset[MockBatch]:
|
||||
return MockDataset()
|
||||
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
return {"mock_model": self.mock_model}
|
||||
|
||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||
self.step_counter += 1
|
||||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||
outputs = self.mock_model(inputs)
|
||||
return norm(outputs - targets)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(test_device: torch.device) -> MockConfig:
|
||||
if not test_device.type == "cuda":
|
||||
warn("only running on CUDA, skipping")
|
||||
pytest.skip("Skipping test because test_device is not CUDA")
|
||||
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
|
||||
config.training.gpu_index = test_device.index
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer(mock_config: MockConfig) -> MockTrainer:
|
||||
return MockTrainer(config=mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model() -> fl.Chain:
|
||||
return MockModel()
|
||||
|
||||
|
||||
def test_count_learnable_parameters_with_params() -> None:
|
||||
params = [
|
||||
nn.Parameter(torch.randn(2, 2), requires_grad=True),
|
||||
nn.Parameter(torch.randn(5), requires_grad=False),
|
||||
nn.Parameter(torch.randn(3, 3), requires_grad=True),
|
||||
]
|
||||
assert count_learnable_parameters(params) == 13
|
||||
|
||||
|
||||
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
|
||||
assert count_learnable_parameters(mock_model.parameters()) == 330
|
||||
|
||||
|
||||
def test_human_readable_number() -> None:
|
||||
assert human_readable_number(123) == "123.0"
|
||||
assert human_readable_number(1234) == "1.2K"
|
||||
assert human_readable_number(1234567) == "1.2M"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def training_clock() -> TrainingClock:
|
||||
return TrainingClock(
|
||||
dataset_length=100,
|
||||
batch_size=10,
|
||||
training_duration={"number": 5, "unit": TimeUnit.EPOCH},
|
||||
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH},
|
||||
)
|
||||
|
||||
|
||||
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10
|
||||
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20
|
||||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.STEP) == 1
|
||||
|
||||
|
||||
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.convert_steps_to_time_unit(10, TimeUnit.EPOCH) == 1
|
||||
assert training_clock.convert_steps_to_time_unit(20, TimeUnit.EPOCH) == 2
|
||||
assert training_clock.convert_steps_to_time_unit(1, TimeUnit.STEP) == 1
|
||||
|
||||
|
||||
def test_clock_properties(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.num_batches_per_epoch == 10
|
||||
assert training_clock.num_epochs == 5
|
||||
assert training_clock.num_iterations == 5
|
||||
assert training_clock.num_steps == 50
|
||||
|
||||
|
||||
def test_timer_functionality(training_clock: TrainingClock) -> None:
|
||||
training_clock.start_timer()
|
||||
assert training_clock.start_time is not None
|
||||
training_clock.stop_timer()
|
||||
assert training_clock.end_time is not None
|
||||
assert training_clock.time_elapsed >= 0
|
||||
|
||||
|
||||
def test_state_based_properties(training_clock: TrainingClock) -> None:
|
||||
training_clock.step = 5 # Halfway through the first epoch
|
||||
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch
|
||||
assert not training_clock.is_checkpointing_step
|
||||
training_clock.step = 10 # End of the first epoch
|
||||
assert training_clock.is_evaluation_step
|
||||
assert training_clock.is_checkpointing_step
|
||||
|
||||
|
||||
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None:
|
||||
assert mock_trainer.config == mock_config
|
||||
assert isinstance(mock_trainer, MockTrainer)
|
||||
assert mock_trainer.optimizer is not None
|
||||
assert mock_trainer.lr_scheduler is not None
|
||||
|
||||
|
||||
def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||
clock = mock_trainer.clock
|
||||
config = mock_trainer.config
|
||||
|
||||
assert clock.num_step_per_iteration == config.training.gradient_accumulation["number"]
|
||||
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
|
||||
|
||||
assert mock_trainer.step_counter == 0
|
||||
assert mock_trainer.clock.epoch == 0
|
||||
|
||||
mock_trainer.train()
|
||||
|
||||
assert clock.epoch == config.training.duration["number"]
|
||||
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch
|
||||
|
||||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
Loading…
Reference in a new issue