Coverage for agentlib/core/module.py: 95%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-11-07 11:57 +0000

1"""This module contains the base AgentModule.""" 

2 

3from __future__ import annotations 

4 

5import abc 

6import json 

7import logging 

8from copy import deepcopy 

9from typing import ( 

10 TYPE_CHECKING, 

11 List, 

12 Dict, 

13 Union, 

14 Any, 

15 TypeVar, 

16 Optional, 

17 get_type_hints, 

18 Type, 

19) 

20 

21import pydantic 

22from pydantic import field_validator, ConfigDict, BaseModel, Field, PrivateAttr 

23from pydantic.json_schema import GenerateJsonSchema 

24from pydantic_core import core_schema 

25 

26import agentlib.core.logging_ as agentlib_logging 

27from agentlib.core import datamodels 

28from agentlib.core.datamodels import ( 

29 AgentVariable, 

30 Source, 

31 AgentVariables, 

32 AttrsToPydanticAdaptor, 

33) 

34from agentlib.core.environment import CustomSimpyEnvironment 

35from agentlib.core.errors import ConfigurationError 

36from agentlib.utils.fuzzy_matching import fuzzy_match, RAPIDFUZZ_IS_INSTALLED 

37from agentlib.utils.validators import ( 

38 include_defaults_in_root, 

39 update_default_agent_variable, 

40 is_list_of_agent_variables, 

41 is_valid_agent_var_config, 

42) 

43 

44if TYPE_CHECKING: 

45 # this avoids circular import 

46 from agentlib.core import Agent 

47 

48 

49logger = logging.getLogger(__name__) 

50 

51 

52class BaseModuleConfig(BaseModel): 

53 """ 

54 Pydantic data model for basic module configuration 

55 """ 

56 

57 # The type is relevant to load the correct module class. 

58 type: Union[str, Dict[str, str]] = Field( 

59 title="Type", 

60 description="The type of the Module. Used to find the Python-Object " 

61 "from all agentlib-core and plugin Module options. If a dict is given," 

62 "it must contain the keys 'file' and 'class_name'. " 

63 "'file' is the filepath of a python file containing the Module." 

64 "'class_name' is the name of the Module class within this file.", 

65 ) 

66 # A module is uniquely identified in the MAS using agent_id and module_id. 

67 # The module_id should be unique inside one agent. 

68 # This is checked inside the agent-class. 

69 module_id: str = Field( 

70 description="The unqiue id of the module within an agent, " 

71 "used only to communicate withing the agent." 

72 ) 

73 validate_incoming_values: Optional[bool] = Field( 

74 default=True, 

75 title="Validate Incoming Values", 

76 description="If true, the validator of the AgentVariable value is called when " 

77 "receiving a new value from the DataBroker.", 

78 ) 

79 log_level: Optional[str] = Field( 

80 default=None, 

81 description="The log level for this Module. " 

82 "Default uses the root-loggers level." 

83 "Options: DEBUG; INFO; WARNING; ERROR; CRITICAL", 

84 ) 

85 shared_variable_fields: List[str] = Field( 

86 default=[], 

87 description="A list of strings with each string being a field of the Modules configs. " 

88 "The field must be or contain an AgentVariable. If the field is added to this list, " 

89 "all shared attributes of the AgentVariables will be set to True.", 

90 validate_default=True, 

91 ) 

92 # Aggregation of all instances of an AgentVariable in this Config 

93 _variables: AgentVariables = PrivateAttr(default=[]) 

94 

95 # The config given by the user to instantiate this class. 

96 # Will be stored to enable a valid overwriting of the 

97 # default config and to better restart modules. 

98 # Is also useful to debug validators and the general BaseModuleConfig. 

99 _user_config: dict = PrivateAttr(default=None) 

100 

101 model_config = ConfigDict( 

102 arbitrary_types_allowed=True, 

103 validate_assignment=True, 

104 extra="forbid", 

105 frozen=True, 

106 ) 

107 

108 def get_variables(self): 

109 """Return the private attribute with all AgentVariables""" 

110 return self._variables 

111 

112 @classmethod 

113 def model_json_schema(cls, *args, **kwargs) -> dict: 

114 """ 

115 Custom schema method to 

116 - Add JSON Schema for custom attrs types Source and AgentVariable 

117 - put log_level last, as it is the only optional field of the module config. 

118 Used to better display relevant options of children classes in GUIs. 

119 """ 

120 if "schema_generator" in kwargs: 

121 raise ValueError("Custom schema_generator is not supported for BaseModule.") 

122 

123 class CustomGenerateJsonSchema(GenerateJsonSchema): 

