/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.tm.GrammarFactory;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.FileUtilityOld;
import joshua.util.FileUtility;
import joshua.util.Regex;
import joshua.util.io.LineReader;
/**
* this class implements:
* (1) parallel decoding: split the test file, initiate DecoderThread,
* wait and merge the decoding results
* (2) non-parallel decoding is a special case of parallel decoding
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-01-29 16:04:50 -0600 (Fri, 29 Jan 2010) $
*/
public class DecoderFactory {
private List<GrammarFactory> grammarFactories = null;
private List<FeatureFunction> featureFunctions = null;
private List<StateComputer> stateComputers;
private boolean hasLanguageModel = false;
/**
* Shared symbol table for source language terminals, target
* language terminals, and shared nonterminals.
*/
private SymbolTable symbolTable = null;
private DecoderThread[] parallelThreads;
private static final Logger logger =
Logger.getLogger(DecoderFactory.class.getName());
public DecoderFactory(List<GrammarFactory> grammarFactories, boolean hasLanguageModel, List<FeatureFunction> featureFunctions,
List<StateComputer> stateComputers, SymbolTable symbolTable) {
this.grammarFactories = grammarFactories;
this.hasLanguageModel = hasLanguageModel;
this.featureFunctions = featureFunctions;
this.stateComputers = stateComputers;
this.symbolTable = symbolTable;
}
/**
* This is the public-facing method to decode a set of
* sentences. This automatically detects whether we should
* run the decoder in parallel or not.
*/
public void decodeTestSet(String testFile, String nbestFile, String oracleFile) {
try {
if (JoshuaConfiguration.num_parallel_decoders == 1) {
DecoderThread pdecoder = new DecoderThread(
this.grammarFactories, this.hasLanguageModel,
this.featureFunctions, this.stateComputers, this.symbolTable,
testFile, nbestFile, oracleFile, 0);
// do not call *start*; this way we stay in the current main thread
pdecoder.decodeTestFile();
if (JoshuaConfiguration.save_disk_hg) {
pdecoder.hypergraphSerializer.writeRulesNonParallel(
nbestFile + ".hg.rules");
}
} else {
if (JoshuaConfiguration.use_remote_lm_server) { // TODO
throw new IllegalArgumentException("You cannot run parallel decoder and remote LM server together");
}
if (null != oracleFile) {
logger.warning("Parallel decoding appears not to support oracle decoding, but you passed an oracle file in.");
}
runParallelDecoder(testFile, nbestFile);
}
} catch (IOException e) {
e.printStackTrace();
}
}
/**decode a single sentence, and returns a hypergraph
* */
public HyperGraph getHyperGraphForSentence(String sentence)
{
try {
DecoderThread pdecoder = new DecoderThread(
this.grammarFactories, this.hasLanguageModel,
this.featureFunctions, this.stateComputers, this.symbolTable,
sentence, null, null, 0);
return pdecoder.getHyperGraph(sentence);
}
catch (IOException e) {
e.printStackTrace();
}
return null;
}
// BUG: this kind of file munging isn't going to work with generalized SegmentFileParser
private void runParallelDecoder(String testFile, String nbestFile)
throws IOException {
{
/**if a segment corresponds to a single line in the input file, we can use parallel decoding as we can simply split the file.
* This is true for the input files that will be processed by PlainSegmentParser and PlainSegmentParser.
* But, if a segment corresponds to multiple lines (as in the case of xml file input, which has
* manual constraints specified in the xml format), we cannot do parallel decoding as we do below.
* This is true for the file that will be handled by the SAXSegmentParser
**/
//TODO: what about lattice input? is it a single input?
final String className = JoshuaConfiguration.segmentFileParserClass;
if (className==null || "PlainSegmentParser".equals(className)) {
// Do nothing, this one is okay.
} else if ("PlainSegmentParser".equals(className)) {
logger.warning("Using HackishSegmentParser with parallel decoding may cause sentence IDs to become garbled");
} else {
throw new IllegalArgumentException("Parallel decoding is currently not supported with SegmentFileParsers other than PlainSegmentParser or HackishSegmentParser");
}
}
this.parallelThreads = new DecoderThread[JoshuaConfiguration.num_parallel_decoders];
//==== compute number of lines for each decoder
int n_lines = 0; {
LineReader testReader = new LineReader(testFile);
try {
n_lines = testReader.countLines();
} finally { testReader.close(); }
}
double num_per_thread_double = n_lines * 1.0 / JoshuaConfiguration.num_parallel_decoders;
int num_per_thread_int = (int) num_per_thread_double;
if (logger.isLoggable(Level.INFO))
logger.info("num_per_file_double: " + num_per_thread_double
+ "; num_per_file_int: " + num_per_thread_int);
//==== Initialize all threads and their input files
int decoder_i = 1;
String cur_test_file = JoshuaConfiguration.parallel_files_prefix + ".test." + decoder_i;
String cur_nbest_file = JoshuaConfiguration.parallel_files_prefix + ".nbest." + decoder_i;
BufferedWriter t_writer_test =
FileUtility.getWriteFileStream(cur_test_file);
int sent_id = 0;
int start_sent_id = sent_id;
LineReader testReader = new LineReader(testFile);
try {
for (String cn_sent : testReader) {
sent_id++;
t_writer_test.write(cn_sent);
t_writer_test.newLine();
//make the Symbol table finalized before running multiple threads, this is to avoid synchronization among threads
{
String words[] = Regex.spaces.split(cn_sent);
this.symbolTable.addTerminals(words); // TODO
}
//logger.info("sent_id="+sent_id);
// we will include all additional lines into last file
//prepare current job
if (0 != sent_id
&& decoder_i < JoshuaConfiguration.num_parallel_decoders
&& sent_id % num_per_thread_int == 0
) {
t_writer_test.flush();
t_writer_test.close();
DecoderThread pdecoder = new DecoderThread(
this.grammarFactories,
this.hasLanguageModel,
this.featureFunctions,
this.stateComputers,
this.symbolTable,
cur_test_file,
cur_nbest_file,
null,
start_sent_id);
this.parallelThreads[decoder_i-1] = pdecoder;
// prepare next job
start_sent_id = sent_id;
decoder_i++;
cur_test_file = JoshuaConfiguration.parallel_files_prefix + ".test." + decoder_i;
cur_nbest_file = JoshuaConfiguration.parallel_files_prefix + ".nbest." + decoder_i;
t_writer_test = FileUtility.getWriteFileStream(cur_test_file);
}
}
}finally {
testReader.close();
//==== prepare the the last job
t_writer_test.flush();
t_writer_test.close();
}
DecoderThread pdecoder = new DecoderThread(
this.grammarFactories,
this.hasLanguageModel,
this.featureFunctions,
this.stateComputers,
this.symbolTable,
cur_test_file,
cur_nbest_file,
null,
start_sent_id);
this.parallelThreads[decoder_i-1] = pdecoder;
// End initializing threads and their files
//==== run all the jobs
for (int i = 0; i < this.parallelThreads.length; i++) {
if (logger.isLoggable(Level.INFO))
logger.info("##############start thread " + i);
this.parallelThreads[i].start();
}
//==== wait for the threads finish
for (int i = 0; i < this.parallelThreads.length; i++) {
try {
this.parallelThreads[i].join();
} catch (InterruptedException e) {
if (logger.isLoggable(Level.WARNING))
logger.warning("thread is interupted for server " + i);
}
}
//==== merge the nbest files, and remove tmp files
BufferedWriter nbestWriter = FileUtility.getWriteFileStream(nbestFile);
BufferedWriter itemsWriter = null;
if (JoshuaConfiguration.save_disk_hg) {
itemsWriter = FileUtility.getWriteFileStream(nbestFile + ".hg.items");
}
for (DecoderThread decoder : this.parallelThreads) {
//merge nbest
LineReader nbestReader = new LineReader(decoder.nbestFile);
try {
for (String sent : nbestReader) {
nbestWriter.write(sent);
nbestWriter.newLine();
}
} finally {
nbestReader.close();
}
//remove the tem nbest file
FileUtility.deleteFile(decoder.nbestFile);
FileUtility.deleteFile(decoder.testFile);
//merge hypergrpah items
if (JoshuaConfiguration.save_disk_hg) {
LineReader itemReader = new LineReader(decoder.nbestFile + ".hg.items");
try {
for (String sent : itemReader) {
itemsWriter.write(sent);
itemsWriter.newLine();
}
} finally {
itemReader.close();
decoder.hypergraphSerializer.closeItemsWriter();
}
//remove the tem item file
FileUtility.deleteFile(decoder.nbestFile + ".hg.items");
}
}
nbestWriter.flush();
nbestWriter.close();
if (JoshuaConfiguration.save_disk_hg) {
itemsWriter.flush();
itemsWriter.close();
}
//merge the grammar rules for disk hyper-graphs
if (JoshuaConfiguration.save_disk_hg) {
HashMap<Integer,Integer> tblDone = new HashMap<Integer,Integer>();
BufferedWriter rulesWriter = FileUtility.getWriteFileStream(nbestFile + ".hg.rules");
for (DecoderThread decoder : this.parallelThreads) {
decoder.hypergraphSerializer.writeRulesParallel(rulesWriter, tblDone);
//decoder.hypergraphSerializer.closeReaders();
}
rulesWriter.flush();
rulesWriter.close();
}
}
}