Coverage for agentlib_flexquant/utils/parsing.py: 85%
151 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-15 15:25 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-15 15:25 +0000
1import ast
2from typing import Union, List, Optional
3from string import Template
4from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable
5from agentlib_flexquant.data_structures.mpcs import (
6 BaseMPCData,
7 PFMPCData,
8 NFMPCData,
9 BaselineMPCData,
10)
11from agentlib_flexquant.data_structures.globals import (
12 SHADOW_MPC_COST_FUNCTION,
13 return_baseline_cost_function,
14 full_trajectory_prefix,
15 full_trajectory_suffix,
16 MARKET_TIME,
17 PREP_TIME,
18 FLEX_EVENT_DURATION
19)
22# Constants
23CASADI_INPUT = "CasadiInput"
24CASADI_PARAMETER = "CasadiParameter"
25CASADI_OUTPUT = "CasadiOutput"
27# String templates
28INPUT_TEMPLATE = Template(
29 "$class_name(name='$name', value=$value, unit='$unit', description='$description')"
30)
31PARAMETER_TEMPLATE = Template(
32 "$class_name(name='$name', value=$value, unit='$unit', description='$description')"
33)
34OUTPUT_TEMPLATE = Template(
35 "$class_name(name='$name', unit='$unit', type='$type', value=$value, description='$description')"
36)
39def create_ast_element(template_string: str) -> ast.Call:
40 """Convert a template string into an AST call node.
42 Args:
43 template_string: A Python code template string to parse.
45 Returns:
46 ast.Call: An abstract syntax tree (AST) call node parsed from the template string.
48 """
49 return ast.parse(template_string).body[0].value
52def add_input(name: str, value: Union[bool, str, int], unit: str, description: str, type: str) -> ast.Call:
53 """Create an AST node for an input definition.
55 Args:
56 name: The name of the input.
57 value: The default value for the input. Can be a boolean, string, or integer.
58 unit: The unit associated with the input value.
59 description: A human-readable description of the input.
60 type: The data type of the input (e.g., "float", "int", "string").
62 Returns:
63 ast.Call: An abstract syntax tree (AST) call node representing the input definition.
65 """
66 return create_ast_element(
67 INPUT_TEMPLATE.substitute(
68 class_name=CASADI_INPUT,
69 name=name,
70 value=value,
71 unit=unit,
72 description=description,
73 type=type,
74 )
75 )
78def add_parameter(name: str, value: Union[int, float], unit: str, description: str) -> ast.Call:
79 """Create an AST node for a parameter definition.
81 Args:
82 name: The name of the parameter.
83 value: The value of the parameter. Can be an integer or float.
84 unit: The unit associated with the parameter value.
85 description: A human-readable description of the parameter.
87 Returns:
88 ast.Call: An abstract syntax tree (AST) call node representing the parameter definition.
90 """
91 return create_ast_element(
92 PARAMETER_TEMPLATE.substitute(
93 class_name=CASADI_PARAMETER,
94 name=name,
95 value=value,
96 unit=unit,
97 description=description,
98 )
99 )
102def add_output(name: str, unit: str, type: str, value: Union[str, float], description: str) -> ast.Call:
103 """Create an AST node for an output definition.
105 Args:
106 name: The name of the output.
107 unit: The unit associated with the output value.
108 type: The data type of the output (e.g., "float", "string").
109 value: The value of the output. Can be a string or float.
110 description: A human-readable description of the output.
112 Returns:
113 ast.Call: An abstract syntax tree (AST) call node representing the output definition.
115 """
116 return create_ast_element(
117 OUTPUT_TEMPLATE.substitute(
118 class_name=CASADI_OUTPUT,
119 name=name,
120 unit=unit,
121 type=type,
122 value=value,
123 description=description,
124 )
125 )
128class SetupSystemModifier(ast.NodeTransformer):
129 """A custom AST transformer for modifying the MPC model file.
131 This class traverses the AST of the input file, identifies the relevant classes and methods,
132 and performs the necessary modifications.
134 """
136 def __init__(
137 self,
138 mpc_data: BaseMPCData,
139 controls: List[MPCVariable],
140 binary_controls: Optional[List[MPCVariable]],
141 ):
142 self.mpc_data = mpc_data
143 self.controls = controls
144 self.binary_controls = binary_controls
145 # create object for ast parsing for both, the config and the model
146 self.config_obj: Union[None, ast.expr] = None
147 self.model_obj: Union[None, ast.expr] = None
148 # select modification of setup_system based on mpc type
149 if isinstance(mpc_data, (PFMPCData, NFMPCData)):
150 self.modify_config_class = self.modify_config_class_shadow
151 self.modify_setup_system = self.modify_setup_system_shadow
152 if isinstance(mpc_data, BaselineMPCData):
153 self.modify_config_class = self.modify_config_class_baseline
154 self.modify_setup_system = self.modify_setup_system_baseline
156 def visit_Module(self, module: ast.Module) -> ast.Module:
157 """Visit a module definition in the AST.
159 Append or delete the import statements at the top of the module.
161 Args:
162 module: The module definition node in the AST.
164 Returns:
165 The possibly modified module definition node.
167 """
168 # append imports for baseline
169 if isinstance(self.mpc_data, BaselineMPCData):
170 module = add_import_to_tree(name="pandas", alias="pd", tree=module)
171 module = add_import_to_tree(name="casadi", alias="ca", tree=module)
172 # delete imports for shadow MPCs
173 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)):
174 module = remove_all_imports_from_tree(module)
175 # trigger the next visit method (ClassDef)
176 self.generic_visit(module)
177 return module
179 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
180 """Visit a class definition in the AST.
182 This method is called for each class definition in the AST. It identifies the
183 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions.
185 Args:
186 node: The class definition node in the AST.
188 Returns:
189 The possibly modified class definition node.
191 """
192 for base in node.bases:
193 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig":
194 # get ast object and trigger modification
195 self.config_obj = node
196 self.modify_config_class(node)
197 # change class name
198 node.name = self.mpc_data.class_name + "Config"
199 if isinstance(base, ast.Name) and base.id == "CasadiModel":
200 # get ast object and trigger modification
201 self.model_obj = node
202 for item in node.body:
203 if (
204 isinstance(item, ast.FunctionDef)
205 and item.name == "setup_system"
206 ):
207 self.modify_setup_system(item)
208 # change config value
209 if isinstance(item, ast.AnnAssign) and item.target.id == "config":
210 item.annotation = (
211 ast.parse(self.mpc_data.class_name + "Config").body[0].value
212 )
214 # change class name
215 node.name = self.mpc_data.class_name
217 return node
219 def get_leftmost_list(self, node: Union[ast.Tuple, ast.BinOp, ast.List]) -> Optional[ast.List]:
220 """Recursively traverse binary operations to get the leftmost list.
222 Args:
223 node: An AST node (could be a BinOp or directly a List)
225 Returns:
226 The leftmost List node found
228 """
229 if isinstance(node, ast.List):
230 return node
231 elif isinstance(node, ast.BinOp):
232 # If it's a binary operation, recurse to the left
233 return self.get_leftmost_list(node.left)
234 elif isinstance(node, ast.Tuple):
235 # If it's a tuple with elements, check the first element
236 if node.elts and len(node.elts) > 0:
237 return self.get_leftmost_list(node.elts[0])
238 # If we get here, we couldn't find a list
239 return None
241 def modify_config_class_shadow(self, node: ast.ClassDef):
242 """Modify the config class of the shadow mpc.
244 Args:
245 node: The class definition node of the config.
247 """
248 # loop over config object and modify fields
249 for body in node.body:
250 # add the time and full control trajectory inputs
251 if body.target.id == "inputs":
252 for control in self.controls:
253 body.value.elts.append(
254 add_input(
255 f"{full_trajectory_prefix}{control.name}"
256 f"{full_trajectory_suffix}",
257 "pd.Series([0])",
258 "W",
259 "pd.Series",
260 "full control output",
261 )
262 )
263 # also include binary controls
264 if self.binary_controls:
265 for control in self.binary_controls:
266 body.value.elts.append(
267 add_input(
268 f"{full_trajectory_prefix}{control.name}"
269 f"{full_trajectory_suffix}",
270 "pd.Series([0])",
271 "W",
272 "full control output",
273 "pd.Series",
274 )
275 )
276 body.value.elts.append(
277 add_input("in_provision", False, "-", "provision flag", "bool")
278 )
279 # add the flex variables and the weights
280 if body.target.id == "parameters":
281 for param_name in [PREP_TIME, FLEX_EVENT_DURATION, MARKET_TIME]:
282 body.value.elts.append(
283 add_parameter(param_name, 0, "s", "time to switch objective")
284 )
285 for weight in self.mpc_data.weights:
286 body.value.elts.append(
287 add_parameter(
288 weight.name,
289 weight.value,
290 "-",
291 "Weight for P in objective function",
292 )
293 )
295 def modify_config_class_baseline(self, node: ast.ClassDef):
296 """Modify the config class of the baseline mpc.
298 Args:
299 node: The class definition node of the config.
301 """
302 # loop over config object and modify fields
303 for body in node.body:
304 # add the fullcontrol trajectories to the baseline config class
305 if body.target.id == "outputs":
306 if isinstance(body.value, ast.List):
307 # Simple list case
308 value_list = body.value
309 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple):
310 # Complex case with concatenated lists or tuple
311 value_list = self.get_leftmost_list(body.value)
312 for control in self.controls:
313 value_list.elts.append(
314 add_output(
315 f"{full_trajectory_prefix}{control.name}"
316 f"{full_trajectory_suffix}",
317 "W",
318 "pd.Series",
319 "pd.Series([0])",
320 "full control output",
321 )
322 )
323 # also include binary controls
324 if self.binary_controls:
325 for control in self.binary_controls:
326 body.value.elts.append(
327 add_output(
328 f"{full_trajectory_prefix}{control.name}"
329 f"{full_trajectory_suffix}",
330 "W",
331 "pd.Series",
332 "pd.Series([0])",
333 "full control output",
334 )
335 )
336 # add the flexibility inputs
337 if body.target.id == "inputs":
338 if isinstance(body.value, ast.List):
339 # Simple list case
340 value_list = body.value
341 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple):
342 # Complex case with concatenated lists or tuple
343 value_list = self.get_leftmost_list(body.value)
344 value_list.elts.append(
345 add_input(
346 "_P_external",
347 0,
348 "W",
349 "External power profile to be provided",
350 "pd.Series",
351 )
352 )
353 value_list.elts.append(
354 add_input(
355 "in_provision",
356 False,
357 "-",
358 "Flag signaling if the flexibility is in provision",
359 "bool",
360 )
361 )
362 value_list.elts.append(
363 add_input(
364 "rel_start",
365 0,
366 "s",
367 "relative start time of the flexibility event",
368 "int",
369 )
370 )
371 value_list.elts.append(
372 add_input(
373 "rel_end",
374 0,
375 "s",
376 "relative end time of the flexibility event",
377 "int",
378 )
379 )
381 # add the flex variables and the weights
382 if body.target.id == "parameters":
383 for parameter in self.mpc_data.config_parameters_appendix:
384 body.value.elts.append(
385 add_parameter(parameter.name, 0, "-", parameter.description)
386 )
388 def modify_setup_system_shadow(self, node: ast.FunctionDef):
389 """Modify the setup_system method of the shadow mpc model class.
391 This method changes the return statement of the setup_system method and adds
392 all necessary new lines of code.
394 Args:
395 node: The function definition node of setup_system.
397 """
398 # constraint the control trajectories for t < market_time
399 for i, item in enumerate(node.body):
400 if (
401 isinstance(item, ast.Assign)
402 and isinstance(item.targets[0], ast.Attribute)
403 and item.targets[0].attr == "constraints"
404 ):
405 if isinstance(item.value, ast.List):
406 for ind, control in enumerate(self.controls):
407 # insert control boundaries at beginning of function
408 node.body.insert(
409 0,
410 ast.parse(
411 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, "
412 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
413 f"self.{control.name}.ub)"
414 ).body[0],
415 )
416 node.body.insert(
417 0,
418 ast.parse(
419 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, "
420 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
421 f"self.{control.name}.lb)"
422 ).body[0],
423 )
424 # append to constraints
425 new_element = (
426 ast.parse(
427 f"({control.name}_lower, self.{control.name}, {control.name}_upper)"
428 )
429 .body[0]
430 .value
431 )
432 item.value.elts.append(new_element)
433 # also include binary controls
434 if self.binary_controls:
435 for ind, control in enumerate(self.binary_controls):
436 # insert control boundaries at beginning of function
437 node.body.insert(
438 0,
439 ast.parse(
440 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, "
441 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
442 f"self.{control.name}.ub)"
443 ).body[0],
444 )
445 node.body.insert(
446 0,
447 ast.parse(
448 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, "
449 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
450 f"self.{control.name}.lb)"
451 ).body[0],
452 )
453 # append to constraints
454 new_element = (
455 ast.parse(
456 f"({control.name}_lower, self.{control.name}, {control.name}_upper)"
457 )
458 .body[0]
459 .value
460 )
461 item.value.elts.append(new_element)
462 break
463 # loop through setup_system function to find return statement
464 for i, stmt in enumerate(node.body):
465 if isinstance(stmt, ast.Return):
466 # store current return statement
467 original_return = stmt.value
468 new_body = [
469 # create new standard objective variable
470 ast.Assign(
471 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
472 value=original_return,
473 ),
474 # create flex objective variable
475 ast.Assign(
476 targets=[ast.Name(id="obj_flex", ctx=ast.Store())],
477 value=ast.parse(
478 self.mpc_data.flex_cost_function, mode="eval"
479 ).body,
480 ),
481 # overwrite return statement with custom function
482 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value),
483 ]
484 # append new variables to end of function
485 node.body[i:] = new_body
486 break
488 def modify_setup_system_baseline(self, node: ast.FunctionDef):
489 """Modify the setup_system method of the baseline mpc model class.
491 This method changes the return statement of the setup_system method and adds
492 all necessary new lines of code.
494 Args:
495 node: The function definition node of setup_system.
497 """
498 # set the control trajectories with the respective variables
499 if self.binary_controls:
500 controls_list = self.controls + self.binary_controls
501 else:
502 controls_list = self.controls
503 full_traj_list = [
504 ast.Assign(
505 targets=[
506 ast.Attribute(
507 value=ast.Name(id="self", ctx=ast.Load()),
508 attr=f"{full_trajectory_prefix}{control.name}"
509 f"{full_trajectory_suffix}.alg",
510 ctx=ast.Store(),
511 )
512 ],
513 value=ast.Attribute(
514 value=ast.Name(id="self", ctx=ast.Load()),
515 attr=control.name,
516 ctx=ast.Load(),
517 ),
518 )
519 for control in controls_list
520 ]
521 # loop through setup_system function to find return statement
522 for i, stmt in enumerate(node.body):
523 if isinstance(stmt, ast.Return):
524 # store current return statement
525 original_return = stmt.value
526 new_body = [
527 # create new standard objective variable
528 ast.Assign(
529 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
530 value=original_return,
531 ),
532 # overwrite return statement with custom function
533 ast.Return(
534 value=ast.parse(
535 return_baseline_cost_function(
536 power_variable=self.mpc_data.power_variable,
537 comfort_variable=self.mpc_data.comfort_variable
538 )
539 )
540 .body[0]
541 .value
542 ),
543 ]
544 # append new variables to end of function
545 node.body[i:] = full_traj_list + new_body
546 break
549def add_import_to_tree(name: str, alias: str, tree: ast.Module) -> ast.Module:
550 """Add import to the module.
552 The statement 'import name as alias' will be added.
554 Args:
555 name: name of the module to be imported
556 alias: alias of the module
557 tree: the tree to be imported
559 Returns:
560 The tree updated with the import statement
562 """
563 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)])
564 for node in tree.body:
565 if isinstance(node, ast.Import):
566 already_imported_names = [alias.name for alias in node.names]
567 already_imported_alias = [alias.asname for alias in node.names]
568 if (
569 name not in already_imported_names
570 and alias not in already_imported_alias
571 ):
572 tree.body.insert(0, import_statement)
573 break
574 else:
575 tree.body.insert(0, import_statement)
576 return tree
579def remove_all_imports_from_tree(tree: ast.Module) -> ast.Module:
580 # Create a new list to hold nodes that are not imports
581 new_body = [
582 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom))
583 ]
584 # Update the body of the tree to the new list
585 tree.body = new_body
586 return tree