Coverage for addmo/s5_insights/model_plots/carpet_plots.py: 40%

223 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import numpy as np 

2import math 

3import pandas as pd 

4import matplotlib.pyplot as plt 

5from addmo.util import plotting_utils as d 

6from pathlib import Path 

7 

8from addmo.util.definitions import return_results_dir_model_tuning, return_best_model 

9from addmo.s3_model_tuning.models.model_factory import ModelFactory 

10from addmo.util.load_save import load_data 

11import matplotlib.pyplot as plt 

12from itertools import product 

13from mpl_toolkits.mplot3d import Axes3D 

14from matplotlib import cm 

15from matplotlib.colors import ListedColormap 

16from addmo.util.plotting_utils import * 

17import matplotlib.colors as colors 

18 

19from matplotlib import cm 

20from matplotlib.colors import LinearSegmentedColormap 

21 

22def truncate_colormap(cmap_name, min_val=0.1, max_val=0.9, n=256): 

23 """ 

24 Truncate a colormap to exclude the extreme ends (e.g., near-white tips). 

25 """ 

26 cmap = cm.get_cmap(cmap_name, n) 

27 new_colors = cmap(np.linspace(min_val, max_val, n)) 

28 return LinearSegmentedColormap.from_list(f"{cmap_name}_trunc", new_colors) 

29 

30 

31def plot_carpets(variables, measurements_data, regressor_func, system_func=None, bounds= None, combinations=None, defaults_dict=None): 

32 """ 

33 Create 3D surface model_plots for prediction function. 

34 Note: 

35 regressor_func: the regressor function 

36 system_func: the system/measurement data 

37 """ 

38 

39 # Define bounds 

40 if bounds is None: 

41 bounds = {} 

42 for var in variables: 

43 if var in measurements_data.columns: 

44 bounds[var] = [measurements_data[var].min(), measurements_data[var].max()] 

45 

46 # Define default values 

47 if defaults_dict is None: 

48 # Use mean value as default 

49 defaults_dict = {var: measurements_data[var].mean() for var in variables} 

50 

51 # Create combinations 

52 if combinations is None: 

53 combinations = [ 

54 (v1, v2) for i, v1 in enumerate(variables) for v2 in variables[i + 1:] 

55 ] 

56 

57 

58 # Create a grid for each variable 

59 grids = {var: np.linspace(bounds[var][0], bounds[var][1], 150) for var in variables} 

60 

61 # Filter combinations where both the features are non-zero 

62 valid_combinations = [ 

63 (x_label, y_label) for x_label, y_label in combinations 

64 if bounds[x_label][0] != 0 or bounds[x_label][1] != 0 

65 if bounds[y_label][0] != 0 or bounds[y_label][1] != 0 

66 

67 ] 

68 removed_items =[] 

69 for var in variables: 

70 if bounds[var][0] == 0 and bounds[var][1] == 0: 

71 removed_items.append(var) 

72 print('The following combinations are removed because the column only consists of zero values: {}'.format(removed_items)) 

73 # Handle case where all combinations are invalid 

74 if not valid_combinations: 

75 print("No valid subplots to display. Skipping plot creation.") 

76 return None 

77 

78 num_plots = len(valid_combinations) 

79 num_cols = 2 

80 num_rows = math.ceil(num_plots / num_cols) 

81 

82 fig_height = max(5, num_plots * 3.5) 

83 fig_size = (d.cm2inch(16), d.cm2inch(fig_height)) 

84 fig = plt.figure(figsize=fig_size) 

85 plt.subplots_adjust(left=-0.05, right=0.88, bottom=0.02, top=1, wspace=-0.1, hspace=0.05) 

86 

87 

88 for i, (x_label, y_label) in enumerate(valid_combinations, 1): 

89 ax = fig.add_subplot(num_rows, num_cols, i, projection="3d") 

90 X, Y = np.meshgrid(grids[x_label], grids[y_label]) 

91 

92 # Create input arrays for prediction functions 

93 inputs = {} 

94 for var in variables: 

95 if var == x_label: 

96 inputs[var] = X 

97 elif var == y_label: 

98 inputs[var] = Y 

99 else: 

100 if defaults_dict == None: 

101 inputs[var] = np.full_like(X, np.mean(grids[var])) 

102 else: 

103 inputs[var] = np.full_like(X, defaults_dict[var]) 

104 

105 Z1 = regressor_func(**inputs) 

106 surf1_cmap = "winter" 

107 surf2_cmap = "autumn" 

