Coverage for agentlib_flexquant/utils/config_management.py: 99%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-03-26 09:43 +0000

1import importlib.util 

2import inspect 

3import math 

4import os 

5from abc import ABCMeta 

6from copy import deepcopy 

7from typing import TypeVar 

8 

9from agentlib.core.agent import AgentConfig 

10from agentlib.core.module import BaseModuleConfig 

11from agentlib.modules import get_all_module_types 

12 

13T = TypeVar("T", bound=BaseModuleConfig) 

14 

15all_module_types = get_all_module_types(["agentlib_mpc", "agentlib_flexquant"]) 

16# remove ML models, since import takes ages 

17all_module_types.pop("agentlib_mpc.ann_trainer") 

18all_module_types.pop("agentlib_mpc.gpr_trainer") 

19all_module_types.pop("agentlib_mpc.linreg_trainer") 

20all_module_types.pop("agentlib_mpc.ml_simulator") 

21all_module_types.pop("agentlib_mpc.set_point_generator") 

22# remove clone since not used 

23all_module_types.pop("clonemap") 

24 

25# dictionary mapping the module name to the module config (ModelMetaclass) 

26MODULE_TYPE_DICT = { 

27 name: inspect.get_annotations(class_type.import_class())["config"] 

28 for name, class_type in all_module_types.items() 

29} 

30# dictionary mapping the module name to the module (ModuleImport) 

31MODULE_NAME_DICT = all_module_types 

32 

33MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc" 

34BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc" 

35SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc" 

36BASELINEMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_minlp_mpc" 

37SHADOWMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_minlp_mpc" 

38INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator" 

39MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market" 

40SIMULATOR_CONFIG_TYPE: str = "simulator" 

41 

42 

43def get_module_type_matching_dict(dictionary: dict) -> (dict, dict): 

44 """Create two dictionaries, which map the modules types of the agentlib_mpc modules 

45 to those of the flexquant modules. 

46 

47 This is done by using the MODULE_NAME_DICT. 

48 

49 """ 

50 # Create dictionaries to store keys grouped by values 

51 value_to_keys = {} 

52 for k, v in dictionary.items(): 

53 if k.startswith("agentlib_mpc."): 

54 if v not in value_to_keys: 

55 value_to_keys[v] = {"agentlib": [], "flex": []} 

56 value_to_keys[v]["agentlib"].append(k) 

57 if k.startswith("agentlib_flexquant."): 

58 # find the parent class of the module in the flexquant in agentlib_mpc 

59 for vv in value_to_keys: 

60 if vv.import_class() is v.import_class().__bases__[0]: 

61 value_to_keys[vv]["flex"].append(k) 

62 break 

63 

64 # Create result dictionaries 

65 baseline_matches = {} 

66 shadow_matches = {} 

67 

68 for v, keys in value_to_keys.items(): 

69 # Check if we have both agentlib and flexibility keys for this value 

70 if keys["agentlib"] and keys["flex"]: 

71 # Map each agentlib key to corresponding flexibility key 

72 for agent_key in keys["agentlib"]: 

73 for flex_key in keys["flex"]: 

74 if "baseline" in flex_key: 

75 baseline_matches[agent_key] = flex_key 

76 elif "shadow" in flex_key: 

77 shadow_matches[agent_key] = flex_key 

78 

79 return baseline_matches, shadow_matches 

80 

81 

82BASELINE_MODULE_TYPE_DICT, SHADOW_MODULE_TYPE_DICT = ( 

83 get_module_type_matching_dict(MODULE_NAME_DICT)) 

84 

85 

86def get_orig_module_type(config: AgentConfig) -> str: 

87 """Return the config type of the original MPC.""" 

88 for module in config.modules: 

89 if module["type"].startswith("agentlib_mpc"): 

90 return module["type"] 

91 

92 

93def get_module(config: AgentConfig, module_type: str) -> T: 

94 """Extracts a module from a config based on its name.""" 

95 for module in config.modules: 

96 if module["type"] == module_type: 

97 # deepcopy -> avoid changing the original config, when editing the module 

98 # deepcopy the args of the constructor instead of the module object, 

99 # because the simulator module exceeds the recursion limit 

