mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
implement data_iterable (bis)
This commit is contained in:
parent
de8334b6fc
commit
d6c225a112
|
@ -20,14 +20,11 @@ class ClockConfig(CallbackConfig):
|
|||
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
training_duration: TimeValue,
|
||||
gradient_accumulation: Step,
|
||||
lr_scheduler_interval: TimeValue,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
assert batch_size > 0, "Batch size must be greater than 0."
|
||||
self.batch_size = batch_size
|
||||
self.training_duration = training_duration
|
||||
self.gradient_accumulation = gradient_accumulation
|
||||
self.lr_scheduler_interval = lr_scheduler_interval
|
||||
|
|
|
@ -22,7 +22,7 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
|||
class TrainingConfig(BaseModel):
|
||||
device: str = "cpu"
|
||||
dtype: str = "float32"
|
||||
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
||||
duration: TimeValue = Iteration(1)
|
||||
seed: int = 0
|
||||
batch_size: int = 1
|
||||
gradient_accumulation: Step = Step(1)
|
||||
|
@ -144,17 +144,6 @@ class OptimizerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class DataloaderConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
num_workers: int = 0
|
||||
pin_memory: bool = False
|
||||
prefetch_factor: int | None = None
|
||||
persistent_workers: bool = False
|
||||
drop_last: bool = False
|
||||
shuffle: bool = True
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
# If None, then requires_grad will NOT be changed when loading the model
|
||||
# this can be useful if you want to train only a part of the model
|
||||
|
@ -176,7 +165,6 @@ class BaseConfig(BaseModel):
|
|||
optimizer: OptimizerConfig
|
||||
lr_scheduler: LRSchedulerConfig
|
||||
clock: ClockConfig = ClockConfig()
|
||||
dataloader: DataloaderConfig = DataloaderConfig()
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
55
src/refiners/training_utils/data_loader.py
Normal file
55
src/refiners/training_utils/data_loader.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from typing import Callable, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, PositiveInt
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
BatchT = TypeVar("BatchT")
|
||||
|
||||
|
||||
class DataloaderConfig(BaseModel):
|
||||
batch_size: PositiveInt = 1
|
||||
num_workers: int = 0
|
||||
pin_memory: bool = False
|
||||
prefetch_factor: int | None = None
|
||||
persistent_workers: bool = False
|
||||
drop_last: bool = False
|
||||
shuffle: bool = True
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
# TODO: Add more validation to the config
|
||||
|
||||
|
||||
class DatasetFromCallable(Dataset[BatchT]):
|
||||
"""
|
||||
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
|
||||
"""
|
||||
|
||||
def __init__(self, get_item: Callable[[int], BatchT], length: int) -> None:
|
||||
assert length > 0, "Dataset length must be greater than 0."
|
||||
self.length = length
|
||||
self.get_item = get_item
|
||||
|
||||
def __getitem__(self, index: int) -> BatchT:
|
||||
return self.get_item(index)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.length
|
||||
|
||||
|
||||
def create_data_loader(
|
||||
get_item: Callable[[int], BatchT],
|
||||
length: int,
|
||||
config: DataloaderConfig,
|
||||
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
|
||||
) -> DataLoader[BatchT]:
|
||||
return DataLoader(
|
||||
DatasetFromCallable(get_item, length),
|
||||
batch_size=config.batch_size,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=config.pin_memory,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
persistent_workers=config.persistent_workers,
|
||||
drop_last=config.drop_last,
|
||||
shuffle=config.shuffle,
|
||||
collate_fn=collate_fn,
|
||||
)
|
|
@ -1,7 +1,7 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, wraps
|
||||
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
||||
from typing import Any, Callable, Generic, Iterable, Literal, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import (
|
|||
ReduceLROnPlateau,
|
||||
StepLR,
|
||||
)
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.training_utils.callback import (
|
||||
|
@ -64,23 +63,6 @@ Batch = TypeVar("Batch")
|
|||
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||
|
||||
|
||||
class _Dataset(Dataset[Batch]):
|
||||
"""
|
||||
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
|
||||
"""
|
||||
|
||||
def __init__(self, get_item: Callable[[int], Batch], length: int) -> None:
|
||||
assert length > 0, "Dataset length must be greater than 0."
|
||||
self.length = length
|
||||
self.get_item = get_item
|
||||
|
||||
def __getitem__(self, index: int) -> Batch:
|
||||
return self.get_item(index)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.length
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelItem:
|
||||
name: str
|
||||
|
@ -151,7 +133,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
@register_callback()
|
||||
def clock(self, config: ClockConfig) -> TrainingClock:
|
||||
return TrainingClock(
|
||||
batch_size=self.config.training.batch_size,
|
||||
training_duration=self.config.training.duration,
|
||||
gradient_accumulation=self.config.training.gradient_accumulation,
|
||||
lr_scheduler_interval=self.config.lr_scheduler.update_interval,
|
||||
|
@ -294,58 +275,14 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
return lr_scheduler
|
||||
|
||||
@abstractmethod
|
||||
def get_item(self, index: int) -> Batch:
|
||||
"""
|
||||
Returns a batch of data.
|
||||
def compute_loss(self, batch: Batch) -> Tensor: ...
|
||||
|
||||
This function is used by the dataloader to fetch a batch of data.
|
||||
"""
|
||||
...
|
||||
@abstractmethod
|
||||
def create_data_iterable(self) -> Iterable[Batch]: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset_length(self) -> int:
|
||||
"""
|
||||
Returns the length of the dataset.
|
||||
|
||||
This is used to compute the number of batches per epoch.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def collate_fn(self, batch: list[Batch]) -> Batch:
|
||||
"""
|
||||
Collate function for the dataloader.
|
||||
|
||||
This function is used to tell the dataloader how to combine a list of
|
||||
batches into a single batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@cached_property
|
||||
def dataset(self) -> Dataset[Batch]:
|
||||
"""
|
||||
Returns the dataset constructed with the `get_item` method.
|
||||
"""
|
||||
return _Dataset(get_item=self.get_item, length=self.dataset_length)
|
||||
|
||||
@cached_property
|
||||
def dataloader(self) -> DataLoader[Any]:
|
||||
config = self.config.dataloader
|
||||
return DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_size=self.config.training.batch_size,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=config.num_workers,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
persistent_workers=config.persistent_workers,
|
||||
pin_memory=config.pin_memory,
|
||||
shuffle=config.shuffle,
|
||||
drop_last=config.drop_last,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Batch) -> Tensor: ...
|
||||
def data_iterable(self) -> Iterable[Batch]:
|
||||
return self.create_data_iterable()
|
||||
|
||||
def backward(self) -> None:
|
||||
"""Backward pass on the loss."""
|
||||
|
@ -375,7 +312,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
|
||||
def epoch(self) -> None:
|
||||
"""Perform a single epoch."""
|
||||
for batch in self.dataloader:
|
||||
for batch in self.data_iterable:
|
||||
if self.clock.done:
|
||||
break
|
||||
self._call_callbacks(event_name="on_step_begin")
|
||||
|
|
|
@ -20,6 +20,9 @@ batch_size = 4
|
|||
gradient_accumulation = "4:step"
|
||||
gradient_clipping_max_norm = 1.0
|
||||
|
||||
[data_loader]
|
||||
batch_size = 4
|
||||
|
||||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
learning_rate = 1
|
||||
|
|
|
@ -26,6 +26,7 @@ from refiners.training_utils.common import (
|
|||
scoped_seed,
|
||||
)
|
||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||
from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader
|
||||
from refiners.training_utils.trainer import (
|
||||
Trainer,
|
||||
TrainingClock,
|
||||
|
@ -64,6 +65,7 @@ class MockConfig(BaseConfig):
|
|||
|
||||
mock_model: MockModelConfig
|
||||
mock_callback: MockCallbackConfig
|
||||
data_loader: DataloaderConfig
|
||||
|
||||
|
||||
class MockModel(fl.Chain):
|
||||
|
@ -134,6 +136,14 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
|||
targets=torch.cat([b.targets for b in batch]),
|
||||
)
|
||||
|
||||
def create_data_iterable(self):
|
||||
return create_data_loader(
|
||||
get_item=self.get_item,
|
||||
length=self.dataset_length,
|
||||
config=self.config.data_loader,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
@register_callback()
|
||||
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
|
||||
return EarlyMockCallback()
|
||||
|
@ -204,23 +214,12 @@ def test_human_readable_number() -> None:
|
|||
@pytest.fixture
|
||||
def training_clock() -> TrainingClock:
|
||||
return TrainingClock(
|
||||
batch_size=10,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=Step(1),
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
||||
def test_zero_batch_size_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
batch_size=0,
|
||||
training_duration=Epoch(5),
|
||||
gradient_accumulation=Step(1),
|
||||
lr_scheduler_interval=Epoch(1),
|
||||
)
|
||||
|
||||
|
||||
def test_timer_functionality(training_clock: TrainingClock) -> None:
|
||||
training_clock.start_timer()
|
||||
assert training_clock.start_time is not None
|
||||
|
|
Loading…
Reference in a new issue