Package brickhouse.udf.collect

Source Code of brickhouse.udf.collect.UnionMaxUDAF

package brickhouse.udf.collect;
/**
* Copyright 2012 Klout, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**/

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.TreeSet;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.log4j.Logger;


/**
*  UDAF to merge a union of maps,
*    but only hold on the keys with the top 20 values
*/

@Description(name="union_max",
    value = "_FUNC_(x,  n) - Returns an map of the union of maps of max N elements in the aggregation group "
)
public class UnionMaxUDAF extends AbstractGenericUDAFResolver {
  private static final Logger LOG = Logger.getLogger(UnionMaxUDAF.class);
  public static int DEFAULT_MAX_VALUES = 20;


  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
      throws SemanticException {
    return new MapCollectMaxUDAFEvaluator();
  }


  public static class MapCollectMaxUDAFEvaluator extends GenericUDAFEvaluator {
    // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
    private PrimitiveObjectInspector  inputKeyOI;
    private PrimitiveObjectInspector inputValOI; /// XXX Support nested values instead of just primitives as values
    private MapObjectInspector inputMapOI;
    private IntObjectInspector nvOI;
    // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
    // of objs)
    private StandardMapObjectInspector moi;
    private StandardMapObjectInspector internalMergeOI;


    public static class SortedKeyValue implements Comparable {
      private String key;
      private Double value;


      public SortedKeyValue(String key, Double val) {
        this.key = key;
        this.value = val;
      }

      @Override
      public boolean equals(Object other) {
        if(!(other instanceof SortedKeyValue )) {
          return false;
        }
        SortedKeyValue otherKV= (SortedKeyValue)other;
        if( key.equals(otherKV.key)) {
          return true;
        } else {
          return false;
        }

      }

      public String getKey() { return key; }
      public Double getValue() { return value; }

      @Override
      public int compareTo(Object arg1) {
        SortedKeyValue kv0 = (SortedKeyValue) this;
        SortedKeyValue kv1 = (SortedKeyValue) arg1;

        if( kv0.value != kv1.value ) {
          if(kv0.value > kv1.value)  {
            return -1;
          } else  {
            if( kv0.value < kv1.value ) {
              return 1;
            }
          }
          return kv0.key.compareTo(kv1.key);
        } else {
          return kv0.key.compareTo(kv1.key);
        }
      }
    }

    static class MapAggBuffer implements AggregationBuffer {
      private TreeSet<SortedKeyValue> sortedValues = new TreeSet<SortedKeyValue>();
      private int numValues = DEFAULT_MAX_VALUES;
      public void setNumValues(int nv) { numValues = nv; }

      public void addValue(String key, Double value) {
        if( sortedValues.size() < numValues) {
          sortedValues.add( new SortedKeyValue( key, value ));
        } else {
          SortedKeyValue minValue = sortedValues.last();
          if( value > minValue.getValue() ) {
            sortedValues.remove( minValue);
            sortedValues.add( new SortedKeyValue( key, value));
          }
        }
      }

      public void fromMap( Map<Object,Object> fromMap) {
        for( Object kObj : fromMap.keySet() ) {
          Object val = fromMap.get(kObj);
          addValue( (String)kObj, (Double)val);
        }
      }

      public Map<String, Double> getValueMap() {
        LinkedHashMap<String, Double> reverseOrderMap = new LinkedHashMap<String,Double>();
        for( SortedKeyValue kv : sortedValues ) {
          reverseOrderMap.put( kv.key, kv.value);
        }
        return reverseOrderMap;
      }

