Coverage for agentlib_flexquant/utils/parsing.py: 90%

136 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-10-20 14:09 +0000

1import ast 

2from string import Template 

3from typing import Optional, Union 

4 

5from agentlib_mpc.data_structures.mpc_datamodels import MPCVariable 

6 

7from agentlib_flexquant.data_structures.globals import ( 

8 FLEX_EVENT_DURATION, 

9 MARKET_TIME, 

10 PREP_TIME, 

11 SHADOW_MPC_COST_FUNCTION, 

12 full_trajectory_prefix, 

13 full_trajectory_suffix, 

14 return_baseline_cost_function, 

15) 

16from agentlib_flexquant.data_structures.mpcs import ( 

17 BaselineMPCData, 

18 BaseMPCData, 

19 NFMPCData, 

20 PFMPCData, 

21) 

22 

23# Constants 

24CASADI_INPUT = "CasadiInput" 

25CASADI_PARAMETER = "CasadiParameter" 

26CASADI_OUTPUT = "CasadiOutput" 

27 

28# String templates 

29INPUT_TEMPLATE = Template( 

30 "$class_name(name='$name', value=$value, unit='$unit', type='$type', " 

31 "description='$description')" 

32) 

33PARAMETER_TEMPLATE = Template( 

34 "$class_name(name='$name', value=$value, unit='$unit', description='$description')" 

35) 

36OUTPUT_TEMPLATE = Template( 

37 "$class_name(name='$name', unit='$unit', type='$type', value=$value, " 

38 "description='$description')" 

39) 

40 

41 

42def create_ast_element(template_string: str) -> ast.Call: 

43 """Convert a template string into an AST call node. 

44 

45 Args: 

46 template_string: A Python code template string to parse. 

47 

48 Returns: 

49 ast.Call: An abstract syntax tree (AST) call node parsed from the template string. 

50 

51 """ 

52 return ast.parse(template_string).body[0].value 

53 

54 

55def add_input( 

56 name: str, value: Union[bool, str, int], unit: str, description: str, type: str 

57) -> ast.Call: 

58 """Create an AST node for an input definition. 

59 

60 Args: 

61 name: The name of the input. 

62 value: The default value for the input. Can be a boolean, string, or integer. 

63 unit: The unit associated with the input value. 

64 description: A human-readable description of the input. 

65 type: The data type of the input (e.g., "float", "int", "string"). 

66 

67 Returns: 

68 ast.Call: An abstract syntax tree (AST) call node representing the input definition. 

69 

70 """ 

71 return create_ast_element( 

72 INPUT_TEMPLATE.substitute( 

73 class_name=CASADI_INPUT, 

74 name=name, 

75 value=value, 

76 unit=unit, 

77 description=description, 

78 type=type, 

79 ) 

80 ) 

81 

82 

83def add_parameter( 

84 name: str, value: Union[int, float], unit: str, description: str 

85) -> ast.Call: 

86 """Create an AST node for a parameter definition. 

87 

88 Args: 

89 name: The name of the parameter. 

90 value: The value of the parameter. Can be an integer or float. 

91 unit: The unit associated with the parameter value. 

92 description: A human-readable description of the parameter. 

93 

94 Returns: 

95 ast.Call: An abstract syntax tree (AST) call node 

96 representing the parameter definition. 

97 

98 """ 

99 return create_ast_element( 

100 PARAMETER_TEMPLATE.substitute( 

101 class_name=CASADI_PARAMETER, 

102 name=name, 

103 value=value, 

104 unit=unit, 

105 description=description, 

106 ) 

107 ) 

108 

109 

110def add_output( 

111 name: str, unit: str, type: str, value: Union[str, float], description: str 

112) -> ast.Call: 

113 """Create an AST node for an output definition. 

114 

115 Args: 

116 name: The name of the output. 

117 unit: The unit associated with the output value. 

118 type: The data type of the output (e.g., "float", "string"). 

119 value: The value of the output. Can be a string or float. 

120 description: A human-readable description of the output. 

121 

122 Returns: 

123 ast.Call: An abstract syntax tree (AST) call node representing the output definition. 

124 

125 """ 

126 return create_ast_element( 

127 OUTPUT_TEMPLATE.substitute( 

128 class_name=CASADI_OUTPUT, 

129 name=name, 

130 unit=unit, 

131 type=type, 

132 value=value, 

133 description=description, 

134 ) 

135 ) 

136 

137 

