Package plm.util.gaussian

Source Code of plm.util.gaussian.GaussianArHpEvaluator

package plm.util.gaussian;

import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;

import java.util.List;

import plm.gaussian.GaussianArHpWfParticle;
import plm.hmm.GenericHMM;
import plm.hmm.GenericHMM.SimHmmObservedValue;
import plm.hmm.gaussian.GaussianArHpTransitionState;
import plm.hmm.HmmTransitionState;
import au.com.bytecode.opencsv.CSVWriter;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.statslibextensions.util.ObservedValue;
import com.statslibextensions.util.ObservedValue.SimObservedValue;

public class GaussianArHpEvaluator {

  public static final String evaluatorType = "psi and sigma2 learning rmse";
  protected RingAccumulator<MutableDouble> runningStateRmse =
      new RingAccumulator<MutableDouble>();
  protected RingAccumulator<MutableDouble> runningPsiRmse =
      new RingAccumulator<MutableDouble>();
  protected RingAccumulator<MutableDouble> runningSigma2Rmse =
      new RingAccumulator<MutableDouble>();

  protected RingAccumulator<MutableDouble> runningStateErrRate =
      new RingAccumulator<MutableDouble>();
  protected RingAccumulator<MutableDouble> runningPsiErrRate =
      new RingAccumulator<MutableDouble>();
  protected RingAccumulator<MutableDouble> runningSigma2ErrRate =
      new RingAccumulator<MutableDouble>();

  protected final String modelId;
  protected final CSVWriter writer;
  protected final Vector truePsi;
  protected final double trueSigma2;
  private double stateLastRmse;
  private double psiLastRmse;
  private double sigma2LastRmse;
  private Vector stateLastMean;
  private Vector psiLastMean;
  private double sigma2LastMean;
  private double stateLastErrRate;
  private double psiLastErrRate;
  private double sigma2LastErrRate;
 
  public GaussianArHpEvaluator(String modelId, Vector truePsi, double trueSigma2,
    CSVWriter writer) {
    this.modelId = modelId;
    this.writer = writer;
    this.truePsi = truePsi;
    this.trueSigma2 = trueSigma2;
  }

  public <N, T extends GaussianArHpWfParticle> void evaluate(
      int replication, SimObservedValue<Vector, ?, Vector> obs, DataDistribution<T> distribution) {

    RingAccumulator<Vector> stateAvg = new RingAccumulator<Vector>();
    RingAccumulator<Vector> psiAvg = new RingAccumulator<Vector>();
    RingAccumulator<MutableDouble> sigma2Avg = new RingAccumulator<MutableDouble>();
    for (T particle : distribution.getDomain()) {
      final double particleWeight = distribution.getFraction(particle);
      // stupid hack for a java bug
      final Object pobj = particle;
      Preconditions.checkState(pobj instanceof GaussianArHpWfParticle);
      GaussianArHpWfParticle gParticle = (GaussianArHpWfParticle) pobj;

      MultivariateGaussian state = gParticle.getState();
      stateAvg.accumulate(state.getMean().scale(particleWeight));

      MultivariateGaussian psi = gParticle.getPsiSS();
      psiAvg.accumulate(psi.getMean().scale(particleWeight));
     
      double sigma2Mean = gParticle.getSigma2SS().getMean();
      sigma2Avg.accumulate(new MutableDouble(sigma2Mean * particleWeight));

    }
   
    this.stateLastMean = stateAvg.getSum();
    this.psiLastMean = psiAvg.getSum();
    this.sigma2LastMean = sigma2Avg.getSum().doubleValue();

    this.stateLastRmse = stateLastMean.minus(obs.getTrueState()).norm2();
    this.psiLastRmse = psiLastMean.minus(this.truePsi).norm2();
    this.sigma2LastRmse = Math.abs(this.trueSigma2 - sigma2LastMean);
    runningStateRmse.accumulate(new MutableDouble(stateLastRmse));
    runningPsiRmse.accumulate(new MutableDouble(psiLastRmse));
    runningSigma2Rmse.accumulate(new MutableDouble(sigma2LastRmse));

    this.stateLastErrRate = stateLastRmse/obs.getTrueState().norm2();
    this.psiLastErrRate = psiLastRmse/this.truePsi.norm2();
    this.sigma2LastErrRate = sigma2LastRmse/this.trueSigma2;
    runningStateErrRate.accumulate(new MutableDouble(stateLastErrRate));
    runningPsiErrRate.accumulate(new MutableDouble(psiLastErrRate));
    runningSigma2ErrRate.accumulate(new MutableDouble(sigma2LastErrRate));

    if (writer != null) {
      String[] line = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          "psi",
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(psiLastRmse)};
      writer.writeNext(line);

      String[] line2 = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          "sigma2",
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(sigma2LastRmse)};
      writer.writeNext(line2);

