Package

Source Code of SampleSarsaAgent

/*
* Copyright 2008 Brian Tanner
* http://rl-glue-ext.googlecode.com/
* brian@tannerpages.com
* http://brian.tannerpages.com
*
*  Licensed under the Apache License, Version 2.0 (the "License");
*  you may not use this file except in compliance with the License.
*  You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
*  Unless required by applicable law or agreed to in writing, software
*  distributed under the License is distributed on an "AS IS" BASIS,
*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*  See the License for the specific language governing permissions and
*  limitations under the License.
*
*  $Revision: 991 $
*  $Date: 2009-02-08 19:29:20 -0500 (Sun, 08 Feb 2009) $
*  $Author: brian@tannerpages.com $
*  $HeadURL: http://rl-library.googlecode.com/svn/trunk/projects/packages/examples/mines-sarsa-java/SampleSarsaAgent.java $
*
*/

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Random;
import org.rlcommunity.rlglue.codec.AgentInterface;
import org.rlcommunity.rlglue.codec.types.Action;
import org.rlcommunity.rlglue.codec.types.Observation;
import org.rlcommunity.rlglue.codec.util.AgentLoader;
import org.rlcommunity.rlglue.codec.taskspec.TaskSpec;

/**
This is a very simple Sarsa agent for discrete-action, discrete-state
environments.  It uses epsilon-greedy exploration.

We've made a decision to store the previous action and observation in
their raw form, as structures.  This code could be simplified and you
could store them just as ints.
* @author Brian Tanner
*/
public class SampleSarsaAgent implements AgentInterface {

    private Random randGenerator = new Random();
    private Action lastAction;
    private Observation lastObservation;
    private double[][] valueFunction = null;
    private double sarsa_stepsize = 0.1;
    private double sarsa_epsilon = 0.1;
    private double sarsa_gamma = 1.0;
    private int numActions = 0;
    private int numStates = 0;
    private boolean policyFrozen = false;
    private boolean exploringFrozen = false;

    /**
     * Parse the task spec, make sure it is only 1 integer observation and
     * action, and then allocate the valueFunction.
     *
     * @param taskSpecification
     */
    public void agent_init(String taskSpecification) {
        TaskSpec theTaskSpec = new TaskSpec(taskSpecification);

        /* Lots of assertions to make sure that we can handle this problem.  */
        assert (theTaskSpec.getNumDiscreteObsDims() == 1);
        assert (theTaskSpec.getNumContinuousObsDims() == 0);
        assert (!theTaskSpec.getDiscreteObservationRange(0).hasSpecialMinStatus());
        assert (!theTaskSpec.getDiscreteObservationRange(0).hasSpecialMaxStatus());
        numStates = theTaskSpec.getDiscreteObservationRange(0).getMax() + 1;

        assert (theTaskSpec.getNumDiscreteActionDims() == 1);
        assert (theTaskSpec.getNumContinuousActionDims() == 0);
        assert (!theTaskSpec.getDiscreteActionRange(0).hasSpecialMinStatus());
        assert (!theTaskSpec.getDiscreteActionRange(0).hasSpecialMaxStatus());
        numActions = theTaskSpec.getDiscreteActionRange(0).getMax() + 1;

        sarsa_gamma=theTaskSpec.getDiscountFactor();

        valueFunction = new double[numActions][numStates];

    }

    /**
     * Choose an action e-greedily from the value function and store the action
     * and observation.
     * @param observation
     * @return
     */
    public Action agent_start(Observation observation) {
        int newActionInt = egreedy(observation.getInt(0));

        /**
         * Create a structure to hold 1 integer action
         * and set the value
         */
        Action returnAction = new Action(1, 0, 0);
        returnAction.intArray[0] = newActionInt;

        lastAction = returnAction.duplicate();
        lastObservation = observation.duplicate();

        return returnAction;
    }

    /**
     * Choose an action e-greedily from the value function and store the action
     * and observation.  Update the valueFunction entry for the last
     * state,action pair.
     * @param reward
     * @param observation
     * @return
     */
    public Action agent_step(double reward, Observation observation) {
        int newStateInt = observation.getInt(0);
        int lastStateInt = lastObservation.getInt(0);
        int lastActionInt = lastAction.getInt(0);

        int newActionInt = egreedy(newStateInt);

        double Q_sa = valueFunction[lastActionInt][lastStateInt];
        double Q_sprime_aprime = valueFunction[newActionInt][newStateInt];

        double new_Q_sa = Q_sa + sarsa_stepsize * (reward + sarsa_gamma * Q_sprime_aprime - Q_sa);
        /*  Only update the value function if the policy is not frozen */
        if (!policyFrozen) {
            valueFunction[lastActionInt][lastStateInt] = new_Q_sa;
        }

        /* Creating the action a different way to showcase variety */
        Action returnAction = new Action();
        returnAction.intArray = new int[]{newActionInt};

        lastAction = returnAction.duplicate();
        lastObservation = observation.duplicate();

        return returnAction;
    }

