Package joshua.decoder.hypergraph

Source Code of joshua.decoder.hypergraph.KBestExtractor$VirtualNode

/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.hypergraph;


import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.logging.Logger;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.chart_parser.ComputeNodeResult;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.tm.Rule;
import joshua.util.CoIterator;
import joshua.util.Regex;
import joshua.util.io.UncheckedIOException;

/**
* This class implements lazy k-best extraction on a hyper-graph.
* To seed the kbest extraction, it only needs that each hyperedge
* should have the best_cost properly set, and it does not require
* any list being sorted.
* Instead, the priority queue heap_cands will do internal sorting
* In fact, the real crucial cost is the transition-cost at each
* hyperedge.
* We store the best-cost instead of the transition cost since it
* is easy to do pruning and find one-best. Moreover, the transition
* cost can be recovered by get_transition_cost(), though somewhat
* expensive.
*
* To recover the model cost for each individual model, we should
* either have access to the model, or store the model cost in the
* hyperedge. (For example, in the case of disk-hypergraph, we need
* to store all these model cost at each hyperedge.)
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-02-08 13:03:13 -0600 (Mon, 08 Feb 2010) $
*/
public class KBestExtractor {
 
  /** Logger for this class. */
  private static final Logger logger =
    Logger.getLogger(KBestExtractor.class.getName());
 
  private final HashMap<HGNode,VirtualNode> virtualNodesTbl = new HashMap<HGNode,VirtualNode>();

 
  private final SymbolTable symbolTable;
 
 
  static String rootSym = "ROOT";
  static int rootID;//TODO: bug
 
  //configuratoin option
  private boolean extractUniqueNbest  = true;
  private boolean extractNbestTree = false;
  private boolean includeAlign=false;
  private boolean addCombinedScore=true;
  private boolean isMonolingual = false;
  private boolean performSanityCheck = true;
 
//  public KBestExtractor(SymbolTable symbolTable) {
//    this.p_symbolTable = symbolTable;
//  }
 
  private int sentID;
 
  public KBestExtractor(SymbolTable symbolTable, boolean extractUniqueNbest, boolean   extractNbestTree,  boolean includeAlign,
      boolean addCombinedScore,  boolean isMonolingual,  boolean performSanityCheck){
    this.symbolTable = symbolTable;
    rootID = symbolTable.addNonterminal(rootSym);
   
    this.extractUniqueNbest = extractUniqueNbest;
    this.extractNbestTree = extractNbestTree;
    this.includeAlign = includeAlign;
    this.addCombinedScore = addCombinedScore;
    this.isMonolingual = isMonolingual;
    this.performSanityCheck = performSanityCheck;
    //System.out.println("===============sanitycheck="+performSanityCheck);
  }
 
 
  //k start from 1
  //***************** you may need to reset_state() before you call this function for the first time
  public String getKthHyp(HGNode it, int k,  int sentID, List<FeatureFunction> models, int[] numNodesAndEdges) {
   
    this.sentID = sentID;
    VirtualNode virtualNode = addVirtualNode(it);
   
    //==== setup the kbest at each hgnode
    DerivationState cur = virtualNode.lazyKBestExtractOnNode(symbolTable, this, k);
    if( cur==null)
      return null;
    else{   
      //==== read the kbest from each hgnode and convert to output format
      double[] modelCost = null;
      if(models!=null)
        modelCost = new double[models.size()];   
      String strHypNumeric = cur.getHypothesis(symbolTable, this, extractNbestTree, modelCost, models, numNodesAndEdges)
      //for(int k=0; k<model_cost.length; k++) System.out.println(model_cost[k]);
      String strHypStr = convertHyp2String(sentID, cur, models, strHypNumeric, modelCost);
      return strHypStr;
    }
  }
 
  //================= extract kbest into trivial kbest hypergraphs 
  public HyperGraph extractKbestIntoHyperGraph(HyperGraph inHG, int topN){
    if ( inHG.goalNode == null ){
      logger.severe("Goal node is null");
      System.exit(1);
    }
    resetState();
   
    List<HyperGraph> hgs = new ArrayList<HyperGraph>();
    int nextN = 0;
    while (true) {
      int[] numNodesAndEdges = new int[2];
      HGNode newGoalNode = getKthHyp(inHG.goalNode, ++nextN, numNodesAndEdges);
     
      if(newGoalNode!=null){
        hgs.add( new HyperGraph(newGoalNode, numNodesAndEdges[0], numNodesAndEdges[1], inHG.sentID, inHG.sentLen) );
      }
     
      if (null == newGoalNode || nextN >= topN)
        break;
    }
    //System.out.println("numHgs=" + hgs.size() + " for topn="+topN);
    return HyperGraph.mergeHyperGraphs(hgs);
  }
 
