package ivory.lsh.eval;
import ivory.core.data.document.WeightedIntDocVector;
import ivory.core.util.CLIRUtils;
import ivory.lsh.data.Signature;
import java.io.IOException;
import java.net.URI;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import edu.umd.cloud9.io.SequenceFileUtils;
import edu.umd.cloud9.io.map.HMapSFW;
import edu.umd.cloud9.io.pair.PairOfFloatInt;
import edu.umd.cloud9.io.pair.PairOfInts;
import edu.umd.cloud9.io.pair.PairOfWritables;
/**
* A class to extract the similarity list of each sample document, either by performing dot product
* between the doc vectors or finding hamming distance between signatures.
*
* @author ferhanture
*
*/
public class BruteForcePwsim extends Configured implements Tool {
private static final Logger LOG = Logger.getLogger(BruteForcePwsim.class);
static enum Pairs {
Total, Emitted, DEBUG, DEBUG2, Total2
};
static enum Sample {
Size
};
/**
* For every document (weighted int doc vector) in the sample, find all other docs that have
* cosine similarity higher than some given threshold.
*
* @author ferhanture
*/
public static class MyMapperWeightedIntDocVectors extends MapReduceBase implements
Mapper<IntWritable, WeightedIntDocVector, IntWritable, PairOfFloatInt> {
private List<PairOfWritables<IntWritable, WeightedIntDocVector>> vectors;
private float threshold;
public void configure(JobConf job) {
threshold = job.getFloat("Ivory.CosineThreshold", -1);
LOG.info("Threshold = " + threshold);
String sampleFile = job.get("Ivory.SampleFile");
LOG.info("Reading signature from : " + sampleFile);
try {
Path[] localFiles = DistributedCache.getLocalCacheFiles(job);
for (Path localFile : localFiles) {
if (localFile.toString().contains(sampleFile)) {
vectors = SequenceFileUtils.readFile(localFile, FileSystem.getLocal(job));
}
}
if (vectors == null) throw new RuntimeException("Sample file not found at " + sampleFile);
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Error reading doc vectors from " + sampleFile);
}
LOG.info("Read " + vectors.size() + " sample weighted int doc vectors");
}
public void map(IntWritable docno, WeightedIntDocVector docvector,
OutputCollector<IntWritable, PairOfFloatInt> output, Reporter reporter) throws IOException {
for (int i = 0; i < vectors.size(); i++) {
IntWritable sampleDocno = vectors.get(i).getLeftElement();
WeightedIntDocVector fromSample = vectors.get(i).getRightElement();
float cs = CLIRUtils.cosine(docvector.getWeightedTerms(), fromSample.getWeightedTerms());
if (cs >= threshold) {
output.collect(new IntWritable(sampleDocno.get()), new PairOfFloatInt(cs, docno.get()));
}
}
}
}
/**
* For every document (term doc vector) in the sample, find all other docs that have cosine
* similarity higher than some given threshold.
*
* @author ferhanture
*/
public static class MyMapperTermDocVectors extends MapReduceBase implements
Mapper<IntWritable, HMapSFW, IntWritable, PairOfFloatInt> {
private List<PairOfWritables<IntWritable, HMapSFW>> vectors;
float threshold;
public void configure(JobConf job) {
LOG.setLevel(Level.INFO);
threshold = job.getFloat("Ivory.CosineThreshold", -1);
LOG.info("Threshold = " + threshold);
String sampleFile = job.get("Ivory.SampleFile");
LOG.info("Reading signature from " + sampleFile);
try {
Path[] localFiles = DistributedCache.getLocalCacheFiles(job);
for (Path localFile : localFiles) {
if (localFile.toString().contains(sampleFile)) {
vectors = SequenceFileUtils.readFile(localFile, FileSystem.getLocal(job));
}
}
if (vectors == null) throw new RuntimeException("Sample file not found at " + sampleFile);
} catch (Exception e) {
throw new RuntimeException("Error reading doc vectors from " + sampleFile);
}
LOG.info("Read " + vectors.size() + " sample doc vectors");
}
public void map(IntWritable docno, HMapSFW docvector,
OutputCollector<IntWritable, PairOfFloatInt> output, Reporter reporter) throws IOException {
for (int i = 0; i < vectors.size(); i++) {
reporter.incrCounter(Pairs.Total, 1);
IntWritable sampleDocno = vectors.get(i).getLeftElement();
HMapSFW fromSample = vectors.get(i).getRightElement();
float cs = CLIRUtils.cosine(docvector, fromSample);
if (cs >= threshold) {
LOG.debug(sampleDocno + "," + fromSample + "\n" + fromSample.length());
LOG.debug(docno + "," + docvector + "\n" + docvector.length());
LOG.debug(cs);
reporter.incrCounter(Pairs.Emitted, 1);
output.collect(new IntWritable(sampleDocno.get()), new PairOfFloatInt(cs, docno.get()));
}
}
}
}
/**
* For every document (signature) in the sample, find all other docs that are closer than some
* given hamming distance.
*
* @author ferhanture
*/
public static class MyMapperSignature extends MapReduceBase implements
Mapper<IntWritable, Signature, IntWritable, PairOfFloatInt> {
private List<PairOfWritables<IntWritable, Signature>> signatures;
private int maxDist;
public void configure(JobConf job) {
maxDist = (int) job.getFloat("Ivory.MaxHammingDistance", -1);
LOG.info("Threshold = " + maxDist);
String sampleFile = job.get("Ivory.SampleFile");
LOG.info("Reading signature from " + sampleFile);
// read doc ids of sample into vectors
try {
Path[] localFiles = DistributedCache.getLocalCacheFiles(job);
for (Path localFile : localFiles) {
if (localFile.toString().contains(sampleFile)) {
signatures = SequenceFileUtils.readFile(localFile, FileSystem.getLocal(job));
}
}
if (signatures == null) throw new RuntimeException("Sample file not found at " + sampleFile);
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Error reading sample signatures!");
}
LOG.info(signatures.size());
}
public void map(IntWritable docno, Signature signature,
OutputCollector<IntWritable, PairOfFloatInt> output, Reporter reporter) throws IOException {
for (int i = 0; i < signatures.size(); i++) {
reporter.incrCounter(Pairs.Total, 1);
IntWritable sampleDocno = signatures.get(i).getLeftElement();
Signature fromSample = signatures.get(i).getRightElement();
int dist = signature.hammingDistance(fromSample, maxDist);
if (dist <= maxDist) {
output.collect(new IntWritable(sampleDocno.get()), new PairOfFloatInt(-dist, docno.get()));
reporter.incrCounter(Pairs.Emitted, 1);
}
}
}
}
/**
* This reducer reduces the number of pairs per sample document to a given number
* (Ivory.NumResults).
*
* @author ferhanture
*/
public static class MyReducer extends MapReduceBase implements
Reducer<IntWritable, PairOfFloatInt, PairOfInts, Text> {
int numResults;
TreeSet<PairOfFloatInt> list = new TreeSet<PairOfFloatInt>();
PairOfInts keyOut = new PairOfInts();
Text valOut = new Text();
NumberFormat nf;
public void configure(JobConf conf) {
LOG.setLevel(Level.INFO);
numResults = conf.getInt("Ivory.NumResults", Integer.MAX_VALUE);
nf = NumberFormat.getInstance();
nf.setMaximumFractionDigits(3);
nf.setMinimumFractionDigits(3);
}
public void reduce(IntWritable key, Iterator<PairOfFloatInt> values,
OutputCollector<PairOfInts, Text> output, Reporter reporter) throws IOException {
list.clear();
while (values.hasNext()) {
PairOfFloatInt p = values.next();
if (!list.add(new PairOfFloatInt(p.getLeftElement(), p.getRightElement()))) {
LOG.debug("Not added: " + p);
} else {
LOG.debug("Added: " + p);
}
reporter.incrCounter(Pairs.Total, 1);
}
LOG.debug(list.size());
int cntr = 0;
while (!list.isEmpty() && cntr < numResults) {
PairOfFloatInt pair = list.pollLast();
LOG.debug("output " + cntr + "=" + pair);
keyOut.set(pair.getRightElement(), key.get()); // first english docno, then foreign language
// docno
valOut.set(nf.format(pair.getLeftElement()));
output.collect(keyOut, valOut);
cntr++;
}
}
}
public int run(String[] args) throws Exception {
if ( parseArgs(args) < 0 ) {
return printUsage();
}
JobConf job = new JobConf(getConf(), BruteForcePwsim.class);
FileSystem fs = FileSystem.get(job);
fs.delete(new Path(outputPath), true);
int numMappers = 100;
int numReducers = 1;
FileInputFormat.setInputPaths(job, new Path(inputPath));
FileOutputFormat.setOutputPath(job, new Path(outputPath));
FileOutputFormat.setCompressOutput(job, false);
job.set("mapred.child.java.opts", "-Xmx2048m");
job.setInt("mapred.map.max.attempts", 10);
job.setInt("mapred.reduce.max.attempts", 10);
job.setInt("mapred.task.timeout", 6000000);
job.setNumMapTasks(numMappers);
job.setNumReduceTasks(numReducers);
job.setInputFormat(SequenceFileInputFormat.class);
job.setMapOutputKeyClass(IntWritable.class);
job.setMapOutputValueClass(PairOfFloatInt.class);
job.setOutputKeyClass(PairOfInts.class);
job.setOutputValueClass(FloatWritable.class);
job.set("Ivory.SampleFile", sampleFile.substring(sampleFile.lastIndexOf("/") + 1));
DistributedCache.addCacheFile(new URI(sampleFile), job);
if (inputType.contains("signature")) {
job.setMapperClass(MyMapperSignature.class);
job.setFloat("Ivory.MaxHammingDistance", threshold);
} else if (inputType.contains("vector")) {
if (inputType.contains("term")) {
job.setMapperClass(MyMapperTermDocVectors.class);
} else {
job.setMapperClass(MyMapperWeightedIntDocVectors.class);
}
job.setFloat("Ivory.CosineThreshold", threshold);
}
job.setJobName("BruteForcePwsim_type=" + inputType + "_cosine=" + threshold + "_top=" + (numResults > 0 ? numResults : "all"));
if (numResults > 0) {
job.setInt("Ivory.NumResults", numResults);
}
job.setReducerClass(MyReducer.class);
LOG.info("Running job " + job.getJobName());
JobClient.runJob(job);
return 0;
}
private static final String INPUT_PATH_OPTION = "input";
private static final String OUTPUT_PATH_OPTION = "output";
private static final String INPTYPE_OPTION = "type";
private static final String THRESHOLD_OPTION = "cosineT";
private static final String SAMPLE_OPTION = "sample";
private static final String TOPN_OPTION = "topN";
private static final String LIBJARS_OPTION = "libjars";
private Options options;
private float threshold;
private int numResults;
private String sampleFile, inputPath, outputPath, inputType;
private int printUsage() {
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp( this.getClass().getCanonicalName(), options );
return -1;
}
@SuppressWarnings("static-access")
private int parseArgs(String[] args) {
options = new Options();
options.addOption(OptionBuilder.withDescription("path to input doc vectors or signatures").withArgName("path").hasArg().isRequired().create(INPUT_PATH_OPTION));
options.addOption(OptionBuilder.withDescription("path to output directory").withArgName("path").hasArg().isRequired().create(OUTPUT_PATH_OPTION));
options.addOption(OptionBuilder.withDescription("cosine similarity threshold when type=*docvector, hamming distance threshold when type=signature").withArgName("threshold").hasArg().isRequired().create(THRESHOLD_OPTION));
options.addOption(OptionBuilder.withDescription("path to file with sample doc vectors or signatures").withArgName("path").hasArg().isRequired().create(SAMPLE_OPTION));
options.addOption(OptionBuilder.withDescription("type of input").withArgName("signature|intdocvector|termdocvector").hasArg().isRequired().create(INPTYPE_OPTION));
options.addOption(OptionBuilder.withDescription("keep only N results for each source document").withArgName("N").hasArg().create(TOPN_OPTION));
options.addOption(OptionBuilder.withDescription("Hadoop option to load external jars").withArgName("jar packages").hasArg().create(LIBJARS_OPTION));
CommandLine cmdline;
CommandLineParser parser = new GnuParser();
try {
cmdline = parser.parse(options, args);
} catch (ParseException exp) {
System.err.println("Error parsing command line: " + exp.getMessage());
return -1;
}
inputPath = cmdline.getOptionValue(INPUT_PATH_OPTION);
outputPath = cmdline.getOptionValue(OUTPUT_PATH_OPTION);
threshold = Float.parseFloat(cmdline.getOptionValue(THRESHOLD_OPTION));
sampleFile = cmdline.getOptionValue(SAMPLE_OPTION);
inputType = cmdline.getOptionValue(INPTYPE_OPTION);
numResults = cmdline.hasOption(TOPN_OPTION) ? Integer.parseInt(cmdline.getOptionValue(TOPN_OPTION)) : -1;
return 0;
}
public static void main(String[] args) throws Exception {
ToolRunner.run(new BruteForcePwsim(), args);
return;
}
}