//------------------------------------------------------------------------------
// Tasting families of features for image classification.
// 
// Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
// Written by Charles Dubout <charles.dubout@idiap.ch>
// 
// This file is part of Tasting.
// 
// Tasting is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 2 as
// published by the Free Software Foundation.
// 
// Tasting is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with Tasting. If not, see <http://www.gnu.org/licenses/>.
//------------------------------------------------------------------------------

#include "TastingBoostMH.h"
#include "Utils.h"

#include <algorithm>
#include <cassert>
#include <cmath>
#include <iostream>
#include <iomanip>
#include <limits>

using namespace ML;
using namespace std;

TastingBoostMH::TastingBoostMH(unsigned int nbRounds,
							   unsigned int nbFeatures,
							   unsigned int nbInitial,
					   		   bool single,
							   InputSet* testSet)
: nbRounds_(nbRounds), nbFeatures_(nbFeatures), nbInitial_(nbInitial),
  single_(single), testSet_(testSet) {
	// The number of rounds/features/initial must be strictly positive
	assert(nbRounds);
	assert(nbFeatures);
	assert(nbInitial);
}

Classifier* TastingBoostMH::clone() const {
	return new TastingBoostMH(*this);
}

void TastingBoostMH::train(InputSet& inputSet) {
	// Clear the previous classifier
	stumps_.clear();
	alphas_.clear();

	// Get the number of samples, features, labels, and heuristics
	const unsigned int nbSamples = inputSet.nbSamples();
	const unsigned int nbFeatures = inputSet.nbFeatures();
	const unsigned int nbLabels = inputSet.nbLabels();
	const unsigned int nbHeuristics = inputSet.nbHeuristics();

	// Get the labels associated to every sample
	const unsigned int* labels = inputSet.labels();

	// Set the distribution of weights uniformly
	vector<vector<double> > weights(nbSamples);
	vector<vector<double> > hypotheses(nbLabels);
	vector<vector<double> > testHypotheses(nbLabels);

	for(unsigned int s = 0; s < nbSamples; ++s)
		weights[s].resize(nbLabels, 1.0 / (nbSamples * nbLabels));

	for(unsigned int l = 0; l < nbLabels; ++l) {
		hypotheses[l].resize(nbSamples, 0.0);

		if(testSet_)
			testHypotheses[l].resize(testSet_->nbSamples(), 0.0);
	}

	// Separate the features by heuristic
	vector<vector<unsigned int> > heuristics(nbHeuristics);

	for(unsigned int f = 0; f < nbFeatures; ++f)
		heuristics[inputSet.heuristic(f)].push_back(f);

	// Fill the matrices by sampling nbInitial features at random from every
	// heuristic
	vector<vector<scalar_t> > matrices(nbHeuristics);
	vector<vector<vector<unsigned int> > > orderings(nbHeuristics);

	for(unsigned int h = 0; h < nbHeuristics; ++h) {
		// This should be taken care of by Mash
		assert(!heuristics[h].empty());

		if(heuristics[h].size() <= nbInitial_) {
			inputSet.pushFeatures(heuristics[h]);
		}
		else {
			vector<unsigned int> indices(heuristics[h]);
			random_shuffle(indices.begin(), indices.end());
			indices.resize(nbInitial_);
			inputSet.pushFeatures(indices);
		}

		orderings[h].resize(inputSet.nbFeatures());

		for(unsigned int f = 0; f < orderings[h].size(); ++f) {
			orderings[h][f].resize(nbSamples);

			for(unsigned int s = 0; s < nbSamples; ++s)
				orderings[h][f][s] = s;

			Utils::sort(&orderings[h][f][0], inputSet.samples(f), nbSamples);
		}

		inputSet.swapSamples(matrices[h]);
		inputSet.popFeatures();
	}

	// Look-up table
	vector<double> lookup;

	if(single_) {
		lookup.resize(nbInitial_);

		for(unsigned int i = 0; i < nbInitial_; ++i)
			lookup[i] = pow(i + 1.0, double(nbFeatures_));

		for(unsigned int i = nbInitial_ - 1; i > 0; --i)
			lookup[i] -= lookup[i - 1];
	}

	// The (average) loss
	double logLoss = 0.0;

	// Do nbRounds rounds of boosting
	for(unsigned int r = 0; r < nbRounds_; ++r) {
		cout.precision(4);
		cout.setf(ios::fixed, ios::floatfield);

		// Compute the edges for every label
		vector<double> edges(nbLabels, 0.0);

		for(unsigned int s = 0; s < nbSamples; ++s)
			for(unsigned int l = 0; l < nbLabels; ++l)
				edges[l] += (labels[s] == l) ? weights[s][l] : -weights[s][l];

		// Compute the edges of every sampled features
		vector<vector<double> > edgesDistr(nbHeuristics);

		for(unsigned int h = 0; h < nbHeuristics; ++h) {
			vector<unsigned int> indices(matrices[h].size() / nbSamples);
			inputSet.pushFeatures(indices);
			inputSet.swapSamples(matrices[h], false);

			Stump dummy;
			train(inputSet, orderings[h], weights, edges, dummy, edgesDistr[h]);

			inputSet.swapSamples(matrices[h], false);
			inputSet.popFeatures();
		}

		Stump stump;
		double edge = 0.0;
		unsigned int heuristic = 0;
		const scalar_t* samples = 0;

		if(single_) {
			double maximum = 0.0;

			cout << "Expectations:";

			for(unsigned int h = 0; h < nbHeuristics; ++h) {
				sort(edgesDistr[h].begin(), edgesDistr[h].end());

				double expectation = 0.0;

				for(unsigned int k = 0; k < edgesDistr[h].size(); ++k) {
					expectation += lookup[k] * edgesDistr[h][k];
				}

				expectation /= pow(double(edgesDistr[h].size()),
										double(nbFeatures_));

				cout << setw(7) << expectation;

				if(expectation > maximum) {
					maximum = expectation;
					heuristic = h;
				}
			}

			cout << '.' << endl;

			// Sample nbFeatures_ features of the selected heuristic
			vector<unsigned int> indices(nbFeatures_);

			for(unsigned int f = 0; f < nbFeatures_; ++ f)
				indices[f] = heuristics[heuristic][rand() % heuristics[heuristic].size()];

			sort(indices.begin(), indices.end());

			// Push those features on the input set
			inputSet.pushFeatures(indices);

			vector<vector<unsigned int> > indexes(nbFeatures_);

			for(unsigned int f = 0; f < nbFeatures_; ++f) {
				indexes[f].resize(nbSamples);

				for(unsigned int s = 0; s < nbSamples; ++s)
					indexes[f][s] = s;

				Utils::sort(&indexes[f][0], inputSet.samples(f), nbSamples);
			}

			vector<double> dummies;
			edge = train(inputSet, indexes, weights, edges, stump, dummies);

			samples = inputSet.samples(stump.feature_);
			stump.feature_ = indices[stump.feature_];
		}
		else {
			bool changed = true;
			vector<double> expecteds(nbHeuristics);
			unsigned int argmax = 0;

			for(unsigned int k = 0; k < nbFeatures_; ++k) {
				double max = 0.0;

				cout << "K: " << setw(2) << k << ", expected errors:";

				if(changed) {
					for(unsigned int h = 0; h < nbHeuristics; ++h) {
						double expected = 0.0;

						for(unsigned int f = 0; f < edgesDistr[h].size(); ++f)
							expected += std::max(edge, edgesDistr[h][f]);

						expected /= edgesDistr[h].size();

						expecteds[h] = expected;

						cout << setw(7) << expected;

						if(expected > max) {
							max = expected;
							argmax = h;
						}
					}
				}
				else {
					for(unsigned int h = 0; h < nbHeuristics; ++h)
						cout << setw(7) << expecteds[h];
				}

				// Sample one feature from argmax
				unsigned int f = heuristics[argmax][rand() %
												  	heuristics[argmax].size()];

				// Push that features on the input set
				vector<unsigned int> index(1, f);
				inputSet.pushFeatures(index);

				vector<vector<unsigned int> > indexes(1);
				indexes[0].resize(nbSamples);

				for(unsigned int s = 0; s < nbSamples; ++s)
					indexes[0][s] = s;

				Utils::sort(&indexes[0][0], inputSet.samples(), nbSamples);

				Stump s;
				vector<double> dummies;
				double e = train(inputSet, indexes, weights, edges, s, dummies);

				cout << ", argmax: " << setw(2) << argmax << ", feature: " << setw(6) << f
					 << ", edge: " << setw(6) << e << '.' << endl;

				if(e > edge) {
					edge = e;
					heuristic = argmax;
					stump = s;
					stump.feature_ = f;
					changed = true;
				}
				else {
					changed = false;
				}

				inputSet.popFeatures();
			}

			// Push the best feature on the input set
			vector<unsigned int> index(1, stump.feature_);
			inputSet.pushFeatures(index);
			samples = inputSet.samples();
		}

		// Update the loss
		logLoss += 0.5 * log10(1.0 - edge * edge);

		// Compute the weight to give to the weak learner and update the weights
		// of the samples
		double expAlpha = sqrt((1.0 + edge) / (1.0 - edge));
		double invExpAlpha = 1.0 / expAlpha;
		double alpha = log(expAlpha);
		double norm = 0.0;
		unsigned nbErrors = 0;

		for(unsigned int s = 0; s < nbSamples; ++s) {
			// The label with the maximum hypothesis
			double max = -numeric_limits<double>::infinity();
			unsigned int label;

			bool phi = samples[s] >= stump.split_;

			for(unsigned int l = 0; l < nbLabels; ++l) {
				bool sign = phi ^ stump.signs_[l];

				weights[s][l] *= (sign ^ (l == labels[s])) ? invExpAlpha :
															 expAlpha;

				norm += weights[s][l];

				hypotheses[l][s] += sign ? -alpha : alpha;

				if(hypotheses[l][s] > max) {
					max = hypotheses[l][s];
					label = l;
				}
			}

			if(label != labels[s])
				++nbErrors;
		}
	
		cout << "[TastingBoostMH::train] round: " << setw(4) << r
			 << ", log10(loss): " << setw(7) << logLoss
			 << ", edge: " << setw(6) << edge
			 << ", heuristic: " << setw(2) << heuristic
			 << ", feature: " << setw(6) << stump.feature_
			 << ", training error: " << setw(6) << (float(nbErrors) / nbSamples);

		if(testSet_) {
			unsigned int nbTestSamples = testSet_->nbSamples();
			vector<unsigned int> index(1, stump.feature_);
			testSet_->pushFeatures(index);
			const scalar_t* testSamples = testSet_->samples(0);
			unsigned int nbTestErrors = 0;

			for(unsigned int s = 0; s < nbTestSamples; ++s) {
				// The label with the maximum hypothesis
				double max = -numeric_limits<double>::infinity();
				unsigned int label;

				bool phi = testSamples[s] >= stump.split_;

				for(unsigned int l = 0; l < nbLabels; ++l) {
					bool sign = phi ^ stump.signs_[l];

					testHypotheses[l][s] += sign ? -alpha : alpha;

					if(testHypotheses[l][s] > max) {
						max = testHypotheses[l][s];
						label = l;
					}
				}

				if(label != testSet_->label(s))
					++nbTestErrors;
			}

			testSet_->popFeatures();

			cout << ", test error: " << setw(6) << (float(nbTestErrors) / nbTestSamples);
		}

		cout << '.' << endl;

		// Normalize the weights of the samples
		for(unsigned int s = 0; s < nbSamples; ++s) {
			transform(weights[s].begin(), weights[s].end(),
						   weights[s].begin(),
						   bind2nd(divides<double>(), norm));
		}

		// Add the weak learner and its weight
		stumps_.push_back(stump);
		alphas_.push_back(alpha);

		// Pop the pushed feature
		inputSet.popFeatures();
	}
}

