better example model for Chain

This commit is contained in:
Pierre Chapuis 2024-02-02 17:02:06 +01:00 committed by Cédric Deltheil
parent 17b085341a
commit c02a7b378f

View file

@ -9,35 +9,54 @@ When we say models are implemented in a declarative way in Refiners, what this m
## A first example ## A first example
To give you an idea of how it looks, let us take an example similar to the one from the PyTorch paper[^1]: To give you an idea of how it looks, let us take a simple convolution network to classify MNIST as an example. First, let us define a few variables.
```py
img_res = 28
channels = 128
kernel_size = 3
hidden_layer_in = (((img_res - kernel_size + 1) // 2) ** 2) * channels
hidden_layer_out = 200
output_size = 10
```
Now, here is the model in PyTorch:
```py ```py
class BasicModel(nn.Module): class BasicModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = nn.Conv2d(1, 128, 3) self.conv = nn.Conv2d(1, channels, kernel_size)
self.linear_1 = nn.Linear(128, 40) self.linear_1 = nn.Linear(hidden_layer_in, hidden_layer_out)
self.linear_2 = nn.Linear(40, 10) self.maxpool = nn.MaxPool2d(2)
self.linear_2 = nn.Linear(hidden_layer_out, output_size)
def forward(self, x): def forward(self, x):
t1 = self.conv(x) x = self.conv(x)
t2 = nn.functional.relu(t1) x = nn.functional.relu(x)
t3 = self.linear_1(t2) x = self.maxpool(x)
t4 = self.linear_2(t3) x = x.flatten(start_dim=1)
return nn.functional.softmax(t4) x = self.linear_1(x)
x = nn.functional.relu(x)
x = self.linear_2(x)
return nn.functional.softmax(x, dim=0)
``` ```
Here is how we could implement the same model in Refiners: And here is how we could implement the same model in Refiners:
```py ```py
class BasicModel(fl.Chain): class BasicModel(fl.Chain):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
fl.Conv2d(1, 128, 3), fl.Conv2d(1, channels, kernel_size),
fl.ReLU(), fl.ReLU(),
fl.Linear(128, 40), fl.MaxPool2d(2),
fl.Linear(40, 10), fl.Flatten(start_dim=1),
fl.Lambda(torch.nn.functional.softmax), fl.Linear(hidden_layer_in, hidden_layer_out),
fl.ReLU(),
fl.Linear(hidden_layer_out, output_size),
fl.Lambda(lambda x: torch.nn.functional.softmax(x, dim=0)),
) )
``` ```
@ -49,7 +68,7 @@ As of writing, Refiners does not include a `Softmax` layer by default, but as yo
```py ```py
class Softmax(fl.Module): class Softmax(fl.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softmax(x) return torch.nn.functional.softmax(x, dim=0)
``` ```
!!! note !!! note
@ -64,9 +83,12 @@ Let us instantiate the `BasicModel` we just defined and inspect its representati
>>> m >>> m
(CHAIN) BasicModel() (CHAIN) BasicModel()
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
├── ReLU() ├── ReLU() #1
├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1 ├── MaxPool2d(kernel_size=2, stride=2)
├── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2 ├── Flatten(start_dim=1)
├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) #1
├── ReLU() #2
├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32) #2
└── Softmax() └── Softmax()
``` ```
@ -78,17 +100,28 @@ The children of a `Chain` are stored in a dictionary and can be accessed by name
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
>>> m.Conv2d >>> m.Conv2d
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
>>> m[3] >>> m[6]
Linear(in_features=40, out_features=10, device=cpu, dtype=float32) Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
>>> m.Linear_2 >>> m.Linear_2
Linear(in_features=40, out_features=10, device=cpu, dtype=float32) Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
``` ```
The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to wrap the two `Linear`s in a subchain. Here is how I could do it: The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to organize my model by wrapping each layer of the convnet in a subchain. Here is how I could do it:
```py ```py
m.insert_after_type(fl.ReLU, fl.Chain(m.pop(2), m.pop(2))) class ConvLayer(fl.Chain):
pass
class HiddenLayer(fl.Chain):
pass
class OutputLayer(fl.Chain):
pass
m.insert(0, ConvLayer(m.pop(0), m.pop(0), m.pop(0)))
m.insert_after_type(ConvLayer, HiddenLayer(m.pop(1), m.pop(1), m.pop(1)))
m.append(OutputLayer(m.pop(2), m.pop(2)))
``` ```
Did it work? Let's see: Did it work? Let's see:
@ -96,14 +129,22 @@ Did it work? Let's see:
``` ```
>>> m >>> m
(CHAIN) BasicModel() (CHAIN) BasicModel()
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) ├── (CHAIN) ConvLayer()
├── ReLU() │ ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
├── (CHAIN) │ ├── ReLU()
│ ├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1 │ └── MaxPool2d(kernel_size=2, stride=2)
│ └── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2 ├── (CHAIN) HiddenLayer()
│ ├── Flatten(start_dim=1)
│ ├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)
│ └── ReLU()
└── (CHAIN) OutputLayer()
├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
└── Softmax() └── Softmax()
``` ```
!!! note
Organizing models like this is actually a good idea, it makes them easier to understand and adapt.
## Accessing and iterating ## Accessing and iterating
There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named [`walk`][refiners.fluxion.layers.Chain.walk]. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do: There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named [`walk`][refiners.fluxion.layers.Chain.walk]. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do:
@ -117,8 +158,6 @@ It prints:
``` ```
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32) Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
Linear(in_features=128, out_features=40, device=cpu, dtype=float32) Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)
Linear(in_features=40, out_features=10, device=cpu, dtype=float32 Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
``` ```
[^1]: Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library.