package joshua.discriminative.semiring_parsingv2.applications.min_risk_da;
import java.util.HashMap;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.discriminative.semiring_parsingv2.DefaultIOParserWithXLinearCombinator;
import joshua.discriminative.semiring_parsingv2.SignedValue;
import joshua.discriminative.semiring_parsingv2.pmodule.ExpectationSemiringPM;
import joshua.discriminative.semiring_parsingv2.pmodule.ListPM;
import joshua.discriminative.semiring_parsingv2.pmodule.SparseMap;
import joshua.discriminative.semiring_parsingv2.semiring.ExpectationSemiring;
import joshua.discriminative.semiring_parsingv2.semiring.LogSemiring;
import joshua.discriminative.training.risk_annealer.hypergraph.FeatureForest;
/** P is in a SemiringLog
* R is a RiskAndEntropyPModule, containing entropy and risk
* S is a map
* T is a map
* */
/** This class implements the method described in Sec 6.2.
* It requires a hyperpgraph, who provides the topology
* and the four quantities including P_e, L_e, log P_e, and (P_e)'.
* This is provided by the feature forest, through three functions:
* getEdgeLogTransitionProb, getEdgeRisk, and getGradientSparseMap
*
* */
/**compute func and graident for
* risk - temperature*entropy*/
public class MinRiskDADenseFeaturesSemiringParser
extends DefaultIOParserWithXLinearCombinator<
ExpectationSemiring<LogSemiring, RiskAndEntropyPM>,
ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>
> {
MinRiskDABO pBilinearOperator = new MinRiskDABO();
//annealing parameters
private double temperature;
//latest value
private double entropy;
private double risk;
private double functionValue;
//TODO
boolean computeEntropy = true;
boolean computeRisk = true;
public MinRiskDADenseFeaturesSemiringParser(double temperature) {
super();
this.temperature = temperature;
}
@Override
protected ExpectationSemiring<LogSemiring, RiskAndEntropyPM>
createNewKWeight() {
LogSemiring p = new LogSemiring();
RiskAndEntropyPM r = new RiskAndEntropyPM();
return new ExpectationSemiring<LogSemiring, RiskAndEntropyPM>( p, r );
}
@Override
protected ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>
createNewXWeight() {
ListPM s = new ListPM( new SparseMap() );
ListPM t = new ListPM( new SparseMap() );
return new ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>( s, t, pBilinearOperator);
}
@Override
protected ExpectationSemiring<LogSemiring, RiskAndEntropyPM>
getEdgeKWeight(HyperEdge dt, HGNode parentItem) {
//== p
double logProb = getFeatureForest().getEdgeLogTransitionProb(dt, parentItem);
LogSemiring p = new LogSemiring(logProb);
//== r
double rRisk = getFeatureForest().getEdgeRisk(dt);
double rEntropy = logProb;//log(p_e)
double rMixed = rRisk + this.getTemperature() * rEntropy; //the objective is risk - T * entropy
RiskAndEntropyPM r = new RiskAndEntropyPM(
SignedValue.createSignedValueFromRealNumber(rMixed),
SignedValue.createSignedValueFromRealNumber(rEntropy),
SignedValue.createSignedValueFromRealNumber(rRisk)
);
//r= p r
r.multiSemiring(p);
return new ExpectationSemiring<LogSemiring, RiskAndEntropyPM>(p,r);
}
@Override
protected ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>
getEdgeXWeight(HyperEdge dt, HGNode parentItem) {
//TODO: p and r has been computed twice, consider speed up
//== p
double logProb = getFeatureForest().getEdgeLogTransitionProb(dt, parentItem);
//== r
double rRisk = getFeatureForest().getEdgeRisk(dt);
double rEntropy = logProb;//log(p_e);
double rMixed = rRisk + this.getTemperature()*(rEntropy + 1); //( L_e + temperature * (logP_e + 1) )
/*
RiskAndEntropyPM r = new RiskAndEntropyPM(
SignedValue.createSignedValueFromRealNumber(rMixed),
SignedValue.createSignedValueFromRealNumber(rEntropy),
SignedValue.createSignedValueFromRealNumber(rRisk)
);*/
SparseMap gradientsMap = getFeatureForest().getGradientSparseMap(parentItem, dt, logProb);
ListPM s = new ListPM( gradientsMap );
//== t = L_e * (P_e)' - temperature * (1+logP_e) (P_e)' = (P_e)' * ( L_e - temperature * (1+logP_e) )
//ListPM t = pBilinearOperator.bilinearMulti(r, s);
ListPM t = pBilinearOperator.bilinearMulti(SignedValue.createSignedValueFromRealNumber(rMixed), s);
return new ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>(s, t, pBilinearOperator);
}
@Override
public void normalizeGoal() {
// TODO Auto-generated method stub
}
//============== additional functions ========
public final void setTemperature(double temperature_){
this.temperature = temperature_;
}
protected final double getTemperature(){
return this.temperature;
}
//return the entropy
public double getEntropy(){
return this.entropy;
}
//return the entropy
public double getRisk(){
return this.risk;
}
public double getFuncVal(){
return functionValue;//this.risk - this.temperature*this.entropy;
}
public HashMap<Integer, Double> computeGradientForTheta(){
this.clearState();
this.insideEstimationOverHG();
this.outsideEstimationOverHG();
ExpectationSemiring<LogSemiring, RiskAndEntropyPM> goalK = this.getGoalK();
ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO> goalX = this.getGoalX();
double logZ = goalK.getP().getLogValue();
//goalK.printInfor();
//goalX.printInfor();
//=== entropy
//--normalize
SignedValue entropyFactor = goalK.getR().getEntropy().duplicate();
entropyFactor.multiLogNumber(-logZ);
this.entropy = logZ - entropyFactor.convertToRealValue();//logZ - \bar{r}/Z
//System.out.print("Entropy is" + entropy);
//=== risk
//--normalize
SignedValue riskFactor = goalK.getR().getRisk().duplicate();
riskFactor.multiLogNumber(-logZ);
this.risk = riskFactor.convertToRealValue();
//System.out.print("Risk is" + risk);
this.functionValue = this.risk - this.temperature*this.entropy;
//=== gradients
//System.out.print("Gradients are: ");
HashMap<Integer, Double> gradient = new HashMap<Integer, Double>();
for(Integer featID : goalX.getT().getValue().getIds()){
//delta(r)*Z/Z^2=delta(r)/Z
//--normalize
SignedValue resT = goalX.getT().getValue().getValueAt(featID).duplicate();
resT.multiLogNumber(-logZ);
//-delta(Z)*r/Z^2
SignedValue resRS = SignedValue.multi(
goalX.getS().getValue().getValueAt(featID),
goalK.getR().getValue()
);
resRS.negate();
resRS.multiLogNumber(-2*logZ);
//-T*delta(Z)/Z
SignedValue resS = goalX.getS().getValue().getValueAt(featID).duplicate();
resS.multiLogNumber(Math.log(this.getTemperature()));
resS.negate();
resS.multiLogNumber(-logZ);
//add them together
resT.add(resRS);
resT.add(resS);
double finalVal = resT.convertToRealValue();
gradient.put(featID, finalVal);
if(Double.isNaN(finalVal)){
System.out.println("gradient value for theta is NaN");
System.exit(1);
}
//System.out.print( gradient[i]+" " );
}
/*
System.out.println("Risk is : " + risk);
System.out.println("Entropy is : " + entropy);
System.out.println("Function value is : " + functionValue);
*/
return gradient;
}
//@todo: parameterize the HG
private final FeatureForest getFeatureForest(){
return (FeatureForest) hg;
}
}