"""
Use labelled images from the compaq_labels directory to construct a
prior over the class (skin / hair / clothes / background) of each
pixel in an area around the face. The facial area is defined as the
bounding square that is twice the size of the bounding square returned
by the Viola-Jones face detector.
Requires: Compaq skin database, labels in compaq_labels/
Creates:  storage/pim_prior.data

Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
from scipy import *
import os, sys
import opencv
from opencv import highgui
from colorlib import load_image
from viola_jones_opencv import viola_jones_opencv

# The size in pixels of a face image centered on the Viola-Jones
# bounding box and twice its size.
from FaceColorModelWrapper import FCM_SCALED_SIZE as scaledSize
                  
# Parameters
classes = ['skin', 'hair', 'clothes', 'background']
classIndex = dict(zip(classes, range(len(classes))))
labelColors = {
    'skin':       (255, 255,   0),
    'hair':       (255,   0,   0),
    'background': (  0,   0, 255),
    'clothes':    (  0, 255,   0),
}

# Read command line arguments
try:                                         # Input file list
    index = sys.argv.index("--input_list")
    inputListFilename = sys.argv[index+1]
except ValueError:
    print "ERROR: input list not specified"
    sys.exit()
try:                                         # Label file list
    index = sys.argv.index("--label_list")
    labelListFilename = sys.argv[index+1]
except ValueError:
    print "ERROR: label list not specified"
    sys.exit()
verbose = ("--quiet" not in sys.argv[1:])    # Verbosity

# Read training lists
with open(inputListFilename, 'rt') as fp:
    inputList = fp.read().strip().split('\n')
with open(labelListFilename, 'rt') as fp:
    labelList = fp.read().strip().split('\n')
if len(inputList) != len(labelList):
    print "Error: input and label lists are not of equal length"
    sys.exit()

# Allocate memory for image rescaling
affineMap = opencv.cvCreateMat(2, 3, opencv.CV_32F)
resizedMask = opencv.cvCreateImage(opencv.cvSize(scaledSize,scaledSize),
                                   opencv.IPL_DEPTH_8U, 1)

# Build prior
classAlpha = 1/len(classes) # Dirichlet hyper-prior for each class
maskSums = {}
for clas in classes:
    maskSums[clas] = classAlpha * \
                     ones((scaledSize, scaledSize), dtype=float)
allMaskSums = (classAlpha * len(classes)) * \
              ones((scaledSize, scaledSize), dtype=float)
imageCount = 0
for imageIndex in range(len(inputList)):
    # Read mask image
    # Read mask image and extract hair mask
    labelImage = load_image(labelList[imageIndex])
    if labelImage is None:
        if verbose:
            print "Warning: Could not read label image:", labelList[imageIndex]
        continue
        
    # Read color input image
    inputImage = load_image(inputList[imageIndex], highgui.CV_LOAD_IMAGE_GRAYSCALE)
    if inputImage is None:
        if verbose:
            print "Warning: Could not read input image:", inputList[imageIndex]
        continue

    assert (inputImage.width == labelImage.width) and\
           (inputImage.height == labelImage.height),\
           "Size mismatch in %s."%inputList[imageIndex]

    face = viola_jones_opencv(inputImage)
    if face is not None:
        scale = scaledSize/(2*(face[2]-face[0]))
        affineMap[0,0] = scale
        affineMap[0,1] = 0
        affineMap[0,2] = scaledSize/2 - scale*(face[1]+face[3])/2
        affineMap[1,0] = 0
        affineMap[1,1] = scale
        affineMap[1,2] = scaledSize/2 - scale*(face[0]+face[2])/2

        # Resize each class and accumulate
        labelArray = array(labelImage, dtype=uint8)[:,:,::-1] # Reverse color channels: BGR -> RGB
        for clas in classes:
            classMask = opencv.adaptors.NumPy2Ipl(
                array((labelArray == array(labelColors[clas], dtype=uint8)).all(axis=2) * 255, dtype=uint8))
            opencv.cvWarpAffine(classMask, resizedMask, affineMap,
                                opencv.CV_INTER_AREA+opencv.CV_WARP_FILL_OUTLIERS, 0)
            resizedArray = array(resizedMask) / 255
            maskSums[clas] += resizedArray
            allMaskSums += resizedArray
            
# Normalize mask distributions and enforce left-right symmetry
classPrior = empty((len(classes), scaledSize**2), dtype=float)
for clas in classes:
    x = maskSums[clas] / allMaskSums
    x = (x + x[:,::-1])/2
    classPrior[classIndex[clas]] = ravel(x)

with open('storage/pim_prior.data', 'wb') as fp:
    fp.write(classPrior.data)
