Package joshua.decoder.chart_parser

Source Code of joshua.decoder.chart_parser.ComputeNodeResult

package joshua.decoder.chart_parser;

import java.util.HashMap;
import java.util.List;

import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;


/**
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2009-12-22 14:00:36 -0500 (星期二, 22 十二月 2009) $
*/

public class ComputeNodeResult {
 
  private double expectedTotalLogP;
  private double finalizedTotalLogP;
  private double transitionTotalLogP;
 
  // the key is state id;
  private HashMap<Integer,DPState> dpStates;
 
 
 
 
  /**
   * Compute logPs and the states of thE node
   */
  public ComputeNodeResult(List<FeatureFunction> featureFunctions, Rule rule,
      List<HGNode> antNodes, int i, int j, SourcePath srcPath,
      List<StateComputer> stateComputers, int sentID){
   
    double finalizedTotalLogP = 0.0;
   
    if (null != antNodes) {
      for (HGNode item : antNodes) {
        finalizedTotalLogP += item.bestHyperedge.bestDerivationLogP; //semiring times
      }
    }
   
   
    HashMap<Integer,DPState> allDPStates = null;
   
    if(stateComputers!=null){
      for(StateComputer stateComputer : stateComputers){
        DPState dpState = stateComputer.computeState(rule, antNodes, i, j, srcPath);         
       
        if(allDPStates==null)
          allDPStates = new HashMap<Integer,DPState>();
        allDPStates.put(stateComputer.getStateID(), dpState);
      }
    }
   
    //=== compute feature logPs
    double transitionLogPSum    = 0.0;
    double futureLogPEstimation = 0.0;
   
    for (FeatureFunction ff : featureFunctions) {   
      transitionLogPSum +=
        ff.getWeight() * ff.transitionLogP(rule, antNodes, i, j, srcPath, sentID);
     
      DPState dpState = null;
      if(allDPStates!=null)
        dpState = allDPStates.get(ff.getStateID());
      futureLogPEstimation +=
        ff.getWeight() * ff.estimateFutureLogP(rule, dpState, sentID);
     
    }
   
    /* if we use this one (instead of compute transition
     * logP on the fly, we will rely on the correctness
     * of rule.statelesscost. This will cause a nasty
     * bug for MERT. Specifically, even we change the
     * weight vector for features along the iteration,
     * the HG cost does not reflect that as the Grammar
     * is not reestimated!!! Of course, compute it on
     * the fly will slow down the decoding (e.g., from
     * 5 seconds to 6 seconds, for the example test
     * set)
     */
    //transitionCostSum += rule.getStatelessCost();
    //System.out.println(futureLogPEstimation);
   
    finalizedTotalLogP += transitionLogPSum;
    double expectedTotalLogP = finalizedTotalLogP + futureLogPEstimation;
   
   
    //== set the final results
    this.expectedTotalLogP = expectedTotalLogP;
    this.finalizedTotalLogP = finalizedTotalLogP;
    this.transitionTotalLogP = transitionLogPSum;
    this.dpStates =  allDPStates;
   
    //System.out.println(rule.toString());
    //printInfo();
  }
 
  public static double computeCombinedTransitionLogP(List<FeatureFunction> featureFunctions, HyperEdge edge,
      int i, int j, int sentID){
    double res = 0;
    for(FeatureFunction ff : featureFunctions) {       
      if(edge.getRule()!=null)
        res += ff.getWeight() * ff.transitionLogP(edge, i,  j, sentID);
      else
        res += ff.getWeight() * ff.finalTransitionLogP(edge, i, j, sentID);   
    }
    return res;
  }
 
  public static double computeCombinedTransitionLogP(List<FeatureFunction> featureFunctions, Rule rule,
      List<HGNode> antNodes,  int i, int j,  SourcePath srcPath, int sentID){
    double res = 0;
    for(FeatureFunction ff : featureFunctions) {       
      if(rule!=null)
        res += ff.getWeight() * ff.transitionLogP(rule, antNodes, i,  j, srcPath, sentID);
      else
        res += ff.getWeight() * ff.finalTransitionLogP(antNodes.get(0), i,  j, srcPath, sentID);   
    }
    return res;
  }
 
  public static double[] computeModelTransitionLogPs(List<FeatureFunction> featureFunctions, HyperEdge edge,
      int i, int j,  int sentID){

      double[] res = new double[featureFunctions.size()];
     
      //=== compute feature logPs
      int k=0;
      for(FeatureFunction ff : featureFunctions) {       
        if(edge.getRule()!=null)
          res[k] = ff.transitionLogP(edge, i, j, sentID);
        else
          res[k] = ff.finalTransitionLogP(edge,  i, j, sentID);   
        k++;
      }
     
      return res;   
  }

  public static double[] computeModelTransitionLogPs(List<FeatureFunction> featureFunctions, Rule rule,
          List<HGNode> antNodes, int i, int j, SourcePath srcPath, int sentID){
   
    double[] res = new double[featureFunctions.size()];
   
    //=== compute feature logPs
    int k=0;
    for(FeatureFunction ff : featureFunctions) {       
      if(rule!=null)
        res[k] = ff.transitionLogP(rule, antNodes, i, j, srcPath, sentID);
      else
        res[k] = ff.finalTransitionLogP(antNodes.get(0),  i, j, srcPath, sentID);   
      k++;
    }
   
    return res;   
  }
   
 
 
  void setExpectedTotalLogP(double logP) {
    this.expectedTotalLogP = logP;
  }
 
  public double getExpectedTotalLogP() {
    return this.expectedTotalLogP;
  }
 
  void setFinalizedTotalLogP(double logP) {
    this.finalizedTotalLogP = logP;
  }
 
  double getFinalizedTotalLogP() {
    return this.finalizedTotalLogP;
  }
 
  void setTransitionTotalLogP(double logP) {
    this.transitionTotalLogP = logP;
  }
 
  double getTransitionTotalLogP() {
    return this.transitionTotalLogP;
  }
 
  void setDPStates(HashMap<Integer,DPState> states) {
    this.dpStates = states;
  }
 
  HashMap<Integer,DPState> getDPStates() {
    return this.dpStates;
  }
 
  public void printInfo(){
    System.out.println("scores: "+ transitionTotalLogP + "; " + finalizedTotalLogP + "; " +  expectedTotalLogP);
  }
}
TOP

Related Classes of joshua.decoder.chart_parser.ComputeNodeResult

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.