chain: add insert_before_type

This commit is contained in:
Cédric Deltheil 2023-09-24 21:13:37 +02:00 committed by Cédric Deltheil
parent 4352e78483
commit d72e1d3478
2 changed files with 19 additions and 0 deletions

View file

@ -310,6 +310,13 @@ class Chain(ContextModule):
module._set_parent(self) module._set_parent(self)
self._register_provider() self._register_provider()
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
for i, module in enumerate(self):
if isinstance(module, module_type):
self.insert(i, new_module)
return
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None: def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
for i, module in enumerate(self): for i, module in enumerate(self):
if isinstance(module, module_type): if isinstance(module, module_type):

View file

@ -107,6 +107,18 @@ def test_chain_insert_after_type() -> None:
assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"] assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"]
def test_chain_insert_before_type() -> None:
child = fl.Chain()
parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
parent_1.insert_before_type(fl.Linear, child)
assert module_keys(parent_1) == ["Chain", "Linear_1", "Linear_2"]
parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1))
parent_2.insert_before_type(fl.Linear, child)
assert module_keys(parent_2) == ["Conv2d", "Chain", "Linear"]
def test_chain_insert_overflow() -> None: def test_chain_insert_overflow() -> None:
# This behaves as insert() in lists in Python. # This behaves as insert() in lists in Python.