Coverage for agentlib_flexquant/utils/parsing.py: 83%

160 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-03-26 09:43 +0000

1import ast 

2import logging 

3from string import Template 

4from typing import Optional, Union 

5 

6from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable 

7 

8from agentlib_flexquant.data_structures.globals import ( 

9 SHADOW_MPC_COST_FUNCTION, 

10 full_trajectory_suffix, 

11 return_baseline_cost_function, 

12 PROVISION_VAR_NAME, 

13 ACCEPTED_POWER_VAR_NAME, 

14 RELATIVE_EVENT_START_TIME_VAR_NAME, 

15 RELATIVE_EVENT_END_TIME_VAR_NAME 

16) 

17from agentlib_flexquant.data_structures.mpcs import ( 

18 BaselineMPCData, 

19 NFMPCData, 

20 PFMPCData, 

21) 

22 

23logger = logging.getLogger(__name__) 

24 

25# Constants 

26CASADI_INPUT = "CasadiInput" 

27CASADI_PARAMETER = "CasadiParameter" 

28CASADI_OUTPUT = "CasadiOutput" 

29 

30# String templates 

31INPUT_TEMPLATE = Template( 

32 "$class_name(name='$name', value=$value, unit='$unit', type='$type', " 

33 "description='$description')" 

34) 

35PARAMETER_TEMPLATE = Template( 

36 "$class_name(name='$name', value=$value, unit='$unit', description='$description')" 

37) 

38OUTPUT_TEMPLATE = Template( 

39 "$class_name(name='$name', unit='$unit', type='$type', value=$value, " 

40 "description='$description')" 

41) 

42 

43 

44def create_ast_element(template_string: str) -> ast.expr: 

45 """Convert a template string into an AST call node. 

46 

47 Args: 

48 template_string: A Python code template string to parse. 

49 

50 Returns: 

51 ast.Expr: An abstract syntax tree (AST) expr node parsed from the template 

52 string. 

53 

54 """ 

55 return ast.parse(template_string).body[0].value 

56 

57 

58def add_input( 

59 name: str, value: Union[bool, str, int], unit: str, description: str, type: str 

60) -> ast.expr: 

61 """Create an AST node for an input definition. 

62 

63 Args: 

64 name: The name of the input. 

65 value: The default value for the input. Can be a boolean, string, or integer. 

66 unit: The unit associated with the input value. 

67 description: A human-readable description of the input. 

68 type: The data type of the input (e.g., "float", "int", "string"). 

69 

70 Returns: 

71 ast.Call: An abstract syntax tree (AST) call node representing the input definition. 

72 

73 """ 

74 return create_ast_element( 

75 INPUT_TEMPLATE.substitute( 

76 class_name=CASADI_INPUT, 

77 name=name, 

78 value=value, 

79 unit=unit, 

80 description=description, 

81 type=type, 

82 ) 

83 ) 

84 

85 

86def add_parameter( 

87 name: str, value: Union[int, float], unit: str, description: str 

88) -> ast.expr: 

89 """Create an AST node for a parameter definition. 

90 

91 Args: 

92 name: The name of the parameter. 

93 value: The value of the parameter. Can be an integer or float. 

94 unit: The unit associated with the parameter value. 

95 description: A human-readable description of the parameter. 

96 

97 Returns: 

98 ast.expr: An abstract syntax tree (AST) call node 

99 representing the parameter definition. 

100 

101 """ 

102 return create_ast_element( 

103 PARAMETER_TEMPLATE.substitute( 

104 class_name=CASADI_PARAMETER, 

105 name=name, 

106 value=value, 

107 unit=unit, 

108 description=description, 

109 ) 

110 ) 

111 

112 

113def add_output( 

114 name: str, unit: str, type: str, value: Union[str, float], description: str 

115) -> ast.expr: 

