/**
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
package org.cspoker.ai.bots.bot.gametree.rollout;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import org.apache.log4j.Logger;
import org.cspoker.ai.opponentmodels.OpponentModel;
import org.cspoker.client.common.gamestate.GameState;
import org.cspoker.client.common.playerstate.PlayerState;
import org.cspoker.common.elements.cards.Card;
import org.cspoker.common.elements.player.PlayerId;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multiset;
import com.google.common.collect.TreeMultiset;
import com.google.common.collect.ImmutableMap.Builder;
public class BucketRollOut extends RollOutStrategy {
private final static Logger logger = Logger.getLogger(BucketRollOut.class);
private final OpponentModel model;
private final Map<PlayerId, double[]> bucketProbs;
private final static int nbBuckets = 6;
private final static int nbSamplesPerBucket = 6;
public BucketRollOut(GameState gameState, PlayerId botId, OpponentModel model) {
super(gameState, botId);
this.model = model;
Builder<PlayerId, double[]> builder = new ImmutableMap.Builder<PlayerId, double[]>();
for (PlayerState opponentThatCanWin : activeOpponents) {
PlayerId playerId = opponentThatCanWin.getPlayerId();
double[] bucketProbs = model.getShowdownProbabilities(gameState, playerId);
builder.put(playerId, bucketProbs);
}
bucketProbs = builder.build();
}
//TODO optimize
public double doRollOut(int nbCommunitySamples) {
boolean traceEnabled = logger.isTraceEnabled();
double totalEV = 0;
model.assumeTemporarily(gameState);
for (int i = 0; i < nbCommunitySamples; i++) {
int communitySampleRank = fixedRank;
EnumSet<Card> usedCommunityAndBotCards = EnumSet.copyOf(usedFixedCommunityAndBotCards);
EnumSet<Card> usedCommunityCards = EnumSet.copyOf(usedFixedCommunityCards);
for (int j = 0; j < nbMissingCommunityCards; j++) {
Card communityCard = drawNewCard(usedCommunityAndBotCards);
if (traceEnabled) {
logger.trace("Evaluating sampled community card "
+ communityCard);
}
usedCommunityCards.add(communityCard);
communitySampleRank = updateIntermediateRank(
communitySampleRank, communityCard);
}
if (traceEnabled) {
logger.trace("Evaluating bot cards " + botCard1 + " "
+ botCard2);
}
int botRank = getFinalRank(communitySampleRank, botCard1, botCard2);
// int minSampleRank = Integer.MAX_VALUE;
// int maxSampleRank = Integer.MIN_VALUE;
// int sum = 0;
Multiset<Integer> ranks = new TreeMultiset<Integer>();
Multiset<Integer> deadRanks = new TreeMultiset<Integer>();
int n = 100;
for (int j = 0; j < n; j++) {
EnumSet<Card> handCards = EnumSet.copyOf(usedCommunityCards);
Card sampleCard1 = drawNewCard(handCards);
Card sampleCard2 = drawNewCard(handCards);
int sampleRank = getFinalRank(communitySampleRank, sampleCard1, sampleCard2);
ranks.add(sampleRank);
if(botCard1.equals(sampleCard1) || botCard1.equals(sampleCard2)
|| botCard2.equals(sampleCard1) || botCard2.equals(sampleCard2)){
deadRanks.add(sampleRank);
}
// if(sampleRank<minSampleRank){
// minSampleRank = sampleRank;
// }
// if(sampleRank>maxSampleRank){
// maxSampleRank = sampleRank;
// }
// sum += sampleRank;
}
// double mean = ((double)sum)/n;
// double var = calcVariance(ranks, mean);
// int averageSampleRank = (int) Math.round(mean);
// int sigmaSampleRank = (int) Math.round(Math.sqrt(var));
WinDistribution[] winProbs = calcWinDistributions(botRank, ranks, deadRanks);
double[] deadCardWeights = calcDeadCardWeights(ranks, deadRanks);
TreeMap<PlayerState, WinDistribution> winDistributions = calcOpponentWinDistributionMap(
winProbs, deadCardWeights);
int maxDistributed = 0;
int botInvestment = botState.getTotalInvestment();
double sampleEV = 0;
for (Iterator<PlayerState> iter = winDistributions.keySet().iterator(); iter.hasNext();) {
PlayerState opponent = iter.next();
int toDistribute = Math.min(botInvestment, opponent.getTotalInvestment())-maxDistributed;
if(toDistribute>0){
double pWin = 1;
double pNotLose = 1;
for (WinDistribution distribution : winDistributions.values()) {
//you win when you win from every opponent
pWin *= distribution.pWin;
//you don't lose when you don't lose from every opponent
pNotLose *= distribution.pWin+distribution.pDraw;
}
sampleEV += toDistribute*pWin;
//you draw when you don't lose but don't win everything either;
double pDraw = pNotLose - pWin;
// assume worst case, with winDistributions.size()+1 drawers
//TODO do this better, use rollout or statistics!
sampleEV += pDraw * toDistribute / (winDistributions.size()+1.0);
maxDistributed += toDistribute;
}
iter.remove();
}
//get back uncalled investment
sampleEV += botInvestment-maxDistributed;
totalEV += sampleEV;
}
model.forgetLastAssumption();
return (1-gameState.getTableConfiguration().getRake())*(totalEV/nbCommunitySamples);
}
private TreeMap<PlayerState, WinDistribution> calcOpponentWinDistributionMap(
WinDistribution[] winProbs, double[] deadCardWeights) {
TreeMap<PlayerState,WinDistribution> winDistributions= new TreeMap<PlayerState,WinDistribution>(playerComparatorByInvestment);
for (PlayerState opponentThatCanWin : activeOpponents) {
double[] bucketProb = bucketProbs.get(opponentThatCanWin.getPlayerId());
bucketProb = normalize(multiply(deadCardWeights,bucketProb));
winDistributions.put(opponentThatCanWin, calcOpponentWinDistr(winProbs, bucketProb));
}
return winDistributions;
}
private double[] multiply(double[] a, double[] b) {
double[] c = new double[a.length];
for(int i=0;i<a.length;i++) c[i] = a[i]* b[i];
return c;
}
private double[] normalize(double[] a){
double[] c = new double[a.length];
double sum = 0;
for(int i=0;i<a.length;i++) sum += a[i];
if(Double.isNaN(sum) || sum==0 || Double.isInfinite(sum)){
throw new IllegalStateException("Bad probabilities:"+sum+" = "+a);
}
double invSum = 1/sum;
for (int i = 0; i < a.length; i++) {
c[i] = a[i]*invSum;
}
return c;
}
private WinDistribution calcOpponentWinDistr(WinDistribution[] winProbs,
double[] bucketProbs) {
WinDistribution winDistr;
double pWin=0, pDraw=0, pLose=0;
for (int j = 0; j < bucketProbs.length; j++) {
pWin += winProbs[j].pWin*bucketProbs[j];
pDraw += winProbs[j].pDraw*bucketProbs[j];
pLose += winProbs[j].pLose*bucketProbs[j];
}
winDistr = new WinDistribution(pWin,pDraw,pLose);
return winDistr;
}
private WinDistribution[] calcWinDistributions(int botRank,
Multiset<Integer> ranks, Multiset<Integer> deadRanks) {
Iterator<Integer> iter = ranks.iterator();
WinDistribution[] winProbs = new WinDistribution[10];
for(int bucket=0;bucket<nbBuckets;bucket++){
double winWeight = 0;
double drawWeight = 0;
double loseWeight = 0;
for (int j = 0; j < nbSamplesPerBucket; j++) {
int rank = iter.next();
double weight = 1-deadRanks.count(rank)/ranks.count(rank);
if(rank<botRank){
winWeight+=weight;
}else if(rank>botRank){
loseWeight+=weight;
}else{
drawWeight+=weight;
}
}
double nbSamples = winWeight+drawWeight+loseWeight;
if(nbSamples==0) nbSamples = 1;
winProbs[bucket] = new WinDistribution(winWeight/nbSamples, drawWeight/nbSamples, loseWeight/nbSamples);
}
return winProbs;
}
public static class WinDistribution{
//from the perspective of the bot
public final double pWin, pDraw, pLose;
public WinDistribution(double pWin, double pDraw, double pLose) {
this.pWin = pWin;
this.pDraw = pDraw;
this.pLose = pLose;
}
@Override
public String toString() {
return pWin+"/"+pDraw+"/"+pLose;
}
}
private double[] calcDeadCardWeights(Multiset<Integer> ranks,
Multiset<Integer> deadRanks) {
Iterator<Integer> iter = ranks.iterator();
double[] deadCardWeights = new double[nbBuckets];
for(int bucket=0;bucket<nbBuckets;bucket++){
double nbDead = 0;
for (int j = 0; j < nbSamplesPerBucket; j++) {
int rank = iter.next();
double count = ranks.count(rank);
double deadCount = deadRanks.count(rank);
nbDead += deadCount/count;
}
deadCardWeights[bucket] = ((nbSamplesPerBucket-nbDead)/nbSamplesPerBucket);
}
return deadCardWeights;
}
// private double calcVariance(Multiset<Integer> ranks, double mean) {
// double var = 0;
// for (Multiset.Entry<Integer> entry : ranks.entrySet()) {
// double diff = mean - entry.getElement();
// var += diff * diff * entry.getCount();
// }
// var /= (ranks.size()-1);
// return var;
// }
}