Package plm.hmm

Source Code of plm.hmm.StandardHMM$ExposedHmm

package plm.hmm;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;

/**
* HMM without state dependencies.
* @author bwillar0
*
* @param <ResponseType>
*/
public class StandardHMM<ResponseType> extends AbstractCloneableSerializable implements GenericHMM<ResponseType, Integer, DataDistribution<Integer>> {
 
  protected ExposedHmm<ResponseType> hmm;

  public StandardHMM(Vector initialProbs, Matrix transProbs, List<? extends ComputableDistribution<ResponseType>> emissions) {
    hmm = new ExposedHmm<ResponseType>(new HiddenMarkovModel<ResponseType>(initialProbs, transProbs, emissions));
  }

  @Override
  public List<SimHmmObservedValue<ResponseType, Integer>> sample(
      Random random, int numSamples) {

    List<SimHmmObservedValue<ResponseType, Integer>> results = Lists.newArrayList();
    Vector p = hmm.getInitialProbability();
    int state = -1;
    for (int n = 0; n < numSamples; n++) {
      double value = random.nextDouble();
      state = -1;
      while (value > 0.0) {
        state++;
        value -= p.getElement(state);
      }
 
      ResponseType sample =
          CollectionUtil.getElement(hmm.getEmissionFunctions(),
              state).sample(random);
      results.add(new SimHmmObservedValue<ResponseType, Integer>(n, state, state, sample));
      p = hmm.getTransitionProbability().getColumn(state);
    }
 
    return results;
  }

  @Override
  public int getNumStates() {
    return hmm.getNumStates();
  }

  @Override
  public Matrix getTransitionProbability() {
    return hmm.getTransitionProbability();
  }

  public ArrayList<Integer> viterbi(
      Collection<? extends ResponseType> observations) {
    return hmm.viterbi(observations);
  }

  public ArrayList<Vector> stateBeliefs(
      Collection<? extends ResponseType> observations) {
    return hmm.stateBeliefs(observations);
  }

  protected static class ExposedHmm<ObservationType> extends HiddenMarkovModel<ObservationType> {
 
    private static final long serialVersionUID = 7707497613044304004L;
 
    public ExposedHmm(HiddenMarkovModel<ObservationType> hmm) {
      this.emissionFunctions = hmm.getEmissionFunctions();
      this.initialProbability = hmm.getInitialProbability();
      this.transitionProbability = hmm.getTransitionProbability();
    }
 
    @Override
    protected WeightedValue<Vector> computeForwardProbabilities(
      Vector alpha, Vector b, boolean normalize) {
      return super.computeForwardProbabilities(alpha, b, normalize);
    }
 
    @Override
    protected void computeObservationLikelihoods(
      ObservationType observation, Vector b) {
      super.computeObservationLikelihoods(observation, b);
    }
 
    @Override
    protected WeightedValue<Vector> computeBackwardProbabilities(
      Vector beta, Vector b, double weight) {
      return super.computeBackwardProbabilities(beta, b, weight);
    }
 
    @Override
    public ArrayList<Vector> computeStateObservationLikelihood(
      ArrayList<WeightedValue<Vector>> alphas,
      ArrayList<WeightedValue<Vector>> betas, double scaleFactor) {
      return super.computeStateObservationLikelihood(alphas, betas,
          scaleFactor);
    }
 
    @Override
    protected Matrix computeTransitions(
      ArrayList<WeightedValue<Vector>> alphas,
      ArrayList<WeightedValue<Vector>> betas, ArrayList<Vector> b) {
      return super.computeTransitions(alphas, betas, b);
    }
 
    @Override
    protected WeightedValue<Integer> findMostLikelyState(
      int destinationState, Vector delta) {
      return super.findMostLikelyState(destinationState, delta);
    }
 
    @Override
    protected Pair<Vector, int[]> computeViterbiRecursion(Vector delta,
      Vector bn) {
      return super.computeViterbiRecursion(delta, bn);
    }
 
    @Override
    public 
        double
        computeMultipleObservationLogLikelihood(
          Collection<? extends Collection<? extends ObservationType>> sequences) {
      return super.computeMultipleObservationLogLikelihood(sequences);
    }
 
    @Override
    public ArrayList<WeightedValue<Vector>>
        computeForwardProbabilities(ArrayList<Vector> b,
          boolean normalize) {
      return super.computeForwardProbabilities(b, normalize);
    }
 
    @Override
    public ArrayList<Vector> computeObservationLikelihoods(
      Collection<? extends ObservationType> observations) {
      return super.computeObservationLikelihoods(observations);
    }
 
    @Override
    public ArrayList<WeightedValue<Vector>>
        computeBackwardProbabilities(ArrayList<Vector> b,
          ArrayList<WeightedValue<Vector>> alphas) {
      return super.computeBackwardProbabilities(b, alphas);
    }
  }

  public ArrayList<WeightedValue<Vector>> computeForwardProbabilities(
      ArrayList<Vector> b, boolean normalize) {
    return hmm.computeForwardProbabilities(b, normalize);
  }

  public ArrayList<Vector> computeObservationLikelihoods(
      Collection<? extends ResponseType> observations) {
    return hmm.computeObservationLikelihoods(observations);
  }

  public ArrayList<WeightedValue<Vector>> computeBackwardProbabilities(
      ArrayList<Vector> b, ArrayList<WeightedValue<Vector>> alphas) {
    return hmm.computeBackwardProbabilities(b, alphas);
  }

  public ArrayList<Vector> computeStateObservationLikelihood(
      ArrayList<WeightedValue<Vector>> alphas,
      ArrayList<WeightedValue<Vector>> betas, double scaleFactor) {
    return hmm.computeStateObservationLikelihood(alphas, betas, scaleFactor);
  }

  @Override
  public DataDistribution<Integer> getEmissionFunction(DataDistribution<Integer> o, int classId) {
    return (DataDistribution<Integer>) Iterables.get(this.hmm.getEmissionFunctions(), classId);
  }

  @Override
  public Vector getClassMarginalProbabilities() {
    return this.hmm.getInitialProbability();
  }

  public static <T> StandardHMM<T> create(
      HiddenMarkovModel<T> hmm) {
    return new StandardHMM(hmm.getInitialProbability(),
        hmm.getTransitionProbability(), (List) hmm.getEmissionFunctions());
  }

  @Override
  public StandardHMM<ResponseType> clone() {
    StandardHMM<ResponseType> clone = (StandardHMM<ResponseType>) super.clone();
    clone.hmm = this.hmm;
    return clone;
  }
}
TOP

Related Classes of plm.hmm.StandardHMM$ExposedHmm

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.