Package joshua.prefix_tree

Source Code of joshua.prefix_tree.ExtractRules

/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation; either version 2.1 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
* License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library; if not, write to the Free Software Foundation,
* Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
package joshua.prefix_tree;

import java.io.File;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.PrintStream;
import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
import java.util.logging.Level;
import java.util.logging.Logger;

import joshua.corpus.Corpus;
import joshua.corpus.alignment.AlignmentGrids;
import joshua.corpus.alignment.Alignments;
import joshua.corpus.alignment.mm.MemoryMappedAlignmentGrids;
import joshua.corpus.mm.MemoryMappedCorpusArray;
import joshua.corpus.suffix_array.FrequentPhrases;
import joshua.corpus.suffix_array.ParallelCorpusGrammarFactory;
import joshua.corpus.suffix_array.SuffixArrayFactory;
import joshua.corpus.suffix_array.Suffixes;
import joshua.corpus.suffix_array.mm.MemoryMappedSuffixArray;
import joshua.corpus.vocab.SymbolTable;
import joshua.corpus.vocab.Vocabulary;
import joshua.decoder.JoshuaConfiguration;
import joshua.util.Cache;
import joshua.util.io.BinaryIn;


/**
* Main program to extract hierarchical phrase-based statistical
* translation rules from an aligned parallel corpus using the
* suffix array techniques of Lopez (2008).
*
* @author Lane Schwartz
* @version $LastChangedDate:2008-11-13 13:13:31 -0600 (Thu, 13 Nov 2008) $
* @see "Lopez (2008)"
*/
public class ExtractRules {

  /** Logger for this class. */
  private static final Logger logger =
    Logger.getLogger(ExtractRules.class.getName());

  private String encoding = "UTF-8";

  private String outputFile = "";
 
  private String sourceFileName = "";
  private String sourceSuffixesFileName = "";
 
  private String targetFileName = "";
  private String targetSuffixesFileName = "";
 
  private String alignmentsFileName = "";
  private String commonVocabFileName = "";
 
  private String lexCountsFileName = "";
 
  private String testFileName = "";
  private String frequentPhrasesFileName = "";
 
  private int cacheSize = Cache.DEFAULT_CAPACITY;
 
  private int maxPhraseSpan = 10;
  private int maxPhraseLength = 10;
  private int maxNonterminals = 2;
  private int minNonterminalSpan = 2;
 
  private boolean sentenceInitialX = true;
  private boolean sentenceFinalX = true;
  private boolean edgeXViolates = true;
 
  private boolean requireTightSpans = true;
 
  private boolean binaryCorpus = false;
 
  private String alignmentsType = "AlignmentGrids";
 
  private boolean keepTree = true;
  private int ruleSampleSize = 300;
  private boolean printPrefixTree = false;
 
  private int maxTestSentences = Integer.MAX_VALUE;
  private int startingSentence = 1;
 
  private boolean usePrecomputedFrequentPhrases = true;
 
  public ExtractRules() {
  }
 
  public void setUsePrecomputedFrequentPhrases(boolean usePrecomputedFrequentPhrases) {
    this.usePrecomputedFrequentPhrases = usePrecomputedFrequentPhrases;
  }
 
  public void setSourceFileName(String sourceFileName) {
    this.sourceFileName = sourceFileName;
  }
 
  public void setTargetFileName(String targetFileName) {
    this.targetFileName = targetFileName;
  }
 
  public void setAlignmentsFileName(String alignmentsFileName) {
    this.alignmentsFileName = alignmentsFileName;
  }
 
  public void setLexCountsFileName(String lexCountsFileName) {
    this.lexCountsFileName = lexCountsFileName;
  }
 
  public void setStartingSentence(int startingSentence) {
    this.startingSentence = startingSentence;
  }
 
  public void setMaxPhraseSpan(int maxPhraseSpan) {
    this.maxPhraseSpan = maxPhraseSpan;
  }
 
