Package srmdata

Source Code of srmdata.MultiLabelClassificationNSDL

package srmdata;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Fieldable;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermDocs;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.NumericRangeQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.RAMDirectory;

import srmdata.StructuredRelevanceModel.DescendingScoreComp;
import srmdata.StructuredRelevanceModel.Score;

public class MultiLabelClassificationNSDL {

  public static final String TEST_INDEX_NAME = "../../multi_label_test_index";
  public static final String TRAIN_INDEX_NAME = "../../multi_label_train_index";
  public static final String PREDICTION_OUTFILE_NAME = "../../outputs/prediction_output";
  public static final String OUTPUT_FILE_NAME = "../../outputs/output";
  public static final String SCORE_FILE_NAME = "../../outputs/log_scores";
  private static int numTesting = 200;
  private static int numTraining = 80000;

  private static Set<Integer> testDocIDs;
  private static Set<Integer> trainDocIDs;

  private static Map<String,Double> time_taken;
  private static boolean generateDataAgain = false;
  private static final int MAX_RESULTS = 10;

  private static boolean shouldTokenize(String field) {
    if (field.equals("subject"))
      return true;
    else
      return false;
  }
 
  static class PredictionResult {
    int docID;
    String URI;
    Set<String> actualValues;
    LinkedHashSet<String> predictedValues;

    Set<String> actualTokens;
    List<List<String>> predictedTokens;
   
    PredictionResult(String URI) {
      actualValues = new HashSet<String>();
      predictedValues = new LinkedHashSet<String>();
      actualTokens = new LinkedHashSet<String>();
      predictedTokens = new ArrayList<List<String>>();
      this.URI = URI;
    }

    List<String> tokenize(String strValue, boolean shouldTokenize) {
      if (!shouldTokenize)
        return Arrays.asList(strValue);
      List<String> tokens = new ArrayList<String>();
      String[] strs = strValue.split("[^a-zA-Z]");
      for (int i = 0; i < strs.length; ++i) {
        if (strs[i].length() > 2)
          tokens.add(strs[i]);
      }
      return tokens;
    }
   
    void addActualValue(String actualValue, boolean shouldTokenize) {
      actualValue = actualValue.toLowerCase();
      actualValues.add(actualValue);
      actualTokens.addAll(tokenize(actualValue, shouldTokenize));
    }
   
    void addPredictedValue(String predictedValue, boolean shouldTokenize) {
      predictedValue = predictedValue.toLowerCase();
      predictedValues.add(predictedValue);
      predictedTokens.add(tokenize(predictedValue, shouldTokenize));
    }

    public LinkedHashSet<String> getPredictedValues() {
      return predictedValues;
    }

    public Set<String> getActualValues() {
      return actualValues;
    }

    public List<List<String>> getPredictedTokenList() {
      return predictedTokens;
    }
   
    public boolean isRelevant(String token) {
      return actualTokens.contains(token);
    }
  }
 
  static {
    testDocIDs = new HashSet<Integer>();
    trainDocIDs = new HashSet<Integer>();
    time_taken = new LinkedHashMap<String,Double>();
  }
 
