import re
import webbrowser
from pathlib import Path
from typing import Dict, Union, Optional, Literal, Any, List, Tuple
import dash
import h5py
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from dash import html, dcc
from dash.dependencies import Input, Output, State
# Keep existing imports
from agentlib_mpc.utils import TIME_CONVERSION
from agentlib_mpc.utils.analysis import load_mpc, load_mpc_stats
from agentlib_mpc.utils.plotting.basic import EBCColors
from agentlib_mpc.utils.plotting.interactive import get_port, obj_plot, solver_return
from agentlib_mpc.utils.plotting.mpc import interpolate_colors
[docs]def reduce_triple_index(df: pd.DataFrame) -> pd.DataFrame:
    """
    Reduce a triple-indexed DataFrame to a double index by keeping only the rows
    with the largest level 1 index for each unique level 0 index.
    Args:
        df: DataFrame with either double or triple index
    Returns:
        DataFrame with double index
    """
    if len(df.index.levels) == 2:
        return df
    # Group by level 0 and get the maximum level 1 index for each group
    idx = df.index.get_level_values(0)
    sub_idx = df.index.get_level_values(1)
    max_sub_indices = df.groupby(idx)[[]].max().index
    # Create a mask for rows we want to keep
    mask = pd.Series(False, index=df.index)
    for time in max_sub_indices:
        max_sub_idx = df.loc[time].index.get_level_values(0).max()
        mask.loc[(time, max_sub_idx)] = True
    # Apply the mask and drop the middle level
    return df[mask].droplevel(1) 
[docs]def is_mhe_data(series: pd.Series) -> bool:
    """
    Detect if the data represents MHE (Moving Horizon Estimator) results
    rather than MPC predictions.
    Args:
        series: Series of predictions with time steps as index
    Returns:
        bool: True if the data appears to be MHE data, False otherwise
    """
    # Get the unique prediction time points
    unique_time_points = series.index.unique(level=0)
    # For each time point, check the distribution of indices
    negative_indices_count = 0
    positive_indices_count = 0
    for time_point in unique_time_points:
        prediction = series.xs(time_point, level=0)
        # Count negative and non-negative indices
        negative_indices_count += sum(prediction.index < 0)
        positive_indices_count += sum(prediction.index >= 0)
    # If we have mostly negative indices with just a few non-negative ones,
    # it's likely MHE data (which primarily contains past states)
    if negative_indices_count > 0 and positive_indices_count <= unique_time_points.size:
        return True
    return False 
