mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
test (and fix) basic_attributes
This commit is contained in:
parent
bba478abf2
commit
bca50b71f2
|
@ -59,7 +59,7 @@ class Module(TorchModule):
|
||||||
tree = ModuleTree(module=self)
|
tree = ModuleTree(module=self)
|
||||||
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
|
||||||
"""Return a dictionary of basic attributes of the module.
|
"""Return a dictionary of basic attributes of the module.
|
||||||
|
|
||||||
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
||||||
|
@ -81,7 +81,7 @@ class Module(TorchModule):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return {
|
return {
|
||||||
key: str(object=value)
|
key: value
|
||||||
for key, value in self.__dict__.items()
|
for key, value in self.__dict__.items()
|
||||||
if is_basic_attribute(key=key, value=value)
|
if is_basic_attribute(key=key, value=value)
|
||||||
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
|
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
|
||||||
|
|
|
@ -13,3 +13,15 @@ def test_module_get_path() -> None:
|
||||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_basic_attributes() -> None:
|
||||||
|
class MyModule(fl.Module):
|
||||||
|
def __init__(self, spam: int = 0, foo: list[str | int] = ["bar", "qux", 42]) -> None:
|
||||||
|
self.spam = spam
|
||||||
|
self.foo = foo
|
||||||
|
self.chunky = "bacon"
|
||||||
|
|
||||||
|
m = MyModule(spam=3995)
|
||||||
|
assert str(m) == "MyModule(spam=3995)"
|
||||||
|
assert m.basic_attributes() == {"chunky": "bacon", "foo": ["bar", "qux", 42], "spam": 3995}
|
||||||
|
|
Loading…
Reference in a new issue