Package joshua.discriminative.semiring_parsingv2.applications

Source Code of joshua.discriminative.semiring_parsingv2.applications.HypLenSquareExpectation

package joshua.discriminative.semiring_parsingv2.applications;

import java.util.ArrayList;
import java.util.logging.Logger;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.semiring_parsingv2.DefaultIOParserWithXLinearCombinator;
import joshua.discriminative.semiring_parsingv2.SignedValue;
import joshua.discriminative.semiring_parsingv2.bilinear_operator.ScalarBO;
import joshua.discriminative.semiring_parsingv2.pmodule.ExpectationSemiringPM;
import joshua.discriminative.semiring_parsingv2.pmodule.ScalarPM;
import joshua.discriminative.semiring_parsingv2.semiring.ExpectationSemiring;
import joshua.discriminative.semiring_parsingv2.semiring.LogSemiring;




public class HypLenSquareExpectation  extends DefaultIOParserWithXLinearCombinator<
ExpectationSemiring<LogSemiring,ScalarPM>,
ExpectationSemiringPM<LogSemiring,ScalarPM,ScalarPM,ScalarPM,ScalarBO>
> {
 
  private static final Logger logger =
    Logger.getLogger(HypLenSquareExpectation.class.getName());
 
  double scale;
  static  ScalarBO pBilinearOperator = new ScalarBO();
 
  public HypLenSquareExpectation(double scale_){
    super();
    this.scale = scale_;
  }

  @Override
  protected ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO>
  createNewXWeight() {
    ScalarPM s = new ScalarPM();
    ScalarPM t = new ScalarPM();
    return new ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO>(s, t, pBilinearOperator);
  }

  @Override
  protected ExpectationSemiring<LogSemiring, ScalarPM> createNewKWeight() {
    LogSemiring p = new LogSemiring();
    ScalarPM s = new ScalarPM();
    return new ExpectationSemiring<LogSemiring, ScalarPM>(p,s);
  }

  @Override
  protected ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM,ScalarBO>
  getEdgeXWeight(HyperEdge dt, HGNode parent_item) {
    //== p
    double logProb = scale * dt.getTransitionLogP(false);
    LogSemiring p = new LogSemiring(logProb);
   
    //== r
    double val = 0;//real
    if(dt.getRule()!=null){
      val = dt.getRule().getEnglish().length-dt.getRule().getArity();//length; real semiring
    }   
    ScalarPM r =  new ScalarPM( SignedValue.createSignedValueFromRealNumber(val) );
   
    //== s
    ScalarPM s = r;
   
    //== t
    ScalarPM t = pBilinearOperator.bilinearMulti(r, s);
   
    //s = p s
    s.multiSemiring(p);
   
    //t= p t
    t.multiSemiring(p);
   
    return new ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO>(s, t, pBilinearOperator);
  }
 
 

  @Override
  protected ExpectationSemiring<LogSemiring, ScalarPM>
  getEdgeKWeight(HyperEdge dt, HGNode parent_item) {
    //== p
    double logProb = scale * dt.getTransitionLogP(false);
    LogSemiring p = new LogSemiring(logProb);
   
    //== r
    double val = 0;//real
    if(dt.getRule()!=null){
      val = dt.getRule().getEnglish().length-dt.getRule().getArity();//length; real semiring
    }   
    ScalarPM r =  new ScalarPM( SignedValue.createSignedValueFromRealNumber(val) );
   
    // r= p r
    r.multiSemiring(p);
   
    return new ExpectationSemiring<LogSemiring, ScalarPM>(p,r);
  }
 

  @Override
  public void normalizeGoal() {
    ExpectationSemiring<LogSemiring,ScalarPM> goalKVal = getGoalK();
    ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO> goalXVal = getGoalX();
   
    //goalKVal.printInfor();
    //goalXVal.printInfor();
   
    double normConstant = goalKVal.getP().getLogValue();//p
    goalKVal.getR().multiSemiring(-normConstant);//r
    goalXVal.getS().multiSemiring(-normConstant);//s
    goalXVal.getT().multiSemiring(-normConstant);//t
   
  }
 
 
  public double getSecondOrderExpectation(){
    return getGoalX().getT().getValue().convertToRealValue();
  }


//  ####################################################################### 
  public static void main(String[] args) {
    if(args.length<1){
      System.out.println("Wrong command: java HypLenSquareExpectation  f_nbest_prefix scale numSent ....");
      System.exit(1);
    }
   
   
    String f_dev_hg_prefix=args[0].trim();
    String f_dev_items = f_dev_hg_prefix +".items";
    String f_dev_rules = f_dev_hg_prefix +".rules";
   
    double scale = 1;
    if(args.length>=2)
      scale = new Double(args[1].trim());
   
    int num_sents =5;
    if(args.length>=3)
      num_sents = new Integer(args[2].trim());
   
 
   
    SymbolTable p_symbol = new BuildinSymbol(null);
    int baseline_lm_feat_id =0;
   
    ArrayList<HyperGraph> hyperGraphs = new  ArrayList<HyperGraph>();;
   
    HypLenSquareExpectation ds = new HypLenSquareExpectation(scale);
    DiskHyperGraph diskHG = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, true, null); //have model costs stored
    diskHG.initRead(f_dev_items, f_dev_rules,null);
    for(int k=0;k<136; k++){
      for(int sent_id=0; sent_id < num_sents; sent_id ++){
        System.out.println("#Process sentence " + sent_id);
        HyperGraph hg_test;
        if(k==0){ 
          hg_test = diskHG.readHyperGraph();
          hyperGraphs.add(hg_test);
        }else
          hg_test = hyperGraphs.get(sent_id);
        ds.setHyperGraph(hg_test);
        ds.runInsideOutside();
        //ds.printTotalX();
        ds.normalizeGoal();
        double lenSecondOrderExpectation = ds.getSecondOrderExpectation();
         
       
        System.out.println("hyplensquireexpectation is " + lenSecondOrderExpectation);
        ds.clearState();
      }
      HypLenSquareExpectation.logger.info("numTimesCalled=" + k);
    }
    diskHG.closeReaders();
   
   
   
  }



 

 
}
TOP

Related Classes of joshua.discriminative.semiring_parsingv2.applications.HypLenSquareExpectation

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.