Package stallone.hmm.pmm

Source Code of stallone.hmm.pmm.NinjaEstimator

/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package stallone.hmm.pmm;

import static stallone.api.API.*;

import java.util.ArrayList;
import java.util.List;
import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.IDoubleArray;
import stallone.api.ints.IIntArray;
import stallone.api.hmm.IHMMParameters;
import stallone.api.hmm.ParameterEstimationException;
import stallone.api.hmm.IExpectationMaximization;
import stallone.api.hmm.IHMM;

/**
*
* @author noe
*/
public class NinjaEstimator implements IExpectationMaximization
{
    // discrete trajectory

    private List<IIntArray> dtraj;
    // estimation parameters
    private int tau = -1;
    // time shift between time grids
    private int timeshift = -1;
    // number of hidden states
    private int nhidden = -1;
    // max number of hmm iterations
    private double maxHMMLInc = -0.01;
    private int nIterHMMMax = 1000;
    // user initialization
    private IDoubleArray initChi = null;
    private IDoubleArray initTC = null;
    // current MSM estimation data
    private IDoubleArray msmC, msmT, msmpi, msmPi, msmCorr;
    private IDoubleArray msmChi, msmTC, msmTimescales;
    // current HMM estimates
    private IHMM hmmEst;
    private IDoubleArray hmmChi, hmmTC, hmmpiC, hmmTimescales;
    private double[] hmmLikelihoodHistory;
    // error estimates

    public NinjaEstimator(List<IIntArray> _dtraj)
    {
        this.dtraj = _dtraj;
    }

    public NinjaEstimator(int _tau, int _timeshift, int _nhidden)
    {
        this.tau = _tau;
        this.timeshift = _timeshift;
        this.nhidden = _nhidden;
    }

    public void setData(List<IIntArray> _dtraj)
    {
        this.dtraj = _dtraj;
    }

    public void setTau(int _tau)
    {
        this.tau = _tau;
    }

    public void setTimeshift(int _timeshift)
    {
        this.timeshift = _timeshift;
    }

    public void setNHiddenStates(int _nhidden)
    {
        this.nhidden = _nhidden;
    }

    public void setDiscreteTrajectory(List<IIntArray> _dtraj)
    {
        this.dtraj = _dtraj;
    }

    public void setNIterHMMMax(int _nIterHMMMax)
    {
        this.nIterHMMMax = _nIterHMMMax;
    }

    public void setHMMLikelihoodMaxIncrease(double _maxHMMLInc)
    {
        this.maxHMMLInc = _maxHMMLInc;
    }

    public void setInit(IDoubleArray TCInit, IDoubleArray ChiInit)
    {
        this.initTC = TCInit;
        this.initChi = ChiInit;
    }

    public IDoubleArray getMSMTransitionMatrix()
    {
        return msmT;
    }

    public IDoubleArray getMSMStationaryDistribution()
    {
        return msmpi;
    }

    public IDoubleArray getMSMTimescales()
    {
        return msmTimescales;
    }

    public IDoubleArray getPCCATransitionMatrix()
    {
        return msmTC;
    }
   
    public IDoubleArray getPCCAOutputProbabilities()
    {
        return msmChi;
    }
   
    public IDoubleArray getHMMTransitionMatrix()
    {
        return hmmTC;
    }

    public IDoubleArray getHMMStationaryDistribution()
    {
        return hmmpiC;
    }

    public IDoubleArray getHMMOutputProbabilities()
    {
        return hmmChi;
    }

    public IDoubleArray getHMMTimescales()
    {
        return hmmTimescales;
    }

    public double[] getHMMLikelihoodHistory()
    {
        return hmmLikelihoodHistory;
    }


