implement data_iterable (bis)

This commit is contained in:
limiteinductive 2024-04-24 17:14:32 +00:00 committed by Benjamin Trom
parent de8334b6fc
commit d6c225a112
6 changed files with 76 additions and 97 deletions

View file

@ -20,14 +20,11 @@ class ClockConfig(CallbackConfig):
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def __init__( def __init__(
self, self,
batch_size: int,
training_duration: TimeValue, training_duration: TimeValue,
gradient_accumulation: Step, gradient_accumulation: Step,
lr_scheduler_interval: TimeValue, lr_scheduler_interval: TimeValue,
verbose: bool = True, verbose: bool = True,
) -> None: ) -> None:
assert batch_size > 0, "Batch size must be greater than 0."
self.batch_size = batch_size
self.training_duration = training_duration self.training_duration = training_duration
self.gradient_accumulation = gradient_accumulation self.gradient_accumulation = gradient_accumulation
self.lr_scheduler_interval = lr_scheduler_interval self.lr_scheduler_interval = lr_scheduler_interval

View file

@ -22,7 +22,7 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
device: str = "cpu" device: str = "cpu"
dtype: str = "float32" dtype: str = "float32"
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) duration: TimeValue = Iteration(1)
seed: int = 0 seed: int = 0
batch_size: int = 1 batch_size: int = 1
gradient_accumulation: Step = Step(1) gradient_accumulation: Step = Step(1)
@ -144,17 +144,6 @@ class OptimizerConfig(BaseModel):
) )
class DataloaderConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
num_workers: int = 0
pin_memory: bool = False
prefetch_factor: int | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
# If None, then requires_grad will NOT be changed when loading the model # If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model # this can be useful if you want to train only a part of the model
@ -176,7 +165,6 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig optimizer: OptimizerConfig
lr_scheduler: LRSchedulerConfig lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig() clock: ClockConfig = ClockConfig()
dataloader: DataloaderConfig = DataloaderConfig()
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")

View file

@ -0,0 +1,55 @@
from typing import Callable, TypeVar
from pydantic import BaseModel, ConfigDict, PositiveInt
from torch.utils.data import DataLoader, Dataset
BatchT = TypeVar("BatchT")
class DataloaderConfig(BaseModel):
batch_size: PositiveInt = 1
num_workers: int = 0
pin_memory: bool = False
prefetch_factor: int | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True
model_config = ConfigDict(extra="forbid")
# TODO: Add more validation to the config
class DatasetFromCallable(Dataset[BatchT]):
"""
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
"""
def __init__(self, get_item: Callable[[int], BatchT], length: int) -> None:
assert length > 0, "Dataset length must be greater than 0."
self.length = length
self.get_item = get_item
def __getitem__(self, index: int) -> BatchT:
return self.get_item(index)
def __len__(self) -> int:
return self.length
def create_data_loader(
get_item: Callable[[int], BatchT],
length: int,
config: DataloaderConfig,
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
) -> DataLoader[BatchT]:
return DataLoader(
DatasetFromCallable(get_item, length),
batch_size=config.batch_size,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
drop_last=config.drop_last,
shuffle=config.shuffle,
collate_fn=collate_fn,
)

