Coverage for agentlib/core/agent.py: 93%

148 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-10-30 13:39 +0000

1""" 

2Module containing only the Agent class. 

3""" 

4 

5import json 

6import threading 

7from copy import deepcopy 

8from pathlib import Path 

9from typing import Union, List, Dict, TypeVar, Optional 

10 

11from pydantic import field_validator, BaseModel, FilePath, Field 

12from pydantic_core.core_schema import FieldValidationInfo 

13 

14import agentlib 

15import agentlib.core.logging_ as agentlib_logging 

16from agentlib.core import ( 

17 Environment, 

18 LocalDataBroker, 

19 RTDataBroker, 

20 DirectCallbackDataBroker, 

21 BaseModule, 

22 DataBroker, 

23) 

24from agentlib.core.environment import CustomSimpyEnvironment 

25from agentlib.core.errors import ConfigurationError 

26from agentlib.utils import custom_injection 

27from agentlib.utils.load_config import load_config 

28 

29BaseModuleClass = TypeVar("BaseModuleClass", bound=BaseModule) 

30 

31 

32class AgentConfig(BaseModel): 

33 """ 

34 Class containing settings / config for an Agent. 

35 

36 Contains just two fields, id and modules. 

37 """ 

38 

39 id: Union[str, int] = Field( 

40 title="id", 

41 description="The ID of the Agent, should be unique in " 

42 "the multi-agent-system the agent is living in.", 

43 ) 

44 modules: Union[List[Union[Dict, FilePath]], Dict[str, Union[Dict, FilePath]]] = ( 

45 Field( 

46 default_factory=list, 

47 description="A list or dictionary of modules. If a dictionary is provided, keys are treated as module_ids.", 

48 ) 

49 ) 

50 check_alive_interval: float = Field( 

51 title="check_alive_interval", 

52 default=1, 

53 ge=0, 

54 description="Check every other check_alive_interval second " 

55 "if the threads of the agent are still alive." 

56 "If that's not the case, exit the main thread of the " 

57 "agent. Updating this value at runtime will " 

58 "not work as all processes have already been started.", 

59 ) 

60 max_queue_size: Optional[int] = Field( 

61 default=1000, 

62 ge=-1, 

63 description="Maximal number of waiting items in data-broker queues. " 

64 "Set to -1 for infinity", 

65 ) 

66 use_direct_callback_databroker: bool = Field( 

67 default=False, 

68 description="If True, the `DirectCallbackDataBroker` will be used " 

69 ) 

70 

71 @field_validator("modules") 

72 @classmethod 

73 def check_modules(cls, modules: Union[List, Dict], info: FieldValidationInfo): 

74 """Validator to ensure all modules are in dict-format and include 'module_id'.""" 

75 modules_loaded = [] 

76 if isinstance(modules, dict): 

77 for module_id, module in modules.items(): 

78 if isinstance(module, (str, Path)): 

79 if Path(module).exists(): 

80 with open(module, "r") as f: 

81 module = json.load(f) 

82 else: 

83 module = json.loads(module) 

84 if isinstance(module, dict): 

85 module = deepcopy(module) 

86 if "module_id" in module and not module["module_id"] == module_id: 

87 agent = info.data["id"] 

88 raise ConfigurationError( 

89 f"Provided agent {agent} has ambiguous module_id. Module " 

90 f"config was declared with dict key {module_id} but " 

91 f"contains different module_id {module['module_id']} " 

92 f"within config." 

93 ) 

94 module["module_id"] = module_id 

95 modules_loaded.append(module) 

96 elif isinstance(modules, list): 

97 for module in modules: 

98 if isinstance(module, (str, Path)): 

99 if Path(module).exists(): 

100 with open(module, "r") as f: 

101 module = json.load(f) 

102 else: 

103 module = json.loads(module) 

104 modules_loaded.append(module) 

105 else: 

106 raise TypeError("Modules must be a list or a dict") 

107 return modules_loaded 

108 

109 

110class Agent: 

