package net.javlov.example.landmarks;
import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.swing.Timer;
import util.MatOp;
import net.javlov.Action;
import net.javlov.Agent;
import net.javlov.DecayingLearningRate;
import net.javlov.EpisodicRewardStepStatistic;
import net.javlov.FixedLearningRate;
import net.javlov.LearningRate;
import net.javlov.Option;
import net.javlov.Policy;
import net.javlov.QLearningAgent;
import net.javlov.SarsaAgent;
import net.javlov.Simulator;
import net.javlov.TabularQFunction;
import net.javlov.TabularValueFunction;
import net.javlov.example.ExperimentGUI;
import net.javlov.example.GridLimitedOptionsWorld;
import net.javlov.example.rooms.ReachHallOption;
import net.javlov.policy.EGreedyPolicy;
import net.javlov.world.AgentBody;
import net.javlov.world.Body;
import net.javlov.world.GoalBody;
import net.javlov.world.grid.GridGPSSensor;
import net.javlov.world.grid.GridMove;
import net.javlov.world.grid.GridRewardFunction;
import net.javlov.world.phys2d.Phys2DAgentBody;
import net.javlov.world.phys2d.Phys2DBody;
import net.javlov.world.ui.GridWorldView;
import net.phys2d.raw.shapes.Circle;
import net.phys2d.raw.shapes.StaticBox;
public class Main implements Runnable {
GridLimitedOptionsWorld world;
GridRewardFunction rf;
Simulator sim;
TabularQFunction qf;
SarsaAgent agent;
boolean gui;
int cellwidth, cellheight;
public static void main(String[] args) {
Main m = new Main();
m.gui = false;
m.init();
m.start();
}
public void init() {
cellwidth = 50; cellheight = cellwidth;
makeWorld();
List<? extends Option> optionPool = makeOptions();
world.setOptionPool(optionPool);
agent = makeAgent(optionPool);
AgentBody aBody = makeAgentBody();
sim = new Simulator();
sim.setEnvironment(world);
world.add(agent, aBody);
sim.setAgent(agent);
}
protected SarsaAgent makeAgent(List<? extends Option> optionPool) {
qf = TabularQFunction.getInstance(optionPool.size());
SarsaAgent a = new QLearningAgent(qf, 1, optionPool);
//a.setLearnRate(new DecayingLearningRate(1, optionPool.size(), 0.8));
a.setLearnRate( new FixedLearningRate(0.2) );
Policy pi = new EGreedyPolicy(qf, 0.05, optionPool);
a.setPolicy(pi);
a.setSMDPMode(true);
a.setLearnStateValueFunction(false);
a.setInterruptOptions(false);
return a;
}
protected AgentBody makeAgentBody() {
Phys2DAgentBody aBody = new Phys2DAgentBody(new Circle(20), 0.5f);
GridGPSSensor gps = new GridGPSSensor(cellwidth, cellheight);
gps.setBody(aBody);
aBody.add(gps);
return aBody;
}
protected List<? extends Option> makeOptions() {
List<Action> primitiveActions = new ArrayList<Action>();
primitiveActions.add(GridMove.getNorthInstance(world));
primitiveActions.add(GridMove.getNorthEastInstance(world));
primitiveActions.add(GridMove.getEastInstance(world));
primitiveActions.add(GridMove.getSouthEastInstance(world));
primitiveActions.add(GridMove.getSouthInstance(world));
primitiveActions.add(GridMove.getSouthWestInstance(world));
primitiveActions.add(GridMove.getWestInstance(world));
primitiveActions.add(GridMove.getNorthWestInstance(world));
List<ReachLandmarkOption> optionPool = new ArrayList<ReachLandmarkOption>();
Point p1 = new Point(4,14),
p2 = new Point(4,8),
p3 = new Point(10,8),
p4 = new Point(10,2);
ReachLandmarkOption o = new ReachLandmarkOption("L1", 8, p1, primitiveActions, new Point[]{p2,p3,p4});
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachLandmarkOption("L2", 8, p2, primitiveActions, new Point[]{p1,p3,p4});
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachLandmarkOption("L3", 8, p3, primitiveActions, new Point[]{p2,p1,p4});
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachLandmarkOption("L4", 8, p4, primitiveActions, new Point[]{p2,p3,p1});
o.setID(optionPool.size());
optionPool.add( o );
return optionPool;
}
protected void makeWorld() {
world = new GridLimitedOptionsWorld(15, 18, cellwidth, cellheight);
Body landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
landmark.setLocation(4.5*cellwidth, 14.5*cellheight);
landmark.setType(Body.UNDEFINED);
world.addFixedBody(landmark);
landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
landmark.setLocation(4.5*cellwidth, 8.5*cellheight);
landmark.setType(Body.UNDEFINED);
world.addFixedBody(landmark);
landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
landmark.setLocation(10.5*cellwidth, 8.5*cellheight);
landmark.setType(Body.UNDEFINED);
world.addFixedBody(landmark);
landmark = new Phys2DBody( new StaticBox(cellwidth, cellheight), 10, true );
landmark.setLocation(10.5*cellwidth, 2.5*cellheight);
landmark.setType(Body.UNDEFINED);
world.addFixedBody(landmark);
rf = new GridRewardFunction();
world.setRewardFunction(rf);
world.addCollisionListener(rf);
GoalBody goal = new GoalBody(525, 125);
goal.setReward(0);
world.addFixedBody(goal);
}
public void start() {
if ( gui ) {
GridWorldView wv = new LandmarkGridWorldView(world, 8);
Timer timer = new Timer(1000/24, wv);
ExperimentGUI g = new ExperimentGUI("Rooms example", wv, sim);
timer.start();
new Thread(this).start();
} else
run();
}
@Override
public void run() {
int episodes = 500;
int runs = 100;
double[][] allrewards = new double[runs][episodes];
EpisodicRewardStepStatistic stat = new EpisodicRewardStepStatistic(episodes);
sim.addStatistic(stat);
for ( int r = 0; r < runs; r++ ) {
sim.init();
agent.setInterruptOptions(false);
//sim.suspend();
//agent.setInterruptOptions(true);
for ( int i = 0; i < episodes; i++ ) {
if ( i == 400 )
agent.setInterruptOptions(true);
sim.runEpisode();
sim.reset();
}
allrewards[r] = stat.getRewards();
}
System.out.println(agent.getValueFunction());
System.out.println("FINISHED");
System.out.println(qf);
System.out.println(Arrays.toString(MatOp.mean(allrewards,2)));
}
}