124 """ 

125 This class in necessary, as the default object type 

126 AttrsToPydanticAdaptor (e.g. Source, AgentVariable) are 

127 not json serializable by default. 

128 """ 

129 

130 def default_schema(self, schema: core_schema.WithDefaultSchema): 

131 if "default" in schema: 

132 _default = schema["default"] 

133 if isinstance(_default, AttrsToPydanticAdaptor): 

134 schema["default"] = _default.json() 

135 return super().default_schema(schema=schema) 

136 

137 kwargs["schema_generator"] = CustomGenerateJsonSchema 

138 schema = super().model_json_schema(*args, **kwargs) 

139 definitions = schema.get("$defs", {}) 

140 definitions_out = {} 

141 for class_name, metadata in definitions.items(): 

142 if class_name in datamodels.ATTRS_MODELS: 

143 class_object: AttrsToPydanticAdaptor = getattr(datamodels, class_name) 

144 metadata = class_object.get_json_schema() 

145 definitions_out[class_name] = metadata 

146 if definitions_out: 

147 schema["$defs"] = definitions_out 

148 

149 log_level = schema["properties"].pop("log_level") 

150 shared_variable_fields = schema["properties"].pop("shared_variable_fields") 

151 schema["properties"]["shared_variable_fields"] = shared_variable_fields 

152 schema["properties"]["log_level"] = log_level 

153 return schema 

154 

155 @classmethod 

156 def check_if_variables_are_unique(cls, names): 

157 """Check if a given iterable of AgentVariables have a 

158 unique name.""" 

159 if len(names) != len(set(names)): 

160 for name in set(names.copy()): 

161 names.remove(name) 

162 raise ValueError( 

163 f"{cls.__name__} contains variables with the same name. The " 

164 f"following appear at least twice: {' ,'.join(names)}" 

165 ) 

166 

167 @field_validator("shared_variable_fields") 

168 @classmethod 

169 def check_valid_fields(cls, shared_variables_fields): 

170 """ 

171 Check if the shared_variables_fields are valid 

172 fields. 

173 """ 

174 wrong_public_fields = set(shared_variables_fields).difference( 

175 cls.model_fields.keys() 

176 ) 

177 if wrong_public_fields: 

178 raise ConfigurationError( 

179 f"Public fields {wrong_public_fields} do not exist. Maybe you " 

180 f"misspelled them?" 

181 ) 

182 return shared_variables_fields 

183 

184 @field_validator("log_level") 

185 @classmethod 

186 def check_valid_level(cls, log_level: str): 

187 """ 

188 Check if the given log_level is valid 

189 """ 

190 if log_level is None: 

191 return log_level 

192 log_level = log_level.upper() 

193 if not isinstance(logging.getLevelName(log_level), int): 

194 raise ValueError( 

195 f"Given log level '{log_level}' is not " 

196 f"supported by logging library." 

197 ) 

198 return log_level 

199 

200 @classmethod 

201 def merge_variables( 

202 cls, 

203 pre_validated_instance: BaseModuleConfig, 

204 user_config: dict, 

205 agent_id: str, 

206 shared_variable_fields: List[str], 

207 ): 

208 """ 

209 Merge, rigorously check and validate the input of 

210 all AgentVariables into the module. 

211 This function: 

212 

213 - Collects all variables 

214 - Checks if duplicate names (will cause errors in the get() function. 

215 """ 

216 _vars = [] 

217 # Extract all variables from fields 

218 for field_name, field in cls.model_fields.items(): 

219 # If field is missing in values, validation of field was not 

220 # successful. Continue and pydantic will later raise the ValidationError 

221 if field_name not in pre_validated_instance.model_fields: 

222 continue 

223 

224 pre_merged_attr = pre_validated_instance.__getattribute__(field_name) 

225 # we need the type if plugins subclass the AgentVariable 

226 

227 if isinstance(pre_merged_attr, AgentVariable): 

228 update_var_with = user_config.get(field_name, {}) 

229 

230 make_shared = field_name in shared_variable_fields 

231 

232 var = update_default_agent_variable( 

233 default_var=field.default, 

234 user_data=update_var_with, 

235 make_shared=make_shared, 

236 agent_id=agent_id, 

237 field_name=field_name, 

238 ) 

239 _vars.append(var) 

240 pre_validated_instance.__setattr__(field_name, var) 

241 

242 elif is_list_of_agent_variables(pre_merged_attr): 

243 user_config_var_dicts = user_config.get(field_name, []) 

244 type_ = pre_merged_attr[0].__class__ 

