Source code for tabensemb.data.datasplitter

import numpy as np
from tabensemb.utils import *
from tabensemb.data import AbstractSplitter
import inspect
from sklearn.model_selection import train_test_split
from typing import Type, List, Tuple


[docs] class RandomSplitter(AbstractSplitter): """ Randomly split the dataset. """
[docs] def _split(self, df, cont_feature_names, cat_feature_names, label_name): length = len(df) train_indices, test_indices = train_test_split( np.arange(length), test_size=self.train_val_test[2], shuffle=True ) train_indices, val_indices = train_test_split( train_indices, test_size=self.train_val_test[1] / np.sum(self.train_val_test[0:2]), shuffle=True, ) return train_indices, val_indices, test_indices
@property def support_cv(self): return True
[docs] def _next_cv( self, df: pd.DataFrame, cont_feature_names: List[str], cat_feature_names: List[str], label_name: List[str], cv: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: return self._sklearn_k_fold(np.arange(len(df)), cv)
splitter_mapping = {} clsmembers = inspect.getmembers(sys.modules[__name__], inspect.isclass) for name, cls in clsmembers: if issubclass(cls, AbstractSplitter): splitter_mapping[name] = cls
[docs] def get_data_splitter(name: str) -> Type[AbstractSplitter]: if name not in splitter_mapping.keys(): raise Exception(f"Data splitter {name} not implemented.") elif not issubclass(splitter_mapping[name], AbstractSplitter): raise Exception(f"{name} is not the subclass of AbstractSplitter.") else: return splitter_mapping[name]