//------------------------------------------------------------------------------
// Tasting families of features for image classification.
// 
// Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
// Written by Charles Dubout <charles.dubout@idiap.ch>
// 
// This file is part of Tasting.
// 
// Tasting is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 2 as
// published by the Free Software Foundation.
// 
// Tasting is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with Tasting. If not, see <http://www.gnu.org/licenses/>.
//------------------------------------------------------------------------------

#include "MLInputSet/FileInputSet.h"
#include "MLClassifiers/LazyBoostMH.h"
#include "MLClassifiers/UCBBoostMH.h"
#include "MLClassifiers/Exp3pBoostMH.h"
#include "MLClassifiers/EGreedyBoostMH.h"
#include "MLClassifiers/TastingBoostMH.h"

#include <algorithm>
#include <ctime>
#include <iostream>

using namespace ML;
using namespace std;

int main(int argc, char* const argv[]) {
	// Check for correct usage of the command line
	if(argc < 7) {
		cerr << "Usage: " << argv[0] << " train_t.raw train_labels.txt test_t.raw test_labels.txt heuristics.txt"
				" N/l/L/U/X/P/t/T [T] [Q] [R/scale of the rewards]" << endl;
		return -1;
	}

	int seed = time(0);
	cout << "Seed: " << seed << '.' << endl;
	srand(seed);
	srand48(seed);

	// Create the input sets
	FileInputSet trainSet(argv[1], argv[2], argv[5]);
	FileInputSet testSet(argv[3], argv[4], argv[5]);

	Classifier* classifier;

	unsigned int T = (argc >= 8) ? atoi(argv[7]) : 100;
	unsigned int Q = (argc >= 9) ? atoi(argv[8]) : 10;
	float Rscale = (argc >= 10) ? atof(argv[9]) : -1.0f;

	if(argv[6][0] == 'N') {
		vector<unsigned int> data(trainSet.nbFeatures(), 0);
		trainSet.swapHeuristics(data);
		classifier = new LazyBoostMH(T, Q, true, &testSet);
	}
	else if(argv[6][0] == 'L') {
		classifier = new LazyBoostMH(T, Q, true, &testSet);
	}
	else if(argv[6][0] == 'l') {
		classifier = new LazyBoostMH(T, Q, false, &testSet);
	}	
	else if(argv[6][0] == 'U') {
		classifier = new UCBBoostMH(T, Q, (Rscale < 0) ? 1 : Rscale, &testSet);
	}
	else if(argv[6][0] == 'X') {
		classifier = new Exp3pBoostMH(T, Q, (Rscale < 0) ? 1 : Rscale, 0.3, 0.15, &testSet);
	}
	else if(argv[6][0] == 'P') {
		classifier = new EGreedyBoostMH(T, Q, (Rscale < 0) ? 1 : Rscale, &testSet);
	}
	else if(argv[6][0] == 'T') {
		classifier = new TastingBoostMH(T, Q, (Rscale < 0) ? 100 : Rscale, true, &testSet);
	}
	else if(argv[6][0] == 't') {
		classifier = new TastingBoostMH(T, Q, (Rscale < 0) ? 100 : Rscale, false, &testSet);
	}
	else {
		cerr << "Invalid Boosting type." << endl;
		return -1;
	}

	classifier->train(trainSet);

	delete classifier;
}