  public static void main(String[] args) throws Exception {

    long t1, t2;

    if (generateDataAgain ) {
      t1 = System.nanoTime();   
        generateMultiLabelDataset();
      t2 = System.nanoTime();
      System.out.println("Time Taken for Generation: " + (t2-t1)/1E9);
    }

    List<PredictionResult> results_subject = new ArrayList<PredictionResult>();
    List<PredictionResult> results_audience = new ArrayList<PredictionResult>();
    List<PredictionResult> results_educationLevel = new ArrayList<PredictionResult>();

    Map<String, List<PredictionResult>> predMap = new HashMap<String, List<PredictionResult>>();
    predMap.put("subject", results_subject);
    predMap.put("audience", results_audience);
    predMap.put("educationLevel", results_educationLevel);
   
    t1 = System.nanoTime();
      predictFields(predMap);
    t2 = System.nanoTime();
    System.out.println("Time Taken for Subject Prediction: " + (t2-t1)/1E9);

    BufferedWriter writer;
    writer = new BufferedWriter(new FileWriter(OUTPUT_FILE_NAME + "_" + numTesting + "_" + numTraining + "_subject"));
    dumpOutput(writer, results_subject);
    writer.flush();
    writer.close();

    writer = new BufferedWriter(new FileWriter(OUTPUT_FILE_NAME + "_" + numTesting + "_" + numTraining + "_audience"));
    dumpOutput(writer, results_audience);
    writer.flush();
    writer.close();

    writer = new BufferedWriter(new FileWriter(OUTPUT_FILE_NAME + "_" + numTesting + "_" + numTraining + "_educationLevel"));
    dumpOutput(writer, results_educationLevel);
    writer.flush();
    writer.close();

  }

