Package InfoCollection

Source Code of InfoCollection.Simulator

package InfoCollection;

import java.util.Random;

import InfoCollection.util.MeanVariance;

/*
* Eventually we might want to make IndependentNormal and CorrelatedNormal
* members of an "Offline" or "RankingSelection" library, and make this
* Simulator part of that library.  It depends heavily on the "take a
* collection of measurements, and collect loss at the end."
*/

/*
* Simulates one or several measurement policies on a single measurement
* problem.  Each measurement policy is simulated many times and aggregate
* values are reported.
*/
public class Simulator {
  static int debug_level = 0;

  /*
   * Run a policy a single time through N iterations with the truth fixed.
   * Returns the loss at each iteration n.
   */
  public static double[] LossThroughTime(SamplingRule policy, Truth truth,
      int N) throws Exception {
    double[] loss = new double[N];
    Random rnd = new Random();
    policy.Start(truth.M());
    for (int n = 0; n < N; n++) {
      int x = policy.GetMeasurementDecision();
      double y = truth.Sample(x, rnd);
      policy.RecordMeasurement(x, y);
      int i = policy.GetImplementationDecision();
      loss[n] = truth.Loss(i);
    }
    return loss;
  }

  /*
   * This function runs a policy on a fixed truth for a fixed number of
   * measurements N. We have two versions. One takes a number of samples, and
   * repeats the experiment, returning a sample average. The other runs the
   * experiment once and returns the loss realized.
   */
  public static double LossAtEnd(SamplingRule policy, Truth truth, int N,
      int nruns) throws Exception {
    MeanVariance lossEstimator = new MeanVariance();
    Random rnd = new Random();
    for (int r = 0; r < nruns; r++) {
      policy.Start(truth.M());
      for (int n = 0; n < N; n++) {
        int x = policy.GetMeasurementDecision();
        double y = truth.Sample(x, rnd);
        policy.RecordMeasurement(x, y);
      }
      int i = policy.GetImplementationDecision();
      lossEstimator.AddSample(truth.Loss(i));
    }
    return lossEstimator.SampleMean();
  }

  public static double LossAtEnd(SamplingRule policy, Truth truth, int N)
      throws Exception {
    return LossAtEnd(policy, truth, N, 1);
  }

  /*
   * Runs a policy for a fixed number of measurements, generating the truth
   * from the belief b. Also note that it returns a MeanVariance, rather than
   * just the sample average.
   */
  public static MeanVariance LossAtEnd(SamplingRule policy, Belief b, int N,
      int nruns) throws Exception {
    MeanVariance lossEstimator = new MeanVariance();
    Random rnd = new Random();
    for (int r = 0; r < nruns; r++) {
      Truth truth = b.GenerateTruth(rnd);
      policy.Start(truth.M());
      for (int n = 0; n < N; n++) {
        int x = policy.GetMeasurementDecision();
        double y = truth.Sample(x, rnd);
        policy.RecordMeasurement(x, y);
      }
      int i = policy.GetImplementationDecision();
      lossEstimator.AddSample(truth.Loss(i));
    }
    return lossEstimator;
  }

  /*
   * This function takes a single policy, a single set of true values, an
   * array of stopping rules, and a maximum number of samples to take per run.
   * It also takes an optional argument n0 which is a number of times to
   * sample each alternative in a first stage of measurements.  If this
   * argument is omitted, it is set to 0.
   *
   * For each run, the function runs the policy until it finds that stop[0]
   * would stop. It then records the opportunity cost that would be received
   * for stopping then, and continues on until stop[1] would stop, recording
   * the opportunity cost there, etc. If the number of samples taken in a
   * particular run would exceed maxN, then that run is terminated.
   *
   * It returns a two dimensional array. For stopping rule stop[i],
   * return[i][0] contains the average opportunity cost; return[i][1] contains
   * the average number of samples taken; return[i][2] contains the total
   * number of runs terminated; return[i][3] and return[i][4] contain the
   * standard error in the estimates of OC and num samples respectively.
   */
  public static double[][] LossWhenStopped(SamplingRule policy, Truth truth,
      StoppingRuleSeq stop, int nruns, int maxN) throws Exception {
    return LossWhenStopped(policy, truth, stop, nruns, maxN, 0); // run with n0=0
  }
  public static double[][] LossWhenStopped(SamplingRule policy, Truth truth,
      StoppingRuleSeq stop, int nruns, int maxN, int n0) throws Exception {

    int n, r, s;
    int timeouts[] = new int[stop.Length()];
    MeanVariance[] loss = new MeanVariance[stop.Length()];
    MeanVariance[] N = new MeanVariance[stop.Length()];
    double[][] result = new double[stop.Length()][5];
    for (s = 0; s < stop.Length(); s++) {
      timeouts[s] = 0;
      loss[s] = new MeanVariance();
      N[s] = new MeanVariance();
    }

    Random rnd = new Random();
    for (r = 0; r < nruns; r++) {
      policy.Start(truth.M());
      stop.Start(truth.M());
      n = s = 0;

      /* Sample each alternative n0 times. */
      for (int i=0; i<n0; i++) {
        for (int x=0; x<truth.M(); x++) {
          /* Sample alternative x. */
          double y = truth.Sample(x, rnd);
          if (debug_level > 1)
            System.out.format("Measuring alternative %d\n", x);
          policy.RecordMeasurement(x, y);
          stop.RecordMeasurement(x, y);
          n++;
        }
      }


      /* Do a second stage where we pass control to the allocation/stopping rule. */
      while (n < maxN && s < stop.Length()) {
        if (stop.ShouldStop()) {
          int i = policy.GetImplementationDecision();
          if (debug_level > 0)
            System.out.format(
                "Stopping at n=%d with i=%d and loss=%f\n", n,
                i, truth.Loss(i));
          loss[s].AddSample(truth.Loss(i));
          assert truth.Loss(i) >= 0;
          N[s].AddSample(n);
          s++;
          stop.Next();
          continue;
        }

        int x = policy.GetMeasurementDecision();
        double y = truth.Sample(x, rnd);
        if (debug_level > 1)
          System.out.format("Measuring alternative %d\n", x);
        policy.RecordMeasurement(x, y);
        stop.RecordMeasurement(x, y);
        n++;
      }
      if (n == maxN) {
        if (debug_level > 0)
          System.out
              .println("Maximum number of measurements exceeded.  Stopping all stopping rules.\n");
        int i = policy.GetImplementationDecision();
        for (; s < stop.Length(); s++) {
          timeouts[s]++;
          loss[s].AddSample(truth.Loss(i));
          assert truth.Loss(i) >= 0;
          N[s].AddSample(n);
        }
      }
    }

    for (s = 0; s < stop.Length(); s++) {
      result[s][0] = loss[s].SampleMean();
      result[s][1] = N[s].SampleMean();
      result[s][2] = timeouts[s];
      result[s][3] = loss[s].SampleMeanDeviation();
      result[s][4] = N[s].SampleMeanDeviation();
    }
    return result;
  }

