Package org.apache.hadoop.hive.ql.exec.tez

Source Code of org.apache.hadoop.hive.ql.exec.tez.CustomPartitionVertex

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.hive.ql.exec.tez;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.io.HiveInputFormat;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.split.TezGroupedSplitsInputFormat;
import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper;
import org.apache.tez.dag.api.EdgeManagerDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.mapreduce.hadoop.MRHelpers;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRInputUserPayloadProto;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitProto;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
import org.apache.tez.runtime.api.events.RootInputUpdatePayloadEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;


/*
* Only works with old mapred API
* Will only work with a single MRInput for now.
*/
public class CustomPartitionVertex implements VertexManagerPlugin {

  private static final Log LOG = LogFactory.getLog(CustomPartitionVertex.class.getName());


  VertexManagerPluginContext context;

  private Multimap<Integer, Integer> bucketToTaskMap = HashMultimap.<Integer, Integer>create();
  private Multimap<Integer, InputSplit> bucketToInitialSplitMap =
      ArrayListMultimap.<Integer, InputSplit>create();

  private RootInputConfigureVertexTasksEvent configureVertexTaskEvent;
  private List<RootInputDataInformationEvent> dataInformationEvents;
  private Map<Path, List<FileSplit>> pathFileSplitsMap = new TreeMap<Path, List<FileSplit>>();
  private int numBuckets = -1;
  private Configuration conf = null;
  private boolean rootVertexInitialized = false;
  Multimap<Integer, InputSplit> bucketToGroupedSplitMap;


  private Map<Integer, Integer> bucketToNumTaskMap = new HashMap<Integer, Integer>();

  public CustomPartitionVertex() {
  }

  @Override
  public void initialize(VertexManagerPluginContext context) {
    this.context = context;
    ByteBuffer byteBuf = ByteBuffer.wrap(context.getUserPayload());
    this.numBuckets = byteBuf.getInt();
  }

  @Override
  public void onVertexStarted(Map<String, List<Integer>> completions) {
    int numTasks = context.getVertexNumTasks(context.getVertexName());
    List<Integer> scheduledTasks = new ArrayList<Integer>(numTasks);
    for (int i=0; i<numTasks; ++i) {
      scheduledTasks.add(new Integer(i));
    }
    context.scheduleVertexTasks(scheduledTasks);
  }

  @Override
  public void onSourceTaskCompleted(String srcVertexName, Integer attemptId) {
  }

