Coverage for agentlib/modules/communicator/mqtt.py: 0%

135 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1import abc 

2import time 

3from functools import cached_property 

4from typing import Union, List, Optional 

5 

6from pydantic import AnyUrl, Field, ValidationError, field_validator 

7 

8from agentlib.modules.communicator.communicator import ( 

9 Communicator, 

10 SubscriptionCommunicatorConfig, 

11) 

12from agentlib.core import Agent 

13from agentlib.core.datamodels import AgentVariable 

14from agentlib.core.errors import InitializationError 

15from agentlib.utils.validators import convert_to_list 

16from agentlib.core.errors import OptionalDependencyError 

17 

18try: 

19 from paho.mqtt.client import ( 

20 Client as PahoMQTTClient, 

21 MQTTv5, 

22 MQTT_CLEAN_START_FIRST_ONLY, 

23 MQTT_LOG_ERR, 

24 MQTT_LOG_WARNING, 

25 ) 

26 from paho.mqtt.enums import CallbackAPIVersion 

27except ImportError as err: 

28 raise OptionalDependencyError( 

29 dependency_name="mqtt", 

30 dependency_install="paho-mqtt", 

31 used_object="Module type 'mqtt'", 

32 ) from err 

33 

34 

35class BaseMQTTClientConfig(SubscriptionCommunicatorConfig): 

36 keepalive: int = Field( 

37 default=60, 

38 description="Maximum period in seconds between " 

39 "communications with the broker. " 

40 "If no other messages are being " 

41 "exchanged, this controls the " 

42 "rate at which the client will " 

43 "send ping messages to the " 

44 "broker.", 

45 ) 

46 clean_start: bool = Field( 

47 default=True, 

48 description="True, False or " 

49 "MQTT_CLEAN_START_FIRST_ONLY." 

50 "Sets the MQTT v5.0 clean_start " 

51 "flag always, never or on the " 

52 "first successful connect " 

53 "only, respectively. " 

54 "MQTT session data (such as " 

55 "outstanding messages and " 

56 "subscriptions) is cleared " 

57 "on successful connect when " 

58 "the clean_start flag is set.", 

59 ) 

60 subtopics: Union[List[str], str] = Field( 

61 default=[], description="Topics to that the agent subscribes" 

62 ) 

63 prefix: str = Field(default="/agentlib", description="Prefix for MQTT-Topic") 

64 qos: int = Field(default=0, description="Quality of Service", ge=0, le=2) 

65 connection_timeout: float = Field( 

66 default=10, 

67 description="Number of seconds to wait for the initial connection " 

68 "until throwing an Error.", 

69 ) 

70 username: Optional[str] = Field(default=None, title="Username to login") 

71 password: Optional[str] = Field(default=None, title="Password to login") 

72 use_tls: Optional[bool] = Field( 

73 default=None, description="Option to use TLS certificates" 

74 ) 

75 tls_ca_certs: Optional[str] = Field( 

76 default=None, 

77 description="Path to the Certificate Authority certificate files. " 

78 "If None, windows certificate will be used.", 

79 ) 

80 client_id: Optional[str] = Field(default=None, title="Client ID") 

81 

82 # Add validator 

83 check_subtopics = field_validator("subtopics")(convert_to_list) 

84 

85 

86class MQTTClientConfig(BaseMQTTClientConfig): 

87 url: AnyUrl = Field( 

88 title="Host", 

89 description="Host is the hostname or IP address " "of the remote broker.", 

90 ) 

91 

92 @field_validator("url") 

93 @classmethod 

94 def check_url(cls, url): 

95 if url.scheme in ["mqtts", "mqtt"]: 

96 return url 

97 if url.scheme is None: 

98 url.scheme = "mqtt" 

99 return url 

100 raise ValidationError 

101 

102 

103class BaseMqttClient(Communicator): 

104 # We use the paho-mqtt module and are 

105 # thus required to use their function signatures and function names 

106 # pylint: disable=unused-argument,too-many-arguments,invalid-name 

107 config: BaseMQTTClientConfig 

108 mqttc_type = PahoMQTTClient 

109 

110 def _log_all(self, client, userdata, level, buf): 

111 """ 

112 client: the client instance for this callback 

113 userdata: the private user data as set in Client() or userdata_set() 

114 level: gives the severity of the message and will be one of 

115 MQTT_LOG_INFO, MQTT_LOG_NOTICE, MQTT_LOG_WARNING, 

116 MQTT_LOG_ERR, and MQTT_LOG_DEBUG. 

117 buf: the message itself 

118 Args: 

119 *args: 

120 

121 Returns: 

122 

123 """ 

124 if level == MQTT_LOG_ERR or level == MQTT_LOG_WARNING: 

125 self.logger.error("ERROR OR WARNING: %s", buf) 

126 

127 def __init__(self, config: dict, agent: Agent): 

128 super().__init__(config=config, agent=agent) 

129 self._subcribed_topics = 0 

130 self._mqttc = self.mqttc_type( 

131 client_id=self.config.client_id or str(self.source), 

132 protocol=MQTTv5, 

133 callback_api_version=CallbackAPIVersion.VERSION2, 

134 ) 

135 if self.config.username is not None: 

136 self.logger.debug("Setting password and username") 

137 self._mqttc.username_pw_set( 

138 username=self.config.username, password=self.config.password 

139 ) 

140 # Add TLS-Settings (default behavior) 

141 if self.config.use_tls is None: 

142 self._mqttc.tls_set(ca_certs=self.config.tls_ca_certs) 

143 # Add TLS-Settings 

144 if self.config.use_tls: 

145 self._mqttc.tls_set(ca_certs=self.config.tls_ca_certs) 

146 

147 self._mqttc.on_connect = self._connect_callback 

