from typing import Dict, Union, List
import json
import os.path
import importlib.machinery
import types
import tabensemb
from tabensemb.utils import pretty, str_to_dataframe
from .default import cfg as default_cfg
import argparse
import urllib.request
import ssl
import re
import zipfile
import numpy as np
import warnings
[docs]
class UserConfig(dict):
"""
The configuration holder for :class:`~tabensemb.data.datamodule.DataModule` and :class:`~tabensemb.trainer.Trainer`.
"""
[docs]
def __init__(self, path: str = None):
"""
Parameters
----------
path
Path to the configuration file. See :meth:`from_file`.
"""
super(UserConfig, self).__init__()
self.update(default_cfg)
self._defaults = default_cfg.copy()
if path is not None:
self.merge(self.from_file(path))
[docs]
def defaults(self):
"""
The default values in ``tabensemb.config.default.py``
Returns
-------
dict
A dictionary of default values.
"""
return self._defaults.copy()
[docs]
def merge(self, d: Dict):
"""
Similar to :meth:`dict.update`, but will ignore values that are None.
Parameters
----------
d
The dictionary used to update the configuration.
"""
d_cp = d.copy()
for key, val in d_cp.items():
if val is None:
d.__delitem__(key)
super(UserConfig, self).update(d)
[docs]
@staticmethod
def parse() -> Dict:
"""
Try to parse the configuration using ``argparse``.
Returns
-------
dict
The parsed configuration dictionary.
"""
base_config = UserConfig()
parser = argparse.ArgumentParser()
parser.add_argument("--base", required=True)
for key in base_config.keys():
if type(base_config[key]) in [str, int, float]:
parser.add_argument(
f"--{key}", type=type(base_config[key]), required=False
)
elif type(base_config[key]) == list:
parser.add_argument(
f"--{key}",
nargs="+",
type=(
type(base_config[key][0]) if len(base_config[key]) > 0 else None
),
required=False,
)
elif type(base_config[key]) == bool:
parser.add_argument(f"--{key}", dest=key, action="store_true")
parser.add_argument(f"--no-{key}", dest=key, action="store_false")
parser.set_defaults(**{key: base_config[key]})
parse_res = parser.parse_known_args()[0].__dict__
return parse_res
[docs]
@staticmethod
def from_parser() -> Dict:
"""
Try to parse the configuration using ``argparse`` and merge it into defaults.
Returns
-------
dict
The parsed configuration dictionary.
"""
d = UserConfig.parse()
return UserConfig.from_dict(d)
[docs]
@staticmethod
def from_dict(cfg: Dict) -> "UserConfig":
"""
Merge the input dictionary into defaults.
Parameters
----------
cfg
The dictionary used to update the default configuration.
Returns
-------
UserConfig
The combined configuration.
"""
tmp_cfg = UserConfig()
tmp_cfg.merge(cfg)
return tmp_cfg
[docs]
@staticmethod
def from_file(path: str) -> "UserConfig":
"""
Merge the .py or .json file into defaults. If no suffix is given, it will search the current directory and
``tabensemb.setting["default_config_path"]`` for a matched file. In a legal .py file, there should be a
dictionary named "cfg".
Parameters
----------
path
The path to the configuration file to update the default configuration with or without a suffix
(.py or .json).
Returns
-------
UserConfig
The combined configuration.
"""
file_path = (
path
if "/" in path or os.path.isfile(path)
else os.path.join(tabensemb.setting["default_config_path"], path)
)
ty = UserConfig.file_type(file_path)
if ty is None:
json_path = file_path + ".json"
py_path = file_path + ".py"
is_json = os.path.isfile(json_path)
is_py = os.path.isfile(py_path)
if is_json and is_py:
raise Exception(
f"Both {json_path} and {py_path} exist. Specify the full name of the file."
)
elif not is_json and not is_py:
raise Exception(f"{file_path} does not exist.")
else:
file_path = json_path if is_json else py_path
ty = UserConfig.file_type(file_path)
else:
if not os.path.isfile(file_path):
raise Exception(f"{file_path} does not exist.")
if ty == "json":
with open(file_path, "r") as file:
cfg = json.load(file)
else:
loader = importlib.machinery.SourceFileLoader("cfg", file_path)
mod = types.ModuleType(loader.name)
loader.exec_module(mod)
cfg = mod.cfg
return UserConfig.from_dict(cfg)
[docs]
@staticmethod
def from_uci(
name: str,
datafile_name: str = None,
column_names: List[str] = None,
save_zip: bool = False,
max_retries=3,
timeout=20,
sep=",",
) -> Union["UserConfig", None]:
"""
Search, download, and configure a dataset from https://archive.ics.uci.edu/. The dataset will be extracted and
saved into a .csv file, and a corresponding UserConfig is returned. This function supports tabular datasets for
"Classification" and "Regression". Integer features are treated as continuous features.
Parameters
----------
name
The name of the dataset like "Heart Disease", "Iris", etc. The name will be searched on the website and be
configured if there is a matched dataset.
datafile_name
The name of ".data" file in the downloaded .zip file. If is None and there exists more than one file with
the suffix ".data" in a single dataset, the function will print available names.
column_names
Labels of columns in the ".data" file in the downloaded .zip file. If not given, names recorded on the
website will be used. However, these names can be in a wrong order, of which "Auto MPG" is a typical
example. So a warning will be logged, and `save_zip` will be set to True to let the user check the ".name"
file in the .zip file for the correct order.
save_zip
Whether the downloaded .zip file should be stored.
max_retries
The maximum number of tries of ``urllib.request.urlopen``.
timeout
Waiting time of ``urllib.request.urlopen``.
sep
The delimiter of ``pd.read_csv``.
Returns
-------
UserConfig
The configuration of the dataset. If the dataset can not be automatically configured, None will be returned
and the reason will be printed.
"""
# Extract information
url = (
f"https://archive.ics.uci.edu/datasets?skip=0&take=1&sort=desc&orderBy=Relevance&"
f"search={name.replace(' ', '+')}"
)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
reties = 0
while True:
try:
uh = urllib.request.urlopen(url, context=ctx, timeout=timeout)
break
except:
reties += 1
if reties == max_retries:
raise Exception(
f"max_retries reached. Check whether {url} is accessible."
)
html = uh.read()
datasets: Dict = json.loads(
re.findall(r"\"body\":\"\[(.*?)\]\"\}</script>", html.decode())[0]
.encode()
.decode("unicode-escape")
)["result"]["data"]["json"]["datasets"]
if len(datasets) == 0:
raise Exception(f"Dataset {name} not found.")
dataset = datasets[0]
if dataset["Name"] != name:
raise Exception(f"Dataset {name} not found. Do you mean {dataset['Name']}?")
# Download the dataset
id = dataset["ID"]
slug = dataset["slug"]
link = f"https://archive.ics.uci.edu/static/public/{id}/{slug}.zip"
zip_save_to = os.path.join(
tabensemb.setting["default_data_path"], f"{name}.zip"
)
print(f"Downloading {link} to {zip_save_to}")
os.makedirs(tabensemb.setting["default_data_path"], exist_ok=True)
urllib.request.urlretrieve(link, filename=zip_save_to)
# Check task and tabular
_saved_to_suffix = f" The downloaded file is saved to {zip_save_to}."
task = dataset["Task"]
is_tabular = dataset["isTabular"]
if task not in ["Regression", "Classification"]:
print(f"Task {task} is not supported.{_saved_to_suffix}")
return None
if not is_tabular:
print(f"The dataset {name} is not tabular.{_saved_to_suffix}")
return None
# Check contents
zipf = zipfile.ZipFile(zip_save_to, "r")
files = zipf.namelist()
datafiles = [name.split(".data")[0] for name in files if name.endswith(".data")]
if len(datafiles) == 0:
print(f"No file with suffix `.data` is found.{_saved_to_suffix}")
return None
if len(datafiles) > 1 and (
datafile_name is None or datafile_name not in datafiles
):
print(
f"Found multiple data files {datafiles}, but `datafile_name` is {datafile_name}.{_saved_to_suffix}"
)
return None
test_datafiles = [name for name in files if "test" in name]
if len(test_datafiles) > 0:
warnings.warn(
f"There exists .test file(s) {test_datafiles} which should be used for final metrics. The .zip file is "
f"left for the user to process."
)
save_zip = True
if datafile_name is None:
datafile_name = datafiles[0]
# Extract feature information.
all_features = []
cont_feature_names = []
cat_feature_names = []
label_name = []
for attr in dataset["variables"]:
if attr["role"] == "Feature":
if attr["type"] == "Continuous":
cont_feature_names.append(attr["name"])
if attr["type"] == "Integer":
cont_feature_names.append(attr["name"])
print(
f"{attr['name']} is Integer and will be treated as a continuous feature."
)
elif attr["type"] in ["Categorical", "Binary"]:
cat_feature_names.append(attr["name"])
elif attr["role"] == "Target":
label_name.append(attr["name"])
all_features.append(attr["name"])
# Load and save as .csv
datafile = zipf.read(datafile_name + ".data")
if column_names is None:
warnings.warn(
"`column_names` is not given. The order of columns will be loaded from the website. It is highly "
"recommended to manually set column names. The downloaded .zip is saved. Please check its .name file "
"for the correct order."
)
save_zip = True
column_names = all_features
column_names_not_all_features = [
x for x in column_names if x not in all_features
]
if len(column_names_not_all_features) > 0:
raise Exception(
f"Available column names are {all_features}, but `column_names` has columns not available: "
f"{column_names_not_all_features}."
)
all_features_not_column_names = [
x for x in all_features if x not in column_names
]
if len(all_features_not_column_names) > 0:
warnings.warn(
f"Available column names are {all_features}, but `column_names` does not have "
f"{all_features_not_column_names}."
)
cont_feature_names = [
x for x in cont_feature_names if x not in all_features_not_column_names
]
cat_feature_names = [
x for x in cat_feature_names if x not in all_features_not_column_names
]
original_label_names = label_name.copy()
label_name = [
x for x in label_name if x not in all_features_not_column_names
]
if len(label_name) == 0:
raise Exception(
f"No label is found. Did you miss the label names {original_label_names} in `column_names`?"
)
try:
df = str_to_dataframe(
datafile.decode(),
sep=sep,
names=column_names,
check_nan_on=cont_feature_names,
)
except Exception as e:
print(e)
print(_saved_to_suffix)
return None
# Save csv
csv_name = name if datafile_name is None else datafile_name
df.to_csv(
os.path.join(tabensemb.setting["default_data_path"], f"{csv_name}.csv"),
index=False,
)
zipf.close()
if not save_zip:
os.remove(zip_save_to)
# Configurations
if task == "Regression":
inferred_task = "regression"
else:
if len(np.unique(df[label_name].values)) <= 2:
inferred_task = "binary"
else:
inferred_task = "multiclass"
feature_types = {
name: "Continuous" if name in cont_feature_names else "Categorical"
for name in cont_feature_names + cat_feature_names
}
cfg = UserConfig()
cfg.merge(
{
"database": csv_name,
"task": inferred_task,
"feature_types": feature_types,
"categorical_feature_names": cat_feature_names,
"continuous_feature_names": cont_feature_names,
"label_name": label_name,
}
)
return cfg
[docs]
def to_file(self, path: str):
"""
Save the configuration to a ``.py`` or ``.json`` file.
Parameters
----------
path
The path to save the configuration. If no suffix is given, ``.py`` is added as the suffix.
"""
if path.endswith(".json"):
with open(os.path.join(path), "w") as f:
json.dump(self, f, indent=4)
else:
if not path.endswith(".py"):
path += ".py"
s = "cfg = " + pretty(self, htchar=" " * 4, indent=0)
try:
import black
s = black.format_str(s, mode=black.Mode())
except:
pass
with open(path, "w") as f:
f.write(s)
[docs]
@staticmethod
def file_type(path: str) -> Union[str, None]:
"""
Check the suffix of the path (json, py, or None).
"""
if path.endswith(".json"):
return "json"
elif path.endswith(".py"):
return "py"
else:
return None
def __getitem__(self, item):
if item == "feature_types":
val = super(UserConfig, self).__getitem__(item)
for cont in self["continuous_feature_names"]:
if cont not in val.keys():
val[cont] = "Continuous"
for cat in self["categorical_feature_names"]:
if cat not in val.keys():
val[cat] = "Categorical"
return val
elif item == "unique_feature_types":
return list(sorted(set(self["feature_types"].values())))
else:
return super(UserConfig, self).__getitem__(item)