Source code for tabensemb.trainer.trainer

import os.path
import warnings
import matplotlib.figure
import matplotlib.axes
import matplotlib.legend
import numpy as np
import pandas as pd
import tabensemb
from tabensemb.utils import *
from tabensemb.config import UserConfig
from tabensemb.data import DataModule
from tabensemb.data.utils import get_imputed_dtype, fill_cat_nan
from copy import deepcopy as cp
from skopt.space import Real, Integer, Categorical
import time
from typing import *
import torch.nn as nn
import torch.cuda
import torch.utils.data as Data
import scipy.stats as st
from sklearn.utils import resample as skresample
import platform, psutil, subprocess
import shutil
import pickle

set_random_seed(tabensemb.setting["random_seed"])


[docs] class Trainer: """ The model manager that provides saving, loading, ranking, and analyzing utilities. Attributes ---------- args A :class:`tabensemb.config.UserConfig` instance. configfile The source of the configuration. If the ``config`` argument of :meth:`load_config` is a :class:`tabensemb.config.UserConfig`, it is "UserInputConfig". If the ``config`` argument is a path, it is the path. If the ``config`` argument is not given, it is the "base" argument passed to python when executing the script. datamodule A :class:`tabensemb.data.datamodule.DataModule` instance. device The device on which models are trained. "cpu", "cuda", or "cuda:X". leaderboard The ranking of all models in all model bases. Only valid after :meth:`get_leaderboard` is called. modelbases A list of :class:`tabensemb.model.AbstractModel`. modelbases_names Corresponding names (:attr:`tabensemb.model.AbstractModel.program`) of :attr:`modelbases`. project The name of the :class:`Trainer`. project_root The place where all files are stored. ``tabensemb.setting["default_output_path"]`` ``/{project}/{project_root_subfolder}/{TIME}-{config}`` where ``project`` is :attr:`project`, ``project_root_subfolder`` and ``config`` are arguments of :meth:`load_config`. sys_summary Summary of the system when :meth:`summarize_device` is called. SPACE all_feature_names cat_feature_mapping cat_feature_names chosen_params cont_feature_names derived_data derived_stacked_features df feature_data label_data label_name static_params tensors test_indices train_indices training unscaled_feature_data unscaled_label_data val_indices """
[docs] def __init__(self, device: str = "cpu", project: str = None): """ The bridge of all modules. It contains all configurations and data. It can train model bases and evaluate results (including feature importance, partial dependency, etc.). Parameters ---------- device: The device on which models are trained. Choose from "cpu", "cuda", or "cuda:X" (if available). project: The name of the :class:`Trainer`. """ self.device = "cpu" self.project = project self.modelbases = [] self.modelbases_names = [] self.set_device(device)
[docs] def set_device(self, device: str): """ Set the device on which models are trained. Parameters ---------- device "cpu", "cuda", or "cuda:X" (if available) Notes ----- Multi-GPU training and training on a machine with multiple GPUs are not tested. """ if device not in ["cpu", "cuda"] and "cuda" not in device: raise Exception( f"Device {device} is an invalid selection. Choose among {['cpu', 'cuda']}." f"Note: Multi-GPU training and training on a machine with multiple GPUs are not tested." ) self.device = device
[docs] def add_modelbases(self, models: List): """ Add a list of model bases and check whether their names conflict. Parameters ---------- models: A list of :class:`tabensemb.model.AbstractModel`. """ new_modelbases_names = self.modelbases_names + [x.program for x in models] if len(new_modelbases_names) != len(list(set(new_modelbases_names))): raise Exception(f"Conflicted model base names: {self.modelbases_names}") self.modelbases += models self.modelbases_names = new_modelbases_names
[docs] def get_modelbase(self, program: str): """ Get the selected model base by its name. Parameters ---------- program The name of the model base. Returns ------- AbstractModel A model base. """ if program not in self.modelbases_names: raise Exception(f"Model base {program} not added to the trainer.") return self.modelbases[self.modelbases_names.index(program)]
[docs] def clear_modelbase(self): """ Delete all model bases in the :class:`Trainer`. """ self.modelbases = [] self.modelbases_names = []
[docs] def detach_modelbase(self, program: str, verbose: bool = True) -> "Trainer": """ Detach the selected model base to a separate :class:`Trainer` and save it to another directory. It is much cheaper than :meth:`copy` if only one model base is needed. If any external model is required, please use :meth:``detach_model`` to detach a single model. Parameters ---------- program The selected model base. verbose Verbosity Returns ------- Trainer A :class:`Trainer` with the selected model base. See Also -------- :meth:`copy`, :meth:`detach_model`, :meth:`tabensemb.model.AbstractModel.detach_model` """ modelbase = cp(self.get_modelbase(program=program)) tmp_trainer = modelbase.trainer tmp_trainer.clear_modelbase() new_path = safe_mkdir(add_postfix(self.project_root)) tmp_trainer.set_path(new_path, verbose=False) modelbase.set_path(os.path.join(new_path, modelbase.program)) tmp_trainer.add_modelbases([modelbase]) shutil.copytree(self.get_modelbase(program=program).root, modelbase.root) save_trainer(tmp_trainer, verbose=verbose) return tmp_trainer
[docs] def detach_model( self, program: str, model_name: str, verbose: bool = True ) -> "Trainer": """ Detach the selected model of the selected model base to a separate :class:`Trainer` and save it to another directory. If external models are required, they are also detached into the separated Trainer. Parameters ---------- program The selected model base. model_name The selected model. verbose Verbosity. Returns ------- Trainer A :class:`Trainer` with the selected model in its model base. """ required_models_names = self.get_modelbase(program=program).required_models( model_name ) if required_models_names is not None and any( [x.startswith("EXTERN") for x in required_models_names] ): tmp_trainer = self.copy() tmp_modelbase = tmp_trainer.get_modelbase(program=program) detached_model = tmp_modelbase.detach_model( model_name=model_name, program=f"{program}_{model_name}" ) required_models = tmp_modelbase._get_required_models(model_name) required_modelbases = ( [ model for x, model in required_models.items() if x.startswith("EXTERN") ] if required_models is not None else [] ) else: tmp_trainer = self.detach_modelbase(program=program, verbose=False) tmp_modelbase = tmp_trainer.get_modelbase(program=program) detached_model = tmp_modelbase.detach_model( model_name=model_name, program=f"{program}_{model_name}" ) required_modelbases = [] tmp_trainer.clear_modelbase() tmp_trainer.add_modelbases([detached_model] + required_modelbases) shutil.rmtree(tmp_modelbase.root) save_trainer(tmp_trainer, verbose=verbose) return tmp_trainer
[docs] def copy(self) -> "Trainer": """ Copy the :class:`Trainer` and save it to another directory. It might be time and space-consuming because all model bases are copied once. Returns ------- trainer A :class:`Trainer` instance. See Also -------- :meth:`detach_modelbase`, :meth:`detach_model`, :meth:`tabensemb.model.AbstractModel.detach_model` """ tmp_trainer = cp(self) new_path = safe_mkdir(add_postfix(self.project_root)) tmp_trainer.set_path(new_path, verbose=True) for modelbase in tmp_trainer.modelbases: modelbase.set_path(os.path.join(new_path, modelbase.program)) shutil.copytree(self.project_root, tmp_trainer.project_root, dirs_exist_ok=True) save_trainer(tmp_trainer) return tmp_trainer
[docs] def load_config( self, config: Union[str, UserConfig] = None, manual_config: Dict = None, project_root_subfolder: str = None, ) -> None: """ Load the configuration using a :class:`tabensemb.config.UserConfig` or a file in .py or .json format. Arguments passed to python when executing the script are parsed using ``argparse`` if ``config`` is left None. All keys in :meth:`tabensemb.config.UserConfig.defaults` can be parsed, for example: For the loss function: ``--loss mse``, For the total epoch: ``--epoch 200``, For the option of bayes opt: ``--bayes_opt`` to turn on Bayesian hyperparameter optimization, ``--no-bayes_opt`` to turn it off. The loaded configuration will be saved as a .py file in the project folder. Parameters ---------- config It can be the path to the configuration file in json or python format, or a :class:`tabensemb.config.UserConfig` instance. If it is None, arguments passed to python will be parsed. If it is a path, it will be passed to :meth:`tabensemb.config.UserConfig.from_file`. manual_config Update the configuration with a dict. For example: ``manual_config={"bayes_opt": True}``. project_root_subfolder The subfolder that the project will be locate in. The folder name will be ``tabensemb.setting["default_output_path"]`` ``/{project}/{project_root_subfolder}/{TIME}-{config}`` """ input_config = config is not None if isinstance(config, str) or not input_config: # The base config is loaded using the --base argument if is_notebook() and not input_config: raise Exception( "A config file must be assigned in notebook environment." ) elif is_notebook() or input_config: parse_res = {"base": config} else: # not notebook and config is None parse_res = UserConfig.parse() self.configfile = parse_res["base"] config = UserConfig(path=self.configfile) # Then, several args can be modified using other arguments like --lr, --weight_decay # only when a config file is not given so that configs depend on input arguments. if not is_notebook() and not input_config: # If the argument is not given in the command, the item will be None and will not be merged into # `config` using the `merge` method. config.merge(parse_res) if manual_config is not None: config.merge(manual_config) self.args = config else: self.configfile = "UserInputConfig" if manual_config is not None: warnings.warn(f"manual_config is ignored when config is an UserConfig.") self.args = config self.datamodule = DataModule(self.args) self.project = self.args["database"] if self.project is None else self.project self._create_dir(project_root_subfolder=project_root_subfolder) config.to_file(os.path.join(self.project_root, "args.py"))
@property def static_params(self) -> Dict: """ The "patience" and "epoch" parameters in the configuration. """ return { "patience": self.args["patience"], "epoch": self.args["epoch"], } @property def chosen_params(self): """ The "lr", "weight_decay", and "batch_size" parameters in the configuration. """ return { "lr": self.args["lr"], "weight_decay": self.args["weight_decay"], "batch_size": self.args["batch_size"], } @property def SPACE(self): """ Search spaces for "lr", "weight_decay", and "batch_size" defined in the configuration. """ SPACE = [] for var in self.args["SPACEs"].keys(): setting = cp(self.args["SPACEs"][var]) ty = setting["type"] setting.pop("type") if ty == "Real": SPACE.append(Real(name=var, **setting)) elif ty == "Categorical": SPACE.append(Categorical(name=var, **setting)) elif ty == "Integer": SPACE.append(Integer(name=var, **setting)) else: raise Exception("Invalid type of skopt space.") return SPACE @property def feature_data(self) -> pd.DataFrame: """ :meth:`tabensemb.data.datamodule.DataModule.feature_data` """ return self.datamodule.feature_data if hasattr(self, "datamodule") else None @property def unscaled_feature_data(self): """ :meth:`tabensemb.data.datamodule.DataModule.unscaled_feature_data` """ return ( self.datamodule.unscaled_feature_data if hasattr(self, "datamodule") else None ) @property def unscaled_label_data(self): """ :meth:`tabensemb.data.datamodule.DataModule.unscaled_label_data` """ return ( self.datamodule.unscaled_label_data if hasattr(self, "datamodule") else None ) @property def label_data(self) -> pd.DataFrame: """ :meth:`tabensemb.data.datamodule.DataModule.label_data` """ return self.datamodule.label_data if hasattr(self, "datamodule") else None @property def derived_data(self): """ :attr:`tabensemb.data.datamodule.DataModule.derived_data` """ return self.datamodule.derived_data if hasattr(self, "datamodule") else None @property def cont_feature_names(self): """ :attr:`tabensemb.data.datamodule.DataModule.cont_feature_names` """ return ( self.datamodule.cont_feature_names if hasattr(self, "datamodule") else None ) @property def cat_feature_names(self): """ :attr:`tabensemb.data.datamodule.DataModule.cat_feature_names` """ return ( self.datamodule.cat_feature_names if hasattr(self, "datamodule") else None ) @property def all_feature_names(self): """ :meth:`tabensemb.data.datamodule.DataModule.all_feature_names` """ return ( self.datamodule.all_feature_names if hasattr(self, "datamodule") else None ) @property def label_name(self): """ :attr:`tabensemb.data.datamodule.DataModule.label_name` """ return self.datamodule.label_name if hasattr(self, "datamodule") else None @property def train_indices(self): """ :attr:`tabensemb.data.datamodule.DataModule.train_indices` """ return self.datamodule.train_indices if hasattr(self, "datamodule") else None @property def val_indices(self): """ :attr:`tabensemb.data.datamodule.DataModule.val_indices` """ return self.datamodule.val_indices if hasattr(self, "datamodule") else None @property def test_indices(self): """ :attr:`tabensemb.data.datamodule.DataModule.test_indices` """ return self.datamodule.test_indices if hasattr(self, "datamodule") else None @property def df(self): """ :attr:`tabensemb.data.datamodule.DataModule.df` """ return self.datamodule.df if hasattr(self, "datamodule") else None @property def tensors(self): """ :attr:`tabensemb.data.datamodule.DataModule.tensors` """ return self.datamodule.tensors if hasattr(self, "datamodule") else None @property def cat_feature_mapping(self): """ :attr:`tabensemb.data.datamodule.DataModule.cat_feature_mapping` """ return ( self.datamodule.cat_feature_mapping if hasattr(self, "datamodule") else None ) @property def derived_stacked_features(self): """ :meth:`tabensemb.data.datamodule.DataModule.derived_stacked_features` """ return ( self.datamodule.derived_stacked_features if hasattr(self, "datamodule") else None ) @property def training(self): """ :attr:`tabensemb.data.datamodule.DataModule.training` """ return self.datamodule.training if hasattr(self, "datamodule") else None
[docs] def set_status(self, training: bool): """ A wrapper of :meth:`tabensemb.data.datamodule.DataModule.set_status` """ self.datamodule.set_status(training)
[docs] def load_data(self, *args, **kwargs): """ A wrapper of :meth:`tabensemb.data.datamodule.DataModule.load_data`. The ``save_path`` argument is set to :attr:`project_root`. """ if "save_path" in kwargs.keys(): kwargs.__delitem__("save_path") self.datamodule.load_data(save_path=self.project_root, *args, **kwargs)
[docs] def set_path(self, path: Union[os.PathLike, str], verbose=False): """ Set the work directory of the :class:`Trainer`. Parameters ---------- path The work directory. """ self.project_root = path if not os.path.exists(self.project_root): os.mkdir(self.project_root) if verbose: print(f"The project will be saved to {self.project_root}")
[docs] def _create_dir(self, verbose: bool = True, project_root_subfolder: str = None): """ Create the folder for the :class:`Trainer`. Parameters ---------- verbose Whether to print the path of the :class:`Trainer`. project_root_subfolder See :meth:`load_config`. """ default_path = tabensemb.setting["default_output_path"] if not os.path.exists(default_path): os.makedirs(default_path, exist_ok=True) if project_root_subfolder is not None: if not os.path.exists(os.path.join(default_path, project_root_subfolder)): os.makedirs( os.path.join(default_path, project_root_subfolder), exist_ok=True ) subfolder = ( self.project if project_root_subfolder is None else os.path.join(project_root_subfolder, self.project) ) t = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) folder_name = t + "-0" + "_" + os.path.split(self.configfile)[-1] if not os.path.exists(os.path.join(default_path, subfolder)): os.makedirs(os.path.join(default_path, subfolder), exist_ok=True) self.set_path( safe_mkdir(os.path.join(default_path, subfolder, folder_name)), verbose=verbose, )
[docs] def summarize_setting(self): """ Print the summary of the device, the configuration, and the global setting of the package (``tabensemb.setting``). """ print("Device:") print(pretty(self.summarize_device())) print("Configurations:") print(pretty(self.args)) print(f"Global settings:") print(pretty(tabensemb.setting))
[docs] def summarize_device(self): """ Print a summary of the environment. https://www.thepythoncode.com/article/get-hardware-system-information-python """ def get_size(bytes, suffix="B"): """ Scale bytes to its proper format e.g: 1253656 => '1.20MB' 1253656678 => '1.17GB' """ factor = 1024 for unit in ["", "K", "M", "G", "T", "P"]: if bytes < factor: return f"{bytes:.2f}{unit}{suffix}" bytes /= factor def get_processor_info(): if platform.system() == "Windows": return platform.processor() elif platform.system() == "Darwin": return ( subprocess.check_output( ["/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string"] ) .strip() .decode("utf-8") ) elif platform.system() == "Linux": command = "cat /proc/cpuinfo" all_info = ( subprocess.check_output(command, shell=True).strip().decode("utf-8") ) for string in all_info.split("\n"): if "model name\t: " in string: return string.split("\t: ")[1] return "" uname = platform.uname() cpufreq = psutil.cpu_freq() svmem = psutil.virtual_memory() self.sys_summary = { "System": uname.system, "Node name": uname.node, "System release": uname.release, "System version": uname.version, "Machine architecture": uname.machine, "Processor architecture": uname.processor, "Processor model": get_processor_info(), "Physical cores": psutil.cpu_count(logical=False), "Total cores": psutil.cpu_count(logical=True), "Max core frequency": f"{cpufreq.max:.2f}Mhz", "Total memory": get_size(svmem.total), "Python version": platform.python_version(), "Python implementation": platform.python_implementation(), "Python compiler": platform.python_compiler(), "Cuda availability": torch.cuda.is_available(), "GPU devices": [ torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()) ], } return self.sys_summary
[docs] def train( self, programs: List[str] = None, verbose: bool = True, *args, **kwargs, ): """ Train all model bases (:attr:`modelbases`). Parameters ---------- programs A selected subset of model bases. verbose Verbosity. *args Arguments passed to :meth:`tabensemb.model.AbstractModel.train` **kwargs Arguments passed to :meth:`tabensemb.model.AbstractModel.train` """ if programs is None: modelbases_to_train = self.modelbases else: modelbases_to_train = [self.get_modelbase(x) for x in programs] if len(modelbases_to_train) == 0: warnings.warn( f"No modelbase is trained. Please confirm that trainer.add_modelbases is called." ) for modelbase in modelbases_to_train: modelbase.train(*args, verbose=verbose, **kwargs)
[docs] def cross_validation( self, programs: List[str], n_random: int, verbose: bool, test_data_only: bool, split_type: str = "cv", load_from_previous: bool = False, **kwargs, ) -> Dict[str, Dict[str, Dict[str, Tuple[np.ndarray, np.ndarray]]]]: """ Repeat :meth:`load_data`, train model bases, and evaluate all models for multiple times. Parameters ---------- programs A selected subset of model bases. n_random The number of repeats. verbose Verbosity. test_data_only Whether to evaluate models only on testing datasets. split_type The type of data splitting. "random" and "cv" are supported. Ignored when ``load_from_previous`` is True. load_from_previous Load the state of a previous run (mostly because of an unexpected interruption). **kwargs Arguments for :meth:`tabensemb.model.AbstractModel.train` Notes ----- The results of a continuous run and a continued run (``load_from_previous=True``) are consistent. Returns ------- dict A dict in the following format: {keys: programs, values: {keys: model names, values: {keys: ["Training", "Testing", "Validation"], values: (Predicted values, true values)}} """ programs_predictions = {} for program in programs: programs_predictions[program] = {} if load_from_previous: if not os.path.exists( os.path.join(self.project_root, "cv") ) or not os.path.isfile( os.path.join(self.project_root, "cv", "cv_state.pkl") ): raise Exception(f"No previous state to load from.") with open( os.path.join(self.project_root, "cv", "cv_state.pkl"), "rb" ) as file: current_state = pickle.load(file) start_i = current_state["i_random"] self.load_state(current_state["trainer"]) programs_predictions = current_state["programs_predictions"] reloaded_once_predictions = current_state["once_predictions"] skip_program = reloaded_once_predictions is not None if start_i >= n_random: raise Exception( f"The loaded state is incompatible with the current setting." ) print(f"Previous cross validation state is loaded.") split_type = ( "cv" if self.datamodule.datasplitter.cv_generator is not None else "random" ) else: start_i = 0 skip_program = False reloaded_once_predictions = None if split_type == "cv" and not self.datamodule.datasplitter.support_cv: warnings.warn( f"{self.datamodule.datasplitter.__class__.__name__} does not support cross validation splitting. " f"Use its original regime instead." ) split_type = "random" self.datamodule.datasplitter.reset_cv( cv=n_random if split_type == "cv" else -1 ) if n_random > 0 and not os.path.exists( os.path.join(self.project_root, "cv") ): os.mkdir(os.path.join(self.project_root, "cv")) def func_save_state(state): with open( os.path.join(self.project_root, "cv", "cv_state.pkl"), "wb" ) as file: pickle.dump(state, file) for i in range(start_i, n_random): if verbose: print( f"----------------------------{i + 1}/{n_random} {split_type}----------------------------" ) trainer_state = cp(self) if not skip_program: current_state = { "trainer": trainer_state, "i_random": i, "programs_predictions": programs_predictions, "once_predictions": None, } func_save_state(current_state) with HiddenPrints(disable_std=not verbose): set_random_seed(tabensemb.setting["random_seed"] + i) self.load_data() once_predictions = {} if not skip_program else reloaded_once_predictions for program in programs: if skip_program: if program in once_predictions.keys(): print(f"Skipping finished model base {program}") continue else: skip_program = False modelbase = self.get_modelbase(program) modelbase.train(dump_trainer=True, verbose=verbose, **kwargs) predictions = modelbase._predict_all( verbose=verbose, test_data_only=test_data_only ) once_predictions[program] = predictions for model_name, value in predictions.items(): if model_name in programs_predictions[program].keys(): # current_predictions is a reference, so modifications are directly applied to it. current_predictions = programs_predictions[program][model_name] def append_once(key): current_predictions[key] = ( np.append( current_predictions[key][0], value[key][0], axis=0 ), np.append( current_predictions[key][1], value[key][1], axis=0 ), ) append_once("Testing") if not test_data_only: append_once("Training") append_once("Validation") else: programs_predictions[program][model_name] = value # It is expected that only model bases in self is changed. datamodule is not updated because the cross # validation status should remain before load_data() is called. trainer_state.modelbases = self.modelbases current_state = { "trainer": trainer_state, "i_random": i, "programs_predictions": programs_predictions, "once_predictions": once_predictions, } func_save_state(current_state) df_once = self._cal_leaderboard( once_predictions, test_data_only=test_data_only, save=False ) df_once.to_csv( os.path.join(self.project_root, "cv", f"leaderboard_cv_{i}.csv") ) trainer_state.modelbases = self.modelbases current_state = { "trainer": trainer_state, "i_random": i + 1, "programs_predictions": programs_predictions, "once_predictions": None, } func_save_state(current_state) if verbose: print( f"--------------------------End {i + 1}/{n_random} {split_type}--------------------------" ) return programs_predictions
[docs] def get_leaderboard( self, test_data_only: bool = False, dump_trainer: bool = True, cross_validation: int = 0, verbose: bool = True, load_from_previous: bool = False, split_type: str = "cv", **kwargs, ) -> pd.DataFrame: """ Run all model bases with/without cross validation for a leaderboard. Parameters ---------- test_data_only Whether to evaluate models only on testing datasets. dump_trainer Whether to save the :class:`Trainer`. cross_validation The number of cross-validation. See :meth:`cross_validation`. 0 to evaluate current trained models on the current dataset. verbose Verbosity. load_from_previous Load the state of a previous run (mostly because of an unexpected interruption). split_type The type of data splitting. "random" and "cv" are supported. Ignored when ``load_from_previous`` is True. **kwargs Arguments for :meth:`tabensemb.model.AbstractModel.train` Returns ------- pd.DataFrame The leaderboard. """ if len(self.modelbases) == 0: raise Exception( f"No modelbase available. Run trainer.add_modelbases() first." ) if cross_validation != 0: programs_predictions = self.cross_validation( programs=self.modelbases_names, n_random=cross_validation, verbose=verbose, test_data_only=test_data_only, load_from_previous=load_from_previous, split_type=split_type, **kwargs, ) else: programs_predictions = {} for modelbase in self.modelbases: print(f"{modelbase.program} metrics") programs_predictions[modelbase.program] = modelbase._predict_all( verbose=verbose, test_data_only=test_data_only ) df_leaderboard = self._cal_leaderboard( programs_predictions, test_data_only=test_data_only ) if dump_trainer: save_trainer(self) return df_leaderboard
[docs] def get_predict_leaderboard( self, df: pd.DataFrame, *args, **kwargs ) -> pd.DataFrame: """ Get prediction leaderboard of all models on an upcoming labeled dataset. Parameters ---------- df: A new tabular dataset that has the same structure as ``self.trainer.datamodule.X_test``. args Arguments of :meth:`tabensemb.model.AbstractModel.predict`. kwargs Arguments of :meth:`tabensemb.model.AbstractModel.predict`. Returns ------- pd.DataFrame """ if len(self.modelbases) == 0: raise Exception( f"No modelbase available. Run trainer.add_modelbases() first." ) kwargs["proba"] = True programs_predictions = {} for modelbase in self.modelbases: print(f"{modelbase.program} metrics") truth: np.ndarray = df[self.label_name].values program_predictions = {} for model_name in modelbase.get_model_names(): pred: np.ndarray = modelbase.predict( df, *args, model_name=model_name, **kwargs ) program_predictions[model_name] = {"Testing": (pred, truth)} programs_predictions[modelbase.program] = program_predictions df_leaderboard = self._cal_leaderboard( programs_predictions, test_data_only=True ) return df_leaderboard
[docs] def get_approx_cv_leaderboard( self, leaderboard: pd.DataFrame, save: bool = True ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Calculate approximated averages and standard errors based on :meth:`cross_validation` results in the folder ``self.project_root/cv``. Parameters ---------- leaderboard A reference leaderboard to be filled by avg and std, and to sort the returned DataFrame. save Save returned results locally with names "leaderboard_approx_mean.csv" and "leaderboard_approx_std.csv" Returns ------- pd.DataFrame Averages in the same format as the input ``leaderboard``. There is an additional column "Rank". pd.DataFrame Standard errors in the same format as the input ``leaderboard``. There is an additional column "Rank". Notes ----- The returned results are approximations of the precise leaderboard from ``get_leaderboard``. Some metrics like RMSE may be different because data-point-wise and cross-validation-wise averaging are different. """ leaderboard_mean = leaderboard.copy() leaderboard_std = leaderboard.copy() leaderboard_mean["Rank"] = np.nan leaderboard_std["Rank"] = np.nan if not os.path.exists(os.path.join(self.project_root, "cv")): warnings.warn( f"Cross validation folder {os.path.join(self.project_root, 'cv')} not found." ) leaderboard_mean["Rank"] = leaderboard.index.values + 1 leaderboard_std.loc[ :, np.setdiff1d(leaderboard_std.columns, ["Program", "Model"]) ] = 0 return leaderboard_mean, leaderboard_std df_cvs, programs, models, metrics = self._read_cv_leaderboards() modelwise_cv = self.get_modelwise_cv_metrics() for program in programs: program_models = models[program] for model in program_models: res_cv = modelwise_cv[program][model] # If numeric_only=True, only "Rank" is calculated somehow. mean = res_cv[metrics].mean(0, numeric_only=False) std = res_cv[metrics].std(0, numeric_only=False) where_model = leaderboard_std.loc[ (leaderboard_std["Program"] == program) & (leaderboard_std["Model"] == model) ].index[0] leaderboard_mean.loc[where_model, mean.index] = mean leaderboard_std.loc[where_model, std.index] = std if save: leaderboard_mean.to_csv( os.path.join(self.project_root, "leaderboard_approx_mean.csv"), index=False, ) leaderboard_std.to_csv( os.path.join(self.project_root, "leaderboard_approx_std.csv"), index=False, ) return leaderboard_mean, leaderboard_std
[docs] def get_modelwise_cv_metrics(self) -> Dict[str, Dict[str, pd.DataFrame]]: """ Assemble cross-validation results in the folder ``self.project_root/cv`` for metrics of each model in each model base. Returns ------- dict A dict of dicts where each of them contains metrics of cross-validation of one model. """ df_cvs, programs, models, metrics = self._read_cv_leaderboards() res_cvs = {} for program in programs: res_cvs[program] = {} program_models = models[program] for model in program_models: res_cvs[program][model] = pd.DataFrame( columns=df_cvs[0].columns, index=np.arange(len(df_cvs)) ) cv_metrics = np.zeros((len(df_cvs), len(metrics))) for cv_idx, df_cv in enumerate(df_cvs): where_model = (df_cv["Program"] == program) & ( df_cv["Model"] == model ) model_metrics = df_cv.loc[where_model][metrics].values.flatten() cv_metrics[cv_idx, :] = model_metrics res_cvs[program][model][metrics] = cv_metrics res_cvs[program][model]["Program"] = program res_cvs[program][model]["Model"] = model return res_cvs
[docs] def _read_cv_leaderboards( self, ) -> Tuple[List[pd.DataFrame], List[str], Dict[str, List[str]], List[str]]: """ Read cross-validation leaderboards in the folder ``self.project_root/cv``. Returns ------- list Cross validation leaderboards list Model base names dict Model names in each model base list Metric names. """ if not os.path.exists(os.path.join(self.project_root, "cv")): raise Exception( f"Cross validation folder {os.path.join(self.project_root, 'cv')} not found." ) cvs = sorted( [ i for i in os.listdir(os.path.join(self.project_root, "cv")) if "leaderboard_cv" in i ] ) df_cvs = [ pd.read_csv(os.path.join(self.project_root, "cv", cv), index_col=0) for cv in cvs ] programs = list(np.unique(df_cvs[0]["Program"].values)) models = { a: list(df_cvs[0].loc[np.where(df_cvs[0]["Program"] == a)[0], "Model"]) for a in programs } for df_cv in df_cvs: df_cv["Rank"] = df_cv.index.values + 1 metrics = list(np.setdiff1d(df_cvs[0].columns, ["Program", "Model"])) return df_cvs, programs, models, metrics
[docs] def _cal_leaderboard( self, programs_predictions: Dict[ str, Dict[str, Dict[str, Tuple[np.ndarray, np.ndarray]]] ], metrics: List[str] = None, test_data_only: bool = False, save: bool = True, ) -> pd.DataFrame: """ Calculate the leaderboard based on results from :meth:`cross_validation` or :meth:`tabensemb.model.AbstractModel._predict_all`. Parameters ---------- programs_predictions Results from :meth:`cross_validation`, or assembled results from :meth:`tabensemb.model.AbstractModel._predict_all`. See the source code of :meth:`get_leaderboard` for details. metrics The metrics that have been implemented in :func:`tabensemb.utils.utils.metric_sklearn`. test_data_only Whether to evaluate models only on testing datasets. save Whether to save the leaderboard locally and as an attribute in the :class:`Trainer`. Returns ------- pd.DataFrame The leaderboard dataframe. """ if metrics is None: metrics = { "regression": REGRESSION_METRICS, "binary": BINARY_METRICS, "multiclass": MULTICLASS_METRICS, }[self.datamodule.task] dfs = [] for modelbase_name in self.modelbases_names: df = self._metrics( programs_predictions[modelbase_name], metrics, test_data_only=test_data_only, ) df["Program"] = modelbase_name dfs.append(df) df_leaderboard = pd.concat(dfs, axis=0, ignore_index=True) sorted_by = metrics[0].upper() df_leaderboard.sort_values( f"Testing {sorted_by}" if not test_data_only else sorted_by, inplace=True ) df_leaderboard.reset_index(drop=True, inplace=True) df_leaderboard = df_leaderboard[["Program"] + list(df_leaderboard.columns)[:-1]] if save: df_leaderboard.to_csv(os.path.join(self.project_root, "leaderboard.csv")) self.leaderboard = df_leaderboard if os.path.exists(os.path.join(self.project_root, "cv")): self.get_approx_cv_leaderboard(df_leaderboard, save=True) return df_leaderboard
[docs] def _plot_action_subplots( self, meth_name: str, ls: List[str], ls_kwarg_name: Union[str, None], tqdm_active: bool = False, with_title: bool = False, titles: List[str] = None, fontsize: float = 12, xlabel: str = None, ylabel: str = None, twin_ylabel: str = None, get_figsize_kwargs: Dict = None, figure_kwargs: Dict = None, meth_fix_kwargs: Dict = None, ): """ Iterate over a list to plot subplots in a single figure. Parameters ---------- ls The list to be iterated. ls_kwarg_name The argument name of the components in ``ls`` when the component is passed to ``meth_name`` one by one. If is None, the components in ``ls`` should be dictionaries and will be unpacked and passed to the method ``meth_name``. tqdm_active Whether to use a tqdm progress bar. meth_name The method to plot on a subplot. It has an argument named ``ax`` which indicates the subplot. with_title Whether each subplot has a title, which is the components in ``ls`` if ``titles`` is None. titles The titles of each subplot if ``with_title`` is True. fontsize ``plt.rcParams["font.size"]`` xlabel The overall xlabel. ylabel The overall ylabel. twin_ylabel The overall ylabel of the twin x-axis. get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. figure_kwargs Arguments for ``plt.figure()`` meth_fix_kwargs Fixed arguments of ``meth_name`` (except for ``ax`` and ``ls_kwarg_name``). Returns ------- matplotlib.figure.Figure The figure that has plotted subplots. """ from tqdm.auto import tqdm def _iterator(iterator, *args, **kwargs): for item in iterator: yield item figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) get_figsize_kwargs_ = update_defaults_by_kwargs( dict(max_col=4, width_per_item=3, height_per_item=3, max_width=14), get_figsize_kwargs, ) figsize, width, height = get_figsize(n=len(ls), **get_figsize_kwargs_) fig = plt.figure(figsize=figsize, **figure_kwargs_) plt.rcParams["font.size"] = fontsize tqdm = tqdm if tqdm_active else _iterator for idx, name in tqdm(enumerate(ls), total=len(ls)): ax = plt.subplot(height, width, idx + 1) if with_title: ax.set_title( name if titles is None else titles[idx], {"fontsize": fontsize} ) getattr(self, meth_name)( ax=ax, **({ls_kwarg_name: name} if ls_kwarg_name is not None else name), **meth_fix_kwargs, ) ax = fig.add_subplot(111, frameon=False) ax.tick_params( labelcolor="none", which="both", top=False, bottom=False, left=False, right=False, ) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if twin_ylabel is not None: twin_ax = ax.twinx() twin_ax.set_frame_on(False) twin_ax.tick_params( labelcolor="none", which="both", top=False, bottom=False, left=False, right=False, ) twin_ax.set_ylabel(twin_ylabel) return fig
[docs] def _plot_action_get_df( self, imputed: bool, scaled: bool, cat_transformed: bool ) -> pd.DataFrame: """ A wrapper of :meth:`tabensemb.data.datamodule.DataModule.get_df`. """ return self.datamodule.get_df( imputed=imputed, scaled=scaled, cat_transformed=cat_transformed )
[docs] def plot_subplots( self, ls: List[str], ls_kwarg_name: str, meth_name: str, with_title: bool = False, titles: List[str] = None, fontsize: float = 12, xlabel: str = None, ylabel: str = None, twin_ylabel: str = None, get_figsize_kwargs: Dict = None, figure_kwargs: Dict = None, meth_fix_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, ): """ Iterate over a list to plot subplots in a single figure. Parameters ---------- ls The list to be iterated. ls_kwarg_name The argument name of the components in ``ls`` when the component is passed to ``meth_name``. meth_name The method to plot on a subplot. It has an argument named ``ax`` which indicates the subplot. with_title Whether each subplot has a title, which is the components in ``ls`` if ``titles`` is None. titles The titles of each subplot if ``with_title`` is True. fontsize ``plt.rcParams["font.size"]`` xlabel The overall xlabel. ylabel The overall ylabel. twin_ylabel The overall ylabel of the twin x-axis. get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. figure_kwargs Arguments for ``plt.figure()`` meth_fix_kwargs Fixed arguments of ``meth_name`` (except for ``ax`` and ``ls_kwarg_name``). savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. Returns ------- matplotlib.figure.Figure The figure that has plotted subplots. """ fig = self._plot_action_subplots( ls=ls, ls_kwarg_name=ls_kwarg_name, meth_name=meth_name, meth_fix_kwargs=meth_fix_kwargs, fontsize=fontsize, with_title=with_title, titles=titles, xlabel=xlabel, ylabel=ylabel, twin_ylabel=twin_ylabel, get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, tqdm_active=tqdm_active, ) return self._plot_action_after_plot( disable=False, ax_or_fig=fig, fig_name=os.path.join(self.project_root, f"subplots.pdf"), tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_truth_pred_all( self, program: str, fontsize=14, get_figsize_kwargs: Dict = None, figure_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, **kwargs, ) -> Union[None, matplotlib.figure.Figure]: """ Compare ground truth and prediction for all models in a model base. Parameters ---------- program The selected model base. fontsize ``plt.rcParams["font.size"]`` get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. figure_kwargs Arguments for ``plt.figure()`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. kwargs Arguments for :meth:`plot_truth_pred` Returns ------- matplotlib.figure.Figure The figure if ``save_show_close`` is False. """ modelbase = self.get_modelbase(program) model_names = modelbase.get_model_names() savefig_kwargs_ = update_defaults_by_kwargs( dict(fname=os.path.join(self.project_root, program, f"truth_pred.pdf")), savefig_kwargs, ) return self.plot_subplots( ls=model_names, ls_kwarg_name="model_name", meth_name="plot_truth_pred", meth_fix_kwargs=dict(program=program, **kwargs), fontsize=fontsize, with_title=True, xlabel="Ground truth", ylabel="Prediction", get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs_, tqdm_active=tqdm_active, )
[docs] def plot_truth_pred( self, program: str, model_name: str, kde_color: bool = False, train_val_test: str = "all", log_trans: bool = True, central_line: bool = True, upper_lim=9, ax=None, clr: Iterable = None, select_by_value_kwargs: Dict = None, figure_kwargs: Dict = None, scatter_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Compare ground truth and prediction for one model. Parameters ---------- program The selected model base. model_name The selected model in the model base kde_color Whether the scatters are colored by their KDE density. Ignored if ``train_val_test`` is "all". train_val_test Which subset to be plotted. Choose from "Training", "Validation", "Testing", and "all". log_trans Whether the label data is in log scale. central_line Whether to plot a 45-degree diagonal line. upper_lim The upper limit of x/y-axis. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. figure_kwargs Arguments for ``plt.figure()`` scatter_kwargs Arguments for ``plt.scatter()`` legend_kwargs Arguments for ``plt.legend()`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) legend_kwargs_ = update_defaults_by_kwargs( dict(loc="upper left", markerscale=1.5, handlelength=0.2, handleheight=0.9), legend_kwargs, ) if select_by_value_kwargs is not None: select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) df = self._plot_action_get_df( imputed=True, scaled=False, cat_transformed=True ) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) df = df.loc[indices, :].reset_index(drop=True) derived_data = self.datamodule.get_derived_data_slice( derived_data=self.derived_data, indices=indices ) train_val_test = "User" prediction = { "User": ( self.get_modelbase(program)._predict( df=df, model_name=model_name, derived_data=derived_data ), df[self.label_name].values, ) } else: prediction = self.get_modelbase(program)._predict_model( model_name=model_name, test_data_only=False if train_val_test != "Testing" else True, ) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) def plot_one(name, color, marker): pred_y, y = prediction[name] r2 = metric_sklearn(y, pred_y, "r2") loss = metric_sklearn(y, pred_y, "mse") print(f"{name} MSE Loss: {loss:.4f}, R2: {r2:.4f}") final_y = 10**y if log_trans else y final_y_pred = 10**pred_y if log_trans else pred_y if kde_color: xy = np.hstack([final_y, final_y_pred]).T z = st.gaussian_kde(xy)(xy) scatter_kwargs_ = update_defaults_by_kwargs( scatter_kwargs, dict(c=z, color=None) ) else: scatter_kwargs_ = update_defaults_by_kwargs( dict(color=color), scatter_kwargs ) scatter_kwargs_ = update_defaults_by_kwargs( dict( s=20, marker=marker, label=f"{name} dataset ($R^2$={r2:.3f})", linewidth=0.4, edgecolors="k", ), scatter_kwargs_, ) ax.scatter(final_y, final_y_pred, **scatter_kwargs_) if train_val_test == "all": plot_one("Training", clr[0], "o") plot_one("Validation", clr[1], "o") plot_one("Testing", clr[2], "o") else: plot_one(train_val_test, clr[0], "o") if log_trans: ax.set_xscale("log") ax.set_yscale("log") if central_line: ax.plot( np.linspace(0, 10**upper_lim, 100), np.linspace(0, 10**upper_lim, 100), "--", c="grey", alpha=0.2, ) locmin = matplotlib.ticker.LogLocator( base=10.0, subs=[0.1 * x for x in range(10)], numticks=20 ) # ax.set_aspect("equal", "box") ax.xaxis.set_minor_locator(locmin) ax.yaxis.set_minor_locator(locmin) ax.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter()) ax.yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter()) ax.set_xlim(1, 10**upper_lim) ax.set_ylim(1, 10**upper_lim) ax.set_box_aspect(1) else: # ax.set_aspect("equal", "box") lx, rx = ax.get_xlim() ly, ry = ax.get_ylim() l = np.min([lx, ly]) r = np.max([rx, ry]) if central_line: ax.plot( np.linspace(l, r, 100), np.linspace(l, r, 100), "--", c="grey", alpha=0.2, ) ax.set_xlim(left=l, right=r) ax.set_ylim(bottom=l, top=r) ax.set_box_aspect(1) ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, program, f"{model_name.replace('/', '_')}_truth_pred.pdf", ), disable=given_ax, ax_or_fig=ax, xlabel="Ground truth", ylabel="Prediction", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def cal_feature_importance( self, program: str, model_name: str, method: str = "permutation", **kwargs ) -> Tuple[np.ndarray, List[str]]: """ Calculate feature importance using a specified model. If the model base is a :class:`tabensemb.model.TorchModel`, ``captum`` or ``shap`` is called to make permutations. If the model base is only a :class:`tabensemb.model.AbstractModel`, the calculation will be much slower. Parameters ---------- program The selected model base. model_name The selected model in the model base. method The method to calculate importance. "permutation" or "shap". kwargs kwargs for :meth:`tabensemb.model.AbstractModel.cal_feature_importance` Returns ------- attr Values of feature importance. importance_names Corresponding feature names. If the model base is a ``TorchModel``, all features including derived unstacked features will be included. Otherwise, only :meth:`all_feature_names` will be considered. See Also -------- :meth:`tabensemb.model.AbstractModel.cal_feature_importance`, :meth:`tabensemb.model.TorchModel.cal_feature_importance` """ modelbase = self.get_modelbase(program) return modelbase.cal_feature_importance( model_name=model_name, method=method, **kwargs )
[docs] def cal_shap(self, program: str, model_name: str, **kwargs) -> np.ndarray: """ Calculate SHAP values using a specified model. If the model base is a :class:`tabensemb.model.TorchModel`, the ``shap.DeepExplainer`` is used. Otherwise, ``shap.KernelExplainer`` is called, which is much slower, and shap.kmeans is called to summarize the training data to 10 samples as the background data and 10 random samples in the testing set is explained, which will bias the results. Parameters ---------- program The selected model base. model_name The selected model in the model base. kwargs kwargs for :meth:`tabensemb.model.AbstractModel.cal_shap` Returns ------- attr The SHAP values. If the model base is a `TorchModel`, all features including derived unstacked features will be included. Otherwise, only :meth:`all_feature_names` will be considered. See Also -------- :meth:`tabensemb.model.AbstractModel.cal_shap`, :meth:`tabensemb.model.TorchModel.cal_shap` """ modelbase = self.get_modelbase(program) return modelbase.cal_shap(model_name=model_name, **kwargs)
[docs] def plot_feature_importance( self, program: str, model_name: str, method: str = "permutation", importance: np.ndarray = None, feature_names: List[str] = None, clr: Iterable = None, ax=None, figure_kwargs: Dict = None, bar_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, **kwargs, ) -> matplotlib.axes.Axes: """ Plot feature importance of a model using :meth:`cal_feature_importance`. Parameters ---------- program The selected model base. model_name The selected model in the model base. method The method to calculate feature importance. "permutation" or "shap". importance Passing feature importance values directly instead of calling :meth:`tabensemb.model.AbstractModel.cal_feature_importance` internally in this method. feature_names Names of features assigned to each `importance` value. clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). ax ``matplotlib.axes.Axes`` figure_kwargs Arguments for ``plt.figure`` bar_kwargs Arguments for ``seaborn.barplot``. legend_kwargs Arguments for ``plt.legend`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. kwargs Other arguments of :meth:`tabensemb.model.AbstractModel.cal_feature_importance` Returns ------- matplotlib.axes.Axes """ attr, names = ( self.cal_feature_importance( program=program, model_name=model_name, method=method, **kwargs ) if (importance is None and feature_names is None) else (importance, feature_names) ) bar_kwargs_ = update_defaults_by_kwargs( dict(linewidth=1, edgecolor="k", orient="h", saturation=1), bar_kwargs ) figure_kwargs_ = update_defaults_by_kwargs(dict(figsize=(7, 4)), figure_kwargs) where_effective = np.abs(attr) > 1e-5 effective_names = np.array(names)[where_effective] not_effective = list(np.setdiff1d(names, effective_names)) if len(not_effective) > 0: print(f"Feature importance less than 1e-5: {not_effective}") attr = attr[where_effective] ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) df = pd.DataFrame(columns=["feature", "attr", "clr"]) df["feature"] = effective_names df["attr"] = np.abs(attr) / np.sum(np.abs(attr)) df.sort_values(by="attr", inplace=True, ascending=False) df.reset_index(drop=True, inplace=True) ax.set_axisbelow(True) x = df["feature"].values y = df["attr"].values clr = global_palette if clr is None else clr palette = self._plot_action_generate_feature_types_palette(clr=clr, features=x) # ax.set_facecolor((0.97,0.97,0.97)) # plt.grid(axis='x') plt.grid(axis="x", linewidth=0.2) # plt.barh(x,y, color= [clr_map[name] for name in x]) sns.barplot(x=y, y=x, palette=palette, ax=ax, **bar_kwargs_) ax.set_ylabel(None) ax.set_xlabel(None) # ax.set_xlim([0, 1]) legend = self._plot_action_generate_feature_types_legends( clr=clr, ax=ax, legend_kwargs=legend_kwargs ) legend.get_frame().set_alpha(None) legend.get_frame().set_facecolor([1, 1, 1, 0.4]) if method == "permutation": xlabel = "Permutation feature importance" elif method == "shap": xlabel = "SHAP feature importance" else: xlabel = "Feature importance" return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"feature_importance_{program}_{model_name}_{method}.png", ), disable=given_ax, ax_or_fig=ax, xlabel=xlabel, ylabel=None, tight_layout=True, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_partial_dependence_all( self, program: str, model_name: str, fontsize=12, figure_kwargs: Dict = None, get_figsize_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, **kwargs, ) -> Union[None, matplotlib.figure.Figure]: """ Calculate and plot partial dependence plots with bootstrapping. Parameters ---------- program The selected model base. model_name The selected model in the model base. fontsize ``plt.rcParams["font.size"]`` figure_kwargs Arguments for ``plt.figure``. get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. kwargs Arguments for :meth:`plot_partial_dependence`. Returns ------- matplotlib.figure.Figure The figure if ``save_show_close`` is False. """ savefig_kwargs_ = update_defaults_by_kwargs( dict( fname=os.path.join( self.project_root, f"partial_dependence_{program}_{model_name}.pdf" ) ), savefig_kwargs, ) return self.plot_subplots( ls=self.all_feature_names, ls_kwarg_name="feature", meth_name="plot_partial_dependence", meth_fix_kwargs=dict(program=program, model_name=model_name, **kwargs), fontsize=fontsize, with_title=True, xlabel=r"Value of predictors ($10\%$-$90\%$ percentile)", ylabel="Predicted target", get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs_, tqdm_active=tqdm_active, )
[docs] def plot_partial_dependence( self, program: str, model_name: str, feature: str, ax=None, refit: bool = True, log_trans: bool = True, lower_lim: float = 2, upper_lim: float = 7, n_bootstrap: int = 1, grid_size: int = 30, CI: float = 0.95, verbose: bool = True, figure_kwargs: Dict = None, plot_kwargs: Dict = None, fill_between_kwargs: Dict = None, bar_kwargs: Dict = None, hist_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Calculate and plot a partial dependence plot with bootstrapping for a feature. Parameters ---------- program The selected model base. model_name The selected model in the model base. feature The selected feature to calculate partial dependence. ax ``matplotlib.axes.Axes`` refit Whether to refit models on bootstrapped datasets. See :meth:`_bootstrap_fit`. log_trans Whether the label data is in log scale. lower_lim Lower limit of all pdp plots. upper_lim Upper limit of all pdp plot. n_bootstrap The number of bootstrap evaluations. It should be greater than 0. grid_size The number of steps of all pdp plot. CI The confidence interval of pdp results calculated across multiple bootstrap runs. verbose Verbosity figure_kwargs Arguments for ``plt.figure``. plot_kwargs Arguments for ``ax.plot``. fill_between_kwargs Arguments for ``ax.fill_between``. bar_kwargs Arguments for ``ax.bar`` (used for frequencies of categorical features). hist_kwargs Arguments for ``ax.hist`` (used for histograms of continuous features). savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ ( x_values_list, mean_pdp_list, ci_left_list, ci_right_list, ) = self.cal_partial_dependence( feature_subset=[feature], program=program, model_name=model_name, df=self.datamodule.X_train, derived_data=self.datamodule.D_train, n_bootstrap=n_bootstrap, refit=refit, grid_size=grid_size, percentile=90, CI=CI, average=True, ) x_values = x_values_list[0] mean_pdp = mean_pdp_list[0] ci_left = ci_left_list[0] ci_right = ci_right_list[0] figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) plot_kwargs_ = update_defaults_by_kwargs( dict(color="k", linewidth=0.7), plot_kwargs ) fill_between_kwargs_ = update_defaults_by_kwargs( dict(alpha=0.4, color="k", edgecolor=None), fill_between_kwargs ) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) def transform(value): if log_trans: return 10**value else: return value if feature not in self.cat_feature_names: ax.plot(x_values, transform(mean_pdp), **plot_kwargs_) ax.fill_between( x_values, transform(ci_left), transform(ci_right), **fill_between_kwargs_, ) else: yerr = ( np.abs( np.vstack([transform(ci_left), transform(ci_right)]) - transform(mean_pdp) ) if not np.isnan(ci_left).any() else None ) ax.errorbar(x_values, transform(mean_pdp), yerr=yerr, **plot_kwargs_) # ax.set_xlim([0, 1]) if log_trans: ax.set_yscale("log") ax.set_ylim([10**lower_lim, 10**upper_lim]) locmin = matplotlib.ticker.LogLocator( base=10.0, subs=[0.1 * x for x in range(10)], numticks=20 ) # ax.xaxis.set_minor_locator(locmin) ax.yaxis.set_minor_locator(locmin) # ax.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter()) ax.yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter()) if np.min(x_values) < np.max(x_values): ax2 = ax.twinx() hist_kwargs_ = update_defaults_by_kwargs( dict(bins=x_values, alpha=0.2, color="k"), hist_kwargs ) bar_kwargs_ = update_defaults_by_kwargs( dict(alpha=0.2, color="k"), bar_kwargs ) self.plot_hist( feature=feature, ax=ax2, imputed=False, x_values=x_values, hist_kwargs=hist_kwargs_, bar_kwargs=bar_kwargs_, ) ax2.set_yticks([]) else: ax2 = ax.twinx() ax2.text(0.5, 0.5, "Invalid interval", ha="center", va="center") ax2.set_xlim([0, 1]) ax2.set_ylim([0, 1]) ax2.set_yticks([]) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"partial_dependence_{program}_{model_name}_{feature}.pdf", ), disable=given_ax, ax_or_fig=ax, xlabel=feature + r" ($10\%$-$90\%$ percentile)", ylabel="Predicted target", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def cal_partial_dependence( self, feature_subset: List[str] = None, **kwargs ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[np.ndarray]]: """ Calculate partial dependency. See the source code of :meth:`plot_partial_dependence` for its usage. Parameters ---------- feature_subset A subset of :meth:`all_feature_names`. kwargs Arguments for :meth:`_bootstrap_fit`. Returns ------- list x values for each feature list pdp values for each feature list lower confidence limits for each feature list upper confidence limits for each feature """ x_values_list = [] mean_pdp_list = [] ci_left_list = [] ci_right_list = [] for feature_idx, feature_name in enumerate( self.all_feature_names if feature_subset is None else feature_subset ): print("Calculate PDP: ", feature_name) x_value, model_predictions, ci_left, ci_right = self._bootstrap_fit( focus_feature=feature_name, **kwargs ) x_values_list.append(x_value) mean_pdp_list.append(model_predictions) ci_left_list.append(ci_left) ci_right_list.append(ci_right) return x_values_list, mean_pdp_list, ci_left_list, ci_right_list
[docs] def plot_partial_dependence_2way_all( self, program: str, model_name: str, x_feature: str, y_features: List[str] = None, fontsize=12, figure_kwargs: Dict = None, get_figsize_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, **kwargs, ) -> Union[None, matplotlib.figure.Figure]: """ Calculate and plot 2-way partial dependence plots with bootstrapping. One continuous feature is fixed for x-axis. The rest of the continuous features are on y-axis, respectively. Parameters ---------- program The selected model base. model_name The selected model in the model base. x_feature The continuous feature fixed for x-axis. y_features Continuous features on y-axis respectively. If None, all other continuous features are used. fontsize ``plt.rcParams["font.size"]`` figure_kwargs Arguments for ``plt.figure``. get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. kwargs Arguments for :meth:`plot_partial_dependence_2way`. Returns ------- matplotlib.figure.Figure The figure if ``save_show_close`` is False. """ y_features = ( y_features if y_features is not None else [x for x in self.cont_feature_names if x != x_feature] ) savefig_kwargs_ = update_defaults_by_kwargs( dict( fname=os.path.join( self.project_root, f"partial_dependence_2way_{program}_{model_name}_{x_feature}.pdf", ) ), savefig_kwargs, ) return self.plot_subplots( ls=y_features, ls_kwarg_name="y_feature", meth_name="plot_partial_dependence_2way", meth_fix_kwargs=dict( x_feature=x_feature, program=program, model_name=model_name, **kwargs ), fontsize=fontsize, with_title=True, xlabel=r"Value of the fixed predictors", ylabel="Value of other predictors", get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs_, tqdm_active=tqdm_active, )
[docs] def plot_partial_dependence_2way( self, x_feature: str, y_feature: str, program: str, model_name: str, df: pd.DataFrame, derived_data: Dict[str, np.ndarray], ax: matplotlib.axes.Axes = None, projection: str = "3d", grid_size: int = 10, percentile: Union[int, float] = 100, figure_kwargs: Dict = None, imshow_kwargs: Dict = None, surf_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, **kwargs, ): """ Calculate and plot a 2-way partial dependence plot with bootstrapping for a pair of features. Parameters ---------- x_feature A continuous feature. y_feature A continuous feature. program The selected model base. model_name The selected model in the model base. ax ``matplotlib.axes.Axes`` projection None or "3d". Will use ``matplotlib.pyplot.imshow`` for None and ``matplotlib.pyplot.plot_surface`` for "3d". grid_size The number of sequential values. percentile The percentile of the feature used to generate sequential values. df The tabular dataset. derived_data The derived data calculated using :meth:`derive_unstacked`. kwargs Other arguments for :meth:`cal_partial_dependence_2way`. figure_kwargs Arguments for ``plt.savefig`` savefig_kwargs Arguments for ``plt.savefig`` imshow_kwargs Arguments for ``plt.imshow`` surf_kwargs Arguments for ``plt.plot_surface`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. kwargs Arguments for :meth:`cal_partial_dependence_2way`. Returns ------- matplotlib.axes.Axes """ from matplotlib import cm figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) imshow_kwargs_ = update_defaults_by_kwargs(dict(), imshow_kwargs) surf_kwargs_ = update_defaults_by_kwargs( dict(cmap=cm.coolwarm, linewidth=0, antialiased=False), surf_kwargs ) given_ax = ax is not None if not given_ax: fig = plt.figure(**figure_kwargs_) ax = plt.subplot(111, projection=projection) plt.sca(ax) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) X, Y, Z = self.cal_partial_dependence_2way( x_feature=x_feature, y_feature=y_feature, grid_size=grid_size, percentile=percentile, program=program, model_name=model_name, derived_data=derived_data, df=df, **kwargs, ) if projection != "3d": ax.imshow(np.rot90(Z), **imshow_kwargs_) ax.set_xticks(np.arange(len(X))) ax.set_yticks(np.arange(len(Y))) ax.set_xticklabels([round(x, 2) for x in X[:, 0]]) ax.set_yticklabels([round(x, 2) for x in Y[0, ::-1]]) else: ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.xaxis.pane.set_edgecolor("w") ax.yaxis.pane.set_edgecolor("w") ax.zaxis.pane.set_edgecolor("w") surf = ax.plot_surface(X, Y, Z, **surf_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"partial_dependence_2way_{program}_{model_name}_{x_feature}_{y_feature}.pdf", ), disable=given_ax, ax_or_fig=ax, xlabel=x_feature + r" (${}\%$-${}\%$ percentile)".format(100 - percentile, percentile), ylabel=y_feature + r" (${}\%$-${}\%$ percentile)".format(100 - percentile, percentile), tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def cal_partial_dependence_2way( self, x_feature: str, y_feature: str, grid_size: int = 10, percentile: Union[int, float] = 100, x_min: Union[int, float] = None, x_max: Union[int, float] = None, y_min: Union[int, float] = None, y_max: Union[int, float] = None, df: pd.DataFrame = None, **kwargs, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Calculate 2-way partial dependency. See the source code of :meth:`plot_partial_dependence_2way` for its usage. Parameters ---------- x_feature A continuous feature. y_feature A continuous feature. grid_size The number of sequential values. percentile The percentile of the feature used to generate sequential values. x_min The lower limit of the generated sequential values of the first feature. It will override the left percentile. x_max The upper limit of the generated sequential values of the first feature. It will override the right percentile. y_min The lower limit of the generated sequential values of the second feature. It will override the left percentile. y_max The upper limit of the generated sequential values of the second feature. It will override the right percentile. df The tabular dataset. kwargs Other arguments for :meth:`_bootstrap_fit`. The above `grid_size`, `percentile`, `y_min`, `y_max` are passed to it for the second feature. Returns ------- list The grid of the first feature list The grid of the second feature list pdp values of each first-feature value and each second-feature value in grids. """ y_values_list = [] mean_pdp_list = [] df = df if df is not None else self.df df = df.copy() x_values_list = list( self._generate_grid( feature=x_feature, grid_size=grid_size, percentile=percentile, x_min=x_min, x_max=x_max, df=df, ) ) for x_val in x_values_list: df[x_feature] = x_val x_value, model_predictions, _, _ = self._bootstrap_fit( focus_feature=y_feature, df=df, grid_size=grid_size, percentile=percentile, x_min=y_min, x_max=y_max, **kwargs, ) y_values_list.append(x_value) mean_pdp_list.append(model_predictions) return ( np.repeat(np.array(x_values_list).reshape(1, -1), grid_size, axis=0).T, np.array(y_values_list), np.array(mean_pdp_list), )
[docs] def plot_partial_err_all( self, program: str, model_name: str, fontsize=12, figure_kwargs: Dict = None, get_figsize_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, **kwargs, ) -> Union[None, matplotlib.figure.Figure]: """ Calculate prediction absolute errors on the testing dataset, and plot histograms of high-error samples and low-error samples respectively. Parameters ---------- program The selected model base. model_name The selected model in the model base. fontsize ``plt.rcParams["font.size"]`` figure_kwargs Arguments for ``plt.figure``. get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. kwargs Arguments for :meth:`plot_partial_err` Returns ------- matplotlib.figure.Figure The figure if ``save_show_close`` is False. """ savefig_kwargs_ = update_defaults_by_kwargs( dict( fname=os.path.join( self.project_root, f"partial_err_{program}_{model_name}.pdf" ) ), savefig_kwargs, ) return self.plot_subplots( ls=self.all_feature_names, ls_kwarg_name="feature", meth_name="plot_partial_err", meth_fix_kwargs=dict(program=program, model_name=model_name, **kwargs), fontsize=fontsize, with_title=True, xlabel="Value of predictors", ylabel="Prediction absolute error", get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs_, tqdm_active=tqdm_active, )
[docs] def plot_partial_err( self, program: str, model_name: str, feature, thres=0.8, ax=None, clr: Iterable = None, figure_kwargs: Dict = None, scatter_kwargs: Dict = None, hist_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Calculate prediction absolute errors on the testing dataset, and plot histograms of high-error samples and low-error samples respectively for a single feature. Parameters ---------- program The selected model base. model_name The selected model in the model base. feature The selected feature. thres The absolute error threshold to identify high-error samples and low-error samples. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). figure_kwargs Arguments for ``plt.figure``. scatter_kwargs Arguments for ``ax.scatter()`` hist_kwargs Arguments for ``ax.hist()`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) scatter_kwargs_ = update_defaults_by_kwargs(dict(s=1), scatter_kwargs) hist_kwargs_ = update_defaults_by_kwargs( dict(density=True, alpha=0.2, rwidth=0.8), hist_kwargs ) feature_data = self.df.loc[ np.array(self.test_indices), self.all_feature_names ].reset_index(drop=True) truth = self.label_data.loc[self.test_indices, :].values.flatten() modelbase = self.get_modelbase(program) pred = modelbase.predict( df=self.datamodule.X_test, derived_data=self.datamodule.D_test, model_name=model_name, ).flatten() err = np.abs(truth - pred) high_err_data = feature_data.loc[np.where(err > thres)[0], :] high_err = err[np.where(err > thres)[0]] low_err_data = feature_data.loc[np.where(err <= thres)[0], :] low_err = err[np.where(err <= thres)[0]] ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) ax.scatter( high_err_data[feature].values, high_err, color=clr[0], marker="s", **scatter_kwargs_, ) ax.scatter( low_err_data[feature].values, low_err, color=clr[1], marker="^", **scatter_kwargs_, ) ax.set_ylim([0, np.max(err) * 1.1]) ax2 = ax.twinx() ax2.hist( [ high_err_data[feature].values, low_err_data[feature].values, ], bins=np.linspace( np.min(feature_data[feature].values), np.max(feature_data[feature].values), 20, ), color=clr[:2], **hist_kwargs_, ) if feature in self.cat_feature_names: ticks = np.sort(np.unique(feature_data[feature].values)).astype(int) tick_label = [self.cat_feature_mapping[feature][x] for x in ticks] ax.set_xticks(ticks) ax.set_xticklabels(tick_label) ax.set_xlim([-0.5, len(ticks) - 0.5]) ax2.set_xlim([-0.5, len(ticks) - 0.5]) # sns.rugplot(data=chosen_data, height=0.05, ax=ax2, color='k') # ax2.set_ylim([0,1]) # ax2.set_xlim([np.min(x_values_list[idx]), np.max(x_values_list[idx])]) ax2.set_yticks([]) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"partial_err_{program}_{model_name}_{feature}.pdf" ), disable=given_ax, ax_or_fig=ax, xlabel=feature, ylabel="Prediction absolute error", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_err_hist( self, program: str, model_name: str, category: str = None, metric: str = None, ax=None, legend=True, clr: Iterable = None, figure_kwargs: Dict = None, hist_kwargs: Dict = None, select_by_value_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot histograms of prediction errors. Parameters ---------- program The selected model base. model_name The selected model in the model base. category The category to classify histograms and stack them with different colors. metric The metric to be calculated. It should be supported by :func:`tabenseb.utils.utils.auto_metric_sklearn`. ax ``matplotlib.axes.Axes`` legend Show legends if ``category`` is not None. clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). figure_kwargs Arguments for ``plt.figure``. hist_kwargs Arguments for ``ax.hist`` (used for histograms of continuous features). select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. legend_kwargs Arguments for ``plt.legend`` if ``legend`` is True and ``category`` is not None. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) hist_kwargs_ = update_defaults_by_kwargs( dict(density=True, color=clr[0], rwidth=0.95, bins=20), hist_kwargs ) legend_kwargs_ = update_defaults_by_kwargs(dict(), legend_kwargs) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) df = self._plot_action_get_df( imputed=True, scaled=False, cat_transformed=False ).loc[indices, :] derived_data = self.datamodule.get_derived_data_slice( self.datamodule.derived_data, indices=indices ) pred = self.get_modelbase(program=program).predict( df=df, model_name=model_name, derived_data=derived_data, proba=True ) truth = df[self.label_name].values metric = ( metric if metric is not None else ("rmse" if self.datamodule.task == "regression" else "log_loss") ) metrics = np.array( [ auto_metric_sklearn( t, p, metric=metric, task=self.datamodule.task, ) for t, p in zip(truth, pred) ] ) if category is not None: category_data, unique_values = self._plot_action_category_unique_values( df=df, category=category ) metrics = [ metrics[np.where(category_data == val)[0]] for val in unique_values ] hist_kwargs_.update( dict( color=clr[: len(unique_values)], label=unique_values.astype(str), stacked=True, ) ) ax.hist(metrics, **hist_kwargs_) if legend: ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"err_hist.pdf"), disable=given_ax, ax_or_fig=ax, xlabel=metric.upper(), ylabel="Density" if hist_kwargs_["density"] else "Frequency", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_corr( self, fontsize: Any = 10, imputed=False, features: List[str] = None, method: Union[str, Callable] = "pearson", include_label: bool = True, ax=None, figure_kwargs: Dict = None, imshow_kwargs: Dict = None, select_by_value_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot correlation coefficients among features and the target. Parameters ---------- fontsize The ``fontsize`` argument for matplotlib. imputed Whether the imputed dataset should be considered. If False, some NaN coefficients may exist for features with missing values. features A subset of continuous features to calculate correlations on. method The argument of ``pd.DataFrame.corr``. "pearson", "kendall", "spearman" or Callable. include_label If True, the target is also considered. ax ``matplotlib.axes.Axes`` figure_kwargs Arguments for ``plt.figure``. imshow_kwargs Arguments for ``plt.imshow``. select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ figure_kwargs_ = update_defaults_by_kwargs( dict(figsize=(10, 10)), figure_kwargs ) imshow_kwargs_ = update_defaults_by_kwargs(dict(cmap="bwr"), imshow_kwargs) cont_feature_names = ( self.cont_feature_names if features is None else features ) + (self.label_name if include_label else []) # sns.reset_defaults() ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) plt.box(on=True) corr = ( self.datamodule.cal_corr( method=method, imputed=imputed, features_only=False, select_by_value_kwargs=select_by_value_kwargs, ) .loc[cont_feature_names, cont_feature_names] .values ) im = ax.imshow(corr, **imshow_kwargs_) ax.set_xticks(np.arange(len(cont_feature_names))) ax.set_yticks(np.arange(len(cont_feature_names))) ax.set_xticklabels(cont_feature_names, fontsize=fontsize) ax.set_yticklabels(cont_feature_names, fontsize=fontsize) plt.setp( ax.get_xticklabels(), rotation=90, va="center", ha="right", rotation_mode="anchor", ) norm_corr = corr - (np.nanmax(corr) + np.nanmin(corr)) / 2 norm_corr /= np.nanmax(norm_corr) for i in range(len(cont_feature_names)): for j in range(len(cont_feature_names)): text = ax.text( j, i, round(corr[i, j], 2), ha="center", va="center", color="w" if np.abs(norm_corr[i, j]) > 0.3 else "k", fontsize=fontsize, ) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"corr{'_imputed' if imputed else ''}.pdf" ), disable=given_ax, ax_or_fig=ax, tight_layout=True, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_corr_with_label( self, imputed=False, features: List[str] = None, order: str = "alphabetic", method: str = "pearson", clr=None, ax=None, figure_kwargs: Dict = None, barplot_kwargs: Dict = None, select_by_value_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot correlation coefficients between the target and each feature. Parameters ---------- imputed Whether the imputed dataset should be considered. If False, some NaN coefficients may exist for features with missing values. features A subset of continuous features to calculate correlations on. order The order of features. "alphabetic", "ascending", or "descending". method The argument of ``pd.DataFrame.corr``. "pearson", "kendall", "spearman" or Callable. clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). ax ``matplotlib.axes.Axes`` figure_kwargs Arguments for ``plt.figure``. imshow_kwargs Arguments for ``plt.imshow``. select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. legend_kwargs Arguments for ``plt.legend`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ figure_kwargs_ = update_defaults_by_kwargs(dict(figsize=(8, 5)), figure_kwargs) barplot_kwargs_ = update_defaults_by_kwargs( dict( orient="h", linewidth=1, edgecolor="k", saturation=1, ), barplot_kwargs, ) legend_kwargs_ = update_defaults_by_kwargs(dict(), legend_kwargs) is_horizontal = barplot_kwargs_["orient"] == "h" cont_feature_names = self.cont_feature_names if features is None else features # sns.reset_defaults() ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) plt.box(on=True) corr = ( self.datamodule.cal_corr( method=method, imputed=imputed, features_only=False, select_by_value_kwargs=select_by_value_kwargs, ) .loc[cont_feature_names, self.label_name] .values.flatten() ) df = pd.DataFrame(data={"feature": cont_feature_names, "correlation": corr}) df.sort_values( by="feature" if order == "alphabetic" else "correlation", ascending=order != "descending", inplace=True, ) clr = global_palette if clr is None else clr palette = self._plot_action_generate_feature_types_palette( clr=clr, features=df["feature"] ) sns.barplot( data=df, x="correlation" if is_horizontal else "feature", y="feature" if is_horizontal else "correlation", ax=ax, palette=palette, **barplot_kwargs_, ) ax.set_xlabel(None) ax.set_ylabel(None) legend = self._plot_action_generate_feature_types_legends( clr=clr, ax=ax, legend_kwargs=legend_kwargs_ ) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"corr_with_label{'_imputed' if imputed else ''}.pdf" ), disable=given_ax, ax_or_fig=ax, xlabel=f"Correlation with {self.label_name[0]}" if is_horizontal else None, ylabel=( f"Correlation with {self.label_name[0]}" if not is_horizontal else None ), tight_layout=True, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_pairplot( self, imputed: bool = False, features: List[str] = None, include_label=True, pairplot_kwargs: Dict = None, select_by_value_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> Union[None, sns.axisgrid.PairGrid]: """ Plot ``seaborn.pairplot`` among features and label. Kernel Density Estimation plots are on the diagonal. Parameters ---------- imputed Whether the imputed dataset should be considered. features A subset of continuous features to plot pairplots for. include_label If True, the target is also considered. pairplot_kwargs Arguments for ``seaborn.pairplot``. select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``seaborn.axisgrid.PairGrid`` instance. """ pairplot_kwargs_ = update_defaults_by_kwargs( dict(corner=True, diag_kind="kde"), pairplot_kwargs ) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) cont_feature_names = ( self.cont_feature_names if features is None else features ) + (self.label_name if include_label else []) df_all = self._plot_action_get_df( imputed=imputed, scaled=False, cat_transformed=False )[cont_feature_names] indices = self.datamodule.select_by_value(**select_by_value_kwargs_) grid = sns.pairplot(df_all.loc[indices, :], **pairplot_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, "pair.jpg"), disable=False, ax_or_fig=grid, tight_layout=True, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_feature_box( self, imputed: bool = False, features: List[str] = None, ax=None, clr: Iterable = None, figure_kwargs: Dict = None, boxplot_kwargs: Dict = None, select_by_value_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot boxplot of the tabular data. Parameters ---------- imputed Whether the imputed dataset should be considered. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). figure_kwargs Arguments for ``plt.figure`` boxplot_kwargs Arguments for ``seaborn.boxplot`` select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(figsize=(6, 6)), figure_kwargs) boxplot_kwargs_ = update_defaults_by_kwargs( dict( orient="h", linewidth=1, fliersize=2, flierprops={"marker": "o"}, color=clr[0], saturation=1, ), boxplot_kwargs, ) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) # sns.reset_defaults() ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) data = self._plot_action_get_df( imputed=imputed, scaled=True, cat_transformed=False )[self.cont_feature_names if features is None else features] bp = sns.boxplot( data=data.loc[indices, :], ax=ax, **boxplot_kwargs_, ) ax.set_ylabel(None) ax.set_xlabel(None) boxes = [] for x in ax.get_children(): if isinstance(x, matplotlib.patches.PathPatch): boxes.append(x) for patch in boxes: patch.set_facecolor(clr[0]) plt.grid(linewidth=0.4, axis="x") ax.set_axisbelow(True) # ax.tick_params(axis='x', rotation=90) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"feature_box{'_imputed' if imputed else ''}.pdf" ), disable=given_ax, ax_or_fig=ax, xlabel="Values (Scaled)", ylabel=None, tight_layout=True, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_hist_all( self, imputed=False, fontsize=12, get_figsize_kwargs: Dict = None, figure_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, **kwargs, ) -> matplotlib.figure.Figure: """ Plot histograms of the tabular data. Parameters ---------- imputed Whether the imputed dataset should be considered. figure_kwargs Arguments for ``plt.figure``. fontsize ``plt.rcParams["font.size"]`` get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. **kwargs Arguments for :meth:`plot_hist`. Returns ------- matplotlib.figure.Figure The figure if ``save_show_close`` is False. """ savefig_kwargs_ = update_defaults_by_kwargs( dict( fname=os.path.join( self.project_root, f"hist{'_imputed' if imputed else ''}.pdf" ) ), savefig_kwargs, ) return self.plot_subplots( ls=self.all_feature_names + self.label_name, ls_kwarg_name="feature", meth_name="plot_hist", meth_fix_kwargs=dict(imputed=imputed, **kwargs), fontsize=fontsize, with_title=True, xlabel="Value of predictors", ylabel="Density", get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs_, tqdm_active=tqdm_active, )
[docs] def plot_hist( self, feature: str, ax=None, clr: Iterable = None, imputed=False, kde=False, category: str = None, x_values=None, legend: bool = True, figure_kwargs: Dict = None, hist_kwargs: Dict = None, bar_kwargs: Dict = None, select_by_value_kwargs: Dict = None, kde_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot the histogram of a feature. Parameters ---------- feature The selected feature. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). imputed Whether the imputed dataset should be considered. kde Plot the kernel density estimation along with each histogram of continuous features. category The category to classify histograms and stack them with different colors. x_values Unique values of the `feature`. If None, it will be inferred from the dataset. legend Show legends if ``category`` is not None. figure_kwargs Arguments for ``plt.figure``. bar_kwargs Arguments for ``ax.bar`` (used for frequencies of categorical features). hist_kwargs Arguments for ``ax.hist`` (used for histograms of continuous features). kde_kwargs Arguments for :meth:`plot_kde` when ``kde`` is True. select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. legend_kwargs Arguments for ``plt.legend`` if ``legend`` is True and ``category`` is not None. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) kde_kwargs_ = update_defaults_by_kwargs( dict(imputed=imputed, select_by_value_kwargs=select_by_value_kwargs_), kde_kwargs, ) legend_kwargs_ = update_defaults_by_kwargs(dict(), legend_kwargs) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) hist_data = self._plot_action_get_df( imputed=imputed, scaled=False, cat_transformed=True ) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) hist_data = hist_data.loc[indices, :].reset_index(drop=True) bar_kwargs_ = update_defaults_by_kwargs( dict(color=clr[0], edgecolor=None), bar_kwargs ) hist_kwargs_ = update_defaults_by_kwargs( dict(density=True, color=clr[0], rwidth=0.95, stacked=True), hist_kwargs ) x_values = ( np.sort(np.unique(hist_data[feature].values.flatten())) if x_values is None else x_values ) x_values = x_values[np.isfinite(x_values)] category_data, category_unique_values = ( self._plot_action_category_unique_values(df=hist_data, category=category) if category is not None else (None, None) ) if len(x_values) > 0: values = hist_data[feature] if feature not in self.cat_feature_names: if category is not None: values = [ values[category_data == val] for val in category_unique_values ] hist_kwargs_.update( color=clr[: len(category_unique_values)], label=category_unique_values.astype(str), ) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="All-NaN slice encountered" ) ax.hist(values, **hist_kwargs_) # sns.rugplot(data=chosen_data, height=0.05, ax=ax2, color='k') # ax2.set_ylim([0,1]) if "range" not in hist_kwargs_.keys(): ax.set_xlim([np.min(x_values), np.max(x_values)]) if kde: self.plot_kde( x_col=feature, ax=ax, **kde_kwargs_, ) else: counts = np.array( [len(np.where(values.values == x)[0]) for x in x_values] ) if category is not None: bottom = np.zeros(len(x_values)) for idx, val in enumerate(category_unique_values): category_counts = np.array( [ len( np.where(values[category_data == val].values == x)[ 0 ] ) for x in x_values ] ) bar_kwargs_.update( color=clr[idx], label=str(val), bottom=bottom ) ax.bar( x_values, category_counts, tick_label=[ self.cat_feature_mapping[feature][x] for x in x_values ], **bar_kwargs_, ) bottom += category_counts else: ax.bar( x_values, counts, tick_label=[ self.cat_feature_mapping[feature][x] for x in x_values ], **bar_kwargs_, ) if "range" not in hist_kwargs_.keys(): ax.set_xlim([np.min(x_values) - 0.5, np.max(x_values) + 0.5]) count_range = np.max(counts) - np.min(counts) ax.set_ylim( [ max([np.min(counts) - 0.2 * count_range, 0]), np.max(counts) + 0.2 * count_range, ] ) plt.setp( ax.get_xticklabels(), rotation=90, va="center", ha="right", rotation_mode="anchor", ) if category is not None and legend: ax.legend(**legend_kwargs_) else: ax.text(0.5, 0.5, "Invalid interval", ha="center", va="center") ax.set_xlim([0, 1]) ax.set_ylim([0, 1]) ax.set_yticks([]) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"hist{'_imputed' if imputed else ''}_{feature}.pdf" ), disable=given_ax, ax_or_fig=ax, xlabel=feature, ylabel="Density" if hist_kwargs_["density"] else "Frequency", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_on_one_axes( self, meth_name: Union[str, List], meth_kwargs_ls: List[Dict], twin: bool = False, fontsize: float = 12, xlabel: str = None, ylabel: str = None, twin_ylabel: str = None, ax=None, meth_fix_kwargs: Dict = None, figure_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, legend: bool = False, ) -> matplotlib.axes.Axes: """ Plot multiple items on one ``matplotlib.axes.Axes``. Parameters ---------- meth_name The method or a list of methods to plot multiple items. The method should have an argument named `ax` which indicates the subplot. meth_kwargs_ls A list of arguments of the corresponding ``meth_name`` (except for ``ax``). twin Plot one plot on ``ax`` and the next plot on ``ax.twin()``. fontsize ``plt.rcParams["font.size"]`` xlabel The overall xlabel. ylabel The overall ylabel. twin_ylabel The overall ylabel of the twin x-axis if ``twin`` is True. ax ``matplotlib.axes.Axes`` meth_fix_kwargs Fixed arguments of ``meth_name`` (except for ``ax``, ``ls_kwarg_name``, and those given in ``meth_kwargs_ls``). figure_kwargs Arguments for ``plt.figure``. legend_kwargs Arguments for ``plt.legend()`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. legend Whether to show the legend. Returns ------- matplotlib.axes.Axes """ figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) meth_fix_kwargs_ = update_defaults_by_kwargs(dict(), meth_fix_kwargs) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) plt.rcParams["font.size"] = fontsize if isinstance(meth_name, str): meth_name = [meth_name] * len(meth_kwargs_ls) current_ax = ax twin_ax = ax.twinx() if twin else ax for meth, meth_kwargs in zip(meth_name, meth_kwargs_ls): getattr(self, meth)(ax=current_ax, **meth_kwargs, **meth_fix_kwargs_) current_ax = twin_ax if current_ax == ax and twin else ax ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) handlers, labels = ax.get_legend_handles_labels() if twin: twin_ax.set_ylabel(twin_ylabel) handlers_twin, labels_twin = twin_ax.get_legend_handles_labels() handlers += handlers_twin labels += labels_twin if legend: legend_kwargs_ = update_defaults_by_kwargs( dict(handles=handlers, labels=labels), legend_kwargs ) ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, "plot_on_one_axes.pdf"), disable=given_ax, ax_or_fig=ax, tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_scatter( self, x_col: str, y_col: str, category: str = None, ax=None, clr: Iterable = None, imputed: bool = False, kde_color: bool = False, figure_kwargs: Dict = None, scatter_kwargs: Dict = None, select_by_value_kwargs: Dict = None, savefig_kwargs: Dict = None, legend_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot one column against another. Parameters ---------- x_col The column for the x-axis. y_col The column for the y-axis. category The category to classify data points with different colors and markers. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). imputed Whether the imputed dataset should be considered. kde_color Whether the scatters are colored by their KDE density. figure_kwargs Arguments for ``plt.figure``. scatter_kwargs Arguments for ``plt.scatter()`` select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. savefig_kwargs Arguments for ``plt.savefig`` legend_kwargs Arguments for ``plt.legend`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) scatter_kwargs_ = update_defaults_by_kwargs(dict(color=clr[0]), scatter_kwargs) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) legend_kwargs_ = update_defaults_by_kwargs(dict(), legend_kwargs) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) df = self._plot_action_get_df( imputed=imputed, scaled=False, cat_transformed=False ) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) x = df.loc[indices, x_col].values.flatten() y = df.loc[indices, y_col].values.flatten() isna = np.union1d(np.where(np.isnan(x))[0], np.where(np.isnan(y))[0]) notna = np.setdiff1d(np.arange(len(x)), isna) if kde_color: xy = np.vstack([x[notna], y[notna]]) z = st.gaussian_kde(xy)(xy) idx = z.argsort() scatter_kwargs_ = update_defaults_by_kwargs( scatter_kwargs_, dict(c=z[idx], color=None) ) ax.scatter(x[notna][idx], y[notna][idx], **scatter_kwargs_) else: if category is None: ax.scatter(x[notna], y[notna], **scatter_kwargs_) else: df = df.loc[indices, :].reset_index(drop=True) self._plot_action_categorical_scatter( x=x[notna], y=y[notna], df=df.loc[notna, :], category=category, ax=ax, clr=clr, scatter_kwargs=scatter_kwargs_, ) ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"scatter_{x_col}_{y_col}.pdf"), disable=given_ax, ax_or_fig=ax, xlabel=x_col, ylabel=y_col, tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_pdf( self, feature: str, dist: st.rv_continuous = st.norm, ax=None, clr: Iterable = None, imputed: bool = False, figure_kwargs: Dict = None, plot_kwargs: Dict = None, select_by_value_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot the probability density function of a feature. Parameters ---------- feature The investigated feature. dist The distribution to fit. It should be an instance of ``scipy.stats.rv_continuous`` that has ``fit`` and ``pdf`` methods. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). imputed Whether the imputed dataset should be considered. figure_kwargs Arguments for ``plt.figure``. plot_kwargs Arguments for ``plt.plot`` select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) plot_kwargs_ = update_defaults_by_kwargs(dict(color=clr[0]), plot_kwargs) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) df = self._plot_action_get_df( imputed=imputed, scaled=False, cat_transformed=False ) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) df = df.loc[indices, :] ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) values = df[feature].values.flatten() x = np.linspace(np.nanmin(values), np.nanmax(values), 200) pdf = dist.pdf(x, *dist.fit(values[np.isfinite(values)])) ax.plot(x, pdf, **plot_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"pdf_{feature}.pdf"), disable=given_ax, ax_or_fig=ax, xlabel=feature, ylabel="Probability density", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_kde_all( self, imputed=False, fontsize=12, get_figsize_kwargs: Dict = None, figure_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, tqdm_active: bool = False, **kwargs, ) -> matplotlib.figure.Figure: """ Plot the kernel density estimation for each feature in the tabular data. Parameters ---------- imputed Whether the imputed dataset should be considered. figure_kwargs Arguments for ``plt.figure``. fontsize ``plt.rcParams["font.size"]`` get_figsize_kwargs Arguments for :func:`tabensemb.utils.utils.get_figsize`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure, or return the ``matplotlib.figure.Figure`` instance. tqdm_active Whether to use a tqdm progress bar. **kwargs Arguments for :meth:`plot_kde`. Returns ------- matplotlib.figure.Figure The figure if ``save_show_close`` is False. """ savefig_kwargs_ = update_defaults_by_kwargs( dict( fname=os.path.join( self.project_root, f"kdes{'_imputed' if imputed else ''}.pdf" ) ), savefig_kwargs, ) return self.plot_subplots( ls=self.cont_feature_names + self.label_name, ls_kwarg_name="x_col", meth_name="plot_kde", meth_fix_kwargs=dict(imputed=imputed, **kwargs), fontsize=fontsize, with_title=True, xlabel="Value of features", ylabel="Density", get_figsize_kwargs=get_figsize_kwargs, figure_kwargs=figure_kwargs, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs_, tqdm_active=tqdm_active, )
[docs] def plot_kde( self, x_col: str, y_col: str = None, ax=None, clr: Iterable = None, imputed: bool = False, figure_kwargs: Dict = None, kdeplot_kwargs: Dict = None, select_by_value_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot the kernel density estimation of a feature or two features. Parameters ---------- x_col The investigated feature. y_col If not None, a bi-variate distribution will be plotted. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). imputed Whether the imputed dataset should be considered. figure_kwargs Arguments for ``plt.figure``. kdeplot_kwargs Arguments for ``seaborn.kdeplot`` select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) kdeplot_kwargs_ = update_defaults_by_kwargs(dict(color=clr[0]), kdeplot_kwargs) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) df = self._plot_action_get_df( imputed=imputed, scaled=False, cat_transformed=False ) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) df = df.loc[indices, :] ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) sns.kdeplot(data=df, x=x_col, y=y_col, ax=ax, **kdeplot_kwargs_) ax.set_ylabel(None) ax.set_xlabel(None) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"kde_{x_col}{'' if y_col is None else '_'+y_col}.pdf", ), disable=given_ax, ax_or_fig=ax, xlabel=x_col, ylabel="Density" if y_col is None else y_col, tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_presence_ratio( self, order="ratio", ax=None, clr: Iterable = None, figure_kwargs: Dict = None, barplot_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot the ratio of presence of each feature. Parameters ---------- order "ratio" or "type". If is "ratio", the labels will be sorted by the presence ratio. If is "type", the labels will be sorted first by their feature types defined in the configuration, and then sorted by the presence ratio. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). figure_kwargs Arguments for ``plt.figure``. barplot_kwargs Arguments for ``seaborn.barplot`` legend_kwargs Arguments for ``plt.legend`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) barplot_kwargs_ = update_defaults_by_kwargs( dict( hue_order=self.datamodule.unique_feature_types_with_derived(), orient="h", linewidth=1, edgecolor="k", saturation=1, ), barplot_kwargs, ) legend_kwargs_ = update_defaults_by_kwargs( dict(frameon=True, fancybox=True), legend_kwargs ) is_horizontal = barplot_kwargs_["orient"] == "h" cont_mask = self.datamodule.cont_imputed_mask cat_mask = self.datamodule.cat_imputed_mask cont_presence_ratio = np.sum(1 - cont_mask) / cont_mask.shape[0] cat_presence_ratio = np.sum(1 - cat_mask) / cat_mask.shape[0] presence_ratio = pd.concat([cont_presence_ratio, cat_presence_ratio]) presence = pd.DataFrame( { "feature": presence_ratio.index, "ratio": presence_ratio.values, "types": self.datamodule.get_feature_types( list(presence_ratio.index), allow_unknown=True ), } ) presence.sort_values( by=["types", "ratio"] if order == "type" else "ratio", inplace=True ) clr = global_palette if clr is None else clr palette = self._plot_action_generate_feature_types_palette( clr=clr, features=presence["feature"] ) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) ax.set_axisbelow(True) ax.grid(axis="x", linewidth=0.2) sns.barplot( data=presence, x="ratio" if is_horizontal else "feature", y="feature" if is_horizontal else "ratio", ax=ax, palette=palette, **barplot_kwargs_, ) ax.set_ylabel(None) ax.set_xlabel(None) getattr(ax, "set_xlim" if is_horizontal else "set_ylim")([0, 1]) legend = self._plot_action_generate_feature_types_legends( clr=clr, ax=ax, legend_kwargs=legend_kwargs_ ) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"presence_ratio.pdf"), disable=given_ax, ax_or_fig=ax, xlabel="Presence ratio" if is_horizontal else "", ylabel="Presence ratio" if not is_horizontal else "", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_fill_rating( self, ax=None, clr: Iterable = None, category: str = None, legend: bool = True, figure_kwargs: Dict = None, hist_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot the histogram of data point rating which is the percentage of filled features. Parameters ---------- ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). category The category to classify histograms and stack them with different colors. legend Show legends if ``category`` is not None. figure_kwargs Arguments for ``plt.figure``. hist_kwargs Arguments for ``plt.hist``. legend_kwargs Arguments for ``plt.legend`` if ``legend`` is True and ``category`` is not None. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes References ---------- Zhang, Zian, and Zhiping Xu. “Fatigue Database of Additively Manufactured Alloys.” Scientific Data 10, no. 1 (May 2, 2023): 249. """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) hist_kwargs_ = update_defaults_by_kwargs( dict(linewidth=1, edgecolor="k", color=clr[0], density=True), hist_kwargs, ) legend_kwargs_ = update_defaults_by_kwargs(dict(), legend_kwargs) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) cont_mask = self.datamodule.cont_imputed_mask.values cat_mask = self.datamodule.cat_imputed_mask.values cont_presence_features = np.sum(1 - cont_mask, axis=1) cat_presence_features = np.sum(1 - cat_mask, axis=1) rating = (cont_presence_features + cat_presence_features) / len( self.all_feature_names ) if category is not None: # augmented data points should not be included. df = self._plot_action_get_df( imputed=True, scaled=False, cat_transformed=False ).loc[self.datamodule.cont_imputed_mask.index, :] category_data, unique_values = self._plot_action_category_unique_values( df=df, category=category ) rating = [rating[category_data == val] for val in unique_values] hist_kwargs_.update( dict( label=unique_values.astype(str), stacked=True, color=clr[: len(unique_values)], ) ) ax.hist(rating, **hist_kwargs_) ax.set_xlim([0, 1]) if legend and category is not None: ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"fill_rating.pdf"), disable=given_ax, ax_or_fig=ax, xlabel="Fill rating", ylabel="Density", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_categorical_presence_ratio( self, category: str = None, ax=None, orient="h", figure_kwargs: Dict = None, imshow_kwargs: Dict = None, cbar_kwargs: Dict = None, cbar_ax_linewidth: float = 1, cbar_ax_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot the ratio of presence of each feature, but is classified by a categorical variable. Parameters ---------- category The category (usually data sources) to classify data points. ax ``matplotlib.axes.Axes`` figure_kwargs Arguments for ``plt.figure``. imshow_kwargs Arguments for ``plt.imshow``. cbar_kwargs Arguments for ``plt.colorbar``. cbar_ax_linewidth Line width of bounding box of cbar. cbar_ax_kwargs Arguments for ``mpl_toolkits.axes_grid1.inset_locator.inset_axes`` savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) imshow_kwargs_ = update_defaults_by_kwargs(dict(cmap="Blues"), imshow_kwargs) cbar_kwargs_ = update_defaults_by_kwargs(dict(), cbar_kwargs) cont_mask = self.datamodule.cont_imputed_mask cat_mask = self.datamodule.cat_imputed_mask df = self._plot_action_get_df( imputed=False, scaled=False, cat_transformed=False ).loc[cont_mask.index, :] category_data, unique_values = self._plot_action_category_unique_values( df=df, category=category ) mat = np.zeros((len(self.all_feature_names), len(unique_values))) for idx, cls in enumerate(unique_values): cls_indices = df.index[category_data == cls] cont_presence_ratio = np.sum(1 - cont_mask.loc[cls_indices, :]) / len( cls_indices ) cat_presence_ratio = np.sum(1 - cat_mask.loc[cls_indices, :]) / len( cls_indices ) presence_ratio = pd.concat([cont_presence_ratio, cat_presence_ratio]) mat[:, idx] = presence_ratio[self.all_feature_names] ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) im = ax.imshow(mat if orient == "h" else mat.T, **imshow_kwargs_) (ax.set_xticks if orient == "h" else ax.set_yticks)( np.arange(len(unique_values)) ) (ax.set_yticks if orient == "h" else ax.set_xticks)( np.arange(len(self.all_feature_names)) ) (ax.set_xticklabels if orient == "h" else ax.set_yticklabels)(unique_values) (ax.set_yticklabels if orient == "h" else ax.set_xticklabels)( self.all_feature_names ) plt.setp( ax.get_xticklabels(), rotation=45, ha="right", va="center", rotation_mode="anchor", ) from mpl_toolkits.axes_grid1.inset_locator import inset_axes cbar_ax_kwargs_ = update_defaults_by_kwargs( dict( width=f"{1/len(unique_values)*100}%", height="20%", loc="lower left", bbox_to_anchor=(1.05, 0.0, 1, 1), bbox_transform=ax.transAxes, borderpad=0, ), cbar_ax_kwargs, ) axins = inset_axes(ax, **cbar_ax_kwargs_) cbar = ax.figure.colorbar(im, cax=axins, **cbar_kwargs_) cbar.ax.set_ylabel("Presence ratio", rotation=-90, va="bottom") [i.set_linewidth(cbar_ax_linewidth) for i in cbar.ax.spines.values()] cbar.ax.xaxis.set_tick_params(width=cbar_ax_linewidth) cbar.ax.yaxis.set_tick_params(width=cbar_ax_linewidth) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"presence_ratio_{category}.pdf"), disable=given_ax, ax_or_fig=ax, xlabel=None, ylabel=None, tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def plot_pca_2d_visual( self, ax=None, category: str = None, clr: Iterable = None, features: List[str] = None, pca_kwargs: Dict = None, figure_kwargs: Dict = None, scatter_kwargs: Dict = None, legend_kwargs: Dict = None, savefig_kwargs: Dict = None, select_by_value_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Fit a ``sklearn.decomposition.PCA`` on a set of features, and plot its first two principal components as scatters. Parameters ---------- ax ``matplotlib.axes.Axes`` category The category to classify data points with different colors and markers. clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). features A subset of continuous features to fit the PCA. pca_kwargs Arguments for ``sklearn.decomposition.PCA.fit`` figure_kwargs Arguments for ``plt.figure``. scatter_kwargs Arguments for ``plt.scatter`` legend_kwargs Arguments for ``plt.legend`` savefig_kwargs Arguments for ``plt.savefig`` select_by_value_kwargs Arguments for :meth:`tabensemb.data.datamodule.DataModule.select_by_value`. save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr features = self.cont_feature_names if features is None else features figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) pca_kwargs_ = update_defaults_by_kwargs(dict(random_state=0), pca_kwargs) scatter_kwargs_ = update_defaults_by_kwargs(dict(color=clr[0]), scatter_kwargs) legend_kwargs_ = update_defaults_by_kwargs(dict(title=category), legend_kwargs) select_by_value_kwargs_ = update_defaults_by_kwargs( dict(), select_by_value_kwargs ) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) indices = self.datamodule.select_by_value(**select_by_value_kwargs_) df = ( self._plot_action_get_df(imputed=True, scaled=True, cat_transformed=False) .loc[indices, :] .reset_index(drop=True) ) pca = self.datamodule.pca( feature_names=features, indices=indices, **pca_kwargs_ ) low_dim_rep = pca.transform(df[features]) x, y = low_dim_rep[:, 0], low_dim_rep[:, 1] if category is None: ax.scatter(x, y, **scatter_kwargs_) else: self._plot_action_categorical_scatter( x=x, y=y, df=df, category=category, ax=ax, clr=clr, scatter_kwargs=scatter_kwargs_, ) ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join(self.project_root, f"pca_2d_visual_{category}.pdf"), disable=given_ax, ax_or_fig=ax, xlabel="1st principal component", ylabel="2nd principal component", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def _plot_action_category_unique_values( self, df: pd.DataFrame, category: str ) -> Tuple[pd.Series, np.ndarray]: """ Get the category to classify data points and its unique values. Parameters ---------- df The dataframe. The returned Series has the same indices. category The category to classify data points. Returns ------- pd.Series The category np.ndarray Unique values. """ df = self.datamodule.categories_inverse_transform(df) # Same as the procedure in OrdinalEncoder. dtype = get_imputed_dtype(df.dtypes[category]) category_data = ( fill_cat_nan(df[[category]], {category: dtype})[category] if dtype == str else df[category] ) unique_values = np.sort(np.unique(category_data)) return category_data, unique_values
[docs] def _plot_action_categorical_scatter( self, x, y, df: pd.DataFrame, category: str, ax, clr: Iterable, scatter_kwargs: Dict, ): """ Plot scatters whose colors are related to their category. Parameters ---------- x x-values of the scatter plot. y y-values of the scatter plot. df The dataframe whose ``category`` column is used to classify data points. category The column to classify data points. ax ``matplotlib.axes.Axes`` clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). scatter_kwargs Arguments for ``plt.scatter`` """ df = self.datamodule.categories_inverse_transform(df).reset_index(drop=True) category_data, unique_values = self._plot_action_category_unique_values( df=df, category=category ) for idx, cat in enumerate(unique_values): colored_scatter_kwargs_ = scatter_kwargs.copy() colored_scatter_kwargs_.update( { "color": clr[idx % len(clr)], "marker": global_marker[idx % len(global_marker)], } ) cat_indices = np.array(df[category_data == cat].index) ax.scatter( x[cat_indices], y[cat_indices], label=str(cat), **colored_scatter_kwargs_, )
[docs] def plot_loss( self, program: str, model_name: str, ax=None, train_val: str = "both", restored_epoch_mark: bool = True, restored_epoch_mark_if_last: bool = False, legend: bool = True, clr: Iterable = None, plot_kwargs: Dict = None, scatter_kwargs: Dict = None, legend_kwargs: Dict = None, figure_kwargs: Dict = None, savefig_kwargs: Dict = None, save_show_close: bool = True, ) -> matplotlib.axes.Axes: """ Plot loss curves for a model. Parameters ---------- program The selected model base. model_name The selected model in the model base. ax ``matplotlib.axes.Axes`` train_val "train" to plot training loss only. "val" to plot validation loss only. "both" to plot both of them. restored_epoch_mark Plot the best epoch from where the model is restored after training. restored_epoch_mark_if_last Plot the best epoch when it is the last epoch. legend Show legends. clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). plot_kwargs Arguments for ``plt.plot`` scatter_kwargs Arguments for ``plt.scatter`` (used to plot the restored epoch). legend_kwargs Arguments for ``plt.legend``. figure_kwargs Arguments for ``plt.figure``. savefig_kwargs Arguments for ``plt.savefig`` save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. Returns ------- matplotlib.axes.Axes """ clr = global_palette if clr is None else clr figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) plot_kwargs_ = update_defaults_by_kwargs(dict(markersize=4), plot_kwargs) scatter_kwargs_ = update_defaults_by_kwargs( dict( color=clr[2], marker=global_marker[2], s=15, label="Best epoch", zorder=10, ), scatter_kwargs, ) legend_kwargs_ = update_defaults_by_kwargs(dict(), legend_kwargs) ax, given_ax = self._plot_action_init_ax(ax, figure_kwargs_) modelbase = self.get_modelbase(program=program) train_ls = modelbase.train_losses.get(model_name, None) val_ls = modelbase.val_losses.get(model_name, None) restored_epoch = modelbase.restored_epochs.get(model_name, None) if train_ls is None and val_ls is None: raise Exception( f"The model base {program} did not record losses during training in its attributes `train_losses` or " f"`val_losses` (in the `_train_single_model` method). " ) if restored_epoch is None and restored_epoch_mark: warnings.warn( f"The model base {program} did not record the best epoch from where the model is restored in its " f"attribute `restored_epochs` (in the `_train_single_model` method)" ) if train_val in ["both", "train"] and train_ls is not None: train_plot_kwargs = plot_kwargs_.copy() train_plot_kwargs.update( dict(color=clr[0], marker=global_marker[0], label="Training loss") ) ax.plot(np.arange(len(train_ls)), train_ls, **train_plot_kwargs) if train_val in ["both", "val"] and val_ls is not None: val_plot_kwargs = plot_kwargs_.copy() val_plot_kwargs.update( dict(color=clr[1], marker=global_marker[1], label="Validation loss") ) ax.plot(np.arange(len(val_ls)), val_ls, **val_plot_kwargs) if ( restored_epoch is not None and restored_epoch_mark and (restored_epoch < len(val_ls) - 1 or restored_epoch_mark_if_last) ): ax.scatter( restored_epoch, (val_ls if train_val in ["both", "val"] else train_ls)[restored_epoch], **scatter_kwargs_, ) if legend: ax.legend(**legend_kwargs_) return self._plot_action_after_plot( fig_name=os.path.join( self.project_root, f"loss_{train_val}_{program}_{model_name}.pdf" ), disable=given_ax, ax_or_fig=ax, xlabel="Epoch", ylabel=f"{self.datamodule.loss.upper()} loss", tight_layout=False, save_show_close=save_show_close, savefig_kwargs=savefig_kwargs, )
[docs] def _plot_action_generate_feature_types_palette( self, clr: Iterable, features: List[str] ) -> List: """ Generate color palette for each feature according to their types defined in the configuration. Parameters ---------- clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). features A list of features to be plotted. Returns ------- list A list of colors for each feature. It can be used as the argument ``palette`` for seaborn functions. """ type_idx = self.datamodule.get_feature_types_idx( features=features, allow_unknown=True ) palette = [clr[i] for i in type_idx] return palette
[docs] def _plot_action_generate_feature_types_legends( self, clr, ax, legend_kwargs ) -> matplotlib.legend.Legend: """ Generate the legend for feature types defined in the configuration. Parameters ---------- clr A seaborn color palette or an Iterable of colors. For example seaborn.color_palette("deep"). ax ``matplotlib.axes.Axes`` legend_kwargs Arguments for ``plt.legend`` Returns ------- matplotlib.legend.Legend """ clr_map = dict() for idx, feature_type in enumerate( self.datamodule.unique_feature_types_with_derived() ): clr_map[feature_type] = clr[idx] legend_kwargs_ = update_defaults_by_kwargs( dict( loc="lower right", handleheight=2, fancybox=False, frameon=False, ), legend_kwargs, ) legend = ax.legend( handles=[ Rectangle((0, 0), 1, 1, color=value, ec="k", label=key) for key, value in clr_map.items() ], **legend_kwargs_, ) return legend
[docs] def _plot_action_init_ax( self, ax=None, figure_kwargs: Dict = None, return_fig: bool = False ) -> Tuple[matplotlib.axes.Axes, bool]: figure_kwargs_ = update_defaults_by_kwargs(dict(), figure_kwargs) given_ax = ax is not None if not given_ax: fig = plt.figure(**figure_kwargs_) if not return_fig: ax = plt.subplot(111) if isinstance(ax, matplotlib.axes.Axes): plt.sca(ax) return (ax, given_ax) if not return_fig else (fig, given_ax)
[docs] def _plot_action_after_plot( self, fig_name, disable: bool, ax_or_fig=None, xlabel: str = None, ylabel: str = None, save_show_close: bool = True, tight_layout=False, savefig_kwargs: Dict = None, ) -> Union[matplotlib.axes.Axes, matplotlib.figure.Figure, Any]: """ Set the labels of x/y-axis, set the layout, save the current figure, show the figure if in a notebook, and close the figure. Parameters ---------- fig_name The path to save the figure. Can be updated by ``savefig_kwargs`` using the key ``fname`` ax_or_fig ``matplotlib.axes.Axes`` or ``matplotlib.figure.Figure``. If is a ``matplotlib.axes.Axes``, x/y-axis labels will be set using ``xlabel`` and ``ylabel``. disable True to disable the action. ``ax_or_fig`` is still returned. xlabel The label of the x-axis. Will be set only when ``ax_or_fig`` is a ``matplotlib.axes.Axes``. ylabel The label of the y-axis. Will be set only when ``ax_or_fig`` is a ``matplotlib.axes.Axes``. save_show_close Whether to save, show (in the notebook), and close the figure if ``ax`` is not given. tight_layout If True, ``plt.tight_layout`` is called. savefig_kwargs Arguments for ``plt.savefig``. Returns ------- matplotlib.axes.Axes or matplotlib.figure.Figure Just the input ``ax_or_fig`` """ if not disable: if ax_or_fig is not None: if isinstance(ax_or_fig, matplotlib.axes.Axes): if xlabel is not None: ax_or_fig.set_xlabel(xlabel) if ylabel is not None: ax_or_fig.set_ylabel(ylabel) if save_show_close: savefig_kwargs_ = update_defaults_by_kwargs( dict(fname=fig_name), savefig_kwargs ) if tight_layout: plt.tight_layout() os.makedirs(os.path.dirname(savefig_kwargs_["fname"]), exist_ok=True) plt.savefig(**savefig_kwargs_) if is_notebook(): plt.show() plt.close() return ax_or_fig
[docs] def _bootstrap_fit( self, program: str, df: pd.DataFrame, derived_data: Dict[str, np.ndarray], focus_feature: str, model_name: str, n_bootstrap: int = 1, grid_size: int = 30, refit: bool = True, resample: bool = True, percentile: float = 100, x_min: float = None, x_max: float = None, CI: float = 0.95, average: bool = True, inspect_attr_kwargs: Dict = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Make bootstrap resampling, fit the selected model on the resampled data, and assign sequential values to the selected feature to see how the prediction changes with respect to the feature. Cook, Thomas R., et al. Explaining Machine Learning by Bootstrapping Partial Dependence Functions and Shapley Values. No. RWP 21-12. 2021. Parameters ---------- program The selected model base. model_name The selected model in the model base. df The tabular dataset. derived_data The derived data calculated using :meth:`derive_unstacked`. focus_feature The feature to assign sequential values. n_bootstrap The number of bootstrapping, fitting, and assigning runs. grid_size The number of sequential values. refit Whether to fit the model on the bootstrap dataset (with warm_start=True). resample Whether to do bootstrap resampling. Only recommended to False when n_bootstrap=1. percentile The percentile of the feature used to generate sequential values. x_min The lower limit of the generated sequential values. It will override the left percentile. x_max The upper limit of the generated sequential values. It will override the right percentile. CI The confidence interval level to evaluate bootstrapped predictions. average If True, CI will be calculated on results ``(grid_size, n_bootstrap)``where predictions for all samples are averaged for each bootstrap run. If False, CI will be calculated on results ``(grid_size, n_bootstrap*len(df))``. Returns ------- np.ndarray The generated sequential values for the feature. np.ndarray Averaged predictions on the sequential values across multiple bootstrap runs and all samples. np.ndarray The left confidence interval. np.ndarray The right confidence interval. """ from .utils import NoBayesOpt modelbase = self.get_modelbase(program) derived_data = self.datamodule.sort_derived_data(derived_data) df = df.reset_index(drop=True) if focus_feature in self.cont_feature_names: x_value = self._generate_grid( feature=focus_feature, grid_size=grid_size, percentile=percentile, x_min=x_min, x_max=x_max, df=df, ) elif focus_feature in self.cat_feature_names: x_value = np.unique(df[focus_feature].values) else: raise Exception(f"{focus_feature} not available.") expected_value_bootstrap_replications = [] inspects = [] for i_bootstrap in range(n_bootstrap): if resample: df_bootstrap = skresample(df) else: df_bootstrap = df tmp_derived_data = self.datamodule.get_derived_data_slice( derived_data, list(df_bootstrap.index) ) df_bootstrap = df_bootstrap.reset_index(drop=True) bootstrap_model = modelbase.detach_model(model_name=model_name) if refit: with NoBayesOpt(self): bootstrap_model.fit( df_bootstrap, model_subset=[model_name], cont_feature_names=self.datamodule.dataprocessors[ 0 ].record_cont_features, cat_feature_names=self.datamodule.dataprocessors[ 0 ].record_cat_features, label_name=self.label_name, verbose=False, warm_start=True, ) i_inspect = [] bootstrap_model_predictions = [] for value in x_value: df_perm = df_bootstrap.copy() df_perm[focus_feature] = value inspect_attr_kwargs_ = update_defaults_by_kwargs( dict(attributes=[]), inspect_attr_kwargs ) inspect = bootstrap_model.inspect_attr( model_name=model_name, df=df_perm, derived_data=( tmp_derived_data if focus_feature in self.derived_stacked_features else None ), **inspect_attr_kwargs_, ) bootstrap_model_predictions.append(inspect["USER_INPUT"]["prediction"]) i_inspect.append((value, inspect["USER_INPUT"])) if average: expected_value_bootstrap_replications.append( np.mean(np.hstack(bootstrap_model_predictions), axis=0) ) else: expected_value_bootstrap_replications.append( np.hstack(bootstrap_model_predictions) ) inspects.append(i_inspect) expected_value_bootstrap_replications = np.vstack( expected_value_bootstrap_replications ) ci_left = [] ci_right = [] mean_pred = [] for col_idx in range(expected_value_bootstrap_replications.shape[1]): y_pred = expected_value_bootstrap_replications[:, col_idx] if len(y_pred) != 1 and len(np.unique(y_pred)) != 1: ci_int = st.norm.interval(CI, loc=np.mean(y_pred), scale=np.std(y_pred)) else: ci_int = (np.nan, np.nan) ci_left.append(ci_int[0]) ci_right.append(ci_int[1]) mean_pred.append(np.mean(y_pred)) return ( (x_value, np.array(mean_pred), np.array(ci_left), np.array(ci_right)) if inspect_attr_kwargs is None else ( x_value, np.array(mean_pred), np.array(ci_left), np.array(ci_right), inspects, ) )
[docs] def _generate_grid( self, feature: str, grid_size: int, percentile: Union[int, float] = 100, x_min: Union[int, float] = None, x_max: Union[int, float] = None, df: pd.DataFrame = None, ) -> np.ndarray: """ Generate a sequential (linspace) grid for a feature in the tabular dataset. Parameters ---------- feature The focused feature. grid_size The number of sequential values. percentile The percentile of the feature used to generate sequential values. x_min The lower limit of the generated sequential values. It will override the left percentile. x_max The upper limit of the generated sequential values. It will override the right percentile. df The tabular dataset. Returns ------- np.ndarray """ df = df if df is not None else self.df return np.linspace( ( np.nanpercentile(df[feature].values, (100 - percentile) / 2) if x_min is None else x_min ), ( np.nanpercentile(df[feature].values, 100 - (100 - percentile) / 2) if x_max is None else x_max ), grid_size, )
[docs] def load_state(self, trainer: "Trainer"): """ Restore a :class:`Trainer` from a deep-copied state. Parameters ---------- trainer A deep-copied status of a :class:`Trainer`. """ # https://stackoverflow.com/questions/1216356/is-it-safe-to-replace-a-self-object-by-another-object-of-the-same-type-in-a-meth current_root = cp(self.project_root) self.__dict__.update(trainer.__dict__) # The update operation does not change the location of self. However, model bases contains another trainer # that points to another location if the state is loaded from disk. for model in self.modelbases: model.trainer = self self.set_path(current_root, verbose=False) for modelbase in self.modelbases: modelbase.set_path(os.path.join(current_root, modelbase.program))
[docs] def get_best_model(self) -> Tuple[str, str]: """ Get the best model from :attr:`leaderboard`. Returns ------- str The name of a model base where the best model is. model_name The name of the best model. """ if not hasattr(self, "leaderboard"): self.get_leaderboard(test_data_only=True, dump_trainer=False) return ( self.leaderboard["Program"].values[0], self.leaderboard["Model"].values[0], )
[docs] def _metrics( self, predictions: Dict[str, Dict[str, Tuple[np.ndarray, np.ndarray]]], metrics: List[str], test_data_only: bool, ) -> pd.DataFrame: """ Calculate metrics for predictions from :meth:`tabensemb.model.AbstractModel._predict_all`. Parameters ---------- predictions Results from :meth:`tabensemb.model.AbstractModel._predict_all`. metrics The metrics that have been implemented in :func:`tabensemb.utils.utils.metric_sklearn`. test_data_only Whether to evaluate models only on testing datasets. Returns ------- pd.DataFrame A dataframe of metrics. """ df_metrics = pd.DataFrame() for model_name, model_predictions in predictions.items(): df = pd.DataFrame(index=[0]) df["Model"] = model_name for tvt, (y_pred, y_true) in model_predictions.items(): if test_data_only and tvt != "Testing": continue for metric in metrics: metric_value = auto_metric_sklearn( y_true, y_pred, metric, self.datamodule.task ) df[ ( tvt + " " + metric.upper() if not test_data_only else metric.upper() ) ] = metric_value df_metrics = pd.concat([df_metrics, df], axis=0, ignore_index=True) return df_metrics
[docs] def save_trainer( trainer: Trainer, path: Union[os.PathLike, str] = None, verbose: bool = True ): """ Pickling the :class:`Trainer` instance. Parameters ---------- trainer The :class:`Trainer` to be saved. path The folder path to save the :class:`Trainer`. verbose Verbosity. """ import pickle path = os.path.join(trainer.project_root, "trainer.pkl") if path is None else path with open(path, "wb") as outp: pickle.dump(trainer, outp, pickle.HIGHEST_PROTOCOL) if verbose: print( f"Trainer saved. To load the trainer, run trainer = load_trainer(path='{path}')" )
[docs] def load_trainer(path: Union[os.PathLike, str]) -> Trainer: """ Loading a pickled :class:`Trainer`. Paths of the :class:`Trainer` and its model bases (i.e. :attr:`project_root`, :attr:`tabensemb.model.AbstractModel.root`, :attr:`tabensemb.model.base.ModelDict.root`, and :meth:`tabensemb.model.base.ModelDict.model_path.keys`) will be changed. Parameters ---------- path Path of the :class:`Trainer`. Returns ------- trainer The loaded :class:`Trainer`. """ import pickle with open(path, "rb") as inp: trainer = pickle.load(inp) root = os.path.join(*os.path.split(path)[:-1]) trainer.set_path(root, verbose=False) for modelbase in trainer.modelbases: modelbase.set_path(os.path.join(root, modelbase.program)) modelbase.trainer = trainer trainer.datamodule.args = trainer.args return trainer