Coverage for agentlib_flexquant/utils/config_management.py: 90%

96 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-01 15:10 +0000

1from agentlib.core.agent import AgentConfig 

2from agentlib.core.module import BaseModuleConfig 

3import agentlib_flexquant.data_structures.globals as glbs 

4from copy import deepcopy 

5from typing import TypeVar 

6import math 

7from agentlib.modules import get_all_module_types 

8import inspect 

9import os 

10import importlib.util 

11 

12 

13T = TypeVar('T', bound=BaseModuleConfig) 

14 

15 

16all_module_types = get_all_module_types(["agentlib_mpc", "agentlib_flexquant"]) 

17# remove ML models, since import takes ages 

18all_module_types.pop("agentlib_mpc.ann_trainer") 

19all_module_types.pop("agentlib_mpc.gpr_trainer") 

20all_module_types.pop("agentlib_mpc.linreg_trainer") 

21all_module_types.pop("agentlib_mpc.ann_simulator") 

22all_module_types.pop("agentlib_mpc.set_point_generator") 

23# remove clone since not used 

24all_module_types.pop("clonemap") 

25 

26MODULE_TYPE_DICT = {name: inspect.get_annotations(class_type.import_class())["config"] for name, class_type in all_module_types.items()} 

27 

28MPC_CONFIG_TYPE: str = "agentlib_mpc.mpc" 

29BASELINEMPC_CONFIG_TYPE: str = "agentlib_flexquant.baseline_mpc" 

30SHADOWMPC_CONFIG_TYPE: str = "agentlib_flexquant.shadow_mpc" 

31INDICATOR_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_indicator" 

32MARKET_CONFIG_TYPE: str = "agentlib_flexquant.flexibility_market" 

33SIMULATOR_CONFIG_TYPE: str = "simulator" 

34 

35 

36def get_module_type_matching_dict(dictionary: dict): 

37 """Creates two dictionaries, which map the modules types of the agentlib_mpc modules 

38 to those of the flexquant modules. This is done by using the MODULE_TYPE_DICT 

39 

40 """ 

41 # Create dictionaries to store keys grouped by values 

42 value_to_keys = {} 

43 for k, v in dictionary.items(): 

44 if k.startswith(('agentlib_mpc.', 'agentlib_flexquant.')): 

45 if v not in value_to_keys: 

46 value_to_keys[v] = {'agentlib': [], 'flex': []} 

47 if k.startswith('agentlib_mpc.'): 

48 value_to_keys[v]['agentlib'].append(k) 

49 else: 

50 value_to_keys[v]['flex'].append(k) 

51 

52 # Create result dictionaries 

53 baseline_matches = {} 

54 shadow_matches = {} 

55 

56 for v, keys in value_to_keys.items(): 

57 # Check if we have both agentlib and flexibility keys for this value 

58 if keys['agentlib'] and keys['flex']: 

59 # Map each agentlib key to corresponding flexibility key 

60 for agent_key in keys['agentlib']: 

61 for flex_key in keys['flex']: 

62 if 'baseline' in flex_key: 

63 baseline_matches[agent_key] = flex_key 

64 elif 'shadow' in flex_key: 

65 shadow_matches[agent_key] = flex_key 

66 

67 return baseline_matches, shadow_matches 

68 

69 

70BASELINE_MODULE_TYPE_DICT, SHADOW_MODULE_TYPE_DICT = ( 

71 get_module_type_matching_dict(MODULE_TYPE_DICT)) 

72 

73 

74def get_orig_module_type(config: AgentConfig): 

75 """Returns the config type of the original MPC 

76 

77 """ 

78 for module in config.modules: 

79 if module["type"].startswith("agentlib_mpc"): 

80 return module["type"] 

81 

82 

83def get_module(config: AgentConfig, module_type: str) -> T: 

84 """Extracts a module from a config based on its name 

85 

86 """ 

87 for module in config.modules: 

88 if module["type"] == module_type: 

89 # deepcopy -> avoid changing the original config, when editing the module 

90 # deepcopy the args of the constructor instead of the module object, 

91 # because the simulator module exceeds the recursion limit 

92 config_id = deepcopy(config.id) 

93 mod = deepcopy(module) 

94 return MODULE_TYPE_DICT[mod["type"]](**mod, _agent_id=config_id) 

95 else: 

96 raise ModuleNotFoundError(f"Module type {module['type']} not found in " 

97 f"agentlib and its plug ins.") 

98 

99 

100 

101def to_dict_and_remove_unnecessary_fields(module: BaseModuleConfig): 

102 """Removes unnecessary fields from the module to keep the created json simple 

103 

104 """ 

105 excluded_fields = ["rdf_class", "source", "type", "timestamp", "description", "unit", "clip", 

106 "shared", "interpolation_method", "allowed_values"] 

107 

108 def check_bounds(parameter): 

109 delete_list = excluded_fields.copy() 

110 if parameter.lb == -math.inf: 

111 delete_list.append("lb") 

112 if parameter.ub == math.inf: 

113 delete_list.append("ub") 

114 return delete_list 

115 

116 parent_dict = module.dict(exclude_defaults=True) 

117 # update every variable with a dict excluding the defined fields 

118 if "parameters" in parent_dict: 

119 parent_dict["parameters"] = [parameter.dict(exclude=check_bounds(parameter)) for parameter in module.parameters] 

120 if "inputs" in parent_dict: 

121 parent_dict["inputs"] = [input.dict(exclude=check_bounds(input)) for input in module.inputs] 

122 if "outputs" in parent_dict: 

123 parent_dict["outputs"] = [output.dict(exclude=check_bounds(output)) for output in module.outputs] 

124 if "controls" in parent_dict: 

125 parent_dict["controls"] = [control.dict(exclude=check_bounds(control)) for control in module.controls] 

126 if "states" in parent_dict: 

127 parent_dict["states"] = [state.dict(exclude=check_bounds(state)) for state in module.states] 

128 

129 return parent_dict 

130 

131 

132def subtract_relative_path(absolute_path, relative_path): 

133 # Normalize paths (convert slashes to the correct system format) 

134 absolute_path = os.path.normpath(absolute_path) 

135 relative_path = os.path.normpath(relative_path) 

136 

137 # Split relative path to get the first component 

138 rel_parts = relative_path.split(os.sep) 

139 first_rel_component = rel_parts[0] 

140 

141 # Find where the relative path starts in the absolute path 

142 if first_rel_component in absolute_path: 

143 # Find the last occurrence of the first component of the relative path 

144 index = absolute_path.rfind(first_rel_component) 

145 

146 if index != -1: 

147 # Return the part of absolute_path before the relative path component 

148 return absolute_path[:index].rstrip(os.sep) 

149 

150 # If the relative path component wasn't found, return the original path 

151 return absolute_path 

152 

153 

154def get_class_from_file(file_path, class_name): 

155 # Get the absolute path if needed 

156 abs_path = os.path.abspath(file_path) 

157 

158 # Get the module name from the file path 

159 module_name = os.path.splitext(os.path.basename(file_path))[0] 

160 

161 # Load the module specification 

162 spec = importlib.util.spec_from_file_location(module_name, abs_path) 

163 

164 # Create the module 

165 module = importlib.util.module_from_spec(spec) 

166 

167 # Execute the module 

168 spec.loader.exec_module(module) 

169 

170 # Get the class from the module 

171 target_class = getattr(module, class_name) 

172 

173 return target_class