{
final int numStepFunctionOuts = headConfig.getNumOutputs();
final int numFinalOuts = headFinalOutputConfig.getNumOutputs();
if (numStepFunctionOuts == 0) {
throw new CompilerException("The workset iteration has no operation on the workset inside the step function.");
}
headConfig.setIterationHeadFinalOutputConfig(headFinalOutputConfig);
headConfig.setIterationHeadIndexOfSyncOutput(numStepFunctionOuts + numFinalOuts);
final double relativeMemory = iterNode.getRelativeMemoryPerSubTask();
if (relativeMemory <= 0) {
throw new CompilerException("Bug: No memory has been assigned to the workset iteration.");
}
headConfig.setIsWorksetIteration();
headConfig.setRelativeBackChannelMemory(relativeMemory / 2);
headConfig.setRelativeSolutionSetMemory(relativeMemory / 2);
// set the solution set serializer and comparator
headConfig.setSolutionSetSerializer(iterNode.getSolutionSetSerializer());
headConfig.setSolutionSetComparator(iterNode.getSolutionSetComparator());
}
// --------------------------- create the sync task ---------------------------
final TaskConfig syncConfig;
{
final AbstractJobVertex sync = new AbstractJobVertex("Sync (" + iterNode.getNodeName() + ")");
sync.setInvokableClass(IterationSynchronizationSinkTask.class);
sync.setParallelism(1);
this.auxVertices.add(sync);
syncConfig = new TaskConfig(sync.getConfiguration());
syncConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(0, headVertex.getParallelism());
// set the number of iteration / convergence criterion for the sync
final int maxNumIterations = iterNode.getIterationNode().getIterationContract().getMaximumNumberOfIterations();
if (maxNumIterations < 1) {
throw new CompilerException("Cannot create workset iteration with unspecified maximum number of iterations.");
}
syncConfig.setNumberOfIterations(maxNumIterations);
// connect the sync task
sync.connectNewDataSetAsInput(headVertex, DistributionPattern.POINTWISE);
}
// ----------------------------- create the iteration tails -----------------------------
// ----------------------- for next workset and solution set delta-----------------------
{
// we have three possible cases:
// 1) Two tails, one for workset update, one for solution set update
// 2) One tail for workset update, solution set update happens in an intermediate task
// 3) One tail for solution set update, workset update happens in an intermediate task
final PlanNode nextWorksetNode = iterNode.getNextWorkSetPlanNode();
final PlanNode solutionDeltaNode = iterNode.getSolutionSetDeltaPlanNode();
final boolean hasWorksetTail = nextWorksetNode.getOutgoingChannels().isEmpty();
final boolean hasSolutionSetTail = (!iterNode.isImmediateSolutionSetUpdate()) || (!hasWorksetTail);
{
// get the vertex for the workset update
final TaskConfig worksetTailConfig;
AbstractJobVertex nextWorksetVertex = (AbstractJobVertex) this.vertices.get(nextWorksetNode);
if (nextWorksetVertex == null) {
// nextWorksetVertex is chained
TaskInChain taskInChain = this.chainedTasks.get(nextWorksetNode);
if (taskInChain == null) {
throw new CompilerException("Bug: Next workset node not found as vertex or chained task.");
}
nextWorksetVertex = (AbstractJobVertex) taskInChain.getContainingVertex();
worksetTailConfig = taskInChain.getTaskConfig();
} else {
worksetTailConfig = new TaskConfig(nextWorksetVertex.getConfiguration());
}
// mark the node to perform workset updates
worksetTailConfig.setIsWorksetIteration();
worksetTailConfig.setIsWorksetUpdate();
if (hasWorksetTail) {
nextWorksetVertex.setInvokableClass(IterationTailPactTask.class);
worksetTailConfig.setOutputSerializer(iterNode.getWorksetSerializer());
}
}
{
final TaskConfig solutionDeltaConfig;
AbstractJobVertex solutionDeltaVertex = (AbstractJobVertex) this.vertices.get(solutionDeltaNode);
if (solutionDeltaVertex == null) {
// last op is chained
TaskInChain taskInChain = this.chainedTasks.get(solutionDeltaNode);
if (taskInChain == null) {
throw new CompilerException("Bug: Solution Set Delta not found as vertex or chained task.");
}
solutionDeltaVertex = (AbstractJobVertex) taskInChain.getContainingVertex();
solutionDeltaConfig = taskInChain.getTaskConfig();
} else {
solutionDeltaConfig = new TaskConfig(solutionDeltaVertex.getConfiguration());
}
solutionDeltaConfig.setIsWorksetIteration();
solutionDeltaConfig.setIsSolutionSetUpdate();
if (hasSolutionSetTail) {
solutionDeltaVertex.setInvokableClass(IterationTailPactTask.class);
solutionDeltaConfig.setOutputSerializer(iterNode.getSolutionSetSerializer());
// tell the head that it needs to wait for the solution set updates
headConfig.setWaitForSolutionSetUpdate();
}
else {
// no tail, intermediate update. must be immediate update
if (!iterNode.isImmediateSolutionSetUpdate()) {
throw new CompilerException("A solution set update without dedicated tail is not set to perform immediate updates.");
}
solutionDeltaConfig.setIsSolutionSetUpdateWithoutReprobe();
}
}
}
// ------------------- register the aggregators -------------------
AggregatorRegistry aggs = iterNode.getIterationNode().getIterationContract().getAggregators();
Collection<AggregatorWithName<?>> allAggregators = aggs.getAllRegisteredAggregators();
for (AggregatorWithName<?> agg : allAggregators) {
if (agg.getName().equals(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME)) {
throw new CompilerException("User defined aggregator used the same name as built-in workset " +
"termination check aggregator: " + WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME);
}
}
headConfig.addIterationAggregators(allAggregators);
syncConfig.addIterationAggregators(allAggregators);
String convAggName = aggs.getConvergenceCriterionAggregatorName();
ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
if (convCriterion != null || convAggName != null) {
throw new CompilerException("Error: Cannot use custom convergence criterion with workset iteration. Workset iterations have implicit convergence criterion where workset is empty.");
}
headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator());
syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion());