/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import joshua.corpus.suffix_array.Pattern;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.chart_parser.Chart;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.lm.LanguageModelFF;
import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.tm.Grammar;
import joshua.decoder.ff.tm.GrammarFactory;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.decoder.segment_file.HackishSegmentParser;
import joshua.decoder.segment_file.PlainSegmentParser;
import joshua.decoder.segment_file.Segment;
import joshua.decoder.segment_file.SegmentFileParser;
import joshua.decoder.segment_file.sax_parser.SAXSegmentParser;
import joshua.lattice.Lattice;
import joshua.oracle.OracleExtractor;
import joshua.ui.hypergraph_visualizer.HyperGraphViewer;
import joshua.util.CoIterator;
import joshua.util.FileUtility;
import joshua.util.io.LineReader;
import joshua.util.io.NullReader;
import joshua.util.io.Reader;
import joshua.util.io.UncheckedIOException;
/**
* this class implements:
* (1) interact with the chart-parsing functions to do the true
* decoding
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-02-08 13:03:13 -0600 (Mon, 08 Feb 2010) $
*/
// BUG: known synchronization problem: LM cache; srilm call;
public class DecoderThread extends Thread {
/* these variables may be the same across all threads (e.g.,
* just copy from DecoderFactory), or differ from thread
* to thread */
private final List<GrammarFactory> grammarFactories;
private final boolean hasLanguageModel;
private final List<FeatureFunction> featureFunctions;
private final List<StateComputer> stateComputers;
/**
* Shared symbol table for source language terminals, target
* language terminals, and shared nonterminals.
* <p>
* It may be that separate tables should be maintained for
* the source and target languages.
* <p>
* This class explicitly uses the symbol table to get integer
* IDs for the source language sentence.
*/
private final SymbolTable symbolTable;
//more test set specific
final String testFile;
private final String oracleFile;
final String nbestFile; // package-private for DecoderFactory
private BufferedWriter nbestWriter; // set in decodeTestFile
private final int startSentenceID;
private final KBestExtractor kbestExtractor;
DiskHyperGraph hypergraphSerializer; // package-private for DecoderFactory
private static final Logger logger =
Logger.getLogger(DecoderThread.class.getName());
//===============================================================
// Constructor
//===============================================================
public DecoderThread(
List<GrammarFactory> grammarFactories,
boolean hasLanguageModel,
List<FeatureFunction> featureFunctions,
List<StateComputer> stateComputers,
SymbolTable symbolTable,
String testFile, String nbestFile, String oracleFile,
int startSentenceID
) throws IOException {
this.grammarFactories = grammarFactories;
this.hasLanguageModel = hasLanguageModel;
this.featureFunctions = featureFunctions;
this.stateComputers = stateComputers;
this.symbolTable = symbolTable;
this.testFile = testFile;
this.nbestFile = nbestFile;
this.oracleFile = oracleFile;
this.startSentenceID = startSentenceID;
this.kbestExtractor = new KBestExtractor(
this.symbolTable,
JoshuaConfiguration.use_unique_nbest,
JoshuaConfiguration.use_tree_nbest,
JoshuaConfiguration.include_align_index,
JoshuaConfiguration.add_combined_cost,
false, (oracleFile==null));
if (JoshuaConfiguration.save_disk_hg) {
FeatureFunction languageModel = null;
for (FeatureFunction ff : this.featureFunctions) {
if (ff instanceof LanguageModelFF) {
languageModel = ff;
//break;
}
}
int lmFeatID = -1;
if (null == languageModel) {
logger.warning("No language model feature function found, but save disk hg");
}else{
lmFeatID = languageModel.getFeatureID();
}
this.hypergraphSerializer = new DiskHyperGraph(
this.symbolTable,
lmFeatID,
true, // always store model cost
this.featureFunctions);
this.hypergraphSerializer.initWrite(
this.nbestFile + ".hg.items",
JoshuaConfiguration.forest_pruning,
JoshuaConfiguration.forest_pruning_threshold);
}
}
//===============================================================
// Methods
//===============================================================
// Overriding of Thread.run() cannot throw anything
public void run() {
try {
this.decodeTestFile();
//this.hypergraphSerializer.closeReaders();
} catch (IOException e) {
e.printStackTrace();
System.exit(1);
}
}
// BUG: log file is not properly handled for parallel decoding
void decodeTestFile() throws IOException {
SegmentFileParser segmentParser;
// BUG: As written, this will need duplicating in DecoderFactory
// TODO: Fix JoshuaConfiguration so we can make this less gross.
//
// TODO: maybe using real reflection would be cleaner. If it weren't for the argument for HackishSegmentParser then we could do all this over in the JoshuaConfiguration class instead
final String className = JoshuaConfiguration.segmentFileParserClass;
if (null == className) {
// Use old behavior by default
segmentParser = new HackishSegmentParser(this.startSentenceID);
} else if ("PlainSegmentParser".equals(className)) {
segmentParser = new PlainSegmentParser();
} else if ("HackishSegmentParser".equals(className)) {
segmentParser = new HackishSegmentParser(this.startSentenceID);
} else if ("SAXSegmentParser".equals(className)) {
segmentParser = new SAXSegmentParser();
} else {
throw new IllegalArgumentException(
"Unknown SegmentFileParser class: " + className);
}
// TODO: we need to run the segmentParser over the file once in order to catch any errors before we do the actual translation. Getting formatting errors asynchronously after a long time is a Bad Thing(tm). Some errors may be recoverable (e.g. by skipping the sentence that's invalid), but we're going to call all exceptions errors for now.
//
// TODO: we should unwrapper SAXExceptions and give good error messages
segmentParser.parseSegmentFile(
LineReader.getInputStream(this.testFile),
new CoIterator<Segment>() {
public void coNext(Segment seg) {
// Consume Segment and do nothing (for now)
}
public void finish() {
// Nothing to clean up
}
});
// TODO: we should also have the CoIterator<Segment> test compatibility with a given grammar, e.g. count of grammatical feature functions match, nonterminals match,...
// TODO: we may also want to validate that all segments have different ids
//=== Translate the test file
this.nbestWriter = FileUtility.getWriteFileStream(this.nbestFile);
try {
try {
//this method will analyze the input file (to generate segments), and then translate segments one by one
segmentParser.parseSegmentFile(
LineReader.getInputStream(this.testFile),
new TranslateCoiterator(
null == this.oracleFile
? new NullReader<String>()
: new LineReader(this.oracleFile)
)
);
} catch (UncheckedIOException e) {
e.throwCheckedException();
}
} finally {
this.nbestWriter.flush();
this.nbestWriter.close();
}
}
/**
* This coiterator is for calling the DecoderThread.translate
* method on each Segment to be translated. All interface
* methods can throw {@link UncheckedIOException}, which
* should be converted back into a {@link IOException} once
* it's possible.
*/
private class TranslateCoiterator implements CoIterator<Segment> {
// TODO: it would be nice if we could somehow push this into the parseSegmentFile call and use a coiterator over some subclass of Segment which has another method for returning the oracular senence. That may take some work though, since Java hates mixins so much.
private Reader<String> oracleReader;
public TranslateCoiterator(Reader<String> oracleReader) {
this.oracleReader = oracleReader;
}
public void coNext(Segment segment) {
try {
if (logger.isLoggable(Level.FINE))
logger.fine("Segment id: " + segment.id());
DecoderThread.this.translate(
segment, this.oracleReader.readLine());
} catch (IOException ioe) {
throw new UncheckedIOException(ioe);
}
}
public void finish() {
try {
this.oracleReader.close();
} catch (IOException ioe) {
throw new UncheckedIOException(ioe);
}
}
} // End inner class TranslateCoiterator
/**
* Translate a sentence.
*
* @param segment The sentence to be translated.
* @param oracleSentence
*/
private void translate(Segment segment, String oracleSentence)
throws IOException {
long startTime = 0;
if (logger.isLoggable(Level.FINER)) {
startTime = System.currentTimeMillis();
}
if (logger.isLoggable(Level.FINE))
logger.fine("now translating\n" + segment.sentence());
Chart chart; {
//TODO: we should not use "(((" to decide whether it is a lattice input
final boolean looksLikeLattice = segment.sentence().startsWith("(((");
Lattice<Integer> inputLattice = null;
Pattern sentence = null;
if (looksLikeLattice) {
inputLattice = Lattice.createFromString(segment.sentence(),
this.symbolTable);
sentence = null; // TODO SA needs to accept lattices!
} else {
int[] intSentence = this.symbolTable.getIDs(segment.sentence());
if (logger.isLoggable(Level.FINEST))
logger.finest("Converted \"" + segment.sentence() + "\" into " + Arrays.toString(intSentence));
inputLattice = Lattice.createLattice(intSentence);
sentence = new Pattern(this.symbolTable, intSentence);
}
if (logger.isLoggable(Level.FINEST))
logger.finest("Translating input lattice:\n" + inputLattice.toString());
Grammar[] grammars = new Grammar[grammarFactories.size()];
int i = 0;
for (GrammarFactory factory : this.grammarFactories) {
grammars[i] = factory.getGrammarForSentence(sentence);
// For batch grammar, we do not want to sort it every time
if (! grammars[i].isSorted()) {
System.out.println("!!!!!!!!!!!! called again");
// TODO Check to see if this is ever called here. It probably is not
grammars[i].sortGrammar(this.featureFunctions);
}
i++;
}
/* Seeding: the chart only sees the grammars, not the factories */
chart = new Chart(
inputLattice,
this.featureFunctions,
this.stateComputers,
this.symbolTable,
Integer.parseInt(segment.id()),
grammars,
this.hasLanguageModel,
JoshuaConfiguration.goal_symbol,
segment.constraints());
if (logger.isLoggable(Level.FINER))
logger.finer("after seed, time: "
+ ((double)(System.currentTimeMillis() - startTime) / 1000.0)
+ " seconds");
}
/* Parsing */
HyperGraph hypergraph = chart.expand();
if (JoshuaConfiguration.visualize_hypergraph) {
HyperGraphViewer.visualizeHypergraphInFrame(hypergraph, symbolTable);
}
if (logger.isLoggable(Level.FINER))
logger.finer("after expand, time: "
+ ((double)(System.currentTimeMillis() - startTime) / 1000.0)
+ " seconds");
if (oracleSentence != null) {
logger.fine("Creating oracle extractor");
OracleExtractor extractor = new OracleExtractor(this.symbolTable);
logger.finer("Extracting oracle hypergraph...");
HyperGraph oracle = extractor.getOracle(hypergraph, 3, oracleSentence);
logger.finer("... Done Extracting. Getting k-best...");
this.kbestExtractor.lazyKBestExtractOnHG(
oracle, this.featureFunctions,
JoshuaConfiguration.topN,
Integer.parseInt(segment.id()), this.nbestWriter);
logger.finer("... Done getting k-best");
} else {
/* k-best extraction */
this.kbestExtractor.lazyKBestExtractOnHG(
hypergraph, this.featureFunctions,
JoshuaConfiguration.topN,
Integer.parseInt(segment.id()), this.nbestWriter);
if (logger.isLoggable(Level.FINER))
logger.finer("after k-best, time: "
+ ((double)(System.currentTimeMillis() - startTime) / 1000.0)
+ " seconds");
}
if (null != this.hypergraphSerializer) {
if(JoshuaConfiguration.use_kbest_hg){
HyperGraph kbestHG = this.kbestExtractor.extractKbestIntoHyperGraph(hypergraph, JoshuaConfiguration.topN);
this.hypergraphSerializer.saveHyperGraph(kbestHG);
}else{
this.hypergraphSerializer.saveHyperGraph(hypergraph);
}
}
/* //debug
if (JoshuaConfiguration.use_variational_decoding) {
ConstituentVariationalDecoder vd = new ConstituentVariationalDecoder();
vd.decoding(hypergraph);
System.out.println("#### new 1best is #####\n" + HyperGraph.extract_best_string(p_main_controller.p_symbol, hypergraph.goal_item));
}
// end */
//debug
//g_con.get_confusion_in_hyper_graph_cell_specific(hypergraph, hypergraph.sent_len);
}
/**decode a sentence, and return a hypergraph*/
public HyperGraph getHyperGraph(String sentence)
{
Chart chart;
int[] intSentence = this.symbolTable.getIDs(sentence);
Lattice<Integer> inputLattice = Lattice.createLattice(intSentence);
Grammar[] grammars = new Grammar[grammarFactories.size()];
int i = 0;
for (GrammarFactory factory : this.grammarFactories) {
grammars[i] = factory.getGrammarForSentence(
new Pattern(this.symbolTable, intSentence));
// For batch grammar, we do not want to sort it every time
if (! grammars[i].isSorted()) {
grammars[i].sortGrammar(this.featureFunctions);
}
i++;
}
chart = new Chart(
inputLattice,
this.featureFunctions,
this.stateComputers,
this.symbolTable,
0,
grammars,
this.hasLanguageModel,
JoshuaConfiguration.goal_symbol,
null);
return chart.expand();
}
}