// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


/*
This program train a PAMIR model
*/
#include "PAMachineImageIR.h" // machine
#include "PATrainer.h"        // trainer
#include "ImgLoss.h"          // loss
#include "ImgConstraintPolicies.h"
#include "PASparseSupportSetLK.h" // suport sets
#include "PAIndexSupportSet.h"
#include "PAKernelSparseLinear.h" // kernels
#include "PACachedKernel.h"

#include "PASparseDataset.h"  // datasets
#include "PAIndexDataset.h"
#include "MatDataSet.h"
#include "ImgConstraintDataset.h"
#include "SparseMatrixROM.h"

#include "ImgRtrvMeasurerFast.h" // measurers
#include "PASupportVectorSetSizeMeasurer.h"

#include "DiskXFile.h" // misc.
#include "CmdLine.h"
#include "Random.h"     

using namespace Torch;

int main(int argc, char **argv)
{
  // ======= Constants =======
  int train_index = 0;
  int valid_index = 1;

	// ======= Variables =======

	// training data
	char *train_image_f;
	char *train_kernel_mat_f;
	char *train_relevance_f;
	char *train_query_f;

	// criterion
	bool bep;
  char *train_caption_f;

  // hyperparameters
  int n_iter;
  real C;
	real min_epsilon;

	// valid data
  char *valid_image_f;
	char *valid_kernel_mat_f;
  char *valid_relevance_f;
	char *valid_query_f;
	char *measure_file;
	char *sv_set_size_f;
	int measure_freq;

	// save the model
	char *model_file;
	int save_freq;

	// allocator
  Allocator *allocator = new Allocator;

	// random generator seed
	int seed;

	// === Read command line ====
	
  CmdLine cmd;

	// training data
	cmd.addText("\n Query Data");
	cmd.addSCmdArg("train_queries", &train_query_f, "training queries");
	cmd.addSCmdArg("relevance", &train_relevance_f, "training relevance");

	// hyperparameters
	cmd.addText("\n Hyperparameters");
	cmd.addICmdArg("n_iter", &n_iter, "number of training iterations");
	cmd.addRCmdArg("C", &C, "aggressiveness");

	// training data
	cmd.addText("\n Image Data");	
	cmd.addSCmdOption("-timg", &train_image_f, "", "training images (for linear kernel)");
	cmd.addSCmdOption("-tkernel", &train_kernel_mat_f, "", "train-train kernel values (for other kernels)");

  // option for Text Epsilon training
  cmd.addText("\n Criterion");
	cmd.addBCmdOption("-bep", &bep, false, "BEP loss");
  cmd.addSCmdOption("-cap", &train_caption_f,"", "training captions for text epsilon loss");
  cmd.addRCmdOption("-min_epsilon", &min_epsilon, 0.001, "min_epsilon for text epsilon loss");

	// valid data	
	cmd.addText("\n Validation Measurements");	
	cmd.addSCmdOption("-mfile", &measure_file, "", "validation measure file");
	cmd.addICmdOption("-mfreq", &measure_freq, -1, "validation measure frequency");
  cmd.addSCmdOption("-vimg", &valid_image_f, "", "validation images (for linear kernel)");
	cmd.addSCmdOption("-vkernel", &valid_kernel_mat_f, "", "train-valid kernel values (for other kernels)");
  cmd.addSCmdOption("-vque", &valid_query_f, "", "validation queries");
  cmd.addSCmdOption("-vrel", &valid_relevance_f, "", "validation relevance");

	// measurements of support vector set size
	cmd.addText("\n Support Set Measurements");
  cmd.addSCmdOption("-svset_size", &sv_set_size_f, "", "support vector set size");

	// saving the model
	cmd.addText("\n Save the model");
	cmd.addSCmdOption("-save", &model_file, "", "model file");	
	cmd.addICmdOption("-sfreq", &save_freq, -1, "frequency of periodic save");

  // random number generator seed 
  cmd.addText("\n Random number generator seed");
	cmd.addICmdOption("-seed", &seed, -1, "the seed");
	cmd.read(argc, argv);

	// === Init random number generator ===    
	if(seed == -1)
      Random::seed();
    else
      Random::manualSeed((long)seed);
	
	// === training mode ===
	bool linear_kernel = (strcmp(train_image_f, "") != 0);
	bool text_epsilon = (strcmp(train_caption_f, "") != 0);
	bool measure_valid = (strcmp(measure_file, "") != 0);
	bool measure_svset = (strcmp(sv_set_size_f, "") != 0);

	// ===== training data ====
	message("Loading the training data...");

	// -- queries --
	SparseMatrix *train_queries = new (allocator) SparseMatrixROM(train_query_f);

	// -- relevance --
	SparseMatrix *train_relevance = new (allocator) SparseMatrixROM(train_relevance_f);

	// -- images or cached kernel values --
	PADataset *train_images = NULL;
	DataSet *train_kernel_mat = NULL;
	int n_terms = train_queries->nc;
	int n_train_img = train_relevance->nc;
	int n_visterms = 0;
	if (linear_kernel) // image visterms
	{	
		SparseMatrix *image_matrix = new (allocator) SparseMatrixROM(train_image_f);
		n_visterms = image_matrix->nc;
		train_images = new (allocator) PASparseDataset(image_matrix);
	}
	else               // cached kernel matrix
	{
		train_kernel_mat = new (allocator) MatDataSet(train_kernel_mat_f, -1, 0, true);
		train_kernel_mat->setExample(0);
		train_images = new (allocator) PAIndexDataset(train_index, n_train_img);
	}

	// -- build the training constraint set from loaded data --
	MarginPolicy *margin = NULL;
	WeightPolicy *weight = NULL;
	if (text_epsilon) // training captions for TE
	{
		SparseMatrix *train_captions = new (allocator) SparseMatrixROM(train_caption_f);
		margin = new (allocator) TextEpsilon(min_epsilon, train_captions);
	}
	if (bep)
		weight = new (allocator) BEPWeighting();
	PADataset *train_set = new (allocator) ImgConstraintDataset(train_queries, train_relevance, 
                                          train_images, margin, weight);
	
	// ===== validation data ====
	message("Loading the validation data...");
	SparseMatrix *valid_queries = NULL;
	SparseMatrix *valid_relevance = NULL;
	PADataset *valid_images = NULL;	
	DataSet *valid_kernel_mat = NULL;
	int n_valid_img = 0;
	if (measure_valid) 
	{
		// -- queries --
		valid_queries = new (allocator) SparseMatrixROM(valid_query_f);	

		// -- relevance --
		valid_relevance = new (allocator) SparseMatrixROM(valid_relevance_f);

		// -- images or cached kernel values --
		n_valid_img = valid_relevance->nc;
	  if (linear_kernel) // image visterms
	  {
  	  SparseMatrix *image_matrix = new (allocator) SparseMatrixROM(valid_image_f);
    	valid_images = new (allocator) PASparseDataset(image_matrix);
  	}
  	else               // cached kernel matrix
  	{
    	valid_kernel_mat = new (allocator) MatDataSet(valid_kernel_mat_f, -1, 0, true);
	    valid_kernel_mat->setExample(0);
	    valid_images = new (allocator) PAIndexDataset(valid_index, n_valid_img);
  	}
	}

	// === Build the model and the measurers ===

	// kernel
	PAKernel *kernel = NULL;
	if (linear_kernel)
		kernel = new (allocator) PAKernelSparseLinear();
	else
	{
		PACachedKernel *kernel_ = new (allocator) PACachedKernel();
	  kernel_->addMatrix(train_index, train_index, train_kernel_mat->inputs);
   	if (measure_valid)
     kernel_->addMatrix(train_index, valid_index, valid_kernel_mat->inputs);	
		kernel = kernel_;
	}

	// term machines
	PAMachine **term_machines = (PAMachine **) allocator->alloc(sizeof(PAMachine*) * n_terms);
	for (int t = 0; t < n_terms; t++)
	{
		PASupportSet *support_set = NULL;
		if (linear_kernel)
			support_set = new (allocator) PASparseSupportSetLK(n_visterms);
		else
			support_set = new (allocator) PAIndexSupportSet(train_index, n_train_img);
		term_machines[t] = new (allocator) PAMachine(kernel, support_set);
	}

	// global machine
	PAMachine *machine = new (allocator) PAMachineImageIR(n_terms, term_machines);

	// measurers
	PAMeasurer *measurer_valid = NULL;
	if (measure_valid)
	{
		XFile *out = new (allocator) DiskXFile(measure_file, "w");
		measurer_valid = new (allocator) ImgRtrvMeasurerFast(machine, valid_queries, 
																		valid_images, valid_relevance, out);
		measurer_valid->setIterFrequency((unsigned int)measure_freq);
	}

	PAMeasurer *measurer_svset = NULL;
	if (measure_svset)
  {
		XFile *out = new (allocator) DiskXFile(sv_set_size_f, "w");
    measurer_svset = new (allocator) PASupportVectorSetSizeMeasurer(n_terms, term_machines, out);
    measurer_svset->setIterFrequency((unsigned int)measure_freq);
	}

	// criterion
	PALoss *loss = new (allocator) ImgLoss(C);

	// trainer
	PATrainer *trainer = new (allocator) PATrainer(machine, train_set, loss);	
	if (measure_valid) 
		trainer->addMeasurer(measurer_valid);	
	if (measure_svset)
		trainer->addMeasurer(measurer_svset);

	// === Training ===
  message("Start training...");
	trainer->train(n_iter);

	// === Save the model if required
	if (strcmp(model_file, "") != 0)
	{
		XFile *file = new (allocator) DiskXFile(model_file, "w");
		machine->saveXFile(file);
		allocator->free(file);
	}

	// === Free mem ===
	delete allocator;
}


