import abc
import time
from functools import cached_property
from typing import Union, List, Optional
from pydantic import AnyUrl, Field, ValidationError, field_validator
from agentlib.modules.communicator.communicator import (
Communicator,
SubscriptionCommunicatorConfig,
)
from agentlib.core import Agent
from agentlib.core.datamodels import AgentVariable
from agentlib.core.errors import InitializationError
from agentlib.utils.validators import convert_to_list
from agentlib.core.errors import OptionalDependencyError
try:
from paho.mqtt.client import (
Client as PahoMQTTClient,
MQTTv5,
MQTT_CLEAN_START_FIRST_ONLY,
MQTT_LOG_ERR,
MQTT_LOG_WARNING,
)
from paho.mqtt.enums import CallbackAPIVersion
except ImportError as err:
raise OptionalDependencyError(
dependency_name="mqtt",
dependency_install="paho-mqtt",
used_object="Module type 'mqtt'",
) from err
[docs]class BaseMQTTClientConfig(SubscriptionCommunicatorConfig):
keepalive: int = Field(
default=60,
description="Maximum period in seconds between "
"communications with the broker. "
"If no other messages are being "
"exchanged, this controls the "
"rate at which the client will "
"send ping messages to the "
"broker.",
)
clean_start: bool = Field(
default=True,
description="True, False or "
"MQTT_CLEAN_START_FIRST_ONLY."
"Sets the MQTT v5.0 clean_start "
"flag always, never or on the "
"first successful connect "
"only, respectively. "
"MQTT session data (such as "
"outstanding messages and "
"subscriptions) is cleared "
"on successful connect when "
"the clean_start flag is set.",
)
subtopics: Union[List[str], str] = Field(
default=[], description="Topics to that the agent subscribes"
)
prefix: str = Field(default="/agentlib", description="Prefix for MQTT-Topic")
qos: int = Field(default=0, description="Quality of Service", ge=0, le=2)
connection_timeout: float = Field(
default=10,
description="Number of seconds to wait for the initial connection "
"until throwing an Error.",
)
username: Optional[str] = Field(default=None, title="Username to login")
password: Optional[str] = Field(default=None, title="Password to login")
use_tls: Optional[bool] = Field(
default=None, description="Option to use TLS certificates"
)
tls_ca_certs: Optional[str] = Field(
default=None,
description="Path to the Certificate Authority certificate files. "
"If None, windows certificate will be used.",
)
client_id: Optional[str] = Field(default=None, title="Client ID")
# Add validator
check_subtopics = field_validator("subtopics")(convert_to_list)
[docs]class MQTTClientConfig(BaseMQTTClientConfig):
url: AnyUrl = Field(
title="Host",
description="Host is the hostname or IP address " "of the remote broker.",
)
[docs] @field_validator("url")
@classmethod
def check_url(cls, url):
if url.scheme in ["mqtts", "mqtt"]:
return url
if url.scheme is None:
url.scheme = "mqtt"
return url
raise ValidationError
[docs]class BaseMqttClient(Communicator):
# We use the paho-mqtt module and are
# thus required to use their function signatures and function names
# pylint: disable=unused-argument,too-many-arguments,invalid-name
config: BaseMQTTClientConfig
mqttc_type = PahoMQTTClient
def _log_all(self, client, userdata, level, buf):
"""
client: the client instance for this callback
userdata: the private user data as set in Client() or userdata_set()
level: gives the severity of the message and will be one of
MQTT_LOG_INFO, MQTT_LOG_NOTICE, MQTT_LOG_WARNING,
MQTT_LOG_ERR, and MQTT_LOG_DEBUG.
buf: the message itself
Args:
*args:
Returns:
"""
if level == MQTT_LOG_ERR or level == MQTT_LOG_WARNING:
self.logger.error("ERROR OR WARNING: %s", buf)
def __init__(self, config: dict, agent: Agent):
super().__init__(config=config, agent=agent)
self._subcribed_topics = 0
self._mqttc = self.mqttc_type(
client_id=self.config.client_id or str(self.source),
protocol=MQTTv5,
callback_api_version=CallbackAPIVersion.VERSION2,
)
if self.config.username is not None:
self.logger.debug("Setting password and username")
self._mqttc.username_pw_set(
username=self.config.username, password=self.config.password
)
# Add TLS-Settings (default behavior)
if self.config.use_tls is None:
self._mqttc.tls_set(ca_certs=self.config.tls_ca_certs)
# Add TLS-Settings
if self.config.use_tls:
self._mqttc.tls_set(ca_certs=self.config.tls_ca_certs)
self._mqttc.on_connect = self._connect_callback
self._mqttc.on_disconnect = self._disconnect_callback
self._mqttc.on_message = self._message_callback
self._mqttc.on_subscribe = self._subscribe_callback
self._mqttc.on_log = self._log_all
self._mqttc.loop_start()
self.connect()
self.logger.info(
"Agent %s waits for mqtt connections to be ready ...", self.agent.id
)
started_wait = time.time()
while True:
if (
self._mqttc.is_connected()
and self._subcribed_topics == self.topics_size
):
break
if time.time() - started_wait > self.config.connection_timeout:
raise InitializationError("Could not connect to MQTT broker.")
self.logger.info("Module is fully connected")
[docs] @abc.abstractmethod
def connect(self):
raise NotImplementedError
[docs] def terminate(self):
"""Disconnect from client and join loop"""
self.disconnect()
super().terminate()
# The callback for when the client receives a CONNACK response from the server.
def _connect_callback(self, client, userdata, flags, reasonCode, properties):
if reasonCode != 0:
err_msg = f"Connection failed with error code: '{reasonCode}'"
self.logger.error(err_msg)
raise ConnectionError(err_msg)
self.logger.debug("Connected with result code: '%s'", reasonCode)
[docs] def disconnect(self, reasoncode=None, properties=None):
"""Trigger the disconnect"""
self._mqttc.disconnect(reasoncode=reasoncode, properties=properties)
def _disconnect_callback(self, client, userdata, reasonCode, properties):
"""Stop the loop as a result of the disconnect"""
self.logger.warning(
"Disconnected with result code: %s | userdata: %s | properties: %s",
reasonCode,
userdata,
properties,
)
self.logger.info("Active: %s", self._mqttc.is_connected())
def _message_callback(self, client, userdata, msg):
"""
The default callback for when a PUBLISH message is
received from the server.
"""
agent_inp = AgentVariable.from_json(msg.payload)
self.logger.debug(
"Received variable %s = %s from source %s",
agent_inp.alias,
agent_inp.value,
agent_inp.source,
)
self.agent.data_broker.send_variable(agent_inp)
def _subscribe_callback(self, client, userdata, mid, reasonCodes, properties):
"""Log if the subscription was successful"""
for reason_code in reasonCodes:
if reason_code == self.config.qos:
self._subcribed_topics += 1
self.logger.info(
"Subscribed to topic %s/%s",
self._subcribed_topics,
self.topics_size,
)
else:
msg = f"{self.agent.id}'s subscription failed: {reason_code}"
self.logger.error(msg)
raise ConnectionError(msg)
@property
def topics_size(self):
return len(self.config.subtopics) + len(self.config.subscriptions)
[docs]class MqttClient(BaseMqttClient):
config: MQTTClientConfig
@cached_property
def pubtopic(self):
return self.generate_topic(agent_id=self.agent.id, subscription=False)
@property
def topics_size(self):
return len(self._get_all_topics())
[docs] def generate_topic(self, agent_id: str, subscription: bool = True):
"""
Generate the topic with the given agent_id and
configs prefix
"""
if subscription:
topic = "/".join([self.config.prefix, agent_id, "#"])
else:
topic = "/".join([self.config.prefix, agent_id])
topic.replace("//", "/")
return topic
[docs] def connect(self):
port = self.config.url.port
if port is None:
port = 1883
else:
port = int(port)
self._mqttc.connect(
host=self.config.url.host,
port=port,
keepalive=self.config.keepalive,
bind_address="",
bind_port=0,
clean_start=MQTT_CLEAN_START_FIRST_ONLY,
properties=None,
)
def _get_all_topics(self):
"""
Helper function to return all topics the client
should listen to.
"""
topics = set()
for subscription in self.config.subscriptions:
topics.add(self.generate_topic(agent_id=subscription))
topics.update(set(self.config.subtopics))
return topics
def _connect_callback(self, client, userdata, flags, reasonCode, properties):
super()._connect_callback(
client=client,
userdata=userdata,
flags=flags,
reasonCode=reasonCode,
properties=properties,
)
# Subscribing in on_connect() means that if we lose the connection and
# reconnect then subscriptions will be renewed.
self._subcribed_topics = 0 # Reset counter as well
for topic in self._get_all_topics():
client.subscribe(topic=topic, qos=self.config.qos)
self.logger.info("Subscribes to: '%s'", topic)
def _send(self, payload: dict):
"""Publish the given output"""
topic = "/".join([self.pubtopic, payload["alias"]])
self._mqttc.publish(
topic=topic,
payload=self.to_json(payload),
qos=self.config.qos,
retain=False,
properties=None,
)