Package plm.util.hmm.gaussian

Source Code of plm.util.hmm.gaussian.GaussianArHpHmmRunner

package plm.util.hmm.gaussian;

import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.signals.LinearDynamicalSystem;
import gov.sandia.cognition.statistics.bayesian.KalmanFilter;
import gov.sandia.cognition.statistics.distribution.InverseGammaDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;

import java.io.FileWriter;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import plm.hmm.DlmHiddenMarkovModel;
import plm.hmm.GenericHMM.SimHmmObservedValue;
import plm.hmm.gaussian.GaussianArHpHmmPLFilter;
import plm.hmm.gaussian.GaussianArHpTransitionState;
import plm.hmm.HmmPlFilter;
import plm.util.hmm.HmmResampleComparisonRunner;
import au.com.bytecode.opencsv.CSVWriter;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import com.statslibextensions.statistics.distribution.CountedDataDistribution;

public class GaussianArHpHmmRunner extends HmmResampleComparisonRunner {

  public static void main(String[] args) throws IOException {

    final long seed = new Random().nextLong();
    final Random random = new Random(seed);
    log.info("seed=" + seed);

    final double trueSigma = Math.pow(0.2d, 2);
    Matrix modelCovariance1 = MatrixFactory.getDefault().copyArray(
        new double[][] {{trueSigma}});
    Matrix modelCovariance2 = MatrixFactory.getDefault().copyArray(
        new double[][] {{trueSigma}});
    Matrix measurementCovariance = MatrixFactory.getDefault().copyArray(
        new double[][] {{trueSigma}});

    List<Vector> truePsis = Lists.newArrayList(
        VectorFactory.getDefault().copyValues(3d, 0.2d),
        VectorFactory.getDefault().copyValues(1d, 0.9d));

    LinearDynamicalSystem model1 = new LinearDynamicalSystem(
        MatrixFactory.getDefault().copyArray(new double[][] {{truePsis.get(0).getElement(1)}}),
        MatrixFactory.getDefault().copyArray(new double[][] {{1d}}),
        MatrixFactory.getDefault().copyArray(new double[][] {{1d}})
      );
    LinearDynamicalSystem model2 = new LinearDynamicalSystem(
        MatrixFactory.getDefault().copyArray(new double[][] {{truePsis.get(1).getElement(1)}}),
        MatrixFactory.getDefault().copyArray(new double[][] {{1d}}),
        MatrixFactory.getDefault().copyArray(new double[][] {{1d}})
      );
    KalmanFilter trueKf1 = new KalmanFilter(model1, modelCovariance1, measurementCovariance);
    trueKf1.setCurrentInput(VectorFactory.getDefault().copyValues(truePsis.get(0).getElement(0)));
    KalmanFilter trueKf2 = new KalmanFilter(model2, modelCovariance2, measurementCovariance);
    trueKf2.setCurrentInput(VectorFactory.getDefault().copyValues(truePsis.get(1).getElement(0)));
   
    Vector initialClassProbs = VectorFactory.getDefault()
            .copyArray(new double[] { 0.7d, 0.3d });
    Matrix classTransProbs = MatrixFactory.getDefault().copyArray(
                new double[][] { { 0.7d, 0.7d },
                    { 0.3d, 0.3d } });
   
    DlmHiddenMarkovModel trueHmm1 = new DlmHiddenMarkovModel(
        Lists.newArrayList(trueKf1, trueKf2),
        initialClassProbs, classTransProbs);

    final double sigmaPriorMean = Math.pow(0.4, 2);
    final double sigmaPriorShape = 2d;
    final double sigmaPriorScale = sigmaPriorMean*(sigmaPriorShape + 1d);
    final InverseGammaDistribution sigmaPrior = new InverseGammaDistribution(sigmaPriorShape,
        sigmaPriorScale);
   
    final Vector phiMean1 = VectorFactory.getDefault().copyArray(new double[] {
        0d, 0.8d
    });
    final Matrix phiCov1 = MatrixFactory.getDefault().copyArray(new double[][] {
        {2d + 4d * sigmaPriorMean, 0d},
        { 0d, 4d * sigmaPriorMean}
    });
    final MultivariateGaussian phiPrior1 = new MultivariateGaussian(phiMean1, phiCov1);

    final Vector phiMean2 = VectorFactory.getDefault().copyArray(new double[] {
        0d, 0.1d
    });
    final Matrix phiCov2 = MatrixFactory.getDefault().copyArray(new double[][] {
        { 1d + 4d * sigmaPriorMean, 0d},
        { 0d, 4d * sigmaPriorMean}
    });
    final MultivariateGaussian phiPrior2 = new MultivariateGaussian(phiMean2, phiCov2);
   
    List<MultivariateGaussian> priorPhis = Lists.newArrayList(phiPrior1, phiPrior2);

    final HmmPlFilter<DlmHiddenMarkovModel, GaussianArHpTransitionState, Vector> wfFilter =
        new GaussianArHpHmmPLFilter(trueHmm1, sigmaPrior, priorPhis, random, true);


    final String path;
    if (args.length == 0)
      path = ".";
    else
      path = args[0];
    String outputFilename = path + "/hmm-nar-wf-rs-10000-class-errors-m1.csv";

    final int K = 5;
    final int T = 700;
    final int N = 1000;

    /*
     * Note: replications are over the same set of simulated observations.
     */
    List<SimHmmObservedValue<Vector, Vector>> simulation = trueHmm1.sample(random, T);

    wfFilter.setNumParticles(N);
    wfFilter.setResampleOnly(false);

    CSVWriter writer = new CSVWriter(new FileWriter(outputFilename), ',');
    String[] header = "rep,t,filter.type,measurement.type,resample.type,measurement".split(",");
    writer.writeNext(header);

    GaussianArHmmClassEvaluator wfClassEvaluator = new GaussianArHmmClassEvaluator("wf-pl",
        writer);
    GaussianArHmmRmseEvaluator wfRmseEvaluator = new GaussianArHmmRmseEvaluator("wf-pl",
        writer);
    GaussianArHmmPsiLearningEvaluator wfPsiEvaluator = new GaussianArHmmPsiLearningEvaluator("wf-pl",
        truePsis, writer);

    RingAccumulator<MutableDouble> wfLatency =
        new RingAccumulator<MutableDouble>();
    Stopwatch wfWatch = new Stopwatch();


    for (int k = 0; k < K; k++) {
      log.info("Processing replication " + k);
      CountedDataDistribution<GaussianArHpTransitionState> wfDistribution =
          (CountedDataDistribution<GaussianArHpTransitionState>) wfFilter.getUpdater().createInitialParticles(N);


      final long numPreRuns = -1l;//wfDistribution.getMaxValueKey().getTime();
     
      /*
       * Recurse through the particle filter
       */
      for (int i = 0; i < T; i++) {
 
        final double x = simulation.get(i).getClassId();
        final Vector y = simulation.get(i).getObservedValue();

        if (i > numPreRuns) {

          if (i > 0) {
            wfWatch.reset();
            wfWatch.start();
            wfFilter.update(wfDistribution, simulation.get(i));
            wfWatch.stop();
            final long latency = wfWatch.elapsed(TimeUnit.MILLISECONDS);
            wfLatency.accumulate(new MutableDouble(latency));
            writer.writeNext(new String[] {
                Integer.toString(k), Integer.toString(i),
                "wf-pl", "latency", "NA",
                Long.toString(latency)
            });
          }
         
          wfClassEvaluator.evaluate(k, simulation.get(i), wfDistribution);
          wfRmseEvaluator.evaluate(k, simulation.get(i), wfDistribution);
          wfPsiEvaluator.evaluate(k, simulation.get(i), wfDistribution);
        }

        if ((i+1) % (T/4d) < 1) {
          log.info("avg. wf latency=" + wfLatency.getMean().value);
          log.info("avg. wfRmse=" + wfRmseEvaluator.getTotalRate().getMean().value);
          log.info("avg. wfClassRate=" + wfClassEvaluator.getTotalRate().getMean().value);
          log.info("avg. wfPsi=" + wfPsiEvaluator.getTotalRate());
        }
      }

    }

    writer.close();
  }
   

}
TOP

Related Classes of plm.util.hmm.gaussian.GaussianArHpHmmRunner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.