package brickhouse.udf.sketch;
/**
* Copyright 2012 Klout, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**/
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
import org.apache.log4j.Logger;
import brickhouse.analytics.uniques.SketchSet;
/**
* UDAF to collect the union of multiple sketch sets
* and return a sketch representing the union of all
* the sketches.
*
*/
import java.util.List;
@Description(name="union_sketch",
value = "_FUNC_(x) - Constructs a sketch set to estimate reach for large values by collecting multiple sketches "
)
public class UnionSketchSetUDAF extends AbstractGenericUDAFResolver {
private static final Logger LOG = Logger.getLogger(UnionSketchSetUDAF.class);
public static int DEFAULT_SKETCH_SET_SIZE = 5000;
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
return new MergeSketchSetUDAFEvaluator();
}
public static class MergeSketchSetUDAFEvaluator extends GenericUDAFEvaluator {
// For PARTIAL1 and COMPLETE: ObjectInspectors for original data
private StringObjectInspector inputStrOI;
private ListObjectInspector inputListOI;
private MapObjectInspector partialMapOI;
private LongObjectInspector partialMapHashOI;
private StringObjectInspector partialMapStrOI;
private int sketchSetSize = -1;
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
///
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
//// iterate() gets called.. list of string is passed in
if( parameters[0].getCategory() == Category.LIST) {
this.inputListOI = (ListObjectInspector) parameters[0];
if(((PrimitiveObjectInspector)this.inputListOI.getListElementObjectInspector()).getPrimitiveCategory() == PrimitiveCategory.STRING ) {
this.inputStrOI = (StringObjectInspector)this.inputListOI.getListElementObjectInspector();
}
} else {
throw new HiveException(" union_sketch expecting list of strings as arg");
}
if( parameters.length > 1 && m == Mode.PARTIAL1) {
//// get the sketch set size from the second parameters
if(!( parameters[1] instanceof ConstantObjectInspector ) ) {
throw new HiveException("Sketch Set size must be a constant");
}
ConstantObjectInspector sizeOI = (ConstantObjectInspector) parameters[1];
this.sketchSetSize = ((IntWritable) sizeOI.getWritableConstantValue()).get();
} else {
sketchSetSize = DEFAULT_SKETCH_SET_SIZE;
}
} else { /// Mode m == Mode.PARTIAL2 || m == Mode.FINAL
/// merge() gets called ... map is passed in ..
this.partialMapOI = (MapObjectInspector) parameters[0];
this.partialMapHashOI = (LongObjectInspector) partialMapOI.getMapKeyObjectInspector();
this.partialMapStrOI = (StringObjectInspector) partialMapOI.getMapValueObjectInspector();
}
/// The intermediate result is a map of hashes and strings,
/// The final result is an array of strings
if( m == Mode.FINAL || m == Mode.COMPLETE) {
/// for final result
return ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector );
} else { /// m == Mode.PARTIAL1 || m == Mode.PARTIAL2
return ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaLongObjectInspector,
PrimitiveObjectInspectorFactory.javaStringObjectInspector
);
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
SketchSetBuffer buff= new SketchSetBuffer();
buff.init(sketchSetSize);
return buff;
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
Object listStrObj = parameters[0];
if (listStrObj != null) {
List listStr = inputListOI.getList(listStrObj);
for( Object strObj : listStr ) {
String str = inputStrOI.getPrimitiveJavaObject( strObj);
SketchSetBuffer myagg = (SketchSetBuffer) agg;
myagg.addItem(str);
}
}
}
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
/// Partial is going to be a map of strings and hashes
SketchSetBuffer myagg = (SketchSetBuffer) agg;
if( partial != null) {
Map<Object,Object> partialResult = (Map<Object,Object>) this.partialMapOI.getMap(partial);
if( partialResult !=null) {
//// Place SKETCH_SIZE into the partial map ...
if(myagg.getSize() == -1) {
for( Map.Entry entry : partialResult.entrySet()) {
Long hash = this.partialMapHashOI.get( entry.getKey());
String item = partialMapStrOI.getPrimitiveJavaObject( entry.getValue());
if(item.equals(SketchSetUDAF.SKETCH_SIZE_STR)) {
this.sketchSetSize = (int)hash.intValue();
myagg.init(sketchSetSize);
break;
}
}
}
for( Map.Entry entry : partialResult.entrySet()) {
Long hash = this.partialMapHashOI.get( entry.getKey());
String item = partialMapStrOI.getPrimitiveJavaObject( entry.getValue());
if(!item.equals(SketchSetUDAF.SKETCH_SIZE_STR)) {
myagg.addHash(hash, item);
}
}
}
}
}
@Override
public void reset(AggregationBuffer buff) throws HiveException {
SketchSetBuffer sketchBuff = (SketchSetBuffer) buff;
sketchBuff.reset();
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
SketchSetBuffer myagg = (SketchSetBuffer) agg;
return myagg.getSketchItems();
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
SketchSetBuffer myagg = (SketchSetBuffer)agg;
return myagg.getPartialMap();
}
}
}