Package com.mapr.stats.bandit

Source Code of com.mapr.stats.bandit.BanditTrainer

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mapr.stats.bandit;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.mapr.stats.random.BinomialDistributionSampler;
import com.mapr.stats.random.DistributionGenerator;
import com.mapr.stats.random.DistributionWithMean;
import com.mapr.stats.random.NormalDistributionSampler;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.stats.OnlineSummarizer;

import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
* Simulate a two-armed bandit playing against a beta-Bayesian model.
* <p/>
* The output indicates the quantiles of the distribution for regret relative to the optimal pick.
* The regret distribution is estimated by picking two random conversion probabilities and then
* running the beta-Bayesian model for a number of steps.  The regret is computed by taking the
* expectation for the optimal choice and subtracting from the actual percentage of conversion
* achieved.  On average, this should be somewhat negative since the model has to spend some effort
* examining the sub-optimal choice.  The median, 25 and 75%-ile marks all scale downward fairly
* precisely with the square root of the number of trials which is to be expected from theoretical
* considerations.
* <p/>
* The beta-Bayesian model works by keeping an estimate of the posterior distribution for the
* conversion probability for each of the bandits.  We take a uniform distribution as the prior so
* the posterior is a beta distribution.  The model samples probabilities from the two posterior
* distributions and chooses the model whose sample is larger. As data is collected for the two
* bandits, the better of the bandits will quickly have a pretty narrow posterior distribution and
* the lesser bandit will rarely have a sampled probability higher than the better bandit.  This
* means that we will stop getting data from the less bandit, but only when there is essentially no
* chance that it is better.
*/
public class BanditTrainer {
    private static final int BUCKET_SIZE = 1;

