/***********************************************************************************************************************
*
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package eu.stratosphere.compiler.postpass;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import eu.stratosphere.api.common.operators.DualInputOperator;
import eu.stratosphere.api.common.operators.base.BulkIterationBase;
import eu.stratosphere.api.common.operators.base.DeltaIterationBase;
import eu.stratosphere.api.common.operators.Operator;
import eu.stratosphere.api.common.operators.SingleInputOperator;
import eu.stratosphere.api.common.operators.base.GenericDataSourceBase;
import eu.stratosphere.api.common.operators.base.GroupReduceOperatorBase;
import eu.stratosphere.api.common.operators.util.FieldList;
import eu.stratosphere.api.common.typeutils.TypeComparator;
import eu.stratosphere.api.common.typeutils.TypeComparatorFactory;
import eu.stratosphere.api.common.typeutils.TypePairComparatorFactory;
import eu.stratosphere.api.common.typeutils.TypeSerializer;
import eu.stratosphere.api.common.typeutils.TypeSerializerFactory;
import eu.stratosphere.api.java.operators.translation.PlanUnwrappingReduceGroupOperator;
import eu.stratosphere.api.java.tuple.Tuple;
import eu.stratosphere.api.java.typeutils.AtomicType;
import eu.stratosphere.api.java.typeutils.CompositeType;
import eu.stratosphere.types.TypeInformation;
import eu.stratosphere.api.java.typeutils.runtime.RuntimeComparatorFactory;
import eu.stratosphere.api.java.typeutils.runtime.RuntimePairComparatorFactory;
import eu.stratosphere.api.java.typeutils.runtime.RuntimeStatelessSerializerFactory;
import eu.stratosphere.api.java.typeutils.runtime.RuntimeStatefulSerializerFactory;
import eu.stratosphere.compiler.CompilerException;
import eu.stratosphere.compiler.CompilerPostPassException;
import eu.stratosphere.compiler.plan.BulkIterationPlanNode;
import eu.stratosphere.compiler.plan.BulkPartialSolutionPlanNode;
import eu.stratosphere.compiler.plan.Channel;
import eu.stratosphere.compiler.plan.DualInputPlanNode;
import eu.stratosphere.compiler.plan.NAryUnionPlanNode;
import eu.stratosphere.compiler.plan.OptimizedPlan;
import eu.stratosphere.compiler.plan.PlanNode;
import eu.stratosphere.compiler.plan.SingleInputPlanNode;
import eu.stratosphere.compiler.plan.SinkPlanNode;
import eu.stratosphere.compiler.plan.SolutionSetPlanNode;
import eu.stratosphere.compiler.plan.SourcePlanNode;
import eu.stratosphere.compiler.plan.WorksetIterationPlanNode;
import eu.stratosphere.compiler.plan.WorksetPlanNode;
import eu.stratosphere.compiler.util.NoOpUnaryUdfOp;
import eu.stratosphere.pact.runtime.task.DriverStrategy;
/**
* The post-optimizer plan traversal. This traversal fills in the API specific utilities (serializers and
* comparators).
*/
public class JavaApiPostPass implements OptimizerPostPass {
private final Set<PlanNode> alreadyDone = new HashSet<PlanNode>();
@Override
public void postPass(OptimizedPlan plan) {
for (SinkPlanNode sink : plan.getDataSinks()) {
traverse(sink);
}
}
protected void traverse(PlanNode node) {
if (!alreadyDone.add(node)) {
// already worked on that one
return;
}
// distinguish the node types
if (node instanceof SinkPlanNode) {
// descend to the input channel
SinkPlanNode sn = (SinkPlanNode) node;
Channel inchannel = sn.getInput();
traverseChannel(inchannel);
}
else if (node instanceof SourcePlanNode) {
TypeInformation<?> typeInfo = getTypeInfoFromSource((SourcePlanNode) node);
((SourcePlanNode) node).setSerializer(createSerializer(typeInfo));
}
else if (node instanceof BulkIterationPlanNode) {
BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node;
if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) {
throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node.");
}
// traverse the termination criterion for the first time. create schema only, no utilities. Needed in case of intermediate termination criterion
if (iterationNode.getRootOfTerminationCriterion() != null) {
SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion();
traverseChannel(addMapper.getInput());
}
BulkIterationBase<?> operator = (BulkIterationBase<?>) iterationNode.getPactContract();
// set the serializer
iterationNode.setSerializerForIterationChannel(createSerializer(operator.getOperatorInfo().getOutputType()));
// done, we can now propagate our info down
traverseChannel(iterationNode.getInput());
traverse(iterationNode.getRootOfStepFunction());
}
else if (node instanceof WorksetIterationPlanNode) {
WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node;
if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) {
throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node.");
}
if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) {
throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node.");
}
DeltaIterationBase<?, ?> operator = (DeltaIterationBase<?, ?>) iterationNode.getPactContract();
// set the serializers and comparators for the workset iteration
iterationNode.setSolutionSetSerializer(createSerializer(operator.getOperatorInfo().getFirstInputType()));
iterationNode.setWorksetSerializer(createSerializer(operator.getOperatorInfo().getSecondInputType()));
iterationNode.setSolutionSetComparator(createComparator(operator.getOperatorInfo().getFirstInputType(),
iterationNode.getSolutionSetKeyFields(), getSortOrders(iterationNode.getSolutionSetKeyFields(), null)));
// traverse the inputs
traverseChannel(iterationNode.getInput1());
traverseChannel(iterationNode.getInput2());
// traverse the step function
traverse(iterationNode.getSolutionSetDeltaPlanNode());
traverse(iterationNode.getNextWorkSetPlanNode());
}
else if (node instanceof SingleInputPlanNode) {
SingleInputPlanNode sn = (SingleInputPlanNode) node;
if (!(sn.getOptimizerNode().getPactContract() instanceof SingleInputOperator)) {
// Special case for delta iterations
if(sn.getOptimizerNode().getPactContract() instanceof NoOpUnaryUdfOp) {
traverseChannel(sn.getInput());
return;
} else {
throw new RuntimeException("Wrong operator type found in post pass.");
}
}
SingleInputOperator<?, ?, ?> singleInputOperator = (SingleInputOperator<?, ?, ?>) sn.getOptimizerNode().getPactContract();
// parameterize the node's driver strategy
if (sn.getDriverStrategy().requiresComparator()) {
sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(),
getSortOrders(sn.getKeys(), sn.getSortOrders())));
}
// done, we can now propagate our info down
traverseChannel(sn.getInput());
// don't forget the broadcast inputs
for (Channel c: sn.getBroadcastInputs()) {
traverseChannel(c);
}
}
else if (node instanceof DualInputPlanNode) {
DualInputPlanNode dn = (DualInputPlanNode) node;
if (!(dn.getOptimizerNode().getPactContract() instanceof DualInputOperator)) {
throw new RuntimeException("Wrong operator type found in post pass.");
}
DualInputOperator<?, ?, ?, ?> dualInputOperator = (DualInputOperator<?, ?, ?, ?>) dn.getOptimizerNode().getPactContract();
// parameterize the node's driver strategy
if (dn.getDriverStrategy().requiresComparator()) {
dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(),
getSortOrders(dn.getKeysForInput1(), dn.getSortOrders())));
dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(),
getSortOrders(dn.getKeysForInput2(), dn.getSortOrders())));
dn.setPairComparator(createPairComparator(dualInputOperator.getOperatorInfo().getFirstInputType(),
dualInputOperator.getOperatorInfo().getSecondInputType()));
}
traverseChannel(dn.getInput1());
traverseChannel(dn.getInput2());
// don't forget the broadcast inputs
for (Channel c: dn.getBroadcastInputs()) {
traverseChannel(c);
}
}
// catch the sources of the iterative step functions
else if (node instanceof BulkPartialSolutionPlanNode ||
node instanceof SolutionSetPlanNode ||
node instanceof WorksetPlanNode)
{
// Do nothing :D
}
else if (node instanceof NAryUnionPlanNode){
// Traverse to all child channels
for (Iterator<Channel> channels = node.getInputs(); channels.hasNext(); ) {
traverseChannel(channels.next());
}
}
else {
throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName());
}
}
private void traverseChannel(Channel channel) {
PlanNode source = channel.getSource();
Operator<?> javaOp = source.getPactContract();
// if (!(javaOp instanceof BulkIteration) && !(javaOp instanceof JavaPlanNode)) {
// throw new RuntimeException("Wrong operator type found in post pass: " + javaOp);
// }
TypeInformation<?> type = javaOp.getOperatorInfo().getOutputType();
if(javaOp instanceof GroupReduceOperatorBase &&
(source.getDriverStrategy() == DriverStrategy.SORTED_GROUP_COMBINE || source.getDriverStrategy() == DriverStrategy.ALL_GROUP_COMBINE)) {
GroupReduceOperatorBase<?, ?, ?> groupNode = (GroupReduceOperatorBase<?, ?, ?>) javaOp;
type = groupNode.getInput().getOperatorInfo().getOutputType();
}
else if(javaOp instanceof PlanUnwrappingReduceGroupOperator &&
source.getDriverStrategy().equals(DriverStrategy.SORTED_GROUP_COMBINE)) {
PlanUnwrappingReduceGroupOperator<?, ?, ?> groupNode = (PlanUnwrappingReduceGroupOperator<?, ?, ?>) javaOp;
type = groupNode.getInput().getOperatorInfo().getOutputType();
}
// the serializer always exists
channel.setSerializer(createSerializer(type));
// parameterize the ship strategy
if (channel.getShipStrategy().requiresComparator()) {
channel.setShipStrategyComparator(createComparator(type, channel.getShipStrategyKeys(),
getSortOrders(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder())));
}
// parameterize the local strategy
if (channel.getLocalStrategy().requiresComparator()) {
channel.setLocalStrategyComparator(createComparator(type, channel.getLocalStrategyKeys(),
getSortOrders(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder())));
}
// descend to the channel's source
traverse(channel.getSource());
}
@SuppressWarnings("unchecked")
private static <T> TypeInformation<T> getTypeInfoFromSource(SourcePlanNode node) {
Operator<?> op = node.getOptimizerNode().getPactContract();
if (op instanceof GenericDataSourceBase) {
return ((GenericDataSourceBase<T, ?>) op).getOperatorInfo().getOutputType();
} else {
throw new RuntimeException("Wrong operator type found in post pass.");
}
}
private static <T> TypeSerializerFactory<?> createSerializer(TypeInformation<T> typeInfo) {
TypeSerializer<T> serializer = typeInfo.createSerializer();
if (serializer.isStateful()) {
return new RuntimeStatefulSerializerFactory<T>(serializer, typeInfo.getTypeClass());
} else {
return new RuntimeStatelessSerializerFactory<T>(serializer, typeInfo.getTypeClass());
}
}
@SuppressWarnings("unchecked")
private static <T> TypeComparatorFactory<?> createComparator(TypeInformation<T> typeInfo, FieldList keys, boolean[] sortOrder) {
TypeComparator<T> comparator;
if (typeInfo instanceof CompositeType) {
comparator = ((CompositeType<T>) typeInfo).createComparator(keys.toArray(), sortOrder);
}
else if (typeInfo instanceof AtomicType) {
// handle grouping of atomic types
throw new UnsupportedOperationException("Grouping on atomic types is currently not implemented. " + typeInfo);
}
else {
throw new RuntimeException("Unrecognized type: " + typeInfo);
}
return new RuntimeComparatorFactory<T>(comparator);
}
private static <T1 extends Tuple, T2 extends Tuple> TypePairComparatorFactory<T1,T2> createPairComparator(TypeInformation<?> typeInfo1, TypeInformation<?> typeInfo2) {
if (!(typeInfo1.isTupleType() && typeInfo2.isTupleType())) {
throw new RuntimeException("The runtime currently supports only keyed binary operations on tuples.");
}
// @SuppressWarnings("unchecked")
// TupleTypeInfo<T1> info1 = (TupleTypeInfo<T1>) typeInfo1;
// @SuppressWarnings("unchecked")
// TupleTypeInfo<T2> info2 = (TupleTypeInfo<T2>) typeInfo2;
return new RuntimePairComparatorFactory<T1,T2>();
}
private static final boolean[] getSortOrders(FieldList keys, boolean[] orders) {
if (orders == null) {
orders = new boolean[keys.size()];
Arrays.fill(orders, true);
}
return orders;
}
}