tabensemb.trainer.Trainer#
- class tabensemb.trainer.Trainer(device: str = 'cpu', project: str | None = None)[source]#
Bases:
objectThe model manager that provides saving, loading, ranking, and analyzing utilities.
- Attributes:
- args
A
tabensemb.config.UserConfiginstance.- configfile
The source of the configuration. If the
configargument ofload_config()is atabensemb.config.UserConfig, it is “UserInputConfig”. If theconfigargument is a path, it is the path. If theconfigargument is not given, it is the “base” argument passed to python when executing the script.- datamodule
A
tabensemb.data.datamodule.DataModuleinstance.- device
The device on which models are trained. “cpu”, “cuda”, or “cuda:X”.
- leaderboard
The ranking of all models in all model bases. Only valid after
get_leaderboard()is called.- modelbases
A list of
tabensemb.model.AbstractModel.- modelbases_names
Corresponding names (
tabensemb.model.AbstractModel.program) ofmodelbases.- project
The name of the
Trainer.- project_root
The place where all files are stored.
tabensemb.setting["default_output_path"]/{project}/{project_root_subfolder}/{TIME}-{config}whereprojectisproject,project_root_subfolderandconfigare arguments ofload_config().- sys_summary
Summary of the system when
summarize_device()is called.SPACESearch spaces for “lr”, “weight_decay”, and “batch_size” defined in the configuration.
all_feature_namescat_feature_mappingtabensemb.data.datamodule.DataModule.cat_feature_mappingcat_feature_namestabensemb.data.datamodule.DataModule.cat_feature_nameschosen_paramsThe “lr”, “weight_decay”, and “batch_size” parameters in the configuration.
cont_feature_namestabensemb.data.datamodule.DataModule.cont_feature_namesderived_datatabensemb.data.datamodule.DataModule.derived_dataderived_stacked_featurestabensemb.data.datamodule.DataModule.derived_stacked_features()dftabensemb.data.datamodule.DataModule.dffeature_datalabel_datalabel_nametabensemb.data.datamodule.DataModule.label_namestatic_paramsThe “patience” and “epoch” parameters in the configuration.
tensorstabensemb.data.datamodule.DataModule.tensorstest_indicestabensemb.data.datamodule.DataModule.test_indicestrain_indicestabensemb.data.datamodule.DataModule.train_indicestrainingtabensemb.data.datamodule.DataModule.trainingunscaled_feature_datatabensemb.data.datamodule.DataModule.unscaled_feature_data()unscaled_label_dataval_indicestabensemb.data.datamodule.DataModule.val_indices
Methods
- __init__(device: str = 'cpu', project: str | None = None)[source]#
The bridge of all modules. It contains all configurations and data. It can train model bases and evaluate results (including feature importance, partial dependency, etc.).
- Parameters:
- device:
The device on which models are trained. Choose from “cpu”, “cuda”, or “cuda:X” (if available).
- project:
The name of the
Trainer.
add_modelbases(models)Add a list of model bases and check whether their names conflict.
cal_feature_importance(program, model_name)Calculate feature importance using a specified model.
cal_partial_dependence([feature_subset])Calculate partial dependency.
cal_partial_dependence_2way(x_feature, y_feature)Calculate 2-way partial dependency.
cal_shap(program, model_name, **kwargs)Calculate SHAP values using a specified model.
Delete all model bases in the
Trainer.copy()Copy the
Trainerand save it to another directory.cross_validation(programs, n_random, ...[, ...])Repeat
load_data(), train model bases, and evaluate all models for multiple times.detach_model(program, model_name[, verbose])Detach the selected model of the selected model base to a separate
Trainerand save it to another directory.detach_modelbase(program[, verbose])Detach the selected model base to a separate
Trainerand save it to another directory.get_approx_cv_leaderboard(leaderboard[, save])Calculate approximated averages and standard errors based on
cross_validation()results in the folderself.project_root/cv.Get the best model from
leaderboard.get_leaderboard([test_data_only, ...])Run all model bases with/without cross validation for a leaderboard.
get_modelbase(program)Get the selected model base by its name.
Assemble cross-validation results in the folder
self.project_root/cvfor metrics of each model in each model base.get_predict_leaderboard(df, *args, **kwargs)Get prediction leaderboard of all models on an upcoming labeled dataset.
load_config([config, manual_config, ...])Load the configuration using a
tabensemb.config.UserConfigor a file in .py or .json format.load_data(*args, **kwargs)A wrapper of
tabensemb.data.datamodule.DataModule.load_data().load_state(trainer)Restore a
Trainerfrom a deep-copied state.plot_categorical_presence_ratio([category, ...])Plot the ratio of presence of each feature, but is classified by a categorical variable.
plot_corr([fontsize, imputed, features, ...])Plot correlation coefficients among features and the target.
plot_corr_with_label([imputed, features, ...])Plot correlation coefficients between the target and each feature.
plot_err_hist(program, model_name[, ...])Plot histograms of prediction errors.
plot_feature_box([imputed, features, ax, ...])Plot boxplot of the tabular data.
plot_feature_importance(program, model_name)Plot feature importance of a model using
cal_feature_importance().plot_fill_rating([ax, clr, category, ...])Plot the histogram of data point rating which is the percentage of filled features.
plot_hist(feature[, ax, clr, imputed, kde, ...])Plot the histogram of a feature.
plot_hist_all([imputed, fontsize, ...])Plot histograms of the tabular data.
plot_kde(x_col[, y_col, ax, clr, imputed, ...])Plot the kernel density estimation of a feature or two features.
plot_kde_all([imputed, fontsize, ...])Plot the kernel density estimation for each feature in the tabular data.
plot_loss(program, model_name[, ax, ...])Plot loss curves for a model.
plot_on_one_axes(meth_name, meth_kwargs_ls)Plot multiple items on one
matplotlib.axes.Axes.plot_pairplot([imputed, features, ...])Plot
seaborn.pairplotamong features and label.plot_partial_dependence(program, model_name, ...)Calculate and plot a partial dependence plot with bootstrapping for a feature.
plot_partial_dependence_2way(x_feature, ...)Calculate and plot a 2-way partial dependence plot with bootstrapping for a pair of features.
plot_partial_dependence_2way_all(program, ...)Calculate and plot 2-way partial dependence plots with bootstrapping.
plot_partial_dependence_all(program, model_name)Calculate and plot partial dependence plots with bootstrapping.
plot_partial_err(program, model_name, feature)Calculate prediction absolute errors on the testing dataset, and plot histograms of high-error samples and low-error samples respectively for a single feature.
plot_partial_err_all(program, model_name[, ...])Calculate prediction absolute errors on the testing dataset, and plot histograms of high-error samples and low-error samples respectively.
plot_pca_2d_visual([ax, category, clr, ...])Fit a
sklearn.decomposition.PCAon a set of features, and plot its first two principal components as scatters.plot_pdf(feature[, dist, ax, clr, imputed, ...])Plot the probability density function of a feature.
plot_presence_ratio([order, ax, clr, ...])Plot the ratio of presence of each feature.
plot_scatter(x_col, y_col[, category, ax, ...])Plot one column against another.
plot_subplots(ls, ls_kwarg_name, meth_name)Iterate over a list to plot subplots in a single figure.
plot_truth_pred(program, model_name[, ...])Compare ground truth and prediction for one model.
plot_truth_pred_all(program[, fontsize, ...])Compare ground truth and prediction for all models in a model base.
set_device(device)Set the device on which models are trained.
set_path(path[, verbose])Set the work directory of the
Trainer.set_status(training)A wrapper of
tabensemb.data.datamodule.DataModule.set_status()Print a summary of the environment.
Print the summary of the device, the configuration, and the global setting of the package (
tabensemb.setting).train([programs, verbose])Train all model bases (
modelbases)._bootstrap_fit(program, df, derived_data, ...)Make bootstrap resampling, fit the selected model on the resampled data, and assign sequential values to the selected feature to see how the prediction changes with respect to the feature.
_cal_leaderboard(programs_predictions[, ...])Calculate the leaderboard based on results from
cross_validation()ortabensemb.model.AbstractModel._predict_all()._create_dir([verbose, project_root_subfolder])Create the folder for the
Trainer._generate_grid(feature, grid_size[, ...])Generate a sequential (linspace) grid for a feature in the tabular dataset.
_metrics(predictions, metrics, test_data_only)Calculate metrics for predictions from
tabensemb.model.AbstractModel._predict_all()._plot_action_after_plot(fig_name, disable[, ...])Set the labels of x/y-axis, set the layout, save the current figure, show the figure if in a notebook, and close the figure.
_plot_action_categorical_scatter(x, y, df, ...)Plot scatters whose colors are related to their category.
_plot_action_category_unique_values(df, category)Get the category to classify data points and its unique values.
Generate the legend for feature types defined in the configuration.
Generate color palette for each feature according to their types defined in the configuration.
_plot_action_get_df(imputed, scaled, ...)A wrapper of
tabensemb.data.datamodule.DataModule.get_df()._plot_action_init_ax([ax, figure_kwargs, ...])_plot_action_subplots(meth_name, ls, ...[, ...])Iterate over a list to plot subplots in a single figure.
Read cross-validation leaderboards in the folder
self.project_root/cv.