Coverage for agentlib/core/model.py: 94%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1"""This module contains just the basic Model.""" 

2 

3import abc 

4import os 

5import json 

6import logging 

7from copy import deepcopy 

8from itertools import chain 

9from typing import Union, List, Dict, Any, Optional, get_type_hints, Type 

10from pydantic import ConfigDict, BaseModel, Field, field_validator 

11import numpy as np 

12from pydantic.fields import PrivateAttr 

13from pydantic_core.core_schema import FieldValidationInfo 

14 

15from agentlib.core.datamodels import ( 

16 ModelVariable, 

17 ModelInputs, 

18 ModelStates, 

19 ModelOutputs, 

20 ModelParameters, 

21 ModelState, 

22 ModelParameter, 

23 ModelOutput, 

24 ModelInput, 

25) 

26 

27logger = logging.getLogger(__name__) 

28 

29 

30class ModelConfig(BaseModel): 

31 """ 

32 Pydantic data model for controller configuration parser 

33 """ 

34 

35 user_config: dict = Field( 

36 default=None, 

37 description="The config given by the user to instantiate this class." 

38 "Will be stored to enable a valid overwriting of the " 

39 "default config and to better restart modules." 

40 "Is also useful to debug validators and the general BaseModuleConfig.", 

41 ) 

42 name: Optional[str] = Field(default=None, validate_default=True) 

43 description: str = Field(default="You forgot to document your model!") 

44 sim_time: float = Field(default=0, title="Current simulation time") 

45 dt: Union[float, int] = Field(default=1, title="time increment") 

46 validate_variables: bool = Field( 

47 default=True, 

48 title="Validate Variables", 

49 description="If true, the validator of a variables value is called whenever a " 

50 "new value is set. Disabled by default for performance reasons.", 

51 ) 

52 

53 inputs: ModelInputs = Field(default=list()) 

54 outputs: ModelOutputs = Field(default=list()) 

55 states: ModelStates = Field(default=list()) 

56 parameters: ModelParameters = Field(default=list()) 

57 

58 _types: Dict[str, type] = PrivateAttr( 

59 default={ 

60 "inputs": ModelInput, 

61 "outputs": ModelOutput, 

62 "states": ModelState, 

63 "parameters": ModelParameter, 

64 } 

65 ) 

66 model_config = ConfigDict( 

67 validate_assignment=True, arbitrary_types_allowed=True, extra="forbid" 

68 ) 

69 

70 @field_validator("name") 

71 @classmethod 

72 def check_name(cls, name): 

73 """ 

74 Check if name of model is given. If not, use the 

75 name of the model class. 

76 """ 

77 if name is None: 

78 name = str(cls).replace("Config", "") 

79 return name 

80 

81 @field_validator("parameters", "inputs", "outputs", "states", mode="after") 

82 @classmethod 

83 def include_default_model_variables( 

84 cls, _: List[ModelVariable], info: FieldValidationInfo 

85 ): 

86 """ 

87 Validator building block to merge default variables with config variables in a standard validator. 

88 Updates default variables when a variable with the same name is present in the config. 

89 Then returns the union of the default variables and the external config variables. 

90 

91 This validator ensures default variables are kept 

92 when the config provides new variables 

93 """ 

94 default = cls.model_fields[info.field_name].get_default() 

95 user_config = info.data["user_config"].get(info.field_name, []) 

96 variables: List[ModelVariable] = deepcopy(default) 

97 user_variables_dict = {d["name"]: d for d in user_config} 

98 

99 for i, var in enumerate(variables): 

100 if var.name in user_variables_dict: 

101 var_to_update_with = user_variables_dict[var.name] 

102 user_config.remove(var_to_update_with) 

103 var_dict = var.dict() 

104 var_dict.update(var_to_update_with) 

105 variables[i] = cls._types.get_default()[info.field_name](**var_dict) 

106 variables.extend( 

107 [cls._types.get_default()[info.field_name](**var) for var in user_config] 

108 ) 

109 return variables 

110 

111 def get_variable_names(self): 

112 """ 

113 Returns the names of every variable as list 

114 """ 

115 return [ 

116 var.name 

117 for var in self.inputs + self.outputs + self.states + self.parameters 

118 ] 

119 

120 def __init__(self, **kwargs): 

121 kwargs["user_config"] = kwargs.copy() 

122 super().__init__(**kwargs) 

123 

124 

