Coverage for agentlib/core/module.py: 95%

224 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1"""This module contains the base AgentModule.""" 

2 

3from __future__ import annotations 

4 

5import abc 

6import json 

7import logging 

8from copy import deepcopy 

9from typing import ( 

10 TYPE_CHECKING, 

11 List, 

12 Dict, 

13 Union, 

14 Any, 

15 TypeVar, 

16 Optional, 

17 get_type_hints, 

18 Type, 

19) 

20 

21import pydantic 

22from pydantic import field_validator, ConfigDict, BaseModel, Field, PrivateAttr 

23from pydantic.json_schema import GenerateJsonSchema 

24from pydantic_core import core_schema 

25 

26import agentlib.core.logging_ as agentlib_logging 

27from agentlib.core import datamodels 

28from agentlib.core.datamodels import ( 

29 AgentVariable, 

30 Source, 

31 AgentVariables, 

32 AttrsToPydanticAdaptor, 

33) 

34from agentlib.core.environment import CustomSimpyEnvironment 

35from agentlib.core.errors import ConfigurationError 

36from agentlib.utils.fuzzy_matching import fuzzy_match, RAPIDFUZZ_IS_INSTALLED 

37from agentlib.utils.validators import ( 

38 include_defaults_in_root, 

39 update_default_agent_variable, 

40 is_list_of_agent_variables, 

41 is_valid_agent_var_config, 

42) 

43 

44if TYPE_CHECKING: 

45 # this avoids circular import 

46 from agentlib.core import Agent 

47 

48 

49logger = logging.getLogger(__name__) 

50 

51 

52class BaseModuleConfig(BaseModel): 

53 """ 

54 Pydantic data model for basic module configuration 

55 """ 

56 

57 # The type is relevant to load the correct module class. 

58 type: Union[str, Dict[str, str]] = Field( 

59 title="Type", 

60 description="The type of the Module. Used to find the Python-Object " 

61 "from all agentlib-core and plugin Module options. If a dict is given," 

62 "it must contain the keys 'file' and 'class_name'. " 

63 "'file' is the filepath of a python file containing the Module." 

64 "'class_name' is the name of the Module class within this file.", 

65 ) 

66 # A module is uniquely identified in the MAS using agent_id and module_id. 

67 # The module_id should be unique inside one agent. 

68 # This is checked inside the agent-class. 

69 module_id: str = Field( 

70 description="The unqiue id of the module within an agent, " 

71 "used only to communicate withing the agent." 

72 ) 

73 validate_incoming_values: Optional[bool] = Field( 

74 default=True, 

75 title="Validate Incoming Values", 

76 description="If true, the validator of the AgentVariable value is called when " 

77 "receiving a new value from the DataBroker.", 

78 ) 

79 log_level: Optional[str] = Field( 

80 default=None, 

81 description="The log level for this Module. " 

82 "Default uses the root-loggers level." 

83 "Options: DEBUG; INFO; WARNING; ERROR; CRITICAL", 

84 ) 

85 shared_variable_fields: List[str] = Field( 

86 default=[], 

87 description="A list of strings with each string being a field of the Modules configs. " 

88 "The field must be or contain an AgentVariable. If the field is added to this list, " 

89 "all shared attributes of the AgentVariables will be set to True.", 

90 validate_default=True, 

91 ) 

92 # Aggregation of all instances of an AgentVariable in this Config 

93 _variables: AgentVariables = PrivateAttr(default=[]) 

94 

95 # The config given by the user to instantiate this class. 

96 # Will be stored to enable a valid overwriting of the 

97 # default config and to better restart modules. 

98 # Is also useful to debug validators and the general BaseModuleConfig. 

99 _user_config: dict = PrivateAttr(default=None) 

100 

101 model_config = ConfigDict( 

102 arbitrary_types_allowed=True, 

103 validate_assignment=True, 

104 extra="forbid", 

105 frozen=True, 

106 ) 

107 

