/*
* Javlov - a Java toolkit for reinforcement learning with multi-agent support.
*
* Copyright (c) 2009-2011 Matthijs Snel
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package net.javlov.policy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import net.javlov.Option;
import net.javlov.Policy;
import net.javlov.QFunction;
import net.javlov.State;
import net.javlov.util.ArrayUtil;
/**
* Epsilon-greedy (e-greedy) policy. Chooses action with maximum Q-value with probability
* 1 - e, and a (uniform) random action with probability e. This means that the greedy action
* will be chosen with probability 1 - e + e/(nr of actions).
*
* @author Matthijs Snel
*
*/
public class EGreedyPolicy implements Policy {
/**
* The Q-value function.
*/
private QFunction q;
/**
* Epsilon.
*/
protected double e;
/**
* List of allowed actions. Index of actions in the list should correspond to their
* ID.
*/
protected List<? extends Option> optionPool;
/**
* Random number generator.
*/
private Random rng;
/**
* Creates an epsilon-greedy policy with the specified Q-function, epsilon, and pool
* of allowed actions.
*
* @param q the Q-value function.
* @param epsilon action with maximum Q-value is chosen with probability
* 1 - e, and a (uniform) random action with probability e.
* @param actions list of allowed actions. Index of actions in the list should correspond to their
* ID.
*/
public EGreedyPolicy(QFunction q, double epsilon, List<? extends Option> options) {
setQFunction(q);
setEpsilon(epsilon);
optionPool = options;
rng = new Random();
}
/**
* Creates an epsilon-greedy policy with the specified Q-function, epsilon, and pool
* of allowed actions.
*
* @param q the Q-value function.
* @param epsilon action with maximum Q-value is chosen with probability
* 1 - e, and a (uniform) random action with probability e.
* @param actions list of allowed actions. Index of actions in the list should correspond to their
* ID.
* @param rng random number generator used to pick actions.
*/
public EGreedyPolicy(QFunction q, double epsilon, List<? extends Option> options, Random rng) {
setQFunction(q);
setEpsilon(epsilon);
optionPool = options;
this.rng = rng;
}
public void setQFunction(QFunction q) {
this.q = q;
}
public QFunction getQFunction() {
return q;
}
public double getEpsilon() {
return e;
}
public void setEpsilon(double epsilon) {
e = epsilon;
}
/**
* Chooses option with maximum Q-value (greedy option) with probability
* 1 - e, and a (uniformly distributed) random option with probability e.
* This means that the greedy option will be chosen with probability
* {@code 1 - e + e/(nr of options)}. If there is more than one greedy option, ties
* are broken randomly.
*
* @param s the state based on which to choose the option. The Q-value function will
* be queried first to determine the Q-values of the options for this state, after which
* the option will be determined by calling {@link #pickNewOption(State, double[])}.
* @return an {@code Option} chosen according to the rule as specified above.
*/
@Override
public <T> Option getOption(State<T> s) {
return getOption(s, q.getValues(s));
}
//TODO inefficient implementation of determining this
protected <T> List<Option> getStateOptionSet(State<T> s) {
List<Option> stateOptionSet = new ArrayList<Option>(optionPool.size());
for ( Option o : optionPool )
if ( o.isEligible(s) )
stateOptionSet.add(o);
if ( stateOptionSet.size() == 0 )
throw new RuntimeException("No eligible options for state: " + s);
return stateOptionSet;
}
@Override
public <T> Option getOption(State<T> s, double[] qvalues) {
List<Option> stateOptionSet = getStateOptionSet(s);
//System.out.println(stateOptionSet + "--" + Arrays.toString(qvalues));
if ( stateOptionSet == null || stateOptionSet.size() == qvalues.length ) {
//choose greedy option, randomly break ties if there is more than one max option
if ( rng.nextDouble() > e ) {
//get indices of options with max Q-value (returns 1 or more)
int a[] = ArrayUtil.multimaxIndex(qvalues);
if ( a.length < 1 ) {
throw new RuntimeException("Impossible: " + stateOptionSet.size() + "," + Arrays.toString(qvalues) );
}
//System.out.println(Arrays.toString(a));
return optionPool.get( a[rng.nextInt(a.length)] );
}
//choose random action
return optionPool.get( rng.nextInt(qvalues.length) );
}
else {
if ( rng.nextDouble() > e ) {
List<Option> maxOpts = getMaxOpts(stateOptionSet, qvalues);
//System.out.println(maxOpts);
Option ret = ( maxOpts.size() == 1 ? maxOpts.get(0) : maxOpts.get( rng.nextInt(maxOpts.size()) ) );
return ret;
}
Option ret = stateOptionSet.get( rng.nextInt(stateOptionSet.size()) );
return ret;
}
}
@Override
public <T> double[] getOptionProbabilities( State<T> s, double[] qvalues ) {
List<? extends Option> stateOptionSet = getStateOptionSet(s);
double[] probs = new double[qvalues.length];
if ( stateOptionSet == null || stateOptionSet.size() == qvalues.length ) {
int a[] = ArrayUtil.multimaxIndex(qvalues);
double otherProb = e / qvalues.length,
maxProb = (1 - e) / a.length + otherProb;
for ( int i = 0; i < a.length; i++ )
probs[a[i]] = maxProb;
for ( int i = 0; i < qvalues.length; i++ )
if ( probs[i] == 0 )
probs[i] = otherProb;
}
else {
List<Option> maxOpts = getMaxOpts(stateOptionSet, qvalues);
double otherProb = e / stateOptionSet.size(),
maxProb = (1 - e) / maxOpts.size() + otherProb;
stateOptionSet.removeAll(maxOpts);
for ( Option opt : maxOpts )
probs[opt.getID()] = maxProb;
for ( Option opt : stateOptionSet )
probs[opt.getID()] = otherProb;
}
return probs;
}
protected List<Option> getMaxOpts(List<? extends Option> stateOptionSet, double[] qvalues) {
List<Option> maxOpts = new ArrayList<Option>();
double maxVal = Double.NEGATIVE_INFINITY,
val;
for ( Option o : stateOptionSet ) {
val = qvalues[o.getID()];
if ( val > maxVal ) {
maxVal = val;
maxOpts.clear();
maxOpts.add(o);
} else if ( val == maxVal )
maxOpts.add(o);
}
return maxOpts;
}
@Override
public void init() {
for ( Option o : optionPool )
o.init();
}
@Override
public void reset() {
for ( Option o : optionPool )
o.reset();
}
}