Coverage for agentlib_flexquant/utils/parsing.py: 85%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-01 15:10 +0000

1import ast 

2from typing import Union, List, Optional 

3from agentlib_flexquant.data_structures.mpcs import ( 

4 BaseMPCData, 

5 PFMPCData, 

6 NFMPCData, 

7 BaselineMPCData, 

8) 

9from agentlib_flexquant.data_structures.globals import ( 

10 SHADOW_MPC_COST_FUNCTION, 

11 return_baseline_cost_function, 

12 full_trajectory_prefix, 

13 full_trajectory_suffix, 

14 PROFILE_DEVIATION_WEIGHT, 

15 MARKET_TIME, 

16 PREP_TIME, 

17 FLEX_EVENT_DURATION 

18) 

19from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable 

20from string import Template 

21 

22# Constants 

23CASADI_INPUT = "CasadiInput" 

24CASADI_PARAMETER = "CasadiParameter" 

25CASADI_OUTPUT = "CasadiOutput" 

26 

27# String templates 

28INPUT_TEMPLATE = Template( 

29 "$class_name(name='$name', value=$value, unit='$unit', description='$description')" 

30) 

31PARAMETER_TEMPLATE = Template( 

32 "$class_name(name='$name', value=$value, unit='$unit', description='$description')" 

33) 

34OUTPUT_TEMPLATE = Template( 

35 "$class_name(name='$name', unit='$unit', type='$type', value=$value, description='$description')" 

36) 

37 

38 

39def create_ast_element(template_string): 

40 return ast.parse(template_string).body[0].value 

41 

42 

43def add_input(name, value, unit, description, type): 

44 return create_ast_element( 

45 INPUT_TEMPLATE.substitute( 

46 class_name=CASADI_INPUT, 

47 name=name, 

48 value=value, 

49 unit=unit, 

50 description=description, 

51 type=type, 

52 ) 

53 ) 

54 

55 

56def add_parameter(name, value, unit, description): 

57 return create_ast_element( 

58 PARAMETER_TEMPLATE.substitute( 

59 class_name=CASADI_PARAMETER, 

60 name=name, 

61 value=value, 

62 unit=unit, 

63 description=description, 

64 ) 

65 ) 

66 

67 

68def add_output(name, unit, type, value, description): 

69 return create_ast_element( 

70 OUTPUT_TEMPLATE.substitute( 

71 class_name=CASADI_OUTPUT, 

72 name=name, 

73 unit=unit, 

74 type=type, 

75 value=value, 

76 description=description, 

77 ) 

78 ) 

79 

80 

81class SetupSystemModifier(ast.NodeTransformer): 

82 """A custom AST transformer for modifying the MPC model file. 

83 

84 This class traverses the AST of the input file, identifies the relevant classes and methods, 

85 and performs the necessary modifications. 

86 

87 Attributes: 

88 mpc_data (str): The new return expression to be used in the setup_system method. 

89 

90 """ 

91 

92 def __init__( 

93 self, 

94 mpc_data: BaseMPCData, 

95 controls: List[MPCVariable], 

96 binary_controls: Optional[List[MPCVariable]], 

97 ): 

98 self.mpc_data = mpc_data 

99 self.controls = controls 

100 self.binary_controls = binary_controls 

101 # create object for ast parsing for both, the config and the model 

102 self.config_obj: Union[None, ast.expr] = None 

103 self.model_obj: Union[None, ast.expr] = None 

104 # select modification of setup_system based on mpc type 

105 if isinstance(mpc_data, (PFMPCData, NFMPCData)): 

106 self.modify_config_class = self.modify_config_class_shadow 

107 self.modify_setup_system = self.modify_setup_system_shadow 

108 if isinstance(mpc_data, BaselineMPCData): 

109 self.modify_config_class = self.modify_config_class_baseline 

110 self.modify_setup_system = self.modify_setup_system_baseline 

111 

112 def visit_Module(self, module): 

113 """Visit a module definition in the AST. 

114 

115 Appends or deletes the import statements at the top of the module. 

116 

117 Args: 

118 module (ast.Module): The module definition node in the AST. 

119 

120 Returns: 

121 ast.Module: The possibly modified module definition node. 

122 

123 """ 

124 # append imports for baseline 

125 if isinstance(self.mpc_data, BaselineMPCData): 

