Coverage for agentlib_flexquant/utils/config_management.py: 99%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-15 15:25 +0000

1import math 

2import inspect 

3import os 

4import importlib.util 

5from copy import deepcopy 

6from typing import TypeVar 

7from abc import ABCMeta 

8from agentlib.modules import get_all_module_types 

9from agentlib.core.agent import AgentConfig 

10from agentlib.core.module import BaseModuleConfig 

11 

12 

13T = TypeVar('T', bound=BaseModuleConfig) 

14 

15all_module_types = get_all_module_types(["agentlib_mpc", "agentlib_flexquant"]) 

16# remove ML models, since import takes ages 

17all_module_types.pop("agentlib_mpc.ann_trainer") 

18all_module_types.pop("agentlib_mpc.gpr_trainer") 

19all_module_types.pop("agentlib_mpc.linreg_trainer") 

20all_module_types.pop("agentlib_mpc.ann_simulator") 

21all_module_types.pop("agentlib_mpc.set_point_generator") 

22# remove clone since not used 

23all_module_types.pop("clonemap") 

24 

25MODULE_TYPE_DICT = {name: inspect.get_annotations(class_type.import_class())["config"] for name, class_type in all_module_types.items()} 

26 

27MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc" 

28BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc" 

29SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc" 

30INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator" 

31MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market" 

32SIMULATOR_CONFIG_TYPE: str = "simulator" 

33 

34 

35def get_module_type_matching_dict(dictionary: dict) -> (dict, dict): 

36 """Create two dictionaries, which map the modules types of the agentlib_mpc modules to those of the flexquant modules. 

37 

38 This is done by using the MODULE_TYPE_DICT. 

39 

40 """ 

41 # Create dictionaries to store keys grouped by values 

42 value_to_keys = {} 

43 for k, v in dictionary.items(): 

44 if k.startswith(('agentlib_mpc.', 'agentlib_flexquant.')): 

45 if v not in value_to_keys: 

46 value_to_keys[v] = {'agentlib': [], 'flex': []} 

47 if k.startswith('agentlib_mpc.'): 

48 value_to_keys[v]['agentlib'].append(k) 

49 else: 

50 value_to_keys[v]['flex'].append(k) 

51 

52 # Create result dictionaries 

53 baseline_matches = {} 

54 shadow_matches = {} 

55 

56 for v, keys in value_to_keys.items(): 

57 # Check if we have both agentlib and flexibility keys for this value 

58 if keys['agentlib'] and keys['flex']: 

59 # Map each agentlib key to corresponding flexibility key 

60 for agent_key in keys['agentlib']: 

61 for flex_key in keys['flex']: 

62 if 'baseline' in flex_key: 

63 baseline_matches[agent_key] = flex_key 

64 elif 'shadow' in flex_key: 

65 shadow_matches[agent_key] = flex_key 

66 

67 return baseline_matches, shadow_matches 

68 

69 

70BASELINE_MODULE_TYPE_DICT, SHADOW_MODULE_TYPE_DICT = ( 

71 get_module_type_matching_dict(MODULE_TYPE_DICT)) 

72 

73 

74def get_orig_module_type(config: AgentConfig) -> str: 

75 """Return the config type of the original MPC.""" 

76 for module in config.modules: 

77 if module["type"].startswith("agentlib_mpc"): 

78 return module["type"] 

79 

80 

81def get_module(config: AgentConfig, module_type: str) -> T: 

82 """Extracts a module from a config based on its name.""" 

83 for module in config.modules: 

84 if module["type"] == module_type: 

85 # deepcopy -> avoid changing the original config, when editing the module 

86 # deepcopy the args of the constructor instead of the module object, 

87 # because the simulator module exceeds the recursion limit 

88 config_id = deepcopy(config.id) 

89 mod = deepcopy(module) 

90 return MODULE_TYPE_DICT[mod["type"]](**mod, _agent_id=config_id) 

91 else: 

92 raise ModuleNotFoundError(f"Module type {module['type']} not found in " 

93 f"agentlib and its plug ins.") 

94 

95 

96def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig) -> dict: 

97 """Remove unnecessary fields from the module to keep the created json simple.""" 

98 excluded_fields = ["rdf_class", "source", "type", "timestamp", "description", "unit", "clip", 

99 "shared", "interpolation_method", "allowed_values"] 

100 

101 def check_bounds(parameter): 

102 delete_list = excluded_fields.copy() 

103 if parameter.lb == -math.inf: 

104 delete_list.append("lb") 

105 if parameter.ub == math.inf: 

106 delete_list.append("ub") 

107 return delete_list 

108 

109 parent_dict = module.dict(exclude_defaults=True) 

110 # update every variable with a dict excluding the defined fields 

111 if "parameters" in parent_dict: 

112 parent_dict["parameters"] = [parameter.dict(exclude=check_bounds(parameter)) for parameter in module.parameters] 

113 if "inputs" in parent_dict: 

114 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for input in module.inputs] 

115 if "outputs" in parent_dict: 

116 parent_dict["outputs"] = [output.dict(exclude=check_bounds(output)) for output in module.outputs] 

117 if "controls" in parent_dict: 

118 parent_dict["controls"] = [control.dict(exclude=check_bounds(control)) for control in module.controls] 

119 if "states" in parent_dict: 

120 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for state in module.states] 

121 

122 return parent_dict 

123 

124 

125def get_class_from_file(file_path: str, class_name: str) -> ABCMeta: 

126 # Get the absolute path if needed 

127 abs_path = os.path.abspath(file_path) 

128 

129 # Get the module name from the file path 

130 module_name = os.path.splitext(os.path.basename(file_path))[0] 

131 

132 # Load the module specification 

133 spec = importlib.util.spec_from_file_location(module_name, abs_path) 

134 

135 # Create the module 

136 module = importlib.util.module_from_spec(spec) 

137 

138 # Execute the module 

139 spec.loader.exec_module(module) 

140 

141 # Get the class from the module 

142 target_class = getattr(module, class_name) 

143 

144 return target_class