refiners/docs/concepts/chain.md

164 lines
5.7 KiB
Markdown
Raw Permalink Normal View History

2024-02-01 17:32:05 +00:00
---
icon: material/family-tree
---
2024-01-31 11:44:43 +00:00
# Chain
2024-02-02 10:48:46 +00:00
When we say models are implemented in a declarative way in Refiners, what this means in practice is they are implemented as Chains. [`Chain`][refiners.fluxion.layers.Chain] is a Python class to implement trees of modules. It is a subclass of Refiners' [`Module`][refiners.fluxion.layers.Module], which is in turn a subclass of PyTorch's `Module`. All inner nodes of a Chain are subclasses of `Chain`, and leaf nodes are subclasses of Refiners' `Module`.
2024-01-31 11:44:43 +00:00
## A first example
2024-02-02 16:02:06 +00:00
To give you an idea of how it looks, let us take a simple convolution network to classify MNIST as an example. First, let us define a few variables.
```py
img_res = 28
channels = 128
kernel_size = 3
hidden_layer_in = (((img_res - kernel_size + 1) // 2) ** 2) * channels
hidden_layer_out = 200
output_size = 10
```
Now, here is the model in PyTorch:
2024-01-31 11:44:43 +00:00
```py
class BasicModel(nn.Module):
def __init__(self):
super().__init__()
2024-02-02 16:02:06 +00:00
self.conv = nn.Conv2d(1, channels, kernel_size)
self.linear_1 = nn.Linear(hidden_layer_in, hidden_layer_out)
self.maxpool = nn.MaxPool2d(2)
self.linear_2 = nn.Linear(hidden_layer_out, output_size)
2024-01-31 11:44:43 +00:00
def forward(self, x):
2024-02-02 16:02:06 +00:00
x = self.conv(x)
x = nn.functional.relu(x)
x = self.maxpool(x)
x = x.flatten(start_dim=1)
x = self.linear_1(x)
x = nn.functional.relu(x)
x = self.linear_2(x)
return nn.functional.softmax(x, dim=0)
2024-01-31 11:44:43 +00:00
```
2024-02-02 16:02:06 +00:00
And here is how we could implement the same model in Refiners:
2024-01-31 11:44:43 +00:00
```py
class BasicModel(fl.Chain):
def __init__(self):
super().__init__(
2024-02-02 16:02:06 +00:00
fl.Conv2d(1, channels, kernel_size),
fl.ReLU(),
fl.MaxPool2d(2),
fl.Flatten(start_dim=1),
fl.Linear(hidden_layer_in, hidden_layer_out),
2024-01-31 11:44:43 +00:00
fl.ReLU(),
2024-02-02 16:02:06 +00:00
fl.Linear(hidden_layer_out, output_size),
fl.Lambda(lambda x: torch.nn.functional.softmax(x, dim=0)),
2024-01-31 11:44:43 +00:00
)
```
2024-02-02 10:48:46 +00:00
!!! note
We often use the namespace `fl` which means `fluxion`, which is the name of the part of Refiners that implements basic layers.
2024-01-31 11:44:43 +00:00
2024-02-02 10:48:46 +00:00
As of writing, Refiners does not include a `Softmax` layer by default, but as you can see you can easily call arbitrary code using [`fl.Lambda`][refiners.fluxion.layers.Lambda]. Alternatively, if you just wanted to write `Softmax()`, you could implement it like this:
2024-01-31 11:44:43 +00:00
```py
class Softmax(fl.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
2024-02-02 16:02:06 +00:00
return torch.nn.functional.softmax(x, dim=0)
2024-01-31 11:44:43 +00:00
```
2024-02-02 10:48:46 +00:00
!!! note
Notice the type hints here. All of Refiners' codebase is typed, which makes it a pleasure to use if your downstream code is typed too.
2024-01-31 11:44:43 +00:00
## Inspecting and manipulating
Let us instantiate the `BasicModel` we just defined and inspect its representation in a Python REPL:
```
>>> m = BasicModel()
>>> m
(CHAIN) BasicModel()
├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
2024-02-02 16:02:06 +00:00
├── ReLU() #1
├── MaxPool2d(kernel_size=2, stride=2)
├── Flatten(start_dim=1)
├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32) #1
├── ReLU() #2
├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32) #2
2024-01-31 11:44:43 +00:00
└── Softmax()
```
The children of a `Chain` are stored in a dictionary and can be accessed by name or index. When layers of the same type appear in the Chain, distinct suffixed keys are automatically generated.
```
>>> m[0]
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
>>> m.Conv2d
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
2024-02-02 16:02:06 +00:00
>>> m[6]
Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
2024-01-31 11:44:43 +00:00
>>> m.Linear_2
2024-02-02 16:02:06 +00:00
Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
2024-01-31 11:44:43 +00:00
```
2024-02-02 16:02:06 +00:00
The Chain class includes several helpers to manipulate the tree. For instance, imagine I want to organize my model by wrapping each layer of the convnet in a subchain. Here is how I could do it:
2024-01-31 11:44:43 +00:00
```py
2024-02-02 16:02:06 +00:00
class ConvLayer(fl.Chain):
pass
class HiddenLayer(fl.Chain):
pass
class OutputLayer(fl.Chain):
pass
m.insert(0, ConvLayer(m.pop(0), m.pop(0), m.pop(0)))
m.insert_after_type(ConvLayer, HiddenLayer(m.pop(1), m.pop(1), m.pop(1)))
m.append(OutputLayer(m.pop(2), m.pop(2)))
2024-01-31 11:44:43 +00:00
```
Did it work? Let's see:
```
>>> m
(CHAIN) BasicModel()
2024-02-02 16:02:06 +00:00
├── (CHAIN) ConvLayer()
│ ├── Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
│ ├── ReLU()
│ └── MaxPool2d(kernel_size=2, stride=2)
├── (CHAIN) HiddenLayer()
│ ├── Flatten(start_dim=1)
│ ├── Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)
│ └── ReLU()
└── (CHAIN) OutputLayer()
├── Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
└── Softmax()
2024-01-31 11:44:43 +00:00
```
2024-02-02 16:02:06 +00:00
!!! note
Organizing models like this is actually a good idea, it makes them easier to understand and adapt.
2024-01-31 11:44:43 +00:00
## Accessing and iterating
There are also many ways to access or iterate nodes even if they are deep in the tree. Most of them are implemented using a powerful iterator named [`walk`][refiners.fluxion.layers.Chain.walk]. However, most of the time, you can use simpler helpers. For instance, to iterate all the modules in the tree that hold weights (the `Conv2d` and the `Linear`s), we can just do:
2024-01-31 11:44:43 +00:00
```py
for x in m.layers(fl.WeightedModule):
print(x)
```
It prints:
```
Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), device=cpu, dtype=float32)
2024-02-02 16:02:06 +00:00
Linear(in_features=21632, out_features=200, device=cpu, dtype=float32)
Linear(in_features=200, out_features=10, device=cpu, dtype=float32)
2024-01-31 11:44:43 +00:00
```