
######################### importation of pyVerif modules #####################

import os
from pyVerif.Experiment import Experiment
from pyVerif.Model import ModelErrors
from pyVerif.Score import Score

class ExperimentError:
    class _ExperimentError (Exception): pass

    class InvalidNumberOfWorldModel (_ExperimentError): pass

######################### the DiscriminantExperiment class #####################
class DiscriminantExperiment(Experiment):
    """Class for a discriminant (MLP, SVM, ...) biometric verification experiment"""

    def __init__ (self, features_dir, bin_dir, norm_file=None,
                  tmp_dir="tmp", out_dir="./", features_ext=".mat", 
                  log_on_screen=False, log_file=None, executor=None):
        """initialize this instance
        
            Keyword arguments:

                features_dir  --  string: directory containing the features data
                bin_dir  --  string: directory containing the Torch executable
                norm_file  --  hash: file containing the world model file list
                tmp_dir  --  string: name of the temporary directory (default "tmp")
                out_dir  --  string: output directory to put models, scores, results (default "./")
                features_ext  --  string: features files extension (default ".mat")
                log_on_screen  --  boolean: put the log message on stdout/stdin (default False)
                log_file  --  string: file to put the log messages (default None)
                executor  --  pyVerif.Executor instance: object to run external command (default None)
        """
        Experiment.__init__(self, features_dir, bin_dir, norm_file, tmp_dir,
                            out_dir, features_ext, log_on_screen, log_file,
                            executor)

    def train_clients_models (self, for_dev):
        """train all the clients models
        
            Keyword arguments:
                for_dev  --  boolean: enroll the models for the development set
                             or for the test set
        """

        # set the directories to use
        self.set_dirs (
            models_dir = self.clients_models_dir,
            scores_dir = self.dev_scores_dir)

        if for_dev:
            protocol = self.protocol_dev
            txt = 'dev'
        else:
            protocol = self.protocol_test
            txt = 'test'

        # do the train all the models
        for client_id in protocol.clients ():
            
            # set the identity of the current model to train
            self.model.set_identity (client_id)

            # train it !
            self.log ('train %s for client %s' % (txt,client_id))
            true_accesses = protocol.train_list (client_id)
            impost_accesses = protocol.impostors_list(client_id)
            try:
                self.model.train (samples=true_accesses,impostors_samples=impost_accesses,overwrite=False,files_must_exist=False)
            except ModelErrors.TrainExistingModel:
                self.executor.log("Warning model: %s exist!" % client_id)

    def build_clients_scores (self, for_dev):
        """
        Compute the clients scores for all the models, for either dev
        or test.

        specify for_dev=True if you want to compute for dev set, and
        for_dev=False if you want to compute for test set.
        """

        print "DiscriminantExperiment::build_clients_scores() ..."

        # select the wanted protocol
        if for_dev:
            protocol = self.protocol_dev
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.dev_scores_dir)
            txt = 'dev'
        else:
            protocol = self.protocol_test
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.test_scores_dir)
            txt = 'test'
            
        # set the commands and options to use for adapting the models
        #self.model.set_build_scores_cmd(self.cmd,self.test_options)
        #self.model.set_build_scores_cmd(self.test_cmd,self.test_options)

        # compute the score for all the models
        for client_id in protocol.clients ():
            
            # set the identity of the current model to train
            self.model.set_identity (client_id)

            # build scores !
            self.log ('build %s scores %s' % (txt, client_id))
            try:
                self.model.build_scores (protocol.score_list (client_id),overwrite=False,files_must_exist=False)
            except ModelErrors.ComputeExistingScore:
                self.executor.log("Warning score for client: %s exist!" % client_id)

    def merge_score (self, out_file, for_dev):
        """merge the computed into a result file.

             Merge the score for either dev or test, putting the result in
             out_file.
        """

        # select the wanted protocol
        if for_dev:
            protocol = self.protocol_dev
            scores_dir = self.dev_scores_dir
            world_scores_dir = self.dev_world_scores_dir
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.dev_scores_dir)
            txt = 'dev'
        else:
            protocol = self.protocol_test
            scores_dir = self.test_scores_dir
            world_scores_dir = self.test_world_scores_dir
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.test_scores_dir)
            txt = 'test'
            
        self.log ('merge scores for %s set' % txt)

        # create a score object
        score = Score (self.protocol_dev.id_of)

        # load all the scores computed with the regular models
        score.load_all (scores_dir, protocol.clients ())

        # save the merged scores
        score.save (out_file)


