Coverage for agentlib/modules/simulation/simulator.py: 89%

315 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2026-02-26 17:34 +0000

1""" 

2Module contains the Simulator, used to simulate any model. 

3""" 

4 

5import os 

6import warnings 

7from dataclasses import dataclass, field 

8from math import inf 

9from pathlib import Path 

10from typing import Union, Dict, List, Optional 

11 

12import numpy as np 

13import pandas as pd 

14from pydantic import field_validator, Field 

15from pydantic_core.core_schema import FieldValidationInfo 

16 

17from agentlib.core import ( 

18 BaseModule, 

19 BaseModuleConfig, 

20 Agent, 

21 Causality, 

22 AgentVariable, 

23 AgentVariables, 

24 ModelVariable, 

25 Model, 

26) 

27from agentlib.core.errors import OptionalDependencyError 

28from agentlib.models import get_model_type, UNINSTALLED_MODEL_TYPES 

29from agentlib.utils import custom_injection, create_time_samples 

30 

31 

32@dataclass 

33class SimulatorResults: 

34 """Class to organize in-memory simulator results.""" 

35 

36 # Configuration 

37 filename: Optional[str] = None 

38 header_written: bool = False 

39 

40 # Data Buffers 

41 index: List[float] = field(default_factory=list) 

42 data: List[List[float]] = field(default_factory=list) 

43 

44 # State tracking 

45 _current_inputs: List[float] = field(default_factory=list) 

46 _current_outputs: List[float] = field(default_factory=list) 

47 _columns: pd.MultiIndex = None 

48 _input_count: int = 0 

49 _output_count: int = 0 

50 

51 def setup(self, input_vars: List[ModelVariable], output_vars: List[ModelVariable]): 

52 """ 

53 Initializes results object input variables 

54 "u", outputs "x" and internal state variables "x". 

55 It uses a nested dict provided from model class. 

56 +---------------------------------------------------+ 

57 | | inputs u | outputs y | states x | 

58 | t | u1 | u2 | y1 | y2 | x1 | x2 | 

59 +---------------------------------------------------+ 

60 | 1 | ... | ... | ... | ... | ... | ... | 

61 | 2 | ... | ... | ... | ... | ... | ... | 

62 |...| ... | ... | ... | ... | ... | ... | 

63 |... 

64 Also initializes the internal buffers. 

65 """ 

66 variables = output_vars + input_vars 

67 self._input_count = len(input_vars) 

68 self._output_count = len(output_vars) 

69 

70 # Initialize current inputs with current values 

71 self._current_inputs = [var.value for var in input_vars] 

72 

73 self._columns = pd.MultiIndex.from_arrays( 

74 arrays=np.array( 

75 [ 

76 [_var.causality.name for _var in variables], 

77 [_var.name for _var in variables], 

78 [_var.type for _var in variables], 

79 ] 

80 ), 

81 sortorder=0, 

82 names=["causality", "name", "type"], 

83 ) 

84 

85 def update_inputs(self, values: List[float], time: float, capture_all_inputs: bool): 

86 """ 

87 Updates the result with the inputs creating a full result row (output + input). 

88 If capture_all_inputs is True, creates a row with NaN outputs. 

89 """ 

90 self._current_inputs = values 

91 # Results can already hold the input (at t_sample_communication created by 

92 # the output writing) or the input time is new (created by an input callback) 

93 if not self.index or time != self.index[-1]: 

94 # For capture_all_inputs, append the inputs created by an input callback 

95 if capture_all_inputs: 

96 self.index.append(time) 

97 # index is not in data, if results have been written to disk 

98 # Create row: [NaN, NaN, ..., In1, In2, ...] 

99 row = [None] * self._output_count + self._current_inputs 

100 # If timestamp is new, this needs to be appended 

101 self.data.append(row) 

102 else: 

103 # Create row: [Out1, Out2, ..., In1, In2, ...] 

104 row = self.data[-1][:self._output_count] + self._current_inputs 

105 # Update timestamp with new inputs 

106 self.data[-1] = row 

107 

108 def update_outputs(self, values: List[float], time: float): 

