implement data_iterable (bis)

This commit is contained in:
limiteinductive 2024-04-24 17:14:32 +00:00 committed by Benjamin Trom
parent de8334b6fc
commit d6c225a112
6 changed files with 76 additions and 97 deletions

View file

@ -20,14 +20,11 @@ class ClockConfig(CallbackConfig):
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def __init__(
self,
batch_size: int,
training_duration: TimeValue,
gradient_accumulation: Step,
lr_scheduler_interval: TimeValue,
verbose: bool = True,
) -> None:
assert batch_size > 0, "Batch size must be greater than 0."
self.batch_size = batch_size
self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation
self.lr_scheduler_interval = lr_scheduler_interval

View file

@ -22,7 +22,7 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
class TrainingConfig(BaseModel):
device: str = "cpu"
dtype: str = "float32"
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
duration: TimeValue = Iteration(1)
seed: int = 0
batch_size: int = 1
gradient_accumulation: Step = Step(1)
@ -144,17 +144,6 @@ class OptimizerConfig(BaseModel):
)
class DataloaderConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
num_workers: int = 0
pin_memory: bool = False
prefetch_factor: int | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True
class ModelConfig(BaseModel):
# If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model
@ -176,7 +165,6 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig
lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig()
dataloader: DataloaderConfig = DataloaderConfig()
model_config = ConfigDict(extra="forbid")

View file

@ -0,0 +1,55 @@
from typing import Callable, TypeVar
from pydantic import BaseModel, ConfigDict, PositiveInt
from torch.utils.data import DataLoader, Dataset
BatchT = TypeVar("BatchT")
class DataloaderConfig(BaseModel):
batch_size: PositiveInt = 1
num_workers: int = 0
pin_memory: bool = False
prefetch_factor: int | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True
model_config = ConfigDict(extra="forbid")
# TODO: Add more validation to the config
class DatasetFromCallable(Dataset[BatchT]):
"""
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
"""
def __init__(self, get_item: Callable[[int], BatchT], length: int) -> None:
assert length > 0, "Dataset length must be greater than 0."
self.length = length
self.get_item = get_item
def __getitem__(self, index: int) -> BatchT:
return self.get_item(index)
def __len__(self) -> int:
return self.length
def create_data_loader(
get_item: Callable[[int], BatchT],
length: int,
config: DataloaderConfig,
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
) -> DataLoader[BatchT]:
return DataLoader(
DatasetFromCallable(get_item, length),
batch_size=config.batch_size,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
drop_last=config.drop_last,
shuffle=config.shuffle,
collate_fn=collate_fn,
)

View file

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property, wraps
from typing import Any, Callable, Generic, Literal, TypeVar, cast
from typing import Any, Callable, Generic, Iterable, Literal, TypeVar, cast
import torch
from loguru import logger
@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import (
ReduceLROnPlateau,
StepLR,
)
from torch.utils.data import DataLoader, Dataset
from refiners.fluxion import layers as fl
from refiners.training_utils.callback import (
@ -64,23 +63,6 @@ Batch = TypeVar("Batch")
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
class _Dataset(Dataset[Batch]):
"""
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
"""
def __init__(self, get_item: Callable[[int], Batch], length: int) -> None:
assert length > 0, "Dataset length must be greater than 0."
self.length = length
self.get_item = get_item
def __getitem__(self, index: int) -> Batch:
return self.get_item(index)
def __len__(self) -> int:
return self.length
@dataclass
class ModelItem:
name: str
@ -151,7 +133,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@register_callback()
def clock(self, config: ClockConfig) -> TrainingClock:
return TrainingClock(
batch_size=self.config.training.batch_size,
training_duration=self.config.training.duration,
gradient_accumulation=self.config.training.gradient_accumulation,
lr_scheduler_interval=self.config.lr_scheduler.update_interval,
@ -294,58 +275,14 @@ class Trainer(Generic[ConfigType, Batch], ABC):
return lr_scheduler
@abstractmethod
def get_item(self, index: int) -> Batch:
"""
Returns a batch of data.
def compute_loss(self, batch: Batch) -> Tensor: ...
This function is used by the dataloader to fetch a batch of data.
"""
...
@abstractmethod
def create_data_iterable(self) -> Iterable[Batch]: ...
@property
@abstractmethod
def dataset_length(self) -> int:
"""
Returns the length of the dataset.
This is used to compute the number of batches per epoch.
"""
...
@abstractmethod
def collate_fn(self, batch: list[Batch]) -> Batch:
"""
Collate function for the dataloader.
This function is used to tell the dataloader how to combine a list of
batches into a single batch.
"""
...
@cached_property
def dataset(self) -> Dataset[Batch]:
"""
Returns the dataset constructed with the `get_item` method.
"""
return _Dataset(get_item=self.get_item, length=self.dataset_length)
@cached_property
def dataloader(self) -> DataLoader[Any]:
config = self.config.dataloader
return DataLoader(
dataset=self.dataset,
batch_size=self.config.training.batch_size,
collate_fn=self.collate_fn,
num_workers=config.num_workers,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
pin_memory=config.pin_memory,
shuffle=config.shuffle,
drop_last=config.drop_last,
)
@abstractmethod
def compute_loss(self, batch: Batch) -> Tensor: ...
def data_iterable(self) -> Iterable[Batch]:
return self.create_data_iterable()
def backward(self) -> None:
"""Backward pass on the loss."""
@ -375,7 +312,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def epoch(self) -> None:
"""Perform a single epoch."""
for batch in self.dataloader:
for batch in self.data_iterable:
if self.clock.done:
break
self._call_callbacks(event_name="on_step_begin")

View file

@ -20,6 +20,9 @@ batch_size = 4
gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0
[data_loader]
batch_size = 4
[optimizer]
optimizer = "SGD"
learning_rate = 1

View file

@ -26,6 +26,7 @@ from refiners.training_utils.common import (
scoped_seed,
)
from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader
from refiners.training_utils.trainer import (
Trainer,
TrainingClock,
@ -64,6 +65,7 @@ class MockConfig(BaseConfig):
mock_model: MockModelConfig
mock_callback: MockCallbackConfig
data_loader: DataloaderConfig
class MockModel(fl.Chain):
@ -134,6 +136,14 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
targets=torch.cat([b.targets for b in batch]),
)
def create_data_iterable(self):
return create_data_loader(
get_item=self.get_item,
length=self.dataset_length,
config=self.config.data_loader,
collate_fn=self.collate_fn,
)
@register_callback()
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
return EarlyMockCallback()
@ -204,17 +214,6 @@ def test_human_readable_number() -> None:
@pytest.fixture
def training_clock() -> TrainingClock:
return TrainingClock(
batch_size=10,
training_duration=Epoch(5),
gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1),
)
def test_zero_batch_size_error():
with pytest.raises(AssertionError):
TrainingClock(
batch_size=0,
training_duration=Epoch(5),
gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1),