Coverage for agentlib/core/model.py: 94%
219 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
1"""This module contains just the basic Model."""
3import abc
4import os
5import json
6import logging
7from copy import deepcopy
8from itertools import chain
9from typing import Union, List, Dict, Any, Optional, get_type_hints, Type
10from pydantic import ConfigDict, BaseModel, Field, field_validator
11import numpy as np
12from pydantic.fields import PrivateAttr
13from pydantic_core.core_schema import FieldValidationInfo
15from agentlib.core.datamodels import (
16 ModelVariable,
17 ModelInputs,
18 ModelStates,
19 ModelOutputs,
20 ModelParameters,
21 ModelState,
22 ModelParameter,
23 ModelOutput,
24 ModelInput,
25)
27logger = logging.getLogger(__name__)
30class ModelConfig(BaseModel):
31 """
32 Pydantic data model for controller configuration parser
33 """
35 user_config: dict = Field(
36 default=None,
37 description="The config given by the user to instantiate this class."
38 "Will be stored to enable a valid overwriting of the "
39 "default config and to better restart modules."
40 "Is also useful to debug validators and the general BaseModuleConfig.",
41 )
42 name: Optional[str] = Field(default=None, validate_default=True)
43 description: str = Field(default="You forgot to document your model!")
44 sim_time: float = Field(default=0, title="Current simulation time")
45 dt: Union[float, int] = Field(default=1, title="time increment")
46 validate_variables: bool = Field(
47 default=True,
48 title="Validate Variables",
49 description="If true, the validator of a variables value is called whenever a "
50 "new value is set. Disabled by default for performance reasons.",
51 )
53 inputs: ModelInputs = Field(default=list())
54 outputs: ModelOutputs = Field(default=list())
55 states: ModelStates = Field(default=list())
56 parameters: ModelParameters = Field(default=list())
58 _types: Dict[str, type] = PrivateAttr(
59 default={
60 "inputs": ModelInput,
61 "outputs": ModelOutput,
62 "states": ModelState,
63 "parameters": ModelParameter,
64 }
65 )
66 model_config = ConfigDict(
67 validate_assignment=True, arbitrary_types_allowed=True, extra="forbid"
68 )
70 @field_validator("name")
71 @classmethod
72 def check_name(cls, name):
73 """
74 Check if name of model is given. If not, use the
75 name of the model class.
76 """
77 if name is None:
78 name = str(cls).replace("Config", "")
79 return name
81 @field_validator("parameters", "inputs", "outputs", "states", mode="after")
82 @classmethod
83 def include_default_model_variables(
84 cls, _: List[ModelVariable], info: FieldValidationInfo
85 ):
86 """
87 Validator building block to merge default variables with config variables in a standard validator.
88 Updates default variables when a variable with the same name is present in the config.
89 Then returns the union of the default variables and the external config variables.
91 This validator ensures default variables are kept
92 when the config provides new variables
93 """
94 default = cls.model_fields[info.field_name].get_default()
95 user_config = info.data["user_config"].get(info.field_name, [])
96 variables: List[ModelVariable] = deepcopy(default)
97 user_variables_dict = {d["name"]: d for d in user_config}
99 for i, var in enumerate(variables):
100 if var.name in user_variables_dict:
101 var_to_update_with = user_variables_dict[var.name]
102 user_config.remove(var_to_update_with)
103 var_dict = var.dict()
104 var_dict.update(var_to_update_with)
105 variables[i] = cls._types.get_default()[info.field_name](**var_dict)
106 variables.extend(
107 [cls._types.get_default()[info.field_name](**var) for var in user_config]
108 )
109 return variables
111 def get_variable_names(self):
112 """
113 Returns the names of every variable as list
114 """
115 return [
116 var.name
117 for var in self.inputs + self.outputs + self.states + self.parameters
118 ]
120 def __init__(self, **kwargs):
121 kwargs["user_config"] = kwargs.copy()
122 super().__init__(**kwargs)
125class Model(abc.ABC):
126 """
127 Base class for simulation models. To implement your
128 own model, inherit from this class.
129 """
131 config: ModelConfig
133 # pylint: disable=too-many-public-methods
135 def __init__(self, **kwargs):
136 """
137 Initializes model class
138 """
139 self._inputs = {}
140 self._outputs = {}
141 self._states = {}
142 self._parameters = {}
144 self.config = self.get_config_type()(**kwargs)
146 @classmethod
147 def get_config_type(cls) -> Type[ModelConfig]:
148 return get_type_hints(cls)["config"]
150 @abc.abstractmethod
151 def do_step(self, *, t_start: float, t_sample: float):
152 """
153 Performing one simulation step
154 Args:
155 t_start: start time for integration
156 t_sample: increment of solver integration
157 Returns:
158 """
159 raise NotImplementedError(
160 "The Model class does not implement this "
161 "because it is individual to the subclasses"
162 )
164 @abc.abstractmethod
165 def initialize(self, **kwargs):
166 """
167 Abstract method to define what to
168 do in order to initialize the model in use.
169 """
170 raise NotImplementedError(
171 "The Model class does not implement this "
172 "because it is individual to the subclasses"
173 )
175 def terminate(self):
176 """Terminate the model if applicable by subclass."""
178 def __getattr__(self, item):
179 if item in self._inputs:
180 return self._inputs.get(item)
181 if item in self._outputs:
182 return self._outputs.get(item)
183 if item in self._parameters:
184 return self._parameters.get(item)
185 if item in self._states:
186 return self._states.get(item)
187 raise AttributeError(
188 f"'{self.__class__.__name__}' object has no attribute '{item}'"
189 )
191 def generate_variables_config(self, filename: str = None, **kwargs) -> str:
192 """
193 Generate a config file (.json) to enable an user friendly
194 configuration of the model.
197 Args:
198 filename (str): Optional path where to store the config.
199 If None, current model name and workdir are used.
200 kwargs: Kwargs directly passed to the json.dump method.
201 Returns:
202 filepath (str): Filepath where the json is stored
203 """
204 if filename is None:
205 filename = os.path.join(os.getcwd(), f"{self.__class__.__name__}.json")
206 model_config = {
207 "inputs": [inp.dict() for inp in self.inputs],
208 "outputs": [out.dict() for out in self.outputs],
209 "states": [sta.dict() for sta in self.states],
210 "parameters": [par.dict() for par in self.parameters],
211 }
212 with open(filename, "w") as file:
213 json.dump(obj=model_config, fp=file, **kwargs)
214 return filename
216 @property
217 def config(self) -> ModelConfig:
218 """Get the current config, which is
219 a ModelConfig object."""
220 return self._config
222 @config.setter
223 def config(self, config: Union[dict, ModelConfig]):
224 """
225 Set a new config.
227 Args:
228 config (dict, ModelConfig): The config dict or ModelConfig object.
229 """
230 # Instantiate the ModelConfig.
231 if isinstance(config, self.get_config_type()):
232 self._config = config
233 else:
234 self._config = self.get_config_type()(**config)
235 # Update model variables.
236 self._inputs = {var.name: var for var in self.config.inputs.copy()}
237 self._outputs = {var.name: var for var in self.config.outputs.copy()}
238 self._states = {var.name: var for var in self.config.states.copy()}
239 self._parameters = {var.name: var for var in self.config.parameters.copy()}
241 @property
242 def description(self):
243 """Get model description"""
244 return self.config.description
246 @description.setter
247 def description(self, description: str):
248 """Set model description"""
249 self.config.description = description
251 @description.deleter
252 def description(self):
253 """Delete model description. Default is then used."""
254 # todo fwu do we have a use for this, or should we just get rid of deleters, and these properties alltogether?
255 self.config.description = (
256 self.get_config_type().model_fields["description"].default
257 )
259 @property
260 def name(self):
261 """Get model name"""
262 return self.config.name
264 @name.setter
265 def name(self, name: str):
266 """
267 Set the model name
268 Args:
269 name (str): Name of the model
270 """
271 self.config.name = name
273 @name.deleter
274 def name(self):
275 """Delete the model name"""
276 self.config.name = self.get_config_type().model_fields["name"].default
278 @property
279 def sim_time(self):
280 """Get the current simulation time"""
281 return self.config.sim_time
283 @sim_time.setter
284 def sim_time(self, sim_time: float):
285 """Set the current simulation time"""
286 self.config.sim_time = sim_time
288 @sim_time.deleter
289 def sim_time(self):
290 """Reset the current simulation time to the default value"""
291 self.config.sim_time = self.get_config_type().model_fields["sim_time"].default
293 @property
294 def dt(self):
295 """Get time increment of simulation"""
296 return self.config.dt
298 @property
299 def variables(self):
300 """Get all model variables as a list"""
301 return list(
302 chain.from_iterable(
303 [self.inputs, self.outputs, self.parameters, self.states]
304 )
305 )
307 @property
308 def inputs(self) -> ModelInputs:
309 """Get all model inputs as a list"""
310 return list(self._inputs.values())
312 @property
313 def outputs(self) -> ModelOutputs:
314 """Get all model outputs as a list"""
315 return list(self._outputs.values())
317 @property
318 def states(self) -> ModelStates:
319 """Get all model states as a list"""
320 return list(self._states.values())
322 @property
323 def parameters(self) -> ModelParameters:
324 """Get all model parameters as a list"""
325 return list(self._parameters.values())
327 def _create_time_samples(self, t_sample):
328 """
329 Function to generate an array of time samples
330 using the current self.dt object.
331 Note that, if self.dt is not a true divider of t_sample,
332 the output array is not equally samples.
334 Args:
335 t_sample (float): Sample
337 Returns:
339 """
340 samples = np.arange(0, t_sample, self.dt)
341 if samples[-1] == t_sample:
342 return samples
343 return np.append(samples, t_sample)
345 ##########################################################################################
346 # Getter and setter function using names for easier access
347 ##########################################################################################
348 def get_outputs(self, names: List[str]):
349 """Get model outputs based on given names."""
350 assert isinstance(names, list), "Given names are not a list"
351 return [self._outputs[name] for name in names if name in self._outputs]
353 def get_inputs(self, names: List[str]):
354 """Get model inputs based on given names."""
355 assert isinstance(names, list), "Given names are not a list"
356 return [self._inputs[name] for name in names if name in self._inputs]
358 def get_parameters(self, names: List[str]):
359 """Get model parameters based on given names."""
360 assert isinstance(names, list), "Given names are not a list"
361 return [self._parameters[name] for name in names if name in self._parameters]
363 def get_states(self, names: List[str]):
364 """Get model states based on given names."""
365 assert isinstance(names, list), "Given names are not a list"
366 return [self._states[name] for name in names if name in self._states]
368 def get_output(self, name: str):
369 """Get model output based on given name."""
370 return self._outputs.get(name, None)
372 def get_input(self, name: str):
373 """Get model input based on given name."""
374 return self._inputs.get(name, None)
376 def get_state(self, name: str):
377 """Get model state based on given name."""
378 return self._states.get(name, None)
380 def get_parameter(self, name: str):
381 """Get model parameter based on given name."""
382 return self._parameters.get(name, None)
384 def set_input_value(self, name: str, value: Union[float, int, bool]):
385 """Just used from external modules like simulator to set new input values"""
386 self.set_input_values(names=[name], values=[value])
388 def set_input_values(self, names: List[str], values: List[Union[float, int, bool]]):
389 """Just used from external modules like simulator to set new input values"""
390 self.__setter(variables=self._inputs, values=values, names=names)
392 def _set_output_value(self, name: str, value: Union[float, int, bool]):
393 """Just used internally to write output values"""
394 self._set_output_values(names=[name], values=[value])
396 def _set_output_values(
397 self, names: List[str], values: List[Union[float, int, bool]]
398 ):
399 """Just used internally to write output values"""
400 self.__setter(variables=self._outputs, values=values, names=names)
402 def _set_state_value(self, name: str, value: Union[float, int, bool]):
403 """Just used internally to write state values"""
404 self._set_state_values(names=[name], values=[value])
406 def _set_state_values(
407 self, names: List[str], values: List[Union[float, int, bool]]
408 ):
409 """Just used internally to write state values"""
410 self.__setter(variables=self._states, values=values, names=names)
412 def set_parameter_value(self, name: str, value: Union[float, int, bool]):
413 """Used externally to write new parameter values from e.g. a calibration process"""
414 self.set_parameter_values(names=[name], values=[value])
416 def set_parameter_values(
417 self, names: List[str], values: List[Union[float, int, bool]]
418 ):
419 """Used externally to write new parameter values from e.g. a calibration process"""
420 self.__setter(variables=self._parameters, values=values, names=names)
422 def __setter(
423 self,
424 variables: Dict[str, ModelVariable],
425 values: List[Union[float, int, bool]],
426 names: List[str],
427 ):
428 """General setter of model values."""
429 assert len(names) == len(
430 values
431 ), "Length of names has to equal length of values"
432 for name, value in zip(names, values):
433 if value is None:
434 logger.warning(
435 "Tried to override variable '%s' in model '%s' "
436 "with None. Keeping the previous value of %s",
437 name,
438 self.name,
439 variables[name].value,
440 )
441 continue
442 variables[name].set_value(
443 value=value, validate=self.config.validate_variables
444 )
446 def get(self, name: str) -> ModelVariable:
447 """
448 Get any variable from using name:
450 Args:
451 name (str): The item to get from config by name of Variable.
452 Hence, item=ModelVariable.name
453 Returns:
454 var (ModelVariable): The matching variable
455 Raises:
456 AttributeError: If the item was not found in the variables of the
457 module.
458 """
459 if name in self._inputs:
460 return self._inputs[name]
461 if name in self._outputs:
462 return self._outputs[name]
463 if name in self._parameters:
464 return self._parameters[name]
465 if name in self._states:
466 return self._states[name]
467 raise ValueError(
468 f"'{self.__class__.__name__}' has "
469 f"no ModelVariable with the name '{name}' "
470 f"in the config."
471 )
473 def set(self, name: str, value: Any):
474 """
475 Set any variable from using name:
477 Args:
478 name (str): The item to get from data_broker by name of Variable.
479 Hence, item=AgentVariable.name
480 value (Any): Any value to set to the Variable
481 Raises:
482 AttributeError: If the item was not found in the variables of the
483 module.
484 """
485 if name in self._inputs:
486 self.set_input_value(name=name, value=value)
487 elif name in self._outputs:
488 self._set_output_value(name=name, value=value)
489 elif name in self._parameters:
490 self.set_parameter_value(name=name, value=value)
491 elif name in self._states:
492 self._set_state_value(name=name, value=value)
493 else:
494 raise ValueError(
495 f"'{self.__class__.__name__}' has "
496 f"no ModelVariable with the name '{name}' "
497 f"in the config."
498 )
500 def get_input_names(self):
501 """
502 Returns:
503 names (list): A list containing all input names
504 """
505 return list(self._inputs.keys())
507 def get_output_names(self):
508 """
509 Returns:
510 names (list): A list containing all output names
511 """
512 return list(self._outputs.keys())
514 def get_state_names(self):
515 """
516 Returns:
517 names (list): A list containing all state names
518 """
519 return list(self._states.keys())
521 def get_parameter_names(self):
522 """
523 Returns:
524 names (list): A list containing all state names
525 """
526 return list(self._parameters.keys())