Package ca.queensu.cs.sail.lucenelda

Source Code of ca.queensu.cs.sail.lucenelda.LDAQueryAllInDirectory

/*
####################################################################################
Stephen W. Thomas
sthomas@cs.queensu.ca
Queen's University

LDAQueryAllInDirectory.java

(Invoked from command line, or via main() method.)

This command-line class reads all queries in the given directory, and throws them against a
specified (prebuilt) index using LDA. The results are output in a given output directory.
There are two options: K, and scoringCode.
See below for the specification.

####################################################################################
*/


package ca.queensu.cs.sail.lucenelda;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.ObjectInputStream;
import java.io.PrintWriter;

import java.util.HashMap;

import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.payloads.AveragePayloadFunction;
import org.apache.lucene.search.payloads.PayloadTermQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.NIOFSDirectory;
import org.apache.lucene.search.BooleanQuery;

import org.apache.commons.io.FileUtils;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Switch;
import com.martiansoftware.jsap.UnflaggedOption;

public class LDAQueryAllInDirectory {
 
  private static int maxHits = 500;
  private static IndexReader reader = null;
  private static IndexSearcher searcher = null;
  static LDAHelper lda = null;
 
  private static final Logger logger = Logger.getRootLogger();
 
  public static void main(String[] args) throws Exception {
   
    // Set up the Apache log4j logger, only if we need to (another class or test case or ant
    // may have already set up the logger.)
    if (!logger.getAllAppenders().hasMoreElements()) {
      BasicConfigurator.configure();
      logger.setLevel(Level.INFO);
    }

   
    // Use the JSAP library to intelligently set up and parse our command
    // line options
    JSAP jsap = new JSAP();
   
    UnflaggedOption opt0 = new UnflaggedOption("indexDir").setStringParser(
        JSAP.STRING_PARSER).setRequired(true);
    opt0.setHelp("The directory containing the pre-build Lucene index.");
   
    UnflaggedOption opt0a = new UnflaggedOption("LDAIndexDir").setStringParser(
        JSAP.STRING_PARSER).setRequired(true);
    opt0a.setHelp("The directory containing the pre-build LDA index.");

    UnflaggedOption opt1 = new UnflaggedOption("queryDir").setStringParser(
        JSAP.STRING_PARSER).setRequired(true);
    opt1.setHelp("The input directory containing queries to run against the specified index.");

    UnflaggedOption opt2 = new UnflaggedOption("resultsDir").setStringParser(
        JSAP.STRING_PARSER).setRequired(true);
    opt2.setHelp("The output directory for the results of each query: one file per original query in queryDirName.");

    FlaggedOption opt3 = new FlaggedOption("K")
        .setStringParser(JSAP.INTEGER_PARSER).setRequired(false)
        .setLongFlag("K").setDefault("0");
    opt3.setHelp("If multiple LDA configuration were run (i.e., multiple Ks), then specify which one to use."
        + "Default: the lowest K.");
   
    FlaggedOption opt4 = new FlaggedOption("scoringCode")
    .setStringParser(JSAP.INTEGER_PARSER).setRequired(false)
    .setLongFlag("scoringCode").setDefault("1");
    opt4.setHelp("An integer code that specifies the scoring metric that should be used. "
    + "1=conditional probability.");

    Switch sw0 = new Switch("help").setDefault("false").setLongFlag("help");
    sw0.setHelp("Prints this message.");

    jsap.registerParameter(sw0);
    jsap.registerParameter(opt0);
    jsap.registerParameter(opt0a);
    jsap.registerParameter(opt1);
    jsap.registerParameter(opt2);
    jsap.registerParameter(opt3);
    jsap.registerParameter(opt4);

    // check whether the command line was valid, and if it wasn't,
    // display usage information and exit.
    JSAPResult config = jsap.parse(args);
    if (!config.success()) {
      for (java.util.Iterator errs = config.getErrorMessageIterator(); errs
          .hasNext();) {
        logger.error("Error: " + errs.next());
      }
      displayHelp(config, jsap);
      return;
    }

    if (config.getBoolean("help")) {
      displayHelp(config, jsap);
      return;
    }
   
    String indexDirName    = config.getString("indexDir");
    String LDAIndexName    = config.getString("LDAIndexDir");
    String queryDirName    = config.getString("queryDir");
    String resultsDirName   = config.getString("resultsDir");
    int K                 = config.getInt("K");
    int scoringCode       = config.getInt("scoringCode");

        // Make sure output directory exists
        File outDirF = new File(resultsDirName);
        if (!outDirF.exists()){
            outDirF.mkdirs();
        }
   
    // Open the serialized LDAHelper object
    FileInputStream fis  = null;
    ObjectInputStream in = null;
    try{
      fis = new FileInputStream(LDAIndexName);
      in = new ObjectInputStream(fis);
      lda = (LDAHelper)in.readObject();
      in.close();
    }
    catch(Exception ex){
      ex.printStackTrace();
    }
   
    int idx = lda.which(K);
   
    // Open the index
    File indexDir = new File(indexDirName);
    Directory dir = NIOFSDirectory.open(indexDir);
   
    reader   = IndexReader.open(dir, true);
    searcher = new IndexSearcher(reader);

        BooleanQuery.setMaxClauseCount(8092);
   
    // Run every query in the directory!
    File queryDir = new File(queryDirName);
    File[] files  = queryDir.listFiles();
    for (int i = 0; i < files.length; ++i){
      File f = files[i];
      if (f.isDirectory() || f.isHidden() || !f.exists() || !f.canRead()){
        continue;
      }
      String query = FileUtils.readFileToString(f);

            // tmp hack: make sure query doesn't have numbers or punctuation
            query = query.replaceAll("\\^.", " ");
            query = query.replaceAll("[1234567890\\p{Punct}\\n]", " ");
     
      // Make sure the query isn't blank
      if (!query.matches("^\\s*$")){
       
        // Open the output file for results
        String outFile     = resultsDirName + "/" + f.getName();
        FileWriter fwriter   = new FileWriter(outFile);
        PrintWriter out   = new PrintWriter(fwriter);
       
        logger.info("Executing query for " + f.toString() + "; results will be placed in " + outFile.toString());
       
       
        // First, we need to find all the topics in the query: for each term, find out all topics that contain this term;
        // Then, take the union of all in the topics of all the terms
        // Then, build the Boolean query below with multiple PayloadTermQuery()s
        float queryOpt[]    = new float[lda.scens.get(idx).K];
        BooleanQuery bquery = new BooleanQuery();
             
        // For each topic, sum up the scores of the words in the query
        String[] querySplitParts = query.split("\\s+");
        for (int k = 0; k < lda.scens.get(idx).K; ++k){
          float score = 0.0f;
          for (int j = 0; j<querySplitParts.length; ++j){
            if (lda.scens.get(idx).termMap.containsKey(querySplitParts[j])){ // Make sure the key exists in the map (it might not, due to vocab mismatch)
              int idx2 = lda.scens.get(idx).termMap.get(querySplitParts[j]);
              score = score + lda.scens.get(idx).phi[k][idx2];
            }
          }
         
          // Now, add this topic to the list, if the score is nonzero
          if (score > 0.01f){
            //System.out.printf("Adding topic %d with score %f\n", k, score);
            PayloadTermQuery fsq1 = new PayloadTermQuery(new Term("topicspayload"+lda.scens.get(idx).K, "p"+k), new AveragePayloadFunction(), false);
            bquery.add(fsq1, Occur.SHOULD);           
            queryOpt[k] = score;
          }
        }
       
        //logger.info("LDA query: " + bquery);

        // Actually execute the query
        TopDocs hits = searcher.search(bquery, maxHits);
        ScoreDoc[] scoreDocs = hits.scoreDocs;
       
        logger.info("Found " +hits.totalHits + " hits\n");
       
        // Save a filename -> docID mapping
        HashMap<String, Integer> hm = new HashMap<String, Integer>();
        for (int n = 0; n < scoreDocs.length; ++n) {
          ScoreDoc sd = scoreDocs[n];
          int docId = sd.doc;
          Document d = searcher.doc(docId);
          String fileName = d.get("file");
          hm.put(fileName,  docId);
        }
       
        // Rerank the results, based on the custom LDA scoring scheme
        HashMap<String, Float> sorted = lda.reRank(searcher, hits, queryOpt, lda.scens.get(idx).K);
     
       
        // Print our results
        int counter = 0;
                int numToOutput = Math.min(maxHits, sorted.size());
        for (String fileName : sorted.keySet()){
                    if (counter >= numToOutput){
                        break;
                    }
          float score      = sorted.get(fileName);
          int docId      = hm.get(fileName);
          Document d      = searcher.doc(docId);

          out.printf("%s,%4.3f\n", fileName, score);
          ++counter;
        }
        out.close();
      }
    }
   
    // Close the index to save memory
    reader.close();
  }
 
  /* Use JSAP to display command-line usage information */
  private static void displayHelp(JSAPResult config, JSAP jsap) {
    System.err.println();
    System.err.println("Usage: java " + LDAQueryAllInDirectory.class.getName());
    System.err.println("                " + jsap.getUsage());
    System.err.println();
    System.err.println(jsap.getHelp());
    System.err.println();
  }
}
TOP

Related Classes of ca.queensu.cs.sail.lucenelda.LDAQueryAllInDirectory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.