void TastingBoostMH::distribution(InputSet& inputSet,
								 unsigned int sample,
								 scalar_t* distr) const {
	// Get the number of labels
	const unsigned int nbLabels = inputSet.nbLabels();

	// Fill the distribution with zeros
	fill_n(distr, nbLabels, 0);

	// Add to the distribution the result of each weak learner
	for(unsigned int w = 0; w < stumps_.size(); ++w) {
		assert(stumps_[w].signs_.size() <= nbLabels);

		// Push the feature on which the weak learner was trained
		vector<unsigned int> index(1, stumps_[w].feature_);
		inputSet.pushFeatures(index);

		// The unique feature of the sample
		scalar_t f = *inputSet.features(sample);
		bool phi = f >= stumps_[w].split_;

		for(unsigned int l = 0; l < stumps_[w].signs_.size(); ++l)
			distr[l] += (phi ^ stumps_[w].signs_[l]) ? -alphas_[w] : alphas_[w];

		// Pop the pushed feature
		inputSet.popFeatures();
	}
}

void TastingBoostMH::report(vector<unsigned int>& features) const {
	for(unsigned int w = 0; w < stumps_.size(); ++w)
		features.push_back(stumps_[w].feature_);
}

double TastingBoostMH::train(InputSet& inputSet,
							const vector<vector<unsigned int> >& indices,
							const vector<vector<double> >& weights,
							const vector<double> totalEdges,
							Stump& stump,
							vector<double>& edges) {
	// Get the number of samples and labels
	unsigned int nbSamples = inputSet.nbSamples();
	unsigned int nbFeatures = inputSet.nbFeatures();
	unsigned int nbLabels = inputSet.nbLabels();

	// Make sure that the stump has the correct number of signs
	stump.signs_.resize(nbLabels);

	// Make sure that the edge vector has the correct size
	edges.resize(nbFeatures);

	// Make sure the number of indices, weights, and edges are correct
//	assert(indices.size() == nbFeatures);
//	assert(weights.size() == nbLabels);
//	assert(totalEdges.size() == nbLabels);

	// Get the samples' features and labels
	const scalar_t* samples = inputSet.samples();
	const unsigned int* labels = inputSet.labels();

	// The best sum of absolute values of edges so far
	double sumEdges = 0.0;
	stump.feature_ = 0;

	for(unsigned int l = 0; l < nbLabels; ++l) {
		stump.signs_[l] = totalEdges[l] >= 0.0;
		sumEdges += abs(totalEdges[l]);
	}

	stump.split_ = -numeric_limits<scalar_t>::max();

	// Fill the edges vector with the minimum edge
	fill(edges.begin(), edges.end(), sumEdges);

	// The right edges are simply edges - leftEdges
	vector<double> leftEdges(nbLabels, 0.0);

	for(unsigned int f = 0; f < nbFeatures; ++f) {
		// Make sure the number of indices is correct
	//	assert(indices[f].size() == nbSamples);

		// Zero out the left edges vector
		fill(leftEdges.begin(), leftEdges.end(), 0.0);

		// Try to split in between every sample
		for(unsigned int s = 0; s < nbSamples - 1; ++s) {
			unsigned int index = indices[f][s];
			unsigned int nextIndex = indices[f][s + 1];
			scalar_t feature = samples[index];
			scalar_t nextFeature = samples[nextIndex];
			unsigned int label = labels[index];

			// Include the current sample in the left edge
			double sum = 0.0;

			for(unsigned int l = 0; l < nbLabels; ++l) {
			//	assert(weights[l].size() == nbSamples);
				leftEdges[l] += (label == l) ? weights[index][l] :
											  -weights[index][l];

				sum += abs(totalEdges[l] - 2.0 * leftEdges[l]);
			}

			// If a stump can be put in between and with a better sum of edges
			if(feature < nextFeature) {
				if(sum > sumEdges) {
					stump.feature_ = f;

					for(unsigned int l = 0; l < nbLabels; ++l) {
						stump.signs_[l] = totalEdges[l] >= 2.0 * leftEdges[l];
					}

					stump.split_ = (feature + nextFeature) / 2;

					sumEdges = sum;
				}

				if(sum > edges[f]) {
					edges[f] = sum;
				}
			}
		}

		samples += nbSamples;
	}

	return sumEdges;
}
