mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
implement data_iterable
This commit is contained in:
parent
05a63ef44e
commit
38bddc49bd
|
@ -24,7 +24,6 @@ class TrainingConfig(BaseModel):
|
||||||
dtype: str = "float32"
|
dtype: str = "float32"
|
||||||
duration: TimeValue = Iteration(1)
|
duration: TimeValue = Iteration(1)
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
batch_size: int = 1
|
|
||||||
gradient_accumulation: Step = Step(1)
|
gradient_accumulation: Step = Step(1)
|
||||||
gradient_clipping_max_norm: float | None = None
|
gradient_clipping_max_norm: float | None = None
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ class DatasetFromCallable(Dataset[BatchT]):
|
||||||
def create_data_loader(
|
def create_data_loader(
|
||||||
get_item: Callable[[int], BatchT],
|
get_item: Callable[[int], BatchT],
|
||||||
length: int,
|
length: int,
|
||||||
config: DataloaderConfig,
|
config: DataLoaderConfig,
|
||||||
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
|
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
|
||||||
) -> DataLoader[BatchT]:
|
) -> DataLoader[BatchT]:
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
|
|
|
@ -280,7 +280,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_data_iterable(self) -> Iterable[Batch]: ...
|
def create_data_iterable(self) -> Iterable[Batch]: ...
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def data_iterable(self) -> Iterable[Batch]:
|
def data_iterable(self) -> Iterable[Batch]:
|
||||||
return self.create_data_iterable()
|
return self.create_data_iterable()
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
batch_size = 4
|
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
gradient_clipping_max_norm = 1.0
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ verbose = false
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
batch_size = 4
|
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
gradient_clipping_max_norm = 1.0
|
gradient_clipping_max_norm = 1.0
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from refiners.training_utils.common import (
|
||||||
scoped_seed,
|
scoped_seed,
|
||||||
)
|
)
|
||||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||||
from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader
|
from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader
|
||||||
from refiners.training_utils.trainer import (
|
from refiners.training_utils.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingClock,
|
TrainingClock,
|
||||||
|
@ -65,7 +65,7 @@ class MockConfig(BaseConfig):
|
||||||
|
|
||||||
mock_model: MockModelConfig
|
mock_model: MockModelConfig
|
||||||
mock_callback: MockCallbackConfig
|
mock_callback: MockCallbackConfig
|
||||||
data_loader: DataloaderConfig
|
data_loader: DataLoaderConfig
|
||||||
|
|
||||||
|
|
||||||
class MockModel(fl.Chain):
|
class MockModel(fl.Chain):
|
||||||
|
@ -334,3 +334,15 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M
|
||||||
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
||||||
assert len(mock_trainer_2_models.optimizer.param_groups) == 2
|
assert len(mock_trainer_2_models.optimizer.param_groups) == 2
|
||||||
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
class MockTrainerNoDataLoader(MockTrainer):
|
||||||
|
def create_data_iterable(self) -> list[MockBatch]: # type: ignore
|
||||||
|
return [MockBatch(inputs=torch.randn(4, 10), targets=torch.randn(4, 10)) for _ in range(5)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trainer_no_data_loader(mock_config: MockConfig) -> None:
|
||||||
|
trainer = MockTrainerNoDataLoader(config=mock_config)
|
||||||
|
trainer.train()
|
||||||
|
assert trainer.step_counter == 500
|
||||||
|
assert trainer.clock.epoch == 100
|
||||||
|
|
Loading…
Reference in a new issue