View file

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, wraps from functools import cached_property, wraps
from typing import Any, Callable, Generic, Literal, TypeVar, cast from typing import Any, Callable, Generic, Iterable, Literal, TypeVar, cast
import torch import torch
from loguru import logger from loguru import logger
@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import (
ReduceLROnPlateau, ReduceLROnPlateau,
StepLR, StepLR,
) )
from torch.utils.data import DataLoader, Dataset
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.training_utils.callback import ( from refiners.training_utils.callback import (
@ -64,23 +63,6 @@ Batch = TypeVar("Batch")
ConfigType = TypeVar("ConfigType", bound=BaseConfig) ConfigType = TypeVar("ConfigType", bound=BaseConfig)
class _Dataset(Dataset[Batch]):
"""
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
"""
def __init__(self, get_item: Callable[[int], Batch], length: int) -> None:
assert length > 0, "Dataset length must be greater than 0."
self.length = length
self.get_item = get_item
def __getitem__(self, index: int) -> Batch:
return self.get_item(index)
def __len__(self) -> int:
return self.length
@dataclass @dataclass
class ModelItem: class ModelItem:
name: str name: str
@ -151,7 +133,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@register_callback() @register_callback()
def clock(self, config: ClockConfig) -> TrainingClock: def clock(self, config: ClockConfig) -> TrainingClock:
return TrainingClock( return TrainingClock(
batch_size=self.config.training.batch_size,
training_duration=self.config.training.duration, training_duration=self.config.training.duration,
gradient_accumulation=self.config.training.gradient_accumulation, gradient_accumulation=self.config.training.gradient_accumulation,
lr_scheduler_interval=self.config.lr_scheduler.update_interval, lr_scheduler_interval=self.config.lr_scheduler.update_interval,
@ -294,58 +275,14 @@ class Trainer(Generic[ConfigType, Batch], ABC):
return lr_scheduler return lr_scheduler
@abstractmethod @abstractmethod
def get_item(self, index: int) -> Batch: def compute_loss(self, batch: Batch) -> Tensor: ...
"""
Returns a batch of data.
This function is used by the dataloader to fetch a batch of data. @abstractmethod
""" def create_data_iterable(self) -> Iterable[Batch]: ...
...
@property @property
@abstractmethod def data_iterable(self) -> Iterable[Batch]:
def dataset_length(self) -> int: return self.create_data_iterable()
"""
Returns the length of the dataset.
This is used to compute the number of batches per epoch.
"""
...
@abstractmethod
def collate_fn(self, batch: list[Batch]) -> Batch:
"""
Collate function for the dataloader.
This function is used to tell the dataloader how to combine a list of
batches into a single batch.
"""
...
@cached_property
def dataset(self) -> Dataset[Batch]:
"""
Returns the dataset constructed with the `get_item` method.
"""
return _Dataset(get_item=self.get_item, length=self.dataset_length)
@cached_property
def dataloader(self) -> DataLoader[Any]:
config = self.config.dataloader
return DataLoader(
dataset=self.dataset,
batch_size=self.config.training.batch_size,
collate_fn=self.collate_fn,
num_workers=config.num_workers,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
pin_memory=config.pin_memory,
shuffle=config.shuffle,
drop_last=config.drop_last,
)
@abstractmethod
def compute_loss(self, batch: Batch) -> Tensor: ...
def backward(self) -> None: def backward(self) -> None:
"""Backward pass on the loss.""" """Backward pass on the loss."""
@ -375,7 +312,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def epoch(self) -> None: def epoch(self) -> None:
"""Perform a single epoch.""" """Perform a single epoch."""
for batch in self.dataloader: for batch in self.data_iterable:
if self.clock.done: if self.clock.done:
break break
self._call_callbacks(event_name="on_step_begin") self._call_callbacks(event_name="on_step_begin")

View file

@ -20,6 +20,9 @@ batch_size = 4
gradient_accumulation = "4:step" gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0 gradient_clipping_max_norm = 1.0
[data_loader]
batch_size = 4
[optimizer] [optimizer]
optimizer = "SGD" optimizer = "SGD"
learning_rate = 1 learning_rate = 1

View file

@ -26,6 +26,7 @@ from refiners.training_utils.common import (
scoped_seed, scoped_seed,
) )
from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader
from refiners.training_utils.trainer import ( from refiners.training_utils.trainer import (
Trainer, Trainer,
TrainingClock, TrainingClock,
@ -64,6 +65,7 @@ class MockConfig(BaseConfig):
mock_model: MockModelConfig mock_model: MockModelConfig
mock_callback: MockCallbackConfig mock_callback: MockCallbackConfig
data_loader: DataloaderConfig
class MockModel(fl.Chain): class MockModel(fl.Chain):
@ -134,6 +136,14 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
targets=torch.cat([b.targets for b in batch]), targets=torch.cat([b.targets for b in batch]),
) )
def create_data_iterable(self):
return create_data_loader(
get_item=self.get_item,
length=self.dataset_length,
config=self.config.data_loader,
collate_fn=self.collate_fn,
)
@register_callback() @register_callback()
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback: def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
return EarlyMockCallback() return EarlyMockCallback()
@ -204,23 +214,12 @@ def test_human_readable_number() -> None:
@pytest.fixture @pytest.fixture
def training_clock() -> TrainingClock: def training_clock() -> TrainingClock:
return TrainingClock( return TrainingClock(
batch_size=10,
training_duration=Epoch(5), training_duration=Epoch(5),
gradient_accumulation=Step(1), gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1), lr_scheduler_interval=Epoch(1),
) )
def test_zero_batch_size_error():
with pytest.raises(AssertionError):
TrainingClock(
batch_size=0,
training_duration=Epoch(5),
gradient_accumulation=Step(1),
lr_scheduler_interval=Epoch(1),
)
def test_timer_functionality(training_clock: TrainingClock) -> None: def test_timer_functionality(training_clock: TrainingClock) -> None:
training_clock.start_timer() training_clock.start_timer()
assert training_clock.start_time is not None assert training_clock.start_time is not None