111 """ 

112 The base class for all reactive agent implementations. 

113 

114 Args: 

115 config (Union[AgentConfig, FilePath, str, dict]): 

116 A config object to initialize the agents config 

117 env (Environment): The environment the agent is running in 

118 """ 

119 

120 def __init__(self, *, config, env: Environment): 

121 """ 

122 Create instance of Agent 

123 """ 

124 self._modules = {} 

125 self._threads: Dict[str, threading.Thread] = {} 

126 self.env = env 

127 self.is_alive = True 

128 config: AgentConfig = load_config(config, config_type=AgentConfig) 

129 data_broker_logger = agentlib_logging.create_logger( 

130 env=self.env, name=f"{config.id}/DataBroker" 

131 ) 

132 if env.config.rt: 

133 if config.use_direct_callback_databroker: 

134 raise ValueError("Can not use the direct callback databroker in real-time") 

135 self._data_broker = RTDataBroker( 

136 env=env, logger=data_broker_logger, max_queue_size=config.max_queue_size 

137 ) 

138 self.register_thread(thread=self._data_broker.thread) 

139 elif config.use_direct_callback_databroker: 

140 self._data_broker = DirectCallbackDataBroker( 

141 logger=data_broker_logger 

142 ) 

143 else: 

144 self._data_broker = LocalDataBroker( 

145 env=env, logger=data_broker_logger, max_queue_size=config.max_queue_size 

146 ) 

147 # Update modules 

148 self.config = config 

149 # Setup logger 

150 self.logger = agentlib_logging.create_logger(env=self.env, name=self.id) 

151 

152 # Register the thread monitoring if configured 

153 if env.config.rt: 

154 self.env.process(self._monitor_threads()) 

155 

156 @property 

157 def id(self) -> str: 

158 """ 

159 Getter for current agent's id 

160 

161 Returns: 

162 str: current id of agent 

163 """ 

164 return self.config.id 

165 

166 def __repr__(self): 

167 return f"Agent {self.id}" 

168 

169 @property 

170 def config(self) -> AgentConfig: 

171 """ 

172 Get the config (AgentConfig) of the agent 

173 

174 Returns: 

175 AgentConfig: An instance of AgentConfig 

176 """ 

177 return self._config 

178 

179 @config.setter 

180 def config(self, config: Union[AgentConfig, FilePath, str, dict]): 

181 """ 

182 Set the config of the agent. 

183 As relevant info may be updated, all modules 

184 are re-registered. 

185 

186 Args: 

187 config (Union[AgentConfig, FilePath, str, dict]): 

188 Essentially any object which can be parsed by pydantic 

189 """ 

190 # Set the config 

191 

192 self._config = load_config(config, config_type=AgentConfig) 

193 self._register_modules() 

194 

195 @property 

196 def data_broker(self) -> DataBroker: 

197 """ 

198 Get the data_broker of the agent 

199 

200 Returns: 

201 DataBroker: An instance of the DataBroker class 

202 """ 

203 return self._data_broker 

204 

205 @property 

206 def env(self) -> CustomSimpyEnvironment: 

207 """ 

208 Get the environment the agent is in 

209 

210 Returns: 

211 Environment: The environment instance 

212 """ 

213 return self._env 

214 

215 @env.setter 

216 def env(self, env: Environment): 

217 """ 

218 Set the environment of the agent 

219 

220 Args: 

221 env (Environment): The environment instance 

222 """ 

223 self._env = env 

224 

225 @property 

226 def modules(self) -> List[BaseModuleClass]: 

227 """ 

228 Get all modules of agent 

229 

230 Returns: 

231 List[BaseModule]: List of all modules 

232 """ 

233 return list(self._modules.values()) 

234 

235 def get_module(self, module_id: str) -> BaseModuleClass: 

236 """ 

237 Get the module by given module_id. 

238 If no such module exists, None is returned 

239 Args: 

240 module_id (str): Id of the module to return 

241 Returns: 

242 BaseModule: Module with the given name 

243 """ 

244 return self._modules.get(module_id, None) 

245 

246 def register_thread(self, thread: threading.Thread): 

