/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.hypergraph;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.chart_parser.ComputeNodeResult;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.ff.tm.BilingualRule;
import joshua.decoder.ff.tm.Grammar;
import joshua.decoder.ff.tm.GrammarReader;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.ff.tm.hiero.DiskHyperGraphFormatReader;
import joshua.decoder.ff.tm.hiero.MemoryBasedBatchGrammar;
import joshua.util.FileUtility;
import joshua.util.Regex;
/**
* this class implements functions of writting/reading hypergraph
* on disk. Limitations of this version
* (1) cannot recover each individual feature, notably the LM feature
* (2) assume we only have one stateful featuure, which must be a
* LM feature
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @author wren ng thornton <wren@users.sourceforge.net>
* @version $LastChangedDate: 2010-02-10 09:46:05 -0600 (Wed, 10 Feb 2010) $
*/
//Bottom-up
//line: SENTENCE_TAG, sent_id, sent_len, numNodes, numEdges (in average, numEdges is about 10 times larger than the numNodes, which is in average about 4000)
//line: ITEM_TAG, item id, i, j, lhs, numEdges, tbl_state;
//line: bestLogP, numNodes, item_ids, rule id, OOV-Non-Terminal (optional), OOV (optional), \newline feature scores
public class DiskHyperGraph {
//===============================================================
// Fields
//===============================================================
private int LMFeatureID = 0;
private SymbolTable symbolTable;
//when saving the hg, we simply compute all the model logP on the fly and store them on the disk
/*TODO: when reading the hg, we read thm into a WithModelCostsHyperEdge;
*now, we let a program outside this class to figure out which model logP corresponds which feature function, we should avoid this in the future*/
private List<FeatureFunction> featureFunctions;
// Whether to store the logPs at each HyperEdge
private boolean storeModelLogP = false;
// This will be set if the previous sentence is skipped
private String startLine;
private HashMap<HGNode,Integer> itemToID
= new HashMap<HGNode,Integer>(); // for saving hypergraph
private HashMap<Integer,HGNode> idToItem
= new HashMap<Integer,HGNode>(); // for reading hypergraph
private int currentItemID = 1;
private int qtyDeductions = 0;
// Shared by many hypergraphs, via the initialization functions
private HashMap<Integer,Rule> associatedGrammar = new HashMap<Integer, Rule>();
private BufferedWriter itemsWriter;
private BufferedReader itemsReader;
private HyperGraphPruning pruner;
// TODO: this is not pretty, but avoids re-allocation in writeRule()
private GrammarReader<BilingualRule> ruleReader;
// Set in init_read(...), used in read_hyper_graph()
private HashMap<Integer,?> selectedSentences;
private int sentID;
//===============================================================
// Static Fields
//===============================================================
private static final String SENTENCE_TAG = "#SENT: ";
private static final String ITEM_TAG = "#I";
private static final String ITEM_STATE_TAG = " ST ";
private static final String NULL_ITEM_STATE = "nullstate";
/* three kinds of rule:
* (>0) regular rule
* (0) oov rule
* (-1) null rule
*/
private static int NULL_RULE_ID = -1;
//FIXME: this is a hack for us to create OOVRule, and OOVRuleID
/**
* This is wrong as the default LHS and owner are not
* properly set. For this reason, the creation of OOV rule
* may cause bugs
*/
private static Grammar pGrammar = new MemoryBasedBatchGrammar();
private static final Logger logger =
Logger.getLogger(DiskHyperGraph.class.getName());
//===============================================================
// Constructors
//===============================================================
/**
* For saving purpose, one needs to specify the featureFunctions.
* For reading purpose, one does not need to provide the
* list.
*/
public DiskHyperGraph(SymbolTable symbolTable, int LMFeatureID,
boolean storeModelCosts, List<FeatureFunction> featureFunctions)
{
this.symbolTable = symbolTable;
this.LMFeatureID = LMFeatureID;
this.storeModelLogP = storeModelCosts;
this.featureFunctions = featureFunctions;
}
//===============================================================
// Initialization Methods
//===============================================================
/*
* for writting hyper-graph:
* (1) saving each hyper-graph;
* (2) remember each regualar rule used;
* (3) dump the rule jointly (in case parallel decoding)
*/
public void initWrite(String itemsFile, boolean useForestPruning,
double threshold) throws IOException
{
this.itemsWriter =
(null == itemsFile)
? new BufferedWriter(new OutputStreamWriter(System.out))
: FileUtility.getWriteFileStream(itemsFile);
if (ruleReader == null)
ruleReader = new DiskHyperGraphFormatReader(null, this.symbolTable);
if (useForestPruning) {
this.pruner = new HyperGraphPruning(this.symbolTable, true, threshold, threshold);
}
}
public void initRead(String hypergraphsFile, String rulesFile, HashMap<Integer,?> selectedSentences) {
try {
this.itemsReader = FileUtility.getReadFileStream(hypergraphsFile);
} catch (IOException e) {
logger.severe("Error opening hypergraph file: " + hypergraphsFile);
}
this.selectedSentences = selectedSentences;
/* Reload the rule table */
if (logger.isLoggable(Level.FINE))
logger.fine("Reading rules from file " + rulesFile);
this.associatedGrammar.clear();
this.ruleReader =
new DiskHyperGraphFormatReader(rulesFile, this.symbolTable);
if (ruleReader != null) {
ruleReader.initialize();
for (Rule rule : ruleReader) {
this.associatedGrammar.put(rule.getRuleID(), rule);
}
ruleReader.close();
}
}
public HashMap<Integer,Rule> getAssocatedGrammar(){
return associatedGrammar;
}
private void resetStates() {
this.itemToID.clear();
this.idToItem.clear();
this.currentItemID = 1;
this.qtyDeductions = 0;
}
public void closeReaders(){
try {
if(this.itemsReader!=null)
this.itemsReader.close();
if(this.ruleReader!=null)
this.ruleReader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
public void closeItemsWriter(){
try {
if(this.itemsWriter!=null)
this.itemsWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
//===============================================================
// Methods
//===============================================================
public void saveHyperGraph(HyperGraph hg) throws IOException {
resetStates();
if (null != this.pruner) this.pruner.pruningHG(hg);
constructItemTables(hg);
if (logger.isLoggable(Level.FINE))
logger.fine("Number of Items is: " + this.itemToID.size());
this.itemsWriter.write(
SENTENCE_TAG + hg.sentID
+ " " + hg.sentLen
+ " " + this.itemToID.size()
+ " " + this.qtyDeductions
+ "\n" );
this.sentID = hg.sentID;
// we save the hypergraph in a bottom-up way: so that reading is easy
if (this.idToItem.size() != this.itemToID.size()) {
throw new RuntimeException("Number of Items is not equal");
}
for (int i = 1; i <= this.idToItem.size(); i++) {
writeItem(this.idToItem.get(i));
}
if (null != this.pruner)
this.pruner.clearState();
}
/**
* Assign IDs to all HGNodes in the hypergraph. We do a
* depth-first traversal starting at the goal item, and
* assign IDs from the bottom up. BUG: this code could stack
* overflow for deep trees.
*/
private void constructItemTables(HyperGraph hg) {
resetStates();
constructItemTables(hg.goalNode);
}
/**
* This method is <i>really</i> private, and should only
* be called by constructItemTables(HyperGraph).
*/
private void constructItemTables(HGNode item) {
if (this.itemToID.containsKey(item)) return;
// first: assign IDs to all my antecedents
for (HyperEdge hyperEdge : item.hyperedges) {
this.qtyDeductions++;
if (null != hyperEdge.getAntNodes()) {
for (HGNode antecedentItem : hyperEdge.getAntNodes()) {
constructItemTables(antecedentItem);
}
}
}
// second: assign ID to "myself"
this.idToItem.put(this.currentItemID, item);
this.itemToID.put(item, this.currentItemID);
this.currentItemID++;
}
private void writeItem(HGNode item) throws IOException {
this.itemsWriter.write(
new StringBuffer()
.append(ITEM_TAG)
.append(" ")
.append(this.itemToID.get(item))
.append(" ")
.append(item.i)
.append(" ")
.append(item.j)
.append(" ")
.append(this.symbolTable.getWord(item.lhs))
.append(" ")
.append(
null == item.hyperedges
? 0
: item.hyperedges.size() )
.append(ITEM_STATE_TAG)
.append(
// Assume LM is the only stateful feature
null != item.getDPStates()
? item.getDPStates()
.get(this.LMFeatureID)
.getSignature(this.symbolTable, true)
: NULL_ITEM_STATE )
.append("\n")
.toString()
);
if (null != item.hyperedges) {
for (HyperEdge hyperEdge : item.hyperedges) {
writeHyperedge(item, hyperEdge);
}
}
this.itemsWriter.flush();
}
private final boolean isOutOfVocabularyRule(Rule rl) {
return (rl.getRuleID() == MemoryBasedBatchGrammar.OOV_RULE_ID);//pGrammar.getOOVRuleID());
}
private void writeHyperedge(HGNode node, HyperEdge edge)
throws IOException {
//get rule id
int ruleID = NULL_RULE_ID;
final Rule edgeRule = edge.getRule();
if (null != edgeRule) {
ruleID = edgeRule.getRuleID();
if (! isOutOfVocabularyRule(edgeRule)) {
this.associatedGrammar.put(ruleID, edgeRule); //remember used regular rule
}
}
StringBuffer s = new StringBuffer();
//line: bestLogP, numNodes, item_ids, rule id, OOV-Non-Terminal (optional), OOV (optional),
s.append(String.format("%.4f ", edge.bestDerivationLogP));
//s.append(" ").append(cur_d.bestDerivationLogP).append(" ");//this 1.2 faster than the previous statement
//s.append(String.format("%.4f ", cur_d.get_transition_logP(false)));
//s.append(cur_d.get_transition_logP(false)).append(" ");//this 1.2 faster than the previous statement, but cost 1.4 larger disk space
if (null == edge.getAntNodes()) {
s.append(0);
} else {
final int qtyItems = edge.getAntNodes().size();
s.append(qtyItems);
for (int i = 0; i < qtyItems; i++) {
s.append(' ')
.append(this.itemToID.get(
edge.getAntNodes().get(i) ));
}
}
s.append(' ')
.append(ruleID);
//if (ruleID == MemoryBasedBatchGrammar.OOV_RULE_ID) {//pGrammar.getOOVRuleID()) {
if (edgeRule != null) {
//System.out.println("lhs id: " + deduction_rule.getLHS());
//System.out.println("rule words: " + deduction_rule.getEnglish());
s.append(' ')
.append(this.symbolTable.getWord(edgeRule.getLHS()))
.append(' ')
.append(this.symbolTable.getWords(edgeRule.getEnglish()));
}
s.append('\n');
// save model logPs as a seprate line; optional
if (this.storeModelLogP) {
s.append( createModelLogPLine(node, edge) );
}
this.itemsWriter.write(s.toString());
}
/**
* Do not remove this function as it gives freedom for an
* extended class to override it
*/
public String createModelLogPLine(HGNode parentNode, HyperEdge edge){
StringBuffer line = new StringBuffer();
double[] transitionLogPs = null;
if(this.featureFunctions!=null){
transitionLogPs = ComputeNodeResult.computeModelTransitionLogPs(
this.featureFunctions, edge, parentNode.i, parentNode.j, this.sentID);
}else{
transitionLogPs = ((WithModelLogPsHyperEdge) edge).modeLogPs;
}
for (int k = 0; k < transitionLogPs.length; k++) {
line.append(String.format("%.4f", transitionLogPs[k]))
.append(
k < transitionLogPs.length - 1
? " "
: "\n");
}
return line.toString();
}
// End save_hyper_graph()
//===============================================================
public HyperGraph readHyperGraph() {
resetStates();
//read first line: SENTENCE_TAG, sent_id, sent_len, numNodes, num_deduct
String line = null;
if (null != this.startLine) { // the previous sentence is skipped
line = this.startLine;
this.startLine = null;
} else {
line = FileUtility.read_line_lzf(this.itemsReader);
}
if (! line.startsWith(SENTENCE_TAG)) {
throw new RuntimeException("wrong sent tag line: " + line);
}
// Test if we should skip this sentence
if (null != this.selectedSentences
&& (! this.selectedSentences.containsKey(
Integer.parseInt(Regex.spaces.split(line)[1]) ))
) {
while ((line = FileUtility.read_line_lzf(this.itemsReader)) != null) {
if (line.startsWith(SENTENCE_TAG)) break;
}
this.startLine = line;
System.out.println("sentence is skipped");
return null;
} else {
String[] fds = Regex.spaces.split(line);
int sentenceID = Integer.parseInt(fds[1]);
int sentenceLength = Integer.parseInt(fds[2]);
int qtyItems = Integer.parseInt(fds[3]);
int qtyDeductions = Integer.parseInt(fds[4]);
//System.out.println("numNodes: "+ qtyItems + "; num_deducts: " + qtyDeductions);
for (int i = 0; i < qtyItems; i++)
readNode();
//TODO check if the file reaches EOF, or if the num_deducts matches
//create hyper graph
HGNode goalItem = this.idToItem.get(qtyItems);
if (null == goalItem) {
throw new RuntimeException("no goal item");
}
return new HyperGraph(goalItem, qtyItems, qtyDeductions, sentenceID, sentenceLength);
}
}
private HGNode readNode() {
//line: ITEM_TAG itemID i j lhs qtyDeductions ITEM_STATE_TAG item_state
String line = FileUtility.read_line_lzf(this.itemsReader);
String[] fds = line.split(ITEM_STATE_TAG); // TODO: use joshua.util.Regex
if (fds.length != 2) {
throw new RuntimeException("wrong item line");
}
String[] words = Regex.spaces.split(fds[0]);
int itemID = Integer.parseInt(words[1]);
int i = Integer.parseInt(words[2]);
int j = Integer.parseInt(words[3]);
int lhs = this.symbolTable.addNonterminal(words[4]);
int qtyDeductions = Integer.parseInt(words[5]);
//item state: signature (created from HashMap tbl_states)
HashMap<Integer,DPState> dpStates = null;
if (fds[1].compareTo(NULL_ITEM_STATE) != 0) {
// Assume the only stateful feature is lm feature
dpStates = new HashMap<Integer,DPState>();
dpStates.put(this.LMFeatureID, new NgramDPState(this.symbolTable, fds[1]));
}
List<HyperEdge> edges = null;
HyperEdge bestEdge = null;
double bestLogP = Double.NEGATIVE_INFINITY;
if (qtyDeductions > 0) {
edges = new ArrayList<HyperEdge>();
for (int t = 0; t < qtyDeductions; t++) {
HyperEdge edge = readHyperedge();
edges.add(edge);
if (edge.bestDerivationLogP > bestLogP) {//semiring plus
bestLogP = edge.bestDerivationLogP;
bestEdge = edge;
}
}
}
HGNode item = new HGNode(i, j, lhs, edges, bestEdge, dpStates);
this.idToItem.put(itemID, item);
return item;
}
// Assumption: has this.associatedGrammar and this.idToItem
private HyperEdge readHyperedge() {
//line: bestLogP, numNodes, item_ids, rule id, OOV-Non-Terminal (optional), OOV (optional),
String line = FileUtility.read_line_lzf(this.itemsReader);
String[] fds = Regex.spaces.split(line);
//bestLogP numNodes item_ids
double bestLogP = Double.parseDouble(fds[0]);
ArrayList<HGNode> antecedentItems = null;
final int qtyAntecedents = Integer.parseInt(fds[1]);
if (qtyAntecedents > 0) {
antecedentItems = new ArrayList<HGNode>();
for (int t = 0; t < qtyAntecedents; t++) {
final int itemID = Integer.parseInt(fds[2+t]);
HGNode item = this.idToItem.get(itemID);
if (null == item) {
throw new RuntimeException("item is null for id: " + itemID);
}
antecedentItems.add(item);
}
}
//rule_id
Rule rule = null;
final int ruleID = Integer.parseInt(fds[2+qtyAntecedents]);
if (ruleID != NULL_RULE_ID) {
if (ruleID != MemoryBasedBatchGrammar.OOV_RULE_ID) {//pGrammar.getOOVRuleID()) {
rule = this.associatedGrammar.get(ruleID);
if (null == rule) {
throw new RuntimeException("rule is null but id is " + ruleID);
}
} else {
rule = pGrammar.constructOOVRule(1, this.symbolTable.addTerminal(fds[4+qtyAntecedents]), this.symbolTable.addTerminal(fds[4+qtyAntecedents]), false);
/**This is a hack. as the pGrammar does not set defaultLHS properly*/
int lhs = this.symbolTable.addNonterminal(fds[3+qtyAntecedents]);
rule.setLHS(lhs);
}
} else {
// Do nothing: goal item has null rule
}
HyperEdge hyperEdge;
if (this.storeModelLogP) {
String[] logPString =
Regex.spaces.split(FileUtility.read_line_lzf(this.itemsReader));
double[] logPs = new double[logPString.length];
for (int i = 0; i < logPString.length; i++) {
logPs[i] = Double.parseDouble(logPString[i]);
}
hyperEdge = new WithModelLogPsHyperEdge(rule, bestLogP, null, antecedentItems, logPs, null);
} else {
hyperEdge = new HyperEdge(rule, bestLogP, null, antecedentItems, null);
}
hyperEdge.getTransitionLogP(true); // to set the transition logP
return hyperEdge;
}
// end readHyperGraph()
//===============================================================
static public Map<String,Integer> obtainRuleStringToIDTable(String rulesFile) {
SymbolTable symbolTable = new BuildinSymbol(null);
GrammarReader<BilingualRule> ruleReader = new DiskHyperGraphFormatReader(rulesFile, symbolTable);
Map<String,Integer> rulesIDTable = new HashMap<String,Integer>();
ruleReader.initialize();
for (Rule rule : ruleReader) {
rulesIDTable.put(rule.toStringWithoutFeatScores(symbolTable), rule.getRuleID());
}
ruleReader.close();
return rulesIDTable;
}
static public int mergeDiskHyperGraphs(int ngramStateID, boolean saveModelCosts, int totalNumSent,
boolean useUniqueNbest, boolean useTreeNbest,
String filePrefix1, String filePrefix2, String filePrefixOut, boolean removeDuplicate) throws IOException{
SymbolTable symbolTbl = new BuildinSymbol();
DiskHyperGraph diskHG1 = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
diskHG1.initRead(filePrefix1+".hg.items", filePrefix1+".hg.rules", null);
DiskHyperGraph diskHG2 = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
diskHG2.initRead(filePrefix2+".hg.items", filePrefix2+".hg.rules", null);
DiskHyperGraph diskHGOut = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
//TODO
boolean forestPruning = false;
double forestPruningThreshold = -1;
diskHGOut.initWrite(filePrefixOut + ".hg.items", forestPruning, forestPruningThreshold);
KBestExtractor kbestExtrator = new KBestExtractor(symbolTbl, useUniqueNbest, useTreeNbest,
false, false, false, false);
int totalNumHyp = 0;
for(int sentID=0; sentID < totalNumSent; sentID ++){
//System.out.println("#Process sentence " + sentID);
HyperGraph hg1 = diskHG1.readHyperGraph();
HyperGraph hg2 = diskHG2.readHyperGraph();
//filter hypergraphs by removing duplicate
if(removeDuplicate){
Set<String> uniqueHyps = new HashSet<String>();
kbestExtrator.filterKbestHypergraph(hg1, uniqueHyps);
kbestExtrator.filterKbestHypergraph(hg2, uniqueHyps);
}
HyperGraph mergedHG = HyperGraph.mergeTwoHyperGraphs(hg1, hg2);
diskHGOut.saveHyperGraph(mergedHG);
/*System.out.println("size1=" + hg1.goalNode.hyperedges.size() +
"; size2=" + hg2.goalNode.hyperedges.size() +
"; mergedsize=" + mergedHG.goalNode.hyperedges.size());
*/
totalNumHyp += mergedHG.goalNode.hyperedges.size();
}
diskHGOut.writeRulesNonParallel(filePrefixOut + ".hg.rules");
System.out.println("totalMergeSize="+totalNumHyp);
diskHG1.closeReaders();
diskHG2.closeReaders();
diskHGOut.closeReaders();
diskHGOut.closeItemsWriter();
return totalNumHyp;
}
public void writeRulesNonParallel(String rulesFile)
throws IOException {
BufferedWriter out =
(null == rulesFile)
? new BufferedWriter(new OutputStreamWriter(System.out))
: FileUtility.getWriteFileStream(rulesFile) ;
logger.info("writing rules");
for (int ruleID : this.associatedGrammar.keySet()) {
writeRule(out, this.associatedGrammar.get(ruleID), ruleID);
}
out.flush();
out.close();
this.ruleReader.close();
}
// writtenRules: remember what kind of rules have already been saved
public void writeRulesParallel(BufferedWriter out,
HashMap<Integer,Integer> writtenRules) throws IOException
{
logger.info("writing rules in a partition");
for (int ruleID : this.associatedGrammar.keySet()) {
if (! writtenRules.containsKey(ruleID)) {//not been written on disk yet
writtenRules.put(ruleID, 1);
writeRule(out, this.associatedGrammar.get(ruleID), ruleID);
}
}
out.flush();
}
private void writeRule(BufferedWriter out, Rule rule,
int ruleID) throws IOException
{
// HACK: this is VERY wrong, but avoiding it seems to require major architectural changes
out.write(this.ruleReader.toWords( (BilingualRule) rule));
out.write("\n");
}
}