126 module = add_import_to_tree(name="pandas", alias="pd", tree=module) 

127 module = add_import_to_tree(name="casadi", alias="ca", tree=module) 

128 # delete imports for shadow MPCs 

129 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)): 

130 module = remove_all_imports_from_tree(module) 

131 # trigger the next visit method (ClassDef) 

132 self.generic_visit(module) 

133 return module 

134 

135 def visit_ClassDef(self, node): 

136 """Visit a class definition in the AST. 

137 

138 This method is called for each class definition in the AST. It identifies the 

139 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions. 

140 

141 Args: 

142 node (ast.ClassDef): The class definition node in the AST. 

143 

144 Returns: 

145 ast.ClassDef: The possibly modified class definition node. 

146 

147 """ 

148 for base in node.bases: 

149 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig": 

150 # get ast object and trigger modification 

151 self.config_obj = node 

152 self.modify_config_class(node) 

153 # change class name 

154 node.name = self.mpc_data.class_name + "Config" 

155 if isinstance(base, ast.Name) and base.id == "CasadiModel": 

156 # get ast object and trigger modification 

157 self.model_obj = node 

158 for item in node.body: 

159 if ( 

160 isinstance(item, ast.FunctionDef) 

161 and item.name == "setup_system" 

162 ): 

163 self.modify_setup_system(item) 

164 # change config value 

165 if isinstance(item, ast.AnnAssign) and item.target.id == "config": 

166 item.annotation = ( 

167 ast.parse(self.mpc_data.class_name + "Config").body[0].value 

168 ) 

169 

170 # change class name 

171 node.name = self.mpc_data.class_name 

172 

173 return node 

174 

175 def get_leftmost_list(self, node): 

176 """ 

177 Recursively traverse binary operations to get the leftmost list. 

178 

179 Args: 

180 node: An AST node (could be a BinOp or directly a List) 

181 

182 Returns: 

183 The leftmost List node found 

184 """ 

185 if isinstance(node, ast.List): 

186 return node 

187 elif isinstance(node, ast.BinOp): 

188 # If it's a binary operation, recurse to the left 

189 return self.get_leftmost_list(node.left) 

190 elif isinstance(node, ast.Tuple): 

191 # If it's a tuple with elements, check the first element 

192 if node.elts and len(node.elts) > 0: 

193 return self.get_leftmost_list(node.elts[0]) 

194 # If we get here, we couldn't find a list 

195 return None 

196 

197 def modify_config_class_shadow(self, node): 

198 """Modify the config class of the shadow mpc. 

199 

200 Args: 

201 node (ast.ClassDef): The class definition node of the config. 

202 

203 """ 

204 # loop over config object and modify fields 

205 for body in node.body: 

206 # add the time and full control trajectory inputs 

207 if body.target.id == "inputs": 

208 for control in self.controls: 

209 body.value.elts.append( 

210 add_input( 

211 f"{full_trajectory_prefix}{control.name}" 

212 f"{full_trajectory_suffix}", 

213 "pd.Series([0])", 

214 "W", 

215 "pd.Series", 

216 "full control output", 

217 ) 

218 ) 

219 # also include binary controls 

220 if self.binary_controls: 

221 for control in self.binary_controls: 

222 body.value.elts.append( 

223 add_input( 

224 f"{full_trajectory_prefix}{control.name}" 

225 f"{full_trajectory_suffix}", 

226 "pd.Series([0])", 

227 "W", 

228 "full control output", 

229 "pd.Series", 

230 ) 

231 ) 

232 body.value.elts.append( 

233 add_input("in_provision", False, "-", "provision flag", "bool") 

234 ) 

235 # add the flex variables and the weights 

236 if body.target.id == "parameters": 

237 for param_name in [PREP_TIME, FLEX_EVENT_DURATION, MARKET_TIME]: 

238 body.value.elts.append( 

239 add_parameter(param_name, 0, "s", "time to switch objective") 

240 ) 

241 for weight in self.mpc_data.weights: 

242 body.value.elts.append( 

243 add_parameter( 

244 weight.name, 

245 weight.value, 

246 "-", 

247 "Weight for P in objective function", 

248 ) 

249 ) 

250 

251 def modify_config_class_baseline(self, node): 

252 """Modify the config class of the baseline mpc. 

253 

254 Args: 

255 node (ast.ClassDef): The class definition node of the config. 

256 

257 """ 

