From 76a6ce86413b6a9be90d155e803240a9979c59f8 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Mon, 29 Apr 2024 08:04:10 +0000 Subject: [PATCH] update Training 101 --- docs/guides/training_101/index.md | 258 +++++++++++++++++++----------- 1 file changed, 164 insertions(+), 94 deletions(-) diff --git a/docs/guides/training_101/index.md b/docs/guides/training_101/index.md index 0b62806..64bf80a 100644 --- a/docs/guides/training_101/index.md +++ b/docs/guides/training_101/index.md @@ -222,12 +222,10 @@ Example: from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulerType, Epoch class AutoencoderConfig(BaseConfig): - # Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size. - epoch_size: int = 2048 + ... training = TrainingConfig( duration=Epoch(1000), - batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32" ) @@ -252,30 +250,31 @@ config = AutoencoderConfig( We can now define the Trainer subclass. It should inherit from `refiners.training_utils.Trainer` and implement the following methods: - - `get_item`: This method should take an index and return a Batch. - - `collate_fn`: This method should take a list of Batch and return a concatenated Batch. - - `dataset_length`: We implement this property to return the length of the dataset. - - `compute_loss`: This method should take a Batch and return the loss. + - `create_data_iterable`: The `Trainer` will call this method to create and cache the data iterable. During training, the loop will pull batches from this iterable and pass them to the `compute_loss` method. Every time the iterable is exhausted, an epoch ends. + - `compute_loss`: This method should take a Batch and return the loss tensor. + +Here is a simple implementation of the `create_data_iterable` method. For this toy example, we will generate a simple list of `Batch` objects containing random masks. Later you can replace this with `torch.utils.data.DataLoader` or any other data loader with more complex features that support shuffling, parallel loading, etc. ```python from functools import cached_property from refiners.training_utils import Trainer +class AutoencoderConfig(BaseConfig): + num_images: int = 2048 + batch_size: int = 32 + + class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): - @cached_property - def image_generator(self) -> Generator[torch.Tensor, None, None]: - return generate_mask(size=64) + def create_data_iterable(self) -> list[Batch]: + dataset: list[Batch] = [] + generator = generate_mask(size=64) - def get_item(self, index: int) -> Batch: - return Batch(image=next(self.image_generator).to(self.device, self.dtype)) - - def collate_fn(self, batch: list[Batch]) -> Batch: - return Batch(image=torch.cat([b.image for b in batch])) - - @property - def dataset_length(self) -> int: - return self.config.epoch_size + for _ in range(self.config.num_images // self.config.batch_size): + masks = [next(generator) for _ in range(self.config.batch_size)] + dataset.append(Batch(image=torch.cat(masks, dim=0))) + + return dataset def compute_loss(self, batch: Batch) -> torch.Tensor: raise NotImplementedError("We'll implement this later") @@ -304,16 +303,20 @@ class AutoencoderModelConfig(ModelConfig): class AutoencoderConfig(BaseConfig): - epoch_size: int = 2048 + num_images: int = 2048 + batch_size: int = 32 autoencoder: AutoencoderModelConfig class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): + # ... other methods + @register_model() def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder: return Autoencoder() def compute_loss(self, batch: Batch) -> torch.Tensor: + batch.image = batch.image.to(self.device, self.dtype) x_reconstructed = self.autoencoder.decoder( self.autoencoder.encoder(batch.image) ) @@ -329,58 +332,6 @@ trainer.train() ![alt text](terminal-logging.png) -## Evaluation - -We can also evaluate the model using the `compute_evaluation` method. - -```python -training = TrainingConfig( - duration=Epoch(1000) - batch_size=32, - device="cuda" if torch.cuda.is_available() else "cpu", - dtype="float32", - evaluation_interval=Epoch(50), -) - -class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): - # ... other methods - - def compute_evaluation(self) -> None: - generator = generate_mask(size=64, seed=0) - - grid: list[tuple[Image.Image, Image.Image]] = [] - for _ in range(4): - mask = next(generator).to(self.device, self.dtype) - x_reconstructed = self.autoencoder.decoder( - self.autoencoder.encoder(mask) - ) - loss = F.mse_loss(x_reconstructed, mask) - logger.info(f"Validation loss: {loss.detach().cpu().item()}") - grid.append( - (tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float())) - ) - - import matplotlib.pyplot as plt - - _, axes = plt.subplots(4, 2, figsize=(8, 16)) - - for i, (mask, reconstructed) in enumerate(grid): - axes[i, 0].imshow(mask, cmap='gray') - axes[i, 0].axis('off') - axes[i, 0].set_title('Mask') - - axes[i, 1].imshow(reconstructed, cmap='gray') - axes[i, 1].axis('off') - axes[i, 1].set_title('Reconstructed') - - plt.tight_layout() - plt.savefig(f"result_{trainer.clock.epoch}.png") - plt.close() -``` - -![alt text](evaluation.png) - - ## Logging Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from `refiners.training_utils.Callback` and implement any of the following methods: @@ -391,8 +342,8 @@ Let's write a simple logging callback to log the loss and the reconstructed imag - `on_train_end` - `on_epoch_begin` - `on_epoch_end` -- `on_batch_begin` -- `on_batch_end` +- `on_step_begin` +- `on_step_end` - `on_backward_begin` - `on_backward_end` - `on_optimizer_step_begin` @@ -430,7 +381,7 @@ Exactly like models, we need to register the callback to the Trainer. We can do from refiners.training_utils import CallbackConfig, register_callback class AutoencoderConfig(BaseConfig): - epoch_size: int = 2048 + # ... other properties logging: CallbackConfig = CallbackConfig() @@ -444,6 +395,101 @@ class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): ![alt text](loss-logging.png) +## Evaluation + +Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a `compute_evaluation` method, then we register a callback to call this method at regular intervals. + + +```python +class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): + # ... other methods + + def compute_evaluation(self) -> None: + generator = generate_mask(size=64, seed=0) + + grid: list[tuple[Image.Image, Image.Image]] = [] + for _ in range(4): + mask = next(generator).to(self.device, self.dtype) + x_reconstructed = self.autoencoder.decoder( + self.autoencoder.encoder(mask) + ) + loss = F.mse_loss(x_reconstructed, mask) + logger.info(f"Validation loss: {loss.detach().cpu().item()}") + grid.append( + (tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float())) + ) + + import matplotlib.pyplot as plt + + _, axes = plt.subplots(4, 2, figsize=(8, 16)) + + for i, (mask, reconstructed) in enumerate(grid): + axes[i, 0].imshow(mask, cmap='gray') + axes[i, 0].axis('off') + axes[i, 0].set_title('Mask') + + axes[i, 1].imshow(reconstructed, cmap='gray') + axes[i, 1].axis('off') + axes[i, 1].set_title('Reconstructed') + + plt.tight_layout() + plt.savefig(f"result_{trainer.clock.epoch}.png") + plt.close() +``` + +We starting by implementing an `EvaluationConfig` that controls the evaluation interval and the seed for the random generator. + +```python +from refiners.training_utils.config import TimeValueField + +class EvaluationConfig(CallbackConfig): + interval: TimeValueField + seed: int +``` + +The `TimeValueField` is a custom field that allow Pydantic to parse a string representing a time value (e.g., "50:epochs") into a `TimeValue` object. This is useful to specify the evaluation interval in the configuration file. + +```python +from refiners.training_utils import scoped_seed, Callback + +class EvaluationCallback(Callback[Any]): + def __init__(self, config: EvaluationConfig) -> None: + self.config = config + + def on_epoch_end(self, trainer: Trainer) -> None: + # The `is_due` method checks if the current epoch is a multiple of the interval. + if not trainer.clock.is_due(self.config.interval): + return + + # The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the + # evaluation. + with scoped_seed(self.config.seed): + trainer.compute_evaluation() +``` + +We can now register the callback to the Trainer. + + +```python +class AutoencoderConfig(BaseConfig): + # ... other properties + evaluation: EvaluationConfig +``` + + +```python +class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): + # ... other methods + + @register_callback() + def evaluation(self, config: EvaluationConfig) -> EvaluationCallback: + return EvaluationCallback(config) +``` + +We can now train the model and see the results in the `result_{epoch}.png` files. + +![alt text](evaluation.png) + ## Wrap up You can train this toy model using the code below: @@ -453,7 +499,6 @@ You can train this toy model using the code below: ```py import random from dataclasses import dataclass - from functools import cached_property from typing import Any, Generator import torch @@ -468,6 +513,7 @@ You can train this toy model using the code below: Callback, CallbackConfig, ClockConfig, + Epoch, LRSchedulerConfig, LRSchedulerType, ModelConfig, @@ -477,8 +523,10 @@ You can train this toy model using the code below: TrainingConfig, register_callback, register_model, - Epoch, ) + from refiners.training_utils.common import scoped_seed + from refiners.training_utils.config import TimeValueField + class ConvBlock(fl.Chain): def __init__(self, in_channels: int, out_channels: int) -> None: @@ -615,9 +663,31 @@ You can train this toy model using the code below: self.losses = [] + class EvaluationConfig(CallbackConfig): + interval: TimeValueField + seed: int + + + class EvaluationCallback(Callback["AutoencoderTrainer"]): + def __init__(self, config: EvaluationConfig) -> None: + self.config = config + + def on_epoch_end(self, trainer: "AutoencoderTrainer") -> None: + # The `is_due` method checks if the current epoch is a multiple of the interval. + if not trainer.clock.is_due(self.config.interval): + return + + # The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the + # evaluation. + with scoped_seed(self.config.seed): + trainer.compute_evaluation() + + class AutoencoderConfig(BaseConfig): - epoch_size: int = 2048 + num_images: int = 2048 + batch_size: int = 32 autoencoder: AutoencoderModelConfig + evaluation: EvaluationConfig logging: CallbackConfig = CallbackConfig() @@ -626,11 +696,9 @@ You can train this toy model using the code below: ) training = TrainingConfig( - duration=Epoch(1000), - batch_size=32, + duration=Epoch(200), device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32", - evaluation_interval=Epoch(50), ) optimizer = OptimizerConfig( @@ -645,30 +713,28 @@ You can train this toy model using the code below: optimizer=optimizer, lr_scheduler=lr_scheduler, autoencoder=autoencoder_config, + evaluation=EvaluationConfig(interval=Epoch(50), seed=0), clock=ClockConfig(verbose=False), # to disable the default clock logging ) class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): - @cached_property - def image_generator(self) -> Generator[torch.Tensor, None, None]: - return generate_mask(size=64) + def create_data_iterable(self) -> list[Batch]: + dataset: list[Batch] = [] + generator = generate_mask(size=64) - def get_item(self, index: int) -> Batch: - return Batch(image=next(self.image_generator).to(self.device, self.dtype)) + for _ in range(self.config.num_images // self.config.batch_size): + masks = [next(generator).to(self.device, self.dtype) for _ in range(self.config.batch_size)] + dataset.append(Batch(image=torch.cat(masks, dim=0))) - def collate_fn(self, batch: list[Batch]) -> Batch: - return Batch(image=torch.cat([b.image for b in batch])) - - @property - def dataset_length(self) -> int: - return self.config.epoch_size + return dataset @register_model() def autoencoder(self, config: AutoencoderModelConfig) -> Autoencoder: return Autoencoder() def compute_loss(self, batch: Batch) -> torch.Tensor: + batch.image = batch.image.to(self.device, self.dtype) x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image)) return F.binary_cross_entropy(x_reconstructed, batch.image) @@ -700,9 +766,13 @@ You can train this toy model using the code below: axes[i, 1].axis("off") axes[i, 1].set_title("Reconstructed") - plt.tight_layout() # type: ignore + plt.tight_layout() # type: ignore plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore - plt.close() # type: ignore + plt.close() # type: ignore + + @register_callback() + def evaluation(self, config: EvaluationConfig) -> EvaluationCallback: + return EvaluationCallback(config) @register_callback() def logging(self, config: CallbackConfig) -> LoggingCallback: