package joshua.discriminative.syntax_reorder;
import java.io.BufferedReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.List;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.syntax_reorder.HashtableBasedHieroGrammarScorer.Rule;
/* Zhifei Li, <zhifei.work@gmail.com>
* Johns Hopkins University
*/
//TODO: (1) may ignore flat phrases in extract_rules; (2) the accummulation of the first feat;
//(3) the accumulation of other feats; (4) weight calculation; (5) ignore phrase whose lexical weght is zero
public class HieroExtractor {
public static int INVALID_POS=-1;
public static int INVALID_WRD_ID=-1;//TODO must be different from terminal and non-terminal symbols
public static String NULL_ALIGN_WRD_SYM="NULL";
public static int NULL_ALIGN_WRD_SYM_ID=0;
public static String NON_TERMINAL_TAG_SYM = "PHRASE";//tag for [PHRASE] or [X]
public static int NON_TERMINAL_TAG_SYM_ID = 0;//tag for [PHRASE] or [X]
public static int maxInitPhraseSize = 10;
public static int max_final_phrase_size = 5;
public static int min_sub_phrase_size = 2;
public static int max_num_non_terminals = 2;
public static String file_align="";
public static String file_zh="";
public static String file_en="";
public static String dir_grammar_out="";
public static String file_f2e_lexical_weights="";
public static String file_e2f_lexical_weights="";
public static Boolean allow_non_lexicial_rules=false;
public static Boolean forbid_adjacent_nonterminals=true;//in french
public static Boolean require_aligned_terminal=true;
public static Boolean use_tight_phrase=true;
public static Boolean remove_overlap_phrases=false;
public static Boolean keep_alignment_infor=false;
public static HashMap fweights_table;
public static HashMap eweights_table;
public static float[] fweights; //sentence-specific dynamical array
public static float[] eweights; //sentence-specific dynamical array
public static float[] fratios; //sentence-specific dynamical array
public static int g_num_init_phrases=0;
public static int g_num_rules_and_phrases=0;
public static SymbolTable symbolTable = null; //TODO
private static HashMap readWeightFile(String file){
//BufferedReader t_reader_tree = FileUtility.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\parse.sync.berkeley1","UTF8");
BufferedReader t_reader = FileUtilityOld.getReadFileStream(file,"UTF8");
HashMap res = new HashMap();
String line;
int n=0;
while((line=FileUtilityOld.readLineLzf(t_reader))!=null){
n++;
if(n%500000==0) System.out.println("reading lines " + n);
String[] fds = line.split("\\s+");//format: wrd1 wrd2 weight
int id1 = symbolTable.addTerminal(fds[0]);
int id2 = symbolTable.addTerminal(fds[1]);
res.put(form_weight_key(id1,id2), new Double(fds[2]));
}
FileUtilityOld.closeReadFile(t_reader);
return res;
}
private static String form_weight_key(int id1, int id2){
StringBuffer res = new StringBuffer();
res.append(id1);
res.append("-");
res.append(id2);
return res.toString();
}
private static double get_weight_from_matrix(HashMap weights_maxtrix, int f, int e){
String key = form_weight_key(f,e);
return weights_maxtrix.containsKey(key) ? (Double)weights_maxtrix.get(key) : 0;
}
private static List<Rule> processASentence(String line_align, String line_fr, String line_en){
//==== create alignment datastructure
Alignment align = new Alignment(line_fr, line_en, line_align);
//==== compute weights
if(fweights_table!=null)
fweights = computeLexicalWeights(align, fweights_table, false, false);
if(eweights_table!=null)
eweights = computeLexicalWeights(align, eweights_table, true, false);//transpose
//==== exphrase phrases
int[] actualMaxInitPhraseSize = new int[1];
ArrayList initPhrases = extractPhrases(align, maxInitPhraseSize, actualMaxInitPhraseSize);//extract regular flat phrases;
if(initPhrases.size()==0){
System.out.println("warning: no init phrases are extracted");
return null;
}
//==== test-set specific filtering, the extract_phrases can use suffix-array architecture
if(use_tight_phrase==false)
loosenPhrases();
if(remove_overlap_phrases==true)
removeOverlapPhrases();
//==== create index ??
//==== add label to the flat phrases
//done in extract_phrases, if we want to implement loosen_phrases() and remove_overlap_phrases(), we should put it here
//==== extract rules: the Rule is a list of phrases from a specific training sentence-pair
List<Rule> rulesAndPhrases = extractRules(align, initPhrases, max_final_phrase_size, min_sub_phrase_size, max_num_non_terminals, actualMaxInitPhraseSize[0]);
return rulesAndPhrases;
}
//sentence-specific
//for each word, calculate the lexicalized weight based on the alighment and the lexical weight tables
private static float[] computeLexicalWeights(Alignment align, HashMap weights_maxtrix, Boolean transpose, Boolean swap){
int[] fwords, ewords, faligned;
if(transpose==false){
fwords = align.french_wrds;
ewords = align.english_wrds;
faligned = align.num_alignments_infor_for_french;
}else{
fwords = align.english_wrds;
ewords = align.french_wrds;
faligned = align.num_alignments_infor_for_english;
}
float[] results = new float[fwords.length];
for(int i=0; i<fwords.length; i++){
float total = 0;
int n=0;
if(faligned[i]>0){
for(int j=0; j<ewords.length; j++){
int flag;
if(transpose == false)
flag = align.alignment_matrix[i][j];
else
flag = align.alignment_matrix[j][i];
if(flag==1){//aligned
if(swap == false)
total += get_weight_from_matrix(weights_maxtrix, fwords[i],ewords[j]);
else
total += get_weight_from_matrix(weights_maxtrix, ewords[j],fwords[i]);
n++;
}
}
}else{//unaligned
if(swap == false)
total += get_weight_from_matrix(weights_maxtrix, fwords[i],NULL_ALIGN_WRD_SYM_ID);
else
total += get_weight_from_matrix(weights_maxtrix,NULL_ALIGN_WRD_SYM_ID, fwords[i]);
n++;
}
results[i] = total/n;
}
System.out.println("weights are ");
for(int t=0; t<results.length; t++)
System.out.print(results[t]);
System.out.print("\n");
return results;
}
//extract flat phrases: i1,i2,j1,j2; who are the positions in the french and english sentences
private static ArrayList extractPhrases(Alignment align, int maxInitPhraseSize, int[] actualMaxInitPhraseSize){
actualMaxInitPhraseSize[0]=0;
ArrayList l_phrases = new ArrayList();
int i1,i2,j1,j2;//french span: [i1,i2]; eng span: [j1,j2]
for(i1=0; i1< align.french_wrds.length; i1++){
if(align.num_alignments_infor_for_french[i1]<=0) continue;//skip unaligned wrd
j1 = align.english_wrds.length;
j2 = -1;
for(i2=i1; i2<Alignment.min(i1+maxInitPhraseSize, align.french_wrds.length); i2++){
if(align.num_alignments_infor_for_french[i2]<=0) continue;//skip unaligned wrd
//j1 and j2: [j1,j2] is the "maximum" (thoug may be in-consistent) eng span for the french span [i1,i2]
j1 = Alignment.min(j1, align.min_pos_infor_for_french[i2]);
j2 = Alignment.max(j2, align.max_pos_infor_for_french[i2]);
if(j1>j2) continue;//empty english span
if(j2-j1+1>maxInitPhraseSize) break; //go to next i1, since adding more wrds in french will increase the eng span
int flag=0;
for(int j=j1; j<=j2; j++){//for each english wrd in [j1,j2]
if(align.min_pos_infor_for_english[j]<i1){//must extend i1 to solve inconsistence
flag = 1; //next i1
break;
}
if(align.max_pos_infor_for_english[j]>i2){//fix i1, but extending i2 may solve this inconsistence
flag = 2; //next i2
break;
}
}
if(flag==1) break; //next i1
if(flag==2) continue;//next i2
//add the phrase
l_phrases.add( new int[] {i1,i2,j1,j2, NON_TERMINAL_TAG_SYM_ID});
actualMaxInitPhraseSize[0] = Alignment.max(actualMaxInitPhraseSize[0], i2-i1+1);
}
}
g_num_init_phrases += l_phrases.size();
return l_phrases;
}
private static void loosenPhrases(){
System.out.println("Error: un-implemented function");
System.exit(0);
}
private static void removeOverlapPhrases(){
System.out.println("Error: un-implemented function");
System.exit(0);
}
//extract hiearchical rules
private static List<Rule> extractRules(Alignment align, ArrayList l_init_phrases, int max_final_phrase_size, int min_sub_phrase_size, int max_num_non_terminals, int actual_max_init_phrase_size){
if(l_init_phrases.size()==0){
System.out.println("warning: no flat phrases are extracted");
return null;
}
int n= align.english_wrds.length;
ArrayList[][][] bins = new ArrayList[n+1][n+1][max_num_non_terminals+1];//bins[i][j] spans the french [i,j-1], each bin is a list of items
ArrayList[] i2index = new ArrayList[n];//each is a list of init-phrases ending with i2
Hashtable i1s = new Hashtable();
//get i2index and i1s
for(int t=0; t < l_init_phrases.size(); t++){
int[] phrase = (int[])l_init_phrases.get(t);
printPhrase(align,phrase);
int t_i1 =phrase[0], t_i2 =phrase[1];
if(i2index[t_i2]==null) i2index[t_i2] = new ArrayList();
i2index[t_i2].add(phrase);
if(i1s.containsKey(t_i1)==false){
i1s.put(t_i1, 1);
//chart seeding
bins[t_i1][t_i1][0] = new ArrayList();
bins[t_i1][t_i1][0].add(new ArrayList());//add empty item
}
}
System.out.println("num of init phrases is " + l_init_phrases.size() + "; i1s len: " + i1s.size() + "; french len " + align.french_wrds.length + " maxabslen; " + actual_max_init_phrase_size);
//chart parsing: each item is an arraylist
int loop1=0;
int loop2=0;
for(int k=1; k<=Alignment.min(n, actual_max_init_phrase_size); k++){
loop1++;
loop2=0;
for(int i1=0; i1+k<=n; i1++){
if(i1s.containsKey(i1)==false) continue;//because the phrases never start from this index; bug: this may skip the flat phrase
loop2++;
int i2=i1+k-1;
//extend the dot by a subphrase
int tem1=0, tem2=0;
if(i2index[i2]!=null){
for(int t=0; t<i2index[i2].size(); t++){//for all sub-phrases ending at i2
int[] sub_phrase = (int[]) i2index[i2].get(t);
if(sub_phrase[1]-sub_phrase[0]+1>=min_sub_phrase_size){
for(int n_nts=0; n_nts<max_num_non_terminals; n_nts++){
if(bins[i1][sub_phrase[0]][n_nts]!=null){//no ant-items
for(int t2=0; t2<bins[i1][sub_phrase[0]][n_nts].size(); t2++){//for all ant-items
ArrayList item = (ArrayList) bins[i1][sub_phrase[0]][n_nts].get(t2);
if(item.size()<max_final_phrase_size &&
!(forbid_adjacent_nonterminals && item.size()>0 && !(item.get(item.size()-1) instanceof Integer))
){
ArrayList new_item = new ArrayList(item);
new_item.add(sub_phrase);
if(bins[i1][i2+1][n_nts+1]==null) bins[i1][i2+1][n_nts+1] = new ArrayList();
bins[i1][i2+1][n_nts+1].add(new_item);
tem1++;
}
}
}
}
}
}
}
//extend the dot by a wrd
for(int n_nts=0; n_nts<=max_num_non_terminals; n_nts++){
if(bins[i1][i2][n_nts]!=null){
for(int t2=0; t2<bins[i1][i2][n_nts].size(); t2++){
ArrayList item = (ArrayList) bins[i1][i2][n_nts].get(t2);
if(item.size()<max_final_phrase_size ){
ArrayList new_item = new ArrayList(item);
new_item.add(i2);
if(bins[i1][i2+1][n_nts]==null) bins[i1][i2+1][n_nts] = new ArrayList();
bins[i1][i2+1][n_nts].add(new_item);
tem2++;
}
}
}
}
//tem3++;
//System.out.println(loop1 + " "+ loop2 + " number of subphrases textend is " + tem1 + " and " + tem2);
}
}
//extract rules from the chart
ArrayList l_phrases_and_rules = new ArrayList();
for(int t=0; t <l_init_phrases.size(); t++){
ArrayList local_results = new ArrayList();
int[] phrase = (int[])l_init_phrases.get(t);
for(int n_nts=0; n_nts<=max_num_non_terminals; n_nts++){
if(bins[phrase[0]][phrase[1]+1][n_nts]!=null){
for(int t2=0; t2<bins[phrase[0]][phrase[1]+1][n_nts].size(); t2++){
ArrayList item = (ArrayList) bins[phrase[0]][phrase[1]+1][n_nts].get(t2);
Rule rule = makeAndScoreRule(align, phrase, item);
if(rule!=null){
local_results.add(rule);
}
}
}
}
//distribute the count, normalization
for(int k=0; k < local_results.size(); k++){
Rule rl = (Rule)local_results.get(k);
for(int f=0; f<rl.feat_scores.length; f++)
rl.feat_scores[f] /= local_results.size();
rl.print_info(symbolTable);
}
l_phrases_and_rules.addAll(local_results);
//System.out.println("local size " + local_results.size() + " all size " + l_phrases_and_rules.size());
}
System.out.println("num of rules and phrases is " + l_phrases_and_rules.size());
g_num_rules_and_phrases += l_phrases_and_rules.size();
return l_phrases_and_rules;
}
private static Rule makeAndScoreRule(Alignment align, int[] phrase, ArrayList item){
//ignore ??
if(item.size()==1 && !(item.get(0) instanceof Integer))
return null; //bug: this may skip the flat phrase
int nt_index=1;
boolean have_alignment=false;
int original_en_len = phrase[3]-phrase[2]+1;
//System.out.println("original en len " + original_en_len);
//System.out.println("item is " + item.toString());
int[] original_en_wrds = new int[original_en_len] ;
for(int t=0; t<original_en_len; t++)
original_en_wrds[t]= align.english_wrds[t+phrase[2]];
//get french words in the rule
int[] rule_fwords = new int[item.size()];
int[] fpos = new int[item.size()];//remember the original position in the sentence
for(int t=0; t<item.size(); t++){
//System.out.println("size is " + item.size());
if(item.get(t) instanceof Integer){//terminal
//System.out.println("terminal");
fpos[t]=(Integer)item.get(t);
if(align.num_alignments_infor_for_french[ fpos[t] ]>0){
have_alignment=true;
}
rule_fwords[t]= align.french_wrds[ fpos[t] ];
}else{//non-terminal
//System.out.println("non-terminal");
fpos[t]=INVALID_POS;
int[] sub_phrase = (int[])item.get(t);
original_en_len -= sub_phrase[3] - sub_phrase[2];//reserved one slot for the NT symbol
//System.out.println("e span: " + sub_phrase[0] +" - " + sub_phrase[1] +" ; len " + original_en_len);
int nt = symbolTable.addNonterminal(NON_TERMINAL_TAG_SYM+","+nt_index);//get [PHRASE,nt_index]
rule_fwords[t] = nt;
original_en_wrds[sub_phrase[2]-phrase[2]]=nt;
for(int k=sub_phrase[2]-phrase[2]+1; k<=sub_phrase[3]-phrase[2]; k++)
original_en_wrds[k]= INVALID_WRD_ID;
nt_index++;
}
}
if(require_aligned_terminal && have_alignment==false)
return null;
//get English words in the rule
int[] rule_ewords = new int[original_en_len];//original_en_len is changed to actual en len
int[] epos = new int[original_en_len];
for(int t=0, k=0; t<original_en_wrds.length; t++){
if(original_en_wrds[t]!=INVALID_WRD_ID){
rule_ewords[k] = original_en_wrds[t];
if(symbolTable.isNonterminal(original_en_wrds[t])==true)
epos[k]=INVALID_POS;
else
epos[k]=phrase[2]+t;
k++;
}
}
//create rule
Rule rl = new Rule(phrase[4], rule_fwords , rule_ewords);
//add alignment infor
if(keep_alignment_infor==true){
ArrayList align_info =new ArrayList();
for(int i=0; i<fpos.length; i++)
if(fpos[i]!=INVALID_POS)
for(int j=0; j<epos.length; j++)
if(epos[j]!=INVALID_POS)
if(align.alignment_matrix[fpos[i]][epos[j]]==1)
align_info.add(i+"-"+j);//add "i-j"
rl.alignments=align_info;
}
scoreRule(align, rl, fpos, epos);
return rl;
}
//compute and add feat scores
private static Rule scoreRule(Alignment align, Rule r_in, int[] fpos, int[] epos){
int funaligned=0, eunaligned=0;
float fweight=1, eweight=1, fratio=0;
//P_lex(eng|fr)
for(int t=0; t<r_in.french.length; t++){
if( symbolTable.isNonterminal(r_in.french[t])==false ){
if(align.num_alignments_infor_for_french[fpos[t]]<=0)
funaligned++;
if(fweights!=null)
fweight *= fweights[fpos[t]];
if(fratios!=null)
fratio += fratios[fpos[t]];
}
}
//P_lex(fr|eng)
for(int t=0; t<r_in.english.length; t++){
if( symbolTable.isNonterminal(r_in.english[t])==false ){
if(align.num_alignments_infor_for_english[epos[t]]<=0)
eunaligned++;
if(eweights!=null)
eweight *= eweights[epos[t]];
}
}
//add the feat scores
int num_feats=1;
if(fweights!=null) num_feats++;
if(eweights!=null) num_feats++;
if(fratios!=null) num_feats++;
float[] scores = new float[num_feats];
int t_id=0;
scores[t_id++]=1.0f;
if(fweights!=null) scores[t_id++]=fweight;
if(eweights!=null) scores[t_id++]=eweight;;
if(fratios!=null) scores[t_id++]=fratio;;
r_in.feat_scores=scores;
return r_in;
}
private static void printPhrase(Alignment align, int[] phrase){
String str="zh: " + phrase[0] + "-" + phrase[1];
for(int t=phrase[0]; t<=phrase[1]; t++)
str += " " + symbolTable.getWord(align.french_wrds[t]);
str += " en: " + phrase[2] + "-" + phrase[3];
for(int t=phrase[2]; t<=phrase[3]; t++)
str += " " + symbolTable.getWord(align.english_wrds[t]);
str += " nt: " + symbolTable.getWord(phrase[4]);
System.out.println("phrase is = " +str);
}
public static void main(String[] args) {
SymbolTable symbolTable = new BuildinSymbol();
//init symbol
NULL_ALIGN_WRD_SYM_ID = symbolTable.addTerminal(NULL_ALIGN_WRD_SYM);
NON_TERMINAL_TAG_SYM_ID = symbolTable.addNonterminal(NON_TERMINAL_TAG_SYM);
//read weights files
eweights_table = readWeightFile("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\lex.f2e.gz");
fweights_table = readWeightFile("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\lex.e2f.gz");
HashtableBasedHieroGrammarScorer grammar;
if(fweights_table==null)
grammar= new HashtableBasedHieroGrammarScorer(2);
else
grammar= new HashtableBasedHieroGrammarScorer(4);
BufferedReader t_reader_tree = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\parse.sync.berkeley1","UTF8");
//BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
BufferedReader t_reader_align = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\aligned.ibm","UTF8");
//BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
BufferedReader t_reader_zh = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\aligned.zh","UTF8");
//BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
BufferedReader t_reader_en = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\aligned.en.tmp1","UTF8");
//BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
String alignLine, frLine, enLine;
int nLine=1;
while((alignLine=FileUtilityOld.readLineLzf(t_reader_align))!=null){
frLine = FileUtilityOld.readLineLzf(t_reader_zh);
enLine = FileUtilityOld.readLineLzf(t_reader_en);
//if(n_line++==1)continue;
List<Rule> rulesAndPhrases = processASentence( alignLine, frLine, enLine);
if(rulesAndPhrases!=null){//write into the grammar
for(int t=0; t< rulesAndPhrases.size(); t++)
grammar.addRawRule(rulesAndPhrases.get(t));
}
if(nLine>=50)
break;
nLine++;
}
System.out.println("total lines: " + nLine + "; total ini phrases " + g_num_init_phrases + "; total rules and phrases " + g_num_rules_and_phrases );
grammar.score_grammar();
}
}