108 def get_variables(self): 

109 """Return the private attribute with all AgentVariables""" 

110 return self._variables 

111 

112 @classmethod 

113 def model_json_schema(cls, *args, **kwargs) -> dict: 

114 """ 

115 Custom schema method to 

116 - Add JSON Schema for custom attrs types Source and AgentVariable 

117 - put log_level last, as it is the only optional field of the module config. 

118 Used to better display relevant options of children classes in GUIs. 

119 """ 

120 if "schema_generator" in kwargs: 

121 raise ValueError("Custom schema_generator is not supported for BaseModule.") 

122 

123 class CustomGenerateJsonSchema(GenerateJsonSchema): 

124 """ 

125 This class in necessary, as the default object type 

126 AttrsToPydanticAdaptor (e.g. Source, AgentVariable) are 

127 not json serializable by default. 

128 """ 

129 

130 def default_schema(self, schema: core_schema.WithDefaultSchema): 

131 if "default" in schema: 

132 _default = schema["default"] 

133 if isinstance(_default, AttrsToPydanticAdaptor): 

134 schema["default"] = _default.json() 

135 return super().default_schema(schema=schema) 

136 

137 kwargs["schema_generator"] = CustomGenerateJsonSchema 

138 schema = super().model_json_schema(*args, **kwargs) 

139 definitions = schema.get("$defs", {}) 

140 definitions_out = {} 

141 for class_name, metadata in definitions.items(): 

142 if class_name in datamodels.ATTRS_MODELS: 

143 class_object: AttrsToPydanticAdaptor = getattr(datamodels, class_name) 

144 metadata = class_object.get_json_schema() 

145 definitions_out[class_name] = metadata 

146 if definitions_out: 

147 schema["$defs"] = definitions_out 

148 

149 log_level = schema["properties"].pop("log_level") 

150 shared_variable_fields = schema["properties"].pop("shared_variable_fields") 

151 schema["properties"]["shared_variable_fields"] = shared_variable_fields 

152 schema["properties"]["log_level"] = log_level 

153 return schema 

154 

155 @classmethod 

156 def check_if_variables_are_unique(cls, names): 

157 """Check if a given iterable of AgentVariables have a 

158 unique name.""" 

159 if len(names) != len(set(names)): 

160 for name in set(names.copy()): 

161 names.remove(name) 

162 raise ValueError( 

163 f"{cls.__name__} contains variables with the same name. The " 

164 f"following appear at least twice: {' ,'.join(names)}" 

165 ) 

166 

167 @field_validator("shared_variable_fields") 

168 @classmethod 

169 def check_valid_fields(cls, shared_variables_fields): 

170 """ 

171 Check if the shared_variables_fields are valid 

172 fields. 

173 """ 

174 wrong_public_fields = set(shared_variables_fields).difference( 

175 cls.model_fields.keys() 

176 ) 

177 if wrong_public_fields: 

178 raise ConfigurationError( 

179 f"Public fields {wrong_public_fields} do not exist. Maybe you " 

180 f"misspelled them?" 

181 ) 

182 return shared_variables_fields 

183 

184 @field_validator("log_level") 

185 @classmethod 

186 def check_valid_level(cls, log_level: str): 

187 """ 

188 Check if the given log_level is valid 

189 """ 

190 if log_level is None: 

191 return log_level 

192 log_level = log_level.upper() 

193 if not isinstance(logging.getLevelName(log_level), int): 

194 raise ValueError( 

195 f"Given log level '{log_level}' is not " 

196 f"supported by logging library." 

197 ) 

198 return log_level 

199 

200 @classmethod 

201 def merge_variables( 

202 cls, 

203 pre_validated_instance: BaseModuleConfig, 

204 user_config: dict, 

205 agent_id: str, 

206 shared_variable_fields: List[str], 

207 ): 

208 """ 

209 Merge, rigorously check and validate the input of 

210 all AgentVariables into the module. 

211 This function: 

212 

213 - Collects all variables 

214 - Checks if duplicate names (will cause errors in the get() function. 

215 """ 

