Coverage for agentlib/core/data_broker.py: 93%
162 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-07 16:27 +0000
1"""
2The module contains the relevant classes
3to execute and use the DataBroker.
4Besides the DataBroker itself, the BrokerCallback is defined.
6Internally, uses the tuple _map_tuple in the order of
8(alias, source)
10to match callbacks and AgentVariables.
12"""
14import abc
15import inspect
16import logging
17import queue
18import threading
19from typing import (
20 List,
21 Callable,
22 Dict,
23 Tuple,
24 Optional,
25 Union,
26)
28from pydantic import BaseModel, field_validator, model_validator, ConfigDict
30from agentlib.core.datamodels import AgentVariable, Source
31from agentlib.core.environment import Environment
32from agentlib.core.logging_ import CustomLogger
33from agentlib.core.module import BaseModule
36class NoCopyBrokerCallback(BaseModel):
37 """
38 Basic broker callback.
39 This object does not copy the AgentVariable
40 before calling the callback, which can be unsafe.
42 This class checks if the given callback function
43 adheres to the signature it needs to be correctly called.
44 The first argument will be an AgentVariable. If a type-hint
45 is specified, it must be `AgentVariable` or `"AgentVariable"`.
46 Any further arguments must match the kwargs
47 specified in the class and will also be the ones you
48 pass to this class.
50 Example:
51 >>> def my_callback(variable: "AgentVariable", some_static_info: str):
52 >>> print(variable, some_other_info)
53 >>> NoCopyBrokerCallback(
54 >>> callback=my_callback,
55 >>> kwargs={"some_static_info": "Hello World"}
56 >>> )
58 """
60 # pylint: disable=too-few-public-methods
61 callback: Callable
62 alias: Optional[str] = None
63 source: Optional[Source] = None
64 kwargs: dict = {}
65 model_config = ConfigDict(arbitrary_types_allowed=True)
66 module_id: Optional[str] = None
68 @model_validator(mode="before")
69 @classmethod
70 def check_valid_callback_function(cls, data):
71 """Ensures the callback function signature is valid."""
72 func_params = dict(inspect.signature(data["callback"]).parameters)
73 par = func_params.pop(next(iter(func_params)))
74 if par.annotation is not par.empty and par.annotation not in (
75 "AgentVariable",
76 AgentVariable,
77 ):
78 raise RuntimeError(
79 "Defined callback Function does not take an "
80 "AgentVariable as first parameter"
81 )
83 if not list(data["kwargs"]) == list(func_params):
84 kwargs_not_in_function_args = set(list(data["kwargs"])).difference(
85 list(func_params)
86 )
87 function_args_not_in_kwargs = set(list(func_params)).difference(
88 list(data["kwargs"])
89 )
90 if function_args_not_in_kwargs:
91 missing_kwargs = "Missing arguments in kwargs: " + ", ".join(
92 function_args_not_in_kwargs
93 )
94 else:
95 missing_kwargs = ""
96 if kwargs_not_in_function_args:
97 missing_func_args = "Missing kwargs in function call: " + ", ".join(
98 kwargs_not_in_function_args
99 )
100 else:
101 missing_func_args = ""
102 raise RuntimeError(
103 "The registered Callback secondary arguments do not match the given kwargs:\n"
104 f"{missing_func_args}\n"
105 f"{missing_kwargs}"
106 )
107 # note from which module this callback came. If it is not a bound method, we
108 # assign it to none
109 try:
110 if isinstance(data["callback"].__self__, BaseModule):
111 module_id = data["callback"].__self__.id
112 else:
113 module_id = None
114 except AttributeError:
115 module_id = None
116 data["module_id"] = module_id
117 return data
119 def __eq__(self, other: "NoCopyBrokerCallback"):
120 """
121 Check equality to another callback using equality of all fields
122 and the name of the callback function
123 """
124 return (
125 self.alias,
126 self.source,
127 self.kwargs,
128 self.callback.__name__,
129 self.module_id,
130 ) == (
131 other.alias,
132 other.source,
133 other.kwargs,
134 other.callback.__name__,
135 other.module_id,
136 )
139class BrokerCallback(NoCopyBrokerCallback):
140 """
141 This broker callback always creates a deep-copy of the
142 AgentVariable it is going to send.
143 It is considered the safer option, as the receiving module
144 only get's the values and is not able to alter
145 the AgentVariable for other modules.
146 """
148 @field_validator("callback")
149 @classmethod
150 def auto_copy(cls, callback_func: Callable):
151 """Automatically supply the callback function with a copy"""
153 def callback_copy(variable: AgentVariable, **kwargs):
154 callback_func(variable.copy(deep=True), **kwargs)
156 callback_copy.__name__ = callback_func.__name__
157 return callback_copy
160class DataBroker(abc.ABC):
161 """
162 Handles communication and Callback triggers within an agent.
163 Write variables to the broker with ``send_variable()``.
164 Variables send to the broker will trigger callbacks
165 based on the alias and the source of the variable.
166 Commonly, this is used to provide other
167 modules with the variable.
169 Register and de-register Callbacks to the DataBroker
170 with ``register_callback`` and ``deregister_callback``.
171 """
173 def __init__(self, logger: CustomLogger, max_queue_size: int = 1000):
174 """
175 Initialize lock, callbacks and entries
176 """
177 self.logger = logger
178 self._max_queue_size = max_queue_size
179 self._mapped_callbacks: Dict[Tuple[str, Source], List[BrokerCallback]] = {}
180 self._unmapped_callbacks: List[BrokerCallback] = []
181 self._variable_queue = queue.Queue(maxsize=max_queue_size)
183 def send_variable(self, variable: AgentVariable, copy: bool = True):
184 """
185 Send variable to data_broker. Evokes callbacks associated with this variable.
187 Args:
188 variable AgentVariable:
189 The variable to set.
190 copy boolean:
191 Whether to copy the variable before sending.
192 Default is True.
193 """
194 if copy:
195 self._send_variable_to_modules(variable=variable.copy(deep=True))
196 else:
197 self._send_variable_to_modules(variable=variable)
199 def _send_variable_to_modules(self, variable: AgentVariable):
200 """
201 Enqueue AgentVariable in local queue for executing relevant callbacks.
203 Args:
204 variable AgentVariable: The variable to append to the local queue.
205 """
206 self._variable_queue.put(variable)
208 def _execute_callbacks(self):
209 """
210 Run relevant callbacks for AgentVariable's from local queue.
211 """
212 variable = self._variable_queue.get(block=True)
213 log_queue_status(
214 logger=self.logger,
215 queue_name="Callback-Distribution",
216 queue_object=self._variable_queue,
217 max_queue_size=self._max_queue_size,
218 )
219 _map_tuple = (variable.alias, variable.source)
220 # First the unmapped cbs
221 callbacks = self._filter_unmapped_callbacks(map_tuple=_map_tuple)
222 # Then the mapped once.
223 # Use try-except to avoid possible deregister during check and execution
224 try:
225 callbacks.extend(self._mapped_callbacks[_map_tuple])
226 except KeyError:
227 pass
229 # Then run the callbacks
230 self._run_callbacks(callbacks, variable)
232 def _filter_unmapped_callbacks(self, map_tuple: tuple) -> List[BrokerCallback]:
233 """
234 Filter the unmapped callbacks according to the given
235 tuple of variable information.
237 Args:
238 map_tuple tuple:
239 The tuple of alias and source in that order
241 Returns:
242 List[BrokerCallback]: The filtered list
244 """
245 # Filter all callbacks matching the given variable
246 callbacks = self._unmapped_callbacks
247 # First filter source
248 source = map_tuple[1]
249 callbacks = [
250 cb for cb in callbacks if (cb.source is None) or (cb.source.matches(source))
251 ]
252 # Now alias
253 callbacks = [
254 cb for cb in callbacks if (cb.alias is None) or (cb.alias == map_tuple[0])
255 ]
257 return callbacks
259 def register_callback(
260 self,
261 callback: Callable,
262 alias: str = None,
263 source: Source = None,
264 _unsafe_no_copy: bool = False,
265 **kwargs,
266 ) -> Union[BrokerCallback, NoCopyBrokerCallback]:
267 """
268 Register a callback to the data_broker.
270 Args:
271 callback callable: The function of the callback
272 alias str: The alias of variables to trigger callback
273 source Source: The Source of variables to trigger callback
274 kwargs dict: Kwargs to be passed to the callback function
275 _unsafe_no_copy: If True, the callback will not be passed a copy, but the
276 original AgentVariable. When using this option, the user promises to not
277 modify the AgentVariable, as doing so could lead to
278 wrong and difficult to debug behaviour in other modules (default False)
279 """
280 if _unsafe_no_copy:
281 callback_ = NoCopyBrokerCallback(
282 alias=alias, source=source, callback=callback, kwargs=kwargs
283 )
284 else:
285 callback_ = BrokerCallback(
286 alias=alias, source=source, callback=callback, kwargs=kwargs
287 )
288 _map_tuple = (alias, source)
289 if self.any_is_none(alias=alias, source=source):
290 self._unmapped_callbacks.append(callback_)
291 elif _map_tuple in self._mapped_callbacks:
292 self._mapped_callbacks[_map_tuple].append(callback_)
293 else:
294 self._mapped_callbacks[_map_tuple] = [callback_]
295 return callback_
297 def deregister_callback(
298 self, callback: Callable, alias: str = None, source: Source = None, **kwargs
299 ):
300 """
301 Deregister the given callback based on given
302 alias and source.
304 Args:
305 callback callable: The function of the callback
306 alias str: The alias of variables to trigger callback
307 source Source: The Source of variables to trigger callback
308 kwargs dict: Kwargs of the callback function
309 """
310 try:
311 callback = BrokerCallback(
312 alias=alias, source=source, callback=callback, kwargs=kwargs
313 )
314 _map_tuple = (alias, source)
315 if self.any_is_none(alias=alias, source=source):
316 self._unmapped_callbacks.remove(callback)
317 elif _map_tuple in self._mapped_callbacks:
318 self._mapped_callbacks[_map_tuple].remove(callback)
319 else:
320 return # No delete necessary
321 self.logger.debug("Callback de-registered: %s", callback)
322 except ValueError:
323 pass
325 @staticmethod
326 def any_is_none(alias: str, source: Source) -> bool:
327 """
328 Return True if any of alias or source are None.
330 Args:
331 alias str:
332 The alias of the callback
333 source Source:
334 The Source of the callback
335 """
336 return (
337 (alias is None)
338 or (source is None)
339 or (source.agent_id is None)
340 or (source.module_id is None)
341 )
343 @staticmethod
344 def _run_callbacks(callbacks: List[BrokerCallback], variable: AgentVariable):
345 """Runs the callbacks on a single AgentVariable."""
346 raise NotImplementedError
349class LocalDataBroker(DataBroker):
350 """Local variation of the DataBroker written for fast-as-possible
351 simulation within a single non-realtime Environment."""
353 def __init__(
354 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000
355 ):
356 """
357 Initialize env
358 """
359 self.env = env
360 super().__init__(logger=logger, max_queue_size=max_queue_size)
361 self._callbacks_available = self.env.event()
363 def _send_variable_to_modules(self, variable: AgentVariable):
364 """
365 Enqueue AgentVariable in local queue for executing relevant callbacks.
367 Args:
368 variable AgentVariable: The variable to append to the local queue.
369 """
370 super()._send_variable_to_modules(variable)
371 self._callbacks_available.callbacks.append(self._execute_callback_simpy)
372 self._callbacks_available.succeed()
373 self._callbacks_available = self.env.event()
375 def _execute_callback_simpy(self, ignored):
376 """
377 Run relevant callbacks for AgentVariable's from local queue.
378 To be appended to the callback of the callbacks available event.
379 """
380 self._execute_callbacks()
382 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable):
383 """Runs callbacks of an agent on a single AgentVariable in sequence.
384 Used in fast-as-possible execution mode."""
385 for cb in callbacks:
386 cb.callback(variable, **cb.kwargs)
389class RTDataBroker(DataBroker):
390 """DataBroker written for Realtime operation regardless of Environment."""
392 def __init__(
393 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000
394 ):
395 """
396 Initialize env.
397 Adds the function to start callback execution to the environment as a process.
398 Since the databroker is initialized before the modules, this will always be
399 the first triggered event, so no other process starts before the broker is
400 ready
401 """
402 super().__init__(logger=logger, max_queue_size=max_queue_size)
403 self._stop_queue = queue.SimpleQueue()
404 self.thread = threading.Thread(
405 target=self._callback_thread, daemon=True, name="DataBroker"
406 )
407 self._module_queues: dict[Union[str, None], queue.Queue] = {}
409 env.process(self._start_executing_callbacks(env))
411 def _start_executing_callbacks(self, env: Environment):
412 """
413 Starts the callback thread.
414 Thread is started after it is registered by the agent. Should be fine, since
415 the monitor process is started after the process in this function
416 """
417 self.thread.start()
418 yield env.event()
420 def _callback_thread(self):
421 """Thread to check and process the callback queue in Realtime
422 applications."""
423 while True:
424 if not self._stop_queue.empty():
425 err, module_id = self._stop_queue.get()
426 raise RuntimeError(
427 f"A callback failed in the module {module_id}."
428 ) from err
429 self._execute_callbacks()
431 def register_callback(
432 self,
433 callback: Callable,
434 alias: str = None,
435 source: Source = None,
436 _unsafe_no_copy: bool = False,
437 **kwargs,
438 ) -> Union[NoCopyBrokerCallback, BrokerCallback]:
439 # check to which object the callable is bound, to determine the module
440 callback = super().register_callback(
441 callback=callback,
442 alias=alias,
443 source=source,
444 _unsafe_no_copy=_unsafe_no_copy,
445 **kwargs,
446 )
447 if callback.module_id not in self._module_queues:
448 self._start_module_thread(callback.module_id)
449 return callback
451 def _start_module_thread(self, module_id: str):
452 """Starts a consumer thread for callbacks registered from a module."""
453 module_queue = queue.Queue(maxsize=self._max_queue_size)
454 threading.Thread(
455 target=self._execute_callbacks_of_module,
456 daemon=True,
457 name=f"DataBroker/{module_id}",
458 kwargs={"queue": module_queue, "module_id": module_id},
459 ).start()
460 self._module_queues[module_id] = module_queue
462 def _execute_callbacks_of_module(self, queue: queue.SimpleQueue, module_id: str):
463 """Executes the callbacks associated with a specific module."""
464 try:
465 while True:
466 cb, variable = queue.get(block=True)
467 cb.callback(variable=variable, **cb.kwargs)
468 except Exception as e:
469 self._stop_queue.put((e, module_id))
470 raise e
472 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable):
473 """Distributes callbacks to the threads running for each module."""
474 for cb in callbacks:
475 self._module_queues[cb.module_id].put_nowait((cb, variable))
476 log_queue_status(
477 logger=self.logger,
478 queue_name=cb.module_id,
479 queue_object=self._module_queues[cb.module_id],
480 max_queue_size=self._max_queue_size,
481 )
484def log_queue_status(
485 logger: logging.Logger,
486 queue_object: queue.Queue,
487 max_queue_size: int,
488 queue_name: str,
489):
490 """
491 Log the current load of the given queue in percent.
493 Args:
494 logger (logging.Logger): A logger instance
495 queue_object (queue.Queue): The queue object
496 max_queue_size (int): Maximal queue size
497 queue_name (str): Name associated with the queue
498 """
499 if max_queue_size < 1:
500 return
501 number_of_items = queue_object.qsize()
502 percent_full = round(number_of_items / max_queue_size * 100, 2)
503 if percent_full < 10:
504 return
505 elif percent_full < 80:
506 logger_func = logger.debug
507 else:
508 logger_func = logger.warning
509 logger_func(
510 "Queue '%s' fullness is %s percent (%s items).",
511 queue_name,
512 percent_full,
513 number_of_items,
514 )