Coverage for agentlib_flexquant/utils/parsing.py: 85%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-15 15:25 +0000

1import ast 

2from typing import Union, List, Optional 

3from string import Template 

4from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable 

5from agentlib_flexquant.data_structures.mpcs import ( 

6 BaseMPCData, 

7 PFMPCData, 

8 NFMPCData, 

9 BaselineMPCData, 

10) 

11from agentlib_flexquant.data_structures.globals import ( 

12 SHADOW_MPC_COST_FUNCTION, 

13 return_baseline_cost_function, 

14 full_trajectory_prefix, 

15 full_trajectory_suffix, 

16 MARKET_TIME, 

17 PREP_TIME, 

18 FLEX_EVENT_DURATION 

19) 

20 

21 

22# Constants 

23CASADI_INPUT = "CasadiInput" 

24CASADI_PARAMETER = "CasadiParameter" 

25CASADI_OUTPUT = "CasadiOutput" 

26 

27# String templates 

28INPUT_TEMPLATE = Template( 

29 "$class_name(name='$name', value=$value, unit='$unit', description='$description')" 

30) 

31PARAMETER_TEMPLATE = Template( 

32 "$class_name(name='$name', value=$value, unit='$unit', description='$description')" 

33) 

34OUTPUT_TEMPLATE = Template( 

35 "$class_name(name='$name', unit='$unit', type='$type', value=$value, description='$description')" 

36) 

37 

38 

39def create_ast_element(template_string: str) -> ast.Call: 

40 """Convert a template string into an AST call node. 

41 

42 Args: 

43 template_string: A Python code template string to parse. 

44 

45 Returns: 

46 ast.Call: An abstract syntax tree (AST) call node parsed from the template string. 

47 

48 """ 

49 return ast.parse(template_string).body[0].value 

50 

51 

52def add_input(name: str, value: Union[bool, str, int], unit: str, description: str, type: str) -> ast.Call: 

53 """Create an AST node for an input definition. 

54 

55 Args: 

56 name: The name of the input. 

57 value: The default value for the input. Can be a boolean, string, or integer. 

58 unit: The unit associated with the input value. 

59 description: A human-readable description of the input. 

60 type: The data type of the input (e.g., "float", "int", "string"). 

61 

62 Returns: 

63 ast.Call: An abstract syntax tree (AST) call node representing the input definition. 

64 

65 """ 

66 return create_ast_element( 

67 INPUT_TEMPLATE.substitute( 

68 class_name=CASADI_INPUT, 

69 name=name, 

70 value=value, 

71 unit=unit, 

72 description=description, 

73 type=type, 

74 ) 

75 ) 

76 

77 

78def add_parameter(name: str, value: Union[int, float], unit: str, description: str) -> ast.Call: 

79 """Create an AST node for a parameter definition. 

80 

81 Args: 

82 name: The name of the parameter. 

83 value: The value of the parameter. Can be an integer or float. 

84 unit: The unit associated with the parameter value. 

85 description: A human-readable description of the parameter. 

86 

87 Returns: 

88 ast.Call: An abstract syntax tree (AST) call node representing the parameter definition. 

89 

90 """ 

91 return create_ast_element( 

92 PARAMETER_TEMPLATE.substitute( 

93 class_name=CASADI_PARAMETER, 

94 name=name, 

95 value=value, 

96 unit=unit, 

97 description=description, 

98 ) 

99 ) 

100 

101 

102def add_output(name: str, unit: str, type: str, value: Union[str, float], description: str) -> ast.Call: 

103 """Create an AST node for an output definition. 

104 

105 Args: 

106 name: The name of the output. 

107 unit: The unit associated with the output value. 

108 type: The data type of the output (e.g., "float", "string"). 

109 value: The value of the output. Can be a string or float. 

110 description: A human-readable description of the output. 

111 

112 Returns: 

113 ast.Call: An abstract syntax tree (AST) call node representing the output definition. 

114 

115 """ 

116 return create_ast_element( 

117 OUTPUT_TEMPLATE.substitute( 

118 class_name=CASADI_OUTPUT, 

119 name=name, 

120 unit=unit, 

121 type=type, 

122 value=value, 

123 description=description, 

124 ) 

125 ) 

126 

127 

128class SetupSystemModifier(ast.NodeTransformer): 

129 """A custom AST transformer for modifying the MPC model file. 

130 

131 This class traverses the AST of the input file, identifies the relevant classes and methods, 

132 and performs the necessary modifications. 

133 

134 """ 

135 

136 def __init__( 

137 self, 

138 mpc_data: BaseMPCData, 

139 controls: List[MPCVariable], 

140 binary_controls: Optional[List[MPCVariable]], 

141 ): 