116 """Create an AST node for an output definition. 

117 

118 Args: 

119 name: The name of the output. 

120 unit: The unit associated with the output value. 

121 type: The data type of the output (e.g., "float", "string"). 

122 value: The value of the output. Can be a string or float. 

123 description: A human-readable description of the output. 

124 

125 Returns: 

126 ast.expr: An abstract syntax tree (AST) call node representing the output definition. 

127 

128 """ 

129 return create_ast_element( 

130 OUTPUT_TEMPLATE.substitute( 

131 class_name=CASADI_OUTPUT, 

132 name=name, 

133 unit=unit, 

134 type=type, 

135 value=value, 

136 description=description, 

137 ) 

138 ) 

139 

140 

141def _get_assignment_name(node: ast.stmt) -> Optional[str]: 

142 """Extract the variable name from an assignment statement. 

143 

144 Handles both annotated assignments (ast.AnnAssign) and regular assignments 

145 (ast.Assign). 

146 

147 Args: 

148 node: An AST statement node. 

149 

150 Returns: 

151 The variable name if the node is an assignment, None otherwise. 

152 

153 """ 

154 if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): 

155 return node.target.id 

156 elif isinstance(node, ast.Assign) and len(node.targets) == 1: 

157 if isinstance(node.targets[0], ast.Name): 

158 return node.targets[0].id 

159 return None 

160 

161 

162class SetupSystemModifier(ast.NodeTransformer): 

163 """A custom AST transformer for modifying the MPC model file. 

164 

165 This class traverses the AST of the input file, identifies the relevant classes and methods, 

166 and performs the necessary modifications. 

167 

168 """ 

169 

170 def __init__( 

171 self, 

172 mpc_data: Union[BaselineMPCData, NFMPCData, PFMPCData], 

173 controls: list[MPCVariable], 

174 binary_controls: Optional[list[MPCVariable]], 

175 ): 

176 self.mpc_data = mpc_data 

177 self.controls = controls 

178 self.binary_controls = binary_controls 

179 # create object for ast parsing for both, the config and the model 

180 self.config_obj: Union[None, ast.expr] = None 

181 self.model_obj: Union[None, ast.expr] = None 

182 # select modification of setup_system based on mpc type 

183 if isinstance(mpc_data, (PFMPCData, NFMPCData)): 

184 self.modify_config_class = self.modify_config_class_shadow 

185 self.modify_setup_system = self.modify_setup_system_shadow 

186 if isinstance(mpc_data, BaselineMPCData): 

187 self.modify_config_class = self.modify_config_class_baseline 

188 self.modify_setup_system = self.modify_setup_system_baseline 

189 

190 def visit_Module(self, module: ast.Module) -> ast.Module: 

191 """Visit a module definition in the AST. 

192 

193 Append or delete the import statements at the top of the module. 

194 

195 Args: 

196 module: The module definition node in the AST. 

197 

198 Returns: 

199 The possibly modified module definition node. 

200 

201 """ 

202 # append imports for baseline 

203 if isinstance(self.mpc_data, BaselineMPCData): 

204 module = add_import_to_tree(name="pandas", alias="pd", tree=module) 

205 module = add_import_to_tree(name="casadi", alias="ca", tree=module) 

206 # delete imports for shadow MPCs 

207 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)): 

208 module = remove_all_imports_from_tree(module) 

209 # trigger the next visit method (ClassDef) 

210 self.generic_visit(module) 

211 return module 

212 

213 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: 

214 """Visit a class definition in the AST. 

215 

216 This method is called for each class definition in the AST. It identifies the 

217 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions. 

218 

219 Args: 

220 node: The class definition node in the AST. 

221 

222 Returns: 

223 The possibly modified class definition node. 

224 

225 """ 

226 for base in node.bases: 

227 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig": 

228 # get ast object and trigger modification 

229 self.config_obj = node 

230 self.modify_config_class(node) 

231 # change class name 

232 node.name = self.mpc_data.class_name + "Config" 

233 if isinstance(base, ast.Name) and base.id == "CasadiModel": 

234 # get ast object and trigger modification 

235 self.model_obj = node 

236 for item in node.body: 