  private static void dumpOutput(BufferedWriter writer, List<PredictionResult> results) throws IOException {

    File trainFile = new File(TRAIN_INDEX_NAME);
    File testFile = new File(TEST_INDEX_NAME);

    IndexReader trainIR = IndexReader.open(FSDirectory.open(trainFile), true);
    IndexReader testIR = IndexReader.open(FSDirectory.open(testFile), true);

    numTesting = testIR.numDocs();
    // first search for all documents having at least one value for each of
    // content, description, title, subject, audience fields
    IndexSearcher searcher1 = new IndexSearcher(trainIR);
    IndexSearcher searcher2 = new IndexSearcher(testIR);

    NumericRangeQuery<Integer> nq = NumericRangeQuery.newIntRange("num_subject", 1, 100000, true, true);
    TopDocs t1 = searcher1.search(nq, trainIR.numDocs());
    TopDocs t2 = searcher2.search(nq, testIR.numDocs());
    ScoreDoc[] hits1 = t1.scoreDocs;
    ScoreDoc[] hits2 = t2.scoreDocs;
    double docCount = hits1.length + hits2.length;
    double percentage = (docCount/(numTesting+numTraining));
    writer.append("Number of documents having subject: " + docCount + "   percentage: " +  percentage + "\n");

    nq = NumericRangeQuery.newIntRange("title_len", 1, 100000, true, true);
    t1 = searcher1.search(nq, trainIR.numDocs());
    t2 = searcher2.search(nq, testIR.numDocs());
    hits1 = t1.scoreDocs;
    hits2 = t2.scoreDocs;
    docCount = hits1.length + hits2.length;
    percentage = (docCount/(numTesting+numTraining));
    writer.append("Number of documents having title: " + docCount + "   percentage: " +  percentage + "\n");

    nq = NumericRangeQuery.newIntRange("desc_len", 1, 100000, true, true);
    t1 = searcher1.search(nq, trainIR.numDocs());
    t2 = searcher2.search(nq, testIR.numDocs());
    hits1 = t1.scoreDocs;
    hits2 = t2.scoreDocs;
    docCount = hits1.length + hits2.length;
    percentage = (docCount/(numTesting+numTraining));
    writer.append("Number of documents having description: " + docCount + "   percentage: " +  percentage + "\n");

    nq = NumericRangeQuery.newIntRange("content_len", 1, 100000, true, true);
    t1 = searcher1.search(nq, trainIR.numDocs());
    t2 = searcher2.search(nq, testIR.numDocs());
    hits1 = t1.scoreDocs;
    hits2 = t2.scoreDocs;
    docCount = hits1.length + hits2.length;
    percentage = (docCount/(numTesting+numTraining));
    writer.append("Number of documents having content: " + docCount + "   percentage: " +  percentage + "\n");

    nq = NumericRangeQuery.newIntRange("num_audience", 1, 100000, true, true);
    t1 = searcher1.search(nq, trainIR.numDocs());
    t2 = searcher2.search(nq, testIR.numDocs());
    hits1 = t1.scoreDocs;
    hits2 = t2.scoreDocs;
    docCount = hits1.length + hits2.length;
    percentage = (docCount/(numTesting+numTraining));
    writer.append("Number of documents having audience: " + docCount + "   percentage: " +  percentage + "\n");

    nq = NumericRangeQuery.newIntRange("num_educationLevel", 1, 100000, true, true);
    t1 = searcher1.search(nq, trainIR.numDocs());
    t2 = searcher2.search(nq, testIR.numDocs());
    hits1 = t1.scoreDocs;
    hits2 = t2.scoreDocs;
    docCount = hits1.length + hits2.length;
    percentage = (docCount/(numTesting+numTraining));
    writer.append("Number of documents having educationLevel: " + docCount + "   percentage: " +  percentage + "\n");

    trainIR.close();
    testIR.close();
   
    writer.append("num_testing:" + numTesting + "\n");
    for (int docID : testDocIDs)
      writer.append(docID + ";");
    writer.append("\n");

    writer.append("num_training:" + numTraining + "\n");
    for (int docID : trainDocIDs)
      writer.append(docID + ";");
    writer.append("\n");

    writer.append("num_threads:" + PriorCalculator.numThreads + "\n");
   
    for (Map.Entry<String, Double> ttaken : time_taken.entrySet())
      writer.append("time_taken: " + ttaken.getKey() + " : " + ttaken.getValue() + "\n");

    double[] precisionAtK = new double[MAX_RESULTS];
    double map_score = 0;

    writer.append("prediction_results: " + numTesting + "\n");
    for (PredictionResult result : results) {

      writer.append("\n");
      writer.append("docID: " + result.docID + "    URI: " + result.URI + "\n");

      writer.append("actualValues: " + result.getActualValues().size() + "\n");
      for (String actualValue : result.getActualValues())
        writer.append(actualValue + "\n");

      writer.append("predictedValues: " + result.getPredictedValues().size() + "\n");
      for (String predictedValue : result.getPredictedValues())
        writer.append("\t\t" + predictedValue + "\n");


      List<List<String>> predictedTokenList = result.getPredictedTokenList();

      double relevance = 0;
      double fractional_correct = 0;

      int nRelevant = 0;
      // each result in the predicted field values
      for (int rank = 1; rank <= predictedTokenList.size(); ++rank) {

        List<String> tokens = predictedTokenList.get(rank-1);
        // fraction of tokens matched
        double fracToken = 0;
        for (String token : tokens) {
          if (result.isRelevant(token)) {
            fracToken++;
          }
        }
        if (tokens.size() == 0)
          continue;
        fracToken = fracToken / tokens.size();

        if (fracToken > 0.0) {
          fractional_correct += fracToken;
          nRelevant++;
        }

        relevance += fracToken;
        precisionAtK[rank-1] += ((double)relevance) / rank;
        for (int i = 0; i < precisionAtK.length; ++i) {
          writer.append(precisionAtK[i] + " ");
        }
        writer.append("\n");
      }

      if (relevance != 0)
        map_score = map_score + fractional_correct / nRelevant;
    }

    map_score = map_score / numTesting;
   
    System.out.println("MAP Score: " + map_score);
    System.out.println("Precison@K: ");
    writer.append("Precison@K: ");
    for (int i = 0; i < precisionAtK.length; ++i) {
      precisionAtK[i] = precisionAtK[i] / results.size();
      System.out.print(precisionAtK[i] + " ");
      writer.append(precisionAtK[i] + " ");
    }
    System.out.println();
    writer.append("\n");
  }

