Coverage for agentlib/core/data_broker.py: 93%
177 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-10-30 13:39 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-10-30 13:39 +0000
1"""
2The module contains the relevant classes
3to execute and use the DataBroker.
4Besides the DataBroker itself, the BrokerCallback is defined.
6Internally, uses the tuple _map_tuple in the order of
8(alias, source)
10to match callbacks and AgentVariables.
12"""
14import abc
15import inspect
16import logging
17import queue
18import threading
19from typing import (
20 List,
21 Callable,
22 Dict,
23 Tuple,
24 Optional,
25 Union,
26)
28from pydantic import BaseModel, field_validator, model_validator, ConfigDict
30from agentlib.core.datamodels import AgentVariable, Source
31from agentlib.core.environment import Environment
32from agentlib.core.logging_ import CustomLogger
33from agentlib.core.module import BaseModule
36class NoCopyBrokerCallback(BaseModel):
37 """
38 Basic broker callback.
39 This object does not copy the AgentVariable
40 before calling the callback, which can be unsafe.
42 This class checks if the given callback function
43 adheres to the signature it needs to be correctly called.
44 The first argument will be an AgentVariable. If a type-hint
45 is specified, it must be `AgentVariable` or `"AgentVariable"`.
46 Any further arguments must match the kwargs
47 specified in the class and will also be the ones you
48 pass to this class.
50 Example:
51 >>> def my_callback(variable: "AgentVariable", some_static_info: str):
52 >>> print(variable, some_other_info)
53 >>> NoCopyBrokerCallback(
54 >>> callback=my_callback,
55 >>> kwargs={"some_static_info": "Hello World"}
56 >>> )
58 """
60 # pylint: disable=too-few-public-methods
61 callback: Callable
62 alias: Optional[str] = None
63 source: Optional[Source] = None
64 kwargs: dict = {}
65 model_config = ConfigDict(arbitrary_types_allowed=True)
66 module_id: Optional[str] = None
68 @model_validator(mode="before")
69 @classmethod
70 def check_valid_callback_function(cls, data):
71 """Ensures the callback function signature is valid."""
72 func_params = dict(inspect.signature(data["callback"]).parameters)
73 par = func_params.pop(next(iter(func_params)))
74 if par.annotation is not par.empty and par.annotation not in (
75 "AgentVariable",
76 AgentVariable,
77 ):
78 raise RuntimeError(
79 "Defined callback Function does not take an "
80 "AgentVariable as first parameter"
81 )
83 if not list(data["kwargs"]) == list(func_params):
84 kwargs_not_in_function_args = set(list(data["kwargs"])).difference(
85 list(func_params)
86 )
87 function_args_not_in_kwargs = set(list(func_params)).difference(
88 list(data["kwargs"])
89 )
90 if function_args_not_in_kwargs:
91 missing_kwargs = "Missing arguments in kwargs: " + ", ".join(
92 function_args_not_in_kwargs
93 )
94 else:
95 missing_kwargs = ""
96 if kwargs_not_in_function_args:
97 missing_func_args = "Missing kwargs in function call: " + ", ".join(
98 kwargs_not_in_function_args
99 )
100 else:
101 missing_func_args = ""
102 raise RuntimeError(
103 "The registered Callback secondary arguments do not match the given kwargs:\n"
104 f"{missing_func_args}\n"
105 f"{missing_kwargs}"
106 )
107 # note from which module this callback came. If it is not a bound method, we
108 # assign it to none
109 try:
110 if isinstance(data["callback"].__self__, BaseModule):
111 module_id = data["callback"].__self__.id
112 else:
113 module_id = None
114 except AttributeError:
115 module_id = None
116 data["module_id"] = module_id
117 return data
119 def __eq__(self, other: "NoCopyBrokerCallback"):
120 """
121 Check equality to another callback using equality of all fields
122 and the name of the callback function
123 """
124 return (
125 self.alias,
126 self.source,
127 self.kwargs,
128 self.callback.__name__,
129 self.module_id,
130 ) == (
131 other.alias,
132 other.source,
133 other.kwargs,
134 other.callback.__name__,
135 other.module_id,
136 )
139class BrokerCallback(NoCopyBrokerCallback):
140 """
141 This broker callback always creates a deep-copy of the
142 AgentVariable it is going to send.
143 It is considered the safer option, as the receiving module
144 only get's the values and is not able to alter
145 the AgentVariable for other modules.
146 """
148 @field_validator("callback")
149 @classmethod
150 def auto_copy(cls, callback_func: Callable):
151 """Automatically supply the callback function with a copy"""
153 def callback_copy(variable: AgentVariable, **kwargs):
154 callback_func(variable.copy(deep=True), **kwargs)
156 callback_copy.__name__ = callback_func.__name__
157 return callback_copy
160class DataBroker(abc.ABC):
161 """
162 Handles communication and Callback triggers within an agent.
163 Write variables to the broker with ``send_variable()``.
164 Variables send to the broker will trigger callbacks
165 based on the alias and the source of the variable.
166 Commonly, this is used to provide other
167 modules with the variable.
169 Register and de-register Callbacks to the DataBroker
170 with ``register_callback`` and ``deregister_callback``.
171 """
173 def __init__(self, logger: CustomLogger):
174 """
175 Initialize lock, callbacks and entries
176 """
177 self.logger = logger
178 self._mapped_callbacks: Dict[Tuple[str, Source], List[BrokerCallback]] = {}
179 self._unmapped_callbacks: List[BrokerCallback] = []
181 def send_variable(self, variable: AgentVariable, copy: bool = True):
182 """
183 Send variable to data_broker. Evokes callbacks associated with this variable.
185 Args:
186 variable AgentVariable:
187 The variable to set.
188 copy boolean:
189 Whether to copy the variable before sending.
190 Default is True.
191 """
192 if copy:
193 self._send_variable_to_modules(variable=variable.copy(deep=True))
194 else:
195 self._send_variable_to_modules(variable=variable)
197 @abc.abstractmethod
198 def _send_variable_to_modules(self, variable: AgentVariable):
199 """
200 Enqueue AgentVariable in local queue for executing relevant callbacks.
202 Args:
203 variable AgentVariable: The variable to append to the local queue.
204 """
205 raise NotImplementedError
207 def _get_variable_callbacks(self, variable: AgentVariable):
208 """
209 Helper function to get all callbacks associated with a given variable
210 """
211 _map_tuple = (variable.alias, variable.source)
212 # First the unmapped cbs
213 callbacks = self._filter_unmapped_callbacks(map_tuple=_map_tuple)
214 # Then the mapped once.
215 # Use try-except to avoid possible deregister during check and execution
216 try:
217 callbacks.extend(self._mapped_callbacks[_map_tuple])
218 except KeyError:
219 pass
220 return callbacks
222 def _filter_unmapped_callbacks(self, map_tuple: tuple) -> List[BrokerCallback]:
223 """
224 Filter the unmapped callbacks according to the given
225 tuple of variable information.
227 Args:
228 map_tuple tuple:
229 The tuple of alias and source in that order
231 Returns:
232 List[BrokerCallback]: The filtered list
234 """
235 # Filter all callbacks matching the given variable
236 callbacks = self._unmapped_callbacks
237 # First filter source
238 source = map_tuple[1]
239 callbacks = [
240 cb for cb in callbacks if (cb.source is None) or (cb.source.matches(source))
241 ]
242 # Now alias
243 callbacks = [
244 cb for cb in callbacks if (cb.alias is None) or (cb.alias == map_tuple[0])
245 ]
247 return callbacks
249 def register_callback(
250 self,
251 callback: Callable,
252 alias: str = None,
253 source: Source = None,
254 _unsafe_no_copy: bool = False,
255 **kwargs,
256 ) -> Union[BrokerCallback, NoCopyBrokerCallback]:
257 """
258 Register a callback to the data_broker.
260 Args:
261 callback callable: The function of the callback
262 alias str: The alias of variables to trigger callback
263 source Source: The Source of variables to trigger callback
264 kwargs dict: Kwargs to be passed to the callback function
265 _unsafe_no_copy: If True, the callback will not be passed a copy, but the
266 original AgentVariable. When using this option, the user promises to not
267 modify the AgentVariable, as doing so could lead to
268 wrong and difficult to debug behaviour in other modules (default False)
269 """
270 if _unsafe_no_copy:
271 callback_ = NoCopyBrokerCallback(
272 alias=alias, source=source, callback=callback, kwargs=kwargs
273 )
274 else:
275 callback_ = BrokerCallback(
276 alias=alias, source=source, callback=callback, kwargs=kwargs
277 )
278 _map_tuple = (alias, source)
279 if self.any_is_none(alias=alias, source=source):
280 self._unmapped_callbacks.append(callback_)
281 elif _map_tuple in self._mapped_callbacks:
282 self._mapped_callbacks[_map_tuple].append(callback_)
283 else:
284 self._mapped_callbacks[_map_tuple] = [callback_]
285 return callback_
287 def deregister_callback(
288 self, callback: Callable, alias: str = None, source: Source = None, **kwargs
289 ):
290 """
291 Deregister the given callback based on given
292 alias and source.
294 Args:
295 callback callable: The function of the callback
296 alias str: The alias of variables to trigger callback
297 source Source: The Source of variables to trigger callback
298 kwargs dict: Kwargs of the callback function
299 """
300 try:
301 callback = BrokerCallback(
302 alias=alias, source=source, callback=callback, kwargs=kwargs
303 )
304 _map_tuple = (alias, source)
305 if self.any_is_none(alias=alias, source=source):
306 self._unmapped_callbacks.remove(callback)
307 elif _map_tuple in self._mapped_callbacks:
308 self._mapped_callbacks[_map_tuple].remove(callback)
309 else:
310 return # No delete necessary
311 self.logger.debug("Callback de-registered: %s", callback)
312 except ValueError:
313 pass
315 @staticmethod
316 def any_is_none(alias: str, source: Source) -> bool:
317 """
318 Return True if any of alias or source are None.
320 Args:
321 alias str:
322 The alias of the callback
323 source Source:
324 The Source of the callback
325 """
326 return (
327 (alias is None)
328 or (source is None)
329 or (source.agent_id is None)
330 or (source.module_id is None)
331 )
333 @staticmethod
334 def _run_callbacks(callbacks: List[BrokerCallback], variable: AgentVariable):
335 """Runs the callbacks on a single AgentVariable."""
336 raise NotImplementedError
339class DirectCallbackDataBroker(DataBroker):
340 """
341 This DataBroker directly executes all callbacks.
342 This may lead to infinite recursion loops, if two callbacks trigger
343 each other when being triggered, for example.
344 However, using this class, you can directly "follow" your variable
345 from module to other modules or agents.
346 """
348 def _send_variable_to_modules(self, variable: AgentVariable):
349 """
350 Directly execute all callbacks for the given variable.
352 Args:
353 variable AgentVariable: The variable to append to the local queue.
354 """
355 callbacks = self._get_variable_callbacks(variable)
356 for cb in callbacks:
357 cb.callback(variable, **cb.kwargs)
360class QueuedCallbackDataBroker(DataBroker):
362 def __init__(self, logger: CustomLogger, max_queue_size: int = 1000):
363 """
364 Initialize lock, callbacks and entries
365 """
366 super().__init__(logger=logger)
367 self._max_queue_size = max_queue_size
368 self._variable_queue = queue.Queue(maxsize=max_queue_size)
370 def _send_variable_to_modules(self, variable: AgentVariable):
371 """
372 Enqueue AgentVariable in local queue for executing relevant callbacks.
374 Args:
375 variable AgentVariable: The variable to append to the local queue.
376 """
377 self._variable_queue.put(variable)
379 def _execute_callbacks(self):
380 """
381 Run relevant callbacks for AgentVariable's from local queue.
382 """
383 variable = self._variable_queue.get(block=True)
384 log_queue_status(
385 logger=self.logger,
386 queue_name="Callback-Distribution",
387 queue_object=self._variable_queue,
388 max_queue_size=self._max_queue_size,
389 )
390 callbacks = self._get_variable_callbacks(variable)
391 # Then run the callbacks
392 self._run_callbacks(callbacks, variable)
395class LocalDataBroker(QueuedCallbackDataBroker):
396 """Local variation of the DataBroker written for fast-as-possible
397 simulation within a single non-realtime Environment."""
399 def __init__(
400 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000
401 ):
402 """
403 Initialize env
404 """
405 self.env = env
406 super().__init__(logger=logger, max_queue_size=max_queue_size)
407 self._callbacks_available = self.env.event()
409 def _send_variable_to_modules(self, variable: AgentVariable):
410 """
411 Enqueue AgentVariable in local queue for executing relevant callbacks.
413 Args:
414 variable AgentVariable: The variable to append to the local queue.
415 """
416 super()._send_variable_to_modules(variable)
417 self._callbacks_available.callbacks.append(self._execute_callback_simpy)
418 self._callbacks_available.succeed()
419 self._callbacks_available = self.env.event()
421 def _execute_callback_simpy(self, ignored):
422 """
423 Run relevant callbacks for AgentVariable's from local queue.
424 To be appended to the callback of the callbacks available event.
425 """
426 self._execute_callbacks()
428 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable):
429 """Runs callbacks of an agent on a single AgentVariable in sequence.
430 Used in fast-as-possible execution mode."""
431 for cb in callbacks:
432 cb.callback(variable, **cb.kwargs)
435class RTDataBroker(QueuedCallbackDataBroker):
436 """DataBroker written for Realtime operation regardless of Environment."""
438 def __init__(
439 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000
440 ):
441 """
442 Initialize env.
443 Adds the function to start callback execution to the environment as a process.
444 Since the databroker is initialized before the modules, this will always be
445 the first triggered event, so no other process starts before the broker is
446 ready
447 """
448 super().__init__(logger=logger, max_queue_size=max_queue_size)
449 self._stop_queue = queue.SimpleQueue()
450 self.thread = threading.Thread(
451 target=self._callback_thread, daemon=True, name="DataBroker"
452 )
453 self._module_queues: dict[Union[str, None], queue.Queue] = {}
455 env.process(self._start_executing_callbacks(env))
457 def _start_executing_callbacks(self, env: Environment):
458 """
459 Starts the callback thread.
460 Thread is started after it is registered by the agent. Should be fine, since
461 the monitor process is started after the process in this function
462 """
463 self.thread.start()
464 yield env.event()
466 def _callback_thread(self):
467 """Thread to check and process the callback queue in Realtime
468 applications."""
469 while True:
470 if not self._stop_queue.empty():
471 err, module_id = self._stop_queue.get()
472 raise RuntimeError(
473 f"A callback failed in the module {module_id}."
474 ) from err
475 self._execute_callbacks()
477 def register_callback(
478 self,
479 callback: Callable,
480 alias: str = None,
481 source: Source = None,
482 _unsafe_no_copy: bool = False,
483 **kwargs,
484 ) -> Union[NoCopyBrokerCallback, BrokerCallback]:
485 # check to which object the callable is bound, to determine the module
486 callback = super().register_callback(
487 callback=callback,
488 alias=alias,
489 source=source,
490 _unsafe_no_copy=_unsafe_no_copy,
491 **kwargs,
492 )
493 if callback.module_id not in self._module_queues:
494 self._start_module_thread(callback.module_id)
495 return callback
497 def _start_module_thread(self, module_id: str):
498 """Starts a consumer thread for callbacks registered from a module."""
499 module_queue = queue.Queue(maxsize=self._max_queue_size)
500 threading.Thread(
501 target=self._execute_callbacks_of_module,
502 daemon=True,
503 name=f"DataBroker/{module_id}",
504 kwargs={"queue": module_queue, "module_id": module_id},
505 ).start()
506 self._module_queues[module_id] = module_queue
508 def _execute_callbacks_of_module(self, queue: queue.SimpleQueue, module_id: str):
509 """Executes the callbacks associated with a specific module."""
510 try:
511 while True:
512 cb, variable = queue.get(block=True)
513 cb.callback(variable=variable, **cb.kwargs)
514 except Exception as e:
515 self._stop_queue.put((e, module_id))
516 raise e
518 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable):
519 """Distributes callbacks to the threads running for each module."""
520 for cb in callbacks:
521 self._module_queues[cb.module_id].put_nowait((cb, variable))
522 log_queue_status(
523 logger=self.logger,
524 queue_name=cb.module_id,
525 queue_object=self._module_queues[cb.module_id],
526 max_queue_size=self._max_queue_size,
527 )
530def log_queue_status(
531 logger: logging.Logger,
532 queue_object: queue.Queue,
533 max_queue_size: int,
534 queue_name: str,
535):
536 """
537 Log the current load of the given queue in percent.
539 Args:
540 logger (logging.Logger): A logger instance
541 queue_object (queue.Queue): The queue object
542 max_queue_size (int): Maximal queue size
543 queue_name (str): Name associated with the queue
544 """
545 if max_queue_size < 1:
546 return
547 number_of_items = queue_object.qsize()
548 percent_full = round(number_of_items / max_queue_size * 100, 2)
549 if percent_full < 10:
550 return
551 elif percent_full < 80:
552 logger_func = logger.debug
553 else:
554 logger_func = logger.warning
555 logger_func(
556 "Queue '%s' fullness is %s percent (%s items).",
557 queue_name,
558 percent_full,
559 number_of_items,
560 )