public Object answer(InvocationOnMock invocation) {
when(mockContext.getVertexNumTasks(mockManagedVertexId)).thenReturn(2);
newEdgeManagers.clear();
for (Entry<String, EdgeManagerDescriptor> entry :
((Map<String, EdgeManagerDescriptor>)invocation.getArguments()[2]).entrySet()) {
EdgeManager edgeManager = RuntimeUtils.createClazzInstance(
entry.getValue().getClassName());
final byte[] userPayload = entry.getValue().getUserPayload();
edgeManager.initialize(new EdgeManagerContext() {
@Override
public byte[] getUserPayload() {
return userPayload;
}
@Override
public String getSrcVertexName() {
return null;
}
@Override
public String getDestVertexName() {
return null;
}
});
newEdgeManagers.put(entry.getKey(), edgeManager);
}
return null;
}}).when(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());
// source vertices have 0 tasks. immediate start of all managed tasks
when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(0);
when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(0);
when(mockContext.getVertexNumTasks(mockSrcVertexId3)).thenReturn(1);
manager.onVertexStarted(null);
Assert.assertTrue(manager.pendingTasks.isEmpty());
Assert.assertTrue(scheduledTasks.size() == 4); // all tasks scheduled
scheduledTasks.clear();
when(mockContext.getVertexNumTasks(mockSrcVertexId1)).thenReturn(2);
when(mockContext.getVertexNumTasks(mockSrcVertexId2)).thenReturn(2);
byte[] payload =
VertexManagerEventPayloadProto.newBuilder().setOutputSize(5000L).build().toByteArray();
VertexManagerEvent vmEvent = new VertexManagerEvent("Vertex", payload);
// parallelism not change due to large data size
manager = createManager(conf, mockContext, 0.1f, 0.1f);
manager.onVertexStarted(null);
Assert.assertTrue(manager.pendingTasks.size() == 4); // no tasks scheduled
Assert.assertTrue(manager.numSourceTasks == 4);
manager.onVertexManagerEventReceived(vmEvent);
manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
// managedVertex tasks reduced
verify(mockContext, times(0)).setVertexParallelism(anyInt(), any(VertexLocationHint.class), anyMap());
Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
Assert.assertEquals(4, scheduledTasks.size());
Assert.assertEquals(1, manager.numSourceTasksCompleted);
Assert.assertEquals(5000L, manager.completedSourceTasksOutputSize);
// parallelism changed due to small data size
scheduledTasks.clear();
payload =
VertexManagerEventPayloadProto.newBuilder().setOutputSize(500L).build().toByteArray();
vmEvent = new VertexManagerEvent("Vertex", payload);
manager = createManager(conf, mockContext, 0.5f, 0.5f);
manager.onVertexStarted(null);
Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
Assert.assertEquals(4, manager.numSourceTasks);
// task completion from non-bipartite stage does nothing
manager.onSourceTaskCompleted(mockSrcVertexId3, new Integer(0));
Assert.assertEquals(4, manager.pendingTasks.size()); // no tasks scheduled
Assert.assertEquals(4, manager.numSourceTasks);
Assert.assertEquals(0, manager.numSourceTasksCompleted);
manager.onVertexManagerEventReceived(vmEvent);
manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
Assert.assertEquals(4, manager.pendingTasks.size());
Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
Assert.assertEquals(1, manager.numSourceTasksCompleted);
Assert.assertEquals(1, manager.numVertexManagerEventsReceived);
Assert.assertEquals(500L, manager.completedSourceTasksOutputSize);
// ignore duplicate completion
manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(0));
Assert.assertEquals(4, manager.pendingTasks.size());
Assert.assertEquals(0, scheduledTasks.size()); // no tasks scheduled
Assert.assertEquals(1, manager.numSourceTasksCompleted);
Assert.assertEquals(500L, manager.completedSourceTasksOutputSize);
manager.onVertexManagerEventReceived(vmEvent);
manager.onSourceTaskCompleted(mockSrcVertexId1, new Integer(1));
// managedVertex tasks reduced
verify(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());
Assert.assertEquals(2, newEdgeManagers.size());
// TODO improve tests for parallelism
Assert.assertEquals(0, manager.pendingTasks.size()); // all tasks scheduled
Assert.assertEquals(2, scheduledTasks.size());
Assert.assertTrue(scheduledTasks.contains(new Integer(0)));
Assert.assertTrue(scheduledTasks.contains(new Integer(1)));
Assert.assertEquals(2, manager.numSourceTasksCompleted);
Assert.assertEquals(2, manager.numVertexManagerEventsReceived);
Assert.assertEquals(1000L, manager.completedSourceTasksOutputSize);
// more completions dont cause recalculation of parallelism
manager.onSourceTaskCompleted(mockSrcVertexId2, new Integer(0));
verify(mockContext).setVertexParallelism(eq(2), any(VertexLocationHint.class), anyMap());
Assert.assertEquals(2, newEdgeManagers.size());
EdgeManager edgeManager = newEdgeManagers.values().iterator().next();
Map<Integer, List<Integer>> targets = Maps.newHashMap();
DataMovementEvent dmEvent = new DataMovementEvent(1, new byte[0]);
edgeManager.routeDataMovementEventToDestination(dmEvent, 1, 2, targets);
Assert.assertEquals(1, targets.size());
Map.Entry<Integer, List<Integer>> e = targets.entrySet().iterator().next();
Assert.assertEquals(3, e.getKey().intValue());
Assert.assertEquals(1, e.getValue().size());
Assert.assertEquals(0, e.getValue().get(0).intValue());
targets.clear();
dmEvent = new DataMovementEvent(2, new byte[0]);
edgeManager.routeDataMovementEventToDestination(dmEvent, 0, 2, targets);
Assert.assertEquals(1, targets.size());
e = targets.entrySet().iterator().next();
Assert.assertEquals(0, e.getKey().intValue());
Assert.assertEquals(1, e.getValue().size());
Assert.assertEquals(1, e.getValue().get(0).intValue());
targets.clear();
edgeManager.routeInputSourceTaskFailedEventToDestination(2, 2, targets);
Assert.assertEquals(2, targets.size());
for (Map.Entry<Integer, List<Integer>> entry : targets.entrySet()) {
Assert.assertTrue(entry.getKey().intValue() == 4 || entry.getKey().intValue() == 5);
Assert.assertEquals(2, entry.getValue().size());
Assert.assertEquals(0, entry.getValue().get(0).intValue());