[docs]def plot_mpc_plotly(
    series: pd.Series,
    step: bool = False,
    convert_to: Literal["seconds", "minutes", "hours", "days"] = "seconds",
    y_axis_label: str = "",
    use_datetime: bool = False,
    max_predictions: int = 1000,
) -> go.Figure:
    """
    Create a plotly figure from MPC prediction series.
    Args:
        series: Series of MPC predictions with time steps as index
        step: Whether to display step plots (True) or continuous lines (False)
        convert_to: Unit for time conversion
        y_axis_label: Label for y-axis
        use_datetime: Whether to interpret timestamps as datetime
        max_predictions: Maximum number of predictions to show (for performance)
    Returns:
        Plotly figure object
    """
    fig = go.Figure()
    predictions_grouped = series.groupby(level=0)
    number_of_predictions = predictions_grouped.ngroups
    # Detect if this is MHE data
    is_mhe = is_mhe_data(series)
    # Sample predictions if there are too many
    if number_of_predictions > max_predictions:
        # Always include the most recent prediction
        most_recent_time = series.index.unique(level=0)[-1]
        # Calculate step size for the remaining predictions
        remaining_slots = max_predictions - 1
        step_size = (number_of_predictions - 1) // remaining_slots
        # Select evenly spaced predictions and combine with most recent
        selected_times = series.index.unique(level=0)[:-1:step_size][:remaining_slots]
        selected_times = pd.Index(list(selected_times) + [most_recent_time])
        predictions_iterator = ((t, series.xs(t, level=0)) for t in selected_times)
        number_of_predictions = max_predictions
    else:
        selected_times = series.index.unique(level=0)
        predictions_iterator = ((t, series.xs(t, level=0)) for t in selected_times)
    # stores the first value of each prediction (only for selected times)
    actual_values: dict[float, float] = {}
    for i, (time_seconds, prediction) in enumerate(predictions_iterator):
        prediction: pd.Series = prediction.dropna()
        # For MPC, only show future values (index >= 0)
        # For MHE, show all values including past (don't filter)
        if not is_mhe:
            prediction = prediction[prediction.index >= 0]
        if use_datetime:
            time_converted = pd.Timestamp(time_seconds, unit="s", tz="UTC").tz_convert(
                "Europe/Berlin"
            )
            relative_times = prediction.index
            try:
                # For MHE, the reference point is typically at index 0
                # For MPC, the reference point is also at index 0
                actual_values[time_converted] = prediction.loc[0]
            except KeyError:
                pass
            timedeltas = pd.to_timedelta(relative_times, unit="s")
            base_time = pd.Timestamp(time_seconds, unit="s", tz="UTC")
            prediction.index = base_time + timedeltas
            prediction.index = prediction.index.tz_convert("Europe/Berlin")
        else:
            time_converted = time_seconds / TIME_CONVERSION[convert_to]
            try:
                actual_values[time_converted] = prediction.loc[0]
            except KeyError:
                pass
            prediction.index = (prediction.index + time_seconds) / TIME_CONVERSION[
                convert_to
            ]
        progress = i / number_of_predictions
        prediction_color = interpolate_colors(
            progress=progress,
            colors=[EBCColors.red, EBCColors.dark_grey],
        )
        # For MHE data, use a different line style to visually distinguish from MPC
        line_style = "dash" if is_mhe else None
        line_width = 1.0 if is_mhe else 0.7
        trace_kwargs = dict(
            x=prediction.index,
            y=prediction,
            mode="lines",
            line=dict(
                color=f"rgb{prediction_color}",
                width=line_width,
                shape="hv" if step else None,
                dash=line_style,
            ),
            name=(
                f"{time_converted}"
                if use_datetime
                else f"{time_converted} {convert_to[0]}"
            ),
            legendgroup="Prediction",
            legendgrouptitle_text="Predictions",
            visible=True,
            legendrank=i + 2,
        )
        fig.add_trace(go.Scattergl(**trace_kwargs))
    actual_series = pd.Series(actual_values)
    fig.add_trace(
        go.Scattergl(
            x=actual_series.index,
            y=actual_series,
            mode="lines",
            line=dict(color="black", width=1.5, shape="hv" if step else None),
            name="Actual Values",
            legendrank=1,
        )
    )
    # Add annotation to indicate if this is MHE data
    if is_mhe:
        fig.add_annotation(
            x=0.05,
            y=0.95,
            xref="paper",
            yref="paper",
            text="MHE Data (includes past values)",
            showarrow=False,
            font=dict(color="red", size=12),
            bgcolor="rgba(255, 255, 255, 0.8)",
            bordercolor="red",
            borderwidth=1,
            borderpad=4,
        )
    x_axis_label = "Time" if use_datetime else f"Time in {convert_to}"
    fig.update_layout(
        showlegend=True,
        legend=dict(
            groupclick="toggleitem",
            itemclick="toggle",
            itemdoubleclick="toggleothers",
        ),
        xaxis_title=x_axis_label,
        yaxis_title=y_axis_label,
        uirevision="same",
    )
    return fig 
[docs]def make_components(
    data: pd.DataFrame,
    convert_to: str,
    stats: Optional[pd.DataFrame] = None,
    use_datetime: bool = False,
    step: bool = False,
) -> html.Div:
    """
    Create dashboard components from MPC data and stats.
    Args:
        data: DataFrame with MPC data
        convert_to: Time unit for plotting
        stats: Optional DataFrame with MPC statistics
        use_datetime: Whether to interpret timestamps as datetime
        step: Whether to use step plots
    Returns:
        Dash HTML Div containing all components
    """
    components = []
    # Add statistics components if available
    if stats is not None:
        # Add solver iterations plot
        solver_plot = solver_return(stats, convert_to)
        if solver_plot is not None:
            components.insert(0, html.Div([solver_plot]))
        # Add objective plot if available
        obj_value_plot = obj_plot(stats, convert_to)
        if obj_value_plot is not None:
            components.insert(1, html.Div([obj_value_plot]))
    # Create one component for each variable
    # Remove try-except to expose errors directly
    if isinstance(data.columns, pd.MultiIndex):
        for var_type, column in data.columns:
            if var_type == "variable":
                components.append(
                    html.Div(
                        [
                            dcc.Graph(
                                id=f"plot-{column}",
                                figure=plot_mpc_plotly(
                                    data[var_type][column],
                                    step=step,
                                    convert_to=convert_to,
                                    y_axis_label=column,
                                    use_datetime=use_datetime,
                                ),
                                style={
                                    "min-width": "600px",
                                    "min-height": "400px",
                                    "max-width": "900px",
                                    "max-height": "450px",
                                },
                            ),
                        ],
                        className="draggable",
                    )
                )
    # Handle alternative column structures explicitly without exception handling
    elif isinstance(data.columns, pd.Index):
        for column in data.columns:
            if column.startswith("variable_"):
                column_name = column.replace("variable_", "")
                components.append(
                    html.Div(
                        [
                            dcc.Graph(
                                id=f"plot-{column_name}",
                                figure=plot_mpc_plotly(
                                    data[column],
                                    step=step,
                                    convert_to=convert_to,
                                    y_axis_label=column_name,
                                    use_datetime=use_datetime,
                                ),
                                style={
                                    "min-width": "600px",
                                    "min-height": "400px",
                                    "max-width": "900px",
                                    "max-height": "450px",
                                },
                            ),
                        ],
                        className="draggable",
                    )
                )
    return html.Div(
        components,
        style={
            "display": "grid",
            "grid-template-columns": "repeat(auto-fit, minmax(600px, 1fr))",
            "grid-gap": "20px",
            "padding": "20px",
            "min-width": "600px",
            "min-height": "200px",
        },
        id="plot-container",
    ) 
