mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
update test_trainer.py
This commit is contained in:
parent
603c8abb1e
commit
1db0845db2
|
@ -13,6 +13,7 @@ from torch.optim import SGD
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import norm
|
from refiners.fluxion.utils import norm
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.common import (
|
from refiners.training_utils.common import (
|
||||||
Epoch,
|
Epoch,
|
||||||
Iteration,
|
Iteration,
|
||||||
|
@ -57,6 +58,10 @@ class MockCallbackConfig(CallbackConfig):
|
||||||
|
|
||||||
|
|
||||||
class MockConfig(BaseConfig):
|
class MockConfig(BaseConfig):
|
||||||
|
# we register the `early_callback` before the `clock` callback to test the callback ordering
|
||||||
|
early_callback: CallbackConfig = CallbackConfig()
|
||||||
|
clock: ClockConfig = ClockConfig()
|
||||||
|
|
||||||
mock_model: MockModelConfig
|
mock_model: MockModelConfig
|
||||||
mock_callback: MockCallbackConfig
|
mock_callback: MockCallbackConfig
|
||||||
|
|
||||||
|
@ -95,7 +100,7 @@ class MockCallback(Callback["MockTrainer"]):
|
||||||
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
||||||
return
|
return
|
||||||
|
|
||||||
# We verify that the callback is always called before the clock is updated
|
# We verify that this callback is always called before the clock is updated (see `_call_callbacks` in trainer.py)
|
||||||
assert trainer.clock.step // 3 <= self.step_end_count
|
assert trainer.clock.step // 3 <= self.step_end_count
|
||||||
|
|
||||||
self.step_end_count += 1
|
self.step_end_count += 1
|
||||||
|
@ -103,6 +108,15 @@ class MockCallback(Callback["MockTrainer"]):
|
||||||
self.step_end_random_int = random.randint(0, 100)
|
self.step_end_random_int = random.randint(0, 100)
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyMockCallback(Callback["MockTrainer"]):
|
||||||
|
"""
|
||||||
|
A callback that will be registered before the Clock callback to test the callback ordering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_train_begin(self, trainer: "MockTrainer") -> None:
|
||||||
|
assert trainer.clock.start_time is not None, "Clock callback should have been called before this callback."
|
||||||
|
|
||||||
|
|
||||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
step_counter: int = 0
|
step_counter: int = 0
|
||||||
model_registration_counter: int = 0
|
model_registration_counter: int = 0
|
||||||
|
@ -120,6 +134,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
targets=torch.cat([b.targets for b in batch]),
|
targets=torch.cat([b.targets for b in batch]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@register_callback()
|
||||||
|
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
|
||||||
|
return EarlyMockCallback()
|
||||||
|
|
||||||
@register_callback()
|
@register_callback()
|
||||||
def mock_callback(self, config: MockCallbackConfig) -> MockCallback:
|
def mock_callback(self, config: MockCallbackConfig) -> MockCallback:
|
||||||
return MockCallback(config)
|
return MockCallback(config)
|
||||||
|
@ -276,7 +294,7 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
|
||||||
|
|
||||||
# Check that the callback skips every other iteration
|
# Check that the callback skips every other iteration
|
||||||
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
||||||
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3
|
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1
|
||||||
|
|
||||||
# Check that the random seed was set
|
# Check that the random seed was set
|
||||||
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
||||||
|
|
Loading…
Reference in a new issue