216 _vars = [] 

217 # Extract all variables from fields 

218 for field_name, field in cls.model_fields.items(): 

219 # If field is missing in values, validation of field was not 

220 # successful. Continue and pydantic will later raise the ValidationError 

221 if field_name not in pre_validated_instance.model_fields: 

222 continue 

223 

224 pre_merged_attr = pre_validated_instance.__getattribute__(field_name) 

225 # we need the type if plugins subclass the AgentVariable 

226 

227 if isinstance(pre_merged_attr, AgentVariable): 

228 update_var_with = user_config.get(field_name, {}) 

229 

230 make_shared = field_name in shared_variable_fields 

231 

232 var = update_default_agent_variable( 

233 default_var=field.default, 

234 user_data=update_var_with, 

235 make_shared=make_shared, 

236 agent_id=agent_id, 

237 field_name=field_name, 

238 ) 

239 _vars.append(var) 

240 pre_validated_instance.__setattr__(field_name, var) 

241 

242 elif is_list_of_agent_variables(pre_merged_attr): 

243 user_config_var_dicts = user_config.get(field_name, []) 

244 type_ = pre_merged_attr[0].__class__ 

245 update_vars_with = [ 

246 conf 

247 for conf in user_config_var_dicts 

248 if is_valid_agent_var_config(conf, field_name, type_) 

249 ] 

250 

251 make_shared = field_name in shared_variable_fields 

252 variables = include_defaults_in_root( 

253 variables=update_vars_with, 

254 field=field, 

255 type_=type_, # subtype of AgentVariable 

256 make_shared=make_shared, 

257 agent_id=agent_id, 

258 field_name=field_name, 

259 ) 

260 

261 _vars.extend(variables) 

262 pre_validated_instance.__setattr__(field_name, variables) 

263 

264 # Extract names 

265 variable_names = [var.name for var in _vars] 

266 

267 # First check if names exists more than once 

268 cls.check_if_variables_are_unique(names=variable_names) 

269 

270 for _var in _vars: 

271 # case the agent id is a different agent 

272 if (_var.source.agent_id != agent_id) and ( 

273 _var.source.module_id is not None 

274 ): 

275 logger.warning( 

276 "Setting given module_id '%s' in variable '%s' to None. " 

277 "You can not specify module_ids of other agents.", 

278 _var.source.module_id, 

279 _var.name, 

280 ) 

281 _var.source = Source(agent_id=_var.source.agent_id) 

282 

283 return _vars 

284 

285 @classmethod 

286 def default(cls, field: str): 

287 return cls.model_fields[field].get_default() 

288 

289 def __init__(self, _agent_id, *args, **kwargs): 

290 _user_config = kwargs.copy() 

291 try: 

292 super().__init__(*args, **kwargs) 

293 except pydantic.ValidationError as e: 

294 better_error = self._improve_extra_field_error_messages(e) 

295 raise better_error 

296 # Enable mutation 

297 self.model_config["frozen"] = False 

298 self._variables = self.__class__.merge_variables( 

299 pre_validated_instance=self, 

300 user_config=_user_config, 

301 agent_id=_agent_id, 

302 shared_variable_fields=self.shared_variable_fields, 

303 ) 

304 self._user_config = _user_config 

305 # Disable mutation 

306 self.model_config["frozen"] = True 

307 

308 @classmethod 

309 def _improve_extra_field_error_messages( 

310 cls, e: pydantic.ValidationError 

311 ) -> pydantic.ValidationError: 

312 """Checks the validation errors for invalid fields and adds suggestions for 

313 correct field names to the error message.""" 

314 if not RAPIDFUZZ_IS_INSTALLED: 

315 return e 

316 

317 error_list = e.errors() 

318 for error in error_list: 

319 if not error["type"] == "extra_forbidden": 

320 continue 

321 

322 # change error type to literal because it allows for context 

