"""
The module contains the relevant classes
to execute and use the DataBroker.
Besides the DataBroker itself, the BrokerCallback is defined.
Internally, uses the tuple _map_tuple in the order of
(alias, source)
to match callbacks and AgentVariables.
"""
import abc
import inspect
import logging
import queue
import threading
from typing import (
List,
Callable,
Dict,
Tuple,
Optional,
Union,
)
from pydantic import BaseModel, field_validator, model_validator, ConfigDict
from agentlib.core.datamodels import AgentVariable, Source
from agentlib.core.environment import Environment
from agentlib.core.logging_ import CustomLogger
from agentlib.core.module import BaseModule
[docs]class NoCopyBrokerCallback(BaseModel):
"""
Basic broker callback.
This object does not copy the AgentVariable
before calling the callback, which can be unsafe.
This class checks if the given callback function
adheres to the signature it needs to be correctly called.
The first argument will be an AgentVariable. If a type-hint
is specified, it must be `AgentVariable` or `"AgentVariable"`.
Any further arguments must match the kwargs
specified in the class and will also be the ones you
pass to this class.
Example:
>>> def my_callback(variable: "AgentVariable", some_static_info: str):
>>> print(variable, some_other_info)
>>> NoCopyBrokerCallback(
>>> callback=my_callback,
>>> kwargs={"some_static_info": "Hello World"}
>>> )
"""
# pylint: disable=too-few-public-methods
callback: Callable
alias: Optional[str] = None
source: Optional[Source] = None
kwargs: dict = {}
model_config = ConfigDict(arbitrary_types_allowed=True)
module_id: Optional[str] = None
[docs] @model_validator(mode="before")
@classmethod
def check_valid_callback_function(cls, data):
"""Ensures the callback function signature is valid."""
func_params = dict(inspect.signature(data["callback"]).parameters)
par = func_params.pop(next(iter(func_params)))
if par.annotation is not par.empty and par.annotation not in (
"AgentVariable",
AgentVariable,
):
raise RuntimeError(
"Defined callback Function does not take an "
"AgentVariable as first parameter"
)
if not list(data["kwargs"]) == list(func_params):
kwargs_not_in_function_args = set(list(data["kwargs"])).difference(
list(func_params)
)
function_args_not_in_kwargs = set(list(func_params)).difference(
list(data["kwargs"])
)
if function_args_not_in_kwargs:
missing_kwargs = "Missing arguments in kwargs: " + ", ".join(
function_args_not_in_kwargs
)
else:
missing_kwargs = ""
if kwargs_not_in_function_args:
missing_func_args = "Missing kwargs in function call: " + ", ".join(
kwargs_not_in_function_args
)
else:
missing_func_args = ""
raise RuntimeError(
"The registered Callback secondary arguments do not match the given kwargs:\n"
f"{missing_func_args}\n"
f"{missing_kwargs}"
)
# note from which module this callback came. If it is not a bound method, we
# assign it to none
try:
if isinstance(data["callback"].__self__, BaseModule):
module_id = data["callback"].__self__.id
else:
module_id = None
except AttributeError:
module_id = None
data["module_id"] = module_id
return data
def __eq__(self, other: "NoCopyBrokerCallback"):
"""
Check equality to another callback using equality of all fields
and the name of the callback function
"""
return (
self.alias,
self.source,
self.kwargs,
self.callback.__name__,
self.module_id,
) == (
other.alias,
other.source,
other.kwargs,
other.callback.__name__,
other.module_id,
)
[docs]class BrokerCallback(NoCopyBrokerCallback):
"""
This broker callback always creates a deep-copy of the
AgentVariable it is going to send.
It is considered the safer option, as the receiving module
only get's the values and is not able to alter
the AgentVariable for other modules.
"""
[docs] @field_validator("callback")
@classmethod
def auto_copy(cls, callback_func: Callable):
"""Automatically supply the callback function with a copy"""
def callback_copy(variable: AgentVariable, **kwargs):
callback_func(variable.copy(deep=True), **kwargs)
callback_copy.__name__ = callback_func.__name__
return callback_copy
[docs]class DataBroker(abc.ABC):
"""
Handles communication and Callback triggers within an agent.
Write variables to the broker with ``send_variable()``.
Variables send to the broker will trigger callbacks
based on the alias and the source of the variable.
Commonly, this is used to provide other
modules with the variable.
Register and de-register Callbacks to the DataBroker
with ``register_callback`` and ``deregister_callback``.
"""
def __init__(self, logger: CustomLogger, max_queue_size: int = 1000):
"""
Initialize lock, callbacks and entries
"""
self.logger = logger
self._max_queue_size = max_queue_size
self._mapped_callbacks: Dict[Tuple[str, Source], List[BrokerCallback]] = {}
self._unmapped_callbacks: List[BrokerCallback] = []
self._variable_queue = queue.Queue(maxsize=max_queue_size)
[docs] def send_variable(self, variable: AgentVariable, copy: bool = True):
"""
Send variable to data_broker. Evokes callbacks associated with this variable.
Args:
variable AgentVariable:
The variable to set.
copy boolean:
Whether to copy the variable before sending.
Default is True.
"""
if copy:
self._send_variable_to_modules(variable=variable.copy(deep=True))
else:
self._send_variable_to_modules(variable=variable)
def _send_variable_to_modules(self, variable: AgentVariable):
"""
Enqueue AgentVariable in local queue for executing relevant callbacks.
Args:
variable AgentVariable: The variable to append to the local queue.
"""
self._variable_queue.put(variable)
def _execute_callbacks(self):
"""
Run relevant callbacks for AgentVariable's from local queue.
"""
variable = self._variable_queue.get(block=True)
log_queue_status(
logger=self.logger,
queue_name="Callback-Distribution",
queue_object=self._variable_queue,
max_queue_size=self._max_queue_size,
)
_map_tuple = (variable.alias, variable.source)
# First the unmapped cbs
callbacks = self._filter_unmapped_callbacks(map_tuple=_map_tuple)
# Then the mapped once.
# Use try-except to avoid possible deregister during check and execution
try:
callbacks.extend(self._mapped_callbacks[_map_tuple])
except KeyError:
pass
# Then run the callbacks
self._run_callbacks(callbacks, variable)
def _filter_unmapped_callbacks(self, map_tuple: tuple) -> List[BrokerCallback]:
"""
Filter the unmapped callbacks according to the given
tuple of variable information.
Args:
map_tuple tuple:
The tuple of alias and source in that order
Returns:
List[BrokerCallback]: The filtered list
"""
# Filter all callbacks matching the given variable
callbacks = self._unmapped_callbacks
# First filter source
source = map_tuple[1]
callbacks = [
cb for cb in callbacks if (cb.source is None) or (cb.source.matches(source))
]
# Now alias
callbacks = [
cb for cb in callbacks if (cb.alias is None) or (cb.alias == map_tuple[0])
]
return callbacks
[docs] def register_callback(
self,
callback: Callable,
alias: str = None,
source: Source = None,
_unsafe_no_copy: bool = False,
**kwargs,
) -> Union[BrokerCallback, NoCopyBrokerCallback]:
"""
Register a callback to the data_broker.
Args:
callback callable: The function of the callback
alias str: The alias of variables to trigger callback
source Source: The Source of variables to trigger callback
kwargs dict: Kwargs to be passed to the callback function
_unsafe_no_copy: If True, the callback will not be passed a copy, but the
original AgentVariable. When using this option, the user promises to not
modify the AgentVariable, as doing so could lead to
wrong and difficult to debug behaviour in other modules (default False)
"""
if _unsafe_no_copy:
callback_ = NoCopyBrokerCallback(
alias=alias, source=source, callback=callback, kwargs=kwargs
)
else:
callback_ = BrokerCallback(
alias=alias, source=source, callback=callback, kwargs=kwargs
)
_map_tuple = (alias, source)
if self.any_is_none(alias=alias, source=source):
self._unmapped_callbacks.append(callback_)
elif _map_tuple in self._mapped_callbacks:
self._mapped_callbacks[_map_tuple].append(callback_)
else:
self._mapped_callbacks[_map_tuple] = [callback_]
return callback_
[docs] def deregister_callback(
self, callback: Callable, alias: str = None, source: Source = None, **kwargs
):
"""
Deregister the given callback based on given
alias and source.
Args:
callback callable: The function of the callback
alias str: The alias of variables to trigger callback
source Source: The Source of variables to trigger callback
kwargs dict: Kwargs of the callback function
"""
try:
callback = BrokerCallback(
alias=alias, source=source, callback=callback, kwargs=kwargs
)
_map_tuple = (alias, source)
if self.any_is_none(alias=alias, source=source):
self._unmapped_callbacks.remove(callback)
elif _map_tuple in self._mapped_callbacks:
self._mapped_callbacks[_map_tuple].remove(callback)
else:
return # No delete necessary
self.logger.debug("Callback de-registered: %s", callback)
except ValueError:
pass
[docs] @staticmethod
def any_is_none(alias: str, source: Source) -> bool:
"""
Return True if any of alias or source are None.
Args:
alias str:
The alias of the callback
source Source:
The Source of the callback
"""
return (
(alias is None)
or (source is None)
or (source.agent_id is None)
or (source.module_id is None)
)
@staticmethod
def _run_callbacks(callbacks: List[BrokerCallback], variable: AgentVariable):
"""Runs the callbacks on a single AgentVariable."""
raise NotImplementedError
[docs]class LocalDataBroker(DataBroker):
"""Local variation of the DataBroker written for fast-as-possible
simulation within a single non-realtime Environment."""
def __init__(
self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000
):
"""
Initialize env
"""
self.env = env
super().__init__(logger=logger, max_queue_size=max_queue_size)
self._callbacks_available = self.env.event()
def _send_variable_to_modules(self, variable: AgentVariable):
"""
Enqueue AgentVariable in local queue for executing relevant callbacks.
Args:
variable AgentVariable: The variable to append to the local queue.
"""
super()._send_variable_to_modules(variable)
self._callbacks_available.callbacks.append(self._execute_callback_simpy)
self._callbacks_available.succeed()
self._callbacks_available = self.env.event()
def _execute_callback_simpy(self, ignored):
"""
Run relevant callbacks for AgentVariable's from local queue.
To be appended to the callback of the callbacks available event.
"""
self._execute_callbacks()
def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable):
"""Runs callbacks of an agent on a single AgentVariable in sequence.
Used in fast-as-possible execution mode."""
for cb in callbacks:
cb.callback(variable, **cb.kwargs)
[docs]class RTDataBroker(DataBroker):
"""DataBroker written for Realtime operation regardless of Environment."""
def __init__(
self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000
):
"""
Initialize env.
Adds the function to start callback execution to the environment as a process.
Since the databroker is initialized before the modules, this will always be
the first triggered event, so no other process starts before the broker is
ready
"""
super().__init__(logger=logger, max_queue_size=max_queue_size)
self._stop_queue = queue.SimpleQueue()
self.thread = threading.Thread(
target=self._callback_thread, daemon=True, name="DataBroker"
)
self._module_queues: dict[Union[str, None], queue.Queue] = {}
env.process(self._start_executing_callbacks(env))
def _start_executing_callbacks(self, env: Environment):
"""
Starts the callback thread.
Thread is started after it is registered by the agent. Should be fine, since
the monitor process is started after the process in this function
"""
self.thread.start()
yield env.event()
def _callback_thread(self):
"""Thread to check and process the callback queue in Realtime
applications."""
while True:
if not self._stop_queue.empty():
err, module_id = self._stop_queue.get()
raise RuntimeError(
f"A callback failed in the module {module_id}."
) from err
self._execute_callbacks()
[docs] def register_callback(
self,
callback: Callable,
alias: str = None,
source: Source = None,
_unsafe_no_copy: bool = False,
**kwargs,
) -> Union[NoCopyBrokerCallback, BrokerCallback]:
# check to which object the callable is bound, to determine the module
callback = super().register_callback(
callback=callback,
alias=alias,
source=source,
_unsafe_no_copy=_unsafe_no_copy,
**kwargs,
)
if callback.module_id not in self._module_queues:
self._start_module_thread(callback.module_id)
return callback
def _start_module_thread(self, module_id: str):
"""Starts a consumer thread for callbacks registered from a module."""
module_queue = queue.Queue(maxsize=self._max_queue_size)
threading.Thread(
target=self._execute_callbacks_of_module,
daemon=True,
name=f"DataBroker/{module_id}",
kwargs={"queue": module_queue, "module_id": module_id},
).start()
self._module_queues[module_id] = module_queue
def _execute_callbacks_of_module(self, queue: queue.SimpleQueue, module_id: str):
"""Executes the callbacks associated with a specific module."""
try:
while True:
cb, variable = queue.get(block=True)
cb.callback(variable=variable, **cb.kwargs)
except Exception as e:
self._stop_queue.put((e, module_id))
raise e
def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable):
"""Distributes callbacks to the threads running for each module."""
for cb in callbacks:
self._module_queues[cb.module_id].put_nowait((cb, variable))
log_queue_status(
logger=self.logger,
queue_name=cb.module_id,
queue_object=self._module_queues[cb.module_id],
max_queue_size=self._max_queue_size,
)
[docs]def log_queue_status(
logger: logging.Logger,
queue_object: queue.Queue,
max_queue_size: int,
queue_name: str,
):
"""
Log the current load of the given queue in percent.
Args:
logger (logging.Logger): A logger instance
queue_object (queue.Queue): The queue object
max_queue_size (int): Maximal queue size
queue_name (str): Name associated with the queue
"""
if max_queue_size < 1:
return
number_of_items = queue_object.qsize()
percent_full = round(number_of_items / max_queue_size * 100, 2)
if percent_full < 10:
return
elif percent_full < 80:
logger_func = logger.debug
else:
logger_func = logger.warning
logger_func(
"Queue '%s' fullness is %s percent (%s items).",
queue_name,
percent_full,
number_of_items,
)