Coverage for addmo/s5_insights/model_plots/parallel_plots.py: 13%
63 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import numpy as np
2import pandas as pd
3import plotly.express as px
4import matplotlib.pyplot as plt
5from addmo.util import plotting_utils as d
6from addmo.util.load_save import load_data
9def parallel_plots(target, data, regressor):
11 xy_grid = data.drop(target, axis=1)
12 y_pred = pd.Series(regressor.predict(xy_grid), index=xy_grid.index)
13 xy_grid[target] = data[target]
14 xy_grid['y_pred'] = y_pred
17 # columns to plot:
18 cols = []
19 for var in xy_grid.columns:
20 min_val, max_val = xy_grid[var].min(), xy_grid[var].max()
21 if min_val != max_val: # Only keep variables with a valid range (no constant or 0 values)
22 cols.append(var)
24 xy_grid = xy_grid[cols]
25 ys_grid = xy_grid.to_numpy()[:, :]
26 ymins_grid = ys_grid.min(axis=0)
27 ymax_grid = ys_grid.max(axis=0)
28 dys_grid = ymax_grid - ymins_grid
29 ymins_grid -= dys_grid * 0.05 # Add padding
30 ymax_grid += dys_grid * 0.05
32 zs_grid = np.zeros_like(ys_grid)
33 zs_grid[:, 0] = ys_grid[:, 0]
34 zs_grid[:, 1:] = (ys_grid[:, 1:] - ymins_grid[1:]) / dys_grid[1:] * dys_grid[0] + ymins_grid[0]
36 dys = ymax_grid - ymins_grid
37 zs = np.zeros_like(ys_grid)
38 zs[:, 0] = ys_grid[:, 0]
39 zs[:, 1:] = (ys_grid[:, 1:] - ymins_grid[1:]) / dys[1:] * dys[0] + ymins_grid[0]
41 num_vars= len(xy_grid.columns) + 1
42 figure_width = max(5, num_vars * 2.5)
43 fig_size = (d.cm2inch(figure_width), d.cm2inch(8)) # Adjusted figure size
44 fig, host = plt.subplots(figsize=fig_size)
45 plt.subplots_adjust(left=0.05, right=0.92, bottom=0.08, top=0.8)
48 axes = [host] + [host.twinx() for i in range(ys_grid.shape[1] - 1)]
49 for i, ax in enumerate(axes):
51 ax.set_ylim(ymins_grid[i], ymax_grid[i])
52 ax.spines['top'].set_visible(False)
53 ax.spines['bottom'].set_visible(False)
54 if ax != host:
55 ax.spines['left'].set_visible(False)
56 ax.yaxis.set_ticks_position('right')
57 ax.spines['right'].set_position(("axes", i / (ys_grid.shape[1] - 1)))
58 host.set_xlim(0, ys_grid.shape[1] - 1)
59 host.set_xticks(range(ys_grid.shape[1]))
60 host.set_xticklabels([col.replace(' ', '\n').replace('__', '\n') for col in xy_grid.columns])
61 host.tick_params(axis='x', which='major', pad=7, labelsize=9)
62 host.spines['right'].set_visible(False)
63 host.xaxis.tick_top()
65 for j in range(zs_grid.shape[0]):
66 host.plot(np.arange(ys_grid.shape[1]), zs_grid[j, :], color=d.red, linewidth=0.5, alpha=0.7)
68 return fig
73def parallel_plots_interactive(target, data, regressor):
76 xy_grid = data.drop(target, axis=1)
77 y_pred = pd.Series(regressor.predict(xy_grid), index=xy_grid.index)
78 xy_grid[target] = data[target]
79 xy_grid['y_pred'] = y_pred
81 # Drop constant columns (no range)
82 variable_cols = [col for col in xy_grid.columns if xy_grid[col].nunique() > 1]
83 plot_data = xy_grid[variable_cols]
85 norm_data = plot_data
86 fig = px.parallel_coordinates(
87 norm_data,
88 color='y_pred',
89 labels={col: col.replace('__', '\n') for col in norm_data.columns},
90 color_continuous_scale=px.colors.sequential.Viridis,
91 title="Interactive Parallel Coordinates Plot"
92 )
94 return fig