  private static void predictFields(Map<String, List<PredictionResult>> fieldsToPredict) throws Exception {

    RAMDirectory trainRAMDirectory = new RAMDirectory(FSDirectory.open(new File(TRAIN_INDEX_NAME)));
    RAMDirectory testRAMDirectory = new RAMDirectory(FSDirectory.open(new File(TEST_INDEX_NAME)));
    IndexReader trainIR = IndexReader.open(trainRAMDirectory, true);
    IndexReader testIR  = IndexReader.open(testRAMDirectory, true);

    int nTrainDocs = trainIR.numDocs();
    int nTestDocs = testIR.numDocs();

    long t1, t2;
    double ttaken;

    double[][][] scores;
    int num_fields = 3;
    if (generateDataAgain) {
      scores = new double[num_fields][][];
  //    StructuredRelevanceModel srm = new StructuredRelevanceModel();
      t1 = System.nanoTime();
  //    scores[0] = srm.computePriors(testIR, trainIR, "title");
      PriorCalculator priorCalcTitle = new PriorCalculator(testIR, trainIR, "title");
      scores[0] = priorCalcTitle.computePriors();
      priorCalcTitle = null;
      t2 = System.nanoTime();
      ttaken = ((double)(t2-t1)) / 1E9;
      System.out.println("Time Taken Priors (title): " + ttaken);
      time_taken.put("title_model", ttaken);
 
      t1 = System.nanoTime();
  //      scores[1] = srm.computePriors(testIR, trainIR, "desc");
      PriorCalculator priorCalcDesc = new PriorCalculator(testIR, trainIR, "desc");
      scores[1] = priorCalcDesc.computePriors();
      priorCalcDesc = null;
      t2 = System.nanoTime();
      ttaken = ((double)(t2-t1)) / 1E9;
      System.out.println("Time Taken Priors (desc): " + ttaken);
      time_taken.put("desc_model", ttaken);
 
      t1 = System.nanoTime();
  //      scores[2] = srm.computePriors(testIR, trainIR, "content");
      PriorCalculator priorCalcContent = new PriorCalculator(testIR, trainIR, "content");
      scores[2] = priorCalcContent.computePriors();
      priorCalcContent = null;
      t2 = System.nanoTime();
      ttaken = ((double)(t2-t1)) / 1E9;
      System.out.println("Time Taken Priors (content): " + ttaken);
      time_taken.put("content_model", ttaken);
 
      dumpScores(scores);
    }
    else {
      scores = readScores();
    }

    double avg_title_len = findAverageFieldLength(trainIR, "title");
    double avg_desc_len = findAverageFieldLength(trainIR, "desc");
    double avg_content_len = findAverageFieldLength(trainIR, "content");

    System.out.println("avg_title_len: " + avg_title_len);
    System.out.println("avg_desc_len: " + avg_desc_len);
    System.out.println("avg_content_len: " + avg_content_len);

    for (String fieldToPredict : fieldsToPredict.keySet()) {
      List<PredictionResult> results = fieldsToPredict.get(fieldToPredict);
      Score[][] combined_score = new Score[nTestDocs][nTrainDocs];
      for (int i = 0; i < nTestDocs; ++i) {
        for (int j = 0; j < nTrainDocs; ++j) {
          combined_score[i][j] = new Score();
        }
      }
 
      for (int i = 0; i < nTrainDocs; ++i) {
        for (int j = 0; j < nTestDocs; ++j) {
          combined_score[j][i].docID = i;
//          combined_score[j][i].score = (1.0/avg_title_len)*Math.log(scores[0][i][j]) +
//              (1.0/avg_desc_len)*Math.log(scores[1][i][j]) + (1.0/avg_content_len)*Math.log(scores[2][i][j]);
//          combined_score[j][i].score = Math.exp(combined_score[j][i].score);
          combined_score[j][i].score =
              scores[0][i][j] * 0.45 +
              scores[1][i][j] * 0.45 +
              scores[2][i][j] * 0.10;
          combined_score[j][i].score = Math.pow(10.0, combined_score[j][i].score);
  //        if (scores[1][i][j] == 0.0) {
  //          System.out.print("Score is Zero: " + i + " " + j + " ");
  //        }
        }
      }
 
      t1 = System.nanoTime();
      DescendingScoreComp comp = new DescendingScoreComp();
      for (int i = 0; i < nTestDocs; ++i) {
        Arrays.sort(combined_score[i], comp);
      }
 
      int topN = 500;
      for (int i = 0; i < nTestDocs; ++i) {
        double total_score = 0.0;
        for (int j = 0; j < topN; ++j) {
          total_score += combined_score[i][j].score;
        }
 
        if (total_score == 0.0) {
          for (int j = 0; j < topN; ++j)
            combined_score[i][j].score = 0.0;
        }
        else {
          for (int j = 0; j < topN; ++j)
            combined_score[i][j].score /= total_score;
        }
      }
      t2 = System.nanoTime();
      System.out.println("Time Taken Normalization and Sorting: " + ((double)(t2-t1)) / 1E9);
 
      Map<String,Double> fieldCounts = new HashMap<String,Double>();
      for (int i = 0; i < nTestDocs; ++i) {
        fieldCounts.clear();
        for (int j = 0; j < topN; ++j) {
          int docID = combined_score[i][j].docID;
          combined_score[i][j].relevance = 0.0;
          Document doc = trainIR.document(docID);
          Fieldable[] fields = doc.getFieldables(fieldToPredict);
          for (Fieldable field : fields) {
            String fieldValue = field.stringValue();
            Double count = fieldCounts.get(fieldValue);
            if (count == null)
              count = 1.0 * combined_score[i][j].score;
            else
              count = count + combined_score[i][j].score;
            fieldCounts.put(fieldValue, count);
          }
        }
 
        List<Relevance> relevanceModel = new ArrayList<Relevance>();
        for (String fieldValue : fieldCounts.keySet()) {
          Relevance relevance = new Relevance();
          relevance.fieldValue = fieldValue;
          relevance.score = fieldCounts.get(fieldValue);
          relevanceModel.add(relevance);
        }
        Collections.sort(relevanceModel, new DescendingRelevanceComp());
 
        Document testDoc = testIR.document(i);
 
        PredictionResult result = new PredictionResult(testDoc.get("URI"));
//        System.out.println("docID:" + i + "      URI:" + testDoc.get("URI") + "       ID:" + testDoc.get("id"));
        for (Fieldable fieldable : testDoc.getFieldables(fieldToPredict)) {
          String actualValue = fieldable.stringValue();
//          System.out.print(actualValue + ":");
          result.addActualValue(actualValue, shouldTokenize(fieldToPredict));
        }
//        System.out.println(":");
//        System.out.print("*");
        result.docID = i;
        int max = (relevanceModel.size() < 10) ? relevanceModel.size() : 10;
        for (int j = 0; j < max; ++j) {
          Relevance relevance = relevanceModel.get(j);
          String predictedValue = relevance.fieldValue;
//          System.out.println("    " + predictedValue + "(" + relevance.score + ");");
          result.addPredictedValue(predictedValue, shouldTokenize(fieldToPredict));
        }
 
        results.add(result);
//        System.out.println();
      }
    }

    trainIR.close();
    testIR.close();
  }

