Package org.apache.tez.runtime.library.common.shuffle.impl

Source Code of org.apache.tez.runtime.library.common.shuffle.impl.ShuffleScheduler$Penalty

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.library.common.shuffle.impl;

import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.TezInputContext;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.TezRuntimeUtils;

import com.google.common.collect.Lists;

class ShuffleScheduler {
  static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
    protected Long initialValue() {
      return 0L;
    }
  };

  private static final Log LOG = LogFactory.getLog(ShuffleScheduler.class);
  private static final int MAX_MAPS_AT_ONCE = 20;
  private static final long INITIAL_PENALTY = 10000;
  private static final float PENALTY_GROWTH_RATE = 1.3f;
 
  // TODO NEWTEZ May need to be a string if attempting to fetch from multiple inputs.
  private boolean[] finishedMaps;
  private final int numInputs;
  private int remainingMaps;
  private Map<String, MapHost> mapLocations = new HashMap<String, MapHost>();
  //TODO NEWTEZ Clean this and other maps at some point
  private ConcurrentMap<String, InputAttemptIdentifier> pathToIdentifierMap = new ConcurrentHashMap<String, InputAttemptIdentifier>();
  private Set<MapHost> pendingHosts = new HashSet<MapHost>();
  private Set<InputAttemptIdentifier> obsoleteMaps = new HashSet<InputAttemptIdentifier>();
 
  private final Random random = new Random(System.currentTimeMillis());
  private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
  private final Referee referee = new Referee();
  private final Map<InputAttemptIdentifier, IntWritable> failureCounts =
    new HashMap<InputAttemptIdentifier,IntWritable>();
  private final Map<String,IntWritable> hostFailures =
    new HashMap<String,IntWritable>();
  private final TezInputContext inputContext;
  private final Shuffle shuffle;
  private final int abortFailureLimit;
  private final TezCounter shuffledMapsCounter;
  private final TezCounter reduceShuffleBytes;
  private final TezCounter failedShuffleCounter;
 
  private final long startTime;
  private long lastProgressTime;
 
  private int maxMapRuntime = 0;
  private int maxFailedUniqueFetches = 5;
  private int maxFetchFailuresBeforeReporting;
 
  private long totalBytesShuffledTillNow = 0;
  private DecimalFormat  mbpsFormat = new DecimalFormat("0.00");

  private boolean reportReadErrorImmediately = true;
 
  public ShuffleScheduler(TezInputContext inputContext,
                          Configuration conf,
                          int numberOfInputs,
                          Shuffle shuffle,
                          TezCounter shuffledMapsCounter,
                          TezCounter reduceShuffleBytes,
                          TezCounter failedShuffleCounter) {
    this.inputContext = inputContext;
    this.numInputs = numberOfInputs;
    abortFailureLimit = Math.max(30, numberOfInputs / 10);
    remainingMaps = numberOfInputs;
    finishedMaps = new boolean[remainingMaps]; // default init to false
    this.shuffle = shuffle;
    this.shuffledMapsCounter = shuffledMapsCounter;
    this.reduceShuffleBytes = reduceShuffleBytes;
    this.failedShuffleCounter = failedShuffleCounter;
    this.startTime = System.currentTimeMillis();
    this.lastProgressTime = startTime;
    referee.start();
    this.maxFailedUniqueFetches = Math.min(numberOfInputs,
        this.maxFailedUniqueFetches);
    this.maxFetchFailuresBeforeReporting =
        conf.getInt(
            TezJobConfig.TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES,
            TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_FETCH_FAILURES_LIMIT);
    this.reportReadErrorImmediately =
        conf.getBoolean(
            TezJobConfig.TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR,
            TezJobConfig.DEFAULT_TEZ_RUNTIME_SHUFFLE_NOTIFY_READERROR);
  }

  public synchronized void copySucceeded(InputAttemptIdentifier srcAttemptIdentifier,
                                         MapHost host,
                                         long bytes,
                                         long milis,
                                         MapOutput output
                                         ) throws IOException {
    String taskIdentifier = TezRuntimeUtils.getTaskAttemptIdentifier(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(), srcAttemptIdentifier.getAttemptNumber());
    failureCounts.remove(taskIdentifier);
    hostFailures.remove(host.getHostName());
   
    if (!isInputFinished(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex())) {
      output.commit();
      setInputFinished(srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex());
      shuffledMapsCounter.increment(1);
      if (--remainingMaps == 0) {
        notifyAll();
      }

      // update the status
      lastProgressTime = System.currentTimeMillis();
      totalBytesShuffledTillNow += bytes;
      logProgress();
      reduceShuffleBytes.increment(bytes);
      if (LOG.isDebugEnabled()) {
        LOG.debug("src task: "
            + TezRuntimeUtils.getTaskAttemptIdentifier(
                inputContext.getSourceVertexName(), srcAttemptIdentifier.getInputIdentifier().getSrcTaskIndex(),
                srcAttemptIdentifier.getAttemptNumber()) + " done");
      }
    }
    // TODO NEWTEZ Should this be releasing the output, if not committed ? Possible memory leak in case of speculation.
  }

  private void logProgress() {
    float mbs = (float) totalBytesShuffledTillNow / (1024 * 1024);
    int mapsDone = numInputs - remainingMaps;
    long secsSinceStart = (System.currentTimeMillis() - startTime) / 1000 + 1;

    float transferRate = mbs / secsSinceStart;
    LOG.info("copy(" + mapsDone + " of " + numInputs + " at "
        + mbpsFormat.format(transferRate) + " MB/s)");
  }

  public synchronized void copyFailed(InputAttemptIdentifier srcAttempt,
                                      MapHost host,
                                      boolean readError) {
    host.penalize();
    int failures = 1;
    if (failureCounts.containsKey(srcAttempt)) {
      IntWritable x = failureCounts.get(srcAttempt);
      x.set(x.get() + 1);
      failures = x.get();
    } else {
      failureCounts.put(srcAttempt, new IntWritable(1));     
    }
    String hostname = host.getHostName();
    if (hostFailures.containsKey(hostname)) {
      IntWritable x = hostFailures.get(hostname);
      x.set(x.get() + 1);
    } else {
      hostFailures.put(hostname, new IntWritable(1));
    }
    if (failures >= abortFailureLimit) {
      IOException ioe = new IOException(failures
            + " failures downloading "
            + TezRuntimeUtils.getTaskAttemptIdentifier(
                inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
                srcAttempt.getAttemptNumber()));
      ioe.fillInStackTrace();
      shuffle.reportException(ioe);
    }
   
    checkAndInformJobTracker(failures, srcAttempt, readError);

    checkReducerHealth();
   
    long delay = (long) (INITIAL_PENALTY *
        Math.pow(PENALTY_GROWTH_RATE, failures));
   
    penalties.add(new Penalty(host, delay));
   
    failedShuffleCounter.increment(1);
  }
 
  // Notify the JobTracker 
  // after every read error, if 'reportReadErrorImmediately' is true or
  // after every 'maxFetchFailuresBeforeReporting' failures
  private void checkAndInformJobTracker(
      int failures, InputAttemptIdentifier srcAttempt, boolean readError) {
    if ((reportReadErrorImmediately && readError)
        || ((failures % maxFetchFailuresBeforeReporting) == 0)) {
      LOG.info("Reporting fetch failure for "
          + TezRuntimeUtils.getTaskAttemptIdentifier(
              inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
              srcAttempt.getAttemptNumber()) + " to jobtracker.");

      List<Event> failedEvents = Lists.newArrayListWithCapacity(1);
      failedEvents.add(new InputReadErrorEvent("Fetch failure for "
          + TezRuntimeUtils.getTaskAttemptIdentifier(
              inputContext.getSourceVertexName(), srcAttempt.getInputIdentifier().getSrcTaskIndex(),
              srcAttempt.getAttemptNumber()) + " to jobtracker.", srcAttempt.getInputIdentifier()
          .getSrcTaskIndex(), srcAttempt.getAttemptNumber()));

      inputContext.sendEvents(failedEvents);     
      //status.addFailedDependency(mapId);
    }
  }
   
  private void checkReducerHealth() {
    final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
    final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
    final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;

    long totalFailures = failedShuffleCounter.getValue();
    int doneMaps = numInputs - remainingMaps;
   
    boolean reducerHealthy =
      (((float)totalFailures / (totalFailures + doneMaps))
          < MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);
   
    // check if the reducer has progressed enough
    boolean reducerProgressedEnough =
      (((float)doneMaps / numInputs)
          >= MIN_REQUIRED_PROGRESS_PERCENT);

    // check if the reducer is stalled for a long time
    // duration for which the reducer is stalled
    int stallDuration =
      (int)(System.currentTimeMillis() - lastProgressTime);
   
    // duration for which the reducer ran with progress
    int shuffleProgressDuration =
      (int)(lastProgressTime - startTime);

    // min time the reducer should run without getting killed
    int minShuffleRunDuration =
      (shuffleProgressDuration > maxMapRuntime)
      ? shuffleProgressDuration
          : maxMapRuntime;
   
    boolean reducerStalled =
      (((float)stallDuration / minShuffleRunDuration)
          >= MAX_ALLOWED_STALL_TIME_PERCENT);

    // kill if not healthy and has insufficient progress
    if ((failureCounts.size() >= maxFailedUniqueFetches ||
        failureCounts.size() == (numInputs - doneMaps))
        && !reducerHealthy
        && (!reducerProgressedEnough || reducerStalled)) {
      LOG.fatal("Shuffle failed with too many fetch failures " +
      "and insufficient progress!");
      String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
      shuffle.reportException(new IOException(errorMsg));
    }

  }
 
  public synchronized void addKnownMapOutput(String hostName,
                                             int partitionId,
                                             String hostUrl,
                                             InputAttemptIdentifier srcAttempt) {
    String identifier = MapHost.createIdentifier(hostName, partitionId);
    MapHost host = mapLocations.get(identifier);
    if (host == null) {
      host = new MapHost(partitionId, hostName, hostUrl);
      assert identifier.equals(host.getIdentifier());
      mapLocations.put(identifier, host);
    }
    host.addKnownMap(srcAttempt);
    pathToIdentifierMap.put(
        getIdentifierFromPathAndReduceId(srcAttempt.getPathComponent(), partitionId), srcAttempt);

    // Mark the host as pending
    if (host.getState() == MapHost.State.PENDING) {
      pendingHosts.add(host);
      notifyAll();
    }
  }
 
  public synchronized void obsoleteMapOutput(InputAttemptIdentifier srcAttempt) {
    // The incoming srcAttempt does not contain a path component.
    obsoleteMaps.add(srcAttempt);
  }
 
  public synchronized void putBackKnownMapOutput(MapHost host,
                                                 InputAttemptIdentifier srcAttempt) {
    host.addKnownMap(srcAttempt);
  }

  public synchronized MapHost getHost() throws InterruptedException {
      while(pendingHosts.isEmpty()) {
        wait();
      }
     
      MapHost host = null;
      Iterator<MapHost> iter = pendingHosts.iterator();
      int numToPick = random.nextInt(pendingHosts.size());
      for (int i=0; i <= numToPick; ++i) {
        host = iter.next();
      }
     
      pendingHosts.remove(host);    
      host.markBusy();
     
      LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() +
               " to " + Thread.currentThread().getName());
      shuffleStart.set(System.currentTimeMillis());
     
      return host;
  }
 
  public InputAttemptIdentifier getIdentifierForFetchedOutput(
      String path, int reduceId) {
    return pathToIdentifierMap.get(getIdentifierFromPathAndReduceId(path, reduceId));
  }
 
  public synchronized List<InputAttemptIdentifier> getMapsForHost(MapHost host) {
    List<InputAttemptIdentifier> list = host.getAndClearKnownMaps();
    Iterator<InputAttemptIdentifier> itr = list.iterator();
    List<InputAttemptIdentifier> result = new ArrayList<InputAttemptIdentifier>();
    int includedMaps = 0;
    int totalSize = list.size();
    // find the maps that we still need, up to the limit
    while (itr.hasNext()) {
      InputAttemptIdentifier id = itr.next();
      if (!obsoleteMaps.contains(id) && !isInputFinished(id.getInputIdentifier().getSrcTaskIndex())) {
        result.add(id);
        if (++includedMaps >= MAX_MAPS_AT_ONCE) {
          break;
        }
      }
    }
    // put back the maps left after the limit
    while (itr.hasNext()) {
      InputAttemptIdentifier id = itr.next();
      if (!obsoleteMaps.contains(id) && !isInputFinished(id.getInputIdentifier().getSrcTaskIndex())) {
        host.addKnownMap(id);
      }
    }
    LOG.info("assigned " + includedMaps + " of " + totalSize + " to " +
             host + " to " + Thread.currentThread().getName());
    return result;
  }

  public synchronized void freeHost(MapHost host) {
    if (host.getState() != MapHost.State.PENALIZED) {
      if (host.markAvailable() == MapHost.State.PENDING) {
        pendingHosts.add(host);
        notifyAll();
      }
    }
    LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " +
             (System.currentTimeMillis()-shuffleStart.get()) + "s");
  }
   
  public synchronized void resetKnownMaps() {
    mapLocations.clear();
    obsoleteMaps.clear();
    pendingHosts.clear();
    pathToIdentifierMap.clear();
  }

  /**
   * Utility method to check if the Shuffle data fetch is complete.
   * @return
   */
  public synchronized boolean isDone() {
    return remainingMaps == 0;
  }

  /**
   * Wait until the shuffle finishes or until the timeout.
   * @param millis maximum wait time
   * @return true if the shuffle is done
   * @throws InterruptedException
   */
  public synchronized boolean waitUntilDone(int millis
                                            ) throws InterruptedException {
    if (remainingMaps > 0) {
      wait(millis);
      return remainingMaps == 0;
    }
    return true;
  }
 
  /**
   * A structure that records the penalty for a host.
   */
  private static class Penalty implements Delayed {
    MapHost host;
    private long endTime;
   
    Penalty(MapHost host, long delay) {
      this.host = host;
      this.endTime = System.currentTimeMillis() + delay;
    }

    public long getDelay(TimeUnit unit) {
      long remainingTime = endTime - System.currentTimeMillis();
      return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
    }

    public int compareTo(Delayed o) {
      long other = ((Penalty) o).endTime;
      return endTime == other ? 0 : (endTime < other ? -1 : 1);
    }
   
  }
 
  private String getIdentifierFromPathAndReduceId(String path, int reduceId) {
    return path + "_" + reduceId;
  }
 
  /**
   * A thread that takes hosts off of the penalty list when the timer expires.
   */
  private class Referee extends Thread {
    public Referee() {
      setName("ShufflePenaltyReferee");
      setDaemon(true);
    }

    public void run() {
      try {
        while (true) {
          // take the first host that has an expired penalty
          MapHost host = penalties.take().host;
          synchronized (ShuffleScheduler.this) {
            if (host.markAvailable() == MapHost.State.PENDING) {
              pendingHosts.add(host);
              ShuffleScheduler.this.notifyAll();
            }
          }
        }
      } catch (InterruptedException ie) {
        return;
      } catch (Throwable t) {
        shuffle.reportException(t);
      }
    }
  }
 
  public void close() throws InterruptedException {
    referee.interrupt();
    referee.join();
  }

  public synchronized void informMaxMapRunTime(int duration) {
    if (duration > maxMapRuntime) {
      maxMapRuntime = duration;
    }
  }
 
  void setInputFinished(int inputIndex) {
    synchronized(finishedMaps) {
      finishedMaps[inputIndex] = true;
    }
  }
 
  boolean isInputFinished(int inputIndex) {
    synchronized (finishedMaps) {
      return finishedMaps[inputIndex];     
    }
  }
}
TOP

Related Classes of org.apache.tez.runtime.library.common.shuffle.impl.ShuffleScheduler$Penalty

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.