int treeIndex = 0;
for (int partition = 0; partition < nbMappers; partition++) {
int nbMapTrees = Step1Mapper.nbTrees(nbMappers, nbTrees, partition);
for (int tree = 0; tree < nbMapTrees; tree++, treeIndex++) {
keys[treeIndex] = new TreeID(partition, treeIndex);
// put the partition in the leaf's label
// this way we can track the outputs
trees[treeIndex] = new Leaf(partition);
}
sizes[partition] = splits[partition].length;
}
// store the first step outputs in a file
FileSystem fs = FileSystem.getLocal(new Configuration());
Path forestPath = new Path("testdata/Step2MapperTest.forest");
InterResults.store(fs, forestPath, keys, trees, sizes);
LongWritable key = new LongWritable();
Text value = new Text();
for (int partition = 0; partition < nbMappers; partition++) {
String[] split = splits[partition];
// number of trees that will be handled by the mapper
int nbConcerned = Step2Mapper.nbConcerned(nbMappers, nbTrees, partition);
PartialOutputCollector output = new PartialOutputCollector(nbConcerned);
// load the current mapper's (key, tree) pairs
TreeID[] curKeys = new TreeID[nbConcerned];
Node[] curTrees = new Node[nbConcerned];
InterResults.load(fs, forestPath, nbMappers, nbTrees, partition, curKeys, curTrees);
// simulate the job
MockStep2Mapper mapper = new MockStep2Mapper(partition, dataset, curKeys, curTrees, split.length);
for (int index = 0; index < split.length; index++) {
key.set(index);
value.set(split[index]);
mapper.map(key, value, output, Reporter.NULL);
}
mapper.close();
// make sure the mapper did not return its own trees
assertEquals(nbConcerned, output.nbOutputs());
// check the returned results
int current = 0;
for (int index = 0; index < nbTrees; index++) {
if (keys[index].partition() == partition) {
// should not be part of the results
continue;
}
TreeID k = output.getKeys()[current];
// the tree should receive the partition's index
assertEquals(partition, k.partition());
// make sure all the trees of the other partitions are handled in the
// correct order
assertEquals(index, k.treeId());
int[] predictions = output.getValues()[current].getPredictions();
// all the instances of the partition should be classified
assertEquals(split.length, predictions.length);