// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


#include "SVectorCalculator.h"

namespace Torch
{

SVectorCalculator::SVectorCalculator(int dim_)
{
	dim=dim_;
	res=(svector*)allocator->alloc(sizeof(svector));
	memset(res,0,sizeof(svector));
}	

void SVectorCalculator::reset()
{
	allocator->free(res->frame);
	memset(res,0,sizeof(svector));
}

void SVectorCalculator::add(real r, svector *v)
{
	// 2 cases : res = 0 or res !=0
	if (res->size==0)
	{
		res->size=v->size;
		res->frame=(sreal*)allocator->alloc(v->size*sizeof(sreal));
		memcpy(res->frame,v->frame,v->size*sizeof(sreal));

		if (r!=1.0) 
			for (int i=0;i<v->size;i++)
				res->frame[i].value*=r;
	}
	else
	{
		// memory for final result
		svector *new_res=(svector*)allocator->alloc(sizeof(svector));
		new_res->size=res->size+v->size;
		new_res->frame=(sreal*)allocator->alloc(new_res->size*sizeof(sreal));
		memset(new_res->frame,0,new_res->size*sizeof(sreal));

		int v_index=0;
		int res_index=0;
		int new_res_index=0;
		while ((v_index<v->size)||(res_index<res->size))
		{
			bool copy_res=false;
			bool copy_v=false;

			// what do we put in new_res ?
			if (v_index>=v->size) // no more v, copy res
				copy_res=true;
			else if (res_index>=res->size) // no more res, copy v
				copy_v=true;
			else // still res and v, copy smallest index
			{
				if (res->frame[res_index].index<=v->frame[v_index].index)
					copy_res=true;
				if (v->frame[v_index].index<=res->frame[res_index].index)
					copy_v=true;
			}

			// put it !
			if (copy_res)
			{
				memcpy(new_res->frame+new_res_index,res->frame+res_index,sizeof(sreal));
				res_index++;
			}

			if (copy_v)
			{
				new_res->frame[new_res_index].index=v->frame[v_index].index;
				new_res->frame[new_res_index].value+=r*v->frame[v_index].value;
				v_index++;
			}

			// stricly enforce sparsity		
			if (new_res->frame[new_res_index].value!=0) // keep the sreal only if value != 0
				new_res_index++;
		}
		// resize new_res and put it in res
		allocator->free(res->frame);
		allocator->free(res);
		new_res->size=new_res_index;
		new_res->frame=(sreal*)allocator->realloc(new_res->frame,sizeof(sreal)*new_res_index);
		res=new_res;
	}
	//DBG: detect overflows
	for (int i=0;i<res->size;i++)
	{
		int index=res->frame[i].index;
		real value=res->frame[i].value;
		if (isnan(value))
			print("ISNAN t=%d",index);
		if (isinf(value))
			print("ISINF t=%d",index);
	}

}

real SVectorCalculator::inner(svector *a,svector *b)
{
	double res=0;
	
	int i_a=0;
	int i_b=0;
	while ((i_a<a->size)&&(i_b<b->size))
	{
		if (a->frame[i_a].index==b->frame[i_b].index)
		{
			res+=(double)(a->frame[i_a].value*b->frame[i_b].value);
			i_a++;
			i_b++;
		}
		else if (a->frame[i_a].index<b->frame[i_b].index)
			i_a++;
		else
			i_b++;
	}

	return (real)res;
}

SVectorCalculator::~SVectorCalculator()
{}

}
