Coverage for agentlib/core/data_broker.py: 93%

162 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1""" 

2The module contains the relevant classes 

3to execute and use the DataBroker. 

4Besides the DataBroker itself, the BrokerCallback is defined. 

5 

6Internally, uses the tuple _map_tuple in the order of 

7 

8(alias, source) 

9 

10to match callbacks and AgentVariables. 

11 

12""" 

13 

14import abc 

15import inspect 

16import logging 

17import queue 

18import threading 

19from typing import ( 

20 List, 

21 Callable, 

22 Dict, 

23 Tuple, 

24 Optional, 

25 Union, 

26) 

27 

28from pydantic import BaseModel, field_validator, model_validator, ConfigDict 

29 

30from agentlib.core.datamodels import AgentVariable, Source 

31from agentlib.core.environment import Environment 

32from agentlib.core.logging_ import CustomLogger 

33from agentlib.core.module import BaseModule 

34 

35 

36class NoCopyBrokerCallback(BaseModel): 

37 """ 

38 Basic broker callback. 

39 This object does not copy the AgentVariable 

40 before calling the callback, which can be unsafe. 

41 

42 This class checks if the given callback function 

43 adheres to the signature it needs to be correctly called. 

44 The first argument will be an AgentVariable. If a type-hint 

45 is specified, it must be `AgentVariable` or `"AgentVariable"`. 

46 Any further arguments must match the kwargs 

47 specified in the class and will also be the ones you 

48 pass to this class. 

49 

50 Example: 

51 >>> def my_callback(variable: "AgentVariable", some_static_info: str): 

52 >>> print(variable, some_other_info) 

53 >>> NoCopyBrokerCallback( 

54 >>> callback=my_callback, 

55 >>> kwargs={"some_static_info": "Hello World"} 

56 >>> ) 

57 

58 """ 

59 

60 # pylint: disable=too-few-public-methods 

61 callback: Callable 

62 alias: Optional[str] = None 

63 source: Optional[Source] = None 

64 kwargs: dict = {} 

65 model_config = ConfigDict(arbitrary_types_allowed=True) 

66 module_id: Optional[str] = None 

67 

68 @model_validator(mode="before") 

69 @classmethod 

70 def check_valid_callback_function(cls, data): 

71 """Ensures the callback function signature is valid.""" 

72 func_params = dict(inspect.signature(data["callback"]).parameters) 

73 par = func_params.pop(next(iter(func_params))) 

74 if par.annotation is not par.empty and par.annotation not in ( 

75 "AgentVariable", 

76 AgentVariable, 

77 ): 

78 raise RuntimeError( 

79 "Defined callback Function does not take an " 

80 "AgentVariable as first parameter" 

81 ) 

82 

83 if not list(data["kwargs"]) == list(func_params): 

84 kwargs_not_in_function_args = set(list(data["kwargs"])).difference( 

85 list(func_params) 

86 ) 

87 function_args_not_in_kwargs = set(list(func_params)).difference( 

88 list(data["kwargs"]) 

89 ) 

90 if function_args_not_in_kwargs: 

91 missing_kwargs = "Missing arguments in kwargs: " + ", ".join( 

92 function_args_not_in_kwargs 

93 ) 

94 else: 

95 missing_kwargs = "" 

96 if kwargs_not_in_function_args: 

97 missing_func_args = "Missing kwargs in function call: " + ", ".join( 

98 kwargs_not_in_function_args 

99 ) 

100 else: 

101 missing_func_args = "" 

102 raise RuntimeError( 

103 "The registered Callback secondary arguments do not match the given kwargs:\n" 

104 f"{missing_func_args}\n" 

105 f"{missing_kwargs}" 

106 ) 

107 # note from which module this callback came. If it is not a bound method, we 

108 # assign it to none 

109 try: 

110 if isinstance(data["callback"].__self__, BaseModule): 

111 module_id = data["callback"].__self__.id 

112 else: 

113 module_id = None 

114 except AttributeError: 

115 module_id = None 

116 data["module_id"] = module_id 

117 return data 

118 

119 def __eq__(self, other: "NoCopyBrokerCallback"): 

120 """ 

121 Check equality to another callback using equality of all fields 

122 and the name of the callback function 

123 """ 

124 return ( 

125 self.alias, 

126 self.source, 

127 self.kwargs, 

128 self.callback.__name__, 

129 self.module_id, 

130 ) == ( 

131 other.alias, 

132 other.source, 

133 other.kwargs, 

134 other.callback.__name__, 

135 other.module_id, 

136 ) 

137 

138 

