Package joshua.discriminative.semiring_parsingv2.applications.min_risk_da

Source Code of joshua.discriminative.semiring_parsingv2.applications.min_risk_da.MinRiskDADenseFeaturesSemiringParser

package joshua.discriminative.semiring_parsingv2.applications.min_risk_da;

import java.util.HashMap;

import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.discriminative.semiring_parsingv2.DefaultIOParserWithXLinearCombinator;
import joshua.discriminative.semiring_parsingv2.SignedValue;
import joshua.discriminative.semiring_parsingv2.pmodule.ExpectationSemiringPM;
import joshua.discriminative.semiring_parsingv2.pmodule.ListPM;
import joshua.discriminative.semiring_parsingv2.pmodule.SparseMap;
import joshua.discriminative.semiring_parsingv2.semiring.ExpectationSemiring;
import joshua.discriminative.semiring_parsingv2.semiring.LogSemiring;
import joshua.discriminative.training.risk_annealer.hypergraph.FeatureForest;

/** P is in a SemiringLog
*  R is a RiskAndEntropyPModule, containing entropy and risk
*  S is a map
*  T is a map
* */

/** This class implements the method described in Sec 6.2.
* It requires a hyperpgraph, who provides the topology
* and the four quantities including P_e, L_e, log P_e, and (P_e)'.
* This is provided by the feature forest, through three functions:
* getEdgeLogTransitionProb, getEdgeRisk, and getGradientSparseMap
*
* */

/**compute func and graident for
*  risk - temperature*entropy*/

