Coverage for agentlib/core/data_broker.py: 93%

177 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-10-30 13:39 +0000

1""" 

2The module contains the relevant classes 

3to execute and use the DataBroker. 

4Besides the DataBroker itself, the BrokerCallback is defined. 

5 

6Internally, uses the tuple _map_tuple in the order of 

7 

8(alias, source) 

9 

10to match callbacks and AgentVariables. 

11 

12""" 

13 

14import abc 

15import inspect 

16import logging 

17import queue 

18import threading 

19from typing import ( 

20 List, 

21 Callable, 

22 Dict, 

23 Tuple, 

24 Optional, 

25 Union, 

26) 

27 

28from pydantic import BaseModel, field_validator, model_validator, ConfigDict 

29 

30from agentlib.core.datamodels import AgentVariable, Source 

31from agentlib.core.environment import Environment 

32from agentlib.core.logging_ import CustomLogger 

33from agentlib.core.module import BaseModule 

34 

35 

36class NoCopyBrokerCallback(BaseModel): 

37 """ 

38 Basic broker callback. 

39 This object does not copy the AgentVariable 

40 before calling the callback, which can be unsafe. 

41 

42 This class checks if the given callback function 

43 adheres to the signature it needs to be correctly called. 

44 The first argument will be an AgentVariable. If a type-hint 

45 is specified, it must be `AgentVariable` or `"AgentVariable"`. 

46 Any further arguments must match the kwargs 

47 specified in the class and will also be the ones you 

48 pass to this class. 

49 

50 Example: 

51 >>> def my_callback(variable: "AgentVariable", some_static_info: str): 

52 >>> print(variable, some_other_info) 

53 >>> NoCopyBrokerCallback( 

54 >>> callback=my_callback, 

55 >>> kwargs={"some_static_info": "Hello World"} 

56 >>> ) 

57 

58 """ 

59 

60 # pylint: disable=too-few-public-methods 

61 callback: Callable 

62 alias: Optional[str] = None 

63 source: Optional[Source] = None 

64 kwargs: dict = {} 

65 model_config = ConfigDict(arbitrary_types_allowed=True) 

66 module_id: Optional[str] = None 

67 

68 @model_validator(mode="before") 

69 @classmethod 

70 def check_valid_callback_function(cls, data): 

71 """Ensures the callback function signature is valid.""" 

72 func_params = dict(inspect.signature(data["callback"]).parameters) 

73 par = func_params.pop(next(iter(func_params))) 

74 if par.annotation is not par.empty and par.annotation not in ( 

75 "AgentVariable", 

76 AgentVariable, 

77 ): 

78 raise RuntimeError( 

79 "Defined callback Function does not take an " 

80 "AgentVariable as first parameter" 

81 ) 

82 

83 if not list(data["kwargs"]) == list(func_params): 

84 kwargs_not_in_function_args = set(list(data["kwargs"])).difference( 

85 list(func_params) 

86 ) 

87 function_args_not_in_kwargs = set(list(func_params)).difference( 

88 list(data["kwargs"]) 

89 ) 

90 if function_args_not_in_kwargs: 

91 missing_kwargs = "Missing arguments in kwargs: " + ", ".join( 

92 function_args_not_in_kwargs 

93 ) 

94 else: 

95 missing_kwargs = "" 

96 if kwargs_not_in_function_args: 

97 missing_func_args = "Missing kwargs in function call: " + ", ".join( 

98 kwargs_not_in_function_args 

99 ) 

100 else: 

101 missing_func_args = "" 

102 raise RuntimeError( 

103 "The registered Callback secondary arguments do not match the given kwargs:\n" 

104 f"{missing_func_args}\n" 

105 f"{missing_kwargs}" 

106 ) 

107 # note from which module this callback came. If it is not a bound method, we 

108 # assign it to none 

109 try: 

110 if isinstance(data["callback"].__self__, BaseModule): 

111 module_id = data["callback"].__self__.id 

112 else: 

