# Libraries.
import time
import numpy as np
import pickle as pk
import matplotlib.pyplot as plt
from sklearn import tree
from IPython.display import SVG
#from graphviz import Source
from IPython.display import display
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.ensemble import RandomForestClassifier

# Import own libraries.
import synthetic_dataset_2d as synth_dataset

# Split labels.
def split_labels(samples, feature, thr, binary_test_fun):
    """ This function returns the split labels for input samples after applying
    the binary test.

    Parameters:
    + samples (Numpy array): 2xM array with the samples 
        (2D points), being M the number of samples.
    + feature (int): Binary test feature.
    + thr (float): Binary test threshold.
    + binary_test_fun (fun): Binary test function.

    Returns:
    + split (Numpy array): 1xN array containing the split labels for input
        samples. Labels are binary numbers, 0 for left child node and 1 for
        right child node.

    """

    # Number of samples.
    num_samples = samples.shape[1]

    # Split samples using a binary test. Note: 1 is used to convert Boolean
    # values (false/true) to integers (0/1).
    split = np.zeros((1, num_samples))
    for n in range(num_samples):
        split[0, n] = 1*binary_test_fun(samples[:,n], feature, thr)

    return split

# Show binary test.
def show_binary_test(dataset, feature, thr, binary_test_fun):
    """ This function shows the binary test and the dataset samples after
    splitting the feature space.

    + dataset (dict): Dictionary containing the dataset with the 
        following data:
        * samples (Numpy array): 2xM array with the samples 
            (2D points), being M the number of samples.
        * labels (Numpy array): 1xM array with the class labels
            for all samples. Labels are numbers in the range 
            [1, N], being N the number of classes.
        * num_classes (int): The number of classes for the
            selected scenario.
    + feature (int): Binary test feature.
    + thr (float): Binary test threshold.
    + binary_test_fun (fun): Binary test function.

    Returns:
    None

    """

    # Run the binary test to get the split labels indicating if every sample
    # falls in the left child node (0) or the right child node (1).
    split = split_labels(dataset['samples'], feature, thr, binary_test_fun)

    # Check number of samples and split labels.
    assert dataset['samples'].shape[1]==split.shape[1]

    # Plot samples.
    plt.figure(figsize=(5, 5))
    for c in range(2):
        indx = split[0,:]==c  # Indexes for current class.
        if c == 0: color=(0.5,0.5,0.5)
        if c == 1: color=(0,0.7,0.7)
        plt.plot(dataset['samples'][0, indx], dataset['samples'][1, indx],\
                 marker='o', linestyle='None', color=color)

    # Draw split line.
    if feature == 0: plt.plot([thr, thr], [0, 1], 'k-', linewidth=5)
    if feature == 1: plt.plot([0, 1], [thr, thr], 'k-', linewidth=5)

    # Plot figure.
    plt.axis('equal')
    plt.xlabel('f0', fontsize=18)
    plt.ylabel('f1', fontsize=18)
    plt.title('2D Samples: left node (gray), right node (cyan)', fontsize=18)
    plt.grid()
    plt.show()

# Class distribution.
def class_distribution(labels, num_classes):
    """ This function computes the class distribution according 
        to the input labels. It is computed as a normalized histogram.

    Parameters:
    + labels (Numpy array): 1xM array with the class labels. They
        are numbers in the range [1, N], being N the number of classes.
    + num_classes (int): Number of dataset classes.

    Returns:
    + distr (Numpy array): 1xN array with the class 
        probability distribution.

    """

    # Empty labels.
    if len(labels)==0:
        return np.zeros((1,num_classes))

    # Compute class distribution -normalized histogram-.
    distr, _ = np.histogram(labels[0,:], bins=np.arange(1, num_classes+2),
                            density=True)
    distr = np.reshape(distr, (-1, num_classes))

    return distr

