Package com.mapr.bandit

Source Code of com.mapr.bandit.BanditRanking

package com.mapr.bandit;

import com.google.common.collect.Lists;
import com.mapr.stats.bandit.BanditFactory;
import com.mapr.stats.bandit.BayesianBandit;
import com.mapr.stats.bandit.BetaBayesFactory;
import com.mapr.stats.bandit.GammaNormalBayesFactory;
import com.mapr.stats.random.AbstractBayesianDistribution;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.jet.random.Uniform;

import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

/**
* Implements a Bandit ranking.
*/
public class BanditRanking {
    private static Random gen = RandomUtils.getRandom();

    public static void main(String[] args) throws FileNotFoundException {
        int keyItems = 10, pageSize = 20, totalItems = 200;

        if (args.length > 0) {
            keyItems = Integer.parseInt(args[0]);
        }

        if (args.length > 1) {
            pageSize = Integer.parseInt(args[1]);
        }

        if (args.length > 2) {
            totalItems = Integer.parseInt(args[2]);
        }

        BanditFactory bf = new BetaBayesFactory();
        if (args.length > 3) {
            if (args[3].startsWith("beta")) {
                bf = new BetaBayesFactory();
            } else if (args[3].startsWith("gamma")) {
                bf = new GammaNormalBayesFactory();
            } else {
                throw new IllegalArgumentException("Wanted beta or gamma to specify distribution");
            }
        }

        List<BayesianBandit> bandit = Lists.newArrayList();
        for (int m = 0; m < 50; m++) {
            bandit.add(bf.createBandit(totalItems, gen));
        }

        double[] prob = new double[totalItems];
        Uniform u = new Uniform(gen);
        for (int j = 0; j < totalItems; j++) {
            prob[j] = u.nextDouble();
        }
        Arrays.sort(prob);
        for (int j = 0; j < totalItems; j++) {
            prob[j] = 1 - prob[j];
        }

        double cumulativeRegret = 0;
        PrintWriter quality = new PrintWriter("quality.csv");
        quality.printf("Trials,Precision,Regret,CumulativeRegret\n");
        for (int i = 0; i < 1000; i++) {

            double precision = 0;
            double regret = 0;
            for (int m = 0; m < 50; m++) {
                List<Integer> page = bandit.get(m).rank(pageSize);
                for (Integer item : page) {
                    if (item < keyItems) {
                        precision++;
                    }
                }

                for (int j = 0; j < pageSize; j++) {
                    int k = page.get(j);
                    regret += prob[j] - prob[k];
                }

                for (int j = 0; j < pageSize; j++) {
                    int k = page.get(j);
                    int reward = u.nextDouble() < prob[k] ? 1 : 0;
                    bandit.get(m).train(k, reward);
                }
            }
            precision /= keyItems * 50.0;
            regret /= 50;
            cumulativeRegret += regret;

            quality.printf("%d,%.1f,%.3f,%.3f\n", i + 1, precision * 100, regret, cumulativeRegret);
        }
        quality.close();

        // display samples per rank
        PrintWriter samples = new PrintWriter("samples.csv");
        samples.printf("Rank,Samples\n");
        for (int m = 0; m < 10; m++) {
            int i = 0;
            for (AbstractBayesianDistribution distribution : bandit.get(m)) {
                samples.printf("%d,%.1f\n", i++, distribution.getSamples());
            }
        }
        samples.close();
    }
}
TOP

Related Classes of com.mapr.bandit.BanditRanking

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.