139class BrokerCallback(NoCopyBrokerCallback): 

140 """ 

141 This broker callback always creates a deep-copy of the 

142 AgentVariable it is going to send. 

143 It is considered the safer option, as the receiving module 

144 only get's the values and is not able to alter 

145 the AgentVariable for other modules. 

146 """ 

147 

148 @field_validator("callback") 

149 @classmethod 

150 def auto_copy(cls, callback_func: Callable): 

151 """Automatically supply the callback function with a copy""" 

152 

153 def callback_copy(variable: AgentVariable, **kwargs): 

154 callback_func(variable.copy(deep=True), **kwargs) 

155 

156 callback_copy.__name__ = callback_func.__name__ 

157 return callback_copy 

158 

159 

160class DataBroker(abc.ABC): 

161 """ 

162 Handles communication and Callback triggers within an agent. 

163 Write variables to the broker with ``send_variable()``. 

164 Variables send to the broker will trigger callbacks 

165 based on the alias and the source of the variable. 

166 Commonly, this is used to provide other 

167 modules with the variable. 

168 

169 Register and de-register Callbacks to the DataBroker 

170 with ``register_callback`` and ``deregister_callback``. 

171 """ 

172 

173 def __init__(self, logger: CustomLogger, max_queue_size: int = 1000): 

174 """ 

175 Initialize lock, callbacks and entries 

176 """ 

177 self.logger = logger 

178 self._max_queue_size = max_queue_size 

179 self._mapped_callbacks: Dict[Tuple[str, Source], List[BrokerCallback]] = {} 

180 self._unmapped_callbacks: List[BrokerCallback] = [] 

181 self._variable_queue = queue.Queue(maxsize=max_queue_size) 

182 

183 def send_variable(self, variable: AgentVariable, copy: bool = True): 

184 """ 

185 Send variable to data_broker. Evokes callbacks associated with this variable. 

186 

187 Args: 

188 variable AgentVariable: 

189 The variable to set. 

190 copy boolean: 

191 Whether to copy the variable before sending. 

192 Default is True. 

193 """ 

194 if copy: 

195 self._send_variable_to_modules(variable=variable.copy(deep=True)) 

196 else: 

197 self._send_variable_to_modules(variable=variable) 

198 

199 def _send_variable_to_modules(self, variable: AgentVariable): 

200 """ 

201 Enqueue AgentVariable in local queue for executing relevant callbacks. 

202 

203 Args: 

204 variable AgentVariable: The variable to append to the local queue. 

205 """ 

206 self._variable_queue.put(variable) 

207 

208 def _execute_callbacks(self): 

209 """ 

210 Run relevant callbacks for AgentVariable's from local queue. 

211 """ 

212 variable = self._variable_queue.get(block=True) 

213 log_queue_status( 

214 logger=self.logger, 

215 queue_name="Callback-Distribution", 

216 queue_object=self._variable_queue, 

217 max_queue_size=self._max_queue_size, 

218 ) 

219 _map_tuple = (variable.alias, variable.source) 

220 # First the unmapped cbs 

221 callbacks = self._filter_unmapped_callbacks(map_tuple=_map_tuple) 

222 # Then the mapped once. 

223 # Use try-except to avoid possible deregister during check and execution 

224 try: 

225 callbacks.extend(self._mapped_callbacks[_map_tuple]) 

226 except KeyError: 

227 pass 

228 

229 # Then run the callbacks 

230 self._run_callbacks(callbacks, variable) 

231 

232 def _filter_unmapped_callbacks(self, map_tuple: tuple) -> List[BrokerCallback]: 

233 """ 

234 Filter the unmapped callbacks according to the given 

235 tuple of variable information. 

236 

237 Args: 

238 map_tuple tuple: 

239 The tuple of alias and source in that order 

240 

241 Returns: 

242 List[BrokerCallback]: The filtered list 

243 

244 """ 

245 # Filter all callbacks matching the given variable 

246 callbacks = self._unmapped_callbacks 

247 # First filter source 

248 source = map_tuple[1] 

249 callbacks = [ 

250 cb for cb in callbacks if (cb.source is None) or (cb.source.matches(source)) 

251 ] 

252 # Now alias 

253 callbacks = [ 

254 cb for cb in callbacks if (cb.alias is None) or (cb.alias == map_tuple[0]) 

255 ] 

256 

257 return callbacks 

258 

259 def register_callback( 

260 self, 

261 callback: Callable, 

262 alias: str = None, 

263 source: Source = None, 

264 _unsafe_no_copy: bool = False, 

265 **kwargs, 

266 ) -> Union[BrokerCallback, NoCopyBrokerCallback]: 