  //k start from 1
  public HGNode  getKthHyp(HGNode it, int k, int[] numNodesAndEdges) {
    //==== setup the kbest at each hgnode
    VirtualNode virtualNode = addVirtualNode(it);
    DerivationState cur = virtualNode.lazyKBestExtractOnNode(symbolTable, this, k);
    if( cur==null)
      return null;
    else{   
      //==== recursive setup hgnodes
      return cur.getHypothesis(this, numNodesAndEdges);
    }
  }
 
  /*
  public void  getNumNodesAndEdges(HGNode it, int k, int[] numNodesAndEdges) {
    //==== setup the kbest at each hgnode
    VirtualNode virtualNode = addVirtualNode(it);
    DerivationState cur = virtualNode.lazyKBestExtractOnNode(symbolTable, this, k);
    if( cur==null){
      numNodesAndEdges[0]=0;
      numNodesAndEdges[1]=0;
    }else{   
      cur.getNumNodesAndEdges(this, numNodesAndEdges);
    }
  }
  */
 
  /**Filter out those hypotheses that are already in uniqueHyps
   * */
  public void filterKbestHypergraph(HyperGraph hg, Set<String> uniqueHyps){
    resetState();
   
    for(int i=0; i<hg.goalNode.hyperedges.size(); i++){
      HyperEdge goalEdge = hg.goalNode.hyperedges.get(i)
      if(goalEdge.getAntNodes().size()!=1){
        logger.severe("gaol edge does not have exactly one child, must be wrong");
        System.exit(1);
      }
      HGNode childNode = goalEdge.getAntNodes().get(0);
   

      //get the yield and number of nodes and edges in the tree
      int[] numNodesAndEdges = new int[2];
      String hypStr = getKthHyp(childNode, 1, hg.sentID, null, numNodesAndEdges);
      if(hypStr==null || (numNodesAndEdges[0]==0 && numNodesAndEdges[1]==0 )){
        logger.severe("hypStr==null or numNodesAndEdges==0, must be wrong");
        System.exit(1);
      }
     
      if(uniqueHyps.contains(hypStr)){//skip this edge
        hg.numNodes -= numNodesAndEdges[0];
        hg.numEdges -= (numNodesAndEdges[1]+1);
        hg.goalNode.hyperedges.remove(i);
        i--;
      }else{
        uniqueHyps.add(hypStr);       
      }
    }
  }
  //=========================== end kbestHypergraph
 
 
  public void lazyKBestExtractOnHG(
      HyperGraph hg, List<FeatureFunction> models,
      int topN, int sentID, final List<String> out) {
   
    CoIterator<String> coIt = new CoIterator<String>() {
     
      public void coNext(String hypStr) {
        out.add(hypStr);
      }
     
      public void finish() {         
      }
    };
   
    this.lazyKBestExtractOnHG(hg, models, topN, sentID,  coIt);
  }
 
 
  public void lazyKBestExtractOnHG(
      HyperGraph hg, List<FeatureFunction> models,
      int topN,
      int sentID,
      BufferedWriter out
    ) throws IOException {
     
      final BufferedWriter writer;
      if (null == out) {
        writer = new BufferedWriter(new OutputStreamWriter(System.out));
      } else {
        writer = out;
      }
     
      try {
       
        CoIterator<String> coIt = new CoIterator<String>() {
          public void coNext(String hypStr) {
            try {
              writer.write(hypStr);
              writer.write("\n");
            } catch (IOException e) {
              throw new UncheckedIOException(e);
            }
          }
         
          public void finish() {
          }
        };
       
        this.lazyKBestExtractOnHG(hg, models, topN, sentID, coIt);
      } catch (UncheckedIOException e) {
        e.throwCheckedException();
      }
    }
   

