package net.javlov.example;
import java.util.ArrayList;
import java.util.List;
import net.javlov.Agent;
import net.javlov.Option;
import net.javlov.State;
import net.javlov.VectorState;
import net.javlov.world.AgentBody;
import net.javlov.world.grid.GridWorld;
public class GridLimitedOptionsWorld extends GridWorld {
List<Option>[][] stateOptions;
public GridLimitedOptionsWorld(int width, int height, double cellWidth, double cellHeight) {
super(width, height, cellWidth, cellHeight);
stateOptions = new ArrayList[width][height];
}
public void setOptionPool(List<? extends Option> options) {
VectorState state;
ArrayList<Option> eligibleOptions;
for ( int i = 0; i < stateOptions.length; i++ ) {
for ( int j = 0; j < stateOptions[0].length; j++ ) {
eligibleOptions = new ArrayList<Option>();
state = new VectorState(new double[]{i, j});
for ( Option o : options )
if ( o.isEligible(state) )
eligibleOptions.add(o);
stateOptions[i][j] = eligibleOptions;
}
}
}
@Override
protected State constructObservation(Agent a) {
State s = super.constructObservation(a);
double[] data = (double[])s.getData();
s.setOptionSet( stateOptions[(int)data[0]][(int)data[1]] );
return s;
}
@Override
public void reset() {
super.reset();
AgentBody b = agentBodyMap.values().iterator().next(); //this is ridiculous
removeFromCells(b);
grid.getCellAtGridPosition(0, grid.getHeight()-1).addBody(b);
b.setLocation(0 + 0.5*grid.getCellWidth(), (grid.getHeight()-1 + 0.5)*grid.getCellHeight());
}
@Override
public void init() {
super.init();
AgentBody b = agentBodyMap.values().iterator().next(); //this is ridiculous
removeFromCells(b);
grid.getCellAtGridPosition(0, grid.getHeight()-1).addBody(b);
b.setLocation(0 + 0.5*grid.getCellWidth(), (grid.getHeight()-1 + 0.5)*grid.getCellHeight());
}
}