245 update_vars_with = [ 

246 conf 

247 for conf in user_config_var_dicts 

248 if is_valid_agent_var_config(conf, field_name, type_) 

249 ] 

250 

251 make_shared = field_name in shared_variable_fields 

252 variables = include_defaults_in_root( 

253 variables=update_vars_with, 

254 field=field, 

255 type_=type_, # subtype of AgentVariable 

256 make_shared=make_shared, 

257 agent_id=agent_id, 

258 field_name=field_name, 

259 ) 

260 

261 _vars.extend(variables) 

262 pre_validated_instance.__setattr__(field_name, variables) 

263 

264 # Extract names 

265 variable_names = [var.name for var in _vars] 

266 

267 # First check if names exists more than once 

268 cls.check_if_variables_are_unique(names=variable_names) 

269 

270 for _var in _vars: 

271 # case the agent id is a different agent 

272 if (_var.source.agent_id != agent_id) and ( 

273 _var.source.module_id is not None 

274 ): 

275 logger.warning( 

276 "Setting given module_id '%s' in variable '%s' to None. " 

277 "You can not specify module_ids of other agents.", 

278 _var.source.module_id, 

279 _var.name, 

280 ) 

281 _var.source = Source(agent_id=_var.source.agent_id) 

282 

283 return _vars 

284 

285 @classmethod 

286 def default(cls, field: str): 

287 return cls.model_fields[field].get_default() 

288 

289 def __init__(self, _agent_id, *args, **kwargs): 

290 _user_config = kwargs.copy() 

291 try: 

292 super().__init__(*args, **kwargs) 

293 except pydantic.ValidationError as e: 

294 better_error = self._improve_extra_field_error_messages( 

295 e, agent_id=_agent_id, module_id=_user_config.get("module_id") 

296 ) 

297 raise better_error 

298 

299 # Enable mutation 

300 self.model_config["frozen"] = False 

301 self._variables = self.__class__.merge_variables( 

302 pre_validated_instance=self, 

303 user_config=_user_config, 

304 agent_id=_agent_id, 

305 shared_variable_fields=self.shared_variable_fields, 

306 ) 

307 self._user_config = _user_config 

308 # Disable mutation 

309 self.model_config["frozen"] = True 

310 

311 @classmethod 

312 def _improve_extra_field_error_messages( 

313 cls, e: pydantic.ValidationError, agent_id: str, module_id: str 

314 ) -> pydantic.ValidationError: 

315 """Checks the validation errors for invalid fields and adds suggestions for 

316 correct field names to the error message.""" 

317 

318 error_list = e.errors() 

319 if module_id is not None: 

320 title = f"configuration of agent '{agent_id}' / module '{module_id}':" 

321 else: 

322 title = f"configuration of agent '{agent_id}':" 

323 if RAPIDFUZZ_IS_INSTALLED: 

324 for error in error_list: 

325 if not error["type"] == "extra_forbidden": 

326 continue 

327 

328 # change error type to literal because it allows for context 

329 error["type"] = "literal_error" 

330 # pydantic automatically prints the __dict__ of an error, so it is 

331 # sufficient to just assign the suggestions to an arbitrary attribute of 

332 # the error 

333 suggestions = fuzzy_match( 

334 target=error["loc"][0], choices=cls.model_fields.keys() 

335 ) 

336 if suggestions: 

337 error["ctx"] = { 

338 "expected": f"a valid Field name. Field '{error['loc'][0]}' does " 

339 f"not exist. Did you mean any of {suggestions}?" 

340 } 

341 

342 return pydantic.ValidationError.from_exception_data( 

343 title=title, line_errors=error_list 

344 ) 

345 

346 

347BaseModuleConfigClass = TypeVar("BaseModuleConfigClass", bound=BaseModuleConfig) 

348 

349 

350class BaseModule(abc.ABC): 

351 """ 

352 Basic module used by any agent. 

353 Besides a common configuration, where ids 

354 and variables are defined, this class manages 

355 the setting and getting of variables and relevant 

356 attributes. 

357 """ 

358 

359 # pylint: disable=too-many-public-methods 

360 

361 def __init__(self, *, config: dict, agent: Agent): 

362 self._agent = agent 

363 self.logger = agentlib_logging.create_logger( 

364 env=self.env, name=f"{self.agent.id}/{config['module_id']}" 

365 ) 

366 self.config = config # evokes the config setter 

367 # Add process to environment 

368 self.env.process(self.process()) 

369 self.register_callbacks() 

370 

371 ############################################################################ 

372 # Methods to inherit by subclasses 

373 ############################################################################ 

374 

375 @classmethod 

376 def get_config_type(cls) -> Type[BaseModuleConfigClass]: 