125class Model(abc.ABC): 

126 """ 

127 Base class for simulation models. To implement your 

128 own model, inherit from this class. 

129 """ 

130 

131 config: ModelConfig 

132 

133 # pylint: disable=too-many-public-methods 

134 

135 def __init__(self, **kwargs): 

136 """ 

137 Initializes model class 

138 """ 

139 self._inputs = {} 

140 self._outputs = {} 

141 self._states = {} 

142 self._parameters = {} 

143 

144 self.config = self.get_config_type()(**kwargs) 

145 

146 @classmethod 

147 def get_config_type(cls) -> Type[ModelConfig]: 

148 return get_type_hints(cls)["config"] 

149 

150 @abc.abstractmethod 

151 def do_step(self, *, t_start: float, t_sample: float): 

152 """ 

153 Performing one simulation step 

154 Args: 

155 t_start: start time for integration 

156 t_sample: increment of solver integration 

157 Returns: 

158 """ 

159 raise NotImplementedError( 

160 "The Model class does not implement this " 

161 "because it is individual to the subclasses" 

162 ) 

163 

164 @abc.abstractmethod 

165 def initialize(self, **kwargs): 

166 """ 

167 Abstract method to define what to 

168 do in order to initialize the model in use. 

169 """ 

170 raise NotImplementedError( 

171 "The Model class does not implement this " 

172 "because it is individual to the subclasses" 

173 ) 

174 

175 def terminate(self): 

176 """Terminate the model if applicable by subclass.""" 

177 

178 def __getattr__(self, item): 

179 if item in self._inputs: 

180 return self._inputs.get(item) 

181 if item in self._outputs: 

182 return self._outputs.get(item) 

183 if item in self._parameters: 

184 return self._parameters.get(item) 

185 if item in self._states: 

186 return self._states.get(item) 

187 raise AttributeError( 

188 f"'{self.__class__.__name__}' object has no attribute '{item}'" 

189 ) 

190 

191 def generate_variables_config(self, filename: str = None, **kwargs) -> str: 

192 """ 

193 Generate a config file (.json) to enable an user friendly 

194 configuration of the model. 

195 

196 

197 Args: 

198 filename (str): Optional path where to store the config. 

199 If None, current model name and workdir are used. 

200 kwargs: Kwargs directly passed to the json.dump method. 

201 Returns: 

202 filepath (str): Filepath where the json is stored 

203 """ 

204 if filename is None: 

205 filename = os.path.join(os.getcwd(), f"{self.__class__.__name__}.json") 

206 model_config = { 

207 "inputs": [inp.dict() for inp in self.inputs], 

208 "outputs": [out.dict() for out in self.outputs], 

209 "states": [sta.dict() for sta in self.states], 

210 "parameters": [par.dict() for par in self.parameters], 

211 } 

212 with open(filename, "w") as file: 

213 json.dump(obj=model_config, fp=file, **kwargs) 

214 return filename 

215 

216 @property 

217 def config(self) -> ModelConfig: 

218 """Get the current config, which is 

219 a ModelConfig object.""" 

220 return self._config 

221 

222 @config.setter 

223 def config(self, config: Union[dict, ModelConfig]): 

224 """ 

225 Set a new config. 

226 

227 Args: 

228 config (dict, ModelConfig): The config dict or ModelConfig object. 

229 """ 

230 # Instantiate the ModelConfig. 

231 if isinstance(config, self.get_config_type()): 

232 self._config = config 

233 else: 

234 self._config = self.get_config_type()(**config) 

235 # Update model variables. 

236 self._inputs = {var.name: var for var in self.config.inputs.copy()} 

237 self._outputs = {var.name: var for var in self.config.outputs.copy()} 

238 self._states = {var.name: var for var in self.config.states.copy()} 

239 self._parameters = {var.name: var for var in self.config.parameters.copy()} 

240 

241 @property 

242 def description(self): 

243 """Get model description""" 

244 return self.config.description 

245 

246 @description.setter 

247 def description(self, description: str): 

248 """Set model description""" 

249 self.config.description = description 

250 

251 @description.deleter 

252 def description(self): 

253 """Delete model description. Default is then used.""" 

254 # todo fwu do we have a use for this, or should we just get rid of deleters, and these properties alltogether? 

255 self.config.description = ( 

256 self.get_config_type().model_fields["description"].default 

257 ) 

258 

259 @property 