109 """ 

110 Stores a result row at the end of a simulation step. 

111 Combines provided output values with None for input values, as these are 

112 updated in the next time step. 

113 """ 

114 # Create row: [Out1, Out2, ..., None, None, ...] 

115 row = values + [None] * self._input_count 

116 self.index.append(time) 

117 self.data.append(row) 

118 

119 def update_current_outputs(self, values: List[float]): 

120 """ 

121 Stores the current output values of intermediate simulation steps. 

122 """ 

123 self._current_outputs = values 

124 

125 def initialize_outputs(self, time): 

126 """ 

127 Initializes output data with Nones. 

128 """ 

129 self.index.append(time) 

130 # Create row: [None, None, ..., In1, In2, ...] 

131 self.data.append([None] * self._output_count + self._current_inputs) 

132 

133 def initialize_inputs(self, values: List[float]): 

134 """ 

135 Initializes input data with Nones. 

136 """ 

137 self._current_inputs = values 

138 

139 def write_results(self): 

140 """ 

141 Dumps results which are currently in memory to the file. 

142 Clears memory after writing to keep footprint low. 

143 """ 

144 if not self.filename or not self.data: 

145 return 

146 

147 df = pd.DataFrame(self.data, index=self.index, columns=self._columns) 

148 

149 # Write header only once 

150 header = not self.header_written and not Path(self.filename).exists() 

151 df.to_csv(self.filename, mode="a", header=header) 

152 

153 self.header_written = True 

154 

155 # Clear buffers 

156 self.index.clear() 

157 self.data.clear() 

158 

159 def df(self) -> pd.DataFrame: 

160 """Returns the current results as a dataframe.""" 

161 return pd.DataFrame(self.data, index=self.index, columns=self._columns) 

162 

163 

164def read_simulator_results(file: str): 

165 """Reads results from file with correct multi-column format.""" 

166 return pd.read_csv(file, header=[0, 1, 2], index_col=0) 

167 

168 

169class SimulatorConfig(BaseModuleConfig): 

170 """ 

171 Pydantic data model for simulator configuration parser 

172 """ 

173 

174 parameters: AgentVariables = [] 

175 inputs: AgentVariables = [] 

176 outputs: AgentVariables = [] 

177 states: AgentVariables = [] 

178 shared_variable_fields: List[str] = ["outputs"] 

179 

180 t_start: Union[float, int] = Field( 

181 title="t_start", default=0.0, ge=0, description="Simulation start time" 

182 ) 

183 t_stop: Union[float, int] = Field( 

184 title="t_stop", default=inf, ge=0, description="Simulation stop time" 

185 ) 

186 t_sample: Union[float, int] = Field( 

187 title="t_sample", default=1, ge=0, description="Deprecated option." 

188 ) 

189 t_sample_communication: Union[float, int] = Field( 

190 title="t_sample", 

191 default=1, 

192 validate_default=True, 

193 ge=0, 

194 description="Sample time of a full simulation step relevant for communication, including:" 

195 "- Perform simulation with t_sample_simulation" 

196 "- Update model results and send output values to other Agents or Modules." 

197 ) 

198 t_sample_simulation: Union[float, int] = Field( 

199 title="t_sample_simulation", 

200 default=1, 

201 validate_default=True, 

202 ge=0, 

203 description="Sample time of the simulation itself. " 

204 "The inputs of the models may be updated every other t_sample_simulation, " 

205 "as long as the model supports this. Used to override dt of the model." 

206 ) 

207 model: Dict 

208 # Model results 

209 save_results: bool = Field( 

210 title="save_results", 

211 default=False, 

212 description="If True, results are created and stored", 

213 ) 

214 overwrite_result_file: bool = Field( 

215 title="overwrite_result", 

216 default=False, 

217 description="If True, and the result file already exists, the file is overwritten.", 

218 ) 

219 result_filename: Optional[str] = Field( 

220 title="result_filename", 

221 default=None, 

222 description="If not None, results are stored in that filename." 

223 "Needs to be a .csv file", 

224 ) 

225 result_sep: str = Field( 

226 title="result_sep", 

227 default=",", 

228 description="Separator in the .csv file. Only relevant if " 

229 "result_filename is passed", 

230 ) 

