def plot_predictions(td: TrainingDataGeneric) -> go.Figure:
y_train = td.y_train_single[:, 0]
y_train_pred = td.y_train_pred_single
if td.y_val_single is not None:
y_val = td.y_val_single[:, 0]
y_val_pred = td.y_val_pred_single
else:
y_val = np.array([])
y_val_pred = np.array([])
y_test = td.y_test_single[:, 0]
y_test_pred = td.y_test_pred_single
# Plot the predictions using plotly library
y_train_combined = list(zip(y_train, y_train_pred, ['Train'] * len(y_train)))
y_val_combined = list(zip(y_val, y_val_pred, ['Val'] * len(y_val)))
y_test_combined = list(zip(y_test, y_test_pred, ['Test'] * len(y_test)))
combined = y_train_combined + y_val_combined + y_test_combined
combined.sort(key=lambda x: x[0])
y_sorted, y_pred_sorted, labels = zip(*combined)
df_plotly = pd.DataFrame(combined, columns=['y_sorted', 'y_pred_sorted', 'labels'])
df_plotly['y_pred_sorted'] = df_plotly['y_pred_sorted'].apply(
lambda x: x[0] if isinstance(x, (list, np.ndarray)) else x)
fig1 = go.Figure()
fig1.add_trace(go.Scatter(x=list(range(len(y_sorted))), y=df_plotly['y_sorted'],
mode='markers',
name='True values',
marker=dict(color='black')))
color_map = {
'Train': 'green',
'Val': 'blue',
'Test': 'red',
}
fig2 = px.scatter(df_plotly, x=list(range(len(y_sorted))), y='y_pred_sorted', color='labels',
color_discrete_map=color_map)
fig = go.Figure(data=fig1.data + fig2.data)
fig.update_layout(
xaxis_title="Sample",
yaxis_title="Value"
)
return fig