    /*
     * private List<IDataSequence> compressTrajectory(IIntArray dtraj) { int
     * nout = ((dtraj.size()-1) / tau) + 1; int nstates = ints.max(dtraj) + 1;
     *
     * double[][] out = new double[nout][nstates];
     *
     * for (int i = 0; i < dtraj.size(); i++) { // closest out index: int iout =
     * i / tau;
     *
     * // are we within averageWindow/2? double timeOfGrid = iout * tau + tau /
     * 2.0; double dist = Math.abs(timeOfGrid - i); if (dist <= (averageWindow /
     * 2)) { out[iout][dtraj.get(i)] += 1.0; } } IDataList seq =
     * dataNew.createDatalist();
     *
     * for (int i = 0; i < out.length; i++) {
     * seq.add(doublesNew.arrayFrom(out[i])); }
     *
     * ArrayList<IDataSequence> res = new ArrayList(); res.add(seq);
     *
     * return res;
    }
     */
    /**
     * Creates a series of subsamples of the given trajectory: (0, tau, 2*tau,
     * ...) (dt, tau+dt, 2*tau+dt, ...) ...
     *
     * @param dtraj discrete trajectory
     * @param dt time shift between two subsequent sampling series
     * @return
     */
    private List<IIntArray> subsamples(IIntArray dtraj, int dt)
    {
        int N = dtraj.size();
        List<IIntArray> res = new ArrayList();
        for (int s = 0; s < tau; s += dt)
        {
            IIntArray I = intsNew.arrayRange(s, N, tau);
            res.add(ints.subToNew(dtraj, I));
        }
        return res;
    }

    private List<IIntArray> subsamples(List<IIntArray> dtraj, int dt)
    {
        List<IIntArray> res = new ArrayList();
        for (int i = 0; i < dtraj.size(); i++)
        {
            res.addAll(subsamples(dtraj.get(i), dt));
        }
        return res;
    }

    private List<IDataSequence> toObservation(List<IIntArray> dTrajectories, int nstates)
    {
        List<IDataSequence> obs = new ArrayList();
        for (int i = 0; i < dTrajectories.size(); i++)
        {
            //DiscreteTrajectoryDataSequence seq = new DiscreteTrajectoryDataSequence(dTrajectories.get(i), nstates);
            DiscreteTrajectorySimpleDataSequence seq = new DiscreteTrajectorySimpleDataSequence(dTrajectories.get(i));
            obs.add(seq);
        }
        return obs;
    }

    private IHMMParameters hmm(List<IDataSequence> observations, IDoubleArray TCInit, IDoubleArray ChiInit)
            throws ParameterEstimationException
    {
        System.out.println("HMM initialized with timescales: "+msm.timescales(TCInit, tau));
       
        // initial parameters
        IHMMParameters par0 = hmmNew.parameters(nhidden, true, true);
        par0.setTransitionMatrix(TCInit.copy());
        for (int i = 0; i < nhidden; i++)
        {
            par0.setOutputParameters(i, ChiInit.viewColumn(i));
        }

        // Estimate
        int nObservableStates = ChiInit.rows();
        double[] uniformPrior = new double[nObservableStates];
        java.util.Arrays.fill(uniformPrior, 1.0 / (double) uniformPrior.length);
        IExpectationMaximization EM = hmmNew.emDiscrete(observations, par0, uniformPrior);
        EM.setMaximumNumberOfStep(nIterHMMMax);
        EM.setLikelihoodDecreaseTolerance(maxHMMLInc);

        System.out.println(" running hmm on " + observations.size() + " x " + observations.get(0).size() + " observations with maxIter " + nIterHMMMax);
        EM.run();
        hmmLikelihoodHistory = EM.getLogLikelihoodHistory();
        System.out.println(" hmm iterations: " + hmmLikelihoodHistory.length);
        System.out.println(" likelihood history: " + doubleArrays.toString(hmmLikelihoodHistory, "\n"));

        // HMM
        this.hmmEst = EM.getHMM();
        IHMMParameters parEst = hmmEst.getParameters();

        return parEst;
    }

    private IDoubleArray getChi(IHMMParameters par)
    {
        int nObservableStates = par.getOutputParameters(0).size();
        IDoubleArray res = doublesNew.array(nObservableStates, nhidden);
        for (int i = 0; i < nhidden; i++)
        {
            IDoubleArray pout = par.getOutputParameters(i);
            for (int k = 0; k < nObservableStates; k++)
            {
                res.set(k, i, pout.get(k));
            }
        }
        return res;
    }

