//------------------------------------------------------------------------------
// Tasting families of features for image classification.
// 
// Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
// Written by Charles Dubout <charles.dubout@idiap.ch>
// 
// This file is part of Tasting.
// 
// Tasting is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License version 2 as
// published by the Free Software Foundation.
// 
// Tasting is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with Tasting. If not, see <http://www.gnu.org/licenses/>.
//------------------------------------------------------------------------------

#ifndef ML_CLASSIFIERS_UCBBOOSTMH_H
#define ML_CLASSIFIERS_UCBBOOSTMH_H

#include "Classifier.h"

namespace ML {
	//--------------------------------------------------------------------------
	/// @brief	AdaBoost.MH.UCB with stumps
	//--------------------------------------------------------------------------
	class UCBBoostMH : public Classifier {
		//_____ Construction / Destruction and Copy _______
	public:
		//----------------------------------------------------------------------
		/// @brief Constructor
		///
		/// @param	nbRounds	The number of desired boosting rounds (can be
		///						less in case the error of the weak learner is
		///						zero, default 100)
		/// @param	nbFeatures	The number of features to sample at every round
		///						(default 10)
		/// @param	scale		Scale of the rewards (default 1.0)
		//----------------------------------------------------------------------
		UCBBoostMH(unsigned int nbRounds = 100,
					 unsigned int nbFeatures = 10,
					 double scale = 1.0,
					 InputSet* testSet = 0);

		//_____ Methods inherited from Classifier _______
	public:
		//----------------------------------------------------------------------
		/// @brief	Clone method
		///
		/// @return	A deep copy of the classifier
		//----------------------------------------------------------------------
		virtual Classifier* clone() const;

		//----------------------------------------------------------------------
		/// @brief	Trains the classifier
		///
		/// @param	inputSet	The input set over which the classifier should
		///						be trained
		//----------------------------------------------------------------------
		virtual void train(InputSet& inputSet);

		//----------------------------------------------------------------------
		/// @brief	Predicts the class memberships for a given sample
		///
		/// @param	inputSet	The input set containing the sample to classify
		/// @param	sample 		The index of the sample to classify
		/// @retval	distr		The predicted class memberships
		//----------------------------------------------------------------------
		virtual void distribution(InputSet& inputSet,
								  unsigned int sample,
								  scalar_t* distr) const;

		//----------------------------------------------------------------------
		/// @brief	Reports the features used by the classifier
		///
		/// @retval	features	The vector of features to populate
		//----------------------------------------------------------------------
		virtual void report(std::vector<unsigned int>& features) const;

		//_____ Custom types _______
	private:
		//----------------------------------------------------------------------
		/// @brief	Simpler stump than a C45Tree
		//----------------------------------------------------------------------
		struct Stump {
			unsigned int feature_;		///< The feature used for classification
			std::vector<bool> signs_;	///< The signs for every label
			scalar_t split_;			///< The split point along the feature
		};

		//_____ Internal methods _______
	private:
		//----------------------------------------------------------------------
		/// @brief	Trains a stump given an ordered input set
		//----------------------------------------------------------------------
		static double train(InputSet& inputSet,
					 		const std::vector<std::vector<unsigned int> >& indices,
					 		const std::vector<std::vector<double> >& weights,
							const std::vector<double> edges,
							Stump& stump);

		//_____ Attributes _______
	private:
		unsigned int nbRounds_;			///< The number of rounds
		unsigned int nbFeatures_;		///< Features per round
		double scale_;					///< Scale of the rewards
		InputSet* testSet_;

		std::vector<Stump> stumps_;		///< The weak learners
		std::vector<scalar_t> alphas_;	///< The weights of the learners
	};
} // namespace ML

#endif // ML_CLASSIFIERS_UCBBOOSTMH_H
