package com.enterprisemath.math.statistics.em;

import org.apache.commons.lang3.builder.ToStringBuilder;

import com.enterprisemath.math.algebra.Hypercube;
import com.enterprisemath.math.algebra.Vector;
import com.enterprisemath.math.probability.DiagonalNormalDistribution;
import com.enterprisemath.math.probability.DiagonalNormalDistributionMixture;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.ValidationUtils;

* This class is responsible for estimating the diagonal normal distribution mixture.
* Purpose of this class is to let the user to set high level limits and let the algorithm to do the rest.
* The whole algorithm works in the way that whole estimation starts with one component.
* Every iteration one component is selected and split. Then classic EM algorithm is invoked.
* @author radek.hecl
public class DiagonalNormalDistributionMixtureEstimator {

     * Constant for truncting components.
    private static final double COMP_TRUNC = Math.log(Double.MIN_VALUE);

     * Builder object.
    public static class Builder {

         * Maximum allowed number of components.
        private Integer maxComponents;

         * The minimum weight which is allowed for a component during estimation.
         * Every component which has less than the minWeight will be removed.
         * This essentially means that group of observations statistically less important than minWeight
         * might (but not necessary are) be ignored.
         * Value must be in interval [0, 1).
        private Double minWeight;

         * Step listener. This is not mandatory field.
         * Listener allows to track each step of algorithm.
        private DiagonalNormalDistributionMixtureEstimatorStepListener stepListener;

