Coverage for agentlib/modules/simulation/simulator.py: 87%
233 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
1"""
2Module contains the Simulator, used to simulate any model.
3"""
5import os
6from dataclasses import dataclass
7from math import inf
8from pathlib import Path
9from typing import Union, Dict, List, Optional
11import numpy as np
12import pandas as pd
13from pydantic import field_validator, Field
14from pydantic_core.core_schema import FieldValidationInfo
16from agentlib.core import (
17 BaseModule,
18 BaseModuleConfig,
19 Agent,
20 Causality,
21 AgentVariable,
22 AgentVariables,
23 ModelVariable,
24 Model,
25)
26from agentlib.core.errors import OptionalDependencyError
27from agentlib.models import get_model_type, UNINSTALLED_MODEL_TYPES
28from agentlib.utils import custom_injection
31@dataclass
32class SimulatorResults:
33 """Class to organize in-memory simulator results."""
35 index: List[float]
36 columns: pd.MultiIndex
37 data: List[List[float]]
39 def __init__(self, variables: List[ModelVariable]):
40 """
41 Initializes results object input variables
42 "u", outputs "x" and internal state variables "x".
43 It uses a nested dict provided from model class.
44 +---------------------------------------------------+
45 | | inputs u | outputs y | states x |
46 | t | u1 | u2 | y1 | y2 | x1 | x2 |
47 +---------------------------------------------------+
48 | 1 | ... | ... | ... | ... | ... | ... |
49 | 2 | ... | ... | ... | ... | ... | ... |
50 |...| ... | ... | ... | ... | ... | ... |
51 |...
52 """
53 self.columns = pd.MultiIndex.from_arrays(
54 arrays=np.array(
55 [
56 [_var.causality.name for _var in variables],
57 [_var.name for _var in variables],
58 [_var.type for _var in variables],
59 ]
60 ),
61 sortorder=0,
62 names=["causality", "name", "type"],
63 )
64 self.index = []
65 self.data = []
67 def initialize(
68 self,
69 time: float,
70 ):
71 """Adds the first row to the data"""
73 def df(self) -> pd.DataFrame:
74 """Returns the current results as a dataframe."""
75 # We do not return the last row, as it is always only half complete (since
76 # inputs at time step k influence results of time step k+1. Writing in
77 # incomplete dataframe would break the csv-file we append to.
78 return pd.DataFrame(self.data[:-1], index=self.index[:-1], columns=self.columns)
80 def write_results(self, file: str):
81 """
82 Dumps results which are currently in memory to a file.
83 On creation of the file, the header columns are dumped, as well.
84 """
85 header = not Path(file).exists()
86 self.df().to_csv(file, mode="a", header=header)
87 # keep the last row of the results, as it is not finished (inputs missing)
88 self.index = [self.index[-1]]
89 self.data = [self.data[-1]]
92def read_simulator_results(file: str):
93 """Reads results from file with correct multi-column format."""
94 return pd.read_csv(file, header=[0, 1, 2], index_col=0)
97class SimulatorConfig(BaseModuleConfig):
98 """
99 Pydantic data model for simulator configuration parser
100 """
102 parameters: AgentVariables = []
103 inputs: AgentVariables = []
104 outputs: AgentVariables = []
105 states: AgentVariables = []
106 shared_variable_fields: List[str] = ["outputs"]
107 model: Dict
109 t_start: Union[float, int] = Field(
110 title="t_start", default=0.0, ge=0, description="Simulation start time"
111 )
112 t_stop: Union[float, int] = Field(
113 title="t_stop", default=inf, ge=0, description="Simulation stop time"
114 )
115 t_sample: Union[float, int] = Field(
116 title="t_sample", default=1, ge=0, description="Simulation sample time"
117 )
118 # Model results
119 save_results: bool = Field(
120 title="save_results",
121 default=False,
122 description="If True, results are created and stored",
123 )
124 overwrite_result_file: bool = Field(
125 title="overwrite_result",
126 default=False,
127 description="If True, and the result file already exists, the file is overwritten.",
128 )
129 result_filename: Optional[str] = Field(
130 title="result_filename",
131 default=None,
132 description="If not None, results are stored in that filename."
133 "Needs to be a .csv file",
134 )
135 result_sep: str = Field(
136 title="result_sep",
137 default=",",
138 description="Separator in the .csv file. Only relevant if "
139 "result_filename is passed",
140 )
141 result_causalities: List[Causality] = Field(
142 title="result_causalities",
143 default=[Causality.input, Causality.output],
144 description="List of causalities to store. Default stores "
145 "only inputs and outputs",
146 )
147 write_results_delay: Optional[float] = Field(
148 title="Write Results Delay",
149 default=None,
150 description="Sampling interval for which the results are written to disc in seconds.",
151 validate_default=True,
152 gt=0,
153 )
154 update_inputs_on_callback: bool = Field(
155 title="update_inputs_on_callback",
156 default=True,
157 description="If True, model inputs are updated if they are updated in data_broker."
158 "Else, the model inputs are updated before each simulation.",
159 )
160 measurement_uncertainty: Union[Dict[str, float], float] = Field(
161 title="measurement_uncertainty",
162 default=0,
163 description="Either pass a float and add the percentage uncertainty "
164 "to all measurements from the model."
165 "Or pass a Dict and specify the model variable name as key"
166 "and the associated uncertainty as a float",
167 )
168 validate_incoming_values: Optional[bool] = Field(
169 default=False, # we overwrite the default True in base, to be more efficient
170 title="Validate Incoming Values",
171 description="If true, the validator of the AgentVariable value is called when "
172 "receiving a new value from the DataBroker. In the simulator, this "
173 "is False by default, as we expect to receive a lot of measurements"
174 " and want to be efficient.",
175 )
177 @field_validator("result_filename")
178 @classmethod
179 def check_nonexisting_csv(cls, result_filename, info: FieldValidationInfo):
180 """Check if the result_filename is a .csv file or an hf
181 and assert that it does not exist."""
182 if not info.data.get("save_results", False):
183 # No need to check as filename will never be used anyways
184 return None
185 if result_filename is None:
186 return result_filename
187 if not result_filename.endswith(".csv"):
188 raise TypeError(
189 f"Given result_filename ends with "
190 f'{result_filename.split(".")[-1]} '
191 f"but should be a .csv file"
192 )
193 if os.path.isfile(result_filename):
194 # remove result file, so a new one can be created
195 if info.data["overwrite_result_file"]:
196 os.remove(result_filename)
197 return result_filename
198 raise FileExistsError(
199 f"Given result_filename at {result_filename} "
200 f"already exists. We won't overwrite it automatically. "
201 f"You can use the key word 'overwrite_result_file' to "
202 f"activate automatic overwrite."
203 )
204 # Create path in case it does not exist
205 fpath = os.path.dirname(result_filename)
206 if fpath:
207 os.makedirs(fpath, exist_ok=True)
208 return result_filename
210 @field_validator("t_stop")
211 @classmethod
212 def check_t_stop(cls, t_stop, info: FieldValidationInfo):
213 """Check if stop is greater than start time"""
214 t_start = info.data.get("t_start")
215 assert t_stop > t_start, "t_stop must be greater than t_start"
216 return t_stop
218 @field_validator("t_sample")
219 @classmethod
220 def check_t_sample(cls, t_sample, info: FieldValidationInfo):
221 """Check if t_sample is smaller than stop-start time"""
222 t_start = info.data.get("t_start")
223 t_stop = info.data.get("t_stop")
224 assert (
225 t_start + t_sample <= t_stop
226 ), "t_stop-t_start must be greater than t_sample"
227 return t_sample
229 @field_validator("write_results_delay")
230 @classmethod
231 def set_default_t_sample(cls, write_results_delay, info: FieldValidationInfo):
232 t_sample = info.data["t_sample"]
233 if write_results_delay is None:
234 # 5 is an arbitrary default which should balance writing new results as
235 # soon as possible to disk with saving file I/O overhead
236 return 5 * t_sample
237 if write_results_delay < t_sample:
238 raise ValueError(
239 "Saving results more frequently than you simulate makes no sense. "
240 "Increase write_results_delay above t_sample."
241 )
242 return write_results_delay
244 @field_validator("model")
245 @classmethod
246 def check_model(cls, model, info: FieldValidationInfo):
247 """Validate the model input"""
248 parameters = info.data.get("parameters")
249 inputs = info.data.get("inputs")
250 outputs = info.data.get("outputs")
251 states = info.data.get("states")
252 if "type" not in model:
253 raise KeyError(
254 "Given model config does not " "contain key 'type' (type of the model)."
255 )
256 _type = model.pop("type")
257 if isinstance(_type, dict):
258 custom_cls = custom_injection(config=_type)
259 model = custom_cls(**model)
260 elif isinstance(_type, str):
261 if _type in UNINSTALLED_MODEL_TYPES:
262 raise OptionalDependencyError(
263 dependency_name=_type,
264 dependency_install=UNINSTALLED_MODEL_TYPES[_type],
265 used_object=f"model {_type}",
266 )
267 model = get_model_type(_type)(
268 **model,
269 parameters=convert_agent_vars_to_list_of_dicts(parameters),
270 inputs=convert_agent_vars_to_list_of_dicts(inputs),
271 outputs=convert_agent_vars_to_list_of_dicts(outputs),
272 states=convert_agent_vars_to_list_of_dicts(states),
273 )
274 # Check if model was correctly initialized
275 assert isinstance(model, Model)
276 return model
279class Simulator(BaseModule):
280 """
281 The Simulator is the interface between simulation models
282 and further other implementations. It contains all interface functions for
283 interacting with the standard model class.
284 """
286 config: SimulatorConfig
288 def __init__(self, *, config: dict, agent: Agent):
289 super().__init__(config=config, agent=agent)
290 # Initialize instance attributes
291 self._model = None
292 self.model = self.config.model
293 self._result: SimulatorResults = SimulatorResults(
294 variables=self._get_result_model_variables()
295 )
296 self._save_count: int = 1 # tracks, how often results have been saved
297 if self.config.update_inputs_on_callback:
298 self._register_input_callbacks()
299 self.logger.info("%s initialized!", self.__class__.__name__)
301 def terminate(self):
302 """Terminate the model"""
303 self.model.terminate()
304 super().terminate()
306 @property
307 def model(self) -> Model:
308 """
309 Getter for current simulation model
311 Returns:
312 agentlib.core.model.Model: Current simulation model
313 """
314 return self._model
316 @model.setter
317 def model(self, model: Model):
318 """
319 Setter for current simulation model.
320 Also initializes it if needed!
321 Args:
322 model (agentlib.core.model.Model): model to set as current simulation model
323 """
324 if not isinstance(model, Model):
325 self.logger.error(
326 "You forgot to pass a valid model to the simulator module!"
327 )
328 raise TypeError(
329 f"Given model is of type {type(model)} "
330 f"but should be an instance of Model or a valid subclass"
331 )
332 self._model = model
333 if self.config.t_start and self.env.offset:
334 self.logger.warning(
335 "config.t_start and env.offset are both non-zero. "
336 "This may cause unexpected behavior. Ensure that this "
337 "is intended and you know what you are doing."
338 )
339 self.model.initialize(
340 t_start=self.config.t_start + self.env.config.offset,
341 t_stop=self.config.t_stop,
342 )
343 self.logger.info("Model successfully loaded model: %s", self.model.name)
345 def run(self, until=None):
346 """
347 Runs the simulator in stand-alone mode if needed
348 Attention: If the environment is connected to another environment
349 all scheduled process will be started in this environment.
350 """
351 if until is None:
352 self.env.run(until=self.config.t_stop - self.config.t_start)
353 else:
354 self.env.run(until=until)
356 def register_callbacks(self):
357 pass
359 def _register_input_callbacks(self):
360 """Register input callbacks"""
361 # Possible inputs are Inputs and parameters.
362 # Outputs and states are always the result of the model
363 # "Complicated" double for-loop to avoid boilerplate code
364 for _type, model_var_names, ag_vars, callback in zip(
365 ["input", "parameter"],
366 [self.model.get_input_names(), self.model.get_parameter_names()],
367 [self.config.inputs, self.config.parameters],
368 [self._callback_update_model_input, self._callback_update_model_parameter],
369 ):
370 for var in ag_vars:
371 if var.name in model_var_names:
372 self.logger.info(
373 "Registered callback for model %s %s ", _type, var.name
374 )
375 self.agent.data_broker.register_callback(
376 alias=var.alias,
377 source=var.source,
378 callback=callback,
379 name=var.name,
380 )
381 # Case for variable overwriting
382 if var.value is not None:
383 self.logger.debug(
384 "Updating model %s %s=%s", _type, var.name, var.value
385 )
386 self.model.set(name=var.name, value=var.value)
388 def _callback_update_model_input(self, inp: AgentVariable, name: str):
389 """Set given model input value to the model"""
390 self.logger.debug("Updating model input %s=%s", name, inp.value)
391 self.model.set_input_value(name=name, value=inp.value)
393 def _callback_update_model_parameter(self, par: AgentVariable, name: str):
394 """Set given model parameter value to the model"""
395 self.logger.debug("Updating model parameter %s=%s", name, par.value)
396 self.model.set_parameter_value(name=name, value=par.value)
398 def process(self):
399 """
400 This function creates a endless loop for the single simulation step event.
401 The do_step() function needs to return a generator.
402 """
403 self._update_result_outputs(self.env.time)
404 while True:
405 self.do_step()
406 yield self.env.timeout(self.config.t_sample)
407 self.update_module_vars()
409 def do_step(self):
410 """
411 Generator function to perform a simulation step,
412 update inputs, outputs and model results.
414 In a simulation step following happens:
415 1. Update inputs (only necessary if self.update_inputs_on_callback = False)
416 2. Specify the end time of the simulation from the agents perspective.
417 **Important note**: The agents use unix-time as a timestamp and start
418 the simulation with the current datetime (represented by self.env.time),
419 the model starts at 0 seconds (represented by self.env.now).
420 3. Directly after the simulation we send the updated output values
421 to other modules and agents by setting them the data_broker.
422 Even though the environment time is not already at the end time specified above,
423 we explicitly add the timestamp to the variables.
424 This way other agents and communication has the maximum time possible to
425 process the outputs and send input signals to the simulation.
426 4. Call the timeout in the environment,
427 hence actually increase the environment time.
428 """
429 if not self.config.update_inputs_on_callback:
430 # Update inputs manually
431 self.update_model_inputs()
432 # Simulate
433 self.model.do_step(
434 t_start=(self.env.now + self.config.t_start), t_sample=self.config.t_sample
435 )
436 # Update the results and outputs
437 self._update_results()
439 def update_model_inputs(self):
440 """
441 Internal method to write current data_broker to simulation model.
442 Only update values, not other module_types.
443 """
444 model_input_names = (
445 self.model.get_input_names() + self.model.get_parameter_names()
446 )
447 for inp in self.variables:
448 if inp.name in model_input_names:
449 self.logger.debug("Updating model variable %s=%s", inp.name, inp.value)
450 self.model.set(name=inp.name, value=inp.value)
452 def update_module_vars(self):
453 """
454 Method to write current model output and states
455 values to the module outputs and states.
456 """
457 # pylint: disable=logging-fstring-interpolation
458 for _type, model_get, agent_vars in zip(
459 ["state", "output"],
460 [self.model.get_state, self.model.get_output],
461 [self.config.states, self.config.outputs],
462 ):
463 for var in agent_vars:
464 mo_var = model_get(var.name)
465 if mo_var is None:
466 raise KeyError(f"Given variable {var.name} not found in model.")
467 value = self._get_uncertain_value(model_variable=mo_var)
468 self.logger.debug("Updating %s %s=%s", _type, var.name, value)
469 self.set(name=var.name, value=value)
471 def _get_uncertain_value(self, model_variable: ModelVariable) -> float:
472 """Get the value with added uncertainty based on the value of the variable"""
473 if isinstance(self.config.measurement_uncertainty, dict):
474 bias = self.config.measurement_uncertainty.get(model_variable.name, 0)
475 else:
476 bias = self.config.measurement_uncertainty
477 return model_variable.value * (1 + np.random.uniform(-bias, bias))
479 def get_results(self) -> Optional[pd.DataFrame]:
480 """
481 Return the current results.
483 Returns:
484 pd.DataFrame: The results DataFrame.
485 """
486 if not self.config.save_results:
487 return
488 file = self.config.result_filename
489 if file:
490 self._result.write_results(self.config.result_filename)
491 df = read_simulator_results(file)
492 else:
493 df = self._result.df()
494 df = df.droplevel(level=2, axis=1).droplevel(level=0, axis=1)
495 return df
497 def _update_results(self):
498 """
499 Adds model variables to the SimulationResult object
500 at the given timestamp.
501 """
502 if not self.config.save_results:
503 return
504 timestamp = self.env.time + self.config.t_sample
505 inp_values = [var.value for var in self._get_result_input_variables()]
507 # add inputs in the time stamp before adding outputs, as they are active from
508 # the start of this interval
509 self._result.data[-1].extend(inp_values)
510 # adding output results afterwards. If the order here is switched, the [-1]
511 # above will point to the wrong entry
512 self._update_result_outputs(timestamp)
513 if (
514 self.config.result_filename is not None
515 and timestamp // (self.config.write_results_delay * self._save_count) > 0
516 ):
517 self._save_count += 1
518 self._result.write_results(self.config.result_filename)
520 def _update_result_outputs(self, timestamp: float):
521 """Updates results with current values for states and outputs."""
522 self._result.index.append(timestamp)
523 out_values = [var.value for var in self._get_result_output_variables()]
524 self._result.data.append(out_values)
526 def _get_result_model_variables(self) -> AgentVariables:
527 """
528 Gets all variables to be saved in the result based
529 on self.result_causalities.
530 """
532 # THE ORDER OF THIS CONCAT IS IMPORTANT. The _update_results function will
533 # extend the outputs with the inputs
534 return self._get_result_output_variables() + self._get_result_input_variables()
536 def _get_result_input_variables(self) -> AgentVariables:
537 """Gets all input variables to be saved in the results based on
538 self.result_causalities. Input variables are added to the results at the time
539 index before an interval, i.e. parameters and inputs."""
540 _variables = []
541 for causality in self.config.result_causalities:
542 if causality == Causality.input:
543 _variables.extend(self.model.inputs)
544 elif causality in [Causality.parameter, Causality.calculatedParameter]:
545 _variables.extend(self.model.parameters)
546 return _variables
548 def _get_result_output_variables(self) -> AgentVariables:
549 """Gets all output variables to be saved in the results based on
550 self.result_causalities. Input variables are added to the results at the time
551 index after an interval, i.e. locals and outputs."""
552 _variables = []
553 for causality in self.config.result_causalities:
554 if causality == Causality.output:
555 _variables.extend(self.model.outputs)
556 elif causality == Causality.local:
557 _variables.extend(self.model.states)
558 return _variables
561def convert_agent_vars_to_list_of_dicts(var: AgentVariables) -> List[Dict]:
562 """
563 Function to convert AgentVariables to a list of dictionaries containing information for
564 ModelVariables.
565 """
566 var_dict_list = [
567 agent_var.dict(exclude={"source", "alias", "shared", "rdf_class"})
568 for agent_var in var
569 ]
570 return var_dict_list