This guide will walk you through training a model using Refiners. We built the `training_utils` module to provide a simple, flexible, statically type-safe interface.
We will use a simple model and a toy dataset for demonstration purposes. The model will be a simple [autoencoder](https://en.wikipedia.org/wiki/Autoencoder), and the dataset will be a synthetic dataset of rectangles
of different sizes.
## Pre-requisites
We recommend installing Refiners targeting a specific commit hash to avoid unexpected changes in the API. You also
get the benefit of having a perfectly reproducible environment.
We will now create a Trainer class to handle the training loop. This class will manage the model, the optimizer, the loss function, and the dataset. It will also orchestrate the training loop and the evaluation loop.
But first, we need to define the batch type that will be used to represent a batch for the forward and backward pass and the configuration associated with the trainer.
### Batch
Our batches are composed of a single tensor representing the images. We will define a simple `Batch` type to implement this.
```python
from dataclasses import dataclass
@dataclass
class Batch:
image: torch.Tensor
```
### Config
We will now define the configuration for the autoencoder. It holds the configuration for the training loop, the optimizer, and the learning rate scheduler. It should inherit `refiners.training_utils.BaseConfig` and has the following mandatory attributes:
-`TrainingConfig`: The configuration for the training loop, including the duration of the training, the batch size, device, dtype, etc.
-`OptimizerConfig`: The configuration for the optimizer, including the learning rate, weight decay, etc.
-`LRSchedulerConfig`: The configuration for the learning rate scheduler, including the scheduler type, parameters, etc.
Example:
```python
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers
class AutoencoderConfig(BaseConfig):
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
epoch_size: int = 2048
training = TrainingConfig(
duration="1000:epoch",
batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32"
)
optimizer = OptimizerConfig(
optimizer=Optimizers.AdamW,
learning_rate=1e-4,
)
lr_scheduler = LRSchedulerConfig(
type=LRSchedulers.ConstantLR
)
config = AutoencoderConfig(
training=training,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
```
### Subclass
We can now define the Trainer subclass. It should inherit from `refiners.training_utils.Trainer` and implement the following methods:
-`get_item`: This method should take an index and return a Batch.
-`collate_fn`: This method should take a list of Batch and return a concatenated Batch.
-`dataset_length`: We implement this property to return the length of the dataset.
-`compute_loss`: This method should take a Batch and return the loss.
```python
from functools import cached_property
from refiners.training_utils import Trainer
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
raise NotImplementedError("We'll implement this later")
trainer = AutoencoderTrainer(config)
```
### Model registration
For the Trainer to be able to handle the model, we need to register it.
We need two things to do so:
- Add `refiners.training_utils.ModelConfig` attribute to the Config named `autoencoder`.
- Add a method to the Trainer subclass that returns the model decorated with `@register_model` decorator. This method should take the `ModelConfig` as an argument. The Trainer's `__init__` will register the models and add any parameters to the optimizer that have `requires_grad` enabled.
After registering the model, the `self.autoencoder` attribute will be available in the Trainer.
```python
from refiners.training_utils import ModelConfig, register_model
class AutoencoderModelConfig(ModelConfig):
pass
class AutoencoderConfig(BaseConfig):
epoch_size: int = 2048
autoencoder: AutoencoderModelConfig
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from `refiners.training_utils.Callback` and implement any of the following methods:
-`on_init_begin`
-`on_init_end`
-`on_train_begin`
-`on_train_end`
-`on_epoch_begin`
-`on_epoch_end`
-`on_batch_begin`
-`on_batch_end`
-`on_backward_begin`
-`on_backward_end`
-`on_optimizer_step_begin`
-`on_optimizer_step_end`
-`on_compute_loss_begin`
-`on_compute_loss_end`
-`on_evaluate_begin`
-`on_evaluate_end`
-`on_lr_scheduler_step_begin`
-`on_lr_scheduler_step_end`
We will implement the `on_epoch_end` method to log the loss and the reconstructed images and the `on_compute_loss_end` method to store the loss in a list.
Exactly like models, we need to register the callback to the Trainer. We can do so by adding a CallbackConfig attribute to the Config named `logging` and adding a method to the Trainer class that returns the callback decorated with `@register_callback` decorator.
```python
from refiners.training_utils import CallbackConfig, register_callback
class AutoencoderConfig(BaseConfig):
epoch_size: int = 2048
logging: CallbackConfig = CallbackConfig()
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):