Package org.apache.hadoop.hive.ql.exec

Source Code of org.apache.hadoop.hive.ql.exec.GroupByOperator

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.hive.ql.exec;

import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.io.Serializable;
import java.lang.reflect.Method;

import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.aggregationDesc;
import org.apache.hadoop.hive.ql.plan.exprNodeDesc;
import org.apache.hadoop.hive.ql.plan.groupByDesc;
import org.apache.hadoop.hive.serde2.objectinspector.InspectableObject;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.conf.Configuration;

/**
* GroupBy operator implementation.
*/
public class GroupByOperator extends Operator <groupByDesc> implements Serializable {

  private static final long serialVersionUID = 1L;
  transient protected ExprNodeEvaluator[] keyFields;
  transient protected ExprNodeEvaluator[][] aggregationParameterFields;
  // In the future, we may allow both count(DISTINCT a) and sum(DISTINCT a) in the same SQL clause,
  // so aggregationIsDistinct is a boolean array instead of a single number.
  transient protected boolean[] aggregationIsDistinct;

  transient Class<? extends UDAF>[] aggregationClasses;
  transient protected Method[] aggregationsAggregateMethods;
  transient protected Method[] aggregationsEvaluateMethods;

  transient protected ArrayList<ObjectInspector> objectInspectors;
  transient protected ObjectInspector outputObjectInspector;

  // Used by sort-based GroupBy: Mode = COMPLETE, PARTIAL1, PARTIAL2
  transient protected ArrayList<Object> currentKeys;
  transient protected UDAF[] aggregations;
  transient protected Object[][] aggregationsParametersLastInvoke;

  // Used by hash-based GroupBy: Mode = HASH
  transient protected HashMap<ArrayList<Object>, UDAF[]> hashAggregations;
 
  transient boolean firstRow;
 
