Package org.apache.crunch.lib.join

Source Code of org.apache.crunch.lib.join.BloomFilterJoinStrategy

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.crunch.lib.join;

import java.io.ByteArrayOutputStream;
import java.io.IOException;

import org.apache.avro.io.BinaryEncoder;
import org.apache.avro.io.EncoderFactory;
import org.apache.avro.reflect.ReflectDatumWriter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.crunch.CrunchRuntimeException;
import org.apache.crunch.DoFn;
import org.apache.crunch.Emitter;
import org.apache.crunch.FilterFn;
import org.apache.crunch.MapFn;
import org.apache.crunch.PCollection;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
import org.apache.crunch.ParallelDoOptions;
import org.apache.crunch.SourceTarget;
import org.apache.crunch.impl.mr.MRPipeline;
import org.apache.crunch.io.ReadableSourceTarget;
import org.apache.crunch.materialize.MaterializableIterable;
import org.apache.crunch.types.PType;
import org.apache.crunch.types.PTypeFamily;
import org.apache.crunch.types.avro.AvroType;
import org.apache.crunch.types.avro.AvroTypeFamily;
import org.apache.crunch.types.avro.Avros;
import org.apache.crunch.types.writable.WritableType;
import org.apache.crunch.types.writable.WritableTypeFamily;
import org.apache.crunch.types.writable.Writables;
import org.apache.crunch.util.DistCache;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.bloom.BloomFilter;
import org.apache.hadoop.util.bloom.Key;
import org.apache.hadoop.util.hash.Hash;

/**
* Join strategy that uses a <a href="http://en.wikipedia.org/wiki/Bloom_filter">Bloom filter</a>
* that is trained on the keys of the left-side table to filter the key/value pairs of the right-side
* table before sending through the shuffle and reduce phase.
* <p>
* This strategy is useful in cases where the right-side table contains many keys that are not
* present in the left-side table. In this case, the use of the Bloom filter avoids a
* potentially costly shuffle phase for data that would never be joined to the left side.
*/
public class BloomFilterJoinStrategy<K, U, V> implements JoinStrategy<K, U, V> {
 
  private static final Log LOG = LogFactory.getLog(BloomFilterJoinStrategy.class);

  private int vectorSize;
  private int nbHash;
  private JoinStrategy<K, U, V> delegateJoinStrategy;

  /**
   * Instantiate with the expected number of unique keys in the left table.
   * <p>
   * The {@link DefaultJoinStrategy} will be used to perform the actual join after filtering.
   *
   * @param numElements expected number of unique keys
   */
  public BloomFilterJoinStrategy(int numElements) {
    this(numElements, 0.05f);
  }
 
  /**
   * Instantiate with the expected number of unique keys in the left table, and the acceptable
   * false positive rate for the Bloom filter.
   * <p>
   * The {@link DefaultJoinStrategy} will be used to perform the actual join after filtering.
   *
   * @param numElements expected number of unique keys
   * @param falsePositiveRate acceptable false positive rate for Bloom Filter
   */
  public BloomFilterJoinStrategy(int numElements, float falsePositiveRate) {
    this(numElements, falsePositiveRate, new DefaultJoinStrategy<K, U, V>());
  }
 
  /**
   * Instantiate with the expected number of unique keys in the left table, and the acceptable
   * false positive rate for the Bloom filter, and an underlying join strategy to delegate to.
   *
   * @param numElements expected number of unique keys
   * @param falsePositiveRate acceptable false positive rate for Bloom Filter
   * @param delegateJoinStrategy join strategy to delegate to after filtering
   */
  public BloomFilterJoinStrategy(int numElements, float falsePositiveRate, JoinStrategy<K,U,V> delegateJoinStrategy) {
    this.vectorSize = getOptimalVectorSize(numElements, falsePositiveRate);
    this.nbHash = getOptimalNumHash(numElements, vectorSize);
    this.delegateJoinStrategy = delegateJoinStrategy;
  }
 
  /**
   * Calculates the optimal vector size for a given number of elements and acceptable false
   * positive rate.
   */
  private static int getOptimalVectorSize(int numElements, float falsePositiveRate) {
    return (int) (-numElements * (float)Math.log(falsePositiveRate) / Math.pow(Math.log(2), 2));
  }
 
  /**
   * Calculates the optimal number of hash functions to be used.
   */
  private static int getOptimalNumHash(int numElements, float vectorSize) {
    return (int)Math.round(vectorSize * Math.log(2) / numElements);
  }
 