323 error["type"] = "literal_error" 

324 # pydantic automatically prints the __dict__ of an error, so it is 

325 # sufficient to just assign the suggestions to an arbitrary attribute of 

326 # the error 

327 suggestions = fuzzy_match( 

328 target=error["loc"][0], choices=cls.model_fields.keys() 

329 ) 

330 if suggestions: 

331 error["ctx"] = { 

332 "expected": f"a valid Field name. Field '{error['loc'][0]}' does " 

333 f"not exist. Did you mean any of {suggestions}?" 

334 } 

335 

336 return pydantic.ValidationError.from_exception_data( 

337 title=e.title, line_errors=error_list 

338 ) 

339 

340 

341BaseModuleConfigClass = TypeVar("BaseModuleConfigClass", bound=BaseModuleConfig) 

342 

343 

344class BaseModule(abc.ABC): 

345 """ 

346 Basic module used by any agent. 

347 Besides a common configuration, where ids 

348 and variables are defined, this class manages 

349 the setting and getting of variables and relevant 

350 attributes. 

351 """ 

352 

353 # pylint: disable=too-many-public-methods 

354 

355 def __init__(self, *, config: dict, agent: Agent): 

356 self._agent = agent 

357 self.logger = agentlib_logging.create_logger( 

358 env=self.env, name=f"{self.agent.id}/{config['module_id']}" 

359 ) 

360 self.config = config # evokes the config setter 

361 # Add process to environment 

362 self.env.process(self.process()) 

363 self.register_callbacks() 

364 

365 ############################################################################ 

366 # Methods to inherit by subclasses 

367 ############################################################################ 

368 

369 @classmethod 

370 def get_config_type(cls) -> Type[BaseModuleConfigClass]: 

371 if hasattr(cls, "config_type"): 

372 raise AttributeError( 

373 "The 'config_type' attribute is deprecated and has been removed. " 

374 "Please use the following syntax to assign the config of your custom " 

375 f"module '{cls.__name__}': \n" 

376 "class MyModule(agentlib.BaseModule):\n" 

377 " config: MyConfigClass\n" 

378 ) 

379 

380 return get_type_hints(cls).get("config") 

381 

382 @abc.abstractmethod 

383 def register_callbacks(self): 

384 raise NotImplementedError("Needs to be implemented by derived modules") 

385 

386 @abc.abstractmethod 

387 def process(self): 

388 """This abstract method must be implemented in order to sync the module 

389 with the other processes of the agent and the whole MAS.""" 

390 raise NotImplementedError("Needs to be implemented by derived modules") 

391 

392 def terminate(self): 

393 """ 

394 Terminate all relevant processes of the module. 

395 This is necessary to correctly terminate an agent 

396 at runtime. Not all modules may need this, hence it is 

397 not an abstract method. 

398 """ 

399 self.logger.info( 

400 "Successfully terminated module %s in agent %s", self.id, self.agent.id 

401 ) 

402 

403 ############################################################################ 

404 # Properties 

405 ############################################################################ 

406 

407 @property 

408 def agent(self) -> Agent: 

409 """Get the agent this module is located in.""" 

410 return self._agent 

411 

412 @property 

413 def config(self) -> BaseModuleConfigClass: 

414 """ 

415 The module config. 

416 

417 Returns: 

418 BaseModuleConfigClass: Config of type self.config_type 

419 """ 

420 return self._config 

421 

422 @config.setter 

423 def config(self, config: Union[BaseModuleConfig, dict, str]): 

424 """Set a new config""" 

425 if self.get_config_type() is None: 

426 raise ConfigurationError( 

427 "The module has no valid config. Please make sure you " 

428 "specify the class attribute 'config' when writing your module." 

429 ) 

430 if isinstance(config, str): 

431 config = json.loads(config) 

432 self._config = self.get_config_type()(_agent_id=self.agent.id, **config) 

433 

434 # Update variables: 

