mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
update Training 101
This commit is contained in:
parent
0bec9a855d
commit
76a6ce8641
|
@ -222,12 +222,10 @@ Example:
|
|||
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch
|
||||
|
||||
class AutoencoderConfig(BaseConfig):
|
||||
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
|
||||
epoch_size: int = 2048
|
||||
...
|
||||
|
||||
training = TrainingConfig(
|
||||
duration=Epoch(1000),
|
||||
batch_size=32,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
dtype="float32"
|
||||
)
|
||||
|
@ -252,30 +250,31 @@ config = AutoencoderConfig(
|
|||
|
||||
We can now define the Trainer subclass. It should inherit from `refiners.training_utils.Trainer` and implement the following methods:
|
||||
|
||||
- `get_item`: This method should take an index and return a Batch.
|
||||
- `collate_fn`: This method should take a list of Batch and return a concatenated Batch.
|
||||
- `dataset_length`: We implement this property to return the length of the dataset.
|
||||
- `compute_loss`: This method should take a Batch and return the loss.
|
||||
- `create_data_iterable`: The `Trainer` will call this method to create and cache the data iterable. During training, the loop will pull batches from this iterable and pass them to the `compute_loss` method. Every time the iterable is exhausted, an epoch ends.
|
||||
- `compute_loss`: This method should take a Batch and return the loss tensor.
|
||||
|
||||
Here is a simple implementation of the `create_data_iterable` method. For this toy example, we will generate a simple list of `Batch` objects containing random masks. Later you can replace this with `torch.utils.data.DataLoader` or any other data loader with more complex features that support shuffling, parallel loading, etc.
|
||||
|
||||
```python
|
||||
from functools import cached_property
|
||||
from refiners.training_utils import Trainer
|
||||
|
||||
|
||||
class AutoencoderConfig(BaseConfig):
|
||||
num_images: int = 2048
|
||||
batch_size: int = 32
|
||||
|
||||
|
||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||
@cached_property
|
||||
def image_generator(self) -> Generator[torch.Tensor, None, None]:
|
||||
return generate_mask(size=64)
|
||||
def create_data_iterable(self) -> list[Batch]:
|
||||
dataset: list[Batch] = []
|
||||
generator = generate_mask(size=64)
|
||||
|
||||
def get_item(self, index: int) -> Batch:
|
||||
return Batch(image=next(self.image_generator).to(self.device, self.dtype))
|
||||
|
||||
def collate_fn(self, batch: list[Batch]) -> Batch:
|
||||
return Batch(image=torch.cat([b.image for b in batch]))
|
||||
|
||||
@property
|
||||
def dataset_length(self) -> int:
|
||||
return self.config.epoch_size
|
||||
for _ in range(self.config.num_images // self.config.batch_size):
|
||||
masks = [next(generator) for _ in range(self.config.batch_size)]
|
||||
dataset.append(Batch(image=torch.cat(masks, dim=0)))
|
||||
|
||||
return dataset
|
||||
|
||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||
raise NotImplementedError("We'll implement this later")
|
||||
|
@ -304,16 +303,20 @@ class AutoencoderModelConfig(ModelConfig):
|
|||
|
||||
|
||||
class AutoencoderConfig(BaseConfig):
|
||||
epoch_size: int = 2048
|
||||
num_images: int = 2048
|
||||
batch_size: int = 32
|
||||
autoencoder: AutoencoderModelConfig
|
||||
|
||||
|
||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||
# ... other methods
|
||||
|
||||
@register_model()
|
||||
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
|
||||
return Autoencoder()
|
||||
|
||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||
batch.image = batch.image.to(self.device, self.dtype)
|
||||
x_reconstructed = self.autoencoder.decoder(
|
||||
self.autoencoder.encoder(batch.image)
|
||||
)
|
||||
|
@ -329,58 +332,6 @@ trainer.train()
|
|||
![alt text](terminal-logging.png)
|
||||
|
||||
|
||||
## Evaluation
|
||||
|
||||
We can also evaluate the model using the `compute_evaluation` method.
|
||||
|
||||
```python
|
||||
training = TrainingConfig(
|
||||
duration=Epoch(1000)
|
||||
batch_size=32,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
dtype="float32",
|
||||
evaluation_interval=Epoch(50),
|
||||
)
|
||||
|
||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||
# ... other methods
|
||||
|
||||
def compute_evaluation(self) -> None:
|
||||
generator = generate_mask(size=64, seed=0)
|
||||
|
||||
grid: list[tuple[Image.Image, Image.Image]] = []
|
||||
for _ in range(4):
|
||||
mask = next(generator).to(self.device, self.dtype)
|
||||
x_reconstructed = self.autoencoder.decoder(
|
||||
self.autoencoder.encoder(mask)
|
||||
)
|
||||
loss = F.mse_loss(x_reconstructed, mask)
|
||||
logger.info(f"Validation loss: {loss.detach().cpu().item()}")
|
||||
grid.append(
|
||||
(tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_, axes = plt.subplots(4, 2, figsize=(8, 16))
|
||||
|
||||
for i, (mask, reconstructed) in enumerate(grid):
|
||||
axes[i, 0].imshow(mask, cmap='gray')
|
||||
axes[i, 0].axis('off')
|
||||
axes[i, 0].set_title('Mask')
|
||||
|
||||
axes[i, 1].imshow(reconstructed, cmap='gray')
|
||||
axes[i, 1].axis('off')
|
||||
axes[i, 1].set_title('Reconstructed')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"result_{trainer.clock.epoch}.png")
|
||||
plt.close()
|
||||
```
|
||||
|
||||
![alt text](evaluation.png)
|
||||
|
||||
|
||||
## Logging
|
||||
|
||||
Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from `refiners.training_utils.Callback` and implement any of the following methods:
|
||||
|
@ -391,8 +342,8 @@ Let's write a simple logging callback to log the loss and the reconstructed imag
|
|||
- `on_train_end`
|
||||
- `on_epoch_begin`
|
||||
- `on_epoch_end`
|
||||
- `on_batch_begin`
|
||||
- `on_batch_end`
|
||||
- `on_step_begin`
|
||||
- `on_step_end`
|
||||
- `on_backward_begin`
|
||||
- `on_backward_end`
|
||||
- `on_optimizer_step_begin`
|
||||
|
@ -430,7 +381,7 @@ Exactly like models, we need to register the callback to the Trainer. We can do
|
|||
from refiners.training_utils import CallbackConfig, register_callback
|
||||
|
||||
class AutoencoderConfig(BaseConfig):
|
||||
epoch_size: int = 2048
|
||||
# ... other properties
|
||||
logging: CallbackConfig = CallbackConfig()
|
||||
|
||||
|
||||
|
@ -444,6 +395,101 @@ class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
|||
|
||||
![alt text](loss-logging.png)
|
||||
|
||||
## Evaluation
|
||||
|
||||
Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a `compute_evaluation` method, then we register a callback to call this method at regular intervals.
|
||||
|
||||
|
||||
```python
|
||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||
# ... other methods
|
||||
|
||||
def compute_evaluation(self) -> None:
|
||||
generator = generate_mask(size=64, seed=0)
|
||||
|
||||
grid: list[tuple[Image.Image, Image.Image]] = []
|
||||
for _ in range(4):
|
||||
mask = next(generator).to(self.device, self.dtype)
|
||||
x_reconstructed = self.autoencoder.decoder(
|
||||
self.autoencoder.encoder(mask)
|
||||
)
|
||||
loss = F.mse_loss(x_reconstructed, mask)
|
||||
logger.info(f"Validation loss: {loss.detach().cpu().item()}")
|
||||
grid.append(
|
||||
(tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_, axes = plt.subplots(4, 2, figsize=(8, 16))
|
||||
|
||||
for i, (mask, reconstructed) in enumerate(grid):
|
||||
axes[i, 0].imshow(mask, cmap='gray')
|
||||
axes[i, 0].axis('off')
|
||||
axes[i, 0].set_title('Mask')
|
||||
|
||||
axes[i, 1].imshow(reconstructed, cmap='gray')
|
||||
axes[i, 1].axis('off')
|
||||
axes[i, 1].set_title('Reconstructed')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"result_{trainer.clock.epoch}.png")
|
||||
plt.close()
|
||||
```
|
||||
|
||||
We starting by implementing an `EvaluationConfig` that controls the evaluation interval and the seed for the random generator.
|
||||
|
||||
```python
|
||||
from refiners.training_utils.config import TimeValueField
|
||||
|
||||
class EvaluationConfig(CallbackConfig):
|
||||
interval: TimeValueField
|
||||
seed: int
|
||||
```
|
||||
|
||||
The `TimeValueField` is a custom field that allow Pydantic to parse a string representing a time value (e.g., "50:epochs") into a `TimeValue` object. This is useful to specify the evaluation interval in the configuration file.
|
||||
|
||||
```python
|
||||
from refiners.training_utils import scoped_seed, Callback
|
||||
|
||||
class EvaluationCallback(Callback[Any]):
|
||||
def __init__(self, config: EvaluationConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def on_epoch_end(self, trainer: Trainer) -> None:
|
||||
# The `is_due` method checks if the current epoch is a multiple of the interval.
|
||||
if not trainer.clock.is_due(self.config.interval):
|
||||
return
|
||||
|
||||
# The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the
|
||||
# evaluation.
|
||||
with scoped_seed(self.config.seed):
|
||||
trainer.compute_evaluation()
|
||||
```
|
||||
|
||||
We can now register the callback to the Trainer.
|
||||
|
||||
|
||||
```python
|
||||
class AutoencoderConfig(BaseConfig):
|
||||
# ... other properties
|
||||
evaluation: EvaluationConfig
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||
# ... other methods
|
||||
|
||||
@register_callback()
|
||||
def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:
|
||||
return EvaluationCallback(config)
|
||||
```
|
||||
|
||||
We can now train the model and see the results in the `result_{epoch}.png` files.
|
||||
|
||||
![alt text](evaluation.png)
|
||||
|
||||
## Wrap up
|
||||
|
||||
You can train this toy model using the code below:
|
||||
|
@ -453,7 +499,6 @@ You can train this toy model using the code below:
|
|||
```py
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, Generator
|
||||
|
||||
import torch
|
||||
|
@ -468,6 +513,7 @@ You can train this toy model using the code below:
|
|||
Callback,
|
||||
CallbackConfig,
|
||||
ClockConfig,
|
||||
Epoch,
|
||||
LRSchedulerConfig,
|
||||
LRSchedulerType,
|
||||
ModelConfig,
|
||||
|
@ -477,8 +523,10 @@ You can train this toy model using the code below:
|
|||
TrainingConfig,
|
||||
register_callback,
|
||||
register_model,
|
||||
Epoch,
|
||||
)
|
||||
from refiners.training_utils.common import scoped_seed
|
||||
from refiners.training_utils.config import TimeValueField
|
||||
|
||||
|
||||
class ConvBlock(fl.Chain):
|
||||
def __init__(self, in_channels: int, out_channels: int) -> None:
|
||||
|
@ -615,9 +663,31 @@ You can train this toy model using the code below:
|
|||
self.losses = []
|
||||
|
||||
|
||||
class EvaluationConfig(CallbackConfig):
|
||||
interval: TimeValueField
|
||||
seed: int
|
||||
|
||||
|
||||
class EvaluationCallback(Callback["AutoencoderTrainer"]):
|
||||
def __init__(self, config: EvaluationConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def on_epoch_end(self, trainer: "AutoencoderTrainer") -> None:
|
||||
# The `is_due` method checks if the current epoch is a multiple of the interval.
|
||||
if not trainer.clock.is_due(self.config.interval):
|
||||
return
|
||||
|
||||
# The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the
|
||||
# evaluation.
|
||||
with scoped_seed(self.config.seed):
|
||||
trainer.compute_evaluation()
|
||||
|
||||
|
||||
class AutoencoderConfig(BaseConfig):
|
||||
epoch_size: int = 2048
|
||||
num_images: int = 2048
|
||||
batch_size: int = 32
|
||||
autoencoder: AutoencoderModelConfig
|
||||
evaluation: EvaluationConfig
|
||||
logging: CallbackConfig = CallbackConfig()
|
||||
|
||||
|
||||
|
@ -626,11 +696,9 @@ You can train this toy model using the code below:
|
|||
)
|
||||
|
||||
training = TrainingConfig(
|
||||
duration=Epoch(1000),
|
||||
batch_size=32,
|
||||
duration=Epoch(200),
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
dtype="float32",
|
||||
evaluation_interval=Epoch(50),
|
||||
)
|
||||
|
||||
optimizer = OptimizerConfig(
|
||||
|
@ -645,30 +713,28 @@ You can train this toy model using the code below:
|
|||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
autoencoder=autoencoder_config,
|
||||
evaluation=EvaluationConfig(interval=Epoch(50), seed=0),
|
||||
clock=ClockConfig(verbose=False), # to disable the default clock logging
|
||||
)
|
||||
|
||||
|
||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||
@cached_property
|
||||
def image_generator(self) -> Generator[torch.Tensor, None, None]:
|
||||
return generate_mask(size=64)
|
||||
def create_data_iterable(self) -> list[Batch]:
|
||||
dataset: list[Batch] = []
|
||||
generator = generate_mask(size=64)
|
||||
|
||||
def get_item(self, index: int) -> Batch:
|
||||
return Batch(image=next(self.image_generator).to(self.device, self.dtype))
|
||||
for _ in range(self.config.num_images // self.config.batch_size):
|
||||
masks = [next(generator).to(self.device, self.dtype) for _ in range(self.config.batch_size)]
|
||||
dataset.append(Batch(image=torch.cat(masks, dim=0)))
|
||||
|
||||
def collate_fn(self, batch: list[Batch]) -> Batch:
|
||||
return Batch(image=torch.cat([b.image for b in batch]))
|
||||
|
||||
@property
|
||||
def dataset_length(self) -> int:
|
||||
return self.config.epoch_size
|
||||
return dataset
|
||||
|
||||
@register_model()
|
||||
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
|
||||
return Autoencoder()
|
||||
|
||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||
batch.image = batch.image.to(self.device, self.dtype)
|
||||
x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))
|
||||
return F.binary_cross_entropy(x_reconstructed, batch.image)
|
||||
|
||||
|
@ -700,9 +766,13 @@ You can train this toy model using the code below:
|
|||
axes[i, 1].axis("off")
|
||||
axes[i, 1].set_title("Reconstructed")
|
||||
|
||||
plt.tight_layout() # type: ignore
|
||||
plt.tight_layout() # type: ignore
|
||||
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
|
||||
plt.close() # type: ignore
|
||||
plt.close() # type: ignore
|
||||
|
||||
@register_callback()
|
||||
def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:
|
||||
return EvaluationCallback(config)
|
||||
|
||||
@register_callback()
|
||||
def logging(self, config: CallbackConfig) -> LoggingCallback:
|
||||
|
|
Loading…
Reference in a new issue