Package joshua.discriminative.syntax_reorder

Source Code of joshua.discriminative.syntax_reorder.HieroExtractor

package joshua.discriminative.syntax_reorder;

import java.io.BufferedReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.List;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.syntax_reorder.HashtableBasedHieroGrammarScorer.Rule;



/* Zhifei Li, <zhifei.work@gmail.com>
* Johns Hopkins University
*/
//TODO: (1) may ignore flat phrases in extract_rules; (2) the accummulation of the first feat;
//(3) the accumulation of other feats; (4) weight calculation; (5) ignore phrase whose lexical weght is zero

public class HieroExtractor {
 
 
  public static int INVALID_POS=-1;
  public static int INVALID_WRD_ID=-1;//TODO must be different from terminal and non-terminal symbols
 
  public static String NULL_ALIGN_WRD_SYM="NULL";
  public static int NULL_ALIGN_WRD_SYM_ID=0;
 
  public static String NON_TERMINAL_TAG_SYM = "PHRASE";//tag for [PHRASE] or [X]
  public static int NON_TERMINAL_TAG_SYM_ID = 0;//tag for [PHRASE] or [X]
 
  public static int maxInitPhraseSize = 10;
  public static int max_final_phrase_size = 5;
  public static int min_sub_phrase_size = 2;
  public static int max_num_non_terminals = 2;
 
  public static String file_align="";
  public static String file_zh="";
  public static String file_en="";
  public static String dir_grammar_out="";
 
  public static String file_f2e_lexical_weights="";
  public static String file_e2f_lexical_weights="";
 
  public static Boolean allow_non_lexicial_rules=false;
  public static Boolean forbid_adjacent_nonterminals=true;//in french
  public static Boolean require_aligned_terminal=true;
 
  public static Boolean use_tight_phrase=true;
 
  public static Boolean remove_overlap_phrases=false;
  public static Boolean keep_alignment_infor=false;
 
  public static HashMap fweights_table;
  public static HashMap eweights_table;
  public static float[] fweights; //sentence-specific dynamical array
  public static float[] eweights; //sentence-specific dynamical array
  public static float[] fratios; //sentence-specific dynamical array
 
  public static int g_num_init_phrases=0;
  public static int g_num_rules_and_phrases=0;
 
  public static SymbolTable symbolTable = null; //TODO
 
 
 
  private static  HashMap readWeightFile(String file){
    //BufferedReader t_reader_tree = FileUtility.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\parse.sync.berkeley1","UTF8");
    BufferedReader t_reader = FileUtilityOld.getReadFileStream(file,"UTF8");
    HashMap res = new HashMap();
    String line;
    int n=0;
    while((line=FileUtilityOld.readLineLzf(t_reader))!=null){
      n++;
      if(n%500000==0) System.out.println("reading lines " + n);
      String[] fds = line.split("\\s+");//format: wrd1 wrd2 weight
      int id1 = symbolTable.addTerminal(fds[0]);
      int id2 = symbolTable.addTerminal(fds[1]);     
      res.put(form_weight_key(id1,id2), new Double(fds[2]));
    }
    FileUtilityOld.closeReadFile(t_reader);
    return res;
  }
  private static String form_weight_key(int id1, int id2){
    StringBuffer res = new StringBuffer();
    res.append(id1);
    res.append("-");
    res.append(id2);
    return res.toString();
  }
  private static double get_weight_from_matrix(HashMap weights_maxtrix, int f, int e){
    String key = form_weight_key(f,e);
    return weights_maxtrix.containsKey(key) ? (Double)weights_maxtrix.get(key) : 0;
  }
 