113 module_id = None 

114 except AttributeError: 

115 module_id = None 

116 data["module_id"] = module_id 

117 return data 

118 

119 def __eq__(self, other: "NoCopyBrokerCallback"): 

120 """ 

121 Check equality to another callback using equality of all fields 

122 and the name of the callback function 

123 """ 

124 return ( 

125 self.alias, 

126 self.source, 

127 self.kwargs, 

128 self.callback.__name__, 

129 self.module_id, 

130 ) == ( 

131 other.alias, 

132 other.source, 

133 other.kwargs, 

134 other.callback.__name__, 

135 other.module_id, 

136 ) 

137 

138 

139class BrokerCallback(NoCopyBrokerCallback): 

140 """ 

141 This broker callback always creates a deep-copy of the 

142 AgentVariable it is going to send. 

143 It is considered the safer option, as the receiving module 

144 only get's the values and is not able to alter 

145 the AgentVariable for other modules. 

146 """ 

147 

148 @field_validator("callback") 

149 @classmethod 

150 def auto_copy(cls, callback_func: Callable): 

151 """Automatically supply the callback function with a copy""" 

152 

153 def callback_copy(variable: AgentVariable, **kwargs): 

154 callback_func(variable.copy(deep=True), **kwargs) 

155 

156 callback_copy.__name__ = callback_func.__name__ 

157 return callback_copy 

158 

159 

160class DataBroker(abc.ABC): 

161 """ 

162 Handles communication and Callback triggers within an agent. 

163 Write variables to the broker with ``send_variable()``. 

164 Variables send to the broker will trigger callbacks 

165 based on the alias and the source of the variable. 

166 Commonly, this is used to provide other 

167 modules with the variable. 

168 

169 Register and de-register Callbacks to the DataBroker 

170 with ``register_callback`` and ``deregister_callback``. 

171 """ 

172 

173 def __init__(self, logger: CustomLogger): 

174 """ 

175 Initialize lock, callbacks and entries 

176 """ 

177 self.logger = logger 

178 self._mapped_callbacks: Dict[Tuple[str, Source], List[BrokerCallback]] = {} 

179 self._unmapped_callbacks: List[BrokerCallback] = [] 

180 

181 def send_variable(self, variable: AgentVariable, copy: bool = True): 

182 """ 

183 Send variable to data_broker. Evokes callbacks associated with this variable. 

184 

185 Args: 

186 variable AgentVariable: 

187 The variable to set. 

188 copy boolean: 

189 Whether to copy the variable before sending. 

190 Default is True. 

191 """ 

192 if copy: 

193 self._send_variable_to_modules(variable=variable.copy(deep=True)) 

194 else: 

195 self._send_variable_to_modules(variable=variable) 

196 

197 @abc.abstractmethod 

198 def _send_variable_to_modules(self, variable: AgentVariable): 

199 """ 

200 Enqueue AgentVariable in local queue for executing relevant callbacks. 

201 

202 Args: 

203 variable AgentVariable: The variable to append to the local queue. 

204 """ 

205 raise NotImplementedError 

206 

207 def _get_variable_callbacks(self, variable: AgentVariable): 

208 """ 

209 Helper function to get all callbacks associated with a given variable 

210 """ 

211 _map_tuple = (variable.alias, variable.source) 

212 # First the unmapped cbs 

213 callbacks = self._filter_unmapped_callbacks(map_tuple=_map_tuple) 

214 # Then the mapped once. 

215 # Use try-except to avoid possible deregister during check and execution 

216 try: 

217 callbacks.extend(self._mapped_callbacks[_map_tuple]) 

218 except KeyError: 

219 pass 

220 return callbacks 

221 

222 def _filter_unmapped_callbacks(self, map_tuple: tuple) -> List[BrokerCallback]: 

223 """ 

224 Filter the unmapped callbacks according to the given 

225 tuple of variable information. 

226 

227 Args: 

228 map_tuple tuple: 

229 The tuple of alias and source in that order 

230 

231 Returns: 

232 List[BrokerCallback]: The filtered list 

233 

234 """ 