# Decision stump.
def decision_stump(dataset, feature, thr, binary_test_fun):
    """ This function performs a decision stump. 

    Parameters:
    + dataset (dict): Dictionary containing the dataset with the 
        following data:
        * samples (Numpy array): 2xM array with the samples 
            (2D points), being M the number of samples.
        * labels (Numpy array): 1xM array with the class labels
            for all samples. Labels are numbers in the range 
            [1, N], being N the number of classes.
        * num_classes (int): The number of classes for the
            selected scenario.
    + feature (int): Binary test feature.
    + thr (float): Binary test threshold.
    + binary_test_fun (fun): Binary test function.

    Returns:
    + stump (dict): Python dictionary containing the decision stump data:
        * root_node (dict): Dictionary with the root node data:
            * samples (Numpy array): 2xM array with the samples
                (2D points), being M the total number of samples.
            * labels (Numpy array): 1xM array with the class labels
                for all samples. Labels are numbers in the range 
                [1, N], being N the number of classes.
            * num_samples (int): Number of samples in the root node.
            * distr (Numpy array): 1xN array containing the probability
                class distribution.
        * left_node (dict): Dictionary with the left child node data:
            * samples (Numpy array): 2xM array with the samples
                (2D points), being M the number of samples in the left node.
            * labels (Numpy array): 1xM array with the class labels
                for samples in the left node. Labels are numbers in the range 
                [1, N], being N the number of classes.
            * num_samples (int): Number of samples in the left node.
            * distr (Numpy array): 1xN array containing the probability
                class distribution.
        * right (dict): Dictionary with the right child node data:
            * samples (Numpy array): 2xM array with the samples
                (2D points), being M the number of samples in the right node.
            * labels (Numpy array): 1xM array with the class labels
                for samples in the right node. Labels are numbers in the range 
                [1, N], being N the number of classes.
            * num_samples (int): Number of samples in the right node.
            * distr (Numpy array): 1xN array containing the probability
                class distribution.
        * binary_test (dict): Dictionary with the binary test parameters:
            * feature (int): Feature index.
            * thr (float): Threshold value.
            * function (fun): Binary test function.

    """

    # Dataset data.
    labels = dataset['labels']
    samples = dataset['samples']
    num_samples = dataset['samples'].shape[1]
    num_classes = dataset['num_classes']

    # Split samples using a binary test. Note: 1 is used to convert Boolean
    # values (false/true) to integers (0/1).
    split = split_labels(samples, feature, thr, binary_test_fun)

    # Left and right sample indexes. 
    l_indxs = split[0,:]==0  # Left child node.
    r_indxs = split[0,:]==1  # Right child node.

    # Left and right samples.
    l_samples = samples[:, l_indxs]
    r_samples = samples[:, r_indxs]

    # Number of left and right samples.
    num_lsamples = l_samples.shape[1]
    num_rsamples = r_samples.shape[1]
    assert num_lsamples + num_rsamples == num_samples

    # Left and right labels.
    l_labels = labels[:, l_indxs]
    r_labels = labels[:, r_indxs]

    # Probability class distribution for the parent node -root node-.
    p_distr = class_distribution(labels, num_classes)

    # Probability class distributions for the left and right child nodes.
    l_distr = class_distribution(l_labels, num_classes)
    r_distr = class_distribution(r_labels, num_classes)

    # Parent node and child nodes.
    p_node = {'samples':samples, 'labels':labels, 'distr':p_distr, \
              'num_samples':num_samples}
    l_node = {'samples':l_samples, 'labels':l_labels, 'distr':l_distr, \
              'num_samples':num_lsamples}
    r_node = {'samples':r_samples, 'labels':r_labels, 'distr':r_distr, \
              'num_samples':num_rsamples}

    # Binary test.
    bin_test = {'feature':feature, 'thr':thr, 'function':binary_test_fun}

    # Decision stump.
    stump = {'root_node':p_node, 'left_node':l_node, 'right_node':r_node, \
             'binary_test':bin_test}

    return stump

# Node histograms.
def node_histograms(stump, num_bins):
    """ This function computes the histograms of samples in the root, left and
    child nodes. """

    # Root node.
    plt.figure(figsize=(12,5))
    plt.subplot(1,3,1)
    n, bins, patches = plt.hist(stump['root_node']['labels'][0,:], \
                                bins=np.arange(1, num_bins+2)-0.5)
    plt.xlabel('Classes', fontsize=18)
    plt.ylabel('# Samples', fontsize=18)
    plt.title('Root Node', fontsize=18)
    plt.xticks(range(1, num_bins+1))
    plt.grid()
    for k in range(num_bins):
        patches[k].set_facecolor(synth_dataset.colors(k))

    # Left child node.
    plt.subplot(1,3,2)
    n, bins, patches = plt.hist(stump['left_node']['labels'][0,:], \
                                bins=np.arange(1, num_bins+2)-0.5)
    plt.xlabel('Classes', fontsize=18)
    #plt.ylabel('# Samples', fontsize=18)
    plt.title('Left Child Node', fontsize=18)
    plt.xticks(range(1, num_bins+1))
    plt.grid()
    for k in range(num_bins):
        patches[k].set_facecolor(synth_dataset.colors(k))

    # Right child node.
    plt.subplot(1,3,3)
    n, bins, patches = plt.hist(stump['right_node']['labels'][0,:], \
                                bins=np.arange(1, num_bins+2)-0.5)
    plt.xlabel('Classes', fontsize=18)
    #plt.ylabel('# Samples', fontsize=18)
    plt.title('Right Child Node', fontsize=18)
    plt.xticks(range(1, num_bins+1))
    plt.grid()
    for k in range(num_bins):
        patches[k].set_facecolor(synth_dataset.colors(k))

    plt.show()