    /**
     * The episode is over, learn from the last reward that was received.
     * @param reward
     */
    public void agent_end(double reward) {
        int lastStateInt = lastObservation.getInt(0);
        int lastActionInt = lastAction.getInt(0);

        double Q_sa = valueFunction[lastActionInt][lastStateInt];
        double new_Q_sa = Q_sa + sarsa_stepsize * (reward - Q_sa);

        /*  Only update the value function if the policy is not frozen */
        if (!policyFrozen) {
            valueFunction[lastActionInt][lastStateInt] = new_Q_sa;
        }
        lastObservation = null;
        lastAction = null;
    }

    /**
     * Release memory that is no longer required/used.
     */
    public void agent_cleanup() {
        lastAction = null;
        lastObservation = null;
        valueFunction = null;
    }

    /**
     * This agent responds to some simple messages for freezing learning and
     * saving/loading the value function to a file.
     * @param message
     * @return
     */
    public String agent_message(String message) {

        if (message.equals("freeze learning")) {
            policyFrozen = true;
            return "message understood, policy frozen";
        }
        if (message.equals("unfreeze learning")) {
            policyFrozen = false;
            return "message understood, policy unfrozen";
        }
        if (message.equals("freeze exploring")) {
            exploringFrozen = true;
            return "message understood, exploring frozen";
        }
        if (message.equals("unfreeze exploring")) {
            exploringFrozen = false;
            return "message understood, exploring unfrozen";
        }
        if (message.startsWith("save_policy")) {
            String[] parts = message.split(" ");
            saveValueFunction(parts[1]);
            System.out.println("Saved.");
            return "message understood, saving policy";
        }
        if (message.startsWith("load_policy")) {
            String[] parts = message.split(" ");
            loadValueFunction(parts[1]);
            System.out.println("Loaded.");
            return "message understood, loading policy";
        }

        return "SampleSarsaAgent(Java) does not understand your message.";

    }

    /**
     *
     * Selects a random action with probability 1-sarsa_epsilon,
     * and the action with the highest value otherwise.  This is a
     * quick'n'dirty implementation, it does not do tie-breaking.

     * @param theState
     * @return
     */
    private int egreedy(int theState) {
        if (!exploringFrozen) {
            if (randGenerator.nextDouble() <= sarsa_epsilon) {
                return randGenerator.nextInt(numActions);
            }
        }

        /*otherwise choose the greedy action*/
        int maxIndex = 0;
        for (int a = 1; a < numActions; a++) {
            if (valueFunction[a][theState] > valueFunction[maxIndex][theState]) {
                maxIndex = a;
            }
        }
        return maxIndex;
    }

    /**
     * This is a trick we can use to make the agent easily loadable.  Using this
     * trick you can directly execute the class and it will load itself through
     * AgentLoader and connect to the rl_glue server.
     * @param args
     */
    public static void main(String[] args) {
        AgentLoader theLoader = new AgentLoader(new SampleSarsaAgent());
        theLoader.run();
    }

    /**
     * Dumps the value function to a file named theFileName.  Not fancy. Must be
     * called after init but before cleanup.
     * @param theFileName
     */
    private void saveValueFunction(String theFileName) {
        try {
            DataOutputStream DO = new DataOutputStream(new FileOutputStream(new File(theFileName)));
            for (int a = 0; a < numActions; a++) {
                for (int s = 0; s < numStates; s++) {
                    DO.writeDouble(valueFunction[a][s]);
                }
            }
            DO.close();
        } catch (FileNotFoundException ex) {
            System.err.println("Problem saving value function to file: " + theFileName + " :: " + ex);
        } catch (IOException ex) {
            System.err.println("Problem writing value function to file:: " + ex);
        }

    }

    /**
     * Loads the value function from a file named theFileName.  Must be called after
     * init but before cleanup.
     * @param theFileName
     */
    private void loadValueFunction(String theFileName) {
        try {
            DataInputStream DI = new DataInputStream(new FileInputStream(new File(theFileName)));
            for (int a = 0; a < numActions; a++) {
                for (int s = 0; s < numStates; s++) {
                    valueFunction[a][s] = DI.readDouble();
                }
            }
            DI.close();
        } catch (FileNotFoundException ex) {
            System.err.println("Problem loading value function from file: " + theFileName + " :: " + ex);
        } catch (IOException ex) {
            System.err.println("Problem reading value function from file:: " + ex);
        }
    }
}
TOP

Related Classes of SampleSarsaAgent

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.