  @Override
  public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
  }

  // One call per root Input - and for now only one is handled.
  @Override
  public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor,
      List<Event> events) {

    // Ideally, since there's only 1 Input expected at the moment -
    // ensure this method is called only once. Tez will call it once per Root Input.
    Preconditions.checkState(rootVertexInitialized == false);
    rootVertexInitialized = true;
    try {
      // This is using the payload from the RootVertexInitializer corresponding
      // to InputName. Ideally it should be using it's own configuration class - but that
      // means serializing another instance.
      MRInputUserPayloadProto protoPayload =
          MRHelpers.parseMRInputPayload(inputDescriptor.getUserPayload());
      this.conf = MRHelpers.createConfFromByteString(protoPayload.getConfigurationBytes());

      /*
       * Currently in tez, the flow of events is thus: "Generate Splits -> Initialize Vertex"
       * (with parallelism info obtained from the generate splits phase). The generate splits
       * phase groups splits using the TezGroupedSplitsInputFormat. However, for bucket map joins
       * the grouping done by this input format results in incorrect results as the grouper has no
       * knowledge of buckets. So, we initially set the input format to be HiveInputFormat
       * (in DagUtils) for the case of bucket map joins so as to obtain un-grouped splits.
       * We then group the splits corresponding to buckets using the tez grouper which returns
       * TezGroupedSplits.
       */

      // This assumes that Grouping will always be used.
      // Changing the InputFormat - so that the correct one is initialized in MRInput.
      this.conf.set("mapred.input.format.class", TezGroupedSplitsInputFormat.class.getName());
      MRInputUserPayloadProto updatedPayload = MRInputUserPayloadProto
          .newBuilder(protoPayload)
          .setConfigurationBytes(MRHelpers.createByteStringFromConf(conf))
          .build();
      inputDescriptor.setUserPayload(updatedPayload.toByteArray());
    } catch (IOException e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    }
    boolean dataInformationEventSeen = false;
    for (Event event : events) {
      if (event instanceof RootInputConfigureVertexTasksEvent) {
        // No tasks should have been started yet. Checked by initial state check.
        Preconditions.checkState(dataInformationEventSeen == false);
        Preconditions
        .checkState(
            context.getVertexNumTasks(context.getVertexName()) == -1,
            "Parallelism for the vertex should be set to -1 if the InputInitializer is setting parallelism");
        RootInputConfigureVertexTasksEvent cEvent = (RootInputConfigureVertexTasksEvent) event;

        // The vertex cannot be configured until all DataEvents are seen - to build the routing table.
        configureVertexTaskEvent = cEvent;
        dataInformationEvents = Lists.newArrayListWithCapacity(configureVertexTaskEvent.getNumTasks());
      }
      if (event instanceof RootInputUpdatePayloadEvent) {
        // this event can never occur. If it does, fail.
        Preconditions.checkState(false);
      } else if (event instanceof RootInputDataInformationEvent) {
        dataInformationEventSeen = true;
        RootInputDataInformationEvent diEvent = (RootInputDataInformationEvent) event;
        dataInformationEvents.add(diEvent);
        FileSplit fileSplit;
        try {
          fileSplit = getFileSplitFromEvent(diEvent);
        } catch (IOException e) {
          throw new RuntimeException("Failed to get file split for event: " + diEvent);
        }
        List<FileSplit> fsList = pathFileSplitsMap.get(fileSplit.getPath());
        if (fsList == null) {
          fsList = new ArrayList<FileSplit>();
          pathFileSplitsMap.put(fileSplit.getPath(), fsList);
        }
        fsList.add(fileSplit);
      }
    }

    setBucketNumForPath(pathFileSplitsMap);
    try {
      groupSplits();
      processAllEvents(inputName);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  private void processAllEvents(String inputName) throws IOException {

    List<InputSplit> finalSplits = Lists.newLinkedList();
    int taskCount = 0;
    for (Entry<Integer, Collection<InputSplit>> entry : bucketToGroupedSplitMap.asMap().entrySet()) {
      int bucketNum = entry.getKey();
      Collection<InputSplit> initialSplits = entry.getValue();
      finalSplits.addAll(initialSplits);
      for (int i = 0; i < initialSplits.size(); i++) {
        bucketToTaskMap.put(bucketNum, taskCount);
        taskCount++;
      }
    }

    // Construct the EdgeManager descriptor to be used by all edges which need the routing table.
    EdgeManagerDescriptor hiveEdgeManagerDesc = new EdgeManagerDescriptor(
        CustomPartitionEdge.class.getName());   
    byte[] payload = getBytePayload(bucketToTaskMap);
    hiveEdgeManagerDesc.setUserPayload(payload);

    Map<String, EdgeManagerDescriptor> emMap = Maps.newHashMap();

    // Replace the edge manager for all vertices which have routing type custom.
    for (Entry<String, EdgeProperty> edgeEntry : context.getInputVertexEdgeProperties().entrySet()) {
      if (edgeEntry.getValue().getDataMovementType() == DataMovementType.CUSTOM
          && edgeEntry.getValue().getEdgeManagerDescriptor().getClassName()
              .equals(CustomPartitionEdge.class.getName())) {
        emMap.put(edgeEntry.getKey(), hiveEdgeManagerDesc);
      }
    }

    LOG.info("Task count is " + taskCount);

    List<RootInputDataInformationEvent> taskEvents = Lists.newArrayListWithCapacity(finalSplits.size());
    // Re-serialize the splits after grouping.
    int count = 0;
    for (InputSplit inputSplit : finalSplits) {
      MRSplitProto serializedSplit = MRHelpers.createSplitProto(inputSplit);
      RootInputDataInformationEvent diEvent = new RootInputDataInformationEvent(
          count, serializedSplit.toByteArray());
      diEvent.setTargetIndex(count);
      count++;
      taskEvents.add(diEvent);
    }

    // Replace the Edge Managers
    context.setVertexParallelism(
        taskCount,
        new VertexLocationHint(createTaskLocationHintsFromSplits(finalSplits
            .toArray(new InputSplit[finalSplits.size()]))), emMap);

    // Set the actual events for the tasks.
    context.addRootInputEvents(inputName, taskEvents);
  }

  private byte[] getBytePayload(Multimap<Integer, Integer> routingTable) throws IOException {
    CustomEdgeConfiguration edgeConf =
        new CustomEdgeConfiguration(routingTable.keySet().size(), routingTable);
    DataOutputBuffer dob = new DataOutputBuffer();
    edgeConf.write(dob);
    byte[] serialized = dob.getData();

    return serialized;
  }

  private FileSplit getFileSplitFromEvent(RootInputDataInformationEvent event)
      throws IOException {
    InputSplit inputSplit = null;
    if (event.getDeserializedUserPayload() != null) {
      inputSplit = (InputSplit) event.getDeserializedUserPayload();
    } else {
      MRSplitProto splitProto = MRSplitProto.parseFrom(event.getUserPayload());
      SerializationFactory serializationFactory = new SerializationFactory(
          new Configuration());
      inputSplit = MRHelpers.createOldFormatSplitFromUserPayload(splitProto,
          serializationFactory);
    }

    if (!(inputSplit instanceof FileSplit)) {
      throw new UnsupportedOperationException(
          "Cannot handle splits other than FileSplit for the moment");
    }
    return (FileSplit) inputSplit;
  }

  /*
   * This method generates the map of bucket to file splits.
   */
  private void setBucketNumForPath(Map<Path, List<FileSplit>> pathFileSplitsMap) {
    int bucketNum = 0;
    int fsCount = 0;
    for (Map.Entry<Path, List<FileSplit>> entry : pathFileSplitsMap.entrySet()) {
      int bucketId = bucketNum % numBuckets;
      for (FileSplit fsplit : entry.getValue()) {
        fsCount++;
        bucketToInitialSplitMap.put(bucketId, fsplit);
      }
      bucketNum++;
    }

    LOG.info("Total number of splits counted: " + fsCount + " and total files encountered: "
        + pathFileSplitsMap.size());
  }

  private void groupSplits () throws IOException {
    estimateBucketSizes();
    bucketToGroupedSplitMap =
        ArrayListMultimap.<Integer, InputSplit>create(bucketToInitialSplitMap);
   
    Map<Integer, Collection<InputSplit>> bucketSplitMap = bucketToInitialSplitMap.asMap();
    for (int bucketId : bucketSplitMap.keySet()) {
      Collection<InputSplit>inputSplitCollection = bucketSplitMap.get(bucketId);
      TezMapredSplitsGrouper grouper = new TezMapredSplitsGrouper();

      InputSplit[] groupedSplits = grouper.getGroupedSplits(conf,
          inputSplitCollection.toArray(new InputSplit[0]), bucketToNumTaskMap.get(bucketId),
          HiveInputFormat.class.getName());
      LOG.info("Original split size is " +
          inputSplitCollection.toArray(new InputSplit[0]).length +
          " grouped split size is " + groupedSplits.length);
      bucketToGroupedSplitMap.removeAll(bucketId);
      for (InputSplit inSplit : groupedSplits) {
        bucketToGroupedSplitMap.put(bucketId, inSplit);
      }
    }
  }

  private void estimateBucketSizes() {
    Map<Integer, Long>bucketSizeMap = new HashMap<Integer, Long>();
    Map<Integer, Collection<InputSplit>> bucketSplitMap = bucketToInitialSplitMap.asMap();
    long totalSize = 0;
    for (int bucketId : bucketSplitMap.keySet()) {
      Long size = 0L;
      Collection<InputSplit>inputSplitCollection = bucketSplitMap.get(bucketId);
      Iterator<InputSplit> iter = inputSplitCollection.iterator();
      while (iter.hasNext()) {
        FileSplit fsplit = (FileSplit)iter.next();
        size += fsplit.getLength();
        totalSize += fsplit.getLength();
      }
      bucketSizeMap.put(bucketId, size);
    }

    int totalResource = context.getTotalAVailableResource().getMemory();
    int taskResource = context.getVertexTaskResource().getMemory();
    float waves = conf.getFloat(
        TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES,
        TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES_DEFAULT);

    int numTasks = (int)((totalResource*waves)/taskResource);
    LOG.info("Total resource: " + totalResource + " Task Resource: " + taskResource
        + " waves: " + waves + " total size of splits: " + totalSize +
        " total number of tasks: " + numTasks);

    for (int bucketId : bucketSizeMap.keySet()) {
      int numEstimatedTasks = 0;
      if (totalSize != 0) {
        numEstimatedTasks = (int)(numTasks * bucketSizeMap.get(bucketId) / totalSize);
      }
      LOG.info("Estimated number of tasks: " + numEstimatedTasks + " for bucket " + bucketId);
      if (numEstimatedTasks == 0) {
        numEstimatedTasks = 1;
      }
      bucketToNumTaskMap.put(bucketId, numEstimatedTasks);
    }
  }

  private static List<TaskLocationHint> createTaskLocationHintsFromSplits(
      org.apache.hadoop.mapred.InputSplit[] oldFormatSplits) {
    Iterable<TaskLocationHint> iterable = Iterables.transform(Arrays.asList(oldFormatSplits),
        new Function<org.apache.hadoop.mapred.InputSplit, TaskLocationHint>() {
      @Override
      public TaskLocationHint apply(org.apache.hadoop.mapred.InputSplit input) {
        try {
          if (input.getLocations() != null) {
            return new TaskLocationHint(new HashSet<String>(Arrays.asList(input.getLocations())),
                null);
          } else {
            LOG.info("NULL Location: returning an empty location hint");
            return new TaskLocationHint(null,null);
          }
        } catch (IOException e) {
          throw new RuntimeException(e);
        }
      }
    });

    return Lists.newArrayList(iterable);
  }
}
TOP

Related Classes of org.apache.hadoop.hive.ql.exec.tez.CustomPartitionVertex

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.