# Node distributions.
def node_distributions(stump, num_bins):
    """ This function computes the distribution of samples -normalized
    histogram- in the left and child nodes and the root node. """

    # Root node.
    plt.figure(figsize=(12,5))
    plt.subplot(1,3,1)
    plt.bar(np.arange(1, num_bins+1), height=stump['root_node']['distr'][0,:])
    plt.xlabel('Classes', fontsize=18)
    plt.ylabel('Probability', fontsize=18)
    plt.title('Root Node', fontsize=18)
    plt.xticks(np.arange(1, num_bins+1))
    plt.grid()

    # Left child node.
    plt.subplot(1,3,2)
    plt.bar(np.arange(1, num_bins+1), height=stump['left_node']['distr'][0,:])
    plt.xlabel('Classes', fontsize=18)
    #plt.ylabel('Probability', fontsize=18)
    plt.title('Left Child Node', fontsize=18)
    plt.xticks(np.arange(1, num_bins+1))
    plt.grid()

    # Right child node.
    plt.subplot(1,3,3)
    plt.bar(np.arange(1, num_bins+1), height=stump['right_node']['distr'][0,:])
    plt.xlabel('Classes', fontsize=18)
    plt.title('Right Child Node', fontsize=18)
    plt.xticks(np.arange(1, num_bins+1))
    plt.grid()
    plt.show()

# Class prediction
def prediction(dataset, stump):
    """" This function computes the class predictions (class labels) for every
    sample in the dataset. """

    # Run the binary test to get the split labels indicating if every sample
    # falls in the left child node (0) or the right child node (1).
    split = split_labels(dataset['samples'], stump['binary_test']['feature'], \
                 stump['binary_test']['thr'], stump['binary_test']['function'])

    # Number of samples.
    num_samples = dataset['samples'].shape[1]

    # Predictions.
    pred = np.zeros((1, num_samples))

    # Compute class predictions.
    for n in range(num_samples):
        if split[0,n] == 0:
            # Left child node.
            max_class = np.argmax(stump['left_node']['distr'])+1
        else:
            # Right child node.
            max_class = np.argmax(stump['right_node']['distr'])+1
        pred[0,n] = max_class

    return pred

# Accuracy.
def accuracy(labels, pred, show=False):
    """ This function computes classification accuracy between the predictions
    and the ground-truth labels. """

    # Accuracy.
    acc= np.sum(pred==labels)/labels.shape[1]

    # Message.
    if show:
        print ('Accuracy: {0:.3f}'.format(acc))

    return acc

# Show classification.
def show_classification(samples, labels, pred, num_classes):
    """ This function the classification results. """

    # Check samples and labels size.
    assert samples.shape[1]==labels.shape[1]

    # Plot samples.
    plt.figure(figsize=(15, 5))
    plt.subplot(1,2,1)
    for c in range(num_classes):
        indx = labels[0,:]==c+1  # Indexes for current class.
        plt.plot(samples[0, indx], samples[1, indx], marker='o', \
                 linestyle='None', color=synth_dataset.colors(c))
    plt.axis('equal')
    plt.xlabel('f0', fontsize=18)
    plt.ylabel('f1', fontsize=18)
    plt.title('2D Samples', fontsize=18)
    plt.grid()

    # Plot predictions.
    plt.subplot(1,2,2)
    for c in range(num_classes):
        indx = pred[0,:]==c+1  # Indexes for current class.
        plt.plot(samples[0, indx], samples[1, indx], marker='o', \
                 linestyle='None', color=synth_dataset.colors(c))
    plt.xlabel('f0', fontsize=18)
    plt.ylabel('f1', fontsize=18)
    plt.title('Classification Output', fontsize=18)
    plt.grid()
    plt.show()

