import time
import numpy as np
from numpy.linalg import norm
from sklearn.base import BaseEstimator
from scipy.sparse import lil_matrix, hstack
from sklearn.metrics import mean_squared_error, mean_absolute_error


class JSWeights(BaseEstimator):

	def __init__(self, l1=0.2, l2=0.0, l3=50):
		self.adagrad = True
		self.l1 = 0.1 if l1 is None else l1
		self.l2 = 0.0 if l2 is None else l2
		self.l3 = 50  if l3 is None else l3

	def predict(self, B, validate=False):
		pred = []
		if not validate:
			B = self.intercept_sparse(B)
 		for bi in B:
 			pred.append(self.hypo(bi)[0][0])
		return np.array(pred)

 	def predict_weights(self, B, norm=False):
		pred = []
		weights = []
		labels = []
		B = self.intercept_sparse(B)
		for bi in B:
			x = bi.dot(self.O)
			e_x = np.exp(x - np.max(x))
			w = e_x / e_x.sum()
			x_out = bi.T.dot(w)#weights.T.dot(bi)
			p = 0.5*self.W.T.dot(x_out)
			pred.append(p[0][0])
			weights.append(np.array([ww[0] for ww in w]))
			labels.append(np.array([ww[0]*e_x.sum() for ww in w]))
		return weights, labels, pred


	def fit(self, X, Y):
		B = self.intercept_sparse(X)
		self.W = np.zeros((B[0].shape[1], 1), dtype=np.float64)
		self.O = np.zeros((B[0].shape[1], 1), dtype=np.float64)
		a, mom, minib = self.l1, self.l2, self.l3
		curb, epoch, maxiter, prev_mse = 0, 0, 200, 99999
		converged = False
		total_sec = []

		print "alpha=%.3f" % a
		print "momentum=%.3f" % mom
		print "minibatch=%d" % minib
		if self.adagrad:
			gw = 0
			go = 0
			fudge = 1e-6
			epsilon = 0.001
		else:
			vw = self.W.copy()
			vo = self.O.copy()

		while not converged:
			curb = 0
			start = time.time()
			while( curb < len(B) ):
				sum_errw = 0.0
				sum_erro = 0.0
				tob = curb+minib
				for i, bi in enumerate(B[curb:tob]):
					sum_errw += self.d_W(bi, Y[curb:tob][i])
					sum_erro += self.d_O(bi, Y[curb:tob][i])
					#if epoch == 1:
					#	self.grad_check_o(bi, Y[curb:tob][i])

				sum_errw = sum_errw/(minib*1.0)
				sum_erro = sum_erro/(minib*1.0)
				if not self.adagrad:
					vw = mom * vw - a * sum_errw
					vo = mom * vo - a * sum_erro
					self.W  += -mom * sum_errw + (1 + mom) * vw
					self.O  += -mom * sum_erro + (1 + mom) * vo
	 			else:
					#sum_errw += epsilon * self.W
					#sum_erro += epsilon * self.O
				 	gw += pow(sum_errw,2)
				 	go += pow(sum_erro,2)
					adjusted_w = sum_errw / (fudge + np.sqrt(gw))
					adjusted_o = sum_erro / (fudge + np.sqrt(go))
					self.W = self.W - a*adjusted_w
					self.O = self.O - a*adjusted_o

				curb += minib

			pred = self.predict(B, validate=True)
			cur_mse = mean_squared_error(pred, Y)
			elapsed = (time.time() - start)
			total_sec.append(elapsed)
			print "epoch -> %d / mse: %.6f (%.2f sec)" % (epoch, cur_mse, elapsed)
			if prev_mse - cur_mse < 0.0001 or epoch > maxiter:
				converged = True
			if epoch > 0 and not self.adagrad:
				a = a*0.998
			prev_mse = cur_mse
			epoch += 1

		tot = sum(total_sec)
		avg = (sum(total_sec)/(1.0*len(total_sec)))
		print "total time: %.2f" % tot
		print "time/epoch: %.2f" % avg
		return 	tot, avg, epoch

	def intercept_sparse(self, X):
		new_x = []
		for i,xi in enumerate(X):
			intercept = np.ones((xi.shape[0],1))
			new_x.append(hstack([intercept,xi]).tocsr())
		return new_x

	def softmax(self, x):
		e_x = np.exp(x - np.max(x))
		out = e_x / e_x.sum()
		return out

	def mul(self, bi):
		weights = self.softmax(bi.dot(self.O))
		return bi.T.dot(weights)#weights.T.dot(bi)

	def hypo(self, bi):
		x_out = self.mul(bi)
		return 0.5*self.W.T.dot(x_out)

	def d_W(self, bi, yi):
		df_dw = self.mul(bi)
		dl_df = (0.5*self.W.T.dot(df_dw) - yi).view(np.ndarray)[0][0]
		return (dl_df * df_dw)

	def d_O(self, bi, yi):
		pi = self.softmax(bi.dot(self.O))
		x_out = bi.T.dot(pi)
		dl_df = (0.5*self.W.T.dot(x_out) - yi)
		df_dg = bi.dot(self.W) # 1 x ni
 		dg_do = (np.identity(bi.shape[0])-pi)*pi
		return bi.T.dot((dl_df*df_dg).T.dot(dg_do).T)

	def L(self, bi, yi, W, O):
		self.W, self.O = W, O
		return pow((self.hypo(bi) - yi),2)

	def grad_check_w(self, bi, yi):
		w_e = self.W
		epsilon = 0.00001
		for i in range(len(w_e)):
			cur_p = w_e.copy()
			cur_p[i] = cur_p[i] + epsilon
			cur_m = w_e.copy()
			cur_m[i] = cur_m[i] - epsilon
			actual = (self.L(bi,yi, cur_p, self.O) - self.L(bi, yi, cur_m, self.O))/(2.0*epsilon)
			approx = self.d_W(bi, yi)[i]
			if actual != 0:
				nom = abs(approx - actual)[0][0]
				denom = np.max([approx[0], actual[0][0]])
				if nom/denom < 0.0001:
				 	print "[+] approx: %.8f / actual: %.8f " % (approx, actual)
				else:
					print "[-] approx: %.8f / actual: %.8f " % (approx, actual)

	def grad_check_o(self, bi, yi):
		o_e = self.O
		epsilon = 0.00001
		print
		for i in range(len(o_e)):
			cur_p = o_e.copy()
			cur_p[i] = cur_p[i] + epsilon
			cur_m = o_e.copy()
			cur_m[i] = cur_m[i] - epsilon
			actual = (self.L(bi,yi, self.W, cur_p) - self.L(bi, yi, self.W, cur_m))/(2.0*epsilon)
			approx = self.d_O(bi, yi)[i]
			if actual != 0 or approx !=0:
				nom = abs(approx - actual)[0][0]
				denom = np.max([approx[0], actual[0][0]])
				if abs(approx - actual)[0][0] < 0.0001:
				 	print "[+] -> (%d) approx: %.10f / actual: %.10f " % (i, approx, actual)
				else:
					print "[-] -> (%d) approx: %.10f / actual: %.10f " % (i, approx, actual)