237 if ( 

238 isinstance(item, ast.FunctionDef) 

239 and item.name == "setup_system" 

240 ): 

241 self.modify_setup_system(item) 

242 # change config value 

243 if isinstance(item, ast.AnnAssign) and item.target.id == "config": 

244 item.annotation = ( 

245 ast.parse(self.mpc_data.class_name + "Config").body[0].value 

246 ) 

247 

248 # change class name 

249 node.name = self.mpc_data.class_name 

250 

251 return node 

252 

253 def get_leftmost_list( 

254 self, node: Union[ast.Tuple, ast.BinOp, ast.List] 

255 ) -> Optional[ast.List]: 

256 """Recursively traverse binary operations to get the leftmost list. 

257 

258 Args: 

259 node: An AST node (could be a BinOp or directly a List) 

260 

261 Returns: 

262 The leftmost List node found 

263 

264 """ 

265 if isinstance(node, ast.List): 

266 return node 

267 elif isinstance(node, ast.BinOp): 

268 # If it's a binary operation, recurse to the left 

269 return self.get_leftmost_list(node.left) 

270 elif isinstance(node, ast.Tuple): 

271 # If it's a tuple with elements, check the first element 

272 if node.elts and len(node.elts) > 0: 

273 return self.get_leftmost_list(node.elts[0]) 

274 # If we get here, we couldn't find a list 

275 return None 

276 

277 def modify_config_class_shadow(self, node: ast.ClassDef): 

278 """Modify the config class of the shadow mpc. 

279 

280 Args: 

281 node: The class definition node of the config. 

282 

283 """ 

284 # loop over config object and modify fields 

285 for body in node.body: 

286 # If there are custom functions in the config class, skip them 

287 if isinstance(body, ast.FunctionDef): 

288 continue 

289 

290 # Skip non-annotated assignments with a warning 

291 if isinstance(body, ast.Assign): 

292 var_name = _get_assignment_name(body) 

293 logger.warning( 

294 "Skipping non-annotated class variable '%s' in config class '%s'. " 

295 "Only type-annotated variables (e.g., 'var: Type = value') can be " 

296 "modified by the AST transformer. If this variable should be " 

297 "included in the MPC configuration, please add a type annotation.", 

298 var_name or "<unknown>", 

299 node.name 

300 ) 

301 continue 

302 # add the time and full baseline control trajectory as inputs 

303 if body.target.id == "inputs": 

304 for control in self.controls: 

305 body.value.elts.append( 

306 add_input( 

307 f"{control.name}{full_trajectory_suffix}", 

308 None, 

309 control.unit, 

310 "full control trajectory output of baseline mpc", 

311 "pd.Series", 

312 ) 

313 ) 

314 # also include binary controls 

315 if self.binary_controls: 

316 for control in self.binary_controls: 

317 body.value.elts.append( 

318 add_input( 

319 f"{control.name}{full_trajectory_suffix}", 

320 None, 

321 control.unit, 

322 "full control trajectory output of baseline mpc", 

323 "pd.Series", 

324 ) 

325 ) 

326 for var in self.mpc_data.config_inputs_appendix: 

327 body.value.elts.append( 

328 add_input(var.name, var.value, var.unit, var.description, var.type) 

329 ) 

330 

331 # add the flex variables and the weights 

332 if body.target.id == "parameters": 

333 for parameter in self.mpc_data.config_parameters_appendix: 

334 body.value.elts.append( 

335 add_parameter(parameter.name, parameter.value, parameter.unit, parameter.description) 

336 ) 

337 

338 

339 def modify_config_class_baseline(self, node: ast.ClassDef): 

340 """Modify the config class of the baseline mpc. 

341 

342 Args: 

343 node: The class definition node of the config. 

344 

345 """ 

346 # loop over config object and modify fields 

347 for body in node.body: 

348 # If there are custom functions in the config class, skip them 

349 if isinstance(body, ast.FunctionDef): 

350 continue 

351 