  private static void dumpScores(double[][][] scores) throws Exception {
    BufferedWriter writer = new BufferedWriter(new FileWriter(new File(SCORE_FILE_NAME + "_" + numTesting + "_" + numTraining)));
    writer.write(numTraining + " " + numTesting + "\n");
    for (int i = 0; i < scores.length; ++i) {
      for (int j = 0; j < scores[i].length; ++j) {
        for (int k = 0; k < scores[i][j].length; ++k) {
          writer.write(" " + scores[i][j][k]);
        }
        writer.write("\n");
      }
    }
    writer.flush();
    writer.close();
  }

  private static double[][][] readScores() throws Exception {

    double[][][] scores;
    BufferedReader reader = new BufferedReader(new FileReader(new File(SCORE_FILE_NAME + "_" + numTesting + "_" + numTraining)));
    String line = reader.readLine();

    String[] sizes = line.split(" ");
    numTraining = Integer.parseInt(sizes[0]);
    numTesting = Integer.parseInt(sizes[1]);
    scores = new double[3][numTraining][numTesting];
   
    for (int i = 0; i < scores.length; ++i) {
      for (int j = 0; j < numTraining; ++j) {
        line = reader.readLine();
        String[] values = line.split(" ");
        int index = 0;
        for (int k = 0; k < values.length; ++k) {
          if (values[k].trim().length() == 0)
            continue;
          scores[i][j][index] = Double.parseDouble(values[k].trim());
//          System.out.print(" " + scores[i][j][index]);
          ++index;
        }
//        System.out.println();
//        System.exit(1);
      }
    }
    reader.close();

    return scores;
  }

