test (and fix) basic_attributes

This commit is contained in:
Pierre Chapuis 2024-01-29 17:33:07 +01:00 committed by Cédric Deltheil
parent bba478abf2
commit bca50b71f2
2 changed files with 14 additions and 2 deletions

View file

@ -59,7 +59,7 @@ class Module(TorchModule):
tree = ModuleTree(module=self)
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
"""Return a dictionary of basic attributes of the module.
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
@ -81,7 +81,7 @@ class Module(TorchModule):
return False
return {
key: str(object=value)
key: value
for key, value in self.__dict__.items()
if is_basic_attribute(key=key, value=value)
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))

View file

@ -13,3 +13,15 @@ def test_module_get_path() -> None:
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
assert chain.Sum_1.get_path() == "Chain.Sum_1"
def test_module_basic_attributes() -> None:
class MyModule(fl.Module):
def __init__(self, spam: int = 0, foo: list[str | int] = ["bar", "qux", 42]) -> None:
self.spam = spam
self.foo = foo
self.chunky = "bacon"
m = MyModule(spam=3995)
assert str(m) == "MyModule(spam=3995)"
assert m.basic_attributes() == {"chunky": "bacon", "foo": ["bar", "qux", 42], "spam": 3995}