Coverage for addmo/s5_insights/model_plots/scatter_plot.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-08-31 13:05 +0000

1import matplotlib.pyplot as plt 

2from addmo.util import plotting_utils as d 

3 

4 

5def scatter(train_data, predictions, target_name, rmse): 

6 

7 plt.figure(figsize= (d.cm2inch(15.5), d.cm2inch(15.5))) 

8 plt.subplots_adjust(left=0.12, right=0.97, bottom=0.08, top=0.95) 

9 plt.scatter(train_data, predictions, color=d.blue, label='Predictions') 

10 min_val = min(train_data.min(), predictions.min()) 

11 max_val = max(train_data.max(), predictions.max()) 

12 plt.plot([min_val, max_val], [min_val, max_val], color=d.red, linestyle='--', label="Target Value") 

13 plt.gca().text(0.85, 0.1, f'RMSE: {rmse:.2f}', color=d.black, ha='center', va='center',transform=plt.gca().transAxes, bbox=dict(facecolor='white',alpha=0.7, edgecolor=d.black)) 

14 plt.xlabel("Training Data") 

15 plt.ylabel("Predicted Values") 

16 plt.title(target_name) 

17 plt.legend() 

18 plt.grid(True) 

19 plt.axis('tight') 

20 return plt