# Libraries.
import sys
import numpy as np

# Test linear svm.
def test_linear_svm(linear_svm):
    """This function tests the implementation of the linear svm classifier. """

    # Data.
    X = np.array([[10,2],[9,3],[5,5],[0,0],[1,1],[-2,0]])
    y = np.array([1,1,1,-1,-1,-1])

    # Run linear SVM.
    svm = linear_svm(X, y, 1)
    p = svm.predict(X)

    # Test 1.
    sys.stdout.write("Checking implementation\n")
    sys.stdout.write(">> Running test 1 ... ")
    val = np.sum(p)
    assert np.isclose(val, 0), 'The sum of element in the prediction vector should return %s, but returned %s \n' % (str(0), str(val))
    sys.stdout.write(" ok\n")

    X = np.array([[7,3],[4,4],[0,0],[1,1],[1,-2],[-2,0], [-10, -2]])
    p = svm.predict(X)

    sys.stdout.write(">> Running test 2 ... ")
    val = np.sum(p)
    assert np.isclose(val, -3), 'The sum of element in the prediction vector should return %s, but returned %s \n' % (str(-3), str(val))
    sys.stdout.write(" ok\n")

    sys.stdout.write("The exercise is correct. Well done !!!\n")

# Test rbf svm.
def test_rbf_svm(rbf_svm):
    """This function tests the implementation of the rbf svm classifier. """

    # Data.
    X = np.array([[1,2],[0,3],[4,5],[2,5],[3,1],[-2,0], [5,1], [4,3]])
    y = np.array([1,1,1,-1,-1,-1, 1, -1])

    # Run rbf SVM.
    svm = rbf_svm(X, y, 1, 10)
    p = svm.predict(X)

    # Test 1.
    sys.stdout.write("Checking implementation\n")
    sys.stdout.write(">> Running test 1 ... ")
    val = np.sum(p)
    assert np.isclose(val, 0), 'The sum of element in the prediction vector should return %s, but returned %s \n' % (str(0), str(val))
    sys.stdout.write(" ok\n")

    X = np.array([[7,3],[4,4],[0,0],[1,1],[1,-2],[-2,0], [-10, -2]])
    p = svm.predict(X)

    sys.stdout.write(">> Running test 2 ... ")
    val = np.sum(p)
    assert np.isclose(val, -5), 'The sum of element in the prediction vector should return %s, but returned %s \n' % (str(-5), str(val))
    sys.stdout.write(" ok\n")

    sys.stdout.write("The exercise is correct. Well done !!!\n")

