Package mage.player.ai

Source Code of mage.player.ai.MCTSNode

/*
*  Copyright 2011 BetaSteward_at_googlemail.com. All rights reserved.
*
*  Redistribution and use in source and binary forms, with or without modification, are
*  permitted provided that the following conditions are met:
*
*     1. Redistributions of source code must retain the above copyright notice, this list of
*        conditions and the following disclaimer.
*
*     2. Redistributions in binary form must reproduce the above copyright notice, this list
*        of conditions and the following disclaimer in the documentation and/or other materials
*        provided with the distribution.
*
*  THIS SOFTWARE IS PROVIDED BY BetaSteward_at_googlemail.com ``AS IS'' AND ANY EXPRESS OR IMPLIED
*  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
*  FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL BetaSteward_at_googlemail.com OR
*  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
*  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
*  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
*  ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*  The views and conclusions contained in the software and documentation are those of the
*  authors and should not be interpreted as representing official policies, either expressed
*  or implied, of BetaSteward_at_googlemail.com.
*/
package mage.player.ai;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import mage.constants.PhaseStep;
import mage.constants.Zone;
import mage.abilities.Ability;
import mage.abilities.ActivatedAbility;
import mage.abilities.PlayLandAbility;
import mage.abilities.common.PassAbility;
import mage.cards.Card;
import mage.game.Game;
import mage.game.combat.Combat;
import mage.game.combat.CombatGroup;
import mage.game.turn.Step.StepPart;
import mage.players.Player;
import org.apache.log4j.Logger;

/**
*
* @author BetaSteward_at_googlemail.com
*/
public class MCTSNode {

    private static final double selectionCoefficient = 1.0;
    private static final double passRatioTolerance = 0.0;
     private static final transient Logger logger = Logger.getLogger(MCTSNode.class);

    private int visits = 0;
    private int wins = 0;
    private MCTSNode parent;
    private final List<MCTSNode> children = new ArrayList<MCTSNode>();
    private Ability action;
    private Game game;
    private Combat combat;
    private final String stateValue;
    private UUID playerId;
    private boolean terminal = false;

    private static int nodeCount;

    public MCTSNode(Game game) {
        this.game = game;
        this.stateValue = game.getState().getValue(false, game);
        this.terminal = game.gameOver(null);
        setPlayer();
        nodeCount = 1;
    }   

    protected MCTSNode(MCTSNode parent, Game game, Ability action) {
        this.game = game;
        this.stateValue = game.getState().getValue(false, game);
        this.terminal = game.gameOver(null);
        this.parent = parent;
        this.action = action;
        setPlayer();
        nodeCount++;
    }

    protected MCTSNode(MCTSNode parent, Game game, Combat combat) {
        this.game = game;
        this.combat = combat;
        this.stateValue = game.getState().getValue(false, game);
        this.terminal = game.gameOver(null);
        this.parent = parent;
        setPlayer();
        nodeCount++;
    }

    private void setPlayer() {
        if (game.getStep().getStepPart() == StepPart.PRIORITY) {
            playerId = game.getPriorityPlayerId();
        } else {
            if (game.getStep().getType() == PhaseStep.DECLARE_BLOCKERS) {
                playerId = game.getCombat().getDefenders().iterator().next();
            } else {
                playerId = game.getActivePlayerId();
            }
        }
    }

    public MCTSNode select(UUID targetPlayerId) {
        double bestValue = Double.NEGATIVE_INFINITY;
        boolean isTarget = playerId.equals(targetPlayerId);
        MCTSNode bestChild = null;
        if (children.size() == 1) {
            return children.get(0);
        }
        for (MCTSNode node: children) {
            double uct;
            if (node.visits > 0)
                if (isTarget)
                    uct = (node.wins / (node.visits)) + (selectionCoefficient * Math.sqrt(Math.log(visits) / (node.visits)));
                else
                    uct = ((node.visits - node.wins) / (node.visits)) + (selectionCoefficient * Math.sqrt(Math.log(visits) / (node.visits)));
            else
                // ensure that a random unvisited node is played first
                uct = 10000 + 1000 * Math.random();
            if (uct > bestValue) {
                bestChild = node;
                bestValue = uct;
            }
        }
        return bestChild;
    }

