// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


#include "ImgRtrvMeasurer.h"
#include "ImgConstraint.h"

namespace Torch {

ImgRtrvMeasurer::ImgRtrvMeasurer
 (PAMachine *machine_, SparseMatrix *queries_, 
	PADataset *images_, SparseMatrix *relevance_,
  XFile *out_)
{
  machine = machine_;
  queries = queries_;
  images = images_;
  relevance = relevance_;
	out = out_;

  n_que = queries_->nl;
  n_img = images->getNExamples();

  averaging = true;
	cur_que_scores = (real*) allocator->alloc(sizeof(real) * n_img);
  rtrv_eval = new (allocator) RetrievalEvaluation();
} 


void ImgRtrvMeasurer::setAveraging(bool avg)
{
	averaging = avg;
}

void ImgRtrvMeasurer::measure()
{
	initMeasure();
	int n_measures = 4;
	double *measures = (double*) allocator->alloc(sizeof(double) * n_measures);
	double *avg_measures = (double*) allocator->alloc(sizeof(double) * n_measures);
	memset(avg_measures, 0, sizeof(double) * n_measures);

	for (int q = 0; q < n_que; q++)
	{
		measureQuery(queries->lines + q, relevance->lines + q);
		measures[0] = (double) rtrv_eval->errorRate();
		measures[1] = (double) rtrv_eval->avgp();
		measures[2] = (double) rtrv_eval->bep();
		measures[3] = (double) rtrv_eval->ptop(10);
		if (!averaging)
			printOut(measures, n_measures);
		else
			for (int m = 0; m < n_measures; m++)
				avg_measures[m] += measures[m];
	}
	if (averaging)
	{
		for (int m = 0; m < n_measures; m++)
    	avg_measures[m] /= (double) n_que;
		printOut(avg_measures, n_measures);
	}

	allocator->free(measures);
	allocator->free(avg_measures);
}

void ImgRtrvMeasurer::printOut(double *measures, int n_measures)
{
	out->printf("%d", n_iter);
	for (int m = 0; m < n_measures; m++)
		out->printf(" %lf", measures[m]);
	out->printf("\n");
	out->flush();
}

void ImgRtrvMeasurer::measureQuery(svector *query, svector *rel)
{
	// get the RSV for each image
	getQueryScores(query);
	// evaluate
	rtrv_eval->set(n_img, cur_que_scores, rel);
}

void ImgRtrvMeasurer::initMeasure()
{}

void ImgRtrvMeasurer::getQueryScores(svector *query)
{
  // get the RSV for each image
  ImgConstraint *x = new (allocator) ImgConstraint();
  for (int i = 0; i < n_img; i++)
  {
    PAExample *img = images->getExample(i);
    x->set(query, img, NULL, 0);
    cur_que_scores[i] = machine->forward(x);
  }
  allocator->free(x);
}

ImgRtrvMeasurer::~ImgRtrvMeasurer()
{}

}
