Package com.wcohen.ss.abbvGapsHmm

Source Code of com.wcohen.ss.abbvGapsHmm.AlignmentPredictionModel

package com.wcohen.ss.abbvGapsHmm;


import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;


/**
* @author Dana Movshovitz-Attias
*/
public class AlignmentPredictionModel {
 
  public static final String SEPARATOR = "#_#";
 
  public static String _trainingDataDir;
  public static String _trueLabelsFile;
  public static String _trainingCorpusFile;

  private AbbvGapsHMM _abbvHmm = null;

  public AlignmentPredictionModel() throws IOException{
    _abbvHmm = new AbbvGapsHMM();
    setTrainingDataDir("train/");
  }
 
  public void setTrainingDataDir(String trainDir) {
    _trainingDataDir = trainDir;
    _trueLabelsFile = _trainingDataDir+"abbvAlign_pairs.txt";
    _trainingCorpusFile = _trainingDataDir+"abbvAlign_corpus.txt";
  }
 
  public void setTfIdfData(String dataFile) throws IOException{
    _abbvHmm.setTfIdfData(dataFile);
  }
 
  public void setModelParamsFile(String paramFilename){
    _abbvHmm.setParamFile(paramFilename);
  }
  public void setModelParamsFile(){
    setModelParamsFile("hmmModelParams.txt");
  }
 
  public static ArrayList<Map<String, String>> loadLabels(String labelsFile) {
    if(labelsFile == null)
      return null;
   
    ArrayList<Map<String, String>> labels = null;
   
    try{
      BufferedReader fi = new BufferedReader(new FileReader(labelsFile));
     
      labels = new ArrayList<Map<String, String>>();

      String docLine;
      while( (docLine = fi.readLine()) != null){

        Map<String, String> docAcronymMap = new HashMap<String, String>();

        String acronyms[] = docLine.split(SEPARATOR);

        for (int i = 0; i < acronyms.length; i++) {
          String singleAcronym = acronyms[i];

          if(singleAcronym.isEmpty())
            continue;

          String parts[] = singleAcronym.split("\t");
          if(parts.length != 2){
            System.out.println("BAD FORMAT in "+labelsFile+": "+singleAcronym);
          }
          else{
            docAcronymMap.put(parts[0].trim(), parts[1].trim());
          }
        }

        labels.add(docAcronymMap);
      }
      fi.close();
    }
    catch(Exception e){
      e.printStackTrace();
      System.exit(1);
    }

    return labels;
  }
 
  public static List<String> loadTrainingCorpus(String corpusFile){
   
    List<String> trainingCorpus = null;
    try{
      trainingCorpus = new ArrayList<String>();

      BufferedReader fi = new BufferedReader(new FileReader(corpusFile));
      String line;
      while( (line = fi.readLine()) != null){
        trainingCorpus.add(line);
      }
      fi.close();
    }catch (Exception e) {
      e.printStackTrace();
      System.exit(1);
    }
   
    return trainingCorpus;
  }
 
  // Trains on full corpus
  public boolean trainOnAll(){
   
    List<Map<String, String>> trueLabels = loadLabels(_trueLabelsFile);
    List<String> corpus = loadTrainingCorpus(_trainingCorpusFile);
   
    List<List<Acronym>> trainingExtractedCandidates = new ArrayList<List<Acronym>>();
    List<Map<String, String>> trueLabelsForTraining = new ArrayList<Map<String,String>>();

    for(Integer docID = 0; docID < corpus.size(); ++docID){
      // Adds to training examples, all the extracted pairs from the current document.
      trainingExtractedCandidates.add( extractCandidatePairs(corpus.get(docID)) );
      trueLabelsForTraining.add( trueLabels.get(docID) );
    }

    return _abbvHmm.train(trainingExtractedCandidates, trueLabelsForTraining, true);
  }
 
  // Trains on candidate pairs extracted from the corpus
  public boolean trainOnCandidates(){
   
    List<String> corpus = loadTrainingCorpus(_trainingCorpusFile);
   
    List<List<Acronym>> trainingExtractedCandidates = new ArrayList<List<Acronym>>();

    for(Integer docID = 0; docID < corpus.size(); ++docID){
      // Adds to training examples, all the extracted pairs from the current document.
      trainingExtractedCandidates.add( extractCandidatePairs(corpus.get(docID)) );
    }

    return _abbvHmm.train(trainingExtractedCandidates, null, true);
  }


