/*
* Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* -------------------
* To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
* http://www.manning.com/ingersoll
*/
package com.tamingtext.classifier.bayes;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.TermFreqVector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** A utility to extract training data from a Lucene index using document term vectors to recreate the list of terms
* found in each document. Writes output in Mahout Bayes classifier input format */
public class ExtractTrainingData {
private static final Logger log = LoggerFactory.getLogger(ExtractTrainingData.class);
static final Map<String, PrintWriter> trainingWriters = new HashMap<String, PrintWriter>();
public static void main(String[] args) {
log.info("Command-line arguments: " + Arrays.toString(args));
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
Option inputOpt = obuilder.withLongName("dir")
.withRequired(true)
.withArgument(
abuilder.withName("dir")
.withMinimum(1)
.withMaximum(1).create())
.withDescription("Lucene index directory containing input data")
.withShortName("d").create();
Option categoryOpt = obuilder.withLongName("categories")
.withRequired(true)
.withArgument(
abuilder.withName("file")
.withMinimum(1)
.withMaximum(1).create())
.withDescription("File containing a list of categories")
.withShortName("c").create();
Option outputOpt = obuilder.withLongName("output")
.withRequired(false)
.withArgument(
abuilder.withName("output")
.withMinimum(1)
.withMaximum(1).create())
.withDescription("Output directory")
.withShortName("o").create();
Option categoryFieldsOpt =
obuilder.withLongName("category-fields")
.withRequired(true)
.withArgument(
abuilder.withName("fields")
.withMinimum(1)
.withMaximum(1)
.create())
.withDescription("Fields to match categories against (comma-delimited)")
.withShortName("cf").create();
Option textFieldsOpt =
obuilder.withLongName("text-fields")
.withRequired(true)
.withArgument(
abuilder.withName("fields")
.withMinimum(1)
.withMaximum(1)
.create())
.withDescription("Fields from which to extract training text (comma-delimited)")
.withShortName("tf").create();
Option useTermVectorsOpt = obuilder.withLongName("use-term-vectors")
.withDescription("Extract term vectors containing preprocessed data " +
"instead of unprocessed, stored text values")
.withShortName("tv").create();
Option helpOpt = obuilder.withLongName("help")
.withDescription("Print out help")
.withShortName("h").create();
Group group = gbuilder.withName("Options")
.withOption(inputOpt)
.withOption(categoryOpt)
.withOption(outputOpt)
.withOption(categoryFieldsOpt)
.withOption(textFieldsOpt)
.withOption(useTermVectorsOpt)
.create();
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
File inputDir = new File(cmdLine.getValue(inputOpt).toString());
if (!inputDir.isDirectory()) {
throw new IllegalArgumentException(inputDir + " does not exist or is not a directory");
}
File categoryFile = new File(cmdLine.getValue(categoryOpt).toString());
if (!categoryFile.isFile()) {
throw new IllegalArgumentException(categoryFile + " does not exist or is not a directory");
}
File outputDir = new File(cmdLine.getValue(outputOpt).toString());
outputDir.mkdirs();
if (!outputDir.isDirectory()) {
throw new IllegalArgumentException(outputDir + " is not a directory or could not be created");
}
Collection<String> categoryFields = stringToList(cmdLine.getValue(categoryFieldsOpt).toString());
if (categoryFields.size() < 1) {
throw new IllegalArgumentException("At least one category field must be spcified.");
}
Collection<String> textFields = stringToList(cmdLine.getValue(textFieldsOpt).toString());
if (categoryFields.size() < 1) {
throw new IllegalArgumentException("At least one text field must be spcified.");
}
boolean useTermVectors = cmdLine.hasOption(useTermVectorsOpt);
extractTraininingData(inputDir, categoryFile, categoryFields, textFields, outputDir, useTermVectors);
} catch (OptionException e) {
log.error("Exception", e);
CommandLineUtil.printHelp(group);
} catch (IOException e) {
log.error("IOException", e);
} finally {
closeWriters();
}
}
/**
* Extract training data from a lucene index.
* <p>
* Iterates over documents in the lucene index, the values in the categoryFields are inspected and if found to
* contain any of the strings found in the category file, a training data item will be emitted, assigned to the
* matching category and containing the terms found in the fields listed in textFields. Output is written to
* the output directory with one file per category.
* <p>
* The category file contains one line per category, each line contains a number of whitespace delimited strings.
* The first string on each line is the category name, while subsequent strings will be used to identify documents
* that belong in that category.
* <p>
* 'Technology Computers Macintosh' will cause documents that contain either 'Technology', 'Computers' or 'Machintosh'
* in one of their categoryFields to be assigned to the 'Technology' category.
*
*
* @param indexDir
* directory of lucene index to extract from
*
* @param maxDocs
* the maximum number of documents to process.
*
* @param categoryFile
* file containing category strings to extract
*
* @param categoryFields
* list of fields to match against category data
*
* @param textFields
* list of fields containing terms to extract
*
* @param outputDir
* directory to write output to
*
* @throws IOException
*/
public static void extractTraininingData(File indexDir, File categoryFile,
Collection<String> categoryFields, Collection<String> textFields, File outputDir, boolean useTermVectors) throws IOException {
log.info("Index dir: " + indexDir);
log.info("Category file: " + categoryFile);
log.info("Output dir: " + outputDir);
log.info("Category fields: " + categoryFields.toString());
log.info("Text fields: " + textFields.toString());
log.info("Use Term Vectors?: " + useTermVectors);
OpenObjectIntHashMap<String> categoryCounts = new OpenObjectIntHashMap<String>();
Map<String, List<String>> categories = readCategoryFile(categoryFile);
Directory dir = FSDirectory.open(indexDir);
IndexReader reader = IndexReader.open(dir, true);
int max = reader.maxDoc();
StringBuilder buf = new StringBuilder();
for (int i=0; i < max; i++) {
if (!reader.isDeleted(i)) {
Document d = reader.document(i);
String category = null;
// determine whether any of the fields in this document contain a
// category in the category list
fields: for (String field: categoryFields) {
for (Field f: d.getFields(field)) {
if (f.isStored() && !f.isBinary()) {
String fieldValue = f.stringValue().toLowerCase();
for (String cat: categories.keySet()) {
List<String> cats = categories.get(cat);
for (String c: cats) {
if (fieldValue.contains(c)) {
category = cat;
break fields;
}
}
}
}
}
}
if (category == null) continue;
// append the terms from each of the textFields to the training data for this document.
buf.setLength(0);
for (String field: textFields) {
if (useTermVectors) {
appendVectorTerms(buf, reader.getTermFreqVector(i, field));
}
else {
appendFieldText(buf, d.getField(field));
}
}
getWriterForCategory(outputDir, category).printf("%s\t%s\n", category, buf.toString());
categoryCounts.adjustOrPutValue(category, 1, 1);
}
}
if (log.isInfoEnabled()) {
StringBuilder b = new StringBuilder();
b.append("\nCatagory document counts:\n");
LinkedList<String> keyList = new LinkedList<String>();
categoryCounts.keysSortedByValue(keyList);
String key;
while (!keyList.isEmpty()) {
key = keyList.removeLast();
b.append(categoryCounts.get(key)).append('\t').append(key).append('\n');
}
log.info(b.toString());
}
}
/** Read the category file from disk, see {@link #extractTraininingData(File, File, Collection, Collection, File)}
* for a description of the format.
*
* @param categoryFile
* @return
* @throws IOException
*/
public static Map<String,List<String>> readCategoryFile(File categoryFile) throws IOException {
Map<String,List<String>> categoryMap = new HashMap<String, List<String>>();
BufferedReader rin = new BufferedReader(new InputStreamReader(new FileInputStream(categoryFile), "UTF-8"));
String line;
while ((line = rin.readLine()) != null) {
String[] parts = line.trim().toLowerCase().split("\\s+");
if (parts.length > 0) {
String key = parts[0];
for (String e: parts) {
List<String> entries = categoryMap.get(key);
if (entries == null) {
entries = new LinkedList<String>();
categoryMap.put(key, entries);
}
entries.add(e);
}
}
}
rin.close();
return categoryMap;
}
/** Obtain a writer for the training data assigned to the the specified category.
* <p>
* Maintains an internal hash of writers used for a category which must be closed by {@link #closeWriters()}.
* <p>
*
* @param outputDir
* @param category
* @return
* @throws IOException
*/
protected static PrintWriter getWriterForCategory(File outputDir, String category) throws IOException {
PrintWriter out = trainingWriters.get(category);
if (out == null) {
out = new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(outputDir, category))));
trainingWriters.put(category, out);
}
return out;
}
/** Close writers opened by {@link #getWriterForCategory(File, String)} */
protected static void closeWriters() {
for (PrintWriter p: trainingWriters.values()) {
p.close();
}
}
/** Append the contents of the specified termVector to a buffer containing a list of terms
*
* @param buf
* @param tv
*/
protected static void appendVectorTerms(StringBuilder buf, TermFreqVector tv) {
if (tv == null) return;
String[] terms = tv.getTerms();
int[] frequencies = tv.getTermFrequencies();
for (int j=0; j < terms.length; j++) {
int freq = frequencies[j];
String term = terms[j];
for (int k=0; k < freq; k++) {
buf.append(term).append(' ');
}
}
}
/** Append the contents of the specified field to buffer containing text,
* normalizing whitespace in the process.
*
* @param buf
* @param f
*/
protected static void appendFieldText(StringBuilder buf, Field f) {
if (f == null) return;
if (f.isBinary()) return;
if (!f.isStored()) return;
if (buf.length() > 0) buf.append(' ');
String s = f.stringValue();
s = s.replaceAll("\\s+", " "); // normalize whitespace.
buf.append(s);
}
/** Split a comma-delimited set of strings into a list
*
* @param input
* @return
*/
private static Collection<String> stringToList(String input) {
if (input == null || input.equals("")) return Collections.emptyList();
String[] parts = input.split("\\s*,\\s*");
return Arrays.asList(parts);
}
}