mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
implement data_iterable (bis)
This commit is contained in:
parent
de8334b6fc
commit
d6c225a112
|
@ -20,14 +20,11 @@ class ClockConfig(CallbackConfig):
|
||||||
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
|
||||||
training_duration: TimeValue,
|
training_duration: TimeValue,
|
||||||
gradient_accumulation: Step,
|
gradient_accumulation: Step,
|
||||||
lr_scheduler_interval: TimeValue,
|
lr_scheduler_interval: TimeValue,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert batch_size > 0, "Batch size must be greater than 0."
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.training_duration = training_duration
|
self.training_duration = training_duration
|
||||||
self.gradient_accumulation = gradient_accumulation
|
self.gradient_accumulation = gradient_accumulation
|
||||||
self.lr_scheduler_interval = lr_scheduler_interval
|
self.lr_scheduler_interval = lr_scheduler_interval
|
||||||
|
|
|
@ -22,7 +22,7 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
dtype: str = "float32"
|
dtype: str = "float32"
|
||||||
duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
|
duration: TimeValue = Iteration(1)
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: Step = Step(1)
|
gradient_accumulation: Step = Step(1)
|
||||||
|
@ -144,17 +144,6 @@ class OptimizerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataloaderConfig(BaseModel):
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
|
||||||
|
|
||||||
num_workers: int = 0
|
|
||||||
pin_memory: bool = False
|
|
||||||
prefetch_factor: int | None = None
|
|
||||||
persistent_workers: bool = False
|
|
||||||
drop_last: bool = False
|
|
||||||
shuffle: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
# If None, then requires_grad will NOT be changed when loading the model
|
# If None, then requires_grad will NOT be changed when loading the model
|
||||||
# this can be useful if you want to train only a part of the model
|
# this can be useful if you want to train only a part of the model
|
||||||
|
@ -176,7 +165,6 @@ class BaseConfig(BaseModel):
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
lr_scheduler: LRSchedulerConfig
|
lr_scheduler: LRSchedulerConfig
|
||||||
clock: ClockConfig = ClockConfig()
|
clock: ClockConfig = ClockConfig()
|
||||||
dataloader: DataloaderConfig = DataloaderConfig()
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
55
src/refiners/training_utils/data_loader.py
Normal file
55
src/refiners/training_utils/data_loader.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, PositiveInt
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
BatchT = TypeVar("BatchT")
|
||||||
|
|
||||||
|
|
||||||
|
class DataloaderConfig(BaseModel):
|
||||||
|
batch_size: PositiveInt = 1
|
||||||
|
num_workers: int = 0
|
||||||
|
pin_memory: bool = False
|
||||||
|
prefetch_factor: int | None = None
|
||||||
|
persistent_workers: bool = False
|
||||||
|
drop_last: bool = False
|
||||||
|
shuffle: bool = True
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
# TODO: Add more validation to the config
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFromCallable(Dataset[BatchT]):
|
||||||
|
"""
|
||||||
|
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, get_item: Callable[[int], BatchT], length: int) -> None:
|
||||||
|
assert length > 0, "Dataset length must be greater than 0."
|
||||||
|
self.length = length
|
||||||
|
self.get_item = get_item
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> BatchT:
|
||||||
|
return self.get_item(index)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_loader(
|
||||||
|
get_item: Callable[[int], BatchT],
|
||||||
|
length: int,
|
||||||
|
config: DataloaderConfig,
|
||||||
|
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
|
||||||
|
) -> DataLoader[BatchT]:
|
||||||
|
return DataLoader(
|
||||||
|
DatasetFromCallable(get_item, length),
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
pin_memory=config.pin_memory,
|
||||||
|
prefetch_factor=config.prefetch_factor,
|
||||||
|
persistent_workers=config.persistent_workers,
|
||||||
|
drop_last=config.drop_last,
|
||||||
|
shuffle=config.shuffle,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
|
@ -1,7 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property, wraps
|
from functools import cached_property, wraps
|
||||||
from typing import Any, Callable, Generic, Literal, TypeVar, cast
|
from typing import Any, Callable, Generic, Iterable, Literal, TypeVar, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import (
|
||||||
ReduceLROnPlateau,
|
ReduceLROnPlateau,
|
||||||
StepLR,
|
StepLR,
|
||||||
)
|
)
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.training_utils.callback import (
|
from refiners.training_utils.callback import (
|
||||||
|
@ -64,23 +63,6 @@ Batch = TypeVar("Batch")
|
||||||
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||||
|
|
||||||
|
|
||||||
class _Dataset(Dataset[Batch]):
|
|
||||||
"""
|
|
||||||
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`][torch.utils.data.Dataset].
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, get_item: Callable[[int], Batch], length: int) -> None:
|
|
||||||
assert length > 0, "Dataset length must be greater than 0."
|
|
||||||
self.length = length
|
|
||||||
self.get_item = get_item
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Batch:
|
|
||||||
return self.get_item(index)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.length
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelItem:
|
class ModelItem:
|
||||||
name: str
|
name: str
|
||||||
|
@ -151,7 +133,6 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@register_callback()
|
@register_callback()
|
||||||
def clock(self, config: ClockConfig) -> TrainingClock:
|
def clock(self, config: ClockConfig) -> TrainingClock:
|
||||||
return TrainingClock(
|
return TrainingClock(
|
||||||
batch_size=self.config.training.batch_size,
|
|
||||||
training_duration=self.config.training.duration,
|
training_duration=self.config.training.duration,
|
||||||
gradient_accumulation=self.config.training.gradient_accumulation,
|
gradient_accumulation=self.config.training.gradient_accumulation,
|
||||||
lr_scheduler_interval=self.config.lr_scheduler.update_interval,
|
lr_scheduler_interval=self.config.lr_scheduler.update_interval,
|
||||||
|
@ -294,58 +275,14 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
return lr_scheduler
|
return lr_scheduler
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_item(self, index: int) -> Batch:
|
def compute_loss(self, batch: Batch) -> Tensor: ...
|
||||||
"""
|
|
||||||
Returns a batch of data.
|
|
||||||
|
|
||||||
This function is used by the dataloader to fetch a batch of data.
|
@abstractmethod
|
||||||
"""
|
def create_data_iterable(self) -> Iterable[Batch]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
def data_iterable(self) -> Iterable[Batch]:
|
||||||
def dataset_length(self) -> int:
|
return self.create_data_iterable()
|
||||||
"""
|
|
||||||
Returns the length of the dataset.
|
|
||||||
|
|
||||||
This is used to compute the number of batches per epoch.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collate_fn(self, batch: list[Batch]) -> Batch:
|
|
||||||
"""
|
|
||||||
Collate function for the dataloader.
|
|
||||||
|
|
||||||
This function is used to tell the dataloader how to combine a list of
|
|
||||||
batches into a single batch.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def dataset(self) -> Dataset[Batch]:
|
|
||||||
"""
|
|
||||||
Returns the dataset constructed with the `get_item` method.
|
|
||||||
"""
|
|
||||||
return _Dataset(get_item=self.get_item, length=self.dataset_length)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def dataloader(self) -> DataLoader[Any]:
|
|
||||||
config = self.config.dataloader
|
|
||||||
return DataLoader(
|
|
||||||
dataset=self.dataset,
|
|
||||||
batch_size=self.config.training.batch_size,
|
|
||||||
collate_fn=self.collate_fn,
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
prefetch_factor=config.prefetch_factor,
|
|
||||||
persistent_workers=config.persistent_workers,
|
|
||||||
pin_memory=config.pin_memory,
|
|
||||||
shuffle=config.shuffle,
|
|
||||||
drop_last=config.drop_last,
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_loss(self, batch: Batch) -> Tensor: ...
|
|
||||||
|
|
||||||
def backward(self) -> None:
|
def backward(self) -> None:
|
||||||
"""Backward pass on the loss."""
|
"""Backward pass on the loss."""
|
||||||
|
@ -375,7 +312,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
|
|
||||||
def epoch(self) -> None:
|
def epoch(self) -> None:
|
||||||
"""Perform a single epoch."""
|
"""Perform a single epoch."""
|
||||||
for batch in self.dataloader:
|
for batch in self.data_iterable:
|
||||||
if self.clock.done:
|
if self.clock.done:
|
||||||
break
|
break
|
||||||
self._call_callbacks(event_name="on_step_begin")
|
self._call_callbacks(event_name="on_step_begin")
|
||||||
|
|
|
@ -20,6 +20,9 @@ batch_size = 4
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
gradient_clipping_max_norm = 1.0
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
|
[data_loader]
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
[optimizer]
|
[optimizer]
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
learning_rate = 1
|
learning_rate = 1
|
||||||
|
|
|
@ -26,6 +26,7 @@ from refiners.training_utils.common import (
|
||||||
scoped_seed,
|
scoped_seed,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||||
|
from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader
|
||||||
from refiners.training_utils.trainer import (
|
from refiners.training_utils.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingClock,
|
TrainingClock,
|
||||||
|
@ -64,6 +65,7 @@ class MockConfig(BaseConfig):
|
||||||
|
|
||||||
mock_model: MockModelConfig
|
mock_model: MockModelConfig
|
||||||
mock_callback: MockCallbackConfig
|
mock_callback: MockCallbackConfig
|
||||||
|
data_loader: DataloaderConfig
|
||||||
|
|
||||||
|
|
||||||
class MockModel(fl.Chain):
|
class MockModel(fl.Chain):
|
||||||
|
@ -134,6 +136,14 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
targets=torch.cat([b.targets for b in batch]),
|
targets=torch.cat([b.targets for b in batch]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_data_iterable(self):
|
||||||
|
return create_data_loader(
|
||||||
|
get_item=self.get_item,
|
||||||
|
length=self.dataset_length,
|
||||||
|
config=self.config.data_loader,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
@register_callback()
|
@register_callback()
|
||||||
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
|
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
|
||||||
return EarlyMockCallback()
|
return EarlyMockCallback()
|
||||||
|
@ -204,23 +214,12 @@ def test_human_readable_number() -> None:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def training_clock() -> TrainingClock:
|
def training_clock() -> TrainingClock:
|
||||||
return TrainingClock(
|
return TrainingClock(
|
||||||
batch_size=10,
|
|
||||||
training_duration=Epoch(5),
|
training_duration=Epoch(5),
|
||||||
gradient_accumulation=Step(1),
|
gradient_accumulation=Step(1),
|
||||||
lr_scheduler_interval=Epoch(1),
|
lr_scheduler_interval=Epoch(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_zero_batch_size_error():
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
TrainingClock(
|
|
||||||
batch_size=0,
|
|
||||||
training_duration=Epoch(5),
|
|
||||||
gradient_accumulation=Step(1),
|
|
||||||
lr_scheduler_interval=Epoch(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_timer_functionality(training_clock: TrainingClock) -> None:
|
def test_timer_functionality(training_clock: TrainingClock) -> None:
|
||||||
training_clock.start_timer()
|
training_clock.start_timer()
|
||||||
assert training_clock.start_time is not None
|
assert training_clock.start_time is not None
|
||||||
|
|
Loading…
Reference in a new issue