377 if hasattr(cls, "config_type"): 

378 raise AttributeError( 

379 "The 'config_type' attribute is deprecated and has been removed. " 

380 "Please use the following syntax to assign the config of your custom " 

381 f"module '{cls.__name__}': \n" 

382 "class MyModule(agentlib.BaseModule):\n" 

383 " config: MyConfigClass\n" 

384 ) 

385 

386 return get_type_hints(cls).get("config") 

387 

388 @abc.abstractmethod 

389 def register_callbacks(self): 

390 raise NotImplementedError("Needs to be implemented by derived modules") 

391 

392 @abc.abstractmethod 

393 def process(self): 

394 """This abstract method must be implemented in order to sync the module 

395 with the other processes of the agent and the whole MAS.""" 

396 raise NotImplementedError("Needs to be implemented by derived modules") 

397 

398 def terminate(self): 

399 """ 

400 Terminate all relevant processes of the module. 

401 This is necessary to correctly terminate an agent 

402 at runtime. Not all modules may need this, hence it is 

403 not an abstract method. 

404 """ 

405 self.logger.info( 

406 "Successfully terminated module %s in agent %s", self.id, self.agent.id 

407 ) 

408 

409 ############################################################################ 

410 # Properties 

411 ############################################################################ 

412 

413 @property 

414 def agent(self) -> Agent: 

415 """Get the agent this module is located in.""" 

416 return self._agent 

417 

418 @property 

419 def config(self) -> BaseModuleConfigClass: 

420 """ 

421 The module config. 

422 

423 Returns: 

424 BaseModuleConfigClass: Config of type self.config_type 

425 """ 

426 return self._config 

427 

428 @config.setter 

429 def config(self, config: Union[BaseModuleConfig, dict, str]): 

430 """Set a new config""" 

431 if self.get_config_type() is None: 

432 raise ConfigurationError( 

433 "The module has no valid config. Please make sure you " 

434 "specify the class attribute 'config' when writing your module." 

435 ) 

436 if isinstance(config, str): 

437 config = json.loads(config) 

438 self._config = self.get_config_type()(_agent_id=self.agent.id, **config) 

439 

440 # Update variables: 

441 self._variables_dict: Dict[str, AgentVariable] = self._copy_list_to_dict( 

442 self.config.get_variables() 

443 ) 

444 # Now de-and re-register all callbacks: 

445 self._register_variable_callbacks() 

446 

447 # Set log-level 

448 if self.config.log_level is not None: 

449 if not logging.getLogger().hasHandlers(): 

450 _root_lvl_int = logging.getLogger().level 

451 _log_lvl_int = logging.getLevelName(self.config.log_level) 

452 if _log_lvl_int < _root_lvl_int: 

453 self.logger.error( 

454 "Log level '%s' is below root loggers level '%s'. " 

455 "Without calling logging.basicConfig, " 

456 "logs will not be printed.", 

457 self.config.log_level, 

458 logging.getLevelName(_root_lvl_int), 

459 ) 

460 self.logger.setLevel(self.config.log_level) 

461 

462 # Call the after config update: 

463 self._after_config_update() 

464 

465 def _after_config_update(self): 

466 """ 

467 This function is called after the config of 

468 the module is updated. 

469 

470 Overwrite this function to enable custom behaviour 

471 after your config is updated. 

472 For instance, a simulator may re-initialize it's model, 

473 or a coordinator in an ADMM-MAS send new settings to 

474 the participants. 

475 

476 Returns nothing, the config is immutable 

477 """ 

478 

479 def _register_variable_callbacks(self): 

480 """ 

481 This functions de-registers and then re-registers 

482 callbacks for all variables of the module to 

483 update their specific values. 

484 """ 

485 # Keep everything in THAT order!! 

486 for name, var in self._variables_dict.items(): 

487 self.agent.data_broker.deregister_callback( 

488 alias=var.alias, 

489 source=var.source, 

490 callback=self._callback_config_vars, 

491 name=name, 

492 ) 

493 for name, var in self._variables_dict.items(): 

494 self.agent.data_broker.register_callback( 

495 alias=var.alias, 

496 source=var.source, 

497 callback=self._callback_config_vars, 

498 name=name, 

499 _unsafe_no_copy=True, 

500 ) 

501 

502 @property 

503 def env(self) -> CustomSimpyEnvironment: 

504 """Get the environment of the agent.""" 

505 return self.agent.env 

506 

507 @property 

508 def id(self) -> str: 

509 """Get the module's id""" 

510 return self.config.module_id 

511 

512 @property 

513 def source(self) -> Source: 