  public void setMaxPhraseLength(int maxPhraseLength) {
    this.maxPhraseLength = maxPhraseLength;
  }
 
  public void setMaxNonterminals(int maxNonterminals) {
    this.maxNonterminals = maxNonterminals;
  }
 
  public void setMinNonterminalSpan(int minNonterminalSpan) {
    this.minNonterminalSpan = minNonterminalSpan;
  }
 
  public void setCacheSize(int cacheSize) {
    this.cacheSize = cacheSize;
  }
 
  public void setMaxTestSentences(int maxTestSentences) {
    this.maxTestSentences = maxTestSentences;
  }
 
  public void setJoshDir(String joshDir) {

    this.sourceFileName = joshDir + File.separator + "source.corpus";
    this.targetFileName = joshDir + File.separator + "target.corpus";
   
    this.commonVocabFileName = joshDir + File.separator + "common.vocab";

    this.lexCountsFileName = joshDir + File.separator + "lexicon.counts";

    this.sourceSuffixesFileName = joshDir + File.separator + "source.suffixes";
    this.targetSuffixesFileName = joshDir + File.separator + "target.suffixes";
   
    this.alignmentsFileName = joshDir + File.separator + "alignment.grids";
    this.alignmentsType = "MemoryMappedAlignmentGrids";
   
    this.frequentPhrasesFileName = joshDir + File.separator + "frequentPhrases";
   
    this.binaryCorpus = true;
  }
 
  public void setTestFile(String testFileName) {
    this.testFileName = testFileName;
  }
 
  public void setOutputFile(String outputFile) {
    this.outputFile = outputFile;
  }
 
  public void setEncoding(String encoding) {
    this.encoding = encoding;
  }
 
  public void setSentenceInitialX(boolean sentenceInitialX) {
    this.sentenceInitialX = sentenceInitialX;
  }
 
  public void setSentenceFinalX(boolean sentenceFinalX) {
    this.sentenceFinalX = sentenceFinalX;
  }
 
  public void setEdgeXViolates(boolean edgeXViolates) {
    this.edgeXViolates = edgeXViolates;
  }
 
  public void setRequireTightSpans(boolean requireTightSpans) {
    this.requireTightSpans = requireTightSpans;
  }
   
  public void setKeepTree(boolean keepTree) {
    this.keepTree = keepTree;
  }
 
  public void setRuleSampleSize(int ruleSampleSize) {
    this.ruleSampleSize = ruleSampleSize;
  }
 
  public void setPrintPrefixTree(boolean printPrefixTree) {
    this.printPrefixTree = printPrefixTree;
  }
 
 
 