267 """ 

268 Register a callback to the data_broker. 

269 

270 Args: 

271 callback callable: The function of the callback 

272 alias str: The alias of variables to trigger callback 

273 source Source: The Source of variables to trigger callback 

274 kwargs dict: Kwargs to be passed to the callback function 

275 _unsafe_no_copy: If True, the callback will not be passed a copy, but the 

276 original AgentVariable. When using this option, the user promises to not 

277 modify the AgentVariable, as doing so could lead to 

278 wrong and difficult to debug behaviour in other modules (default False) 

279 """ 

280 if _unsafe_no_copy: 

281 callback_ = NoCopyBrokerCallback( 

282 alias=alias, source=source, callback=callback, kwargs=kwargs 

283 ) 

284 else: 

285 callback_ = BrokerCallback( 

286 alias=alias, source=source, callback=callback, kwargs=kwargs 

287 ) 

288 _map_tuple = (alias, source) 

289 if self.any_is_none(alias=alias, source=source): 

290 self._unmapped_callbacks.append(callback_) 

291 elif _map_tuple in self._mapped_callbacks: 

292 self._mapped_callbacks[_map_tuple].append(callback_) 

293 else: 

294 self._mapped_callbacks[_map_tuple] = [callback_] 

295 return callback_ 

296 

297 def deregister_callback( 

298 self, callback: Callable, alias: str = None, source: Source = None, **kwargs 

299 ): 

300 """ 

301 Deregister the given callback based on given 

302 alias and source. 

303 

304 Args: 

305 callback callable: The function of the callback 

306 alias str: The alias of variables to trigger callback 

307 source Source: The Source of variables to trigger callback 

308 kwargs dict: Kwargs of the callback function 

309 """ 

310 try: 

311 callback = BrokerCallback( 

312 alias=alias, source=source, callback=callback, kwargs=kwargs 

313 ) 

314 _map_tuple = (alias, source) 

315 if self.any_is_none(alias=alias, source=source): 

316 self._unmapped_callbacks.remove(callback) 

317 elif _map_tuple in self._mapped_callbacks: 

318 self._mapped_callbacks[_map_tuple].remove(callback) 

319 else: 

320 return # No delete necessary 

321 self.logger.debug("Callback de-registered: %s", callback) 

322 except ValueError: 

323 pass 

324 

325 @staticmethod 

326 def any_is_none(alias: str, source: Source) -> bool: 

327 """ 

328 Return True if any of alias or source are None. 

329 

330 Args: 

331 alias str: 

332 The alias of the callback 

333 source Source: 

334 The Source of the callback 

335 """ 

336 return ( 

337 (alias is None) 

338 or (source is None) 

339 or (source.agent_id is None) 

340 or (source.module_id is None) 

341 ) 

342 

343 @staticmethod 

344 def _run_callbacks(callbacks: List[BrokerCallback], variable: AgentVariable): 

345 """Runs the callbacks on a single AgentVariable.""" 

346 raise NotImplementedError 

347 

348 

349class LocalDataBroker(DataBroker): 

350 """Local variation of the DataBroker written for fast-as-possible 

351 simulation within a single non-realtime Environment.""" 

352 

353 def __init__( 

354 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000 

355 ): 

356 """ 

357 Initialize env 

358 """ 

359 self.env = env 

360 super().__init__(logger=logger, max_queue_size=max_queue_size) 

361 self._callbacks_available = self.env.event() 

362 

363 def _send_variable_to_modules(self, variable: AgentVariable): 

364 """ 

365 Enqueue AgentVariable in local queue for executing relevant callbacks. 

366 

367 Args: 

368 variable AgentVariable: The variable to append to the local queue. 

369 """ 

370 super()._send_variable_to_modules(variable) 

371 self._callbacks_available.callbacks.append(self._execute_callback_simpy) 

372 self._callbacks_available.succeed() 

373 self._callbacks_available = self.env.event() 

374 

375 def _execute_callback_simpy(self, ignored): 

376 """ 

377 Run relevant callbacks for AgentVariable's from local queue. 

378 To be appended to the callback of the callbacks available event. 

379 """ 

380 self._execute_callbacks() 

381 

382 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable): 

383 """Runs callbacks of an agent on a single AgentVariable in sequence. 

384 Used in fast-as-possible execution mode.""" 

385 for cb in callbacks: 

386 cb.callback(variable, **cb.kwargs) 

387 

388 

389class RTDataBroker(DataBroker): 