      String[] line3 = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          "state",
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(stateLastRmse)};
      writer.writeNext(line3);

      String[] line11 = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          "psi",
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(psiLastErrRate)};
      writer.writeNext(line11);

      String[] line22 = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          "sigma2",
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(sigma2LastErrRate)};
      writer.writeNext(line22);

      String[] line33 = {
          Integer.toString(replication),
          Long.toString(obs.getTime()),
          this.modelId,
          "state",
          distribution.getMaxValueKey().getResampleType().toString(),
          Double.toString(stateLastErrRate)};
      writer.writeNext(line33);
    }
   
  }

  public double getStateLastErrRate() {
    return stateLastErrRate;
  }

  public double getPsiLastErrRate() {
    return psiLastErrRate;
  }

  public double getSigma2LastErrRate() {
    return sigma2LastErrRate;
  }

  public Vector getStateLastMean() {
    return stateLastMean;
  }

  public Vector getPsiLastMean() {
    return psiLastMean;
  }

  public double getSigma2LastMean() {
    return sigma2LastMean;
  }

  public double getRunningStateErrRate() {
    return runningStateErrRate.getMean().doubleValue();
  }

  @Override
  public String toString() {
    StringBuilder builder = new StringBuilder();
    builder.append("GaussianArHpEvaluator [modelId=");
    builder.append(getModelId());
    builder.append(",\n truePsi=");
    builder.append(getTruePsi());
    builder.append(",\n trueSigma2=");
    builder.append(getTrueSigma2());
    builder.append(",\n stateLastMean=");
    builder.append(getStateLastMean());
    builder.append(",\n psiLastMean=");
    builder.append(getPsiLastMean());
    builder.append(",\n sigma2LastMean=");
    builder.append(getSigma2LastMean());
    builder.append(",\n stateLastRmse=");
    builder.append(getStateLastRmse());
    builder.append(",\n psiLastRmse=");
    builder.append(getPsiLastRmse());
    builder.append(",\n sigma2LastRmse=");
    builder.append(getSigma2LastRmse());
    builder.append(",\n runningStateErrRate=");
    builder.append(getRunningStateErrRate());
    builder.append(",\n runningPsiErrRate=");
    builder.append(getRunningPsiErrRate());
    builder.append(",\n runningSigma2ErrRate=");
    builder.append(getRunningSigma2ErrRate());
    builder.append(",\n runningStateRmse=");
    builder.append(getRunningStateRmse());
    builder.append(",\n runningSigma2Rmse=");
    builder.append(getRunningSigma2Rmse());
    builder.append(",\n runningPsiRmse=");
    builder.append(getRunningPsiRmse());
    builder.append("]");
    return builder.toString();
  }

  public double getRunningPsiErrRate() {
    return runningPsiErrRate.getMean().doubleValue();
  }

  public double getRunningSigma2ErrRate() {
    return runningSigma2ErrRate.getMean().doubleValue();
  }

  public static String getEvaluatortype() {
    return evaluatorType;
  }

  public Vector getTruePsi() {
    return truePsi;
  }

  public double getTrueSigma2() {
    return trueSigma2;
  }

  public double getStateLastRmse() {
    return stateLastRmse;
  }

  public double getPsiLastRmse() {
    return psiLastRmse;
  }

  public double getSigma2LastRmse() {
    return sigma2LastRmse;
  }

  public double getRunningStateRmse() {
    return this.runningStateRmse.getMean().doubleValue();
  }

  public double getRunningSigma2Rmse() {
    return this.runningSigma2Rmse.getMean().doubleValue();
  }

  public double getRunningPsiRmse() {
    return this.runningPsiRmse.getMean().doubleValue();
  }

  public String getModelId() {
    return modelId;
  }

  public CSVWriter getWriter() {
    return writer;
  }

}
TOP

Related Classes of plm.util.gaussian.GaussianArHpEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.