# Libraries.
import sys
import numpy as np

# Test decision tree training.
def test_tree_train(train_fun):
    """ This function tests the implementation to train a decision tree. """

    # Train samples and labels.
    samples = np.array([[1,2],[2,1], [1,1], [2,2], \
                        [-1,-1], [0,0], [-2,0], [-1,0]]).T
    labels = np.array([1, 1, 1, 1, 0, 0, 0, 0])
    dataset = {'samples':samples, 'labels':labels}

    # Train tree.
    dtree = train_fun(dataset, 2, 'entropy')

    # Check decision tree training.
    sys.stdout.write("Checking decision tree training\n")
    sys.stdout.write(">> Running test 1 ... ")
    x = [[2,2]]
    out = dtree.predict(x)
    assert np.isclose(out, 1), 'For the sample %s, the decision tree should predict %s, but returned %s \n' % (str(x), str(1), str(out))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Running test 2 ... ")
    x = [[0,0]]
    out = dtree.predict(x)
    assert np.isclose(out, 0), 'For the sample %s, the decision tree should predict %s, but returned %s \n' % (str(x), str(0), str(out))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Running test 3 ... ")
    x = [[5,5]]
    out = dtree.predict(x)
    assert np.isclose(out, 1), 'For the sample %s, the decision tree should predict %s, but returned %s \n' % (str(x), str(1), str(out))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Running test 4 ... ")
    x = [[-1,0]]
    out = dtree.predict(x)
    assert np.isclose(out, 0), 'For the sample %s, the decision tree should predict %s, but returned %s \n' % (str(x), str(0), str(out))
    sys.stdout.write(" ok\n")
    
    sys.stdout.write("The exercise is correct. Well done !!!\n")


# Test classification accuracies.
def test_accuracies(train_acc, test_acc, error=0.005):
    ''' This function tests the obtained train and test classification
    accuracies. '''

    # Fix random seed.
    np.random.seed(33)
    a1 = np.random.randn()*train_acc
    a2 = np.random.randn()*test_acc
    
    # Checking accuracy scores.
    sys.stdout.write("Checking classification accuracies\n")
    sys.stdout.write(">> Checking training accuracy ... ")
    assert np.isclose(a1, -0.311838, error), 'The accuracy value is not correct, please try again'
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Checking test accuracy ... ")
    assert np.isclose(a2, -1.490771, error), 'The accuracy value is not correct, please try again'
    sys.stdout.write(" ok\n")
    
    sys.stdout.write("The exercise is correct. Well done !!!\n")

