mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
chain: add insert_before_type
This commit is contained in:
parent
4352e78483
commit
d72e1d3478
|
@ -310,6 +310,13 @@ class Chain(ContextModule):
|
|||
module._set_parent(self)
|
||||
self._register_provider()
|
||||
|
||||
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||
for i, module in enumerate(self):
|
||||
if isinstance(module, module_type):
|
||||
self.insert(i, new_module)
|
||||
return
|
||||
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||
|
||||
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||
for i, module in enumerate(self):
|
||||
if isinstance(module, module_type):
|
||||
|
|
|
@ -107,6 +107,18 @@ def test_chain_insert_after_type() -> None:
|
|||
assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"]
|
||||
|
||||
|
||||
def test_chain_insert_before_type() -> None:
|
||||
child = fl.Chain()
|
||||
|
||||
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||
parent_1.insert_before_type(fl.Linear, child)
|
||||
assert module_keys(parent_1) == ["Chain", "Linear_1", "Linear_2"]
|
||||
|
||||
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
|
||||
parent_2.insert_before_type(fl.Linear, child)
|
||||
assert module_keys(parent_2) == ["Conv2d", "Chain", "Linear"]
|
||||
|
||||
|
||||
def test_chain_insert_overflow() -> None:
|
||||
# This behaves as insert() in lists in Python.
|
||||
|
||||
|
|
Loading…
Reference in a new issue