Package stallone.mc.estimator

Source Code of stallone.mc.estimator.TransitionMatrixEstimatorRevFixPi

/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package stallone.mc.estimator;

import stallone.api.doubles.IDoubleList;
import stallone.api.doubles.Doubles;
import stallone.api.doubles.IDoubleIterator;
import stallone.api.doubles.IDoubleArray;
import stallone.api.algebra.*;
import stallone.api.mc.*;

/**
*
* @author noe
*/
public final class TransitionMatrixEstimatorRevFixPi implements ITransitionMatrixEstimator
{
    // convergence

    //private int nIterMax = 10000;
    //private int nIterPer1 = 100;
    private int nIterMax = 1000000;
    private int nIterPer1 = 1000;
    private IDoubleList logliks = Doubles.create.list(nIterMax);

    private IDoubleArray C;

    private IDoubleArray pi = null;
    private IDoubleArray X;
    private IDoubleIterator itX;

    private boolean verbose = false;

    public TransitionMatrixEstimatorRevFixPi(IDoubleArray _C, IDoubleArray _pi)
    {
        this.pi = _pi;
        this.setCounts(_C);
    }

    public TransitionMatrixEstimatorRevFixPi(IDoubleArray _pi)
    {
        this.pi = _pi;
    }

    private void initX()
    {
        // initial T(tau)
        IDoubleArray T = C.copy();
        for (int i = 0; i < T.rows(); i++)
        {
            IDoubleArray r = T.viewRow(i);
            Algebra.util.scale(Doubles.util.sum(r), r);
            /*double s = DoubleArrays.sum(T[i]);
            for (int j = 0; j < T.length; j++)
            {
                if (T[i][j] != 0)
                {
                    T[i][j] /= s;
                }
                else if (i == j)
                {
                    T[i][j] = 1;
                }
            }*/
        }

        // initial off-diagonals
        this.X = C.create(C.rows(),C.columns());
        for (IDoubleIterator it = T.nonzeroIterator(); it.hasNext(); it.advance())
        {
            int i = it.row();
            int j = it.column();

                if (i != j)
                {
                    X.set(i,j, 0.5 * (pi.get(i) * T.get(i,j) + pi.get(j) * T.get(j,i)));
                }

                if (Double.isNaN(X.get(i,j)))
                {
                    System.out.println("NaN: " + i + " " + j);
                }
        }

        // scale all off-diagonals if necessary to realize stationary dist.
        double o = 0;
        for (int i = 0; i < X.rows(); i++)
        {
            o = Math.max(o, Doubles.util.sum(X.viewRow(i)) / pi.get(i));
        }

        if (o > 0.9)
        {
            Algebra.util.scale((0.9 / o), X);
        }

        // enforce stationary distribution
        for (int i = 0; i < X.rows(); i++)
        {
            X.set(i,i, pi.get(i) - Doubles.util.sum(X.viewRow(i)));
        }

        this.itX = X.nonzeroIterator();

        // TEST
        /*System.out.println("Testing X: ");
        IDoubleArray T2 = X.copy();
        alg.normalizeRows(T2, 1);
        IDoubleArray pitest = msm.stationaryDistribution(T2);
        for (int i=0; i<pitest.size(); i++)
            System.out.println(pi.get(i)+"\t"+pitest.get(i));*/
    }

    private double logL()
    {
        // compute likelihood using the matrix iterator
        double ll = 0;

        itX.reset();
        int i,j;

        while(itX.hasNext())
        {
            i = itX.row();
            j = itX.column();

            if (X.get(i,j) > 0)
            {
                ll += C.get(i,j) * Math.log(X.get(i,j) / pi.get(i));
            }

            itX.advance();
        }

        return (ll);
    }

    private boolean isConverged()
    {
        // check for number of iterations
        if (this.logliks.size() >= this.nIterMax)
        {
            return (true);
        }

        // check for good convergence
        if (this.logliks.size() <= this.nIterPer1)
        {
            return (false);
        }

        int i2 = this.logliks.size() - 1;
        int i1 = i2 - this.nIterPer1;
        double dL = this.logliks.get(i2) - this.logliks.get(i1);
        return (dL <= 1.0);
    }


    private double dLL(int i, int j, double d)
    {
        double dll = C.get(i,i)*Math.log(X.get(i,i) - d)
                    +C.get(j,j)*Math.log(X.get(j,j) - d)
                    +(C.get(i,j)+C.get(j,i))*Math.log(X.get(i,j) + d);
        return(dll);
    }

