"""
Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
import command_line
import sys, os

classes = ['skin', 'hair', 'clothes', 'background']

commandLineSpec = (
    ("inputListFilename", "--input_list", str, "input_list.txt", """
      A text file listing the relative path to each input image
      in the training set."""),
    ("labelListFilename", "--label_list", str, "label_list.txt", """
      A text file listing the relative path to each label image
      in the training set. Each line in the label list should
      correspond to a line in the input list."""),
    ("folds", "--folds", int, 5, """
      The number of folds in the experimental run. The full data
      set will be divided into this many parts and testing on each
      of the folds while training on the other (n-1) folds."""),
    ("kMrf", "--kmrf", float, 2, """
      The strength constant of the Markov random field."""),
    ("skinType", "--skin", str, "cont", """
      Whether to use the discrete (disc) or continuous (cont) color
      model for the skin palette."""),
    ("hairType", "--hair", str, "cont", """
      Whether to use the discrete (disc) or continuous (cont) color
      model for the hair palette."""),
    ("quiet", "--quiet", """
      Tell the script *not* to be verbose."""),
)

# Calling for help
if "--help" in sys.argv[1:]:
    print command_line.usage(commandLineSpec, sys.argv)
    sys.exit()

# Parse command line arguments
try:
    PARAMS, _ = command_line.parse(commandLineSpec, sys.argv)
except ValueError, msg:
    print msg
    print command_line.usage(commandLineSpec, sys.argv)
    sys.exit()

# Read input and label lists
fileLists = {}
for typ in ['input','label']:
    with open(PARAMS[typ + "ListFilename"], 'rt') as fp:
        fileLists[typ] = fp.read().strip().split('\n')
if len(fileLists['input']) != len(fileLists['label']):
    print "Error: input and label file lists have different numbers of entries"
    sys.exit()
fileListLength = len(fileLists['input'])

# Initialize storage for aggregate results
true_positives = {}
false_positives = {}
true_negatives = {}
false_negatives = {}
for clas in classes:
    true_positives[clas] = 0
    false_positives[clas] = 0
    true_negatives[clas] = 0
    false_negatives[clas] = 0

# Divide into folds and for each fold:
foldStop = 0
for fold in range(PARAMS['folds']):
    if not PARAMS['quiet']:
        print 'Fold %i / %i'%(fold+1, PARAMS['folds'])
    foldStart = foldStop
    foldStop = int(round(fileListLength*(fold+1)/PARAMS['folds']))

    # Generate temporary file lists for training
    for typ in ['input','label']:
        with open("__%s_list.txt"%typ, "wt") as fp:
            fp.write('\n'.join(fileLists[typ][:foldStart]) + '\n')
            fp.write('\n'.join(fileLists[typ][foldStop:]) + '\n')
    
    # Retrain models quietly
    os.system("python train.py --input_list __input_list.txt --label_list __label_list.txt --quiet")
                     
    # Generate temporary file lists for testing
    for typ in ['input','label']:
        with open("__%s_list.txt"%typ, "wt") as fp:
            fp.write('\n'.join(fileLists[typ][foldStart:foldStop]) + '\n')
    
    # Test using test_image_segmentation.py
    if PARAMS['quiet']:
        flags = " --quiet"
    else:
        flags = ""
    os.system("python test_image_segmentation.py" +
              " --input_list __input_list.txt" +
              " --label_list __label_list.txt" +
              " --kmrf " + str(PARAMS['kMrf']) +
              " --skin " + PARAMS['skinType'] +
              " --hair " + PARAMS['hairType'] +
              flags +
              " > __test_image_segmentation.txt")

    # Aggregate results
    with open('__test_image_segmentation.txt','rt') as fp:
        text = fp.read()
    pos = text.find('Overall:')
    for clas in classes:
        pos = text.find(clas, pos)
        pos = text.find('--', pos) + 3
        entries = text[pos:text.find('\n', pos)].split(', ')
        for i in range(4):
            [true_positives, false_positives, true_negatives, false_negatives][i][clas] += float(entries[i][3:])
if not PARAMS['quiet']:
    print """
Key:
  f  -- F-score
  p  -- precision
  r  -- recall
  tp -- true positives
  fp -- false positives
  tn -- true negatives
  fn -- false negatives
"""

# Produce results for experiment
accuracy_numerator = 0
accuracy_denominator = 0
for clas in classes:
    accuracy_numerator += true_positives[clas]
    accuracy_denominator += true_positives[clas]+false_negatives[clas]
    precision = true_positives[clas]/(true_positives[clas]+false_positives[clas])
    recall = true_positives[clas]/(true_positives[clas]+false_negatives[clas])
    fScore = 2*precision*recall/(precision+recall)
    print '%-10s: f %.3f, p %.3f, r %.3f, tp %.2e, fp %.2e, tn %.2e, fn %.2e'%\
          (clas, fScore, precision, recall, true_positives[clas],
           false_positives[clas], true_negatives[clas], false_negatives[clas])
print 'Overall accuracy: %.3f'%(accuracy_numerator / accuracy_denominator)

# Delete temporary file lists
os.unlink('__input_list.txt')
os.unlink('__label_list.txt')
os.unlink('__test_image_segmentation.txt')
