# This Python file uses the character encoding: utf-8
"""
Module for computing the maximum likelihood parameter estimate of a
high-dimensional Pólya distribution. The code will work for
low-dimensional distributions, but is perhaps not as efficient as it
could be. It was made to be robust for very large numbers of
dimensions and tested at approximately 4000 dimensions.

Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
import scipy as S

def _polya_log_likelihood(alpha, counts):
    from scipy.special import gammaln
    s = sum(S.sort(alpha))
    result = \
           + counts.shape[0] * gammaln(s) \
           - S.sum(gammaln(s + S.sum(counts, axis=1))) \
           + S.sum(gammaln(counts + alpha)) \
           - counts.shape[0] * S.sum(gammaln(alpha))
    return result

_polya_precision_opt_memoized = dict.fromkeys(['log_s', 'm', 'counts', 'result'])
def _polya_precision_opt(log_s, m, counts):
    if (m is _polya_precision_opt_memoized['m']) and \
       (counts is _polya_precision_opt_memoized['counts']):
        index = S.searchsorted(_polya_precision_opt_memoized['log_s'], log_s[0])
        if (index < len(_polya_precision_opt_memoized['log_s'])) and \
           (_polya_precision_opt_memoized['log_s'][index] == log_s):
            return _polya_precision_opt_memoized['result'][index]
    else:
        index = 0
        _polya_precision_opt_memoized['log_s'] = S.array([], dtype=float)
        _polya_precision_opt_memoized['result'] = S.array([], dtype=float)
        _polya_precision_opt_memoized['m'] = m
        _polya_precision_opt_memoized['counts'] = counts

    result = -_polya_log_likelihood(S.exp(log_s)*m, counts)
    
    _polya_precision_opt_memoized['log_s'] = S.insert(
        _polya_precision_opt_memoized['log_s'], index, log_s)
    _polya_precision_opt_memoized['result'] = S.insert(
        _polya_precision_opt_memoized['result'], index, result)
    
    return result

def compute_polya_parameter(counts, verbose=False):
    """
    Compute the maximum likelihood parameter estimate of a
    multivariate Pólya distribution. Counts is a N x D array of
    independent samples from the Pólya distribution, where N is the
    number of samples and D is the number of dimensions. Returns a
    length-D vector, which is the estimated parameter of the
    distribution.
    """
    # Regularize counts
    counts = S.vstack((counts, 1e-4 * S.ones(counts.shape[1])))

    # Compute initial estimate of alpha
    probs = counts/counts.sum(axis=1)[:,S.newaxis]
    m = S.mean(probs, axis=0)
    v = S.var(probs, axis=0)
    sum_alphas = m*(1-m)/v-1
    sum_alpha = S.median(sum_alphas)
    mean_alpha = m

    if verbose:
        print '0: sum(alpha) =', sum_alpha

    from scipy import optimize
    from scipy.special import digamma
    for i in range(10):
        # Re-estimate sum(alpha)
        old_sum_alpha = sum_alpha
        sum_alpha = S.exp(optimize.fmin(
            _polya_precision_opt, S.log(sum_alpha),
            args=(m, counts), xtol=1e-2, disp=0)[0])
        if verbose:
            print '%i: sum(alpha) ='%(i+1), sum_alpha

        # Re-estimate alpha/sum(alpha)
        a = sum_alpha*mean_alpha
        a = S.sum(a*(digamma(counts+a)-digamma(a)), axis=0)
        mean_alpha = a/S.sum(a)

        # Check for convergence
        if abs(sum_alpha-old_sum_alpha)/old_sum_alpha < 1e-3:
            if verbose:
                print 'Converged.'
            break

    return sum_alpha*mean_alpha