108 if system_func is None: 

109 surf1 = ax.plot_surface(X, Y, Z1, cmap=surf1_cmap, alpha=0.5) 

110 if system_func is not None: 

111 Z2 = system_func(**inputs) 

112 

113 # Create a common normalization for consistent coloring 

114 norm1 = colors.Normalize(vmin=np.nanmin(Z1), vmax=np.nanmax(Z1)) 

115 norm2 = colors.Normalize(vmin=np.nanmin(Z2), vmax=np.nanmax(Z2)) 

116 

117 Z1_greater = np.where(Z1 >= Z2, Z1, np.nan) 

118 Z1_smaller = np.where(Z1 <= Z2, Z1, np.nan) 

119 Z2_greater = np.where(Z2 >= Z1, Z2, np.nan) 

120 Z2_smaller = np.where(Z2 <= Z1, Z2, np.nan) 

121 

122 # surface plots in correct order and normalization 

123 surf1 = ax.plot_surface(X, Y, Z1, cmap=surf1_cmap, visible=False, norm=norm1) 

124 surf2 = ax.plot_surface(X, Y, Z2, cmap=surf2_cmap, visible=False, norm=norm2) 

125 surf2_smaller = ax.plot_surface(X, Y, Z2_smaller, cmap=surf2_cmap, alpha=0.5, norm=norm2) 

126 surf1_smaller = ax.plot_surface(X, Y, Z1_smaller, cmap=surf1_cmap, alpha=0.5, norm=norm1) 

127 surf2_greater = ax.plot_surface(X, Y, Z2_greater, cmap=surf2_cmap, alpha=0.5, norm=norm2) 

128 surf1_greater = ax.plot_surface(X, Y, Z1_greater, cmap=surf1_cmap, alpha=0.5, norm=norm1) 

129 

130 # Add this line to reverse the axis direction 

131 ax.set_box_aspect([1, 1, 0.6]) 

132 ax.margins(x=0, y=0) 

133 ax.set_xlabel(x_label.replace('__', '\n'), fontsize=7, labelpad=-6) 

134 ax.set_ylabel(y_label.replace('__', '\n'), fontsize=7, labelpad=-6) 

135 ax.set_zlabel("Prediction", fontsize=7, labelpad=-7) 

136 ax.set_zlabel("Prediction", labelpad=-7) 

137 ax.tick_params(axis="x", which="major", pad=-5) 

138 ax.view_init(elev=30, azim=120) 

139 ax.tick_params(axis="y", pad=-3) 

140 ax.tick_params(axis="z", pad=-3) 

141 plt.setp(ax.get_yticklabels(), fontsize=7) 

142 plt.setp(ax.get_xticklabels(), fontsize=7) 

143 plt.setp(ax.get_zticklabels(), fontsize=7) 

144 

145 # Add colorbars and label them 

146 if system_func is not None: 

147 cbar_ax1 = fig.add_axes([0.9, 0.35, 0.02, 0.3]) 

148 cbar1 = fig.colorbar(surf1, cax=cbar_ax1) 

149 cbar1.set_label("Regressor") 

150 cbar1.set_ticks([]) 

151 cbar1.set_ticklabels([]) 

152 

153 cbar_ax2 = fig.add_axes([0.9, 0.05, 0.02, 0.3]) 

154 cbar2 = fig.colorbar(surf2, cax=cbar_ax2) 

155 cbar2.set_label("System") 

156 cbar2.set_ticks([]) 

157 cbar2.set_ticklabels([]) # Remove tick label 

158 

159 else: 

160 cbar_ax1 = fig.add_axes([0.9, 0.35, 0.02, 0.3]) 

161 cbar1 = fig.colorbar(surf1, cax=cbar_ax1) 

162 cbar1.set_label("Regressor") 

163 cbar1.set_ticks([]) 

164 cbar1.set_ticklabels([]) 

165 

166 return fig 

167 

168 

169def prediction_func_4_regressor(regressor, rename_dict: dict = None): 

170 """ 

171 Create a prediction function for a regressor as the regressor takes a DataFrame as input. 

172 """ 

173 

174 def pred_func(**kwargs): 

175 features = regressor.metadata["features_ordered"] 

176 if rename_dict is not None: 

177 features = [rename_dict[feature] for feature in features] 

178 

179 # Determine the shape of the output 

180 shape = next( 

181 arr.shape for arr in kwargs.values() if isinstance(arr, np.ndarray) 

182 ) 

183 

184 # Prepare input saved_plots 