231 result_causalities: List[Causality] = Field( 

232 title="result_causalities", 

233 default=[Causality.input, Causality.output], 

234 description="List of causalities to store. Default stores " 

235 "only inputs and outputs", 

236 ) 

237 capture_all_inputs: bool = Field( 

238 title="capture_all_inputs", 

239 default=False, 

240 description="If True, results are stored immediately when " 

241 "inputs change, even during simulation steps.", 

242 ) 

243 write_results_delay: Optional[float] = Field( 

244 title="Write Results Delay", 

245 default=None, 

246 description="Sampling interval for which the results are written to disc in seconds.", 

247 validate_default=True, 

248 gt=0, 

249 ) 

250 update_inputs_on_callback: bool = Field( 

251 title="update_inputs_on_callback", 

252 default=True, 

253 description="Deprecated! Will be removed in future versions." 

254 "If True, model inputs are updated if they are updated in data_broker." 

255 "Else, the model inputs are updated before each simulation.", 

256 ) 

257 measurement_uncertainty: Union[Dict[str, float], float] = Field( 

258 title="measurement_uncertainty", 

259 default=0, 

260 description="Either pass a float and add the percentage uncertainty " 

261 "to all measurements from the model." 

262 "Or pass a Dict and specify the model variable name as key" 

263 "and the associated uncertainty as a float", 

264 ) 

265 validate_incoming_values: Optional[bool] = Field( 

266 default=False, # we overwrite the default True in base, to be more efficient 

267 title="Validate Incoming Values", 

268 description="If true, the validator of the AgentVariable value is called when " 

269 "receiving a new value from the DataBroker. In the simulator, this " 

270 "is False by default, as we expect to receive a lot of measurements" 

271 " and want to be efficient.", 

272 ) 

273 

274 @field_validator("result_filename") 

275 @classmethod 

276 def check_nonexisting_csv(cls, result_filename, info: FieldValidationInfo): 

277 """Check if the result_filename is a .csv file or an hf 

278 and assert that it does not exist.""" 

279 if not info.data.get("save_results", False): 

280 # No need to check as filename will never be used anyways 

281 return None 

282 if result_filename is None: 

283 return result_filename 

284 if not result_filename.endswith(".csv"): 

285 raise TypeError( 

286 f"Given result_filename ends with " 

287 f'{result_filename.split(".")[-1]} ' 

288 f"but should be a .csv file" 

289 ) 

290 if os.path.isfile(result_filename): 

291 # remove result file, so a new one can be created 

292 if info.data["overwrite_result_file"]: 

293 os.remove(result_filename) 

294 return result_filename 

295 raise FileExistsError( 

296 f"Given result_filename at {result_filename} " 

297 f"already exists. We won't overwrite it automatically. " 

298 f"You can use the key word 'overwrite_result_file' to " 

299 f"activate automatic overwrite." 

300 ) 

301 # Create path in case it does not exist 

302 fpath = os.path.dirname(result_filename) 

303 if fpath: 

304 os.makedirs(fpath, exist_ok=True) 

305 return result_filename 

306 

307 @field_validator("t_stop") 

308 @classmethod 

309 def check_t_stop(cls, t_stop, info: FieldValidationInfo): 

310 """Check if stop is greater than start time""" 

311 t_start = info.data.get("t_start") 

312 assert t_stop > t_start, "t_stop must be greater than t_start" 

313 return t_stop 

314 

315 @field_validator("t_sample_communication", "t_sample_simulation") 

316 @classmethod 

317 def check_t_sample(cls, t_sample, info: FieldValidationInfo): 

318 """Check if t_sample is smaller than stop-start time""" 

319 t_start = info.data.get("t_start") 

320 t_stop = info.data.get("t_stop") 

321 t_sample_old = info.data.get("t_sample") 

322 

323 # Handle legacy t_sample logic 

324 if t_sample_old != 1: 

325 if info.field_name == "t_sample_simulation": 

326 t_sample = 1 

327 else: 

328 t_sample = t_sample_old 

329 assert ( 

330 t_start + t_sample <= t_stop 

331 ), "t_stop-t_start must be greater than t_sample" 

332 return t_sample 