  public void initialize(Configuration hconf) throws HiveException {
    super.initialize(hconf);
    try {
      // init keyFields
      keyFields = new ExprNodeEvaluator[conf.getKeys().size()];
      for (int i = 0; i < keyFields.length; i++) {
        keyFields[i] = ExprNodeEvaluatorFactory.get(conf.getKeys().get(i));
      }
   
      // init aggregationParameterFields
      aggregationParameterFields = new ExprNodeEvaluator[conf.getAggregators().size()][];
      for (int i = 0; i < aggregationParameterFields.length; i++) {
        ArrayList<exprNodeDesc> parameters = conf.getAggregators().get(i).getParameters();
        aggregationParameterFields[i] = new ExprNodeEvaluator[parameters.size()];
        for (int j = 0; j < parameters.size(); j++) {
          aggregationParameterFields[i][j] = ExprNodeEvaluatorFactory.get(parameters.get(j));
        }
      }
      // init aggregationIsDistinct
      aggregationIsDistinct = new boolean[conf.getAggregators().size()];
      for(int i=0; i<aggregationIsDistinct.length; i++) {
        aggregationIsDistinct[i] = conf.getAggregators().get(i).getDistinct();
      }

      // init aggregationClasses 
      aggregationClasses = (Class<? extends UDAF>[]) new Class[conf.getAggregators().size()];
      for (int i = 0; i < conf.getAggregators().size(); i++) {
        aggregationDesc agg = conf.getAggregators().get(i);
        aggregationClasses[i] = agg.getAggregationClass();
      }

      // init aggregations, aggregationsAggregateMethods,
      // aggregationsEvaluateMethods
      aggregationsAggregateMethods = new Method[aggregationClasses.length];
      aggregationsEvaluateMethods = new Method[aggregationClasses.length];
      String aggregateMethodName = (conf.getMode() == groupByDesc.Mode.PARTIAL2
         ? "aggregatePartial" : "aggregate");
      String evaluateMethodName = ((conf.getMode() == groupByDesc.Mode.PARTIAL1 || conf.getMode() == groupByDesc.Mode.HASH)
         ? "evaluatePartial" : "evaluate");
      for(int i=0; i<aggregationClasses.length; i++) {
        // aggregationsAggregateMethods
        for( Method m : aggregationClasses[i].getMethods() ){
          if( m.getName().equals( aggregateMethodName )
              && m.getParameterTypes().length == aggregationParameterFields[i].length) {             
            aggregationsAggregateMethods[i] = m;
            break;
          }
        }
        if (null == aggregationsAggregateMethods[i]) {
          throw new RuntimeException("Cannot find " + aggregateMethodName + " method of UDAF class "
                                   + aggregationClasses[i].getName() + " that accepts "
                                   + aggregationParameterFields[i].length + " parameters!");
        }
        // aggregationsEvaluateMethods
        aggregationsEvaluateMethods[i] = aggregationClasses[i].getMethod(evaluateMethodName);

        if (null == aggregationsEvaluateMethods[i]) {
          throw new RuntimeException("Cannot find " + evaluateMethodName + " method of UDAF class "
                                   + aggregationClasses[i].getName() + "!");
        }
        assert(aggregationsEvaluateMethods[i] != null);
      }

      if (conf.getMode() != groupByDesc.Mode.HASH) {
        aggregationsParametersLastInvoke = new Object[conf.getAggregators().size()][];
        aggregations = newAggregations();
      } else {
        hashAggregations = new HashMap<ArrayList<Object>, UDAF[]>();
      }
      // init objectInspectors
      int totalFields = keyFields.length + aggregationClasses.length;
      objectInspectors = new ArrayList<ObjectInspector>(totalFields);
      for(int i=0; i<keyFields.length; i++) {
        objectInspectors.add(null);
      }
      for(int i=0; i<aggregationClasses.length; i++) {
        objectInspectors.add(ObjectInspectorFactory.getStandardPrimitiveObjectInspector(
            aggregationsEvaluateMethods[i].getReturnType()));
      }
     
      firstRow = true;
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
  }

  protected UDAF[] newAggregations() throws Exception {     
    UDAF[] aggs = new UDAF[aggregationClasses.length];
    for(int i=0; i<aggregationClasses.length; i++) {
      aggs[i] = aggregationClasses[i].newInstance();
      aggs[i].init();
    }
    return aggs;
  }

  InspectableObject tempInspectableObject = new InspectableObject();
 
  protected void updateAggregations(UDAF[] aggs, Object row, ObjectInspector rowInspector, Object[][] lastInvoke) throws Exception {
    for(int ai=0; ai<aggs.length; ai++) {
      // Calculate the parameters
      Object[] o = new Object[aggregationParameterFields[ai].length];
      for(int pi=0; pi<aggregationParameterFields[ai].length; pi++) {
        aggregationParameterFields[ai][pi].evaluate(row, rowInspector, tempInspectableObject);
        o[pi] = tempInspectableObject.o;
      }
      // Update the aggregations.
      if (aggregationIsDistinct[ai] && lastInvoke != null) {
        // different differentParameters?
        boolean differentParameters = (lastInvoke[ai] == null);
        if (!differentParameters) {
          for(int pi=0; pi<o.length; pi++) {
            if (!o[pi].equals(lastInvoke[ai][pi])) {
              differentParameters = true;
              break;
            }
          }
        } 
        if (differentParameters) {
          aggregationsAggregateMethods[ai].invoke(aggs[ai], o);
          lastInvoke[ai] = o;
        }
      } else {
        aggregationsAggregateMethods[ai].invoke(aggs[ai], o);
      }
    }
  }
 
  public void process(Object row, ObjectInspector rowInspector) throws HiveException {
   
    try {
      // Compute the keys
      ArrayList<Object> newKeys = new ArrayList<Object>(keyFields.length);
      for (int i = 0; i < keyFields.length; i++) {
        keyFields[i].evaluate(row, rowInspector, tempInspectableObject);
        newKeys.add(tempInspectableObject.o);
        if (firstRow) {
          objectInspectors.set(i, tempInspectableObject.oi);
        }
      }
      if (firstRow) {
        firstRow = false;
        ArrayList<String> fieldNames = new ArrayList<String>(objectInspectors.size());
        for(int i=0; i<objectInspectors.size(); i++) {
          fieldNames.add(Integer.valueOf(i).toString());
        }
        outputObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector(
          fieldNames, objectInspectors);
      }
      // Prepare aggs for updating
      UDAF[] aggs = null;
      Object[][] lastInvoke = null;
      if (aggregations != null) {
        // sort-based aggregation
        // Need to forward?
        boolean keysAreEqual = newKeys.equals(currentKeys);
        if (currentKeys != null && !keysAreEqual) {
          forward(currentKeys, aggregations);
        }
        // Need to update the keys?
        if (currentKeys == null || !keysAreEqual) {
          currentKeys = newKeys;
          // init aggregations
          for(UDAF aggregation: aggregations) {
            aggregation.init();
          }
          // clear parameters in last-invoke
          for(int i=0; i<aggregationsParametersLastInvoke.length; i++) {
            aggregationsParametersLastInvoke[i] = null;
          }
        }
        aggs = aggregations;
        lastInvoke = aggregationsParametersLastInvoke;
      } else {
        // hash-based aggregations
        aggs = hashAggregations.get(newKeys);
        if (aggs == null) {
          aggs = newAggregations();
          hashAggregations.put(newKeys, aggs);
          // TODO: Hash aggregation does not support DISTINCT now
          lastInvoke = null;
        }
      }

      // Update the aggs
      updateAggregations(aggs, row, rowInspector, lastInvoke);

    } catch (Exception e) {
      e.printStackTrace();
      throw new HiveException(e);
    }
  }
 
  /**
   * Forward a record of keys and aggregation results.
   *
   * @param keys
   *          The keys in the record
   * @throws HiveException
   */
  protected void forward(ArrayList<Object> keys, UDAF[] aggs) throws Exception {
    int totalFields = keys.size() + aggs.length;
    List<Object> a = new ArrayList<Object>(totalFields);
    for(int i=0; i<keys.size(); i++) {
      a.add(keys.get(i));
    }
    for(int i=0; i<aggs.length; i++) {
      a.add(aggregationsEvaluateMethods[i].invoke(aggs[i]));
    }
    forward(a, outputObjectInspector);
  }
 
  /**
   * We need to forward all the aggregations to children.
   *
   */
  public void close(boolean abort) throws HiveException {
    if (!abort) {
      try {
        if (aggregations != null) {
          // sort-based aggregations
          if (currentKeys != null) {
            forward(currentKeys, aggregations);
          }
        } else if (hashAggregations != null) {
          // hash-based aggregations
          for (ArrayList<Object> key: hashAggregations.keySet()) {
            forward(key, hashAggregations.get(key));
          }
        } else {
          // The GroupByOperator is not initialized, which means there is no data
          // (since we initialize the operators when we see the first record).
          // Just do nothing here.
        }
      } catch (Exception e) {
        e.printStackTrace();
        throw new HiveException(e);
      }
    }
    super.close(abort);
  }

}
TOP

Related Classes of org.apache.hadoop.hive.ql.exec.GroupByOperator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.