Coverage for agentlib/utils/plotting/simulator_dashboard.py: 46%

112 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-04-07 16:27 +0000

1import io 

2import socket 

3import webbrowser 

4from collections import defaultdict 

5from pathlib import Path 

6from typing import List, Union 

7 

8from agentlib.core.errors import OptionalDependencyError 

9 

10try: 

11 import dash 

12 from dash import dcc, html, callback_context 

13 from dash.dependencies import Input, Output, State, ALL, ClientsideFunction 

14 import dash_bootstrap_components as dbc 

15 import plotly.graph_objs as go 

16except ImportError: 

17 raise OptionalDependencyError("simulator_dashboard", "interactive") 

18import pandas as pd 

19 

20from agentlib.core import datamodels 

21 

22# Global variable to store the last read position for each file 

23file_positions = defaultdict(int) 

24data = {} # Global variable to store loaded data 

25 

26 

27def get_port(): 

28 port = 8050 

29 while True: 

30 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 

31 is_free = s.connect_ex(("localhost", port)) != 0 

32 if is_free: 

33 return port 

34 else: 

35 port += 1 

36 

37 

38def load_new_data(file_path: Path) -> pd.DataFrame: 

39 try: 

40 with file_path.open("r") as f: 

41 f.seek(file_positions[file_path]) 

42 header = [0, 1, 2] if file_positions[file_path] == 0 else None 

43 new_data = f.read() 

44 file_positions[file_path] = f.tell() 

45 

46 if not new_data: 

47 return pd.DataFrame() 

48 

49 df = pd.read_csv(io.StringIO(new_data), index_col=0, header=header) 

50 if header: 

51 df.columns = df.columns.droplevel(2) 

52 return df 

53 except IOError as e: 

54 print(f"Error reading file {file_path}: {e}") 

55 return pd.DataFrame() 

56 

57 

58def update_data(existing_data: pd.DataFrame, new_data: pd.DataFrame) -> pd.DataFrame: 

59 if not existing_data.empty: 

60 new_data.columns = existing_data.columns 

61 return pd.concat([existing_data, new_data], axis=0).drop_duplicates() 

62 

63 

64def format_time_axis(seconds): 

65 """Formats units on the time axis, scaling to minutes, hours etc. for longer 

66 simulations.""" 

67 if seconds < 60 * 5: 

68 return seconds, "s", "{:.0f}" 

69 elif seconds < 3600 * 4: 

70 return seconds / 60, "min", "{:.1f}" 

71 elif seconds < 86400 * 3: 

72 return seconds / 3600, "h", "{:.1f}" 

73 elif seconds < 604800 * 2: 

74 return seconds / 86400, "d", "{:.1f}" 

75 elif seconds < 2592000 * 2: 

76 return seconds / 604800, "w", "{:.1f}" 

77 else: 

78 return seconds / 2592000, "mo", "{:.1f}" 

79 

80 

81def create_plot(df: pd.Series, title: str, plot_id: str) -> html.Div: 

82 # Convert index to seconds if it's not already 

83 if df.index.dtype != "float64": 

84 df.index = pd.to_numeric(df.index) 

85 

86 # Determine the appropriate time unit 

87 time_range = df.index.max() - df.index.min() 

88 scaled_time, time_unit, tick_format = format_time_axis(time_range) 

89 

90 # Scale the x-axis values 

91 scale_factor = time_range / scaled_time 

92 x_values = df.index / scale_factor 

93 

94 return html.Div( 

95 [ 

96 dcc.Graph( 

97 id={"type": "plot", "index": plot_id}, 

98 figure={ 

99 "data": [ 

100 go.Scatter(x=x_values, y=df.values, mode="lines", name=title) 

101 ], 

102 "layout": go.Layout( 

103 title=title, 

104 xaxis={ 

105 "title": f"Time ({time_unit})", 

106 "tickformat": tick_format, 

107 "hoverformat": ".2f", 

108 }, 

109 yaxis={"title": "Value"}, 

110 margin=dict(l=40, r=20, t=40, b=30), 

111 height=250, 

112 uirevision=plot_id, # This helps maintain zoom state 

113 ), 

114 }, 

115 config={"displayModeBar": False}, 

116 style={"height": "100%", "width": "100%"}, 

117 ) 

118 ] 

119 ) 

120 

121 

122def create_layout(file_names: List[Union[str, Path]]) -> html.Div: 

123 file_names = [Path(n) for n in file_names] 

124 return html.Div( 

125 [ 

126 dcc.Tabs( 

127 id="agent-tabs", 

128 children=[ 

129 dcc.Tab(label=file_name.stem, value=str(file_name)) 

130 for file_name in file_names 

131 ], 

132 value=str(file_names[0]) if file_names else None, 

133 ), 

134 dbc.Row( 

135 [ 

136 dbc.Col( 

137 html.Div(id="tab-content"), width=12, lg=9, className="pr-lg-0" 

138 ), 

139 dbc.Col( 

140 html.Div(id="variable-checkboxes", className="mt-3 mt-lg-0"), 

141 width=12, 

142 lg=3, 

143 className="pl-lg-0", 

144 ), 

145 ], 

146 className="mt-3", 

147 ), 

148 dcc.Interval( 

149 id="interval-component", 

150 interval=2.5 * 1000, 

151 n_intervals=0, 

152 ), 

153 ] 

154 ) 