    private double opt(int i, int j, double dmin, double dmax)
    {
        double x_ii = X.get(i,i);
        double x_jj = X.get(j,j);
        double x_ij = X.get(i,j);
        double c_ii = C.get(i,i);
        double c_jj = C.get(j,j);
        double c_ij = C.get(i,j);
        double c_ji = C.get(j,i);

        double E = c_ij*x_ii + c_ji*x_ii + c_jj*x_ii - c_ii*x_ij - c_jj*x_ij + c_ii*x_jj + c_ij*x_jj + c_ji*x_jj;
        double A = Math.pow(-c_ij*x_ii - c_ji*x_ii - c_jj*x_ii + c_ii*x_ij + c_jj*x_ij - c_ii*x_jj - c_ij*x_jj - c_ji*x_jj, 2);
        double B = 4*(c_ii + c_ij + c_ji + c_jj)*(-c_jj*x_ii*x_ij + c_ij*x_ii*x_jj + c_ji*x_ii*x_jj - c_ii*x_ij*x_jj);
        double D = 2*(c_ii + c_ij + c_ji + c_jj);

        double d1 = (E - Math.sqrt(A - B))/D;
        double d2 = (E + Math.sqrt(A - B))/D;

        double lbest = dLL(i,j,0);
        double dbest = 0;

        double l = dLL(i,j,dmin);
        if (l>lbest)
        {
            lbest = l;
            dbest = dmin;
        }

        l = dLL(i,j,dmax);
        if (l>lbest)
        {
            lbest = l;
            dbest = dmax;
        }

        // test d1
        if (d1 >= dmin && d1 <= dmax)
        {
            l = dLL(i,j,d1);
            if (l>lbest)
            {
                lbest = l;
                dbest = d1;
            }
        }

        // test d2
        if (d2 >= dmin && d2 <= dmax)
        {
            l = dLL(i,j,d2);
            if (l>lbest)
            {
                lbest = l;
                dbest = d2;
            }
        }

        return(dbest);
    }

    private void optimizeElement(int i, int j)
    {
        double dmin = -X.get(i,j);
        double dmax = Math.min(X.get(i,i), X.get(j,j));

        //System.out.println("X before: " + X[i][i] + "\t" + X[i][j] + "\t" + X[j][i] + "\t" + X[j][j]);

        double d = opt(i,j,dmin,dmax);
        // Newton step:
        /*double d = 0;
        for (int k = 0; k < 30; k++)
        {
            d = newtonStep(i, j, d);
        }

        if (d < dmin)
        {
            d = dmin;
        }
        if (d > dmax)
        {
            d = dmax;
        }*/

        //System.out.println("mod d " + d);
        X.set(i,i, X.get(i,i)-d);
        X.set(i,j, X.get(i,j)+d);
        X.set(j,i, X.get(j,i)+d);
        X.set(j,j, X.get(j,j)-d);
        //System.out.println("X after: " + X[i][i] + "\t" + X[i][j] + "\t" + X[j][i] + "\t" + X[j][j]);
    }

    private void step()
    {
        itX.reset();
        int i,j;

        while(itX.hasNext())
        {
            i = itX.row();
            j = itX.column();

            if (i<j)
            {
                optimizeElement(i,j);
            }
            itX.advance();
        }

        double ll = logL();

        if (verbose)
        {
            //System.out.println(DoubleArrays.toString(X,"\t","\t"));
            System.out.println((logliks.size() + 1) + "\t" + ll);
        }

        this.logliks.append(ll);
    }

    ////////////////////////////////////////////////////////////////
    //
    //  INTERFACE FUNCTIONS
    //
    ////////////////////////////////////////////////////////////////
    @Override
    public void setMaxIter(int nmax)
    {
        this.nIterMax = nmax;
    }

    /**
     * Sets the convergence criterion. Convergence accepted when the likelihood has not changed more than 1
     * for nIterPer1 Consecutive steps.
     * @param nIterPer1
     */
    @Override
    public void setConvergence(int niter)
    {
        this.nIterPer1 = niter;
    }

    @Override
    public void setCounts(IDoubleArray _C)
    {
        this.C = _C;

        initX();

        double ll = logL();

        this.logliks = Doubles.create.list(nIterMax);
        this.logliks.append(ll);
    }

    @Override
    public void estimate()
    {
        while (!isConverged())
        {
            step();
        }

    }

    @Override
    public IDoubleArray getTransitionMatrix()
    {
        // output likelihood list
        //for (int i=0; i<this.logliks.size(); i++)
        //    System.err.println("# "+i+"\t"+this.logliks.get(i));

        IDoubleArray T = X.create(X.rows(),X.columns());
        for (IDoubleIterator it = X.nonzeroIterator(); it.hasNext(); it.advance())
        {
            int i = it.row();
            int j = it.column();
            T.set(i, j, X.get(i,j) / pi.get(i));
            T.set(j, i, X.get(i,j) / pi.get(j));
        }

        return (T);
    }

    @Override
    public double[] getLikelihoodHistory()
    {
        return(this.logliks.getArray());
    }

    public int getIterations()
    {
        return(this.logliks.size());
    }

    public void setVerbose(boolean _verbose)
    {
        this.verbose = _verbose;
    }
}
TOP

Related Classes of stallone.mc.estimator.TransitionMatrixEstimatorRevFixPi

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.