Coverage for agentlib_flexquant/utils/parsing.py: 85%
151 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-01 15:10 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-01 15:10 +0000
1import ast
2from typing import Union, List, Optional
3from agentlib_flexquant.data_structures.mpcs import (
4 BaseMPCData,
5 PFMPCData,
6 NFMPCData,
7 BaselineMPCData,
8)
9from agentlib_flexquant.data_structures.globals import (
10 SHADOW_MPC_COST_FUNCTION,
11 return_baseline_cost_function,
12 full_trajectory_prefix,
13 full_trajectory_suffix,
14 PROFILE_DEVIATION_WEIGHT,
15 MARKET_TIME,
16 PREP_TIME,
17 FLEX_EVENT_DURATION
18)
19from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable
20from string import Template
22# Constants
23CASADI_INPUT = "CasadiInput"
24CASADI_PARAMETER = "CasadiParameter"
25CASADI_OUTPUT = "CasadiOutput"
27# String templates
28INPUT_TEMPLATE = Template(
29 "$class_name(name='$name', value=$value, unit='$unit', description='$description')"
30)
31PARAMETER_TEMPLATE = Template(
32 "$class_name(name='$name', value=$value, unit='$unit', description='$description')"
33)
34OUTPUT_TEMPLATE = Template(
35 "$class_name(name='$name', unit='$unit', type='$type', value=$value, description='$description')"
36)
39def create_ast_element(template_string):
40 return ast.parse(template_string).body[0].value
43def add_input(name, value, unit, description, type):
44 return create_ast_element(
45 INPUT_TEMPLATE.substitute(
46 class_name=CASADI_INPUT,
47 name=name,
48 value=value,
49 unit=unit,
50 description=description,
51 type=type,
52 )
53 )
56def add_parameter(name, value, unit, description):
57 return create_ast_element(
58 PARAMETER_TEMPLATE.substitute(
59 class_name=CASADI_PARAMETER,
60 name=name,
61 value=value,
62 unit=unit,
63 description=description,
64 )
65 )
68def add_output(name, unit, type, value, description):
69 return create_ast_element(
70 OUTPUT_TEMPLATE.substitute(
71 class_name=CASADI_OUTPUT,
72 name=name,
73 unit=unit,
74 type=type,
75 value=value,
76 description=description,
77 )
78 )
81class SetupSystemModifier(ast.NodeTransformer):
82 """A custom AST transformer for modifying the MPC model file.
84 This class traverses the AST of the input file, identifies the relevant classes and methods,
85 and performs the necessary modifications.
87 Attributes:
88 mpc_data (str): The new return expression to be used in the setup_system method.
90 """
92 def __init__(
93 self,
94 mpc_data: BaseMPCData,
95 controls: List[MPCVariable],
96 binary_controls: Optional[List[MPCVariable]],
97 ):
98 self.mpc_data = mpc_data
99 self.controls = controls
100 self.binary_controls = binary_controls
101 # create object for ast parsing for both, the config and the model
102 self.config_obj: Union[None, ast.expr] = None
103 self.model_obj: Union[None, ast.expr] = None
104 # select modification of setup_system based on mpc type
105 if isinstance(mpc_data, (PFMPCData, NFMPCData)):
106 self.modify_config_class = self.modify_config_class_shadow
107 self.modify_setup_system = self.modify_setup_system_shadow
108 if isinstance(mpc_data, BaselineMPCData):
109 self.modify_config_class = self.modify_config_class_baseline
110 self.modify_setup_system = self.modify_setup_system_baseline
112 def visit_Module(self, module):
113 """Visit a module definition in the AST.
115 Appends or deletes the import statements at the top of the module.
117 Args:
118 module (ast.Module): The module definition node in the AST.
120 Returns:
121 ast.Module: The possibly modified module definition node.
123 """
124 # append imports for baseline
125 if isinstance(self.mpc_data, BaselineMPCData):
126 module = add_import_to_tree(name="pandas", alias="pd", tree=module)
127 module = add_import_to_tree(name="casadi", alias="ca", tree=module)
128 # delete imports for shadow MPCs
129 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)):
130 module = remove_all_imports_from_tree(module)
131 # trigger the next visit method (ClassDef)
132 self.generic_visit(module)
133 return module
135 def visit_ClassDef(self, node):
136 """Visit a class definition in the AST.
138 This method is called for each class definition in the AST. It identifies the
139 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions.
141 Args:
142 node (ast.ClassDef): The class definition node in the AST.
144 Returns:
145 ast.ClassDef: The possibly modified class definition node.
147 """
148 for base in node.bases:
149 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig":
150 # get ast object and trigger modification
151 self.config_obj = node
152 self.modify_config_class(node)
153 # change class name
154 node.name = self.mpc_data.class_name + "Config"
155 if isinstance(base, ast.Name) and base.id == "CasadiModel":
156 # get ast object and trigger modification
157 self.model_obj = node
158 for item in node.body:
159 if (
160 isinstance(item, ast.FunctionDef)
161 and item.name == "setup_system"
162 ):
163 self.modify_setup_system(item)
164 # change config value
165 if isinstance(item, ast.AnnAssign) and item.target.id == "config":
166 item.annotation = (
167 ast.parse(self.mpc_data.class_name + "Config").body[0].value
168 )
170 # change class name
171 node.name = self.mpc_data.class_name
173 return node
175 def get_leftmost_list(self, node):
176 """
177 Recursively traverse binary operations to get the leftmost list.
179 Args:
180 node: An AST node (could be a BinOp or directly a List)
182 Returns:
183 The leftmost List node found
184 """
185 if isinstance(node, ast.List):
186 return node
187 elif isinstance(node, ast.BinOp):
188 # If it's a binary operation, recurse to the left
189 return self.get_leftmost_list(node.left)
190 elif isinstance(node, ast.Tuple):
191 # If it's a tuple with elements, check the first element
192 if node.elts and len(node.elts) > 0:
193 return self.get_leftmost_list(node.elts[0])
194 # If we get here, we couldn't find a list
195 return None
197 def modify_config_class_shadow(self, node):
198 """Modify the config class of the shadow mpc.
200 Args:
201 node (ast.ClassDef): The class definition node of the config.
203 """
204 # loop over config object and modify fields
205 for body in node.body:
206 # add the time and full control trajectory inputs
207 if body.target.id == "inputs":
208 for control in self.controls:
209 body.value.elts.append(
210 add_input(
211 f"{full_trajectory_prefix}{control.name}"
212 f"{full_trajectory_suffix}",
213 "pd.Series([0])",
214 "W",
215 "pd.Series",
216 "full control output",
217 )
218 )
219 # also include binary controls
220 if self.binary_controls:
221 for control in self.binary_controls:
222 body.value.elts.append(
223 add_input(
224 f"{full_trajectory_prefix}{control.name}"
225 f"{full_trajectory_suffix}",
226 "pd.Series([0])",
227 "W",
228 "full control output",
229 "pd.Series",
230 )
231 )
232 body.value.elts.append(
233 add_input("in_provision", False, "-", "provision flag", "bool")
234 )
235 # add the flex variables and the weights
236 if body.target.id == "parameters":
237 for param_name in [PREP_TIME, FLEX_EVENT_DURATION, MARKET_TIME]:
238 body.value.elts.append(
239 add_parameter(param_name, 0, "s", "time to switch objective")
240 )
241 for weight in self.mpc_data.weights:
242 body.value.elts.append(
243 add_parameter(
244 weight.name,
245 weight.value,
246 "-",
247 "Weight for P in objective function",
248 )
249 )
251 def modify_config_class_baseline(self, node):
252 """Modify the config class of the baseline mpc.
254 Args:
255 node (ast.ClassDef): The class definition node of the config.
257 """
258 # loop over config object and modify fields
259 for body in node.body:
260 # add the fullcontrol trajectories to the baseline config class
261 if body.target.id == "outputs":
262 if isinstance(body.value, ast.List):
263 # Simple list case
264 value_list = body.value
265 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple):
266 # Complex case with concatenated lists or tuple
267 value_list = self.get_leftmost_list(body.value)
268 for control in self.controls:
269 value_list.elts.append(
270 add_output(
271 f"{full_trajectory_prefix}{control.name}"
272 f"{full_trajectory_suffix}",
273 "W",
274 "pd.Series",
275 "pd.Series([0])",
276 "full control output",
277 )
278 )
279 # also include binary controls
280 if self.binary_controls:
281 for control in self.binary_controls:
282 body.value.elts.append(
283 add_output(
284 f"{full_trajectory_prefix}{control.name}"
285 f"{full_trajectory_suffix}",
286 "W",
287 "pd.Series",
288 "pd.Series([0])",
289 "full control output",
290 )
291 )
292 # add the flexibility inputs
293 if body.target.id == "inputs":
294 if isinstance(body.value, ast.List):
295 # Simple list case
296 value_list = body.value
297 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple):
298 # Complex case with concatenated lists or tuple
299 value_list = self.get_leftmost_list(body.value)
300 value_list.elts.append(
301 add_input(
302 "_P_external",
303 0,
304 "W",
305 "External power profile to be provided",
306 "pd.Series",
307 )
308 )
309 value_list.elts.append(
310 add_input(
311 "in_provision",
312 False,
313 "-",
314 "Flag signaling if the flexibility is in provision",
315 "bool",
316 )
317 )
318 value_list.elts.append(
319 add_input(
320 "rel_start",
321 0,
322 "s",
323 "relative start time of the flexibility event",
324 "int",
325 )
326 )
327 value_list.elts.append(
328 add_input(
329 "rel_end",
330 0,
331 "s",
332 "relative end time of the flexibility event",
333 "int",
334 )
335 )
337 # add the flex variables and the weights
338 if body.target.id == "parameters":
339 for parameter in self.mpc_data.config_parameters_appendix:
340 body.value.elts.append(
341 add_parameter(parameter.name, 0, "-", parameter.description)
342 )
344 def modify_setup_system_shadow(self, node):
345 """Modify the setup_system method of the shadow mpc model class.
347 This method changes the return statement of the setup_system method and adds
348 all necessary new lines of code.
350 Args:
351 node (ast.FunctionDef): The function definition node of setup_system.
353 """
354 # constraint the control trajectories for t < market_time
355 for i, item in enumerate(node.body):
356 if (
357 isinstance(item, ast.Assign)
358 and isinstance(item.targets[0], ast.Attribute)
359 and item.targets[0].attr == "constraints"
360 ):
361 if isinstance(item.value, ast.List):
362 for ind, control in enumerate(self.controls):
363 # insert control boundaries at beginning of function
364 node.body.insert(
365 0,
366 ast.parse(
367 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, "
368 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
369 f"self.{control.name}.ub)"
370 ).body[0],
371 )
372 node.body.insert(
373 0,
374 ast.parse(
375 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, "
376 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
377 f"self.{control.name}.lb)"
378 ).body[0],
379 )
380 # append to constraints
381 new_element = (
382 ast.parse(
383 f"({control.name}_lower, self.{control.name}, {control.name}_upper)"
384 )
385 .body[0]
386 .value
387 )
388 item.value.elts.append(new_element)
389 # also include binary controls
390 if self.binary_controls:
391 for ind, control in enumerate(self.binary_controls):
392 # insert control boundaries at beginning of function
393 node.body.insert(
394 0,
395 ast.parse(
396 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, "
397 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
398 f"self.{control.name}.ub)"
399 ).body[0],
400 )
401 node.body.insert(
402 0,
403 ast.parse(
404 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, "
405 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, "
406 f"self.{control.name}.lb)"
407 ).body[0],
408 )
409 # append to constraints
410 new_element = (
411 ast.parse(
412 f"({control.name}_lower, self.{control.name}, {control.name}_upper)"
413 )
414 .body[0]
415 .value
416 )
417 item.value.elts.append(new_element)
418 break
419 # loop through setup_system function to find return statement
420 for i, stmt in enumerate(node.body):
421 if isinstance(stmt, ast.Return):
422 # store current return statement
423 original_return = stmt.value
424 new_body = [
425 # create new standard objective variable
426 ast.Assign(
427 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
428 value=original_return,
429 ),
430 # create flex objective variable
431 ast.Assign(
432 targets=[ast.Name(id="obj_flex", ctx=ast.Store())],
433 value=ast.parse(
434 self.mpc_data.flex_cost_function, mode="eval"
435 ).body,
436 ),
437 # overwrite return statement with custom function
438 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value),
439 ]
440 # append new variables to end of function
441 node.body[i:] = new_body
442 break
444 def modify_setup_system_baseline(self, node):
445 """Modify the setup_system method of the baseline mpc model class.
447 This method changes the return statement of the setup_system method and adds
448 all necessary new lines of code.
450 Args:
451 node (ast.FunctionDef): The function definition node of setup_system.
453 """
454 # set the control trajectories with the respective variables
455 if self.binary_controls:
456 controls_list = self.controls + self.binary_controls
457 else:
458 controls_list = self.controls
459 full_traj_list = [
460 ast.Assign(
461 targets=[
462 ast.Attribute(
463 value=ast.Name(id="self", ctx=ast.Load()),
464 attr=f"{full_trajectory_prefix}{control.name}"
465 f"{full_trajectory_suffix}.alg",
466 ctx=ast.Store(),
467 )
468 ],
469 value=ast.Attribute(
470 value=ast.Name(id="self", ctx=ast.Load()),
471 attr=control.name,
472 ctx=ast.Load(),
473 ),
474 )
475 for control in controls_list
476 ]
477 # loop through setup_system function to find return statement
478 for i, stmt in enumerate(node.body):
479 if isinstance(stmt, ast.Return):
480 # store current return statement
481 original_return = stmt.value
482 new_body = [
483 # create new standard objective variable
484 ast.Assign(
485 targets=[ast.Name(id="obj_std", ctx=ast.Store())],
486 value=original_return,
487 ),
488 # overwrite return statement with custom function
489 ast.Return(
490 value=ast.parse(
491 return_baseline_cost_function(
492 power_variable=self.mpc_data.power_variable,
493 comfort_variable=self.mpc_data.comfort_variable
494 )
495 )
496 .body[0]
497 .value
498 ),
499 ]
500 # append new variables to end of function
501 node.body[i:] = full_traj_list + new_body
502 break
505def add_import_to_tree(name: str, alias: str, tree: ast.Module):
506 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)])
507 for node in tree.body:
508 if isinstance(node, ast.Import):
509 already_imported_names = [alias.name for alias in node.names]
510 already_imported_alias = [alias.asname for alias in node.names]
511 if (
512 name not in already_imported_names
513 and alias not in already_imported_alias
514 ):
515 tree.body.insert(0, import_statement)
516 break
517 else:
518 tree.body.insert(0, import_statement)
519 return tree
522def remove_all_imports_from_tree(tree: ast.Module):
523 # Create a new list to hold nodes that are not imports
524 new_body = [
525 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom))
526 ]
527 # Update the body of the tree to the new list
528 tree.body = new_body
529 return tree