  private static List<Rule> processASentence(String line_align, String line_fr, String line_en){ 

    //==== create alignment datastructure
    Alignment align = new Alignment(line_fr, line_en, line_align);
   
    //==== compute weights
    if(fweights_table!=null)
      fweights = computeLexicalWeights(align, fweights_table, false, false);
    if(eweights_table!=null)
      eweights = computeLexicalWeights(align, eweights_table, true, false);//transpose
   
   
    //==== exphrase phrases
    int[] actualMaxInitPhraseSize = new int[1];
    ArrayList initPhrases = extractPhrases(align, maxInitPhraseSize, actualMaxInitPhraseSize);//extract regular flat phrases;
    if(initPhrases.size()==0){
      System.out.println("warning: no init phrases are extracted");
      return null;
    }
   
    //==== test-set specific filtering, the extract_phrases can use suffix-array architecture
    if(use_tight_phrase==false
      loosenPhrases();
   
    if(remove_overlap_phrases==true)
      removeOverlapPhrases();
   
    //==== create index ??           
    //==== add label to the flat phrases
    //done in extract_phrases, if we want to implement loosen_phrases() and remove_overlap_phrases(), we should put it here
   
    //==== extract rules: the Rule is a list of phrases from a specific training sentence-pair
    List<Rule> rulesAndPhrases = extractRules(align, initPhrases, max_final_phrase_size, min_sub_phrase_size, max_num_non_terminals, actualMaxInitPhraseSize[0]);     
   
    return rulesAndPhrases;
  }
 
 
  //sentence-specific
  //for each word, calculate the lexicalized weight based on the alighment and the lexical weight tables
  private static float[] computeLexicalWeights(Alignment align, HashMap weights_maxtrix, Boolean transpose, Boolean swap){
    int[] fwords, ewords, faligned;
    if(transpose==false){
      fwords = align.french_wrds;
      ewords = align.english_wrds;
      faligned = align.num_alignments_infor_for_french;
    }else{
      fwords = align.english_wrds;
      ewords = align.french_wrds;
      faligned = align.num_alignments_infor_for_english;
    }
    float[] results = new float[fwords.length];
    for(int i=0; i<fwords.length; i++){
      float total = 0;
      int n=0;
      if(faligned[i]>0){
        for(int j=0; j<ewords.length; j++){
          int flag;
          if(transpose == false)
            flag = align.alignment_matrix[i][j];
          else
            flag = align.alignment_matrix[j][i];
          if(flag==1){//aligned
            if(swap == false)
              total += get_weight_from_matrix(weights_maxtrix, fwords[i],ewords[j]);
            else
              total += get_weight_from_matrix(weights_maxtrix, ewords[j],fwords[i]);           
            n++;
          }           
        }
      }else{//unaligned
        if(swap == false)
          total += get_weight_from_matrix(weights_maxtrix, fwords[i],NULL_ALIGN_WRD_SYM_ID);
        else         
          total += get_weight_from_matrix(weights_maxtrix,NULL_ALIGN_WRD_SYM_ID, fwords[i]);
        n++;
      }
      results[i] = total/n;
    }
   
    System.out.println("weights are ");
    for(int t=0; t<results.length; t++)
      System.out.print(results[t]);
    System.out.print("\n");
   
    return results;   
  }
 
 
  //extract flat phrases: i1,i2,j1,j2; who are the positions in the french and english sentences
  private static ArrayList extractPhrases(Alignment align, int maxInitPhraseSize, int[] actualMaxInitPhraseSize){
    actualMaxInitPhraseSize[0]=0;
    ArrayList l_phrases = new ArrayList();
    int i1,i2,j1,j2;//french span: [i1,i2]; eng span: [j1,j2]
    for(i1=0; i1< align.french_wrds.length; i1++){
      if(align.num_alignments_infor_for_french[i1]<=0) continue;//skip unaligned wrd        
      j1 = align.english_wrds.length;
      j2 = -1;
      for(i2=i1; i2<Alignment.min(i1+maxInitPhraseSize, align.french_wrds.length); i2++){
        if(align.num_alignments_infor_for_french[i2]<=0) continue;//skip unaligned wrd          
       
        //j1 and j2: [j1,j2] is the "maximum" (thoug may be in-consistent) eng span for the french span  [i1,i2]
        j1 = Alignment.min(j1, align.min_pos_infor_for_french[i2]);
        j2 = Alignment.max(j2, align.max_pos_infor_for_french[i2]);
       
        if(j1>j2) continue;//empty english span
       
        if(j2-j1+1>maxInitPhraseSize) break; //go to next i1, since adding more wrds in french will increase the eng span
       
        int flag=0;
        for(int j=j1; j<=j2; j++){//for each english wrd in [j1,j2]
          if(align.min_pos_infor_for_english[j]<i1){//must extend i1 to solve inconsistence
            flag = 1; //next i1
            break;
          }
          if(align.max_pos_infor_for_english[j]>i2){//fix i1, but extending i2 may solve this inconsistence
            flag = 2; //next i2
            break;
          }
        }
        if(flag==1break; //next i1
        if(flag==2continue;//next i2
       
        //add the phrase
        l_phrases.add( new int[] {i1,i2,j1,j2, NON_TERMINAL_TAG_SYM_ID});
        actualMaxInitPhraseSize[0] = Alignment.max(actualMaxInitPhraseSize[0], i2-i1+1);
      }
    }
    g_num_init_phrases += l_phrases.size();
    return l_phrases;   
  }
 
  private static void loosenPhrases(){ 
    System.out.println("Error: un-implemented function");
    System.exit(0);
  }
 
  private static void removeOverlapPhrases(){
    System.out.println("Error: un-implemented function");
    System.exit(0);
  }
 
 
  //extract hiearchical rules
  private static List<Rule> extractRules(Alignment align, ArrayList l_init_phrases, int max_final_phrase_size, int min_sub_phrase_size, int max_num_non_terminals, int actual_max_init_phrase_size){
    if(l_init_phrases.size()==0){
      System.out.println("warning: no flat phrases are extracted");
      return null;
    }
    int n= align.english_wrds.length;
    ArrayList[][][] bins = new ArrayList[n+1][n+1][max_num_non_terminals+1];//bins[i][j] spans the french [i,j-1], each bin is a list of items
    ArrayList[] i2index = new  ArrayList[n];//each is a list of init-phrases ending with i2
    Hashtable i1s = new Hashtable();

    //get i2index and i1s   
    for(int t=0; t < l_init_phrases.size(); t++){
      int[] phrase = (int[])l_init_phrases.get(t);
      printPhrase(align,phrase);
      int t_i1 =phrase[0], t_i2 =phrase[1];
      if(i2index[t_i2]==nulli2index[t_i2] = new ArrayList();
      i2index[t_i2].add(phrase);
      if(i1s.containsKey(t_i1)==false){
        i1s.put(t_i1, 1);       
        //chart seeding
        bins[t_i1][t_i1][0] = new ArrayList();
        bins[t_i1][t_i1][0].add(new ArrayList());//add empty item
      }
    }   
    System.out.println("num of init phrases is " + l_init_phrases.size() + "; i1s len: " + i1s.size() + "; french len " + align.french_wrds.length + " maxabslen; " + actual_max_init_phrase_size);
    //chart parsing: each item is an arraylist
    int loop1=0;
    int loop2=0;
    for(int k=1; k<=Alignment.min(n, actual_max_init_phrase_size); k++){
      loop1++;
      loop2=0;
      for(int i1=0; i1+k<=n; i1++){
        if(i1s.containsKey(i1)==false) continue;//because the phrases never start from this index; bug: this may skip the flat phrase
        loop2++;
        int i2=i1+k-1;             
        //extend the dot by a subphrase       
        int tem1=0, tem2=0;
        if(i2index[i2]!=null){
          for(int t=0; t<i2index[i2].size(); t++){//for all sub-phrases ending at i2
            int[] sub_phrase = (int[]) i2index[i2].get(t);
            if(sub_phrase[1]-sub_phrase[0]+1>=min_sub_phrase_size){
              for(int n_nts=0; n_nts<max_num_non_terminals; n_nts++){
                if(bins[i1][sub_phrase[0]][n_nts]!=null){//no ant-items
                  for(int t2=0; t2<bins[i1][sub_phrase[0]][n_nts].size(); t2++){//for all ant-items
                    ArrayList item = (ArrayList) bins[i1][sub_phrase[0]][n_nts].get(t2);
                    if(item.size()<max_final_phrase_size &&
                       !(forbid_adjacent_nonterminals && item.size()>0 && !(item.get(item.size()-1) instanceof Integer))
                    ){
                      ArrayList new_item = new ArrayList(item);
                      new_item.add(sub_phrase);
                      if(bins[i1][i2+1][n_nts+1]==null) bins[i1][i2+1][n_nts+1] = new ArrayList();                 
                      bins[i1][i2+1][n_nts+1].add(new_item);
                      tem1++;
                    }
                  }
                }
              }
            }
          }
        }
        //extend the dot by a wrd 
        for(int n_nts=0; n_nts<=max_num_non_terminals; n_nts++){
          if(bins[i1][i2][n_nts]!=null){
            for(int t2=0; t2<bins[i1][i2][n_nts].size(); t2++){
              ArrayList item = (ArrayList) bins[i1][i2][n_nts].get(t2);
              if(item.size()<max_final_phrase_size ){
                ArrayList new_item = new ArrayList(item);
                new_item.add(i2);
                if(bins[i1][i2+1][n_nts]==null) bins[i1][i2+1][n_nts] = new ArrayList();                 
                bins[i1][i2+1][n_nts].add(new_item);
                tem2++;
              }
            }
          }
        }
        //tem3++;
        //System.out.println(loop1 + " "+ loop2 + " number of subphrases textend is " + tem1 + " and " + tem2);
      }
    }
   
    //extract rules from the chart
    ArrayList l_phrases_and_rules = new ArrayList();
    for(int t=0; t <l_init_phrases.size(); t++){
      ArrayList local_results = new ArrayList();
      int[] phrase = (int[])l_init_phrases.get(t);
      for(int n_nts=0; n_nts<=max_num_non_terminals; n_nts++){
        if(bins[phrase[0]][phrase[1]+1][n_nts]!=null){
          for(int t2=0; t2<bins[phrase[0]][phrase[1]+1][n_nts].size(); t2++){
            ArrayList item = (ArrayList) bins[phrase[0]][phrase[1]+1][n_nts].get(t2);
            Rule rule = makeAndScoreRule(align, phrase, item);
            if(rule!=null){
              local_results.add(rule);
            }
          }
        }
      }     
      //distribute the count, normalization
      for(int k=0; k < local_results.size(); k++){
        Rule rl = (Rule)local_results.get(k);
        for(int f=0; f<rl.feat_scores.length; f++)
          rl.feat_scores[f] /= local_results.size();
        rl.print_info(symbolTable);
      }
      l_phrases_and_rules.addAll(local_results);
      //System.out.println("local size " + local_results.size() + " all size " + l_phrases_and_rules.size());
    }
    System.out.println("num of rules and phrases is "  + l_phrases_and_rules.size());
    g_num_rules_and_phrases += l_phrases_and_rules.size();
    return   l_phrases_and_rules;
  }
 
  private static Rule makeAndScoreRule(Alignment align, int[] phrase, ArrayList item){
    //ignore ??
    if(item.size()==1 && !(item.get(0) instanceof Integer)) 
      return null; //bug: this may skip the flat phrase
   
    int nt_index=1;
    boolean have_alignment=false;
   
    int original_en_len = phrase[3]-phrase[2]+1;
    //System.out.println("original en len " + original_en_len);
    //System.out.println("item is " + item.toString());
    int[] original_en_wrds = new int[original_en_len] ;
    for(int t=0; t<original_en_len; t++)
      original_en_wrds[t]= align.english_wrds[t+phrase[2]];
 
    //get french words in the rule
    int[] rule_fwords = new int[item.size()];
    int[] fpos = new int[item.size()];//remember the original position in the sentence
    for(int t=0; t<item.size(); t++){
      //System.out.println("size is " + item.size());
      if(item.get(t) instanceof Integer){//terminal
        //System.out.println("terminal");
        fpos[t]=(Integer)item.get(t);
        if(align.num_alignments_infor_for_french[ fpos[t] ]>0){
          have_alignment=true;
        }
        rule_fwords[t]= align.french_wrds[ fpos[t] ];
      }else{//non-terminal
        //System.out.println("non-terminal");
        fpos[t]=INVALID_POS;
        int[] sub_phrase = (int[])item.get(t);
        original_en_len -= sub_phrase[3] - sub_phrase[2];//reserved one slot for the NT symbol
        //System.out.println("e span: " + sub_phrase[0] +" - " + sub_phrase[1] +" ; len " + original_en_len);
        int nt = symbolTable.addNonterminal(NON_TERMINAL_TAG_SYM+","+nt_index);//get [PHRASE,nt_index]
        rule_fwords[t] = nt;
        original_en_wrds[sub_phrase[2]-phrase[2]]=nt;
        for(int k=sub_phrase[2]-phrase[2]+1; k<=sub_phrase[3]-phrase[2]; k++)
          original_en_wrds[k]= INVALID_WRD_ID;       
        nt_index++;
      }
    }
   
    if(require_aligned_terminal && have_alignment==false)
      return null;
   
    //get English words in the rule 
    int[] rule_ewords = new int[original_en_len];//original_en_len is changed to actual en len
    int[] epos = new int[original_en_len];
    for(int t=0, k=0; t<original_en_wrds.length; t++){
      if(original_en_wrds[t]!=INVALID_WRD_ID){
        rule_ewords[k] = original_en_wrds[t];
        if(symbolTable.isNonterminal(original_en_wrds[t])==true)
          epos[k]=INVALID_POS;
        else
          epos[k]=phrase[2]+t;
        k++;
      }
    }
   
    //create rule
    Rule rl = new Rule(phrase[4], rule_fwords , rule_ewords);
   
    //add alignment infor
    if(keep_alignment_infor==true){
      ArrayList align_info =new ArrayList();
      for(int i=0; i<fpos.length; i++)
        if(fpos[i]!=INVALID_POS)
          for(int j=0; j<epos.length; j++)
            if(epos[j]!=INVALID_POS)
              if(align.alignment_matrix[fpos[i]][epos[j]]==1)
                align_info.add(i+"-"+j);//add "i-j"
      rl.alignments=align_info;
    }   
   
    scoreRule(align, rl, fpos, epos);
    return rl;
  }
 
  //compute and add feat scores
  private static Rule scoreRule(Alignment align, Rule r_in, int[] fpos, int[] epos){
    int funaligned=0, eunaligned=0;
    float fweight=1, eweight=1, fratio=0;
   
   
    //P_lex(eng|fr)
    for(int t=0; t<r_in.french.length; t++){
      if( symbolTable.isNonterminal(r_in.french[t])==false ){
        if(align.num_alignments_infor_for_french[fpos[t]]<=0)
          funaligned++;
        if(fweights!=null)
          fweight *= fweights[fpos[t]];
        if(fratios!=null)
          fratio += fratios[fpos[t]];
      }
    }
   
    //P_lex(fr|eng)
    for(int t=0; t<r_in.english.length; t++){
      if( symbolTable.isNonterminal(r_in.english[t])==false ){
        if(align.num_alignments_infor_for_english[epos[t]]<=0)
          eunaligned++;
        if(eweights!=null)
          eweight *= eweights[epos[t]];     
      }
    }
   
    //add the feat scores
    int num_feats=1;
    if(fweights!=null) num_feats++;
    if(eweights!=null) num_feats++;
    if(fratios!=null) num_feats++;
    float[] scores = new float[num_feats];
    int t_id=0;
    scores[t_id++]=1.0f;
    if(fweights!=null) scores[t_id++]=fweight;
    if(eweights!=null) scores[t_id++]=eweight;;
    if(fratios!=null) scores[t_id++]=fratio;;
    r_in.feat_scores=scores;
    return r_in;
  }
 
  private static void printPhrase(Alignment align, int[] phrase){
    String str="zh: " + phrase[0] + "-" + phrase[1];
    for(int t=phrase[0]; t<=phrase[1]; t++)
      str += " " + symbolTable.getWord(align.french_wrds[t]);
    str += " en: " + phrase[2] + "-" + phrase[3];
    for(int t=phrase[2]; t<=phrase[3]; t++)
      str += " " + symbolTable.getWord(align.english_wrds[t]);
    str += " nt: " + symbolTable.getWord(phrase[4]);
    System.out.println("phrase is = " +str);
  }
 
 
  public static void main(String[] args) {
   
    SymbolTable symbolTable = new BuildinSymbol();
    //init symbol
   
    NULL_ALIGN_WRD_SYM_ID = symbolTable.addTerminal(NULL_ALIGN_WRD_SYM);
    NON_TERMINAL_TAG_SYM_ID = symbolTable.addNonterminal(NON_TERMINAL_TAG_SYM);
   
    //read weights files
    eweights_table = readWeightFile("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\lex.f2e.gz");
    fweights_table = readWeightFile("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\lex.e2f.gz");
   
    HashtableBasedHieroGrammarScorer grammar;
    if(fweights_table==null)
      grammar= new HashtableBasedHieroGrammarScorer(2);
    else
      grammar= new HashtableBasedHieroGrammarScorer(4);
   
    BufferedReader t_reader_tree = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\parse.sync.berkeley1","UTF8");
    //BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
   
    BufferedReader t_reader_align = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\aligned.ibm","UTF8");
    //BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
   
    BufferedReader t_reader_zh = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\aligned.zh","UTF8");
    //BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
   
    BufferedReader t_reader_en = FileUtilityOld.getReadFileStream("C:\\data_disk\\java_work_space\\SyntaxMT\\phraseExtraction\\aligned.en.tmp1","UTF8");
    //BufferedReader t_reader_tree = FileUtility.getReadFileStream(args[1].trim(),"UTF8");
    String alignLine, frLine, enLine;
    int nLine=1;
    while((alignLine=FileUtilityOld.readLineLzf(t_reader_align))!=null){
      frLine = FileUtilityOld.readLineLzf(t_reader_zh);
      enLine = FileUtilityOld.readLineLzf(t_reader_en);
      //if(n_line++==1)continue;
      List<Rule> rulesAndPhrases = processASentence( alignLine, frLine, enLine);
      if(rulesAndPhrases!=null){//write into the grammar
        for(int t=0; t< rulesAndPhrases.size(); t++)
          grammar.addRawRule(rulesAndPhrases.get(t));
      }
      if(nLine>=50)
        break;
      nLine++;
    }
    System.out.println("total lines: " + nLine + "; total ini phrases  " + g_num_init_phrases + "; total rules and phrases  " + g_num_rules_and_phrases );
    grammar.score_grammar();
  }
 
 
}


TOP

Related Classes of joshua.discriminative.syntax_reorder.HieroExtractor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.