Package plm.util.logit.fruehwirth

Source Code of plm.util.logit.fruehwirth.FruehwirthLogitRunner

package plm.util.logit.fruehwirth;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.signals.LinearDynamicalSystem;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.bayesian.KalmanFilter;
import gov.sandia.cognition.statistics.distribution.LogisticDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.Pair;

import java.util.List;
import java.util.Map.Entry;
import java.util.Random;

import plm.logit.fruehwirth.LogitFSWFFilter;
import plm.logit.fruehwirth.LogitMixParticle;

import com.google.common.collect.Lists;
import com.statslibextensions.statistics.bayesian.DlmUtils;
import com.statslibextensions.util.ObservedValue;
import com.statslibextensions.util.ObservedValue.SimObservedValue;

/**
* @author bwillard
*
*/
public class FruehwirthLogitRunner {
 
  public static void main(String[] args) {


    final Random rng = new Random(829351983l);
   
    /*
     * Sample from a logit model.
     */
    final KalmanFilter initialFilter = new KalmanFilter(
          new LinearDynamicalSystem(
              MatrixFactory.getDefault().copyArray(new double[][] {
                  {1d, 0d},
                  {0d, 1d}}),
              MatrixFactory.getDefault().copyArray(new double[][] {
                  {0d, 0d},
                  {0d, 0d}}),
              MatrixFactory.getDefault().copyArray(new double[][] {
                  {2d, 1d}})),
          MatrixFactory.getDefault().copyArray(new double[][] {
              {0d, 0d},
              {0d, 0d}}),
          MatrixFactory.getDefault().copyArray(new double[][] {{0d}})   
        );
   
    final MultivariateGaussian trueInitialPrior = new MultivariateGaussian(
        VectorFactory.getDefault().copyValues(0.5d, 3d),
        MatrixFactory.getDefault().copyArray(new double[][] {
            {1d, 0d},
            {0d, 1d}}));
    final int N = 100;
    final List<ObservedValue<Vector, Matrix>> observations = Lists.newArrayList();
    final LogisticDistribution ev1Dist = new LogisticDistribution(0d, 1d);

    List<SimObservedValue<Vector, Matrix, Vector>> dlmSamples = DlmUtils.sampleDlm(
        rng, N, trueInitialPrior, initialFilter);
    for (SimObservedValue<Vector, Matrix, Vector> samplePair : dlmSamples) {
      final double ev1Upper = -Math.log(-Math.log(rng.nextDouble()));
      final double upperUtility = samplePair.getObservedValue().getElement(0) +
          ev1Upper;
      final double lowerUtility = -Math.log(-Math.log(rng.nextDouble()));
      final double obs = (upperUtility > lowerUtility) ? 1d : 0d;
      observations.add(
          SimObservedValue.<Vector, Matrix, LogitTrueState>create(
              VectorFactory.getDefault().copyValues(obs),
              samplePair.getObservedData(),
              new LogitTrueState(samplePair.getTrueState(), upperUtility, ev1Upper, lowerUtility)));
    }

    /*
     * Create and initialize the PL filter
     */
    final MultivariateGaussian initialPrior = new MultivariateGaussian(
        VectorFactory.getDefault().copyValues(0d, 0d),
        MatrixFactory.getDefault().copyArray(new double[][] {
            {10d, 0d},
            {0d, 10d}}));
    final Matrix F = MatrixFactory.getDefault().copyArray(new double[][] {
        {1d, 1d}});
    final Matrix G = MatrixFactory.getDefault().copyArray(new double[][] {
        {1d, 0d},
        {0d, 1d}});
    final Matrix modelCovariance = MatrixFactory.getDefault().copyArray(new double[][] {
        {0d, 0d},
        {0d, 0d}});
    final LogitFSWFFilter plFilter =
        new LogitFSWFFilter(initialPrior, F, G, modelCovariance, rng);
    plFilter.setNumParticles(10000);

    final DataDistribution<LogitMixParticle> currentMixtureDistribution =
        plFilter.createInitialLearnedObject();
    for (int i = 0; i < N; i++) {
      final ObservedValue<Vector, Matrix> observation = observations.get(i);
      System.out.println("obs:" + observation);
      plFilter.update(currentMixtureDistribution, observation);

      /*
       * Compute some summary stats. TODO We need to compute something informative for this
       * situation.
       */
      double rmseMean = 0d;

      for (Entry<LogitMixParticle, ? extends Number> particleEntry :
        currentMixtureDistribution.asMap().entrySet()) {
        final Vector trueState = dlmSamples.get(i).getTrueState();
        final Vector particleState = particleEntry.getKey().getLinearState().getMean();
        final double rse = Math.sqrt(trueState.minus(particleState).norm2());
//        System.out.println("true: " + trueState + ", particle:" + particleState);

        rmseMean += rse * Math.exp(particleEntry.getValue().doubleValue());
      }
      System.out.println("posterior RMSE mean: " + rmseMean);
//          + " (" + rmseVar + ")");
    }

    System.out.println("finished simulation");

  }

}
TOP

Related Classes of plm.util.logit.fruehwirth.FruehwirthLogitRunner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.