package joshua.discriminative.monolingual_parser;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.tm.GrammarFactory;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.TrivialInsideOutside;
public class EMDecoderThread extends MonolingualDecoderThread {
private HashMap<HGNode, Boolean> processedItemsTbl = new HashMap<HGNode, Boolean>();//help to tranverse a hypergraph
TrivialInsideOutside insideOutsider = new TrivialInsideOutside();
double ioScalingFactor = 1.0;//TODO
EMDecoderFactory parentFactory = null;//pointer to the Factory who creates me
public EMDecoderThread(EMDecoderFactory parentFactory, GrammarFactory[] grammarFactories, boolean haveLMModel,
List<FeatureFunction> featFunctions, List<Integer> defaultNonterminals,
SymbolTable symbolTable, String testFile, int startSentID) throws IOException {
super(grammarFactories, haveLMModel, featFunctions,
defaultNonterminals, symbolTable, testFile,
startSentID);
this.parentFactory = parentFactory;
}
@Override
public void postProcessHypergraph(HyperGraph hyperGraph, int sentenceID) throws IOException{
//=== run E step here for each hypergraph
collectPosteriorCount(hyperGraph);
}
//recursive
private void collectPosteriorCount(HyperGraph hg){
insideOutsider.runInsideOutside(hg, 0, 1, ioScalingFactor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
parentFactory.accumulateDataLogProb(insideOutsider.getLogNormalizationConstant());
collectHGNodePosteriorCount(hg.goalNode);
clearState();
}
private void clearState(){
insideOutsider.clearState();
processedItemsTbl.clear();
}
//recursive
private void collectHGNodePosteriorCount(HGNode it){
if(processedItemsTbl.containsKey(it))return;
processedItemsTbl.put(it,true);
//### recursive call on each deduction
for(HyperEdge dt : it.hyperedges){
collectHyperEdgePosteriorCount(it, dt);//deduction-specifc feature
}
}
//recursive
private void collectHyperEdgePosteriorCount(HGNode parentNode, HyperEdge dt){
//### recursive call on each ant item
if(dt.getAntNodes()!=null)
for(HGNode antNode : dt.getAntNodes())
collectHGNodePosteriorCount(antNode);
//### deduction-specific operation
Rule rl = dt.getRule();
if(rl!=null){
//TODO: underflow problem
//TODO: what about OOV rule
//TODO: synchronization problem
//System.out.println("postProb: " + p_inside_outside.get_deduction_posterior_prob(dt, parent_item));
parentFactory.incrementRulePosteriorProb(rl, insideOutsider.getEdgePosteriorProb(dt, parentNode));
}
}
@Override
public void postProcess() throws IOException{
//do nothing
}
}