mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix chain slicing with structural copy
This commit is contained in:
parent
e7c1db50e0
commit
8c7298f8cc
|
@ -210,9 +210,23 @@ class Chain(ContextModule):
|
||||||
other = Chain(*other)
|
other = Chain(*other)
|
||||||
return Chain(*self, *other)
|
return Chain(*self, *other)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: int) -> Module:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: str) -> Module:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: slice) -> "Chain":
|
||||||
|
...
|
||||||
|
|
||||||
def __getitem__(self, key: int | str | slice) -> Module:
|
def __getitem__(self, key: int | str | slice) -> Module:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
return Chain(*list(self)[key])
|
copy = self.structural_copy()
|
||||||
|
copy._regenerate_keys(modules=list(copy)[key])
|
||||||
|
return copy
|
||||||
elif isinstance(key, str):
|
elif isinstance(key, str):
|
||||||
return self._modules[key]
|
return self._modules[key]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -64,3 +64,23 @@ def test_chain_find():
|
||||||
|
|
||||||
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
||||||
assert chain.find(fl.Conv2d) is None
|
assert chain.find(fl.Conv2d) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_slice() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(1, 1)
|
||||||
|
sliced_chain = chain[1:4]
|
||||||
|
|
||||||
|
assert len(chain) == 5
|
||||||
|
assert len(sliced_chain) == 3
|
||||||
|
assert chain[:-1](x).shape == (1, 1)
|
||||||
|
|
Loading…
Reference in a new issue