333 

334 @field_validator("t_sample_communication") 

335 @classmethod 

336 def check_t_comm_against_sim(cls, t_sample_communication, 

337 info: FieldValidationInfo): 

338 """Check if t_sample is smaller than stop-start time""" 

339 t_sample_simulation = info.data.get("t_sample_simulation") 

340 if t_sample_simulation is not None: 

341 if t_sample_simulation > t_sample_communication: 

342 warnings.warn( 

343 f"{t_sample_communication=} is smaller than {t_sample_simulation=}", 

344 category=UserWarning 

345 ) 

346 return t_sample_communication 

347 

348 @field_validator("update_inputs_on_callback") 

349 @classmethod 

350 def deprecate_update_inputs_on_callback(cls, update_inputs_on_callback, 

351 info: FieldValidationInfo): 

352 """Check if t_sample is smaller than stop-start time""" 

353 warnings.warn( 

354 "update_inputs_on_callback is deprecated, remove it from your config. " 

355 "Will use update_inputs_on_callback=True", 

356 category=DeprecationWarning 

357 ) 

358 return True 

359 

360 @field_validator("t_sample") 

361 @classmethod 

362 def deprecate_t_sample(cls, t_sample, info: FieldValidationInfo): 

363 """Deprecates the t_sample field in favor of t_sample_communication 

364 and t_sample_simulation.""" 

365 warnings.warn( 

366 "t_sample is deprecated, use t_sample_communication for storing outputs " 

367 "and t_sample_simulation for the actual simulation step. " 

368 "Will use the given t_sample for t_sample_communication and " 

369 "t_sample_simulation=1 s, the `model.dt` default.", 

370 ) 

371 return t_sample 

372 

373 @field_validator("write_results_delay") 

374 @classmethod 

375 def set_default_t_sample(cls, write_results_delay, info: FieldValidationInfo): 

376 t_comm = info.data.get("t_sample_communication", 1) 

377 

378 if write_results_delay is None: 

379 # Default to writing every 5 communication steps to balance I/O 

380 return t_comm * 5 

381 

382 if write_results_delay < t_comm: 

383 raise ValueError("write_results_delay should be >= t_sample_communication") 

384 return write_results_delay 

385 

386 @field_validator("model") 

387 @classmethod 

388 def check_model(cls, model, info: FieldValidationInfo): 

389 """Validate the model input""" 

390 parameters = info.data.get("parameters") 

391 inputs = info.data.get("inputs") 

392 outputs = info.data.get("outputs") 

393 states = info.data.get("states") 

394 dt = info.data.get("t_sample_simulation") 

395 if "dt" in model and dt != model["dt"]: 

396 warnings.warn( 

397 f"Given model {model['dt']=} differs from {dt=} of simulator. " 

398 f"Using models dt, consider switching to t_sample_simulation." 

399 ) 

400 else: 

401 model["dt"] = dt 

402 if "type" not in model: 

403 raise KeyError( 

404 "Given model config does not " "contain key 'type' (type of the model)." 

405 ) 

406 _type = model.pop("type") 

407 if isinstance(_type, dict): 

408 custom_cls = custom_injection(config=_type) 

409 model = custom_cls(**model) 

410 elif isinstance(_type, str): 

411 if _type in UNINSTALLED_MODEL_TYPES: 

412 raise OptionalDependencyError( 

413 dependency_name=_type, 

414 dependency_install=UNINSTALLED_MODEL_TYPES[_type], 

415 used_object=f"model {_type}", 

416 ) 

417 model = get_model_type(_type)( 

418 **model, 

419 parameters=convert_agent_vars_to_list_of_dicts(parameters), 

420 inputs=convert_agent_vars_to_list_of_dicts(inputs), 

421 outputs=convert_agent_vars_to_list_of_dicts(outputs), 

422 states=convert_agent_vars_to_list_of_dicts(states), 

423 ) 

424 # Check if model was correctly initialized 

425 assert isinstance(model, Model) 

426 return model 

427 

428 

429class Simulator(BaseModule): 

430 """ 

431 The Simulator is the interface between simulation models 

432 and further other implementations. It contains all interface functions for 

433 interacting with the standard model class. 

434 """ 

