/*
* Javlov - a Java toolkit for reinforcement learning with multi-agent support.
*
* Copyright (c) 2009 Matthijs Snel
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package net.javlov.example.rooms;
import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.swing.Timer;
import net.javlov.*;
import net.javlov.policy.EGreedyPolicy;
import net.javlov.world.*;
import net.javlov.world.grid.*;
import net.javlov.world.phys2d.*;
import net.javlov.world.ui.GridWorldView;
import net.phys2d.raw.shapes.Circle;
import net.phys2d.raw.shapes.StaticBox;
import net.javlov.example.ExperimentGUI;
import net.javlov.example.GridLimitedOptionsWorld;
public class Main implements Runnable {
GridLimitedOptionsWorld world;
GridRewardFunction rf;
Simulator sim;
TabularQFunction qf;
boolean gui;
int cellwidth, cellheight;
public static void main(String[] args) {
Main m = new Main();
m.gui = true;
m.init();
m.start();
}
public void init() {
cellwidth = 50; cellheight = cellwidth;
makeWorld();
List<? extends Option> optionPool = makeOptions();
world.setOptionPool(optionPool);
Agent a = makeAgent(optionPool);
AgentBody aBody = makeAgentBody();
sim = new Simulator();
sim.setEnvironment(world);
world.add(a, aBody);
sim.setAgent(a);
}
protected Agent makeAgent(List<? extends Option> optionPool) {
qf = TabularQFunction.getInstance(optionPool.size());
SarsaAgent a = new QLearningAgent(qf, 1, optionPool);
a.setLearnRate(new DecayingLearningRate(1, optionPool.size(), 0.8));
Policy pi = new EGreedyPolicy(qf, 0.1, optionPool);
a.setPolicy(pi);
a.setSMDPMode(false);
return a;
}
protected AgentBody makeAgentBody() {
Phys2DAgentBody aBody = new Phys2DAgentBody(new Circle(20), 0.5f);
GridGPSSensor gps = new GridGPSSensor(cellwidth, cellheight);
gps.setBody(aBody);
aBody.add(gps);
return aBody;
}
protected List<? extends Option> makeOptions() {
List<Action> primitiveActions = new ArrayList<Action>();
primitiveActions.add(GridMove.getNorthInstance(world));
primitiveActions.add(GridMove.getEastInstance(world));
primitiveActions.add(GridMove.getSouthInstance(world));
primitiveActions.add(GridMove.getWestInstance(world));
List<Option> optionPool = new ArrayList<Option>();
Option o = new ReachHallOption("R1H1", 0, 8, 0, 9, new Point(4,10), new Point(9,2), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R1H4", 0, 8, 0, 9, new Point(9,2), new Point(4,10), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R2H1", 10, 19, 0, 7, new Point(16,8), new Point(9,2), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R2H2", 10, 19, 0, 7, new Point(9,2), new Point(16,8), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R3H2", 10, 19, 9, 19, new Point(9,15), new Point(16,8), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R3H3", 10, 19, 9, 19, new Point(16,8), new Point(9,15), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R4H3", 0, 8, 11, 19, new Point(4,10), new Point(9,15), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
o = new ReachHallOption("R4H4", 0, 8, 11, 19, new Point(9,15), new Point(4,10), primitiveActions);
o.setID(optionPool.size());
optionPool.add( o );
return optionPool;
}
protected void makeWorld() {
world = new GridLimitedOptionsWorld(20, 20, cellwidth, cellheight);
Body wall = new Phys2DBody( new StaticBox(4*cellwidth, cellheight), 10, true );
wall.setLocation(2*cellwidth, 10*cellheight+0.5*cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
wall = new Phys2DBody( new StaticBox(4*cellwidth, cellheight), 10, true );
wall.setLocation(7*cellwidth, 10*cellheight+0.5*cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
wall = new Phys2DBody( new StaticBox(cellwidth, 2*cellheight), 10, true );
wall.setLocation(9*cellwidth+0.5*cellwidth, cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
wall = new Phys2DBody( new StaticBox(cellwidth, 12*cellheight), 10, true );
wall.setLocation(9*cellwidth+0.5*cellwidth, 9*cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
wall = new Phys2DBody( new StaticBox(cellwidth, 4*cellheight), 10, true );
wall.setLocation(9*cellwidth+0.5*cellwidth, 18*cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
wall = new Phys2DBody( new StaticBox(6*cellwidth, cellheight), 10, true );
wall.setLocation(13*cellwidth, 8*cellheight+0.5*cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
wall = new Phys2DBody( new StaticBox(3*cellwidth, cellheight), 10, true );
wall.setLocation(18.5*cellwidth, 8*cellheight+0.5*cellheight);
wall.setType(Body.OBSTACLE);
world.addFixedBody(wall);
rf = new GridRewardFunction();
world.setRewardFunction(rf);
world.addCollisionListener(rf);
GoalBody goal = new GoalBody(825, 425);
goal.setReward(0);
world.addFixedBody(goal);
}
public void start() {
if ( gui ) {
GridWorldView wv = new GridWorldView(world);
Timer timer = new Timer(1000/24, wv);
ExperimentGUI g = new ExperimentGUI("Rooms example", wv, sim);
timer.start();
new Thread(this).start();
} else
run();
}
@Override
public void run() {
int episodes = 5000;
EpisodicRewardStepStatistic stat = new EpisodicRewardStepStatistic(episodes);
sim.addStatistic(stat);
sim.init();
sim.suspend();
sim.runEpisodes(episodes);
System.out.println(Arrays.toString(stat.getRewards()));
System.out.println(qf);
}
}