Coverage for agentlib_flexquant/utils/parsing.py: 90%
136 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-10-20 14:09 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-10-20 14:09 +0000
1import ast
2from string import Template
3from typing import Optional, Union
5from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable
7from agentlib_flexquant.data_structures.globals import (
8 FLEX_EVENT_DURATION,
9 MARKET_TIME,
10 PREP_TIME,
11 SHADOW_MPC_COST_FUNCTION,
12 full_trajectory_prefix,
13 full_trajectory_suffix,
14 return_baseline_cost_function,
15)
16from agentlib_flexquant.data_structures.mpcs import (
17 BaselineMPCData,
18 BaseMPCData,
19 NFMPCData,
20 PFMPCData,
21)
23# Constants
24CASADI_INPUT = "CasadiInput"
25CASADI_PARAMETER = "CasadiParameter"
26CASADI_OUTPUT = "CasadiOutput"
28# String templates
29INPUT_TEMPLATE = Template(
30 "$class_name(name='$name', value=$value, unit='$unit', type='$type', "
31 "description='$description')"
32)
33PARAMETER_TEMPLATE = Template(
34 "$class_name(name='$name', value=$value, unit='$unit', description='$description')"
35)
36OUTPUT_TEMPLATE = Template(
37 "$class_name(name='$name', unit='$unit', type='$type', value=$value, "
38 "description='$description')"
39)
42def create_ast_element(template_string: str) -> ast.Call:
43 """Convert a template string into an AST call node.
45 Args:
46 template_string: A Python code template string to parse.
48 Returns:
49 ast.Call: An abstract syntax tree (AST) call node parsed from the template string.
51 """
52 return ast.parse(template_string).body[0].value
55def add_input(
56 name: str, value: Union[bool, str, int], unit: str, description: str, type: str
57) -> ast.Call:
58 """Create an AST node for an input definition.
60 Args:
61 name: The name of the input.
62 value: The default value for the input. Can be a boolean, string, or integer.
63 unit: The unit associated with the input value.
64 description: A human-readable description of the input.
65 type: The data type of the input (e.g., "float", "int", "string").
67 Returns:
68 ast.Call: An abstract syntax tree (AST) call node representing the input definition.
70 """
71 return create_ast_element(
72 INPUT_TEMPLATE.substitute(
73 class_name=CASADI_INPUT,
74 name=name,
75 value=value,
76 unit=unit,
77 description=description,
78 type=type,
79 )
80 )
83def add_parameter(
84 name: str, value: Union[int, float], unit: str, description: str
85) -> ast.Call:
86 """Create an AST node for a parameter definition.
88 Args:
89 name: The name of the parameter.
90 value: The value of the parameter. Can be an integer or float.
91 unit: The unit associated with the parameter value.
92 description: A human-readable description of the parameter.
94 Returns:
95 ast.Call: An abstract syntax tree (AST) call node
96 representing the parameter definition.
98 """
99 return create_ast_element(
100 PARAMETER_TEMPLATE.substitute(
101 class_name=CASADI_PARAMETER,
102 name=name,
103 value=value,
104 unit=unit,
105 description=description,
106 )
107 )
110def add_output(
111 name: str, unit: str, type: str, value: Union[str, float], description: str
112) -> ast.Call:
113 """Create an AST node for an output definition.
115 Args:
116 name: The name of the output.
117 unit: The unit associated with the output value.
118 type: The data type of the output (e.g., "float", "string").
119 value: The value of the output. Can be a string or float.
120 description: A human-readable description of the output.
122 Returns:
123 ast.Call: An abstract syntax tree (AST) call node representing the output definition.
125 """
126 return create_ast_element(
127 OUTPUT_TEMPLATE.substitute(
128 class_name=CASADI_OUTPUT,
129 name=name,
130 unit=unit,
131 type=type,
132 value=value,
133 description=description,
134 )
135 )
138class SetupSystemModifier(ast.NodeTransformer):
139 """A custom AST transformer for modifying the MPC model file.
141 This class traverses the AST of the input file, identifies the relevant classes and methods,
142 and performs the necessary modifications.
144 """
146 def __init__(
147 self,
148 mpc_data: BaseMPCData,
149 controls: list[MPCVariable],
150 binary_controls: Optional[list[MPCVariable]],
151 ):
152 self.mpc_data = mpc_data
153 self.controls = controls
154 self.binary_controls = binary_controls
155 # create object for ast parsing for both, the config and the model
156 self.config_obj: Union[None, ast.expr] = None
157 self.model_obj: Union[None, ast.expr] = None
158 # select modification of setup_system based on mpc type
159 if isinstance(mpc_data, (PFMPCData, NFMPCData)):
160 self.modify_config_class = self.modify_config_class_shadow
161 self.modify_setup_system = self.modify_setup_system_shadow
162 if isinstance(mpc_data, BaselineMPCData):
163 self.modify_config_class = self.modify_config_class_baseline
164 self.modify_setup_system = self.modify_setup_system_baseline
166 def visit_Module(self, module: ast.Module) -> ast.Module:
167 """Visit a module definition in the AST.
169 Append or delete the import statements at the top of the module.
171 Args:
172 module: The module definition node in the AST.
174 Returns:
175 The possibly modified module definition node.
177 """
178 # append imports for baseline
179 if isinstance(self.mpc_data, BaselineMPCData):
180 module = add_import_to_tree(name="pandas", alias="pd", tree=module)
181 module = add_import_to_tree(name="casadi", alias="ca", tree=module)
182 # delete imports for shadow MPCs
183 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)):
184 module = remove_all_imports_from_tree(module)
185 # trigger the next visit method (ClassDef)
186 self.generic_visit(module)
187 return module
189 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
190 """Visit a class definition in the AST.
192 This method is called for each class definition in the AST. It identifies the
193 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions.
195 Args:
196 node: The class definition node in the AST.
198 Returns:
199 The possibly modified class definition node.
201 """
202 for base in node.bases:
203 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig":
204 # get ast object and trigger modification
205 self.config_obj = node
206 self.modify_config_class(node)
207 # change class name
208 node.name = self.mpc_data.class_name + "Config"
209 if isinstance(base, ast.Name) and base.id == "CasadiModel":
210 # get ast object and trigger modification
211 self.model_obj = node
212 for item in node.body:
213 if (
214 isinstance(item, ast.FunctionDef)
215 and item.name == "setup_system"
216 ):
217 self.modify_setup_system(item)
218 # change config value
219 if isinstance(item, ast.AnnAssign) and item.target.id == "config":
220 item.annotation = (
221 ast.parse(self.mpc_data.class_name + "Config").body[0].value
222 )
224 # change class name
225 node.name = self.mpc_data.class_name
227 return node
229 def get_leftmost_list(
230 self, node: Union[ast.Tuple, ast.BinOp, ast.List]
231 ) -> Optional[ast.List]:
232 """Recursively traverse binary operations to get the leftmost list.
234 Args:
235 node: An AST node (could be a BinOp or directly a List)
237 Returns:
238 The leftmost List node found
240 """
241 if isinstance(node, ast.List):
242 return node
243 elif isinstance(node, ast.BinOp):
244 # If it's a binary operation, recurse to the left
245 return self.get_leftmost_list(node.left)
246 elif isinstance(node, ast.Tuple):
247 # If it's a tuple with elements, check the first element
248 if node.elts and len(node.elts) > 0:
249 return self.get_leftmost_list(node.elts[0])
250 # If we get here, we couldn't find a list
251 return None
253 def modify_config_class_shadow(self, node: ast.ClassDef):
254 """Modify the config class of the shadow mpc.
256 Args:
257 node: The class definition node of the config.
259 """
260 # loop over config object and modify fields
261 for body in node.body:
262 # add the time and full baseline control trajectory as inputs
263 if body.target.id == "inputs":
264 for control in self.controls:
265 body.value.elts.append(
266 add_input(
267 f"{control.name}{full_trajectory_suffix}",
268 None,
269 control.unit,
270 "full control trajectory output of baseline mpc",
271 "pd.Series",
272 )
273 )
274 # also include binary controls
275 if self.binary_controls:
276 for control in self.binary_controls:
277 body.value.elts.append(
278 add_input(
279 f"{control.name}{full_trajectory_suffix}",
280 None,
281 control.unit,
282 "full control trajectory output of baseline mpc",
283 "pd.Series",
284 )
285 )
286 body.value.elts.append(
287 add_input("in_provision", False, "-", "provision flag", "bool")
288 )
289 # add the flex variables and the weights
290 if body.target.id == "parameters":
291 for param_name in [PREP_TIME, FLEX_EVENT_DURATION, MARKET_TIME]:
292 body.value.elts.append(
293 add_parameter(param_name, 0, "s", "time to switch objective")
294 )
295 for weight in self.mpc_data.weights:
296 body.value.elts.append(
297 add_parameter(
298 weight.name,
299 weight.value,
300 "-",
301 "Weight for P in objective function",
302 )
303 )
305 def modify_config_class_baseline(self, node: ast.ClassDef):
306 """Modify the config class of the baseline mpc.
308 Args:
309 node: The class definition node of the config.
311 """
312 # loop over config object and modify fields
313 for body in node.body:
314 # add the fullcontrol trajectories to the baseline config class
315 if body.target.id == "outputs":
316 if isinstance(body.value, ast.List):
317 # Simple list case
318 value_list = body.value
319 elif isinstance(body.value, ast.BinOp) or isinstance(
320 body.value, ast.Tuple
321 ):
322 # Complex case with concatenated lists or tuple
323 value_list = self.get_leftmost_list(body.value)
325 # add the flexibility inputs
326 if body.target.id == "inputs":
327 if isinstance(body.value, ast.List):
328 # Simple list case
329 value_list = body.value
330 elif isinstance(body.value, ast.BinOp) or isinstance(
331 body.value, ast.Tuple
332 ):
333 # Complex case with concatenated lists or tuple
334 value_list = self.get_leftmost_list(body.value)
335 value_list.elts.append(
336 add_input(
337 "_P_external",
338 0,
339 "W",
340 "External power profile to be provided",
341 "pd.Series",
342 )
343 )
344 value_list.elts.append(
345 add_input(
346 "in_provision",
347 False,
348 "-",
349 "Flag signaling if the flexibility is in provision",
350 "bool",
351 )
352 )
353 value_list.elts.append(
354 add_input(
355 "rel_start",
356 0,
357 "s",
358 "relative start time of the flexibility event",
359 "int",
360 )
361 )
362 value_list.elts.append(
363 add_input(
364 "rel_end",
365 0,
366 "s",
367 "relative end time of the flexibility event",
368 "int",
369 )
370 )
372 # add the flex variables and the weights
373 if body.target.id == "parameters":
374 for parameter in self.mpc_data.config_parameters_appendix:
375 body.value.elts.append(
376 add_parameter(parameter.name, 0, "-", parameter.description)
377 )
379 def modify_setup_system_shadow(self, node: ast.FunctionDef):
380 """Modify the setup_system method of the shadow mpc model class.
382 This method changes the return statement of the setup_system method and adds
383 all necessary new lines of code.
385 Args:
386 node: The function definition node of setup_system.
388 """
389 # constraint the control trajectories for t < market_time
390 for i, item in enumerate(node.body):
391 if (
392 isinstance(item, ast.Assign)
393 and isinstance(item.targets[0], ast.Attribute)
394 and item.targets[0].attr == "constraints"
395 ):
396 if isinstance(item.value, ast.List):
397 for ind, control in enumerate(self.controls):
398 # insert control boundaries at beginning of function
399 node.body.insert(
400 0,
401 ast.parse(
402 f"{control.name}_upper = ca.if_else(self.time < "
403 f"self.market_time.sym, "
404 f"self.{control.name}{full_trajectory_suffix}.sym, "
405 f"self.{control.name}.ub)"
406 ).body[0],
407 )
408 node.body.insert(
409 0,
410 ast.parse(
411 f"{control.name}_lower = ca.if_else(self.time < "
412 f"self.market_time.sym, "
413 f"self.{control.name}{full_trajectory_suffix}.sym, "
414 f"self.{control.name}.lb)"
415 ).body[0],
416 )
417 # append to constraints
418 new_element = (
419 ast.parse(
420 f"({control.name}_lower, self.{control.name}, {control.name}_upper)"
421 )
422 .body[0]
423 .value
424 )
425 item.value.elts.append(new_element)
426 break
427 # loop through setup_system function to find return statement
428 for i, stmt in enumerate(node.body):
429 if isinstance(stmt, ast.Return):
430 # store current return statement
431 original_return = stmt.value
432 new_body = [
433 # create new standard objective variable
434 ast.Assign(
435 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
436 value=original_return,
437 ),
438 # create flex objective variable
439 ast.Assign(
440 targets=[ast.Name(id="obj_flex", ctx=ast.Store())],
441 value=ast.parse(
442 self.mpc_data.flex_cost_function, mode="eval"
443 ).body,
444 ),
445 # overwrite return statement with custom function
446 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value),
447 ]
448 # append new variables to end of function
449 node.body[i:] = new_body
450 break
452 def modify_setup_system_baseline(self, node: ast.FunctionDef):
453 """Modify the setup_system method of the baseline mpc model class.
455 This method changes the return statement of the setup_system method and adds
456 all necessary new lines of code.
458 Args:
459 node: The function definition node of setup_system.
461 """
463 # loop through setup_system function to find return statement
464 for i, stmt in enumerate(node.body):
465 if isinstance(stmt, ast.Return):
466 # store current return statement
467 original_return = stmt.value
468 new_body = [
469 # create new standard objective variable
470 ast.Assign(
471 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
472 value=original_return,
473 ),
474 # overwrite return statement with custom function
475 ast.Return(
476 value=ast.parse(
477 return_baseline_cost_function(
478 power_variable=self.mpc_data.power_variable,
479 comfort_variable=self.mpc_data.comfort_variable,
480 )
481 )
482 .body[0]
483 .value
484 ),
485 ]
486 # append new variables to end of function
487 node.body[i:] = new_body
488 break
491def add_import_to_tree(name: str, alias: str, tree: ast.Module) -> ast.Module:
492 """Add import to the module.
494 The statement 'import name as alias' will be added.
496 Args:
497 name: name of the module to be imported
498 alias: alias of the module
499 tree: the tree to be imported
501 Returns:
502 The tree updated with the import statement
504 """
505 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)])
506 for node in tree.body:
507 if isinstance(node, ast.Import):
508 already_imported_names = [alias.name for alias in node.names]
509 already_imported_alias = [alias.asname for alias in node.names]
510 if (
511 name not in already_imported_names
512 and alias not in already_imported_alias
513 ):
514 tree.body.insert(0, import_statement)
515 break
516 else:
517 tree.body.insert(0, import_statement)
518 return tree
521def remove_all_imports_from_tree(tree: ast.Module) -> ast.Module:
522 # Create a new list to hold nodes that are not imports
523 new_body = [
524 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom))
525 ]
526 # Update the body of the tree to the new list
527 tree.body = new_body
528 return tree