# Show performance.
def show_performance(thrs, vec_acc, vec_gain, show=True):
    """ This function shows the performance of a decision stump in terms of
    classification accuracy and information gain for different thresholds and
    selected features. """

    # Figure 1: accuracy plot.
    plt.figure(figsize=(15, 5))
    plt.subplot(1,2,1)
    for feature in range(2):
        plt.plot(thrs, vec_acc[feature, :], linewidth=4, \
                 color= synth_dataset.colors(feature), label='f{}'.format(feature))
    plt.xlabel('Threshold', fontsize=18)
    plt.ylabel('Accuracy', fontsize=18)
    plt.legend(fontsize=18)
    plt.grid()

    # Figure 2: information gain plot.
    plt.subplot(1,2,2)
    for feature in range(2):
        plt.plot(thrs, vec_gain[feature, :], linewidth=4, \
                 color= synth_dataset.colors(feature), label='f{}'.format(feature))
    plt.xlabel('Threshold', fontsize=18)
    plt.ylabel('Info. Gain', fontsize=18)
    plt.legend(fontsize=18)
    plt.grid()
    plt.show()

    # Best parameters.
    vec_gain[np.isnan(vec_gain)] = 0 # Convert NaN values to zero.
    b_params = np.where(vec_gain == np.amax(vec_gain))
    b_feat = b_params[0][0]
    b_indx = b_params[1][0]
    b_thr = thrs[b_indx]
    b_gain = vec_gain[b_feat, b_indx]

    # Message.
    if show:
        print ('Best decision stump parameters:')
        print ('+ Feature: f{0} '.format(b_feat))
        print ('+ Threshold: {0:.2f} '.format(b_thr))
        print ('Best scores:')
        print ('+ Max. information gain: {0:.3f} '.format(b_gain))
        print ('+ Max. accuracy: {0:.3f} '.format(vec_acc[b_feat, b_indx]))

    return b_gain, b_feat, b_thr

# Decision stump: synthetic scenarios.
def play_stump_synthetic_scenarios(scenario, impurity_fun, binary_test_fun,\
                                   info_gain_fun):
    """ This function finds the best binary test parameters (feature and
    threshold) for different synthetic scenarios. """

    # Train samples.
    dataset = synth_dataset.load_dataset(scenario=scenario)

    # Thresholds.
    step = 0.01  # Threshold step.
    thrs = np.arange(0.1, 0.9, step)

    # Vectors.
    vec_acc = np.zeros((2, len(thrs)))
    vec_gain = np.zeros((2, len(thrs)))

    for feature in range(2):
        for k, thr in enumerate(thrs):

            # Run the decision stump.
            stump = decision_stump(dataset, feature, thr, binary_test_fun)

            # Class predictions.
            pred = prediction(dataset, stump)

            # Compute classification accuracy.
            acc = accuracy(dataset['labels'], pred)

            # Node impurity.
            p_impurity = impurity_fun(stump['root_node']['distr'])  # Parent node -root-.
            l_impurity = impurity_fun(stump['left_node']['distr'])  # Left child node.
            r_impurity = impurity_fun(stump['right_node']['distr'])  # Right child node.

            # Number of samples in the left and right child nodes.
            num_lsamples = stump['left_node']['num_samples']
            num_rsamples = stump['right_node']['num_samples']

            # Compute information gain.
            gain = info_gain_fun(p_impurity, l_impurity, r_impurity, \
                                 num_lsamples, num_rsamples)

            # Save accuracy and information gain.
            vec_acc[feature, k] = acc
            vec_gain[feature, k] = gain

    # Show decision stump performance in terms of selected feature and thresholds.
    b_gain, b_feat, b_thr = show_performance(thrs, vec_acc, vec_gain);

    # Run best decision stump.
    stump = decision_stump(dataset, b_feat, b_thr, binary_test_fun)

    # Class predictions.
    pred = prediction(dataset, stump)

    # Show classification output.
    show_classification(dataset['samples'], dataset['labels'], pred, \
                        dataset['num_classes'])

# Show decision tree.
def show_tree(d_tree, num_classes):
    """ This function shows the decision treee.  """

    # Class names
    class_names = [str(k) for k in range(num_classes)]

    # Create tree graph.
    #graph = Source(tree.export_graphviz(d_tree, out_file=None, \
        #feature_names=['f0', 'f1'], class_names=class_names, filled = True))
    #display(SVG(graph.pipe(format='svg')))

