Package stallone.hmm

Source Code of stallone.hmm.HMMForwardModel

/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/

package stallone.hmm;

import java.util.List;
import stallone.api.datasequence.IDataSequence;
import stallone.api.doubles.IDoubleArray;
import stallone.api.function.IParametricFunction;
import stallone.api.hmm.IHMMParameters;
import stallone.api.mc.ITransitionMatrixEstimator;
import stallone.api.mc.MarkovModel;

/**
* This is a basic implementation of IHMMParameters that may be used. However,
* the user can also choose to implement his own IHMMParameters.
* This implementation contains transition matrix estimation and stationary
* distribution estimation.
* @author noe
*/
public class HMMForwardModel implements IHMMForwardModel
{
    private List<IDataSequence> obs;

    private int nstates;
    private boolean eventBased = false;
    private IHMMParameters par;
    private IParametricFunction[] fOut;

    private ITransitionMatrixEstimator Testimator;
    private MatrixPowerCache matrixPower;

    public HMMForwardModel(List<IDataSequence> _obs, boolean _eventBased, int _nstates, boolean reversible, IParametricFunction _fOut)
    {
        obs = _obs;
        this.nstates = _nstates;
        eventBased = _eventBased;
        fOut = new IParametricFunction[nstates];
        for (int i=0; i<fOut.length; i++)
        {
            fOut[i] = _fOut.copy();
        }

        // find out what is the maximum time step and initialized matrix power cache.
        if (eventBased)
        {
            int dtmax = 1;
            for (IDataSequence seq : obs)
                for (int i=0; i<seq.size()-1; i++)
                {
                    int dt = (int)Math.round(seq.getTime(i+1)-seq.getTime(i));
                    if (dt > dtmax)
                        dtmax = dt;
                }
            if (dtmax > 1000)
                dtmax = 1000;
            matrixPower = new MatrixPowerCache(dtmax);
        }

        // construct transition matrix estimator
        if (reversible)
        {
            Testimator = MarkovModel.create.createTransitionMatrixEstimatorRev();
        }
        else
        {
            Testimator = MarkovModel.create.createTransitionMatrixEstimatorNonrev();
        }
    }

    /**
     * Sets the underlying parameter set.
     * @param _par
     */
    public void setParameters(IHMMParameters _par)
    {
        par = _par;
        for (int i=0; i<fOut.length; i++)
        {
            fOut[i].setParameters(_par.getOutputParameters(i));
        }
    }

    /**
     * Creates deep copy
     * @return
     */
    public HMMForwardModel copy()
    {
        HMMForwardModel res = new HMMForwardModel(obs, eventBased, fOut.length, par.isReversible(), fOut[0]);
        res.setParameters(par);
        return res;
    }

    //@Override
    public int getNStates()
    {
        return nstates;
    }

    //@Override
    public int getNObs()
    {
        return obs.size();
    }

    //@Override
    public int getNObs(int traj)
    {
        return obs.get(traj).size();
    }

    public boolean isEventBased()
    {
        return eventBased;
    }

    //@Override
    public double getP0(int traj, int state)
    {
        return par.getInitialDistribution().get(state);
    }

    //@Override
    public double getPtrans(int traj, int timeindex1, int state1, int state2)
    {
        if (!eventBased)
            return(par.getTransitionMatrix().get(state1,state2));
        else
        {
            IDataSequence seq = obs.get(traj);
            int dt = (int)(seq.getTime(timeindex1+1)-seq.getTime(timeindex1));
            return matrixPower.getPowerElement(par.getTransitionMatrix(), dt, state1, state2);
        }
    }

    //@Override
    public double getPout(int traj, int timeindex, int state)
    {
        IDoubleArray x = obs.get(traj).get(timeindex);
        return fOut[state].f(x);
    }

    public IDoubleArray getObs(int traj, int timeindex)
    {
        IDoubleArray x = obs.get(traj).get(timeindex);
        return x;
    }

    //@Override
    public void setTransitionCounts(IDoubleArray C)
    {
        Testimator.setCounts(C);
        Testimator.estimate();
        IDoubleArray T = Testimator.getTransitionMatrix();
        par.setTransitionMatrix(T);
    }

    //@Override
    public void setOutputParameters(int state, IDoubleArray parOut)
    {
        par.setOutputParameters(state, parOut);
        fOut[state].setParameters(parOut);
    }

    //@Override
    public IHMMParameters getParameters()
    {
        return par;
    }

}
TOP

Related Classes of stallone.hmm.HMMForwardModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.