mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
make walk and layers not recurse by default
There is now a parameter to get the old (recursive) behavior.
This commit is contained in:
parent
2ad26a06b0
commit
dec0d64432
|
@ -270,39 +270,43 @@ class Chain(ContextModule):
|
||||||
wm = self.find(WeightedModule)
|
wm = self.find(WeightedModule)
|
||||||
return None if wm is None else wm.dtype
|
return None if wm is None else wm.dtype
|
||||||
|
|
||||||
def _walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
|
def _walk(
|
||||||
|
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
if predicate is None:
|
if predicate is None:
|
||||||
predicate = lambda _m, _p: True
|
predicate = lambda _m, _p: True
|
||||||
for module in self:
|
for module in self:
|
||||||
keep_going = True
|
|
||||||
try:
|
try:
|
||||||
p = predicate(module, self)
|
p = predicate(module, self)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
p = False
|
continue
|
||||||
keep_going = False
|
|
||||||
if p:
|
if p:
|
||||||
yield (module, self)
|
yield (module, self)
|
||||||
if keep_going and isinstance(module, Chain):
|
if not recurse:
|
||||||
yield from module.walk(predicate)
|
continue
|
||||||
|
if isinstance(module, Chain):
|
||||||
|
yield from module.walk(predicate, recurse)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(self, predicate: Callable[[Module, "Chain"], bool] | None = None) -> Iterator[tuple[Module, "Chain"]]:
|
def walk(
|
||||||
|
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(self, predicate: type[T]) -> Iterator[tuple[T, "Chain"]]:
|
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None
|
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
||||||
if isinstance(predicate, type):
|
if isinstance(predicate, type):
|
||||||
return self._walk(lambda m, _: isinstance(m, predicate))
|
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
|
||||||
else:
|
else:
|
||||||
return self._walk(predicate)
|
return self._walk(predicate, recurse)
|
||||||
|
|
||||||
def layers(self, layer_type: type[T]) -> Iterator[T]:
|
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
|
||||||
for module, _ in self.walk(layer_type):
|
for module, _ in self.walk(layer_type, recurse):
|
||||||
yield module
|
yield module
|
||||||
|
|
||||||
def find(self, layer_type: type[T]) -> T | None:
|
def find(self, layer_type: type[T]) -> T | None:
|
||||||
|
|
|
@ -84,3 +84,14 @@ def test_chain_slice() -> None:
|
||||||
assert len(chain) == 5
|
assert len(chain) == 5
|
||||||
assert len(sliced_chain) == 3
|
assert len(sliced_chain) == 3
|
||||||
assert chain[:-1](x).shape == (1, 1)
|
assert chain[:-1](x).shape == (1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_layers() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Chain(fl.Chain(fl.Chain())),
|
||||||
|
fl.Chain(),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(list(chain.layers(fl.Chain))) == 2
|
||||||
|
assert len(list(chain.layers(fl.Chain, recurse=True))) == 4
|
||||||
|
|
Loading…
Reference in a new issue