[docs]def detect_index_type(data: pd.DataFrame) -> Tuple[bool, bool]:
    """
    Detect the type of index in the DataFrame.
    Args:
        data: DataFrame to check
    Returns:
        Tuple of (is_multi_index, is_datetime)
    """
    is_multi_index = isinstance(data.index, pd.MultiIndex)
    # Check if it's a datetime index (or the first level is datetime)
    if is_multi_index:
        first_level = data.index.levels[0]
        is_datetime = pd.api.types.is_datetime64_any_dtype(first_level)
        if not is_datetime:
            # Check if it might be a Unix timestamp (large integer values)
            if pd.api.types.is_numeric_dtype(first_level):
                is_datetime = (
                    first_level.max() > 1e9
                )  # Simple heuristic for Unix timestamp
    else:
        is_datetime = pd.api.types.is_datetime64_any_dtype(data.index)
        if not is_datetime and pd.api.types.is_numeric_dtype(data.index):
            is_datetime = data.index.max() > 1e9
    return is_multi_index, is_datetime 
[docs]def show_multi_room_dashboard(
    results: Dict[str, Dict[str, Any]], scale: str = "hours", step: bool = False
):
    """
    Show a dashboard with dropdown selection for different agents/rooms.
    Args:
        results: Dictionary with agent results from mas.get_results()
        scale: Time scale for plotting ("seconds", "minutes", "hours", "days")
        step: Whether to use step plots
    """
    app = dash.Dash(__name__, title="Multi-Agent MPC Results")
    # Get all agents
    agent_ids = list(results.keys())
    if not agent_ids:
        raise ValueError("No agents found in results dictionary")
    # Find first valid MPC data to determine index type
    first_agent_id = None
    first_module_id = None
    for agent_id in agent_ids:
        for module_id, module_data in results[agent_id].items():
            if isinstance(module_data, pd.DataFrame):
                first_agent_id = agent_id
                first_module_id = module_id
                break
        if first_agent_id:
            break
    if not first_agent_id:
        raise ValueError("No valid MPC data found in results")
    first_data = results[first_agent_id][first_module_id]
    is_multi_index, use_datetime = detect_index_type(first_data)
    # Create agent and module selector dropdowns
    app.layout = html.Div(
        [
            html.H1("Multi-Agent MPC Results"),
            html.Div(
                [
                    html.Div(
                        [
                            html.Label("Select Agent:"),
                            dcc.Dropdown(
                                id="agent-selector",
                                options=[
                                    {"label": agent_id, "value": agent_id}
                                    for agent_id in agent_ids
                                ],
                                value=first_agent_id,
                            ),
                        ],
                        style={
                            "width": "300px",
                            "margin": "10px",
                            "display": "inline-block",
                        },
                    ),
                    html.Div(
                        [
                            html.Label("Select Module:"),
                            dcc.Dropdown(
                                id="module-selector",
                                # Options will be set by callback
                            ),
                        ],
                        style={
                            "width": "300px",
                            "margin": "10px",
                            "display": "inline-block",
                        },
                    ),
                ],
            ),
            html.Div(
                html.Button(
                    "Toggle Step Plot", id="toggle-step", style={"margin": "10px"}
                )
            ),
            html.Div(id="agent-dashboard"),
            dcc.Store(id="step-state", data=step),
        ]
    )
    @app.callback(
        [Output("module-selector", "options"), Output("module-selector", "value")],
        [Input("agent-selector", "value")],
    )
    def update_module_options(selected_agent):
        if not selected_agent:
            return [], None
        module_options = []
        first_module = None
        for module_id, module_data in results[selected_agent].items():
            if isinstance(module_data, pd.DataFrame):
                module_options.append({"label": module_id, "value": module_id})
                if first_module is None:
                    first_module = module_id
        return module_options, first_module
    @app.callback(
        Output("step-state", "data"),
        [Input("toggle-step", "n_clicks")],
        [State("step-state", "data")],
    )
    def toggle_step_plot(n_clicks, current_step):
        if n_clicks:
            return not current_step
        return current_step
    @app.callback(
        Output("agent-dashboard", "children"),
        [
            Input("agent-selector", "value"),
            Input("module-selector", "value"),
            Input("step-state", "data"),
        ],
    )
    def update_dashboard(selected_agent, selected_module, step_state):
        if not selected_agent or not selected_module:
            return html.Div("Please select both an agent and a module")
        # Remove try-except to expose errors directly
        data = results[selected_agent][selected_module]
        if not isinstance(data, pd.DataFrame):
            return html.Div(f"Selected module does not contain valid MPC data")
        # Reduce triple index to double index if needed
        if isinstance(data.index, pd.MultiIndex) and len(data.index.levels) > 2:
            data = reduce_triple_index(data)
        # Check if data needs time normalization
        if is_multi_index and not use_datetime:
            # Remove try-except to expose errors directly
            first_time = data.index.levels[0][0]
            data.index = data.index.set_levels(
                data.index.levels[0] - first_time, level=0
            )
        # Get stats data if available
        stats = None
        if f"{selected_module}_stats" in results[selected_agent]:
            stats = results[selected_agent][f"{selected_module}_stats"]
        # Create the dashboard components
        return make_components(
            data=data,
            convert_to=scale,
            stats=stats,
            use_datetime=use_datetime,
            step=step_state,
        )
    # Launch the dashboard
    port = get_port()
    webbrowser.open_new_tab(f"http://localhost:{port}")
    app.run(debug=False, port=port) 
