mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
better example model for Chain
This commit is contained in:
parent
17b085341a
commit
c02a7b378f
|
@ -9,35 +9,54 @@ When we say models are implemented in a declarative way in Refiners, what this m
|
|||
|
||||
## A first example
|
||||
|
||||
To give you an idea of how it looks, let us take an example similar to the one from the PyTorch paper[^1]:
|
||||
To give you an idea of how it looks, let us take a simple convolution network to classify MNIST as an example. First, let us define a few variables.
|
||||
|
||||
```py
|
||||
img_res = 28
|
||||
channels = 128
|
||||
kernel_size = 3
|
||||
hidden_layer_in = (((img_res - kernel_size + 1) // 2) ** 2) * channels
|
||||
hidden_layer_out = 200
|
||||
output_size = 10
|
||||
```
|
||||
|
||||
Now, here is the model in PyTorch:
|
||||
|
||||
|
||||
```py
|
||||
class BasicModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 128, 3)
|
||||
self.linear_1 = nn.Linear(128, 40)
|
||||
self.linear_2 = nn.Linear(40, 10)
|
||||
self.conv = nn.Conv2d(1, channels, kernel_size)
|
||||
self.linear_1 = nn.Linear(hidden_layer_in, hidden_layer_out)
|
||||
self.maxpool = nn.MaxPool2d(2)
|
||||
self.linear_2 = nn.Linear(hidden_layer_out, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
t1 = self.conv(x)
|
||||
t2 = nn.functional.relu(t1)
|
||||
t3 = self.linear_1(t2)
|
||||
t4 = self.linear_2(t3)
|
||||
return nn.functional.softmax(t4)
|
||||
x = self.conv(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.maxpool(x)
|
||||
x = x.flatten(start_dim=1)
|
||||
x = self.linear_1(x)
|
||||
x = nn.functional.relu(x)
|
||||
x = self.linear_2(x)
|
||||
return nn.functional.softmax(x, dim=0)
|
||||
```
|
||||
|
||||
Here is how we could implement the same model in Refiners:
|
||||
And here is how we could implement the same model in Refiners:
|
||||
|
||||
```py
|
||||
class BasicModel(fl.Chain):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
fl.Conv2d(1, 128, 3),
|
||||
fl.Conv2d(1, channels, kernel_size),
|
||||
fl.ReLU(),
|
||||
fl.Linear(128, 40),
|
||||
fl.Linear(40, 10),
|
||||
fl.Lambda(torch.nn.functional.softmax),
|
||||
fl.MaxPool2d(2),
|
||||
fl.Flatten(start_dim=1),
|
||||
fl.Linear(hidden_layer_in, hidden_layer_out),
|
||||
fl.ReLU(),
|
||||
fl.Linear(hidden_layer_out, output_size),
|
||||
fl.Lambda(lambda x: torch.nn.functional.softmax(x, dim=0)),
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -49,7 +68,7 @@ As of writing, Refiners does not include a `Softmax` layer by default, but as yo
|
|||
```py
|
||||
class Softmax(fl.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.softmax(x)
|
||||
return torch.nn.functional.softmax(x, dim=0)
|
||||
```
|
||||
|
||||
!!! note
|
||||
|
@ -64,9 +83,12 @@ Let us instantiate the `BasicModel` we just defined and inspect its representati
|
|||
>>> m
|
||||
(CHAIN) BasicModel()
|
||||
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
|
||||
├── ReLU()
|
||||
├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1
|
||||
├── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2
|
||||
├── ReLU() #1
|
||||
├── MaxPool2d(kernel_size=2, stride=2)
|
||||
├── Flatten(start_dim=1)
|
||||
├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) #1
|
||||
├── ReLU() #2
|
||||
├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32) #2
|
||||
└── Softmax()
|
||||
```
|
||||
|
||||
|
@ -78,17 +100,28 @@ The children of a `Chain` are stored in a dictionary and can be accessed by name
|
|||
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
|
||||
>>> m.Conv2d
|
||||
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
|
||||
>>> m[3]
|
||||
Linear(in_features=40, out_features=10, device=cpu, dtype=float32)
|
||||
>>> m[6]
|
||||
Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
|
||||
>>> m.Linear_2
|
||||
Linear(in_features=40, out_features=10, device=cpu, dtype=float32)
|
||||
Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
|
||||
```
|
||||
|
||||
The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to wrap the two `Linear`s in a subchain. Here is how I could do it:
|
||||
The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to organize my model by wrapping each layer of the convnet in a subchain. Here is how I could do it:
|
||||
|
||||
|
||||
```py
|
||||
m.insert_after_type(fl.ReLU, fl.Chain(m.pop(2), m.pop(2)))
|
||||
class ConvLayer(fl.Chain):
|
||||
pass
|
||||
|
||||
class HiddenLayer(fl.Chain):
|
||||
pass
|
||||
|
||||
class OutputLayer(fl.Chain):
|
||||
pass
|
||||
|
||||
m.insert(0, ConvLayer(m.pop(0), m.pop(0), m.pop(0)))
|
||||
m.insert_after_type(ConvLayer, HiddenLayer(m.pop(1), m.pop(1), m.pop(1)))
|
||||
m.append(OutputLayer(m.pop(2), m.pop(2)))
|
||||
```
|
||||
|
||||
Did it work? Let's see:
|
||||
|
@ -96,14 +129,22 @@ Did it work? Let's see:
|
|||
```
|
||||
>>> m
|
||||
(CHAIN) BasicModel()
|
||||
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
|
||||
├── ReLU()
|
||||
├── (CHAIN)
|
||||
│ ├── Linear(in_features=128, out_features=40, device=cpu, dtype=float32) #1
|
||||
│ └── Linear(in_features=40, out_features=10, device=cpu, dtype=float32) #2
|
||||
├── (CHAIN) ConvLayer()
|
||||
│ ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
|
||||
│ ├── ReLU()
|
||||
│ └── MaxPool2d(kernel_size=2, stride=2)
|
||||
├── (CHAIN) HiddenLayer()
|
||||
│ ├── Flatten(start_dim=1)
|
||||
│ ├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)
|
||||
│ └── ReLU()
|
||||
└── (CHAIN) OutputLayer()
|
||||
├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
|
||||
└── Softmax()
|
||||
```
|
||||
|
||||
!!! note
|
||||
Organizing models like this is actually a good idea, it makes them easier to understand and adapt.
|
||||
|
||||
## Accessing and iterating
|
||||
|
||||
There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named [`walk`][refiners.fluxion.layers.Chain.walk]. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do:
|
||||
|
@ -117,8 +158,6 @@ It prints:
|
|||
|
||||
```
|
||||
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
|
||||
Linear(in_features=128, out_features=40, device=cpu, dtype=float32)
|
||||
Linear(in_features=40, out_features=10, device=cpu, dtype=float32
|
||||
Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)
|
||||
Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
|
||||
```
|
||||
|
||||
[^1]: Paszke et al., 2019. PyTorch: An Imperative Style, High-Performance Deep Learning Library.
|
||||
|
|
Loading…
Reference in a new issue