Coverage for addmo/s5_insights/model_plots/scatter_plot.py: 100%
17 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-08-31 13:05 +0000
1import matplotlib.pyplot as plt
2from addmo.util import plotting_utils as d
5def scatter(train_data, predictions, target_name, rmse):
7 plt.figure(figsize= (d.cm2inch(15.5), d.cm2inch(15.5)))
8 plt.subplots_adjust(left=0.12, right=0.97, bottom=0.08, top=0.95)
9 plt.scatter(train_data, predictions, color=d.blue, label='Predictions')
10 min_val = min(train_data.min(), predictions.min())
11 max_val = max(train_data.max(), predictions.max())
12 plt.plot([min_val, max_val], [min_val, max_val], color=d.red, linestyle='--', label="Target Value")
13 plt.gca().text(0.85, 0.1, f'RMSE: {rmse:.2f}', color=d.black, ha='center', va='center',transform=plt.gca().transAxes, bbox=dict(facecolor='white',alpha=0.7, edgecolor=d.black))
14 plt.xlabel("Training Data")
15 plt.ylabel("Predicted Values")
16 plt.title(target_name)
17 plt.legend()
18 plt.grid(True)
19 plt.axis('tight')
20 return plt