public class MinRiskDADenseFeaturesSemiringParser
extends  DefaultIOParserWithXLinearCombinator<
  ExpectationSemiring<LogSemiring, RiskAndEntropyPM>,
  ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>
  > {
 
  MinRiskDABO pBilinearOperator = new MinRiskDABO();
 
  //annealing parameters
  private double temperature;

  //latest value
  private double entropy;
  private double risk;
  private double functionValue;
 
  //TODO
  boolean computeEntropy = true;
  boolean computeRisk = true;
  
 
  public MinRiskDADenseFeaturesSemiringParser(double temperature) {
   
    super();            
    this.temperature = temperature;
  }

  @Override
  protected ExpectationSemiring<LogSemiring, RiskAndEntropyPM>
 
  createNewKWeight() {   
    LogSemiring p = new LogSemiring();
    RiskAndEntropyPM r = new RiskAndEntropyPM();
    return new ExpectationSemiring<LogSemiring, RiskAndEntropyPM>( p, r );
  }


  @Override
  protected ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>
 
  createNewXWeight() {   
    ListPM s = new ListPM( new SparseMap() );   
    ListPM t = new ListPM( new SparseMap() );   
    return new ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM,  MinRiskDABO>( s, t,  pBilinearOperator);
  }

 
  @Override
  protected ExpectationSemiring<LogSemiring, RiskAndEntropyPM>
 
  getEdgeKWeight(HyperEdge dt, HGNode parentItem) {   
   
    //== p
    double logProb = getFeatureForest().getEdgeLogTransitionProb(dt, parentItem);
    LogSemiring p = new LogSemiring(logProb);
   
    //== r
    double rRisk = getFeatureForest().getEdgeRisk(dt);
    double rEntropy =  logProb;//log(p_e)
    double rMixed = rRisk + this.getTemperature() * rEntropy; //the objective is risk - T * entropy
   
    RiskAndEntropyPM r =  new RiskAndEntropyPM(
        SignedValue.createSignedValueFromRealNumber(rMixed),
        SignedValue.createSignedValueFromRealNumber(rEntropy),
        SignedValue.createSignedValueFromRealNumber(rRisk)
    );
   
    //r= p r
    r.multiSemiring(p);
    return new ExpectationSemiring<LogSemiring, RiskAndEntropyPM>(p,r);
  }
 
 
 
 
  @Override
  protected ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>
 
  getEdgeXWeight(HyperEdge dt, HGNode parentItem) {
    //TODO: p and r has been computed twice, consider speed up   
    //== p
    double logProb = getFeatureForest().getEdgeLogTransitionProb(dt, parentItem);
   
    //== r
    double rRisk = getFeatureForest().getEdgeRisk(dt);
    double rEntropy = logProb;//log(p_e);
   
    double rMixed = rRisk + this.getTemperature()*(rEntropy + 1); //( L_e + temperature * (logP_e + 1) )
   
    /*
    RiskAndEntropyPM r =  new RiskAndEntropyPM(
        SignedValue.createSignedValueFromRealNumber(rMixed),
        SignedValue.createSignedValueFromRealNumber(rEntropy),
        SignedValue.createSignedValueFromRealNumber(rRisk)
    );*/
       
     
    SparseMap gradientsMap = getFeatureForest().getGradientSparseMap(parentItem, dt, logProb);
    ListPM s = new ListPM( gradientsMap );
   
 
    //== t = L_e * (P_e)' - temperature * (1+logP_e) (P_e)' = (P_e)' * ( L_e - temperature * (1+logP_e) )
    //ListPM t = pBilinearOperator.bilinearMulti(r, s);
    ListPM t = pBilinearOperator.bilinearMulti(SignedValue.createSignedValueFromRealNumber(rMixed), s);
   
   
    return new ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO>(s, t, pBilinearOperator);
  }

 
  @Override
  public void normalizeGoal() {
    // TODO Auto-generated method stub
   
  }
 
 
 
  //============== additional functions ========
 
  public final void setTemperature(double temperature_){
    this.temperature = temperature_;
  }
 
  protected final double getTemperature(){
    return this.temperature;
  }
 
 

  //return the entropy
  public double getEntropy(){
    return this.entropy;
  }
 
 
  //return the entropy
  public double getRisk(){
    return this.risk;
   
  }
 
  public double getFuncVal(){
    return functionValue;//this.risk - this.temperature*this.entropy;   
  }
 
 
 
  public HashMap<Integer, Double> computeGradientForTheta(){
    this.clearState();
    this.insideEstimationOverHG();
    this.outsideEstimationOverHG();
   
    ExpectationSemiring<LogSemiring, RiskAndEntropyPM> goalK = this.getGoalK();
    ExpectationSemiringPM<LogSemiring, RiskAndEntropyPM, ListPM, ListPM, MinRiskDABO> goalX = this.getGoalX();
    double logZ = goalK.getP().getLogValue();
    //goalK.printInfor();
    //goalX.printInfor();
 
   
    //=== entropy
    //--normalize
    SignedValue entropyFactor = goalK.getR().getEntropy().duplicate();
    entropyFactor.multiLogNumber(-logZ);
   
    this.entropy = logZ - entropyFactor.convertToRealValue();//logZ - \bar{r}/Z
    //System.out.print("Entropy is" + entropy);
   
    //=== risk
    //--normalize
    SignedValue riskFactor = goalK.getR().getRisk().duplicate();
    riskFactor.multiLogNumber(-logZ);
    this.risk =  riskFactor.convertToRealValue();
    //System.out.print("Risk is" + risk);
   
    this.functionValue = this.risk - this.temperature*this.entropy;
   
   
    //=== gradients
    //System.out.print("Gradients are: ");
    HashMap<Integer, Double>  gradient = new HashMap<Integer, Double>();
    for(Integer featID : goalX.getT().getValue().getIds()){
      //delta(r)*Z/Z^2=delta(r)/Z
      //--normalize
      SignedValue resT = goalX.getT().getValue().getValueAt(featID).duplicate();
      resT.multiLogNumber(-logZ);
       
      //-delta(Z)*r/Z^2
      SignedValue resRS = SignedValue.multi(
          goalX.getS().getValue().getValueAt(featID),
          goalK.getR().getValue()
        );
      resRS.negate();
      resRS.multiLogNumber(-2*logZ);
     
      //-T*delta(Z)/Z
      SignedValue resS = goalX.getS().getValue().getValueAt(featID).duplicate();
      resS.multiLogNumber(Math.log(this.getTemperature()));
      resS.negate();     
      resS.multiLogNumber(-logZ);
     
      //add them together
      resT.add(resRS);
      resT.add(resS);
      double finalVal = resT.convertToRealValue();
     
      gradient.put(featID, finalVal);
      if(Double.isNaN(finalVal)){
        System.out.println("gradient value for theta is NaN");
        System.exit(1);
      }
      //System.out.print( gradient[i]+" " );
    }
    /*
    System.out.println("Risk is : " + risk);
    System.out.println("Entropy is : " + entropy);
    System.out.println("Function value is : " + functionValue);
    */
    return gradient;
  }
 
 
  //@todo: parameterize the HG
  private final FeatureForest getFeatureForest(){
    return (FeatureForest) hg;
  }
 
}
TOP

Related Classes of joshua.discriminative.semiring_parsingv2.applications.min_risk_da.MinRiskDADenseFeaturesSemiringParser

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.