Coverage for agentlib/modules/simulation/simulator.py: 88%
279 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-12-23 08:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-12-23 08:15 +0000
1"""
2Module contains the Simulator, used to simulate any model.
3"""
5import os
6import warnings
7from dataclasses import dataclass
8from math import inf
9from pathlib import Path
10from typing import Union, Dict, List, Optional
12import numpy as np
13import pandas as pd
14from pydantic import field_validator, Field
15from pydantic_core.core_schema import FieldValidationInfo
17from agentlib.core import (
18 BaseModule,
19 BaseModuleConfig,
20 Agent,
21 Causality,
22 AgentVariable,
23 AgentVariables,
24 ModelVariable,
25 Model,
26)
27from agentlib.core.errors import OptionalDependencyError
28from agentlib.models import get_model_type, UNINSTALLED_MODEL_TYPES
29from agentlib.utils import custom_injection, create_time_samples
32@dataclass
33class SimulatorResults:
34 """Class to organize in-memory simulator results."""
36 index: List[float]
37 columns: pd.MultiIndex
38 data: List[List[float]]
40 def __init__(self, variables: List[ModelVariable]):
41 """
42 Initializes results object input variables
43 "u", outputs "x" and internal state variables "x".
44 It uses a nested dict provided from model class.
45 +---------------------------------------------------+
46 | | inputs u | outputs y | states x |
47 | t | u1 | u2 | y1 | y2 | x1 | x2 |
48 +---------------------------------------------------+
49 | 1 | ... | ... | ... | ... | ... | ... |
50 | 2 | ... | ... | ... | ... | ... | ... |
51 |...| ... | ... | ... | ... | ... | ... |
52 |...
53 """
54 self.columns = pd.MultiIndex.from_arrays(
55 arrays=np.array(
56 [
57 [_var.causality.name for _var in variables],
58 [_var.name for _var in variables],
59 [_var.type for _var in variables],
60 ]
61 ),
62 sortorder=0,
63 names=["causality", "name", "type"],
64 )
65 self.index = []
66 self.data = []
68 def initialize(
69 self,
70 time: float,
71 ):
72 """Adds the first row to the data"""
74 def df(self) -> pd.DataFrame:
75 """Returns the current results as a dataframe."""
76 return pd.DataFrame(self.data, index=self.index, columns=self.columns)
78 def write_results(self, file: str):
79 """
80 Dumps results which are currently in memory to a file.
81 On creation of the file, the header columns are dumped, as well.
82 """
83 header = not Path(file).exists()
84 self.df().to_csv(file, mode="a", header=header)
85 # keep the last row of the results
86 self.index = [self.index[-1]]
87 self.data = [self.data[-1]]
90def read_simulator_results(file: str):
91 """Reads results from file with correct multi-column format."""
92 return pd.read_csv(file, header=[0, 1, 2], index_col=0)
95class SimulatorConfig(BaseModuleConfig):
96 """
97 Pydantic data model for simulator configuration parser
98 """
100 parameters: AgentVariables = []
101 inputs: AgentVariables = []
102 outputs: AgentVariables = []
103 states: AgentVariables = []
104 shared_variable_fields: List[str] = ["outputs"]
105 t_start: Union[float, int] = Field(
106 title="t_start", default=0.0, ge=0, description="Simulation start time"
107 )
108 t_stop: Union[float, int] = Field(
109 title="t_stop", default=inf, ge=0, description="Simulation stop time"
110 )
111 t_sample: Union[float, int] = Field(
112 title="t_sample",
113 default=1,
114 ge=0,
115 description="Deprecated option."
116 )
117 t_sample_communication: Union[float, int] = Field(
118 title="t_sample",
119 default=1,
120 validate_default=True,
121 ge=0,
122 description="Sample time of a full simulation step relevant for communication, including:"
123 "- Perform simulation with t_sample_simulation"
124 "- Update model results and send output values to other Agents or Modules."
125 )
126 t_sample_simulation: Union[float, int] = Field(
127 title="t_sample_simulation",
128 default=1,
129 validate_default=True,
130 ge=0,
131 description="Sample time of the simulation itself. "
132 "The inputs of the models may be updated every other t_sample_simulation, "
133 "as long as the model supports this. Used to override dt of the model."
134 )
135 model: Dict
137 # Model results
138 save_results: bool = Field(
139 title="save_results",
140 default=False,
141 description="If True, results are created and stored",
142 )
143 overwrite_result_file: bool = Field(
144 title="overwrite_result",
145 default=False,
146 description="If True, and the result file already exists, the file is overwritten.",
147 )
148 result_filename: Optional[str] = Field(
149 title="result_filename",
150 default=None,
151 description="If not None, results are stored in that filename."
152 "Needs to be a .csv file",
153 )
154 result_sep: str = Field(
155 title="result_sep",
156 default=",",
157 description="Separator in the .csv file. Only relevant if "
158 "result_filename is passed",
159 )
160 result_causalities: List[Causality] = Field(
161 title="result_causalities",
162 default=[Causality.input, Causality.output],
163 description="List of causalities to store. Default stores "
164 "only inputs and outputs",
165 )
166 write_results_delay: Optional[float] = Field(
167 title="Write Results Delay",
168 default=None,
169 description="Sampling interval for which the results are written to disc in seconds.",
170 validate_default=True,
171 gt=0,
172 )
173 measurement_uncertainty: Union[Dict[str, float], float] = Field(
174 title="measurement_uncertainty",
175 default=0,
176 description="Either pass a float and add the percentage uncertainty "
177 "to all measurements from the model."
178 "Or pass a Dict and specify the model variable name as key"
179 "and the associated uncertainty as a float",
180 )
181 validate_incoming_values: Optional[bool] = Field(
182 default=False, # we overwrite the default True in base, to be more efficient
183 title="Validate Incoming Values",
184 description="If true, the validator of the AgentVariable value is called when "
185 "receiving a new value from the DataBroker. In the simulator, this "
186 "is False by default, as we expect to receive a lot of measurements"
187 " and want to be efficient.",
188 )
189 update_inputs_on_callback: bool = Field(
190 title="update_inputs_on_callback",
191 default=True,
192 description="Deprecated! Will be removed in future versions."
193 "If True, model inputs are updated if they are updated in data_broker."
194 "Else, the model inputs are updated before each simulation.",
195 )
197 @field_validator("result_filename")
198 @classmethod
199 def check_nonexisting_csv(cls, result_filename, info: FieldValidationInfo):
200 """Check if the result_filename is a .csv file or an hf
201 and assert that it does not exist."""
202 if not info.data.get("save_results", False):
203 # No need to check as filename will never be used anyways
204 return None
205 if result_filename is None:
206 return result_filename
207 if not result_filename.endswith(".csv"):
208 raise TypeError(
209 f"Given result_filename ends with "
210 f'{result_filename.split(".")[-1]} '
211 f"but should be a .csv file"
212 )
213 if os.path.isfile(result_filename):
214 # remove result file, so a new one can be created
215 if info.data["overwrite_result_file"]:
216 os.remove(result_filename)
217 return result_filename
218 raise FileExistsError(
219 f"Given result_filename at {result_filename} "
220 f"already exists. We won't overwrite it automatically. "
221 f"You can use the key word 'overwrite_result_file' to "
222 f"activate automatic overwrite."
223 )
224 # Create path in case it does not exist
225 fpath = os.path.dirname(result_filename)
226 if fpath:
227 os.makedirs(fpath, exist_ok=True)
228 return result_filename
230 @field_validator("t_stop")
231 @classmethod
232 def check_t_stop(cls, t_stop, info: FieldValidationInfo):
233 """Check if stop is greater than start time"""
234 t_start = info.data.get("t_start")
235 assert t_stop > t_start, "t_stop must be greater than t_start"
236 return t_stop
238 @field_validator("t_sample_communication", "t_sample_simulation")
239 @classmethod
240 def check_t_sample(cls, t_sample, info: FieldValidationInfo):
241 """Check if t_sample is smaller than stop-start time"""
242 t_start = info.data.get("t_start")
243 t_stop = info.data.get("t_stop")
244 t_sample_old = info.data.get("t_sample")
245 if t_sample_old != 1: # A change in the default shows t_sample is still in the config of the user
246 if info.field_name == "t_sample_simulation":
247 t_sample = 1
248 else:
249 t_sample = t_sample_old
250 assert (
251 t_start + t_sample <= t_stop
252 ), "t_stop-t_start must be greater than t_sample"
253 return t_sample
255 @field_validator("t_sample_communication")
256 @classmethod
257 def check_t_comm_against_sim(cls, t_sample_communication, info: FieldValidationInfo):
258 """Check if t_sample is smaller than stop-start time"""
259 t_sample_simulation = info.data.get("t_sample_simulation")
260 if t_sample_simulation is not None:
261 if t_sample_simulation > t_sample_communication:
262 warnings.warn(
263 f"{t_sample_communication=} is smaller than {t_sample_simulation=}",
264 category=UserWarning
265 )
266 return t_sample_communication
268 @field_validator("update_inputs_on_callback")
269 @classmethod
270 def deprecate_update_inputs_on_callback(cls, update_inputs_on_callback, info: FieldValidationInfo):
271 """Check if t_sample is smaller than stop-start time"""
272 warnings.warn(
273 "update_inputs_on_callback is deprecated, remove it from your config. "
274 "Will use update_inputs_on_callback=True",
275 category=DeprecationWarning
276 )
277 return True
279 @field_validator("t_sample")
280 @classmethod
281 def deprecate_t_sample(cls, t_sample, info: FieldValidationInfo):
282 """Deprecates the t_sample field in favor of t_sample_communication and t_sample_simulation."""
283 warnings.warn(
284 "t_sample is deprecated, use t_sample_communication, "
285 "t_sample_simulation for a concise separation of the two. "
286 "Will use the given t_sample for t_sample_communication and t_sample_simulation=1 s, "
287 "the `model.dt` default.",
288 )
289 return t_sample
291 @field_validator("write_results_delay")
292 @classmethod
293 def set_default_t_sample(cls, write_results_delay, info: FieldValidationInfo):
294 t_sample = info.data["t_sample"]
295 if write_results_delay is None:
296 # 5 is an arbitrary default which should balance writing new results as
297 # soon as possible to disk with saving file I/O overhead
298 return 5 * t_sample
299 if write_results_delay < t_sample:
300 raise ValueError(
301 "Saving results more frequently than you simulate makes no sense. "
302 "Increase write_results_delay above t_sample."
303 )
304 return write_results_delay
306 @field_validator("model")
307 @classmethod
308 def check_model(cls, model, info: FieldValidationInfo):
309 """Validate the model input"""
310 parameters = info.data.get("parameters")
311 inputs = info.data.get("inputs")
312 outputs = info.data.get("outputs")
313 states = info.data.get("states")
314 dt = info.data.get("t_sample_simulation")
315 if "dt" in model and dt != model["dt"]:
316 warnings.warn(
317 f"Given model {model['dt']=} differs from {dt=} of simulator. "
318 f"Using models dt, consider switching to t_sample_simulation."
319 )
320 else:
321 model["dt"] = dt
322 if "type" not in model:
323 raise KeyError(
324 "Given model config does not " "contain key 'type' (type of the model)."
325 )
326 _type = model.pop("type")
327 if isinstance(_type, dict):
328 custom_cls = custom_injection(config=_type)
329 model = custom_cls(**model)
330 elif isinstance(_type, str):
331 if _type in UNINSTALLED_MODEL_TYPES:
332 raise OptionalDependencyError(
333 dependency_name=_type,
334 dependency_install=UNINSTALLED_MODEL_TYPES[_type],
335 used_object=f"model {_type}",
336 )
337 model = get_model_type(_type)(
338 **model,
339 parameters=convert_agent_vars_to_list_of_dicts(parameters),
340 inputs=convert_agent_vars_to_list_of_dicts(inputs),
341 outputs=convert_agent_vars_to_list_of_dicts(outputs),
342 states=convert_agent_vars_to_list_of_dicts(states),
343 )
344 # Check if model was correctly initialized
345 assert isinstance(model, Model)
346 return model
349class Simulator(BaseModule):
350 """
351 The Simulator is the interface between simulation models
352 and further other implementations. It contains all interface functions for
353 interacting with the standard model class.
354 """
356 config: SimulatorConfig
358 def __init__(self, *, config: dict, agent: Agent):
359 super().__init__(config=config, agent=agent)
360 # Initialize instance attributes
361 self._model = None
362 self.model = self.config.model
363 self._result: SimulatorResults = SimulatorResults(
364 variables=self._get_result_model_variables()
365 )
366 self._save_count: int = 1 # tracks, how often results have been saved
367 self._inputs_changed_since_last_results_saving = False
368 self._register_input_callbacks()
369 self.logger.info("%s initialized!", self.__class__.__name__)
371 def terminate(self):
372 """Terminate the model"""
373 self.model.terminate()
374 super().terminate()
376 @property
377 def model(self) -> Model:
378 """
379 Getter for current simulation model
381 Returns:
382 agentlib.core.model.Model: Current simulation model
383 """
384 return self._model
386 @model.setter
387 def model(self, model: Model):
388 """
389 Setter for current simulation model.
390 Also initializes it if needed!
391 Args:
392 model (agentlib.core.model.Model): model to set as current simulation model
393 """
394 if not isinstance(model, Model):
395 self.logger.error(
396 "You forgot to pass a valid model to the simulator module!"
397 )
398 raise TypeError(
399 f"Given model is of type {type(model)} "
400 f"but should be an instance of Model or a valid subclass"
401 )
402 self._model = model
403 if self.config.t_start and self.env.offset:
404 self.logger.warning(
405 "config.t_start and env.offset are both non-zero. "
406 "This may cause unexpected behavior. Ensure that this "
407 "is intended and you know what you are doing."
408 )
409 self.model.initialize(
410 t_start=self.config.t_start + self.env.config.offset,
411 t_stop=self.config.t_stop,
412 )
413 self.logger.info("Model successfully loaded model: %s", self.model.name)
415 def run(self, until=None):
416 """
417 Runs the simulator in stand-alone mode if needed
418 Attention: If the environment is connected to another environment
419 all scheduled process will be started in this environment.
420 """
421 if until is None:
422 self.env.run(until=self.config.t_stop - self.config.t_start)
423 else:
424 self.env.run(until=until)
426 def register_callbacks(self):
427 pass
429 def _register_input_callbacks(self):
430 """Register input callbacks"""
431 # Possible inputs are Inputs and parameters.
432 # Outputs and states are always the result of the model
433 # "Complicated" double for-loop to avoid boilerplate code
434 for _type, model_var_names, ag_vars, callback in zip(
435 ["input", "parameter"],
436 [self.model.get_input_names(), self.model.get_parameter_names()],
437 [self.config.inputs, self.config.parameters],
438 [self._callback_update_model_input, self._callback_update_model_parameter],
439 ):
440 for var in ag_vars:
441 if var.name in model_var_names:
442 self.logger.info(
443 "Registered callback for model %s %s ", _type, var.name
444 )
445 self.agent.data_broker.register_callback(
446 alias=var.alias,
447 source=var.source,
448 callback=callback,
449 name=var.name,
450 )
451 # Case for variable overwriting
452 if var.value is not None:
453 self.logger.debug(
454 "Updating model %s %s=%s", _type, var.name, var.value
455 )
456 self.model.set(name=var.name, value=var.value)
457 self._inputs_changed_since_last_results_saving = True
459 def _callback_update_model_input(self, inp: AgentVariable, name: str):
460 """Set given model input value to the model"""
461 self.logger.debug("Updating model input %s=%s", name, inp.value)
462 self.model.set_input_value(name=name, value=inp.value)
463 self._inputs_changed_since_last_results_saving = True
465 def _callback_update_model_parameter(self, par: AgentVariable, name: str):
466 """Set given model parameter value to the model"""
467 self.logger.debug("Updating model parameter %s=%s", name, par.value)
468 self.model.set_parameter_value(name=name, value=par.value)
469 self._inputs_changed_since_last_results_saving = True
471 def process(self):
472 """
473 This function creates a endless loop for the single simulation step event,
474 updating inputs, simulating, model results and then outputs.
476 In a simulation step following happens:
477 1. Specify the end time of the simulation from the agents perspective.
478 **Important note**: The agents use unix-time as a timestamp and start
479 the simulation with the current datetime (represented by self.env.time),
480 the model starts at 0 seconds (represented by self.env.now).
481 2. Directly after the simulation we store the results with
482 the output time and then call the timeout in the environment,
483 hence actually increase the environment time.
484 3. Once the environment time reached the simulation time,
485 we send the updated output values to other modules and agents by setting
486 them the data_broker.
487 """
488 self._update_results(timestamp_inputs=self.env.time, timestamp_outputs=self.env.time)
489 while True:
490 # Simulate
491 t_samples = create_time_samples(
492 t_end=self.config.t_sample_communication,
493 dt=self.config.t_sample_simulation
494 )
495 _t_start_simulation_loop = self.env.time
496 self.logger.debug("Doing simulation steps %s", t_samples)
497 for _idx, _t_sample in enumerate(t_samples[:-1]):
498 _t_start = self.env.now + self.config.t_start
499 dt_sim = t_samples[_idx + 1] - _t_sample
500 self.model.do_step(t_start=_t_start, t_sample=dt_sim)
501 if _idx == len(t_samples) - 2 or self._inputs_changed_since_last_results_saving:
502 if not self._inputs_changed_since_last_results_saving:
503 # Did not change during simulation step
504 timestamp_inputs = _t_start_simulation_loop
505 else:
506 # The inputs are only applied at self.env.time, not when they are received by the communicator
507 timestamp_inputs = self.env.time
508 # Update the results
509 self._update_results(
510 timestamp_outputs=self.env.time + dt_sim,
511 timestamp_inputs=timestamp_inputs
512 )
513 yield self.env.timeout(dt_sim)
514 # Communicate
515 self.update_module_vars()
517 def update_module_vars(self):
518 """
519 Method to write current model output and states
520 values to the module outputs and states.
521 """
522 # pylint: disable=logging-fstring-interpolation
523 for _type, model_get, agent_vars in zip(
524 ["state", "output"],
525 [self.model.get_state, self.model.get_output],
526 [self.config.states, self.config.outputs],
527 ):
528 for var in agent_vars:
529 mo_var = model_get(var.name)
530 if mo_var is None:
531 raise KeyError(f"Given variable {var.name} not found in model.")
532 value = self._get_uncertain_value(model_variable=mo_var)
533 self.logger.debug("Updating %s %s=%s", _type, var.name, value)
534 self.set(name=var.name, value=value)
536 def _get_uncertain_value(self, model_variable: ModelVariable) -> float:
537 """Get the value with added uncertainty based on the value of the variable"""
538 if isinstance(self.config.measurement_uncertainty, dict):
539 bias = self.config.measurement_uncertainty.get(model_variable.name, 0)
540 else:
541 bias = self.config.measurement_uncertainty
542 return model_variable.value * (1 + np.random.uniform(-bias, bias))
544 def get_results(self) -> Optional[pd.DataFrame]:
545 """
546 Return the current results.
548 Returns:
549 pd.DataFrame: The results DataFrame.
550 """
551 if not self.config.save_results:
552 return
553 file = self.config.result_filename
554 if file:
555 self._result.write_results(self.config.result_filename)
556 df = read_simulator_results(file)
557 else:
558 df = self._result.df()
559 df = df.droplevel(level=2, axis=1).droplevel(level=0, axis=1)
560 return df
562 def cleanup_results(self):
563 if not self.config.save_results or not self.config.result_filename:
564 return
565 os.remove(self.config.result_filename)
567 def _update_results(self, timestamp_outputs, timestamp_inputs):
568 """
569 Adds model variables to the SimulationResult object
570 at the given timestamp.
571 """
572 if not self.config.save_results:
573 return
575 inp_values = [var.value for var in self._get_result_input_variables()]
576 self._inputs_changed_since_last_results_saving = False
578 out_values = [var.value for var in self._get_result_output_variables()]
579 _len_outputs = len(out_values)
581 # Two cases:
582 # - Either both timestamps are the same and new, add them both
583 # - or the inputs are for an earlier timestamp -> first add the inputs and then append the new outputs index.
584 if timestamp_inputs == timestamp_outputs:
585 # self.logger.debug("Storing data at the same time stamp %s s", timestamp_outputs)
586 self._result.index.append(timestamp_outputs)
587 self._result.data.append(out_values + inp_values)
588 elif timestamp_inputs < timestamp_outputs:
589 # add inputs in the time stamp before adding outputs, as they are active from
590 # the start of this interval
591 if timestamp_inputs in self._result.index:
592 if timestamp_inputs == self._result.index[-1]:
593 # self.logger.debug("Adding inputs to last time stamp %s s", timestamp_inputs)
594 self._result.data[-1] = self._result.data[-1][:_len_outputs] + inp_values
595 # else: pass, as inputs are outdated (have been changed during simulation step)
596 else:
597 # This case may occur if inputs changed during simulation.
598 # In this case, the inputs hold for current time - t_sample_simulation, but the outputs
599 # hold for the current time. In this case, just add Nones as outputs.
600 self._result.index.append(timestamp_inputs)
601 self._result.data.append([None] * _len_outputs + inp_values)
602 # self.logger.debug(
603 # "Storing inputs only due to changes during simulation at time stamp %s s",
604 # timestamp_inputs
605 # )
607 # self.logger.debug("Storing outputs at time stamp %s s", timestamp_outputs)
608 self._result.index.append(timestamp_outputs)
609 self._result.data.append(out_values + [None] * len(inp_values))
610 else:
611 raise ValueError("Storing inputs ahead of outputs is not supported.")
613 if (
614 self.config.result_filename is not None
615 and timestamp_outputs // (self.config.write_results_delay * self._save_count) > 0
616 ):
617 self._save_count += 1
618 self._result.write_results(self.config.result_filename)
620 def _get_result_model_variables(self) -> AgentVariables:
621 """
622 Gets all variables to be saved in the result based
623 on self.result_causalities.
624 """
626 # THE ORDER OF THIS CONCAT IS IMPORTANT. The _update_results function will
627 # extend the outputs with the inputs
628 return self._get_result_output_variables() + self._get_result_input_variables()
630 def _get_result_input_variables(self) -> AgentVariables:
631 """Gets all input variables to be saved in the results based on
632 self.result_causalities. Input variables are added to the results at the time
633 index before an interval, i.e. parameters and inputs."""
634 _variables = []
635 for causality in self.config.result_causalities:
636 if causality == Causality.input:
637 _variables.extend(self.model.inputs)
638 elif causality in [Causality.parameter, Causality.calculatedParameter]:
639 _variables.extend(self.model.parameters)
640 return _variables
642 def _get_result_output_variables(self) -> AgentVariables:
643 """Gets all output variables to be saved in the results based on
644 self.result_causalities. Input variables are added to the results at the time
645 index after an interval, i.e. locals and outputs."""
646 _variables = []
647 for causality in self.config.result_causalities:
648 if causality == Causality.output:
649 _variables.extend(self.model.outputs)
650 elif causality == Causality.local:
651 _variables.extend(self.model.states)
652 return _variables
655def convert_agent_vars_to_list_of_dicts(var: AgentVariables) -> List[Dict]:
656 """
657 Function to convert AgentVariables to a list of dictionaries containing information for
658 ModelVariables.
659 """
660 var_dict_list = [
661 agent_var.dict(exclude={"source", "alias", "shared", "rdf_class"})
662 for agent_var in var
663 ]
664 return var_dict_list