Coverage for agentlib_flexquant/utils/config_management.py: 90%
96 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-01 15:10 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-01 15:10 +0000
1from agentlib.core.agent import AgentConfig
2from agentlib.core.module import BaseModuleConfig
3import agentlib_flexquant.data_structures.globals as glbs
4from copy import deepcopy
5from typing import TypeVar
6import math
7from agentlib.modules import get_all_module_types
8import inspect
9import os
10import importlib.util
13T = TypeVar('T', bound=BaseModuleConfig)
16all_module_types = get_all_module_types(["agentlib_mpc", "agentlib_flexquant"])
17# remove ML models, since import takes ages
18all_module_types.pop("agentlib_mpc.ann_trainer")
19all_module_types.pop("agentlib_mpc.gpr_trainer")
20all_module_types.pop("agentlib_mpc.linreg_trainer")
21all_module_types.pop("agentlib_mpc.ann_simulator")
22all_module_types.pop("agentlib_mpc.set_point_generator")
23# remove clone since not used
24all_module_types.pop("clonemap")
26MODULE_TYPE_DICT = {name: inspect.get_annotations(class_type.import_class())["config"] for name, class_type in all_module_types.items()}
28MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc"
29BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc"
30SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc"
31INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator"
32MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market"
33SIMULATOR_CONFIG_TYPE: str = "simulator"
36def get_module_type_matching_dict(dictionary: dict):
37 """Creates two dictionaries, which map the modules types of the agentlib_mpc modules
38 to those of the flexquant modules. This is done by using the MODULE_TYPE_DICT
40 """
41 # Create dictionaries to store keys grouped by values
42 value_to_keys = {}
43 for k, v in dictionary.items():
44 if k.startswith(('agentlib_mpc.', 'agentlib_flexquant.')):
45 if v not in value_to_keys:
46 value_to_keys[v] = {'agentlib': [], 'flex': []}
47 if k.startswith('agentlib_mpc.'):
48 value_to_keys[v]['agentlib'].append(k)
49 else:
50 value_to_keys[v]['flex'].append(k)
52 # Create result dictionaries
53 baseline_matches = {}
54 shadow_matches = {}
56 for v, keys in value_to_keys.items():
57 # Check if we have both agentlib and flexibility keys for this value
58 if keys['agentlib'] and keys['flex']:
59 # Map each agentlib key to corresponding flexibility key
60 for agent_key in keys['agentlib']:
61 for flex_key in keys['flex']:
62 if 'baseline' in flex_key:
63 baseline_matches[agent_key] = flex_key
64 elif 'shadow' in flex_key:
65 shadow_matches[agent_key] = flex_key
67 return baseline_matches, shadow_matches
70BASELINE_MODULE_TYPE_DICT, SHADOW_MODULE_TYPE_DICT = (
71 get_module_type_matching_dict(MODULE_TYPE_DICT))
74def get_orig_module_type(config: AgentConfig):
75 """Returns the config type of the original MPC
77 """
78 for module in config.modules:
79 if module["type"].startswith("agentlib_mpc"):
80 return module["type"]
83def get_module(config: AgentConfig, module_type: str) -> T:
84 """Extracts a module from a config based on its name
86 """
87 for module in config.modules:
88 if module["type"] == module_type:
89 # deepcopy -> avoid changing the original config, when editing the module
90 # deepcopy the args of the constructor instead of the module object,
91 # because the simulator module exceeds the recursion limit
92 config_id = deepcopy(config.id)
93 mod = deepcopy(module)
94 return MODULE_TYPE_DICT[mod["type"]](**mod, _agent_id=config_id)
95 else:
96 raise ModuleNotFoundError(f"Module type {module['type']} not found in "
97 f"agentlib and its plug ins.")
101def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig):
102 """Removes unnecessary fields from the module to keep the created json simple
104 """
105 excluded_fields = ["rdf_class", "source", "type", "timestamp", "description", "unit", "clip",
106 "shared", "interpolation_method", "allowed_values"]
108 def check_bounds(parameter):
109 delete_list = excluded_fields.copy()
110 if parameter.lb == -math.inf:
111 delete_list.append("lb")
112 if parameter.ub == math.inf:
113 delete_list.append("ub")
114 return delete_list
116 parent_dict = module.dict(exclude_defaults=True)
117 # update every variable with a dict excluding the defined fields
118 if "parameters" in parent_dict:
119 parent_dict["parameters"] = [parameter.dict(exclude=check_bounds(parameter)) for parameter in module.parameters]
120 if "inputs" in parent_dict:
121 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for input in module.inputs]
122 if "outputs" in parent_dict:
123 parent_dict["outputs"] = [output.dict(exclude=check_bounds(output)) for output in module.outputs]
124 if "controls" in parent_dict:
125 parent_dict["controls"] = [control.dict(exclude=check_bounds(control)) for control in module.controls]
126 if "states" in parent_dict:
127 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for state in module.states]
129 return parent_dict
132def subtract_relative_path(absolute_path, relative_path):
133 # Normalize paths (convert slashes to the correct system format)
134 absolute_path = os.path.normpath(absolute_path)
135 relative_path = os.path.normpath(relative_path)
137 # Split relative path to get the first component
138 rel_parts = relative_path.split(os.sep)
139 first_rel_component = rel_parts[0]
141 # Find where the relative path starts in the absolute path
142 if first_rel_component in absolute_path:
143 # Find the last occurrence of the first component of the relative path
144 index = absolute_path.rfind(first_rel_component)
146 if index != -1:
147 # Return the part of absolute_path before the relative path component
148 return absolute_path[:index].rstrip(os.sep)
150 # If the relative path component wasn't found, return the original path
151 return absolute_path
154def get_class_from_file(file_path, class_name):
155 # Get the absolute path if needed
156 abs_path = os.path.abspath(file_path)
158 # Get the module name from the file path
159 module_name = os.path.splitext(os.path.basename(file_path))[0]
161 # Load the module specification
162 spec = importlib.util.spec_from_file_location(module_name, abs_path)
164 # Create the module
165 module = importlib.util.module_from_spec(spec)
167 # Execute the module
168 spec.loader.exec_module(module)
170 # Get the class from the module
171 target_class = getattr(module, class_name)
173 return target_class