    public void estimateMSM()
    {
        // and estimate reversible MSM on tau
        msmC = msm.estimateC(dtraj, tau);

        //System.out.println("C(tau) = " + msmC);

        msmT = msm.estimateTrev(msmC);
        msmpi = msm.stationaryDistribution(msmT);
        msmPi = doublesNew.diag(msmpi);
        msmCorr = alg.product(msmPi, msmT);
        msmTimescales = msm.timescales(msmT, tau);
        System.out.println("MSM timescales: \n" + doubles.subToNew(msmTimescales, 0, Math.min(5, msmTimescales.size())));

        // get chi and coarse-grained T
        IDoubleArray[] cg = msm.coarseGrain(msmT, nhidden);
        msmTC = cg[0];
        msmChi = cg[1];

        System.out.println("PCCA chi: ");
        System.out.println(msmChi);
        System.out.println("PCCA TC: ");
        System.out.println(msmTC);
        System.out.println("PCCA TC timescales: ");
        System.out.println(msm.timescales(msmTC, tau));
    }

    public void estimateHMM()
            throws ParameterEstimationException
    {
        // coarse grain trajectory to tau and run HMM
        //List<IDataSequence> dtrajCompressed = compressTrajectory(dtraj);
        List<IIntArray> dTrajectories = subsamples(dtraj, timeshift);
        int nstates = intseq.max(dTrajectories) + 1;
        List<IDataSequence> dtrajCompressed = toObservation(dTrajectories, nstates);

        IHMMParameters hmmParameters = hmm(dtrajCompressed, initTC, initChi);
        hmmTC = hmmParameters.getTransitionMatrix();
        hmmpiC = msm.stationaryDistribution(hmmTC);
        hmmChi = getChi(hmmParameters);

        // estimate and output timescales
        hmmTimescales = msm.timescales(hmmTC, tau);
    }

    public void estimate()
            throws ParameterEstimationException
    {
        // test input
        if (nhidden < 0)
        {
            throw new IllegalArgumentException("ABORTING NINJA: number of hidden states not yet set.");
        }
        if (tau < 0)
        {
            throw new IllegalArgumentException("ABORTING NINJA: tau not yet set.");
        }
        if (timeshift < 0)
        {
            throw new IllegalArgumentException("ABORTING NINJA: time shift not yet set.");
        }

        System.out.println("====================================================");
        System.out.println(" Running Ninja estimation with tau = "+tau);
        System.out.println("====================================================");
        System.out.println();
        System.out.println("----------------------------------------------------");
        System.out.println(" MSM reference estimation");
        estimateMSM();

        if (initChi == null && initTC == null)
        {
            initChi = msmChi;
            initTC = msmTC;
            System.out.println("HMM initialization: Using PCCA results");
        }
        else
        {
            System.out.println("HMM initialization: Using user-defined matrices:");
            System.out.println(" init chi: ");
            System.out.println(initChi);
            System.out.println(" init TC: ");
            System.out.println(initTC);
        }
        System.out.println("----------------------------------------------------");
        System.out.println();

        System.out.println("----------------------------------------------------");
        System.out.println(" HMM estimation");

        estimateHMM();

        // print chi:
        System.out.println("chi: \n" + hmmChi);

        // print TC:
        System.out.println("TC: \n" + hmmTC);

        System.out.println("HMM timescales: \n" + hmmTimescales);

        System.out.println("----------------------------------------------------");
        System.out.println();
    }

    ////////////////////////////////////////////////////////////////////////////
    // EM methods
    ////////////////////////////////////////////////////////////////////////////
   
    @Override
    public void setInitialParameters(IHMMParameters par)
    {
        throw new UnsupportedOperationException("Not supported. Initialization is done by MSM");
    }

    @Override
    public void setInitialPaths(List<IIntArray> paths)
    {
        throw new UnsupportedOperationException("Not supported. Initialization is done by MSM");
    }

    @Override
    public void setMaximumNumberOfStep(int nsteps)
    {
        this.nIterHMMMax = nsteps;
    }

    @Override
    public void setLikelihoodDecreaseTolerance(double _dectol)
    {
        this.maxHMMLInc = _dectol;
    }

    @Override
    public double[] getLogLikelihoodHistory()
    {
        return this.hmmLikelihoodHistory;
    }

    @Override
    public void run() throws ParameterEstimationException
    {
        this.estimate();
    }

    @Override
    public IHMM getHMM()
    {
        return this.hmmEst;
    }
}
TOP

Related Classes of stallone.hmm.pmm.NinjaEstimator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.