260 def name(self): 

261 """Get model name""" 

262 return self.config.name 

263 

264 @name.setter 

265 def name(self, name: str): 

266 """ 

267 Set the model name 

268 Args: 

269 name (str): Name of the model 

270 """ 

271 self.config.name = name 

272 

273 @name.deleter 

274 def name(self): 

275 """Delete the model name""" 

276 self.config.name = self.get_config_type().model_fields["name"].default 

277 

278 @property 

279 def sim_time(self): 

280 """Get the current simulation time""" 

281 return self.config.sim_time 

282 

283 @sim_time.setter 

284 def sim_time(self, sim_time: float): 

285 """Set the current simulation time""" 

286 self.config.sim_time = sim_time 

287 

288 @sim_time.deleter 

289 def sim_time(self): 

290 """Reset the current simulation time to the default value""" 

291 self.config.sim_time = self.get_config_type().model_fields["sim_time"].default 

292 

293 @property 

294 def dt(self): 

295 """Get time increment of simulation""" 

296 return self.config.dt 

297 

298 @property 

299 def variables(self): 

300 """Get all model variables as a list""" 

301 return list( 

302 chain.from_iterable( 

303 [self.inputs, self.outputs, self.parameters, self.states] 

304 ) 

305 ) 

306 

307 @property 

308 def inputs(self) -> ModelInputs: 

309 """Get all model inputs as a list""" 

310 return list(self._inputs.values()) 

311 

312 @property 

313 def outputs(self) -> ModelOutputs: 

314 """Get all model outputs as a list""" 

315 return list(self._outputs.values()) 

316 

317 @property 

318 def states(self) -> ModelStates: 

319 """Get all model states as a list""" 

320 return list(self._states.values()) 

321 

322 @property 

323 def parameters(self) -> ModelParameters: 

324 """Get all model parameters as a list""" 

325 return list(self._parameters.values()) 

326 

327 def _create_time_samples(self, t_sample): 

328 """ 

329 Function to generate an array of time samples 

330 using the current self.dt object. 

331 Note that, if self.dt is not a true divider of t_sample, 

332 the output array is not equally samples. 

333 

334 Args: 

335 t_sample (float): Sample 

336 

337 Returns: 

338 

339 """ 

340 samples = np.arange(0, t_sample, self.dt) 

341 if samples[-1] == t_sample: 

342 return samples 

343 return np.append(samples, t_sample) 

344 

345 ########################################################################################## 

346 # Getter and setter function using names for easier access 

347 ########################################################################################## 

348 def get_outputs(self, names: List[str]): 

349 """Get model outputs based on given names.""" 

350 assert isinstance(names, list), "Given names are not a list" 

351 return [self._outputs[name] for name in names if name in self._outputs] 

352 

353 def get_inputs(self, names: List[str]): 

354 """Get model inputs based on given names.""" 

355 assert isinstance(names, list), "Given names are not a list" 

356 return [self._inputs[name] for name in names if name in self._inputs] 

357 

358 def get_parameters(self, names: List[str]): 

359 """Get model parameters based on given names.""" 

360 assert isinstance(names, list), "Given names are not a list" 

361 return [self._parameters[name] for name in names if name in self._parameters] 

362 

363 def get_states(self, names: List[str]): 

364 """Get model states based on given names.""" 

365 assert isinstance(names, list), "Given names are not a list" 

366 return [self._states[name] for name in names if name in self._states] 

367 

368 def get_output(self, name: str): 

369 """Get model output based on given name.""" 

370 return self._outputs.get(name, None) 

371 

372 def get_input(self, name: str): 

373 """Get model input based on given name.""" 

374 return self._inputs.get(name, None) 

375 

376 def get_state(self, name: str): 

377 """Get model state based on given name.""" 

378 return self._states.get(name, None) 

379 

380 def get_parameter(self, name: str): 

381 """Get model parameter based on given name.""" 

382 return self._parameters.get(name, None) 

383 

384 def set_input_value(self, name: str, value: Union[float, int, bool]): 

385 """Just used from external modules like simulator to set new input values""" 

386 self.set_input_values(names=[name], values=[value]) 

387 

388 def set_input_values(self, names: List[str], values: List[Union[float, int, bool]]): 

389 """Just used from external modules like simulator to set new input values""" 

390 self.__setter(variables=self._inputs, values=values, names=names) 

391 

392 def _set_output_value(self, name: str, value: Union[float, int, bool]): 