[docs]def launch_dashboard_from_results(
    results: Dict[str, Dict[str, Any]], scale: str = "hours", step: bool = False
) -> bool:
    """
    Launch the multi-agent dashboard from results dictionary returned by mas.get_results().
    Args:
        results: Dictionary with agent results from mas.get_results()
        scale: Time scale for plotting ("seconds", "minutes", "hours", "days")
        step: Whether to use step plots
    Returns:
        bool: True if dashboard was launched, False otherwise
    """
    if not results or not isinstance(results, dict):
        raise ValueError("Invalid results: Expected non-empty dictionary")
    # Validate results structure
    valid_data_found = False
    for agent_id, agent_data in results.items():
        if not isinstance(agent_data, dict):
            continue
        for module_id, module_data in agent_data.items():
            if not isinstance(module_data, pd.DataFrame):
                continue
            # Check if this DataFrame has the expected structure for MPC data
            if isinstance(module_data.index, pd.MultiIndex):
                if len(module_data.index.levels) > 1:
                    # This looks like MPC data with multi-level index
                    valid_data_found = True
                    break
            else:
                # Single level index might still be valid for some data
                valid_data_found = module_data.shape[0] > 0
                break
        if valid_data_found:
            break
    if not valid_data_found:
        raise ValueError("No valid MPC data found in results")
    # Launch the dashboard without catching exceptions
    print(f"Launching dashboard with scale={scale}")
    show_multi_room_dashboard(results, scale=scale, step=step)
    return True 
[docs]def process_mas_results(
    results: Dict[str, Dict[str, Any]],
) -> Dict[str, Dict[str, Any]]:
    """
    Process results from LocalMASAgency to prepare them for visualization.
    Args:
        results: Raw results from mas.get_results()
    Returns:
        Processed results ready for dashboard visualization
    """
    processed_results = {}
    for agent_id, agent_data in results.items():
        processed_results[agent_id] = {}
        # Find all DataFrame modules that could be MPC data
        for module_id, module_data in agent_data.items():
            if not isinstance(module_data, pd.DataFrame):
                continue
            # Remove try-except to expose errors directly
            # Check if this looks like MPC data
            if isinstance(module_data.index, pd.MultiIndex):
                if isinstance(module_data.columns, pd.MultiIndex):
                    # This is likely MPC data with variables, parameters, etc.
                    processed_results[agent_id][module_id] = module_data
                elif any(
                    col.startswith(("variable_", "parameter_"))
                    for col in module_data.columns
                ):
                    # This might be MPC data with flattened column names
                    processed_results[agent_id][module_id] = module_data
                # Check for stats data with matching prefix
                stats_module_id = f"{module_id}_stats"
                if stats_module_id in agent_data and isinstance(
                    agent_data[stats_module_id], pd.DataFrame
                ):
                    processed_results[agent_id][stats_module_id] = agent_data[
                        stats_module_id
                    ]
    return processed_results 
if __name__ == "__main__":
    # Example usage
    import sys
    if len(sys.argv) > 1:
        # If a path is provided as an argument, try to load from files
        path = Path(sys.argv[1])
        if path.exists() and path.is_dir():
            print(f"Loading data from directory: {path}")
            # Note: This function is referenced but not defined in the provided code
            # show_multi_room_dashboard_from_files(path, scale="hours")
        else:
            raise FileNotFoundError(f"Directory not found: {path}")
    else:
        print("No directory specified. Please provide a directory path.")