  public boolean train(List<String> corpus, List<Integer> trainingSet, List<Map<String, String>> trueLabels) {
    List<List<Acronym>> trainingExtractedCandidates = new ArrayList<List<Acronym>>();
    List<Map<String, String>> trueLabelsForTraining = new ArrayList<Map<String,String>>();

    if(trainingSet != null){
      for (Integer docID : trainingSet) {
        // Adds to training examples, all the extracted pairs from the current document.
        trainingExtractedCandidates.add( extractCandidatePairs(corpus.get(docID)) );
        trueLabelsForTraining.add( trueLabels.get(docID) );
      }
    }
    else{
      // Iterates over all documents in the corpus
      for(int docID = 0; docID < corpus.size(); ++docID){
        // Adds to training examples, all the extracted pairs from the current document.
        trainingExtractedCandidates.add( extractCandidatePairs(corpus.get(docID)) );
        trueLabelsForTraining.add( trueLabels.get(docID) );
      }
    }

    return _abbvHmm.train(trainingExtractedCandidates, trueLabelsForTraining, true);
  }
 
  public AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> predict(String sf, String lf) {
    return predictAlignment(new Acronym(sf, lf));
  }

  public AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> predictAlignment(Acronym candidatePair) {
    return _abbvHmm.viterbi(candidatePair);
  }
 
