Coverage for agentlib/utils/multi_processing_broker.py: 81%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1""" 

2Module containing a MultiProcessingBroker that 

3enables communication across different processes. 

4""" 

5 

6import json 

7import multiprocessing 

8from ipaddress import IPv4Address 

9from multiprocessing.managers import SyncManager 

10import threading 

11import time 

12from collections import namedtuple 

13from typing import Union 

14import logging 

15 

16from pathlib import Path 

17from pydantic import BaseModel, Field, FilePath 

18 

19from .broker import Broker 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24MPClient = namedtuple("MPClient", ["agent_id", "read", "write"]) 

25Message = namedtuple("Message", ["agent_id", "payload"]) 

26 

27 

28class BrokerManager(SyncManager): 

29 pass 

30 

31 

32BrokerManager.register("get_queue") 

33 

34 

35class MultiProcessingBrokerConfig(BaseModel): 

36 """Class describing the configuration options for the MultiProcessingBroker.""" 

37 

38 ipv4: IPv4Address = Field( 

39 default="127.0.0.1", 

40 description="IP Address for the communication server. Defaults to localhost.", 

41 ) 

42 port: int = Field( 

43 default=50000, description="Port for setting up the connection with the server." 

44 ) 

45 authkey: bytes = Field( 

46 default=b"useTheAgentlib", 

47 description="Authorization key for the connection with the broker.", 

48 ) 

49 

50 

51ConfigTypes = Union[MultiProcessingBrokerConfig, dict, str, FilePath] 

52 

53 

54class MultiProcessingBroker(Broker): 

55 """ 

56 Singleton which acts as a broker for distributed simulations among multiple 

57 local processes. Establishes a connection to a multiprocessing.Manager object, 

58 which defines a queue. This queue is used to receive connection requests from 

59 local clients. The clients send a Conn object (from multiprocessing.Pipe()) 

60 object through which the connection is established. 

61 For each connected client, a thread waits for incoming objects. 

62 """ 

63 

64 def __init__(self, config: ConfigTypes = None): 

65 super().__init__() 

66 if config is None: 

67 self.config = MultiProcessingBrokerConfig() 

68 else: 

69 self.config = config 

70 server = multiprocessing.Process( 

71 target=self._server, name="Broker_Server", args=(self.config,), daemon=True 

72 ) 

73 server.start() 

74 

75 signup_handler = threading.Thread( 

76 target=self._signup_handler, daemon=True, name="Broker_SignUp" 

77 ) 

78 signup_handler.start() 

79 

80 @property 

81 def config(self) -> MultiProcessingBrokerConfig: 

82 """Return the config of the environment""" 

83 return self._config 

84 

85 @config.setter 

86 def config(self, config: ConfigTypes): 

87 """Set the config/settings of the environment""" 

88 if isinstance(config, MultiProcessingBrokerConfig): 

89 self._config = config 

90 return 

91 elif isinstance(config, (str, Path)): 

92 if Path(config).exists(): 

93 with open(config, "r") as f: 

94 config = json.load(f) 

95 self._config = MultiProcessingBrokerConfig.model_validate(config) 

96 

97 @staticmethod 

98 def _server(config: MultiProcessingBrokerConfig): 

99 """Creates the Manager object which owns the queue and lets it serve forever.""" 

100 from multiprocessing.managers import BaseManager 

101 from queue import Queue 

102 

103 queue = Queue() 

104 

105 class QueueManager(BaseManager): 

106 pass 

107 

108 QueueManager.register("get_queue", callable=lambda: queue) 

109 m = QueueManager(address=(config.ipv4, config.port), authkey=config.authkey) 

110 

111 s = m.get_server() 

112 s.serve_forever() 

113 

114 def _signup_handler(self): 

115 """Connects to the manager queue and processes the signup requests. Starts a 

116 child thread listening to messages from each client.""" 

117 from multiprocessing.managers import BaseManager 

118 

119 class QueueManager(BaseManager): 

120 pass 

121 

122 QueueManager.register("get_queue") 

123 m = QueueManager( 

124 address=(self.config.ipv4, self.config.port), authkey=self.config.authkey 

125 ) 

126 started_wait = time.time() 

127 while True: 

128 try: 

129 m.connect() 

130 break 

131 except ConnectionRefusedError: 

132 time.sleep(0.01) 

133 if time.time() - started_wait > 10: 

134 raise RuntimeError("Could not connect to server.") 

135 

136 signup_queue = m.get_queue() 

137 

138 while True: 

139 try: 

140 client = signup_queue.get() 

141 except ConnectionResetError: 

142 logger.info("Multiprocessing Broker disconnected.") 

143 break 

144 

145 with self.lock: 

146 self._clients.add(client) 

147 

148 # send the client an ack its messages are now being received 

149 client.write.send(1) 

150 threading.Thread( 

151 target=self._client_loop, 

152 args=(client,), 

153 name=f"MPBroker_{client.agent_id}", 

154 daemon=True, 

155 ).start() 

156 

157 def _client_loop(self, client: MPClient): 

158 """Receives messages from a client and redistributes them.""" 

159 while True: 

160 try: 

161 msg: Message = client.read.recv() 

162 except EOFError: 

163 with self.lock: 

164 self._clients.remove(client) 

165 break 

166 self.send(message=msg.payload, source=msg.agent_id) 

167 

168 def send(self, source, message): 

169 """ 

170 Send the given message to all clients if the source 

171 matches. 

172 Args: 

173 source: Source to match 

174 message: The message to send 

175 

176 Returns: 

177 

178 """ 

179 # lock is required so the clients loop does not change size during 

180 # iteration if clients are added or removed 

181 with self.lock: 

182 for client in list(self._clients): 

183 if client.agent_id != source: 

184 try: 

185 client.write.send(message) 

186 except BrokenPipeError: 

187 pass