  private void lazyKBestExtractOnHG(
    HyperGraph hg,
    List<FeatureFunction> featureFunctions,
    int topN,
    int sentID,
    CoIterator<String> coit) {
 
    this.sentID = sentID;
    resetState();
   
    if (null == hg.goalNode)
      return;
   
    //VirtualItem virtual_goal_item = add_virtual_item(hg.goal_item);
    try {
      int nextN = 0;
      while (true) {
        String hypStr = getKthHyp(hg.goalNode, ++nextN, sentID, featureFunctions, null);
        if(hypStr !=null)
          coit.coNext(hypStr);
       
        if (null == hypStr || nextN >= topN)
          break;
      }
      //g_time_kbest_extract += System.currentTimeMillis()-start;
    } finally {
      coit.finish();
    }
  }
 
 
  public void resetState() {
    virtualNodesTbl.clear();
  }
 
 
  /* non-recursive function
   * format: sent_id ||| hyp ||| individual model cost ||| combined cost
   * sent_id<0: do not add sent_id
   * l_models==null: do not add model cost
   * add_combined_score==f: do not add combined model cost
   * */
  private String convertHyp2String(int sentID, DerivationState cur, List<FeatureFunction> models, String strHypNumeric, double[] modelCost){
    String[] tem = Regex.spaces.split(strHypNumeric);
    StringBuffer strHyp = new StringBuffer();
   
    //####sent id
    if (sentID >= 0) { // valid sent id must be >=0
      strHyp.append(sentID);
      strHyp.append(" ||| ");
    }
   
    //TODO: consider_start_sym
    //####hyp words
    for (int t = 0; t < tem.length; t++) {
      tem[t] = tem[t].trim();
      if (extractNbestTree && ( tem[t].startsWith("(") || tem[t].endsWith(")"))) { // tree tag
        if (tem[t].startsWith("(")) {
          if (includeAlign) {
            // we must account for the {i-j} substring
            int ijStrIndex = tem[t].indexOf('{');
            String tag = this.symbolTable.getWord(Integer.parseInt(tem[t].substring(1,ijStrIndex)));
            strHyp.append('(');
            strHyp.append(tag);
            strHyp.append(tem[t].substring(ijStrIndex)); // append {i-j}
          } else {
            String tag = this.symbolTable.getWord(Integer.parseInt(tem[t].substring(1)));
            strHyp.append('(');
            strHyp.append(tag);
          }
        } else {
          //note: it may have more than two ")", e.g., "3499))"
          int firstBracketPos = tem[t].indexOf(')');//TODO: assume the tag/terminal does not have ')'
          String tag = this.symbolTable.getWord(Integer.parseInt(tem[t].substring(0,firstBracketPos)));
          strHyp.append(tag);
          strHyp.append(tem[t].substring(firstBracketPos));
        }
      } else { // terminal symbol
        strHyp.append(this.symbolTable.getWord(Integer.parseInt(tem[t])));
      }
      if (t < tem.length-1) {
        strHyp.append(' ');
      }
    }
   
    //####individual model cost, and final transition cost
    if (null != modelCost) {
      strHyp.append(" |||");
      double temSum = 0.0;
      for (int k = 0; k < modelCost.length; k++) {
        strHyp.append(String.format(" %.3f", - modelCost[k]));
        temSum += modelCost[k]*models.get(k).getWeight();
       
//        System.err.println("tem_sum: " + tem_sum + " += " + model_cost[k] + " * " + l_models.get(k).getWeight());
      }
     
      int x = 0;
      x++;
     
      //sanity check
//      if (false) {
      if (performSanityCheck) {
        if (Math.abs(cur.cost - temSum) > 1e-2) {
          StringBuilder error = new StringBuilder();
          error.append("\nIn nbest extraction, Cost does not match; cur.cost: " + cur.cost + "; temsum: " +temSum + "\n");
          //System.out.println("In nbest extraction, Cost does not match; cur.cost: " + cur.cost + "; temsum: " +tem_sum);
          for (int k = 0; k < modelCost.length; k++) {
            error.append("model weight: " + models.get(k).getWeight() + "; cost: " +modelCost[k]+ "\n");
            //System.out.println("model weight: " + l_models.get(k).getWeight() + "; cost: " +model_cost[k]);
          }
          throw new RuntimeException(error.toString());
        }
      }
    }
   
    //####combined model cost
    if (addCombinedScore) {
      strHyp.append(String.format(" ||| %.3f",-cur.cost));
    }
   
//    System.err.println("Writing hyp");
   
    return strHyp.toString();
  }
   
 
  private VirtualNode addVirtualNode(HGNode it) {
    VirtualNode res = virtualNodesTbl.get(it);
    if (null == res) {
      res = new VirtualNode(it);
      virtualNodesTbl.put(it, res);
    }
    return res;
  }

 
//=========================== class VirtualNode ===========================
  /*to seed the kbest extraction, it only needs that each hyperedge should have the best_cost properly set, and it does not require any list being sorted
    *instead, the priority queue heap_cands will do internal sorting*/