185 input_data = pd.DataFrame( 

186 {feature: np.ravel(kwargs[feature]) for feature in features} 

187 ) 

188 

189 # Make prediction 

190 prediction = regressor.predict(input_data) 

191 

192 # Reshape the prediction to match the input shape 

193 return prediction.reshape(shape) 

194 

195 return pred_func 

196 

197def plot_carpets_with_buckets(variables, measurements_data, target_values, regressor_func , bounds=None , combinations=None , defaults_dict=None ,num_buckets=4): 

198 

199 # Define bounds 

200 if bounds is None: 

201 bounds = {} 

202 for var in variables: 

203 if var in measurements_data.columns: 

204 bounds[var] = [measurements_data[var].min(), measurements_data[var].max()] 

205 # Define default values 

206 if defaults_dict is None: 

207 defaults_dict = {} 

208 for var in variables: 

209 unique_vals = measurements_data[var].dropna().unique() 

210 if len(unique_vals) <= 3 and all(val in [0, 1] for val in unique_vals): # binary feature 

211 defaults_dict[var] = measurements_data[var].mode().iloc[0] # take the mode of columns for binary features 

212 else: 

213 defaults_dict[var] = measurements_data[var].mean() 

214 

215 # Create data buckets based on num of buckets: 

216 bucket_size = { var: ((measurements_data[var].max()- measurements_data[var].min() )/ num_buckets) for var in variables } 

217 

218 bucket = { 

219 var: (defaults_dict[var] - (bucket_size[var]/2), defaults_dict[var] + (bucket_size[var]/2)) 

220 for var in variables 

221 } 

222 

223 # Create combinations 

224 if combinations is None: 

225 combinations = [ 

226 (v1, v2) for i, v1 in enumerate(variables) for v2 in variables[i + 1:] 

227 ] 

228 

229 

230 # Create a grid for each variable 

231 grids = {var: np.linspace(bounds[var][0], bounds[var][1], 150) for var in variables} 

232 

233 # Filter combinations where both the features are non-zero 

234 valid_combinations = [ 

235 (x_label, y_label) for x_label, y_label in combinations 

236 if bounds[x_label][0] != 0 or bounds[x_label][1] != 0 

237 if bounds[y_label][0] != 0 or bounds[y_label][1] != 0 

238 

239 ] 

240 removed_items =[] 

241 for var in variables: 

242 if bounds[var][0] == 0 and bounds[var][1] == 0: 

243 removed_items.append(var) 

244 print('The following combinations are removed because the column only consists of zero values: {}'.format(removed_items)) 

245 # Handle case where all combinations are invalid 

246 if not valid_combinations: 

247 print("No valid subplots to display. Skipping plot creation.") 

248 return None 

249 

250 num_plots = len(valid_combinations) 

251 num_cols = 2 

252 num_rows = math.ceil(num_plots / num_cols) 

253 

254 fig_height = max(5, num_plots * 3.5) 

255 fig_size = (d.cm2inch(16), d.cm2inch(fig_height)) 

256 fig = plt.figure(figsize=fig_size) 

257 plt.subplots_adjust(left=-0.05, right=0.88, bottom=0.02, top=1, wspace=-0.1, hspace=0.05) 

258 

259 for i, (x_label, y_label) in enumerate(valid_combinations, 1): 

260 ax = fig.add_subplot(num_rows, num_cols, i, projection="3d",computed_zorder=False) 

261 X, Y = np.meshgrid(grids[x_label], grids[y_label]) 

262 

263 # Prepare inputs for surface prediction 

264 inputs_surface = {} 

265 for var in variables: 

266 if var == x_label: 

267 inputs_surface[var] = X 

268 elif var == y_label: 

269 inputs_surface[var] = Y 

270 else: 

271 inputs_surface[var] = np.full_like(X, defaults_dict[var]) 

272 

273 # Create predictions based on the 2 combination values, keeping the other features fixed 

274 Z1 = regressor_func(**inputs_surface) 

275 surf1_cmap = "winter" 

276 cmap_below = truncate_colormap("YlGn_r", min_val=0, max_val=0.9) 

277 cmap_above = truncate_colormap("YlOrBr", min_val=0.15, max_val=1) 

278 #TODO: use this colormap incase of no truncation and set the background to darker grey 

279 # cmap_below="YlGn_r" 

280 # cmap_above="YlOrBr" 

281 # for axis in (ax.xaxis, ax.yaxis, ax.zaxis): 

282 # axis.pane.set_facecolor(light_grey) 

