mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
471ef91d1c
PyTorch chose to make it Any because they expect its users' code to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321 It is not the case for us, in Refiners having untyped code goes contrary to one of our core principles. Note that there is currently an open PR in PyTorch to return `Module | Tensor`, but in practice this is not always correct either: https://github.com/pytorch/pytorch/pull/115074 I also moved Residuals-related code from SD1 to latent_diffusion because SDXL should not depend on SD1.
308 lines
8 KiB
Python
308 lines
8 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import refiners.fluxion.layers as fl
|
|
from refiners.fluxion.context import Contexts
|
|
|
|
|
|
class ContextChain(fl.Chain):
|
|
def init_context(self) -> Contexts:
|
|
return {"foo": {"bar": [42]}}
|
|
|
|
|
|
def module_keys(chain: fl.Chain) -> list[str]:
|
|
return list(chain._modules.keys()) # type: ignore[reportPrivateUsage]
|
|
|
|
|
|
def test_chain_find() -> None:
|
|
chain = fl.Chain(fl.Linear(1, 1))
|
|
|
|
assert chain.find(fl.Linear) == chain.Linear
|
|
assert chain.find(fl.Conv2d) is None
|
|
|
|
|
|
def test_chain_getitem_accessor() -> None:
|
|
chain = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
assert chain["Linear_2"] == chain.Linear_2
|
|
with pytest.raises(KeyError):
|
|
chain["Linear_3"]
|
|
|
|
|
|
def test_chain_find_parent():
|
|
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
|
subchain = chain.layer("Chain", fl.Chain)
|
|
|
|
assert chain.find_parent(subchain.layer("Linear", fl.Linear)) == subchain
|
|
assert chain.find_parent(fl.Linear(1, 1)) is None
|
|
|
|
|
|
def test_chain_slice() -> None:
|
|
chain = fl.Chain(
|
|
fl.Linear(1, 1),
|
|
fl.Linear(1, 1),
|
|
fl.Linear(1, 1),
|
|
fl.Chain(
|
|
fl.Linear(1, 1),
|
|
fl.Linear(1, 1),
|
|
),
|
|
fl.Linear(1, 1),
|
|
)
|
|
|
|
x = torch.randn(1, 1)
|
|
sliced_chain = chain[1:4]
|
|
|
|
assert len(chain) == 5
|
|
assert len(sliced_chain) == 3
|
|
assert chain[:-1](x).shape == (1, 1)
|
|
|
|
|
|
def test_chain_walk() -> None:
|
|
chain = fl.Chain(
|
|
fl.Sum(
|
|
fl.Chain(fl.Linear(1, 1)),
|
|
fl.Linear(1, 1),
|
|
),
|
|
fl.Chain(),
|
|
)
|
|
|
|
sum_ = chain.layer("Sum", fl.Sum)
|
|
sum_chain = sum_.layer("Chain", fl.Chain)
|
|
|
|
assert list(chain.walk()) == [(sum_, chain), (chain.Chain, chain)]
|
|
assert list(chain.walk(fl.Linear)) == [
|
|
(sum_chain.Linear, sum_chain),
|
|
(sum_.Linear, sum_),
|
|
]
|
|
|
|
assert list(chain.walk(recurse=True)) == [
|
|
(sum_, chain),
|
|
(sum_chain, sum_),
|
|
(sum_chain.Linear, sum_chain),
|
|
(sum_.Linear, sum_),
|
|
(chain.Chain, chain),
|
|
]
|
|
|
|
|
|
def test_chain_walk_stop_iteration() -> None:
|
|
chain = fl.Chain(
|
|
fl.Sum(
|
|
fl.Chain(fl.Linear(1, 1)),
|
|
fl.Linear(1, 1),
|
|
),
|
|
fl.Chain(),
|
|
fl.Linear(1, 1),
|
|
)
|
|
|
|
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
|
if isinstance(m, fl.Sum):
|
|
raise StopIteration
|
|
return isinstance(m, fl.Linear)
|
|
|
|
assert len(list(chain.walk(fl.Linear))) == 3
|
|
assert len(list(chain.walk(predicate))) == 1
|
|
|
|
|
|
def test_chain_layer() -> None:
|
|
chain = fl.Chain(
|
|
fl.Sum(fl.Chain(), fl.Chain()),
|
|
)
|
|
|
|
sum_ = chain.layer(0, fl.Sum)
|
|
assert chain.layer("Sum", fl.Sum) == sum_
|
|
assert chain.layer("Sum", fl.Chain) == sum_
|
|
|
|
chain_2 = chain.layer((0, 1), fl.Chain)
|
|
assert chain.layer((0, 1)) == chain_2
|
|
assert chain.layer((0, "Chain_2"), fl.Chain) == chain_2
|
|
assert chain.layer(("Sum", "Chain_2"), fl.Chain) == chain_2
|
|
|
|
assert chain.layer((), fl.Chain) == chain
|
|
|
|
with pytest.raises(AssertionError):
|
|
chain.layer((0, 1), fl.Sum)
|
|
|
|
with pytest.raises(AssertionError):
|
|
chain.layer((), fl.Sum)
|
|
|
|
|
|
def test_chain_layers() -> None:
|
|
chain = fl.Chain(
|
|
fl.Chain(fl.Chain(fl.Chain())),
|
|
fl.Chain(),
|
|
fl.Linear(1, 1),
|
|
)
|
|
|
|
assert len(list(chain.layers(fl.Chain))) == 2
|
|
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
|
|
|
|
|
def test_chain_insert() -> None:
|
|
parent = ContextChain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
child = fl.Chain()
|
|
parent.insert(1, child)
|
|
|
|
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
|
|
assert child.parent == parent
|
|
assert child.provider.get_context("foo") == {"bar": [42]}
|
|
|
|
|
|
def test_chain_insert_negative() -> None:
|
|
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
child = fl.Chain()
|
|
parent.insert(-2, child)
|
|
|
|
assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"]
|
|
|
|
|
|
def test_chain_insert_after_type() -> None:
|
|
child = fl.Chain()
|
|
|
|
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
parent_1.insert_after_type(fl.Linear, child)
|
|
assert module_keys(parent_1) == ["Linear_1", "Chain", "Linear_2"]
|
|
|
|
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
|
|
parent_2.insert_after_type(fl.Linear, child)
|
|
assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"]
|
|
|
|
|
|
def test_chain_insert_before_type() -> None:
|
|
child = fl.Chain()
|
|
|
|
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
parent_1.insert_before_type(fl.Linear, child)
|
|
assert module_keys(parent_1) == ["Chain", "Linear_1", "Linear_2"]
|
|
|
|
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
|
|
parent_2.insert_before_type(fl.Linear, child)
|
|
assert module_keys(parent_2) == ["Conv2d", "Chain", "Linear"]
|
|
|
|
|
|
def test_chain_insert_overflow() -> None:
|
|
# This behaves as insert() in lists in Python.
|
|
|
|
child = fl.Chain()
|
|
|
|
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
parent_1.insert(42, child)
|
|
assert module_keys(parent_1) == ["Linear_1", "Linear_2", "Chain"]
|
|
|
|
parent_2 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
parent_2.insert(-42, child)
|
|
assert module_keys(parent_2) == ["Chain", "Linear_1", "Linear_2"]
|
|
|
|
|
|
def test_chain_append() -> None:
|
|
child = fl.Chain()
|
|
|
|
parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
|
parent.append(child)
|
|
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
|
|
|
|
|
|
def test_chain_pop() -> None:
|
|
chain = fl.Chain(fl.Linear(1, 1), fl.Conv2d(1, 1, 1), fl.Chain())
|
|
|
|
with pytest.raises(IndexError):
|
|
chain.pop(3)
|
|
|
|
with pytest.raises(IndexError):
|
|
chain.pop(-4)
|
|
|
|
assert module_keys(chain) == ["Linear", "Conv2d", "Chain"]
|
|
chain.pop(1)
|
|
assert module_keys(chain) == ["Linear", "Chain"]
|
|
|
|
chain.pop(-2)
|
|
assert module_keys(chain) == ["Chain"]
|
|
|
|
|
|
def test_chain_remove() -> None:
|
|
child = fl.Linear(1, 1)
|
|
|
|
parent = fl.Chain(
|
|
fl.Linear(1, 1),
|
|
child,
|
|
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
|
|
)
|
|
|
|
assert child in parent
|
|
assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"]
|
|
|
|
parent.remove(child)
|
|
|
|
assert child not in parent
|
|
assert module_keys(parent) == ["Linear", "Chain"]
|
|
|
|
|
|
def test_chain_replace() -> None:
|
|
chain = fl.Chain(
|
|
fl.Linear(1, 1),
|
|
fl.Linear(1, 1),
|
|
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
|
|
)
|
|
subchain = chain.layer("Chain", fl.Chain)
|
|
|
|
assert isinstance(subchain[1], fl.Linear)
|
|
subchain.replace(subchain[1], fl.Conv2d(1, 1, 1))
|
|
assert len(chain) == 3
|
|
assert isinstance(subchain[1], fl.Conv2d)
|
|
|
|
|
|
def test_chain_structural_copy() -> None:
|
|
m = fl.Chain(
|
|
fl.Sum(
|
|
fl.Linear(4, 8),
|
|
fl.Linear(4, 8),
|
|
),
|
|
fl.Linear(8, 12),
|
|
)
|
|
|
|
x = torch.randn(7, 4)
|
|
y = m(x)
|
|
assert y.shape == (7, 12)
|
|
|
|
m2 = m.structural_copy()
|
|
|
|
m_sum = m.layer("Sum", fl.Sum)
|
|
m2_sum = m2.layer("Sum", fl.Sum)
|
|
|
|
assert m.Linear == m2.Linear
|
|
assert m_sum.Linear_1 == m2_sum.Linear_1
|
|
assert m_sum.Linear_2 == m2_sum.Linear_2
|
|
|
|
assert m_sum != m2_sum
|
|
assert m != m2
|
|
|
|
assert m_sum.parent == m
|
|
assert m2_sum.parent == m2
|
|
|
|
y2 = m2(x)
|
|
assert y2.shape == (7, 12)
|
|
torch.equal(y2, y)
|
|
|
|
|
|
def test_setattr_dont_register() -> None:
|
|
chain = fl.Chain(fl.Linear(in_features=1, out_features=1), fl.Linear(in_features=1, out_features=1))
|
|
|
|
with pytest.raises(expected_exception=ValueError):
|
|
chain.foo = fl.Linear(in_features=1, out_features=1)
|
|
|
|
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
|
|
|
|
|
|
EXPECTED_TREE = (
|
|
"(CHAIN)\n ├── Linear(in_features=1, out_features=1, device=cpu, dtype=float32) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1,"
|
|
" out_features=1, device=cpu, dtype=float32) #1\n └── Linear(in_features=2, out_features=1, device=cpu, dtype=float32) #2"
|
|
)
|
|
|
|
|
|
def test_debug_print() -> None:
|
|
chain = fl.Chain(
|
|
fl.Linear(1, 1),
|
|
fl.Linear(1, 1),
|
|
fl.Chain(fl.Linear(1, 1), fl.Linear(2, 1)),
|
|
)
|
|
|
|
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|