435 self._variables_dict: Dict[str, AgentVariable] = self._copy_list_to_dict( 

436 self.config.get_variables() 

437 ) 

438 # Now de-and re-register all callbacks: 

439 self._register_variable_callbacks() 

440 

441 # Set log-level 

442 if self.config.log_level is not None: 

443 if not logging.getLogger().hasHandlers(): 

444 _root_lvl_int = logging.getLogger().level 

445 _log_lvl_int = logging.getLevelName(self.config.log_level) 

446 if _log_lvl_int < _root_lvl_int: 

447 self.logger.error( 

448 "Log level '%s' is below root loggers level '%s'. " 

449 "Without calling logging.basicConfig, " 

450 "logs will not be printed.", 

451 self.config.log_level, 

452 logging.getLevelName(_root_lvl_int), 

453 ) 

454 self.logger.setLevel(self.config.log_level) 

455 

456 # Call the after config update: 

457 self._after_config_update() 

458 

459 def _after_config_update(self): 

460 """ 

461 This function is called after the config of 

462 the module is updated. 

463 

464 Overwrite this function to enable custom behaviour 

465 after your config is updated. 

466 For instance, a simulator may re-initialize it's model, 

467 or a coordinator in an ADMM-MAS send new settings to 

468 the participants. 

469 

470 Returns nothing, the config is immutable 

471 """ 

472 

473 def _register_variable_callbacks(self): 

474 """ 

475 This functions de-registers and then re-registers 

476 callbacks for all variables of the module to 

477 update their specific values. 

478 """ 

479 # Keep everything in THAT order!! 

480 for name, var in self._variables_dict.items(): 

481 self.agent.data_broker.deregister_callback( 

482 alias=var.alias, 

483 source=var.source, 

484 callback=self._callback_config_vars, 

485 name=name, 

486 ) 

487 for name, var in self._variables_dict.items(): 

488 self.agent.data_broker.register_callback( 

489 alias=var.alias, 

490 source=var.source, 

491 callback=self._callback_config_vars, 

492 name=name, 

493 _unsafe_no_copy=True, 

494 ) 

495 

496 @property 

497 def env(self) -> CustomSimpyEnvironment: 

498 """Get the environment of the agent.""" 

499 return self.agent.env 

500 

501 @property 

502 def id(self) -> str: 

503 """Get the module's id""" 

504 return self.config.module_id 

505 

506 @property 

507 def source(self) -> Source: 

508 """Get the source of the module, 

509 containing the agent and module id""" 

510 return Source(agent_id=self.agent.id, module_id=self.id) 

511 

512 @property 

513 def variables(self) -> List[AgentVariable]: 

514 """Return all values as a list.""" 

515 return [v.copy() for v in self._variables_dict.values()] 

516 

517 ############################################################################ 

518 # Get, set and updaters 

519 ############################################################################ 

520 def get(self, name: str) -> AgentVariable: 

521 """ 

522 Get any variable matching the given name: 

523 

524 Args: 

525 name (str): The item to get by name of Variable. 

526 Hence, item=AgentVariable.name 

527 Returns: 

528 var (AgentVariable): The matching variable 

529 Raises: 

530 KeyError: If the item was not found in the variables of the 

531 module. 

532 """ 

533 try: 

534 return self._variables_dict[name].copy() 

535 except KeyError as err: 

536 raise KeyError( 

537 f"'{self.__class__.__name__}' has " 

538 f"no AgentVariable with the name '{name}' " 

539 f"in the configs variables." 

540 ) from err 

541 

542 def get_value(self, name: str) -> Any: 

543 """ 

544 Get the value of the variable matching the given name: 

545 

546 Args: 

547 name (str): The item to get by name of Variable. 

548 Hence, item=AgentVariable.name 

549 Returns: 

550 var (Any): The matching value 

551 Raises: 

552 KeyError: If the item was not found in the variables of the 

553 module. 

554 """ 

555 try: 

