Coverage for agentlib_flexquant/utils/config_management.py: 99%
123 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-06-17 09:09 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-06-17 09:09 +0000
1import importlib.util
2import inspect
3import math
4import os
5from abc import ABCMeta
6from copy import deepcopy
7from typing import TypeVar, Union, Optional
9from agentlib.core.agent import AgentConfig
10from agentlib.core.module import BaseModuleConfig
11from agentlib.modules import get_all_module_types
14T = TypeVar("T", bound=BaseModuleConfig)
16MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc"
17BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc"
18SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc"
19BASELINEMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_minlp_mpc"
20SHADOWMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_minlp_mpc"
21INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator"
22MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market"
23SIMULATOR_CONFIG_TYPE: str = "simulator"
26class ModuleHandler:
27 def __init__(
28 self, extra_plugins: Optional[list[str]] = None, exclude_ml_plugins: bool = True, exclude_clonemap_plugin: bool = True
29 ):
30 """
31 Manages discovery and lookup of AgentLib module types and their configuration models.
33 The handler builds registries of available modules from a set of plugin packages,
34 optionally excluding slow-to-import modules (e.g., ML trainers).
35 The get_module() function provided allows to get the corresponding module from a config based on its name.
37 Args:
38 extra_plugins: Optional list of additional plugin package names to include in
39 module discovery (in addition to the default plugins).
40 exclude_ml_plugins: If True, excludes ML-related agentlib_mpc modules that are
41 expensive/slow to import.
42 exclude_clonemap_plugin: If True, excludes the "clonemap" module type (not used
43 by FlexQuant).
45 ModuleHandler registries/mappings:
46 - module_type_dict: module_type string -> corresponding agent config from loaded plugins.
47 - module_name_dict: module_type string -> corresponding module
48 - baseline_module_type_dict: agentlib_mpc type -> FlexQuant baseline type mapping
49 - shadow_module_type_dict: agentlib_mpc type -> FlexQuant shadow-MPC type mapping
51 """
53 default_plugins = ["agentlib_mpc", "agentlib_flexquant"]
54 extra_plugins = extra_plugins or []
55 self.plugin_modules = []
56 for p in default_plugins + extra_plugins:
57 if p not in self.plugin_modules:
58 self.plugin_modules.append(p)
60 self.exclude_ml_plugins = exclude_ml_plugins
61 self.exclude_clonemap_plugin = exclude_clonemap_plugin
63 self.module_type_dict = {}
64 self.module_name_dict = {}
65 self.baseline_module_type_dict = {}
66 self.shadow_module_type_dict = {}
68 self.generate_module_dicts()
70 def generate_module_dicts(self):
71 all_module_types = get_all_module_types(self.plugin_modules)
73 # remove ML models, since import takes ages
74 if self.exclude_ml_plugins:
75 all_module_types.pop("agentlib_mpc.ann_trainer", None)
76 all_module_types.pop("agentlib_mpc.gpr_trainer", None)
77 all_module_types.pop("agentlib_mpc.linreg_trainer", None)
78 all_module_types.pop("agentlib_mpc.ml_simulator", None)
79 all_module_types.pop("agentlib_mpc.set_point_generator", None)
81 # remove clonemap since not used
82 if self.exclude_clonemap_plugin:
83 all_module_types.pop("clonemap", None)
85 # dictionary mapping the module name to the module config (ModelMetaclass)
86 self.module_type_dict = {
87 name: inspect.get_annotations(class_type.import_class())["config"]
88 for name, class_type in all_module_types.items()
89 }
90 # dictionary mapping the module name to the module (ModuleImport)
91 self.module_name_dict = all_module_types
93 # get baseline and shadow module types
94 self.baseline_module_type_dict, self.shadow_module_type_dict = (
95 get_module_type_matching_dict(self.module_name_dict)
96 )
98 def get_module(self, config: AgentConfig, module_type: str) -> T:
99 """Extracts a module from a config based on its name."""
100 for module in config.modules:
101 if module["type"] == module_type:
102 # deepcopy -> avoid changing the original config, when editing the module
103 # deepcopy the args of the constructor instead of the module object,
104 # because the simulator module exceeds the recursion limit
105 config_id = deepcopy(config.id)
106 mod = deepcopy(module)
107 return self.module_type_dict[mod["type"]](**mod, _agent_id=config_id)
108 else:
109 raise ModuleNotFoundError(
110 f"Module type {module_type} not found in "
111 f"agentlib and its plug ins."
112 )
114 def get_flex_mpc_module_config(
115 self,
116 agent_config: AgentConfig,
117 module_config: BaseModuleConfig,
118 module_type: str,
119 ):
120 """Get a flexquant module config from an original agentlib module config."""
121 config_dict = module_config.model_dump()
122 config_dict["type"] = module_type
123 flex_config_dict = self.module_type_dict[module_type](
124 **config_dict, _agent_id=agent_config.id
125 )
126 # HOTFIX due to AgentLib-MPC bug. Needs to be adapted after Objectives
127 # in AgentLib-MPC are fixed.
128 if flex_config_dict.r_del_u is None:
129 flex_config_dict = flex_config_dict.model_copy(update={"r_del_u": {}})
130 return flex_config_dict
133def get_module_type_matching_dict(dictionary: dict) -> (dict, dict):
134 """Create two dictionaries, which map the modules types of the agentlib_mpc modules
135 to those of the flexquant modules.
137 This is done by using the module_name_dict.
139 """
140 # Create dictionaries to store keys grouped by values
141 value_to_keys = {}
142 for k, v in dictionary.items():
143 if k.startswith("agentlib_mpc."):
144 if v not in value_to_keys:
145 value_to_keys[v] = {"agentlib": [], "flex": []}
146 value_to_keys[v]["agentlib"].append(k)
147 if k.startswith("agentlib_flexquant."):
148 # find the parent class of the module in the flexquant in agentlib_mpc
149 for vv in value_to_keys:
150 if vv.import_class() is v.import_class().__bases__[0]:
151 value_to_keys[vv]["flex"].append(k)
152 break
154 # Create result dictionaries
155 baseline_matches = {}
156 shadow_matches = {}
158 for v, keys in value_to_keys.items():
159 # Check if we have both agentlib and flexibility keys for this value
160 if keys["agentlib"] and keys["flex"]:
161 # Map each agentlib key to corresponding flexibility key
162 for agent_key in keys["agentlib"]:
163 for flex_key in keys["flex"]:
164 if "baseline" in flex_key:
165 baseline_matches[agent_key] = flex_key
166 elif "shadow" in flex_key:
167 shadow_matches[agent_key] = flex_key
169 return baseline_matches, shadow_matches
171def get_orig_module_type(config: AgentConfig) -> str:
172 """Return the config type of the original MPC."""
173 for module in config.modules:
174 if module["type"].startswith("agentlib_mpc"):
175 return module["type"]
177def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig) -> dict:
178 """Remove unnecessary fields from the module to keep the created json simple."""
179 excluded_fields = [
180 "rdf_class",
181 "source",
182 "type",
183 "timestamp",
184 "description",
185 "unit",
186 "clip",
187 "shared",
188 "interpolation_method",
189 "allowed_values",
190 ]
192 def check_bounds(parameter):
193 delete_list = excluded_fields.copy()
194 if parameter.lb == -math.inf:
195 delete_list.append("lb")
196 if parameter.ub == math.inf:
197 delete_list.append("ub")
198 return delete_list
200 parent_dict = module.model_dump(exclude_defaults=True)
201 # update every variable with a dict excluding the defined fields
202 if "parameters" in parent_dict:
203 parent_dict["parameters"] = [
204 parameter.dict(exclude=check_bounds(parameter)) for
205 parameter in module.parameters
206 ]
207 if "inputs" in parent_dict:
208 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for
209 input in module.inputs]
210 if "outputs" in parent_dict:
211 parent_dict["outputs"] = [
212 output.dict(exclude=check_bounds(output)) for output in module.outputs
213 ]
214 if "controls" in parent_dict:
215 parent_dict["controls"] = [
216 control.dict(exclude=check_bounds(control)) for control in module.controls
217 ]
218 if "binary_controls" in parent_dict:
219 parent_dict["binary_controls"] = [
220 binary_control.dict(exclude=check_bounds(binary_control))
221 for binary_control in module.binary_controls
222 ]
223 if "states" in parent_dict:
224 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for
225 state in module.states]
226 if "full_controls" in parent_dict:
227 parent_dict["full_controls"] = [
228 full_control.dict(
229 exclude=(lambda ex:
230 ex.remove("shared") or ex)(check_bounds(full_control))
231 )
232 for full_control in module.full_controls
233 ]
234 if "vars_to_communicate" in parent_dict:
235 parent_dict["vars_to_communicate"] = [
236 var_to_communicate.dict(
237 exclude=(lambda ex:
238 ex.remove("shared") or ex)(check_bounds(var_to_communicate))
239 )
240 for var_to_communicate in module.vars_to_communicate
241 ]
243 return parent_dict
245def get_class_from_file(file_path: str, class_name: str) -> ABCMeta:
246 # Get the absolute path if needed
247 abs_path = os.path.abspath(file_path)
249 # Get the module name from the file path
250 module_name = os.path.splitext(os.path.basename(file_path))[0]
252 # Load the module specification
253 spec = importlib.util.spec_from_file_location(module_name, abs_path)
255 # Create the module
256 module = importlib.util.module_from_spec(spec)
258 # Execute the module
259 spec.loader.exec_module(module)
261 # Get the class from the module
262 target_class = getattr(module, class_name)
264 return target_class