package joshua.discriminative.semiring_parsingv2.applications;
import java.util.ArrayList;
import java.util.logging.Logger;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.semiring_parsingv2.DefaultIOParserWithXLinearCombinator;
import joshua.discriminative.semiring_parsingv2.SignedValue;
import joshua.discriminative.semiring_parsingv2.bilinear_operator.ScalarBO;
import joshua.discriminative.semiring_parsingv2.pmodule.ExpectationSemiringPM;
import joshua.discriminative.semiring_parsingv2.pmodule.ScalarPM;
import joshua.discriminative.semiring_parsingv2.semiring.ExpectationSemiring;
import joshua.discriminative.semiring_parsingv2.semiring.LogSemiring;
public class HypLenSquareExpectation extends DefaultIOParserWithXLinearCombinator<
ExpectationSemiring<LogSemiring,ScalarPM>,
ExpectationSemiringPM<LogSemiring,ScalarPM,ScalarPM,ScalarPM,ScalarBO>
> {
private static final Logger logger =
Logger.getLogger(HypLenSquareExpectation.class.getName());
double scale;
static ScalarBO pBilinearOperator = new ScalarBO();
public HypLenSquareExpectation(double scale_){
super();
this.scale = scale_;
}
@Override
protected ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO>
createNewXWeight() {
ScalarPM s = new ScalarPM();
ScalarPM t = new ScalarPM();
return new ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO>(s, t, pBilinearOperator);
}
@Override
protected ExpectationSemiring<LogSemiring, ScalarPM> createNewKWeight() {
LogSemiring p = new LogSemiring();
ScalarPM s = new ScalarPM();
return new ExpectationSemiring<LogSemiring, ScalarPM>(p,s);
}
@Override
protected ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM,ScalarBO>
getEdgeXWeight(HyperEdge dt, HGNode parent_item) {
//== p
double logProb = scale * dt.getTransitionLogP(false);
LogSemiring p = new LogSemiring(logProb);
//== r
double val = 0;//real
if(dt.getRule()!=null){
val = dt.getRule().getEnglish().length-dt.getRule().getArity();//length; real semiring
}
ScalarPM r = new ScalarPM( SignedValue.createSignedValueFromRealNumber(val) );
//== s
ScalarPM s = r;
//== t
ScalarPM t = pBilinearOperator.bilinearMulti(r, s);
//s = p s
s.multiSemiring(p);
//t= p t
t.multiSemiring(p);
return new ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO>(s, t, pBilinearOperator);
}
@Override
protected ExpectationSemiring<LogSemiring, ScalarPM>
getEdgeKWeight(HyperEdge dt, HGNode parent_item) {
//== p
double logProb = scale * dt.getTransitionLogP(false);
LogSemiring p = new LogSemiring(logProb);
//== r
double val = 0;//real
if(dt.getRule()!=null){
val = dt.getRule().getEnglish().length-dt.getRule().getArity();//length; real semiring
}
ScalarPM r = new ScalarPM( SignedValue.createSignedValueFromRealNumber(val) );
// r= p r
r.multiSemiring(p);
return new ExpectationSemiring<LogSemiring, ScalarPM>(p,r);
}
@Override
public void normalizeGoal() {
ExpectationSemiring<LogSemiring,ScalarPM> goalKVal = getGoalK();
ExpectationSemiringPM<LogSemiring, ScalarPM, ScalarPM, ScalarPM, ScalarBO> goalXVal = getGoalX();
//goalKVal.printInfor();
//goalXVal.printInfor();
double normConstant = goalKVal.getP().getLogValue();//p
goalKVal.getR().multiSemiring(-normConstant);//r
goalXVal.getS().multiSemiring(-normConstant);//s
goalXVal.getT().multiSemiring(-normConstant);//t
}
public double getSecondOrderExpectation(){
return getGoalX().getT().getValue().convertToRealValue();
}
// #######################################################################
public static void main(String[] args) {
if(args.length<1){
System.out.println("Wrong command: java HypLenSquareExpectation f_nbest_prefix scale numSent ....");
System.exit(1);
}
String f_dev_hg_prefix=args[0].trim();
String f_dev_items = f_dev_hg_prefix +".items";
String f_dev_rules = f_dev_hg_prefix +".rules";
double scale = 1;
if(args.length>=2)
scale = new Double(args[1].trim());
int num_sents =5;
if(args.length>=3)
num_sents = new Integer(args[2].trim());
SymbolTable p_symbol = new BuildinSymbol(null);
int baseline_lm_feat_id =0;
ArrayList<HyperGraph> hyperGraphs = new ArrayList<HyperGraph>();;
HypLenSquareExpectation ds = new HypLenSquareExpectation(scale);
DiskHyperGraph diskHG = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, true, null); //have model costs stored
diskHG.initRead(f_dev_items, f_dev_rules,null);
for(int k=0;k<136; k++){
for(int sent_id=0; sent_id < num_sents; sent_id ++){
System.out.println("#Process sentence " + sent_id);
HyperGraph hg_test;
if(k==0){
hg_test = diskHG.readHyperGraph();
hyperGraphs.add(hg_test);
}else
hg_test = hyperGraphs.get(sent_id);
ds.setHyperGraph(hg_test);
ds.runInsideOutside();
//ds.printTotalX();
ds.normalizeGoal();
double lenSecondOrderExpectation = ds.getSecondOrderExpectation();
System.out.println("hyplensquireexpectation is " + lenSecondOrderExpectation);
ds.clearState();
}
HypLenSquareExpectation.logger.info("numTimesCalled=" + k);
}
diskHG.closeReaders();
}
}