Source code for tabensemb.utils.utils

"""
All utilities used in the project.
"""

import os
import os.path
import sys
import warnings
import logging
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker
import matplotlib.patches
from matplotlib.patches import Rectangle
import seaborn as sns
import torch
import torch.optim
from distutils.spawn import find_executable
from importlib import import_module, reload
from functools import partialmethod, partial
import itertools
from copy import deepcopy as cp
from torch.autograd.grad_mode import _DecoratorContextManager
from typing import Any
import tabensemb
from typing import Dict
from sklearn.metrics import *
from io import StringIO

sns.reset_defaults()
# matplotlib.use("Agg")
if find_executable("latex") and tabensemb.setting["matplotlib_usetex"]:
    matplotlib.rc("text", usetex=True)
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = "Times New Roman"
plt.rcParams["figure.autolayout"] = True

global_sns_palette = sns.color_palette("deep")

global_palette = [
    "#166135",
    "#50b9aa",
    "#0d5a9b",
    "#6a77ac",
    "#322051",
    "#c24135",
    "#aa602d",
    "#eea86f",
    "#c56f9c",
    "#cd552e",
    "#ebde4e",
    "#96235d",
    "#2caf91",
    "#f8b68a",
    "#c0e3df",
    "#000000",
    "#662b3d",
    "#eb3882",
    "#1a8e7c",
    "#a89351",
    "#a6cf79",
    "#f6d761",
    "#50abde",
]

