package joshua.discriminative.training.contrastive_estimation;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.NbestMinRiskReranker;
import joshua.decoder.ff.tm.Grammar;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.ff.tm.Trie;
import joshua.decoder.ff.tm.hiero.MemoryBasedBatchGrammar;
public class ConfusionDeriver extends ConfusionExtractor{
ArrayList<Double> featureWeights;
double scale=1.0;
public ConfusionDeriver(SymbolTable symbol_, ArrayList<Double> featureWeights_, double scale_) {
super(symbol_);
this.featureWeights = featureWeights_;
this.scale = scale_;
}
/** Logger for this class. */
//private static final Logger logger = Logger.getLogger(AbstractGrammar.class.getName());
public void deriveConfusionFromGrammar(Grammar gr) {
Trie root = gr.getTrieRoot();
if(root!=null){
deriveConfusionFromTrieNode(root);
}
}
private void deriveConfusionFromTrieNode(Trie node) {
if (node != null) {
if(node.hasRules()) {
List<Rule> rules = node.getRules().getRules();
List<Double> probs = obtainDistribution(rules);
getConfusionFromRules(rules, probs);
}
}
if(node.hasExtensions()){
for (Trie child : node.getExtensions()) {
deriveConfusionFromTrieNode(child);
}
}
}
private List<Double> obtainDistribution(List<Rule> rules){
//== get a list of log-probs
List<Double> res = new ArrayList<Double>();
for(Rule rule : rules){
double logProb = computeRuleLogProb(rule, featureWeights);
res.add(logProb);
}
//== normalize the probs into values within [0,1]
NbestMinRiskReranker.computeNormalizedProbs(res, scale);
return res;
}
private double computeRuleLogProb(Rule rl, ArrayList<Double> weights) {
double logProb = 0.0;
for(int i=0; i<weights.size(); i++){
logProb += weights.get(i) * rl.getFeatureCost(i);
}
return logProb;
}
// ======================== main method ================================
public static void main(String[] args) throws IOException{
if(args.length<2){
System.out.println("Wrong command, it should be: java ConfusionExtractor f_input_grammar f_confusion_grammar");
}
SymbolTable symbolTbl = new BuildinSymbol(null);
String fInputGrammar = args[0];
String fConfusiongrammar = args[1];
ArrayList<Double> featureWeights = new ArrayList<Double>();
for(int i=2; i<args.length; i++){
featureWeights.add( new Double(args[i]) );
}
System.out.println("Feature Weights are: ");
System.out.println(featureWeights);
/*
String f_hypergraphs="C:\\data_disk\\java_work_space\\sf_trunk\\example\\example.nbest.javalm.out.hg.items";
String f_rule_tbl="C:\\data_disk\\java_work_space\\sf_trunk\\example\\example.nbest.javalm.out.hg.rules";
String f_confusion_grammar;
if(itemSpecific)
f_confusion_grammar="C:\\Users\\zli\\Documents\\itemspecific.confusion.grammar";
else
f_confusion_grammar="C:\\Users\\zli\\Documents\\cellspecific.confusion.grammar";
*/
ConfusionDeriver confusionDeriver = new ConfusionDeriver(symbolTbl, featureWeights, 1.0);
Grammar inputGrammar = new MemoryBasedBatchGrammar(
"hiero",
fInputGrammar,
symbolTbl,
"fake",
"fake",
-1,
-1);
confusionDeriver.deriveConfusionFromGrammar(inputGrammar);
confusionDeriver.printConfusionTbl(fConfusiongrammar);
}
// ======================== end ================================
}