Package joshua.decoder

Source Code of joshua.decoder.DecoderFactory

/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder;

import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.tm.GrammarFactory;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.FileUtilityOld;
import joshua.util.FileUtility;
import joshua.util.Regex;
import joshua.util.io.LineReader;

/**
* this class implements:
* (1) parallel decoding: split the test file, initiate DecoderThread,
*     wait and merge the decoding results
* (2) non-parallel decoding is a special case of parallel decoding
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-01-29 16:04:50 -0600 (Fri, 29 Jan 2010) $
*/
public class DecoderFactory {
  private List<GrammarFactory>  grammarFactories = null;
  private List<FeatureFunction> featureFunctions = null;
  private List<StateComputer> stateComputers;
  private boolean                    hasLanguageModel = false;
 
  /**
   * Shared symbol table for source language terminals, target
   * language terminals, and shared nonterminals.
   */
  private SymbolTable symbolTable = null;
 
  private DecoderThread[] parallelThreads;
 
  private static final Logger logger =
    Logger.getLogger(DecoderFactory.class.getName());
 
 
  public DecoderFactory(List<GrammarFactory> grammarFactories, boolean hasLanguageModel, List<FeatureFunction> featureFunctions,
      List<StateComputer> stateComputers, SymbolTable symbolTable) {
    this.grammarFactories = grammarFactories;
    this.hasLanguageModel = hasLanguageModel;
    this.featureFunctions = featureFunctions;
    this.stateComputers = stateComputers;
    this.symbolTable      = symbolTable;
  }
 
 
  /**
   * This is the public-facing method to decode a set of
   * sentences. This automatically detects whether we should
   * run the decoder in parallel or not.
   */
  public void decodeTestSet(String testFile, String nbestFile, String oracleFile) {
    try {
      if (JoshuaConfiguration.num_parallel_decoders == 1) {
        DecoderThread pdecoder = new DecoderThread(
          this.grammarFactories, this.hasLanguageModel,
          this.featureFunctions, this.stateComputers, this.symbolTable,
          testFile, nbestFile, oracleFile, 0);
       
        // do not call *start*; this way we stay in the current main thread
        pdecoder.decodeTestFile();
       
        if (JoshuaConfiguration.save_disk_hg) {
          pdecoder.hypergraphSerializer.writeRulesNonParallel(
            nbestFile + ".hg.rules");
        }
      } else {
        if (JoshuaConfiguration.use_remote_lm_server) { // TODO
          throw new IllegalArgumentException("You cannot run parallel decoder and remote LM server together");
        }
        if (null != oracleFile) {
          logger.warning("Parallel decoding appears not to support oracle decoding, but you passed an oracle file in.");
        }
        runParallelDecoder(testFile, nbestFile);
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
 
  /**decode a single sentence, and returns a hypergraph
   * */
  public HyperGraph getHyperGraphForSentence(String sentence)
  {
    try {
      DecoderThread pdecoder = new DecoderThread(
        this.grammarFactories, this.hasLanguageModel,
        this.featureFunctions, this.stateComputers, this.symbolTable,
        sentence, null, null, 0);
      return pdecoder.getHyperGraph(sentence);
    }
    catch (IOException e) {
      e.printStackTrace();
    }
    return null;
  }
 
 
  // BUG: this kind of file munging isn't going to work with generalized SegmentFileParser
  private void runParallelDecoder(String testFile, String nbestFile)
  throws IOException {
    {
      /**if a segment corresponds to a single line in the input file, we can use parallel decoding as we can simply split the file.
       * This is true for the input files that will be processed by PlainSegmentParser and PlainSegmentParser.
       * But, if a segment corresponds to multiple lines (as in the case of xml file input, which has
       * manual constraints specified in the xml format), we cannot do parallel decoding as we do below.
       * This is true for the file that will be handled by the SAXSegmentParser
       **/
      //TODO: what about lattice input? is it a single input?
     
      final String className = JoshuaConfiguration.segmentFileParserClass;

      if (className==null || "PlainSegmentParser".equals(className)) {
        // Do nothing, this one is okay.
      } else if ("PlainSegmentParser".equals(className)) {
        logger.warning("Using HackishSegmentParser with parallel decoding may cause sentence IDs to become garbled");
      } else {
        throw new IllegalArgumentException("Parallel decoding is currently not supported with SegmentFileParsers other than PlainSegmentParser or HackishSegmentParser");
      }
    }
   
   
   
    this.parallelThreads = new DecoderThread[JoshuaConfiguration.num_parallel_decoders];
   
    //==== compute number of lines for each decoder
    int n_lines = 0; {
      LineReader testReader = new LineReader(testFile);
      try {
        n_lines = testReader.countLines();
      } finally { testReader.close(); }
    }
   
    double num_per_thread_double = n_lines * 1.0 / JoshuaConfiguration.num_parallel_decoders;
    int    num_per_thread_int    = (int) num_per_thread_double;
   
    if (logger.isLoggable(Level.INFO))
      logger.info("num_per_file_double: " + num_per_thread_double
        + "; num_per_file_int: " + num_per_thread_int);
   
   
    //==== Initialize all threads and their input files
    int decoder_i = 1;
    String cur_test_file  = JoshuaConfiguration.parallel_files_prefix + ".test." + decoder_i;
    String cur_nbest_file = JoshuaConfiguration.parallel_files_prefix + ".nbest." + decoder_i;
    BufferedWriter t_writer_test = 
      FileUtility.getWriteFileStream(cur_test_file);
    int sent_id       = 0;
    int start_sent_id = sent_id;
   
    LineReader testReader = new LineReader(testFile);
    try {
      for (String cn_sent : testReader) {
        sent_id++;
        t_writer_test.write(cn_sent);
        t_writer_test.newLine();
       
        //make the Symbol table finalized before running multiple threads, this is to avoid synchronization among threads
        {
          String words[] = Regex.spaces.split(cn_sent);
          this.symbolTable.addTerminals(words); // TODO
        }
        //logger.info("sent_id="+sent_id);
        // we will include all additional lines into last file
        //prepare current job
        if (0 != sent_id
        && decoder_i < JoshuaConfiguration.num_parallel_decoders
        && sent_id % num_per_thread_int == 0
        ) {
          t_writer_test.flush();
          t_writer_test.close();
         
          DecoderThread pdecoder = new DecoderThread(
            this.grammarFactories,
            this.hasLanguageModel,
            this.featureFunctions,
            this.stateComputers,
            this.symbolTable,
            cur_test_file,
            cur_nbest_file,
            null,
            start_sent_id);
          this.parallelThreads[decoder_i-1] = pdecoder;
         
          // prepare next job
          start_sent_id  = sent_id;
          decoder_i++;
          cur_test_file  = JoshuaConfiguration.parallel_files_prefix + ".test." + decoder_i;
          cur_nbest_file = JoshuaConfiguration.parallel_files_prefix + ".nbest." + decoder_i;
          t_writer_test  = FileUtility.getWriteFileStream(cur_test_file);
        }
      }
    }finally {
      testReader.close();
     
      //==== prepare the the last job
      t_writer_test.flush();
      t_writer_test.close();
    }
   
    DecoderThread pdecoder = new DecoderThread(
      this.grammarFactories,
      this.hasLanguageModel,
      this.featureFunctions,
      this.stateComputers,
      this.symbolTable,
      cur_test_file,
      cur_nbest_file,
      null,
      start_sent_id);
    this.parallelThreads[decoder_i-1] = pdecoder;
   
    // End initializing threads and their files
     
   
    //==== run all the jobs
    for (int i = 0; i < this.parallelThreads.length; i++) {
      if (logger.isLoggable(Level.INFO))
        logger.info("##############start thread " + i);
      this.parallelThreads[i].start();
    }
   
    //==== wait for the threads finish
    for (int i = 0; i < this.parallelThreads.length; i++) {
      try {
        this.parallelThreads[i].join();
      } catch (InterruptedException e) {
        if (logger.isLoggable(Level.WARNING))
          logger.warning("thread is interupted for server " + i);
      }
    }
   
    //==== merge the nbest files, and remove tmp files
    BufferedWriter nbestWriter =  FileUtility.getWriteFileStream(nbestFile);
    BufferedWriter itemsWriter = null;
    if (JoshuaConfiguration.save_disk_hg) {
      itemsWriter = FileUtility.getWriteFileStream(nbestFile + ".hg.items");
    }
    for (DecoderThread decoder : this.parallelThreads) {
      //merge nbest
      LineReader nbestReader = new LineReader(decoder.nbestFile);
      try {
        for (String sent : nbestReader) {
          nbestWriter.write(sent);
          nbestWriter.newLine();
        }
      } finally {
        nbestReader.close();
      }

      //remove the tem nbest file
      FileUtility.deleteFile(decoder.nbestFile);
      FileUtility.deleteFile(decoder.testFile);
     
      //merge hypergrpah items
      if (JoshuaConfiguration.save_disk_hg) {
        LineReader itemReader = new LineReader(decoder.nbestFile + ".hg.items");
        try {
          for (String sent : itemReader) {         
            itemsWriter.write(sent);
            itemsWriter.newLine();
          }
        } finally {
          itemReader.close();
          decoder.hypergraphSerializer.closeItemsWriter();
        }
        //remove the tem item file
        FileUtility.deleteFile(decoder.nbestFile + ".hg.items");
      }
    }
    nbestWriter.flush();
    nbestWriter.close();
   
    if (JoshuaConfiguration.save_disk_hg) {
      itemsWriter.flush();
      itemsWriter.close();
    }
   
    //merge the grammar rules for disk hyper-graphs
    if (JoshuaConfiguration.save_disk_hg) {
      HashMap<Integer,Integer> tblDone = new HashMap<Integer,Integer>();
      BufferedWriter rulesWriter = FileUtility.getWriteFileStream(nbestFile + ".hg.rules");
      for (DecoderThread decoder : this.parallelThreads) {
        decoder.hypergraphSerializer.writeRulesParallel(rulesWriter, tblDone);
        //decoder.hypergraphSerializer.closeReaders();
      }
      rulesWriter.flush();
      rulesWriter.close();
    }
  }
}
TOP

Related Classes of joshua.decoder.DecoderFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.