Package net.javlov.example.landmarks

Source Code of net.javlov.example.landmarks.Main

package net.javlov.example.landmarks;

import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.swing.Timer;

import util.MatOp;

import net.javlov.Action;
import net.javlov.Agent;
import net.javlov.DecayingLearningRate;
import net.javlov.EpisodicRewardStepStatistic;
import net.javlov.FixedLearningRate;
import net.javlov.LearningRate;
import net.javlov.Option;
import net.javlov.Policy;
import net.javlov.QLearningAgent;
import net.javlov.SarsaAgent;
import net.javlov.Simulator;
import net.javlov.TabularQFunction;
import net.javlov.TabularValueFunction;
import net.javlov.example.ExperimentGUI;
import net.javlov.example.GridLimitedOptionsWorld;
import net.javlov.example.rooms.ReachHallOption;
import net.javlov.policy.EGreedyPolicy;
import net.javlov.world.AgentBody;
import net.javlov.world.Body;
import net.javlov.world.GoalBody;
import net.javlov.world.grid.GridGPSSensor;
import net.javlov.world.grid.GridMove;
import net.javlov.world.grid.GridRewardFunction;
import net.javlov.world.phys2d.Phys2DAgentBody;
import net.javlov.world.phys2d.Phys2DBody;
import net.javlov.world.ui.GridWorldView;
import net.phys2d.raw.shapes.Circle;
import net.phys2d.raw.shapes.StaticBox;

public class Main implements Runnable {
 
  GridLimitedOptionsWorld world;
  GridRewardFunction rf;
  Simulator sim;
  TabularQFunction qf;
  SarsaAgent agent;
  boolean gui;
  int cellwidth, cellheight;
 
  public static void main(String[] args) {
    Main m = new Main();
    m.gui = false;
    m.init();
    m.start();
  }
 
  public void init() {
    cellwidth = 50; cellheight = cellwidth;

    makeWorld();   
   
    List<? extends Option> optionPool = makeOptions();
    world.setOptionPool(optionPool);
   
    agent = makeAgent(optionPool);
    AgentBody aBody = makeAgentBody();
   
    sim = new Simulator();
    sim.setEnvironment(world);
   
    world.add(agent, aBody);
    sim.setAgent(agent);
  }
 
  protected SarsaAgent makeAgent(List<? extends Option> optionPool) {
    qf = TabularQFunction.getInstance(optionPool.size());
    SarsaAgent a = new QLearningAgent(qf, 1, optionPool);
    //a.setLearnRate(new DecayingLearningRate(1, optionPool.size(), 0.8));
    a.setLearnRate( new FixedLearningRate(0.2) );
    Policy pi = new EGreedyPolicy(qf, 0.05, optionPool);
    a.setPolicy(pi);
    a.setSMDPMode(true);
    a.setLearnStateValueFunction(false);
    a.setInterruptOptions(false);
    return a;
  }
 
  protected AgentBody makeAgentBody() {
    Phys2DAgentBody aBody = new Phys2DAgentBody(new Circle(20), 0.5f);
    GridGPSSensor gps = new GridGPSSensor(cellwidth, cellheight);
    gps.setBody(aBody);
    aBody.add(gps);
    return aBody;
  }
 
  protected List<? extends Option> makeOptions() {
    List<Action> primitiveActions = new ArrayList<Action>();
    primitiveActions.add(GridMove.getNorthInstance(world));
    primitiveActions.add(GridMove.getNorthEastInstance(world));
    primitiveActions.add(GridMove.getEastInstance(world));
    primitiveActions.add(GridMove.getSouthEastInstance(world));
    primitiveActions.add(GridMove.getSouthInstance(world));
    primitiveActions.add(GridMove.getSouthWestInstance(world));
    primitiveActions.add(GridMove.getWestInstance(world));
    primitiveActions.add(GridMove.getNorthWestInstance(world));
   
    List<ReachLandmarkOption> optionPool = new ArrayList<ReachLandmarkOption>();
    Point p1 = new Point(4,14),
        p2 = new Point(4,8),
        p3 = new Point(10,8),
        p4 = new Point(10,2);
   
    ReachLandmarkOption o = new ReachLandmarkOption("L1", 8, p1, primitiveActions, new Point[]{p2,p3,p4});
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachLandmarkOption("L2", 8, p2, primitiveActions, new Point[]{p1,p3,p4});
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachLandmarkOption("L3", 8, p3, primitiveActions, new Point[]{p2,p1,p4});
    o.setID(optionPool.size());
    optionPool.add( o );
   
    o = new ReachLandmarkOption("L4", 8, p4, primitiveActions, new Point[]{p2,p3,p1});
    o.setID(optionPool.size());
    optionPool.add( o );

    return optionPool;
  }
 
  protected void makeWorld() {
    world = new GridLimitedOptionsWorld(15, 18, cellwidth, cellheight);
   
    Body landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
    landmark.setLocation(4.5*cellwidth, 14.5*cellheight);
    landmark.setType(Body.UNDEFINED);
    world.addFixedBody(landmark);
   
    landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
    landmark.setLocation(4.5*cellwidth, 8.5*cellheight);
    landmark.setType(Body.UNDEFINED);
    world.addFixedBody(landmark);
   
    landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
    landmark.setLocation(10.5*cellwidth, 8.5*cellheight);
    landmark.setType(Body.UNDEFINED);
    world.addFixedBody(landmark);
   
    landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
    landmark.setLocation(10.5*cellwidth, 2.5*cellheight);
    landmark.setType(Body.UNDEFINED);
    world.addFixedBody(landmark);
   
    rf = new GridRewardFunction();
    world.setRewardFunction(rf);
    world.addCollisionListener(rf);
   
    GoalBody goal = new GoalBody(525, 125);
    goal.setReward(0);
    world.addFixedBody(goal);
  }
 
  public void start() {
    if ( gui ) {
      GridWorldView wv = new LandmarkGridWorldView(world, 8);
      Timer timer = new Timer(1000/24, wv);
      ExperimentGUI g = new ExperimentGUI("Rooms example", wv, sim);
      timer.start();
      new Thread(this).start();
    } else
      run();
  }

  @Override
  public void run() {
    int episodes = 500;
    int runs = 100;
    double[][] allrewards = new double[runs][episodes];
    EpisodicRewardStepStatistic stat = new EpisodicRewardStepStatistic(episodes);
    sim.addStatistic(stat);
   
    for ( int r = 0; r < runs; r++ ) {
      sim.init();
      agent.setInterruptOptions(false);
      //sim.suspend();
      //agent.setInterruptOptions(true);
      for ( int i = 0; i < episodes; i++ ) {
        if ( i == 400 )
          agent.setInterruptOptions(true);
        sim.runEpisode();
        sim.reset();
      }
      allrewards[r] = stat.getRewards();
    }
    System.out.println(agent.getValueFunction());
    System.out.println("FINISHED");   
    System.out.println(qf);

    System.out.println(Arrays.toString(MatOp.mean(allrewards,2)));
  }

}

TOP

Related Classes of net.javlov.example.landmarks.Main

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.