Source code for agentlib.core.module

"""This module contains the base AgentModule."""

from __future__ import annotations

import abc
import json
import logging
from copy import deepcopy
from typing import (
    TYPE_CHECKING,
    List,
    Dict,
    Union,
    Any,
    TypeVar,
    Optional,
    get_type_hints,
    Type,
)

import pydantic
from pydantic import field_validator, ConfigDict, BaseModel, Field, PrivateAttr
from pydantic.json_schema import GenerateJsonSchema
from pydantic_core import core_schema

import agentlib.core.logging_ as agentlib_logging
from agentlib.core import datamodels
from agentlib.core.datamodels import (
    AgentVariable,
    Source,
    AgentVariables,
    AttrsToPydanticAdaptor,
)
from agentlib.core.environment import CustomSimpyEnvironment
from agentlib.core.errors import ConfigurationError
from agentlib.utils.fuzzy_matching import fuzzy_match, RAPIDFUZZ_IS_INSTALLED
from agentlib.utils.validators import (
    include_defaults_in_root,
    update_default_agent_variable,
    is_list_of_agent_variables,
    is_valid_agent_var_config,
)

if TYPE_CHECKING:
    # this avoids circular import
    from agentlib.core import Agent


logger = logging.getLogger(__name__)


[docs]class BaseModuleConfig(BaseModel): """ Pydantic data model for basic module configuration """ # The type is relevant to load the correct module class. type: Union[str, Dict[str, str]] = Field( title="Type", description="The type of the Module. Used to find the Python-Object " "from all agentlib-core and plugin Module options. If a dict is given," "it must contain the keys 'file' and 'class_name'. " "'file' is the filepath of a python file containing the Module." "'class_name' is the name of the Module class within this file.", ) # A module is uniquely identified in the MAS using agent_id and module_id. # The module_id should be unique inside one agent. # This is checked inside the agent-class. module_id: str = Field( description="The unqiue id of the module within an agent, " "used only to communicate withing the agent." ) validate_incoming_values: Optional[bool] = Field( default=True, title="Validate Incoming Values", description="If true, the validator of the AgentVariable value is called when " "receiving a new value from the DataBroker.", ) log_level: Optional[str] = Field( default=None, description="The log level for this Module. " "Default uses the root-loggers level." "Options: DEBUG; INFO; WARNING; ERROR; CRITICAL", ) shared_variable_fields: List[str] = Field( default=[], description="A list of strings with each string being a field of the Modules configs. " "The field must be or contain an AgentVariable. If the field is added to this list, " "all shared attributes of the AgentVariables will be set to True.", validate_default=True, ) # Aggregation of all instances of an AgentVariable in this Config _variables: AgentVariables = PrivateAttr(default=[]) # The config given by the user to instantiate this class. # Will be stored to enable a valid overwriting of the # default config and to better restart modules. # Is also useful to debug validators and the general BaseModuleConfig. _user_config: dict = PrivateAttr(default=None) model_config = ConfigDict( arbitrary_types_allowed=True, validate_assignment=True, extra="forbid", frozen=True, )
[docs] def get_variables(self): """Return the private attribute with all AgentVariables""" return self._variables
[docs] @classmethod def model_json_schema(cls, *args, **kwargs) -> dict: """ Custom schema method to - Add JSON Schema for custom attrs types Source and AgentVariable - put log_level last, as it is the only optional field of the module config. Used to better display relevant options of children classes in GUIs. """ if "schema_generator" in kwargs: raise ValueError("Custom schema_generator is not supported for BaseModule.") class CustomGenerateJsonSchema(GenerateJsonSchema): """ This class in necessary, as the default object type AttrsToPydanticAdaptor (e.g. Source, AgentVariable) are not json serializable by default. """ def default_schema(self, schema: core_schema.WithDefaultSchema): if "default" in schema: _default = schema["default"] if isinstance(_default, AttrsToPydanticAdaptor): schema["default"] = _default.json() return super().default_schema(schema=schema) kwargs["schema_generator"] = CustomGenerateJsonSchema schema = super().model_json_schema(*args, **kwargs) definitions = schema.get("$defs", {}) definitions_out = {} for class_name, metadata in definitions.items(): if class_name in datamodels.ATTRS_MODELS: class_object: AttrsToPydanticAdaptor = getattr(datamodels, class_name) metadata = class_object.get_json_schema() definitions_out[class_name] = metadata if definitions_out: schema["$defs"] = definitions_out log_level = schema["properties"].pop("log_level") shared_variable_fields = schema["properties"].pop("shared_variable_fields") schema["properties"]["shared_variable_fields"] = shared_variable_fields schema["properties"]["log_level"] = log_level return schema
[docs] @classmethod def check_if_variables_are_unique(cls, names): """Check if a given iterable of AgentVariables have a unique name.""" if len(names) != len(set(names)): for name in set(names.copy()): names.remove(name) raise ValueError( f"{cls.__name__} contains variables with the same name. The " f"following appear at least twice: {' ,'.join(names)}" )
[docs] @field_validator("shared_variable_fields") @classmethod def check_valid_fields(cls, shared_variables_fields): """ Check if the shared_variables_fields are valid fields. """ wrong_public_fields = set(shared_variables_fields).difference( cls.model_fields.keys() ) if wrong_public_fields: raise ConfigurationError( f"Public fields {wrong_public_fields} do not exist. Maybe you " f"misspelled them?" ) return shared_variables_fields
[docs] @field_validator("log_level") @classmethod def check_valid_level(cls, log_level: str): """ Check if the given log_level is valid """ if log_level is None: return log_level log_level = log_level.upper() if not isinstance(logging.getLevelName(log_level), int): raise ValueError( f"Given log level '{log_level}' is not " f"supported by logging library." ) return log_level
[docs] @classmethod def merge_variables( cls, pre_validated_instance: BaseModuleConfig, user_config: dict, agent_id: str, shared_variable_fields: List[str], ): """ Merge, rigorously check and validate the input of all AgentVariables into the module. This function: - Collects all variables - Checks if duplicate names (will cause errors in the get() function. """ _vars = [] # Extract all variables from fields for field_name, field in cls.model_fields.items(): # If field is missing in values, validation of field was not # successful. Continue and pydantic will later raise the ValidationError if field_name not in pre_validated_instance.model_fields: continue pre_merged_attr = pre_validated_instance.__getattribute__(field_name) # we need the type if plugins subclass the AgentVariable if isinstance(pre_merged_attr, AgentVariable): update_var_with = user_config.get(field_name, {}) make_shared = field_name in shared_variable_fields var = update_default_agent_variable( default_var=field.default, user_data=update_var_with, make_shared=make_shared, agent_id=agent_id, field_name=field_name, ) _vars.append(var) pre_validated_instance.__setattr__(field_name, var) elif is_list_of_agent_variables(pre_merged_attr): user_config_var_dicts = user_config.get(field_name, []) type_ = pre_merged_attr[0].__class__ update_vars_with = [ conf for conf in user_config_var_dicts if is_valid_agent_var_config(conf, field_name, type_) ] make_shared = field_name in shared_variable_fields variables = include_defaults_in_root( variables=update_vars_with, field=field, type_=type_, # subtype of AgentVariable make_shared=make_shared, agent_id=agent_id, field_name=field_name, ) _vars.extend(variables) pre_validated_instance.__setattr__(field_name, variables) # Extract names variable_names = [var.name for var in _vars] # First check if names exists more than once cls.check_if_variables_are_unique(names=variable_names) for _var in _vars: # case the agent id is a different agent if (_var.source.agent_id != agent_id) and ( _var.source.module_id is not None ): logger.warning( "Setting given module_id '%s' in variable '%s' to None. " "You can not specify module_ids of other agents.", _var.source.module_id, _var.name, ) _var.source = Source(agent_id=_var.source.agent_id) return _vars
[docs] @classmethod def default(cls, field: str): return cls.model_fields[field].get_default()
def __init__(self, _agent_id, *args, **kwargs): _user_config = kwargs.copy() try: super().__init__(*args, **kwargs) except pydantic.ValidationError as e: better_error = self._improve_extra_field_error_messages(e) raise better_error # Enable mutation self.model_config["frozen"] = False self._variables = self.__class__.merge_variables( pre_validated_instance=self, user_config=_user_config, agent_id=_agent_id, shared_variable_fields=self.shared_variable_fields, ) self._user_config = _user_config # Disable mutation self.model_config["frozen"] = True @classmethod def _improve_extra_field_error_messages( cls, e: pydantic.ValidationError ) -> pydantic.ValidationError: """Checks the validation errors for invalid fields and adds suggestions for correct field names to the error message.""" if not RAPIDFUZZ_IS_INSTALLED: return e error_list = e.errors() for error in error_list: if not error["type"] == "extra_forbidden": continue # change error type to literal because it allows for context error["type"] = "literal_error" # pydantic automatically prints the __dict__ of an error, so it is # sufficient to just assign the suggestions to an arbitrary attribute of # the error suggestions = fuzzy_match( target=error["loc"][0], choices=cls.model_fields.keys() ) if suggestions: error["ctx"] = { "expected": f"a valid Field name. Field '{error['loc'][0]}' does " f"not exist. Did you mean any of {suggestions}?" } return pydantic.ValidationError.from_exception_data( title=e.title, line_errors=error_list )
BaseModuleConfigClass = TypeVar("BaseModuleConfigClass", bound=BaseModuleConfig)
[docs]class BaseModule(abc.ABC): """ Basic module used by any agent. Besides a common configuration, where ids and variables are defined, this class manages the setting and getting of variables and relevant attributes. """ # pylint: disable=too-many-public-methods def __init__(self, *, config: dict, agent: Agent): self._agent = agent self.logger = agentlib_logging.create_logger( env=self.env, name=f"{self.agent.id}/{config['module_id']}" ) self.config = config # evokes the config setter # Add process to environment self.env.process(self.process()) self.register_callbacks() ############################################################################ # Methods to inherit by subclasses ############################################################################
[docs] @classmethod def get_config_type(cls) -> Type[BaseModuleConfigClass]: if hasattr(cls, "config_type"): raise AttributeError( "The 'config_type' attribute is deprecated and has been removed. " "Please use the following syntax to assign the config of your custom " f"module '{cls.__name__}': \n" "class MyModule(agentlib.BaseModule):\n" " config: MyConfigClass\n" ) return get_type_hints(cls).get("config")
[docs] @abc.abstractmethod def register_callbacks(self): raise NotImplementedError("Needs to be implemented by derived modules")
[docs] @abc.abstractmethod def process(self): """This abstract method must be implemented in order to sync the module with the other processes of the agent and the whole MAS.""" raise NotImplementedError("Needs to be implemented by derived modules")
[docs] def terminate(self): """ Terminate all relevant processes of the module. This is necessary to correctly terminate an agent at runtime. Not all modules may need this, hence it is not an abstract method. """ self.logger.info( "Successfully terminated module %s in agent %s", self.id, self.agent.id )
############################################################################ # Properties ############################################################################ @property def agent(self) -> Agent: """Get the agent this module is located in.""" return self._agent @property def config(self) -> BaseModuleConfigClass: """ The module config. Returns: BaseModuleConfigClass: Config of type self.config_type """ return self._config @config.setter def config(self, config: Union[BaseModuleConfig, dict, str]): """Set a new config""" if self.get_config_type() is None: raise ConfigurationError( "The module has no valid config. Please make sure you " "specify the class attribute 'config' when writing your module." ) if isinstance(config, str): config = json.loads(config) self._config = self.get_config_type()(_agent_id=self.agent.id, **config) # Update variables: self._variables_dict: Dict[str, AgentVariable] = self._copy_list_to_dict( self.config.get_variables() ) # Now de-and re-register all callbacks: self._register_variable_callbacks() # Set log-level if self.config.log_level is not None: if not logging.getLogger().hasHandlers(): _root_lvl_int = logging.getLogger().level _log_lvl_int = logging.getLevelName(self.config.log_level) if _log_lvl_int < _root_lvl_int: self.logger.error( "Log level '%s' is below root loggers level '%s'. " "Without calling logging.basicConfig, " "logs will not be printed.", self.config.log_level, logging.getLevelName(_root_lvl_int), ) self.logger.setLevel(self.config.log_level) # Call the after config update: self._after_config_update() def _after_config_update(self): """ This function is called after the config of the module is updated. Overwrite this function to enable custom behaviour after your config is updated. For instance, a simulator may re-initialize it's model, or a coordinator in an ADMM-MAS send new settings to the participants. Returns nothing, the config is immutable """ def _register_variable_callbacks(self): """ This functions de-registers and then re-registers callbacks for all variables of the module to update their specific values. """ # Keep everything in THAT order!! for name, var in self._variables_dict.items(): self.agent.data_broker.deregister_callback( alias=var.alias, source=var.source, callback=self._callback_config_vars, name=name, ) for name, var in self._variables_dict.items(): self.agent.data_broker.register_callback( alias=var.alias, source=var.source, callback=self._callback_config_vars, name=name, _unsafe_no_copy=True, ) @property def env(self) -> CustomSimpyEnvironment: """Get the environment of the agent.""" return self.agent.env @property def id(self) -> str: """Get the module's id""" return self.config.module_id @property def source(self) -> Source: """Get the source of the module, containing the agent and module id""" return Source(agent_id=self.agent.id, module_id=self.id) @property def variables(self) -> List[AgentVariable]: """Return all values as a list.""" return [v.copy() for v in self._variables_dict.values()] ############################################################################ # Get, set and updaters ############################################################################
[docs] def get(self, name: str) -> AgentVariable: """ Get any variable matching the given name: Args: name (str): The item to get by name of Variable. Hence, item=AgentVariable.name Returns: var (AgentVariable): The matching variable Raises: KeyError: If the item was not found in the variables of the module. """ try: return self._variables_dict[name].copy() except KeyError as err: raise KeyError( f"'{self.__class__.__name__}' has " f"no AgentVariable with the name '{name}' " f"in the configs variables." ) from err
[docs] def get_value(self, name: str) -> Any: """ Get the value of the variable matching the given name: Args: name (str): The item to get by name of Variable. Hence, item=AgentVariable.name Returns: var (Any): The matching value Raises: KeyError: If the item was not found in the variables of the module. """ try: return deepcopy(self._variables_dict[name].value) except KeyError as err: raise KeyError( f"'{self.__class__.__name__}' has " f"no AgentVariable with the name '{name}' " f"in the configs variables." ) from err
[docs] def set(self, name: str, value: Any, timestamp: float = None): """ Set any variable by using the name: Args: name (str): The item to get by name of Variable. Hence, item=AgentVariable.name value (Any): Any value to set to the Variable timestamp (float): The timestamp associated with the variable. If None, current environment time is used. Raises: AttributeError: If the item was not found in the variables of the module. """ # var = self.get(name) var = self._variables_dict[name] var = self._update_relevant_values( variable=var, value=value, timestamp=timestamp ) self.agent.data_broker.send_variable( variable=var.copy(update={"source": self.source}), copy=False, )
[docs] def update_variables(self, variables: List[AgentVariable], timestamp: float = None): """ Updates the given list of variables in the current data_broker. If a given Variable is not in the config of the module, an error is raised. TODO: check if this is needed, we currently don't use it anywhere Args: variables: List with agent_variables. timestamp: The timestamp associated with the variable. If None, current environment time is used. """ if timestamp is None: timestamp = self.env.time for v in variables: if v.name not in self._variables_dict: raise ValueError( f"'{self.__class__.__name__}' has " f"no AgentVariable with the name '{v.name}' " f"in the config." ) self.set(name=v.name, value=v.value, timestamp=timestamp)
############################################################################ # Private and or static class methods ############################################################################ def _update_relevant_values( self, variable: AgentVariable, value: Any, timestamp: float = None ): """ Update the given variables fields with the given value (and possibly timestamp) Args: variable (AgentVariable): The variable to be updated. value (Any): Any value to set to the Variable timestamp (float): The timestamp associated with the variable. If None, current environment time is used. Returns: AgentVariable: The updated variable """ # Update value variable.value = value # Update timestamp if timestamp is None: timestamp = self.env.time variable.timestamp = timestamp # Return updated variable return variable def _callback_config_vars(self, variable: AgentVariable, name: str): """ Callback to update the AgentVariables of the module defined in the config. Args: variable: Variable sent by data broker name: Name of the variable in own config """ own_var = self._variables_dict[name] value = deepcopy(variable.value) own_var.set_value(value=value, validate=self.config.validate_incoming_values) own_var.timestamp = variable.timestamp @staticmethod def _copy_list_to_dict(ls: List[AgentVariable]): # pylint: disable=invalid-name return {var.name: var for var in ls.copy()}
[docs] def get_results(self): """ Returns results of this modules run. Override this method, if your module creates data that you would like to obtain after the run. Returns: Some form of results data, often in the form of a pandas DataFrame. """
[docs] def cleanup_results(self): """ Deletes all files this module created. Override this method, if your module creates e.g. results files etc. """