  @Override
  public PTable<K, Pair<U, V>> join(PTable<K, U> left, PTable<K, V> right, JoinType joinType) {
   
    if (joinType != JoinType.INNER_JOIN && joinType != JoinType.LEFT_OUTER_JOIN) {
      throw new IllegalStateException("JoinType " + joinType + " is not supported for BloomFilter joins");
    }
   
    PTable<K,V> filteredRightSide = null;
    if (left.getPipeline() instanceof MRPipeline) {
      PType<BloomFilter> bloomFilterType = getBloomFilterType(left.getTypeFamily());
      PCollection<BloomFilter> bloomFilters = left.keys().parallelDo(
                                                    "Create bloom filters",
                                                    new CreateBloomFilterFn(vectorSize, nbHash, left.getKeyType()),
                                                    bloomFilterType);
     
      MaterializableIterable<BloomFilter> materializableIterable = (MaterializableIterable<BloomFilter>) bloomFilters.materialize();
      FilterKeysWithBloomFilterFn<K, V> filterKeysFn = new FilterKeysWithBloomFilterFn<K, V>(
                                                              materializableIterable.getPath().toString(),
                                                              vectorSize, nbHash,
                                                              left.getKeyType(), bloomFilterType);
     
      ParallelDoOptions.Builder optionsBuilder = ParallelDoOptions.builder();
      if (materializableIterable.isSourceTarget()) {
        optionsBuilder.sourceTargets((SourceTarget) materializableIterable.getSource());
      }
     
      filteredRightSide = right.parallelDo("Filter right side with BloomFilters",
                                                filterKeysFn, right.getPTableType(), optionsBuilder.build());
 
      // TODO This shouldn't be necessary due to the ParallelDoOptions, but it seems to be needed somehow
      left.getPipeline().run();
    } else {
      LOG.warn("Not using Bloom filters outside of MapReduce context");
      filteredRightSide = right;
    }
   
    return delegateJoinStrategy.join(left, filteredRightSide, joinType);
  }
 
  /**
   * Creates Bloom filter(s) for filtering of right-side keys.
   */
  private static class CreateBloomFilterFn<K> extends DoFn<K, BloomFilter> {

    private int vectorSize;
    private int nbHash;
    private transient BloomFilter bloomFilter;
    private transient MapFn<K,byte[]> keyToBytesFn;
    private PType<K> ptype;
   
    public CreateBloomFilterFn(int vectorSize, int nbHash, PType<K> ptype) {
      this.vectorSize = vectorSize;
      this.nbHash = nbHash;
      this.ptype = ptype;
    }
   
    @Override
    public void initialize() {
      super.initialize();
      bloomFilter = new BloomFilter(vectorSize, nbHash, Hash.MURMUR_HASH);
      ptype.initialize(getConfiguration());
      keyToBytesFn = getKeyToBytesMapFn(ptype, getConfiguration());
    }
   
    @Override
    public void process(K input, Emitter<BloomFilter> emitter) {
      bloomFilter.add(new Key(keyToBytesFn.map(input)));
    }
   
    @Override
    public void cleanup(Emitter<BloomFilter> emitter) {
      emitter.emit(bloomFilter);
    }
   
  }
 
  /**
   * Filters right-side keys with a Bloom filter before passing them off to the delegate join strategy.
   */
  private static class FilterKeysWithBloomFilterFn<K,V> extends FilterFn<Pair<K, V>> {
   
    private String inputPath;
    private int vectorSize;
    private int nbHash;
    private PType<K> keyType;
    private PType<BloomFilter> bloomFilterPType;
    private BloomFilter bloomFilter;
    private transient MapFn<K,byte[]> keyToBytesFn;
   
    public FilterKeysWithBloomFilterFn(String inputPath, int vectorSize, int nbHash, PType<K> keyType, PType<BloomFilter> bloomFilterPtype) {
      this.inputPath = inputPath;
      this.vectorSize = vectorSize;
      this.nbHash = nbHash;
      this.keyType = keyType;
      this.bloomFilterPType = bloomFilterPtype;
    }
   
   
    private Path getCacheFilePath() {
      Path local = DistCache.getPathToCacheFile(new Path(inputPath), getConfiguration());
      if (local == null) {
        throw new CrunchRuntimeException("Can't find local cache file for '" + inputPath + "'");
      }
      return local;
    }