global_marker = ["o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "H", "D", "d"]


[docs] def is_notebook() -> bool: """ Check whether the current environment is a notebook. Returns ------- bool True if in a notebook. """ try: from IPython import get_ipython shell = get_ipython().__class__.__name__ if shell == "ZMQInteractiveShell": return True # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return False # Terminal running IPython else: return False # Other type (?) except NameError: return False # Probably standard Python interpreter
[docs] def set_random_seed(seed=0): """ Set random seeds of pytorch (including cuda and dataloaders), numpy, and random. Parameters ---------- seed The random seed. """ set_torch(seed) np.random.seed(seed) random.seed(seed)
[docs] def seed_worker(worker_id): """ For the argument ``worker_init_fn`` of ``torch.utils.data.DataLoader``. """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed)
[docs] def set_torch(seed=0): """ Set the random seed of pytorch, CUDA, and ``torch.utils.data.DataLoader``. """ torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if torch.cuda.is_available(): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" os.environ["PYTHONHASHSEED"] = str(seed) dl = reload_module("torch.utils.data").DataLoader if not dl.__init__.__name__ == "_method": # Actually, setting generator improves reproducibility, but torch._C.Generator does not support pickling. # https://pytorch.org/docs/stable/notes/randomness.html # https://github.com/pytorch/pytorch/issues/43672 dl.__init__ = partialmethod(dl.__init__, worker_init_fn=seed_worker)
[docs] def metric_sklearn(y_true: np.ndarray, y_pred: np.ndarray, metric: str) -> float: """ Calculate metrics using ``sklearn`` APIs. The format of ``y_true`` and ``y_pred`` should follow the requirement of ``metric`` (See https://scikit-learn.org/stable/modules/model_evaluation.html), so we recommend using :func:`auto_metric_sklearn` to automatically deal with different metrics. Parameters ---------- y_true An array of ground truth values. y_pred An array of predictions. metric Use ``tabensemb.utils.utils.REGRESSION_METRICS``, ``tabensemb.utils.utils.BINARY_METRICS``, and ``tabensemb.utils.utils.MULTICLASS_METRICS`` to check all available metrics for regression, binary, and multiclass tasks respectively. Returns ------- float The metric. See Also -------- :func:`auto_metric_sklearn` """ y_true = np.array(y_true) y_pred = np.array(y_pred) if len(y_true.shape) == 2 and y_true.shape[-1] == 1: y_true = y_true.flatten() if len(y_pred.shape) == 2 and y_pred.shape[-1] == 1: y_true = y_true.flatten() if not np.all(np.isfinite(y_pred)): if tabensemb.setting["warn_nan_metric"]: warnings.warn( f"NaNs exist in the tested prediction. A large value (100) is returned instead." f"To disable this and raise an Exception, turn the global setting `warn_nan_metric` to False." ) return 100 else: raise Exception( f"NaNs exist in the tested prediction. To ignore this and return a large value (100) instead, turn " f"the global setting `warn_nan_metric` to True" ) mapping = { "mse": mean_squared_error, "mae": mean_absolute_error, "mape": mean_absolute_percentage_error, "r2": r2_score, "rmse": lambda y_true, y_pred: np.sqrt(mean_squared_error(y_true, y_pred)), "r2_score": r2_score, "median_absolute_error": median_absolute_error, "max_error": max_error, "mean_absolute_error": mean_absolute_error, "mean_squared_error": mean_squared_error, "mean_squared_log_error": mean_squared_log_error, "mean_poisson_deviance": mean_poisson_deviance, "mean_gamma_deviance": mean_gamma_deviance, "mean_pinball_loss": mean_pinball_loss, "accuracy_score": accuracy_score, "top_k_accuracy_score": top_k_accuracy_score, "f1_score": f1_score, "roc_auc_score": roc_auc_score, "average_precision_score": average_precision_score, "precision_score": partial(precision_score, zero_division=0), "recall_score": partial(recall_score, zero_division=0), "log_loss": partial(log_loss), "balanced_accuracy_score": balanced_accuracy_score, "explained_variance_score": explained_variance_score, "brier_score_loss": brier_score_loss, "jaccard_score": jaccard_score, "mean_absolute_percentage_error": mean_absolute_percentage_error, "cohen_kappa_score": cohen_kappa_score, "hamming_loss": hamming_loss, "matthews_corrcoef": matthews_corrcoef, "zero_one_loss": zero_one_loss, "precision_score_macro": partial( precision_score, average="macro", zero_division=0 ), "precision_score_micro": partial( precision_score, average="micro", zero_division=0 ), "precision_score_weighted": partial( precision_score, average="weighted", zero_division=0 ), "recall_score_macro": partial(recall_score, average="macro", zero_division=0), "recall_score_micro": partial(recall_score, average="micro", zero_division=0), "recall_score_weighted": partial( recall_score, average="weighted", zero_division=0 ), "f1_score_macro": partial(f1_score, average="macro"), "f1_score_micro": partial(f1_score, average="micro"), "f1_score_weighted": partial(f1_score, average="weighted"), "jaccard_score_macro": partial(jaccard_score, average="macro"), "jaccard_score_micro": partial(jaccard_score, average="micro"), "jaccard_score_weighted": partial(jaccard_score, average="weighted"), "roc_auc_score_ovr_macro": partial( roc_auc_score, average="macro", multi_class="ovr" ), "roc_auc_score_ovr_weighted": partial( roc_auc_score, average="weighted", multi_class="ovr" ), "roc_auc_score_ovo": partial( roc_auc_score, average="weighted", multi_class="ovo" ), } if metric in mapping.keys(): return mapping[metric](y_true, y_pred) elif metric == "rmse_conserv": y_pred = np.array(cp(y_pred)).reshape(-1, 1) y_true = np.array(cp(y_true)).reshape(-1, 1) where_not_conserv = y_pred > y_true if np.any(where_not_conserv): return mean_squared_error( y_true[where_not_conserv], y_pred[where_not_conserv] ) else: return 0.0 else: raise Exception(f"Metric {metric} not implemented.")
[docs] def convert_proba_to_target(y_pred: np.ndarray, task) -> np.ndarray: """ Convert probabilities of classes to the class of each sample. Parameters ---------- y_pred An array of predicted probabilities. For binary, it should be the probability of the positive class. task "multiclass" or "binary". Returns ------- np.ndarray The class of each sample. 2d array (the second dimension is 1) for multiclass tasks. 0-1 array (1d or 2d depending on the input ``y_pred``) for binary tasks. """ if task == "regression": raise Exception(f"Not supported for regressions tasks.") elif task == "multiclass": return np.argmax(y_pred, axis=-1).reshape(-1, 1) elif task == "binary": return (y_pred > 0.5).astype(int) else: raise Exception(f"Unrecognized task {task}.")
[docs] def convert_target_to_indicator(y_pred: np.ndarray, n_classes: int) -> np.ndarray: """ Convert the class of each sample to class indicator. Parameters ---------- y_pred The class of each sample (not probabilities). It should be a 1d array or a 2d array whose second dimension is 1. n_classes The number of classes. Returns ------- np.ndarray An array of (n_samples, n_classes) where, at each entry, 1 indicates that the sample belongs to this class. """ indicator = np.zeros((y_pred.shape[0], n_classes)) indicator[np.arange(y_pred.shape[0]), y_pred.flatten()] = 1 return indicator
REGRESSION_METRICS = [ "rmse", "mse", "mae", "mape", "r2", # "mean_squared_log_error", "median_absolute_error", # "max_error", "explained_variance_score", # "mean_poisson_deviance", # "mean_gamma_deviance", ] _BINARY_USE_TARGET_METRICS = [ "f1_score", "precision_score", "recall_score", "jaccard_score", "accuracy_score", "balanced_accuracy_score", "cohen_kappa_score", "hamming_loss", "matthews_corrcoef", "zero_one_loss", ] _BINARY_USE_PROB_METRICS = [ "roc_auc_score", "log_loss", "brier_score_loss", ] _BINARY_USE_INDICATOR_METRICS = ["average_precision_score"] BINARY_METRICS = ( _BINARY_USE_TARGET_METRICS + _BINARY_USE_PROB_METRICS + _BINARY_USE_INDICATOR_METRICS ) _MULTICLASS_USE_TARGET_METRICS = [ "accuracy_score", "balanced_accuracy_score", "cohen_kappa_score", "hamming_loss", "matthews_corrcoef", "zero_one_loss", "precision_score_macro", "precision_score_micro", "precision_score_weighted", "recall_score_macro", "recall_score_micro", "recall_score_weighted", "f1_score_macro", "f1_score_micro", "f1_score_weighted", "jaccard_score_macro", "jaccard_score_micro", "jaccard_score_weighted", ] _MULTICLASS_USE_PROB_METRICS = [ "top_k_accuracy_score", "log_loss", "roc_auc_score_ovr_macro", "roc_auc_score_ovr_weighted", "roc_auc_score_ovo", ] MULTICLASS_METRICS = _MULTICLASS_USE_TARGET_METRICS + _MULTICLASS_USE_PROB_METRICS
[docs] def auto_metric_sklearn( y_true: np.ndarray, y_pred: np.ndarray, metric: str, task: str ) -> float: """ Calculate metrics using ``sklearn`` APIs. It automatically deals with different requirements of input shapes for different metrics. Parameters ---------- y_true An array of ground truth values. For classification, it should be the class of each sample. It can be 1d or 2d (the second dimension is 1) for classification tasks. y_pred An array of predictions. For classification, it should be the probabilities of classes. It can be 1d or 2d (the second dimension is 1) for binary classification tasks. metric Use ``tabensemb.utils.utils.REGRESSION_METRICS``, ``tabensemb.utils.utils.BINARY_METRICS``, and ``tabensemb.utils.utils.MULTICLASS_METRICS`` to check all available metrics for regression, binary, and multiclass tasks respectively. task "regression", "multiclass", or "binary". Returns ------- float The metric. """ if task not in ["binary", "multiclass", "regression"]: raise Exception(f"Task {task} does not support auto metrics.") if task in ["multiclass", "binary"] and not ( len(y_true.shape) == 1 or (len(y_true.shape) == 2 and y_true.shape[1] == 1) ): raise Exception( f"Expecting a 1d or 2d (the second dimension is 1) y_true, but got y_true with shape {y_true.shape}." ) if task == "binary" and not ( len(y_pred.shape) == 1 or (len(y_pred.shape) == 2 and y_pred.shape[1] == 1) ): raise Exception( f"Expecting the probability of the positive class, but got y_pred with shape {y_pred.shape}." ) if task == "binary": y_pred = y_pred.flatten() # For classification tasks, y_pred is proba, y_true is an integer array if task == "regression": return metric_sklearn(y_true, y_pred, metric) elif task == "binary": if metric in _BINARY_USE_TARGET_METRICS: return metric_sklearn( y_true, convert_proba_to_target(y_pred, "binary"), metric ) elif metric in _BINARY_USE_PROB_METRICS: return metric_sklearn(y_true, y_pred, metric) elif metric in _BINARY_USE_INDICATOR_METRICS: y_pred_extend = y_pred.reshape(-1, 1) y_pred_2d = np.concatenate([1 - y_pred_extend, y_pred_extend], axis=-1) n_classes = len(np.unique(y_true)) y_true_indicator = convert_target_to_indicator(y_true, n_classes) return metric_sklearn(y_true_indicator, y_pred_2d, metric) else: raise NotImplementedError elif task == "multiclass": if metric in _MULTICLASS_USE_TARGET_METRICS: return metric_sklearn( y_true, convert_proba_to_target(y_pred, task="multiclass"), metric ) if metric in _MULTICLASS_USE_PROB_METRICS: return metric_sklearn(y_true, y_pred, metric) else: raise NotImplementedError
[docs] def str_to_dataframe(s, sep=",", names=None, check_nan_on=None) -> pd.DataFrame: """ Convert a .csv type of string to a dataframe. Parameters ---------- s A .csv type of string. sep The delimiter. names Column labels. check_nan_on Numerical column labels to detect invalid values and replace them with ``np.nan``. Returns ------- pd.DataFrame The converted dataframe. """ df = pd.read_csv(StringIO(s), names=names, sep=sep) if names is not None: if len(df.columns) != len(names) or ( df.dtypes[names[0]] == object and pd.isna(df[names[1:]]).all().all() ): raise Exception( f"pd.read_csv can not handle the delimiters. Consider specifying `sep`." ) if check_nan_on is not None: is_object = df[check_nan_on].dtypes == object object_features = is_object.index[np.where(is_object)[0]] if len(object_features) > 0: print( f"Unknown values are detected in {list(object_features)}. They will be treated as np.nan." ) for feature in object_features: is_nan = np.array( list(map(lambda x: not x.replace(".", "").isnumeric(), df[feature])) ) df.loc[is_nan, feature] = np.nan return df
[docs] def get_figsize(n, max_col, width_per_item, height_per_item, max_width): """ Calculate the ``figsize`` argument of ``matplotlib`` for a figure with subplots. Parameters ---------- n The number of subplots. max_col The maximum number of columns. width_per_item The width of each column if only one row is needed. height_per_item The height of each row. max_width The width of the figure if multiple rows are needed. Returns ------- tuple The ``figsize`` argument of ``matplotlib`` int The number of columns of the figure int The number of rows of the figure """ if n > max_col: width = max_col if n % max_col == 0: height = n // max_col else: height = n // max_col + 1 figsize = (max_width, height_per_item * height) else: figsize = (width_per_item * n, height_per_item) width = n height = 1 return figsize, width, height
[docs] def check_stream(): """ A utility of :func:`HiddenPrints`. """ if not isinstance(sys.stdout, tabensemb.Stream) or not isinstance( sys.stderr, tabensemb.Stream ): return False return True
[docs] class HiddenPrints: """ A context manager that can temporarily hide all ``sys.stdout`` outputs and ``logging`` outputs. It works better when ``sys.stdout`` is not changed after ``tabensemb`` is imported. """
[docs] def __init__(self, disable_logging: bool = True, disable_std: bool = True): """ Parameters ---------- disable_logging Hide ``logging`` outputs disable_std Hide ``sys.stdout`` outputs """ self.disable_logging = disable_logging self.disable_std = disable_std
def __enter__(self): if self.disable_std: if check_stream(): self._stream = tabensemb.stdout_stream.stream self._null_stream = open(os.devnull, "w") tabensemb.stdout_stream.set_stream(self._null_stream) self._path = tabensemb.stdout_stream.path tabensemb.stdout_stream.set_path(None) else: self._original_stdout = sys.stdout sys.stdout = open(os.devnull, "w") if self.disable_logging: self.logging_state = logging.root.manager.disable logging.disable(logging.CRITICAL) def __exit__(self, exc_type, exc_val, exc_tb): if self.disable_std: if check_stream(): self._null_stream.close() tabensemb.stdout_stream.set_stream(self._stream) tabensemb.stdout_stream.set_path(self._path) else: sys.stdout.close() sys.stdout = self._original_stdout if self.disable_logging: logging.disable(self.logging_state)
[docs] class PlainText: """ A context manager that can temporarily redirect all ``sys.stderr`` outputs to ``sys.stdout``. It works better when ``sys.stdout`` and ``sys.stderr`` are not changed after ``tabensemb`` is imported. """
[docs] def __init__(self, disable=False): self.disable = disable
def __enter__(self): if not self.disable: if check_stream(): self._stream = tabensemb.stderr_stream.stream tabensemb.stderr_stream.set_stream("stdout") else: self._original_stderr = sys.stderr sys.stderr = sys.stdout def __exit__(self, exc_type, exc_val, exc_tb): if not self.disable: if check_stream(): tabensemb.stderr_stream.set_stream(self._stream) else: sys.stderr = self._original_stderr
[docs] class global_setting: """ A context manager that temporarily changes the global setting ``tabensemb.setting``. """
[docs] def __init__(self, setting: Dict): self.setting = setting self.original = None
def __enter__(self): self.original = tabensemb.setting.copy() tabensemb.setting.update(self.setting) def __exit__(self, exc_type, exc_val, exc_tb): tabensemb.setting.update(self.original)
[docs] class HiddenPltShow: """ A context manager that temporarily hide all ``matplotlib.pyplot.show()``. """
[docs] def __init__(self): pass
def __enter__(self): def nullfunc(*args, **kwargs): pass self.original = plt.show plt.show = nullfunc def __exit__(self, exc_type, exc_val, exc_tb): plt.show = self.original
[docs] def reload_module(name): """ Re-import the module. Parameters ---------- name The name of the module """ if name not in sys.modules: mod = import_module(name) else: mod = reload(sys.modules.get(name)) return mod
[docs] class TqdmController: """ A controller of ``tqdm`` progress bars, including ``tqdm.tqdm``, ``tqdm.notebook.tqdm``, and ``tqdm.auto.tqdm``. """
[docs] def __init__(self): self.original_init = {} self.disabled = False
[docs] def disable_tqdm(self): def disable_one(name): tq = reload_module(name).tqdm self.original_init[name] = tq.__init__ tq.__init__ = partialmethod(tq.__init__, disable=True) disable_one("tqdm") disable_one("tqdm.notebook") disable_one("tqdm.auto") self.disabled = True
[docs] def enable_tqdm(self): def enable_one(name): tq = reload_module(name).tqdm tq.__init__ = self.original_init[name] if self.disabled: enable_one("tqdm") enable_one("tqdm.notebook") enable_one("tqdm.auto") self.disabled = False
[docs] def debugger_is_active() -> bool: """ Return True if the debugger is currently active """ return hasattr(sys, "gettrace") and sys.gettrace() is not None
[docs] def gini(x: np.ndarray, w: np.ndarray = None) -> float: """ Calculate the gini index of a feature. https://stackoverflow.com/questions/48999542/more-efficient-weighted-gini-coefficient-in-python Parameters ---------- x The values of a feature. w The weights of samples. Returns ------- float The gini index of the feature. """ x = np.asarray(x) w = w[np.isfinite(x)] if w is not None else None x = x[np.isfinite(x)] if len(np.unique(x)) == 1: return np.nan if w is not None: w = np.asarray(w) sorted_indices = np.argsort(x) sorted_x = x[sorted_indices] sorted_w = w[sorted_indices] # Force float dtype to avoid overflows cumw = np.cumsum(sorted_w, dtype=float) cumxw = np.cumsum(sorted_x * sorted_w, dtype=float) return np.sum(cumxw[1:] * cumw[:-1] - cumxw[:-1] * cumw[1:]) / ( cumxw[-1] * cumw[-1] ) else: sorted_x = np.sort(x) n = len(x) cumx = np.cumsum(sorted_x, dtype=float) # The above formula, with all weights equal to 1 simplifies to: return (n + 1 - 2 * np.sum(cumx) / cumx[-1]) / n
[docs] def pretty(value, htchar="\t", lfchar="\n", indent=0): """ Represent a dictionary, a list, or a tuple by a string. https://stackoverflow.com/questions/3229419/how-to-pretty-print-nested-dictionaries Parameters ---------- value A dictionary, a list, or a tuple to be formatted. htchar The string for indents. lfchar The string between two lines. indent The number of indents. Returns ------- str The formatted representation of ``value``. """ nlch = lfchar + htchar * (indent + 1) if isinstance(value, dict): items = [ nlch + repr(key) + ": " + pretty(value[key], htchar, lfchar, indent + 1) for key in value ] return "{%s}" % (",".join(items) + lfchar + htchar * indent) elif isinstance(value, list): items = [nlch + pretty(item, htchar, lfchar, indent + 1) for item in value] return "[%s]" % (",".join(items) + lfchar + htchar * indent) elif isinstance(value, tuple): items = [nlch + pretty(item, htchar, lfchar, indent + 1) for item in value] return "(%s)" % (",".join(items) + lfchar + htchar * indent) else: return repr(value)
[docs] def update_defaults_by_kwargs(defaults: Dict = None, kwargs: Dict = None): defaults = defaults if defaults is not None else {} defaults.update({} if kwargs is None else kwargs) return defaults
[docs] class Logger: """ Capture all outputs to a log file while still printing it. It works as a utility of :class:`Logging`. https://stackoverflow.com/questions/4675728/redirect-stdout-to-a-file-in-python """
[docs] def __init__(self, path, stream): self.terminal = stream self.path = path
[docs] def write(self, message): self.terminal.write(message) with open(self.path, "ab") as log: log.write(message.encode("utf-8"))
def __getattr__(self, attr): return getattr(self.terminal, attr)
[docs] class Logging: """ Capture all outputs to a log file while still printing it. """
[docs] def enter(self, path): if check_stream(): tabensemb.stdout_stream.set_path(path) tabensemb.stderr_stream.set_path(path) else: self.out_logger = Logger(path, sys.stdout) self.err_logger = Logger(path, sys.stderr) self._stdout = sys.stdout self._stderr = sys.stderr sys.stdout = self.out_logger sys.stderr = self.err_logger
[docs] def exit(self): if check_stream(): tabensemb.stdout_stream.set_path(None) tabensemb.stderr_stream.set_path(None) else: sys.stdout = self._stdout sys.stderr = self._stderr
[docs] def add_postfix(path): """ If the input path exists, add a postfix ``f"-I{n}"`` to it, where ``n`` increases if ``path`` ends with ``f"-I{n}"``. Parameters ---------- path A path to a folder or a file that will be created. Returns ------- str A path that can be created without conflict. """ postfix_iter = itertools.count() s = cp(path) root, ext = os.path.splitext(s) is_folder = len(ext) == 0 last_cnt = postfix_iter.__next__() while os.path.exists(s) if is_folder else os.path.isfile(s): root_split = list(os.path.split(root)) last_postfix = f"-I{last_cnt}" last_cnt = postfix_iter.__next__() if root_split[-1].endswith(last_postfix): # https://stackoverflow.com/questions/2556108/rreplace-how-to-replace-the-last-occurrence-of-an-expression-in-a-string root_split[-1] = f"-I{last_cnt}".join( root_split[-1].rsplit(last_postfix, 1) ) else: root_split[-1] += f"-I{last_cnt}" s = os.path.join(*root_split) + ext root, ext = os.path.splitext(s) return s
[docs] def safe_mkdir(path: os.PathLike): """ Make a previously not existing directory safely resolving conflicts. When multiple tasks are executed simultaneously, this is extremely useful even when ``os.path.exist`` is used. Parameters ---------- path The intended path Returns ------- str The actual made path """ while True: try: os.mkdir(path) break except FileExistsError: path = add_postfix(path) except Exception as e: raise e return path
[docs] class torch_with_grad(_DecoratorContextManager): """ A context manager that enabled gradient calculation. This is an inverse version of torch.no_grad """
[docs] def __init__(self) -> None: if not torch._jit_internal.is_scripting(): super().__init__() self.prev = False
def __enter__(self) -> None: self.prev = torch.is_grad_enabled() torch.set_grad_enabled(True) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch.set_grad_enabled(self.prev)
[docs] class PickleAbleGenerator: """ Turn a generator (not pickle-able) into a pickle-able object by extracting all items in the generator to a list. """
[docs] def __init__(self, generator, max_generate=10000, inf=False): self.ls = [] self.state = 0 for i in range(max_generate): try: self.ls.append(generator.__next__()) except: break else: if not inf: raise Exception( f"The generator {generator} generates more than {max_generate} values. Set inf=True if you " f"accept that only {max_generate} can be pickled." )
def __next__(self): if self.state >= len(self.ls): raise StopIteration else: val = self.ls[self.state] self.state += 1 return val def __getstate__(self): return {"state": self.state, "ls": self.ls} def __setstate__(self, state): self.state = state["state"] self.ls = state["ls"]