Coverage for agentlib_flexquant/utils/config_management.py: 99%
105 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-03-26 09:43 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-03-26 09:43 +0000
1import importlib.util
2import inspect
3import math
4import os
5from abc import ABCMeta
6from copy import deepcopy
7from typing import TypeVar
9from agentlib.core.agent import AgentConfig
10from agentlib.core.module import BaseModuleConfig
11from agentlib.modules import get_all_module_types
13T = TypeVar("T", bound=BaseModuleConfig)
15all_module_types = get_all_module_types(["agentlib_mpc", "agentlib_flexquant"])
16# remove ML models, since import takes ages
17all_module_types.pop("agentlib_mpc.ann_trainer")
18all_module_types.pop("agentlib_mpc.gpr_trainer")
19all_module_types.pop("agentlib_mpc.linreg_trainer")
20all_module_types.pop("agentlib_mpc.ml_simulator")
21all_module_types.pop("agentlib_mpc.set_point_generator")
22# remove clone since not used
23all_module_types.pop("clonemap")
25# dictionary mapping the module name to the module config (ModelMetaclass)
26MODULE_TYPE_DICT = {
27 name: inspect.get_annotations(class_type.import_class())["config"]
28 for name, class_type in all_module_types.items()
29}
30# dictionary mapping the module name to the module (ModuleImport)
31MODULE_NAME_DICT = all_module_types
33MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc"
34BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc"
35SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc"
36BASELINEMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_minlp_mpc"
37SHADOWMINLPMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_minlp_mpc"
38INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator"
39MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market"
40SIMULATOR_CONFIG_TYPE: str = "simulator"
43def get_module_type_matching_dict(dictionary: dict) -> (dict, dict):
44 """Create two dictionaries, which map the modules types of the agentlib_mpc modules
45 to those of the flexquant modules.
47 This is done by using the MODULE_NAME_DICT.
49 """
50 # Create dictionaries to store keys grouped by values
51 value_to_keys = {}
52 for k, v in dictionary.items():
53 if k.startswith("agentlib_mpc."):
54 if v not in value_to_keys:
55 value_to_keys[v] = {"agentlib": [], "flex": []}
56 value_to_keys[v]["agentlib"].append(k)
57 if k.startswith("agentlib_flexquant."):
58 # find the parent class of the module in the flexquant in agentlib_mpc
59 for vv in value_to_keys:
60 if vv.import_class() is v.import_class().__bases__[0]:
61 value_to_keys[vv]["flex"].append(k)
62 break
64 # Create result dictionaries
65 baseline_matches = {}
66 shadow_matches = {}
68 for v, keys in value_to_keys.items():
69 # Check if we have both agentlib and flexibility keys for this value
70 if keys["agentlib"] and keys["flex"]:
71 # Map each agentlib key to corresponding flexibility key
72 for agent_key in keys["agentlib"]:
73 for flex_key in keys["flex"]:
74 if "baseline" in flex_key:
75 baseline_matches[agent_key] = flex_key
76 elif "shadow" in flex_key:
77 shadow_matches[agent_key] = flex_key
79 return baseline_matches, shadow_matches
82BASELINE_MODULE_TYPE_DICT, SHADOW_MODULE_TYPE_DICT = (
83 get_module_type_matching_dict(MODULE_NAME_DICT))
86def get_orig_module_type(config: AgentConfig) -> str:
87 """Return the config type of the original MPC."""
88 for module in config.modules:
89 if module["type"].startswith("agentlib_mpc"):
90 return module["type"]
93def get_module(config: AgentConfig, module_type: str) -> T:
94 """Extracts a module from a config based on its name."""
95 for module in config.modules:
96 if module["type"] == module_type:
97 # deepcopy -> avoid changing the original config, when editing the module
98 # deepcopy the args of the constructor instead of the module object,
99 # because the simulator module exceeds the recursion limit
100 config_id = deepcopy(config.id)
101 mod = deepcopy(module)
102 return MODULE_TYPE_DICT[mod["type"]](**mod, _agent_id=config_id)
103 else:
104 raise ModuleNotFoundError(
105 f"Module type {module['type']} not found in " f"agentlib and its plug ins."
106 )
109def get_flex_mpc_module_config(
110 agent_config: AgentConfig, module_config: BaseModuleConfig, module_type: str
111):
112 """Get a flexquant module config from an original agentlib module config."""
113 config_dict = module_config.model_dump()
114 config_dict["type"] = module_type
115 flex_config_dict = MODULE_TYPE_DICT[module_type](**config_dict,
116 _agent_id=agent_config.id)
117 # HOTFIX due to AgentLib-MPC bug. Needs to be adapted after Objectives
118 # in AgentLib-MPC are fixed.
119 if flex_config_dict.r_del_u is None:
120 flex_config_dict = flex_config_dict.model_copy(update={"r_del_u": {}})
121 return flex_config_dict
124def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig) -> dict:
125 """Remove unnecessary fields from the module to keep the created json simple."""
126 excluded_fields = [
127 "rdf_class",
128 "source",
129 "type",
130 "timestamp",
131 "description",
132 "unit",
133 "clip",
134 "shared",
135 "interpolation_method",
136 "allowed_values",
137 ]
139 def check_bounds(parameter):
140 delete_list = excluded_fields.copy()
141 if parameter.lb == -math.inf:
142 delete_list.append("lb")
143 if parameter.ub == math.inf:
144 delete_list.append("ub")
145 return delete_list
147 parent_dict = module.model_dump(exclude_defaults=True)
148 # update every variable with a dict excluding the defined fields
149 if "parameters" in parent_dict:
150 parent_dict["parameters"] = [
151 parameter.dict(exclude=check_bounds(parameter)) for
152 parameter in module.parameters
153 ]
154 if "inputs" in parent_dict:
155 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for
156 input in module.inputs]
157 if "outputs" in parent_dict:
158 parent_dict["outputs"] = [
159 output.dict(exclude=check_bounds(output)) for output in module.outputs
160 ]
161 if "controls" in parent_dict:
162 parent_dict["controls"] = [
163 control.dict(exclude=check_bounds(control)) for control in module.controls
164 ]
165 if "binary_controls" in parent_dict:
166 parent_dict["binary_controls"] = [
167 binary_control.dict(exclude=check_bounds(binary_control))
168 for binary_control in module.binary_controls
169 ]
170 if "states" in parent_dict:
171 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for
172 state in module.states]
173 if "full_controls" in parent_dict:
174 parent_dict["full_controls"] = [
175 full_control.dict(
176 exclude=(lambda ex:
177 ex.remove("shared") or ex)(check_bounds(full_control))
178 )
179 for full_control in module.full_controls
180 ]
181 if "vars_to_communicate" in parent_dict:
182 parent_dict["vars_to_communicate"] = [
183 var_to_communicate.dict(
184 exclude=(lambda ex:
185 ex.remove("shared") or ex)(check_bounds(var_to_communicate))
186 )
187 for var_to_communicate in module.vars_to_communicate
188 ]
190 return parent_dict
193def get_class_from_file(file_path: str, class_name: str) -> ABCMeta:
194 # Get the absolute path if needed
195 abs_path = os.path.abspath(file_path)
197 # Get the module name from the file path
198 module_name = os.path.splitext(os.path.basename(file_path))[0]
200 # Load the module specification
201 spec = importlib.util.spec_from_file_location(module_name, abs_path)
203 # Create the module
204 module = importlib.util.module_from_spec(spec)
206 # Execute the module
207 spec.loader.exec_module(module)
209 # Get the class from the module
210 target_class = getattr(module, class_name)
212 return target_class