// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


#include "PATrainer.h"
#include "PAExample.h"

namespace Torch {

PATrainer::PATrainer(PAMachine *machine_, PADataset *train_set_, PALoss *loss_)
{
	machine = machine_;
	train_set = train_set_;
	loss = loss_;

	n_measurers = 0;
	measurers = NULL;
}

void PATrainer::addMeasurer(PAMeasurer *measurer_)
{
	n_measurers++;
	measurers = (PAMeasurer**) allocator->realloc(measurers, sizeof(PAMeasurer*) * n_measurers);
	measurers[n_measurers-1] = measurer_;
}

void PATrainer::train(int n_iter)
{
	machine->init();
	train_set->init();

	for (int i = 0; i < n_iter; i++)
	{	
		trainIter();
		for (int m = 0; m < n_measurers; m++)
			measurers[m]->measureIter();
	}
}

void PATrainer::trainIter()
{	
	PAExample *x = train_set->nextExample();
  real l = loss->getLoss(machine, x);
	if (l != 0)
	{
		loss->getSupport(machine, x, l);						
		machine->addSupport(loss->alpha, loss->sv);
	}
}

PATrainer::~PATrainer() {}

}