393 """Just used internally to write output values""" 

394 self._set_output_values(names=[name], values=[value]) 

395 

396 def _set_output_values( 

397 self, names: List[str], values: List[Union[float, int, bool]] 

398 ): 

399 """Just used internally to write output values""" 

400 self.__setter(variables=self._outputs, values=values, names=names) 

401 

402 def _set_state_value(self, name: str, value: Union[float, int, bool]): 

403 """Just used internally to write state values""" 

404 self._set_state_values(names=[name], values=[value]) 

405 

406 def _set_state_values( 

407 self, names: List[str], values: List[Union[float, int, bool]] 

408 ): 

409 """Just used internally to write state values""" 

410 self.__setter(variables=self._states, values=values, names=names) 

411 

412 def set_parameter_value(self, name: str, value: Union[float, int, bool]): 

413 """Used externally to write new parameter values from e.g. a calibration process""" 

414 self.set_parameter_values(names=[name], values=[value]) 

415 

416 def set_parameter_values( 

417 self, names: List[str], values: List[Union[float, int, bool]] 

418 ): 

419 """Used externally to write new parameter values from e.g. a calibration process""" 

420 self.__setter(variables=self._parameters, values=values, names=names) 

421 

422 def __setter( 

423 self, 

424 variables: Dict[str, ModelVariable], 

425 values: List[Union[float, int, bool]], 

426 names: List[str], 

427 ): 

428 """General setter of model values.""" 

429 assert len(names) == len( 

430 values 

431 ), "Length of names has to equal length of values" 

432 for name, value in zip(names, values): 

433 if value is None: 

434 logger.warning( 

435 "Tried to override variable '%s' in model '%s' " 

436 "with None. Keeping the previous value of %s", 

437 name, 

438 self.name, 

439 variables[name].value, 

440 ) 

441 continue 

442 variables[name].set_value( 

443 value=value, validate=self.config.validate_variables 

444 ) 

445 

446 def get(self, name: str) -> ModelVariable: 

447 """ 

448 Get any variable from using name: 

449 

450 Args: 

451 name (str): The item to get from config by name of Variable. 

452 Hence, item=ModelVariable.name 

453 Returns: 

454 var (ModelVariable): The matching variable 

455 Raises: 

456 AttributeError: If the item was not found in the variables of the 

457 module. 

458 """ 

459 if name in self._inputs: 

460 return self._inputs[name] 

461 if name in self._outputs: 

462 return self._outputs[name] 

463 if name in self._parameters: 

464 return self._parameters[name] 

465 if name in self._states: 

466 return self._states[name] 

467 raise ValueError( 

468 f"'{self.__class__.__name__}' has " 

469 f"no ModelVariable with the name '{name}' " 

470 f"in the config." 

471 ) 

472 

473 def set(self, name: str, value: Any): 

474 """ 

475 Set any variable from using name: 

476 

477 Args: 

478 name (str): The item to get from data_broker by name of Variable. 

479 Hence, item=AgentVariable.name 

480 value (Any): Any value to set to the Variable 

481 Raises: 

482 AttributeError: If the item was not found in the variables of the 

483 module. 

484 """ 

485 if name in self._inputs: 

486 self.set_input_value(name=name, value=value) 

487 elif name in self._outputs: 

488 self._set_output_value(name=name, value=value) 

489 elif name in self._parameters: 

490 self.set_parameter_value(name=name, value=value) 

491 elif name in self._states: 

492 self._set_state_value(name=name, value=value) 

493 else: 

494 raise ValueError( 

495 f"'{self.__class__.__name__}' has " 

496 f"no ModelVariable with the name '{name}' " 

497 f"in the config." 

498 ) 

499 

500 def get_input_names(self): 

501 """ 

502 Returns: 

503 names (list): A list containing all input names 

504 """ 

505 return list(self._inputs.keys()) 

506 

507 def get_output_names(self): 

508 """ 

509 Returns: 

510 names (list): A list containing all output names 

511 """ 

512 return list(self._outputs.keys()) 

513 

514 def get_state_names(self): 

515 """ 

516 Returns: 

517 names (list): A list containing all state names 

518 """ 

519 return list(self._states.keys()) 

520 

521 def get_parameter_names(self): 

522 """ 

523 Returns: 

524 names (list): A list containing all state names 

525 """ 

526 return list(self._parameters.keys())