Coverage for agentlib_flexquant/utils/config_management.py: 99%

123 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-06-17 09:09 +0000

1import importlib.util 

2import inspect 

3import math 

4import os 

5from abc import ABCMeta 

6from copy import deepcopy 

7from typing import TypeVar, Union, Optional 

8 

9from agentlib.core.agent import AgentConfig 

10from agentlib.core.module import BaseModuleConfig 

11from agentlib.modules import get_all_module_types 

12 

13 

14T = TypeVar("T", bound=BaseModuleConfig) 

15 

16MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc" 

17BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc" 

18SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc" 

19BASELINEMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_minlp_mpc" 

20SHADOWMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_minlp_mpc" 

21INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator" 

22MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market" 

23SIMULATOR_CONFIG_TYPE: str = "simulator" 

24 

25 

26class ModuleHandler: 

27 def __init__( 

28 self, extra_plugins: Optional[list[str]] = None, exclude_ml_plugins: bool = True, exclude_clonemap_plugin: bool = True 

29 ): 

30 """ 

31 Manages discovery and lookup of AgentLib module types and their configuration models. 

32 

33 The handler builds registries of available modules from a set of plugin packages, 

34 optionally excluding slow-to-import modules (e.g., ML trainers). 

35 The get_module() function provided allows to get the corresponding module from a config based on its name. 

36 

37 Args: 

38 extra_plugins: Optional list of additional plugin package names to include in 

39 module discovery (in addition to the default plugins). 

40 exclude_ml_plugins: If True, excludes ML-related agentlib_mpc modules that are 

41 expensive/slow to import. 

42 exclude_clonemap_plugin: If True, excludes the "clonemap" module type (not used 

43 by FlexQuant). 

44 

45 ModuleHandler registries/mappings: 

46 - module_type_dict: module_type string -> corresponding agent config from loaded plugins. 

47 - module_name_dict: module_type string -> corresponding module 

48 - baseline_module_type_dict: agentlib_mpc type -> FlexQuant baseline type mapping 

49 - shadow_module_type_dict: agentlib_mpc type -> FlexQuant shadow-MPC type mapping 

50 

51 """ 

52 

53 default_plugins = ["agentlib_mpc", "agentlib_flexquant"] 

54 extra_plugins = extra_plugins or [] 

55 self.plugin_modules = [] 

56 for p in default_plugins + extra_plugins: 

57 if p not in self.plugin_modules: 

58 self.plugin_modules.append(p) 

59 

60 self.exclude_ml_plugins = exclude_ml_plugins 

61 self.exclude_clonemap_plugin = exclude_clonemap_plugin 

62 

63 self.module_type_dict = {} 

64 self.module_name_dict = {} 

65 self.baseline_module_type_dict = {} 

66 self.shadow_module_type_dict = {} 

67 

68 self.generate_module_dicts() 

69 

70 def generate_module_dicts(self): 

71 all_module_types = get_all_module_types(self.plugin_modules) 

72 

73 # remove ML models, since import takes ages 

74 if self.exclude_ml_plugins: 

75 all_module_types.pop("agentlib_mpc.ann_trainer", None) 

76 all_module_types.pop("agentlib_mpc.gpr_trainer", None) 

77 all_module_types.pop("agentlib_mpc.linreg_trainer", None) 

78 all_module_types.pop("agentlib_mpc.ml_simulator", None) 

79 all_module_types.pop("agentlib_mpc.set_point_generator", None) 

80 

81 # remove clonemap since not used 

82 if self.exclude_clonemap_plugin: 

83 all_module_types.pop("clonemap", None) 

84 

85 # dictionary mapping the module name to the module config (ModelMetaclass) 

86 self.module_type_dict = { 

87 name: inspect.get_annotations(class_type.import_class())["config"] 

88 for name, class_type in all_module_types.items() 

89 } 

90 # dictionary mapping the module name to the module (ModuleImport) 

91 self.module_name_dict = all_module_types 

92 

93 # get baseline and shadow module types 

94 self.baseline_module_type_dict, self.shadow_module_type_dict = ( 

95 get_module_type_matching_dict(self.module_name_dict) 

96 ) 

97 

98 def get_module(self, config: AgentConfig, module_type: str) -> T: 

99 """Extracts a module from a config based on its name.""" 

100 for module in config.modules: 

101 if module["type"] == module_type: 

102 # deepcopy -> avoid changing the original config, when editing the module 

103 # deepcopy the args of the constructor instead of the module object, 

104 # because the simulator module exceeds the recursion limit 

105 config_id = deepcopy(config.id) 

106 mod = deepcopy(module) 

107 return self.module_type_dict[mod["type"]](**mod, _agent_id=config_id) 

108 else: 

109 raise ModuleNotFoundError( 

110 f"Module type {module_type} not found in " 

111 f"agentlib and its plug ins." 

112 ) 

113 

114 def get_flex_mpc_module_config( 

115 self, 

116 agent_config: AgentConfig, 

117 module_config: BaseModuleConfig, 

118 module_type: str, 

119 ): 

120 """Get a flexquant module config from an original agentlib module config.""" 

121 config_dict = module_config.model_dump() 

122 config_dict["type"] = module_type 

123 flex_config_dict = self.module_type_dict[module_type]( 

124 **config_dict, _agent_id=agent_config.id 

125 ) 