    public static void main(String[] args) throws FileNotFoundException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, InterruptedException {
        int threads = 16;

        if (args.length > 0) {
            threads = Integer.parseInt(args[0]);
        }

        System.out.printf("regret\n");
        ExecutorService ex = Executors.newFixedThreadPool(threads);

        List<Callable<Integer>> tasks = ImmutableList.of(
                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-normal-0.1.tsv", "local-normal-0.1.tsv", 1000, 2, 10000, new GammaNormalBayesFactory(), new NormalDistributionSampler(0.1, new Random()));
                            System.out.printf("2\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },
                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-epsilon-normal-1.tsv", "local-epsilon-normal-1.tsv", 1000, 2, 10000, new EpsilonGreedyFactory(0.05), new NormalDistributionSampler(1, new Random()));
                            System.out.printf("2e\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },


                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-normal-1.tsv", "local-normal-1.tsv", 300, 2, 200000, new GammaNormalBayesFactory(), new NormalDistributionSampler(1, new Random()));
                            System.out.printf("normal 1\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },

                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-normal-10x0.1.tsv", "local-normal-10x0.1.tsv", 1000, 10, 1000, new GammaNormalBayesFactory(), new NormalDistributionSampler(0.1, new Random()));
                            System.out.printf("10\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },

                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-normal-100x0.1.tsv", "local-normal-100x0.1.tsv", 1000, 100, 1000, new GammaNormalBayesFactory(), new NormalDistributionSampler(.1, new Random()));
                            System.out.printf("100\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },

                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret.tsv", "local.tsv", 1000, 2, 1000, new BetaBayesFactory(), new BinomialDistributionSampler(1, 1, new Random()));
                            System.out.printf("2\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },

                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-100.tsv", "local-100.tsv", 1000, 100, 1000, new BetaBayesFactory(), new BinomialDistributionSampler(1, 1, new Random()));
                            System.out.printf("100\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                },

                new Callable<Integer>() {
                    @Override
                    public Integer call() {
                        try {
                            totalRegret("regret-20.tsv", "local-20.tsv", 1000, 20, 1000, new BetaBayesFactory(), new BinomialDistributionSampler(1, 1, new Random()));
                            System.out.printf("20\n");
                        } catch (FileNotFoundException e) {
                            e.printStackTrace();
                        }
                        return null;
                    }
                }
        );
        ex.invokeAll(tasks);
        ex.shutdown();
        System.out.printf("All done");

//    System.out.printf("error rates\n");
//    errorRate("errors.tsv");
//    System.out.printf("commit time\n");
//    commitTime("commit.tsv", 3000, 0.1, 0.12, 2000);
//    System.out.printf("done\n");
    }

    /**
     * Records which bandit was chosen for many runs of the same scenario.  This output is kind of big
     * an hard to digest visually.  As such, it is probably better to reduce this using a mean.  In R,
     * this can be done like this:
     * <pre>
     *    plot(tapply(z$k, floor(z$i/10), mean), type='l')
     * </pre>
     *
     * @param outputFile Where to write results
     * @param n          How many steps to follow
     * @param p1         First probability of reward
     * @param p2         Second probability of reward
     * @param cutoff     Only keep results after this many steps
     * @return Average number of correct choices.
     * @throws java.io.FileNotFoundException If the directory holding the output directory doesn't exist.
     */
    public static double commitTime(String outputFile, int n, double p1, double p2, int cutoff) throws FileNotFoundException {
        try (PrintWriter out = new PrintWriter(outputFile)) {
            Random gen = new Random();
            out.printf("i\tk\n");
            int impressions = 0;
            int correct = 0;
            for (int j = 0; j < 1000; j++) {
                // pick probabilities at random
                double[] p = {
                        p1, p2
                };
                Arrays.sort(p);
                BetaBayesModel s = new BetaBayesModel();
                for (int i = 0; i < n; i++) {
                    int k = s.sample();
                    out.printf("%d\t%d\n", i, k);

                    if (i > cutoff) {
                        impressions++;
                        correct += k;
                    }

                    final double u = gen.nextDouble();
                    boolean r = u <= p[k];
                    s.train(k, r ? 1 : 0);
                }
            }
            return (double) correct / impressions;
        }
    }

    /**
     * Computes error rate (the rate at which the sub-optimal choice is made as a function of the two
     * probabilities and the number of trials.  The output report contains p1, p2, number-of-trials,
     * total-correct, total-correct-in-last-half.
     * <p/>
     * The commitTime output is probably more interesting.
     *
     * @param outputFile Where to write the data.
     * @throws java.io.FileNotFoundException If we can't open our output
     */
    @Deprecated
    private static void errorRate(String outputFile) throws FileNotFoundException {
        try (PrintWriter out = new PrintWriter(outputFile)) {
            out.printf("p1\tp2\tn\twins\tlate\n");
            Random gen = new Random();
            for (int n : new int[]{20, 50, 100, 200, 500, 1000, 2000, 5000}) {
                System.out.printf("%d\n", n);
                for (int j = 0; j < 1000 * (n < 500 ? 10 : 1); j++) {
                    // pick probabilities at random
                    double[] p = {
                            gen.nextDouble(), gen.nextDouble()
                    };
                    // order them to make error interpretation easier
                    Arrays.sort(p);
                    BetaBayesModel s = new BetaBayesModel();
                    int wins = 0;
                    int lateWins = 0;
                    for (int i = 0; i < n; i++) {
                        int k = s.sample();
                        final double u = gen.nextDouble();
                        boolean r = u <= p[k];
                        wins += r ? 1 : 0;
                        if (i > n / 2) {
                            lateWins += r ? 1 : 0;
                        }
                        s.train(k, r ? 1 : 0);
                    }
                    out.printf("%.3f\t%.3f\t%d\t%d\t%d\n", p[0], p[1], n, wins, lateWins);
                }
            }
        }
    }

    /**
     * Computes average regret relative to perfect knowledge given uniform random probabilities. The
     * output contains the quartiles for different numbers of trials.  The quartiles are computed by
     * running many experiments for each specified number of trials.
     * <p/>
     * This can be plotted pretty much directly in R
     * <pre>
     * > x=read.delim(file='~/Apache/storm-aggregator/regret.tsv')
     * > bxp(list(com.mapr.stats=t(as.matrix(x[,2:6])), n=rep(1000,times=8),names=x$n))
     * </pre>
     *
     * @param outputFile   Where to put the output
     * @param sizes        The different size experiments to use
     * @param replications Number of times to repeat the experiment
     * @param bandits      How many bandits to simulate
     * @return Returns the average regret per trial
     * @throws java.io.FileNotFoundException If the output file can't be opened due to a missing directory.
     */
    public static double averageRegret(String outputFile, int[] sizes, int replications, int bandits) throws FileNotFoundException {

        try (PrintWriter out = new PrintWriter(outputFile)) {
            double finalMedianRegret = 0;
            Random gen = new Random();
            out.printf("n\tq0\tq1\tq2\tq3\tq4\n");
            // for each horizon time span of interest
            for (int n : sizes) {
                System.out.printf("%d\n", n);
                OnlineSummarizer summary = new OnlineSummarizer();
                // replicate the test many times
                for (int j = 0; j < replications; j++) {
                    // pick probabilities at random

                    double[] p = new double[bandits];
                    for (int k = 0; k < bandits; k++) {
                        p[k] = gen.nextDouble();
                    }

                    // order them to make error interpretation easier
                    Arrays.sort(p);
                    BetaBayesModel s = new BetaBayesModel(bandits, RandomUtils.getRandom());
                    int wins = 0;
                    for (int i = 0; i < n; i++) {
                        int k = s.sample();
                        final double u = gen.nextDouble();
                        boolean r = u <= p[k];
                        wins += r ? 1 : 0;
                        s.train(k, r ? 1 : 0);
                    }
                    summary.add((double) wins / n - p[bandits - 1]);
                }
                out.printf("%d\t", n);
                for (int quartile = 0; quartile <= 4; quartile++) {
                    out.printf("%.3f%s", summary.getQuartile(quartile), quartile < 4 ? "\t" : "\n");
                }
                out.flush();
                finalMedianRegret = summary.getMedian();

                //      System.out.printf("%.3f\n", summary.getMean());
            }
            return finalMedianRegret;
        }
    }

    /**
     * Computes average regret relative to perfect knowledge given uniform random probabilities. The
     * output contains the quartiles for different numbers of trials.  The quartiles are computed by
     * running many experiments for each specified number of trials.
     * <p/>
     * This can be plotted pretty much directly in R
     * <pre>
     * > x=read.delim(file='~/Apache/storm-aggregator/regret.tsv')
     * > bxp(list(com.mapr.stats=t(as.matrix(x[,2:6])), n=rep(1000,times=8),names=x$n))
     * </pre>
     *
     * @param cumulativeOutput Where to write the cumulative regret results
     * @param perTurnOutput    Where to write the per step regret results
     * @param replications     How many times to replicate the experiment
     * @param bandits          How many bandits to emulate
     * @param maxSteps         Maximum number of trials to run per experiment
     * @param modelFactory     How to construct the solver.
     * @param refSampler       How to get reward distributions for bandits
     * @return An estimate of the average final cumulative regret
     * @throws java.io.FileNotFoundException If the output file can't be opened due to
     *                                       a missing directory.
     */
    public static double totalRegret(String cumulativeOutput, String perTurnOutput, int replications, int bandits, int maxSteps, BanditFactory modelFactory, DistributionGenerator refSampler) throws FileNotFoundException {
        List<OnlineSummarizer> cumulativeRegret = Lists.newArrayList();
        List<OnlineSummarizer> localRegret = Lists.newArrayList();
        List<Integer> steps = Lists.newArrayList();
        List<Integer> localSteps = Lists.newArrayList();

        Random gen = new Random();

        // for each horizon time span of interest
        for (int j = 0; j < replications; j++) {
            BayesianBandit s = modelFactory.createBandit(bandits, gen);

            List<DistributionWithMean> refs = Lists.newArrayList();
            for (int k = 0; k < bandits; k++) {
                refs.add(refSampler.nextDistribution());
            }

            Collections.sort(refs);

            double wins = 0;
            int k = 0;
            int delta = 1;
            double totalRegret = 0;
            for (int i = 0; i < maxSteps; i++) {
                if (i > 50 * delta) {
                    delta = bump(delta);
                }
                int choice = s.sample();
                double r = refs.get(choice).nextDouble();

                totalRegret += refs.get(bandits - 1).getMean() - refs.get(choice).getMean();
                if ((i + 1) % delta == 0) {
                    if (cumulativeRegret.size() <= k) {
                        cumulativeRegret.add(new OnlineSummarizer());
                        steps.add(i + 1);
                    }
                    cumulativeRegret.get(k).add(totalRegret);
                    k++;
                }
                if (localRegret.size() <= i / BUCKET_SIZE) {
                    localRegret.add(new OnlineSummarizer());
                    localSteps.add(i);
                }
                double thisTrialRegret = refs.get(bandits - 1).getMean() - refs.get(choice).getMean();
                localRegret.get(i / BUCKET_SIZE).add(thisTrialRegret);
                wins += r;
                s.train(choice, r);
            }
        }

        printRegret(cumulativeOutput, cumulativeRegret, steps);
        printRegret(perTurnOutput, localRegret, localSteps);
        return cumulativeRegret.get(cumulativeRegret.size() - 1).getMedian();
    }

    private static void printRegret(String outputFile, List<OnlineSummarizer> cumulativeRegret, List<Integer> steps) throws FileNotFoundException {
        try (PrintWriter out = new PrintWriter(outputFile)) {
            out.printf("n\tmean\n");
            int k = 0;
            for (OnlineSummarizer summary : cumulativeRegret) {
                out.printf("%d\t%.4f\n", steps.get(k++), summary.getMean());
//        for (int quartile = 0; quartile <= 4; quartile++) {
//          out.printf("%.3f%s", summary.getQuartile(quartile), quartile < 4 ? "\t" : "\n");
//        }
            }
            out.flush();
        }
    }

    private static int bump(int delta) {
        int multiplier = 1;
        while (delta >= 10) {
            multiplier *= 10;
            delta /= 10;
        }
        // steps each of 1,2,5 up to next level
        delta = (int) (4 * delta - delta * delta / 3 - 1.5);
        return delta * multiplier;
    }
}
TOP

Related Classes of com.mapr.stats.bandit.BanditTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.