Package plm.util.hmm.gaussian

Source Code of plm.util.hmm.gaussian.GaussianArHmmRmseEvaluator

package plm.util.hmm.gaussian;

import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.DataDistribution;
import plm.hmm.GenericHMM;
import plm.hmm.GenericHMM.SimHmmObservedValue;
import plm.hmm.gaussian.GaussianArHpTransitionState;
import plm.hmm.gaussian.GaussianArTransitionState;
import plm.hmm.HmmTransitionState;
import au.com.bytecode.opencsv.CSVWriter;

public class GaussianArHmmRmseEvaluator {

  public static final String evaluatorType = "rmse";
  protected RingAccumulator<MutableDouble> runningRate =
      new RingAccumulator<MutableDouble>();
  protected final String modelId;
  protected final CSVWriter writer;
 
  public GaussianArHmmRmseEvaluator(String modelId) {
    this.modelId = modelId;
    this.writer = null;
  }

  public GaussianArHmmRmseEvaluator(String modelId, CSVWriter writer) {
    this.modelId = modelId;
    this.writer = writer;
  }

  public <N, H extends GenericHMM<N,?,?>, T extends HmmTransitionState<N, H>> void evaluate(
      int replication, SimHmmObservedValue<Vector, Vector> obs,
      DataDistribution<T> distribution) {

    final Vector trueState = obs.getState();
    RingAccumulator<Vector> stateMean = new RingAccumulator<Vector>();
    // rediculous hack to get around java bug
    for (T particle : distribution.getDomain()) {
      final double particleWeight = distribution.getFraction(particle);
      Object tmpParticle = particle;
      if (tmpParticle instanceof GaussianArTransitionState) {
        GaussianArTransitionState gParticle = (GaussianArTransitionState) tmpParticle;
        stateMean.accumulate(VectorFactory.getDefault().copyValues(
            gParticle.getSuffStat().getMean() * particleWeight));
      } else if (tmpParticle instanceof GaussianArHpTransitionState) {
        GaussianArHpTransitionState gParticle = (GaussianArHpTransitionState) tmpParticle;
        stateMean.accumulate(gParticle.getState().getMean().scale(particleWeight));
      }
    }

    final double rmse = stateMean.getSum().minus(trueState).norm2();
    runningRate.accumulate(new MutableDouble(rmse));

    if (writer != null) {
      String[] line = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          evaluatorType,
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(rmse)};
      writer.writeNext(line);
    }
  }

  public RingAccumulator<MutableDouble> getTotalRate() {
    return runningRate;
  }

  public void setWfRunningClassRate(
      RingAccumulator<MutableDouble> rate) {
    this.runningRate = rate;
  }

  public String getModelId() {
    return modelId;
  }

  public CSVWriter getWriter() {
    return writer;
  }

}
TOP

Related Classes of plm.util.hmm.gaussian.GaussianArHmmRmseEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.