142 self.mpc_data = mpc_data 

143 self.controls = controls 

144 self.binary_controls = binary_controls 

145 # create object for ast parsing for both, the config and the model 

146 self.config_obj: Union[None, ast.expr] = None 

147 self.model_obj: Union[None, ast.expr] = None 

148 # select modification of setup_system based on mpc type 

149 if isinstance(mpc_data, (PFMPCData, NFMPCData)): 

150 self.modify_config_class = self.modify_config_class_shadow 

151 self.modify_setup_system = self.modify_setup_system_shadow 

152 if isinstance(mpc_data, BaselineMPCData): 

153 self.modify_config_class = self.modify_config_class_baseline 

154 self.modify_setup_system = self.modify_setup_system_baseline 

155 

156 def visit_Module(self, module: ast.Module) -> ast.Module: 

157 """Visit a module definition in the AST. 

158 

159 Append or delete the import statements at the top of the module. 

160 

161 Args: 

162 module: The module definition node in the AST. 

163 

164 Returns: 

165 The possibly modified module definition node. 

166 

167 """ 

168 # append imports for baseline 

169 if isinstance(self.mpc_data, BaselineMPCData): 

170 module = add_import_to_tree(name="pandas", alias="pd", tree=module) 

171 module = add_import_to_tree(name="casadi", alias="ca", tree=module) 

172 # delete imports for shadow MPCs 

173 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)): 

174 module = remove_all_imports_from_tree(module) 

175 # trigger the next visit method (ClassDef) 

176 self.generic_visit(module) 

177 return module 

178 

179 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: 

180 """Visit a class definition in the AST. 

181 

182 This method is called for each class definition in the AST. It identifies the 

183 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions. 

184 

185 Args: 

186 node: The class definition node in the AST. 

187 

188 Returns: 

189 The possibly modified class definition node. 

190 

191 """ 

192 for base in node.bases: 

193 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig": 

194 # get ast object and trigger modification 

195 self.config_obj = node 

196 self.modify_config_class(node) 

197 # change class name 

198 node.name = self.mpc_data.class_name + "Config" 

199 if isinstance(base, ast.Name) and base.id == "CasadiModel": 

200 # get ast object and trigger modification 

201 self.model_obj = node 

202 for item in node.body: 

203 if ( 

204 isinstance(item, ast.FunctionDef) 

205 and item.name == "setup_system" 

206 ): 

207 self.modify_setup_system(item) 

208 # change config value 

209 if isinstance(item, ast.AnnAssign) and item.target.id == "config": 

210 item.annotation = ( 

211 ast.parse(self.mpc_data.class_name + "Config").body[0].value 

212 ) 

213 

214 # change class name 

215 node.name = self.mpc_data.class_name 

216 

217 return node 

218 

219 def get_leftmost_list(self, node: Union[ast.Tuple, ast.BinOp, ast.List]) -> Optional[ast.List]: 

220 """Recursively traverse binary operations to get the leftmost list. 

221 

222 Args: 

223 node: An AST node (could be a BinOp or directly a List) 

224 

225 Returns: 

226 The leftmost List node found 

227 

228 """ 

229 if isinstance(node, ast.List): 

230 return node 

231 elif isinstance(node, ast.BinOp): 

232 # If it's a binary operation, recurse to the left 

233 return self.get_leftmost_list(node.left) 

234 elif isinstance(node, ast.Tuple): 

235 # If it's a tuple with elements, check the first element 

236 if node.elts and len(node.elts) > 0: 

237 return self.get_leftmost_list(node.elts[0]) 

238 # If we get here, we couldn't find a list 

239 return None 

240 

241 def modify_config_class_shadow(self, node: ast.ClassDef): 

242 """Modify the config class of the shadow mpc. 

243 

244 Args: 

245 node: The class definition node of the config. 

246 

247 """ 

248 # loop over config object and modify fields 

249 for body in node.body: 

250 # add the time and full control trajectory inputs 

251 if body.target.id == "inputs": 

252 for control in self.controls: 

253 body.value.elts.append( 

254 add_input( 

255 f"{full_trajectory_prefix}{control.name}" 

256 f"{full_trajectory_suffix}", 

257 "pd.Series([0])", 

258 "W", 

259 "pd.Series", 

260 "full control output", 

261 ) 

262 ) 

263 # also include binary controls 

264 if self.binary_controls: 

265 for control in self.binary_controls: 

266 body.value.elts.append( 

267 add_input( 

268 f"{full_trajectory_prefix}{control.name}" 

269 f"{full_trajectory_suffix}", 

270 "pd.Series([0])", 

271 "W", 

272 "full control output", 

273 "pd.Series", 

274 ) 

275 ) 

