#!/usr/bin/env python
#$ -cwd
#$ -S /usr/bin/python

########################### usual importation of modules #####################

# import of standard module
import sys
import os
from optparse import OptionParser

# import the user configuration from ~/.pythonrc.py
import user

######################### importation of pyVerif modules #####################

# import necessary classes from pyVerif
from pyVerif.Executor import Executor
from pyVerifseb.MyExperiments import DiscriminantExperiment

##################### the "main" program ##################

if __name__ == '__main__':

    print """
          
	  pyVerif program to perform a face verification experiment based on MLP.

	  Authors: Sebastien Marcel, Johnny Mariethoz and Olivier Bornet

          Version: 1.0
	  
	  Date: August 2006
	  
          """

    # program usage
    usage = "usage: %prog [options]"

    parser = OptionParser(usage)

    # set options
    parser.set_description ("pyVerif program to perform a face verification experiment")

    parser.add_option ("-t", "--test", dest="test",
                       help="Perform also test set",
                       action="store_true", default=False)

    parser.add_option ("-d", "--features-dir", dest="features_dir",
                       help="Directory containing the features files",
                       type="string", default="none")

    parser.add_option ("-o", "--output-dir", dest="output_dir",
                       help="Directory containing the results files",
                       type="string", default="./noname")
                       
    parser.add_option ("-p", "--protocol-dir", dest="protocol_dir",
                       help="Directory containing the protocol files",
                       type="string", default="lp")
    
    parser.add_option ("-b", "--base-dir", dest="base_dir",
                       help="Directory containing protocols and features "
		       + "(default xm2vts)",
                       type="string", default="xm2vts")
    
    parser.add_option ("-v", "--variant", dest="variant",
                       help="Variant of the protocol (default lp1)",
                       type="string", default="lp1")
    
    parser.add_option ("-l", "--log-on-screen", dest="log_on_screen",
                       help="Output log on screen",
                       action="store_true", default=False)
    
    parser.add_option ("-u", "--number-of-hidden-units",
                       dest="number_of_hidden_units",
                       help="Number of Hidden Units "
                       + "(default 25)",
                       type="int", default=25)

    parser.add_option ("-n", "--number-of-inputs",
                       dest="number_of_inputs",
                       help="Number of Inputs "
                       + "(default 5120)",
                       type="int", default=5120)

    parser.add_option ("", "--number-of-iterations",
                       dest="number_of_iterations",
                       help="Number of Iterations "
                       + "(default 25)",
                       type="int", default=25)

    parser.add_option ("", "--learning-rate",
                       dest="learning_rate",
                       help="Learning rate "
                       + "(default 0.01)",
                       type="float", default=0.01)

    parser.add_option ("", "--end-accuracy",
                       dest="end_accuracy",
                       help="End accuracy "
                       + "(default 1e-05)",
                       type="float", default=1e-05)

    parser.add_option ("", "--criterion", dest="criterion",
                       help="Training criterion (default NLL)",
                       type="string", default="nll")
    
    parser.add_option ("-s", "--seed", dest="seed",
                       help="seed (default -1)",
                       type="int", default=-1)
    
    parser.add_option ("", "--perf", dest="compute_perf",
                       help="compute the performance",
                       action="store_true", default=False)
    
    # parse options
    (options,args) = parser.parse_args ()

    # check the number of options
    if len(args) != 0:
        parser.error("Error: incorrect number of arguments, try --help")

    if options.features_dir == "none":
        parser.error("Error: no feature dir provided")

    #
    # prepare the directories
    #

    # base dir where is this script (for instance /home/vision/marcel/work/experiments/faceverification/pyverif/verification)
    base_dir = os.path.join ('/home', 'vision', 'marcel', 'work', 'experiments', 'faceverification', 'pyverif', options.base_dir)

    # directory containing the features
    features_dir = os.path.join (base_dir, 'features', options.features_dir)

    # directory containing the binary programs
    bin_dir = os.path.join (base_dir, 'verification','bin')

    # directory for the output files
    out_dir = os.path.join (options.output_dir, options.variant)

    # base dir containing the protocol files
    base_dir_protocols = os.path.join (base_dir, 'protocols', options.protocol_dir)
    
    # directory containing the results (created if it does'nt exist)
    results_dir = os.path.join (out_dir, 'results')
    
    if not os.path.isdir (results_dir):
       os.makedirs (results_dir)
    
    # directory containing the temporary files (created if it does'nt exist)
    tmp_dir = os.path.join (out_dir, 'tmp')
    
    if not os.path.isdir (tmp_dir):
       os.makedirs (tmp_dir)
    

    #
    # prepare the experiment
    #

    # create an instance of the Experiment class
    experiment = DiscriminantExperiment (
        out_dir = out_dir,
        features_dir = features_dir,
        features_ext = '.bindata',
        bin_dir = bin_dir,
        tmp_dir = tmp_dir,
        log_on_screen = options.log_on_screen,
        log_file = os.path.join (out_dir, 'mlp-' + options.variant + '.log'))
    
    # log the used options
    experiment.log_options (options.__dict__)

    # load the specified protocol
    experiment.load_protocols (base_dir_protocols, options.variant)

    # specify the main program to train and test a MLP
    cmd = "mlp"

    # options to train an MLP
    train_opt = "-verbose -%s " % options.criterion
    train_opt = train_opt + "-seed %d " % options.seed
    train_opt = train_opt + "-nhu %d " % options.number_of_hidden_units
    train_opt = train_opt + "-iter %d " % options.number_of_iterations
    train_opt = train_opt + "-lr %g " % options.learning_rate
    train_opt = train_opt + "-e %g " % options.end_accuracy
    train_opt = train_opt + "-dir %TMPDIR "
    train_opt = train_opt + "-save %MODEL %FILES %FILES_IMPOSTORS "
    train_opt = train_opt + "%d " % options.number_of_inputs
    train_opt = train_opt + " > %TMPDIR/train_model.stdout"

    # options to compute the scores against a GMM (world/client model testing)
    test_opt = "--test -verbose -dir %TMPDIR "
    test_opt = test_opt + "%MODEL %FILES %SCORES "
    test_opt = test_opt + "%d " % options.number_of_inputs
    test_opt = test_opt + " > %TMPDIR/test_model.stdout"

    # set main program and options
    experiment.set_cmds (cmd, train_opt, cmd, test_opt)

    #
    # run the experiment
    #

    # train all the client models
    experiment.train_clients_models (for_dev = True)

    # compute all the scores against the clients models for the development set
    experiment.build_clients_scores (for_dev = True)
    
    # merge scores from the world model and from the client model for the development set
    experiment.merge_score (os.path.join (results_dir, 'scores-dev'), for_dev = True)

    if options.test:
	# Warning: no client models to train in test with XM2VTS
	if options.base_dir != "xm2vts":
		# train all the client models
		experiment.train_clients_models (for_dev=False)

	# compute all the scores against the clients models for the test set
        experiment.build_clients_scores (for_dev = False)
	
        # merge scores from the world model and from the client model for the test set
        experiment.merge_score (os.path.join (results_dir, 'scores-test'), for_dev = False)

    if options.compute_perf:
	# compute the performance with pymeasure
	executor = Executor()
	executor.enable_log(on_screen=True,filename=experiment.executor.log_filename)
	cmd_perf = "pyerror"
	perf_opt = "-b --dev-set=" + os.path.join (results_dir, 'scores-dev') + " " + os.path.join (results_dir, 'scores-test')
	executor.run(cmd_perf,perf_opt)

	# generate DET/EPC curves
	cmd_perf = "pydet"
	perf_opt = "-b -s " + os.path.join (results_dir, 'det.eps') + " " + os.path.join (results_dir, 'scores-dev') + " " + os.path.join (results_dir, 'scores-test')
	executor.run(cmd_perf,perf_opt)

	cmd_perf = "pyepc"
	perf_opt = "-b -n 10 -s " + os.path.join (results_dir, 'epc.eps') + " " + os.path.join (results_dir, 'scores-dev') + " " + os.path.join (results_dir, 'scores-test')
	executor.run(cmd_perf,perf_opt)



