/***********************************************************************************************************************
*
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package eu.stratosphere.pact.compiler;
import org.junit.Test;
import eu.stratosphere.api.common.Plan;
import eu.stratosphere.api.common.operators.util.FieldList;
import eu.stratosphere.api.java.DataSet;
import eu.stratosphere.api.java.ExecutionEnvironment;
import eu.stratosphere.api.java.functions.KeySelector;
import eu.stratosphere.api.java.functions.ReduceFunction;
import eu.stratosphere.api.java.tuple.Tuple2;
import eu.stratosphere.compiler.plan.OptimizedPlan;
import eu.stratosphere.compiler.plan.SingleInputPlanNode;
import eu.stratosphere.compiler.plan.SinkPlanNode;
import eu.stratosphere.compiler.plan.SourcePlanNode;
import eu.stratosphere.pact.runtime.task.DriverStrategy;
import static org.junit.Assert.*;
@SuppressWarnings("serial")
public class ReduceCompilationTest extends CompilerTestBase implements java.io.Serializable {
@Test
public void testAllReduceNoCombiner() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(8);
DataSet<Double> data = env.fromElements(0.2, 0.3, 0.4, 0.5).name("source");
data.reduce(new ReduceFunction<Double>() {
@Override
public Double reduce(Double value1, Double value2){
return value1 + value2;
}
}).name("reducer")
.print().name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// the all-reduce has no combiner, when the DOP of the input is one
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// check wiring
assertEquals(sourceNode, reduceNode.getInput().getSource());
assertEquals(reduceNode, sinkNode.getInput().getSource());
// check DOP
assertEquals(1, sourceNode.getDegreeOfParallelism());
assertEquals(1, reduceNode.getDegreeOfParallelism());
assertEquals(1, sinkNode.getDegreeOfParallelism());
}
catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
@Test
public void testAllReduceWithCombiner() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(8);
DataSet<Long> data = env.generateSequence(1, 8000000).name("source");
data.reduce(new ReduceFunction<Long>() {
@Override
public Long reduce(Long value1, Long value2){
return value1 + value2;
}
}).name("reducer")
.print().name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// get the original nodes
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// get the combiner
SingleInputPlanNode combineNode = (SingleInputPlanNode) reduceNode.getInput().getSource();
// check wiring
assertEquals(sourceNode, combineNode.getInput().getSource());
assertEquals(reduceNode, sinkNode.getInput().getSource());
// check that both reduce and combiner have the same strategy
assertEquals(DriverStrategy.ALL_REDUCE, reduceNode.getDriverStrategy());
assertEquals(DriverStrategy.ALL_REDUCE, combineNode.getDriverStrategy());
// check DOP
assertEquals(8, sourceNode.getDegreeOfParallelism());
assertEquals(8, combineNode.getDegreeOfParallelism());
assertEquals(1, reduceNode.getDegreeOfParallelism());
assertEquals(1, sinkNode.getDegreeOfParallelism());
}
catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
@Test
public void testGroupedReduceWithFieldPositionKey() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(8);
DataSet<Tuple2<String, Double>> data = env.readCsvFile("file:///will/never/be/read").types(String.class, Double.class)
.name("source").setParallelism(6);
data
.groupBy(1)
.reduce(new ReduceFunction<Tuple2<String,Double>>() {
@Override
public Tuple2<String, Double> reduce(Tuple2<String, Double> value1, Tuple2<String, Double> value2){
return null;
}
}).name("reducer")
.print().name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// get the original nodes
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// get the combiner
SingleInputPlanNode combineNode = (SingleInputPlanNode) reduceNode.getInput().getSource();
// check wiring
assertEquals(sourceNode, combineNode.getInput().getSource());
assertEquals(reduceNode, sinkNode.getInput().getSource());
// check that both reduce and combiner have the same strategy
assertEquals(DriverStrategy.SORTED_REDUCE, reduceNode.getDriverStrategy());
assertEquals(DriverStrategy.SORTED_PARTIAL_REDUCE, combineNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(1), reduceNode.getKeys());
assertEquals(new FieldList(1), combineNode.getKeys());
assertEquals(new FieldList(1), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
assertEquals(6, sourceNode.getDegreeOfParallelism());
assertEquals(6, combineNode.getDegreeOfParallelism());
assertEquals(8, reduceNode.getDegreeOfParallelism());
assertEquals(8, sinkNode.getDegreeOfParallelism());
}
catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
@Test
public void testGroupedReduceWithSelectorFunctionKey() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(8);
DataSet<Tuple2<String, Double>> data = env.readCsvFile("file:///will/never/be/read").types(String.class, Double.class)
.name("source").setParallelism(6);
data
.groupBy(new KeySelector<Tuple2<String,Double>, String>() {
public String getKey(Tuple2<String, Double> value) { return value.f0; }
})
.reduce(new ReduceFunction<Tuple2<String,Double>>() {
@Override
public Tuple2<String, Double> reduce(Tuple2<String, Double> value1, Tuple2<String, Double> value2){
return null;
}
}).name("reducer")
.print().name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// get the original nodes
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// get the combiner
SingleInputPlanNode combineNode = (SingleInputPlanNode) reduceNode.getInput().getSource();
// get the key extractors and projectors
SingleInputPlanNode keyExtractor = (SingleInputPlanNode) combineNode.getInput().getSource();
SingleInputPlanNode keyProjector = (SingleInputPlanNode) sinkNode.getInput().getSource();
// check wiring
assertEquals(sourceNode, keyExtractor.getInput().getSource());
assertEquals(keyProjector, sinkNode.getInput().getSource());
// check that both reduce and combiner have the same strategy
assertEquals(DriverStrategy.SORTED_REDUCE, reduceNode.getDriverStrategy());
assertEquals(DriverStrategy.SORTED_PARTIAL_REDUCE, combineNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(0), reduceNode.getKeys());
assertEquals(new FieldList(0), combineNode.getKeys());
assertEquals(new FieldList(0), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
assertEquals(6, sourceNode.getDegreeOfParallelism());
assertEquals(6, keyExtractor.getDegreeOfParallelism());
assertEquals(6, combineNode.getDegreeOfParallelism());
assertEquals(8, reduceNode.getDegreeOfParallelism());
assertEquals(8, keyProjector.getDegreeOfParallelism());
assertEquals(8, sinkNode.getDegreeOfParallelism());
}
catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
}