276 body.value.elts.append( 

277 add_input("in_provision", False, "-", "provision flag", "bool") 

278 ) 

279 # add the flex variables and the weights 

280 if body.target.id == "parameters": 

281 for param_name in [PREP_TIME, FLEX_EVENT_DURATION, MARKET_TIME]: 

282 body.value.elts.append( 

283 add_parameter(param_name, 0, "s", "time to switch objective") 

284 ) 

285 for weight in self.mpc_data.weights: 

286 body.value.elts.append( 

287 add_parameter( 

288 weight.name, 

289 weight.value, 

290 "-", 

291 "Weight for P in objective function", 

292 ) 

293 ) 

294 

295 def modify_config_class_baseline(self, node: ast.ClassDef): 

296 """Modify the config class of the baseline mpc. 

297 

298 Args: 

299 node: The class definition node of the config. 

300 

301 """ 

302 # loop over config object and modify fields 

303 for body in node.body: 

304 # add the fullcontrol trajectories to the baseline config class 

305 if body.target.id == "outputs": 

306 if isinstance(body.value, ast.List): 

307 # Simple list case 

308 value_list = body.value 

309 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple): 

310 # Complex case with concatenated lists or tuple 

311 value_list = self.get_leftmost_list(body.value) 

312 for control in self.controls: 

313 value_list.elts.append( 

314 add_output( 

315 f"{full_trajectory_prefix}{control.name}" 

316 f"{full_trajectory_suffix}", 

317 "W", 

318 "pd.Series", 

319 "pd.Series([0])", 

320 "full control output", 

321 ) 

322 ) 

323 # also include binary controls 

324 if self.binary_controls: 

325 for control in self.binary_controls: 

326 body.value.elts.append( 

327 add_output( 

328 f"{full_trajectory_prefix}{control.name}" 

329 f"{full_trajectory_suffix}", 

330 "W", 

331 "pd.Series", 

332 "pd.Series([0])", 

333 "full control output", 

334 ) 

335 ) 

336 # add the flexibility inputs 

337 if body.target.id == "inputs": 

338 if isinstance(body.value, ast.List): 

339 # Simple list case 

340 value_list = body.value 

341 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple): 

342 # Complex case with concatenated lists or tuple 

343 value_list = self.get_leftmost_list(body.value) 

344 value_list.elts.append( 

345 add_input( 

346 "_P_external", 

347 0, 

348 "W", 

349 "External power profile to be provided", 

350 "pd.Series", 

351 ) 

352 ) 

353 value_list.elts.append( 

354 add_input( 

355 "in_provision", 

356 False, 

357 "-", 

358 "Flag signaling if the flexibility is in provision", 

359 "bool", 

360 ) 

361 ) 

362 value_list.elts.append( 

363 add_input( 

364 "rel_start", 

365 0, 

366 "s", 

367 "relative start time of the flexibility event", 

368 "int", 

369 ) 

370 ) 

371 value_list.elts.append( 

372 add_input( 

373 "rel_end", 

374 0, 

375 "s", 

376 "relative end time of the flexibility event", 

377 "int", 

378 ) 

379 ) 

380 

381 # add the flex variables and the weights 

382 if body.target.id == "parameters": 

383 for parameter in self.mpc_data.config_parameters_appendix: 

384 body.value.elts.append( 

385 add_parameter(parameter.name, 0, "-", parameter.description) 

386 ) 

387 

388 def modify_setup_system_shadow(self, node: ast.FunctionDef): 

389 """Modify the setup_system method of the shadow mpc model class. 

390 

391 This method changes the return statement of the setup_system method and adds 

392 all necessary new lines of code. 

393 

394 Args: 

395 node: The function definition node of setup_system. 

396 

397 """ 

398 # constraint the control trajectories for t < market_time 

399 for i, item in enumerate(node.body): 

400 if ( 

401 isinstance(item, ast.Assign) 

402 and isinstance(item.targets[0], ast.Attribute) 

403 and item.targets[0].attr == "constraints" 

404 ): 

405 if isinstance(item.value, ast.List): 

406 for ind, control in enumerate(self.controls): 

407 # insert control boundaries at beginning of function 

408 node.body.insert( 

409 0, 

410 ast.parse( 

411 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, " 

412 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

413 f"self.{control.name}.ub)" 

414 ).body[0], 

415 ) 

416 node.body.insert( 

417 0, 

418 ast.parse( 

419 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, " 

420 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

421 f"self.{control.name}.lb)" 

422 ).body[0], 

423 ) 

424 # append to constraints 

425 new_element = ( 

426 ast.parse( 

427 f"({control.name}_lower, self.{control.name}, {control.name}_upper)" 

428 ) 

429 .body[0] 

430 .value 

431 ) 