247 """ 

248 Registers the given thread to the dictionary of threads 

249 which need to run in order for the agent 

250 to work. 

251 

252 Args: 

253 thread threading.Thread: 

254 The thread object 

255 """ 

256 name = thread.name 

257 if name in self._threads: 

258 raise KeyError( 

259 f"Given thread with name '{name}' is already a registered thread" 

260 ) 

261 if not thread.daemon: 

262 self.logger.warning( 

263 "'%s' is not a daemon thread. " 

264 "If the agent raises an error, the thread will keep running.", 

265 name, 

266 ) 

267 self._threads[name] = thread 

268 

269 def _monitor_threads(self): 

270 """Process loop to monitor the threads of the agent.""" 

271 while True: 

272 for name, thread in self._threads.items(): 

273 if not thread.is_alive(): 

274 msg = ( 

275 f"The thread {name} is not alive anymore. Exiting agent. " 

276 f"Check errors above for possible reasons" 

277 ) 

278 self.logger.critical(msg) 

279 self.is_alive = False 

280 raise RuntimeError(msg) 

281 yield self.env.timeout(self.config.check_alive_interval) 

282 

283 def _register_modules(self): 

284 """ 

285 Function to register all modules from the 

286 current config. 

287 The module_ids need to be unique inside the 

288 agents config. 

289 The agent object (self) is passed to the modules. 

290 This is the reason the function is not inside the 

291 validator. 

292 """ 

293 updated_modules = [] 

294 for module_config in self.config.modules: 

295 module_cls = get_module_class(module_config=module_config) 

296 _module_id = module_config.get("module_id", module_cls.__name__) 

297 

298 # Insert default module id if it did not exist: 

299 module_config.update({"module_id": _module_id}) 

300 

301 if _module_id in updated_modules: 

302 raise KeyError( 

303 f"Module with module_id '{_module_id}' " 

304 f"exists multiple times inside agent " 

305 f"{self.id}. Use unique names only." 

306 ) 

307 

308 updated_modules.append(_module_id) 

309 

310 if _module_id in self._modules: 

311 # Update the config: 

312 self.get_module(_module_id).config = module_config 

313 else: 

314 # Add the modules to the list of modules 

315 self._modules.update( 

316 {_module_id: module_cls(agent=self, config=module_config)} 

317 ) 

318 

319 def get_results(self, cleanup=False): 

320 """ 

321 Gets the results of this agent. 

322 Args: 

323 cleanup: If true, created files are deleted. 

324 """ 

325 results = {} 

326 for module in self.modules: 

327 try: 

328 result = module.get_results() 

329 except BaseException as e: 

330 self.logger.error(f"Error reading results of module {module.id}: {e}") 

331 result = None 

332 if result is not None: 

333 results[module.id] = result 

334 if cleanup: 

335 self.clean_results() 

336 return results 

337 

338 def clean_results(self): 

339 """ 

340 Calls the cleanup_results function of all modules, removing files that 

341 were created by them. 

342 """ 

343 for module in self.modules: 

344 try: 

345 module.cleanup_results() 

346 except BaseException as e: 

347 self.logger.error( 

348 f"Could not cleanup results for the following module: {module.id}. " 

349 f"The reason is the following exception: {e}" 

350 ) 

351 

352 def terminate(self): 

353 """Calls the terminate function of all modules.""" 

354 for module in self.modules: 

355 module.terminate() 

356 

357 

358def get_module_class(module_config): 

359 """ 

360 Return the Module-Class object for the given config. 

361 

362 Args: 

363 module_config (dict): Config of the module to return 

364 Returns: 

365 BaseModule: Module-Class object 

366 """ 

367 _type = module_config.get("type") 

368 

369 if isinstance(_type, str): 

370 # Get the module-class from the agentlib 

371 module_cls = agentlib.modules.get_module_type(module_type=_type.casefold()) 

372 elif isinstance(_type, dict): 

373 # Load module class 

374 module_cls = custom_injection(config=_type) 

375 else: 

376 raise TypeError( 

377 f"Given module type is of type '{type(_type)}' " 

378 f"but should be str or dict." 

379 ) 

380 

381 return module_cls