514 """Get the source of the module, 

515 containing the agent and module id""" 

516 return Source(agent_id=self.agent.id, module_id=self.id) 

517 

518 @property 

519 def variables(self) -> List[AgentVariable]: 

520 """Return all values as a list.""" 

521 return [v.copy() for v in self._variables_dict.values()] 

522 

523 ############################################################################ 

524 # Get, set and updaters 

525 ############################################################################ 

526 def get(self, name: str) -> AgentVariable: 

527 """ 

528 Get any variable matching the given name: 

529 

530 Args: 

531 name (str): The item to get by name of Variable. 

532 Hence, item=AgentVariable.name 

533 Returns: 

534 var (AgentVariable): The matching variable 

535 Raises: 

536 KeyError: If the item was not found in the variables of the 

537 module. 

538 """ 

539 try: 

540 return self._variables_dict[name].copy() 

541 except KeyError as err: 

542 raise KeyError( 

543 f"'{self.__class__.__name__}' has " 

544 f"no AgentVariable with the name '{name}' " 

545 f"in the configs variables." 

546 ) from err 

547 

548 def get_value(self, name: str) -> Any: 

549 """ 

550 Get the value of the variable matching the given name: 

551 

552 Args: 

553 name (str): The item to get by name of Variable. 

554 Hence, item=AgentVariable.name 

555 Returns: 

556 var (Any): The matching value 

557 Raises: 

558 KeyError: If the item was not found in the variables of the 

559 module. 

560 """ 

561 try: 

562 return deepcopy(self._variables_dict[name].value) 

563 except KeyError as err: 

564 raise KeyError( 

565 f"'{self.__class__.__name__}' has " 

566 f"no AgentVariable with the name '{name}' " 

567 f"in the configs variables." 

568 ) from err 

569 

570 def set(self, name: str, value: Any, timestamp: float = None): 

571 """ 

572 Set any variable by using the name: 

573 

574 Args: 

575 name (str): The item to get by name of Variable. 

576 Hence, item=AgentVariable.name 

577 value (Any): Any value to set to the Variable 

578 timestamp (float): The timestamp associated with the variable. 

579 If None, current environment time is used. 

580 

581 Raises: 

582 AttributeError: If the item was not found in the variables of the 

583 module. 

584 """ 

585 # var = self.get(name) 

586 var = self._variables_dict[name] 

587 var = self._update_relevant_values( 

588 variable=var, value=value, timestamp=timestamp 

589 ) 

590 self.agent.data_broker.send_variable( 

591 variable=var.copy(update={"source": self.source}), 

592 copy=False, 

593 ) 

594 

595 ############################################################################ 

596 # Private and or static class methods 

597 ############################################################################ 

598 

599 def _update_relevant_values( 

600 self, variable: AgentVariable, value: Any, timestamp: float = None 

601 ): 

602 """ 

603 Update the given variables fields 

604 with the given value (and possibly timestamp) 

605 Args: 

606 variable (AgentVariable): The variable to be updated. 

607 value (Any): Any value to set to the Variable 

608 timestamp (float): The timestamp associated with the variable. 

609 If None, current environment time is used. 

610 

611 Returns: 

612 AgentVariable: The updated variable 

613 """ 

614 # Update value 

615 variable.value = value 

616 # Update timestamp 

617 if timestamp is None: 

618 timestamp = self.env.time 

619 variable.timestamp = timestamp 

620 # Return updated variable 

621 return variable 

622 

623 def _callback_config_vars(self, variable: AgentVariable, name: str): 

624 """ 

625 Callback to update the AgentVariables of the module defined in the 

626 config. 

627 

628 Args: 

629 variable: Variable sent by data broker 

630 name: Name of the variable in own config 

631 """ 

632 own_var = self._variables_dict[name] 

633 value = deepcopy(variable.value) 

634 own_var.set_value(value=value, validate=self.config.validate_incoming_values) 

635 own_var.timestamp = variable.timestamp 

636 

637 @staticmethod 

638 def _copy_list_to_dict(ls: List[AgentVariable]): 

639 # pylint: disable=invalid-name 

640 return {var.name: var for var in ls.copy()} 

641 

642 def get_results(self): 

643 """ 

644 Returns results of this modules run. 

645 

646 Override this method, if your module creates data that you would like to obtain 

647 after the run. 

648 

649 Returns: 

650 Some form of results data, often in the form of a pandas DataFrame. 

651 """ 

652 

653 def cleanup_results(self): 

654 """ 

655 Deletes all files this module created. 

656 

657 Override this method, if your module creates e.g. results files etc. 

658 """