Coverage for agentlib_flexquant/utils/config_management.py: 99%

100 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-10-20 14:09 +0000

1import importlib.util 

2import inspect 

3import math 

4import os 

5from abc import ABCMeta 

6from copy import deepcopy 

7from typing import TypeVar 

8 

9from agentlib.core.agent import AgentConfig 

10from agentlib.core.module import BaseModuleConfig 

11from agentlib.modules import get_all_module_types 

12 

13T = TypeVar("T", bound=BaseModuleConfig) 

14 

15all_module_types = get_all_module_types(["agentlib_mpc", "agentlib_flexquant"]) 

16# remove ML models, since import takes ages 

17all_module_types.pop("agentlib_mpc.ann_trainer") 

18all_module_types.pop("agentlib_mpc.gpr_trainer") 

19all_module_types.pop("agentlib_mpc.linreg_trainer") 

20all_module_types.pop("agentlib_mpc.ann_simulator") 

21all_module_types.pop("agentlib_mpc.set_point_generator") 

22# remove clone since not used 

23all_module_types.pop("clonemap") 

24 

25# dictionary mapping the module name to the module config (ModelMetaclass) 

26MODULE_TYPE_DICT = { 

27 name: inspect.get_annotations(class_type.import_class())["config"] 

28 for name, class_type in all_module_types.items() 

29} 

30# dictionary mapping the module name to the module (ModuleImport) 

31MODULE_NAME_DICT = all_module_types 

32 

33MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc" 

34BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc" 

35SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc" 

36BASELINEMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_minlp_mpc" 

37SHADOWMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_minlp_mpc" 

38INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator" 

39MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market" 

40SIMULATOR_CONFIG_TYPE: str = "simulator" 

41 

42 

43def get_module_type_matching_dict(dictionary: dict) -> (dict, dict): 

44 """Create two dictionaries, which map the modules types of the agentlib_mpc modules 

45 to those of the flexquant modules. 

46 

47 This is done by using the MODULE_NAME_DICT. 

48 

49 """ 

50 # Create dictionaries to store keys grouped by values 

51 value_to_keys = {} 

52 for k, v in dictionary.items(): 

53 if k.startswith("agentlib_mpc."): 

54 if v not in value_to_keys: 

55 value_to_keys[v] = {"agentlib": [], "flex": []} 

56 value_to_keys[v]["agentlib"].append(k) 

57 if k.startswith("agentlib_flexquant."): 

58 # find the parent class of the module in the flexquant in agentlib_mpc 

59 for vv in value_to_keys: 

60 if vv.import_class() is v.import_class().__bases__[0]: 

61 value_to_keys[vv]["flex"].append(k) 

62 break 

63 

64 # Create result dictionaries 

65 baseline_matches = {} 

66 shadow_matches = {} 

67 

68 for v, keys in value_to_keys.items(): 

69 # Check if we have both agentlib and flexibility keys for this value 

70 if keys["agentlib"] and keys["flex"]: 

71 # Map each agentlib key to corresponding flexibility key 

72 for agent_key in keys["agentlib"]: 

73 for flex_key in keys["flex"]: 

74 if "baseline" in flex_key: 

75 baseline_matches[agent_key] = flex_key 

76 elif "shadow" in flex_key: 

77 shadow_matches[agent_key] = flex_key 

78 

79 return baseline_matches, shadow_matches 

80 

81 

82BASELINE_MODULE_TYPE_DICT, SHADOW_MODULE_TYPE_DICT = get_module_type_matching_dict(MODULE_NAME_DICT) 

83 

84 

85def get_orig_module_type(config: AgentConfig) -> str: 

86 """Return the config type of the original MPC.""" 

87 for module in config.modules: 

88 if module["type"].startswith("agentlib_mpc"): 

89 return module["type"] 

90 

91 

92def get_module(config: AgentConfig, module_type: str) -> T: 

93 """Extracts a module from a config based on its name.""" 

94 for module in config.modules: 

95 if module["type"] == module_type: 

96 # deepcopy -> avoid changing the original config, when editing the module 

97 # deepcopy the args of the constructor instead of the module object, 

98 # because the simulator module exceeds the recursion limit 

99 config_id = deepcopy(config.id) 

100 mod = deepcopy(module) 

101 return MODULE_TYPE_DICT[mod["type"]](**mod, _agent_id=config_id) 

102 else: 

103 raise ModuleNotFoundError( 

104 f"Module type {module['type']} not found in " f"agentlib and its plug ins." 

105 ) 

106 

107 

108def get_flex_mpc_module_config( 

109 agent_config: AgentConfig, module_config: BaseModuleConfig, module_type: str 

110): 

111 """Get a flexquant module config from an original agentlib module config.""" 

112 config_dict = module_config.model_dump() 

113 config_dict["type"] = module_type 

114 return MODULE_TYPE_DICT[module_type](**config_dict, _agent_id=agent_config.id) 

115 

116 

117def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig) -> dict: 

118 """Remove unnecessary fields from the module to keep the created json simple.""" 

119 excluded_fields = [ 

120 "rdf_class", 

121 "source", 

122 "type", 

123 "timestamp", 

124 "description", 

125 "unit", 

126 "clip", 

127 "shared", 

128 "interpolation_method", 

129 "allowed_values", 

130 ] 

131 

132 def check_bounds(parameter): 

133 delete_list = excluded_fields.copy() 

134 if parameter.lb == -math.inf: 

135 delete_list.append("lb") 

136 if parameter.ub == math.inf: 

137 delete_list.append("ub") 

138 return delete_list 

139 

140 parent_dict = module.model_dump(exclude_defaults=True) 

141 # update every variable with a dict excluding the defined fields 

142 if "parameters" in parent_dict: 

143 parent_dict["parameters"] = [ 

144 parameter.dict(exclude=check_bounds(parameter)) for parameter in module.parameters 

145 ] 

146 if "inputs" in parent_dict: 

147 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for input in module.inputs] 

148 if "outputs" in parent_dict: 

149 parent_dict["outputs"] = [ 

150 output.dict(exclude=check_bounds(output)) for output in module.outputs 

151 ] 

152 if "controls" in parent_dict: 

153 parent_dict["controls"] = [ 

154 control.dict(exclude=check_bounds(control)) for control in module.controls 

155 ] 

156 if "binary_controls" in parent_dict: 

157 parent_dict["binary_controls"] = [ 

158 binary_control.dict(exclude=check_bounds(binary_control)) 

159 for binary_control in module.binary_controls 

160 ] 

161 if "states" in parent_dict: 

162 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for state in module.states] 

163 if "full_controls" in parent_dict: 

164 parent_dict["full_controls"] = [ 

165 full_control.dict( 

166 exclude=(lambda ex: ex.remove("shared") or ex)(check_bounds(full_control)) 

167 ) 

168 for full_control in module.full_controls 

169 ] 

170 

171 return parent_dict 

172 

173 

174def get_class_from_file(file_path: str, class_name: str) -> ABCMeta: 

175 # Get the absolute path if needed 

176 abs_path = os.path.abspath(file_path) 

177 

178 # Get the module name from the file path 

179 module_name = os.path.splitext(os.path.basename(file_path))[0] 

180 

181 # Load the module specification 

182 spec = importlib.util.spec_from_file_location(module_name, abs_path) 

183 

184 # Create the module 

185 module = importlib.util.module_from_spec(spec) 

186 

187 # Execute the module 

188 spec.loader.exec_module(module) 

189 

190 # Get the class from the module 

191 target_class = getattr(module, class_name) 

192 

193 return target_class