435 

436 config: SimulatorConfig 

437 

438 def __init__(self, *, config: dict, agent: Agent): 

439 super().__init__(config=config, agent=agent) 

440 

441 self._model = None 

442 self.model = self.config.model 

443 self._inputs_changed_since_last_results_saving = False 

444 

445 # Caching variables for performance (avoid list comprehensions in loop) 

446 self._input_vars = self._get_result_input_variables() 

447 self._output_vars = self._get_result_output_variables() 

448 

449 # Initialize Result Handler 

450 self._result = SimulatorResults(filename=self.config.result_filename) 

451 if self.config.save_results: 

452 self._result.setup(input_vars=self._input_vars, 

453 output_vars=self._output_vars) 

454 

455 # Initialize local time trackers 

456 self._last_write_time = 0.0 

457 self._last_communication_time = self.env.time 

458 

459 self._register_input_callbacks() 

460 self.logger.info("%s initialized!", self.__class__.__name__) 

461 

462 def terminate(self): 

463 """Terminate the model""" 

464 self.model.terminate() 

465 super().terminate() 

466 

467 @property 

468 def model(self) -> Model: 

469 """ 

470 Getter for current simulation model 

471 

472 Returns: 

473 agentlib.core.model.Model: Current simulation model 

474 """ 

475 return self._model 

476 

477 @model.setter 

478 def model(self, model: Model): 

479 """ 

480 Setter for current simulation model. 

481 Also initializes it if needed! 

482 Args: 

483 model (agentlib.core.model.Model): model to set as current simulation model 

484 """ 

485 if not isinstance(model, Model): 

486 self.logger.error( 

487 "You forgot to pass a valid model to the simulator module!" 

488 ) 

489 raise TypeError( 

490 f"Given model is of type {type(model)} " 

491 f"but should be an instance of Model or a valid subclass" 

492 ) 

493 self._model = model 

494 if self.config.t_start and self.env.offset: 

495 self.logger.warning( 

496 "config.t_start and env.offset are both non-zero. " 

497 "This may cause unexpected behavior. Ensure that this " 

498 "is intended and you know what you are doing." 

499 ) 

500 self.model.initialize( 

501 t_start=self.config.t_start + self.env.config.offset, 

502 t_stop=self.config.t_stop, 

503 ) 

504 self.logger.info("Model successfully loaded model: %s", self.model.name) 

505 

506 def run(self, until=None): 

507 """ 

508 Runs the simulator in stand-alone mode if needed 

509 Attention: If the environment is connected to another environment 

510 all scheduled process will be started in this environment. 

511 """ 

512 if until is None: 

513 self.env.run(until=self.config.t_stop - self.config.t_start) 

514 else: 

515 self.env.run(until=until) 

516 

517 def register_callbacks(self): 

518 pass 

519 

520 def _register_input_callbacks(self): 

521 """Register input callbacks""" 

522 # Possible inputs are Inputs and parameters. 

523 # Outputs and states are always the result of the model 

524 # "Complicated" double for-loop to avoid boilerplate code 

525 for _type, model_var_names, ag_vars, callback in zip( 

526 ["input", "parameter"], 

527 [self.model.get_input_names(), self.model.get_parameter_names()], 

528 [self.config.inputs, self.config.parameters], 

529 [self._callback_update_model_input, self._callback_update_model_parameter], 

530 ): 

531 for var in ag_vars: 

532 if var.name in model_var_names: 

533 self.logger.info( 

534 "Registered callback for model %s %s ", _type, var.name 

535 ) 

536 self.agent.data_broker.register_callback( 

537 alias=var.alias, 

538 source=var.source, 

539 callback=callback, 

540 name=var.name, 

541 ) 

542 # Case for variable overwriting 

543 if var.value is not None: 

544 self.logger.debug( 

545 "Updating model %s %s=%s", _type, var.name, var.value 

546 ) 

547 self.model.set(name=var.name, value=var.value) 

548 self._inputs_changed_since_last_results_saving = True 

549 

550 def _callback_update_model_input(self, inp: AgentVariable, name: str): 

551 """Set given model input value to the model""" 