235 # Filter all callbacks matching the given variable 

236 callbacks = self._unmapped_callbacks 

237 # First filter source 

238 source = map_tuple[1] 

239 callbacks = [ 

240 cb for cb in callbacks if (cb.source is None) or (cb.source.matches(source)) 

241 ] 

242 # Now alias 

243 callbacks = [ 

244 cb for cb in callbacks if (cb.alias is None) or (cb.alias == map_tuple[0]) 

245 ] 

246 

247 return callbacks 

248 

249 def register_callback( 

250 self, 

251 callback: Callable, 

252 alias: str = None, 

253 source: Source = None, 

254 _unsafe_no_copy: bool = False, 

255 **kwargs, 

256 ) -> Union[BrokerCallback, NoCopyBrokerCallback]: 

257 """ 

258 Register a callback to the data_broker. 

259 

260 Args: 

261 callback callable: The function of the callback 

262 alias str: The alias of variables to trigger callback 

263 source Source: The Source of variables to trigger callback 

264 kwargs dict: Kwargs to be passed to the callback function 

265 _unsafe_no_copy: If True, the callback will not be passed a copy, but the 

266 original AgentVariable. When using this option, the user promises to not 

267 modify the AgentVariable, as doing so could lead to 

268 wrong and difficult to debug behaviour in other modules (default False) 

269 """ 

270 if _unsafe_no_copy: 

271 callback_ = NoCopyBrokerCallback( 

272 alias=alias, source=source, callback=callback, kwargs=kwargs 

273 ) 

274 else: 

275 callback_ = BrokerCallback( 

276 alias=alias, source=source, callback=callback, kwargs=kwargs 

277 ) 

278 _map_tuple = (alias, source) 

279 if self.any_is_none(alias=alias, source=source): 

280 self._unmapped_callbacks.append(callback_) 

281 elif _map_tuple in self._mapped_callbacks: 

282 self._mapped_callbacks[_map_tuple].append(callback_) 

283 else: 

284 self._mapped_callbacks[_map_tuple] = [callback_] 

285 return callback_ 

286 

287 def deregister_callback( 

288 self, callback: Callable, alias: str = None, source: Source = None, **kwargs 

289 ): 

290 """ 

291 Deregister the given callback based on given 

292 alias and source. 

293 

294 Args: 

295 callback callable: The function of the callback 

296 alias str: The alias of variables to trigger callback 

297 source Source: The Source of variables to trigger callback 

298 kwargs dict: Kwargs of the callback function 

299 """ 

300 try: 

301 callback = BrokerCallback( 

302 alias=alias, source=source, callback=callback, kwargs=kwargs 

303 ) 

304 _map_tuple = (alias, source) 

305 if self.any_is_none(alias=alias, source=source): 

306 self._unmapped_callbacks.remove(callback) 

307 elif _map_tuple in self._mapped_callbacks: 

308 self._mapped_callbacks[_map_tuple].remove(callback) 

309 else: 

310 return # No delete necessary 

311 self.logger.debug("Callback de-registered: %s", callback) 

312 except ValueError: 

313 pass 

314 

315 @staticmethod 

316 def any_is_none(alias: str, source: Source) -> bool: 

317 """ 

318 Return True if any of alias or source are None. 

319 

320 Args: 

321 alias str: 

322 The alias of the callback 

323 source Source: 

324 The Source of the callback 

325 """ 

326 return ( 

327 (alias is None) 

328 or (source is None) 

329 or (source.agent_id is None) 

330 or (source.module_id is None) 

331 ) 

332 

333 @staticmethod 

334 def _run_callbacks(callbacks: List[BrokerCallback], variable: AgentVariable): 

335 """Runs the callbacks on a single AgentVariable.""" 

336 raise NotImplementedError 

337 

338 

339class DirectCallbackDataBroker(DataBroker): 

