Coverage for agentlib_flexquant/utils/parsing.py: 83%
160 statements
« prev ^ index » next coverage.py v7.4.4, created at 2026-03-26 09:43 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2026-03-26 09:43 +0000
1import ast
2import logging
3from string import Template
4from typing import Optional, Union
6from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable
8from agentlib_flexquant.data_structures.globals import (
9 SHADOW_MPC_COST_FUNCTION,
10 full_trajectory_suffix,
11 return_baseline_cost_function,
12 PROVISION_VAR_NAME,
13 ACCEPTED_POWER_VAR_NAME,
14 RELATIVE_EVENT_START_TIME_VAR_NAME,
15 RELATIVE_EVENT_END_TIME_VAR_NAME
16)
17from agentlib_flexquant.data_structures.mpcs import (
18 BaselineMPCData,
19 NFMPCData,
20 PFMPCData,
21)
23logger = logging.getLogger(__name__)
25# Constants
26CASADI_INPUT = "CasadiInput"
27CASADI_PARAMETER = "CasadiParameter"
28CASADI_OUTPUT = "CasadiOutput"
30# String templates
31INPUT_TEMPLATE = Template(
32 "$class_name(name='$name', value=$value, unit='$unit', type='$type', "
33 "description='$description')"
34)
35PARAMETER_TEMPLATE = Template(
36 "$class_name(name='$name', value=$value, unit='$unit', description='$description')"
37)
38OUTPUT_TEMPLATE = Template(
39 "$class_name(name='$name', unit='$unit', type='$type', value=$value, "
40 "description='$description')"
41)
44def create_ast_element(template_string: str) -> ast.expr:
45 """Convert a template string into an AST call node.
47 Args:
48 template_string: A Python code template string to parse.
50 Returns:
51 ast.Expr: An abstract syntax tree (AST) expr node parsed from the template
52 string.
54 """
55 return ast.parse(template_string).body[0].value
58def add_input(
59 name: str, value: Union[bool, str, int], unit: str, description: str, type: str
60) -> ast.expr:
61 """Create an AST node for an input definition.
63 Args:
64 name: The name of the input.
65 value: The default value for the input. Can be a boolean, string, or integer.
66 unit: The unit associated with the input value.
67 description: A human-readable description of the input.
68 type: The data type of the input (e.g., "float", "int", "string").
70 Returns:
71 ast.Call: An abstract syntax tree (AST) call node representing the input definition.
73 """
74 return create_ast_element(
75 INPUT_TEMPLATE.substitute(
76 class_name=CASADI_INPUT,
77 name=name,
78 value=value,
79 unit=unit,
80 description=description,
81 type=type,
82 )
83 )
86def add_parameter(
87 name: str, value: Union[int, float], unit: str, description: str
88) -> ast.expr:
89 """Create an AST node for a parameter definition.
91 Args:
92 name: The name of the parameter.
93 value: The value of the parameter. Can be an integer or float.
94 unit: The unit associated with the parameter value.
95 description: A human-readable description of the parameter.
97 Returns:
98 ast.expr: An abstract syntax tree (AST) call node
99 representing the parameter definition.
101 """
102 return create_ast_element(
103 PARAMETER_TEMPLATE.substitute(
104 class_name=CASADI_PARAMETER,
105 name=name,
106 value=value,
107 unit=unit,
108 description=description,
109 )
110 )
113def add_output(
114 name: str, unit: str, type: str, value: Union[str, float], description: str
115) -> ast.expr:
116 """Create an AST node for an output definition.
118 Args:
119 name: The name of the output.
120 unit: The unit associated with the output value.
121 type: The data type of the output (e.g., "float", "string").
122 value: The value of the output. Can be a string or float.
123 description: A human-readable description of the output.
125 Returns:
126 ast.expr: An abstract syntax tree (AST) call node representing the output definition.
128 """
129 return create_ast_element(
130 OUTPUT_TEMPLATE.substitute(
131 class_name=CASADI_OUTPUT,
132 name=name,
133 unit=unit,
134 type=type,
135 value=value,
136 description=description,
137 )
138 )
141def _get_assignment_name(node: ast.stmt) -> Optional[str]:
142 """Extract the variable name from an assignment statement.
144 Handles both annotated assignments (ast.AnnAssign) and regular assignments
145 (ast.Assign).
147 Args:
148 node: An AST statement node.
150 Returns:
151 The variable name if the node is an assignment, None otherwise.
153 """
154 if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
155 return node.target.id
156 elif isinstance(node, ast.Assign) and len(node.targets) == 1:
157 if isinstance(node.targets[0], ast.Name):
158 return node.targets[0].id
159 return None
162class SetupSystemModifier(ast.NodeTransformer):
163 """A custom AST transformer for modifying the MPC model file.
165 This class traverses the AST of the input file, identifies the relevant classes and methods,
166 and performs the necessary modifications.
168 """
170 def __init__(
171 self,
172 mpc_data: Union[BaselineMPCData, NFMPCData, PFMPCData],
173 controls: list[MPCVariable],
174 binary_controls: Optional[list[MPCVariable]],
175 ):
176 self.mpc_data = mpc_data
177 self.controls = controls
178 self.binary_controls = binary_controls
179 # create object for ast parsing for both, the config and the model
180 self.config_obj: Union[None, ast.expr] = None
181 self.model_obj: Union[None, ast.expr] = None
182 # select modification of setup_system based on mpc type
183 if isinstance(mpc_data, (PFMPCData, NFMPCData)):
184 self.modify_config_class = self.modify_config_class_shadow
185 self.modify_setup_system = self.modify_setup_system_shadow
186 if isinstance(mpc_data, BaselineMPCData):
187 self.modify_config_class = self.modify_config_class_baseline
188 self.modify_setup_system = self.modify_setup_system_baseline
190 def visit_Module(self, module: ast.Module) -> ast.Module:
191 """Visit a module definition in the AST.
193 Append or delete the import statements at the top of the module.
195 Args:
196 module: The module definition node in the AST.
198 Returns:
199 The possibly modified module definition node.
201 """
202 # append imports for baseline
203 if isinstance(self.mpc_data, BaselineMPCData):
204 module = add_import_to_tree(name="pandas", alias="pd", tree=module)
205 module = add_import_to_tree(name="casadi", alias="ca", tree=module)
206 # delete imports for shadow MPCs
207 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)):
208 module = remove_all_imports_from_tree(module)
209 # trigger the next visit method (ClassDef)
210 self.generic_visit(module)
211 return module
213 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
214 """Visit a class definition in the AST.
216 This method is called for each class definition in the AST. It identifies the
217 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions.
219 Args:
220 node: The class definition node in the AST.
222 Returns:
223 The possibly modified class definition node.
225 """
226 for base in node.bases:
227 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig":
228 # get ast object and trigger modification
229 self.config_obj = node
230 self.modify_config_class(node)
231 # change class name
232 node.name = self.mpc_data.class_name + "Config"
233 if isinstance(base, ast.Name) and base.id == "CasadiModel":
234 # get ast object and trigger modification
235 self.model_obj = node
236 for item in node.body:
237 if (
238 isinstance(item, ast.FunctionDef)
239 and item.name == "setup_system"
240 ):
241 self.modify_setup_system(item)
242 # change config value
243 if isinstance(item, ast.AnnAssign) and item.target.id == "config":
244 item.annotation = (
245 ast.parse(self.mpc_data.class_name + "Config").body[0].value
246 )
248 # change class name
249 node.name = self.mpc_data.class_name
251 return node
253 def get_leftmost_list(
254 self, node: Union[ast.Tuple, ast.BinOp, ast.List]
255 ) -> Optional[ast.List]:
256 """Recursively traverse binary operations to get the leftmost list.
258 Args:
259 node: An AST node (could be a BinOp or directly a List)
261 Returns:
262 The leftmost List node found
264 """
265 if isinstance(node, ast.List):
266 return node
267 elif isinstance(node, ast.BinOp):
268 # If it's a binary operation, recurse to the left
269 return self.get_leftmost_list(node.left)
270 elif isinstance(node, ast.Tuple):
271 # If it's a tuple with elements, check the first element
272 if node.elts and len(node.elts) > 0:
273 return self.get_leftmost_list(node.elts[0])
274 # If we get here, we couldn't find a list
275 return None
277 def modify_config_class_shadow(self, node: ast.ClassDef):
278 """Modify the config class of the shadow mpc.
280 Args:
281 node: The class definition node of the config.
283 """
284 # loop over config object and modify fields
285 for body in node.body:
286 # If there are custom functions in the config class, skip them
287 if isinstance(body, ast.FunctionDef):
288 continue
290 # Skip non-annotated assignments with a warning
291 if isinstance(body, ast.Assign):
292 var_name = _get_assignment_name(body)
293 logger.warning(
294 "Skipping non-annotated class variable '%s' in config class '%s'. "
295 "Only type-annotated variables (e.g., 'var: Type = value') can be "
296 "modified by the AST transformer. If this variable should be "
297 "included in the MPC configuration, please add a type annotation.",
298 var_name or "<unknown>",
299 node.name
300 )
301 continue
302 # add the time and full baseline control trajectory as inputs
303 if body.target.id == "inputs":
304 for control in self.controls:
305 body.value.elts.append(
306 add_input(
307 f"{control.name}{full_trajectory_suffix}",
308 None,
309 control.unit,
310 "full control trajectory output of baseline mpc",
311 "pd.Series",
312 )
313 )
314 # also include binary controls
315 if self.binary_controls:
316 for control in self.binary_controls:
317 body.value.elts.append(
318 add_input(
319 f"{control.name}{full_trajectory_suffix}",
320 None,
321 control.unit,
322 "full control trajectory output of baseline mpc",
323 "pd.Series",
324 )
325 )
326 for var in self.mpc_data.config_inputs_appendix:
327 body.value.elts.append(
328 add_input(var.name, var.value, var.unit, var.description, var.type)
329 )
331 # add the flex variables and the weights
332 if body.target.id == "parameters":
333 for parameter in self.mpc_data.config_parameters_appendix:
334 body.value.elts.append(
335 add_parameter(parameter.name, parameter.value, parameter.unit, parameter.description)
336 )
339 def modify_config_class_baseline(self, node: ast.ClassDef):
340 """Modify the config class of the baseline mpc.
342 Args:
343 node: The class definition node of the config.
345 """
346 # loop over config object and modify fields
347 for body in node.body:
348 # If there are custom functions in the config class, skip them
349 if isinstance(body, ast.FunctionDef):
350 continue
352 # Skip regular assignments (ast.Assign) - only process annotated assignments
353 if not isinstance(body, ast.AnnAssign):
354 var_name = _get_assignment_name(body)
355 logger.warning(
356 "Skipping non-annotated class variable '%s' in config class '%s'. "
357 "Only type-annotated variables (e.g., 'var: Type = value') can be "
358 "modified by the AST transformer. If this variable should be "
359 "included in the MPC configuration, please add a type annotation.",
360 var_name or "<unknown>",
361 node.name
362 )
363 continue
365 # add the fullcontrol trajectories to the baseline config class
366 if body.target.id == "outputs":
367 if isinstance(body.value, ast.List):
368 # Simple list case
369 value_list = body.value
370 elif isinstance(body.value, ast.BinOp) or isinstance(
371 body.value, ast.Tuple
372 ):
373 # Complex case with concatenated lists or tuple
374 value_list = self.get_leftmost_list(body.value)
376 # add the flexibility inputs
377 if body.target.id == "inputs":
378 if isinstance(body.value, ast.List):
379 # Simple list case
380 value_list = body.value
381 elif isinstance(body.value, ast.BinOp) or isinstance(
382 body.value, ast.Tuple
383 ):
384 # Complex case with concatenated lists or tuple
385 value_list = self.get_leftmost_list(body.value)
386 value_list.elts.append(
387 add_input(
388 ACCEPTED_POWER_VAR_NAME,
389 0,
390 "W",
391 "External power profile to be provided",
392 "pd.Series",
393 )
394 )
395 value_list.elts.append(
396 add_input(
397 PROVISION_VAR_NAME,
398 False,
399 "-",
400 "Flag signaling if the flexibility is in provision",
401 "bool",
402 )
403 )
404 value_list.elts.append(
405 add_input(
406 RELATIVE_EVENT_START_TIME_VAR_NAME,
407 0,
408 "s",
409 "relative start time of the flexibility event",
410 "int",
411 )
412 )
413 value_list.elts.append(
414 add_input(
415 RELATIVE_EVENT_END_TIME_VAR_NAME,
416 0,
417 "s",
418 "relative end time of the flexibility event",
419 "int",
420 )
421 )
423 # add the flex variables and the weights
424 if body.target.id == "parameters":
425 for parameter in self.mpc_data.config_parameters_appendix:
426 body.value.elts.append(
427 add_parameter(parameter.name, 0, "-", parameter.description)
428 )
430 def modify_setup_system_shadow(self, node: ast.FunctionDef):
431 """Modify the setup_system method of the shadow mpc model class.
433 This method changes the return statement of the setup_system method and adds
434 all necessary new lines of code.
436 Args:
437 node: The function definition node of setup_system.
439 """
440 # constraint the control trajectories for t < market_time
441 for i, item in enumerate(node.body):
442 if (
443 isinstance(item, ast.Assign)
444 and isinstance(item.targets[0], ast.Attribute)
445 and item.targets[0].attr == "constraints"
446 ):
447 if isinstance(item.value, ast.List):
448 for ind, control in enumerate(self.controls):
449 # insert control boundaries at beginning of function
450 node.body.insert(
451 0,
452 ast.parse(
453 f"{control.name}_upper = ca.if_else(self.time < "
454 f"self.market_time.sym, "
455 f"self.{control.name}{full_trajectory_suffix}.sym, "
456 f"self.{control.name}.ub)"
457 ).body[0],
458 )
459 node.body.insert(
460 0,
461 ast.parse(
462 f"{control.name}_lower = ca.if_else(self.time < "
463 f"self.market_time.sym, "
464 f"self.{control.name}{full_trajectory_suffix}.sym, "
465 f"self.{control.name}.lb)"
466 ).body[0],
467 )
468 # append to constraints
469 new_element = (
470 ast.parse(
471 f"({control.name}_lower, self.{control.name}, {control.name}_upper)"
472 )
473 .body[0]
474 .value
475 )
476 item.value.elts.append(new_element)
477 break
478 # loop through setup_system function to find return statement
479 for i, stmt in enumerate(node.body):
480 if isinstance(stmt, ast.Return):
481 # store current return statement
482 original_return = stmt.value
484 # First, check if there's actually an appendix to add
485 if self.mpc_data.flex_cost_function_appendix:
486 # Parse the appendix string into an AST expression
487 appendix_ast = ast.parse(self.mpc_data.flex_cost_function_appendix,
488 mode="eval").body
489 # Create a BinOp node representing: original_return + appendix
490 combined_value = ast.BinOp(
491 left=original_return,
492 op=ast.Add(),
493 right=appendix_ast
494 )
495 else:
496 combined_value = original_return
498 new_body = [
499 ast.Assign(
500 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
501 value=combined_value,
502 ),
503 # create flex objective variable
504 ast.Assign(
505 targets=[ast.Name(id="obj_flex", ctx=ast.Store())],
506 value=ast.parse(
507 self.mpc_data.flex_cost_function, mode="eval"
508 ).body,
509 ),
510 # overwrite return statement with custom function
511 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value),
512 ]
513 node.body[i:] = new_body
514 break
516 def modify_setup_system_baseline(self, node: ast.FunctionDef):
517 """Modify the setup_system method of the baseline mpc model class.
519 This method changes the return statement of the setup_system method and adds
520 all necessary new lines of code.
522 Args:
523 node: The function definition node of setup_system.
525 """
527 # loop through setup_system function to find return statement
528 for i, stmt in enumerate(node.body):
529 if isinstance(stmt, ast.Return):
530 # store current return statement
531 original_return = stmt.value
532 new_body = [
533 # create new standard objective variable
534 ast.Assign(
535 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
536 value=original_return,
537 ),
538 # overwrite return statement with custom function
539 ast.Return(
540 value=ast.parse(
541 return_baseline_cost_function(
542 power_variable=self.mpc_data.power_variable,
543 comfort_variable=self.mpc_data.comfort_variable,
544 )
545 )
546 .body[0]
547 .value
548 ),
549 ]
550 # append new variables to end of function
551 node.body[i:] = new_body
552 break
555def add_import_to_tree(name: str, alias: str, tree: ast.Module) -> ast.Module:
556 """Add import to the module.
558 The statement 'import name as alias' will be added.
560 Args:
561 name: name of the module to be imported
562 alias: alias of the module
563 tree: the tree to be imported
565 Returns:
566 The tree updated with the import statement
568 """
569 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)])
570 for node in tree.body:
571 if isinstance(node, ast.Import):
572 already_imported_names = [alias.name for alias in node.names]
573 already_imported_alias = [alias.asname for alias in node.names]
574 if (
575 name not in already_imported_names
576 and alias not in already_imported_alias
577 ):
578 tree.body.insert(0, import_statement)
579 break
580 else:
581 tree.body.insert(0, import_statement)
582 return tree
585def remove_all_imports_from_tree(tree: ast.Module) -> ast.Module:
586 # Create a new list to hold nodes that are not imports
587 new_body = [
588 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom))
589 ]
590 # Update the body of the tree to the new list
591 tree.body = new_body
592 return tree