100 config_id = deepcopy(config.id) 

101 mod = deepcopy(module) 

102 return MODULE_TYPE_DICT[mod["type"]](**mod, _agent_id=config_id) 

103 else: 

104 raise ModuleNotFoundError( 

105 f"Module type {module['type']} not found in " f"agentlib and its plug ins." 

106 ) 

107 

108 

109def get_flex_mpc_module_config( 

110 agent_config: AgentConfig, module_config: BaseModuleConfig, module_type: str 

111): 

112 """Get a flexquant module config from an original agentlib module config.""" 

113 config_dict = module_config.model_dump() 

114 config_dict["type"] = module_type 

115 flex_config_dict = MODULE_TYPE_DICT[module_type](**config_dict, 

116 _agent_id=agent_config.id) 

117 # HOTFIX due to AgentLib-MPC bug. Needs to be adapted after Objectives 

118 # in AgentLib-MPC are fixed. 

119 if flex_config_dict.r_del_u is None: 

120 flex_config_dict = flex_config_dict.model_copy(update={"r_del_u": {}}) 

121 return flex_config_dict 

122 

123 

124def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig) -> dict: 

125 """Remove unnecessary fields from the module to keep the created json simple.""" 

126 excluded_fields = [ 

127 "rdf_class", 

128 "source", 

129 "type", 

130 "timestamp", 

131 "description", 

132 "unit", 

133 "clip", 

134 "shared", 

135 "interpolation_method", 

136 "allowed_values", 

137 ] 

138 

139 def check_bounds(parameter): 

140 delete_list = excluded_fields.copy() 

141 if parameter.lb == -math.inf: 

142 delete_list.append("lb") 

143 if parameter.ub == math.inf: 

144 delete_list.append("ub") 

145 return delete_list 

146 

147 parent_dict = module.model_dump(exclude_defaults=True) 

148 # update every variable with a dict excluding the defined fields 

149 if "parameters" in parent_dict: 

150 parent_dict["parameters"] = [ 

151 parameter.dict(exclude=check_bounds(parameter)) for 

152 parameter in module.parameters 

153 ] 

154 if "inputs" in parent_dict: 

155 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for 

156 input in module.inputs] 

157 if "outputs" in parent_dict: 

158 parent_dict["outputs"] = [ 

159 output.dict(exclude=check_bounds(output)) for output in module.outputs 

160 ] 

161 if "controls" in parent_dict: 

162 parent_dict["controls"] = [ 

163 control.dict(exclude=check_bounds(control)) for control in module.controls 

164 ] 

165 if "binary_controls" in parent_dict: 

166 parent_dict["binary_controls"] = [ 

167 binary_control.dict(exclude=check_bounds(binary_control)) 

168 for binary_control in module.binary_controls 

169 ] 

170 if "states" in parent_dict: 

171 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for 

172 state in module.states] 

173 if "full_controls" in parent_dict: 

174 parent_dict["full_controls"] = [ 

175 full_control.dict( 

176 exclude=(lambda ex: 

177 ex.remove("shared") or ex)(check_bounds(full_control)) 

178 ) 

179 for full_control in module.full_controls 

180 ] 

181 if "vars_to_communicate" in parent_dict: 

182 parent_dict["vars_to_communicate"] = [ 

183 var_to_communicate.dict( 

184 exclude=(lambda ex: 

185 ex.remove("shared") or ex)(check_bounds(var_to_communicate)) 

186 ) 

187 for var_to_communicate in module.vars_to_communicate 

188 ] 

189 

190 return parent_dict 

191 

192 

193def get_class_from_file(file_path: str, class_name: str) -> ABCMeta: 

194 # Get the absolute path if needed 

195 abs_path = os.path.abspath(file_path) 

196 

197 # Get the module name from the file path 

198 module_name = os.path.splitext(os.path.basename(file_path))[0] 

199 

200 # Load the module specification 

201 spec = importlib.util.spec_from_file_location(module_name, abs_path) 

202 

203 # Create the module 

204 module = importlib.util.module_from_spec(spec) 

205 

206 # Execute the module 

207 spec.loader.exec_module(module) 

208 

209 # Get the class from the module 

210 target_class = getattr(module, class_name) 

211 

212 return target_class