/*
* Javlov - a Java toolkit for reinforcement learning with multi-agent support.
*
* Copyright (c) 2009 Matthijs Snel
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package net.javlov.world.grid;
import java.awt.Point;
import java.awt.Shape;
import java.awt.geom.Point2D;
import java.awt.geom.Rectangle2D;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import net.javlov.Action;
import net.javlov.Agent;
import net.javlov.Environment;
import net.javlov.RewardFunction;
import net.javlov.State;
import net.javlov.world.AgentBody;
import net.javlov.world.Body;
import net.javlov.world.CollisionEvent;
import net.javlov.world.CollisionListener;
import net.javlov.world.World;
/**
* Grid world implementation that assumes all moving bodies fit within one grid cell.
* @author matthijs
*
*/
public class GridWorld implements World.Discrete, IGridWorld {
/**
* Maps agents to their bodies. Done this way because an agent is not allowed to have access
* to its own environment/body.
*/
protected Map<Agent, AgentBody> agentBodyMap;
/**
* All the bodies in the world.
*/
protected List<Body> bodies;
protected List<Body> fixedBodies;
/**
* The grid.
*/
protected Grid grid;
protected boolean episodeEnd;
private List<CollisionListener> listeners;
protected GridRewardFunction reward;
private Queue<CollisionEvent> collisionEvents;
//single-agent hack for speed
protected State lastState;
public GridWorld(int width, int height, double cellWidth, double cellHeight) {
agentBodyMap = new HashMap<Agent, AgentBody>();
bodies = new ArrayList<Body>();
fixedBodies = new ArrayList<Body>();
grid = new Grid(width, height, cellWidth, cellHeight);
listeners = new ArrayList<CollisionListener>();
initCollisionQueue();
}
protected void initCollisionQueue() {
collisionEvents = new ArrayDeque<CollisionEvent>();
}
@Override
public boolean add(Agent a, AgentBody b) {
if ( addBody(b) ) {
agentBodyMap.put(a, b);
return true;
}
return false;
}
@Override
public boolean addBody(Body b) {
if ( bodies.add(b) ) {
addToCells(b);
return true;
}
return false;
}
public boolean addFixedBody(Body b) {
if ( fixedBodies.add(b) ) {
addToCells(b);
return true;
}
return false;
}
public void addCollisionListener(CollisionListener listener) {
listeners.add(listener);
}
@Override
public Body getAgentBody(Agent a) {
return agentBodyMap.get(a);
}
@Override
public double getHeight() {
return grid.getHeight()*grid.getCellHeight();
}
/**
* @inheritDoc
*/
@Override
public List<Body> getIntersectingObjects(Shape s) {
ArrayList<Body> objects = new ArrayList<Body>();
Rectangle2D bounds = s.getBounds2D();
Set<Body> occupiers;
if ( bounds.getWidth() > grid.getCellWidth() || bounds.getHeight() > grid.getCellHeight() )
occupiers = getOccupiers( getIntersectingCellsLarge(s) );
else
occupiers = getOccupiers( getIntersectingCells(s) );
for ( Body b : occupiers ) {
if ( s.intersects(b.getFigure().getBounds2D()) )
objects.add(b);
}
return objects;
}
@Override
public List<Body> getObjects() {
List<Body> allBodies = new ArrayList<Body>(bodies.size() + fixedBodies.size());
allBodies.addAll(bodies);
allBodies.addAll(fixedBodies);
return allBodies;
}
@Override
public double getTimeStep() {
// TODO Auto-generated method stub
return 0;
}
@Override
public double getWidth() {
return grid.getWidth()*grid.getCellWidth();
}
/**
* @inheritDoc
*/
@Override
public boolean intersectsObject(Shape s) {
for ( Body b : getOccupiers( getIntersectingCells(s) ) )
if ( s.intersects(b.getFigure().getBounds2D()) )
return true;
return false;
}
protected List<GridCell> getIntersectingCells(Shape s) {
//IMPORTANT: the following assumes the shape is SMALLER
//than the width & height of a cell!
Rectangle2D bounds = s.getBounds2D();
double cx = bounds.getCenterX(),
cy = bounds.getCenterY();
List<GridCell> intersectingCells = new ArrayList<GridCell>(5);
GridCell cell = grid.getCell(cx, cy);
intersectingCells.add(cell);
//if the shape is not completely contained within the current cell, check for other
//intersecting cells
if ( !cell.contains(bounds) ) {
GridCell[] neighbours = cell.getQuadrantNeightbours(cx, cy);
for ( int j = 0; j < neighbours.length; j++ )
if ( neighbours[j].intersects(bounds) )
intersectingCells.add(neighbours[j]);
}
return intersectingCells;
}
protected List<GridCell> getIntersectingCellsLarge(Shape s) {
Rectangle2D bounds = s.getBounds();
List<GridCell> boundsCells = new ArrayList<GridCell>();
double minx = Math.max(bounds.getMinX(), 0),
maxx = Math.min(bounds.getMaxX()+grid.getCellWidth(), grid.getWidth()*grid.getCellWidth()),
miny = Math.max(bounds.getMinY(), 0),
maxy = Math.min(bounds.getMaxY()+grid.getCellHeight(), grid.getHeight()*grid.getCellHeight());
for ( double x = minx; x < maxx; x += grid.getCellWidth() )
for ( double y = miny; y < maxy; y += grid.getCellHeight() )
boundsCells.add( grid.getCell(x, y) );
List<GridCell> intersectingCells = new ArrayList<GridCell>(boundsCells.size());
for ( GridCell cell : boundsCells )
if ( s.intersects(cell) )
intersectingCells.add(cell);
return intersectingCells;
}
protected Set<Body> getOccupiers(Collection<? extends GridCell> cells) {
Set<Body> occupiers = new HashSet<Body>( (int) Math.ceil(cells.size() / 0.75) );
for ( GridCell c : cells )
occupiers.addAll(c.getOccupiers());
return occupiers;
}
@Override
public boolean remove(Agent a) {
Body b = agentBodyMap.get(a);
if ( removeBody(b) )
return (agentBodyMap.remove(a) == null ? false : true);
return false;
}
@Override
public boolean removeBody(Body b) {
if ( bodies.remove(b) ) {
removeFromCells(b);
return true;
}
return false;
}
public boolean removeFixedBody(Body b) {
if ( fixedBodies.remove(b) ) {
removeFromCells(b);
return true;
}
return false;
}
@Override
public double executeAction(Action act, Agent a) {
/*reward.preAction(act, a, lastState);
act.execute(a);
lastState = agentBodyMap.get(a).getState(a);
lastState.setTerminal(episodeEnd);
return reward.calculateReward(lastState);*/
//State s = agentBodyMap.get(a).getState(a);
//s.setTerminal(episodeEnd);
State s = lastState;
//System.out.println("Preact: " + s);
//System.out.println(act.getClass() + " " + act.getID());
reward.preAction(act, a, s);
act.execute(a);
processCollisionEvents();
lastState = constructObservation(a);
s = lastState;
//System.out.println("Postact: " + s);
double r = reward.calculateReward(s);
//System.out.println("Reward: " + r);
return r;
}
protected State constructObservation(Agent a) {
State s = agentBodyMap.get(a).getObservation(a);
s.setTerminal(episodeEnd);
return s;
}
/**
* Returns the state by calling {@link AgentBody#getObservation(Agent)} on the agent's body.
*/
@Override
public State getObservation(Agent a) {
if ( lastState == null ) {
lastState = constructObservation(a);
}
//System.out.println("Getstate: " + lastState + ":" + lastState.isTerminal());
return lastState;
}
/**
* Returns the state dim as indicated by {@link AgentBody#getObservationDim()}
*/
@Override
public int getObservationDim() {
Iterator<AgentBody> it = agentBodyMap.values().iterator();
return it.next().getObservationDim();
}
@Override
public void init() {
episodeEnd = false;
lastState = null;
//remove everything from grid
for ( Body b : bodies )
removeFromCells(b);
//put everything back
randomlyPositionAll();
for ( Environment env : agentBodyMap.values() )
env.init();
}
@Override
public void reset() {
episodeEnd = false;
lastState = null;
//remove and reallocate agents
for ( Body b : agentBodyMap.values() )
removeFromCells(b);
for ( Body b : agentBodyMap.values() )
setRandomPosition(b);
for ( Environment env : agentBodyMap.values() )
env.reset();
/*Body b = agentBodyMap.values().iterator().next();
GridCell currCell = grid.getCell(b.getX(), b.getY());
currCell.removeBody(b);
GridCell startCell = grid.getCell(0,0);
if ( startCell.getOccupiers().size() == 0 ) {
startCell.addBody(b);
b.setLocation(startCell.getCenterX(), startCell.getCenterY());
}
else {
startCell = grid.getCell(0,1);
startCell.addBody(b);
b.setLocation(startCell.getCenterX(), startCell.getCenterY());
}*/
}
/**
* Rotates the body without checking for collisions (since body is assumed to fit in
* one cell). Always returns true.
*/
@Override
public boolean rotateBody(Body b, double angle) {
b.setBearing(b.getBearing()+angle);
return true;
}
/**
* Doesn't do anything. After initialisation of the world, the grid cannot be changed.
*/
@Override
public void setHeight(int height) {}
/**
* Doesn't do anything. After initialisation of the world, the grid cannot be changed.
*/
@Override
public void setWidth(int width) {}
@Override
public boolean translateBody(Body b, int dx, int dy) {
int absx = Math.abs(dx),
absy = Math.abs(dy);
if ( absx > 0 && absy > 0 && absx != absy )
throw new IllegalArgumentException("GridWorld: can only move in straight "
+ "line, or in diagonal such that dx=dy.");
//now move in unit steps to see if there is anything in the body's path
int unitdx = (absx == 0 ? 0 : dx/absx),
unitdy = (absy == 0 ? 0 : dy/absy),
steps = (absx+absy) / (unitdx+unitdy);
return translateBody(b, Direction.get(unitdx, unitdy), steps);
}
public boolean translateBody(Body b, Direction d, int speed) {
GridCell origCell = grid.getCell(b.getX(), b.getY()),
currCell = origCell,
targetCell;
int i;
for ( i = 0; i < speed; i++) {
targetCell = currCell.go(d);
if ( targetCell.isBorder() || !move(b, d, targetCell) )
break;
currCell = targetCell;
}
if ( i > 0 ) {
origCell.removeBody(b);
currCell.addBody(b);
b.setLocation(currCell.getCenterX(), currCell.getCenterY());
return true;
}
return false;
}
protected boolean move( Body b, Direction d, GridCell targetCell ) {
//TODO Don't like all these fors and ifs
List<Body> occupiers = targetCell.getOccupiers();
for ( Body targetBody : occupiers )
if ( targetBody.getType() == Body.OBSTACLE || targetBody.getType() == Body.AGENT ) {
addCollisionEvent(b, targetBody, new Point2D.Double(d.x(), d.y()));
return false;
}
for ( Body targetBody : occupiers )
if ( targetBody.getType() == Body.MOVABLE ) {
addCollisionEvent(b, targetBody, new Point2D.Double(d.x(), d.y()));
if ( !translateBody(targetBody, d, 1) )
return false;
}
//just create collisionevent for the other body types
for ( Body targetBody : occupiers )
if ( targetBody.getType() != Body.OBSTACLE && targetBody.getType() != Body.MOVABLE ) {
addCollisionEvent(b, targetBody, new Point2D.Double(d.x(), d.y()));
}
return true;
}
@Override
public Grid getGrid() {
return grid;
}
public RewardFunction getRewardFunction() {
return reward;
}
public void setRewardFunction(GridRewardFunction reward) {
this.reward = reward;
reward.setRewardBroker(new RewardBrokerImpl());
}
protected void addToCells(Body b) {
int count = 0;
for ( GridCell c : getIntersectingCellsLarge(b.getFigure()) ) {
count++;
c.addBody(b);
}
}
protected void removeFromCells(Body b) {
for ( GridCell c : getIntersectingCellsLarge(b.getFigure()) )
c.removeBody(b);
}
protected void addCollisionEvent(Body b1, Body b2, Point2D.Double speed) {
CollisionEvent event = new CollisionEvent(b1, b2, speed, (Point2D.Double)b2.getLocation());
collisionEvents.add(event);
}
protected void processCollisionEvents() {
CollisionEvent event;
while ( !collisionEvents.isEmpty() ) {
event = collisionEvents.remove();
for ( CollisionListener listener : listeners )
listener.collisionOccurred(event);
}
}
protected void randomlyPositionAll() {
List<Point> freePositions = new ArrayList<Point>(grid.getWidth()*grid.getHeight());
for ( int x = 0; x < grid.getWidth(); x++ )
for ( int y = 0; y < grid.getHeight(); y++ )
freePositions.add( new Point(x,y) );
//randomly put everything back
for ( Body b : bodies )
/*if ( b.getType() == Body.OBSTACLE )
testClosedLoop(b);
else*/
setRandomPositionFromList(b, freePositions);
}
protected void testClosedLoop(Body b) {
double cellwidth = grid.getCellWidth(),
cellheight = grid.getCellHeight();
Rectangle2D bounds = b.getFigure().getBounds2D();
int bodywidth = (int)Math.ceil((bounds.getWidth()-1) / cellwidth),
bodyheight = (int)Math.ceil((bounds.getHeight()-1) / cellheight);
int x = 0, y = 0;
GridCell cell;
List<GridCell> cells = new ArrayList<GridCell>();
if ( bodywidth > 1 ) {
x = 0; y = 7;
System.out.println("==================== Horizontal body");
cell = grid.getCell(0, 700);
cells.add(cell);
cell = grid.getCell(100, 700);
cells.add(cell);
for ( GridCell c : cells )
c.addBody(b);
} else {
x = 2; y = 8;
System.out.println("==================== Vertical body");
cell = grid.getCell(200, 800);
cells.add(cell);
cell = grid.getCell(200, 900);
cells.add(cell);
for ( GridCell c : cells )
c.addBody(b);
}
System.out.println("Closed loop: " + closedLoop(b, cells));
b.setLocation( (x + 0.01)*cellwidth + 0.5*bounds.getWidth(),
(y + 0.01)*cellheight + 0.5*bounds.getHeight());
}
protected void setRandomPositionFromList(Body b, List<Point> freePositions) {
double cellwidth = grid.getCellWidth(),
cellheight = grid.getCellHeight();
Rectangle2D bounds = b.getFigure().getBounds();
int bodywidth = (int)Math.ceil((bounds.getWidth()-1) / cellwidth),
bodyheight = (int)Math.ceil((bounds.getHeight()-1) / cellheight),
gridwidth = grid.getWidth(),
gridheight = grid.getHeight();
Point pick;
List<GridCell> cells = new ArrayList<GridCell>();
List<Point> points = new ArrayList<Point>();
GridCell cell;
boolean occupied;
int counter = 0;
do {
counter++;
for ( GridCell c : cells )
c.removeBody(b);
cells.clear();
points.clear();
occupied = false;
do {
pick = freePositions.get( (int)(Math.random()*freePositions.size()) );
} while (pick.x + bodywidth > gridwidth || pick.y + bodyheight > gridheight);
for ( int x = 0; x < bodywidth; x++ ) {
for ( int y = 0; y < bodyheight; y++ ) {
cell = grid.getCell((pick.x + x)*cellwidth+1, (pick.y + y)*cellheight+1);
if ( cell.getOccupiers().size() > 0 ) {
occupied = true;
break;
}
cell.addBody(b);
cells.add(cell);
points.add( new Point(pick.x + x, pick.y + y) );
}
if ( occupied )
break;
}
} while (occupied || closedLoop(b, cells));
b.setLocation( (pick.x + 0.01)*cellwidth + 0.5*bounds.getWidth(),
(pick.y + 0.01)*cellheight + 0.5*bounds.getHeight());
//b.setLocation( (pick.x + 0.5)*cellwidth,
// (pick.y + 0.5)*cellheight);
freePositions.removeAll(points);
//if ( counter > 20 )
// System.out.println("Warning. More than 20 tries: " + counter);
}
protected void setRandomPosition(Body b) {
int width = grid.getWidth(),
height = grid.getHeight();
double cellwidth = grid.getCellWidth(),
cellheight = grid.getCellHeight();
Rectangle2D bounds = b.getFigure().getBounds();
int counter = 0;
List<GridCell> cells;
boolean occupied;
do {
occupied = false;
double x = (int)(Math.random()*(width - (bounds.getWidth()-1) / cellwidth))*cellwidth,
y = (int)(Math.random()*(height - (bounds.getHeight()-1) / cellheight))*cellheight;
//TODO change mechanism to fix below problem
//add small value to position, otherwise (through floating point errors) grid
//sometimes thinks body is in wrong cell (if body bounds are exactly on border
//of cells)
//b.setLocation( x + 0.5*bounds.getWidth() + 0.01*cellwidth,
// y + 0.01*cellheight + 0.5*bounds.getHeight());
b.setLocation( x + 0.5*cellwidth, y + 0.5*cellheight);
cells = getIntersectingCellsLarge( b.getFigure() );
for ( GridCell c : cells )
if ( c.getOccupiers().size() > 0 ) {
occupied = true;
}
} while ( counter++ < 100 && (occupied || closedLoop(b, cells)) );
if ( counter > 100 )
throw new RuntimeException("Could not reposition body after 100 tries.");
for ( GridCell c : cells )
c.addBody(b);
}
protected boolean closedLoop(Body b, List<GridCell> cells) {
if ( b.getType() != Body.OBSTACLE )
return false;
GridCell[] neighbours;
List<GridCell> checkedCells = new ArrayList<GridCell>(20);
GridCell prevCell = null;
for ( GridCell cell : cells ) {
neighbours = cell.getNeighbours();
for ( int i = 0; i < neighbours.length; i++ ) {
for ( Body occupier : neighbours[i].getOccupiers() ) {
if ( !neighbours[i].equals(prevCell) && occupier.getType() == Body.OBSTACLE ) {
checkedCells.add(cell);
if ( closedLoopSub(checkedCells, cell, neighbours[i]) )
return true;
break;
}
}
}
prevCell = cell;
}
return false;
}
protected boolean closedLoopSub(List<GridCell> checkedCells, GridCell prevCell, GridCell currCell) {
//System.out.println("- Sub: " + checkedCells + ", " + prevCell + "," + currCell);
//below if block also invalidates rare cases that are not a closed loop. but who cares.
if ( currCell.isAtRim() ) {
int dir;
for ( dir = 0; dir < 8; dir += 2 )
if ( currCell.isAtRim(dir) )
break;
boolean oneNotAtRim = false, oneAtRim = false;
for ( GridCell cell : checkedCells ) {
if ( cell.isAtRim() )
oneAtRim = true;
if ( !cell.isAtRim(dir) )
oneNotAtRim = true;
if ( oneNotAtRim && oneAtRim ) {
return true;
}
}
checkedCells.add(currCell);
return false;
}
GridCell[] neighbours = currCell.getNeighbours();
for ( int i = 0; i < neighbours.length; i++ )
if ( !neighbours[i].equals(prevCell) && !neighbours[i].isNeighbour(prevCell) ) {
if ( checkedCells.contains(neighbours[i]) )
return true;
for ( Body occupier : neighbours[i].getOccupiers() )
if ( occupier.getType() == Body.OBSTACLE ) {
checkedCells.add(currCell); //will lead to doubles
if ( closedLoopSub(checkedCells, currCell, neighbours[i]) )
return true;
break;
}
}
return false;
}
/*
protected void setRandomPositionFromArray(Body b, int[][] positions) {
List<GridCell> cells = new ArrayList<GridCell>();
GridCell cell;
int counter = 0;
Point pick;
do {
counter++;
for ( GridCell c : cells )
c.removeBody(b);
cells.clear();
pick = pickPosition(b, positions);
for ( int x = 0; x < bodywidth; x++ ) {
for ( int y = 0; y < bodyheight; y++ ) {
cell = grid.getCell((pick.x + x)*cellwidth+1, (pick.y + y)*cellheight+1);
if ( cell.getOccupiers().size() > 0 ) {
occupied = true;
break;
}
cell.addBody(b);
cells.add(cell);
points.add( new Point(pick.x + x, pick.y + y) );
}
if ( occupied )
break;
}
} while (occupied || closedLoop(b, cells));
b.setLocation( (pick.x + 0.01)*cellwidth + 0.5*bounds.getWidth(),
(pick.y + 0.01)*cellheight + 0.5*bounds.getHeight());
freePositions.removeAll(points);
if ( counter > 20 )
System.out.println("Warning. More than 20 tries: " + counter);
}
protected Point pickPosition(Body b, int[][] positions) {
double cellwidth = grid.getCellWidth(),
cellheight = grid.getCellHeight();
Rectangle2D bounds = b.getFigure().getBounds2D();
int bodywidth = (int)Math.ceil((bounds.getWidth()-1) / cellwidth),
bodyheight = (int)Math.ceil((bounds.getHeight()-1) / cellheight);
int x = 0, y = 0;
boolean occupied;
do {
occupied = false;
x = (int)(Math.random()*(positions.length-bodywidth));
y = (int)(Math.random()*(positions[0].length-bodyheight));
for ( int i = x; i < x+bodywidth; i++ ) {
for ( int j = y; j < y+bodywidth; j++ ) {
if ( positions[i][j] == -1 ) {
occupied = true;
break;
}
}
if ( occupied )
break;
}
} while (occupied);
return new Point(x,y);
}*/
public interface RewardBroker {
Map<Agent, AgentBody> getAgentBodyMap();
List<Body> getBodies();
Grid getGrid();
void endEpisode();
}
protected class RewardBrokerImpl implements RewardBroker {
@Override
public Map<Agent, AgentBody> getAgentBodyMap() {
return agentBodyMap;
}
@Override
public List<Body> getBodies() {
return bodies;
}
@Override
public Grid getGrid() {
return grid;
}
@Override
public void endEpisode() {
episodeEnd = true;
}
}
}