352 # Skip regular assignments (ast.Assign) - only process annotated assignments 

353 if not isinstance(body, ast.AnnAssign): 

354 var_name = _get_assignment_name(body) 

355 logger.warning( 

356 "Skipping non-annotated class variable '%s' in config class '%s'. " 

357 "Only type-annotated variables (e.g., 'var: Type = value') can be " 

358 "modified by the AST transformer. If this variable should be " 

359 "included in the MPC configuration, please add a type annotation.", 

360 var_name or "<unknown>", 

361 node.name 

362 ) 

363 continue 

364 

365 # add the fullcontrol trajectories to the baseline config class 

366 if body.target.id == "outputs": 

367 if isinstance(body.value, ast.List): 

368 # Simple list case 

369 value_list = body.value 

370 elif isinstance(body.value, ast.BinOp) or isinstance( 

371 body.value, ast.Tuple 

372 ): 

373 # Complex case with concatenated lists or tuple 

374 value_list = self.get_leftmost_list(body.value) 

375 

376 # add the flexibility inputs 

377 if body.target.id == "inputs": 

378 if isinstance(body.value, ast.List): 

379 # Simple list case 

380 value_list = body.value 

381 elif isinstance(body.value, ast.BinOp) or isinstance( 

382 body.value, ast.Tuple 

383 ): 

384 # Complex case with concatenated lists or tuple 

385 value_list = self.get_leftmost_list(body.value) 

386 value_list.elts.append( 

387 add_input( 

388 ACCEPTED_POWER_VAR_NAME, 

389 0, 

390 "W", 

391 "External power profile to be provided", 

392 "pd.Series", 

393 ) 

394 ) 

395 value_list.elts.append( 

396 add_input( 

397 PROVISION_VAR_NAME, 

398 False, 

399 "-", 

400 "Flag signaling if the flexibility is in provision", 

401 "bool", 

402 ) 

403 ) 

404 value_list.elts.append( 

405 add_input( 

406 RELATIVE_EVENT_START_TIME_VAR_NAME, 

407 0, 

408 "s", 

409 "relative start time of the flexibility event", 

410 "int", 

411 ) 

412 ) 

413 value_list.elts.append( 

414 add_input( 

415 RELATIVE_EVENT_END_TIME_VAR_NAME, 

416 0, 

417 "s", 

418 "relative end time of the flexibility event", 

419 "int", 

420 ) 

421 ) 

422 

423 # add the flex variables and the weights 

424 if body.target.id == "parameters": 

425 for parameter in self.mpc_data.config_parameters_appendix: 

426 body.value.elts.append( 

427 add_parameter(parameter.name, 0, "-", parameter.description) 

428 ) 

429 

430 def modify_setup_system_shadow(self, node: ast.FunctionDef): 

431 """Modify the setup_system method of the shadow mpc model class. 

432 

433 This method changes the return statement of the setup_system method and adds 

434 all necessary new lines of code. 

435 

436 Args: 

437 node: The function definition node of setup_system. 

438 

439 """ 

440 # constraint the control trajectories for t < market_time 

441 for i, item in enumerate(node.body): 

442 if ( 

443 isinstance(item, ast.Assign) 

444 and isinstance(item.targets[0], ast.Attribute) 

445 and item.targets[0].attr == "constraints" 

446 ): 

447 if isinstance(item.value, ast.List): 

448 for ind, control in enumerate(self.controls): 

449 # insert control boundaries at beginning of function 

450 node.body.insert( 

451 0, 

452 ast.parse( 

453 f"{control.name}_upper = ca.if_else(self.time < " 

454 f"self.market_time.sym, " 

455 f"self.{control.name}{full_trajectory_suffix}.sym, " 

456 f"self.{control.name}.ub)" 

457 ).body[0], 

458 ) 

459 node.body.insert( 

460 0, 

461 ast.parse( 

462 f"{control.name}_lower = ca.if_else(self.time < " 

463 f"self.market_time.sym, " 

464 f"self.{control.name}{full_trajectory_suffix}.sym, " 

465 f"self.{control.name}.lb)" 

466 ).body[0], 

467 ) 