148 self._mqttc.on_disconnect = self._disconnect_callback 

149 self._mqttc.on_message = self._message_callback 

150 self._mqttc.on_subscribe = self._subscribe_callback 

151 self._mqttc.on_log = self._log_all 

152 self._mqttc.loop_start() 

153 

154 self.connect() 

155 

156 self.logger.info( 

157 "Agent %s waits for mqtt connections to be ready ...", self.agent.id 

158 ) 

159 started_wait = time.time() 

160 while True: 

161 if ( 

162 self._mqttc.is_connected() 

163 and self._subcribed_topics == self.topics_size 

164 ): 

165 break 

166 if time.time() - started_wait > self.config.connection_timeout: 

167 raise InitializationError("Could not connect to MQTT broker.") 

168 

169 self.logger.info("Module is fully connected") 

170 

171 @abc.abstractmethod 

172 def connect(self): 

173 raise NotImplementedError 

174 

175 def terminate(self): 

176 """Disconnect from client and join loop""" 

177 self.disconnect() 

178 super().terminate() 

179 

180 # The callback for when the client receives a CONNACK response from the server. 

181 def _connect_callback(self, client, userdata, flags, reasonCode, properties): 

182 if reasonCode != 0: 

183 err_msg = f"Connection failed with error code: '{reasonCode}'" 

184 self.logger.error(err_msg) 

185 raise ConnectionError(err_msg) 

186 self.logger.debug("Connected with result code: '%s'", reasonCode) 

187 

188 def disconnect(self, reasoncode=None, properties=None): 

189 """Trigger the disconnect""" 

190 self._mqttc.disconnect(reasoncode=reasoncode, properties=properties) 

191 

192 def _disconnect_callback(self, client, userdata, reasonCode, properties): 

193 """Stop the loop as a result of the disconnect""" 

194 self.logger.warning( 

195 "Disconnected with result code: %s | userdata: %s | properties: %s", 

196 reasonCode, 

197 userdata, 

198 properties, 

199 ) 

200 self.logger.info("Active: %s", self._mqttc.is_connected()) 

201 

202 def _message_callback(self, client, userdata, msg): 

203 """ 

204 The default callback for when a PUBLISH message is 

205 received from the server. 

206 """ 

207 agent_inp = AgentVariable.from_json(msg.payload) 

208 self.logger.debug( 

209 "Received variable %s = %s from source %s", 

210 agent_inp.alias, 

211 agent_inp.value, 

212 agent_inp.source, 

213 ) 

214 self.agent.data_broker.send_variable(agent_inp) 

215 

216 def _subscribe_callback(self, client, userdata, mid, reasonCodes, properties): 

217 """Log if the subscription was successful""" 

218 for reason_code in reasonCodes: 

219 if reason_code == self.config.qos: 

220 self._subcribed_topics += 1 

221 self.logger.info( 

222 "Subscribed to topic %s/%s", 

223 self._subcribed_topics, 

224 self.topics_size, 

225 ) 

226 else: 

227 msg = f"{self.agent.id}'s subscription failed: {reason_code}" 

228 self.logger.error(msg) 

229 raise ConnectionError(msg) 

230 

231 @property 

232 def topics_size(self): 

233 return len(self.config.subtopics) + len(self.config.subscriptions) 

234 

235 

236class MqttClient(BaseMqttClient): 

237 config: MQTTClientConfig 

238 

239 @cached_property 

240 def pubtopic(self): 

241 return self.generate_topic(agent_id=self.agent.id, subscription=False) 

242 

243 @property 

244 def topics_size(self): 

245 return len(self._get_all_topics()) 

246 

247 def generate_topic(self, agent_id: str, subscription: bool = True): 

248 """ 

249 Generate the topic with the given agent_id and 

250 configs prefix 

251 """ 

252 if subscription: 

253 topic = "/".join([self.config.prefix, agent_id, "#"]) 

254 else: 

255 topic = "/".join([self.config.prefix, agent_id]) 

256 topic.replace("//", "/") 

257 return topic 

258 

259 def connect(self): 

260 port = self.config.url.port 

261 if port is None: 

262 port = 1883 

263 else: 

264 port = int(port) 

265 self._mqttc.connect( 

266 host=self.config.url.host, 

267 port=port, 

268 keepalive=self.config.keepalive, 

269 bind_address="", 

270 bind_port=0, 

271 clean_start=MQTT_CLEAN_START_FIRST_ONLY, 

272 properties=None, 

273 ) 

274 

275 def _get_all_topics(self): 

276 """ 

277 Helper function to return all topics the client 

278 should listen to. 

279 """ 

280 topics = set() 

281 for subscription in self.config.subscriptions: 

282 topics.add(self.generate_topic(agent_id=subscription)) 

283 topics.update(set(self.config.subtopics)) 

284 return topics 

285 

286 def _connect_callback(self, client, userdata, flags, reasonCode, properties): 

287 super()._connect_callback( 

288 client=client, 

289 userdata=userdata, 

290 flags=flags, 

291 reasonCode=reasonCode, 

292 properties=properties, 

293 ) 

294 # Subscribing in on_connect() means that if we lose the connection and 

295 # reconnect then subscriptions will be renewed. 

296 self._subcribed_topics = 0 # Reset counter as well 

297 

298 for topic in self._get_all_topics(): 

299 client.subscribe(topic=topic, qos=self.config.qos) 

300 self.logger.info("Subscribes to: '%s'", topic) 

301 

302 def _send(self, payload: dict): 

303 """Publish the given output""" 

304 topic = "/".join([self.pubtopic, payload["alias"]]) 

305 self._mqttc.publish( 

306 topic=topic, 

307 payload=self.to_json(payload), 

308 qos=self.config.qos, 

309 retain=False, 

310 properties=None, 

311 )