340 """ 

341 This DataBroker directly executes all callbacks. 

342 This may lead to infinite recursion loops, if two callbacks trigger 

343 each other when being triggered, for example. 

344 However, using this class, you can directly "follow" your variable 

345 from module to other modules or agents. 

346 """ 

347 

348 def _send_variable_to_modules(self, variable: AgentVariable): 

349 """ 

350 Directly execute all callbacks for the given variable. 

351 

352 Args: 

353 variable AgentVariable: The variable to append to the local queue. 

354 """ 

355 callbacks = self._get_variable_callbacks(variable) 

356 for cb in callbacks: 

357 cb.callback(variable, **cb.kwargs) 

358 

359 

360class QueuedCallbackDataBroker(DataBroker): 

361 

362 def __init__(self, logger: CustomLogger, max_queue_size: int = 1000): 

363 """ 

364 Initialize lock, callbacks and entries 

365 """ 

366 super().__init__(logger=logger) 

367 self._max_queue_size = max_queue_size 

368 self._variable_queue = queue.Queue(maxsize=max_queue_size) 

369 

370 def _send_variable_to_modules(self, variable: AgentVariable): 

371 """ 

372 Enqueue AgentVariable in local queue for executing relevant callbacks. 

373 

374 Args: 

375 variable AgentVariable: The variable to append to the local queue. 

376 """ 

377 self._variable_queue.put(variable) 

378 

379 def _execute_callbacks(self): 

380 """ 

381 Run relevant callbacks for AgentVariable's from local queue. 

382 """ 

383 variable = self._variable_queue.get(block=True) 

384 log_queue_status( 

385 logger=self.logger, 

386 queue_name="Callback-Distribution", 

387 queue_object=self._variable_queue, 

388 max_queue_size=self._max_queue_size, 

389 ) 

390 callbacks = self._get_variable_callbacks(variable) 

391 # Then run the callbacks 

392 self._run_callbacks(callbacks, variable) 

393 

394 

395class LocalDataBroker(QueuedCallbackDataBroker): 

396 """Local variation of the DataBroker written for fast-as-possible 

397 simulation within a single non-realtime Environment.""" 

398 

399 def __init__( 

400 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000 

401 ): 

402 """ 

403 Initialize env 

404 """ 

405 self.env = env 

406 super().__init__(logger=logger, max_queue_size=max_queue_size) 

407 self._callbacks_available = self.env.event() 

408 

409 def _send_variable_to_modules(self, variable: AgentVariable): 

410 """ 

411 Enqueue AgentVariable in local queue for executing relevant callbacks. 

412 

413 Args: 

414 variable AgentVariable: The variable to append to the local queue. 

415 """ 

416 super()._send_variable_to_modules(variable) 

417 self._callbacks_available.callbacks.append(self._execute_callback_simpy) 

418 self._callbacks_available.succeed() 

419 self._callbacks_available = self.env.event() 

420 

421 def _execute_callback_simpy(self, ignored): 

422 """ 

423 Run relevant callbacks for AgentVariable's from local queue. 

424 To be appended to the callback of the callbacks available event. 

425 """ 

426 self._execute_callbacks() 

427 

428 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable): 

429 """Runs callbacks of an agent on a single AgentVariable in sequence. 

430 Used in fast-as-possible execution mode.""" 

431 for cb in callbacks: 

432 cb.callback(variable, **cb.kwargs) 

433 

434 

435class RTDataBroker(QueuedCallbackDataBroker): 

436 """DataBroker written for Realtime operation regardless of Environment.""" 

437 

438 def __init__( 

439 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000 

440 ): 

441 """ 

442 Initialize env. 

443 Adds the function to start callback execution to the environment as a process. 

444 Since the databroker is initialized before the modules, this will always be 

445 the first triggered event, so no other process starts before the broker is 

446 ready 

447 """ 

448 super().__init__(logger=logger, max_queue_size=max_queue_size) 

449 self._stop_queue = queue.SimpleQueue() 

