make walk and layers not recurse by default

There is now a parameter to get the old (recursive) behavior.
This commit is contained in:
Pierre Chapuis 2023-08-22 17:42:41 +02:00
parent 2ad26a06b0
commit dec0d64432
2 changed files with 28 additions and 13 deletions

View file

@ -270,39 +270,43 @@ class Chain(ContextModule):
wm = self.find(WeightedModule)
return None if wm is None else wm.dtype
def _walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
def _walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
if predicate is None:
predicate = lambda _m, _p: True
for module in self:
keep_going = True
try:
p = predicate(module, self)
except StopIteration:
p = False
keep_going = False
continue
if p:
yield (module, self)
if keep_going and isinstance(module, Chain):
yield from module.walk(predicate)
if not recurse:
continue
if isinstance(module, Chain):
yield from module.walk(predicate, recurse)
@overload
def walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
def walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
...
@overload
def walk(self, predicate: type[T]) -> Iterator[tuple[T, "Chain"]]:
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
...
def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
if isinstance(predicate, type):
return self._walk(lambda m, _: isinstance(m, predicate))
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
else:
return self._walk(predicate)
return self._walk(predicate, recurse)
def layers(self, layer_type: type[T]) -> Iterator[T]:
for module, _ in self.walk(layer_type):
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
for module, _ in self.walk(layer_type, recurse):
yield module
def find(self, layer_type: type[T]) -> T | None:

View file

@ -84,3 +84,14 @@ def test_chain_slice() -> None:
assert len(chain) == 5
assert len(sliced_chain) == 3
assert chain[:-1](x).shape == (1, 1)
def test_chain_layers() -> None:
chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())),
fl.Chain(),
fl.Linear(in_features=1, out_features=1),
)
assert len(list(chain.layers(fl.Chain))) == 2
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4