# Show tree depth performance.
def show_tree_depth_performance(depths, vec_acc, vec_depth):
    """ This functions shows the classification performance of a decision tree
    according to the maximum allowed tree depth. """

    # Best accuracy.
    print ('Results:')
    print ('+ Max. train accuracy: {0:.3f}'.format(np.amax(vec_acc[0,:])))
    print ('+ Max. test accuracy: {0:.3f}'.format(np.amax(vec_acc[1,:])))

    # Plot accuracy.
    plt.figure(figsize=(15, 4))
    plt.subplot(1,2,1)
    plt.plot(depths, vec_acc[0,:], marker='o', color=synth_dataset.colors(0),\
             label='Train')
    plt.plot(depths, vec_acc[1,:], marker='o', color=synth_dataset.colors(1), \
             label='Test')
    plt.xlabel('Max. Depth', fontsize=18)
    plt.ylabel('Accuracy', fontsize=18)
    plt.legend(fontsize=18)
    plt.grid()

    # Attained depth.
    plt.subplot(1,2,2)
    plt.plot(depths, vec_depth, marker='o', color=synth_dataset.colors(2))
    plt.xlabel('Max. Depth', fontsize=18)
    plt.ylabel('Attained Depth', fontsize=18)
    plt.grid()
    plt.show()

# Load dataset.
def load_iris_dataset(perc_test=0.2):
    """ This function loads the iris dataset using sklearn. 

    Parameters:
    + perc_test (float): Percentage of test data (default: 0.2).

    Returns:
    + train_dataset (dict): Dictionaty containing the training data:
        * samples (Numpy array): 4xM array containing the 4D feature samples,\
                being M the number of training samples.
        * labels (Numpy array): 1xM array with the class labels (three classes).
        * num_classes (int): Number of classes (three classes).
    + test_dataset (dict): Dictionaty containing the test data:
        * samples (Numpy array): 4xM array containing the 4D feature samples,\
                being M the number of test samples.
        * labels (Numpy array): 1xM array with the class labels (three classes).
        * num_classes (int): Number of classes (three classes).

    """

    # Fix random seed.
    np.random.seed(1)

    # Load dataset using sklearn.
    iris = load_iris()

    # Attributes/features.
    X = iris.data.T
    y = iris.target

    # Reshape labels.
    y = np.reshape(y, (-1, len(y)))

    # Number of samples and classes.
    num_samples = np.shape(X)[1]
    num_classes = 3  # Three flowers.

    # Split dataset into training and test sets.
    rnd_indexes = np.random.permutation(num_samples)  # Shuffle samples.
    split = int(perc_test*num_samples)  # Test split.

    # Test set: samples and targets.
    test_X = X[:, rnd_indexes[:split]]
    test_y = y[:, rnd_indexes[:split]]

    # Training set: samples and targets.
    train_X = X[:, rnd_indexes[split:]]
    train_y = y[:, rnd_indexes[split:]]

    # Check number of features.
    assert np.shape(test_X)[0] == np.shape(train_X)[0]

    # Check number of samples.
    assert np.shape(test_X)[1] == np.shape(test_y)[1]
    assert np.shape(train_X)[1] == np.shape(train_y)[1]

    # Train and test data.
    train_data = {'samples':train_X, 'labels':train_y, \
                  'num_classes':num_classes, 'feat_names':iris.feature_names, \
                  'target_names':iris.target_names}
    test_data = {'samples':test_X, 'labels':test_y, \
                  'num_classes':num_classes, 'feat_names':iris.feature_names, \
                  'target_names':iris.target_names}

    return train_data, test_data

