// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


/*
This program test a PAMIR model
*/
#include "PAMachineImageIR.h" // machine
#include "PATrainer.h"        // trainer
#include "ImgLoss.h"          // loss
#include "ImgConstraintPolicies.h"
#include "PASparseSupportSetLK.h" // suport sets
#include "PAIndexSupportSet.h"
#include "PAKernelSparseLinear.h" // kernels
#include "PACachedKernel.h"

#include "PASparseDataset.h"  // datasets
#include "PAIndexDataset.h"
#include "MatDataSet.h"
#include "ImgConstraintDataset.h"
#include "SparseMatrixROM.h"

#include "ImgRtrvMeasurerFast.h" // measurers
#include "PASupportVectorSetSizeMeasurer.h"

#include "DiskXFile.h" // misc.
#include "CmdLine.h"
#include "Random.h"     

using namespace Torch;

int main(int argc, char **argv)
{
  // ======= Constants =======
  int train_index = 0;
  int test_index = 1;

	// ======= Variables =======

	// test data
  char *test_image_f;
	char *test_kernel_mat_f;
  char *test_relevance_f;
	char *test_query_f;
	char *measure_file;

	// the model
	char *model_file;

	// avg measurement
	bool avg;

	// allocator
  Allocator *allocator = new Allocator;

	// === Read command line ====
	
  CmdLine cmd;

	// test data
	cmd.addText("\n Query Data");
	cmd.addSCmdArg("test_queries", &test_query_f, "test queries");
	cmd.addSCmdArg("relevance", &test_relevance_f, "test relevance");
	cmd.addSCmdArg("model", &model_file, "model file");
	cmd.addSCmdArg("measure", &measure_file, "measure file");

	// test data
	cmd.addText("\n Image Data");	
	cmd.addSCmdOption("-timg", &test_image_f, "", "test images (for linear kernel)");
	cmd.addSCmdOption("-tkernel", &test_kernel_mat_f, "", "train-test kernel values (for other kernels)");

	// measurement option 
  cmd.addText("\n Measurement Option");
  cmd.addBCmdOption("-avg", &avg, false, "average over query set");

	cmd.read(argc, argv);

	// === test mode ===
	bool linear_kernel = (strcmp(test_image_f, "") != 0);

	// ===== test data ====
	message("Loading the test data...");

	// -- queries --
	SparseMatrix *test_queries = new (allocator) SparseMatrixROM(test_query_f);

	// -- relevance --
	SparseMatrix *test_relevance = new (allocator) SparseMatrixROM(test_relevance_f);

	// -- images or cached kernel values --
	PADataset *test_images = NULL;
	DataSet *test_kernel_mat = NULL;
	int n_terms = test_queries->nc;
	int n_test_img = test_relevance->nc;
	int n_visterms = 0;
	if (linear_kernel) // image visterms
	{	
		SparseMatrix *image_matrix = new (allocator) SparseMatrixROM(test_image_f);
		n_visterms = image_matrix->nc;
		test_images = new (allocator) PASparseDataset(image_matrix);
	}
	else               // cached kernel matrix
	{
		test_kernel_mat = new (allocator) MatDataSet(test_kernel_mat_f, -1, 0, true);
		test_kernel_mat->setExample(0);
		test_images = new (allocator) PAIndexDataset(test_index, n_test_img);
	}

	// === Build the model and the measurers ===

	// kernel
	PAKernel *kernel = NULL;
	if (linear_kernel)
		kernel = new (allocator) PAKernelSparseLinear();
	else
	{
		PACachedKernel *kernel_ = new (allocator) PACachedKernel();
	  kernel_->addMatrix(train_index, test_index, test_kernel_mat->inputs);
		kernel = kernel_;
	}

	// term machines
	PAMachine **term_machines = (PAMachine **) allocator->alloc(sizeof(PAMachine*) * n_terms);
	for (int t = 0; t < n_terms; t++)
	{
		PASupportSet *support_set = NULL;
		if (linear_kernel)
			support_set = new (allocator) PASparseSupportSetLK(n_visterms);
		else
			support_set = new (allocator) PAIndexSupportSet(test_index, n_test_img);
		term_machines[t] = new (allocator) PAMachine(kernel, support_set);
	}

	// global machine
	PAMachine *machine = new (allocator) PAMachineImageIR(n_terms, term_machines);

	// load machine
  XFile *file = new (allocator) DiskXFile(model_file, "r");
  machine->loadXFile(file);
  allocator->free(file);

	// measure
	ImgRtrvMeasurer *measurer= NULL;
	XFile *out = new (allocator) DiskXFile(measure_file, "w");
	measurer = new (allocator) ImgRtrvMeasurerFast(machine, test_queries, test_images, test_relevance, out);
	measurer->setAveraging(avg);
	measurer->measure();

	// === Free mem ===
	delete allocator;
}


