import webbrowser
from pathlib import Path
from typing import Dict, Optional, Literal, Any, Tuple
import dash
import pandas as pd
import plotly.graph_objects as go
from dash import html, dcc
from dash.dependencies import Input, Output, State
# Keep existing imports
from agentlib_mpc.utils import TIME_CONVERSION
from agentlib_mpc.utils.plotting.basic import EBCColors
from agentlib_mpc.utils.plotting.interactive import get_port, obj_plot, solver_return
from agentlib_mpc.utils.plotting.mpc import interpolate_colors
[docs]def reduce_triple_index(df: pd.DataFrame) -> pd.DataFrame:
"""
Reduce a triple-indexed DataFrame to a double index by keeping only the rows
with the largest level 1 index for each unique level 0 index.
Args:
df: DataFrame with either double or triple index
Returns:
DataFrame with double index
"""
if len(df.index.levels) == 2:
return df
# Group by level 0 and get the maximum level 1 index for each group
idx = df.index.get_level_values(0)
sub_idx = df.index.get_level_values(1)
max_sub_indices = df.groupby(idx)[[]].max().index
# Create a mask for rows we want to keep
mask = pd.Series(False, index=df.index)
for time in max_sub_indices:
max_sub_idx = df.loc[time].index.get_level_values(0).max()
mask.loc[(time, max_sub_idx)] = True
# Apply the mask and drop the middle level
return df[mask].droplevel(1)
[docs]def is_mhe_data(series: pd.Series) -> bool:
"""
Detect if the data represents MHE (Moving Horizon Estimator) results
rather than MPC predictions.
Args:
series: Series of predictions with time steps as index
Returns:
bool: True if the data appears to be MHE data, False otherwise
"""
# Get the unique prediction time points
unique_time_points = series.index.unique(level=0)
# For each time point, check the distribution of indices
negative_indices_count = 0
positive_indices_count = 0
for time_point in unique_time_points:
prediction = series.xs(time_point, level=0)
# Count negative and non-negative indices
negative_indices_count += sum(prediction.index < 0)
positive_indices_count += sum(prediction.index >= 0)
# If we have mostly negative indices with just a few non-negative ones,
# it's likely MHE data (which primarily contains past states)
if negative_indices_count > 0 and positive_indices_count <= unique_time_points.size:
return True
return False
[docs]def plot_mpc_plotly(
series: pd.Series,
step: bool = False,
convert_to: Literal["seconds", "minutes", "hours", "days"] = "seconds",
y_axis_label: str = "",
use_datetime: bool = False,
max_predictions: int = 1000,
) -> go.Figure:
"""
Create a plotly figure from MPC prediction series.
Args:
series: Series of MPC predictions with time steps as index
step: Whether to display step plots (True) or continuous lines (False)
convert_to: Unit for time conversion
y_axis_label: Label for y-axis
use_datetime: Whether to interpret timestamps as datetime
max_predictions: Maximum number of predictions to show (for performance)
Returns:
Plotly figure object
"""
fig = go.Figure()
predictions_grouped = series.groupby(level=0)
number_of_predictions = predictions_grouped.ngroups
# Detect if this is MHE data
is_mhe = is_mhe_data(series)
# Sample predictions if there are too many
if number_of_predictions > max_predictions:
# Always include the most recent prediction
most_recent_time = series.index.unique(level=0)[-1]
# Calculate step size for the remaining predictions
remaining_slots = max_predictions - 1
step_size = (number_of_predictions - 1) // remaining_slots
# Select evenly spaced predictions and combine with most recent
selected_times = series.index.unique(level=0)[:-1:step_size][:remaining_slots]
selected_times = pd.Index(list(selected_times) + [most_recent_time])
predictions_iterator = ((t, series.xs(t, level=0)) for t in selected_times)
number_of_predictions = max_predictions
else:
selected_times = series.index.unique(level=0)
predictions_iterator = ((t, series.xs(t, level=0)) for t in selected_times)
# stores the first value of each prediction (only for selected times)
actual_values: dict[float, float] = {}
for i, (time_seconds, prediction) in enumerate(predictions_iterator):
prediction: pd.Series = prediction.dropna()
# For MPC, only show future values (index >= 0)
# For MHE, show all values including past (don't filter)
if not is_mhe:
prediction = prediction[prediction.index >= 0]
if use_datetime:
time_converted = pd.Timestamp(time_seconds, unit="s", tz="UTC").tz_convert(
"Europe/Berlin"
)
relative_times = prediction.index
try:
# For MHE, the reference point is typically at index 0
# For MPC, the reference point is also at index 0
actual_values[time_converted] = prediction.loc[0]
except KeyError:
pass
timedeltas = pd.to_timedelta(relative_times, unit="s")
base_time = pd.Timestamp(time_seconds, unit="s", tz="UTC")
prediction.index = base_time + timedeltas
prediction.index = prediction.index.tz_convert("Europe/Berlin")
else:
time_converted = time_seconds / TIME_CONVERSION[convert_to]
try:
actual_values[time_converted] = prediction.loc[0]
except KeyError:
pass
prediction.index = (prediction.index + time_seconds) / TIME_CONVERSION[
convert_to
]
progress = i / number_of_predictions
prediction_color = interpolate_colors(
progress=progress,
colors=[EBCColors.red, EBCColors.dark_grey],
)
# For MHE data, use a different line style to visually distinguish from MPC
line_style = "dash" if is_mhe else None
line_width = 1.0 if is_mhe else 0.7
trace_kwargs = dict(
x=prediction.index,
y=prediction,
mode="lines",
line=dict(
color=f"rgb{prediction_color}",
width=line_width,
shape="hv" if step else None,
dash=line_style,
),
name=(
f"{time_converted}"
if use_datetime
else f"{time_converted} {convert_to[0]}"
),
legendgroup="Prediction",
legendgrouptitle_text="Predictions",
visible=True,
legendrank=i + 2,
)
fig.add_trace(go.Scattergl(**trace_kwargs))
actual_series = pd.Series(actual_values)
fig.add_trace(
go.Scattergl(
x=actual_series.index,
y=actual_series,
mode="lines",
line=dict(color="black", width=1.5, shape="hv" if step else None),
name="Actual Values",
legendrank=1,
)
)
# Add annotation to indicate if this is MHE data
if is_mhe:
fig.add_annotation(
x=0.05,
y=0.95,
xref="paper",
yref="paper",
text="MHE Data (includes past values)",
showarrow=False,
font=dict(color="red", size=12),
bgcolor="rgba(255, 255, 255, 0.8)",
bordercolor="red",
borderwidth=1,
borderpad=4,
)
x_axis_label = "Time" if use_datetime else f"Time in {convert_to}"
fig.update_layout(
showlegend=True,
legend=dict(
groupclick="toggleitem",
itemclick="toggle",
itemdoubleclick="toggleothers",
),
xaxis_title=x_axis_label,
yaxis_title=y_axis_label,
uirevision="same",
)
return fig
[docs]def make_components(
data: pd.DataFrame,
convert_to: str,
stats: Optional[pd.DataFrame] = None,
use_datetime: bool = False,
step: bool = False,
) -> html.Div:
"""
Create dashboard components from MPC data and stats.
Args:
data: DataFrame with MPC data
convert_to: Time unit for plotting
stats: Optional DataFrame with MPC statistics
use_datetime: Whether to interpret timestamps as datetime
step: Whether to use step plots
Returns:
Dash HTML Div containing all components
"""
components = []
# Add statistics components if available
if stats is not None:
# Add solver iterations plot
solver_plot = solver_return(stats, convert_to)
if solver_plot is not None:
components.insert(0, html.Div([solver_plot]))
# Add objective plot if available
obj_value_plot = obj_plot(stats, convert_to)
if obj_value_plot is not None:
components.insert(1, html.Div([obj_value_plot]))
# Create one component for each variable
# Remove try-except to expose errors directly
if isinstance(data.columns, pd.MultiIndex):
for var_type, column in data.columns:
if var_type == "variable":
components.append(
html.Div(
[
dcc.Graph(
id=f"plot-{column}",
figure=plot_mpc_plotly(
data[var_type][column],
step=step,
convert_to=convert_to,
y_axis_label=column,
use_datetime=use_datetime,
),
style={
"min-width": "600px",
"min-height": "400px",
"max-width": "900px",
"max-height": "450px",
},
),
],
className="draggable",
)
)
# Handle alternative column structures explicitly without exception handling
elif isinstance(data.columns, pd.Index):
for column in data.columns:
if column.startswith("variable_"):
column_name = column.replace("variable_", "")
components.append(
html.Div(
[
dcc.Graph(
id=f"plot-{column_name}",
figure=plot_mpc_plotly(
data[column],
step=step,
convert_to=convert_to,
y_axis_label=column_name,
use_datetime=use_datetime,
),
style={
"min-width": "600px",
"min-height": "400px",
"max-width": "900px",
"max-height": "450px",
},
),
],
className="draggable",
)
)
return html.Div(
components,
style={
"display": "grid",
"grid-template-columns": "repeat(auto-fit, minmax(600px, 1fr))",
"grid-gap": "20px",
"padding": "20px",
"min-width": "600px",
"min-height": "200px",
},
id="plot-container",
)
[docs]def detect_index_type(data: pd.DataFrame) -> Tuple[bool, bool]:
"""
Detect the type of index in the DataFrame.
Args:
data: DataFrame to check
Returns:
Tuple of (is_multi_index, is_datetime)
"""
is_multi_index = isinstance(data.index, pd.MultiIndex)
# Check if it's a datetime index (or the first level is datetime)
if is_multi_index:
first_level = data.index.levels[0]
is_datetime = pd.api.types.is_datetime64_any_dtype(first_level)
if not is_datetime:
# Check if it might be a Unix timestamp (large integer values)
if pd.api.types.is_numeric_dtype(first_level):
is_datetime = (
first_level.max() > 1e9
) # Simple heuristic for Unix timestamp
else:
is_datetime = pd.api.types.is_datetime64_any_dtype(data.index)
if not is_datetime and pd.api.types.is_numeric_dtype(data.index):
is_datetime = data.index.max() > 1e9
return is_multi_index, is_datetime
[docs]def show_multi_room_dashboard(
results: Dict[str, Dict[str, Any]], scale: str = "hours", step: bool = False
):
"""
Show a dashboard with dropdown selection for different agents/rooms.
Args:
results: Dictionary with agent results from mas.get_results()
scale: Time scale for plotting ("seconds", "minutes", "hours", "days")
step: Whether to use step plots
"""
app = dash.Dash(__name__, title="Multi-Agent MPC Results")
# Get all agents
agent_ids = list(results.keys())
if not agent_ids:
raise ValueError("No agents found in results dictionary")
# Find first valid MPC data to determine index type
first_agent_id = None
first_module_id = None
for agent_id in agent_ids:
for module_id, module_data in results[agent_id].items():
if isinstance(module_data, pd.DataFrame):
first_agent_id = agent_id
first_module_id = module_id
break
if first_agent_id:
break
if not first_agent_id:
raise ValueError("No valid MPC data found in results")
first_data = results[first_agent_id][first_module_id]
is_multi_index, use_datetime = detect_index_type(first_data)
# Create agent and module selector dropdowns
app.layout = html.Div(
[
html.H1("Multi-Agent MPC Results"),
html.Div(
[
html.Div(
[
html.Label("Select Agent:"),
dcc.Dropdown(
id="agent-selector",
options=[
{"label": agent_id, "value": agent_id}
for agent_id in agent_ids
],
value=first_agent_id,
),
],
style={
"width": "300px",
"margin": "10px",
"display": "inline-block",
},
),
html.Div(
[
html.Label("Select Module:"),
dcc.Dropdown(
id="module-selector",
# Options will be set by callback
),
],
style={
"width": "300px",
"margin": "10px",
"display": "inline-block",
},
),
],
),
html.Div(
html.Button(
"Toggle Step Plot", id="toggle-step", style={"margin": "10px"}
)
),
html.Div(id="agent-dashboard"),
dcc.Store(id="step-state", data=step),
]
)
@app.callback(
[Output("module-selector", "options"), Output("module-selector", "value")],
[Input("agent-selector", "value")],
)
def update_module_options(selected_agent):
if not selected_agent:
return [], None
module_options = []
first_module = None
for module_id, module_data in results[selected_agent].items():
if isinstance(module_data, pd.DataFrame):
module_options.append({"label": module_id, "value": module_id})
if first_module is None:
first_module = module_id
return module_options, first_module
@app.callback(
Output("step-state", "data"),
[Input("toggle-step", "n_clicks")],
[State("step-state", "data")],
)
def toggle_step_plot(n_clicks, current_step):
if n_clicks:
return not current_step
return current_step
@app.callback(
Output("agent-dashboard", "children"),
[
Input("agent-selector", "value"),
Input("module-selector", "value"),
Input("step-state", "data"),
],
)
def update_dashboard(selected_agent, selected_module, step_state):
if not selected_agent or not selected_module:
return html.Div("Please select both an agent and a module")
# Remove try-except to expose errors directly
data = results[selected_agent][selected_module]
if not isinstance(data, pd.DataFrame):
return html.Div("Selected module does not contain valid MPC data")
# Reduce triple index to double index if needed
if isinstance(data.index, pd.MultiIndex) and len(data.index.levels) > 2:
data = reduce_triple_index(data)
# Check if data needs time normalization
if is_multi_index and not use_datetime:
# Remove try-except to expose errors directly
first_time = data.index.levels[0][0]
data.index = data.index.set_levels(
data.index.levels[0] - first_time, level=0
)
# Get stats data if available
stats = None
if f"{selected_module}_stats" in results[selected_agent]:
stats = results[selected_agent][f"{selected_module}_stats"]
# Create the dashboard components
return make_components(
data=data,
convert_to=scale,
stats=stats,
use_datetime=use_datetime,
step=step_state,
)
# Launch the dashboard
port = get_port()
webbrowser.open_new_tab(f"http://localhost:{port}")
app.run(debug=False, port=port)
[docs]def launch_dashboard_from_results(
results: Dict[str, Dict[str, Any]], scale: str = "hours", step: bool = False
) -> bool:
"""
Launch the multi-agent dashboard from results dictionary returned by mas.get_results().
Args:
results: Dictionary with agent results from mas.get_results()
scale: Time scale for plotting ("seconds", "minutes", "hours", "days")
step: Whether to use step plots
Returns:
bool: True if dashboard was launched, False otherwise
"""
if not results or not isinstance(results, dict):
raise ValueError("Invalid results: Expected non-empty dictionary")
# Validate results structure
valid_data_found = False
for agent_id, agent_data in results.items():
if not isinstance(agent_data, dict):
continue
for module_id, module_data in agent_data.items():
if not isinstance(module_data, pd.DataFrame):
continue
# Check if this DataFrame has the expected structure for MPC data
if isinstance(module_data.index, pd.MultiIndex):
if len(module_data.index.levels) > 1:
# This looks like MPC data with multi-level index
valid_data_found = True
break
else:
# Single level index might still be valid for some data
valid_data_found = module_data.shape[0] > 0
break
if valid_data_found:
break
if not valid_data_found:
raise ValueError("No valid MPC data found in results")
# Launch the dashboard without catching exceptions
print(f"Launching dashboard with scale={scale}")
show_multi_room_dashboard(results, scale=scale, step=step)
return True
[docs]def process_mas_results(
results: Dict[str, Dict[str, Any]],
) -> Dict[str, Dict[str, Any]]:
"""
Process results from LocalMASAgency to prepare them for visualization.
Args:
results: Raw results from mas.get_results()
Returns:
Processed results ready for dashboard visualization
"""
processed_results = {}
for agent_id, agent_data in results.items():
processed_results[agent_id] = {}
# Find all DataFrame modules that could be MPC data
for module_id, module_data in agent_data.items():
if not isinstance(module_data, pd.DataFrame):
continue
# Remove try-except to expose errors directly
# Check if this looks like MPC data
if isinstance(module_data.index, pd.MultiIndex):
if isinstance(module_data.columns, pd.MultiIndex):
# This is likely MPC data with variables, parameters, etc.
processed_results[agent_id][module_id] = module_data
elif any(
col.startswith(("variable_", "parameter_"))
for col in module_data.columns
):
# This might be MPC data with flattened column names
processed_results[agent_id][module_id] = module_data
# Check for stats data with matching prefix
stats_module_id = f"{module_id}_stats"
if stats_module_id in agent_data and isinstance(
agent_data[stats_module_id], pd.DataFrame
):
processed_results[agent_id][stats_module_id] = agent_data[
stats_module_id
]
return processed_results
if __name__ == "__main__":
# Example usage
import sys
if len(sys.argv) > 1:
# If a path is provided as an argument, try to load from files
path = Path(sys.argv[1])
if path.exists() and path.is_dir():
print(f"Loading data from directory: {path}")
# Note: This function is referenced but not defined in the provided code
# show_multi_room_dashboard_from_files(path, scale="hours")
else:
raise FileNotFoundError(f"Directory not found: {path}")
else:
print("No directory specified. Please provide a directory path.")