#    Copyright (c) 2016 Idiap Research Institute, http://www.idiap.ch/
#    Written by Nikolaos Pappas <nikolaos.pappas@idiap.ch>,
#
#    test_eval.py is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#    GNU General Public License for more details.

import sys
import os
import csv
import random
import pickle
import numpy as np
from nltk.tokenize import word_tokenize
from sklearn.metrics import mean_squared_error, accuracy_score


aspects = ['overall', 'performance', 'story']
levels  = ['Not at all', 'A little', 'Moderately',
		   'Rather well', 'Very well']

def load_csv(path=sys.argv[1]):
	with open(path, 'rb') as csvfile:
		reader = csv.DictReader(csvfile)
		i, prev_idx, prev_sidx = 0, -1, -1
		sentences, reviews, ridxs, labels = [], [], [], [[],[],[]]
		for row in reader:
			idx = int(row['idx'])
			sidx = int(row['sidx'])
			sentence = row['sentence'].lower()
			if prev_idx == -1:
				prev_idx = idx
			if idx != prev_idx:
				reviews.append(sentences)
				ridxs.append(prev_idx)
				for j, aspect in enumerate(aspects):
					labels[j].append(row[aspect])
				sentences = [ word_tokenize(sentence) ]
			else:
				if sidx != prev_sidx:
					sentences.append( word_tokenize(sentence) )

			prev_idx = idx
			prev_sidx = sidx
			i += 1
		reviews.append(sentences)
		ridxs.append(prev_idx)
		for j, aspect in enumerate(aspects):
			labels[j].append(row[aspect])
	return reviews, ridxs, labels

def load_hash(path):
	exp_hash = {}
	with open(path, 'rb') as csvfile:
		reader = csv.DictReader(csvfile)
		count = 0
		for row in reader:
			idx = int(row['idx'])
			sidx = int(row['sidx'])
			asp = row['aspect']
			exp = row['exp_power']
			con = row['confidence']
			level_tup =  (levels.index(exp), con)
			if idx not in exp_hash:
				exp_hash[idx] = { sidx : { aspects.index(asp): level_tup}}
			elif sidx not in exp_hash[idx]:
				exp_hash[idx][sidx] = { aspects.index(asp): level_tup}
				exp_hash[idx][sidx][aspects.index(asp)] = level_tup
			else:
				exp_hash[idx][sidx][aspects.index(asp)] = level_tup
			count += 1
	return exp_hash


def eval_ratpred(pred_folder, conf_thresh=0.5, path='../aggregated_test.csv'):
	reviews, ridxs, labels = load_csv(path)
	exp_hash = load_hash(path)
	all_cts = {}
	for k, aspect in enumerate(aspects):
		rat_preds = open(pred_folder+'%s.txt' % aspect).read().split('\n')
		rat_preds = [float(v) for v in rat_preds]
		rat_actual = np.array([float(v) for v in labels[k]])/5.
		mse = mean_squared_error(rat_preds,rat_actual)
		all_cts[aspect] = mse
	print all_cts
	return all_cts

def eval_summary(pred_folder, conf_thresh=0.5, path='../aggregated_test.csv'):
	reviews, ridxs, labels = load_csv(path)
	exp_hash = load_hash(path)
	all_cts = {}
	for c_t in [0.1*i for i in range(11)]:
		confs = {'overall':[],'performance':[],'story':[]}
		for k, aspect in enumerate(aspects[:]):
			aspect_hum, aspect_sys, count = [], [], 0
			att_preds = open(pred_folder+'%s.txt' % aspect).read().split('\n')
			att_preds = [float(v) for v in att_preds]
			if aspect not in all_cts:
				all_cts[aspect] = []
			for i, idx in enumerate(ridxs):
				for j, sidx in enumerate(exp_hash[idx].keys()):
					e = exp_hash[idx][sidx][k][0] + 1
					c = float(exp_hash[idx][sidx][k][1])
					if c >= c_t:
						aspect_hum.append(e)
						aspect_sys.append(att_preds[count])
					count += 1
			acc, count = 0.0, 0.0
			for i, hum in enumerate(aspect_hum):
				if  abs(hum - aspect_sys[i]) <= conf_thresh:
					acc += 1.0
				count += 1.0
			all_cts[aspect].append( acc/count )
	print all_cts
	return all_cts


if __name__ == "__main__":
	file_path = sys.argv[1]
	mode = sys.argv[2]
	if mode == "--sum":
		all_acc = eval_summary(file_path)
	if mode == "--rat":
		all_rat = eval_ratpred(file_path)
