# Import libraries.
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual, Layout

# Import own libraries.
import utils as utils

# Dropdown: decision stump feature.
dropdown_feature = widgets.Dropdown(
    options=['f0', 'f1'],
    value='f0',
    description='Feature:',
    disabled=False,
);

# Slider: decision stump threshold.
slider_threshold = widgets.FloatSlider(
    value=0.5,
    min=0.0001,
    max=0.9999,
    step=0.01,
    description='Threshold',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=Layout(height='50px', width='500px')
);

# Buttons: impurity.
buttons_impurity = widgets.ToggleButtons(
    options=['Entropy', 'Gini index'],
    description='Impurity',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltips=['Entropy', 'Gini'],
);

# Buttons: dataset section.
buttons_dataset = widgets.ToggleButtons(
    options=['2D Clusters', '2D Spirals', '2D Circle', '2D Multi-clusters'],
    description='Scenario',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltips=['Scenario 1', 'Scenario 2', 'Scenario 3', 'Scenario 4'],
);

# Slider: number of samples - class 1.
slider_samples_class_1 = widgets.IntSlider(
    value=200,
    min=10,
    max=2000,
    step=10,
    description='C1 samples',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Slider: number of samples - class 2.
slider_samples_class_2 = widgets.IntSlider(
    value=200,
    min=10,
    max=2000,
    step=10,
    description='C2 samples',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Slider: number of samples - class 3.
slider_samples_class_3 = widgets.IntSlider(
    value=200,
    min=10,
    max=2000,
    step=10,
    description='C3 samples:',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Slider: number of samples - class 4.
slider_samples_class_4 = widgets.IntSlider(
    value=200,
    min=10,
    max=2000,
    step=10,
    description='C4 samples',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Slider: random noise
slider_noise = widgets.FloatSlider(
    value=1.0,
    min=0.1,
    max=5.0,
    step=0.1,
    description='Noise',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=Layout(height='50px', width='500px')
);

# Slider: tree depth.
slider_tree_depth = widgets.IntSlider(
    value=1,
    min=1,
    max=15,
    step=1,
    description='Depth',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Dropdown: iris feature 0.
dropdown_iris_feature_0 = widgets.Dropdown(
    options=['Sepal length', 'Sepal width', 'Petal length', 'Petal width'],
    value='Sepal length',
    description='Feature:',
    disabled=False,
);

# Dropdown: iris feature 1.
dropdown_iris_feature_1 = widgets.Dropdown(
    options=['Sepal length', 'Sepal width', 'Petal length', 'Petal width'],
    value='Sepal width',
    description='Feature:',
    disabled=False,
);

# Slider: random forest depth.
slider_rf_depth = widgets.IntSlider(
    value=1,
    min=1,
    max=9,
    step=1,
    description='Depth',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Slider: num. trees.
slider_num_trees = widgets.IntSlider(
    value=1,
    min=1,
    max=50,
    step=1,
    description='Trees',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Slider: max. features.
slider_max_features = widgets.IntSlider(
    value=1,
    min=1,
    max=100,
    step=1,
    description='Max. feat.',
    disabled=False,
    continuous_update=True,
    orientation='horizontal',
    readout=True,
    readout_format='d',
    layout=Layout(height='50px', width='500px')
);

# Decision stump: Binary test parameters.
def stump_binary_test_parameters():
    """ This function sets and returns the binary test parameters for decision
    stumps.

    Parameters:
    + None

    Returns:
    + feature (int): Index of the feature (0 or 1) for 2D problems.
    + thr (float): Threshold value.

    """

    # Decision stump widgets.
    display(dropdown_feature);  # Feature.
    display(slider_threshold);  # Threshold.

    # Decision stump parameters: feature and threshold.
    if dropdown_feature.value == 'f0': feature = 0
    if dropdown_feature.value == 'f1': feature = 1

    return feature, slider_threshold.value

# Synthetic dataset parameters.
def synth_dataset_parameters():
    """ This function sets and returns the parameters of the synthetic dataset.

    Parameters:
    None

    Returns:
    + scenario (int): Synthetic scenario index [1-4].
    + num_samples (int list): List with the number of samples per class.
    + noise (float): Noise factor.

    """

    # Widget: datasets.
    display(buttons_dataset);

    # Dataset scenario.
    if buttons_dataset.value=="2D Clusters":
        scenario = 1
    if buttons_dataset.value=="2D Spirals":
        scenario = 2
    if buttons_dataset.value=="2D Circle":
        scenario = 3
    if buttons_dataset.value=="2D Multi-clusters":
        scenario = 4

    # Widget: number of samples.
    display(slider_samples_class_1);
    display(slider_samples_class_2);
    if scenario == 4:
        display(slider_samples_class_3);
        display(slider_samples_class_4);

    # Number of samples per class.
    num_samples = [slider_samples_class_1.value, slider_samples_class_2.value]
    if scenario == 4:
        num_samples.append(slider_samples_class_3.value)
        num_samples.append(slider_samples_class_4.value)

    # Widget: random noise.
    display(slider_noise);

    return scenario, num_samples, slider_noise.value

# Dataset scenario.
def dataset_scenario():
    """ This function sets the scenario for the synthetic dataset. 

    Parameters:
    None

    Returns:
    + scenario (int): Synthetic scenario index [1-4].

    """
    # Widget: datasets.
    display(buttons_dataset);

    # Dataset scenario.
    if buttons_dataset.value=="2D Clusters":
        scenario = 1
    if buttons_dataset.value=="2D Spirals":
        scenario = 2
    if buttons_dataset.value=="2D Circle":
        scenario = 3
    if buttons_dataset.value=="2D Multi-clusters":
        scenario = 4

    return scenario

# Impurity.
def impurity(entropy_fun, gini_fun):
    """ This function allows to select one of the impurity methods: Shannon
    entropy or Gini index. """

    # Buttons for node impurity.
    display(buttons_impurity);

    # Select impurity function: entropy or gini index.
    if buttons_impurity.value == 'Entropy':
        impurity_fun = entropy_fun
    if buttons_impurity.value == 'Gini index':
        impurity_fun = gini_fun

    return impurity_fun

# Decision tree depth.
def decision_tree_depth():
    """ This function returns the decision tree depth. """

    # Tree depth.
    display(slider_tree_depth);

    return slider_tree_depth.value

# Decision tree impurity.
def decision_tree_impurity():
    """ This function allows to select one of the impurity methods: Shannon
    entropy or Gini index. """

    # Buttons for node impurity.
    display(buttons_impurity);

    # Select impurity function: entropy or gini index.
    if buttons_impurity.value == 'Entropy':
        impurity = 'entropy'
    if buttons_impurity.value == 'Gini index':
        impurity = 'gini'

    return impurity

# Data noise.
def data_noise():
    """ This function returns the noise factor. """

    # Widget: random noise.
    display(slider_noise);

    return slider_noise.value

# IRIS dataset features.
def iris_features():
    """ This function allows to select an return two dataset features.  """

    # Features.
    display(dropdown_iris_feature_0);  # Feature 0.
    display(dropdown_iris_feature_1);  # Feature 1.

    # Selected features.
    if dropdown_iris_feature_0.value == 'Sepal length': f0 = 0
    if dropdown_iris_feature_0.value == 'Sepal width': f0 = 1
    if dropdown_iris_feature_0.value == 'Petal length': f0 = 2
    if dropdown_iris_feature_0.value == 'Petal width': f0 = 3
    if dropdown_iris_feature_1.value == 'Sepal length': f1 = 0
    if dropdown_iris_feature_1.value == 'Sepal width': f1 = 1
    if dropdown_iris_feature_1.value == 'Petal length': f1 = 2
    if dropdown_iris_feature_1.value == 'Petal width': f1 = 3

    return f0, f1

# Random forest parameters.
def random_forest_parameters():
    """ This function sets and returns the random forest parameters."""

    # Tree depth.
    display(slider_rf_depth);

    # Number of trees.
    display(slider_num_trees);

    # Max. features.
    display(slider_max_features);

   # Buttons for node impurity.
    display(buttons_impurity);

    # Select impurity function: entropy or gini index.
    if buttons_impurity.value == 'Entropy':
        impurity_method = 'entropy'
    if buttons_impurity.value == 'Gini index':
        impurity_method = 'gini'

    return slider_num_trees.value, slider_rf_depth.value, \
            slider_max_features.value, impurity_method

