refiners/tests/fluxion/layers/test_chain.py
2023-09-25 13:54:26 +02:00

220 lines
5.5 KiB
Python

import pytest
import torch
import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
class ContextChain(fl.Chain):
def init_context(self) -> Contexts:
return {"foo": {"bar": [42]}}
def module_keys(chain: fl.Chain) -> list[str]:
return list(chain._modules.keys()) # type: ignore[reportPrivateUsage]
def test_chain_find() -> None:
chain = fl.Chain(fl.Linear(1, 1))
assert chain.find(fl.Linear) == chain.Linear
assert chain.find(fl.Conv2d) is None
def test_chain_find_parent():
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
assert chain.find_parent(chain.Chain.Linear) == chain.Chain
assert chain.find_parent(fl.Linear(1, 1)) is None
def test_chain_slice() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
),
fl.Linear(1, 1),
)
x = torch.randn(1, 1)
sliced_chain = chain[1:4]
assert len(chain) == 5
assert len(sliced_chain) == 3
assert chain[:-1](x).shape == (1, 1)
def test_chain_walk_stop_iteration() -> None:
chain = fl.Chain(
fl.Sum(
fl.Chain(fl.Linear(1, 1)),
fl.Linear(1, 1),
),
fl.Chain(),
fl.Linear(1, 1),
)
def predicate(m: fl.Module, p: fl.Chain) -> bool:
if isinstance(m, fl.Sum):
raise StopIteration
return isinstance(m, fl.Linear)
assert len(list(chain.walk(fl.Linear))) == 3
assert len(list(chain.walk(predicate))) == 1
def test_chain_layers() -> None:
chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())),
fl.Chain(),
fl.Linear(1, 1),
)
assert len(list(chain.layers(fl.Chain))) == 2
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
def test_chain_insert() -> None:
parent = ContextChain(fl.Linear(1, 1), fl.Linear(1, 1))
child = fl.Chain()
parent.insert(1, child)
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
assert child.parent == parent
assert child.provider.get_context("foo") == {"bar": [42]}
def test_chain_insert_negative() -> None:
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
child = fl.Chain()
parent.insert(-2, child)
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
def test_chain_insert_after_type() -> None:
child = fl.Chain()
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_1.insert_after_type(fl.Linear, child)
assert module_keys(parent_1) == ["Linear_1", "Chain", "Linear_2"]
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
parent_2.insert_after_type(fl.Linear, child)
assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"]
def test_chain_insert_before_type() -> None:
child = fl.Chain()
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_1.insert_before_type(fl.Linear, child)
assert module_keys(parent_1) == ["Chain", "Linear_1", "Linear_2"]
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
parent_2.insert_before_type(fl.Linear, child)
assert module_keys(parent_2) == ["Conv2d", "Chain", "Linear"]
def test_chain_insert_overflow() -> None:
# This behaves as insert() in lists in Python.
child = fl.Chain()
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_1.insert(42, child)
assert module_keys(parent_1) == ["Linear_1", "Linear_2", "Chain"]
parent_2 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_2.insert(-42, child)
assert module_keys(parent_2) == ["Chain", "Linear_1", "Linear_2"]
def test_chain_append() -> None:
child = fl.Chain()
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent.append(child)
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
def test_chain_pop() -> None:
chain = fl.Chain(fl.Linear(1, 1), fl.Conv2d(1, 1, 1), fl.Chain())
with pytest.raises(IndexError):
chain.pop(3)
with pytest.raises(IndexError):
chain.pop(-4)
assert module_keys(chain) == ["Linear", "Conv2d", "Chain"]
chain.pop(1)
assert module_keys(chain) == ["Linear", "Chain"]
chain.pop(-2)
assert module_keys(chain) == ["Chain"]
def test_chain_remove() -> None:
child = fl.Linear(1, 1)
parent = fl.Chain(
fl.Linear(1, 1),
child,
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
)
assert child in parent
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
parent.remove(child)
assert child not in parent
assert module_keys(parent) == ["Linear", "Chain"]
def test_chain_replace() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
)
assert isinstance(chain.Chain[1], fl.Linear)
chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1))
assert len(chain) == 3
assert isinstance(chain.Chain[1], fl.Conv2d)
def test_chain_structural_copy() -> None:
m = fl.Chain(
fl.Sum(
fl.Linear(4, 8),
fl.Linear(4, 8),
),
fl.Linear(8, 12),
)
x = torch.randn(7, 4)
y = m(x)
assert y.shape == (7, 12)
m2 = m.structural_copy()
assert m.Linear == m2.Linear
assert m.Sum.Linear_1 == m2.Sum.Linear_1
assert m.Sum.Linear_2 == m2.Sum.Linear_2
assert m.Sum != m2.Sum
assert m != m2
assert m.Sum.parent == m
assert m2.Sum.parent == m2
y2 = m2(x)
assert y2.shape == (7, 12)
torch.equal(y2, y)