  private static double findAverageFieldLength(IndexReader trainIR, String field) throws Exception {

    // find average lengths
    int[] docLengths = new int[numTraining];
   
    TermEnum terms = trainIR.terms();
    while (terms.next()) {
      Term term = terms.term();
      if (!term.field().equals(field))
        continue;
      TermDocs docs = trainIR.termDocs(term);
      while (docs.next()) {
        docLengths[docs.doc()]++;
      }
    }
   
    double avg = 0.0;
    for (int i = 0; i < docLengths.length; ++i) {
      avg = avg + docLengths[i];
    }
    avg = avg / docLengths.length;

    return avg;
  }

  static class Relevance {
    public String fieldValue;
    public double score = 0.0;
  }

  static class DescendingRelevanceComp implements Comparator<Relevance> {
    @Override
    public int compare(Relevance o1, Relevance o2) {
      Double diff = o2.score-o1.score;
      if (diff < 0)
        return -1;
      if (diff > 0)
        return 1;
      return 0;
    }
  }
 

  private static void generateMultiLabelDataset() throws Exception {

    File nsdl_global_file = new File(NSDLIndex.NSDL_GLOBAL_INDEX_DIR_NAME);
    IndexReader ir = IndexReader.open(FSDirectory.open(nsdl_global_file), true);

    IndexWriter testIW = NSDLIndex.createIndexWriter(TEST_INDEX_NAME);
    IndexWriter trainIW = NSDLIndex.createIndexWriter(TRAIN_INDEX_NAME);
    testIW.deleteAll();
    trainIW.deleteAll();

    // first search for all documents having at least one value for each of
    // content, description, title, subject, audience fields
    IndexSearcher searcher = new IndexSearcher(ir);

    NumericRangeQuery<Integer> nq1 = NumericRangeQuery.newIntRange("num_subject", 1, 100000, true, true);
    NumericRangeQuery<Integer> nq2 = NumericRangeQuery.newIntRange("num_educationLevel", 1, 10, true, true);
    NumericRangeQuery<Integer> nq3 = NumericRangeQuery.newIntRange("num_audience", 1, 10, true, true);
    NumericRangeQuery<Integer> nq4 = NumericRangeQuery.newIntRange("title_len", 1, 1000000, true, true);
    NumericRangeQuery<Integer> nq5 = NumericRangeQuery.newIntRange("content_len", 1, 100000000, true, true);
    NumericRangeQuery<Integer> nq6 = NumericRangeQuery.newIntRange("desc_len", 1, 100000000, true, true);

    BooleanQuery nq = new BooleanQuery();
    nq.add(nq1, BooleanClause.Occur.MUST);
    nq.add(nq2, BooleanClause.Occur.MUST);
    nq.add(nq3, BooleanClause.Occur.MUST);
    nq.add(nq4, BooleanClause.Occur.MUST);
    nq.add(nq5, BooleanClause.Occur.MUST);
    nq.add(nq6, BooleanClause.Occur.MUST);

    TopDocs t = searcher.search(nq, ir.numDocs());
    ScoreDoc[] hits = t.scoreDocs;
    Collections.shuffle(Arrays.asList(hits));
    System.out.println("Number of documents eligible for testing set: " + hits.length);

    // select the document ids from the global index to go into the testing index
    for (int i = 0; i < numTesting; ++i) {
      int docID = (int) (Math.random() * hits.length);
      testDocIDs.add(hits[docID].doc);
    }

    int totalDocs = ir.numDocs();
    for (int docID = 0; docID < totalDocs; ++docID) {
      if (trainIW.numDocs() >= numTraining && testIW.numDocs() >= numTesting)
        break;
      Document doc = ir.document(docID);
      if (testDocIDs.contains(docID)) {
        testIW.addDocument(doc);
      }
      else {
        if (trainIW.numDocs() < numTraining) {
          trainDocIDs.add(docID);
          trainIW.addDocument(doc);
        }
      }
    }
    ir.close();

    trainIW.commit();
    trainIW.close();

    testIW.commit();
    testIW.close();
  }

