from os.path import exists
import lightgbm
import pandas as pd
import numpy as np
import xgboost
import yaml
from sklearn.base import RegressorMixin
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from polynomial import Polynomial
from sklearn.utils import all_estimators
import sys
try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader


def get_card_split(df, cols, n=10):
    """
    Divide categorical variables based on cardinality
    :param df:
    :param cols:
    :param n:
    :return:
    """
    cond = df[cols].nunique() > n
    card_high = cols[cond]
    card_low = cols[~cond]
    return card_low, card_high


# input class
class InputData:
    """

    This class contains all the user input data and parameters needed to run MLTool

    """
    X: np.ndarray
    y: np.ndarray
    weight: np.ndarray
    X_train: np.ndarray
    X_test: np.ndarray
    y_train: np.ndarray
    y_test: np.ndarray
    split_flag = True
    user_param = None
    regress = None

    # reading the csv or textfiles
    def read_file(self, file_path):
        """
        Reads the data file
        :param file_path:
        :return:
        """
        # data = np.loadtxt(open(file_path, "rb"), delimiter=",", skiprows=1)
        data = pd.read_csv(file_path)
        data = np.array(data)
        if self.user_param['sample_weight']:
            tmp = data[:, -2]
            return data[:, :-2], tmp[:, :-1], tmp[:, -1]
        return data[:, :-1], data[:, -1], None

    def __init__(self, parameters_yml: str):
        """
        Reads the parameters.yaml, updates all configs and train test data
        :param parameters_yml:
        """
        # Open the parameter.yml file
        with open(parameters_yml) as yml:
            self.user_param = yaml.load(yml, Loader=Loader)
        user_param = {}
        for meth, para in self.user_param['validation'].items():
            if para['active']:
                user_param[meth] = self.user_param['validation'][meth]
                break
        self.user_param['validation'] = user_param
        self.X, self.y, self.weight = self.read_file(self.user_param['datafile'])
        self.set_train_test()

        self.regress = self.set_regress()

    def set_train_test(self):
        """
        Splits the train and test data based on user selected parameters
        :return: None
        """
        val_method = list(self.user_param['validation'].keys())[0]
        print("Validation Method: ", val_method)
        if 'user_defined' in self.user_param['validation']:
            train_path = self.user_param['validation']['user_defined']['train']
            test_path = self.user_param['validation']['user_defined']['train']
            if exists(train_path) and exists(test_path) and self.split_flag:
                print("Split: ", "Manual split")
                self.X_train, self.y_train, self.weight = self.read_file(train_path)
                self.X_test, self.y_test, _ = self.read_file(test_path)
                self.split_flag = False
        elif 'standard' in self.user_param['validation']:
            print("Split: ", self.user_param['validation']['standard']['split'])
            if self.user_param['sample_weight']:
                split_data = np.append(self.X, self.weight)
                self.X_train, self.X_test, self.y_train, self.y_test \
                    = train_test_split(split_data, self.y,
                                       test_size=self.user_param['validation']['standard']['split'],
                                       random_state=42)
                self.X_train = self.X_train[:, :-1]
                self.X_test = self.X_test[:, :-1]
                self.weight = self.X_train[:, -1]
            else:
                self.X_train, self.X_test, self.y_train, self.y_test \
                    = train_test_split(self.X, self.y, test_size=self.user_param['validation']['standard']['split'],
                                       random_state=42)
        else:
            key_name = list(self.user_param['validation'][val_method].keys())[1]
            print(key_name + ": ", self.user_param['validation'][val_method][key_name])
            self.X_train = self.X
            self.X_test = self.X
            self.y_train = self.y
            self.y_test = self.y

    def set_regress(self):
        """
        Updates the list of regressors to fit the data
        :return: List of regressors
        """
        removed_regressors = self.user_param['excluded_estimators']

        # Reading list of sklearn estimators and append extra regressors
        regress = [est for est in all_estimators() if
                   (issubclass(est[1], RegressorMixin) and (est[0] not in removed_regressors))]
        regress.append(("XGBRegressor", xgboost.XGBRegressor))
        regress.append(("LGBMRegressor", lightgbm.LGBMRegressor))

        # Adds Polynomial regressions to list of estimators
        polynomial_models = filter(lambda m: m.startswith("Polynomial"), self.user_param['models'].keys())
        for polynomial_model in polynomial_models:
            regress.append((polynomial_model, Polynomial))
        return regress

    def get_preprocessor(self):
        """
        Conduct preprocessing steps based on the user selected configuration
        :return:
        """
        x_train = pd.DataFrame(self.X_train)
        numeric_features = x_train.select_dtypes(include=[np.number]).columns
        categorical_features = x_train.select_dtypes(include=["object"]).columns
        # Split categorical columns into 2 lists based on cardinality (i.e # of unique values)  Parameters
        categorical_low, categorical_high = get_card_split(x_train, categorical_features)

        # Create the imputer and the scaler for numerical values
        if self.user_param['categorical_imputer'] == "constant":
            categorical_transformer_low = Pipeline(
                steps=[
                    ("imputer", SimpleImputer(strategy=self.user_param['categorical_imputer'], fill_value="missing")),
                    ("encoding", OneHotEncoder(handle_unknown="ignore", sparse=False)), ])

            categorical_transformer_high = Pipeline(
                steps=[
                    ("imputer", SimpleImputer(strategy=self.user_param['categorical_imputer'], fill_value="missing")),
                    ("encoding", OrdinalEncoder()), ])
        else:
            categorical_transformer_low = Pipeline(
                steps=[
                    ("imputer", SimpleImputer(strategy=self.user_param['categorical_imputer'])),
                    ("encoding", OneHotEncoder(handle_unknown="ignore", sparse=False)), ])

            categorical_transformer_high = Pipeline(
                steps=[
                    ("imputer", SimpleImputer(strategy=self.user_param['categorical_imputer'])),
                    ("encoding", OrdinalEncoder()), ])

        if self.user_param['numerical_imputer'] == "KNN":
            imputer_meth = KNNImputer(n_neighbors=5)
#        elif self.user_param['numerical_imputer'] == "MissForest":
#            imputer_meth = MissForest()
        else:
            imputer_meth = SimpleImputer(strategy=self.user_param['numerical_imputer'])

        numeric_transformer = Pipeline(
            steps=[("imputer", imputer_meth), ("scaler", StandardScaler())])

        return ColumnTransformer(
            transformers=[("numeric", numeric_transformer, numeric_features),
                          ("categorical_low", categorical_transformer_low, categorical_low),
                          ("categorical_high", categorical_transformer_high, categorical_high), ])

    # get parameters from json file
    def get_parameters(self, model_name, model_parameters):
        """
        Gives the parameters needed for provided model
        :param model_name:
        :param model_parameters:
        :return:
        """
        user_param = self.user_param.copy()
        try:
            model_specific_param = user_param['models'][model_name]
        except KeyError:
            model_specific_param = {}
        del user_param['models']

        user_param['sample_weights'] = self.weight

        parameters = {**user_param, **model_specific_param}

        common_para = list(set(parameters) & set(model_parameters))
        return {key: value for key, value in parameters.items() if key in common_para}
