mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
471ef91d1c
PyTorch chose to make it Any because they expect its users' code to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321 It is not the case for us, in Refiners having untyped code goes contrary to one of our core principles. Note that there is currently an open PR in PyTorch to return `Module | Tensor`, but in practice this is not always correct either: https://github.com/pytorch/pytorch/pull/115074 I also moved Residuals-related code from SD1 to latent_diffusion because SDXL should not depend on SD1.
81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
import pytest
|
|
|
|
from refiners.fluxion.adapters.adapter import Adapter
|
|
from refiners.fluxion.layers import Chain, Linear
|
|
|
|
|
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
|
def __init__(self, target: Linear):
|
|
with self.setup_adapter(target):
|
|
super().__init__(target)
|
|
|
|
|
|
class DummyChainAdapter(Chain, Adapter[Chain]):
|
|
def __init__(self, target: Chain):
|
|
with self.setup_adapter(target):
|
|
super().__init__(target)
|
|
|
|
|
|
@pytest.fixture
|
|
def chain() -> Chain:
|
|
return Chain(Chain(Linear(2, 2)))
|
|
|
|
|
|
def test_weighted_module_adapter_insertion(chain: Chain):
|
|
parent = chain.layer("Chain", Chain)
|
|
adaptee = parent.layer("Linear", Linear)
|
|
|
|
adapter = DummyLinearAdapter(adaptee).inject(parent)
|
|
|
|
assert adapter.parent == parent
|
|
assert adapter in iter(parent)
|
|
assert adaptee not in iter(parent)
|
|
|
|
adapter.eject()
|
|
assert adapter.parent is None
|
|
assert adapter not in iter(parent)
|
|
assert adaptee in iter(parent)
|
|
|
|
|
|
def test_chain_adapter_insertion(chain: Chain):
|
|
parent = chain
|
|
adaptee = parent.layer("Chain", Chain)
|
|
|
|
adapter = DummyChainAdapter(adaptee)
|
|
assert adaptee.parent == parent
|
|
|
|
adapter.inject()
|
|
assert adapter.parent == parent
|
|
assert adaptee.parent == adapter
|
|
assert adapter in iter(parent)
|
|
assert adaptee not in iter(parent)
|
|
|
|
adapter.eject()
|
|
assert adapter.parent is None
|
|
assert adaptee.parent == parent
|
|
assert adapter not in iter(parent)
|
|
assert adaptee in iter(parent)
|
|
|
|
|
|
def test_weighted_module_adapter_structural_copy(chain: Chain):
|
|
parent = chain.layer("Chain", Chain)
|
|
adaptee = parent.layer("Linear", Linear)
|
|
|
|
DummyLinearAdapter(adaptee).inject(parent)
|
|
|
|
clone = chain.structural_copy()
|
|
cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter)
|
|
assert cloned_adapter.parent == clone.Chain
|
|
assert cloned_adapter.target == adaptee
|
|
|
|
|
|
def test_chain_adapter_structural_copy(chain: Chain):
|
|
# Chain adapters cannot be copied by default.
|
|
adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject()
|
|
|
|
with pytest.raises(RuntimeError):
|
|
chain.structural_copy()
|
|
|
|
adapter.eject()
|
|
chain.structural_copy()
|