/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation; either version 2.1 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
* License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library; if not, write to the Free Software Foundation,
* Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
package joshua.prefix_tree;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
import java.util.logging.Level;
import java.util.logging.Logger;
import joshua.corpus.Corpus;
import joshua.corpus.alignment.AlignmentGrids;
import joshua.corpus.alignment.Alignments;
import joshua.corpus.alignment.mm.MemoryMappedAlignmentGrids;
import joshua.corpus.mm.MemoryMappedCorpusArray;
import joshua.corpus.suffix_array.FrequentPhrases;
import joshua.corpus.suffix_array.ParallelCorpusGrammarFactory;
import joshua.corpus.suffix_array.SuffixArrayFactory;
import joshua.corpus.suffix_array.Suffixes;
import joshua.corpus.suffix_array.mm.MemoryMappedSuffixArray;
import joshua.corpus.vocab.SymbolTable;
import joshua.corpus.vocab.Vocabulary;
import joshua.decoder.JoshuaConfiguration;
import joshua.util.Cache;
import joshua.util.io.BinaryIn;
/**
* Main program to extract hierarchical phrase-based statistical
* translation rules from an aligned parallel corpus using the
* suffix array techniques of Lopez (2008).
*
* @author Lane Schwartz
* @version $LastChangedDate:2008-11-13 13:13:31 -0600 (Thu, 13 Nov 2008) $
* @see "Lopez (2008)"
*/
public class ExtractRules {
/** Logger for this class. */
private static final Logger logger =
Logger.getLogger(ExtractRules.class.getName());
private String encoding = "UTF-8";
private String outputFile = "";
private String sourceFileName = "";
private String sourceSuffixesFileName = "";
private String targetFileName = "";
private String targetSuffixesFileName = "";
private String alignmentsFileName = "";
private String commonVocabFileName = "";
private String lexCountsFileName = "";
private String testFileName = "";
private String frequentPhrasesFileName = "";
private int cacheSize = Cache.DEFAULT_CAPACITY;
private int maxPhraseSpan = 10;
private int maxPhraseLength = 10;
private int maxNonterminals = 2;
private int minNonterminalSpan = 2;
private boolean sentenceInitialX = true;
private boolean sentenceFinalX = true;
private boolean edgeXViolates = true;
private boolean requireTightSpans = true;
private boolean binaryCorpus = false;
private String alignmentsType = "AlignmentGrids";
private boolean keepTree = true;
private int ruleSampleSize = 300;
private boolean printPrefixTree = false;
private int maxTestSentences = Integer.MAX_VALUE;
private int startingSentence = 1;
private boolean usePrecomputedFrequentPhrases = true;
public ExtractRules() {
}
public void setUsePrecomputedFrequentPhrases(boolean usePrecomputedFrequentPhrases) {
this.usePrecomputedFrequentPhrases = usePrecomputedFrequentPhrases;
}
public void setSourceFileName(String sourceFileName) {
this.sourceFileName = sourceFileName;
}
public void setTargetFileName(String targetFileName) {
this.targetFileName = targetFileName;
}
public void setAlignmentsFileName(String alignmentsFileName) {
this.alignmentsFileName = alignmentsFileName;
}
public void setLexCountsFileName(String lexCountsFileName) {
this.lexCountsFileName = lexCountsFileName;
}
public void setStartingSentence(int startingSentence) {
this.startingSentence = startingSentence;
}
public void setMaxPhraseSpan(int maxPhraseSpan) {
this.maxPhraseSpan = maxPhraseSpan;
}
public void setMaxPhraseLength(int maxPhraseLength) {
this.maxPhraseLength = maxPhraseLength;
}
public void setMaxNonterminals(int maxNonterminals) {
this.maxNonterminals = maxNonterminals;
}
public void setMinNonterminalSpan(int minNonterminalSpan) {
this.minNonterminalSpan = minNonterminalSpan;
}
public void setCacheSize(int cacheSize) {
this.cacheSize = cacheSize;
}
public void setMaxTestSentences(int maxTestSentences) {
this.maxTestSentences = maxTestSentences;
}
public void setJoshDir(String joshDir) {
this.sourceFileName = joshDir + File.separator + "source.corpus";
this.targetFileName = joshDir + File.separator + "target.corpus";
this.commonVocabFileName = joshDir + File.separator + "common.vocab";
this.lexCountsFileName = joshDir + File.separator + "lexicon.counts";
this.sourceSuffixesFileName = joshDir + File.separator + "source.suffixes";
this.targetSuffixesFileName = joshDir + File.separator + "target.suffixes";
this.alignmentsFileName = joshDir + File.separator + "alignment.grids";
this.alignmentsType = "MemoryMappedAlignmentGrids";
this.frequentPhrasesFileName = joshDir + File.separator + "frequentPhrases";
this.binaryCorpus = true;
}
public void setTestFile(String testFileName) {
this.testFileName = testFileName;
}
public void setOutputFile(String outputFile) {
this.outputFile = outputFile;
}
public void setEncoding(String encoding) {
this.encoding = encoding;
}
public void setSentenceInitialX(boolean sentenceInitialX) {
this.sentenceInitialX = sentenceInitialX;
}
public void setSentenceFinalX(boolean sentenceFinalX) {
this.sentenceFinalX = sentenceFinalX;
}
public void setEdgeXViolates(boolean edgeXViolates) {
this.edgeXViolates = edgeXViolates;
}
public void setRequireTightSpans(boolean requireTightSpans) {
this.requireTightSpans = requireTightSpans;
}
public void setKeepTree(boolean keepTree) {
this.keepTree = keepTree;
}
public void setRuleSampleSize(int ruleSampleSize) {
this.ruleSampleSize = ruleSampleSize;
}
public void setPrintPrefixTree(boolean printPrefixTree) {
this.printPrefixTree = printPrefixTree;
}
public ParallelCorpusGrammarFactory getGrammarFactory() throws IOException, ClassNotFoundException {
////////////////////////////////
// Common vocabulary //
////////////////////////////////
if (logger.isLoggable(Level.INFO)) logger.info("Constructing empty common vocabulary");
Vocabulary commonVocab = new Vocabulary();
int numSourceWords, numSourceSentences;
int numTargetWords, numTargetSentences;
String binaryCommonVocabFileName = this.commonVocabFileName;
if (binaryCorpus) {
if (logger.isLoggable(Level.INFO)) logger.info("Initializing common vocabulary from binary file " + binaryCommonVocabFileName);
ObjectInput in = BinaryIn.vocabulary(binaryCommonVocabFileName);
commonVocab.readExternal(in);
numSourceWords = Integer.MIN_VALUE;
numSourceSentences = Integer.MIN_VALUE;
numTargetWords = Integer.MIN_VALUE;
numTargetSentences = Integer.MIN_VALUE;
} else {
if (logger.isLoggable(Level.INFO)) logger.info("Initializing common vocabulary with source corpus " + sourceFileName);
int[] sourceWordsSentences = Vocabulary.initializeVocabulary(sourceFileName, commonVocab, true);
numSourceWords = sourceWordsSentences[0];
numSourceSentences = sourceWordsSentences[1];
if (logger.isLoggable(Level.INFO)) logger.info("Initializing common vocabulary with target corpus " + sourceFileName);
int[] targetWordsSentences = Vocabulary.initializeVocabulary(targetFileName, commonVocab, true);
numTargetWords = targetWordsSentences[0];
numTargetSentences = targetWordsSentences[1];
}
//////////////////////////////////
// Source language corpus array //
//////////////////////////////////
final Corpus sourceCorpusArray;
if (binaryCorpus) {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing memory mapped source language corpus array.");
sourceCorpusArray = new MemoryMappedCorpusArray(commonVocab, sourceFileName);
} else {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing source language corpus array.");
sourceCorpusArray = SuffixArrayFactory.createCorpusArray(sourceFileName, commonVocab, numSourceWords, numSourceSentences);
}
//////////////////////////////////
// Source language suffix array //
//////////////////////////////////
Suffixes sourceSuffixArray;
String binarySourceSuffixArrayFileName = sourceSuffixesFileName;
if (binaryCorpus) {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing source language suffix array from binary file " + binarySourceSuffixArrayFileName);
sourceSuffixArray = new MemoryMappedSuffixArray(binarySourceSuffixArrayFileName, sourceCorpusArray, cacheSize);
} else {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing source language suffix array from source corpus.");
sourceSuffixArray = SuffixArrayFactory.createSuffixArray(sourceCorpusArray, cacheSize);
}
//////////////////////////////////
// Target language corpus array //
//////////////////////////////////
final Corpus targetCorpusArray;
if (binaryCorpus) {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing memory mapped target language corpus array.");
targetCorpusArray = new MemoryMappedCorpusArray(commonVocab, targetFileName);
} else {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing target language corpus array.");
targetCorpusArray = SuffixArrayFactory.createCorpusArray(targetFileName, commonVocab, numTargetWords, numTargetSentences);
}
//////////////////////////////////
// Target language suffix array //
//////////////////////////////////
Suffixes targetSuffixArray;
String binaryTargetSuffixArrayFileName = targetSuffixesFileName;
if (binaryCorpus) {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing target language suffix array from binary file " + binaryTargetSuffixArrayFileName);
targetSuffixArray = new MemoryMappedSuffixArray(binaryTargetSuffixArrayFileName, targetCorpusArray, cacheSize);
} else {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing target language suffix array from target corpus.");
targetSuffixArray = SuffixArrayFactory.createSuffixArray(targetCorpusArray, cacheSize);
}
int trainingSize = sourceCorpusArray.getNumSentences();
if (trainingSize != targetCorpusArray.getNumSentences()) {
throw new RuntimeException("Source and target corpora have different number of sentences. This is bad.");
}
/////////////////////
// Alignment data //
/////////////////////
if (logger.isLoggable(Level.INFO)) logger.info("Reading alignment data.");
final Alignments alignments;
if ("AlignmentArray".equals(alignmentsType)) {
if (logger.isLoggable(Level.INFO)) logger.info("Using AlignmentArray");
alignments = SuffixArrayFactory.createAlignments(alignmentsFileName, sourceSuffixArray, targetSuffixArray);
} else if ("AlignmentGrids".equals(alignmentsType) || "AlignmentsGrid".equals(alignmentsType)) {
if (logger.isLoggable(Level.INFO)) logger.info("Using AlignmentGrids");
alignments = new AlignmentGrids(new Scanner(new File(alignmentsFileName)), sourceCorpusArray, targetCorpusArray, trainingSize, requireTightSpans);
} else if ("MemoryMappedAlignmentGrids".equals(alignmentsType)) {
if (logger.isLoggable(Level.INFO)) logger.info("Using MemoryMappedAlignmentGrids");
alignments = new MemoryMappedAlignmentGrids(alignmentsFileName, sourceCorpusArray, targetCorpusArray);
} else {
alignments = null;
logger.severe("Invalid alignment type: " + alignmentsType);
System.exit(-1);
}
Map<Integer,String> ntVocab = new HashMap<Integer,String>();
ntVocab.put(SymbolTable.X, SymbolTable.X_STRING);
//////////////////////
// Lexical Probs //
//////////////////////
// final LexProbs lexProbs;
String binaryLexCountsFilename = this.lexCountsFileName;
//////////////////////
// Frequent Phrases //
//////////////////////
if (usePrecomputedFrequentPhrases) {
logger.info("Reading precomputed frequent phrases from disk");
FrequentPhrases frequentPhrases = new FrequentPhrases(sourceSuffixArray, frequentPhrasesFileName);
frequentPhrases.cacheInvertedIndices();
}
logger.info("Constructing grammar factory from parallel corpus");
ParallelCorpusGrammarFactory parallelCorpus;
if (binaryCorpus) {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing lexical translation probabilities from binary file " + binaryLexCountsFilename);
parallelCorpus = new ParallelCorpusGrammarFactory(sourceSuffixArray, targetSuffixArray, alignments, null, binaryLexCountsFilename, ruleSampleSize, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan, JoshuaConfiguration.phrase_owner, JoshuaConfiguration.default_non_terminal, JoshuaConfiguration.oovFeatureCost);
} else {
if (logger.isLoggable(Level.INFO)) logger.info("Constructing lexical translation probabilities from parallel corpus");
parallelCorpus = new ParallelCorpusGrammarFactory(sourceSuffixArray, targetSuffixArray, alignments, null, ruleSampleSize, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan, Float.MIN_VALUE, JoshuaConfiguration.phrase_owner, JoshuaConfiguration.default_non_terminal, JoshuaConfiguration.oovFeatureCost);
}
return parallelCorpus;
}
public void execute() throws IOException, ClassNotFoundException {
// Set System.out and System.err to use the provided character encoding
try {
System.setOut(new PrintStream(System.out, true, "UTF-8"));
System.setErr(new PrintStream(System.err, true, "UTF-8"));
} catch (UnsupportedEncodingException e1) {
System.err.println("UTF-8 is not a valid encoding; using system default encoding for System.out and System.err.");
} catch (SecurityException e2) {
System.err.println("Security manager is configured to disallow changes to System.out or System.err; using system default encoding.");
}
PrintStream out;
if ("-".equals(this.outputFile)) {
out = System.out;
logger.info("Rules will be written to standard out");
} else {
out = new PrintStream(outputFile,"UTF-8");
logger.info("Rules will be written to " + outputFile);
}
ParallelCorpusGrammarFactory parallelCorpus = this.getGrammarFactory();
logger.info("Getting symbol table");
SymbolTable sourceVocab = parallelCorpus.getSourceCorpus().getVocabulary();
int lineNumber = 0;
boolean oneTreePerSentence = ! this.keepTree;
logger.info("Will read test sentences from " + testFileName);
Scanner testFileScanner = new Scanner(new File(testFileName), encoding);
logger.info("Read test sentences from " + testFileName);
PrefixTree prefixTree = null;
while (testFileScanner.hasNextLine() && (lineNumber-startingSentence+1)<maxTestSentences) {
String line = testFileScanner.nextLine();
lineNumber++;
if (lineNumber < startingSentence) continue;
int[] words = sourceVocab.getIDs(line);
if (oneTreePerSentence || null==prefixTree)
{
// prefixTree = new PrefixTree(sourceSuffixArray, targetCorpusArray, alignments, sourceSuffixArray.getVocabulary(), lexProbs, ruleExtractor, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan);
if (logger.isLoggable(Level.INFO)) logger.info("Constructing new prefix tree");
Node.resetNodeCounter();
prefixTree = new PrefixTree(parallelCorpus);
prefixTree.setPrintStream(out);
prefixTree.sentenceInitialX = this.sentenceInitialX;
prefixTree.sentenceFinalX = this.sentenceFinalX;
prefixTree.edgeXMayViolatePhraseSpan = this.edgeXViolates;
}
try {
if (logger.isLoggable(Level.INFO)) logger.info("Processing source line " + lineNumber + ": " + line);
prefixTree.add(words);
} catch (OutOfMemoryError e) {
logger.warning("Out of memory - attempting to clear cache to free space");
parallelCorpus.getSuffixArray().getCachedHierarchicalPhrases().clear();
// targetSuffixArray.getCachedHierarchicalPhrases().clear();
prefixTree = null;
System.gc();
logger.info("Cleared cache and collected garbage. Now attempting to re-construct prefix tree...");
// prefixTree = new PrefixTree(sourceSuffixArray, targetCorpusArray, alignments, sourceSuffixArray.getVocabulary(), lexProbs, ruleExtractor, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan);
Node.resetNodeCounter();
prefixTree = new PrefixTree(parallelCorpus);
prefixTree.setPrintStream(out);
prefixTree.sentenceInitialX = this.sentenceInitialX;
prefixTree.sentenceFinalX = this.sentenceFinalX;
prefixTree.edgeXMayViolatePhraseSpan = this.edgeXViolates;
if (logger.isLoggable(Level.INFO)) logger.info("Re-processing source line " + lineNumber + ": " + line);
prefixTree.add(words);
}
if (printPrefixTree) {
System.out.println(prefixTree.toString());
}
// if (printRules) {
// if (logger.isLoggable(Level.FINE)) logger.fine("Outputting rules for source line: " + line);
//
// for (Rule rule : prefixTree.getAllRules()) {
// String ruleString = rule.toString(ntVocab, sourceVocab, targetVocab);
// if (logger.isLoggable(Level.FINEST)) logger.finest("Rule: " + ruleString);
// out.println(ruleString);
// }
// }
// if (logger.isLoggable(Level.FINEST)) logger.finest(lexProbs.toString());
}
logger.info("Done extracting rules for file " + testFileName);
}
/**
* @param args
* @throws IOException
* @throws ClassNotFoundException
*/
public static void main(String[] args) throws IOException, ClassNotFoundException {
if (args.length==3) {
ExtractRules extractRules = new ExtractRules();
extractRules.setJoshDir(args[0]);
extractRules.setOutputFile(args[1]);
extractRules.setTestFile(args[2]);
extractRules.execute();
} else if (args.length==5) {
ExtractRules extractRules = new ExtractRules();
extractRules.setSourceFileName(args[0]);
extractRules.setTargetFileName(args[1]);
extractRules.setAlignmentsFileName(args[2]);
extractRules.setOutputFile(args[3]);
extractRules.setTestFile(args[4]);
extractRules.execute();
} else {
System.err.println("Usage: joshDir outputRules testFile");
System.err.println("---------------OR------------------");
System.err.println("Usage: source.txt target.txt alignments.txt outputRules testFile");
}
}
}