package joshua.discriminative.variational_decoder;
import java.util.List;
import joshua.decoder.chart_parser.ComputeNodeResult;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.discriminative.semiring_parsing.AtomicSemiring;
import joshua.discriminative.semiring_parsing.DefaultSemiringParser;
import joshua.discriminative.semiring_parsing.ExpectationSemiring;
import joshua.discriminative.semiring_parsing.SignedValue;
/*compute D(p||q)*/
/*Note that Entropy itself is a special case*/
public class CrossEntropyOnHG extends DefaultSemiringParser {
List<FeatureFunction> pFeatFunctions;
List<FeatureFunction> qFeatFunctions;
public CrossEntropyOnHG(int semiring, int add_mode, double scale, List<FeatureFunction> pFeatFunctions, List<FeatureFunction> qFeatFunctions ){
super(semiring, add_mode, scale);//TODO: use different scale for p and q?
this.pFeatFunctions = pFeatFunctions;
this.qFeatFunctions = qFeatFunctions;
}
protected ExpectationSemiring createNewSemiringMember() {
return new ExpectationSemiring();
}
protected ExpectationSemiring getHyperedgeSemiringWeight(HyperEdge dt, HGNode parent_item, double scale, AtomicSemiring p_atomic_semiring){
ExpectationSemiring res = null;
if(p_atomic_semiring.ATOMIC_SEMIRING==AtomicSemiring.LOG_SEMIRING){
double logProbP = - scale * computeTransitionCost(parent_item, dt, pFeatFunctions);//from p
double valQ = - scale * computeTransitionCost(parent_item, dt, qFeatFunctions);//from q;//s(x,y); to compute E(s(x,y)); real semiring
//double factor1 = Math.exp(logProbP)*valQ; //real semiring
SignedValue factor1 = SignedValue.multi(
logProbP,
SignedValue.createSignedValue(valQ)
);
res = new ExpectationSemiring(logProbP, factor1);
}else{
System.out.println("un-implemented atomic-semiring");
System.exit(1);
}
return res;
}
static private double computeTransitionCost(HGNode parentNode, HyperEdge dt, List<FeatureFunction> featFunctions){
double[] transitionCosts = ComputeNodeResult.computeModelTransitionLogPs(
featFunctions, dt, parentNode.i, parentNode.j, -1);
//transition cost
double transCost =0 ;
int i=0;
for(FeatureFunction m : featFunctions ){
transCost += transitionCosts[i++] * m.getWeight();
}
return transCost;
}
}