Coverage for addmo/s5_insights/model_plots/parallel_plots.py: 13%

63 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import numpy as np 

2import pandas as pd 

3import plotly.express as px 

4import matplotlib.pyplot as plt 

5from addmo.util import plotting_utils as d 

6from addmo.util.load_save import load_data 

7 

8 

9def parallel_plots(target, data, regressor): 

10 

11 xy_grid = data.drop(target, axis=1) 

12 y_pred = pd.Series(regressor.predict(xy_grid), index=xy_grid.index) 

13 xy_grid[target] = data[target] 

14 xy_grid['y_pred'] = y_pred 

15 

16 

17 # columns to plot: 

18 cols = [] 

19 for var in xy_grid.columns: 

20 min_val, max_val = xy_grid[var].min(), xy_grid[var].max() 

21 if min_val != max_val: # Only keep variables with a valid range (no constant or 0 values) 

22 cols.append(var) 

23 

24 xy_grid = xy_grid[cols] 

25 ys_grid = xy_grid.to_numpy()[:, :] 

26 ymins_grid = ys_grid.min(axis=0) 

27 ymax_grid = ys_grid.max(axis=0) 

28 dys_grid = ymax_grid - ymins_grid 

29 ymins_grid -= dys_grid * 0.05 # Add padding 

30 ymax_grid += dys_grid * 0.05 

31 

32 zs_grid = np.zeros_like(ys_grid) 

33 zs_grid[:, 0] = ys_grid[:, 0] 

34 zs_grid[:, 1:] = (ys_grid[:, 1:] - ymins_grid[1:]) / dys_grid[1:] * dys_grid[0] + ymins_grid[0] 

35 

36 dys = ymax_grid - ymins_grid 

37 zs = np.zeros_like(ys_grid) 

38 zs[:, 0] = ys_grid[:, 0] 

39 zs[:, 1:] = (ys_grid[:, 1:] - ymins_grid[1:]) / dys[1:] * dys[0] + ymins_grid[0] 

40 

41 num_vars= len(xy_grid.columns) + 1 

42 figure_width = max(5, num_vars * 2.5) 

43 fig_size = (d.cm2inch(figure_width), d.cm2inch(8)) # Adjusted figure size 

44 fig, host = plt.subplots(figsize=fig_size) 

45 plt.subplots_adjust(left=0.05, right=0.92, bottom=0.08, top=0.8) 

46 

47 

48 axes = [host] + [host.twinx() for i in range(ys_grid.shape[1] - 1)] 

49 for i, ax in enumerate(axes): 

50 

51 ax.set_ylim(ymins_grid[i], ymax_grid[i]) 

52 ax.spines['top'].set_visible(False) 

53 ax.spines['bottom'].set_visible(False) 

54 if ax != host: 

55 ax.spines['left'].set_visible(False) 

56 ax.yaxis.set_ticks_position('right') 

57 ax.spines['right'].set_position(("axes", i / (ys_grid.shape[1] - 1))) 

58 host.set_xlim(0, ys_grid.shape[1] - 1) 

59 host.set_xticks(range(ys_grid.shape[1])) 

60 host.set_xticklabels([col.replace(' ', '\n').replace('__', '\n') for col in xy_grid.columns]) 

61 host.tick_params(axis='x', which='major', pad=7, labelsize=9) 

62 host.spines['right'].set_visible(False) 

63 host.xaxis.tick_top() 

64 

65 for j in range(zs_grid.shape[0]): 

66 host.plot(np.arange(ys_grid.shape[1]), zs_grid[j, :], color=d.red, linewidth=0.5, alpha=0.7) 

67 

68 return fig 

69 

70 

71 

72 

73def parallel_plots_interactive(target, data, regressor): 

74 

75 

76 xy_grid = data.drop(target, axis=1) 

77 y_pred = pd.Series(regressor.predict(xy_grid), index=xy_grid.index) 

78 xy_grid[target] = data[target] 

79 xy_grid['y_pred'] = y_pred 

80 

81 # Drop constant columns (no range) 

82 variable_cols = [col for col in xy_grid.columns if xy_grid[col].nunique() > 1] 

83 plot_data = xy_grid[variable_cols] 

84 

85 norm_data = plot_data 

86 fig = px.parallel_coordinates( 

87 norm_data, 

88 color='y_pred', 

89 labels={col: col.replace('__', '\n') for col in norm_data.columns}, 

90 color_continuous_scale=px.colors.sequential.Viridis, 

91 title="Interactive Parallel Coordinates Plot" 

92 ) 

93 

94 return fig