450 self.thread = threading.Thread( 

451 target=self._callback_thread, daemon=True, name="DataBroker" 

452 ) 

453 self._module_queues: dict[Union[str, None], queue.Queue] = {} 

454 

455 env.process(self._start_executing_callbacks(env)) 

456 

457 def _start_executing_callbacks(self, env: Environment): 

458 """ 

459 Starts the callback thread. 

460 Thread is started after it is registered by the agent. Should be fine, since 

461 the monitor process is started after the process in this function 

462 """ 

463 self.thread.start() 

464 yield env.event() 

465 

466 def _callback_thread(self): 

467 """Thread to check and process the callback queue in Realtime 

468 applications.""" 

469 while True: 

470 if not self._stop_queue.empty(): 

471 err, module_id = self._stop_queue.get() 

472 raise RuntimeError( 

473 f"A callback failed in the module {module_id}." 

474 ) from err 

475 self._execute_callbacks() 

476 

477 def register_callback( 

478 self, 

479 callback: Callable, 

480 alias: str = None, 

481 source: Source = None, 

482 _unsafe_no_copy: bool = False, 

483 **kwargs, 

484 ) -> Union[NoCopyBrokerCallback, BrokerCallback]: 

485 # check to which object the callable is bound, to determine the module 

486 callback = super().register_callback( 

487 callback=callback, 

488 alias=alias, 

489 source=source, 

490 _unsafe_no_copy=_unsafe_no_copy, 

491 **kwargs, 

492 ) 

493 if callback.module_id not in self._module_queues: 

494 self._start_module_thread(callback.module_id) 

495 return callback 

496 

497 def _start_module_thread(self, module_id: str): 

498 """Starts a consumer thread for callbacks registered from a module.""" 

499 module_queue = queue.Queue(maxsize=self._max_queue_size) 

500 threading.Thread( 

501 target=self._execute_callbacks_of_module, 

502 daemon=True, 

503 name=f"DataBroker/{module_id}", 

504 kwargs={"queue": module_queue, "module_id": module_id}, 

505 ).start() 

506 self._module_queues[module_id] = module_queue 

507 

508 def _execute_callbacks_of_module(self, queue: queue.SimpleQueue, module_id: str): 

509 """Executes the callbacks associated with a specific module.""" 

510 try: 

511 while True: 

512 cb, variable = queue.get(block=True) 

513 cb.callback(variable=variable, **cb.kwargs) 

514 except Exception as e: 

515 self._stop_queue.put((e, module_id)) 

516 raise e 

517 

518 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable): 

519 """Distributes callbacks to the threads running for each module.""" 

520 for cb in callbacks: 

521 self._module_queues[cb.module_id].put_nowait((cb, variable)) 

522 log_queue_status( 

523 logger=self.logger, 

524 queue_name=cb.module_id, 

525 queue_object=self._module_queues[cb.module_id], 

526 max_queue_size=self._max_queue_size, 

527 ) 

528 

529 

530def log_queue_status( 

531 logger: logging.Logger, 

532 queue_object: queue.Queue, 

533 max_queue_size: int, 

534 queue_name: str, 

535): 

536 """ 

537 Log the current load of the given queue in percent. 

538 

539 Args: 

540 logger (logging.Logger): A logger instance 

541 queue_object (queue.Queue): The queue object 

542 max_queue_size (int): Maximal queue size 

543 queue_name (str): Name associated with the queue 

544 """ 

545 if max_queue_size < 1: 

546 return 

547 number_of_items = queue_object.qsize() 

548 percent_full = round(number_of_items / max_queue_size * 100, 2) 

549 if percent_full < 10: 

550 return 

551 elif percent_full < 80: 

552 logger_func = logger.debug 

553 else: 

554 logger_func = logger.warning 

555 logger_func( 

556 "Queue '%s' fullness is %s percent (%s items).", 

557 queue_name, 

558 percent_full, 

559 number_of_items, 

560 )