138class SetupSystemModifier(ast.NodeTransformer): 

139 """A custom AST transformer for modifying the MPC model file. 

140 

141 This class traverses the AST of the input file, identifies the relevant classes and methods, 

142 and performs the necessary modifications. 

143 

144 """ 

145 

146 def __init__( 

147 self, 

148 mpc_data: BaseMPCData, 

149 controls: list[MPCVariable], 

150 binary_controls: Optional[list[MPCVariable]], 

151 ): 

152 self.mpc_data = mpc_data 

153 self.controls = controls 

154 self.binary_controls = binary_controls 

155 # create object for ast parsing for both, the config and the model 

156 self.config_obj: Union[None, ast.expr] = None 

157 self.model_obj: Union[None, ast.expr] = None 

158 # select modification of setup_system based on mpc type 

159 if isinstance(mpc_data, (PFMPCData, NFMPCData)): 

160 self.modify_config_class = self.modify_config_class_shadow 

161 self.modify_setup_system = self.modify_setup_system_shadow 

162 if isinstance(mpc_data, BaselineMPCData): 

163 self.modify_config_class = self.modify_config_class_baseline 

164 self.modify_setup_system = self.modify_setup_system_baseline 

165 

166 def visit_Module(self, module: ast.Module) -> ast.Module: 

167 """Visit a module definition in the AST. 

168 

169 Append or delete the import statements at the top of the module. 

170 

171 Args: 

172 module: The module definition node in the AST. 

173 

174 Returns: 

175 The possibly modified module definition node. 

176 

177 """ 

178 # append imports for baseline 

179 if isinstance(self.mpc_data, BaselineMPCData): 

180 module = add_import_to_tree(name="pandas", alias="pd", tree=module) 

181 module = add_import_to_tree(name="casadi", alias="ca", tree=module) 

182 # delete imports for shadow MPCs 

183 if isinstance(self.mpc_data, (NFMPCData, PFMPCData)): 

184 module = remove_all_imports_from_tree(module) 

185 # trigger the next visit method (ClassDef) 

186 self.generic_visit(module) 

187 return module 

188 

189 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: 

190 """Visit a class definition in the AST. 

191 

192 This method is called for each class definition in the AST. It identifies the 

193 BaselineMPCModelConfig and BaselineMPCModel classes and performs the necessary actions. 

194 

195 Args: 

196 node: The class definition node in the AST. 

197 

198 Returns: 

199 The possibly modified class definition node. 

200 

201 """ 

202 for base in node.bases: 

203 if isinstance(base, ast.Name) and base.id == "CasadiModelConfig": 

204 # get ast object and trigger modification 

205 self.config_obj = node 

206 self.modify_config_class(node) 

207 # change class name 

208 node.name = self.mpc_data.class_name + "Config" 

209 if isinstance(base, ast.Name) and base.id == "CasadiModel": 

210 # get ast object and trigger modification 

211 self.model_obj = node 

212 for item in node.body: 

213 if ( 

214 isinstance(item, ast.FunctionDef) 

215 and item.name == "setup_system" 

216 ): 

217 self.modify_setup_system(item) 

218 # change config value 

219 if isinstance(item, ast.AnnAssign) and item.target.id == "config": 

220 item.annotation = ( 

221 ast.parse(self.mpc_data.class_name + "Config").body[0].value 

222 ) 

223 

224 # change class name 

225 node.name = self.mpc_data.class_name 

226 

227 return node 

228 

229 def get_leftmost_list( 

230 self, node: Union[ast.Tuple, ast.BinOp, ast.List] 

231 ) -> Optional[ast.List]: 

232 """Recursively traverse binary operations to get the leftmost list. 

233 

234 Args: 

235 node: An AST node (could be a BinOp or directly a List) 

236 

237 Returns: 

238 The leftmost List node found 

239 

240 """ 

241 if isinstance(node, ast.List): 

242 return node 

243 elif isinstance(node, ast.BinOp): 

244 # If it's a binary operation, recurse to the left 

245 return self.get_leftmost_list(node.left) 

246 elif isinstance(node, ast.Tuple): 

247 # If it's a tuple with elements, check the first element 

248 if node.elts and len(node.elts) > 0: 

249 return self.get_leftmost_list(node.elts[0]) 

250 # If we get here, we couldn't find a list 

251 return None 

252 

253 def modify_config_class_shadow(self, node: ast.ClassDef): 

