Package org.cspoker.ai.bots.bot.gametree.rollout

Source Code of org.cspoker.ai.bots.bot.gametree.rollout.BucketRollOut$WinDistribution

/**
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
*  This program is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU General Public License for more details.
*  You should have received a copy of the GNU General Public License
*  along with this program; if not, write to the Free Software
*  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
package org.cspoker.ai.bots.bot.gametree.rollout;

import java.util.EnumSet;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;

import org.apache.log4j.Logger;
import org.cspoker.ai.opponentmodels.OpponentModel;
import org.cspoker.client.common.gamestate.GameState;
import org.cspoker.client.common.playerstate.PlayerState;
import org.cspoker.common.elements.cards.Card;
import org.cspoker.common.elements.player.PlayerId;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multiset;
import com.google.common.collect.TreeMultiset;
import com.google.common.collect.ImmutableMap.Builder;

public class BucketRollOut extends RollOutStrategy {

  private final static Logger logger = Logger.getLogger(BucketRollOut.class);

  private final OpponentModel model;

  private final Map<PlayerId, double[]> bucketProbs;

  private final static int nbBuckets = 6;
  private final static int nbSamplesPerBucket = 6;
 

  public BucketRollOut(GameState gameState, PlayerId botId, OpponentModel model) {
    super(gameState, botId);
    this.model = model;
    Builder<PlayerId, double[]> builder = new ImmutableMap.Builder<PlayerId, double[]>();
    for (PlayerState opponentThatCanWin : activeOpponents) {
      PlayerId playerId = opponentThatCanWin.getPlayerId();
      double[] bucketProbs = model.getShowdownProbabilities(gameState, playerId);
      builder.put(playerId, bucketProbs);
    }
    bucketProbs = builder.build();
  }

  //TODO optimize
  public double doRollOut(int nbCommunitySamples) {
    boolean traceEnabled = logger.isTraceEnabled();
    double totalEV = 0;
    model.assumeTemporarily(gameState);
    for (int i = 0; i < nbCommunitySamples; i++) {
      int communitySampleRank = fixedRank;
      EnumSet<Card> usedCommunityAndBotCards = EnumSet.copyOf(usedFixedCommunityAndBotCards);
      EnumSet<Card> usedCommunityCards = EnumSet.copyOf(usedFixedCommunityCards);
      for (int j = 0; j < nbMissingCommunityCards; j++) {
        Card communityCard = drawNewCard(usedCommunityAndBotCards);
        if (traceEnabled) {
          logger.trace("Evaluating sampled community card "
              + communityCard);
        }
        usedCommunityCards.add(communityCard);
        communitySampleRank = updateIntermediateRank(
            communitySampleRank, communityCard);
      }
      if (traceEnabled) {
        logger.trace("Evaluating bot cards " + botCard1 + " "
            + botCard2);
      }
      int botRank = getFinalRank(communitySampleRank, botCard1, botCard2);

      //      int minSampleRank = Integer.MAX_VALUE;
      //      int maxSampleRank = Integer.MIN_VALUE;
      //      int sum = 0;
      Multiset<Integer> ranks = new TreeMultiset<Integer>();
      Multiset<Integer> deadRanks = new TreeMultiset<Integer>();
      int n = 100;
      for (int j = 0; j < n; j++) {
        EnumSet<Card> handCards = EnumSet.copyOf(usedCommunityCards);
        Card sampleCard1 = drawNewCard(handCards);
        Card sampleCard2 = drawNewCard(handCards);
        int sampleRank = getFinalRank(communitySampleRank, sampleCard1, sampleCard2);
        ranks.add(sampleRank);
        if(botCard1.equals(sampleCard1) || botCard1.equals(sampleCard2)
            || botCard2.equals(sampleCard1) || botCard2.equals(sampleCard2)){
          deadRanks.add(sampleRank);
       
        //        if(sampleRank<minSampleRank){
        //          minSampleRank = sampleRank;
        //        }
        //        if(sampleRank>maxSampleRank){
        //          maxSampleRank = sampleRank;
        //        }
        //        sum += sampleRank;
      }
      //      double mean = ((double)sum)/n;
      //      double var = calcVariance(ranks, mean);
      //      int averageSampleRank = (int) Math.round(mean);
      //      int sigmaSampleRank = (int) Math.round(Math.sqrt(var));

      WinDistribution[] winProbs = calcWinDistributions(botRank, ranks, deadRanks);
      double[] deadCardWeights = calcDeadCardWeights(ranks, deadRanks);

      TreeMap<PlayerState, WinDistribution> winDistributions = calcOpponentWinDistributionMap(
          winProbs, deadCardWeights);

      int maxDistributed = 0;
      int botInvestment = botState.getTotalInvestment();
      double sampleEV = 0;
      for (Iterator<PlayerState> iter = winDistributions.keySet().iterator(); iter.hasNext();) {
        PlayerState opponent = iter.next();
        int toDistribute = Math.min(botInvestment, opponent.getTotalInvestment())-maxDistributed;
        if(toDistribute>0){
          double pWin = 1;
          double pNotLose = 1;
          for (WinDistribution distribution : winDistributions.values()) {
            //you win when you win from every opponent
            pWin *= distribution.pWin;
            //you don't lose when you don't lose from every opponent
            pNotLose *= distribution.pWin+distribution.pDraw;
          }
          sampleEV += toDistribute*pWin;
          //you draw when you don't lose but don't win everything either;
          double pDraw = pNotLose - pWin;
          // assume worst case, with winDistributions.size()+1 drawers
          //TODO do this better, use rollout or statistics!
          sampleEV += pDraw * toDistribute / (winDistributions.size()+1.0);
          maxDistributed += toDistribute;
        }
        iter.remove();
      }
      //get back uncalled investment
      sampleEV += botInvestment-maxDistributed;
      totalEV += sampleEV;
    }
    model.forgetLastAssumption();
    return (1-gameState.getTableConfiguration().getRake())*(totalEV/nbCommunitySamples);
  }

  private TreeMap<PlayerState, WinDistribution> calcOpponentWinDistributionMap(
      WinDistribution[] winProbs, double[] deadCardWeights) {
    TreeMap<PlayerState,WinDistribution> winDistributions= new TreeMap<PlayerState,WinDistribution>(playerComparatorByInvestment);
    for (PlayerState opponentThatCanWin : activeOpponents) {
      double[] bucketProb = bucketProbs.get(opponentThatCanWin.getPlayerId());
      bucketProb = normalize(multiply(deadCardWeights,bucketProb));
      winDistributions.put(opponentThatCanWin, calcOpponentWinDistr(winProbs, bucketProb));
    }
    return winDistributions;
  }

  private double[] multiply(double[] a, double[] b) {
    double[] c = new double[a.length];
    for(int i=0;i<a.length;i++) c[i] = a[i]* b[i];
    return c;
  }

  private double[] normalize(double[] a){
    double[] c = new double[a.length];
    double sum = 0;
    for(int i=0;i<a.length;i++) sum += a[i];
    if(Double.isNaN(sum) || sum==0 || Double.isInfinite(sum)){
      throw new IllegalStateException("Bad probabilities:"+sum+" = "+a);
    }
    double invSum = 1/sum;
    for (int i = 0; i < a.length; i++) {
      c[i] =  a[i]*invSum;
    }
    return c;
  }

  private WinDistribution calcOpponentWinDistr(WinDistribution[] winProbs,
      double[] bucketProbs) {
    WinDistribution winDistr;
    double pWin=0, pDraw=0, pLose=0;
    for (int j = 0; j < bucketProbs.length; j++) {
      pWin += winProbs[j].pWin*bucketProbs[j];
      pDraw += winProbs[j].pDraw*bucketProbs[j];
      pLose += winProbs[j].pLose*bucketProbs[j];
    }
    winDistr = new WinDistribution(pWin,pDraw,pLose);
    return winDistr;
  }

  private WinDistribution[] calcWinDistributions(int botRank,
      Multiset<Integer> ranks, Multiset<Integer> deadRanks) {
    Iterator<Integer> iter = ranks.iterator();
    WinDistribution[] winProbs = new WinDistribution[10];
    for(int bucket=0;bucket<nbBuckets;bucket++){
      double winWeight = 0;
      double drawWeight = 0;
      double loseWeight = 0;
      for (int j = 0; j < nbSamplesPerBucket; j++) {
        int rank = iter.next();
        double weight = 1-deadRanks.count(rank)/ranks.count(rank);
        if(rank<botRank){
          winWeight+=weight;
        }else if(rank>botRank){
          loseWeight+=weight;
        }else{
          drawWeight+=weight;
        }
      }
      double nbSamples = winWeight+drawWeight+loseWeight;
      if(nbSamples==0) nbSamples = 1;
      winProbs[bucket] = new WinDistribution(winWeight/nbSamples, drawWeight/nbSamples, loseWeight/nbSamples);
    }
    return winProbs;
  }

  public static class WinDistribution{

    //from the perspective of the bot
    public final double pWin, pDraw, pLose;

    public WinDistribution(double pWin, double pDraw, double pLose) {
      this.pWin = pWin;
      this.pDraw = pDraw;
      this.pLose = pLose;
    }

    @Override
    public String toString() {
      return pWin+"/"+pDraw+"/"+pLose;
    }

  }

  private double[] calcDeadCardWeights(Multiset<Integer> ranks,
      Multiset<Integer> deadRanks) {
    Iterator<Integer> iter = ranks.iterator();
    double[] deadCardWeights = new double[nbBuckets];
    for(int bucket=0;bucket<nbBuckets;bucket++){
      double nbDead = 0;
      for (int j = 0; j < nbSamplesPerBucket; j++) {
        int rank = iter.next();
        double count = ranks.count(rank);
        double deadCount = deadRanks.count(rank);
        nbDead += deadCount/count;
      }
      deadCardWeights[bucket] = ((nbSamplesPerBucket-nbDead)/nbSamplesPerBucket);
    }
    return deadCardWeights;
  }

//  private double calcVariance(Multiset<Integer> ranks, double mean) {
//    double var = 0;
//    for (Multiset.Entry<Integer> entry : ranks.entrySet()) {
//      double diff = mean - entry.getElement();
//      var += diff * diff * entry.getCount();
//    }
//    var /= (ranks.size()-1);
//    return var;
//  }

}
TOP

Related Classes of org.cspoker.ai.bots.bot.gametree.rollout.BucketRollOut$WinDistribution

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.