Coverage for agentlib/core/module.py: 95%
224 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
1"""This module contains the base AgentModule."""
3from __future__ import annotations
5import abc
6import json
7import logging
8from copy import deepcopy
9from typing import (
10 TYPE_CHECKING,
11 List,
12 Dict,
13 Union,
14 Any,
15 TypeVar,
16 Optional,
17 get_type_hints,
18 Type,
19)
21import pydantic
22from pydantic import field_validator, ConfigDict, BaseModel, Field, PrivateAttr
23from pydantic.json_schema import GenerateJsonSchema
24from pydantic_core import core_schema
26import agentlib.core.logging_ as agentlib_logging
27from agentlib.core import datamodels
28from agentlib.core.datamodels import (
29 AgentVariable,
30 Source,
31 AgentVariables,
32 AttrsToPydanticAdaptor,
33)
34from agentlib.core.environment import CustomSimpyEnvironment
35from agentlib.core.errors import ConfigurationError
36from agentlib.utils.fuzzy_matching import fuzzy_match, RAPIDFUZZ_IS_INSTALLED
37from agentlib.utils.validators import (
38 include_defaults_in_root,
39 update_default_agent_variable,
40 is_list_of_agent_variables,
41 is_valid_agent_var_config,
42)
44if TYPE_CHECKING:
45 # this avoids circular import
46 from agentlib.core import Agent
49logger = logging.getLogger(__name__)
52class BaseModuleConfig(BaseModel):
53 """
54 Pydantic data model for basic module configuration
55 """
57 # The type is relevant to load the correct module class.
58 type: Union[str, Dict[str, str]] = Field(
59 title="Type",
60 description="The type of the Module. Used to find the Python-Object "
61 "from all agentlib-core and plugin Module options. If a dict is given,"
62 "it must contain the keys 'file' and 'class_name'. "
63 "'file' is the filepath of a python file containing the Module."
64 "'class_name' is the name of the Module class within this file.",
65 )
66 # A module is uniquely identified in the MAS using agent_id and module_id.
67 # The module_id should be unique inside one agent.
68 # This is checked inside the agent-class.
69 module_id: str = Field(
70 description="The unqiue id of the module within an agent, "
71 "used only to communicate withing the agent."
72 )
73 validate_incoming_values: Optional[bool] = Field(
74 default=True,
75 title="Validate Incoming Values",
76 description="If true, the validator of the AgentVariable value is called when "
77 "receiving a new value from the DataBroker.",
78 )
79 log_level: Optional[str] = Field(
80 default=None,
81 description="The log level for this Module. "
82 "Default uses the root-loggers level."
83 "Options: DEBUG; INFO; WARNING; ERROR; CRITICAL",
84 )
85 shared_variable_fields: List[str] = Field(
86 default=[],
87 description="A list of strings with each string being a field of the Modules configs. "
88 "The field must be or contain an AgentVariable. If the field is added to this list, "
89 "all shared attributes of the AgentVariables will be set to True.",
90 validate_default=True,
91 )
92 # Aggregation of all instances of an AgentVariable in this Config
93 _variables: AgentVariables = PrivateAttr(default=[])
95 # The config given by the user to instantiate this class.
96 # Will be stored to enable a valid overwriting of the
97 # default config and to better restart modules.
98 # Is also useful to debug validators and the general BaseModuleConfig.
99 _user_config: dict = PrivateAttr(default=None)
101 model_config = ConfigDict(
102 arbitrary_types_allowed=True,
103 validate_assignment=True,
104 extra="forbid",
105 frozen=True,
106 )
108 def get_variables(self):
109 """Return the private attribute with all AgentVariables"""
110 return self._variables
112 @classmethod
113 def model_json_schema(cls, *args, **kwargs) -> dict:
114 """
115 Custom schema method to
116 - Add JSON Schema for custom attrs types Source and AgentVariable
117 - put log_level last, as it is the only optional field of the module config.
118 Used to better display relevant options of children classes in GUIs.
119 """
120 if "schema_generator" in kwargs:
121 raise ValueError("Custom schema_generator is not supported for BaseModule.")
123 class CustomGenerateJsonSchema(GenerateJsonSchema):
124 """
125 This class in necessary, as the default object type
126 AttrsToPydanticAdaptor (e.g. Source, AgentVariable) are
127 not json serializable by default.
128 """
130 def default_schema(self, schema: core_schema.WithDefaultSchema):
131 if "default" in schema:
132 _default = schema["default"]
133 if isinstance(_default, AttrsToPydanticAdaptor):
134 schema["default"] = _default.json()
135 return super().default_schema(schema=schema)
137 kwargs["schema_generator"] = CustomGenerateJsonSchema
138 schema = super().model_json_schema(*args, **kwargs)
139 definitions = schema.get("$defs", {})
140 definitions_out = {}
141 for class_name, metadata in definitions.items():
142 if class_name in datamodels.ATTRS_MODELS:
143 class_object: AttrsToPydanticAdaptor = getattr(datamodels, class_name)
144 metadata = class_object.get_json_schema()
145 definitions_out[class_name] = metadata
146 if definitions_out:
147 schema["$defs"] = definitions_out
149 log_level = schema["properties"].pop("log_level")
150 shared_variable_fields = schema["properties"].pop("shared_variable_fields")
151 schema["properties"]["shared_variable_fields"] = shared_variable_fields
152 schema["properties"]["log_level"] = log_level
153 return schema
155 @classmethod
156 def check_if_variables_are_unique(cls, names):
157 """Check if a given iterable of AgentVariables have a
158 unique name."""
159 if len(names) != len(set(names)):
160 for name in set(names.copy()):
161 names.remove(name)
162 raise ValueError(
163 f"{cls.__name__} contains variables with the same name. The "
164 f"following appear at least twice: {' ,'.join(names)}"
165 )
167 @field_validator("shared_variable_fields")
168 @classmethod
169 def check_valid_fields(cls, shared_variables_fields):
170 """
171 Check if the shared_variables_fields are valid
172 fields.
173 """
174 wrong_public_fields = set(shared_variables_fields).difference(
175 cls.model_fields.keys()
176 )
177 if wrong_public_fields:
178 raise ConfigurationError(
179 f"Public fields {wrong_public_fields} do not exist. Maybe you "
180 f"misspelled them?"
181 )
182 return shared_variables_fields
184 @field_validator("log_level")
185 @classmethod
186 def check_valid_level(cls, log_level: str):
187 """
188 Check if the given log_level is valid
189 """
190 if log_level is None:
191 return log_level
192 log_level = log_level.upper()
193 if not isinstance(logging.getLevelName(log_level), int):
194 raise ValueError(
195 f"Given log level '{log_level}' is not "
196 f"supported by logging library."
197 )
198 return log_level
200 @classmethod
201 def merge_variables(
202 cls,
203 pre_validated_instance: BaseModuleConfig,
204 user_config: dict,
205 agent_id: str,
206 shared_variable_fields: List[str],
207 ):
208 """
209 Merge, rigorously check and validate the input of
210 all AgentVariables into the module.
211 This function:
213 - Collects all variables
214 - Checks if duplicate names (will cause errors in the get() function.
215 """
216 _vars = []
217 # Extract all variables from fields
218 for field_name, field in cls.model_fields.items():
219 # If field is missing in values, validation of field was not
220 # successful. Continue and pydantic will later raise the ValidationError
221 if field_name not in pre_validated_instance.model_fields:
222 continue
224 pre_merged_attr = pre_validated_instance.__getattribute__(field_name)
225 # we need the type if plugins subclass the AgentVariable
227 if isinstance(pre_merged_attr, AgentVariable):
228 update_var_with = user_config.get(field_name, {})
230 make_shared = field_name in shared_variable_fields
232 var = update_default_agent_variable(
233 default_var=field.default,
234 user_data=update_var_with,
235 make_shared=make_shared,
236 agent_id=agent_id,
237 field_name=field_name,
238 )
239 _vars.append(var)
240 pre_validated_instance.__setattr__(field_name, var)
242 elif is_list_of_agent_variables(pre_merged_attr):
243 user_config_var_dicts = user_config.get(field_name, [])
244 type_ = pre_merged_attr[0].__class__
245 update_vars_with = [
246 conf
247 for conf in user_config_var_dicts
248 if is_valid_agent_var_config(conf, field_name, type_)
249 ]
251 make_shared = field_name in shared_variable_fields
252 variables = include_defaults_in_root(
253 variables=update_vars_with,
254 field=field,
255 type_=type_, # subtype of AgentVariable
256 make_shared=make_shared,
257 agent_id=agent_id,
258 field_name=field_name,
259 )
261 _vars.extend(variables)
262 pre_validated_instance.__setattr__(field_name, variables)
264 # Extract names
265 variable_names = [var.name for var in _vars]
267 # First check if names exists more than once
268 cls.check_if_variables_are_unique(names=variable_names)
270 for _var in _vars:
271 # case the agent id is a different agent
272 if (_var.source.agent_id != agent_id) and (
273 _var.source.module_id is not None
274 ):
275 logger.warning(
276 "Setting given module_id '%s' in variable '%s' to None. "
277 "You can not specify module_ids of other agents.",
278 _var.source.module_id,
279 _var.name,
280 )
281 _var.source = Source(agent_id=_var.source.agent_id)
283 return _vars
285 @classmethod
286 def default(cls, field: str):
287 return cls.model_fields[field].get_default()
289 def __init__(self, _agent_id, *args, **kwargs):
290 _user_config = kwargs.copy()
291 try:
292 super().__init__(*args, **kwargs)
293 except pydantic.ValidationError as e:
294 better_error = self._improve_extra_field_error_messages(e)
295 raise better_error
296 # Enable mutation
297 self.model_config["frozen"] = False
298 self._variables = self.__class__.merge_variables(
299 pre_validated_instance=self,
300 user_config=_user_config,
301 agent_id=_agent_id,
302 shared_variable_fields=self.shared_variable_fields,
303 )
304 self._user_config = _user_config
305 # Disable mutation
306 self.model_config["frozen"] = True
308 @classmethod
309 def _improve_extra_field_error_messages(
310 cls, e: pydantic.ValidationError
311 ) -> pydantic.ValidationError:
312 """Checks the validation errors for invalid fields and adds suggestions for
313 correct field names to the error message."""
314 if not RAPIDFUZZ_IS_INSTALLED:
315 return e
317 error_list = e.errors()
318 for error in error_list:
319 if not error["type"] == "extra_forbidden":
320 continue
322 # change error type to literal because it allows for context
323 error["type"] = "literal_error"
324 # pydantic automatically prints the __dict__ of an error, so it is
325 # sufficient to just assign the suggestions to an arbitrary attribute of
326 # the error
327 suggestions = fuzzy_match(
328 target=error["loc"][0], choices=cls.model_fields.keys()
329 )
330 if suggestions:
331 error["ctx"] = {
332 "expected": f"a valid Field name. Field '{error['loc'][0]}' does "
333 f"not exist. Did you mean any of {suggestions}?"
334 }
336 return pydantic.ValidationError.from_exception_data(
337 title=e.title, line_errors=error_list
338 )
341BaseModuleConfigClass = TypeVar("BaseModuleConfigClass", bound=BaseModuleConfig)
344class BaseModule(abc.ABC):
345 """
346 Basic module used by any agent.
347 Besides a common configuration, where ids
348 and variables are defined, this class manages
349 the setting and getting of variables and relevant
350 attributes.
351 """
353 # pylint: disable=too-many-public-methods
355 def __init__(self, *, config: dict, agent: Agent):
356 self._agent = agent
357 self.logger = agentlib_logging.create_logger(
358 env=self.env, name=f"{self.agent.id}/{config['module_id']}"
359 )
360 self.config = config # evokes the config setter
361 # Add process to environment
362 self.env.process(self.process())
363 self.register_callbacks()
365 ############################################################################
366 # Methods to inherit by subclasses
367 ############################################################################
369 @classmethod
370 def get_config_type(cls) -> Type[BaseModuleConfigClass]:
371 if hasattr(cls, "config_type"):
372 raise AttributeError(
373 "The 'config_type' attribute is deprecated and has been removed. "
374 "Please use the following syntax to assign the config of your custom "
375 f"module '{cls.__name__}': \n"
376 "class MyModule(agentlib.BaseModule):\n"
377 " config: MyConfigClass\n"
378 )
380 return get_type_hints(cls).get("config")
382 @abc.abstractmethod
383 def register_callbacks(self):
384 raise NotImplementedError("Needs to be implemented by derived modules")
386 @abc.abstractmethod
387 def process(self):
388 """This abstract method must be implemented in order to sync the module
389 with the other processes of the agent and the whole MAS."""
390 raise NotImplementedError("Needs to be implemented by derived modules")
392 def terminate(self):
393 """
394 Terminate all relevant processes of the module.
395 This is necessary to correctly terminate an agent
396 at runtime. Not all modules may need this, hence it is
397 not an abstract method.
398 """
399 self.logger.info(
400 "Successfully terminated module %s in agent %s", self.id, self.agent.id
401 )
403 ############################################################################
404 # Properties
405 ############################################################################
407 @property
408 def agent(self) -> Agent:
409 """Get the agent this module is located in."""
410 return self._agent
412 @property
413 def config(self) -> BaseModuleConfigClass:
414 """
415 The module config.
417 Returns:
418 BaseModuleConfigClass: Config of type self.config_type
419 """
420 return self._config
422 @config.setter
423 def config(self, config: Union[BaseModuleConfig, dict, str]):
424 """Set a new config"""
425 if self.get_config_type() is None:
426 raise ConfigurationError(
427 "The module has no valid config. Please make sure you "
428 "specify the class attribute 'config' when writing your module."
429 )
430 if isinstance(config, str):
431 config = json.loads(config)
432 self._config = self.get_config_type()(_agent_id=self.agent.id, **config)
434 # Update variables:
435 self._variables_dict: Dict[str, AgentVariable] = self._copy_list_to_dict(
436 self.config.get_variables()
437 )
438 # Now de-and re-register all callbacks:
439 self._register_variable_callbacks()
441 # Set log-level
442 if self.config.log_level is not None:
443 if not logging.getLogger().hasHandlers():
444 _root_lvl_int = logging.getLogger().level
445 _log_lvl_int = logging.getLevelName(self.config.log_level)
446 if _log_lvl_int < _root_lvl_int:
447 self.logger.error(
448 "Log level '%s' is below root loggers level '%s'. "
449 "Without calling logging.basicConfig, "
450 "logs will not be printed.",
451 self.config.log_level,
452 logging.getLevelName(_root_lvl_int),
453 )
454 self.logger.setLevel(self.config.log_level)
456 # Call the after config update:
457 self._after_config_update()
459 def _after_config_update(self):
460 """
461 This function is called after the config of
462 the module is updated.
464 Overwrite this function to enable custom behaviour
465 after your config is updated.
466 For instance, a simulator may re-initialize it's model,
467 or a coordinator in an ADMM-MAS send new settings to
468 the participants.
470 Returns nothing, the config is immutable
471 """
473 def _register_variable_callbacks(self):
474 """
475 This functions de-registers and then re-registers
476 callbacks for all variables of the module to
477 update their specific values.
478 """
479 # Keep everything in THAT order!!
480 for name, var in self._variables_dict.items():
481 self.agent.data_broker.deregister_callback(
482 alias=var.alias,
483 source=var.source,
484 callback=self._callback_config_vars,
485 name=name,
486 )
487 for name, var in self._variables_dict.items():
488 self.agent.data_broker.register_callback(
489 alias=var.alias,
490 source=var.source,
491 callback=self._callback_config_vars,
492 name=name,
493 _unsafe_no_copy=True,
494 )
496 @property
497 def env(self) -> CustomSimpyEnvironment:
498 """Get the environment of the agent."""
499 return self.agent.env
501 @property
502 def id(self) -> str:
503 """Get the module's id"""
504 return self.config.module_id
506 @property
507 def source(self) -> Source:
508 """Get the source of the module,
509 containing the agent and module id"""
510 return Source(agent_id=self.agent.id, module_id=self.id)
512 @property
513 def variables(self) -> List[AgentVariable]:
514 """Return all values as a list."""
515 return [v.copy() for v in self._variables_dict.values()]
517 ############################################################################
518 # Get, set and updaters
519 ############################################################################
520 def get(self, name: str) -> AgentVariable:
521 """
522 Get any variable matching the given name:
524 Args:
525 name (str): The item to get by name of Variable.
526 Hence, item=AgentVariable.name
527 Returns:
528 var (AgentVariable): The matching variable
529 Raises:
530 KeyError: If the item was not found in the variables of the
531 module.
532 """
533 try:
534 return self._variables_dict[name].copy()
535 except KeyError as err:
536 raise KeyError(
537 f"'{self.__class__.__name__}' has "
538 f"no AgentVariable with the name '{name}' "
539 f"in the configs variables."
540 ) from err
542 def get_value(self, name: str) -> Any:
543 """
544 Get the value of the variable matching the given name:
546 Args:
547 name (str): The item to get by name of Variable.
548 Hence, item=AgentVariable.name
549 Returns:
550 var (Any): The matching value
551 Raises:
552 KeyError: If the item was not found in the variables of the
553 module.
554 """
555 try:
556 return deepcopy(self._variables_dict[name].value)
557 except KeyError as err:
558 raise KeyError(
559 f"'{self.__class__.__name__}' has "
560 f"no AgentVariable with the name '{name}' "
561 f"in the configs variables."
562 ) from err
564 def set(self, name: str, value: Any, timestamp: float = None):
565 """
566 Set any variable by using the name:
568 Args:
569 name (str): The item to get by name of Variable.
570 Hence, item=AgentVariable.name
571 value (Any): Any value to set to the Variable
572 timestamp (float): The timestamp associated with the variable.
573 If None, current environment time is used.
575 Raises:
576 AttributeError: If the item was not found in the variables of the
577 module.
578 """
579 # var = self.get(name)
580 var = self._variables_dict[name]
581 var = self._update_relevant_values(
582 variable=var, value=value, timestamp=timestamp
583 )
584 self.agent.data_broker.send_variable(
585 variable=var.copy(update={"source": self.source}),
586 copy=False,
587 )
589 def update_variables(self, variables: List[AgentVariable], timestamp: float = None):
590 """
591 Updates the given list of variables in the current data_broker.
592 If a given Variable is not in the config of the module, an
593 error is raised.
594 TODO: check if this is needed, we currently don't use it anywhere
596 Args:
597 variables: List with agent_variables.
598 timestamp: The timestamp associated with the variable.
599 If None, current environment time is used.
600 """
601 if timestamp is None:
602 timestamp = self.env.time
604 for v in variables:
605 if v.name not in self._variables_dict:
606 raise ValueError(
607 f"'{self.__class__.__name__}' has "
608 f"no AgentVariable with the name '{v.name}' "
609 f"in the config."
610 )
611 self.set(name=v.name, value=v.value, timestamp=timestamp)
613 ############################################################################
614 # Private and or static class methods
615 ############################################################################
617 def _update_relevant_values(
618 self, variable: AgentVariable, value: Any, timestamp: float = None
619 ):
620 """
621 Update the given variables fields
622 with the given value (and possibly timestamp)
623 Args:
624 variable (AgentVariable): The variable to be updated.
625 value (Any): Any value to set to the Variable
626 timestamp (float): The timestamp associated with the variable.
627 If None, current environment time is used.
629 Returns:
630 AgentVariable: The updated variable
631 """
632 # Update value
633 variable.value = value
634 # Update timestamp
635 if timestamp is None:
636 timestamp = self.env.time
637 variable.timestamp = timestamp
638 # Return updated variable
639 return variable
641 def _callback_config_vars(self, variable: AgentVariable, name: str):
642 """
643 Callback to update the AgentVariables of the module defined in the
644 config.
646 Args:
647 variable: Variable sent by data broker
648 name: Name of the variable in own config
649 """
650 own_var = self._variables_dict[name]
651 value = deepcopy(variable.value)
652 own_var.set_value(value=value, validate=self.config.validate_incoming_values)
653 own_var.timestamp = variable.timestamp
655 @staticmethod
656 def _copy_list_to_dict(ls: List[AgentVariable]):
657 # pylint: disable=invalid-name
658 return {var.name: var for var in ls.copy()}
660 def get_results(self):
661 """
662 Returns results of this modules run.
664 Override this method, if your module creates data that you would like to obtain
665 after the run.
667 Returns:
668 Some form of results data, often in the form of a pandas DataFrame.
669 """
671 def cleanup_results(self):
672 """
673 Deletes all files this module created.
675 Override this method, if your module creates e.g. results files etc.
676 """