258 # loop over config object and modify fields 

259 for body in node.body: 

260 # add the fullcontrol trajectories to the baseline config class 

261 if body.target.id == "outputs": 

262 if isinstance(body.value, ast.List): 

263 # Simple list case 

264 value_list = body.value 

265 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple): 

266 # Complex case with concatenated lists or tuple 

267 value_list = self.get_leftmost_list(body.value) 

268 for control in self.controls: 

269 value_list.elts.append( 

270 add_output( 

271 f"{full_trajectory_prefix}{control.name}" 

272 f"{full_trajectory_suffix}", 

273 "W", 

274 "pd.Series", 

275 "pd.Series([0])", 

276 "full control output", 

277 ) 

278 ) 

279 # also include binary controls 

280 if self.binary_controls: 

281 for control in self.binary_controls: 

282 body.value.elts.append( 

283 add_output( 

284 f"{full_trajectory_prefix}{control.name}" 

285 f"{full_trajectory_suffix}", 

286 "W", 

287 "pd.Series", 

288 "pd.Series([0])", 

289 "full control output", 

290 ) 

291 ) 

292 # add the flexibility inputs 

293 if body.target.id == "inputs": 

294 if isinstance(body.value, ast.List): 

295 # Simple list case 

296 value_list = body.value 

297 elif isinstance(body.value, ast.BinOp) or isinstance(body.value, ast.Tuple): 

298 # Complex case with concatenated lists or tuple 

299 value_list = self.get_leftmost_list(body.value) 

300 value_list.elts.append( 

301 add_input( 

302 "_P_external", 

303 0, 

304 "W", 

305 "External power profile to be provided", 

306 "pd.Series", 

307 ) 

308 ) 

309 value_list.elts.append( 

310 add_input( 

311 "in_provision", 

312 False, 

313 "-", 

314 "Flag signaling if the flexibility is in provision", 

315 "bool", 

316 ) 

317 ) 

318 value_list.elts.append( 

319 add_input( 

320 "rel_start", 

321 0, 

322 "s", 

323 "relative start time of the flexibility event", 

324 "int", 

325 ) 

326 ) 

327 value_list.elts.append( 

328 add_input( 

329 "rel_end", 

330 0, 

331 "s", 

332 "relative end time of the flexibility event", 

333 "int", 

334 ) 

335 ) 

336 

337 # add the flex variables and the weights 

338 if body.target.id == "parameters": 

339 for parameter in self.mpc_data.config_parameters_appendix: 

340 body.value.elts.append( 

341 add_parameter(parameter.name, 0, "-", parameter.description) 

342 ) 

343 

344 def modify_setup_system_shadow(self, node): 

345 """Modify the setup_system method of the shadow mpc model class. 

346 

347 This method changes the return statement of the setup_system method and adds 

348 all necessary new lines of code. 

349 

350 Args: 

351 node (ast.FunctionDef): The function definition node of setup_system. 

352 

353 """ 

354 # constraint the control trajectories for t < market_time 

355 for i, item in enumerate(node.body): 

356 if ( 

357 isinstance(item, ast.Assign) 

358 and isinstance(item.targets[0], ast.Attribute) 

359 and item.targets[0].attr == "constraints" 

360 ): 

361 if isinstance(item.value, ast.List): 

362 for ind, control in enumerate(self.controls): 

363 # insert control boundaries at beginning of function 

364 node.body.insert( 

365 0, 

366 ast.parse( 

367 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, " 

368 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

369 f"self.{control.name}.ub)" 

370 ).body[0], 

371 ) 

372 node.body.insert( 

373 0, 

374 ast.parse( 

375 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, " 

376 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

377 f"self.{control.name}.lb)" 

378 ).body[0], 

379 ) 

380 # append to constraints 

381 new_element = ( 

382 ast.parse( 

383 f"({control.name}_lower, self.{control.name}, {control.name}_upper)" 

384 ) 

385 .body[0] 

386 .value 

387 ) 

388 item.value.elts.append(new_element) 

389 # also include binary controls 

390 if self.binary_controls: 

391 for ind, control in enumerate(self.binary_controls): 

392 # insert control boundaries at beginning of function 