    public void expand() {
        MCTSPlayer player = (MCTSPlayer) game.getPlayer(playerId);
        if (player.getNextAction() == null) {
            logger.fatal("next action is null");
        }
        switch (player.getNextAction()) {
            case PRIORITY:
//                logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType());
                List<Ability> abilities = player.getPlayableOptions(game);
                for (Ability ability: abilities) {
                    Game sim = game.copy();
//                    logger.info("expand " + ability.toString());
                    MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
                    simPlayer.activateAbility((ActivatedAbility)ability, sim);
                    sim.resume();
                    children.add(new MCTSNode(this, sim, ability));
                }
                break;
            case SELECT_ATTACKERS:
//                logger.info("Select attackers:" + player.getName());
                List<List<UUID>> attacks = player.getAttacks(game);
                UUID defenderId = game.getOpponents(player.getId()).iterator().next();
                for (List<UUID> attack: attacks) {
                    Game sim = game.copy();
                    MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
                    for (UUID attackerId: attack) {
                        simPlayer.declareAttacker(attackerId, defenderId, sim, false);
                    }
                    sim.resume();
                    children.add(new MCTSNode(this, sim, sim.getCombat()));
                }
                break;
            case SELECT_BLOCKERS:
//                logger.info("Select blockers:" + player.getName());
                List<List<List<UUID>>> blocks = player.getBlocks(game);
                for (List<List<UUID>> block: blocks) {
                    Game sim = game.copy();
                    MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
                    List<CombatGroup> groups = sim.getCombat().getGroups();
                    for (int i = 0; i < groups.size(); i++) {
                        if (i < block.size()) {
                            for (UUID blockerId: block.get(i)) {
                                simPlayer.declareBlocker(simPlayer.getId(), blockerId, groups.get(i).getAttackers().get(0), sim);
                            }
                        }
                    }
                    sim.resume();
                    children.add(new MCTSNode(this, sim, sim.getCombat()));
                }
                break;
        }
        game = null;
    }

    public int simulate(UUID playerId) {
//        long startTime = System.nanoTime();
        Game sim = createSimulation(game, playerId);
        sim.resume();
//        long duration = System.nanoTime() - startTime;
        int retVal = -1//anything other than a win is a loss
        for (Player simPlayer: sim.getPlayers().values()) {
//            logger.info(simPlayer.getName() + " calculated " + ((SimulatedPlayerMCTS)simPlayer).getActionCount() + " actions in " + duration/1000000000.0 + "s");
            if (simPlayer.getId().equals(playerId) && simPlayer.hasWon()) {
//                logger.info("AI won the simulation");
                retVal = 1;
            }
        }
        return retVal;
    }

    public void backpropagate(int result) {
        if (result == 0)
            return;
        if (result == 1)
            wins++;
        visits++;
        if (parent != null)
            parent.backpropagate(result);
    }

    public boolean isLeaf() {
        return children.isEmpty();
    }

    public MCTSNode bestChild() {
        if (children.size() == 1)
            return children.get(0);
        double bestCount = -1;
        double bestRatio = 0;
        boolean bestIsPass = false;
        MCTSNode bestChild = null;
        for (MCTSNode node: children) {
            //favour passing vs any other action except for playing land if ratio is close
            if (node.visits > bestCount) {
                if (bestIsPass) {
                    double ratio = node.wins/(node.visits * 1.0);
                    if (ratio < bestRatio + passRatioTolerance)
                        continue;
                }
                bestChild = node;
                bestCount = node.visits;
                bestRatio = node.wins/(node.visits * 1.0);
                bestIsPass = false;
            }
            else if (node.action instanceof PassAbility && node.visits > 10 && !(bestChild.action instanceof PlayLandAbility)) {
                //favour passing vs any other action if ratio is close
                double ratio = node.wins/(node.visits * 1.0);
                if (ratio > bestRatio - passRatioTolerance) {
                    logger.info("choosing pass over " + bestChild.getAction());
                    bestChild = node;
                    bestCount = node.visits;
                    bestRatio = ratio;
                    bestIsPass = true;
                }
            }
        }
        return bestChild;
    }

    public void emancipate() {
        if (parent != null) {
            this.parent.children.remove(this);
            this.parent = null;
        }
    }

    public Ability getAction() {
        return action;
    }

    public int getNumChildren() {
        return children.size();
    }

    public MCTSNode getParent() {
        return parent;
    }

    public Combat getCombat() {
        return combat;
    }

    public int getNodeCount() {
        return nodeCount;
    }

    public String getStateValue() {
        return stateValue;
    }

    public double getWinRatio() {
        if (visits > 0)
            return wins/(visits * 1.0);
        return -1.0;
    }

    public int getVisits() {
        return visits;
    }

