test all chain manipulation methods

This commit is contained in:
Pierre Chapuis 2023-08-23 17:13:47 +02:00
parent 802970e79a
commit d311f779c0

View file

@ -1,14 +1,32 @@
import pytest
import torch import torch
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
class ContextChain(fl.Chain):
def init_context(self) -> Contexts:
return {"foo": {"bar": [42]}}
def module_keys(chain: fl.Chain) -> list[str]:
return list(chain._modules.keys()) # type: ignore[reportPrivateUsage]
def test_chain_find() -> None: def test_chain_find() -> None:
chain = fl.Chain(fl.Linear(1, 1)) chain = fl.Chain(fl.Linear(1, 1))
assert isinstance(chain.find(fl.Linear), fl.Linear) assert chain.find(fl.Linear) == chain.Linear
assert chain.find(fl.Conv2d) is None assert chain.find(fl.Conv2d) is None
def test_chain_find_parent():
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
assert chain.find_parent(chain.Chain.Linear) == chain.Chain
assert chain.find_parent(fl.Linear(1, 1)) is None
def test_chain_slice() -> None: def test_chain_slice() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Linear(1, 1), fl.Linear(1, 1),
@ -59,21 +77,91 @@ def test_chain_layers() -> None:
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4 assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
def test_chain_insert() -> None:
parent = ContextChain(fl.Linear(1, 1), fl.Linear(1, 1))
child = fl.Chain()
parent.insert(1, child)
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
assert child.parent == parent
assert child.provider.get_context("foo") == {"bar": [42]}
def test_chain_insert_negative() -> None:
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
child = fl.Chain()
parent.insert(-2, child)
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
def test_chain_insert_after_type() -> None:
child = fl.Chain()
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_1.insert_after_type(fl.Linear, child)
assert module_keys(parent_1) == ["Linear_1", "Chain", "Linear_2"]
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
parent_2.insert_after_type(fl.Linear, child)
assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"]
def test_chain_insert_overflow() -> None:
# This behaves as insert() in lists in Python.
child = fl.Chain()
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_1.insert(42, child)
assert module_keys(parent_1) == ["Linear_1", "Linear_2", "Chain"]
parent_2 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_2.insert(-42, child)
assert module_keys(parent_2) == ["Chain", "Linear_1", "Linear_2"]
def test_chain_append() -> None:
child = fl.Chain()
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent.append(child)
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
def test_chain_pop() -> None:
chain = fl.Chain(fl.Linear(1, 1), fl.Conv2d(1, 1, 1), fl.Chain())
with pytest.raises(IndexError):
chain.pop(3)
with pytest.raises(IndexError):
chain.pop(-4)
assert module_keys(chain) == ["Linear", "Conv2d", "Chain"]
chain.pop(1)
assert module_keys(chain) == ["Linear", "Chain"]
chain.pop(-2)
assert module_keys(chain) == ["Chain"]
def test_chain_remove() -> None: def test_chain_remove() -> None:
chain = fl.Chain( child = fl.Linear(1, 1)
fl.Linear(1, 1),
parent = fl.Chain(
fl.Linear(1, 1), fl.Linear(1, 1),
child,
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
) )
assert len(chain) == 3 assert child in parent
assert "Linear_1" in chain._modules assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
assert "Linear" not in chain._modules
chain.remove(chain.Linear_2) parent.remove(child)
assert len(chain) == 2
assert "Linear" in chain._modules assert child not in parent
assert "Linear_1" not in chain._modules assert module_keys(parent) == ["Linear", "Chain"]
def test_chain_replace() -> None: def test_chain_replace() -> None: