fix chain slicing with structural copy

This commit is contained in:
Benjamin Trom 2023-08-21 18:36:36 +02:00
parent e7c1db50e0
commit 8c7298f8cc
2 changed files with 35 additions and 1 deletions

View file

@ -210,9 +210,23 @@ class Chain(ContextModule):
other = Chain(*other) other = Chain(*other)
return Chain(*self, *other) return Chain(*self, *other)
@overload
def __getitem__(self, key: int) -> Module:
...
@overload
def __getitem__(self, key: str) -> Module:
...
@overload
def __getitem__(self, key: slice) -> "Chain":
...
def __getitem__(self, key: int | str | slice) -> Module: def __getitem__(self, key: int | str | slice) -> Module:
if isinstance(key, slice): if isinstance(key, slice):
return Chain(*list(self)[key]) copy = self.structural_copy()
copy._regenerate_keys(modules=list(copy)[key])
return copy
elif isinstance(key, str): elif isinstance(key, str):
return self._modules[key] return self._modules[key]
else: else:

View file

@ -64,3 +64,23 @@ def test_chain_find():
assert isinstance(chain.find(fl.Linear), fl.Linear) assert isinstance(chain.find(fl.Linear), fl.Linear)
assert chain.find(fl.Conv2d) is None assert chain.find(fl.Conv2d) is None
def test_chain_slice() -> None:
chain = fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=1),
)
x = torch.randn(1, 1)
sliced_chain = chain[1:4]
assert len(chain) == 5
assert len(sliced_chain) == 3
assert chain[:-1](x).shape == (1, 1)