254 """Modify the config class of the shadow mpc. 

255 

256 Args: 

257 node: The class definition node of the config. 

258 

259 """ 

260 # loop over config object and modify fields 

261 for body in node.body: 

262 # add the time and full baseline control trajectory as inputs 

263 if body.target.id == "inputs": 

264 for control in self.controls: 

265 body.value.elts.append( 

266 add_input( 

267 f"{control.name}{full_trajectory_suffix}", 

268 None, 

269 control.unit, 

270 "full control trajectory output of baseline mpc", 

271 "pd.Series", 

272 ) 

273 ) 

274 # also include binary controls 

275 if self.binary_controls: 

276 for control in self.binary_controls: 

277 body.value.elts.append( 

278 add_input( 

279 f"{control.name}{full_trajectory_suffix}", 

280 None, 

281 control.unit, 

282 "full control trajectory output of baseline mpc", 

283 "pd.Series", 

284 ) 

285 ) 

286 body.value.elts.append( 

287 add_input("in_provision", False, "-", "provision flag", "bool") 

288 ) 

289 # add the flex variables and the weights 

290 if body.target.id == "parameters": 

291 for param_name in [PREP_TIME, FLEX_EVENT_DURATION, MARKET_TIME]: 

292 body.value.elts.append( 

293 add_parameter(param_name, 0, "s", "time to switch objective") 

294 ) 

295 for weight in self.mpc_data.weights: 

296 body.value.elts.append( 

297 add_parameter( 

298 weight.name, 

299 weight.value, 

300 "-", 

301 "Weight for P in objective function", 

302 ) 

303 ) 

304 

305 def modify_config_class_baseline(self, node: ast.ClassDef): 

306 """Modify the config class of the baseline mpc. 

307 

308 Args: 

309 node: The class definition node of the config. 

310 

311 """ 

312 # loop over config object and modify fields 

313 for body in node.body: 

314 # add the fullcontrol trajectories to the baseline config class 

315 if body.target.id == "outputs": 

316 if isinstance(body.value, ast.List): 

317 # Simple list case 

318 value_list = body.value 

319 elif isinstance(body.value, ast.BinOp) or isinstance( 

320 body.value, ast.Tuple 

321 ): 

322 # Complex case with concatenated lists or tuple 

323 value_list = self.get_leftmost_list(body.value) 

324 

325 # add the flexibility inputs 

326 if body.target.id == "inputs": 

327 if isinstance(body.value, ast.List): 

328 # Simple list case 

329 value_list = body.value 

330 elif isinstance(body.value, ast.BinOp) or isinstance( 

331 body.value, ast.Tuple 

332 ): 

333 # Complex case with concatenated lists or tuple 

334 value_list = self.get_leftmost_list(body.value) 

335 value_list.elts.append( 

336 add_input( 

337 "_P_external", 

338 0, 

339 "W", 

340 "External power profile to be provided", 

341 "pd.Series", 

342 ) 

343 ) 

344 value_list.elts.append( 

345 add_input( 

346 "in_provision", 

347 False, 

348 "-", 

349 "Flag signaling if the flexibility is in provision", 

350 "bool", 

351 ) 

352 ) 

353 value_list.elts.append( 

354 add_input( 

355 "rel_start", 

356 0, 

357 "s", 

358 "relative start time of the flexibility event", 

359 "int", 

360 ) 

361 ) 

362 value_list.elts.append( 

363 add_input( 

364 "rel_end", 

365 0, 

366 "s", 

367 "relative end time of the flexibility event", 

368 "int", 

369 ) 

370 ) 

371 

372 # add the flex variables and the weights 

373 if body.target.id == "parameters": 

374 for parameter in self.mpc_data.config_parameters_appendix: 

375 body.value.elts.append( 

376 add_parameter(parameter.name, 0, "-", parameter.description) 

377 ) 

378 

379 def modify_setup_system_shadow(self, node: ast.FunctionDef): 

380 """Modify the setup_system method of the shadow mpc model class. 

381 

382 This method changes the return statement of the setup_system method and adds 

383 all necessary new lines of code. 

384 

385 Args: 

386 node: The function definition node of setup_system. 

387 

388 """ 

389 # constraint the control trajectories for t < market_time 

390 for i, item in enumerate(node.body): 

391 if ( 

392 isinstance(item, ast.Assign) 

393 and isinstance(item.targets[0], ast.Attribute) 

394 and item.targets[0].attr == "constraints" 

395 ): 

396 if isinstance(item.value, ast.List): 