  private  class VirtualNode {
   
    public List<DerivationState> nbests = new ArrayList<DerivationState>();//sorted ArrayList of DerivationState, in the paper is: D(^) [v]
    private PriorityQueue<DerivationState> candHeap = null; // remember frontier states, best-first;  in the paper, it is called cand[v]
    private HashMap<String, Integer>  derivationTbl = null; // rememeber which DerivationState has been explored; why duplicate, e.g., 1 2 + 1 0 == 2 1 + 0 1
    private HashMap<String, Integer> nbestStrTbl = null; //reember unique *string* at each item, used for unique-nbest-string extraction
    HGNode pNode = null;
   
    public VirtualNode(HGNode it) {
      this.pNode = it;
    }
   
    //return: the k-th hyp or null; k is started from one
    private DerivationState lazyKBestExtractOnNode(SymbolTable symbolTbl, KBestExtractor kbestExtator, int k) {
      if (nbests.size() >= k) { // no need to continue
        return nbests.get(k-1);
      }
     
      //### we need to fill in the l_nest in order to get k-th hyp
      DerivationState res = null;
      if (null == candHeap) {
        getCandidates(symbolTbl, kbestExtator);
      }
      int tAdded = 0; //sanity check
      while (nbests.size() < k) {
        if (candHeap.size() > 0) {
          res = candHeap.poll();
          //derivation_tbl.remove(res.get_signature());//TODO: should remove? note that two state may be tied because the cost is the same
          if (extractUniqueNbest) {
            boolean useTreeFormat=false;
            String res_str = res.getHypothesis(symbolTbl, kbestExtator, useTreeFormat, null, null,  null);
            // We pass false for extract_nbest_tree because we want;
            // to check that the hypothesis *strings* are unique,
            // not the trees.
            //@todo zhifei: this causes trouble to monolingual grammar as there is only one *string*, need to fix it
            if (! nbestStrTbl.containsKey(res_str)) {
              nbests.add(res);
              nbestStrTbl.put(res_str,1);
            }
          } else {
            nbests.add(res);
          }
          lazyNext(symbolTbl, kbestExtator, res);//always extend the last, add all new hyp into heap_cands
         
          //debug: sanity check
          tAdded++;
          if (!extractUniqueNbest && tAdded > 1) { // this is possible only when extracting unique nbest
            throw new RuntimeException("In lazyKBestExtractOnNode, add more than one time, k is " + k);
          }
        } else {
          break;
        }
      }
      if (nbests.size() < k) {
        res = null;//in case we do not get to the depth of k
      }
      //debug: sanity check
      //if (l_nbest.size() >= k && l_nbest.get(k-1) != res) {
      //throw new RuntimeException("In lazy_k_best_extract, ranking is not correct ");
      //}
     
      return res;
    }
   
    //last: the last item that has been selected, we need to extend it
    //get the next hyp at the "last" hyperedge
    private void lazyNext(SymbolTable symbolTbl, KBestExtractor kbestExtator, DerivationState last) {
      if (null == last.edge.getAntNodes()) {
        return;
      }
      for (int i = 0; i < last.edge.getAntNodes().size(); i++) { // slide the ant item
        HGNode it = (HGNode) last.edge.getAntNodes().get(i);
        VirtualNode virtualIT = kbestExtator.addVirtualNode(it);
        int[] newRanks = new int[last.ranks.length];
        for (int c = 0; c < newRanks.length;c++) {
          newRanks[c] = last.ranks[c];
        }
       
        newRanks[i] = last.ranks[i] + 1;
        String newSig = getDerivationStateSignature(last.edge, newRanks, last.edgePos);
       
        //why duplicate, e.g., 1 2 + 1 0 == 2 1 + 0 1
        if (derivationTbl.containsKey(newSig)) {
          continue;
        }
        virtualIT.lazyKBestExtractOnNode(symbolTbl, kbestExtator, newRanks[i]);
        if (newRanks[i] <= virtualIT.nbests.size() // exist the new_ranks[i] derivation
          /*&& "t" is not in heap_cands*/) { // already checked before, check this condition
          double cost = last.cost - virtualIT.nbests.get(last.ranks[i]-1).cost + virtualIT.nbests.get(newRanks[i]-1).cost;
          DerivationState t = new DerivationState(last.parentNode, last.edge, newRanks, cost, last.edgePos);
          candHeap.add(t);
          derivationTbl.put(newSig,1);
        }
      }
    }