  public Acronym predict(Acronym candidatePair) {

    AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> alignment = predictAlignment(candidatePair);
    Acronym currAcronym = null;

    try {
      if(alignment == null){
        // No good alignment found
        return null;
      }
     
      currAcronym = alignment.getAcronym();
      if(currAcronym != null){
        AbbreviationAlignmentContainer<AbbvGapsHMM.Emissions, AbbvGapsHMM.States> acronymAlignment = predictAlignment(currAcronym);
        currAcronym._probability = acronymAlignment.getProbability();
        currAcronym._alignment = alignment;
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
    return currAcronym;
  }
 
  public Map<String, Acronym> acronymsArrayToMap(Collection<Acronym> pairs){
    Map<String, Acronym> out = new HashMap<String, Acronym>();

    for (Acronym acronymPair : pairs) {
      if(out.containsKey(acronymPair._shortForm)){
        Acronym prevAcronym = out.get(acronymPair._shortForm);
        if(acronymPair._probability != null && prevAcronym._probability != null){
          if(acronymPair._probability.compareTo(prevAcronym._probability) > 0){
            out.put(acronymPair._shortForm, acronymPair);
          }
        }
      }
      else{
        out.put(acronymPair._shortForm, acronymPair);
      }
    }

    return out;
  }
 
  public Collection<Acronym> predict(String text) {
    List<Acronym> candidates = extractCandidatePairs(text);

    List<Acronym> predictions = new ArrayList<Acronym>();
    Acronym currPrediction;

    for (Acronym candidateAcronym : candidates) {
      currPrediction = predict(candidateAcronym);
      if(currPrediction != null){
        predictions.add(currPrediction);
      }
    }
   
    return predictions;
  }
 
  public boolean trainIfNeeded() {
    if(!_abbvHmm.loadModelParams()){
      return trainOnCandidates();
    }
    return true;
  }
 
  /**** Candidates Extraction ****/
  public List<Acronym> extractCandidatePairs(String text) {
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();
   
    extractedPairs.addAll(extractSingleAcronyms(text));
    extractedPairs.addAll(extractPatternAcronyms(text));
   
    return extractedPairs;
  }
 
  protected List<Acronym> extractPatternAcronyms(String text){
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();
   
    extractedPairs.addAll(extractHeadNounPattern_2Parts(text));
    extractedPairs.addAll(extractHeadNounPattern_3Parts(text));
    extractedPairs.addAll(extractTrailingNounPattern_2Parts(text));
    extractedPairs.addAll(extractTrailingNounPattern_3Parts(text));
   
    return extractedPairs;
  }
 
  protected void addCandidatePair(List<Acronym> allPairs, String longFormCandidate, String shortFormCandidate){
    Acronym pair = parseCandidate(longFormCandidate, shortFormCandidate);
    if(pair != null && !pair._shortForm.isEmpty()){
      allPairs.add(pair);
    }
  }
 
  protected List<Acronym> extractHeadNounPattern_3Parts(String text){
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();
   
    String nounExp = "([a-zA-Z0-9\\-]{1,20})";
    String shortFormExp = "\\(([^\\(]*?)\\)";
   
    Matcher matcher = Pattern.compile (nounExp+" "+nounExp+" "+shortFormExp+",? "+nounExp+" "+shortFormExp+",? and "+nounExp+" "+shortFormExp).matcher(text);
    int startPos = 0;
    while (startPos < text.length() && matcher.find(startPos))
    {
      String mainNoun = matcher.group(1);
     
      String part1 = matcher.group(2);
      String part1_short = matcher.group(3);
      String part2 = matcher.group(4);
      String part2_short = matcher.group(5);
      String part3 = matcher.group(6);
      String part3_short = matcher.group(7);
     
      startPos = matcher.regionEnd() + 1;
     
      addCandidatePair(extractedPairs, mainNoun+" "+part1, part1_short);
      addCandidatePair(extractedPairs, mainNoun+" "+part2, part2_short);
      addCandidatePair(extractedPairs, mainNoun+" "+part3, part3_short);
    }
   
    return extractedPairs;
  }
 
  protected List<Acronym> extractHeadNounPattern_2Parts(String text){
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();
   
    String nounExp = "([a-zA-Z0-9\\-]{1,20})";
    String shortFormExp = "\\(([^\\(]*?)\\)";
   
    Matcher matcher = Pattern.compile (nounExp+" "+nounExp+" "+shortFormExp+",? and "+nounExp+" "+shortFormExp).matcher(text);
    int startPos = 0;
    while (startPos < text.length() && matcher.find(startPos))
    {
      String mainNoun = matcher.group(1);
     
      String part1 = matcher.group(2);
      String part1_short = matcher.group(3);
      String part2 = matcher.group(4);
      String part2_short = matcher.group(5);
     
      startPos = matcher.regionEnd() + 1;
     
      addCandidatePair(extractedPairs, mainNoun+" "+part1, part1_short);
      addCandidatePair(extractedPairs, mainNoun+" "+part2, part2_short);
    }
   
    return extractedPairs;
  }
 
  protected List<Acronym> extractTrailingNounPattern_3Parts(String text){
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();
   
    String finalNounExp = "([a-zA-Z0-9\\-]{1,20})";
    String nounExp = "(.{1,20}?)";
    String shortFormExp = "\\(([^\\(]*?)\\)";
   
    Matcher matcher = Pattern.compile (nounExp+" "+shortFormExp+",? "+nounExp+" "+shortFormExp+",? and "+nounExp+" "+shortFormExp+" "+finalNounExp).matcher(text);
    int startPos = 0;
    while (startPos < text.length() && matcher.find(startPos))
    {
      String part1 = matcher.group(1);
      String part1_short = matcher.group(2);
      String part2 = matcher.group(3);
      String part2_short = matcher.group(4);
      String part3 = matcher.group(5);
      String part3_short = matcher.group(6);
      String mainNoun = matcher.group(7);
     
      startPos = matcher.regionEnd() + 1;
     
      addCandidatePair(extractedPairs, part1 + " " + mainNoun, part1_short);
      addCandidatePair(extractedPairs, part2 + " " + mainNoun, part2_short);
      addCandidatePair(extractedPairs, part3 + " " + mainNoun, part3_short);
    }
   
    return extractedPairs;
  }
 
  protected List<Acronym> extractTrailingNounPattern_2Parts(String text){
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();
   
    String finalNounExp = "([a-zA-Z0-9\\-]{1,20})";
    String nounExp = "(.{1,20}?)";
    String shortFormExp = "\\(([^\\(]*?)\\)";
   
    Matcher matcher = Pattern.compile (nounExp+" "+shortFormExp+",? and "+nounExp+" "+shortFormExp+" "+finalNounExp).matcher(text);
    int startPos = 0;
    while (startPos < text.length() && matcher.find(startPos))
    {
      String part1 = matcher.group(1);
      String part1_short = matcher.group(2);
      String part2 = matcher.group(3);
      String part2_short = matcher.group(4);
      String mainNoun = matcher.group(5);
     
      startPos = matcher.regionEnd() + 1;
     
      addCandidatePair(extractedPairs, part1 + " " + mainNoun, part1_short);
      addCandidatePair(extractedPairs, part2 + " " + mainNoun, part2_short);
    }
   
    return extractedPairs;
  }

  protected List<Acronym> extractSingleAcronyms(String text){
    ArrayList<Acronym> extractedPairs = new ArrayList<Acronym>();

    int iOpen = text.indexOf("(");
    int iClose = -1;
    String mOutOfPar = "";
    String mInPar = "";

    while(iOpen != -1){
      iClose = -1;
      int numPar = 0;
      for (int p = iOpen+1; p < text.length(); p++) {
        if(text.charAt(p) == '('){
          numPar++;
        }
        if(text.charAt(p) == ')'){
          if(numPar > 0)
            numPar--;
          else{
            iClose = p;
            break;
          }
        }
      }
      if(iClose != -1){
        mInPar = text.substring(iOpen+1, iClose);
        mOutOfPar = text.substring(0, iOpen).trim();
       
        addCandidatePair(extractedPairs, mOutOfPar, mInPar);
      }

      iOpen = text.indexOf("(", iOpen+1);
    }

    return extractedPairs;
  }

  protected Acronym parseCandidate(String outOfParenthesis, String inParenthesis) {

    if(inParenthesis.indexOf(";") != -1){
      int i = inParenthesis.indexOf(";");
      inParenthesis = inParenthesis.substring(0, i);
    }

    if(outOfParenthesis.indexOf(";") != -1){
      int i = outOfParenthesis.indexOf(";");
      outOfParenthesis = outOfParenthesis.substring(i+1);
    }

    // Default assumption: long form is outside the parenthesis
    String shortForm = inParenthesis.trim();
    String longForm = outOfParenthesis.trim();

    // Unless default was found not to be true
    if(!isShortForm(shortForm)){
      longForm = inParenthesis.trim();

      String parts[] = outOfParenthesis.trim().split(" ");
      shortForm = parts[parts.length-1];

    }
    // Is the short form valid?
    if(!isValidShortForm(shortForm)){
      return null;
    }

    if(!isValidExpression(shortForm) || ! isValidExpression(longForm)){
      return null;
    }

    // Chunk long form to correct size
    String parts[] = longForm.split(" ");
    int sfSize = shortForm.length();
    int maxLongFormLength = Math.min( sfSize+5, sfSize*2 );
    int finalLfSize = Math.min( maxLongFormLength, parts.length );


    String finalLongForm = "";
    for (int i = parts.length-1; i > parts.length-finalLfSize-1; i--) {
      finalLongForm = parts[i] + " " + finalLongForm;
    }
    finalLongForm = finalLongForm.trim();

    if(shortForm.equalsIgnoreCase(finalLongForm)){
      return null;
    }

    return new Acronym(shortForm, finalLongForm);
  }
 
  protected String chunkLongForm(String longForm, int size){
    int foundWords = 0;
    int i = longForm.length()-1;
    for (; i >= 0 && foundWords < size ; i--) {
      if(i == 0 || !Character.isLetterOrDigit(longForm.charAt(i-1))){
        foundWords++;
      }
    }
    return longForm.substring(i+1, longForm.length());
  }

  protected boolean isValidExpression(String exp){
    if(    exp == null
        ||
        exp.isEmpty()
    ){
      return false;
    }
    return true;
  }

  protected boolean isShortForm(String candidate) {
    String parts[] = candidate.split(" ");
    return parts.length <= 3;
  }

  protected boolean isValidShortForm(String candidate) {
    // length restriction
    if(candidate.length() > 15)
      return false;
    // length restriction
    if(candidate.length() < 1)
      return false;
    // first char is alpha-numeric
    if(!Pattern.matches("^[a-zA-Z0-9].*", candidate))
      return false;
    // at least one of these characters is a letter
    if(!Pattern.matches(".*[a-zA-Z].*", candidate))
      return false;

    return true;
  }
 
  public List<Double> getEmmisions(){
    return _abbvHmm.getEmmisionParams();
  }
  public List<Double> getTransitions(){
    return _abbvHmm.getTransitionParams();
  }
 
  public void setStartingParams(List<Double> emmisions, List<Double> transitions){
    _abbvHmm.setStartingParams(emmisions, transitions);
  }

}
TOP

Related Classes of com.wcohen.ss.abbvGapsHmm.AlignmentPredictionModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.