397 for ind, control in enumerate(self.controls): 

398 # insert control boundaries at beginning of function 

399 node.body.insert( 

400 0, 

401 ast.parse( 

402 f"{control.name}_upper = ca.if_else(self.time < " 

403 f"self.market_time.sym, " 

404 f"self.{control.name}{full_trajectory_suffix}.sym, " 

405 f"self.{control.name}.ub)" 

406 ).body[0], 

407 ) 

408 node.body.insert( 

409 0, 

410 ast.parse( 

411 f"{control.name}_lower = ca.if_else(self.time < " 

412 f"self.market_time.sym, " 

413 f"self.{control.name}{full_trajectory_suffix}.sym, " 

414 f"self.{control.name}.lb)" 

415 ).body[0], 

416 ) 

417 # append to constraints 

418 new_element = ( 

419 ast.parse( 

420 f"({control.name}_lower, self.{control.name}, {control.name}_upper)" 

421 ) 

422 .body[0] 

423 .value 

424 ) 

425 item.value.elts.append(new_element) 

426 break 

427 # loop through setup_system function to find return statement 

428 for i, stmt in enumerate(node.body): 

429 if isinstance(stmt, ast.Return): 

430 # store current return statement 

431 original_return = stmt.value 

432 new_body = [ 

433 # create new standard objective variable 

434 ast.Assign( 

435 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

436 value=original_return, 

437 ), 

438 # create flex objective variable 

439 ast.Assign( 

440 targets=[ast.Name(id="obj_flex", ctx=ast.Store())], 

441 value=ast.parse( 

442 self.mpc_data.flex_cost_function, mode="eval" 

443 ).body, 

444 ), 

445 # overwrite return statement with custom function 

446 ast.Return(value=ast.parse(SHADOW_MPC_COST_FUNCTION).body[0].value), 

447 ] 

448 # append new variables to end of function 

449 node.body[i:] = new_body 

450 break 

451 

452 def modify_setup_system_baseline(self, node: ast.FunctionDef): 

453 """Modify the setup_system method of the baseline mpc model class. 

454 

455 This method changes the return statement of the setup_system method and adds 

456 all necessary new lines of code. 

457 

458 Args: 

459 node: The function definition node of setup_system. 

460 

461 """ 

462 

463 # loop through setup_system function to find return statement 

464 for i, stmt in enumerate(node.body): 

465 if isinstance(stmt, ast.Return): 

466 # store current return statement 

467 original_return = stmt.value 

468 new_body = [ 

469 # create new standard objective variable 

470 ast.Assign( 

471 targets=[ast.Name(id="obj_std", ctx=ast.Store())], 

472 value=original_return, 

473 ), 

474 # overwrite return statement with custom function 

475 ast.Return( 

476 value=ast.parse( 

477 return_baseline_cost_function( 

478 power_variable=self.mpc_data.power_variable, 

479 comfort_variable=self.mpc_data.comfort_variable, 

480 ) 

481 ) 

482 .body[0] 

483 .value 

484 ), 

485 ] 

486 # append new variables to end of function 

487 node.body[i:] = new_body 

488 break 

489 

490 

491def add_import_to_tree(name: str, alias: str, tree: ast.Module) -> ast.Module: 

492 """Add import to the module. 

493 

494 The statement 'import name as alias' will be added. 

495 

496 Args: 

497 name: name of the module to be imported 

498 alias: alias of the module 

499 tree: the tree to be imported 

500 

501 Returns: 

502 The tree updated with the import statement 

503 

504 """ 

505 import_statement = ast.Import(names=[ast.alias(name=name, asname=alias)]) 

506 for node in tree.body: 

507 if isinstance(node, ast.Import): 

508 already_imported_names = [alias.name for alias in node.names] 

509 already_imported_alias = [alias.asname for alias in node.names] 

510 if ( 

511 name not in already_imported_names 

512 and alias not in already_imported_alias 

513 ): 

514 tree.body.insert(0, import_statement) 

515 break 

516 else: 

517 tree.body.insert(0, import_statement) 

518 return tree 

519 

520 

521def remove_all_imports_from_tree(tree: ast.Module) -> ast.Module: 

522 # Create a new list to hold nodes that are not imports 

523 new_body = [ 

524 node for node in tree.body if not isinstance(node, (ast.Import, ast.ImportFrom)) 

525 ] 

526 # Update the body of the tree to the new list 

527 tree.body = new_body 

528 return tree