# Show iris samples.
def  show_iris_samples(dataset, feats):
    """ This function shows the iris dataset samples in the 2D feature space
    using the input feature indexes. 

    Parameters:
    + dataset (dict): Dictionaty containing the dataset data:
        * samples (Numpy array): 4xM array containing the 4D feature samples,\
                being M the number of samples.
        * labels (Numpy array): 1xM array with the class labels (three classes).
        * num_classes (int): Number of classes (three classes).

    Returns:
    Node

    """

    # Check features: only two features allowed.
    assert len(feats) == 2

    # Defined colors.
    colors = [(1, 0, 0), (0, 0, 1), (0.0, 0.5, 0), (0.5, 0, 0.5), \
              (0.7, 0.5, 0), (0, 0.5, 0.7), (0.7, 0.4, 0.3), (0.3, 0.2, 0.8)]

    # Dataset.
    samples = dataset['samples']
    labels = dataset['labels']
    num_classes = dataset['num_classes']
    feat_names = dataset['feat_names']

    # Check samples and labels size.
    assert samples.shape[1] == labels.shape[1]

    # Plot samples.
    plt.figure(figsize=(8, 5))
    for c in range(num_classes):
        indx = labels[0,:]==c
        plt.scatter(samples[feats[0], indx], samples[feats[1], indx], \
                    color=colors[c])
    plt.grid()
    plt.xlabel('f{0} - {1}'.format(feats[0],feat_names[feats[0]]), fontsize=18)
    plt.ylabel('f{0} - {1}'.format(feats[1],feat_names[feats[1]]), fontsize=18)
    plt.title('Samples: Setosa (red), Versicolour (green), Virginica (blue)',\
              fontsize=18)
    plt.show()

# Show decision tree for iris dataset.
def show_iris_tree(d_tree, num_classes):
    """ This function shows the decision treee for iris flowers dataset. """

    # Class names
    class_names = [str(k) for k in range(num_classes)]

    # Create tree graph.
    #graph = Source(tree.export_graphviz(d_tree, out_file=None, \
        #feature_names=['f0', 'f1', 'f2', 'f3'], class_names=class_names, filled = True))
    #display(SVG(graph.pipe(format='svg')))

# Confusion matrix.
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None,
                          cmap=plt.cm.Blues):
    """ This function prints and plots the confusion matrix.  Normalization can
    be applied by setting `normalize=True`.  """

    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Flatten labels.
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    classes = classes[unique_labels(y_true, y_pred)]

    fig, ax = plt.subplots(figsize=(5,5))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    plt.show()

# Load mnist digit dataset.
def load_mnist_dataset(path='../data/mnist.pkl'):
    """ This function loads the mnist digit dataset.

    Parameters:
    + path (str): String wit the path to the dataset file (pickle file).

    Returns:
    + train_dataset (dict): Dictionary containing the training data:
        * samples (Numpy array): 784xM array containing the 784-dimensional
            feature samples, being M the number of samples.
        * labels (Numpy array): 1xM array with the class labels (ten classes).
        * num_classes (int): Number of classes (ten classes).
    + test_dataset (dict): Dictionary containing the testing data:
        * samples (Numpy array): 784xM array containing the 784-dimensional
            feature samples, being M the number of samples.
        * labels (Numpy array): 1xM array with the class labels (ten classes).
        * num_classes (int): Number of classes (ten classes).

    """

    # Fix random seed.
    np.random.seed(1)

    # Load dataset.
    train_set, _, test_set = pk.load(open(path, 'rb'), encoding='latin1')

    # Extract train and test samples and labels
    train_samples = train_set[0].T
    test_samples = test_set[0].T
    train_labels = train_set[1]
    test_labels = test_set[1]

    # Reshape labels.
    train_labels = np.reshape(train_labels, (-1, len(train_labels)))
    test_labels = np.reshape(test_labels, (-1, len(test_labels)))

    # Variables.
    num_train_samples = np.shape(train_samples)[0]
    num_test_samples = np.shape(test_samples)[0]
    num_features = np.shape(train_samples)[1]
    num_classes = 10

    # Train and test datasets.
    train_dataset = {'samples':train_samples, 'labels':train_labels, \
                     'num_classes':num_classes, 'num_features': num_features}
    test_dataset = {'samples':test_samples, 'labels':test_labels, \
                    'num_classes':num_classes, 'num_features':num_features}

    return train_dataset, test_dataset

# Show mnist digits.
def show_digits(samples, num_max=8):
    """ This function shows MNIST digits images. These images are selected
    randomly. """

    # Samples.
    samples = samples.T

    # Number of samples.
    num_samples = np.shape(samples)[0]

    # Random indexes.
    indxs = np.random.randint(num_samples, size=num_max)

    # Show samples.
    fig = plt.figure(figsize=(16, 10))
    for c, k in enumerate(indxs):
        img = samples[k,:].reshape((28,28))
        ax = fig.add_subplot(1, num_max, c+1)
        ax.imshow(np.uint8(255*img), cmap='gray')
    plt.show()
