Source code for pyabsa.framework.configuration_class.configuration_template

# -*- coding: utf-8 -*-
# file: checkpoint_template.py
# time: 02/11/2022 15:44
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# GScholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# ResearchGate: https://www.researchgate.net/profile/Heng-Yang-17/research
# Copyright (C) 2022. All Rights Reserved.

from argparse import Namespace
from pyabsa.framework.configuration_class.config_verification import config_check
from pyabsa.utils.pyabsa_utils import fprint


[docs] class ConfigManager(Namespace): def __init__(self, args=None, **kwargs): """ The ConfigManager is a subclass of argparse.Namespace and based on a parameter dict. It also counts the call-frequency of each parameter. :param args: A parameter dict. :param kwargs: Same params as Namespace. """ if not args: args = {} super().__init__(**kwargs) if isinstance(args, Namespace): self.args = vars(args) self.args_call_count = {arg: 0 for arg in vars(args)} else: self.args = args self.args_call_count = {arg: 0 for arg in args}
[docs] def __getattribute__(self, arg_name): """ Get the value of an argument and increment its call count. :param arg_name: The name of the argument. :return: The value of the argument. """ if arg_name == "args" or arg_name == "args_call_count": return super().__getattribute__(arg_name) try: value = super().__getattribute__("args")[arg_name] args_call_count = super().__getattribute__("args_call_count") args_call_count[arg_name] += 1 super().__setattr__("args_call_count", args_call_count) return value except Exception as e: return super().__getattribute__(arg_name)
[docs] def __setattr__(self, arg_name, value): """ Set the value of an argument and add it to the argument dict and call count dict. :param arg_name: The name of the argument. :param value: The value of the argument. """ if arg_name == "args" or arg_name == "args_call_count": super().__setattr__(arg_name, value) return try: args = super().__getattribute__("args") args[arg_name] = value super().__setattr__("args", args) args_call_count = super().__getattribute__("args_call_count") if arg_name in args_call_count: super().__setattr__("args_call_count", args_call_count) else: args_call_count[arg_name] = 0 super().__setattr__("args_call_count", args_call_count) except Exception as e: super().__setattr__(arg_name, value)
[docs] def get(self, key, default=None): """ Get the value of a key from the parameter dict. If the key is found, increment its call frequency. :param key: The key to look for in the parameter dict. :param default: The default value to return if the key is not found. :return: The value of the key in the parameter dict, or the default value if the key is not found. """ if key in self.args_call_count: self.args_call_count[key] += 1 return self.args.get(key, default)
[docs] def update(self, *args, **kwargs): """ Update the parameter dict with the given arguments and keyword arguments, and check if the updated configuration is valid. :param args: Positional arguments to update the parameter dict. :param kwargs: Keyword arguments to update the parameter dict. """ self.args.update(*args, **kwargs) config_check(self.args)
[docs] def pop(self, *args): """ Pop a value from the parameter dict. :param args: Arguments to pop from the parameter dict. :return: The value popped from the parameter dict. """ return self.args.pop(*args)
[docs] def keys(self): """ Get a list of all keys in the parameter dict. :return: A list of all keys in the parameter dict. """ return self.args.keys()
[docs] def values(self): """ Get a list of all values in the parameter dict. :return: A list of all values in the parameter dict. """ return self.args.values()
[docs] def items(self): """ Get a list of all key-value pairs in the parameter dict. :return: A list of all key-value pairs in the parameter dict. """ return self.args.items()
[docs] def __str__(self): """ Get a string representation of the parameter dict. :return: A string representation of the parameter dict. """ return str(self.args)
[docs] def __repr__(self): """ Return the string representation of the parameter dict. """ return repr(self.args)
[docs] def __len__(self): """ Return the number of items in the parameter dict. """ return len(self.args)
[docs] def __iter__(self): """ Return an iterator over the keys of the parameter dict. """ return iter(self.args)
[docs] def __contains__(self, item): """ Check if the given item is in the parameter dict. :param item: The item to check. :return: True if the item is in the parameter dict, False otherwise. """ return item in self.args
[docs] def __getitem__(self, item): """ Get the value of a key from the parameter dict. :param item: The key to look for in the parameter dict. :return: The value of the key in the parameter dict. """ return self.args[item]
[docs] def __setitem__(self, key, value): """ Set the value of a key in the parameter dict. Also set the call frequency of the key to 0 and check if the updated configuration is valid. :param key: The key to set the value for in the parameter dict. :param value: The value to set for the key in the parameter dict. """ self.args[key] = value self.args_call_count[key] = 0 config_check(self.args)
[docs] def __delitem__(self, key): """ Delete a key-value pair from the parameter dict and check if the updated configuration is valid. :param key: The key to delete from the parameter dict. """ del self.args[key] config_check(self.args)
[docs] def __eq__(self, other): """ Check if the parameter dict is equal to another object. :param other: The other object to compare with the parameter dict. :return: True if the parameter dict is equal to the other object, False otherwise. """ return self.args == other
[docs] def __ne__(self, other): """ Check if the parameter dict is not equal to another object. :param other: The other object to compare with the parameter dict. :return: True if the parameter dict is not equal to the other object, False otherwise. """ return self.args != other
if __name__ == "__main__": # test
[docs] config = ConfigManager({"a": 1, "b": 2})
config.a = 2 config.b = 3 config.c = 4 fprint(config.a) fprint(config.b) fprint(config.c) fprint(config.args_call_count)