mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
fix chain slicing with structural copy
This commit is contained in:
parent
e7c1db50e0
commit
8c7298f8cc
|
@ -210,9 +210,23 @@ class Chain(ContextModule):
|
|||
other = Chain(*other)
|
||||
return Chain(*self, *other)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Module:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: str) -> Module:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> "Chain":
|
||||
...
|
||||
|
||||
def __getitem__(self, key: int | str | slice) -> Module:
|
||||
if isinstance(key, slice):
|
||||
return Chain(*list(self)[key])
|
||||
copy = self.structural_copy()
|
||||
copy._regenerate_keys(modules=list(copy)[key])
|
||||
return copy
|
||||
elif isinstance(key, str):
|
||||
return self._modules[key]
|
||||
else:
|
||||
|
|
|
@ -64,3 +64,23 @@ def test_chain_find():
|
|||
|
||||
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
||||
assert chain.find(fl.Conv2d) is None
|
||||
|
||||
|
||||
def test_chain_slice() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
)
|
||||
|
||||
x = torch.randn(1, 1)
|
||||
sliced_chain = chain[1:4]
|
||||
|
||||
assert len(chain) == 5
|
||||
assert len(sliced_chain) == 3
|
||||
assert chain[:-1](x).shape == (1, 1)
|
||||
|
|
Loading…
Reference in a new issue