  public ParallelCorpusGrammarFactory getGrammarFactory() throws IOException, ClassNotFoundException {
   
    ////////////////////////////////
    // Common vocabulary          //
    ////////////////////////////////
    if (logger.isLoggable(Level.INFO)) logger.info("Constructing empty common vocabulary");
    Vocabulary commonVocab = new Vocabulary();
    int numSourceWords, numSourceSentences;
    int numTargetWords, numTargetSentences;
    String binaryCommonVocabFileName = this.commonVocabFileName;
    if (binaryCorpus) {
      if (logger.isLoggable(Level.INFO)) logger.info("Initializing common vocabulary from binary file " + binaryCommonVocabFileName);
      ObjectInput in = BinaryIn.vocabulary(binaryCommonVocabFileName);
      commonVocab.readExternal(in);
     
      numSourceWords = Integer.MIN_VALUE;
      numSourceSentences = Integer.MIN_VALUE;
     
      numTargetWords = Integer.MIN_VALUE;
      numTargetSentences = Integer.MIN_VALUE;
    } else {
      if (logger.isLoggable(Level.INFO)) logger.info("Initializing common vocabulary with source corpus " + sourceFileName);
      int[] sourceWordsSentences = Vocabulary.initializeVocabulary(sourceFileName, commonVocab, true);
      numSourceWords = sourceWordsSentences[0];
      numSourceSentences = sourceWordsSentences[1];
     
      if (logger.isLoggable(Level.INFO)) logger.info("Initializing common vocabulary with target corpus " + sourceFileName);     
      int[] targetWordsSentences = Vocabulary.initializeVocabulary(targetFileName, commonVocab, true);
      numTargetWords = targetWordsSentences[0];
      numTargetSentences = targetWordsSentences[1];
    }
 
   
   
    //////////////////////////////////
    // Source language corpus array //
    //////////////////////////////////
    final Corpus sourceCorpusArray;
    if (binaryCorpus) {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing memory mapped source language corpus array.");
      sourceCorpusArray = new MemoryMappedCorpusArray(commonVocab, sourceFileName);
    } else {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing source language corpus array.");
      sourceCorpusArray = SuffixArrayFactory.createCorpusArray(sourceFileName, commonVocab, numSourceWords, numSourceSentences);
    }

    //////////////////////////////////
    // Source language suffix array //
    //////////////////////////////////
    Suffixes sourceSuffixArray;
    String binarySourceSuffixArrayFileName = sourceSuffixesFileName;
    if (binaryCorpus) {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing source language suffix array from binary file " + binarySourceSuffixArrayFileName);
      sourceSuffixArray = new MemoryMappedSuffixArray(binarySourceSuffixArrayFileName, sourceCorpusArray, cacheSize);
    } else {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing source language suffix array from source corpus.");
      sourceSuffixArray = SuffixArrayFactory.createSuffixArray(sourceCorpusArray, cacheSize);
    }
   
   

       
    //////////////////////////////////
    // Target language corpus array //
    //////////////////////////////////
    final Corpus targetCorpusArray;
    if (binaryCorpus) {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing memory mapped target language corpus array.");
      targetCorpusArray = new MemoryMappedCorpusArray(commonVocab, targetFileName);
    } else {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing target language corpus array.");
      targetCorpusArray = SuffixArrayFactory.createCorpusArray(targetFileName, commonVocab, numTargetWords, numTargetSentences);
    }
   

    //////////////////////////////////
    // Target language suffix array //
    //////////////////////////////////
    Suffixes targetSuffixArray;
    String binaryTargetSuffixArrayFileName = targetSuffixesFileName;
    if (binaryCorpus) {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing target language suffix array from binary file " + binaryTargetSuffixArrayFileName);
      targetSuffixArray = new MemoryMappedSuffixArray(binaryTargetSuffixArrayFileName, targetCorpusArray, cacheSize);
    } else {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing target language suffix array from target corpus.");
      targetSuffixArray = SuffixArrayFactory.createSuffixArray(targetCorpusArray, cacheSize);
    }

    int trainingSize = sourceCorpusArray.getNumSentences();
    if (trainingSize != targetCorpusArray.getNumSentences()) {
      throw new RuntimeException("Source and target corpora have different number of sentences. This is bad.");
    }
   
   
    /////////////////////
    // Alignment data  //
    /////////////////////
    if (logger.isLoggable(Level.INFO)) logger.info("Reading alignment data.");
    final Alignments alignments;
    if ("AlignmentArray".equals(alignmentsType)) {
      if (logger.isLoggable(Level.INFO)) logger.info("Using AlignmentArray");
      alignments = SuffixArrayFactory.createAlignments(alignmentsFileName, sourceSuffixArray, targetSuffixArray);
    } else if ("AlignmentGrids".equals(alignmentsType) || "AlignmentsGrid".equals(alignmentsType)) {
      if (logger.isLoggable(Level.INFO)) logger.info("Using AlignmentGrids");
      alignments = new AlignmentGrids(new Scanner(new File(alignmentsFileName)), sourceCorpusArray, targetCorpusArray, trainingSize, requireTightSpans);
    } else if ("MemoryMappedAlignmentGrids".equals(alignmentsType)) {
      if (logger.isLoggable(Level.INFO)) logger.info("Using MemoryMappedAlignmentGrids");
      alignments = new MemoryMappedAlignmentGrids(alignmentsFileName, sourceCorpusArray, targetCorpusArray);
    } else {
      alignments = null;
      logger.severe("Invalid alignment type: " + alignmentsType);
      System.exit(-1);
    }
   
    Map<Integer,String> ntVocab = new HashMap<Integer,String>();
    ntVocab.put(SymbolTable.X, SymbolTable.X_STRING);
   
    //////////////////////
    // Lexical Probs    //
    //////////////////////   

//    final LexProbs lexProbs;
    String binaryLexCountsFilename = this.lexCountsFileName;
   
    //////////////////////
    // Frequent Phrases //
    //////////////////////
    if (usePrecomputedFrequentPhrases) {
      logger.info("Reading precomputed frequent phrases from disk");
      FrequentPhrases frequentPhrases = new FrequentPhrases(sourceSuffixArray, frequentPhrasesFileName);
      frequentPhrases.cacheInvertedIndices();
    }


    logger.info("Constructing grammar factory from parallel corpus");
    ParallelCorpusGrammarFactory parallelCorpus;
    if (binaryCorpus) {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing lexical translation probabilities from binary file " + binaryLexCountsFilename);
      parallelCorpus = new ParallelCorpusGrammarFactory(sourceSuffixArray, targetSuffixArray, alignments, null, binaryLexCountsFilename, ruleSampleSize, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan, JoshuaConfiguration.phrase_owner, JoshuaConfiguration.default_non_terminal, JoshuaConfiguration.oovFeatureCost);
    } else {
      if (logger.isLoggable(Level.INFO)) logger.info("Constructing lexical translation probabilities from parallel corpus");
      parallelCorpus = new ParallelCorpusGrammarFactory(sourceSuffixArray, targetSuffixArray, alignments, null, ruleSampleSize, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan, Float.MIN_VALUE, JoshuaConfiguration.phrase_owner, JoshuaConfiguration.default_non_terminal, JoshuaConfiguration.oovFeatureCost);
    }
    return parallelCorpus;
  }