    /**
     * Copies game and replaces all players in copy with simulated players
     * Shuffles each players library so that there is no knowledge of its order
     *
     * @param game
     * @return a new game object with simulated players
     */
    protected Game createSimulation(Game game, UUID playerId) {
        Game sim = game.copy();

        for (Player copyPlayer: sim.getState().getPlayers().values()) {
            Player origPlayer = game.getState().getPlayers().get(copyPlayer.getId()).copy();
            SimulatedPlayerMCTS newPlayer = new SimulatedPlayerMCTS(copyPlayer.getId(), true);
            newPlayer.restore(origPlayer);
            sim.getState().getPlayers().put(copyPlayer.getId(), newPlayer);
        }
        randomizePlayers(sim, playerId);
        sim.setSimulation(true);
        return sim;
    }

    /*
     * Shuffles each players library so that there is no knowledge of its order
     * Swaps all other players hands with random cards from the library so that
     * there is no knowledge of what cards are in opponents hands
     */
    protected void randomizePlayers(Game game, UUID playerId) {
        for (Player player: game.getState().getPlayers().values()) {
            if (!player.getId().equals(playerId)) {
                int handSize = player.getHand().size();
                player.getLibrary().addAll(player.getHand().getCards(game), game);
                player.getHand().clear();
                player.getLibrary().shuffle();
                for (int i = 0; i < handSize; i++) {
                    Card card = player.getLibrary().removeFromTop(game);
                    game.setZone(card.getId(), Zone.HAND);
                    player.getHand().add(card);
                }
            }
            else {
                player.getLibrary().shuffle();               
            }
        }
    }

    public boolean isTerminal() {
        return terminal;
    }

    public boolean isWinner(UUID playerId) {
        if (game != null) {
            Player player = game.getPlayer(playerId);
            if (player != null && player.hasWon())
                return true;
        }
        return false;
    }

    /**
     *
     * performs a breadth first search for a matching game state
     *
     * @param state - the game state that we are looking for
     * @param nextAction - the next action that will be performed
     * @return the matching state or null if no match is found
     */
    public MCTSNode getMatchingState(String state) {
        ArrayDeque<MCTSNode> queue = new ArrayDeque<MCTSNode>();
        queue.add(this);

        while (!queue.isEmpty()) {
            MCTSNode current = queue.remove();
            if (current.stateValue.equals(state))
                return current;
            for (MCTSNode child: current.children) {
                queue.add(child);
            }
        }
        return null;
    }

    public void merge(MCTSNode merge) {
        if (!stateValue.equals(merge.stateValue)) {
            logger.info("mismatched merge states");
            return;
        }

        this.visits += merge.visits;
        this.wins += merge.wins;

        List<MCTSNode> mergeChildren = new ArrayList<MCTSNode>();
        for (MCTSNode child: merge.children) {
            mergeChildren.add(child);
        }

        for (MCTSNode child: children) {
            for (MCTSNode mergeChild: mergeChildren) {
                if (mergeChild.action != null && child.action != null) {
                    if (mergeChild.action.toString().equals(child.action.toString())) {
                        if (!mergeChild.stateValue.equals(child.stateValue)) {
                            logger.info("mismatched merge states");
                            mergeChildren.remove(mergeChild);
                        }
                        else {
                            child.merge(mergeChild);
                            mergeChildren.remove(mergeChild);
                        }
                        break;
                    }
                }
                else {
                    if (mergeChild.combat.getValue().equals(child.combat.getValue())) {
                        if (!mergeChild.stateValue.equals(child.stateValue)) {
                            logger.info("mismatched merge states");
                            mergeChildren.remove(mergeChild);
                        }
                        else {
                            child.merge(mergeChild);
                            mergeChildren.remove(mergeChild);
                        }
                        break;
                    }
                }
            }
        }
        if (!mergeChildren.isEmpty()) {
            for (MCTSNode child: mergeChildren) {
                child.parent = this;
                children.add(child);
            }
        }
    }

//    public void print(int depth) {
//        String indent = String.format("%1$-" + depth + "s", "");
//        StringBuilder sb = new StringBuilder();
//        MCTSPlayer player = (MCTSPlayer) game.getPlayer(playerId);
//        sb.append(indent).append(player.getName()).append(" ").append(visits).append(":").append(wins).append(" - ");
//        if (action != null)
//            sb.append(action.toString());
//        System.out.println(sb.toString());
//        for (MCTSNode child: children) {
//            child.print(depth + 1);
//        }
//    }

    public int size() {
        int num = 1;
        for (MCTSNode child: children) {
            num += child.size();
        }
        return num;
    }

}
TOP

Related Classes of mage.player.ai.MCTSNode

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.