  private static void generateMultiLabelDatasetWithSubject() throws Exception {

    File nsdl_global_file = new File(NSDLIndex.NSDL_GLOBAL_INDEX_DIR_NAME);
    IndexReader ir = IndexReader.open(FSDirectory.open(nsdl_global_file), true);

    IndexWriter testIW = NSDLIndex.createIndexWriter(TEST_INDEX_NAME);
    IndexWriter trainIW = NSDLIndex.createIndexWriter(TRAIN_INDEX_NAME);
    testIW.deleteAll();
    trainIW.deleteAll();

    // first search for all documents having at least one value for each of
    // content, description, title, subject, audience fields
    IndexSearcher searcher = new IndexSearcher(ir);

    NumericRangeQuery<Integer> nq1 = NumericRangeQuery.newIntRange("num_subject", 1, 100000, true, true);
    NumericRangeQuery<Integer> nq3 = NumericRangeQuery.newIntRange("num_audience", 1, 10, true, true);
    NumericRangeQuery<Integer> nq4 = NumericRangeQuery.newIntRange("num_educationLevel", 1, 10, true, true);
    NumericRangeQuery<Integer> nq5 = NumericRangeQuery.newIntRange("title_len", 1, 1000000, true, true);
    NumericRangeQuery<Integer> nq6 = NumericRangeQuery.newIntRange("content_len", 1, 100000000, true, true);
    NumericRangeQuery<Integer> nq7 = NumericRangeQuery.newIntRange("desc_len", 1, 100000000, true, true);

    BooleanQuery nq = new BooleanQuery();
    nq.add(nq1, BooleanClause.Occur.MUST);
    nq.add(nq3, BooleanClause.Occur.MUST);
    nq.add(nq4, BooleanClause.Occur.MUST);
    nq.add(nq5, BooleanClause.Occur.MUST);
    nq.add(nq6, BooleanClause.Occur.MUST);
    nq.add(nq7, BooleanClause.Occur.MUST);

    TopDocs t = searcher.search(nq, ir.numDocs());
    ScoreDoc[] hits = t.scoreDocs;
    Collections.shuffle(Arrays.asList(hits));
    System.out.println("Number of documents eligible for testing set: " + hits.length);

    for (int docID = 0; docID < hits.length; ++docID) {
      Document doc = ir.document(docID);
      if (docID < numTesting)
        testIW.addDocument(doc);
      else {
        if (trainIW.numDocs() > numTraining)
          continue;
        trainIW.addDocument(doc);
      }
    }
    ir.close();

    trainIW.commit();
    trainIW.close();

    testIW.commit();
    testIW.close();
  }

}
TOP

Related Classes of srmdata.MultiLabelClassificationNSDL

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.