mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
make walk and layers not recurse by default
There is now a parameter to get the old (recursive) behavior.
This commit is contained in:
parent
2ad26a06b0
commit
dec0d64432
|
@ -270,39 +270,43 @@ class Chain(ContextModule):
|
|||
wm = self.find(WeightedModule)
|
||||
return None if wm is None else wm.dtype
|
||||
|
||||
def _walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
|
||||
def _walk(
|
||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||
) -> Iterator[tuple[Module, "Chain"]]:
|
||||
if predicate is None:
|
||||
predicate = lambda _m, _p: True
|
||||
for module in self:
|
||||
keep_going = True
|
||||
try:
|
||||
p = predicate(module, self)
|
||||
except StopIteration:
|
||||
p = False
|
||||
keep_going = False
|
||||
continue
|
||||
if p:
|
||||
yield (module, self)
|
||||
if keep_going and isinstance(module, Chain):
|
||||
yield from module.walk(predicate)
|
||||
if not recurse:
|
||||
continue
|
||||
if isinstance(module, Chain):
|
||||
yield from module.walk(predicate, recurse)
|
||||
|
||||
@overload
|
||||
def walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
|
||||
def walk(
|
||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||
) -> Iterator[tuple[Module, "Chain"]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def walk(self, predicate: type[T]) -> Iterator[tuple[T, "Chain"]]:
|
||||
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
||||
...
|
||||
|
||||
def walk(
|
||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None
|
||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
||||
if isinstance(predicate, type):
|
||||
return self._walk(lambda m, _: isinstance(m, predicate))
|
||||
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
|
||||
else:
|
||||
return self._walk(predicate)
|
||||
return self._walk(predicate, recurse)
|
||||
|
||||
def layers(self, layer_type: type[T]) -> Iterator[T]:
|
||||
for module, _ in self.walk(layer_type):
|
||||
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
|
||||
for module, _ in self.walk(layer_type, recurse):
|
||||
yield module
|
||||
|
||||
def find(self, layer_type: type[T]) -> T | None:
|
||||
|
|
|
@ -84,3 +84,14 @@ def test_chain_slice() -> None:
|
|||
assert len(chain) == 5
|
||||
assert len(sliced_chain) == 3
|
||||
assert chain[:-1](x).shape == (1, 1)
|
||||
|
||||
|
||||
def test_chain_layers() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(fl.Chain(fl.Chain())),
|
||||
fl.Chain(),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
)
|
||||
|
||||
assert len(list(chain.layers(fl.Chain))) == 2
|
||||
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
||||
|
|
Loading…
Reference in a new issue