Package brickhouse.udf.collect

Source Code of brickhouse.udf.collect.CollectMaxUDAF$MapCollectMaxUDAFEvaluator

package brickhouse.udf.collect;
/**
* Copyright 2012 Klout, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
**/
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.log4j.Logger;

import java.util.*;

import org.apache.hadoop.hive.ql.exec.Description;

@Description(name="collect_max",
    value = "_FUNC_(x, val, n) - Returns an map of the max N numeric values in the aggregation group "
)
public class CollectMaxUDAF extends AbstractGenericUDAFResolver {
  public static final Logger LOG = Logger.getLogger(CollectMaxUDAF.class);
  public static int DEFAULT_MAX_VALUES = 20;

  @Override
  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
      throws SemanticException {
    return new MapCollectMaxUDAFEvaluator();
  }

  public static class MapCollectMaxUDAFEvaluator extends GenericUDAFEvaluator {
    // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
    private PrimitiveObjectInspector  inputKeyOI;
    private PrimitiveObjectInspector inputValOI;
    // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
    // of objs)
    private StandardMapObjectInspector internalMergeOI;
    private boolean descending = true;
    private int numValues = DEFAULT_MAX_VALUES;


    public MapCollectMaxUDAFEvaluator() {
      this(true);
    }

    public MapCollectMaxUDAFEvaluator(boolean desc) {
       this.descending = desc; 
    }

    public class SortedKeyValue implements Comparable {
      private Object keyObj;
      private Object valObj;


      public SortedKeyValue(Object keyObj, Object valObj ) {
      this.keyObj = keyObj;
        this.valObj = valObj;
      }

      @Override
      public boolean equals(Object other) {
        if(!(other instanceof SortedKeyValue )) {
          return false;
        }
        SortedKeyValue otherKV= (SortedKeyValue)other;
        if( getKey().equals(otherKV.getKey())) {
          return true;
        } else {
          return false;
        }
      }

      public Object getKey() { return (keyObj == null) ? null : inputKeyOI.getPrimitiveJavaObject(keyObj); }
      public Double getValue() { return (valObj == null ? null : ((Number)inputValOI.getPrimitiveJavaObject(valObj)).doubleValue()); }
     
      public String toString() { return getKey() + ":=" + getValue(); }

      @Override
      public int compareTo(Object arg1) {
        SortedKeyValue otherKV = (SortedKeyValue) arg1;
        int cmp = compareKV( otherKV);
        return ( descending ? cmp : -1*cmp);
      }
     
      public int compareKV(SortedKeyValue otherKV) {
       
        double thisNumber = getValue();
        double otherNumber = otherKV.getValue();
        int sigNum  = (int) Math.signum(otherNumber - thisNumber);
        if(sigNum !=0 )  {
            return sigNum;
        } else {
          return ((Comparable)getKey()).compareTo(otherKV.getKey() );
        }
      }
    }

    class MapAggBuffer implements AggregationBuffer {
      private TreeSet<SortedKeyValue> sortedValues = new TreeSet<SortedKeyValue>();


      public void addValue(Object keyObj, Object valObj) {
        if( sortedValues.size() < numValues) {
          SortedKeyValue newValue =  new SortedKeyValue(inputKeyOI.copyObject(keyObj), inputValOI.copyObject(valObj) );
          sortedValues.add( newValue);
        } else {
          SortedKeyValue minValue = sortedValues.last();
          SortedKeyValue biggerValue =  new SortedKeyValue(inputKeyOI.copyObject(keyObj), inputValOI.copyObject(valObj));
          int cmp =  biggerValue.compareTo(minValue);
          if( cmp < 0 ) {
            sortedValues.remove( minValue);
            sortedValues.add(biggerValue);
          }
        }
      }


      public Map getValueMap() {
        LinkedHashMap<Object, Object> reverseOrderMap = new LinkedHashMap<Object,Object>();
        for( SortedKeyValue kv : sortedValues ) {
          reverseOrderMap.put( kv.keyObj, kv.valObj);
        }
        return reverseOrderMap;
      }

      public void reset() {
        sortedValues.clear();
      }
    }

    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
        throws HiveException {
      super.init(m, parameters);
      LOG.error(" CollectMaxUDAF.init() - Mode= " + m.name() );
      for(int i=0; i<parameters.length; ++i) {
        LOG.error(" ObjectInspector[ "+ i + " ] = " + parameters[0]);
      }
      if(parameters.length > 2) {
        if( parameters[2] instanceof WritableConstantIntObjectInspector ) {
          WritableConstantIntObjectInspector nvOI = (WritableConstantIntObjectInspector) parameters[2];
          numValues = nvOI.getWritableConstantValue().get();
          LOG.info(" Setting number of values to " + numValues);
        } else {
          throw new HiveException("Number of values must be a constant int.");
        }
      }

      // init output object inspectors
      // The output of a partial aggregation is a map
      if (m == Mode.PARTIAL1) {
        inputKeyOI = (PrimitiveObjectInspector) parameters[0];
        inputValOI = (PrimitiveObjectInspector) parameters[1];
        return ObjectInspectorFactory.getStandardMapObjectInspector(
            inputKeyOI,
            inputValOI );
      } else {
        if (!(parameters[0] instanceof StandardMapObjectInspector)) {
          inputKeyOI = (PrimitiveObjectInspector) parameters[0];
          inputValOI = (PrimitiveObjectInspector) parameters[1];
        } else {
          internalMergeOI = (StandardMapObjectInspector) parameters[0];
          inputKeyOI = (PrimitiveObjectInspector) internalMergeOI.getMapKeyObjectInspector();
          inputValOI = (PrimitiveObjectInspector) internalMergeOI.getMapValueObjectInspector();
        }
      }
      return ObjectInspectorFactory.getStandardMapObjectInspector(
          inputKeyOI,
          inputValOI );
    }

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      MapAggBuffer buff= new MapAggBuffer();
      reset(buff);
      return buff;
    }

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
        throws HiveException {
      Object k = parameters[0];
      Object v = parameters[1];
      if (k == null || v == null) {
        throw new HiveException("Key or value is null.  k = " + k + " , v = " + v);
      }

      if (k != null) {
        MapAggBuffer myagg = (MapAggBuffer) agg;

        putIntoSet(k, v, myagg);
      }
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial)
        throws HiveException {
      MapAggBuffer myagg = (MapAggBuffer) agg;
      Map<Object,Object> partialResult = (Map<Object,Object>internalMergeOI.getMap(partial);
      for(Object i : partialResult.keySet()) {
        putIntoSet(i, partialResult.get(i), myagg);
      }
    }

    @Override
    public void reset(AggregationBuffer buff) throws HiveException {
      MapAggBuffer arrayBuff = (MapAggBuffer) buff;
      arrayBuff.reset();
    }

    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException {
      MapAggBuffer myagg = (MapAggBuffer) agg;
      return myagg.getValueMap();

    }

    private void putIntoSet(Object key, Object val, MapAggBuffer myagg) {
      myagg.addValue(key, val);
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException {

      MapAggBuffer myagg = (MapAggBuffer) agg;
      Map<Object, Object> vals =  myagg.getValueMap();
      return vals;
    }
  }


}
TOP

Related Classes of brickhouse.udf.collect.CollectMaxUDAF$MapCollectMaxUDAFEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.