Package joshua.discriminative.variational_decoder.nbest

Source Code of joshua.discriminative.variational_decoder.nbest.EstimateEntropyGapOnNbest

package joshua.discriminative.variational_decoder.nbest;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.discriminative.semiring_parsing.AtomicSemiring;
import joshua.discriminative.training.risk_annealer.nbest.NbestRiskGradientComputer;



public class EstimateEntropyGapOnNbest {
 
  int topN=300;

  boolean useUniqueNbest =false;
  boolean useTreeNbest = false;
  boolean addCombinedCost = true
 
 
  SymbolTable symbolTbl;
  KBestExtractor kbestExtractor;
 
  double scalingFactor= 1.0;

  AtomicSemiring atomicSemirng = new AtomicSemiring(1,0);
 
  public EstimateEntropyGapOnNbest(SymbolTable symbol_, double scalingFactor, int topn){
    this.scalingFactor = scalingFactor;
    this.topN = topn;
   
    this.symbolTbl = symbol_;
    this.kbestExtractor = new KBestExtractor(this.symbolTbl, this.useUniqueNbest, this.useTreeNbest, false, this.addCombinedCost,  false, true);
  }
   

  //if return_disorder_nbest=true; then a disorder nbest with sum log_prob (sent_id ||| hyp ||| empty_feature_scores ||| sum_log_prob)
  //else return 1best (just the hypotheses itself)
  public double processOneSent(HyperGraph hg, int sentenceID, boolean returnDisorderNbest){
    //### step-2: get a nbest derivations
    List<String> nbestNonUniqueStrings = new ArrayList<String>();
    kbestExtractor.lazyKBestExtractOnHG(hg, null, this.topN, sentenceID, nbestNonUniqueStrings);
   
    //### step-3: get the sum for each of the unique strings
    HashMap<String, Double> uniqueStringsSumProbTbl = new HashMap<String, Double>();
    HashMap<String, List<Double>> uniqueStringsListProbTbl = new HashMap<String, List<Double>>();
    double globalLogNorm = Double.NEGATIVE_INFINITY;
    for(String derivation_string : nbestNonUniqueStrings){
      String[] fds = derivation_string.split("\\s+\\|{3}\\s+");
      String hyp_string = fds[1];
      double log_prob = new Double(fds[fds.length-1])*scalingFactor;//normalized log prob //TODO: use inside_outside_scaling_factor here
     
     
      //sum probablity
      Double old_sum =  (Double)uniqueStringsSumProbTbl.get(hyp_string);
      if(old_sum==null){
        old_sum= Double.NEGATIVE_INFINITY;//zero prob
      }
      uniqueStringsSumProbTbl.put(hyp_string, atomicSemirng.add_in_atomic_semiring(old_sum, log_prob));
      globalLogNorm = atomicSemirng.add_in_atomic_semiring(globalLogNorm, log_prob);
      
      //list of probabilities
      List<Double> oldList  =  uniqueStringsListProbTbl.get(hyp_string);
      if(oldList==null){
        oldList= new ArrayList<Double>();
        uniqueStringsListProbTbl.put(hyp_string, oldList);
      }
      oldList.add(log_prob);
    }
   
    //### step-4: find the nbest or find the translation string having the best sum-probablity
   
    ArrayList<Double> listSumProbs = new ArrayList<Double>();
    double tGlobalSum = 0;
    double gap = 0;
    for(String hyp : uniqueStringsSumProbTbl.keySet()){//each unique string
      double sumProb   = Math.exp(uniqueStringsSumProbTbl.get(hyp) - globalLogNorm);
      tGlobalSum += sumProb;
      listSumProbs.add(sumProb);
     
      //compute H(d|y)
      double tLocalSum = 0;
      double sumLogProb = uniqueStringsSumProbTbl.get(hyp);
      List<Double> derivationProbs = uniqueStringsListProbTbl.get(hyp);
      for(int i=0; i<derivationProbs.size(); i++){
        double tProb = Math.exp( derivationProbs.get(i)-sumLogProb );
        tLocalSum += tProb;
        derivationProbs.set(i, tProb);
      }
      if(Math.abs(tLocalSum-1.0)>1e-4){System.out.println("local P is not sum to one, must be wrong; " +tLocalSum);  System.exit(1);}
      double entropyDGivenY = NbestRiskGradientComputer.computeEntropy(derivationProbs);
      gap += entropyDGivenY*sumProb;
    }
    double stringEntropy = NbestRiskGradientComputer.computeEntropy(listSumProbs);
    if(Math.abs(tGlobalSum-1.0)>1e-4){System.out.println("global P is not sum to one, must be wrong");  System.exit(1);}
 
    System.out.println("stringEntropy " + stringEntropy + "gap " + gap);
    return gap;
  }
 
 
 
  public static void main(String[] args) throws InterruptedException, IOException {
    if(args.length!=6){
      System.out.println("Wrong number of parameters, it must be  5");
      System.exit(1);
    }   
    //long start_time = System.currentTimeMillis();   
    String f_test_items=args[0].trim();
    String f_test_rules=args[1].trim();
    int num_sents=new Integer(args[2].trim());
    //String f_1best=args[3].trim();//output
    int topN = new Integer(args[4].trim());
    double inside_outside_scaling_factor = new Double(args[5].trim());
 
    int baseline_lm_feat_id = 0;
    SymbolTable p_symbol = new BuildinSymbol(null);
   
    EstimateEntropyGapOnNbest cruncher = new EstimateEntropyGapOnNbest(p_symbol, inside_outside_scaling_factor, topN);   
 
    //BufferedWriter t_writer_1best =  FileUtilityOld.getWriteFileStream(f_1best); 
   
    System.out.println("############Process file  " + f_test_items);
    DiskHyperGraph dhg_test = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, true, null); //have model costs stored
    dhg_test.initRead(f_test_items, f_test_rules,null);   
    double sumGap = 0;
    for(int sent_id=0; sent_id < num_sents; sent_id ++){
      System.out.println("#Process sentence " + sent_id);
      HyperGraph hg_test = dhg_test.readHyperGraph();     
      double gap = cruncher.processOneSent(hg_test, sent_id, false);//produce the reranked onebest
      sumGap += gap;
    }
    sumGap *= 1.44;
    System.out.println("sum of bits in gap is: " +  sumGap);
    //FileUtilityOld.close_write_file(t_writer_1best);       
  }
}
TOP

Related Classes of joshua.discriminative.variational_decoder.nbest.EstimateEntropyGapOnNbest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.