# Libraries.
import cv2
import numpy as np
from matplotlib import pyplot as plt
import scipy.io as sio

# Load 2D dataset.
def load_2d_dataset(path):
    """ This function loads the 2D dataset which is a text file. """

    # Variables.
    data = []  # 2D feature vectors (points).
    labels = []  # Class labels.

    # Load file.
    data_file = open(path,'r')
    for line in data_file:
        # Read label and features.
        label, f1, f2 = line.split()
        x1 = f1.split(':')[1]
        x2 = f2.split(':')[1]
        # Save values.
        labels.append(int(label))
        data.append([float(x1),float(x2)])
    # close file.
    data_file.close()

    # Convert to numpy arrays.
    data = np.array(data)  # Nx2 array.
    labels = np.array(labels)  # N vector.

    return data, labels

# Read emails.
def read_emails(path, num_emails, num_words):
    """ This function reads the emails which is a text file. """
    
    # Open txt file.
    data_file = open(path,'r')

    # Allocate array.
    data = np.zeros((num_emails, num_words))
    
    # Read emails.
    for line in data_file:
        tmp = line.split(' ')
        d = int(tmp[0])-1  # Document -email-
        w = int(tmp[1])-1  # Word.
        o = int(tmp[2][:-1]) # occurrence.
        data[d, w] += o  # Update array.
        
    # Close file.
    data_file.close()
 
    return data

# Email labels.
def email_labels(path, num_emails):
    """ This function extracts the labels for emails (spam/nonspam) """
    
    # Open txt file.
    data_file = open(path,'r')

    # Allocate array.
    data = np.zeros(num_emails)
    
    # Read emails.
    for n, line in enumerate(data_file):
        data[n] = int(line)
        
    # Close file.
    data_file.close()
 
    return data

# Load emails.
def load_emails(num_emails, num_words=2500, path='./data/emails/'):
    """ This function loads the emails dataset which is a 
        text file. """

    # Check num. emails.
    if not num_emails in [50, 100, 400, 700]:
        print ('Incorrect number of emails. It should be 50, 100, 400 or 700')
        raise

    # Num. test emails.
    num_test_emails = 260

    # Dataset file.
    train_data_path = path+'train-features-{0}.txt'.format(num_emails)
    test_data_path = path+'test-features.txt'
    train_labels_path = path+'train-labels-{0}.txt'.format(num_emails)
    test_labels_path = path+'test-labels.txt'

    # Load train and test emails data (arrays).
    train_data = read_emails(train_data_path, num_emails, num_words)
    test_data = read_emails(test_data_path, num_test_emails, num_words)

    # Load train and test emails labels (spam/nonspam).
    train_labels = email_labels(train_labels_path, num_emails)
    test_labels = email_labels(test_labels_path, num_test_emails)

    return train_data, train_labels, test_data, test_labels

# Load mat file.
def load_mat(path):
    """ This function loads a Matlab (mat) file and convert it into a python
    dictionary. """

    # Load file.
    mat = sio.loadmat(path)
    # Extract data.
    data = {}
    for k in mat.keys():
        if '__' in k: continue
        data[k] = mat[k]
    return data

# Load partitioned dataset.
def load_partitioned_dataset():

    # Parameters.
    dat_path = './data/ex2Data/V.mat'  # Data file.
    lab_path = './data/ex2Data/L.mat'  # labels file.

    # Load data and labels.
    data = load_mat(dat_path)
    label = load_mat(lab_path)

    # Get sets.
    d1 = data['V'][0,0]
    d2 = data['V'][0,1]
    d3 = data['V'][0,2]
    l1 = label['L'][0,0].reshape(len(d1))
    l2 = label['L'][0,1].reshape(len(d2))
    l3 = label['L'][0,2].reshape(len(d3))

    return [d1,d2,d3], [l1,l2,l3]