552 self.logger.debug("Updating model input %s=%s", name, inp.value) 

553 self.model.set_input_value(name=name, value=inp.value) 

554 self._inputs_changed_since_last_results_saving = True 

555 

556 def _callback_update_model_parameter(self, par: AgentVariable, name: str): 

557 """Set given model parameter value to the model""" 

558 self.logger.debug("Updating model parameter %s=%s", name, par.value) 

559 self.model.set_parameter_value(name=name, value=par.value) 

560 self._inputs_changed_since_last_results_saving = True 

561 

562 def process(self): 

563 """ 

564 Main simulation loop. 

565 Handles simulation stepping, result logging, and synchronization. 

566 """ 

567 # 1. Log Initial State (t=0) 

568 if self.config.save_results: 

569 # Ensure the result buffer has the correct initial inputs 

570 in_values = [var.value for var in self._input_vars] 

571 self._result.initialize_inputs(in_values) 

572 self._result.initialize_outputs(self.env.time) 

573 # Prevent false positive "input change" log at t=0 due to initialization callbacks 

574 self._inputs_changed_since_last_results_saving = False 

575 while True: 

576 # Determine the time points for the next communication step 

577 t_samples = create_time_samples( 

578 t_end=self.config.t_sample_communication, 

579 dt=self.config.t_sample_simulation 

580 ) 

581 

582 # Iterate through simulation sub-steps 

583 for i in range(len(t_samples) - 1): 

584 dt_sim = float(t_samples[i + 1] - t_samples[i]) 

585 

586 # 2. Check for Input Changes (Pre-Step) 

587 # If inputs changed since the last step (or during the yield), 

588 # we log them now. 

589 # This ensures the new inputs are recorded at the current timestamp, 

590 # separate from the outputs of the *previous* step (which were logged at 

591 # the end of the last loop). 

592 

593 # Track if this is the first simulation sub-step within the current 

594 # communication interval. At communication boundaries, the inputs are 

595 # always saved to ensure they are associated with the correct outputs. 

596 full_comm_step = (i == 0) 

597 if self._inputs_changed_since_last_results_saving or full_comm_step: 

598 if self.config.save_results: 

599 # Create row: [t=Current, Out=NaN, In=New] 

600 self._log_inputs(self.env.time, 

601 capture_all_inputs=self.config.capture_all_inputs) 

602 self._inputs_changed_since_last_results_saving = False 

603 

604 # 3. Perform Simulation Step 

605 self.model.do_step( 

606 t_start=self.config.t_start + self.env.now, 

607 t_sample=dt_sim 

608 ) 

609 

610 # 4. Store intermediate outputs 

611 if self.config.save_results: 

612 out_values = [var.value for var in self._output_vars] 

613 self._result.update_current_outputs(out_values) 

614 

615 # 5. Write results 

616 if self.config.save_results: 

617 # Since simulation has been performed, the model and its results are 

618 # already a time step ahead 

619 current_time = self.env.time + self.config.t_sample_simulation 

620 if ((current_time - self._last_communication_time) >= 

621 self.config.t_sample_communication): 

622 # Update time tracker for communication 

623 self._last_communication_time = ((current_time // 

624 self.config.t_sample_communication) * 

625 self.config.t_sample_communication) 

626 # Check if we need to write to disk, do this before storing 

627 # outputs, to initialize the new row after dumping the results 

628 self._check_and_write_to_disk(self.env.time + 

629 self.config.t_sample_simulation) 

630 

631 # Log the outputs resulting from the step we just finished. 

632 # These will be paired with the inputs active for the next simulation step. 

633 self._log_outputs(self._last_communication_time) 

634 

635 # 6. Wait for the environment 

636 yield self.env.timeout(dt_sim) 

637 

638 # 7. End of Communication Step (Post-Step) 

639 # Communicate 

640 self.update_module_vars() 

641 

642 def _log_inputs(self, time: float, capture_all_inputs: bool): 

643 """ 

644 Update the result object with current inputs. 

645 If capture_all_inputs is True, a row is added immediately. 

646 """ 

647 values = [var.value for var in self._input_vars] 

648 self._result.update_inputs(values, time, capture_all_inputs=capture_all_inputs) 

