mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
update Training 101
This commit is contained in:
parent
0bec9a855d
commit
76a6ce8641
|
@ -222,12 +222,10 @@ Example:
|
||||||
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch
|
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch
|
||||||
|
|
||||||
class AutoencoderConfig(BaseConfig):
|
class AutoencoderConfig(BaseConfig):
|
||||||
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
|
...
|
||||||
epoch_size: int = 2048
|
|
||||||
|
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration=Epoch(1000),
|
duration=Epoch(1000),
|
||||||
batch_size=32,
|
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32"
|
dtype="float32"
|
||||||
)
|
)
|
||||||
|
@ -252,30 +250,31 @@ config = AutoencoderConfig(
|
||||||
|
|
||||||
We can now define the Trainer subclass. It should inherit from `refiners.training_utils.Trainer` and implement the following methods:
|
We can now define the Trainer subclass. It should inherit from `refiners.training_utils.Trainer` and implement the following methods:
|
||||||
|
|
||||||
- `get_item`: This method should take an index and return a Batch.
|
- `create_data_iterable`: The `Trainer` will call this method to create and cache the data iterable. During training, the loop will pull batches from this iterable and pass them to the `compute_loss` method. Every time the iterable is exhausted, an epoch ends.
|
||||||
- `collate_fn`: This method should take a list of Batch and return a concatenated Batch.
|
- `compute_loss`: This method should take a Batch and return the loss tensor.
|
||||||
- `dataset_length`: We implement this property to return the length of the dataset.
|
|
||||||
- `compute_loss`: This method should take a Batch and return the loss.
|
Here is a simple implementation of the `create_data_iterable` method. For this toy example, we will generate a simple list of `Batch` objects containing random masks. Later you can replace this with `torch.utils.data.DataLoader` or any other data loader with more complex features that support shuffling, parallel loading, etc.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from refiners.training_utils import Trainer
|
from refiners.training_utils import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderConfig(BaseConfig):
|
||||||
|
num_images: int = 2048
|
||||||
|
batch_size: int = 32
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
@cached_property
|
def create_data_iterable(self) -> list[Batch]:
|
||||||
def image_generator(self) -> Generator[torch.Tensor, None, None]:
|
dataset: list[Batch] = []
|
||||||
return generate_mask(size=64)
|
generator = generate_mask(size=64)
|
||||||
|
|
||||||
def get_item(self, index: int) -> Batch:
|
for _ in range(self.config.num_images // self.config.batch_size):
|
||||||
return Batch(image=next(self.image_generator).to(self.device, self.dtype))
|
masks = [next(generator) for _ in range(self.config.batch_size)]
|
||||||
|
dataset.append(Batch(image=torch.cat(masks, dim=0)))
|
||||||
|
|
||||||
def collate_fn(self, batch: list[Batch]) -> Batch:
|
return dataset
|
||||||
return Batch(image=torch.cat([b.image for b in batch]))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dataset_length(self) -> int:
|
|
||||||
return self.config.epoch_size
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||||
raise NotImplementedError("We'll implement this later")
|
raise NotImplementedError("We'll implement this later")
|
||||||
|
@ -304,16 +303,20 @@ class AutoencoderModelConfig(ModelConfig):
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderConfig(BaseConfig):
|
class AutoencoderConfig(BaseConfig):
|
||||||
epoch_size: int = 2048
|
num_images: int = 2048
|
||||||
|
batch_size: int = 32
|
||||||
autoencoder: AutoencoderModelConfig
|
autoencoder: AutoencoderModelConfig
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
|
# ... other methods
|
||||||
|
|
||||||
@register_model()
|
@register_model()
|
||||||
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
|
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
|
||||||
return Autoencoder()
|
return Autoencoder()
|
||||||
|
|
||||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||||
|
batch.image = batch.image.to(self.device, self.dtype)
|
||||||
x_reconstructed = self.autoencoder.decoder(
|
x_reconstructed = self.autoencoder.decoder(
|
||||||
self.autoencoder.encoder(batch.image)
|
self.autoencoder.encoder(batch.image)
|
||||||
)
|
)
|
||||||
|
@ -329,58 +332,6 @@ trainer.train()
|
||||||
![alt text](terminal-logging.png)
|
![alt text](terminal-logging.png)
|
||||||
|
|
||||||
|
|
||||||
## Evaluation
|
|
||||||
|
|
||||||
We can also evaluate the model using the `compute_evaluation` method.
|
|
||||||
|
|
||||||
```python
|
|
||||||
training = TrainingConfig(
|
|
||||||
duration=Epoch(1000)
|
|
||||||
batch_size=32,
|
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
dtype="float32",
|
|
||||||
evaluation_interval=Epoch(50),
|
|
||||||
)
|
|
||||||
|
|
||||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
|
||||||
# ... other methods
|
|
||||||
|
|
||||||
def compute_evaluation(self) -> None:
|
|
||||||
generator = generate_mask(size=64, seed=0)
|
|
||||||
|
|
||||||
grid: list[tuple[Image.Image, Image.Image]] = []
|
|
||||||
for _ in range(4):
|
|
||||||
mask = next(generator).to(self.device, self.dtype)
|
|
||||||
x_reconstructed = self.autoencoder.decoder(
|
|
||||||
self.autoencoder.encoder(mask)
|
|
||||||
)
|
|
||||||
loss = F.mse_loss(x_reconstructed, mask)
|
|
||||||
logger.info(f"Validation loss: {loss.detach().cpu().item()}")
|
|
||||||
grid.append(
|
|
||||||
(tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))
|
|
||||||
)
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
_, axes = plt.subplots(4, 2, figsize=(8, 16))
|
|
||||||
|
|
||||||
for i, (mask, reconstructed) in enumerate(grid):
|
|
||||||
axes[i, 0].imshow(mask, cmap='gray')
|
|
||||||
axes[i, 0].axis('off')
|
|
||||||
axes[i, 0].set_title('Mask')
|
|
||||||
|
|
||||||
axes[i, 1].imshow(reconstructed, cmap='gray')
|
|
||||||
axes[i, 1].axis('off')
|
|
||||||
axes[i, 1].set_title('Reconstructed')
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(f"result_{trainer.clock.epoch}.png")
|
|
||||||
plt.close()
|
|
||||||
```
|
|
||||||
|
|
||||||
![alt text](evaluation.png)
|
|
||||||
|
|
||||||
|
|
||||||
## Logging
|
## Logging
|
||||||
|
|
||||||
Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from `refiners.training_utils.Callback` and implement any of the following methods:
|
Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from `refiners.training_utils.Callback` and implement any of the following methods:
|
||||||
|
@ -391,8 +342,8 @@ Let's write a simple logging callback to log the loss and the reconstructed imag
|
||||||
- `on_train_end`
|
- `on_train_end`
|
||||||
- `on_epoch_begin`
|
- `on_epoch_begin`
|
||||||
- `on_epoch_end`
|
- `on_epoch_end`
|
||||||
- `on_batch_begin`
|
- `on_step_begin`
|
||||||
- `on_batch_end`
|
- `on_step_end`
|
||||||
- `on_backward_begin`
|
- `on_backward_begin`
|
||||||
- `on_backward_end`
|
- `on_backward_end`
|
||||||
- `on_optimizer_step_begin`
|
- `on_optimizer_step_begin`
|
||||||
|
@ -430,7 +381,7 @@ Exactly like models, we need to register the callback to the Trainer. We can do
|
||||||
from refiners.training_utils import CallbackConfig, register_callback
|
from refiners.training_utils import CallbackConfig, register_callback
|
||||||
|
|
||||||
class AutoencoderConfig(BaseConfig):
|
class AutoencoderConfig(BaseConfig):
|
||||||
epoch_size: int = 2048
|
# ... other properties
|
||||||
logging: CallbackConfig = CallbackConfig()
|
logging: CallbackConfig = CallbackConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@ -444,6 +395,101 @@ class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
|
|
||||||
![alt text](loss-logging.png)
|
![alt text](loss-logging.png)
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a `compute_evaluation` method, then we register a callback to call this method at regular intervals.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
|
# ... other methods
|
||||||
|
|
||||||
|
def compute_evaluation(self) -> None:
|
||||||
|
generator = generate_mask(size=64, seed=0)
|
||||||
|
|
||||||
|
grid: list[tuple[Image.Image, Image.Image]] = []
|
||||||
|
for _ in range(4):
|
||||||
|
mask = next(generator).to(self.device, self.dtype)
|
||||||
|
x_reconstructed = self.autoencoder.decoder(
|
||||||
|
self.autoencoder.encoder(mask)
|
||||||
|
)
|
||||||
|
loss = F.mse_loss(x_reconstructed, mask)
|
||||||
|
logger.info(f"Validation loss: {loss.detach().cpu().item()}")
|
||||||
|
grid.append(
|
||||||
|
(tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
_, axes = plt.subplots(4, 2, figsize=(8, 16))
|
||||||
|
|
||||||
|
for i, (mask, reconstructed) in enumerate(grid):
|
||||||
|
axes[i, 0].imshow(mask, cmap='gray')
|
||||||
|
axes[i, 0].axis('off')
|
||||||
|
axes[i, 0].set_title('Mask')
|
||||||
|
|
||||||
|
axes[i, 1].imshow(reconstructed, cmap='gray')
|
||||||
|
axes[i, 1].axis('off')
|
||||||
|
axes[i, 1].set_title('Reconstructed')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(f"result_{trainer.clock.epoch}.png")
|
||||||
|
plt.close()
|
||||||
|
```
|
||||||
|
|
||||||
|
We starting by implementing an `EvaluationConfig` that controls the evaluation interval and the seed for the random generator.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from refiners.training_utils.config import TimeValueField
|
||||||
|
|
||||||
|
class EvaluationConfig(CallbackConfig):
|
||||||
|
interval: TimeValueField
|
||||||
|
seed: int
|
||||||
|
```
|
||||||
|
|
||||||
|
The `TimeValueField` is a custom field that allow Pydantic to parse a string representing a time value (e.g., "50:epochs") into a `TimeValue` object. This is useful to specify the evaluation interval in the configuration file.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from refiners.training_utils import scoped_seed, Callback
|
||||||
|
|
||||||
|
class EvaluationCallback(Callback[Any]):
|
||||||
|
def __init__(self, config: EvaluationConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: Trainer) -> None:
|
||||||
|
# The `is_due` method checks if the current epoch is a multiple of the interval.
|
||||||
|
if not trainer.clock.is_due(self.config.interval):
|
||||||
|
return
|
||||||
|
|
||||||
|
# The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the
|
||||||
|
# evaluation.
|
||||||
|
with scoped_seed(self.config.seed):
|
||||||
|
trainer.compute_evaluation()
|
||||||
|
```
|
||||||
|
|
||||||
|
We can now register the callback to the Trainer.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AutoencoderConfig(BaseConfig):
|
||||||
|
# ... other properties
|
||||||
|
evaluation: EvaluationConfig
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
|
# ... other methods
|
||||||
|
|
||||||
|
@register_callback()
|
||||||
|
def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:
|
||||||
|
return EvaluationCallback(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
We can now train the model and see the results in the `result_{epoch}.png` files.
|
||||||
|
|
||||||
|
![alt text](evaluation.png)
|
||||||
|
|
||||||
## Wrap up
|
## Wrap up
|
||||||
|
|
||||||
You can train this toy model using the code below:
|
You can train this toy model using the code below:
|
||||||
|
@ -453,7 +499,6 @@ You can train this toy model using the code below:
|
||||||
```py
|
```py
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -468,6 +513,7 @@ You can train this toy model using the code below:
|
||||||
Callback,
|
Callback,
|
||||||
CallbackConfig,
|
CallbackConfig,
|
||||||
ClockConfig,
|
ClockConfig,
|
||||||
|
Epoch,
|
||||||
LRSchedulerConfig,
|
LRSchedulerConfig,
|
||||||
LRSchedulerType,
|
LRSchedulerType,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
|
@ -477,8 +523,10 @@ You can train this toy model using the code below:
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
register_callback,
|
register_callback,
|
||||||
register_model,
|
register_model,
|
||||||
Epoch,
|
|
||||||
)
|
)
|
||||||
|
from refiners.training_utils.common import scoped_seed
|
||||||
|
from refiners.training_utils.config import TimeValueField
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(fl.Chain):
|
class ConvBlock(fl.Chain):
|
||||||
def __init__(self, in_channels: int, out_channels: int) -> None:
|
def __init__(self, in_channels: int, out_channels: int) -> None:
|
||||||
|
@ -615,9 +663,31 @@ You can train this toy model using the code below:
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationConfig(CallbackConfig):
|
||||||
|
interval: TimeValueField
|
||||||
|
seed: int
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationCallback(Callback["AutoencoderTrainer"]):
|
||||||
|
def __init__(self, config: EvaluationConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: "AutoencoderTrainer") -> None:
|
||||||
|
# The `is_due` method checks if the current epoch is a multiple of the interval.
|
||||||
|
if not trainer.clock.is_due(self.config.interval):
|
||||||
|
return
|
||||||
|
|
||||||
|
# The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the
|
||||||
|
# evaluation.
|
||||||
|
with scoped_seed(self.config.seed):
|
||||||
|
trainer.compute_evaluation()
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderConfig(BaseConfig):
|
class AutoencoderConfig(BaseConfig):
|
||||||
epoch_size: int = 2048
|
num_images: int = 2048
|
||||||
|
batch_size: int = 32
|
||||||
autoencoder: AutoencoderModelConfig
|
autoencoder: AutoencoderModelConfig
|
||||||
|
evaluation: EvaluationConfig
|
||||||
logging: CallbackConfig = CallbackConfig()
|
logging: CallbackConfig = CallbackConfig()
|
||||||
|
|
||||||
|
|
||||||
|
@ -626,11 +696,9 @@ You can train this toy model using the code below:
|
||||||
)
|
)
|
||||||
|
|
||||||
training = TrainingConfig(
|
training = TrainingConfig(
|
||||||
duration=Epoch(1000),
|
duration=Epoch(200),
|
||||||
batch_size=32,
|
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
evaluation_interval=Epoch(50),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = OptimizerConfig(
|
optimizer = OptimizerConfig(
|
||||||
|
@ -645,30 +713,28 @@ You can train this toy model using the code below:
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
autoencoder=autoencoder_config,
|
autoencoder=autoencoder_config,
|
||||||
|
evaluation=EvaluationConfig(interval=Epoch(50), seed=0),
|
||||||
clock=ClockConfig(verbose=False), # to disable the default clock logging
|
clock=ClockConfig(verbose=False), # to disable the default clock logging
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
|
||||||
@cached_property
|
def create_data_iterable(self) -> list[Batch]:
|
||||||
def image_generator(self) -> Generator[torch.Tensor, None, None]:
|
dataset: list[Batch] = []
|
||||||
return generate_mask(size=64)
|
generator = generate_mask(size=64)
|
||||||
|
|
||||||
def get_item(self, index: int) -> Batch:
|
for _ in range(self.config.num_images // self.config.batch_size):
|
||||||
return Batch(image=next(self.image_generator).to(self.device, self.dtype))
|
masks = [next(generator).to(self.device, self.dtype) for _ in range(self.config.batch_size)]
|
||||||
|
dataset.append(Batch(image=torch.cat(masks, dim=0)))
|
||||||
|
|
||||||
def collate_fn(self, batch: list[Batch]) -> Batch:
|
return dataset
|
||||||
return Batch(image=torch.cat([b.image for b in batch]))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dataset_length(self) -> int:
|
|
||||||
return self.config.epoch_size
|
|
||||||
|
|
||||||
@register_model()
|
@register_model()
|
||||||
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
|
def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder:
|
||||||
return Autoencoder()
|
return Autoencoder()
|
||||||
|
|
||||||
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
def compute_loss(self, batch: Batch) -> torch.Tensor:
|
||||||
|
batch.image = batch.image.to(self.device, self.dtype)
|
||||||
x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))
|
x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))
|
||||||
return F.binary_cross_entropy(x_reconstructed, batch.image)
|
return F.binary_cross_entropy(x_reconstructed, batch.image)
|
||||||
|
|
||||||
|
@ -700,9 +766,13 @@ You can train this toy model using the code below:
|
||||||
axes[i, 1].axis("off")
|
axes[i, 1].axis("off")
|
||||||
axes[i, 1].set_title("Reconstructed")
|
axes[i, 1].set_title("Reconstructed")
|
||||||
|
|
||||||
plt.tight_layout() # type: ignore
|
plt.tight_layout() # type: ignore
|
||||||
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
|
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
|
||||||
plt.close() # type: ignore
|
plt.close() # type: ignore
|
||||||
|
|
||||||
|
@register_callback()
|
||||||
|
def evaluation(self, config: EvaluationConfig) -> EvaluationCallback:
|
||||||
|
return EvaluationCallback(config)
|
||||||
|
|
||||||
@register_callback()
|
@register_callback()
|
||||||
def logging(self, config: CallbackConfig) -> LoggingCallback:
|
def logging(self, config: CallbackConfig) -> LoggingCallback:
|
||||||
|
|
Loading…
Reference in a new issue