# Libraries.
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# Load dataset.
def load_dataset(path='../data/USPS.mat', digit=None):
    """ This function loads the USPS digits dataset. """

    # Load dataset.
    dataset = load_mat(path)

    # Get images and their labels.
    X = dataset['X']  # Images.
    y = dataset['Y']  # Labels.

    # Dataset subset for a particular digit.
    if digit:
        indx = np.where(y[:,0]==digit+1)[0]  # Indexes for the chosen digit.
        X = X[indx,:]  # Images for the chosen digit.
        y = y[indx]  # Labels for the chose digit.

    return X, y

# Load dataset.
def load_yale_face_dataset(subset='set1'):
    """ This function loads the Yale faces dataset. This dataset 
    has several subsets. """

    # Check subset.
    assert subset in ['set1', 'set2', 'set3', 'same'], "Incorrect subset"

    # Load dataset.
    if subset=='set1':dataset = load_mat('../data/Subset1YaleFaces.mat')
    if subset=='set2':dataset = load_mat('../data/Subset2YaleFaces.mat')
    if subset=='set3':dataset = load_mat('../data/Subset3YaleFaces.mat')
    if subset=='same':dataset = load_mat('../data/SameFace.mat')

    # Get data.
    if subset in ['set1', 'set2', 'set3']:
        # Get images and their labels.
        X = dataset['X']  # Images.
        y = dataset['Y']  # Labels.
    else:
        X = dataset['images']  # Images.
        y = np.array([0])  # Empty.

    # Rotate images.
    N = X.shape[0]
    X_imgs = np.reshape(X, (N,50,50))
    X_copy = X.copy()
    for n in range(N):
        img = X_imgs[n,:,:].T
        X_copy[n,:] = img.flatten()
    X = X_copy

    return X, y

# Load mat file.
def load_mat(path):
    """ This function loads a Matlab (mat) file and convert it into a python
    dictionary. """
    # Load file.
    mat = sio.loadmat(path)
    # Extract data.
    data = {}
    for k in mat.keys():
        if '__' in k: continue
        data[k] = mat[k]
    return data

# Show images.
def show_images(images, title=None, maxi=10):
    """ This function shows the input images."""

    # Number of images and image size.
    N = images.shape[0]
    S = int(np.sqrt(images.shape[1]))

    # Number of images.
    N = np.minimum(maxi, N)

    # Figure.
    fig=plt.figure(figsize=(24,3))
    if title:plt.title(title, fontsize=28)
    plt.axis('off')
    for n in range(N):
        ax = fig.add_subplot(1,N,n+1)
        img = images[n,:].reshape((S,S))
        plt.imshow(img, cmap='gray')
    plt.show()

# Plot eigenvalues.
def plot_eigenvalues(lambdas):
    """ The function plots the eigenvalues including their cumulative
    distribution lambda is the set of all eigenvalues of the data covariance
    matrix ranked in decreasing order.  """

    # Num. lambda elements.
    L = len(lambdas)

    # Figure.
    plt.figure(figsize=(15,5))
    plt.subplot(2,1,1)
    plt.plot(np.arange(0, L)+1, lambdas, 'b-', linewidth=4)
    plt.xlabel('Component', fontsize=18)
    plt.ylabel('Eigenvalue', fontsize=18)
    plt.grid(True)

    # Cumulative.
    cumul = np.zeros(L)
    total_variance = np.sum(lambdas)
    for i in range(L):
        cumul[i] = np.sum(lambdas[:i+1])/total_variance 
    plt.subplot(2,1,2)
    plt.plot(np.arange(0, L)+1, cumul, 'r-', linewidth=4)
    plt.xlabel('Component', fontsize=18)
    plt.ylabel('Cumulative', fontsize=18)
    plt.grid(True)
    plt.show()

