package joshua.discriminative.variational_decoder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.TrivialInsideOutside;
import joshua.discriminative.feature_related.feature_function.EdgeTblBasedBaselineFF;
import joshua.discriminative.feature_related.feature_function.FeatureTemplateBasedFF;
import joshua.discriminative.semiring_parsing.DefaultSemiringParser;
import joshua.discriminative.semiring_parsing.ExpectationSemiring;
public class InformationMeasures {
// #######################################################################
public static void main(String[] args) {
if(args.length!=4){
System.out.println("Wrong number of parameters, it must have at least four parameters: java NbestMinRiskAnnealer use_shortest_ref f_config gain_factor f_dev_src f_nbest_prefix f_dev_ref1 f_dev_ref2....");
System.exit(1);
}
//long start_time = System.currentTimeMillis();
String testItemsFile=args[0].trim();
String testRulesFile=args[1].trim();
int num_sents=new Integer(args[2].trim());
String f_config=args[3].trim();//be careful with the weights
//set up models
VariationalDecoderConfiguration.readConfigFile(f_config);
SymbolTable symbolTbl = new BuildinSymbol(null);
List<FeatureFunction> featFunctions = new ArrayList<FeatureFunction>();
HashMap<VariationalNgramApproximator, FeatureTemplateBasedFF> approximatorMap = new HashMap<VariationalNgramApproximator, FeatureTemplateBasedFF> ();
VariationalDecoderConfiguration.initializeModels(f_config, symbolTbl, featFunctions, approximatorMap);
double insideOutsideScalingFactor = VariationalDecoderConfiguration.insideoutsideScalingFactor;
List<FeatureFunction> pFeatFunctions = new ArrayList<FeatureFunction>();
List<FeatureFunction> qFeatFunctions = new ArrayList<FeatureFunction>();
for(FeatureFunction ff : featFunctions){
if(ff instanceof EdgeTblBasedBaselineFF){
pFeatFunctions.add(ff);
System.out.println("############### add one feature in P");
}else{
qFeatFunctions.add(ff);//TODO assume all other features go to q
System.out.println("############## add one feature in q");
}
}
double scale = 1.0;
int ngramStateID = 0;
DefaultSemiringParser parserEntropyP = new CrossEntropyOnHG(1, 0, scale, pFeatFunctions, pFeatFunctions);
DefaultSemiringParser parserEntropyQ = new CrossEntropyOnHG(1, 0, scale, qFeatFunctions, qFeatFunctions);
DefaultSemiringParser parserCrossentropyPQ = new CrossEntropyOnHG(1, 0, scale, pFeatFunctions, qFeatFunctions);
DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, true, null); //have model costs stored
diskHG.initRead(testItemsFile, testRulesFile, null);
for(int sent_id=0; sent_id < num_sents; sent_id ++){
System.out.println("#Process sentence " + sent_id);
HyperGraph testHG = diskHG.readHyperGraph();
//################setup the model: including estimation of variational model
//### step-1: run inside-outside
//note, inside and outside will use the transition_cost of each hyperedge, this cost is already linearly interpolated
TrivialInsideOutside insideOutsider = new TrivialInsideOutside();
insideOutsider.runInsideOutside(testHG, 0, 1, insideOutsideScalingFactor);//ADD_MODE=0=sum; LOG_SEMIRING=1;
//### step-2: model extraction based on the definition of Q
for(Map.Entry<VariationalNgramApproximator, FeatureTemplateBasedFF> entry : approximatorMap.entrySet()){
VariationalNgramApproximator approximator = entry.getKey();
FeatureTemplateBasedFF featureFunction = entry.getValue();
HashMap<String, Double> model = approximator.estimateModel(testHG, insideOutsider);
featureFunction.setModel(model);
}
//###############semiring parsing
parserEntropyP.insideEstimationOverHG(testHG);
parserEntropyQ.insideEstimationOverHG(testHG);
parserCrossentropyPQ.insideEstimationOverHG(testHG);
ExpectationSemiring pGoalSemiring = (ExpectationSemiring) parserEntropyP.getGoalSemiringMember(testHG);
ExpectationSemiring qGoalSemiring = (ExpectationSemiring) parserEntropyQ.getGoalSemiringMember(testHG);
ExpectationSemiring pqGoalSemiring = (ExpectationSemiring) parserCrossentropyPQ.getGoalSemiringMember(testHG);
pGoalSemiring.normalizeFactors();
pGoalSemiring.printInfor();
qGoalSemiring.normalizeFactors();
qGoalSemiring.printInfor();
pqGoalSemiring.normalizeFactors();
pqGoalSemiring.printInfor();
double entropyP = pGoalSemiring.getLogProb() - pGoalSemiring.getFactor1().convertRealValue();//logZ-E(s)??????????
double entropyQ = qGoalSemiring.getLogProb() - qGoalSemiring.getFactor1().convertRealValue();//logZ-E(s)?????
double crossEntropyPQ = qGoalSemiring.getLogProb()- pqGoalSemiring.getFactor1().convertRealValue();//logZ(q)-E(s)?????????
double klPQ = -entropyP + crossEntropyPQ;
if(klPQ<0){System.out.println("kl divergence is negative, must be wrong"); System.exit(1);}
System.out.println("p_entropy=" + entropyP +"; "+"q_entropy=" + entropyQ +"; "+"pq_entropy=" + crossEntropyPQ +"; "+"pq_kl=" + klPQ);
}
}
}