# Libraries.
import sys
import numpy as np

# Test random forest training.
def test_rf_training(train_fun):
    """ This function tests the implementation to train a random forest. """

    # Train samples and labels.
    samples = np.array([[1,2],[2,1], [1,1], [2,2], \
                        [-1,-1], [0,0], [-2,0], [-1,0]]).T
    labels = np.array([[1, 1, 1, 1, 0, 0, 0, 0]])
    dataset = {'samples':samples, 'labels':labels}

    # Train random forest.
    rfs = train_fun(dataset, 5, 2, 1, 'entropy')

    # Check decision tree training.
    sys.stdout.write("Checking random forest training\n")
    sys.stdout.write(">> Running test 1 ... ")
    x = [[2,2]]
    out = rfs.predict(x)
    assert np.isclose(out, 1), 'For the sample %s, the random forest should predict %s, but returned %s \n' % (str(x), str(1), str(out))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Running test 2 ... ")
    x = [[0,0]]
    out = rfs.predict(x)
    assert np.isclose(out, 0), 'For the sample %s, the random forest should predict %s, but returned %s \n' % (str(x), str(0), str(out))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Running test 3 ... ")
    x = [[5,5]]
    out = rfs.predict(x)
    assert np.isclose(out, 1), 'For the sample %s, the random forest should predict %s, but returned %s \n' % (str(x), str(1), str(out))
    sys.stdout.write(" ok\n")
    sys.stdout.write(">> Running test 4 ... ")
    x = [[-1,0]]
    out = rfs.predict(x)
    assert np.isclose(out, 0), 'For the sample %s, the random forest should predict %s, but returned %s \n' % (str(x), str(0), str(out))
    sys.stdout.write(" ok\n")

    sys.stdout.write("The exercise is correct. Well done !!!\n")
