mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
test (and fix) basic_attributes
This commit is contained in:
parent
bba478abf2
commit
bca50b71f2
|
@ -59,7 +59,7 @@ class Module(TorchModule):
|
|||
tree = ModuleTree(module=self)
|
||||
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
|
||||
"""Return a dictionary of basic attributes of the module.
|
||||
|
||||
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
||||
|
@ -81,7 +81,7 @@ class Module(TorchModule):
|
|||
return False
|
||||
|
||||
return {
|
||||
key: str(object=value)
|
||||
key: value
|
||||
for key, value in self.__dict__.items()
|
||||
if is_basic_attribute(key=key, value=value)
|
||||
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
|
||||
|
|
|
@ -13,3 +13,15 @@ def test_module_get_path() -> None:
|
|||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
||||
|
||||
|
||||
def test_module_basic_attributes() -> None:
|
||||
class MyModule(fl.Module):
|
||||
def __init__(self, spam: int = 0, foo: list[str | int] = ["bar", "qux", 42]) -> None:
|
||||
self.spam = spam
|
||||
self.foo = foo
|
||||
self.chunky = "bacon"
|
||||
|
||||
m = MyModule(spam=3995)
|
||||
assert str(m) == "MyModule(spam=3995)"
|
||||
assert m.basic_attributes() == {"chunky": "bacon", "foo": ["bar", "qux", 42], "spam": 3995}
|
||||
|
|
Loading…
Reference in a new issue