// Copyright (c) 2007 David Grangier
// Copyright (c) 2007 Samy Bengio
// 
// All rights reserved.
// 
// Redistribution and use in source and binary forms, with or without 
// modification, are permitted provided that the following conditions are 
// met: Redistributions of source code must retain the above copyright 
// notice, this list of conditions and the following disclaimer.
// Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the 
// documentation and/or other materials provided with the distribution.
// The name of the author may not be used to endorse or promote products
// derived from this software without specific prior written permission.
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 
// THE POSSIBILITY OF SUCH DAMAGE.


#include "SMatrixSplitter.h"
#include "Random.h"
#include "SparseMatrixRW.h"
#include "SparseMatrixROM.h"
#include "string_utils.h"
#include "DiskXFile.h"

namespace Torch {

SMatrixSplitter::SMatrixSplitter()
{
	n = 0;
	nl = 0;
	nc = 0;
	full_mat = NULL;
	sub_mat = NULL;
	index_mat = NULL;
}

void SMatrixSplitter::init(int n_, int nl_, int nc_)
{
	n = n_;
	nl = nl_;
	nc = nc_;

  sub_mat = (SparseMatrix**) allocator->alloc(sizeof(SparseMatrix*) * n);
  index_mat = (SparseMatrix**) allocator->alloc(sizeof(SparseMatrix*) * n);	
}

void SMatrixSplitter::split(char *full_mat_file, int n_)
{
	// load data
	full_mat = new(allocator) SparseMatrixROM(full_mat_file);
	init(n_, full_mat->nl, full_mat->nc);

	// shuffle
	int *shuffle = (int*) allocator->alloc(sizeof(int) * nl);
	Random::getShuffledIndices(shuffle, nl);

	// split
	int *indices = shuffle;
	for (int i = 0; i < n; i++)
	{
		int size = (nl / n) + ((i < (nl % n)) ? 1 : 0);
		sub_mat[i] = new(allocator) SparseMatrixRW(size, nc);
		index_mat[i] = new(allocator) SparseMatrixRW(size, nl);

		for (int j = 0; j < size; j++)
		{
			sub_mat[i]->copy(j, full_mat->lines + indices[j]);
			index_mat[i]->resize(j, 1);
			index_mat[i]->lines[j].frame[0].index = indices[j];
			index_mat[i]->lines[j].frame[0].value = 1.0;
		}

		indices+=size;
	}
}

void SMatrixSplitter::join(char *sub_mat_file, int n_)
{
	init(n_, 0, 0);

	// load data
	for (int i = 0; i < n; i++)
	{
		char *sub_mat_name = filename(sub_mat_file, ".mat.", i);
		char *index_name = filename(sub_mat_file, ".index.", i);

		sub_mat[i] = new(allocator) SparseMatrixROM(sub_mat_name);
		index_mat[i] = new(allocator) SparseMatrixROM(index_name);
    
		free(sub_mat_name);
		free(index_name);
	}

	// join
	nl = index_mat[0]->nc;
	nc = sub_mat[0]->nc;
	full_mat = new(allocator) SparseMatrixRW(nl, nc);
  for (int i = 0; i < n; i++)
  	for (int j= 0; j < index_mat[i]->nl; j++)
		{
			int index = index_mat[i]->lines[j].frame[0].index;
			full_mat->copy(index, sub_mat[i]->lines + j);
		}
}

void SMatrixSplitter::writeSplit(char *sub_mat_file)
{
	for (int i = 0; i < n; i++)
  {
    char *sub_mat_name = filename(sub_mat_file, ".mat.", i);
		XFile *file = new (allocator) DiskXFile(sub_mat_name, "w");
    sub_mat[i]->writeXFile(file);
		allocator->free(file);
		free(sub_mat_name);

		char *index_name = filename(sub_mat_file, ".index.", i);
    file = new (allocator) DiskXFile(index_name, "w");
    index_mat[i]->writeXFile(file);
    allocator->free(file);
		free(index_name);
  }	
}

void SMatrixSplitter::writeJoin(char *full_mat_file)
{
	XFile *file = new (allocator) DiskXFile(full_mat_file, "w");
  full_mat->writeXFile(file);
  allocator->free(file);
}

char* SMatrixSplitter::filename(char *file, char *ext, int i)
{
	char *no = (char*) malloc(sizeof(char)*255);
  sprintf(no,"%03d",i);
  char *res = strConcat(3, file, ext, no);
  free(no);
	return res;
}

SMatrixSplitter::~SMatrixSplitter()
{
}

}