    //this is the seeding function, for example, it will get down to the leaf, and sort the terminals
    //get a 1best from each hyperedge, and add them into the heap_cands
    private void getCandidates(SymbolTable symbolTbl, KBestExtractor kbestExtator) {
      candHeap = new PriorityQueue<DerivationState>();
      derivationTbl = new HashMap<String,Integer>();
      if (extractUniqueNbest) {
        nbestStrTbl = new HashMap<String,Integer>();
      }
      //sanity check
      if (null == pNode.hyperedges) {
        throw new RuntimeException("l_hyperedges is null in get_candidates, must be wrong");
      }
      int pos = 0;
      for (HyperEdge edge : pNode.hyperedges) {
        DerivationState t = getBestDerivation(symbolTbl, kbestExtator, pNode, edge, pos);
//        why duplicate, e.g., 1 2 + 1 0 == 2 1 + 0 1 , but here we should not get duplicate
        if (!derivationTbl.containsKey(t.getSignature())) {
          candHeap.add(t);
          derivationTbl.put(t.getSignature(),1);
        } else { // sanity check
          throw new RuntimeException(
            "get duplicate derivation in get_candidates, this should not happen"
            + "\nsignature is " + t.getSignature()
            + "\nl_hyperedge size is " + pNode.hyperedges.size()
            );
        }
        pos++;
     
     
//      TODO: if tem.size is too large, this may cause unnecessary computation, we comment the segment to accormodate the unique nbest extraction     
      /*if(tem.size()>global_n){
        heap_cands=new PriorityQueue<DerivationState>();
        for(int i=1; i<=global_n; i++)
          heap_cands.add(tem.poll());
      }else
        heap_cands=tem;
      */ 
    }
   
    //get my best derivation, and recursively add 1best for all my children, used by get_candidates only
    private DerivationState getBestDerivation(SymbolTable symbolTbl, KBestExtractor kbestExtator, HGNode parentNode, HyperEdge hyperEdge, int edgePos){
      int[] ranks;
      double cost=0;
      if(hyperEdge.getAntNodes()==null){//axiom
        ranks=null;
        cost= - hyperEdge.bestDerivationLogP;//seeding: this hyperedge only have one single translation for the terminal symbol
      }else{//best combination         
        ranks = new int[hyperEdge.getAntNodes().size()];         
        for(int i=0; i < hyperEdge.getAntNodes().size();i++){//make sure the 1best at my children is ready
          ranks[i]=1;//rank start from one                 
          HGNode child_it = (HGNode) hyperEdge.getAntNodes().get(i);//add the 1best for my children
          VirtualNode virtual_child_it = kbestExtator.addVirtualNode(child_it);
          virtual_child_it.lazyKBestExtractOnNode(symbolTbl, kbestExtator,  ranks[i]);
        }
        cost = - hyperEdge.bestDerivationLogP;//seeding
      }       
      DerivationState t = new DerivationState(parentNode, hyperEdge, ranks, cost, edgePos );
      return t;
    }
  };
 

//===============================================
//  class DerivationState
//===============================================
  /*each Node will maintain a list of this, each of which corresponds to a hyperedge and its children's ranks.
   * remember the ranks of a hyperedge node
   * used for kbest extraction*/
 
  //each DerivationState roughly correponds to a hypothesis
  private class DerivationState implements Comparable<DerivationState>
  {
    HGNode parentNode;//the parentNode of the edge
    HyperEdge edge;//in the paper, it is "e"   
    //**lesson: once we define this as a static variable, which cause big trouble
    int edgePos; //this is my position in my parent's Item.l_hyperedges, used for signature calculation
    int[] ranks;//in the paper, it is "j", which is a ArrayList of size |e|
    double cost;//the cost of this hypthesis
   