126 # HOTFIX due to AgentLib-MPC bug. Needs to be adapted after Objectives 

127 # in AgentLib-MPC are fixed. 

128 if flex_config_dict.r_del_u is None: 

129 flex_config_dict = flex_config_dict.model_copy(update={"r_del_u": {}}) 

130 return flex_config_dict 

131 

132 

133def get_module_type_matching_dict(dictionary: dict) -> (dict, dict): 

134 """Create two dictionaries, which map the modules types of the agentlib_mpc modules 

135 to those of the flexquant modules. 

136 

137 This is done by using the module_name_dict. 

138 

139 """ 

140 # Create dictionaries to store keys grouped by values 

141 value_to_keys = {} 

142 for k, v in dictionary.items(): 

143 if k.startswith("agentlib_mpc."): 

144 if v not in value_to_keys: 

145 value_to_keys[v] = {"agentlib": [], "flex": []} 

146 value_to_keys[v]["agentlib"].append(k) 

147 if k.startswith("agentlib_flexquant."): 

148 # find the parent class of the module in the flexquant in agentlib_mpc 

149 for vv in value_to_keys: 

150 if vv.import_class() is v.import_class().__bases__[0]: 

151 value_to_keys[vv]["flex"].append(k) 

152 break 

153 

154 # Create result dictionaries 

155 baseline_matches = {} 

156 shadow_matches = {} 

157 

158 for v, keys in value_to_keys.items(): 

159 # Check if we have both agentlib and flexibility keys for this value 

160 if keys["agentlib"] and keys["flex"]: 

161 # Map each agentlib key to corresponding flexibility key 

162 for agent_key in keys["agentlib"]: 

163 for flex_key in keys["flex"]: 

164 if "baseline" in flex_key: 

165 baseline_matches[agent_key] = flex_key 

166 elif "shadow" in flex_key: 

167 shadow_matches[agent_key] = flex_key 

168 

169 return baseline_matches, shadow_matches 

170 

171def get_orig_module_type(config: AgentConfig) -> str: 

172 """Return the config type of the original MPC.""" 

173 for module in config.modules: 

174 if module["type"].startswith("agentlib_mpc"): 

175 return module["type"] 

176 

177def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig) -> dict: 

178 """Remove unnecessary fields from the module to keep the created json simple.""" 

179 excluded_fields = [ 

180 "rdf_class", 

181 "source", 

182 "type", 

183 "timestamp", 

184 "description", 

185 "unit", 

186 "clip", 

187 "shared", 

188 "interpolation_method", 

189 "allowed_values", 

190 ] 

191 

192 def check_bounds(parameter): 

193 delete_list = excluded_fields.copy() 

194 if parameter.lb == -math.inf: 

195 delete_list.append("lb") 

196 if parameter.ub == math.inf: 

197 delete_list.append("ub") 

198 return delete_list 

199 

200 parent_dict = module.model_dump(exclude_defaults=True) 

201 # update every variable with a dict excluding the defined fields 

202 if "parameters" in parent_dict: 

203 parent_dict["parameters"] = [ 

204 parameter.dict(exclude=check_bounds(parameter)) for 

205 parameter in module.parameters 

206 ] 

207 if "inputs" in parent_dict: 

208 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for 

209 input in module.inputs] 

210 if "outputs" in parent_dict: 

211 parent_dict["outputs"] = [ 

212 output.dict(exclude=check_bounds(output)) for output in module.outputs 

213 ] 

214 if "controls" in parent_dict: 

215 parent_dict["controls"] = [ 

216 control.dict(exclude=check_bounds(control)) for control in module.controls 

217 ] 

218 if "binary_controls" in parent_dict: 

219 parent_dict["binary_controls"] = [ 

220 binary_control.dict(exclude=check_bounds(binary_control)) 

221 for binary_control in module.binary_controls 

222 ] 

223 if "states" in parent_dict: 

224 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for 

225 state in module.states] 

226 if "full_controls" in parent_dict: 

227 parent_dict["full_controls"] = [ 

228 full_control.dict( 

229 exclude=(lambda ex: 

230 ex.remove("shared") or ex)(check_bounds(full_control)) 

231 ) 

232 for full_control in module.full_controls 

233 ] 

234 if "vars_to_communicate" in parent_dict: 

235 parent_dict["vars_to_communicate"] = [ 

236 var_to_communicate.dict( 

237 exclude=(lambda ex: 

238 ex.remove("shared") or ex)(check_bounds(var_to_communicate)) 

239 ) 

240 for var_to_communicate in module.vars_to_communicate 

241 ] 

242 

243 return parent_dict 

244 

245def get_class_from_file(file_path: str, class_name: str) -> ABCMeta: 

246 # Get the absolute path if needed 

247 abs_path = os.path.abspath(file_path) 

248 

249 # Get the module name from the file path 

250 module_name = os.path.splitext(os.path.basename(file_path))[0] 

251 

252 # Load the module specification 

253 spec = importlib.util.spec_from_file_location(module_name, abs_path) 

254 

255 # Create the module 

256 module = importlib.util.module_from_spec(spec) 

257 

258 # Execute the module 

259 spec.loader.exec_module(module) 

260 

261 # Get the class from the module 

262 target_class = getattr(module, class_name) 

263 

264 return target_class