make walk and layers not recurse by default

There is now a parameter to get the old (recursive) behavior.
This commit is contained in:
Pierre Chapuis 2023-08-22 17:42:41 +02:00
parent 2ad26a06b0
commit dec0d64432
2 changed files with 28 additions and 13 deletions

View file

@ -270,39 +270,43 @@ class Chain(ContextModule):
wm = self.find(WeightedModule) wm = self.find(WeightedModule)
return None if wm is None else wm.dtype return None if wm is None else wm.dtype
def _walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]: def _walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
if predicate is None: if predicate is None:
predicate = lambda _m, _p: True predicate = lambda _m, _p: True
for module in self: for module in self:
keep_going = True
try: try:
p = predicate(module, self) p = predicate(module, self)
except StopIteration: except StopIteration:
p = False continue
keep_going = False
if p: if p:
yield (module, self) yield (module, self)
if keep_going and isinstance(module, Chain): if not recurse:
yield from module.walk(predicate) continue
if isinstance(module, Chain):
yield from module.walk(predicate, recurse)
@overload @overload
def walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]: def walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
... ...
@overload @overload
def walk(self, predicate: type[T]) -> Iterator[tuple[T, "Chain"]]: def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
... ...
def walk( def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]: ) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
if isinstance(predicate, type): if isinstance(predicate, type):
return self._walk(lambda m, _: isinstance(m, predicate)) return self._walk(lambda m, _: isinstance(m, predicate), recurse)
else: else:
return self._walk(predicate) return self._walk(predicate, recurse)
def layers(self, layer_type: type[T]) -> Iterator[T]: def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
for module, _ in self.walk(layer_type): for module, _ in self.walk(layer_type, recurse):
yield module yield module
def find(self, layer_type: type[T]) -> T | None: def find(self, layer_type: type[T]) -> T | None:

View file

@ -84,3 +84,14 @@ def test_chain_slice() -> None:
assert len(chain) == 5 assert len(chain) == 5
assert len(sliced_chain) == 3 assert len(sliced_chain) == 3
assert chain[:-1](x).shape == (1, 1) assert chain[:-1](x).shape == (1, 1)
def test_chain_layers() -> None:
chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())),
fl.Chain(),
fl.Linear(in_features=1, out_features=1),
)
assert len(list(chain.layers(fl.Chain))) == 2
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4