Package com.mapr.stats.bandit

Source Code of com.mapr.stats.bandit.ContextualBayesBandit$InverseLogisticFunction

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mapr.stats.bandit;

import com.mapr.stats.random.BetaDistribution;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.VectorFunction;


/**
* Solves the contextual bandit problem using Bayesian sampling.
*/
public class ContextualBayesBandit {
    private final Matrix featureMap;
    private final Matrix state;
    private final int m;
    private final BetaDistribution rand;

    public ContextualBayesBandit(Matrix featureMap) {
        this(featureMap, 1, 1);
    }

    public ContextualBayesBandit(Matrix featureMap, double alpha_0, double beta_0) {
        this.featureMap = featureMap;
        m = featureMap.numCols();
        this.state = new DenseMatrix(m, 2);
        this.state.viewColumn(0).assign(alpha_0);
        this.state.viewColumn(1).assign(beta_0);
        this.rand = new BetaDistribution(1, 1);
    }

    public Vector samplePi() {
        return sampleNoLink().assign(new LogisticFunction());
    }

    public int sample() {
        final Vector pi = sampleNoLink();
        return pi.maxValueIndex();
    }

    private Vector sampleNoLink() {
        final Vector theta = state.aggregateRows(new VectorFunction() {
            final DoubleFunction inverseLink = new InverseLogisticFunction();

            @Override
            public double apply(Vector f) {
                return inverseLink.apply(rand.nextDouble(f.get(0), f.get(1)));
            }
        });
        return featureMap.times(theta);
    }

    public void train(int bandit, boolean success) {
        state.viewColumn(success ? 0 : 1).assign(featureMap.viewRow(bandit), Functions.plusMult(1.0 / m));
    }

    public class LogisticFunction implements DoubleFunction {
        @Override
        public double apply(double x) {
            return 1 / (1 + Math.exp(-x));
        }
    }

    public class InverseLogisticFunction implements DoubleFunction {
        @Override
        public double apply(double p) {
            return Math.log(p / (1 - p));
        }
    }
}
TOP

Related Classes of com.mapr.stats.bandit.ContextualBayesBandit$InverseLogisticFunction

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.