mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
1575 lines
481 KiB
Plaintext
1575 lines
481 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Refiners Demo\n",
|
||
|
"\n",
|
||
|
"This notebook aims to demonstrate the basics of using the [Refiners](https://github.com/finegrain-ai/refiners) micro-framework.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# to run you need to have `Refiners` installed (uncomment the line below)\n",
|
||
|
"# %pip install git+https://github.com/finegrain-ai/refiners.git"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"from refiners.fluxion import layers as fl, manual_seed\n",
|
||
|
"from torch import nn\n",
|
||
|
"\n",
|
||
|
"torch.set_grad_enabled(mode=False)\n",
|
||
|
"manual_seed(82570858)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Basics"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The core idea of Refiners is to improve on the `Sequential` API of PyTorch.\n",
|
||
|
"\n",
|
||
|
"A `Sequential` is defined by:\n",
|
||
|
"\n",
|
||
|
"`Sequential([layer1, layer2, layer3])(x) = layer3(layer2(layer1(x)))`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Sequential(\n",
|
||
|
" (0): Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" (1): ReLU()\n",
|
||
|
" (2): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" (3): ReLU()\n",
|
||
|
" (4): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" (5): ReLU()\n",
|
||
|
" (6): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Native PyTorch sequential\n",
|
||
|
"sequential = nn.Sequential(\n",
|
||
|
" fl.Conv2d(3, 32, 3, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"sequential"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(CHAIN)\n",
|
||
|
" ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #1\n",
|
||
|
" ├── ReLU() #1\n",
|
||
|
" ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #2\n",
|
||
|
" ├── ReLU() #2\n",
|
||
|
" ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #3\n",
|
||
|
" ├── ReLU() #3\n",
|
||
|
" └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #4"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Same as above, but with a Fluxion Chain\n",
|
||
|
"chain = fl.Chain(\n",
|
||
|
" fl.Conv2d(3, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"chain"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Note here that the keys of the Chain are the names of the layers, whereas in PyTorch Sequential API, the keys are the indices of the layers.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Sequential keys:\n",
|
||
|
"0\n",
|
||
|
"1\n",
|
||
|
"2\n",
|
||
|
"3\n",
|
||
|
"4\n",
|
||
|
"5\n",
|
||
|
"6\n",
|
||
|
"\n",
|
||
|
"Chain keys:\n",
|
||
|
"Conv2d_1\n",
|
||
|
"ReLU_1\n",
|
||
|
"Conv2d_2\n",
|
||
|
"ReLU_2\n",
|
||
|
"Conv2d_3\n",
|
||
|
"ReLU_3\n",
|
||
|
"Conv2d_4\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(\"Sequential keys:\")\n",
|
||
|
"for key, _ in sequential.named_children():\n",
|
||
|
" print(key)\n",
|
||
|
"\n",
|
||
|
"print(\"\\nChain keys:\")\n",
|
||
|
"for key, _ in chain.named_children():\n",
|
||
|
" print(key)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"This choice is made because when a model is simple, it is easy to remember the indices of the layers, but when a model is complex, it is hard to remember the indices of the layers.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We also improved on the Errors to showcase exactly where the error is coming from.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"x = torch.randn(1, 4, 32, 32)\n",
|
||
|
"# uncomment to run\n",
|
||
|
"# sequential(x)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {
|
||
|
"image.png": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABb4AAAFdCAYAAADMhHtHAAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAAFvqADAAQAAAABAAABXQAAAABum17MAABAAElEQVR4Aex9B5wcxbF+3e3u7d5eDso5SygAIhshMiIamYxxwMbxOTw/Z//tZ/s545ywjQEbG7CxCcaYnIQQIgskISFAEsrxct69vbv919e7PTczO7O36e72pKrfb3d6ejrNNzPd1dXVVQVzFhwTpRxQJBwib5GfCgpQmPrLQalShCAgCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAIJAeAoXpJZfUgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAKCQH4jIILv/H4+0jpBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEgTQREMF3moBJckFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBASB/EZABN/5/XykdYKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCaSIggu80AZPkgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAL5jYAIvvP7+UjrBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBAQBQUAQEATSREAE32kCJskFAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBPIbARF85/fzkdYJAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAKCgCAgCKSJgAi+0wRMkgsCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAI5DcCIvjO7+cjrRMEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBAQBQUAQSBMBEXynCZgkFwQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBDIbwRE8J3fz0daJwgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAKCgCCQJgIi+E4TMEkuCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAIJDfCIjgO7+fj7ROEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBAQBQSBNBETwnSZgklwQEAQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBIL8REMF3fj8faZ0gIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAKCQJoIiOA7TcAkuSAgCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoJAfiMggu/8fj7SOkFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBASBNBEQwXeagElyQUAQEAQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBIH8RkAE3/n9fKR1goAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAICAJpIiCC7zQBk+SCgCAgCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAvmNgAi+8/v5SOsEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBAQBNJEQATfaQImyQUBQUAQEAQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAE8hsBEXzn9/OR1gkCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAICAIpImACL7TBEySCwKCgCAgCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAjkNwIi+M7v5yOtEwQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEBAEBAFBQBBIEwERfKcJmCQXBAQBQUAQEAQEAUFAEBAEBAFBQBAQBAQBQUAQEAQEAUFAEMhvBETwnd/PR1onCAgCgoAgIAgIAoKAICAICAKCgCAgCAgCgoAgIAgIAoKAIJAmAt4004+o5GPHjacir0+1ua7+IHV1dY2o9ktjBQFBQBAQBBiBQi8FqhcYUITq13O4zziXgCAgCBD5q+dRQaFfQdHd9g71hVsFFkFAEMgTBLylk8gbqFGt6emqp56O3XnSMmmGICAICAKCgCAgCAgChzYCh7Tg+6qrP0hF/tgk8PlnV9KqVSsSnmZhoYcmTZ6cEI+I3kgP7d6zy/Ha4RBZ6K8iX8V0ivLkubtl8+Fwy8N+j4K58yMoqphFBf5y6us8SJH2kfVNFo0+kgrIOyLb7vw0Mo/1FI8i/6gjyVs+jeXW3dTd9DZ1N77JArqmpIV6/DVUfeavjTR1j3yYIi1bjHMJDIyAr3o+rx8EqKfpLR7b2gfOICmyRsAVc17I8dce6Vx+b4TCDVjYSZ+qT/sFFfiCKmP7xr9Q64ZbXAvxjz5GXYvw2J5LATl4Kl/tUVTQ10uh+rWu9Y/ECxBceoKjhScaiQ/Poc3essmEMam3Yw8Lofc7pMhtVOUxX6CisYtVoREe9+oe/1huK5DSskLAVzGDCv2VRD2dFG7clFVZkjmGgOsYeBgDpMdIQJBrfkyXPRLH36Huj3P5ChYwb13EPDYo0rCB+nrDOSneX8NlegKOZfU0bqTenpDjtZEQKX3DSHhKh14b80rwPWnSZAr4i1NCed/+vdTe3pZSWpWooMAxbVlpGV3JAnI3uvGGX1JLa7Pb5UM6vuLYL1LxxFMo2tVE++6/OOFefeVTyVMyMSEempi97ft4MrGT+njyK5Q6AgNhnnpJh1bKyqU/Im+QJ6jte+nAg1eNqJurOe1XVFBQSOHdq6hh9ddHVNtz1Vgw45XHfZ38U85QWNjLbd94GwvpbrJHu54X8FKCUOoIYEFt1Nm/VxlaXvg+dex4NPXMkjIjBJJh7i2upZrTf+Fa7sEHrsiBIC65JTtdf+fWB6j5lR+7tiXdC8HJ51D5CV9T2fbfd2FOherptiXX6SsWf5b8406ggmgv7fnn6bkuftDKw7uoJ+UDVdJ94PlB5dswFhSNOZ6b4eEfL47se5GPqe/egXDEWxpTVulp30k9bTsHuiXX6zVn/pY8LOgM717JY/P/uqaTC4cHApXv+jb5yqdQb7iZDtz37sPjpgfxLpONgYNYbf4X7Ss1xv+OTX+jlvV/yFmbR/L4O5L7Y1/JJOOZtrzyC+rY+q+cPNPq035OBV5nuVjb2hup7a07clLPUBcifcNQIz5y64PiY2FwjLqBntbtrjvlPLybzlc1b8AbzSvB9yWXXU3+gPMHbr+T19a8RI8/9pA9OufnBYXJJ485rzCPCizmCR6oa8cTjq0qO+KDFJhypuM1HQmtlpYXf0jdrdt0lByTIDAQ5kmyDumlyiM/Q97KqdT2+s2iGTOkyI/MyiD0Dkw9y2h8NMqCDtaI0AxdAWvACg0eAqWzLlGFQ2DXsevJwatISjYQyA7zofwecruIVDxrucKgtyW/Ta34q+ZS2aKPUk/rDmp+rX83ifEAD6FAcNyJxmLEQLeFxQoaJBM5/rHHU/WJ3+DdW6xVG6cD919OvV0H9KnrERp1lYu/RMXTzjbSdO9/lepXfs44Tyfgq5iphN7I07b57nSyDnHaQqpZer1aMK5f+YUhrluqyycEsDuB7b4RMf/UG6rLp6Y5tiW9MfBwfc8Pz/HX/sKMnP7Y3nKn89w+U6caEBctHJp63OrPJj69viGbmiTvSEYA/ULNspsMhbnOzfdS86u/dLyl0jlXUcncKx2vmSOHcnZlrndIwiueeoyCxbFtv++887ZjndDm/tVPf2C5dsLJS+nEk5ZY4g63k+Lxp7BCTpG67fYtd2V8+77quTTq3D/RgYfeSz2sBS7kjkCuMHevIXdXgrN4BwC/H55trDkqW0JzB+yhWBLMOrCmNwiC18aVX6SuA6/xGQu/C32siThP7RBRCVz++iIt1Pp6v0Z4b+fAghKXog7L6OIp56j77jqwlmHvOSwxGOqbToY5zCocuHuZpUklvJBcesR7LXHpnrSu
|
||
|
}
|
||
|
},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"![image.png](attachment:image.png)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# uncomment to run\n",
|
||
|
"# chain(x)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {
|
||
|
"image.png": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABaMAAAGbCAYAAADZZYy2AAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAAFo6ADAAQAAAABAAABmwAAAAANIXnYAABAAElEQVR4Aey9B3QVR7rvWwIJIYFAEgIkRM4552wMJphsjMF5POPxpHPmnJl7zn1nrXvXe+u+t869c2bOmfGMZ2yPc8I5YGObnHPG5JyzQASBJITQ+/4lVat3795Je29pS/y/taTurq6urvp1766qr776Kq5T934lSqSosEDF10lUcXE40v+wQyEBEiABEiABEiABEiABEiABEiABEiABEiABEiABEiCBsAnUCjsFJkACJEACJEACJEACJEACJEACJEACJEACJEACJEACJEACAQhQGR0AEE+TAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmET4DK6PAZMgUSIAESIAESIAESIAESIAESIAESIAESIAESIAESIIEABKiMDgCIp0mABEiABEiABEiABEiABEiABEiABEiABEiABEiABMInQGV0+AyZAgmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQQAACVEYHAMTTJEACJEACJEACJEACJEACJEACJEACJEACJEACJEAC4ROgMjp8hkyBBEiABEiABEiABEiABEiABEiABEiABEiABEiABEggAAEqowMA4mkSIAESIAESIAESIAESIAESIAESIAESIAESIAESIIHwCVAZHT5DpkACJEACJEACJEACJEACJEACJEACJEACJEACJEACJBCAAJXRAQDxNAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQQPgEqIwOnyFTIAESIAESIAESIAESIAESIAESIAESIAESIAESIAESCECAyugAgHiaBEiABEiABEiABEiABEiABEiABEiABEiABEiABEggfAJURofPkCmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAkEIEBldABAPE0CJEACJEACJEACJEACJEACJEACJEACJEACJEACJBA+ASqjw2fIFEiABEiABEiABEiABEiABEiABEiABEiABEiABEiABAIQoDI6ACCeJgESIAESIAESIAESIAESIAESIAESIAESIAESIAESCJ8AldHhM2QKJEACJEACJEACJEACJEACJEACJEACJEACJEACJEACAQhQGR0AEE+TAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmQAAmET4DK6PAZMgUSIAESIAESIAESIAESIAESIAESIAESIAESIAESIIEABOIDnK/S05lZzVSd+ASdh8s5l1R+fn6V5oc3JwESiA0CaWnpKqV+is7MjZs31bVrV2MjY8xFzBFISq6nGjfK0PkqvFOgLl68GHN5ZIZIgARIgARIwBcB9od8kWE4CZAACZAACZBAdSUQ08roOXOfUXUSEzXbDWtXqTVrVlRXzsw3CZBABAk8NGGKatW6jU7x/Pmz6r23X4tg6uVJ1a5dSzVv0UoHXDx/QRUU3t8DYvVlAKCF8Mho3EQV3y1WFy6eVefPn1f5t2+VQ4uxvR7de6rRD47Xubp79676r9//fzGWw8hnp6U8ozh5d53Cd9hJhMfVkUC6DC6lpJQORtrzf+3aNXX9Wq49iPskEDUCzbNbqNoJ8Son57K6lZcXtfsgYfaHQsPLOjA0Xox9/xCoXbu29Gtaehf4Xok6eeqEd/h9HsJ+4P31AsCAqUmTJqqk+J46dfrk/VX4CpS2Vq3aqkVLl++JpFVcdFedOXs6YKpRVUbDerFV67aqsSgubuXdUlevXlZHjx5WRUVFATPmFSEuziuIASRQ3QmkpTdSGemlVps5V3NU7tUr1b1INSr/dRLrqsdkUAyyecNatXLl0hpVvmALg8bYxIenqy5de6g4l2/xhnWr1ZrVy4NNjvGiTOCxJ551fU5rV61Q69ev8rp7o4zGKj013SvcGYDBmNOnTzmDreOGqWkKCprU9HSVLA26a6IYvCKzmk4cP67u3Su24sXKTqTKXdnlqVWrlmrXtkNQtz134WzUFWVBZaQsUrZ0gpMTkwJegtkul0XJ5yYTJk517UxfkfhvvPZXt0sYRgIRJzAX31lR7Bw9ckh9/um8iKfvM0GXOthn3Pv0RKh1oBNTUnKyat6shQ7Oz78dVIfamUZVHzdt2lQ1SEnV2cC3NNZmEEaiLkDhMGugadMslSbtj9rx8eqK9KUuX7igzp47U9WPwPX+VV3uBg0aWv0aZwY/ev8tKuAcUNgPdAAJ4rA6t1E7d+muxj00UZfyvXdfU+fPng2ixPdvFMxSN3oSNwqv/vVP6vqNa26nrLCoKKN79+mvxo2bqBtp1p1sO9dzc9Wrr7xoC+FuNAnAmrFWXG11r6RY5eXdjOatmHaQBBISEhSse7uJ1aaRkyeOq48/fMcccluJBJKSklRCfOksjLxbN90VZ7XCGxAbM3a8atSoiShtl6kL589VYunCvxUU0V27lb+rJSUl6m7RHZVQp5RZ7fja4d+EKUSdAOoANxkx8kHVsVNnt1MeYYUF+erFP/7OIwwHvXr3U0OGjVINGjTwOoeAInlX1q9ZpTZtWud6vqoCwy13VeU7TpTRMx6dG9TtFy9aoHZu3xpU3MqINGPmY3qgItC9KqTgC+8THShLPF+FBDIzm6kRox5UV69cVsuWLqzCnPDWlUGgZ68+qlPn7urgwb3qh53bI3ZLX3Wg/QYDBg1Vo+Vdw0AD5E5hofrTf/1ve5SY328iiuinf/Qza1B6+9ZNaumS72Mq3+HWBePGP6y6SbvUzKB2Fu7GjRtq0ffz1fFjR52nqvQ4lsvtZmwSDVjVtj8UZj8wGixjMc3q3Ea186xVwkalnUdF9vEuBJKIKqNhPTdlyiOqY5dufu/bMC3N73lzcsXyxSo5KVkfHjt2yARzGwIBjIr/7Ff/pK+4IVNYX3n5TyFczajRINBFRt0mTJpiKfKicY+anuaWzevUqRPHdDFhDReuPP/zX6u6YgVdUnJPvfzSH6MyaNO7zwAVL1Ybe3fvrFbKaEzBgUU0BHw+/eg9derUSVHY31OY7pfZLFthgDFW5djRI5LP0qqu4M6dWM1mVPJ1+OB+9d03X1ppF4pSOBrSs3dfL0U0BixMxyYhoY4aNWacfn82b94QjSwwzfuIwMcfvavi48obuC/84p9V3eTA1tb3EaIaV9QOHTurNm3bqcYZTe5LZfT91h/qJUZNWVnZqqCgICxldCh1ICyJp8+cozDDpzoL2myzZz9t1b/VuSz+8t6jZx/dpjZx7G0OhGFwfNbsJ9W7b71So9YKCbfcublX1Yt/+HeDTXUVjsYS1AqM8k517Q9FGQuTJ4FqRwBWz/bvCQowaNhINXjI8KDLElFl9MBBwy1FdEFhgVry/QKZ7nFCTxGFUrR5ixaqu3z0WrUs9cEaKJe7dm4LFIXnAxDoN3CQjoFK+tNPPwgQ2/dpWPIWi/+cWJxq7TvXsXdmyrRHRbFXPlgDS8PEuuxEh/qkoGDEXyQEbgWgiIYsWfxdVBTRmPIJRXR1FPiCMkrF/fv2qBNiwW+kuLhYnfXjtsHEq8ptjljS5WwIf8CiKstQ0XuXyIBBMAro7xd8qZYvcf8OTZo6XbVs2Vpn4eqVqz6zUnSnUG3ftkXt379bXHNc1vVFqrj+gAVM+w6d9HWjxjyktm7dHDP1SCTK7RNIFE8Ui+/zV15yH1jWCpVZc/RvFvX+5cuXopiT0JN+47W/qYTaCV4XJiYmqLlPPGcplS9ePO8VxwSg/HYb/xJVYk5xW0MJNGmaWUNLFlyx7rf+ULq4eYyEBFsHYnbP+IlTrFti4F2pOKvtY52oBjsTH56mkuvXi/mcRqIugJHVJhngPi4uQOGGBO4B0KafMGmquAtrpJ/ftOmz1d9f/UvM8IiFctvbhcV3o2Ok4At4de4P+SoTwz0JVOc2qmdJeBQMAfv3BPHvhuiOOWLaEVhFDx4yTOf5tviH
|
||
|
}
|
||
|
},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"![image.png](attachment:image.png)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Sequential is excellent for building basic and straightforward models, but most models don't have such a simple linear structure. \n",
|
||
|
"\n",
|
||
|
"Let's say you want to add a skip connection to the ConvNet."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 32, 32, 32])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"ConvNet(\n",
|
||
|
" (sequential): Sequential(\n",
|
||
|
" (0): Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" (1): ReLU()\n",
|
||
|
" (2): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" (3): ReLU()\n",
|
||
|
" (4): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" (5): ReLU()\n",
|
||
|
" (6): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
" )\n",
|
||
|
" (skip): Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# ConvNet with a residual connection in PyTorch\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ConvNet(nn.Module):\n",
|
||
|
" def __init__(self) -> None:\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.sequential = nn.Sequential(\n",
|
||
|
" fl.Conv2d(3, 32, 3, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" )\n",
|
||
|
" self.skip = fl.Conv2d(3, 32, 3, padding=1)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
||
|
" return self.sequential(x) + self.skip(x)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"convnet = ConvNet()\n",
|
||
|
"x = torch.randn(1, 3, 32, 32)\n",
|
||
|
"print(convnet(x).shape)\n",
|
||
|
"convnet"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The `repr` of this PyTorch is not declarative anymore: you cannot know how the model works.\n",
|
||
|
"\n",
|
||
|
"You can use Refiners' predefined `Chain` subclasses to handle such cases and build more complex models. \n",
|
||
|
"\n",
|
||
|
"Let's start with the `Sum` class.\n",
|
||
|
"\n",
|
||
|
"`fl.Sum([layer1, layer2, layer3])(x) = layer1(x) + layer2(x) + layer3(x)`"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 32, 32, 32])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(SUM)\n",
|
||
|
" ├── (CHAIN)\n",
|
||
|
" │ ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #1\n",
|
||
|
" │ ├── ReLU() #1\n",
|
||
|
" │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #2\n",
|
||
|
" │ ├── ReLU() #2\n",
|
||
|
" │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #3\n",
|
||
|
" │ ├── ReLU() #3\n",
|
||
|
" │ └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #4\n",
|
||
|
" └── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# ConvNet with a residual connection in Refiners\n",
|
||
|
"convnet = fl.Sum(\n",
|
||
|
" fl.Chain(\n",
|
||
|
" fl.Conv2d(3, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" ),\n",
|
||
|
" fl.Conv2d(3, 32, 3, padding=1),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 3, 32, 32)\n",
|
||
|
"print(convnet(x).shape)\n",
|
||
|
"convnet"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"You can subclass the basics `Chain` to give a name to improve declarativity. The `repr` will still tell you which kind of `Chain` it is and the name of the `Chain`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(SUM) ResidualNet()\n",
|
||
|
" ├── (CHAIN) ConvNet()\n",
|
||
|
" │ ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #1\n",
|
||
|
" │ ├── ReLU() #1\n",
|
||
|
" │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #2\n",
|
||
|
" │ ├── ReLU() #2\n",
|
||
|
" │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #3\n",
|
||
|
" │ ├── ReLU() #3\n",
|
||
|
" │ └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #4\n",
|
||
|
" └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class ConvNet(fl.Chain):\n",
|
||
|
" def __init__(self) -> None:\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Conv2d(3, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" fl.ReLU(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ResidualNet(fl.Sum):\n",
|
||
|
" def __init__(self) -> None:\n",
|
||
|
" super().__init__(\n",
|
||
|
" ConvNet(),\n",
|
||
|
" fl.Conv2d(32, 32, 3, padding=1),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"ResidualNet()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Here are some examples of `Chain` subclasses:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(tensor([[-0.4723, -0.2809]]), tensor([[-0.5384, 0.6123, 0.2659, 0.0916]]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run layers in parallel to output a tuple\n",
|
||
|
"par = fl.Parallel(\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
" fl.Linear(2, 4),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 2)\n",
|
||
|
"par(x) # (Linear_1(x), Linear_2(x))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor([[-1.0949, 0.0749, 0.2607, 0.4013]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run layers in parallel and then concatenate the outputs\n",
|
||
|
"cat = fl.Concatenate(\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
" dim=-1,\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 2)\n",
|
||
|
"cat(x) # Concatenate((Linear_1(x), Linear_2(x)), dim=-1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"tensor([[-0.1487, -1.1180]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run sequentially layers and then add the input\n",
|
||
|
"residual = fl.Residual(\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 2)\n",
|
||
|
"residual(x) # Linear_2(Linear_1(x)) + x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Let's now build something more complex such as a Vision Transformer. \n",
|
||
|
"\n",
|
||
|
"Let's start with the heart of a transformer layer: the Multi-Head Attention."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 197, 128])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(RES) Attention()\n",
|
||
|
" ├── (PAR)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" └── Linear(in_features=128, out_features=128)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from refiners.fluxion.layers.attentions import ScaledDotProductAttention\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class Attention(fl.Residual):\n",
|
||
|
" def __init__(self, dim: int = 128, num_heads: int = 8) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.num_heads = num_heads\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Parallel(\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" ),\n",
|
||
|
" ScaledDotProductAttention(num_heads=num_heads),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 197, 128)\n",
|
||
|
"attention = Attention()\n",
|
||
|
"print(attention(x).shape)\n",
|
||
|
"attention"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(RES) FeedForward()\n",
|
||
|
" ├── Linear(in_features=128, out_features=512) #1\n",
|
||
|
" ├── SiLU()\n",
|
||
|
" └── Linear(in_features=512, out_features=128) #2"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class FeedForward(fl.Residual):\n",
|
||
|
" def __init__(self, dim: int = 128, inner_dim: int = 512) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.inner_dim = inner_dim\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Linear(dim, inner_dim),\n",
|
||
|
" fl.SiLU(),\n",
|
||
|
" fl.Linear(inner_dim, dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"FeedForward()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(CHAIN) TranformerLayer()\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #1\n",
|
||
|
" ├── (RES) Attention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #2\n",
|
||
|
" └── (RES) FeedForward()\n",
|
||
|
" ├── Linear(in_features=128, out_features=512) #1\n",
|
||
|
" ├── SiLU()\n",
|
||
|
" └── Linear(in_features=512, out_features=128) #2"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class TranformerLayer(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self, dim: int = 128, num_heads: int = 8, inner_dim: int = 512\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.num_heads = num_heads\n",
|
||
|
" self.inner_dim = inner_dim\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" Attention(dim, num_heads),\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" FeedForward(dim, inner_dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"TranformerLayer()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([1, 196, 128])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class PatchEncoder(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self, in_channels: int = 3, dim: int = 128, patch_size: int = 16\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.in_channels = in_channels\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.patch_size = patch_size\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Conv2d(\n",
|
||
|
" in_channels=in_channels,\n",
|
||
|
" out_channels=dim,\n",
|
||
|
" kernel_size=patch_size,\n",
|
||
|
" stride=patch_size,\n",
|
||
|
" ),\n",
|
||
|
" fl.Reshape(-1, dim), # Reshape always preserves the batch dimension\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 3, 224, 224)\n",
|
||
|
"PatchEncoder()(x).shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class PositionalToken(fl.Residual):\n",
|
||
|
" def __init__(self, num_patches: int = 196) -> None:\n",
|
||
|
" self.num_patches = num_patches\n",
|
||
|
" super().__init__(fl.Parameter(num_patches, 128))\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ClassToken(fl.Chain):\n",
|
||
|
" def __init__(self, dim: int = 128) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" super().__init__(fl.Parameter(1, dim))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we have every bit to build a full Vision Transformer."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"(CHAIN) ViT()\n",
|
||
|
" ├── (CAT)\n",
|
||
|
" │ ├── (CHAIN) PatchEncoder()\n",
|
||
|
" │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n",
|
||
|
" │ │ └── Reshape(shape=(-1, 128))\n",
|
||
|
" │ └── (CHAIN) ClassToken()\n",
|
||
|
" │ └── Parameter(dims=(1, 128))\n",
|
||
|
" ├── (RES) PositionalToken(num_patches=197)\n",
|
||
|
" │ └── Parameter(dims=(197, 128))\n",
|
||
|
" └── (CHAIN) Transformer()\n",
|
||
|
" └── (CHAIN) TranformerLayer() (x4)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #1\n",
|
||
|
" ├── (RES) Attention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #2\n",
|
||
|
" └── (RES) FeedForward()\n",
|
||
|
" ├── Linear(in_features=128, out_features=512) #1\n",
|
||
|
" ├── SiLU()\n",
|
||
|
" └── Linear(in_features=512, out_features=128) #2\n",
|
||
|
"torch.Size([1, 197, 128])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class Transformer(fl.Chain):\n",
|
||
|
" pass\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ViT(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self,\n",
|
||
|
" dim: int = 128,\n",
|
||
|
" patch_size: int = 16,\n",
|
||
|
" image_size: int = 224,\n",
|
||
|
" num_layers: int = 4,\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.patch_size = patch_size\n",
|
||
|
" self.image_size = image_size\n",
|
||
|
" self.num_layers = num_layers\n",
|
||
|
" self.num_patches = (image_size // patch_size) ** 2 + 1\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Concatenate(\n",
|
||
|
" PatchEncoder(in_channels=3, dim=dim, patch_size=patch_size),\n",
|
||
|
" ClassToken(dim=dim),\n",
|
||
|
" dim=1,\n",
|
||
|
" ),\n",
|
||
|
" PositionalToken(num_patches=self.num_patches),\n",
|
||
|
" Transformer(TranformerLayer(dim=dim) for _ in range(num_layers)),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 3, 224, 224)\n",
|
||
|
"vit = ViT()\n",
|
||
|
"print(repr(vit))\n",
|
||
|
"print(vit(x).shape)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {
|
||
|
"image.png": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0oAAAOSCAYAAABKtyCvAAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAADSqADAAQAAAABAAADkgAAAADrWh4cAABAAElEQVR4AezdB7wsV1048AkmQBJEugISiCAI4gd4Vx4xEHpRMCgEUJGmITEmQihSEvkL0iEgoIjxPRCJiAgI0hGQIsXkxXsDEpoUlS4CAhJKAsk/v4WzzJ23u3fLzOycme98Pvfu7uzMKd8zO3t+Z8oecNHFU2EiQIAAAQIECBAgQIAAgbHAJcbPPCFAgAABAgQIECBAgACBkYBAyYZAgAABAgQIECBAgACBioBAqQLiJQECBAgQIECAAAECBARKtgECBAgQIECAAAECBAhUBARKFRAvCRAgQIAAAQIECBAgIFCyDRAgQIAAAQIECBAgQKAiIFCqgHhJgAABAgQIECBAgAABgZJtgAABAgQIECBAgAABAhUBgVIFxEsCBAgQIECAAAECBAgIlGwDBAgQIECAAAECBAgQqAgIlCogXhIgQIAAAQIECBAgQECgZBsgQIAAAQIECBAgQIBARUCgVAHxkgABAgQIECBAgAABAgIl2wABAgQIECBAgAABAgQqAgKlCoiXBAgQIECAAAECBAgQECjZBggQIECAAAECBAgQIFAREChVQLwkQIAAAQIECBAgQICAQMk2QIAAAQIECBAgQIAAgYqAQKkC4iUBAgQIECBAgAABAgQESrYBAgQIECBAgAABAgQIVAQEShUQLwkQIECAAAECBAgQICBQsg0QIECAAAECBAgQIECgIiBQqoB4SYAAAQIECBAgQIAAAYGSbYAAAQIECBAgQIAAAQIVAYFSBcRLAgQIECBAgAABAgQICJRsAwQIECBAgAABAgQIEKgICJQqIF4SIECAAAECBAgQIEBAoGQbIECAAAECBAgQIECAQEVAoFQB8ZIAAQIECBAgQIAAAQICJdsAAQIECBAgQIAAAQIEKgICpQqIlwQIECBAgAABAgQIEBAo2QYIECBAgAABAgQIECBQERAoVUC8JECAAAECBAgQIECAgEDJNkCAAAECBAgQIECAAIGKgECpAuIlAQIECBAgQIAAAQIEBEq2AQIECBAgQIAAAQIECFQEBEoVEC8JECBAgAABAgQIECAgULINECBAgAABAgQIECBAoCIgUKqAeEmAAAECBAgQIECAAAGBkm2AAAECBAgQIECAAAECFQGBUgXESwIECBAgQIAAAQIECAiUbAMECBAgQIAAAQIECBCoCAiUKiBeEiBAgAABAgQIECBAQKBkGyBAgAABAgQIECBAgEBFQKBUAfGSAAECBAgQIECAAAECAiXbAAECBAgQIECAAAECBCoCAqUKiJcECBAgQIAAAQIECBAQKNkGCBAgQIAAAQIECBAgUBEQKFVAvCRAgAABAgQIECBAgIBAyTZAgAABAgQIECBAgACBioBAqQLiJQECBAgQIECAAAECBA5EQIAAAQIECBAgsL/A1tbW/jPnmLO5uTnHUvsvsux6+6fU3zkbGxuNVW7VtHft2tVY2SS8HoEDLrp4Wk/WciVAgAABAgQItCuwd+/eIv66NulkL94iywayi+fUzBrHHXdcEX+m7goIlC5um1kftJ1Gd3Z6f96mn1WGedNYZLl17JBXHamZp35N5rEOs3nqbBkCBAgQmF9g9+7dRezP0/fFrO/xJr+bV/1OSeWfv+b5LDmrTeapRZPtNiv/eds02i7qGOXct2/frCS9t2aBwQVKXR1JWvN2IPs1CBhJWgN6DVlW9yHzfjGmrHPo3KzaSUl17cLjujpMbdTdPmQ55QiU6rCbtW3N+xmad7lpNZ1VhmnrdH3+ovvUan0W3cfOu/yq5aqWM32XnH766aPAvfq+190QGFygVB1JaqoZVt35rVKuPu44V/GYtG7dO7xJecS8aTvg2D6inYwkTZPr7vzyPqT6OffZ6267rVqyde8zquW3D6mKzP86fYajg2oisC4BgdK65BfLd7A3c2j6nNCm01+smS3dRQGd6i62ynxligB4kRHpaltXA6yU67T56f1Jj9W0Jy3T1rxVg4lpAwvTyr/o8imdVcuZ0ln3Y5faft0W8idAgEATAoMNlJrAlCaBRQWio9OXTtuidc95+QhoFhkMqbZx9XWyWCTNtI5HAgQIECBAoBkBv6PUjKtUCRAgQIAAAQIECBDIWECglHHjKToBAgQIECBAgEC+Asuccp1vbfMruUApvzZT4h4I2DH2oBFVgQCB7ARc15VdkykwgbUKDDJQsqNc6zYncwIECBAgQIAAAQKdFxhkoNT5VlFAAgQIECBAoDEBA6aN0UqYQK8EBhUo2TH2attVGQJrE7AvWRu9jAkQINArAafid7s5BxUodbsplI4AgRwEpt3aO4eyKyMBAgQIECAwv8BgAyUjwvNvJJasX8D2V7+pFAkQIECAAAECdQoMNlCqE1FaBJYVcMh9WTnrESBAgAABAgSaFRAoNesrdQIECBAg0JiAo9ON0UqYQCsCPsOtMC+diUBpaTorEiBAgAABArkK6KDm2nLKTaA9AYFSe9ZyIkCAAAECBNYo4HTnNeLLmkCGAoMNlOwsM9xaFZkAAQIECBAgQIBASwKDDZRa8pUNAQIECBAgQIAAAQIZCgiUMmw0RSZAYP0Crm9YfxsoAQECBAgQaFJAoNSkrrQJECBAgAABAgQIEMhSYFCBUrouadeuXUV6nmWrKXTWAuUjEbbDrJtS4QkQIECAAIEeCwwqUOpxO6oaAQIECBAgsICAgaoFsCzaqEB5ALXRjCS+sIBAaWEyKxAgQIAAAQIECBAg0HcBgVLfW1j9CBAgQIAAAQIECBBYWECgtDCZFQgQIECAAIFcBeI6ZRMBAgTmERhkoLSxsVE4H3SezcMyBAgQIECgPwKuS+pPW6oJgTYEBhkotQErDwIECBAgQKCbAu5+2812USoCXRMQKHWtRZSHAAECBAgQIECAAIG1CwiU1t4ECjBUAefJD7Xl1ZsAAQIEhi4Ql4GYui8gUOp+GylhzwScI9+zBlUdAgQIECCwgoB+wQp4Da8qUGoYWPIECBAgQIAAAQIECOQnMLhAKU53crgzvw1ViQkQIECAAAECBAi0KTCoQMmhzTY3LXntJCBg30nI+wQITBOw/5gmYz4BAgTqExhUoFRl81tKVRGv2xawDbYtLj8CBAgQIECAwHwCgw6U5iOyFAECBAgQINAXAUfj+tKS6kGgeQGBUvPGciBAgAABAgQ6IOAofgcaQREIZCQgUMqosRSVAAECBAgQIECAAIF2BARK7TjLhQABAgQIECBAgACBjAQEShk1lqISIECAAAECBAgQINCOgECpHWe5ECBAgAABAgQIECCQkYBAKaPGUlQCBAgQIECAAAECBNoRECi14ywXAgQIECBAgAABAgQyEhAoZdRYikqAAAECBAgQIECAQDsCgwyUdu3a1Y6uXAgQIECAAAECBAgQyFJgkIFSaqnNzc301CMBAgQIECBAgAABAgTGAoMOlMYKnhAgQIAAAQIECBAgQKAkIFAqYXhKgAABAgQIECBAgACBEBAo2Q4IECBAgAABAgQIECBQERAoVUC8JECAAAECBAgQIECAgEDJNkBgDQLuvLgGdFkSIECAAAECBBYQGFygtLGxsQCPRQkQIECAAAECBAgQGKLA4AKlITayOhMgQIAAAQLfFzBgaksgQGBegUEFSltbW/O6WI4AAQIECBAgQIAAgQELDCpQGnA7qzoBAgQIEBi0gMHSQTe/yhNYSkCgtBSblQgsL7C5ubn8ytYkQIAAAQIECBBoRUCg1AqzTAgQIECAAAECBAgQyElAoJRTaykrAQIECBAgQIAAAQKtCAiUWmGWCQECBAgQIECAAAECOQkIlHJqLWUlQIAAAQIEahFwm/BaGCVCoNcCAqVeN6/KESBAgEAfBdwUZrVWdQe81fysTWAoAgKlobS0enZGwChmZ5pCQQgQIECAAAECUwUE
|
||
|
}
|
||
|
},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Advanced - Context API\n",
|
||
|
"\n",
|
||
|
"This ViT is still rudimentary and linear: we have one input in and one output out. But often, you want to use multiple inputs/modalities in the flow of your model.\n",
|
||
|
"\n",
|
||
|
"Let's take, for example, the `MaskDecoder` of the Segment Anything model by Meta; it's a Transformer that takes as input an image and a prompt and outputs a segmentation mask. You can prompt it with points to guide the segmentation. [Here is a link](https://github.com/finegrain-ai/refiners/blob/main/src/refiners/foundationals/segment_anything/mask_decoder.py) to the complete implementation in Refiners\n",
|
||
|
"\n",
|
||
|
"So the inputs are:\n",
|
||
|
"\n",
|
||
|
" - an image of shape (3, 224, 224)\n",
|
||
|
" - several points of shape (N, 2) \n",
|
||
|
"\n",
|
||
|
"One way to consider the points is to add a `CrossAttention` layer that will attend to the points from the image. Cross attention is a standard `Attention` layer, but the key and value come from a source different from the query. In our case, the query is the image, the key and value are the points embeddings.\n",
|
||
|
"\n",
|
||
|
"![image.png](attachment:image.png)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"Let's start by building a point encoder (to simplify that all points have the same \"meaning\")."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([1, 5, 128])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 18,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class PointEncoder(fl.Chain):\n",
|
||
|
" def __init__(self, dim: int = 128) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Linear(2, dim),\n",
|
||
|
" fl.SiLU(),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.SiLU(),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Unsqueeze(0),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"points = torch.randn(5, 2)\n",
|
||
|
"PointEncoder()(points).shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"torch.Size([1, 197, 128])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Chains can handle multiple inputs\n",
|
||
|
"class CrossAttention(fl.Chain):\n",
|
||
|
" def __init__(self, dim: int = 128, num_heads: int = 8) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.num_heads = num_heads\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Parallel(\n",
|
||
|
" fl.GetArg(0),\n",
|
||
|
" fl.GetArg(1),\n",
|
||
|
" fl.GetArg(1),\n",
|
||
|
" ),\n",
|
||
|
" fl.Distribute(\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" ),\n",
|
||
|
" ScaledDotProductAttention(num_heads=num_heads),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"points_embedding = torch.randn(1, 5, 128)\n",
|
||
|
"patch_embedding = torch.randn(1, 197, 128)\n",
|
||
|
"CrossAttention()(patch_embedding, points_embedding).shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now ideally, I would like to insert this `CrossAttention` layer in the middle of the `Transformer` like this:\n",
|
||
|
"\n",
|
||
|
"```python\n",
|
||
|
"class TranformerLayer(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self, dim: int = 128, num_heads: int = 8, inner_dim: int = 512\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.num_heads = num_heads\n",
|
||
|
" self.inner_dim = inner_dim\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" Attention(dim, num_heads),\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" CrossAttention(dim, num_heads),\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" FeedForward(dim, inner_dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"But how do the `point_embedding` get into the `CrossAttention` layer? That's where the `Context` API comes into play."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"chain = fl.Chain(\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
" fl.Concatenate(\n",
|
||
|
" fl.Linear(2, 2),\n",
|
||
|
" fl.UseContext(\"embedding\", \"value\"),\n",
|
||
|
" dim=-1,\n",
|
||
|
" ),\n",
|
||
|
" fl.Linear(5, 2),\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"chain.set_context(\"embedding\", {\"value\": torch.randn(1, 3)})\n",
|
||
|
"print(f\"Current embedding context: {chain.use_context('embedding')}\")\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 2)\n",
|
||
|
"chain(x)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Note that the context is recursive, so you can access the context of an outer `Chain` from an inner `Chain`.\n",
|
||
|
"\n",
|
||
|
"We can rewrite the `CrossAttention` layer using context instead of passing the `point_embedding` as an argument."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 197, 128])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(CHAIN) PointsCrossAttention()\n",
|
||
|
" ├── (PAR)\n",
|
||
|
" │ ├── Identity()\n",
|
||
|
" │ └── UseContext(context=vit, key=points_embedding) (x2)\n",
|
||
|
" ├── (DISTR)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" └── Linear(in_features=128, out_features=128)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class PointsCrossAttention(fl.Chain):\n",
|
||
|
" def __init__(self, dim: int = 128, num_heads: int = 8) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.num_heads = num_heads\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Parallel(\n",
|
||
|
" fl.Identity(),\n",
|
||
|
" fl.UseContext(\"vit\", \"points_embedding\"),\n",
|
||
|
" fl.UseContext(\"vit\", \"points_embedding\"),\n",
|
||
|
" ),\n",
|
||
|
" fl.Distribute(\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" ),\n",
|
||
|
" ScaledDotProductAttention(num_heads=num_heads),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"points_cross_attention = PointsCrossAttention()\n",
|
||
|
"\n",
|
||
|
"# If the context is not set, the layer will raise an error\n",
|
||
|
"points_embedding = torch.randn(1, 5, 128)\n",
|
||
|
"points_cross_attention.set_context(\"vit\", {\"points_embedding\": points_embedding})\n",
|
||
|
"\n",
|
||
|
"x = torch.randn(1, 197, 128)\n",
|
||
|
"\n",
|
||
|
"print(points_cross_attention(x).shape)\n",
|
||
|
"points_cross_attention"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now let's rewrite the `TransformerLayer` using the `PointsCrossAttention` layer.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 197, 128])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class TranformerLayer(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self, dim: int = 128, num_heads: int = 8, inner_dim: int = 512\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.num_heads = num_heads\n",
|
||
|
" self.inner_dim = inner_dim\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" Attention(dim, num_heads),\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" PointsCrossAttention(dim, num_heads),\n",
|
||
|
" fl.LayerNorm(dim),\n",
|
||
|
" FeedForward(dim, inner_dim),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"layer = TranformerLayer()\n",
|
||
|
"x = torch.randn(1, 197, 128)\n",
|
||
|
"points_embedding = torch.randn(1, 5, 128)\n",
|
||
|
"layer.set_context(\"vit\", {\"points_embedding\": points_embedding})\n",
|
||
|
"print(layer(x).shape)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The ViT is still valid as is, but we might want to add the `PointEncoder` directly into the model to not have to deal with multiple models separately. To do that, we can wrap the `PointEncoder` into a `Passthrough` layer that will let the main arguments pass through, but will also add the `point_embedding` to the context using a `SetContext` layer."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 197, 128])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(CHAIN) ViT()\n",
|
||
|
" ├── (PASS) PointEncoder()\n",
|
||
|
" │ ├── UseContext(context=vit, key=points_tensor)\n",
|
||
|
" │ ├── Linear(in_features=2, out_features=128) #1\n",
|
||
|
" │ ├── SiLU() #1\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=128) #2\n",
|
||
|
" │ ├── SiLU() #2\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=128) #3\n",
|
||
|
" │ ├── Unsqueeze(dim=0)\n",
|
||
|
" │ └── SetContext(context=vit, key=points_embedding)\n",
|
||
|
" ├── (CAT)\n",
|
||
|
" │ ├── (CHAIN) PatchEncoder()\n",
|
||
|
" │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n",
|
||
|
" │ │ └── Reshape(shape=(-1, 128))\n",
|
||
|
" │ └── (CHAIN) ClassToken()\n",
|
||
|
" │ └── Parameter(dims=(1, 128))\n",
|
||
|
" ├── (RES) PositionalToken(num_patches=197)\n",
|
||
|
" │ └── Parameter(dims=(197, 128))\n",
|
||
|
" └── (CHAIN) Transformer()\n",
|
||
|
" └── (CHAIN) TranformerLayer() (x4)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #1\n",
|
||
|
" ├── (RES) Attention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #2\n",
|
||
|
" ├── (CHAIN) PointsCrossAttention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ ├── Identity()\n",
|
||
|
" │ │ └── UseContext(context=vit, key=points_embedding) (x2)\n",
|
||
|
" │ ├── (DISTR)\n",
|
||
|
" │ │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #3\n",
|
||
|
" └── (RES) FeedForward()\n",
|
||
|
" ├── Linear(in_features=128, out_features=512) #1\n",
|
||
|
" ├── SiLU()\n",
|
||
|
" └── Linear(in_features=512, out_features=128) #2"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 22,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class PointEncoder(fl.Passthrough):\n",
|
||
|
" def __init__(self, dim: int = 128) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.UseContext(\"vit\", \"points_tensor\"),\n",
|
||
|
" fl.Linear(2, dim),\n",
|
||
|
" fl.SiLU(),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.SiLU(),\n",
|
||
|
" fl.Linear(dim, dim),\n",
|
||
|
" fl.Unsqueeze(0),\n",
|
||
|
" fl.SetContext(\"vit\", \"points_embedding\"),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ViT(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self,\n",
|
||
|
" dim: int = 128,\n",
|
||
|
" patch_size: int = 16,\n",
|
||
|
" image_size: int = 224,\n",
|
||
|
" num_layers: int = 4,\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.dim = dim\n",
|
||
|
" self.patch_size = patch_size\n",
|
||
|
" self.image_size = image_size\n",
|
||
|
" self.num_layers = num_layers\n",
|
||
|
" self.num_patches = (image_size // patch_size) ** 2 + 1\n",
|
||
|
" super().__init__(\n",
|
||
|
" PointEncoder(dim=dim),\n",
|
||
|
" fl.Concatenate(\n",
|
||
|
" PatchEncoder(in_channels=3, dim=dim, patch_size=patch_size),\n",
|
||
|
" ClassToken(dim=dim),\n",
|
||
|
" dim=1,\n",
|
||
|
" ),\n",
|
||
|
" PositionalToken(num_patches=self.num_patches),\n",
|
||
|
" Transformer(TranformerLayer(dim=dim) for _ in range(num_layers)),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"vit = ViT()\n",
|
||
|
"x = torch.randn(1, 3, 224, 224)\n",
|
||
|
"points = torch.randn(5, 2)\n",
|
||
|
"vit.set_context(\"vit\", {\"points_tensor\": points})\n",
|
||
|
"print(vit(x).shape)\n",
|
||
|
"vit"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {
|
||
|
"image.png": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3YAAAOiCAYAAAA8AkW0AAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAADdqADAAQAAAABAAADogAAAADmkWVMAABAAElEQVR4AezdB9wUxfnA8YeiNAVBROlVrKjwImIvscReokZNorGhxpioiRqjiSVGjSb6jzGKGE1ib4nB2LuxgrwgFhRRQEQ6SO/K/56FWeb2yntl725n9zefz3l7W2ZnvnMv7nOzO9NoTSoJCQEEEEAAAQQQQAABBBBAwFmBxs6WnIIjgAACCCCAAAIIIIAAAgh4AgR2fBEQQAABBBBAAAEEEEAAAccFCOwcb0CKjwACCCCAAAIIIIAAAggQ2PEdQAABBBBAAAEEEEAAAQQcFyCwc7wBKT4CCCCAAAIIIIAAAgggQGDHdwABBBBAAAEEEEAAAQQQcFyAwM7xBqT4CCCAAAIIIIAAAggggACBHd8BBBBAAAEEEEAAAQQQQMBxAQI7xxuQ4iOAAAIIIIAAAggggAACBHZ8BxBAAAEEEEAAAQQQQAABxwUI7BxvQIqPAAIIIIAAAggggAACCBDY8R1AAAEEEEAAAQQQQAABBBwXILBzvAEpPgIIIIAAAggggAACCCBAYMd3AAEEEEAAAQQQQAABBBBwXIDAzvEGpPgIIIAAAggggAACCCCAAIEd3wEEEEAAAQQQQAABBBBAwHEBAjvHG5DiI4AAAggggAACCCCAAAIEdnwHEEAAAQQQQAABBBBAAAHHBQjsHG9Aio8AAggggAACCCCAAAIIENjxHUAAAQQQQAABBBBAAAEEHBcgsHO8ASk+AggggAACCCCAAAIIIEBgx3cAAQQQQAABBBBAAAEEEHBcgMDO8Qak+AgggAACCCCAAAIIIIAAgR3fAQQQQAABBBBAAAEEEEDAcQECO8cbkOIjgAACCCCAAAIIIIAAAgR2fAcQQAABBBBAAAEEEEAAAccFCOwcb0CKjwACCCCAAAIIIIAAAggQ2PEdQAABBBBAAAEEEEAAAQQcFyCwc7wBKT4CCCCAAAIIIIAAAgggQGDHdwABBBBAAAEEEEAAAQQQcFyAwM7xBqT4CCCAAAIIIIAAAggggACBHd8BBBBAAAEEEEAAAQQQQMBxAQI7xxuQ4iOAAAIIIIAAAggggAACBHZ8BxBAAAEEEEAAAQQQQAABxwUI7BxvQIqPAAIIIIAAAggggAACCBDY8R1AAAEEEEAAAQQQQAABBBwXILBzvAEpPgIIIIAAAggggAACCCBAYMd3AAEEEEAAAQQQQAABBBBwXIDAzvEGpPgIIIAAAggggAACCCCAAIEd3wEEEEAAAQQQQAABBBBAwHEBAjvHG5DiI4AAAggggAACCCCAAAIEdnwHEEAAAQQQQAABBBBAAAHHBQjsHG9Aio8AAggggAACCCCAAAIIENjxHUAAAQQQQAABBBBAAAEEHBcgsHO8ASk+AggggAACCCCAAAIIIEBgx3cAAQQQQAABBBBAAAEEEHBcgMDO8Qak+AgggAACCCCAAAIIIIAAgR3fAQQQQAABBBBAAAEEEEDAcQECO8cbkOIjgAACCCCAAAIIIIAAAgR2fAcQQAABBBBAAAEEEEAAAccFCOwcb0CKjwACCCCAAAIIIIAAAggQ2PEdQAABBBBAAAEEEEAAAQQcFyCwc7wBKT4CCCCAAAIIIIAAAgggQGDHdwABBBBAAAEEEEAAAQQQcFyAwM7xBqT4CCCAAAIIIIAAAggggACBHd8BBBBAAAEEEEAAAQQQQMBxAQI7xxuQ4iOAAAIIIIAAAggggAACBHZ8BxBAAAEEEEAAAQQQQAABxwUI7BxvQIqPAAIIIIAAAggggAACCBDY8R1AAAEEEEAAAQQQQAABBBwXILBzvAEpPgIIIIAAAggggAACCCDQFAIEEEAAgegJjB49Omeh6uvrc27LtaHYY+rq6nJlVdb6MPIdMGBAWWXgYAQQQAABBOIo0GhNKsWxYtQJAQQQcEXgzjvvFH2Rai9w5plnir5ICCCAAAIIuCZAYOdai1FeBBCIncCgQYNEe6HC6M1qCKfYnrt8+eXrVcx3XDW3FdK7Z9yNzdChQ6tZRM6FAAIIIIBAKALcihkKI5kggAAC5QlocFGNnqJqnKM8idodbQK72pWAMyOAAAIIIFC6AIOnlG7HkQgggEDZAi70epVdSTJAAAEEEEAAgYoLENhVnJgTIIAAArkF6CXKbVPtLQTZ1RbnfAgggAACYQoQ2IWpSV4IIIAAAk4LENw53XwUHgEEEEi0AIFdopufyiOAAAIIIIAAAggggEAcBAjs4tCK1AEBBJwV4FZMZ5uOgiOAAAIIIBApAQK7SDUHhUEAAQQQQAABBBBAAAEEihcgsCvejCMQQACB0AR4pis0SjJCAAEEEEAg0QIEdolufiqPAAIIIBAUINgOivAZAQQQQMAFAQI7F1qJMiKAAAIIIIAAAggggAACeQQI7PLgsAkBBBBAAAEEEEAAAQQQcEGAwM6FVqKMCCCAAAIIIIAAAggggEAeAQK7PDhsQgABBCopwLNcldQlbwQQQAABBJIlQGCXrPamtggggAACCCCAAAIIIBBDAQK7GDYqVUIAAQQQKE6A3tPivNgbAQQQQCB6AgR20WsTSoQAAggggAACCCCAAAIIFCVAYFcUFzsjgAACCCCAAAIIIIAAAtETILCLXptQIgQQQAABBBBAAAEEEECgKAECu6K42BkBBBAIT6C+vj68zMgJAQQQQAABBBItQGCX6Oan8ggggAACKjBgwAAgEEAAAQQQcFqAwM7p5qPwCCCAAAIIIIAAAggggIAIgR3fAgQQQAABBBBAAAEEEEDAcQECO8cbkOIjgAACCCCAAAIIIIAAAgR2fAcQQAABBBBAAAEEEEAAAccFCOwcb0CKjwACCCAQrgADqYTrSW4IIIAAAtURILCrjjNnQQABBBBAAAEEEEAAAQQqJkBgVzFaMkYAAQQQQAABBBBAAAEEqiNAYFcdZ86CAAIIIIAAAggggAACCFRMgMCuYrRkjAACCCCAAAIIIIAAAghUR4DArjrOnAUBBBBAAAEEEEAAAQQQqJgAgV3FaMkYAQQQyC9QV1eXfwe2IoAAAggggAACBQo0LXA/dkMAAQQQQCDWAkxzEOvmpXIIIIBA7AXosYt9E1NBBBBAAAEEEEAAAQQQiLsAgV3cW5j6IYAAAggggAACCCCAQOwFuBUz9k1MBRFAAIHKCrz22msya9Ysady4sRx99NHee2XPSO4IIIAAAgggEBQgsAuK8BkBBBBAoCiBe++9V8aOHesdc+SRRxLYFaXHzggggAACCIQjwK2Y4TiSCwIIIIAAAggggAACCCBQMwECu5rRc2IEEEAAAQQQQAABBBBAIBwBArtwHMkFAQQQQMBxgdGjRzteA4qPAAIIIJBkAQK7JLc+dUcAAQQQQAABBBBAAIFYCDB4SiyakUoggAAClReYN2+erFmzJuNEK1as8NfNnTtXmjbN/F9LmzZtsq73D2QBAQQQQAABBMoSyPy/b1nZcTACCCCAQBwFFixYIAcddFCDVTvssMOy7nPTTTfJnnvumXUbKxFAAAEEEECgfAFuxSzfkBwQQACBkgQGDBjgH1dXV+cvR3EhW09dFMtJmRBAAAEEEEiqAD12SW156o0AAggUIdC8eXPR3rhvv/024yidoHzJkiXeeu3Va9KkScY+HTt2zFgXxRUMoBLFVqFMCCCAAAKFCBDYFaLEPggggEDCBTSwu+KKK7IqnHHGGf4E5VdeeSXP0mVVYiUCCCCAAAKVFeBWzMr6kjsCCCCAgAMC9NQ50EgUEQEEEEAgrwCBXV4eNiKAAAIIIIAAAggggAAC0RcgsIt+G1FCBBBAAAEEEEAAAQQQQCCvAIFdXh42IoAAAggggAACCCCAAALRFyCwi34bUUIEEIixgD3lgavVzDYhuat1odwIIIAAAgi4KsComK62HOVGAAEEIiIwdOjQiJSEYiCAAAIIIJBc
|
||
|
}
|
||
|
},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Adaptation\n",
|
||
|
"\n",
|
||
|
"I think to have a very explicit and declarative model like we showcased here is an interesting property, but the place where it shines the most is when you want to adapt a model to a new task.\n",
|
||
|
"\n",
|
||
|
"Let's demonstrate on a simple example how to create a LoRA adaptation on the ViT without having to rewrite the whole model.\n",
|
||
|
"\n",
|
||
|
"The low-rank adaptation technique adds lighter new layers on top of the model. The rank is the inner_dim of the new layers. The outer layer is zero-initialized, so the model's output is the same before training.\n",
|
||
|
"![image.png](attachment:image.png)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"torch.Size([1, 1, 128])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" └── Linear(in_features=16, out_features=128) #2"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"class Lora(fl.Chain):\n",
|
||
|
" def __init__(\n",
|
||
|
" self,\n",
|
||
|
" in_features: int,\n",
|
||
|
" out_features: int,\n",
|
||
|
" rank: int = 16,\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.in_features = in_features\n",
|
||
|
" self.out_features = out_features\n",
|
||
|
" self.rank = rank\n",
|
||
|
" self.scale: float = 1.0\n",
|
||
|
"\n",
|
||
|
" super().__init__(\n",
|
||
|
" fl.Linear(in_features=in_features, out_features=rank, bias=False),\n",
|
||
|
" fl.Linear(in_features=rank, out_features=out_features),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" nn.init.normal_(tensor=self.Linear_1.weight, std=1 / self.rank)\n",
|
||
|
" nn.init.zeros_(tensor=self.Linear_2.weight)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"lora = Lora(128, 128)\n",
|
||
|
"x = torch.randn(1, 1, 128)\n",
|
||
|
"print(lora(x).shape)\n",
|
||
|
"lora"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we want to be able to insert this into any `Linear` layer of the Model. To do that, we can use the `Adapter` class."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"Adapter:\n",
|
||
|
"(SUM) LoraAdapter()\n",
|
||
|
" ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
"\n",
|
||
|
"Note that the original attention is not modified:\n",
|
||
|
"(RES) Attention()\n",
|
||
|
" ├── (PAR)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" └── Linear(in_features=128, out_features=128) \n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from refiners.fluxion.adapters import Adapter\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class LoraAdapter(fl.Sum, Adapter[fl.Linear]):\n",
|
||
|
" def __init__(\n",
|
||
|
" self,\n",
|
||
|
" target: fl.Linear,\n",
|
||
|
" rank: int = 16,\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.in_features = target.in_features\n",
|
||
|
" self.out_features = target.out_features\n",
|
||
|
" self.rank = rank\n",
|
||
|
" # the setup_adapter method is used to remove boilerplate code\n",
|
||
|
" with self.setup_adapter(target):\n",
|
||
|
" super().__init__(\n",
|
||
|
" target,\n",
|
||
|
" Lora(\n",
|
||
|
" in_features=target.in_features,\n",
|
||
|
" out_features=target.out_features,\n",
|
||
|
" rank=rank,\n",
|
||
|
" ),\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"attention = Attention()\n",
|
||
|
"linear = attention.ensure_find(fl.Linear)\n",
|
||
|
"adapter = LoraAdapter(linear)\n",
|
||
|
"print(\n",
|
||
|
" f\"\"\"\n",
|
||
|
"Adapter:\n",
|
||
|
"{repr(adapter)}\n",
|
||
|
"\n",
|
||
|
"Note that the original attention is not modified:\n",
|
||
|
"{repr(attention)} \n",
|
||
|
"\"\"\"\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Let's now `inject` the `Adapter` into the `FeedForward` layer of the `TransformerLayer`. One subtlety is that the `Linear` layer is considered a `WeightedModule` and as such can belong to multiple `Chain` at the same time. So we need to specify which `Chain` we want to inject the `Adapter` into."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"(RES) Attention()\n",
|
||
|
" ├── (PAR)\n",
|
||
|
" │ ├── (SUM) LoraAdapter()\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128) (x2)\n",
|
||
|
" ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" └── Linear(in_features=128, out_features=128)\n",
|
||
|
"(RES) Attention()\n",
|
||
|
" ├── (PAR)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" └── Linear(in_features=128, out_features=128)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"adapter.inject(parent=attention.Parallel)\n",
|
||
|
"print(repr(attention))\n",
|
||
|
"\n",
|
||
|
"# we can also `eject` the adapter to get back to normal\n",
|
||
|
"adapter.eject()\n",
|
||
|
"print(repr(attention))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Finally let's write a top-level adapter that will inject the `Adapter` into all the `Linear` layers of the `ViT`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"(CHAIN) ViT()\n",
|
||
|
" ├── (PASS) PointEncoder()\n",
|
||
|
" │ ├── UseContext(context=vit, key=points_tensor)\n",
|
||
|
" │ ├── (SUM) LoraAdapter() #1\n",
|
||
|
" │ │ ├── Linear(in_features=2, out_features=128)\n",
|
||
|
" │ │ └── (CHAIN) Lora(in_features=2, out_features=128)\n",
|
||
|
" │ │ ├── Linear(in_features=2, out_features=16) #1\n",
|
||
|
" │ │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" │ ├── SiLU() #1\n",
|
||
|
" │ ├── (SUM) LoraAdapter() #2\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" │ ├── SiLU() #2\n",
|
||
|
" │ ├── (SUM) LoraAdapter() #3\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" │ ├── Unsqueeze(dim=0)\n",
|
||
|
" │ └── SetContext(context=vit, key=points_embedding)\n",
|
||
|
" ├── (CAT)\n",
|
||
|
" │ ├── (CHAIN) PatchEncoder()\n",
|
||
|
" │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n",
|
||
|
" │ │ └── Reshape(shape=(-1, 128))\n",
|
||
|
" │ └── (CHAIN) ClassToken()\n",
|
||
|
" │ └── Parameter(dims=(1, 128))\n",
|
||
|
" ├── (RES) PositionalToken(num_patches=197)\n",
|
||
|
" │ └── Parameter(dims=(197, 128))\n",
|
||
|
" └── (CHAIN) Transformer()\n",
|
||
|
" └── (CHAIN) TranformerLayer() (x4)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #1\n",
|
||
|
" ├── (RES) Attention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ └── (SUM) LoraAdapter() (x3)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── (SUM) LoraAdapter()\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #2\n",
|
||
|
" ├── (CHAIN) PointsCrossAttention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ ├── Identity()\n",
|
||
|
" │ │ └── UseContext(context=vit, key=points_embedding) (x2)\n",
|
||
|
" │ ├── (DISTR)\n",
|
||
|
" │ │ └── (SUM) LoraAdapter() (x3)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── (SUM) LoraAdapter()\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=128)\n",
|
||
|
" │ └── (CHAIN) Lora(in_features=128, out_features=128)\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #3\n",
|
||
|
" └── (RES) FeedForward()\n",
|
||
|
" ├── (SUM) LoraAdapter() #1\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=512)\n",
|
||
|
" │ └── (CHAIN) Lora(in_features=128, out_features=512)\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=16) #1\n",
|
||
|
" │ └── Linear(in_features=16, out_features=512) #2\n",
|
||
|
" ├── SiLU()\n",
|
||
|
" └── (SUM) LoraAdapter() #2\n",
|
||
|
" ├── Linear(in_features=512, out_features=128)\n",
|
||
|
" └── (CHAIN) Lora(in_features=512, out_features=128)\n",
|
||
|
" ├── Linear(in_features=512, out_features=16) #1\n",
|
||
|
" └── Linear(in_features=16, out_features=128) #2\n",
|
||
|
"torch.Size([1, 197, 128])\n",
|
||
|
"(CHAIN) ViT()\n",
|
||
|
" ├── (PASS) PointEncoder()\n",
|
||
|
" │ ├── UseContext(context=vit, key=points_tensor)\n",
|
||
|
" │ ├── Linear(in_features=2, out_features=128) #1\n",
|
||
|
" │ ├── SiLU() #1\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=128) #2\n",
|
||
|
" │ ├── SiLU() #2\n",
|
||
|
" │ ├── Linear(in_features=128, out_features=128) #3\n",
|
||
|
" │ ├── Unsqueeze(dim=0)\n",
|
||
|
" │ └── SetContext(context=vit, key=points_embedding)\n",
|
||
|
" ├── (CAT)\n",
|
||
|
" │ ├── (CHAIN) PatchEncoder()\n",
|
||
|
" │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n",
|
||
|
" │ │ └── Reshape(shape=(-1, 128))\n",
|
||
|
" │ └── (CHAIN) ClassToken()\n",
|
||
|
" │ └── Parameter(dims=(1, 128))\n",
|
||
|
" ├── (RES) PositionalToken(num_patches=197)\n",
|
||
|
" │ └── Parameter(dims=(197, 128))\n",
|
||
|
" └── (CHAIN) Transformer()\n",
|
||
|
" └── (CHAIN) TranformerLayer() (x4)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #1\n",
|
||
|
" ├── (RES) Attention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #2\n",
|
||
|
" ├── (CHAIN) PointsCrossAttention()\n",
|
||
|
" │ ├── (PAR)\n",
|
||
|
" │ │ ├── Identity()\n",
|
||
|
" │ │ └── UseContext(context=vit, key=points_embedding) (x2)\n",
|
||
|
" │ ├── (DISTR)\n",
|
||
|
" │ │ └── Linear(in_features=128, out_features=128) (x3)\n",
|
||
|
" │ ├── ScaledDotProductAttention(num_heads=8)\n",
|
||
|
" │ └── Linear(in_features=128, out_features=128)\n",
|
||
|
" ├── LayerNorm(normalized_shape=(128,)) #3\n",
|
||
|
" └── (RES) FeedForward()\n",
|
||
|
" ├── Linear(in_features=128, out_features=512) #1\n",
|
||
|
" ├── SiLU()\n",
|
||
|
" └── Linear(in_features=512, out_features=128) #2\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from typing import Self\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class ViTLoraAdapter(fl.Chain, Adapter[ViT]):\n",
|
||
|
" def __init__(\n",
|
||
|
" self,\n",
|
||
|
" target: ViT,\n",
|
||
|
" rank: int = 16,\n",
|
||
|
" ) -> None:\n",
|
||
|
" self.rank = rank\n",
|
||
|
" with self.setup_adapter(target):\n",
|
||
|
" super().__init__(target)\n",
|
||
|
"\n",
|
||
|
" # Let's wrap all the Linear layers in the ViT model into LoraAdapters\n",
|
||
|
" self.sub_adapters: list[tuple[LoraAdapter, fl.Chain]] = []\n",
|
||
|
" for linear, parent in self.target.walk(fl.Linear):\n",
|
||
|
" self.sub_adapters.append((LoraAdapter(target=linear, rank=rank), parent))\n",
|
||
|
"\n",
|
||
|
" def inject(self, parent: fl.Chain | None = None) -> Self:\n",
|
||
|
" for adapter, adapter_parent in self.sub_adapters:\n",
|
||
|
" adapter.inject(adapter_parent)\n",
|
||
|
" return super().inject(parent)\n",
|
||
|
"\n",
|
||
|
" def eject(self) -> None:\n",
|
||
|
" for adapter, _ in self.sub_adapters:\n",
|
||
|
" adapter.eject()\n",
|
||
|
" super().eject()\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"vit = ViT()\n",
|
||
|
"x = torch.randn(1, 3, 224, 224)\n",
|
||
|
"points = torch.randn(5, 2)\n",
|
||
|
"vit.set_context(\"vit\", {\"points_tensor\": points})\n",
|
||
|
"adapter = ViTLoraAdapter(vit)\n",
|
||
|
"adapter.inject() # since `ViT` has no parent, no need to pass it to `inject`\n",
|
||
|
"print(repr(vit))\n",
|
||
|
"print(vit(x).shape)\n",
|
||
|
"\n",
|
||
|
"# we can also `eject` the adapter to get back to normal\n",
|
||
|
"adapter.eject()\n",
|
||
|
"print(repr(vit))"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|