package partition;
import java.io.IOException;
import java.util.ArrayList;
import kmer.ProteinKmerBitFeatureVector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Logger;
/**
* Hadoop program that takes as input a directory of (id, kmer feature vectors) records
* and a directory containing (cluster id, kmers found in that cluster vector) records.
* The output is (cluster id, list of query id's found that cluster that share at least
* 1 kmer in common with the cluster.
*
* @author cmhill
*/
public class FindClustersForSequences extends Configured implements Tool {
private static final String USAGE = "FindClustersForSequences INPUT_FEATURE_VECTOR INPUT_CLUSTER_DIR OUTPUT_BASE_DIR [KMER] [NUM_TASKS]";
public static final String ALPHABET_SIZE = "ALPHABET_SIZE";
public static final String DEBUG_SET = "DEBUG";
public static final String KMER_LENGTH = "KMER_LENGTH";
public static final String LOG_DELIM = ",";
public static final String MIN_MATCHES = "MIN_MATCHES";
public static final String NUM_CLUSTERS = "NUM_CLUSTERS";
private static boolean USING_COUNT_VECTOR = false;
private static boolean DEBUG = true;
private static final int MAX_REDUCES = 200;
private static final int MAX_MAPS = 200;
private static int NUMBER_CLUSTERS = 57;
private static final Logger LOG = Logger.getLogger(FindClustersForSequences.class);
/**
* This mapper takes as input the (query id, kmer feature vectors) records, and outputs
* the (cluster id, query id) for all clusters the query shares at least 1 kmer with.
*/
public static class Map extends Mapper<LongWritable, BytesWritable, LongWritable, LongWritable> {
// Store the database cluster kmer feature vectors.
private ProteinKmerBitFeatureVector[] clusterFeatureVectors;
private int alphabetSize;
private int kmerLength;
private int minMatches = 1;
/**
* Load the cluster kmer feature vector into memory.
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
FileSystem fs = FileSystem.get(conf);
kmerLength = context.getConfiguration().getInt(KMER_LENGTH, 3);
int numCenters = context.getConfiguration().getInt(NUM_CLUSTERS, 100);
alphabetSize = context.getConfiguration().getInt(ALPHABET_SIZE, 21);
minMatches = context.getConfiguration().getInt(MIN_MATCHES, 5);
DEBUG = context.getConfiguration().getBoolean(DEBUG_SET, false);
// Initialize the cluster feature vector array.
clusterFeatureVectors = new ProteinKmerBitFeatureVector[numCenters];
String input = context.getConfiguration().get("CLUSTER_INPUT_PATH");
if (DEBUG)
LOG.info("Loading cluster centers from: " + input);
Path clusterCentersPath = new Path(input);
// Go through each part-r-* file and add the cluster center kmer feature vector.
for (FileStatus srcFileStatus : fs.listStatus(clusterCentersPath)) {
if (srcFileStatus.isDir())
continue;
SequenceFile.Reader reader = new SequenceFile.Reader(fs, srcFileStatus.getPath(), conf);
LongWritable key = new LongWritable();
BytesWritable value = new BytesWritable();
while (reader.next(key, value) != false) {
if (key.toString().equals(""))
break;
//clusterFeatureVectors[(int) key.get()] = new ProteinKmerBitFeatureVector();
byte[] valueVector = new byte[(int) Math.ceil(Math.pow(alphabetSize, kmerLength) / ProteinKmerBitFeatureVector.BITS_IN_BYTE)];
System.arraycopy(value.getBytes(), 0, valueVector, 0, valueVector.length);
// Add the kmer feature vector to the array of clusters'.
clusterFeatureVectors[(int) key.get()] = new ProteinKmerBitFeatureVector(valueVector, kmerLength, alphabetSize);
key.set(0);
}
reader.close();
}
if (DEBUG) {
for (int i = 0; i < clusterFeatureVectors.length; i++) {
LOG.info(("Cluster " + i + ": " + clusterFeatureVectors[i].printKmers()));
}
}
}
/*
* Check if any of of the bytes overlap.
* TODO(cmhill): This doesnt check for multiple overlaps within a single byte.
*/
public boolean overlap(byte[] a, byte[] b) {
int numMatches = 0;
if (DEBUG)
LOG.info("New run");
for (int i = 0; i < b.length; i++) {
if ((a[i] & b[i]) != 0) {
if (DEBUG)
LOG.info("i: " + i + ", a = " + a[i] + ", b = " + b[i]);
++numMatches;
if (numMatches >= minMatches)
return true;
}
}
return false;
}
public void map(LongWritable key, BytesWritable value, Context context)
throws IOException, InterruptedException {
if (DEBUG) {
LOG.info("Value length: " + value.getLength());
byte[] featureVector = new byte[(int) Math.ceil(Math.pow(alphabetSize, kmerLength)
/ ProteinKmerBitFeatureVector.BITS_IN_BYTE)];
// Very slow.
// System.arraycopy(value.getBytes(), 0, featureVector, 0, featureVector.length);
ProteinKmerBitFeatureVector vector = new ProteinKmerBitFeatureVector(
featureVector, context.getConfiguration().getInt(KMER_LENGTH, 3));
LOG.info(key.get() + ": " + vector.printKmers());
}
byte[] valueBytes = value.getBytes();
for (int i = 0; i < clusterFeatureVectors.length; i++) {
// Have to check length this way, since byteswritable gets padded.
if (value.getLength() != clusterFeatureVectors[i].getFeatureVector().length)
continue;
if (overlap(valueBytes /*featureVector*/, clusterFeatureVectors[i].getFeatureVector())) {
context.write(key, new LongWritable(new Long(i)));
}
}
}
}
public static class Reduce extends Reducer<LongWritable, LongWritable, LongWritable, Text> {
public void reduce(LongWritable key, Iterable<LongWritable> values, Context context)
throws IOException, InterruptedException {
// ArrayList<Long> clusters = new ArrayList<Long>();
StringBuilder clusters = new StringBuilder("");
for (LongWritable value : values) {
clusters.append(value.get() + "\t");
}
context.write(key, new Text(clusters.toString()));
}
}
public static void main(String[] args) {
int result = 1;
try {
result = ToolRunner.run(new FindClustersForSequences(), args);
} catch (Exception e) {
e.printStackTrace();
System.out.println("Job failed.");
}
System.exit(result);
}
@Override
public int run(String[] args) throws Exception {
// TODO Auto-generated method stub
if(args.length < 4) {
System.out.println(USAGE);
return -1;
}
String sequenceInputPath = args[0];
String clusterInputPath = args[1];
String baseOutputPath = args[2];
Job job = new Job(getConf(), "FindClustersForSequences");
job.setJarByClass(FindClustersForSequences.class);
int mapTasks = MAX_MAPS;
int reduceTasks = MAX_REDUCES;
int kmerLength = 3;
int numClusters = 100;
int minMatches = 1;
if (args.length > 3) {
kmerLength = Integer.parseInt(args[3]);
job.getConfiguration().setInt(KMER_LENGTH, kmerLength);
if(args.length > 4) {
numClusters = Integer.parseInt(args[4]);
if(args.length > 5) {
int numTasks = Integer.parseInt(args[5]);
mapTasks = numTasks;
reduceTasks = numTasks;
if(args.length > 6) {
minMatches = Integer.parseInt(args[6]);
}
}
}
}
// Delete the output directory if it exists already.
FileSystem.get(job.getConfiguration()).delete(new Path(baseOutputPath), true);
job.setNumReduceTasks(reduceTasks);
LOG.info("Tool name: FindClustersForSequences");
LOG.info(" - inputFeatureVector: " + sequenceInputPath);
LOG.info(" - inputClusterFeatureVector: " + clusterInputPath);
LOG.info(" - outputDir: " + baseOutputPath);
LOG.info(" - kmerLength: " + kmerLength);
LOG.info(" - numClusters: " + numClusters);
LOG.info(" - minKmerMatches: " + minMatches);
job.setOutputKeyClass(LongWritable.class);
job.setOutputValueClass(Text.class);
job.setMapOutputKeyClass(LongWritable.class);
job.setMapOutputValueClass(LongWritable.class);
job.setMapperClass(FindClustersForSequences.Map.class);
job.setReducerClass(FindClustersForSequences.Reduce.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(TextOutputFormat.class);
FileInputFormat.addInputPath(job, new Path(sequenceInputPath));
FileOutputFormat.setOutputPath(job, new Path(baseOutputPath));
/* Setup the key value pairs */
job.getConfiguration().set("BASE_OUTPUT_DIR", baseOutputPath);
job.getConfiguration().set("CLUSTER_INPUT_PATH", clusterInputPath);
job.getConfiguration().setInt(NUM_CLUSTERS, numClusters);
job.getConfiguration().setInt(MIN_MATCHES, minMatches);
long startTime = System.currentTimeMillis();
boolean result = job.waitForCompletion(true);
LOG.info((System.currentTimeMillis() - startTime) + LOG_DELIM
+ mapTasks + LOG_DELIM + reduceTasks);
return result ? 0 : 1;
}
}