diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 91214ee..11df387 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -20,14 +20,11 @@ class ClockConfig(CallbackConfig): class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): def __init__( self, - batch_size: int, training_duration: TimeValue, gradient_accumulation: Step, lr_scheduler_interval: TimeValue, verbose: bool = True, ) -> None: - assert batch_size > 0, "Batch size must be greater than 0." - self.batch_size = batch_size self.training_duration = training_duration self.gradient_accumulation = gradient_accumulation self.lr_scheduler_interval = lr_scheduler_interval diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index a6220c5..f482968 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -22,7 +22,7 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] class TrainingConfig(BaseModel): device: str = "cpu" dtype: str = "float32" - duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) + duration: TimeValue = Iteration(1) seed: int = 0 batch_size: int = 1 gradient_accumulation: Step = Step(1) @@ -144,17 +144,6 @@ class OptimizerConfig(BaseModel): ) -class DataloaderConfig(BaseModel): - model_config = ConfigDict(extra="forbid") - - num_workers: int = 0 - pin_memory: bool = False - prefetch_factor: int | None = None - persistent_workers: bool = False - drop_last: bool = False - shuffle: bool = True - - class ModelConfig(BaseModel): # If None, then requires_grad will NOT be changed when loading the model # this can be useful if you want to train only a part of the model @@ -176,7 +165,6 @@ class BaseConfig(BaseModel): optimizer: OptimizerConfig lr_scheduler: LRSchedulerConfig clock: ClockConfig = ClockConfig() - dataloader: DataloaderConfig = DataloaderConfig() model_config = ConfigDict(extra="forbid") diff --git a/src/refiners/training_utils/data_loader.py b/src/refiners/training_utils/data_loader.py new file mode 100644 index 0000000..9a3ac74 --- /dev/null +++ b/src/refiners/training_utils/data_loader.py @@ -0,0 +1,55 @@ +from typing import Callable, TypeVar + +from pydantic import BaseModel, ConfigDict, PositiveInt +from torch.utils.data import DataLoader, Dataset + +BatchT = TypeVar("BatchT") + + +class DataloaderConfig(BaseModel): + batch_size: PositiveInt = 1 + num_workers: int = 0 + pin_memory: bool = False + prefetch_factor: int | None = None + persistent_workers: bool = False + drop_last: bool = False + shuffle: bool = True + + model_config = ConfigDict(extra="forbid") + # TODO: Add more validation to the config + + +class DatasetFromCallable(Dataset[BatchT]): + """ + A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset]. + """ + + def __init__(self, get_item: Callable[[int], BatchT], length: int) -> None: + assert length > 0, "Dataset length must be greater than 0." + self.length = length + self.get_item = get_item + + def __getitem__(self, index: int) -> BatchT: + return self.get_item(index) + + def __len__(self) -> int: + return self.length + + +def create_data_loader( + get_item: Callable[[int], BatchT], + length: int, + config: DataloaderConfig, + collate_fn: Callable[[list[BatchT]], BatchT] | None = None, +) -> DataLoader[BatchT]: + return DataLoader( + DatasetFromCallable(get_item, length), + batch_size=config.batch_size, + num_workers=config.num_workers, + pin_memory=config.pin_memory, + prefetch_factor=config.prefetch_factor, + persistent_workers=config.persistent_workers, + drop_last=config.drop_last, + shuffle=config.shuffle, + collate_fn=collate_fn, + ) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index c8bca70..3c2afc2 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import cached_property, wraps -from typing import Any, Callable, Generic, Literal, TypeVar, cast +from typing import Any, Callable, Generic, Iterable, Literal, TypeVar, cast import torch from loguru import logger @@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import ( ReduceLROnPlateau, StepLR, ) -from torch.utils.data import DataLoader, Dataset from refiners.fluxion import layers as fl from refiners.training_utils.callback import ( @@ -64,23 +63,6 @@ Batch = TypeVar("Batch") ConfigType = TypeVar("ConfigType", bound=BaseConfig) -class _Dataset(Dataset[Batch]): - """ - A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset]. - """ - - def __init__(self, get_item: Callable[[int], Batch], length: int) -> None: - assert length > 0, "Dataset length must be greater than 0." - self.length = length - self.get_item = get_item - - def __getitem__(self, index: int) -> Batch: - return self.get_item(index) - - def __len__(self) -> int: - return self.length - - @dataclass class ModelItem: name: str @@ -151,7 +133,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): @register_callback() def clock(self, config: ClockConfig) -> TrainingClock: return TrainingClock( - batch_size=self.config.training.batch_size, training_duration=self.config.training.duration, gradient_accumulation=self.config.training.gradient_accumulation, lr_scheduler_interval=self.config.lr_scheduler.update_interval, @@ -294,58 +275,14 @@ class Trainer(Generic[ConfigType, Batch], ABC): return lr_scheduler @abstractmethod - def get_item(self, index: int) -> Batch: - """ - Returns a batch of data. + def compute_loss(self, batch: Batch) -> Tensor: ... - This function is used by the dataloader to fetch a batch of data. - """ - ... + @abstractmethod + def create_data_iterable(self) -> Iterable[Batch]: ... @property - @abstractmethod - def dataset_length(self) -> int: - """ - Returns the length of the dataset. - - This is used to compute the number of batches per epoch. - """ - ... - - @abstractmethod - def collate_fn(self, batch: list[Batch]) -> Batch: - """ - Collate function for the dataloader. - - This function is used to tell the dataloader how to combine a list of - batches into a single batch. - """ - ... - - @cached_property - def dataset(self) -> Dataset[Batch]: - """ - Returns the dataset constructed with the `get_item` method. - """ - return _Dataset(get_item=self.get_item, length=self.dataset_length) - - @cached_property - def dataloader(self) -> DataLoader[Any]: - config = self.config.dataloader - return DataLoader( - dataset=self.dataset, - batch_size=self.config.training.batch_size, - collate_fn=self.collate_fn, - num_workers=config.num_workers, - prefetch_factor=config.prefetch_factor, - persistent_workers=config.persistent_workers, - pin_memory=config.pin_memory, - shuffle=config.shuffle, - drop_last=config.drop_last, - ) - - @abstractmethod - def compute_loss(self, batch: Batch) -> Tensor: ... + def data_iterable(self) -> Iterable[Batch]: + return self.create_data_iterable() def backward(self) -> None: """Backward pass on the loss.""" @@ -375,7 +312,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): def epoch(self) -> None: """Perform a single epoch.""" - for batch in self.dataloader: + for batch in self.data_iterable: if self.clock.done: break self._call_callbacks(event_name="on_step_begin") diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index a147164..9a6f167 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -20,6 +20,9 @@ batch_size = 4 gradient_accumulation = "4:step" gradient_clipping_max_norm = 1.0 +[data_loader] +batch_size = 4 + [optimizer] optimizer = "SGD" learning_rate = 1 diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index c4cf612..56e9881 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -26,6 +26,7 @@ from refiners.training_utils.common import ( scoped_seed, ) from refiners.training_utils.config import BaseConfig, ModelConfig +from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader from refiners.training_utils.trainer import ( Trainer, TrainingClock, @@ -64,6 +65,7 @@ class MockConfig(BaseConfig): mock_model: MockModelConfig mock_callback: MockCallbackConfig + data_loader: DataloaderConfig class MockModel(fl.Chain): @@ -134,6 +136,14 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): targets=torch.cat([b.targets for b in batch]), ) + def create_data_iterable(self): + return create_data_loader( + get_item=self.get_item, + length=self.dataset_length, + config=self.config.data_loader, + collate_fn=self.collate_fn, + ) + @register_callback() def early_callback(self, config: CallbackConfig) -> EarlyMockCallback: return EarlyMockCallback() @@ -204,23 +214,12 @@ def test_human_readable_number() -> None: @pytest.fixture def training_clock() -> TrainingClock: return TrainingClock( - batch_size=10, training_duration=Epoch(5), gradient_accumulation=Step(1), lr_scheduler_interval=Epoch(1), ) -def test_zero_batch_size_error(): - with pytest.raises(AssertionError): - TrainingClock( - batch_size=0, - training_duration=Epoch(5), - gradient_accumulation=Step(1), - lr_scheduler_interval=Epoch(1), - ) - - def test_timer_functionality(training_clock: TrainingClock) -> None: training_clock.start_timer() assert training_clock.start_time is not None