468 # append to constraints 

469 new_element = ( 

470 ast.parse( 

471 f"({control.name}_lower, self.{control.name}, {control.name}_upper)" 

472 ) 

473 .body[0] 

474 .value 

475 ) 

476 item.value.elts.append(new_element) 

477 break 

478 # loop through setup_system function to find return statement 

479 for i, stmt in enumerate(node.body): 

480 if isinstance(stmt, ast.Return): 

481 # store current return statement 

482 original_return = stmt.value 

483 

484 # First, check if there's actually an appendix to add 

485 if self.mpc_data.flex_cost_function_appendix: 

486 # Parse the appendix string into an AST expression 

487 appendix_ast = ast.parse(self.mpc_data.flex_cost_function_appendix, 

488 mode="eval").body 

489 # Create a BinOp node representing: original_return + appendix 

490 combined_value = ast.BinOp( 

491 left=original_return, 

492 op=ast.Add(), 

493 right=appendix_ast 

494 ) 

495 else: 

496 combined_value = original_return 

497 

498 new_body = [ 

499 ast.Assign( 

500 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

501 value=combined_value, 

502 ), 

503 # create flex objective variable 

504 ast.Assign( 

505 targets=[ast.Name(id="obj_flex", ctx=ast.Store())], 

506 value=ast.parse( 

507 self.mpc_data.flex_cost_function, mode="eval" 

508 ).body, 

509 ), 

510 # overwrite return statement with custom function 

511 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value), 

512 ] 

513 node.body[i:] = new_body 

514 break 

515 

516 def modify_setup_system_baseline(self, node: ast.FunctionDef): 

517 """Modify the setup_system method of the baseline mpc model class. 

518 

519 This method changes the return statement of the setup_system method and adds 

520 all necessary new lines of code. 

521 

522 Args: 

523 node: The function definition node of setup_system. 

524 

525 """ 

526 

527 # loop through setup_system function to find return statement 

528 for i, stmt in enumerate(node.body): 

529 if isinstance(stmt, ast.Return): 

530 # store current return statement 

531 original_return = stmt.value 

532 new_body = [ 

533 # create new standard objective variable 

534 ast.Assign( 

535 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

536 value=original_return, 

537 ), 

538 # overwrite return statement with custom function 

539 ast.Return( 

540 value=ast.parse( 

541 return_baseline_cost_function( 

542 power_variable=self.mpc_data.power_variable, 

543 comfort_variable=self.mpc_data.comfort_variable, 

544 ) 

545 ) 

546 .body[0] 

547 .value 

548 ), 

549 ] 

550 # append new variables to end of function 

551 node.body[i:] = new_body 

552 break 

553 

554 

555def add_import_to_tree(name: str, alias: str, tree: ast.Module) -> ast.Module: 

556 """Add import to the module. 

557 

558 The statement 'import name as alias' will be added. 

559 

560 Args: 

561 name: name of the module to be imported 

562 alias: alias of the module 

563 tree: the tree to be imported 

564 

565 Returns: 

566 The tree updated with the import statement 

567 

568 """ 

569 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)]) 

570 for node in tree.body: 

571 if isinstance(node, ast.Import): 

572 already_imported_names = [alias.name for alias in node.names] 

573 already_imported_alias = [alias.asname for alias in node.names] 

574 if ( 

575 name not in already_imported_names 

576 and alias not in already_imported_alias 

577 ): 

578 tree.body.insert(0, import_statement) 

579 break 

580 else: 

581 tree.body.insert(0, import_statement) 

582 return tree 

583 

584 

585def remove_all_imports_from_tree(tree: ast.Module) -> ast.Module: 

586 # Create a new list to hold nodes that are not imports 

587 new_body = [ 

588 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom)) 

589 ] 

590 # Update the body of the tree to the new list 

591 tree.body = new_body 

592 return tree