######################### the GenerativeExperiment class #####################
class GenerativeExperiment(Experiment):
    """Class for a generative (GMM, HMM, ...) biometric verification experiment"""

    def __init__ (self, features_dir, bin_dir, norm_file=None,
                  tmp_dir="tmp", out_dir="./", features_ext=".mat", 
                  log_on_screen=False, log_file=None, executor=None):
        """initialize this instance
        
            Keyword arguments:

                features_dir  --  string: directory containing the features data
                bin_dir  --  string: directory containing the Torch executable
                norm_file  --  hash: file containing the world model file list
                tmp_dir  --  string: name of the temporary directory (default "tmp")
                out_dir  --  string: output directory to put models, scores, results (default "./")
                features_ext  --  string: features files extension (default ".mat")
                log_on_screen  --  boolean: put the log message on stdout/stdin (default False)
                log_file  --  string: file to put the log messages (default None)
                executor  --  pyVerif.Executor instance: object to run external command (default None)
        """
        
        Experiment.__init__(self, features_dir, bin_dir, norm_file, tmp_dir,
                            out_dir, features_ext, log_on_screen, log_file,
                            executor)

    def set_cmds (self, cmd, world_options, client_options, test_options):
        """
        Set the command for experiments for either the world model or
        the clients models.
        """

        self.cmd = cmd
        self.world_cmd = cmd
        self.client_cmd = cmd
        self.test_cmd = cmd
        self.world_options = world_options
        self.client_options = client_options
        self.test_options = test_options

    def set_world_cmd (self, world_cmd, world_options):
        """
        Set the train command for experiments for the world model.
        """

        self.world_cmd = world_cmd
        self.world_options = world_options

    def set_client_cmd (self, client_cmd, client_options):
        """
        Set the train command for experiments for either the client models.
        """

        self.client_cmd = client_cmd
        self.client_options = client_options

    def set_test_cmd (self, test_cmd, test_options):
        """
        Set the test command for experiments for either the world model or
        the clients models.
        """

        self.test_cmd = test_cmd
        self.test_options = test_options


    def train_world_model (self):
        """
        Train the world models
        """
        # set the directories to use
        self.set_dirs (
            models_dir = self.world_models_dir,
            scores_dir = self.dev_world_scores_dir)

        for w in self.protocol_dev.norm_keys(self.world_model_name):

            # set the commands and options to use for the world model
            #self.model.set_train_cmd(self.cmd,self.world_options)
            self.model.set_train_cmd(self.world_cmd,self.world_options)

            #train it !
            self.log ('train world model %s' % w)
            self.model.set_identity (w)
            try:
                self.model.train (
                    self.protocol_dev.norm_list (self.world_model_name,w),
                    overwrite=False,files_must_exist=False)
            except ModelErrors.TrainExistingModel:
                self.executor.log("Warning model: %s exist!" % w)
     
        # optional part (to remove on public release)
        if len(self.protocol_dev.norm_keys(self.world_model_name))==2: 
            self.executor.log ("================= merge world models ============")
            world_models = self.protocol_dev.norm_keys (self.world_model_name)
            self.executor.run (os.path.join(self.bin_dir, self.cmd + " --merge"), '"%s %s" %s' % (os.path.join (self.world_models_dir,world_models [0]),
                                      os.path.join (self.world_models_dir, world_models [1]),
                                      os.path.join (self.world_models_dir, self.world_model_name)))
	elif len(self.protocol_dev.norm_keys(self.world_model_name))==1:
	     pass
	else :
             raise ExperimentError.InvalidNumberOfWorldModel, "The number of world models is invalid (%d)" % len(self.protocol_dev.norm_keys(self.world_model_name))

        #if len(self.protocol_dev.norm_keys(self.world_model_name))==2: 
        #    self.executor.log ("================= merge world models ============")
        #    world_models = self.protocol_dev.norm_keys (self.world_model_name)
        #    self.executor.run (os.path.join(self.bin_dir,'gmm_merge'), '%s %s %s' % (os.path.join (self.world_models_dir,world_models [0]),
        #                              os.path.join (self.world_models_dir, world_models [1]),
        #                              os.path.join (self.world_models_dir, self.world_model_name)))
	#elif len(self.protocol_dev.norm_keys(self.world_model_name))==1:
	#     pass
	#else :
        #     raise ExperimentError.InvalidNumberOfWorldModel, "The number of world models is invalid (%d)" % len(self.protocol_dev.norm_keys(self.world_model_name))


    def train_clients_models (self, for_dev):
        """
        Train all the clients models
        """
        # set the directories to use
        self.set_dirs (
            models_dir = self.clients_models_dir,
            scores_dir = self.dev_scores_dir)
        # set the commands and options to use for adapting the models
        #self.model.set_train_cmd(self.cmd,self.client_options)
        self.model.set_train_cmd(self.client_cmd,self.client_options)

        if for_dev:
            protocol = self.protocol_dev
            txt = 'dev'
        else:
            protocol = self.protocol_test
            txt = 'test'



        # do the train all the models
        for client_id in protocol.clients ():
            
            # set the identity of the current model to train
            self.model.set_identity (client_id)

            # train it !
            self.log ('train %s for client %s' % (txt,client_id))
            try:
                self.model.train (protocol.train_list (client_id),
                                  overwrite=False,files_must_exist=False)
            except ModelErrors.TrainExistingModel:
                self.executor.log("Warning model: %s exist!" % client_id)

    def build_world_scores (self, for_dev):
        """
        Compute the client scores for the world model.
        """

        # select the wanted protocol
        if for_dev:
            protocol = self.protocol_dev
            world_scores_dir = self.dev_world_scores_dir
            self.set_dirs (models_dir = self.world_models_dir,
                           scores_dir = self.dev_world_scores_dir)
            txt = 'dev'
        else:
            protocol = self.protocol_test
            world_scores_dir = self.test_world_scores_dir
            self.set_dirs (models_dir = self.world_models_dir,
                           scores_dir = self.test_world_scores_dir)
            txt = 'test'
            
        # set the commands and options to use for the world model
        #self.model.set_build_scores_cmd(self.cmd,self.test_options)
        self.model.set_build_scores_cmd(self.test_cmd,self.test_options)

        # set the identity to the world model
        self.model.set_identity (self.world_model_name)
        
        # build scores !
        self.log ('build %s scores for %s' % (txt, self.world_model_name))
        try:
            self.model.build_scores (protocol.score_list (),overwrite=False,files_must_exist=False)
        except ModelErrors.ComputeExistingScore:
            self.executor.log("Warning score file: %s exist!" % self.world_model_name)

    def build_clients_scores (self, for_dev):
        """
        Compute the clients scores for all the models, for either dev
        or test.

        specify for_dev=True if you want to compute for dev set, and
        for_dev=False if you want to compute for test set.
        """

        # select the wanted protocol
        if for_dev:
            protocol = self.protocol_dev
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.dev_scores_dir)
            txt = 'dev'
        else:
            protocol = self.protocol_test
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.test_scores_dir)
            txt = 'test'
            
        # set the commands and options to use for adapting the models
        #self.model.set_build_scores_cmd(self.cmd,self.test_options)
        self.model.set_build_scores_cmd(self.test_cmd,self.test_options)

        # compute the score for all the models
        for client_id in protocol.clients ():
            
            # set the identity of the current model to train
            self.model.set_identity (client_id)

            # build scores !
            self.log ('build %s scores %s' % (txt, client_id))
            try:
                self.model.build_scores (protocol.score_list (client_id),overwrite=False,files_must_exist=False)
            except ModelErrors.ComputeExistingScore:
                self.executor.log("Warning score for client: %s exist!" % client_id)

    def merge_score (self, out_file, for_dev):
        """
        Merge the score for either dev or test, putting the result in
        out_file.
        """

        # select the wanted protocol
        if for_dev:
            protocol = self.protocol_dev
            scores_dir = self.dev_scores_dir
            world_scores_dir = self.dev_world_scores_dir
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.dev_scores_dir)
            txt = 'dev'
        else:
            protocol = self.protocol_test
            scores_dir = self.test_scores_dir
            world_scores_dir = self.test_world_scores_dir
            self.set_dirs (models_dir = self.clients_models_dir,
                           scores_dir = self.test_scores_dir)
            txt = 'test'
            
        self.log ('merge scores for %s set' % txt)

        # create a score object
        score = Score (self.protocol_dev.id_of)

        # load all the scores computed with the regular models
        score.load_all (scores_dir, protocol.clients ())

        # load the world scores
        score.load_norm (os.path.join (world_scores_dir, 'world'))

        # compute the difference between the two scores
        score.add_column (lambda x, y: x - y, (1, 2))

        # save the merged scores
        score.save (out_file)

