tabensemb.trainer.Trainer.plot_truth_pred#

method

Trainer.plot_truth_pred(program: str, model_name: str, kde_color: bool = False, train_val_test: str = 'all', log_trans: bool = True, central_line: bool = True, upper_lim=9, ax=None, clr: Iterable | None = None, select_by_value_kwargs: Dict | None = None, figure_kwargs: Dict | None = None, scatter_kwargs: Dict | None = None, legend_kwargs: Dict | None = None, savefig_kwargs: Dict | None = None, save_show_close: bool = True) Axes[source]#

Compare ground truth and prediction for one model.

Parameters:
program

The selected model base.

model_name

The selected model in the model base

kde_color

Whether the scatters are colored by their KDE density. Ignored if train_val_test is “all”.

train_val_test

Which subset to be plotted. Choose from “Training”, “Validation”, “Testing”, and “all”.

log_trans

Whether the label data is in log scale.

central_line

Whether to plot a 45-degree diagonal line.

upper_lim

The upper limit of x/y-axis.

ax

matplotlib.axes.Axes

clr

A seaborn color palette or an Iterable of colors. For example seaborn.color_palette(“deep”).

select_by_value_kwargs

Arguments for tabensemb.data.datamodule.DataModule.select_by_value().

figure_kwargs

Arguments for plt.figure()

scatter_kwargs

Arguments for plt.scatter()

legend_kwargs

Arguments for plt.legend()

savefig_kwargs

Arguments for plt.savefig

save_show_close

Whether to save, show (in the notebook), and close the figure if ax is not given.

Returns:
matplotlib.axes.Axes