Coverage for agentlib/modules/communicator/communicator.py: 68%
119 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-30 13:00 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-04-30 13:00 +0000
1"""
2Module contains basics communicator modules
3"""
5import abc
6import json
7import queue
8import threading
9from typing import Union, List, TypedDict, Any
11import pandas as pd
12from pydantic import Field, field_validator
14from agentlib.core import Agent, BaseModule, BaseModuleConfig
15from agentlib.core.datamodels import AgentVariable
16from agentlib.core.errors import OptionalDependencyError
17from agentlib.utils.broker import Broker
18from agentlib.utils.validators import convert_to_list
21class CommunicationDict(TypedDict):
22 alias: str
23 value: Any
24 timestamp: float
25 type: str
26 source: str
29class CommunicatorConfig(BaseModuleConfig):
30 use_orjson: bool = Field(
31 title="Use orjson",
32 default=False,
33 description="If true, the faster orjson library will be used for serialization "
34 "deserialization. Requires the optional dependency.",
35 )
38class SubscriptionCommunicatorConfig(CommunicatorConfig):
39 subscriptions: Union[List[str], str] = Field(
40 title="Subscriptions",
41 default=[],
42 description="List of agent-id strings to subscribe to",
43 )
44 check_subscriptions = field_validator("subscriptions")(convert_to_list)
47class Communicator(BaseModule):
48 """
49 Base class for all communicators
50 """
52 config: CommunicatorConfig
54 def __init__(self, *, config: dict, agent: Agent):
55 super().__init__(config=config, agent=agent)
57 if self.config.use_orjson:
58 try:
59 import orjson
60 except ImportError:
61 raise OptionalDependencyError(
62 dependency_name="orjson",
63 dependency_install="orjson",
64 used_object="Communicator with 'use_orjson=True'",
65 )
67 def _to_orjson(payload: CommunicationDict) -> bytes:
68 return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
70 self.to_json = _to_orjson
71 else:
73 def _to_json_builtin(payload: CommunicationDict) -> str:
74 return json.dumps(payload)
76 self.to_json = _to_json_builtin
78 def register_callbacks(self):
79 """Register all outputs to the callback function"""
80 self.agent.data_broker.register_callback(
81 callback=self._send_only_shared_variables, _unsafe_no_copy=True
82 )
84 def process(self):
85 yield self.env.event()
87 def _send_only_shared_variables(self, variable: AgentVariable):
88 """Send only variables with field ``shared=True``"""
89 if not self._variable_can_be_send(variable):
90 return
92 payload = self.short_dict(variable)
93 self.logger.debug("Sending variable %s=%s", variable.alias, variable.value)
94 self._send(payload=payload)
96 def _variable_can_be_send(self, variable):
97 return variable.shared and (
98 (variable.source.agent_id is None)
99 or (variable.source.agent_id == self.agent.id)
100 )
102 @abc.abstractmethod
103 def _send(self, payload: CommunicationDict):
104 raise NotImplementedError(
105 "This method needs to be implemented " "individually for each communicator"
106 )
108 def short_dict(self, variable: AgentVariable, parse_json: bool = True) -> CommunicationDict:
109 """Creates a short dict serialization of the Variable.
111 Only contains attributes of the AgentVariable, that are relevant for other
112 modules or agents. For performance and privacy reasons, this function should
113 be called for communicators."""
114 if isinstance(variable.value, pd.Series) and parse_json:
115 value = variable.value.to_json()
116 else:
117 value = variable.value
118 return CommunicationDict(
119 alias=variable.alias,
120 value=value,
121 timestamp=variable.timestamp,
122 type=variable.type,
123 source=self.agent.id,
124 )
126 def to_json(self, payload: CommunicationDict) -> Union[bytes, str]:
127 """Transforms the payload into json serialized form. Dynamically uses orjson
128 if it is installed, and the builtin json otherwise.
130 Returns bytes or str depending on the library used, but this has not mattered
131 with the communicators as of now.
132 """
133 # implemented on init
134 pass
137class LocalCommunicatorConfig(CommunicatorConfig):
138 parse_json: bool = Field(
139 title="Indicate whether variables are converted to json before sending. "
140 "Increasing computing time but makes MAS more close to later stages"
141 "which use MQTT or similar.",
142 default=False,
143 )
144 queue_size: int = Field(
145 title="Size of the queue",
146 default=10000
147 )
150class LocalCommunicator(Communicator):
151 """
152 Base class for local communicators.
153 """
155 config: LocalCommunicatorConfig
157 def __init__(self, config: dict, agent: Agent):
158 # assign methods to receive messages either in realtime or in the
159 # simpy process. Has to be done before calling super().__init__()
160 # because that already calls the process method
161 if agent.env.config.rt:
162 self.process = self._process_realtime
163 self.receive = self._receive_realtime
164 self._loop = None
165 else:
166 self._received_variable = agent.env.event()
167 self.process = self._process
168 self.receive = self._receive
170 super().__init__(config=config, agent=agent)
171 self.broker = self.setup_broker()
172 self._msg_q_in = queue.Queue(self.config.queue_size)
173 self.broker.register_client(client=self)
175 @property
176 def broker(self) -> Broker:
177 """Broker used by LocalCommunicator"""
178 return self._broker
180 @broker.setter
181 def broker(self, broker):
182 """Set the broker of the LocalCommunicator"""
183 self._broker = broker
184 self.logger.info("%s uses broker %s", self.__class__.__name__, self.broker)
186 @abc.abstractmethod
187 def setup_broker(self):
188 """Function to set up the broker object.
189 Needs to return a valid broker option."""
190 raise NotImplementedError(
191 "This method needs to be implemented " "individually for each communicator"
192 )
194 def _send_only_shared_variables(self, variable: AgentVariable):
195 """Send only variables with field ``shared=True``"""
196 if not self._variable_can_be_send(variable):
197 return
199 payload = self.short_dict(variable, parse_json=self.config.parse_json)
200 self.logger.debug("Sending variable %s=%s", variable.alias, variable.value)
201 self._send(payload=payload)
203 def _process(self):
204 """Waits for new messages, sends them to the broker."""
205 yield self.env.event()
207 def _process_realtime(self):
208 """Only start the loop once the env is running."""
209 self._loop = threading.Thread(
210 target=self._message_handler, name=str(self.source)
211 )
212 self._loop.daemon = True # Necessary to enable terminations of scripts
213 self._loop.start()
214 self.agent.register_thread(thread=self._loop)
215 yield self.env.event()
217 def _send_simpy(self, ignored):
218 """Sends new messages to the broker when receiving them, adhering to the
219 simpy event queue. To be appended to a simpy event callback."""
220 variable = self._msg_q_in.get_nowait()
221 self.agent.data_broker.send_variable(variable)
223 def _receive(self, msg_obj):
224 """Receive a given message and put it in the queue and set the
225 corresponding simpy event."""
226 if self.config.parse_json:
227 variable = AgentVariable.from_json(msg_obj)
228 else:
229 variable = msg_obj
230 self._msg_q_in.put(variable, block=False)
231 self._received_variable.callbacks.append(self._send_simpy)
232 self._received_variable.succeed()
233 self._received_variable = self.env.event()
235 def _receive_realtime(self, msg_obj):
236 """Receive a given message and put it in the queue. No event setting
237 is required for realtime."""
238 if self.config.parse_json:
239 variable = AgentVariable.from_json(msg_obj)
240 else:
241 variable = msg_obj
242 self._msg_q_in.put(variable)
244 def _message_handler(self):
245 """Reads messages that were put in the message queue."""
246 while True:
247 variable = self._msg_q_in.get()
248 self.agent.data_broker.send_variable(variable)
250 def terminate(self):
251 # Terminating is important when running multiple
252 # simulations/environments, otherwise the broker will keep spamming all
253 # agents from the previous simulation, potentially filling their queues.
254 self.broker.delete_client(self)
255 super().terminate()