  /*
   * Like the version of LossWhenStopped above, except on each new run the
   * truth is sampled randomly from the prior.
   */
  public static double[][] LossWhenStopped(SamplingRule policy, Belief prior,
      StoppingRuleSeq stop, int nruns, int maxN) throws Exception {
    if (!prior.IsInformative())
      throw (new Exception("prior is not informative"));

    int n, r, s;
    int timeouts[] = new int[stop.Length()];
    MeanVariance[] loss = new MeanVariance[stop.Length()];
    MeanVariance[] N = new MeanVariance[stop.Length()];
    double[][] result = new double[stop.Length()][5];
    for (s = 0; s < stop.Length(); s++) {
      timeouts[s] = 0;
      loss[s] = new MeanVariance();
      N[s] = new MeanVariance();
    }

    Random rnd = new Random();
    for (r = 0; r < nruns; r++) {
      Truth truth = prior.GenerateTruth(rnd);
      policy.Start(truth.M());
      stop.Start(truth.M());
      n = s = 0;
      while (n < maxN && s < stop.Length()) {
        if (stop.ShouldStop()) {
          int i = policy.GetImplementationDecision();
          if (debug_level > 0)
            System.out.format(
                "Stopping at n=%d with i=%d and loss=%f\n", n,
                i, truth.Loss(i));
          loss[s].AddSample(truth.Loss(i));
          assert truth.Loss(i) >= 0;
          N[s].AddSample(n);
          s++;
          stop.Next();
          continue;
        }

        int x = policy.GetMeasurementDecision();
        if (debug_level > 1)
          System.out.format("Measuring alternative %d\n", x);
        double y = truth.Sample(x, rnd);
        policy.RecordMeasurement(x, y);
        stop.RecordMeasurement(x, y);
        n++;
      }
      if (n == maxN) {
        if (debug_level > 0)
          System.out
              .println("Maximum number of measurements exceeded.  Stopping all stopping rules.\n");
        int i = policy.GetImplementationDecision();
        for (; s < stop.Length(); s++) {
          timeouts[s]++;
          loss[s].AddSample(truth.Loss(i));
          assert truth.Loss(i) >= 0;
          N[s].AddSample(n);
        }
      }
    }

    for (s = 0; s < stop.Length(); s++) {
      result[s][0] = loss[s].SampleMean();
      result[s][1] = N[s].SampleMean();
      result[s][2] = timeouts[s];
      result[s][3] = loss[s].SampleMeanDeviation();
      result[s][4] = N[s].SampleMeanDeviation();
    }
    return result;
  }

  /*
   * Version of LossWhenStopped for a single stopping rule. Returns an array,
   * return[0] contains the average opportunity cost.
   * return[1] contains the average number of samples taken.
   * return[2] contains the total number of runs terminated.
   * return[3] contains the sample mean deviation for opportunity cost.
   * return[4] contains the sample mean deviation for average number of samples taken.
   */
  public static double[] LossWhenStopped(SamplingRule policy, Truth truth,
      StoppingRule stop, int nruns, int maxN) throws Exception {
    StoppingRuleSeq stopSeq = new OneStoppingRule(stop);
    double[][] resultSeq = LossWhenStopped(policy, truth, stopSeq, nruns, maxN);
    double[] result = new double[5];
    result[0] = resultSeq[0][0];
    result[1] = resultSeq[0][1];
    result[2] = resultSeq[0][2];
    result[3] = resultSeq[0][3];
    result[4] = resultSeq[0][4];
    return result;
  }

   /*
   * public static double[] LossWhenStopped(SamplingRule policy, Belief prior,
   * StoppingRule stop, int nruns, int maxN) throws Exception {
   * StoppingRuleSeq stopSeq = new StoppingRuleSeq(stop); double[][] resultSeq
   * = LossWhenStopped(policy, prior, stopSeq, nruns, maxN); double[] result;
   * result[0] = resultSeq[0][0]; result[1] = resultSeq[0][1]; result[2] =
   * resultSeq[0][2]; result[3] = resultSeq[0][3]; result[4] =
   * resultSeq[0][4]; }
   */
TOP

Related Classes of InfoCollection.Simulator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.