      public void reset() {
        sortedValues.clear();
      }
    }

    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
        throws HiveException {
      super.init(m, parameters);
      LOG.info(" UnionMaxUDAF.init() - Mode= " + m.name() );
      for(int i=0; i<parameters.length; ++i) {
        LOG.info(" ObjectInspector[ "+ i + " ] = " + parameters[0]);
      }
      if(parameters.length > 1) {
        nvOI = (IntObjectInspector) parameters[1];
      }

      // init output object inspectors
      // The output of a partial aggregation is a map
      if (m == Mode.PARTIAL1 ||  m == Mode.COMPLETE) {
        inputMapOI = (MapObjectInspector)parameters[ 0];
        inputKeyOI = (PrimitiveObjectInspector) inputMapOI.getMapKeyObjectInspector();
        inputValOI = (PrimitiveObjectInspector) inputMapOI.getMapValueObjectInspector();

        /**
         return ObjectInspectorFactory.getStandardMapObjectInspector(
         ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI),
         ObjectInspectorUtils.getStandardObjectInspector(inputValOI) );
         **/
      } else {
        if (!(parameters[0] instanceof StandardMapObjectInspector)) {
          LOG.info(" Not a standard map OjbectInspector " );
          inputKeyOI = (PrimitiveObjectInspectorObjectInspectorUtils
              .getStandardObjectInspector(parameters[0]);
          inputValOI = (PrimitiveObjectInspectorObjectInspectorUtils
              .getStandardObjectInspector(parameters[1]);
          /**
           return (StandardMapObjectInspector) ObjectInspectorFactory
           .getStandardMapObjectInspector(inputKeyOI, inputValOI);
           **/
        } else {
          internalMergeOI = (StandardMapObjectInspector) parameters[0];
          inputKeyOI = (PrimitiveObjectInspector) internalMergeOI.getMapKeyObjectInspector();
          inputValOI = (PrimitiveObjectInspector) internalMergeOI.getMapValueObjectInspector();
          /**
           moi =  (StandardMapObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
           return moi;
           **/
        }
      }
      return ObjectInspectorFactory.getStandardMapObjectInspector(
          PrimitiveObjectInspectorFactory.javaStringObjectInspector,
          PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
    }

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      MapAggBuffer buff= new MapAggBuffer();
      reset(buff);
      return buff;
    }

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
        throws HiveException {
      Map inMap = inputMapOI.getMap( parameters[0]);
      for( Object k:  inMap.keySet()  ) {
        Object v =  inMap.get( k);
        if (k == null || v == null) {
          throw new HiveException("Kay or value is null.  k = " + k + " , v = " + v);
        }

        if (k != null) {
          MapAggBuffer myagg = (MapAggBuffer) agg;

          if( parameters.length > 1 ) {
            Object numValsObj = parameters[1];
            int nv = nvOI.get( numValsObj);
            myagg.setNumValues( nv);
           
          }
          putIntoSet(k, v, myagg);
        }
      }
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial)
        throws HiveException {
      MapAggBuffer myagg = (MapAggBuffer) agg;
      Map<Object,Object> partialResult = (Map<Object,Object>internalMergeOI.getMap(partial);
      for(Object i : partialResult.keySet()) {
        putIntoSet(i, partialResult.get(i), myagg);
      }
    }

    @Override
    public void reset(AggregationBuffer buff) throws HiveException {
      MapAggBuffer arrayBuff = (MapAggBuffer) buff;
      arrayBuff.reset();
    }

    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      MapAggBuffer myagg = (MapAggBuffer) agg;
      return myagg.getValueMap();

    }

    private void putIntoSet(Object key, Object val, MapAggBuffer myagg) {
      StringObjectInspector strInsp = (StringObjectInspector) this.inputKeyOI;
      DoubleObjectInspector dblInsp = (DoubleObjectInspector) this.inputValOI;

      String keyCopy = strInsp.getPrimitiveJavaObject(key);
      Double valCopy = dblInsp.get(val);

      myagg.addValue(keyCopy, valCopy);
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {

      MapAggBuffer myagg = (MapAggBuffer) agg;
      Map<String, Double> vals =  myagg.getValueMap();
      return vals;
    }
  }


}
TOP

Related Classes of brickhouse.udf.collect.UnionMaxUDAF

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.