tabensemb.trainer.Trainer#

class tabensemb.trainer.Trainer(device: str = 'cpu', project: str | None = None)[source]#

Bases: object

The model manager that provides saving, loading, ranking, and analyzing utilities.

Attributes:
args

A tabensemb.config.UserConfig instance.

configfile

The source of the configuration. If the config argument of load_config() is a tabensemb.config.UserConfig, it is “UserInputConfig”. If the config argument is a path, it is the path. If the config argument is not given, it is the “base” argument passed to python when executing the script.

datamodule

A tabensemb.data.datamodule.DataModule instance.

device

The device on which models are trained. “cpu”, “cuda”, or “cuda:X”.

leaderboard

The ranking of all models in all model bases. Only valid after get_leaderboard() is called.

modelbases

A list of tabensemb.model.AbstractModel.

modelbases_names

Corresponding names (tabensemb.model.AbstractModel.program) of modelbases.

project

The name of the Trainer.

project_root

The place where all files are stored. tabensemb.setting["default_output_path"] /{project}/{project_root_subfolder}/{TIME}-{config} where project is project, project_root_subfolder and config are arguments of load_config().

sys_summary

Summary of the system when summarize_device() is called.

SPACE

Search spaces for “lr”, “weight_decay”, and “batch_size” defined in the configuration.

all_feature_names

tabensemb.data.datamodule.DataModule.all_feature_names()

cat_feature_mapping

tabensemb.data.datamodule.DataModule.cat_feature_mapping

cat_feature_names

tabensemb.data.datamodule.DataModule.cat_feature_names

chosen_params

The “lr”, “weight_decay”, and “batch_size” parameters in the configuration.

cont_feature_names

tabensemb.data.datamodule.DataModule.cont_feature_names

derived_data

tabensemb.data.datamodule.DataModule.derived_data

derived_stacked_features

tabensemb.data.datamodule.DataModule.derived_stacked_features()

df

tabensemb.data.datamodule.DataModule.df

feature_data

tabensemb.data.datamodule.DataModule.feature_data()

label_data

tabensemb.data.datamodule.DataModule.label_data()

label_name

tabensemb.data.datamodule.DataModule.label_name

static_params

The “patience” and “epoch” parameters in the configuration.

tensors

tabensemb.data.datamodule.DataModule.tensors

test_indices

tabensemb.data.datamodule.DataModule.test_indices

train_indices

tabensemb.data.datamodule.DataModule.train_indices

training

tabensemb.data.datamodule.DataModule.training

unscaled_feature_data

tabensemb.data.datamodule.DataModule.unscaled_feature_data()

unscaled_label_data

tabensemb.data.datamodule.DataModule.unscaled_label_data()

val_indices

tabensemb.data.datamodule.DataModule.val_indices

Methods

__init__(device: str = 'cpu', project: str | None = None)[source]#

The bridge of all modules. It contains all configurations and data. It can train model bases and evaluate results (including feature importance, partial dependency, etc.).

Parameters:
device:

The device on which models are trained. Choose from “cpu”, “cuda”, or “cuda:X” (if available).

project:

The name of the Trainer.

add_modelbases(models)

Add a list of model bases and check whether their names conflict.

cal_feature_importance(program, model_name)

Calculate feature importance using a specified model.

cal_partial_dependence([feature_subset])

Calculate partial dependency.

cal_partial_dependence_2way(x_feature, y_feature)

Calculate 2-way partial dependency.

cal_shap(program, model_name, **kwargs)

Calculate SHAP values using a specified model.

clear_modelbase()

Delete all model bases in the Trainer.

copy()

Copy the Trainer and save it to another directory.

cross_validation(programs, n_random, ...[, ...])

Repeat load_data(), train model bases, and evaluate all models for multiple times.

detach_model(program, model_name[, verbose])

Detach the selected model of the selected model base to a separate Trainer and save it to another directory.

detach_modelbase(program[, verbose])

Detach the selected model base to a separate Trainer and save it to another directory.

get_approx_cv_leaderboard(leaderboard[, save])

Calculate approximated averages and standard errors based on cross_validation() results in the folder self.project_root/cv.

get_best_model()

Get the best model from leaderboard.

get_leaderboard([test_data_only, ...])

Run all model bases with/without cross validation for a leaderboard.

get_modelbase(program)

Get the selected model base by its name.

get_modelwise_cv_metrics()

Assemble cross-validation results in the folder self.project_root/cv for metrics of each model in each model base.

get_predict_leaderboard(df, *args, **kwargs)

Get prediction leaderboard of all models on an upcoming labeled dataset.

load_config([config, manual_config, ...])

Load the configuration using a tabensemb.config.UserConfig or a file in .py or .json format.

load_data(*args, **kwargs)

A wrapper of tabensemb.data.datamodule.DataModule.load_data().

load_state(trainer)

Restore a Trainer from a deep-copied state.

plot_categorical_presence_ratio([category, ...])

Plot the ratio of presence of each feature, but is classified by a categorical variable.

