# Libraries.
import sys
import numpy as np
from numpy import random as rnd

# Test load dataset implementation.
def test_load_dataset(dataset):
    """ This function tests the implementation to load the synthetic dataset.

    Parameters:
    + dataset (dict): Dictionary containing the dataset with the 
        following data:
        * samples (Numpy ndarray): 2xM array with the samples 
            (2D points), being M the number of samples.
        * labels (Numpy ndarray): 1xM array with the class labels
            for all samples. Labels are numbers in the range 
            [1, N], being N the number of classes.
        * num_classes (int): The number of classes for the
            selected scenario.

    Returns:
        None
    """

    # Dataset.
    samples = dataset['samples']
    labels = dataset['labels']
    num_classes = dataset['num_classes']

    # Check number of classes.
    sys.stdout.write("Checking exercise\n")
    sys.stdout.write(">> Checking the number of classes ... ")
    assert num_classes ==4, 'Num. classes should be %s, but returned %s \n'\
            % (4, num_classes)
    sys.stdout.write(" ok\n")

    # Check the number of samples.
    num_c1 = len(labels[labels==1])
    num_c2 = len(labels[labels==2])
    num_c3 = len(labels[labels==3])
    num_c4 = len(labels[labels==4])
    sys.stdout.write(">> Checking the number of samples for class 1 ... ")
    assert num_c1 ==300, 'Num. samples should return %s, but returned %s\n'\
            % (300, num_c1)
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking the number of samples for class 2 ... ")
    assert num_c2 == 50, 'Num. samples should return %s, but returned %s\n'\
            % (50, num_c2)
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking the number of samples for class 3 ... ")
    assert num_c3 == 100, 'Num. samples should return %s, but returned %s\n'\
            % (100, num_c3)
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking the number of samples for class 4 ... ")
    assert num_c4 == 500, 'Num. samples should return %s, but returned %s\n'\
            % (500, num_c4)
    sys.stdout.write(" ok\n")

    # Standard deviation.
    std_1 = np.std(samples[:,labels[0,:]==1],axis=1)
    std_2 = np.std(samples[:,labels[0,:]==2],axis=1)
    std_3 = np.std(samples[:,labels[0,:]==3],axis=1)
    std_4 = np.std(samples[:,labels[0,:]==4],axis=1)

    # Check noise.
    sys.stdout.write(">> Checking noise for class 1 ... ")
    ref_1 = [0.04625352, 0.04268269]
    assert np.isclose(std_1, ref_1, 0.001).all(), \
            'It should return %s, but returned %s\n' % (str(ref_1), str(std_1))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking noise for class 2 ... ")
    ref_2 = [0.04617169, 0.04489254]
    assert np.isclose(std_2, ref_2, 0.001).all(), \
            'It should return %s, but returned %s\n' % (str(ref_2), str(std_2))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking noise for class 3 ... ")
    ref_3 = [0.07756618, 0.03676369]
    assert np.isclose(std_3, ref_3, 0.001).all(), \
            'It should return %s, but returned %s\n' % (str(ref_3), str(std_3))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking noise for class 4 ... ")
    ref_4 = [0.04611338, 0.04503732]
    assert np.isclose(std_4, ref_4, 0.001).all(), \
            'It should return %s, but returned %s\n' % (str(ref_4), str(std_4))
    sys.stdout.write(" ok\n")

    sys.stdout.write("The exercise is correct. Well done !!!\n")
    
# test mean and std implementation.
def test_mean_std(class_mean_std_fun, dataset):
    """ This function tests the implementation of mean and standard deviation
    using the input dataset.

    Parameters:
    + dataset (dict): Dictionary containing the dataset with the 
        following data:
        * samples (Numpy ndarray): 2xM array with the samples 
            (2D points), being M the number of samples.
        * labels (Numpy ndarray): 1xM array with the class labels
            for all samples. Labels are numbers in the range 
            [1, N], being N the number of classes.
        * num_classes (int): The number of classes for the
            selected scenario.

    Returns:
        None
    """

    # Dataset.
    samples = dataset['samples']
    labels = dataset['labels']
    num_classes = dataset['num_classes']

    # Compute mean and standard matrices for input dataset.
    list_mean = []
    list_std = []
    for c in range(num_classes):
        mean, std = class_mean_std_fun(samples, labels, c+1)
        list_mean.append(mean)
        list_std.append(std)
    mat_mean = np.array(list_mean)
    mat_std = np.array(list_std)

    # Check mean and std values.
    sys.stdout.write("Checking implementation\n")
    sys.stdout.write(">> Checking mean values ... ")
    ref_mean = np.array([[0.25265312, 0.30313714], [0.74553564, 0.74900793], \
                         [0.39830664, 0.58051154], [0.70228882, 0.25130855]])
    assert np.isclose(mat_mean, ref_mean, 0.001).all(), \
            'It should return %s, but returned %s\n' % (str(ref_mean), str(mat_mean))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking standard deviation values ... ")
    ref_std = np.array([[0.04625352, 0.04268269],  [0.04617169, 0.04489254], \
                        [0.07756618, 0.03676369], [0.04611338, 0.04503732]])
    assert np.isclose(mat_std, ref_std, 0.001).all(), \
            'It should return %s, but returned %s\n' % (str(ref_std), str(mat_std))
    sys.stdout.write(" ok\n")

    sys.stdout.write("The exercise is correct. Well done !!!\n")
    