393 node.body.insert( 

394 0, 

395 ast.parse( 

396 f"{control.name}_upper = ca.if_else(self.time < self.market_time.sym, " 

397 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

398 f"self.{control.name}.ub)" 

399 ).body[0], 

400 ) 

401 node.body.insert( 

402 0, 

403 ast.parse( 

404 f"{control.name}_lower = ca.if_else(self.time < self.market_time.sym, " 

405 f"self.{full_trajectory_prefix}{control.name}{full_trajectory_suffix}.sym, " 

406 f"self.{control.name}.lb)" 

407 ).body[0], 

408 ) 

409 # append to constraints 

410 new_element = ( 

411 ast.parse( 

412 f"({control.name}_lower, self.{control.name}, {control.name}_upper)" 

413 ) 

414 .body[0] 

415 .value 

416 ) 

417 item.value.elts.append(new_element) 

418 break 

419 # loop through setup_system function to find return statement 

420 for i, stmt in enumerate(node.body): 

421 if isinstance(stmt, ast.Return): 

422 # store current return statement 

423 original_return = stmt.value 

424 new_body = [ 

425 # create new standard objective variable 

426 ast.Assign( 

427 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

428 value=original_return, 

429 ), 

430 # create flex objective variable 

431 ast.Assign( 

432 targets=[ast.Name(id="obj_flex", ctx=ast.Store())], 

433 value=ast.parse( 

434 self.mpc_data.flex_cost_function, mode="eval" 

435 ).body, 

436 ), 

437 # overwrite return statement with custom function 

438 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value), 

439 ] 

440 # append new variables to end of function 

441 node.body[i:] = new_body 

442 break 

443 

444 def modify_setup_system_baseline(self, node): 

445 """Modify the setup_system method of the baseline mpc model class. 

446 

447 This method changes the return statement of the setup_system method and adds 

448 all necessary new lines of code. 

449 

450 Args: 

451 node (ast.FunctionDef): The function definition node of setup_system. 

452 

453 """ 

454 # set the control trajectories with the respective variables 

455 if self.binary_controls: 

456 controls_list = self.controls + self.binary_controls 

457 else: 

458 controls_list = self.controls 

459 full_traj_list = [ 

460 ast.Assign( 

461 targets=[ 

462 ast.Attribute( 

463 value=ast.Name(id="self", ctx=ast.Load()), 

464 attr=f"{full_trajectory_prefix}{control.name}" 

465 f"{full_trajectory_suffix}.alg", 

466 ctx=ast.Store(), 

467 ) 

468 ], 

469 value=ast.Attribute( 

470 value=ast.Name(id="self", ctx=ast.Load()), 

471 attr=control.name, 

472 ctx=ast.Load(), 

473 ), 

474 ) 

475 for control in controls_list 

476 ] 

477 # loop through setup_system function to find return statement 

478 for i, stmt in enumerate(node.body): 

479 if isinstance(stmt, ast.Return): 

480 # store current return statement 

481 original_return = stmt.value 

482 new_body = [ 

483 # create new standard objective variable 

484 ast.Assign( 

485 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

486 value=original_return, 

487 ), 

488 # overwrite return statement with custom function 

489 ast.Return( 

490 value=ast.parse( 

491 return_baseline_cost_function( 

492 power_variable=self.mpc_data.power_variable, 

493 comfort_variable=self.mpc_data.comfort_variable 

494 ) 

495 ) 

496 .body[0] 

497 .value 

498 ), 

499 ] 

500 # append new variables to end of function 

501 node.body[i:] = full_traj_list + new_body 

502 break 

503 

504 

505def add_import_to_tree(name: str, alias: str, tree: ast.Module): 

506 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)]) 

507 for node in tree.body: 

508 if isinstance(node, ast.Import): 

509 already_imported_names = [alias.name for alias in node.names] 

510 already_imported_alias = [alias.asname for alias in node.names] 

511 if ( 

512 name not in already_imported_names 

513 and alias not in already_imported_alias 

514 ): 

515 tree.body.insert(0, import_statement) 

516 break 

517 else: 

518 tree.body.insert(0, import_statement) 

519 return tree 

520 

521 

522def remove_all_imports_from_tree(tree: ast.Module): 

523 # Create a new list to hold nodes that are not imports 

524 new_body = [ 

525 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom)) 

526 ] 

527 # Update the body of the tree to the new list 

528 tree.body = new_body 

529 return tree