    public DerivationState(HGNode pa, HyperEdge e, int[] r, double c ,int pos){
      parentNode = pa;
      edge =e ;
      ranks = r;
      cost=c;
      edgePos=pos;
    }
   
    private String getSignature() {
      StringBuffer res = new StringBuffer();
      //res.apend(p_edge2.toString());//Wrong: this may not be unique to identify a hyperedge (as it represent the class name and hashcode which my be equal for different objects)
      res.append(edgePos);
      if (null != ranks) {
        for (int i = 0; i < ranks.length;i++) {
          res.append(' ');
          res.append(ranks[i]);
        }
      }
      return res.toString();
    }
   
   
   
    //get the numeric sequence of the particular hypothesis
    //if want to get model cost, then have to set model_cost and l_models
    private String getHypothesis(SymbolTable symbolTbl, KBestExtractor kbestExtator, boolean useTreeFormat, double[] modelCost,
        List<FeatureFunction> models,  int[] numNodesAndEdges) {
      //### accumulate cost of p_edge into model_cost if necessary
      if (null != modelCost) {
        computeCost(parentNode, edge, modelCost, models);
      }
     
      //### get hyp string recursively
      StringBuffer res = new StringBuffer();
      Rule rl = edge.getRule();
     
      if (null == rl) { // hyperedges under "goal item" does not have rule
        if (useTreeFormat) {
          //res.append("(ROOT ");
          res.append('(');
          res.append(rootID);
          if (includeAlign) {
            // append "{i-j}"
            res.append('{');
            res.append(parentNode.i);
            res.append('-');
            res.append(parentNode.j);
            res.append('}');
          }
          res.append(' ');
        }
        for (int id = 0; id < edge.getAntNodes().size(); id++) {
          res.append( getChildDerivationState(kbestExtator, edge, id).getHypothesis(symbolTbl,kbestExtator, useTreeFormat, modelCost, models, numNodesAndEdges) );
          if (id < edge.getAntNodes().size()-1) res.append(' ');
        }
        if (useTreeFormat)
          res.append(')');
      } else {
        if (useTreeFormat) {
          res.append('(');
          res.append(rl.getLHS());
          if (includeAlign) {
            // append "{i-j}"
            res.append('{');
            res.append(parentNode.i);
            res.append('-');
            res.append(parentNode.j);
            res.append('}');
          }
          res.append(' ');
        }
        if (!isMonolingual) { // bilingual
          int[] english = rl.getEnglish();
          for (int c = 0; c < english.length; c++) {
            if (symbolTbl.isNonterminal(english[c])) {
              int id = symbolTbl.getTargetNonterminalIndex(english[c]);
              res.append( getChildDerivationState(kbestExtator, edge, id).getHypothesis(symbolTbl, kbestExtator, useTreeFormat, modelCost, models, numNodesAndEdges));
            } else {
              res.append(english[c]);
            }
            if (c < english.length-1) res.append(' ');
          }
        } else { // monolingual
          int[] french = rl.getFrench();
          int nonTerminalID = 0;//the position of the non-terminal in the rule
          for (int c = 0; c < french.length; c++) {
            if (symbolTbl.isNonterminal(french[c])) {
              res.append( getChildDerivationState(kbestExtator, edge, nonTerminalID).getHypothesis(symbolTbl,kbestExtator, useTreeFormat, modelCost, models, numNodesAndEdges));
              nonTerminalID++;
            } else {
              res.append(french[c]);
            }
            if (c < french.length-1) res.append(' ');
          }
        }
        if (useTreeFormat)
          res.append(')');
      }
      if(numNodesAndEdges!=null){
        numNodesAndEdges[0]++;
        numNodesAndEdges[1]++;
      }
      return res.toString();
    }
   
    private HGNode getHypothesis( KBestExtractor kbestExtator,  int[] numNodesAndEdges) {
     
      List<HGNode> newAntNodes = null;
      if(edge.getAntNodes()!=null){
        newAntNodes = new ArrayList<HGNode>();
        for (int id = 0; id < edge.getAntNodes().size(); id++) {
          HGNode newNode = getChildDerivationState(kbestExtator, edge, id).getHypothesis(kbestExtator, numNodesAndEdges)
          newAntNodes.add(newNode);
        }
      }
     
      HyperEdge newEdge = new HyperEdge(edge.getRule(), this.cost, edge.getTransitionLogP(false), newAntNodes, edge.getSourcePath());
      numNodesAndEdges[1]++;
     
      HGNode newNode = new HGNode(parentNode.i, parentNode.j, parentNode.lhs, parentNode.dpStates, newEdge, parentNode.getEstTotalLogP());
      numNodesAndEdges[0]++;
     
      return newNode; 
    }
   
