LION/third_party/yacs_config.py
2023-01-23 00:14:49 -05:00

587 lines
21 KiB
Python

# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This file has been modified from a file in the following repo
# (released under the Apache License 2.0).
#
# Source:
# https://github.com/rbgirshick/yacs/blob/master/yacs/config.py
#
# The license for the original version of this file can be
# found in
# http://www.apache.org/licenses/LICENSE-2.0
# The modifications
# to this file are subject to the NVIDIA Source Code License for
# LION located at the root directory.
# ---------------------------------------------------------------
"""YACS -- Yet Another Configuration System is designed to be a simple
configuration management system for academic and industrial research
projects.
See README.md for usage and examples.
"""
# this code is modified from https://github.com/rbgirshick/yacs/blob/master/yacs/config.py
import copy
import io
# import logging
from loguru import logger
import os
import sys
from ast import literal_eval
import yaml
# Flag for py2 and py3 compatibility to use when separate code paths are necessary
# When _PY2 is False, we assume Python 3 is in use
_PY2 = sys.version_info.major == 2
# Filename extensions for loading configs from files
_YAML_EXTS = {"", ".yaml", ".yml"}
_PY_EXTS = {".py"}
# py2 and py3 compatibility for checking file object type
# We simply use this to infer py2 vs py3
if _PY2:
_FILE_TYPES = (file, io.IOBase)
else:
_FILE_TYPES = (io.IOBase, )
# CfgNodes can only contain a limited set of valid types
_VALID_TYPES = {tuple, list, str, int, float, bool}
# py2 allow for str and unicode
if _PY2:
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
# Utilities for importing modules from file paths
if _PY2:
# imp is available in both py2 and py3 for now, but is deprecated in py3
import imp
else:
import importlib.util
# logger = logging.getLogger(__name__)
class CfgNode(dict):
"""
CfgNode represents an internal node in the configuration tree. It's a simple
dict-like container that allows for attribute-based access to keys.
"""
IMMUTABLE = "__immutable__"
DEPRECATED_KEYS = "__deprecated_keys__"
RENAMED_KEYS = "__renamed_keys__"
NEW_ALLOWED = "__new_allowed__"
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
"""
Args:
init_dict (dict): the possibly-nested dictionary to initailize the CfgNode.
key_list (list[str]): a list of names which index this CfgNode from the root.
Currently only used for logging purposes.
new_allowed (bool): whether adding new key is allowed when merging with
other configs.
"""
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
init_dict = self._create_config_tree_from_dict(init_dict, key_list)
super(CfgNode, self).__init__(init_dict)
# Manage if the CfgNode is frozen or not
self.__dict__[CfgNode.IMMUTABLE] = False
# Deprecated options
# If an option is removed from the code and you don't want to break existing
# yaml configs, you can add the full config key as a string to the set below.
self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
# Renamed options
# If you rename a config option, record the mapping from the old name to the new
# name in the dictionary below. Optionally, if the type also changed, you can
# make the value a tuple that specifies first the renamed key and then
# instructions for how to edit the config file.
self.__dict__[CfgNode.RENAMED_KEYS] = {
# 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example to follow
# 'EXAMPLE.OLD.KEY': ( # A more complex example to follow
# 'EXAMPLE.NEW.KEY',
# "Also convert to a tuple, e.g., 'foo' -> ('foo',) or "
# + "'foo:bar' -> ('foo', 'bar')"
# ),
}
# Allow new attributes after initialisation
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
@classmethod
def _create_config_tree_from_dict(cls, dic, key_list):
"""
Create a configuration tree using the given dict.
Any dict-like objects inside dict will be treated as a new CfgNode.
Args:
dic (dict):
key_list (list[str]): a list of names which index this CfgNode from the root.
Currently only used for logging purposes.
"""
dic = copy.deepcopy(dic)
for k, v in dic.items():
if isinstance(v, dict):
# Convert dict to CfgNode
dic[k] = cls(v, key_list=key_list + [k])
else:
# Check for valid leaf type or nested CfgNode
_assert_with_logging(
_valid_type(v, allow_cfg_node=False),
"Key {} with value {} is not a valid type; valid types: {}"
.format(".".join(key_list + [str(k)]), type(v),
_VALID_TYPES),
)
return dic
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if self.is_frozen():
raise AttributeError(
"Attempted to set {} to {}, but CfgNode is immutable".format(
name, value))
_assert_with_logging(
name not in self.__dict__,
"Invalid attempt to modify internal CfgNode state: {}".format(
name),
)
_assert_with_logging(
_valid_type(value, allow_cfg_node=True),
"Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES),
)
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
seperator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def __repr__(self):
return "{}({})".format(self.__class__.__name__,
super(CfgNode, self).__repr__())
def to_dict(self, **kwargs):
"""Dump to a string."""
def convert_to_dict(cfg_node, key_list):
if not isinstance(cfg_node, CfgNode):
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}"
.format(".".join(key_list), type(cfg_node), _VALID_TYPES),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_to_dict(v, key_list + [k])
return cfg_dict
self_as_dict = convert_to_dict(self, [])
return self_as_dict
def dump(self, **kwargs):
"""Dump to a string."""
def convert_to_dict(cfg_node, key_list):
if not isinstance(cfg_node, CfgNode):
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}"
.format(".".join(key_list), type(cfg_node), _VALID_TYPES),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_to_dict(v, key_list + [k])
return cfg_dict
self_as_dict = convert_to_dict(self, [])
return yaml.safe_dump(self_as_dict, **kwargs)
def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it this CfgNode."""
with open(cfg_filename, "r") as f:
cfg = self.load_cfg(f)
self.merge_from_other_cfg(cfg)
def merge_from_other_cfg(self, cfg_other):
"""Merge `cfg_other` into this CfgNode."""
_merge_a_into_b(cfg_other, self, self, [])
def merge_from_list(self, cfg_list):
"""Merge config (keys, values) in a list (e.g., from command line) into
this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
"""
_assert_with_logging(
len(cfg_list) % 2 == 0,
"Override list has odd length: {}; it must be a list of pairs".
format(cfg_list),
)
root = self
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
if root.key_is_deprecated(full_key):
continue
if root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
key_list = full_key.split(".")
d = self
for subkey in key_list[:-1]:
_assert_with_logging(subkey in d,
"Non-existent key: {}".format(full_key))
d = d[subkey]
subkey = key_list[-1]
_assert_with_logging(subkey in d,
"Non-existent key: {}".format(full_key))
value = self._decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey,
full_key)
d[subkey] = value
def freeze(self):
"""Make this CfgNode and all of its children immutable."""
self._immutable(True)
def defrost(self):
"""Make this CfgNode and all of its children mutable."""
self._immutable(False)
def is_frozen(self):
"""Return mutability."""
return self.__dict__[CfgNode.IMMUTABLE]
def _immutable(self, is_immutable):
"""Set immutability to is_immutable and recursively apply the setting
to all nested CfgNodes.
"""
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
# Recursively set immutable state
for v in self.__dict__.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
for v in self.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
def clone(self):
"""Recursively copy this CfgNode."""
return copy.deepcopy(self)
def register_deprecated_key(self, key):
"""Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
keys a warning is generated and the key is ignored.
"""
_assert_with_logging(
key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
"key {} is already registered as a deprecated key".format(key),
)
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
def register_renamed_key(self, old_name, new_name, message=None):
"""Register a key as having been renamed from `old_name` to `new_name`.
When merging a renamed key, an exception is thrown alerting to user to
the fact that the key has been renamed.
"""
_assert_with_logging(
old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
"key {} is already registered as a renamed cfg key".format(
old_name),
)
value = new_name
if message:
value = (new_name, message)
self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
def key_is_deprecated(self, full_key):
"""Test if a key is deprecated."""
if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
logger.warning(
"Deprecated config key (ignoring): {}".format(full_key))
return True
return False
def key_is_renamed(self, full_key):
"""Test if a key is renamed."""
return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
def raise_key_rename_error(self, full_key):
new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
if isinstance(new_key, tuple):
msg = " Note: " + new_key[1]
new_key = new_key[0]
else:
msg = ""
raise KeyError(
"Key {} was renamed to {}; please update your config.{}".format(
full_key, new_key, msg))
def is_new_allowed(self):
return self.__dict__[CfgNode.NEW_ALLOWED]
@classmethod
def load_cfg(cls, cfg_file_obj_or_str):
"""
Load a cfg.
Args:
cfg_file_obj_or_str (str or file):
Supports loading from:
- A file object backed by a YAML file
- A file object backed by a Python source file that exports an attribute
"cfg" that is either a dict or a CfgNode
- A string that can be parsed as valid YAML
"""
_assert_with_logging(
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str, )),
"Expected first argument to be of type {} or {}, but it was {}".
format(_FILE_TYPES, str, type(cfg_file_obj_or_str)),
)
if isinstance(cfg_file_obj_or_str, str):
return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
return cls._load_cfg_from_file(cfg_file_obj_or_str)
else:
raise NotImplementedError(
"Impossible to reach here (unless there's a bug)")
@classmethod
def _load_cfg_from_file(cls, file_obj):
"""Load a config from a YAML file or a Python source file."""
_, file_extension = os.path.splitext(file_obj.name)
if file_extension in _YAML_EXTS:
return cls._load_cfg_from_yaml_str(file_obj.read())
elif file_extension in _PY_EXTS:
return cls._load_cfg_py_source(file_obj.name)
else:
raise Exception(
"Attempt to load from an unsupported file type {}; "
"only {} are supported".format(file_obj,
_YAML_EXTS.union(_PY_EXTS)))
@classmethod
def _load_cfg_from_yaml_str(cls, str_obj):
"""Load a config from a YAML string encoding."""
cfg_as_dict = yaml.safe_load(str_obj)
return cls(cfg_as_dict)
@classmethod
def _load_cfg_py_source(cls, filename):
"""Load a config from a Python source file."""
module = _load_module_from_file("yacs.config.override", filename)
_assert_with_logging(
hasattr(module, "cfg"),
"Python module from file {} must have 'cfg' attr".format(filename),
)
VALID_ATTR_TYPES = {dict, CfgNode}
_assert_with_logging(
type(module.cfg) in VALID_ATTR_TYPES,
"Imported module 'cfg' attr must be in {} but is {} instead".
format(VALID_ATTR_TYPES, type(module.cfg)),
)
return cls(module.cfg)
@classmethod
def _decode_cfg_value(cls, value):
"""
Decodes a raw config value (e.g., from a yaml config files or command
line argument) into a Python object.
If the value is a dict, it will be interpreted as a new CfgNode.
If the value is a str, it will be evaluated as literals.
Otherwise it is returned as-is.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be
# converted to CfgNode objects
if isinstance(value, dict):
return cls(value)
# All remaining processing is only applied to strings
if not isinstance(value, str):
return value
# Try to interpret `value` as a:
# string, number, tuple, list, dict, boolean, or None
try:
value = literal_eval(value)
# The following two excepts allow v to pass through when it represents a
# string.
#
# Longer explanation:
# The type of v is always a string (before calling literal_eval), but
# sometimes it *represents* a string and other times a data structure, like
# a list. In the case that v represents a string, what we got back from the
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
# will raise a SyntaxError.
except ValueError:
pass
except SyntaxError:
pass
return value
load_cfg = (CfgNode.load_cfg
) # keep this function in global scope for backward compatibility
def _valid_type(value, allow_cfg_node=False):
return (type(value) in _VALID_TYPES) or (allow_cfg_node
and isinstance(value, CfgNode))
def _merge_a_into_b(a, b, root, key_list):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
_assert_with_logging(
isinstance(a, CfgNode),
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
)
_assert_with_logging(
isinstance(b, CfgNode),
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
)
for k, v_ in a.items():
full_key = ".".join(key_list + [k])
v = copy.deepcopy(v_)
v = b._decode_cfg_value(v)
if k in b:
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
# Recursively merge dicts
if isinstance(v, CfgNode):
try:
_merge_a_into_b(v, b[k], root, key_list + [k])
except BaseException:
raise
else:
b[k] = v
elif b.is_new_allowed():
b[k] = v
else:
if root.key_is_deprecated(full_key):
continue
elif root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
else:
raise KeyError("Non-existent config key: {}".format(full_key))
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
# The types must match (with some exceptions)
if replacement_type == original_type:
return replacement
# Cast replacement from from_type to to_type if the replacement and original
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
logger.warning('cast {} to {}', from_type, to_type)
return True, to_type(replacement)
else:
return False, None
# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple), (bool, int)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass
for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value
# if original_type == int and replacement_type == bool:
# logger.warning(
# "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
# "key: {}".format(
# original_type, replacement_type, original, replacement, full_key
# ))
# else:
raise ValueError(
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
"key: {}".format(original_type, replacement_type, original,
replacement, full_key))
def _assert_with_logging(cond, msg):
if not cond:
logger.debug(msg)
assert cond, msg
def _load_module_from_file(name, filename):
if _PY2:
module = imp.load_source(name, filename)
else:
spec = importlib.util.spec_from_file_location(name, filename)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def same_cfg(cfg_node, cfg_other):
def flatten_dict(dd, sep='_', pf=''):
return {pf+sep+k if pf else k: v for kk, vv in dd.items()
for k, v in flatten_dict(vv, sep, kk).items()} \
if isinstance(dd, dict) else {pf: dd}
node_s = flatten_dict(cfg_node)
other_s = flatten_dict(cfg_other)
k0, k1 = list(node_s.keys()), list(other_s.keys())
if sorted(k0) != sorted(k1):
print(f'[LEN]: {len(k0)} VS {len(k1)}')
k_diff1 = [i for i in k0 if i not in k1]
k_diff0 = [i for i in k1 if i not in k0]
print(f'[DIFF] keys: {k_diff1}; {k_diff0}')
assert (False), 'Diff key'
return False
for k, v in node_s.items():
if k == 'exp_key':
continue
if other_s[k] != v:
msg = f'{k}: {v}; {other_s[k]}'
logger.info(msg)
assert (False), 'Diff key value ' + msg
return False
return True