package joshua.discriminative.ranker;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.ComputeNodeResult;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.decoder.hypergraph.ViterbiExtractor;
import joshua.discriminative.DiscriminativeSupport;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.feature_related.feature_function.EdgeTblBasedBaselineFF;
/**This class implements functions to rank HG based on a bunch of feature functions
*It does not change the topology of the HG, but it changes the
*bestHyperedge and bestLogP in hypergraph.
**/
public class HGRanker {
private HashSet<HGNode> processedNodesTbl = new HashSet<HGNode>();
private List<FeatureFunction> featFunctions;
private int numChangedBestHyperedge = 0;
private static Logger logger = Logger.getLogger(HGRanker.class.getName());
public HGRanker(List<FeatureFunction> featFunctions){
this.featFunctions = featFunctions;
}
/**Change the input hypergraph based on featFunctions
*/
public void rankHG(HyperGraph hg){
resetState();
rankHGNode(hg.goalNode);
//logger.info("number of nodes whose best hyperedge changes is " + numChangedBestHyperedge
// + " among total number of nodes " + processedNodesTbl.size() );
resetState();
}
//get 1best HG
public HyperGraph rerankHGAndGet1best(HyperGraph hg){
rankHG(hg);
return ViterbiExtractor.getViterbiTreeHG(hg);
}
public void resetState(){
processedNodesTbl.clear();
numChangedBestHyperedge = 0;
}
private void rankHGNode(HGNode it ){
if(processedNodesTbl.contains(it))
return;
processedNodesTbl.add(it);
//==== recursively call my children deductions, change pointer for bestHyperedge
HyperEdge oldBestHyperedge = it.bestHyperedge;
it.bestHyperedge=null;
for(HyperEdge dt : it.hyperedges){
rankHyperEdge(it, dt );
it.semiringPlus(dt);
}
/**Due to diskHG precision, the behavior may not be precise
**/
if(it.bestHyperedge!=oldBestHyperedge){
numChangedBestHyperedge++;
}
}
private void rankHyperEdge(HGNode parentNode, HyperEdge dt){
dt.bestDerivationLogP = 0;
if(dt.getAntNodes()!=null){
for(HGNode antNode : dt.getAntNodes()){
rankHGNode(antNode);
dt.bestDerivationLogP += antNode.bestHyperedge.bestDerivationLogP;//semiring times
}
}
double transLogP = getTransitionLogP(parentNode, dt);
dt.setTransitionLogP(transLogP);
dt.bestDerivationLogP += transLogP;
}
private double getTransitionLogP(HGNode parentNode, HyperEdge dt ){
return ComputeNodeResult.computeCombinedTransitionLogP(
this.featFunctions, dt, parentNode.i, parentNode.j, -1);
}
// ================== example main funciton======================
public static void main(String[] args) throws IOException{
if(args.length<10){
System.out.println("wrong command, correct command");
System.out.println("num of args is "+ args.length);
for(int i=0; i <args.length; i++)
System.out.println("arg is: " + args[i]);
System.exit(0);
}
long startTime = System.currentTimeMillis();
String testNodesFile=args[0].trim();
String testRulesFile=args[1].trim();
int numSent = new Integer(args[2].trim());
String modelFile = args[3].trim();
double baselineWeight = new Double(args[4].trim());
boolean useTMFeat = new Boolean(args[5].trim());
boolean useLMFeat = new Boolean(args[6].trim());
boolean useEdgeNgramOnly = new Boolean(args[7].trim());
boolean useTMTargetFeat = new Boolean(args[8].trim());
String reranked1bestFile = args[9].trim();
boolean saveModelCosts = true;
String featureFile = null;
if(args.length>9)
featureFile = args[9].trim();
SymbolTable symbolTbl = new BuildinSymbol(null);
List<FeatureFunction> features = new ArrayList<FeatureFunction>();
//=== baseline feature ====
//TODO: ????????????????????????????????????????????????????
int baselineFeatID = 99;
//??????????????????????????????????????
EdgeTblBasedBaselineFF baselineFeature = new EdgeTblBasedBaselineFF(baselineFeatID, baselineWeight);
features.add(baselineFeature);
//=== reranking feature ===
//TODO: ??????????????
int ngramStateID = 0;
int baselineLMOrder = 5;
int startNgramOrder = 1;
int endNgramOrder = 2;
int featID = 100;
double weight = 1.0;
//????????
Map<String,Integer> rulesIDTable = null; //TODO??
//TODO
FeatureFunction rerankFF = DiscriminativeSupport.setupRerankingFeature(featID, weight, symbolTbl, useTMFeat, useLMFeat, useEdgeNgramOnly, useTMTargetFeat,
JoshuaConfiguration.useMicroTMFeat, JoshuaConfiguration.wordMapFile,
ngramStateID,
baselineLMOrder, startNgramOrder, endNgramOrder, featureFile, modelFile, rulesIDTable);
features.add(rerankFF);
//=== reranker using the feature functions
HGRanker reranker = new HGRanker(features);
BufferedWriter out1best = FileUtilityOld.getWriteFileStream(reranked1bestFile);
int topN=3;
boolean useUniqueNbest =true;
boolean useTreeNbest = false;
boolean addCombinedCost = true;
KBestExtractor kbestExtractor = new KBestExtractor(symbolTbl, useUniqueNbest, useTreeNbest, false, addCombinedCost, false, true);
DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
diskHG.initRead(testNodesFile, testRulesFile, null);
for(int sentID=0; sentID < numSent; sentID ++){
System.out.println("#Process sentence " + sentID);
HyperGraph testHG = diskHG.readHyperGraph();
baselineFeature.collectTransitionLogPs(testHG);
reranker.rankHG(testHG);
try{
kbestExtractor.lazyKBestExtractOnHG(testHG, features, topN, sentID, out1best);
} catch (IOException e) {
e.printStackTrace();
}
}
FileUtilityOld.closeWriteFile(out1best);
System.out.println("Time cost: " + ((System.currentTimeMillis()-startTime)/1000));
}
}