649 

650 def _log_outputs(self, time: float): 

651 """ 

652 Add a full result row (Outputs + Last Inputs). 

653 """ 

654 values = [var.value for var in self._output_vars] 

655 self._result.update_outputs(values, time) 

656 

657 def _check_and_write_to_disk(self, time): 

658 """Check if write delay has passed and dump to disk.""" 

659 if not self.config.result_filename: 

660 return 

661 

662 # Inputs are written in the next time step, therefore results 

663 # are behind actual env time 

664 current_result_time = time - self.config.t_sample_communication 

665 if (current_result_time - self._last_write_time) >= self.config.write_results_delay: 

666 self._result.write_results() 

667 self._last_write_time = time 

668 

669 def update_module_vars(self): 

670 """ 

671 Method to write current model output and states 

672 values to the module outputs and states. 

673 """ 

674 # pylint: disable=logging-fstring-interpolation 

675 for _type, model_get, agent_vars in zip( 

676 ["state", "output"], 

677 [self.model.get_state, self.model.get_output], 

678 [self.config.states, self.config.outputs], 

679 ): 

680 for var in agent_vars: 

681 mo_var = model_get(var.name) 

682 if mo_var is None: 

683 raise KeyError(f"Given variable {var.name} not found in model.") 

684 value = self._get_uncertain_value(model_variable=mo_var) 

685 self.logger.debug("Updating %s %s=%s", _type, var.name, value) 

686 self.set(name=var.name, value=value) 

687 

688 def _get_uncertain_value(self, model_variable: ModelVariable) -> float: 

689 """Get the value with added uncertainty based on the value of the variable""" 

690 if isinstance(self.config.measurement_uncertainty, dict): 

691 bias = self.config.measurement_uncertainty.get(model_variable.name, 0) 

692 else: 

693 bias = self.config.measurement_uncertainty 

694 return model_variable.value * (1 + np.random.uniform(-bias, bias)) 

695 

696 def get_results(self) -> Optional[pd.DataFrame]: 

697 """ 

698 Return the current results. 

699 

700 Returns: 

701 pd.DataFrame: The results DataFrame. 

702 """ 

703 if not self.config.save_results: 

704 return 

705 file = self.config.result_filename 

706 if file: 

707 self._result.write_results() 

708 df = read_simulator_results(file) 

709 else: 

710 df = self._result.df() 

711 df = df.droplevel(level=2, axis=1).droplevel(level=0, axis=1) 

712 return df 

713 

714 def cleanup_results(self): 

715 if not self.config.save_results or not self.config.result_filename: 

716 return 

717 os.remove(self.config.result_filename) 

718 

719 def _get_result_input_variables(self) -> List[ModelVariable]: 

720 """Gets all input variables to be saved in the results based on 

721 self.result_causalities. Input variables are added to the results at the time 

722 index before an interval, i.e. parameters and inputs.""" 

723 _variables = [] 

724 for causality in self.config.result_causalities: 

725 if causality == Causality.input: 

726 _variables.extend(self.model.inputs) 

727 elif causality in [Causality.parameter, Causality.calculatedParameter]: 

728 _variables.extend(self.model.parameters) 

729 return _variables 

730 

731 def _get_result_output_variables(self) -> List[ModelVariable]: 

732 """Gets all output variables to be saved in the results based on 

733 self.result_causalities. Input variables are added to the results at the time 

734 index after an interval, i.e. locals and outputs.""" 

735 _variables = [] 

736 for causality in self.config.result_causalities: 

737 if causality == Causality.output: 

738 _variables.extend(self.model.outputs) 

739 elif causality == Causality.local: 

740 _variables.extend(self.model.states) 

741 return _variables 

742 

743 

744def convert_agent_vars_to_list_of_dicts(var: AgentVariables) -> List[Dict]: 

745 """ 

746 Function to convert AgentVariables to a list of dictionaries containing information for 

747 ModelVariables. 

748 """ 

749 var_dict_list = [ 

750 agent_var.dict(exclude={"source", "alias", "shared", "rdf_class"}) 

751 for agent_var in var 

752 ] 

753 return var_dict_list