plot_corr([fontsize, imputed, features, ...])

Plot correlation coefficients among features and the target.

plot_corr_with_label([imputed, features, ...])

Plot correlation coefficients between the target and each feature.

plot_err_hist(program, model_name[, ...])

Plot histograms of prediction errors.

plot_feature_box([imputed, features, ax, ...])

Plot boxplot of the tabular data.

plot_feature_importance(program, model_name)

Plot feature importance of a model using cal_feature_importance().

plot_fill_rating([ax, clr, category, ...])

Plot the histogram of data point rating which is the percentage of filled features.

plot_hist(feature[, ax, clr, imputed, kde, ...])

Plot the histogram of a feature.

plot_hist_all([imputed, fontsize, ...])

Plot histograms of the tabular data.

plot_kde(x_col[, y_col, ax, clr, imputed, ...])

Plot the kernel density estimation of a feature or two features.

plot_kde_all([imputed, fontsize, ...])

Plot the kernel density estimation for each feature in the tabular data.

plot_loss(program, model_name[, ax, ...])

Plot loss curves for a model.

plot_on_one_axes(meth_name, meth_kwargs_ls)

Plot multiple items on one matplotlib.axes.Axes.

plot_pairplot([imputed, features, ...])

Plot seaborn.pairplot among features and label.

plot_partial_dependence(program, model_name, ...)

Calculate and plot a partial dependence plot with bootstrapping for a feature.

plot_partial_dependence_2way(x_feature, ...)

Calculate and plot a 2-way partial dependence plot with bootstrapping for a pair of features.

plot_partial_dependence_2way_all(program, ...)

Calculate and plot 2-way partial dependence plots with bootstrapping.

plot_partial_dependence_all(program, model_name)

Calculate and plot partial dependence plots with bootstrapping.

plot_partial_err(program, model_name, feature)

Calculate prediction absolute errors on the testing dataset, and plot histograms of high-error samples and low-error samples respectively for a single feature.

plot_partial_err_all(program, model_name[, ...])

Calculate prediction absolute errors on the testing dataset, and plot histograms of high-error samples and low-error samples respectively.

plot_pca_2d_visual([ax, category, clr, ...])

Fit a sklearn.decomposition.PCA on a set of features, and plot its first two principal components as scatters.

plot_pdf(feature[, dist, ax, clr, imputed, ...])

Plot the probability density function of a feature.

plot_presence_ratio([order, ax, clr, ...])

Plot the ratio of presence of each feature.

plot_scatter(x_col, y_col[, category, ax, ...])

Plot one column against another.

plot_subplots(ls, ls_kwarg_name, meth_name)

Iterate over a list to plot subplots in a single figure.

plot_truth_pred(program, model_name[, ...])

Compare ground truth and prediction for one model.

plot_truth_pred_all(program[, fontsize, ...])

Compare ground truth and prediction for all models in a model base.

set_device(device)

Set the device on which models are trained.

set_path(path[, verbose])

Set the work directory of the Trainer.

set_status(training)

A wrapper of tabensemb.data.datamodule.DataModule.set_status()

summarize_device()

Print a summary of the environment.

summarize_setting()

Print the summary of the device, the configuration, and the global setting of the package (tabensemb.setting).

train([programs, verbose])

Train all model bases (modelbases).

_bootstrap_fit(program, df, derived_data, ...)

Make bootstrap resampling, fit the selected model on the resampled data, and assign sequential values to the selected feature to see how the prediction changes with respect to the feature.

_cal_leaderboard(programs_predictions[, ...])

Calculate the leaderboard based on results from cross_validation() or tabensemb.model.AbstractModel._predict_all().

_create_dir([verbose, project_root_subfolder])

Create the folder for the Trainer.

_generate_grid(feature, grid_size[, ...])

Generate a sequential (linspace) grid for a feature in the tabular dataset.

_metrics(predictions, metrics, test_data_only)

Calculate metrics for predictions from tabensemb.model.AbstractModel._predict_all().

_plot_action_after_plot(fig_name, disable[, ...])

Set the labels of x/y-axis, set the layout, save the current figure, show the figure if in a notebook, and close the figure.

_plot_action_categorical_scatter(x, y, df, ...)

Plot scatters whose colors are related to their category.

_plot_action_category_unique_values(df, category)

Get the category to classify data points and its unique values.

_plot_action_generate_feature_types_legends(...)

Generate the legend for feature types defined in the configuration.

_plot_action_generate_feature_types_palette(...)

Generate color palette for each feature according to their types defined in the configuration.

_plot_action_get_df(imputed, scaled, ...)

A wrapper of tabensemb.data.datamodule.DataModule.get_df().

_plot_action_init_ax([ax, figure_kwargs, ...])

_plot_action_subplots(meth_name, ls, ...[, ...])

Iterate over a list to plot subplots in a single figure.

_read_cv_leaderboards()

Read cross-validation leaderboards in the folder self.project_root/cv.