432 item.value.elts.append(new_element) 

433 # also include binary controls 

434 if self.binary_controls: 

435 for ind, control in enumerate(self.binary_controls): 

436 # insert control boundaries at beginning of function 

437 node.body.insert( 

438 0, 

439 ast.parse( 

440 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, " 

441 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

442 f"self.{control.name}.ub)" 

443 ).body[0], 

444 ) 

445 node.body.insert( 

446 0, 

447 ast.parse( 

448 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, " 

449 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

450 f"self.{control.name}.lb)" 

451 ).body[0], 

452 ) 

453 # append to constraints 

454 new_element = ( 

455 ast.parse( 

456 f"({control.name}_lower, self.{control.name}, {control.name}_upper)" 

457 ) 

458 .body[0] 

459 .value 

460 ) 

461 item.value.elts.append(new_element) 

462 break 

463 # loop through setup_system function to find return statement 

464 for i, stmt in enumerate(node.body): 

465 if isinstance(stmt, ast.Return): 

466 # store current return statement 

467 original_return = stmt.value 

468 new_body = [ 

469 # create new standard objective variable 

470 ast.Assign( 

471 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

472 value=original_return, 

473 ), 

474 # create flex objective variable 

475 ast.Assign( 

476 targets=[ast.Name(id="obj_flex", ctx=ast.Store())], 

477 value=ast.parse( 

478 self.mpc_data.flex_cost_function, mode="eval" 

479 ).body, 

480 ), 

481 # overwrite return statement with custom function 

482 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value), 

483 ] 

484 # append new variables to end of function 

485 node.body[i:] = new_body 

486 break 

487 

488 def modify_setup_system_baseline(self, node: ast.FunctionDef): 

489 """Modify the setup_system method of the baseline mpc model class. 

490 

491 This method changes the return statement of the setup_system method and adds 

492 all necessary new lines of code. 

493 

494 Args: 

495 node: The function definition node of setup_system. 

496 

497 """ 

498 # set the control trajectories with the respective variables 

499 if self.binary_controls: 

500 controls_list = self.controls + self.binary_controls 

501 else: 

502 controls_list = self.controls 

503 full_traj_list = [ 

504 ast.Assign( 

505 targets=[ 

506 ast.Attribute( 

507 value=ast.Name(id="self", ctx=ast.Load()), 

508 attr=f"{full_trajectory_prefix}{control.name}" 

509 f"{full_trajectory_suffix}.alg", 

510 ctx=ast.Store(), 

511 ) 

512 ], 

513 value=ast.Attribute( 

514 value=ast.Name(id="self", ctx=ast.Load()), 

515 attr=control.name, 

516 ctx=ast.Load(), 

517 ), 

518 ) 

519 for control in controls_list 

520 ] 

521 # loop through setup_system function to find return statement 

522 for i, stmt in enumerate(node.body): 

523 if isinstance(stmt, ast.Return): 

524 # store current return statement 

525 original_return = stmt.value 

526 new_body = [ 

527 # create new standard objective variable 

528 ast.Assign( 

529 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

530 value=original_return, 

531 ), 

532 # overwrite return statement with custom function 

533 ast.Return( 

534 value=ast.parse( 

535 return_baseline_cost_function( 

536 power_variable=self.mpc_data.power_variable, 

537 comfort_variable=self.mpc_data.comfort_variable 

538 ) 

539 ) 

540 .body[0] 

541 .value 

542 ), 

543 ] 

544 # append new variables to end of function 

545 node.body[i:] = full_traj_list + new_body 

546 break 

547 

548 

549def add_import_to_tree(name: str, alias: str, tree: ast.Module) -> ast.Module: 

550 """Add import to the module. 

551 

552 The statement 'import name as alias' will be added. 

553 

554 Args: 

555 name: name of the module to be imported 

556 alias: alias of the module 

557 tree: the tree to be imported 

558 

559 Returns: 

560 The tree updated with the import statement 

561 

562 """ 

563 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)]) 

564 for node in tree.body: 

565 if isinstance(node, ast.Import): 

566 already_imported_names = [alias.name for alias in node.names] 

567 already_imported_alias = [alias.asname for alias in node.names] 

568 if ( 

569 name not in already_imported_names 

570 and alias not in already_imported_alias 

571 ): 

572 tree.body.insert(0, import_statement) 

573 break 

574 else: 

575 tree.body.insert(0, import_statement) 

576 return tree 

577 

578 

579def remove_all_imports_from_tree(tree: ast.Module) -> ast.Module: 

580 # Create a new list to hold nodes that are not imports 

581 new_body = [ 

582 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom)) 

583 ] 

584 # Update the body of the tree to the new list 

585 tree.body = new_body 

586 return tree