  public void execute() throws IOException, ClassNotFoundException  {

    // Set System.out and System.err to use the provided character encoding
    try {
      System.setOut(new PrintStream(System.out, true, "UTF-8"));
      System.setErr(new PrintStream(System.err, true, "UTF-8"));
    } catch (UnsupportedEncodingException e1) {
      System.err.println("UTF-8 is not a valid encoding; using system default encoding for System.out and System.err.");
    } catch (SecurityException e2) {
      System.err.println("Security manager is configured to disallow changes to System.out or System.err; using system default encoding.");
    }
   
    PrintStream out;
    if ("-".equals(this.outputFile)) {
      out = System.out;
      logger.info("Rules will be written to standard out");
    } else {
      out = new PrintStream(outputFile,"UTF-8");
      logger.info("Rules will be written to " + outputFile);
    }
   
    ParallelCorpusGrammarFactory parallelCorpus = this.getGrammarFactory();
   
    logger.info("Getting symbol table");
    SymbolTable sourceVocab = parallelCorpus.getSourceCorpus().getVocabulary();
   
    int lineNumber = 0;
    boolean oneTreePerSentence = ! this.keepTree;
   
    logger.info("Will read test sentences from " + testFileName);
    Scanner testFileScanner = new Scanner(new File(testFileName), encoding);
   
    logger.info("Read test sentences from " + testFileName);
    PrefixTree prefixTree = null;
    while (testFileScanner.hasNextLine() && (lineNumber-startingSentence+1)<maxTestSentences) {

      String line = testFileScanner.nextLine();
      lineNumber++;
      if (lineNumber < startingSentence) continue;
     
      int[] words = sourceVocab.getIDs(line);
     
      if (oneTreePerSentence || null==prefixTree)
      {
//        prefixTree = new PrefixTree(sourceSuffixArray, targetCorpusArray, alignments, sourceSuffixArray.getVocabulary(), lexProbs, ruleExtractor, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan);
        if (logger.isLoggable(Level.INFO)) logger.info("Constructing new prefix tree");
        Node.resetNodeCounter();
        prefixTree = new PrefixTree(parallelCorpus);
        prefixTree.setPrintStream(out);
        prefixTree.sentenceInitialX = this.sentenceInitialX;
        prefixTree.sentenceFinalX   = this.sentenceFinalX;
        prefixTree.edgeXMayViolatePhraseSpan = this.edgeXViolates;
      }
      try {
        if (logger.isLoggable(Level.INFO)) logger.info("Processing source line " + lineNumber + ": " + line);
        prefixTree.add(words);
      } catch (OutOfMemoryError e) {
        logger.warning("Out of memory - attempting to clear cache to free space");
        parallelCorpus.getSuffixArray().getCachedHierarchicalPhrases().clear();
//        targetSuffixArray.getCachedHierarchicalPhrases().clear();
        prefixTree = null;
        System.gc();
        logger.info("Cleared cache and collected garbage. Now attempting to re-construct prefix tree...");
//        prefixTree = new PrefixTree(sourceSuffixArray, targetCorpusArray, alignments, sourceSuffixArray.getVocabulary(), lexProbs, ruleExtractor, maxPhraseSpan, maxPhraseLength, maxNonterminals, minNonterminalSpan);
        Node.resetNodeCounter();
        prefixTree = new PrefixTree(parallelCorpus);
        prefixTree.setPrintStream(out);
        prefixTree.sentenceInitialX = this.sentenceInitialX;
        prefixTree.sentenceFinalX   = this.sentenceFinalX;
        prefixTree.edgeXMayViolatePhraseSpan = this.edgeXViolates;
        if (logger.isLoggable(Level.INFO)) logger.info("Re-processing source line " + lineNumber + ": " + line);
        prefixTree.add(words);
      }
     
      if (printPrefixTree) {
        System.out.println(prefixTree.toString());
      }
   
//      if (printRules) {
//        if (logger.isLoggable(Level.FINE)) logger.fine("Outputting rules for source line: " + line);
//
//        for (Rule rule : prefixTree.getAllRules()) {
//          String ruleString = rule.toString(ntVocab, sourceVocab, targetVocab);
//          if (logger.isLoggable(Level.FINEST)) logger.finest("Rule: " + ruleString);
//          out.println(ruleString);
//        }
//      }
     
//      if (logger.isLoggable(Level.FINEST)) logger.finest(lexProbs.toString());
     
   
    }
   
    logger.info("Done extracting rules for file " + testFileName);
   
  }
 

  /**
   * @param args
   * @throws IOException
   * @throws ClassNotFoundException
   */
  public static void main(String[] args) throws IOException, ClassNotFoundException {
   
    if (args.length==3) {
      ExtractRules extractRules = new ExtractRules();
      extractRules.setJoshDir(args[0]);
      extractRules.setOutputFile(args[1]);
      extractRules.setTestFile(args[2]);
      extractRules.execute();
    } else if (args.length==5) {
      ExtractRules extractRules = new ExtractRules();
      extractRules.setSourceFileName(args[0]);
      extractRules.setTargetFileName(args[1]);
      extractRules.setAlignmentsFileName(args[2]);
      extractRules.setOutputFile(args[3]);
      extractRules.setTestFile(args[4]);
      extractRules.execute();
    } else {
      System.err.println("Usage: joshDir outputRules testFile");
      System.err.println("---------------OR------------------");
      System.err.println("Usage: source.txt target.txt alignments.txt outputRules testFile");
    }
   
  }

}
TOP

Related Classes of joshua.prefix_tree.ExtractRules

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.