    @Override
    public void configure(Configuration conf) {
      DistCache.addCacheFile(new Path(inputPath), conf);
    }
   
    @Override
    public void initialize() {
      super.initialize();
     
      bloomFilterPType.initialize(getConfiguration());
      keyType.initialize(getConfiguration());
     
      keyToBytesFn = getKeyToBytesMapFn(keyType, getConfiguration());

      ReadableSourceTarget<BloomFilter> sourceTarget = bloomFilterPType.getDefaultFileSource(
          getCacheFilePath());
      Iterable<BloomFilter> iterable = null;
      try {
        iterable = sourceTarget.read(getConfiguration());
      } catch (IOException e) {
        throw new CrunchRuntimeException("Error reading right-side of map side join: ", e);
      }

      bloomFilter = new BloomFilter(vectorSize, nbHash, Hash.MURMUR_HASH);
      for (BloomFilter subFilter : iterable) {
        bloomFilter.or(subFilter);
      }
    }

    @Override
    public boolean accept(Pair<K, V> input) {
      Key key = new Key(keyToBytesFn.map(input.first()));
      return bloomFilter.membershipTest(key);
    }
  }
 
  /**
   * Returns the appropriate MapFn for converting the key type into byte arrays.
   */
  private static <K> MapFn<K,byte[]> getKeyToBytesMapFn(PType<K> ptype, Configuration conf) {
    if (ptype instanceof AvroType) {
      return new AvroToBytesFn<K>((AvroType)ptype, conf);
    } else if (ptype instanceof WritableType) {
      return new WritableToBytesFn<K>((WritableType)ptype, conf);
    } else {
      throw new IllegalStateException("Unrecognized PType: " + ptype);
    }
  }
 
  /**
   * Returns the appropriate PType for serializing BloomFilters using the same
   * type family as is used for the input collections.
   */
  private static PType<BloomFilter> getBloomFilterType(PTypeFamily typeFamily) {
    if (typeFamily.equals(AvroTypeFamily.getInstance())) {
      return Avros.writables(BloomFilter.class);
    } else if (typeFamily.equals(WritableTypeFamily.getInstance())) {
      return Writables.writables(BloomFilter.class);
    } else {
      throw new IllegalStateException("Unrecognized PTypeFamily: " + typeFamily);
    }
  }
 
  /**
   * Converts a Writable into a byte array so that it can be added to a BloomFilter.
   */
  private static class WritableToBytesFn<T> extends MapFn<T,byte[]>{
   
    private WritableType<T,?> ptype;
    private DataOutputBuffer dataOutputBuffer;
   
    public WritableToBytesFn(WritableType<T,?> ptype, Configuration conf) {
      this.ptype = ptype;
      dataOutputBuffer = new DataOutputBuffer();
    }

    @Override
    public byte[] map(T input) {
      dataOutputBuffer.reset();
      Writable writable = (Writable) ptype.getOutputMapFn().map(input);
      try {
        writable.write(dataOutputBuffer);
      } catch (IOException e) {
        throw new CrunchRuntimeException(e);
      }
      byte[] output = new byte[dataOutputBuffer.getLength()];
      System.arraycopy(dataOutputBuffer.getData(), 0, output, 0, dataOutputBuffer.getLength());
      return output;
    }
   
  }
 
  /**
   * Converts an Avro value into a byte array so that it can be added to a Bloom filter.
   */
  private static class AvroToBytesFn<T> extends MapFn<T,byte[]> {
   
    private AvroType<T> ptype;
    private BinaryEncoder encoder;
    private ReflectDatumWriter datumWriter;
   
    public AvroToBytesFn(AvroType<T> ptype, Configuration conf) {
      this.ptype = ptype;
      datumWriter = Avros.getReflectDataFactory(conf).getWriter(ptype.getSchema());
    }

    @Override
    public byte[] map(T input) {
      Object datum = ptype.getOutputMapFn().map(input);
      ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
      encoder = EncoderFactory.get().binaryEncoder(byteArrayOutputStream, encoder);
      try {
        datumWriter.write(datum, encoder);
        encoder.flush();
      } catch (IOException e) {
        throw new CrunchRuntimeException(e);
      }
      return byteArrayOutputStream.toByteArray();
    }
   
   
   
  }

}
TOP

Related Classes of org.apache.crunch.lib.join.BloomFilterJoinStrategy

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.