/***********************************************************************************************************************
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
**********************************************************************************************************************/
package eu.stratosphere.test.accumulators;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import eu.stratosphere.test.util.RecordAPITestBase;
import org.junit.Assert;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import eu.stratosphere.api.common.JobExecutionResult;
import eu.stratosphere.api.common.Plan;
import eu.stratosphere.api.common.accumulators.Accumulator;
import eu.stratosphere.api.common.accumulators.AccumulatorHelper;
import eu.stratosphere.api.common.accumulators.DoubleCounter;
import eu.stratosphere.api.common.accumulators.Histogram;
import eu.stratosphere.api.common.accumulators.IntCounter;
import eu.stratosphere.api.java.record.operators.FileDataSink;
import eu.stratosphere.api.java.record.operators.FileDataSource;
import eu.stratosphere.api.java.record.functions.MapFunction;
import eu.stratosphere.api.java.record.functions.ReduceFunction;
import eu.stratosphere.api.java.record.functions.FunctionAnnotation.ConstantFields;
import eu.stratosphere.api.java.record.io.CsvOutputFormat;
import eu.stratosphere.api.java.record.io.TextInputFormat;
import eu.stratosphere.api.java.record.operators.MapOperator;
import eu.stratosphere.api.java.record.operators.ReduceOperator;
import eu.stratosphere.api.java.record.operators.ReduceOperator.Combinable;
import eu.stratosphere.configuration.Configuration;
import eu.stratosphere.core.io.IOReadableWritable;
import eu.stratosphere.core.io.StringRecord;
import eu.stratosphere.nephele.util.SerializableHashSet;
import eu.stratosphere.types.IntValue;
import eu.stratosphere.types.Record;
import eu.stratosphere.types.StringValue;
import eu.stratosphere.util.SimpleStringUtils;
import eu.stratosphere.util.Collector;
/**
* Test for the basic functionality of accumulators. We cannot test all different
* kinds of plans here (iterative, etc.).
*
* TODO Test conflict when different UDFs write to accumulator with same name
* but with different type. The conflict will occur in JobManager while merging.
*/
@RunWith(Parameterized.class)
public class AccumulatorITCase extends RecordAPITestBase {
private static final String INPUT = "one\n" + "two two\n" + "three three three\n";
private static final String EXPECTED = "one 1\ntwo 2\nthree 3\n";
private static final int NUM_SUBTASKS = 2;
protected String dataPath;
protected String resultPath;
public AccumulatorITCase(Configuration config) {
super(config);
}
@Override
protected void preSubmit() throws Exception {
dataPath = createTempFile("datapoints.txt", INPUT);
resultPath = getTempFilePath("result");
}
@Override
protected void postSubmit() throws Exception {
compareResultsByLinesInMemory(EXPECTED, resultPath);
// Test accumulator results
System.out.println("Accumulator results:");
JobExecutionResult res = getJobExecutionResult();
System.out.println(AccumulatorHelper.getResultsFormated(res.getAllAccumulatorResults()));
Assert.assertEquals(new Integer(3), (Integer) res.getAccumulatorResult("num-lines"));
Assert.assertEquals(new Double(NUM_SUBTASKS), (Double)res.getAccumulatorResult("open-close-counter"));
// Test histogram (words per line distribution)
Map<Integer, Integer> dist = Maps.newHashMap();
dist.put(1, 1); dist.put(2, 2); dist.put(3, 3);
Assert.assertEquals(dist, res.getAccumulatorResult("words-per-line"));
// Test distinct words (custom accumulator)
Set<StringRecord> distinctWords = Sets.newHashSet();
distinctWords.add(new StringRecord("one"));
distinctWords.add(new StringRecord("two"));
distinctWords.add(new StringRecord("three"));
Assert.assertEquals(distinctWords, res.getAccumulatorResult("distinct-words"));
}
@Override
protected Plan getTestJob() {
Plan plan = getTestPlanPlan(config.getInteger("IterationAllReducer#NoSubtasks", 1), dataPath, resultPath);
return plan;
}
@Parameters
public static Collection<Object[]> getConfigurations() {
Configuration config1 = new Configuration();
config1.setInteger("IterationAllReducer#NoSubtasks", NUM_SUBTASKS);
return toParameterList(config1);
}
static Plan getTestPlanPlan(int numSubTasks, String input, String output) {
FileDataSource source = new FileDataSource(new TextInputFormat(), input, "Input Lines");
source.setParameter(TextInputFormat.CHARSET_NAME, "ASCII");
MapOperator mapper = MapOperator.builder(new TokenizeLine())
.input(source)
.name("Tokenize Lines")
.build();
ReduceOperator reducer = ReduceOperator.builder(CountWords.class, StringValue.class, 0)
.input(mapper)
.name("Count Words")
.build();
@SuppressWarnings("unchecked")
FileDataSink out = new FileDataSink(new CsvOutputFormat("\n"," ", StringValue.class, IntValue.class), output, reducer, "Word Counts");
Plan plan = new Plan(out, "WordCount Example");
plan.setDefaultParallelism(numSubTasks);
return plan;
}
public static class TokenizeLine extends MapFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final Record outputRecord = new Record();
private StringValue word;
private final IntValue one = new IntValue(1);
private final SimpleStringUtils.WhitespaceTokenizer tokenizer = new SimpleStringUtils.WhitespaceTokenizer();
// Needs to be instantiated later since the runtime context is not yet
// initialized at this place
IntCounter cntNumLines = null;
Histogram wordsPerLineDistribution = null;
// This counter will be added without convenience functions
DoubleCounter openCloseCounter = new DoubleCounter();
private SetAccumulator<StringRecord> distinctWords = null;
@Override
public void open(Configuration parameters) throws Exception {
// Add counters using convenience functions
this.cntNumLines = getRuntimeContext().getIntCounter("num-lines");
this.wordsPerLineDistribution = getRuntimeContext().getHistogram("words-per-line");
// Add built-in accumulator without convenience function
getRuntimeContext().addAccumulator("open-close-counter", this.openCloseCounter);
// Add custom counter. Didn't find a way to do this with
// getAccumulator()
this.distinctWords = new SetAccumulator<StringRecord>();
this.getRuntimeContext().addAccumulator("distinct-words", distinctWords);
// Create counter and test increment
IntCounter simpleCounter = getRuntimeContext().getIntCounter("simple-counter");
simpleCounter.add(1);
Assert.assertEquals(simpleCounter.getLocalValue().intValue(), 1);
// Test if we get the same counter
IntCounter simpleCounter2 = getRuntimeContext().getIntCounter("simple-counter");
Assert.assertEquals(simpleCounter.getLocalValue(), simpleCounter2.getLocalValue());
// Should fail if we request it with different type
try {
@SuppressWarnings("unused")
DoubleCounter simpleCounter3 = getRuntimeContext().getDoubleCounter("simple-counter");
// DoubleSumAggregator longAggregator3 = (DoubleSumAggregator)
// getRuntimeContext().getAggregator("custom",
// DoubleSumAggregator.class);
Assert.fail("Should not be able to obtain previously created counter with different type");
} catch (UnsupportedOperationException ex) {
}
// Test counter used in open() and closed()
this.openCloseCounter.add(0.5);
}
@Override
public void map(Record record, Collector<Record> collector) {
this.cntNumLines.add(1);
StringValue line = record.getField(0, StringValue.class);
SimpleStringUtils.replaceNonWordChars(line, ' ');
SimpleStringUtils.toLowerCase(line);
this.tokenizer.setStringToTokenize(line);
int wordsPerLine = 0;
this.word = new StringValue();
while (tokenizer.next(this.word))
{
// Use custom counter
distinctWords.add(new StringRecord(this.word.getValue()));
this.outputRecord.setField(0, this.word);
this.outputRecord.setField(1, this.one);
collector.collect(this.outputRecord);
++ wordsPerLine;
}
wordsPerLineDistribution.add(wordsPerLine);
}
@Override
public void close() throws Exception {
// Test counter used in open and close only
this.openCloseCounter.add(0.5);
Assert.assertEquals(1, this.openCloseCounter.getLocalValue().intValue());
}
}
@Combinable
@ConstantFields(0)
public static class CountWords extends ReduceFunction implements Serializable {
private static final long serialVersionUID = 1L;
private final IntValue cnt = new IntValue();
private IntCounter reduceCalls = null;
private IntCounter combineCalls = null;
@Override
public void open(Configuration parameters) throws Exception {
this.reduceCalls = getRuntimeContext().getIntCounter("reduce-calls");
this.combineCalls = getRuntimeContext().getIntCounter("combine-calls");
}
@Override
public void reduce(Iterator<Record> records, Collector<Record> out) throws Exception {
reduceCalls.add(1);
reduceInternal(records, out);
}
@Override
public void combine(Iterator<Record> records, Collector<Record> out) throws Exception {
combineCalls.add(1);
reduceInternal(records, out);
}
private void reduceInternal(Iterator<Record> records, Collector<Record> out) {
Record element = null;
int sum = 0;
while (records.hasNext()) {
element = records.next();
IntValue i = element.getField(1, IntValue.class);
sum += i.getValue();
}
this.cnt.setValue(sum);
element.setField(1, this.cnt);
out.collect(element);
}
}
/**
* Custom accumulator
*/
public static class SetAccumulator<T extends IOReadableWritable> implements Accumulator<T, Set<T>> {
private static final long serialVersionUID = 1L;
private SerializableHashSet<T> set = new SerializableHashSet<T>();
@Override
public void add(T value) {
this.set.add(value);
}
@Override
public Set<T> getLocalValue() {
return this.set;
}
@Override
public void resetLocal() {
this.set.clear();
}
@Override
public void merge(Accumulator<T, Set<T>> other) {
// build union
this.set.addAll(((SetAccumulator<T>) other).getLocalValue());
}
@Override
public void write(DataOutput out) throws IOException {
this.set.write(out);
}
@Override
public void read(DataInput in) throws IOException {
this.set.read(in);
}
}
}