    /*
    private void getNumNodesAndEdges(KBestExtractor kbestExtator, int[] numNodesAndEdges) {     
      if(edge.getAntNodes()!=null){
        for (int id = 0; id < edge.getAntNodes().size(); id++) {
          getChildDerivationState(kbestExtator, edge, id).getNumNodesAndEdges(kbestExtator, numNodesAndEdges) ;           
        }
      }
      numNodesAndEdges[0]++;
      numNodesAndEdges[1]++;
    }
    */
   
    private DerivationState getChildDerivationState(KBestExtractor kbestExtator, HyperEdge edge, int id){
      HGNode child = edge.getAntNodes().get(id);
      VirtualNode virtualChild = kbestExtator.addVirtualNode(child);
      return virtualChild.nbests.get(ranks[id]-1);
    }
   
    /*
    //TODO: we assume at most one lm, and the LM is the only non-stateles model
    //another potential difficulty in handling multiple LMs: symbol synchronization among the LMs
    //accumulate hyperedge cost into model_cost[], used by get_hyp()
    private void compute_cost_not_used(HyperEdge dt, double[] model_cost, ArrayList l_models){
      if(model_cost==null) return;
     
      //System.out.println("Rule is: " + dt.rule.toString());
      double stateless_transition_cost =0;
      FeatureFunction lm_model =null;
      int lm_model_index = -1;
      for(int k=0; k< l_models.size(); k++){
        FeatureFunction m = (FeatureFunction) l_models.get(k); 
        double t_res =0;
        if(m.isStateful() == false){//stateless feature
          if(dt.get_rule()!=null){//hyperedges under goal item do not have rules
            FFTransitionResult tem_tbl =  m.transition(dt.get_rule(), null, -1, -1);
            t_res = tem_tbl.getTransitionCost();
          }else{//final transtion
            t_res = m.finalTransition(null);
          }
          model_cost[k] += t_res;
          stateless_transition_cost += t_res*m.getWeight();
        }else{
          lm_model = m;
          lm_model_index = k;
        }
      }
      if(lm_model_index!=-1)//have lm model
        model_cost[lm_model_index] += (dt.get_transition_cost(false)-stateless_transition_cost)/lm_model.getWeight();
    }
    */
   
   
    //accumulate cost into modelCost
    private void computeCost(HGNode parentNode, HyperEdge dt, double[] modelCost, List<FeatureFunction> models){
      if (null == modelCost)
        return;
      //System.out.println("Rule is: " + dt.rule.toString());
      //double[] transitionCosts = ComputeNodeResult.computeModelTransitionCost(models, dt.getRule(), dt.getAntNodes(), parentNode.i, parentNode.j, dt.getSourcePath(), sentID);
      double[] transitionCosts = ComputeNodeResult.computeModelTransitionLogPs(models, dt, parentNode.i, parentNode.j, sentID);
   
      for(int i=0; i<transitionCosts.length; i++){
        modelCost[i] -= transitionCosts[i];
      }
    }
   
   
    //natual order by cost
    public int compareTo(DerivationState another) {
      if (this.cost < another.cost) {
        return -1;
      } else if (this.cost == another.cost) {
        return 0;
      } else {
        return 1;
      }
    }
   
  }//end of Class DerivationState
 
  private String getDerivationStateSignature(HyperEdge edge2, int[] ranks2, int pos) {
    StringBuffer sb = new StringBuffer();
    //sb.apend(p_edge2.toString());//Wrong: this may not be unique to identify a hyperedge (as it represent the class name and hashcode which my be equal for different objects)
    sb.append(pos);
    if (null != ranks2) {
      for (int i = 0; i < ranks2.length; i++) {
        sb.append(' ');
        sb.append(ranks2[i]);
      }
    }
    return sb.toString();
  }
}
TOP

Related Classes of joshua.decoder.hypergraph.KBestExtractor$VirtualNode

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.