556 return deepcopy(self._variables_dict[name].value) 

557 except KeyError as err: 

558 raise KeyError( 

559 f"'{self.__class__.__name__}' has " 

560 f"no AgentVariable with the name '{name}' " 

561 f"in the configs variables." 

562 ) from err 

563 

564 def set(self, name: str, value: Any, timestamp: float = None): 

565 """ 

566 Set any variable by using the name: 

567 

568 Args: 

569 name (str): The item to get by name of Variable. 

570 Hence, item=AgentVariable.name 

571 value (Any): Any value to set to the Variable 

572 timestamp (float): The timestamp associated with the variable. 

573 If None, current environment time is used. 

574 

575 Raises: 

576 AttributeError: If the item was not found in the variables of the 

577 module. 

578 """ 

579 # var = self.get(name) 

580 var = self._variables_dict[name] 

581 var = self._update_relevant_values( 

582 variable=var, value=value, timestamp=timestamp 

583 ) 

584 self.agent.data_broker.send_variable( 

585 variable=var.copy(update={"source": self.source}), 

586 copy=False, 

587 ) 

588 

589 def update_variables(self, variables: List[AgentVariable], timestamp: float = None): 

590 """ 

591 Updates the given list of variables in the current data_broker. 

592 If a given Variable is not in the config of the module, an 

593 error is raised. 

594 TODO: check if this is needed, we currently don't use it anywhere 

595 

596 Args: 

597 variables: List with agent_variables. 

598 timestamp: The timestamp associated with the variable. 

599 If None, current environment time is used. 

600 """ 

601 if timestamp is None: 

602 timestamp = self.env.time 

603 

604 for v in variables: 

605 if v.name not in self._variables_dict: 

606 raise ValueError( 

607 f"'{self.__class__.__name__}' has " 

608 f"no AgentVariable with the name '{v.name}' " 

609 f"in the config." 

610 ) 

611 self.set(name=v.name, value=v.value, timestamp=timestamp) 

612 

613 ############################################################################ 

614 # Private and or static class methods 

615 ############################################################################ 

616 

617 def _update_relevant_values( 

618 self, variable: AgentVariable, value: Any, timestamp: float = None 

619 ): 

620 """ 

621 Update the given variables fields 

622 with the given value (and possibly timestamp) 

623 Args: 

624 variable (AgentVariable): The variable to be updated. 

625 value (Any): Any value to set to the Variable 

626 timestamp (float): The timestamp associated with the variable. 

627 If None, current environment time is used. 

628 

629 Returns: 

630 AgentVariable: The updated variable 

631 """ 

632 # Update value 

633 variable.value = value 

634 # Update timestamp 

635 if timestamp is None: 

636 timestamp = self.env.time 

637 variable.timestamp = timestamp 

638 # Return updated variable 

639 return variable 

640 

641 def _callback_config_vars(self, variable: AgentVariable, name: str): 

642 """ 

643 Callback to update the AgentVariables of the module defined in the 

644 config. 

645 

646 Args: 

647 variable: Variable sent by data broker 

648 name: Name of the variable in own config 

649 """ 

650 own_var = self._variables_dict[name] 

651 value = deepcopy(variable.value) 

652 own_var.set_value(value=value, validate=self.config.validate_incoming_values) 

653 own_var.timestamp = variable.timestamp 

654 

655 @staticmethod 

656 def _copy_list_to_dict(ls: List[AgentVariable]): 

657 # pylint: disable=invalid-name 

658 return {var.name: var for var in ls.copy()} 

659 

660 def get_results(self): 

661 """ 

662 Returns results of this modules run. 

663 

664 Override this method, if your module creates data that you would like to obtain 

665 after the run. 

666 

667 Returns: 

668 Some form of results data, often in the form of a pandas DataFrame. 

669 """ 

670 

671 def cleanup_results(self): 

672 """ 

673 Deletes all files this module created. 

674 

675 Override this method, if your module creates e.g. results files etc. 

676 """