// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


#include "FastNearestNeighbors.h"
#include "SVectorCalculator.h"

namespace Torch
{

FastNearestNeighbors::FastNearestNeighbors(SparseMatrix *q, SparseMatrix *td, int n)
:SparseMatrixRW(q->nl,td->nc)
{
	// scores (non-sparse)
	real *scores = (real*) allocator->alloc(sizeof(real) * td->nc); 
	// scores (sparse)
  svector *sparse_score=(svector*)allocator->alloc(sizeof(svector));
  sparse_score->frame=(sreal*)allocator->alloc(sizeof(sreal) * td->nc);

	// for each query
	for (int iq = 0; iq < q->nl; iq++)
	{
		memset(scores, 0, sizeof(real) * td->nc);
    // compute all scores
		for (int t=0; t<q->lines[iq].size; t++)
		{
			int it = q->lines[iq].frame[t].index; // query term id
			real wq = q->lines[iq].frame[t].value;// query term weight
			for (int d = 0 ; d < td->lines[it].size; d++)
			{
				int id = td->lines[it].frame[d].index; // document id
				real wd = td->lines[it].frame[d].value; // term weight in doc id
				scores[id]+= wd * wq; 
			}
		}
		// sparsify scores and sort them by (decraesing) value
		int i = 0;
		for (int d = 0; d < td->nc; d++)
		{
			if (scores[d]!=0)
			{
				sparse_score->frame[i].value=scores[d];
				sparse_score->frame[i].index=d;
				i++;
			}
		}
		sparse_score->size=i;		
		qsort(sparse_score->frame, sparse_score->size, sizeof(sreal), sreal_comp_values);

		// put sparse score in the matrix
		if (sparse_score->size>n) 
			sparse_score->size=n;
		qsort(sparse_score->frame, sparse_score->size, sizeof(sreal), sreal_comp_index);
		copy(iq, sparse_score);
	}
	allocator->free(scores);
	allocator->free(sparse_score->frame);
	allocator->free(sparse_score);
}

void FastNearestNeighbors::countMatch(svector *q, SparseMatrix *td, int *n_match)
{
		memset(n_match, 0, sizeof(int)*td->nc);

    for (int t = 0; t < q->size; t++)
    {
      int it = q->frame[t].index; // query term id
      for (int d = 0 ; d < td->lines[it].size; d++)
      {
        int id = td->lines[it].frame[d].index; // document id
				n_match[id]++;
      }
    }
}

FastNearestNeighbors::~FastNearestNeighbors() {}

}