         * Sets maximum allowed number of components.
         * @param maxComponents maximum of number of components in the result
         * @return this instance
        public Builder setMaxComponents(int maxComponents) {
            this.maxComponents = maxComponents;
            return this;

         * Sets minimal weight for the components.
         * Every component with weight less than minWeight will be removed from the result.
         * This essentially means that group of observations statistically less important than minWeight
         * might (but not necessary are) be ignored.
         * Value must be in interval [0, 1).
         * @param minWeight component minimal weight value
         * @return this instance
        public Builder setMinWeight(double minWeight) {
            this.minWeight = minWeight;
            return this;

         * Sets step listener. This is not mandatory field.
         * Listener allows to track each step of algorithm.
         * @param stepListener step listener
         * @return this instance
        public Builder setStepListener(DiagonalNormalDistributionMixtureEstimatorStepListener stepListener) {
            this.stepListener = stepListener;
            return this;

         * Builds the result object.
         * @return created object
        public DiagonalNormalDistributionMixtureEstimator build() {
            return new DiagonalNormalDistributionMixtureEstimator(this);

     * Maximum allowed number of components.
    private Integer maxComponents;

     * The minimum weight which is allowed for a component during estimation.
     * Every component which has less than the minWeight will be removed.
     * This essentially means that group of observations statistically less important than minWeight
     * might (but not necessary are) be ignored.
     * Value must be in interval [0, 1).
    private Double minWeight;

     * Step listener. This is not mandatory field.
     * Listener allows to track each step of algorithm.
    private DiagonalNormalDistributionMixtureEstimatorStepListener stepListener;

     * Creates new instance.
     * @param builder builder object
    public DiagonalNormalDistributionMixtureEstimator(Builder builder) {
        minWeight = builder.minWeight;
        maxComponents = builder.maxComponents;
        stepListener = builder.stepListener;

     * Guards this object to be consistent. Throws exception if this is not the case.
    private void guardInvariants() {
        ValidationUtils.guardPositiveInt(maxComponents, "maxComponents must be positive");
        ValidationUtils.guardNotNegativeDouble(minWeight, "minWeight cannot be negative");
        ValidationUtils.guardGreaterDouble(1, minWeight, "minWeight must be less than 1");

     * Finds the optimal normal distribution mixture for a given set of observations.
     * @param observations observation for which the mixture will be estimated
     * @return estimated mixture
    public DiagonalNormalDistributionMixture estimate(ObservationProvider<Vector> observations) {
        Hypercube minMax = extractHypercube(observations);
        for (int i = 0; i < minMax.getDimension(); ++i) {
            ValidationUtils.guardGreaterOrEqualDouble(minMax.getMin().getComponent(i), -1000000,
                    "observation is out of range for calcualtion");
            ValidationUtils.guardGreaterOrEqualDouble(1000000, minMax.getMax().getComponent(i),
                    "observation is out of range for calcualtion");

        DiagonalNormalDistributionMixture res = initializeOneCompoenent(observations, minMax.getDimension());
        double resL = Double.NEGATIVE_INFINITY;
        double newL = countLnL(observations, res);

        int iteration = 0;
        while (newL - resL > 0.01 && res.getNumComponents() < maxComponents && iteration < 100) {
            resL = newL;
            // find the maximum weight
            double splitValue = 0;
            int splitCompIdx = 0;
            int splitDimIdx = 0;
            for (int i = 0; i < res.getNumComponents(); ++i) {
                DiagonalNormalDistribution comp = res.getComponents().get(i);
                for (int j = 0; j < comp.getDimension(); ++j) {
                    if (res.getWeights().get(i) * res.getComponents().get(i).getSigma().getComponent(j) > splitValue) {
                        splitValue = res.getWeights().get(i) * res.getComponents().get(i).getSigma().getComponent(j);
                        splitCompIdx = i;
                        splitDimIdx = j;
            // split the component with highest weight
            DiagonalNormalDistributionMixture.Builder builder = new DiagonalNormalDistributionMixture.Builder();
            for (int i = 0; i < res.getNumComponents(); ++i) {
                if (i == splitCompIdx) {
                    DiagonalNormalDistribution comp = res.getComponents().get(i);
                    double[] mi1 = new double[res.getDimension()];
                    double[] mi2 = new double[res.getDimension()];
                    for (int j = 0; j < comp.getDimension(); ++j) {
                        if (j == splitDimIdx) {
                            mi1[j] = comp.getMi().getComponent(j) + comp.getSigma().getComponent(j) / 2;
                            mi2[j] = comp.getMi().getComponent(j) - comp.getSigma().getComponent(j) / 2;
                        else {
                            mi1[j] = comp.getMi().getComponent(j);
                            mi2[j] = comp.getMi().getComponent(j);
                    builder.addComponent(res.getWeights().get(i) / 2, Vector.create(mi1), comp.getSigma());
                    builder.addComponent(res.getWeights().get(i) / 2, Vector.create(mi2), comp.getSigma());
                else {
                    if (res.getWeights().get(i) >= minWeight) {
                        builder.addComponent(res.getWeights().get(i), res.getComponents().get(i));
            DiagonalNormalDistributionMixture newMixture =;
            // iterations for the new mixture

            double help = Double.NEGATIVE_INFINITY;
            int emiteration = 0;
            while (emiteration < 5 || newL - help > 0.01) {
                help = newL;
                newMixture = nextIteration(observations, newMixture);
                newL = countLnL(observations, newMixture);
                if (getMinWeigth(newMixture) < minWeight) {
                    newMixture = newMixture.createSignificantComponentMixture(minWeight);
                    newL = countLnL(observations, newMixture);
                // assign new L value if possible
                if (newL > resL) {
                    res = newMixture;
        return res;

     * Extracts hypercube from the specified observations.
     * @param observations observations
     * @return extracted interval
    private Hypercube extractHypercube(ObservationProvider<Vector> observations) {
        ObservationIterator<Vector> iterator = observations.getIterator();
        Hypercube.Builder res = new Hypercube.Builder();
        while (iterator.isNextAvailable()) {

     * Invokes step listener with a given mixture.
     * @param mixture mixture
    private void invokeStepListener(DiagonalNormalDistributionMixture mixture) {
        if (stepListener != null) {

     * Makes the one component initialization for the EM algorithm.
     * @param observations observation for which the initialization should be calculated
     * @param dimension dimension of the observations
     * @return mixture with initial parameters
    private DiagonalNormalDistributionMixture initializeOneCompoenent(ObservationProvider<Vector> observations, int dimension) {
        // calculates first and second central momentum
        double[] m1 = new double[dimension];
        double[] m2 = new double[dimension];
        ObservationIterator<Vector> iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            Vector x = iterator.getNext();
            for (int i = 0; i < dimension; ++i) {
                m1[i] += x.getComponent(i);
                m2[i] += x.getComponent(i) * x.getComponent(i);

        for (int i = 0; i < dimension; ++i) {
            m1[i] /= iterator.getNumIterated();
            m2[i] /= iterator.getNumIterated();
            m2[i] -= m1[i] * m1[i];
            m2[i] = Math.sqrt(m2[i]);

        // determine the components
        return new DiagonalNormalDistributionMixture.Builder().
                addComponent(1, DiagonalNormalDistribution.create(Vector.create(m1), Vector.create(m2))).

     * Makes one iteration of the EM algorithm. Returns the mixture after the iteration.
     * @param observations observation for which the iteration should be calculated
     * @param start starting position
     * @return mixture after the iteration
    private DiagonalNormalDistributionMixture nextIteration(ObservationProvider<Vector> observations, DiagonalNormalDistributionMixture start) {
        int numComponents = start.getNumComponents();
        int dim = start.getDimension();
        double[] newW = new double[numComponents];
        double[][] newMi = new double[numComponents][dim];
        double[][] newSigma = new double[numComponents][dim];

        double[] c = new double[numComponents];
        double c0 = -Double.MAX_VALUE;
        double h = 0;
        double qmx = 0;
        //double L = 0;
        double[] mi = null;
        double[] sigma = null;

        // adding values
        ObservationIterator<Vector> iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            Vector x = iterator.getNext();
            c0 = -Double.MAX_VALUE;

            for (int j = 0; j < numComponents; ++j) {
                c[j] = Math.log(start.getWeights().get(j)) + start.getComponents().get(j).getLnValue(x);
                if (c[j] > c0)
                    c0 = c[j];

            h = 0;
            for (int j = 0; j < numComponents; ++j) {
                c[j] -= c0;
                if (c[j] > COMP_TRUNC) {
                    c[j] = Math.exp(c[j]);
                    h += c[j];
                else {
                    c[j] = 0;
            //L += Math.log(h) + c0;

            for (int j = 0; j < numComponents; ++j) {
                if (c[j] == 0)
                mi = newMi[j];
                sigma = newSigma[j];
                qmx = c[j] / h;
                newW[j] += qmx;

                for (int k = 0; k < dim; ++k) {
                    mi[k] += x.getComponent(k) * qmx;
                    sigma[k] += x.getComponent(k) * x.getComponent(k) * qmx;

        // finishing
        for (int i = 0; i < numComponents; ++i) {
            mi = newMi[i];
            sigma = newSigma[i];
            for (int j = 0; j < dim; ++j) {
                mi[j] /= newW[i];
                sigma[j] = Math.max(0.01, Math.sqrt(sigma[j] / newW[i] - mi[j] * mi[j]));
                if (Double.valueOf(sigma[j]).equals(Double.NaN)) {
                    sigma[j] = 0.01;
            newW[i] /= iterator.getNumIterated();

        //System.out.println("L = " + (L / obs.size()));

        // creating new instance
        DiagonalNormalDistributionMixture.Builder builder = new DiagonalNormalDistributionMixture.Builder();
        for (int i = 0; i < numComponents; ++i) {
            builder.addComponent(newW[i], new DiagonalNormalDistribution.Builder().


     * Calculates the ln(L) value.
     * Where L = prod_x( sum_i(w(i|x)P(i|x)) ) and ln(L) = sum_x( ln(sum_i(w(i|x)P(i|x))) ).
     * @param observations observations
     * @param mixture mixture
     * @return ln(L) value
    private double countLnL(ObservationProvider<Vector> observations, DiagonalNormalDistributionMixture mixture) {
        double[] c = new double[mixture.getNumComponents()];
        double c0 = -Double.MAX_VALUE;
        double h = 0;
        double L = 0;

        // adding values
        ObservationIterator<Vector> iterator = observations.getIterator();
        while (iterator.isNextAvailable()) {
            Vector x = iterator.getNext();
            c0 = -Double.MAX_VALUE;

            for (int j = 0; j < mixture.getNumComponents(); ++j) {
                c[j] = Math.log(mixture.getWeights().get(j)) + mixture.getComponents().get(j).getLnValue(x);
                if (c[j] > c0)
                    c0 = c[j];

            h = 0;
            for (int j = 0; j < mixture.getNumComponents(); ++j) {
                c[j] -= c0;
                if (c[j] > COMP_TRUNC) {
                    c[j] = Math.exp(c[j]);
                    h += c[j];
                else {
                    c[j] = 0;
            L += Math.log(h) + c0;

        // not divide to fall back to the case what was before
        //L = L / iterator.getNumIterated();
        //System.out.println("L = " + L);
        return L;


     * Returns minimal weight.
     * @param mixture mixture
     * @return minimal weight
    private double getMinWeigth(DiagonalNormalDistributionMixture mixture) {
        double res = 1;
        for (double w : mixture.getWeights()) {
            if (w < res) {
                res = w;
        return res;

    public String toString() {
        return ToStringBuilder.reflectionToString(this);


