From c02a7b378f624b2de47cc1ac81beb7a91c65a318 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 2 Feb 2024 17:02:06 +0100 Subject: [PATCH] better example model for Chain --- docs/concepts/chain.md | 105 ++++++++++++++++++++++++++++------------- 1 file changed, 72 insertions(+), 33 deletions(-) diff --git a/docs/concepts/chain.md b/docs/concepts/chain.md index 98a3aa1..9e33f94 100644 --- a/docs/concepts/chain.md +++ b/docs/concepts/chain.md @@ -9,35 +9,54 @@ When we say models are implemented in a declarative way in Refiners, what this m ## A first example -To give you an idea of how it looks, let us take an example similar to the one from the PyTorch paper[^1]: +To give you an idea of how it looks, let us take a simple convolution network to classify MNIST as an example. First, let us define a few variables. + +```py +img_res = 28 +channels = 128 +kernel_size = 3 +hidden_layer_in = (((img_res - kernel_size + 1) // 2) ** 2) * channels +hidden_layer_out = 200 +output_size = 10 +``` + +Now, here is the model in PyTorch: + ```py class BasicModel(nn.Module): def __init__(self): super().__init__() - self.conv = nn.Conv2d(1, 128, 3) - self.linear_1 = nn.Linear(128, 40) - self.linear_2 = nn.Linear(40, 10) + self.conv = nn.Conv2d(1, channels, kernel_size) + self.linear_1 = nn.Linear(hidden_layer_in, hidden_layer_out) + self.maxpool = nn.MaxPool2d(2) + self.linear_2 = nn.Linear(hidden_layer_out, output_size) def forward(self, x): - t1 = self.conv(x) - t2 = nn.functional.relu(t1) - t3 = self.linear_1(t2) - t4 = self.linear_2(t3) - return nn.functional.softmax(t4) + x = self.conv(x) + x = nn.functional.relu(x) + x = self.maxpool(x) + x = x.flatten(start_dim=1) + x = self.linear_1(x) + x = nn.functional.relu(x) + x = self.linear_2(x) + return nn.functional.softmax(x, dim=0) ``` -Here is how we could implement the same model in Refiners: +And here is how we could implement the same model in Refiners: ```py class BasicModel(fl.Chain): def __init__(self): super().__init__( - fl.Conv2d(1, 128, 3), + fl.Conv2d(1, channels, kernel_size), fl.ReLU(), - fl.Linear(128, 40), - fl.Linear(40, 10), - fl.Lambda(torch.nn.functional.softmax), + fl.MaxPool2d(2), + fl.Flatten(start_dim=1), + fl.Linear(hidden_layer_in, hidden_layer_out), + fl.ReLU(), + fl.Linear(hidden_layer_out, output_size), + fl.Lambda(lambda x: torch.nn.functional.softmax(x, dim=0)), ) ``` @@ -49,7 +68,7 @@ As of writing, Refiners does not include a `Softmax` layer by default, but as yo ```py class Softmax(fl.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.softmax(x) + return torch.nn.functional.softmax(x, dim=0) ``` !!! note @@ -64,9 +83,12 @@ Let us instantiate the `BasicModel` we just defined and inspect its representati >>> m (CHAIN) BasicModel() ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) - ├── ReLU() - ├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1 - ├── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2 + ├── ReLU() #1 + ├── MaxPool2d(kernel_size=2, stride=2) + ├── Flatten(start_dim=1) + ├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) #1 + ├── ReLU() #2 + ├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32) #2 └── Softmax() ``` @@ -78,17 +100,28 @@ The children of a `Chain` are stored in a dictionary and can be accessed by name Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) >>> m.Conv2d Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) ->>> m[3] -Linear(in_features=40, out_features=10, device=cpu, dtype=float32) +>>> m[6] +Linear(in_features=200, out_features=10, device=cpu, dtype=float32) >>> m.Linear_2 -Linear(in_features=40, out_features=10, device=cpu, dtype=float32) +Linear(in_features=200, out_features=10, device=cpu, dtype=float32) ``` -The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to wrap the two `Linear`s in a subchain. Here is how I could do it: +The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to organize my model by wrapping each layer of the convnet in a subchain. Here is how I could do it: ```py -m.insert_after_type(fl.ReLU, fl.Chain(m.pop(2), m.pop(2))) +class ConvLayer(fl.Chain): + pass + +class HiddenLayer(fl.Chain): + pass + +class OutputLayer(fl.Chain): + pass + +m.insert(0, ConvLayer(m.pop(0), m.pop(0), m.pop(0))) +m.insert_after_type(ConvLayer, HiddenLayer(m.pop(1), m.pop(1), m.pop(1))) +m.append(OutputLayer(m.pop(2), m.pop(2))) ``` Did it work? Let's see: @@ -96,14 +129,22 @@ Did it work? Let's see: ``` >>> m (CHAIN) BasicModel() - ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) - ├── ReLU() - ├── (CHAIN) - │ ├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1 - │ └── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2 - └── Softmax() + ├── (CHAIN) ConvLayer() + │ ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) + │ ├── ReLU() + │ └── MaxPool2d(kernel_size=2, stride=2) + ├── (CHAIN) HiddenLayer() + │ ├── Flatten(start_dim=1) + │ ├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) + │ └── ReLU() + └── (CHAIN) OutputLayer() + ├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32) + └── Softmax() ``` +!!! note + Organizing models like this is actually a good idea, it makes them easier to understand and adapt. + ## Accessing and iterating There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named [`walk`][refiners.fluxion.layers.Chain.walk]. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do: @@ -117,8 +158,6 @@ It prints: ``` Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) -Linear(in_features=128, out_features=40, device=cpu, dtype=float32) -Linear(in_features=40, out_features=10, device=cpu, dtype=float32 +Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) +Linear(in_features=200, out_features=10, device=cpu, dtype=float32) ``` - -[^1]: Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library.