mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
test all chain manipulation methods
This commit is contained in:
parent
802970e79a
commit
d311f779c0
|
@ -1,14 +1,32 @@
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
|
||||||
|
|
||||||
|
class ContextChain(fl.Chain):
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"foo": {"bar": [42]}}
|
||||||
|
|
||||||
|
|
||||||
|
def module_keys(chain: fl.Chain) -> list[str]:
|
||||||
|
return list(chain._modules.keys()) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
|
||||||
def test_chain_find() -> None:
|
def test_chain_find() -> None:
|
||||||
chain = fl.Chain(fl.Linear(1, 1))
|
chain = fl.Chain(fl.Linear(1, 1))
|
||||||
|
|
||||||
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
assert chain.find(fl.Linear) == chain.Linear
|
||||||
assert chain.find(fl.Conv2d) is None
|
assert chain.find(fl.Conv2d) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_find_parent():
|
||||||
|
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
||||||
|
|
||||||
|
assert chain.find_parent(chain.Chain.Linear) == chain.Chain
|
||||||
|
assert chain.find_parent(fl.Linear(1, 1)) is None
|
||||||
|
|
||||||
|
|
||||||
def test_chain_slice() -> None:
|
def test_chain_slice() -> None:
|
||||||
chain = fl.Chain(
|
chain = fl.Chain(
|
||||||
fl.Linear(1, 1),
|
fl.Linear(1, 1),
|
||||||
|
@ -59,21 +77,91 @@ def test_chain_layers() -> None:
|
||||||
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_insert() -> None:
|
||||||
|
parent = ContextChain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
child = fl.Chain()
|
||||||
|
parent.insert(1, child)
|
||||||
|
|
||||||
|
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
|
||||||
|
assert child.parent == parent
|
||||||
|
assert child.provider.get_context("foo") == {"bar": [42]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_insert_negative() -> None:
|
||||||
|
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
child = fl.Chain()
|
||||||
|
parent.insert(-2, child)
|
||||||
|
|
||||||
|
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_insert_after_type() -> None:
|
||||||
|
child = fl.Chain()
|
||||||
|
|
||||||
|
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
parent_1.insert_after_type(fl.Linear, child)
|
||||||
|
assert module_keys(parent_1) == ["Linear_1", "Chain", "Linear_2"]
|
||||||
|
|
||||||
|
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
|
||||||
|
parent_2.insert_after_type(fl.Linear, child)
|
||||||
|
assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_insert_overflow() -> None:
|
||||||
|
# This behaves as insert() in lists in Python.
|
||||||
|
|
||||||
|
child = fl.Chain()
|
||||||
|
|
||||||
|
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
parent_1.insert(42, child)
|
||||||
|
assert module_keys(parent_1) == ["Linear_1", "Linear_2", "Chain"]
|
||||||
|
|
||||||
|
parent_2 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
parent_2.insert(-42, child)
|
||||||
|
assert module_keys(parent_2) == ["Chain", "Linear_1", "Linear_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_append() -> None:
|
||||||
|
child = fl.Chain()
|
||||||
|
|
||||||
|
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
parent.append(child)
|
||||||
|
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_pop() -> None:
|
||||||
|
chain = fl.Chain(fl.Linear(1, 1), fl.Conv2d(1, 1, 1), fl.Chain())
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
chain.pop(3)
|
||||||
|
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
chain.pop(-4)
|
||||||
|
|
||||||
|
assert module_keys(chain) == ["Linear", "Conv2d", "Chain"]
|
||||||
|
chain.pop(1)
|
||||||
|
assert module_keys(chain) == ["Linear", "Chain"]
|
||||||
|
|
||||||
|
chain.pop(-2)
|
||||||
|
assert module_keys(chain) == ["Chain"]
|
||||||
|
|
||||||
|
|
||||||
def test_chain_remove() -> None:
|
def test_chain_remove() -> None:
|
||||||
chain = fl.Chain(
|
child = fl.Linear(1, 1)
|
||||||
fl.Linear(1, 1),
|
|
||||||
|
parent = fl.Chain(
|
||||||
fl.Linear(1, 1),
|
fl.Linear(1, 1),
|
||||||
|
child,
|
||||||
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
|
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(chain) == 3
|
assert child in parent
|
||||||
assert "Linear_1" in chain._modules
|
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
|
||||||
assert "Linear" not in chain._modules
|
|
||||||
|
|
||||||
chain.remove(chain.Linear_2)
|
parent.remove(child)
|
||||||
assert len(chain) == 2
|
|
||||||
assert "Linear" in chain._modules
|
assert child not in parent
|
||||||
assert "Linear_1" not in chain._modules
|
assert module_keys(parent) == ["Linear", "Chain"]
|
||||||
|
|
||||||
|
|
||||||
def test_chain_replace() -> None:
|
def test_chain_replace() -> None:
|
||||||
|
|
Loading…
Reference in a new issue