update Training 101

This commit is contained in:
limiteinductive 2024-04-29 08:04:10 +00:00 committed by Benjamin Trom
parent 0bec9a855d
commit 76a6ce8641

View file

@ -222,12 +222,10 @@ Example:
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch
class AutoencoderConfig(BaseConfig):
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
epoch_size: int = 2048
...
training = TrainingConfig(
duration=Epoch(1000),
batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32"
)
@ -252,30 +250,31 @@ config = AutoencoderConfig(
We can now define the Trainer subclass. It should inherit from `refiners.training_utils.Trainer` and implement the following methods:
- `get_item`: This method should take an index and return a Batch.
- `collate_fn`: This method should take a list of Batch and return a concatenated Batch.
- `dataset_length`: We implement this property to return the length of the dataset.
- `compute_loss`: This method should take a Batch and return the loss.
- `create_data_iterable`: The `Trainer` will call this method to create and cache the data iterable. During training, the loop will pull batches from this iterable and pass them to the `compute_loss` method. Every time the iterable is exhausted, an epoch ends.
- `compute_loss`: This method should take a Batch and return the loss tensor.
Here is a simple implementation of the `create_data_iterable` method. For this toy example, we will generate a simple list of `Batch` objects containing random masks. Later you can replace this with `torch.utils.data.DataLoader` or any other data loader with more complex features that support shuffling, parallel loading, etc.
```python
from functools import cached_property
from refiners.training_utils import Trainer
class AutoencoderConfig(BaseConfig):
num_images: int = 2048
batch_size: int = 32
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
@cached_property
def image_generator(self) -> Generator[torch.Tensor, None, None]:
return generate_mask(size=64)
def create_data_iterable(self) -> list[Batch]:
dataset: list[Batch] = []
generator = generate_mask(size=64)
def get_item(self, index: int) -> Batch:
return Batch(image=next(self.image_generator).to(self.device, self.dtype))
def collate_fn(self, batch: list[Batch]) -> Batch:
return Batch(image=torch.cat([b.image for b in batch]))
@property
def dataset_length(self) -> int:
return self.config.epoch_size
for _ in range(self.config.num_images // self.config.batch_size):
masks = [next(generator) for _ in range(self.config.batch_size)]
dataset.append(Batch(image=torch.cat(masks, dim=0)))
return dataset
def compute_loss(self, batch: Batch) -> torch.Tensor:
raise NotImplementedError("We'll implement this later")
@ -304,16 +303,20 @@ class AutoencoderModelConfig(ModelConfig):
class AutoencoderConfig(BaseConfig):
epoch_size: int = 2048
num_images: int = 2048
batch_size: int = 32
autoencoder: AutoencoderModelConfig
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
# ... other methods
@register_model()
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
return Autoencoder()
def compute_loss(self, batch: Batch) -> torch.Tensor:
batch.image = batch.image.to(self.device, self.dtype)
x_reconstructed = self.autoencoder.decoder(
self.autoencoder.encoder(batch.image)
)
@ -329,58 +332,6 @@ trainer.train()
![alt text](terminal-logging.png)
## Evaluation
We can also evaluate the model using the `compute_evaluation` method.
```python
training = TrainingConfig(
duration=Epoch(1000)
batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32",
evaluation_interval=Epoch(50),
)
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
# ... other methods
def compute_evaluation(self) -> None:
generator = generate_mask(size=64, seed=0)
grid: list[tuple[Image.Image, Image.Image]] = []
for _ in range(4):
mask = next(generator).to(self.device, self.dtype)
x_reconstructed = self.autoencoder.decoder(
self.autoencoder.encoder(mask)
)
loss = F.mse_loss(x_reconstructed, mask)
logger.info(f"Validation loss: {loss.detach().cpu().item()}")
grid.append(
(tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))
)
import matplotlib.pyplot as plt
_, axes = plt.subplots(4, 2, figsize=(8, 16))
for i, (mask, reconstructed) in enumerate(grid):
axes[i, 0].imshow(mask, cmap='gray')
axes[i, 0].axis('off')
axes[i, 0].set_title('Mask')
axes[i, 1].imshow(reconstructed, cmap='gray')
axes[i, 1].axis('off')
axes[i, 1].set_title('Reconstructed')
plt.tight_layout()
plt.savefig(f"result_{trainer.clock.epoch}.png")
plt.close()
```
![alt text](evaluation.png)
## Logging
Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from `refiners.training_utils.Callback` and implement any of the following methods:
@ -391,8 +342,8 @@ Let's write a simple logging callback to log the loss and the reconstructed imag
- `on_train_end`
- `on_epoch_begin`
- `on_epoch_end`
- `on_batch_begin`
- `on_batch_end`
- `on_step_begin`
- `on_step_end`
- `on_backward_begin`
- `on_backward_end`
- `on_optimizer_step_begin`
@ -430,7 +381,7 @@ Exactly like models, we need to register the callback to the Trainer. We can do
from refiners.training_utils import CallbackConfig, register_callback
class AutoencoderConfig(BaseConfig):
epoch_size: int = 2048
# ... other properties
logging: CallbackConfig = CallbackConfig()
@ -444,6 +395,101 @@ class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
![alt text](loss-logging.png)
## Evaluation
Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a `compute_evaluation` method, then we register a callback to call this method at regular intervals.
```python
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
# ... other methods
def compute_evaluation(self) -> None:
generator = generate_mask(size=64, seed=0)
grid: list[tuple[Image.Image, Image.Image]] = []
for _ in range(4):
mask = next(generator).to(self.device, self.dtype)
x_reconstructed = self.autoencoder.decoder(
self.autoencoder.encoder(mask)
)
loss = F.mse_loss(x_reconstructed, mask)
logger.info(f"Validation loss: {loss.detach().cpu().item()}")
grid.append(
(tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))
)
import matplotlib.pyplot as plt
_, axes = plt.subplots(4, 2, figsize=(8, 16))
for i, (mask, reconstructed) in enumerate(grid):
axes[i, 0].imshow(mask, cmap='gray')
axes[i, 0].axis('off')
axes[i, 0].set_title('Mask')
axes[i, 1].imshow(reconstructed, cmap='gray')
axes[i, 1].axis('off')
axes[i, 1].set_title('Reconstructed')
plt.tight_layout()
plt.savefig(f"result_{trainer.clock.epoch}.png")
plt.close()
```
We starting by implementing an `EvaluationConfig` that controls the evaluation interval and the seed for the random generator.
```python
from refiners.training_utils.config import TimeValueField
class EvaluationConfig(CallbackConfig):
interval: TimeValueField
seed: int
```
The `TimeValueField` is a custom field that allow Pydantic to parse a string representing a time value (e.g., "50:epochs") into a `TimeValue` object. This is useful to specify the evaluation interval in the configuration file.
```python
from refiners.training_utils import scoped_seed, Callback
class EvaluationCallback(Callback[Any]):
def __init__(self, config: EvaluationConfig) -> None:
self.config = config
def on_epoch_end(self, trainer: Trainer) -> None:
# The `is_due` method checks if the current epoch is a multiple of the interval.
if not trainer.clock.is_due(self.config.interval):
return
# The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the
# evaluation.
with scoped_seed(self.config.seed):
trainer.compute_evaluation()
```
We can now register the callback to the Trainer.
```python
class AutoencoderConfig(BaseConfig):
# ... other properties
evaluation: EvaluationConfig
```
```python
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
# ... other methods
@register_callback()
def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:
return EvaluationCallback(config)
```
We can now train the model and see the results in the `result_{epoch}.png` files.
![alt text](evaluation.png)
## Wrap up
You can train this toy model using the code below:
@ -453,7 +499,6 @@ You can train this toy model using the code below:
```py
import random
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Generator
import torch
@ -468,6 +513,7 @@ You can train this toy model using the code below:
Callback,
CallbackConfig,
ClockConfig,
Epoch,
LRSchedulerConfig,
LRSchedulerType,
ModelConfig,
@ -477,8 +523,10 @@ You can train this toy model using the code below:
TrainingConfig,
register_callback,
register_model,
Epoch,
)
from refiners.training_utils.common import scoped_seed
from refiners.training_utils.config import TimeValueField
class ConvBlock(fl.Chain):
def __init__(self, in_channels: int, out_channels: int) -> None:
@ -615,9 +663,31 @@ You can train this toy model using the code below:
self.losses = []
class EvaluationConfig(CallbackConfig):
interval: TimeValueField
seed: int
class EvaluationCallback(Callback["AutoencoderTrainer"]):
def __init__(self, config: EvaluationConfig) -> None:
self.config = config
def on_epoch_end(self, trainer: "AutoencoderTrainer") -> None:
# The `is_due` method checks if the current epoch is a multiple of the interval.
if not trainer.clock.is_due(self.config.interval):
return
# The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the
# evaluation.
with scoped_seed(self.config.seed):
trainer.compute_evaluation()
class AutoencoderConfig(BaseConfig):
epoch_size: int = 2048
num_images: int = 2048
batch_size: int = 32
autoencoder: AutoencoderModelConfig
evaluation: EvaluationConfig
logging: CallbackConfig = CallbackConfig()
@ -626,11 +696,9 @@ You can train this toy model using the code below:
)
training = TrainingConfig(
duration=Epoch(1000),
batch_size=32,
duration=Epoch(200),
device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32",
evaluation_interval=Epoch(50),
)
optimizer = OptimizerConfig(
@ -645,30 +713,28 @@ You can train this toy model using the code below:
optimizer=optimizer,
lr_scheduler=lr_scheduler,
autoencoder=autoencoder_config,
evaluation=EvaluationConfig(interval=Epoch(50), seed=0),
clock=ClockConfig(verbose=False), # to disable the default clock logging
)
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
@cached_property
def image_generator(self) -> Generator[torch.Tensor, None, None]:
return generate_mask(size=64)
def create_data_iterable(self) -> list[Batch]:
dataset: list[Batch] = []
generator = generate_mask(size=64)
def get_item(self, index: int) -> Batch:
return Batch(image=next(self.image_generator).to(self.device, self.dtype))
for _ in range(self.config.num_images // self.config.batch_size):
masks = [next(generator).to(self.device, self.dtype) for _ in range(self.config.batch_size)]
dataset.append(Batch(image=torch.cat(masks, dim=0)))
def collate_fn(self, batch: list[Batch]) -> Batch:
return Batch(image=torch.cat([b.image for b in batch]))
@property
def dataset_length(self) -> int:
return self.config.epoch_size
return dataset
@register_model()
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
return Autoencoder()
def compute_loss(self, batch: Batch) -> torch.Tensor:
batch.image = batch.image.to(self.device, self.dtype)
x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))
return F.binary_cross_entropy(x_reconstructed, batch.image)
@ -700,9 +766,13 @@ You can train this toy model using the code below:
axes[i, 1].axis("off")
axes[i, 1].set_title("Reconstructed")
plt.tight_layout() # type: ignore
plt.tight_layout() # type: ignore
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
plt.close() # type: ignore
plt.close() # type: ignore
@register_callback()
def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:
return EvaluationCallback(config)
@register_callback()
def logging(self, config: CallbackConfig) -> LoggingCallback: