Package plm.hmm

Source Code of plm.hmm.DlmHiddenMarkovModel

package plm.hmm;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrix;
import gov.sandia.cognition.math.matrix.mtj.decomposition.CholeskyDecompositionMTJ;
import gov.sandia.cognition.statistics.DiscreteSamplingUtil;
import gov.sandia.cognition.statistics.bayesian.KalmanFilter;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;

import java.util.List;
import java.util.Random;

import com.google.common.collect.Lists;

public class DlmHiddenMarkovModel extends AbstractCloneableSerializable implements GenericHMM<Vector, Vector, MultivariateGaussian> {

  protected List<KalmanFilter> stateFilters;
  protected Vector marginalClassProbs;
  protected Matrix classTransProbs;
 
  public DlmHiddenMarkovModel(List<KalmanFilter> stateFilters, Vector marginalClassProbs, Matrix classTransProbs) {
    this.stateFilters = stateFilters;
    this.marginalClassProbs = marginalClassProbs;
    this.classTransProbs = classTransProbs;
  }

  public List<KalmanFilter> getStateFilters() {
    return stateFilters;
  }

  public void setStateFilters(List<KalmanFilter> stateFilters) {
    this.stateFilters = stateFilters;
  }
 
  @Override
  public DlmHiddenMarkovModel clone() {
    DlmHiddenMarkovModel clone = (DlmHiddenMarkovModel) super.clone();
    clone.classTransProbs = this.classTransProbs.clone();
    clone.marginalClassProbs = this.marginalClassProbs.clone();
    clone.stateFilters = ObjectUtil.cloneSmartElementsAsArrayList(this.stateFilters);
    return clone;
  }
 

  /**
   * Sample a trajectory up to time T.
   * Note: this method will add the filter's current input to the
   * initial (and all other) states, so if you set the model input, the
   * offset/input will be added twice to the initial state.
   *
   * @param random
   * @param numSamples
   * @return
   */
  @Override
  public List<SimHmmObservedValue<Vector, Vector>> sample(Random random, int T) {
    List<SimHmmObservedValue<Vector, Vector>> results = Lists.newArrayList();

    int currentClass = DiscreteSamplingUtil.sampleIndexFromProbabilities(random, this.marginalClassProbs);
    Vector currentState = null;
    for (int i = 0; i < T; i++) {

      Vector classProbs = (currentState == null) ? this.marginalClassProbs :
        this.classTransProbs.getColumn(currentClass);

      currentClass = DiscreteSamplingUtil.sampleIndexFromProbabilities(random, classProbs);

      KalmanFilter currentFilter = this.stateFilters.get(currentClass);
      if (currentState == null) {
        currentState = currentFilter.createInitialLearnedObject().sample(random);
      } else {
        final Matrix G = currentFilter.getModel().getA();
        Matrix modelCovSqrt = CholeskyDecompositionMTJ.create(
            (DenseMatrix) currentFilter.getModelCovariance()).getR();
        currentState = MultivariateGaussian.sample(G.times(currentState), modelCovSqrt, random);
      }
      currentState.plusEquals(currentFilter.getModel().getB().times(
          currentFilter.getCurrentInput()));

      final Matrix F = currentFilter.getModel().getC();
      Vector observationMean = F.times(currentState);
      Matrix measurementCovSqrt = CholeskyDecompositionMTJ.create(
          (DenseMatrix) currentFilter.getMeasurementCovariance()).getR();
      Vector observation = MultivariateGaussian.sample(observationMean,
          measurementCovSqrt, random);
 
      results.add(new SimHmmObservedValue<Vector, Vector>(i, currentClass,
          currentState.clone(), observation));
    }
   
    return results;
  }

  @Override
  public int getNumStates() {
    return this.marginalClassProbs.getDimensionality();
  }

  @Override
  public MultivariateGaussian getEmissionFunction(MultivariateGaussian state, int classId) {
    KalmanFilter kf = this.stateFilters.get(classId);
    final Vector mean = kf.getModel().getC().times(state.getMean());
    final Matrix cov = kf.getModel().getC().times(state.getCovariance())
        .times(kf.getModel().getC().transpose())
        .plus(kf.getMeasurementCovariance());
    final MultivariateGaussian likelihood = new MultivariateGaussian(
        mean, cov);
    return likelihood;
  }

  @Override
  public Matrix getTransitionProbability() {
    return this.classTransProbs;
  }

  @Override
  public Vector getClassMarginalProbabilities() {
    return this.marginalClassProbs;
  }


}
TOP

Related Classes of plm.hmm.DlmHiddenMarkovModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.