Package org.apache.crunch.lib.join

Source Code of org.apache.crunch.lib.join.ShardedJoinStrategy

/**
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional information regarding
* copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable
* law or agreed to in writing, software distributed under the License is distributed on an "AS IS"
* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
* for the specific language governing permissions and limitations under the License.
*/
package org.apache.crunch.lib.join;

import java.io.Serializable;
import java.util.Random;

import org.apache.crunch.DoFn;
import org.apache.crunch.Emitter;
import org.apache.crunch.MapFn;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
import org.apache.crunch.types.PTableType;
import org.apache.crunch.types.PTypeFamily;

/**
* JoinStrategy that splits the key space up into shards.
* <p>
* This strategy is useful when there are multiple values per key on at least one side of the join,
* and a large proportion of the values are mapped to a small number of keys.
* <p>
* Using this strategy will increase the number of keys being joined, but can increase performance
* by spreading processing of a single key over multiple reduce groups.
* <p>
* A custom {@link ShardingStrategy} can be provided so that only certain keys are sharded, or
* keys can be sharded in accordance with how many values are mapped to them.
*/
public class ShardedJoinStrategy<K, U, V> implements JoinStrategy<K, U, V> {
 
  /**
   * Determines over how many shards a key will be split in a sharded join.
   * <p>
   * It is essential that implementations of this class are deterministic.
   */
  public static interface ShardingStrategy<K> extends Serializable {

    /**
     * Retrieve the number of shards over which the given key should be split.
     * @param key key for which shards are to be determined
     * @return number of shards for the given key, must be greater than 0
     */
    int getNumShards(K key);

  }
 
 
  private JoinStrategy<Pair<K, Integer>, U, V> wrappedJoinStrategy;
  private ShardingStrategy<K> shardingStrategy;
 
  /**
   * Instantiate with a constant number of shards to use for all keys.
   *
   * @param numShards number of shards to use
   */
  public ShardedJoinStrategy(int numShards) {
    this(new ConstantShardingStrategy<K>(numShards));
  }
 
  /**
   * Instantiate with a custom sharding strategy.
   *
   * @param shardingStrategy strategy to be used for sharding
   */
  public ShardedJoinStrategy(ShardingStrategy<K> shardingStrategy) {
    this.wrappedJoinStrategy = new DefaultJoinStrategy<Pair<K, Integer>, U, V>();
    this.shardingStrategy = shardingStrategy;
  }

  @Override
  public PTable<K, Pair<U, V>> join(PTable<K, U> left, PTable<K, V> right, JoinType joinType) {
   
    if (joinType == JoinType.FULL_OUTER_JOIN || joinType == JoinType.LEFT_OUTER_JOIN) {
      throw new UnsupportedOperationException("Join type " + joinType + " not supported by ShardedJoinStrategy");
    }
   
    PTypeFamily ptf = left.getTypeFamily();
    PTableType<Pair<K, Integer>, U> shardedLeftType = ptf.tableOf(ptf.pairs(left.getKeyType(), ptf.ints()), left.getValueType());
    PTableType<Pair<K, Integer>, V> shardedRightType = ptf.tableOf(ptf.pairs(right.getKeyType(), ptf.ints()), right.getValueType());
    PTableType<K, Pair<U,V>> outputType = ptf.tableOf(left.getKeyType(), ptf.pairs(left.getValueType(), right.getValueType()));
   
    PTable<Pair<K,Integer>,U> shardedLeft = left.parallelDo("Pre-shard left", new PreShardLeftSideFn<K, U>(shardingStrategy), shardedLeftType);
    PTable<Pair<K,Integer>,V> shardedRight = right.parallelDo("Pre-shard right", new PreShardRightSideFn<K, V>(shardingStrategy), shardedRightType);

    PTable<Pair<K, Integer>, Pair<U, V>> shardedJoined = wrappedJoinStrategy.join(shardedLeft, shardedRight, joinType);
   
    return shardedJoined.parallelDo("Unshard", new UnshardFn<K, U, V>(), outputType);
  }

  private static class PreShardLeftSideFn<K, U> extends DoFn<Pair<K, U>, Pair<Pair<K, Integer>, U>> {

    private ShardingStrategy<K> shardingStrategy;

    public PreShardLeftSideFn(ShardingStrategy<K> shardingStrategy) {
      this.shardingStrategy = shardingStrategy;
    }

    @Override
    public void process(Pair<K, U> input, Emitter<Pair<Pair<K, Integer>, U>> emitter) {
      K key = input.first();
      int numShards = shardingStrategy.getNumShards(key);
      if (numShards < 1) {
        throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + key);
      }
      for (int i = 0; i < numShards; i++) {
        emitter.emit(Pair.of(Pair.of(key, i), input.second()));
      }
    }

  }

  private static class PreShardRightSideFn<K, V> extends MapFn<Pair<K, V>, Pair<Pair<K, Integer>, V>> {

    private ShardingStrategy<K> shardingStrategy;
    private transient Random random;

    public PreShardRightSideFn(ShardingStrategy<K> shardingStrategy) {
      this.shardingStrategy = shardingStrategy;
    }
   
    @Override
    public void initialize() {
      random = new Random(getTaskAttemptID().getTaskID().getId());
    }

    @Override
    public Pair<Pair<K, Integer>, V> map(Pair<K, V> input) {
      K key = input.first();
      V value = input.second();
      int numShards = shardingStrategy.getNumShards(key);
      if (numShards < 1) {
        throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + key);
      }
     
      return Pair.of(Pair.of(key, random.nextInt(numShards)), value);
    }

  }

  private static class UnshardFn<K, U, V> extends MapFn<Pair<Pair<K, Integer>, Pair<U, V>>, Pair<K, Pair<U, V>>> {

    @Override
    public Pair<K, Pair<U, V>> map(Pair<Pair<K, Integer>, Pair<U, V>> input) {
      return Pair.of(input.first().first(), input.second());
    }

  }
 
  /**
   * Sharding strategy that returns the same number of shards for all keys.
   */
  private static class ConstantShardingStrategy<K> implements ShardingStrategy<K> {

    private int numShards;

    public ConstantShardingStrategy(int numShards) {
      this.numShards = numShards;
    }
   
    @Override
    public int getNumShards(K key) {
      return numShards;
    }
   
  }

}
TOP

Related Classes of org.apache.crunch.lib.join.ShardedJoinStrategy

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.