mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
change param name of abstract get_item method
This commit is contained in:
parent
6d599d53fd
commit
41508e0865
|
@ -446,7 +446,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_item(self, _: int) -> Batch:
|
def get_item(self, index: int) -> Batch:
|
||||||
"""
|
"""
|
||||||
Returns a batch of data.
|
Returns a batch of data.
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
def dataset_length(self) -> int:
|
def dataset_length(self) -> int:
|
||||||
return 20
|
return 20
|
||||||
|
|
||||||
def get_item(self, _: int) -> MockBatch:
|
def get_item(self, index: int) -> MockBatch:
|
||||||
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))
|
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))
|
||||||
|
|
||||||
def collate_fn(self, batch: list[MockBatch]) -> MockBatch:
|
def collate_fn(self, batch: list[MockBatch]) -> MockBatch:
|
||||||
|
|
Loading…
Reference in a new issue