mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
reordering (match chain.py order)
This commit is contained in:
parent
e05c410a86
commit
beacfe816b
|
@ -2,6 +2,63 @@ import torch
|
|||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
def test_chain_find() -> None:
|
||||
chain = fl.Chain(fl.Linear(1, 1))
|
||||
|
||||
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
||||
assert chain.find(fl.Conv2d) is None
|
||||
|
||||
|
||||
def test_chain_slice() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
x = torch.randn(1, 1)
|
||||
sliced_chain = chain[1:4]
|
||||
|
||||
assert len(chain) == 5
|
||||
assert len(sliced_chain) == 3
|
||||
assert chain[:-1](x).shape == (1, 1)
|
||||
|
||||
|
||||
def test_chain_walk_stop_iteration() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Chain(fl.Linear(1, 1)),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Chain(),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
if isinstance(m, fl.Sum):
|
||||
raise StopIteration
|
||||
return isinstance(m, fl.Linear)
|
||||
|
||||
assert len(list(chain.walk(fl.Linear))) == 3
|
||||
assert len(list(chain.walk(predicate))) == 1
|
||||
|
||||
|
||||
def test_chain_layers() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(fl.Chain(fl.Chain())),
|
||||
fl.Chain(),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
assert len(list(chain.layers(fl.Chain))) == 2
|
||||
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
||||
|
||||
|
||||
def test_chain_remove() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
|
@ -60,60 +117,3 @@ def test_chain_structural_copy() -> None:
|
|||
y2 = m2(x)
|
||||
assert y2.shape == (7, 12)
|
||||
torch.equal(y2, y)
|
||||
|
||||
|
||||
def test_chain_find() -> None:
|
||||
chain = fl.Chain(fl.Linear(1, 1))
|
||||
|
||||
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
||||
assert chain.find(fl.Conv2d) is None
|
||||
|
||||
|
||||
def test_chain_slice() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
x = torch.randn(1, 1)
|
||||
sliced_chain = chain[1:4]
|
||||
|
||||
assert len(chain) == 5
|
||||
assert len(sliced_chain) == 3
|
||||
assert chain[:-1](x).shape == (1, 1)
|
||||
|
||||
|
||||
def test_chain_walk_stop_iteration() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Chain(fl.Linear(1, 1)),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Chain(),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
if isinstance(m, fl.Sum):
|
||||
raise StopIteration
|
||||
return isinstance(m, fl.Linear)
|
||||
|
||||
assert len(list(chain.walk(fl.Linear))) == 3
|
||||
assert len(list(chain.walk(predicate))) == 1
|
||||
|
||||
|
||||
def test_chain_layers() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(fl.Chain(fl.Chain())),
|
||||
fl.Chain(),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
assert len(list(chain.layers(fl.Chain))) == 2
|
||||
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
||||
|
|
Loading…
Reference in a new issue