390 """DataBroker written for Realtime operation regardless of Environment.""" 

391 

392 def __init__( 

393 self, env: Environment, logger: CustomLogger, max_queue_size: int = 1000 

394 ): 

395 """ 

396 Initialize env. 

397 Adds the function to start callback execution to the environment as a process. 

398 Since the databroker is initialized before the modules, this will always be 

399 the first triggered event, so no other process starts before the broker is 

400 ready 

401 """ 

402 super().__init__(logger=logger, max_queue_size=max_queue_size) 

403 self._stop_queue = queue.SimpleQueue() 

404 self.thread = threading.Thread( 

405 target=self._callback_thread, daemon=True, name="DataBroker" 

406 ) 

407 self._module_queues: dict[Union[str, None], queue.Queue] = {} 

408 

409 env.process(self._start_executing_callbacks(env)) 

410 

411 def _start_executing_callbacks(self, env: Environment): 

412 """ 

413 Starts the callback thread. 

414 Thread is started after it is registered by the agent. Should be fine, since 

415 the monitor process is started after the process in this function 

416 """ 

417 self.thread.start() 

418 yield env.event() 

419 

420 def _callback_thread(self): 

421 """Thread to check and process the callback queue in Realtime 

422 applications.""" 

423 while True: 

424 if not self._stop_queue.empty(): 

425 err, module_id = self._stop_queue.get() 

426 raise RuntimeError( 

427 f"A callback failed in the module {module_id}." 

428 ) from err 

429 self._execute_callbacks() 

430 

431 def register_callback( 

432 self, 

433 callback: Callable, 

434 alias: str = None, 

435 source: Source = None, 

436 _unsafe_no_copy: bool = False, 

437 **kwargs, 

438 ) -> Union[NoCopyBrokerCallback, BrokerCallback]: 

439 # check to which object the callable is bound, to determine the module 

440 callback = super().register_callback( 

441 callback=callback, 

442 alias=alias, 

443 source=source, 

444 _unsafe_no_copy=_unsafe_no_copy, 

445 **kwargs, 

446 ) 

447 if callback.module_id not in self._module_queues: 

448 self._start_module_thread(callback.module_id) 

449 return callback 

450 

451 def _start_module_thread(self, module_id: str): 

452 """Starts a consumer thread for callbacks registered from a module.""" 

453 module_queue = queue.Queue(maxsize=self._max_queue_size) 

454 threading.Thread( 

455 target=self._execute_callbacks_of_module, 

456 daemon=True, 

457 name=f"DataBroker/{module_id}", 

458 kwargs={"queue": module_queue, "module_id": module_id}, 

459 ).start() 

460 self._module_queues[module_id] = module_queue 

461 

462 def _execute_callbacks_of_module(self, queue: queue.SimpleQueue, module_id: str): 

463 """Executes the callbacks associated with a specific module.""" 

464 try: 

465 while True: 

466 cb, variable = queue.get(block=True) 

467 cb.callback(variable=variable, **cb.kwargs) 

468 except Exception as e: 

469 self._stop_queue.put((e, module_id)) 

470 raise e 

471 

472 def _run_callbacks(self, callbacks: List[BrokerCallback], variable: AgentVariable): 

473 """Distributes callbacks to the threads running for each module.""" 

474 for cb in callbacks: 

475 self._module_queues[cb.module_id].put_nowait((cb, variable)) 

476 log_queue_status( 

477 logger=self.logger, 

478 queue_name=cb.module_id, 

479 queue_object=self._module_queues[cb.module_id], 

480 max_queue_size=self._max_queue_size, 

481 ) 

482 

483 

484def log_queue_status( 

485 logger: logging.Logger, 

486 queue_object: queue.Queue, 

487 max_queue_size: int, 

488 queue_name: str, 

489): 

490 """ 

491 Log the current load of the given queue in percent. 

492 

493 Args: 

494 logger (logging.Logger): A logger instance 

495 queue_object (queue.Queue): The queue object 

496 max_queue_size (int): Maximal queue size 

497 queue_name (str): Name associated with the queue 

498 """ 

499 if max_queue_size < 1: 

500 return 

501 number_of_items = queue_object.qsize() 

502 percent_full = round(number_of_items / max_queue_size * 100, 2) 

503 if percent_full < 10: 

504 return 

505 elif percent_full < 80: 

506 logger_func = logger.debug 

507 else: 

508 logger_func = logger.warning 

509 logger_func( 

510 "Queue '%s' fullness is %s percent (%s items).", 

511 queue_name, 

512 percent_full, 

513 number_of_items, 

514 )