155 

156 

157index_string = """ 

158<!DOCTYPE html> 

159<html> 

160 <head> 

161 {%metas%} 

162 <title>{%title%}</title> 

163 {%favicon%} 

164 {%css%} 

165 <style> 

166 .checkbox-scroll { 

167 max-height: calc(100vh - 100px); 

168 overflow-y: auto; 

169 padding-right: 15px; 

170 } 

171 @media (min-width: 992px) { 

172 .checkbox-scroll { 

173 position: sticky; 

174 top: 20px; 

175 } 

176 } 

177 </style> 

178 </head> 

179 <body> 

180 {%app_entry%} 

181 <footer> 

182 {%config%} 

183 {%scripts%} 

184 {%renderer%} 

185 </footer> 

186 </body> 

187</html> 

188""" 

189 

190 

191def simulator_dashboard(*file_names: Union[str, Path]): 

192 app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) 

193 app.layout = create_layout(file_names) 

194 app.index_string = index_string 

195 

196 @app.callback( 

197 Output("variable-checkboxes", "children"), Input("agent-tabs", "value") 

198 ) 

199 def update_checkboxes(selected_tab): 

200 if not selected_tab: 

201 return html.Div("Please select a tab to view variables.") 

202 

203 file_path = Path(selected_tab) 

204 if str(file_path) not in data: 

205 data[str(file_path)] = pd.DataFrame() 

206 

207 file_data = data[str(file_path)] 

208 checkbox_groups = [] 

209 

210 for causality in datamodels.Causality: 

211 try: 

212 causality_data = file_data[causality] 

213 except KeyError: 

214 continue 

215 

216 checkboxes = [ 

217 dbc.Checkbox( 

218 id={ 

219 "type": "variable-checkbox", 

220 "index": f"{causality.name}-{column}", 

221 }, 

222 label=column, 

223 value=True, 

224 ) 

225 for column in causality_data.columns 

226 ] 

227 

228 checkbox_groups.append( 

229 html.Div([html.H5(causality.name.capitalize()), html.Div(checkboxes)]) 

230 ) 

231 

232 return html.Div(checkbox_groups, className="checkbox-scroll") 

233 

234 @app.callback( 

235 Output("tab-content", "children"), 

236 Input("agent-tabs", "value"), 

237 Input("interval-component", "n_intervals"), 

238 Input({"type": "variable-checkbox", "index": ALL}, "value"), 

239 State({"type": "variable-checkbox", "index": ALL}, "id"), 

240 ) 

241 def update_tab_content(selected_tab, n_intervals, checkbox_values, checkbox_ids): 

242 if not selected_tab: 

243 return html.Div( 

244 "Please select a tab to view data.", style={"padding": "20px"} 

245 ) 

246 

247 file_path = Path(selected_tab) 

248 if str(file_path) not in data: 

249 data[str(file_path)] = pd.DataFrame() 

250 

251 new_data = load_new_data(file_path) 

252 if not new_data.empty: 

253 data[str(file_path)] = update_data(data[str(file_path)], new_data) 

254 

255 file_data = data[str(file_path)] 

256 

257 # Create a dictionary of selected variables 

258 selected_variables = { 

259 checkbox_id["index"]: value 

260 for checkbox_id, value in zip(checkbox_ids, checkbox_values) 

261 } 

262 

263 sections = [] 

264 for causality in [ 

265 datamodels.Causality.output, 

266 datamodels.Causality.input, 

267 datamodels.Causality.local, 

268 datamodels.Causality.parameter, 

269 ]: 

270 try: 

271 causality_data = file_data[causality] 

272 except KeyError: 

273 continue 

274 

275 plots = [] 

276 for column in causality_data.columns: 

277 checkbox_key = f"{causality.name}-{column}" 

278 if selected_variables.get(checkbox_key, True): 

279 plot_id = f"{causality.name}-{column}" 

280 plots.append( 

281 html.Div( 

282 create_plot(causality_data[column], column, plot_id), 

283 style={ 

284 "width": "33%", 

285 "display": "inline-block", 

286 "padding": "10px", 

287 }, 

288 ) 

289 ) 

290 

291 if plots: 

292 sections.append( 

293 html.Div( 

294 [ 

295 html.H3( 

296 causality.name.capitalize(), 

297 style={"padding-left": "10px"}, 

298 ), 

299 html.Div( 

300 plots, style={"display": "flex", "flexWrap": "wrap"} 

301 ), 

302 ] 

303 ) 

304 ) 

305 

306 return html.Div(sections) 

307 

308 port = get_port() 

309 webbrowser.open_new_tab(f"http://localhost:{port}") 

310 app.run_server(debug=False, port=port)