283 

284 norm1 = colors.Normalize(vmin=np.nanmin(Z1), vmax=np.nanmax(Z1)) 

285 

286 # filter real data points which belongs to the default dict bucket of the remaining combinations: 

287 other_features = [f for f in variables if f not in (x_label, y_label)] 

288 mask = pd.Series(True, index=measurements_data.index) # for filtering out rows which we don't want 

289 for f in other_features: 

290 lower, upper = bucket[f] 

291 # returns value true for the index if it falls within the range, 

292 # so iteratively removes indices for features which don't fall in bucket 

293 mask &= measurements_data[f].between(lower, upper) 

294 

295 # Get filtered real data 

296 real_x = measurements_data.loc[mask, x_label].to_numpy() 

297 real_y = measurements_data.loc[mask, y_label].to_numpy() 

298 real_target = target_values.loc[mask].values.flatten() 

299 

300 inputs_meas = {} 

301 for var in variables: 

302 if var == x_label: 

303 inputs_meas[var] = real_x 

304 elif var == y_label: 

305 inputs_meas[var] = real_y 

306 else: 

307 inputs_meas[var] = np.full_like(real_x, defaults_dict[var]) 

308 pred_at_pts = regressor_func(**inputs_meas) 

309 

310 residual = real_target - pred_at_pts 

311 below_mask = residual < 0 

312 above_mask = residual >= 0 

313 

314 norm_below = colors.Normalize(vmin=residual[below_mask].min() if below_mask.any() else -0.01, vmax=0) 

315 norm_above = colors.Normalize(vmin=0, vmax=residual[above_mask].max() if above_mask.any() else 0.01) 

316 scatter1= ax.scatter( 

317 real_x[below_mask], real_y[below_mask], real_target[below_mask], 

318 c=residual[below_mask], 

319 cmap=cmap_below, 

320 norm=norm_below, 

321 alpha=1, s=2, depthshade= False 

322 ) 

323 

324 surf1 = ax.plot_surface(X, Y, Z1, cmap=surf1_cmap, alpha=0.4, norm=norm1) 

325 

326 scatter2 =ax.scatter( 

327 real_x[above_mask], real_y[above_mask], real_target[above_mask], 

328 c=residual[above_mask], 

329 cmap=cmap_above, 

330 norm=norm_above, 

331 alpha=1, s=2, depthshade=False 

332 ) 

333 

334 ax.set_box_aspect([1, 1, 0.6]) 

335 ax.margins(x=0, y=0) 

336 ax.set_xlabel(x_label.replace('__', '\n'), fontsize=7, labelpad=-6) 

337 ax.set_ylabel(y_label.replace('__', '\n'), fontsize=7, labelpad=-6) 

338 ax.set_zlabel("Prediction", fontsize=7, labelpad=-7) 

339 ax.set_zlabel("Prediction", labelpad=-7) 

340 ax.tick_params(axis="x", which="major", pad=-5) 

341 ax.view_init(elev=30, azim=120) 

342 ax.tick_params(axis="y", pad=-3) 

343 ax.tick_params(axis="z", pad=-3) 

344 plt.setp(ax.get_yticklabels(), fontsize=7) 

345 plt.setp(ax.get_xticklabels(), fontsize=7) 

346 plt.setp(ax.get_zticklabels(), fontsize=7) 

347 

348 

349 cax1 = fig.add_axes([0.92, 0.525, 0.02, 0.4]) 

350 cb1 = fig.colorbar(surf1, cax=cax1) 

351 cb1.set_label("Regressor", fontsize=7) 

352 cb1.set_ticks([]) 

353 cb1.set_ticklabels([]) 

354 

355 # for 2 colormaps: 

356 cax_below = fig.add_axes([0.92, 0.05, 0.02, 0.2]) 

357 cb_below = fig.colorbar(scatter1, cax=cax_below) 

358 cb_below.set_label("Negative Prediction Error", fontsize=7) 

359 cb_below.set_ticks([]) 

360 cb_below.set_ticklabels([]) 

361 

362 cax_above = fig.add_axes([0.92, 0.25, 0.02, 0.2]) 

363 cb_above = fig.colorbar(scatter2, cax=cax_above) 

364 cb_above.set_label("Positive Prediction Error", fontsize=7) 

365 cb_above.set_ticks([]) 

366 cb_above.set_ticklabels([]) 

367 fig.text(0.972,0.25,"Measurement Data",rotation=90,va="center",ha="left",fontsize=7) 

368 

369 

370 return fig