Coverage for agentlib/core/module.py: 95%
219 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-07 11:57 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-11-07 11:57 +0000
1"""This module contains the base AgentModule."""
3from __future__ import annotations
5import abc
6import json
7import logging
8from copy import deepcopy
9from typing import (
10 TYPE_CHECKING,
11 List,
12 Dict,
13 Union,
14 Any,
15 TypeVar,
16 Optional,
17 get_type_hints,
18 Type,
19)
21import pydantic
22from pydantic import field_validator, ConfigDict, BaseModel, Field, PrivateAttr
23from pydantic.json_schema import GenerateJsonSchema
24from pydantic_core import core_schema
26import agentlib.core.logging_ as agentlib_logging
27from agentlib.core import datamodels
28from agentlib.core.datamodels import (
29 AgentVariable,
30 Source,
31 AgentVariables,
32 AttrsToPydanticAdaptor,
33)
34from agentlib.core.environment import CustomSimpyEnvironment
35from agentlib.core.errors import ConfigurationError
36from agentlib.utils.fuzzy_matching import fuzzy_match, RAPIDFUZZ_IS_INSTALLED
37from agentlib.utils.validators import (
38 include_defaults_in_root,
39 update_default_agent_variable,
40 is_list_of_agent_variables,
41 is_valid_agent_var_config,
42)
44if TYPE_CHECKING:
45 # this avoids circular import
46 from agentlib.core import Agent
49logger = logging.getLogger(__name__)
52class BaseModuleConfig(BaseModel):
53 """
54 Pydantic data model for basic module configuration
55 """
57 # The type is relevant to load the correct module class.
58 type: Union[str, Dict[str, str]] = Field(
59 title="Type",
60 description="The type of the Module. Used to find the Python-Object "
61 "from all agentlib-core and plugin Module options. If a dict is given,"
62 "it must contain the keys 'file' and 'class_name'. "
63 "'file' is the filepath of a python file containing the Module."
64 "'class_name' is the name of the Module class within this file.",
65 )
66 # A module is uniquely identified in the MAS using agent_id and module_id.
67 # The module_id should be unique inside one agent.
68 # This is checked inside the agent-class.
69 module_id: str = Field(
70 description="The unqiue id of the module within an agent, "
71 "used only to communicate withing the agent."
72 )
73 validate_incoming_values: Optional[bool] = Field(
74 default=True,
75 title="Validate Incoming Values",
76 description="If true, the validator of the AgentVariable value is called when "
77 "receiving a new value from the DataBroker.",
78 )
79 log_level: Optional[str] = Field(
80 default=None,
81 description="The log level for this Module. "
82 "Default uses the root-loggers level."
83 "Options: DEBUG; INFO; WARNING; ERROR; CRITICAL",
84 )
85 shared_variable_fields: List[str] = Field(
86 default=[],
87 description="A list of strings with each string being a field of the Modules configs. "
88 "The field must be or contain an AgentVariable. If the field is added to this list, "
89 "all shared attributes of the AgentVariables will be set to True.",
90 validate_default=True,
91 )
92 # Aggregation of all instances of an AgentVariable in this Config
93 _variables: AgentVariables = PrivateAttr(default=[])
95 # The config given by the user to instantiate this class.
96 # Will be stored to enable a valid overwriting of the
97 # default config and to better restart modules.
98 # Is also useful to debug validators and the general BaseModuleConfig.
99 _user_config: dict = PrivateAttr(default=None)
101 model_config = ConfigDict(
102 arbitrary_types_allowed=True,
103 validate_assignment=True,
104 extra="forbid",
105 frozen=True,
106 )
108 def get_variables(self):
109 """Return the private attribute with all AgentVariables"""
110 return self._variables
112 @classmethod
113 def model_json_schema(cls, *args, **kwargs) -> dict:
114 """
115 Custom schema method to
116 - Add JSON Schema for custom attrs types Source and AgentVariable
117 - put log_level last, as it is the only optional field of the module config.
118 Used to better display relevant options of children classes in GUIs.
119 """
120 if "schema_generator" in kwargs:
121 raise ValueError("Custom schema_generator is not supported for BaseModule.")
123 class CustomGenerateJsonSchema(GenerateJsonSchema):
124 """
125 This class in necessary, as the default object type
126 AttrsToPydanticAdaptor (e.g. Source, AgentVariable) are
127 not json serializable by default.
128 """
130 def default_schema(self, schema: core_schema.WithDefaultSchema):
131 if "default" in schema:
132 _default = schema["default"]
133 if isinstance(_default, AttrsToPydanticAdaptor):
134 schema["default"] = _default.json()
135 return super().default_schema(schema=schema)
137 kwargs["schema_generator"] = CustomGenerateJsonSchema
138 schema = super().model_json_schema(*args, **kwargs)
139 definitions = schema.get("$defs", {})
140 definitions_out = {}
141 for class_name, metadata in definitions.items():
142 if class_name in datamodels.ATTRS_MODELS:
143 class_object: AttrsToPydanticAdaptor = getattr(datamodels, class_name)
144 metadata = class_object.get_json_schema()
145 definitions_out[class_name] = metadata
146 if definitions_out:
147 schema["$defs"] = definitions_out
149 log_level = schema["properties"].pop("log_level")
150 shared_variable_fields = schema["properties"].pop("shared_variable_fields")
151 schema["properties"]["shared_variable_fields"] = shared_variable_fields
152 schema["properties"]["log_level"] = log_level
153 return schema
155 @classmethod
156 def check_if_variables_are_unique(cls, names):
157 """Check if a given iterable of AgentVariables have a
158 unique name."""
159 if len(names) != len(set(names)):
160 for name in set(names.copy()):
161 names.remove(name)
162 raise ValueError(
163 f"{cls.__name__} contains variables with the same name. The "
164 f"following appear at least twice: {' ,'.join(names)}"
165 )
167 @field_validator("shared_variable_fields")
168 @classmethod
169 def check_valid_fields(cls, shared_variables_fields):
170 """
171 Check if the shared_variables_fields are valid
172 fields.
173 """
174 wrong_public_fields = set(shared_variables_fields).difference(
175 cls.model_fields.keys()
176 )
177 if wrong_public_fields:
178 raise ConfigurationError(
179 f"Public fields {wrong_public_fields} do not exist. Maybe you "
180 f"misspelled them?"
181 )
182 return shared_variables_fields
184 @field_validator("log_level")
185 @classmethod
186 def check_valid_level(cls, log_level: str):
187 """
188 Check if the given log_level is valid
189 """
190 if log_level is None:
191 return log_level
192 log_level = log_level.upper()
193 if not isinstance(logging.getLevelName(log_level), int):
194 raise ValueError(
195 f"Given log level '{log_level}' is not "
196 f"supported by logging library."
197 )
198 return log_level
200 @classmethod
201 def merge_variables(
202 cls,
203 pre_validated_instance: BaseModuleConfig,
204 user_config: dict,
205 agent_id: str,
206 shared_variable_fields: List[str],
207 ):
208 """
209 Merge, rigorously check and validate the input of
210 all AgentVariables into the module.
211 This function:
213 - Collects all variables
214 - Checks if duplicate names (will cause errors in the get() function.
215 """
216 _vars = []
217 # Extract all variables from fields
218 for field_name, field in cls.model_fields.items():
219 # If field is missing in values, validation of field was not
220 # successful. Continue and pydantic will later raise the ValidationError
221 if field_name not in pre_validated_instance.model_fields:
222 continue
224 pre_merged_attr = pre_validated_instance.__getattribute__(field_name)
225 # we need the type if plugins subclass the AgentVariable
227 if isinstance(pre_merged_attr, AgentVariable):
228 update_var_with = user_config.get(field_name, {})
230 make_shared = field_name in shared_variable_fields
232 var = update_default_agent_variable(
233 default_var=field.default,
234 user_data=update_var_with,
235 make_shared=make_shared,
236 agent_id=agent_id,
237 field_name=field_name,
238 )
239 _vars.append(var)
240 pre_validated_instance.__setattr__(field_name, var)
242 elif is_list_of_agent_variables(pre_merged_attr):
243 user_config_var_dicts = user_config.get(field_name, [])
244 type_ = pre_merged_attr[0].__class__
245 update_vars_with = [
246 conf
247 for conf in user_config_var_dicts
248 if is_valid_agent_var_config(conf, field_name, type_)
249 ]
251 make_shared = field_name in shared_variable_fields
252 variables = include_defaults_in_root(
253 variables=update_vars_with,
254 field=field,
255 type_=type_, # subtype of AgentVariable
256 make_shared=make_shared,
257 agent_id=agent_id,
258 field_name=field_name,
259 )
261 _vars.extend(variables)
262 pre_validated_instance.__setattr__(field_name, variables)
264 # Extract names
265 variable_names = [var.name for var in _vars]
267 # First check if names exists more than once
268 cls.check_if_variables_are_unique(names=variable_names)
270 for _var in _vars:
271 # case the agent id is a different agent
272 if (_var.source.agent_id != agent_id) and (
273 _var.source.module_id is not None
274 ):
275 logger.warning(
276 "Setting given module_id '%s' in variable '%s' to None. "
277 "You can not specify module_ids of other agents.",
278 _var.source.module_id,
279 _var.name,
280 )
281 _var.source = Source(agent_id=_var.source.agent_id)
283 return _vars
285 @classmethod
286 def default(cls, field: str):
287 return cls.model_fields[field].get_default()
289 def __init__(self, _agent_id, *args, **kwargs):
290 _user_config = kwargs.copy()
291 try:
292 super().__init__(*args, **kwargs)
293 except pydantic.ValidationError as e:
294 better_error = self._improve_extra_field_error_messages(
295 e, agent_id=_agent_id, module_id=_user_config.get("module_id")
296 )
297 raise better_error
299 # Enable mutation
300 self.model_config["frozen"] = False
301 self._variables = self.__class__.merge_variables(
302 pre_validated_instance=self,
303 user_config=_user_config,
304 agent_id=_agent_id,
305 shared_variable_fields=self.shared_variable_fields,
306 )
307 self._user_config = _user_config
308 # Disable mutation
309 self.model_config["frozen"] = True
311 @classmethod
312 def _improve_extra_field_error_messages(
313 cls, e: pydantic.ValidationError, agent_id: str, module_id: str
314 ) -> pydantic.ValidationError:
315 """Checks the validation errors for invalid fields and adds suggestions for
316 correct field names to the error message."""
318 error_list = e.errors()
319 if module_id is not None:
320 title = f"configuration of agent '{agent_id}' / module '{module_id}':"
321 else:
322 title = f"configuration of agent '{agent_id}':"
323 if RAPIDFUZZ_IS_INSTALLED:
324 for error in error_list:
325 if not error["type"] == "extra_forbidden":
326 continue
328 # change error type to literal because it allows for context
329 error["type"] = "literal_error"
330 # pydantic automatically prints the __dict__ of an error, so it is
331 # sufficient to just assign the suggestions to an arbitrary attribute of
332 # the error
333 suggestions = fuzzy_match(
334 target=error["loc"][0], choices=cls.model_fields.keys()
335 )
336 if suggestions:
337 error["ctx"] = {
338 "expected": f"a valid Field name. Field '{error['loc'][0]}' does "
339 f"not exist. Did you mean any of {suggestions}?"
340 }
342 return pydantic.ValidationError.from_exception_data(
343 title=title, line_errors=error_list
344 )
347BaseModuleConfigClass = TypeVar("BaseModuleConfigClass", bound=BaseModuleConfig)
350class BaseModule(abc.ABC):
351 """
352 Basic module used by any agent.
353 Besides a common configuration, where ids
354 and variables are defined, this class manages
355 the setting and getting of variables and relevant
356 attributes.
357 """
359 # pylint: disable=too-many-public-methods
361 def __init__(self, *, config: dict, agent: Agent):
362 self._agent = agent
363 self.logger = agentlib_logging.create_logger(
364 env=self.env, name=f"{self.agent.id}/{config['module_id']}"
365 )
366 self.config = config # evokes the config setter
367 # Add process to environment
368 self.env.process(self.process())
369 self.register_callbacks()
371 ############################################################################
372 # Methods to inherit by subclasses
373 ############################################################################
375 @classmethod
376 def get_config_type(cls) -> Type[BaseModuleConfigClass]:
377 if hasattr(cls, "config_type"):
378 raise AttributeError(
379 "The 'config_type' attribute is deprecated and has been removed. "
380 "Please use the following syntax to assign the config of your custom "
381 f"module '{cls.__name__}': \n"
382 "class MyModule(agentlib.BaseModule):\n"
383 " config: MyConfigClass\n"
384 )
386 return get_type_hints(cls).get("config")
388 @abc.abstractmethod
389 def register_callbacks(self):
390 raise NotImplementedError("Needs to be implemented by derived modules")
392 @abc.abstractmethod
393 def process(self):
394 """This abstract method must be implemented in order to sync the module
395 with the other processes of the agent and the whole MAS."""
396 raise NotImplementedError("Needs to be implemented by derived modules")
398 def terminate(self):
399 """
400 Terminate all relevant processes of the module.
401 This is necessary to correctly terminate an agent
402 at runtime. Not all modules may need this, hence it is
403 not an abstract method.
404 """
405 self.logger.info(
406 "Successfully terminated module %s in agent %s", self.id, self.agent.id
407 )
409 ############################################################################
410 # Properties
411 ############################################################################
413 @property
414 def agent(self) -> Agent:
415 """Get the agent this module is located in."""
416 return self._agent
418 @property
419 def config(self) -> BaseModuleConfigClass:
420 """
421 The module config.
423 Returns:
424 BaseModuleConfigClass: Config of type self.config_type
425 """
426 return self._config
428 @config.setter
429 def config(self, config: Union[BaseModuleConfig, dict, str]):
430 """Set a new config"""
431 if self.get_config_type() is None:
432 raise ConfigurationError(
433 "The module has no valid config. Please make sure you "
434 "specify the class attribute 'config' when writing your module."
435 )
436 if isinstance(config, str):
437 config = json.loads(config)
438 self._config = self.get_config_type()(_agent_id=self.agent.id, **config)
440 # Update variables:
441 self._variables_dict: Dict[str, AgentVariable] = self._copy_list_to_dict(
442 self.config.get_variables()
443 )
444 # Now de-and re-register all callbacks:
445 self._register_variable_callbacks()
447 # Set log-level
448 if self.config.log_level is not None:
449 if not logging.getLogger().hasHandlers():
450 _root_lvl_int = logging.getLogger().level
451 _log_lvl_int = logging.getLevelName(self.config.log_level)
452 if _log_lvl_int < _root_lvl_int:
453 self.logger.error(
454 "Log level '%s' is below root loggers level '%s'. "
455 "Without calling logging.basicConfig, "
456 "logs will not be printed.",
457 self.config.log_level,
458 logging.getLevelName(_root_lvl_int),
459 )
460 self.logger.setLevel(self.config.log_level)
462 # Call the after config update:
463 self._after_config_update()
465 def _after_config_update(self):
466 """
467 This function is called after the config of
468 the module is updated.
470 Overwrite this function to enable custom behaviour
471 after your config is updated.
472 For instance, a simulator may re-initialize it's model,
473 or a coordinator in an ADMM-MAS send new settings to
474 the participants.
476 Returns nothing, the config is immutable
477 """
479 def _register_variable_callbacks(self):
480 """
481 This functions de-registers and then re-registers
482 callbacks for all variables of the module to
483 update their specific values.
484 """
485 # Keep everything in THAT order!!
486 for name, var in self._variables_dict.items():
487 self.agent.data_broker.deregister_callback(
488 alias=var.alias,
489 source=var.source,
490 callback=self._callback_config_vars,
491 name=name,
492 )
493 for name, var in self._variables_dict.items():
494 self.agent.data_broker.register_callback(
495 alias=var.alias,
496 source=var.source,
497 callback=self._callback_config_vars,
498 name=name,
499 _unsafe_no_copy=True,
500 )
502 @property
503 def env(self) -> CustomSimpyEnvironment:
504 """Get the environment of the agent."""
505 return self.agent.env
507 @property
508 def id(self) -> str:
509 """Get the module's id"""
510 return self.config.module_id
512 @property
513 def source(self) -> Source:
514 """Get the source of the module,
515 containing the agent and module id"""
516 return Source(agent_id=self.agent.id, module_id=self.id)
518 @property
519 def variables(self) -> List[AgentVariable]:
520 """Return all values as a list."""
521 return [v.copy() for v in self._variables_dict.values()]
523 ############################################################################
524 # Get, set and updaters
525 ############################################################################
526 def get(self, name: str) -> AgentVariable:
527 """
528 Get any variable matching the given name:
530 Args:
531 name (str): The item to get by name of Variable.
532 Hence, item=AgentVariable.name
533 Returns:
534 var (AgentVariable): The matching variable
535 Raises:
536 KeyError: If the item was not found in the variables of the
537 module.
538 """
539 try:
540 return self._variables_dict[name].copy()
541 except KeyError as err:
542 raise KeyError(
543 f"'{self.__class__.__name__}' has "
544 f"no AgentVariable with the name '{name}' "
545 f"in the configs variables."
546 ) from err
548 def get_value(self, name: str) -> Any:
549 """
550 Get the value of the variable matching the given name:
552 Args:
553 name (str): The item to get by name of Variable.
554 Hence, item=AgentVariable.name
555 Returns:
556 var (Any): The matching value
557 Raises:
558 KeyError: If the item was not found in the variables of the
559 module.
560 """
561 try:
562 return deepcopy(self._variables_dict[name].value)
563 except KeyError as err:
564 raise KeyError(
565 f"'{self.__class__.__name__}' has "
566 f"no AgentVariable with the name '{name}' "
567 f"in the configs variables."
568 ) from err
570 def set(self, name: str, value: Any, timestamp: float = None):
571 """
572 Set any variable by using the name:
574 Args:
575 name (str): The item to get by name of Variable.
576 Hence, item=AgentVariable.name
577 value (Any): Any value to set to the Variable
578 timestamp (float): The timestamp associated with the variable.
579 If None, current environment time is used.
581 Raises:
582 AttributeError: If the item was not found in the variables of the
583 module.
584 """
585 # var = self.get(name)
586 var = self._variables_dict[name]
587 var = self._update_relevant_values(
588 variable=var, value=value, timestamp=timestamp
589 )
590 self.agent.data_broker.send_variable(
591 variable=var.copy(update={"source": self.source}),
592 copy=False,
593 )
595 ############################################################################
596 # Private and or static class methods
597 ############################################################################
599 def _update_relevant_values(
600 self, variable: AgentVariable, value: Any, timestamp: float = None
601 ):
602 """
603 Update the given variables fields
604 with the given value (and possibly timestamp)
605 Args:
606 variable (AgentVariable): The variable to be updated.
607 value (Any): Any value to set to the Variable
608 timestamp (float): The timestamp associated with the variable.
609 If None, current environment time is used.
611 Returns:
612 AgentVariable: The updated variable
613 """
614 # Update value
615 variable.value = value
616 # Update timestamp
617 if timestamp is None:
618 timestamp = self.env.time
619 variable.timestamp = timestamp
620 # Return updated variable
621 return variable
623 def _callback_config_vars(self, variable: AgentVariable, name: str):
624 """
625 Callback to update the AgentVariables of the module defined in the
626 config.
628 Args:
629 variable: Variable sent by data broker
630 name: Name of the variable in own config
631 """
632 own_var = self._variables_dict[name]
633 value = deepcopy(variable.value)
634 own_var.set_value(value=value, validate=self.config.validate_incoming_values)
635 own_var.timestamp = variable.timestamp
637 @staticmethod
638 def _copy_list_to_dict(ls: List[AgentVariable]):
639 # pylint: disable=invalid-name
640 return {var.name: var for var in ls.copy()}
642 def get_results(self):
643 """
644 Returns results of this modules run.
646 Override this method, if your module creates data that you would like to obtain
647 after the run.
649 Returns:
650 Some form of results data, often in the form of a pandas DataFrame.
651 """
653 def cleanup_results(self):
654 """
655 Deletes all files this module created.
657 Override this method, if your module creates e.g. results files etc.
658 """