/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.ff.lm;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.DefaultStatefulFF;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
/**
* This class performs the following:
* <ol>
* <li> Gets the additional LM score due to combinations of small
* items into larger ones by using rules
* <li> Gets the LM state
* <li> Gets the left-side LM state estimation score
* </ol>
*
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-01-14 19:15:28 -0600 (Thu, 14 Jan 2010) $
*/
public class LanguageModelFF extends DefaultStatefulFF {
/** Logger for this class. */
private static final Logger logger = Logger.getLogger(LanguageModelFF.class.getName());
private final String START_SYM="<s>";
private final int START_SYM_ID;
private final String STOP_SYM="</s>";
private final int STOP_SYM_ID;
/* These must be static (for now) for LMGrammar, but they shouldn't be! in case of multiple LM features */
static String BACKOFF_LEFT_LM_STATE_SYM="<lzfbo>";
static public int BACKOFF_LEFT_LM_STATE_SYM_ID;//used for equivelant state
static String NULL_RIGHT_LM_STATE_SYM="<lzfrnull>";
static public int NULL_RIGHT_LM_STATE_SYM_ID;//used for equivelant state
private final boolean addStartAndEndSymbol = true;
/**
* N-gram language model. We assume the language model is
* in ARPA format for equivalent state:
*
* <ol>
* <li>We assume it is a backoff lm, and high-order ngram
* implies low-order ngram; absense of low-order ngram
* implies high-order ngram</li>
* <li>For a ngram, existence of backoffweight => existence
* a probability Two ways of dealing with low counts:
* <ul>
* <li>SRILM: don't multiply zeros in for unknown
* words</li>
* <li>Pharaoh: cap at a minimum score exp(-10),
* including unknown words</li>
* </ul>
* </li>
*/
private final NGramLanguageModel lmGrammar;
/**
* We always use this order of ngram, though the LMGrammar
* may provide higher order probability.
*/
private final int ngramOrder;// = 3;
//boolean add_boundary=false; //this is needed unless the text already has <s> and </s>
/** Symbol table that maps between Strings and integers. */
private final SymbolTable symbolTable;
/** stateID is any integer exept -1
**/
public LanguageModelFF(int stateID, int featID, int ngramOrder, SymbolTable psymbol, NGramLanguageModel lmGrammar, double weight) {
super(stateID, weight, featID);
this.ngramOrder = ngramOrder;
this.lmGrammar = lmGrammar;
this.symbolTable = psymbol;
this.START_SYM_ID = psymbol.addTerminal(START_SYM);
this.STOP_SYM_ID = psymbol.addTerminal(STOP_SYM);
LanguageModelFF.BACKOFF_LEFT_LM_STATE_SYM_ID = symbolTable.addTerminal(BACKOFF_LEFT_LM_STATE_SYM);
LanguageModelFF.NULL_RIGHT_LM_STATE_SYM_ID = symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
logger.info("LM feature, with an order=" + ngramOrder);
}
public double transitionLogP(Rule rule, List<HGNode> antNodes, int spanStart, int spanEnd, SourcePath srcPath, int sentID) {
return computeTransition(rule.getEnglish(), antNodes);
}
public double finalTransitionLogP(HGNode antNode, int spanStart, int spanEnd, SourcePath srcPath, int sentID) {
return computeFinalTransitionLogP((NgramDPState)antNode.getDPState(this.getStateID()));
}
/**will consider all the complete ngrams,
* and all the incomplete-ngrams that will have sth fit into its left side*/
public double estimateLogP(Rule rule, int sentID) {
return estimateRuleLogProb(rule.getEnglish());
}
public double estimateFutureLogP(Rule rule, DPState curDPState, int sentID) {
//TODO: do not consider <s> and </s>
boolean addStart = false;
boolean addEnd = false;
return estimateStateLogProb((NgramDPState)curDPState, addStart, addEnd);
}
/**when calculate transition prob: when saw a <bo>, then need to add backoff weights, start from non-state words
* */
private double computeTransition(int[] enWords, List<HGNode> antNodes) {
List<Integer> currentNgram = new ArrayList<Integer>();
double transitionLogP = 0.0;
for (int c = 0; c < enWords.length; c++) {
int curID = enWords[c];
if (symbolTable.isNonterminal(curID)) {
int index = symbolTable.getTargetNonterminalIndex(curID);
NgramDPState state = (NgramDPState) antNodes.get(index).getDPState(this.getStateID());
List<Integer> leftContext = state.getLeftLMStateWords();
List<Integer> rightContext = state.getRightLMStateWords();
if (leftContext.size() != rightContext.size() ) {
throw new RuntimeException("computeTransition: left and right contexts have unequal lengths");
}
//================ left context
for (int i = 0; i < leftContext.size(); i++) {
int t = leftContext.get(i);
currentNgram.add(t);
//always calculate logP for <bo>: additional backoff weight
if (t == BACKOFF_LEFT_LM_STATE_SYM_ID) {
int numAdditionalBackoffWeight = currentNgram.size() - (i+1);//number of non-state words
//compute additional backoff weight
transitionLogP += this.lmGrammar.logProbOfBackoffState(currentNgram, currentNgram.size(), numAdditionalBackoffWeight);
if (currentNgram.size() == this.ngramOrder) {
currentNgram.remove(0);
}
} else if (currentNgram.size() == this.ngramOrder) {
// compute the current word probablity, and remove it
transitionLogP += this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
currentNgram.remove(0);
}
}
//================ right context
//note: left_state_org_wrds will never take words from right context because it is either duplicate or out of range
//also, we will never score the right context probablity because they are either duplicate or partional ngram
int tSize = currentNgram.size();
for (int i = 0; i < rightContext.size(); i++) {
// replace context
currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i) );
}
} else {//terminal words
currentNgram.add(curID);
if (currentNgram.size() == this.ngramOrder) {
// compute the current word probablity, and remove it
transitionLogP += this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
currentNgram.remove(0);
}
}
}
//===== create tabl
//===== get left euquiv state
//double[] lmLeftCost = new double[2];
//int[] equivLeftState = this.lmGrammar.leftEquivalentState(Support.subIntArray(leftLMStateWrds, 0, leftLMStateWrds.size()), this.ngramOrder, lmLeftCost);
//transitionCost += lmLeftCost[0];//add finalized cost for the left state words
return transitionLogP;
}
private double computeFinalTransitionLogP(NgramDPState state) {
double res = 0.0;
List<Integer> currentNgram = new ArrayList<Integer>();
List<Integer> leftContext = state.getLeftLMStateWords();
List<Integer> rightContext = state.getRightLMStateWords();
if (leftContext.size() != rightContext.size()) {
throw new RuntimeException(
"LMModel.compute_equiv_state_final_transition: left and right contexts have unequal lengths");
}
//================ left context
if (addStartAndEndSymbol)
currentNgram.add(START_SYM_ID);
for (int i = 0; i < leftContext.size(); i++) {
int t = leftContext.get(i);
currentNgram.add(t);
if (t == BACKOFF_LEFT_LM_STATE_SYM_ID) {//calculate logP for <bo>: additional backoff weight
int additionalBackoffWeight = currentNgram.size() - (i+1);
//compute additional backoff weight
//TOTO: may not work with the case that add_start_and_end_symbol=false
res += this.lmGrammar.logProbOfBackoffState(
currentNgram, currentNgram.size(), additionalBackoffWeight);
} else { // partial ngram
//compute the current word probablity
if (currentNgram.size() >= 2) { // start from bigram
res += this.lmGrammar.ngramLogProbability(
currentNgram, currentNgram.size());
}
}
if (currentNgram.size() == this.ngramOrder) {
currentNgram.remove(0);
}
}
//================ right context
//switch context, we will never score the right context probablity because they are either duplicate or partional ngram
if(addStartAndEndSymbol){
int tSize = currentNgram.size();
for (int i = 0; i < rightContext.size(); i++) {//replace context
currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
}
currentNgram.add(STOP_SYM_ID);
res += this.lmGrammar.ngramLogProbability(currentNgram, currentNgram.size());
}
return res;
}
/*in general: consider all the complete ngrams, and all the incomplete-ngrams that WILL have sth fit into its left side, so
*if the left side of incomplete-ngrams is a ECLIPS, then ignore the incomplete-ngrams
*if the left side of incomplete-ngrams is a Non-Terminal, then consider the incomplete-ngrams
*if the left side of incomplete-ngrams is boundary of a rule, then consider the incomplete-ngrams*/
private double estimateRuleLogProb(int[] enWords) {
double estimate = 0.0;
boolean considerIncompleteNgrams = true;
List<Integer> words = new ArrayList<Integer>();
boolean skipStart = (enWords[0] == START_SYM_ID);
for (int c = 0; c < enWords.length; c++) {
int curWrd = enWords[c];
/*if (c_wrd == Symbol.ECLIPS_SYM_ID) {
estimate += score_chunk(
words, consider_incomplete_ngrams, skip_start);
consider_incomplete_ngrams = false;
//for the LM bonus function: this simply means the right state will not be considered at all because all the ngrams in right-context will be incomplete
words.clear();
skip_start = false;
} else*/ if( symbolTable.isNonterminal(curWrd) ) {
estimate += scoreChunkLogP( words, considerIncompleteNgrams, skipStart);
considerIncompleteNgrams = true;
words.clear();
skipStart = false;
} else {
words.add(curWrd);
}
}
estimate += scoreChunkLogP( words, considerIncompleteNgrams, skipStart );
return estimate;
}
/**TODO:
* This does not work when addStart == true or addEnd == true
**/
private double estimateStateLogProb(NgramDPState state, boolean addStart, boolean addEnd) {
double res = 0.0;
List<Integer> leftContext = state.getLeftLMStateWords();
if (null != leftContext) {
List<Integer> words = new ArrayList<Integer>();;
if (addStart == true)
words.add(START_SYM_ID);
words.addAll(leftContext);
boolean considerIncompleteNgrams = true;
boolean skipStart = true;
if (words.size() >0 && words.get(0) != START_SYM_ID) {
skipStart = false;
}
res += scoreChunkLogP(words, considerIncompleteNgrams, skipStart);
}
/*if (add_start == true) {
System.out.println("left context: " +Symbol.get_string(l_context) + ";prob "+res);
}*/
if (addEnd == true) {//only when add_end is true, we get a complete ngram, otherwise, all ngrams in r_state are incomplete and we should do nothing
List<Integer> rightContext = state.getRightLMStateWords();
List<Integer> list = new ArrayList<Integer>(rightContext);
list.add(STOP_SYM_ID);
double tem = scoreChunkLogP(list, false, false);
res += tem;
//System.out.println("right context:"+ Symbol.get_string(r_context) + "; score: " + tem);
}
return res;
}
private double scoreChunkLogP(List<Integer> words, boolean considerIncompleteNgrams, boolean skipStart) {
if (words.size() <= 0) {
return 0.0;
} else {
int startIndex;
if (! considerIncompleteNgrams) {
startIndex = this.ngramOrder;
} else if (skipStart) {
startIndex = 2;
} else {
startIndex = 1;
}
return this.lmGrammar.sentenceLogProbability(
words, this.ngramOrder, startIndex);
}
}
}