Package com.clearnlp.component.online

Source Code of com.clearnlp.component.online.OnlinePOSTagger

/**
* Copyright (c) 2009/09-2012/08, Regents of the University of Colorado
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
*    list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
*    this list of conditions and the following disclaimer in the documentation
*    and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/**
* Copyright 2012/09-2013/04, 2013/11-Present, University of Massachusetts Amherst
* Copyright 2013/05-2013/10, IPSoft Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.clearnlp.component.online;

import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;

import com.clearnlp.classification.feature.FtrToken;
import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.instance.StringInstance;
import com.clearnlp.classification.model.StringModelAD;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.vector.StringFeatureVector;
import com.clearnlp.component.evaluation.POSEval;
import com.clearnlp.component.state.TagState;
import com.clearnlp.dependency.DEPLib;
import com.clearnlp.dependency.DEPNode;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.nlp.NLPProcess;
import com.clearnlp.pattern.PTPunct;
import com.clearnlp.reader.AbstractColumnReader;
import com.clearnlp.util.UTArray;
import com.clearnlp.util.UTString;
import com.clearnlp.util.map.Prob2DMap;
import com.clearnlp.util.pair.Pair;
import com.clearnlp.util.pair.StringDoublePair;
import com.google.common.collect.Sets;

/**
* @since 2.0.1
* @author Jinho D. Choi ({@code jdchoi77@gmail.com})
*/
public class OnlinePOSTagger extends AbstractOnlineStatisticalComponent<TagState>
{
  protected final int LEXICA_LOWER_SIMPLIFIED_FORMS = 0;
  protected final int LEXICA_AMBIGUITY_CLASSE_PROB  = 1;
  protected final int LEXICA_AMBIGUITY_CLASSE_MAP   = 2;
 
  protected Set<String>        s_lsfs;  // lower simplified forms
  protected Prob2DMap          p_ambi;  // ambiguity classes (for collection)
  protected Map<String,String> m_ambi;  // ambiguity classes
 
  private StringModelAD s_model;
  private JointFtrXml   f_xml;
 
//  ====================================== CONSTRUCTORS ======================================

  /** Constructs a part-of-speech tagger for collecting lexica. */
  public OnlinePOSTagger(JointFtrXml[] xmls, Set<String> sLsfs)
  {
    super(xmls);
    f_xml  = f_xmls[0];
    s_lsfs = sLsfs;
    p_ambi = new Prob2DMap();
  }
 
  /** Constructs a part-of-speech tagger for training, bootstrapping, and decoding. */
  public OnlinePOSTagger(JointFtrXml[] xmls, Object[] lexica)
  {
    super(xmls, lexica, 1);
    init();
  }
 
  /** Constructs a part-of-speech tagger from an existing object. */
  public OnlinePOSTagger(ObjectInputStream in)
  {
    super(in);
    init();
  }
 
  private void init()
  {
    s_model = s_models[0];
    f_xml   = f_xmls[0];
  }
 
//  ====================================== LEXICA ======================================

  @Override
  public Object[] getLexica()
  {
    Object[] lexica = new Object[3];
   
    lexica[LEXICA_LOWER_SIMPLIFIED_FORMS] = s_lsfs;
    lexica[LEXICA_AMBIGUITY_CLASSE_PROB] = p_ambi;
    lexica[LEXICA_AMBIGUITY_CLASSE_MAP] = (m_ambi == null) ? getAmbiguityClasses() : m_ambi;
   
    return lexica;
  }
 
  @Override @SuppressWarnings("unchecked")
  public void setLexia(Object[] lexica)
  {
    s_lsfs = (Set<String>)lexica[LEXICA_LOWER_SIMPLIFIED_FORMS];
    p_ambi = (Prob2DMap)lexica[LEXICA_AMBIGUITY_CLASSE_PROB];
    m_ambi = (Map<String,String>)lexica[LEXICA_AMBIGUITY_CLASSE_MAP];
  }
 
  /** Called by {@link #getLexica()}. */
  private Map<String,String> getAmbiguityClasses()
  {
    Map<String,String> mAmbi = new HashMap<String,String>();
    double threshold = f_xml.getAmbiguityClassThreshold();
    StringDoublePair[] ps;
    StringBuilder build;
   
    for (String key : p_ambi.keySet())
    {
      build = new StringBuilder();
      ps = p_ambi.getProb1D(key);
      UTArray.sortReverseOrder(ps);
     
      for (StringDoublePair p : ps)
      {
        if (p.d <= thresholdbreak;
       
        build.append(AbstractColumnReader.BLANK_COLUMN);
        build.append(p.s);
      }
     
      if (build.length() > 0)
        mAmbi.put(key, build.substring(1));       
    }
   
    return mAmbi;
  }
 
//  ====================================== LOAD/SAVE MODELS ======================================
 
  @Override
  public void load(ObjectInputStream in) throws Exception
  {
    loadDefault(in);
    loadLexica (in);
    in.close();
  }
 
  @Override
  public void save(ObjectOutputStream out) throws Exception
  {
    saveDefault(out);
    saveLexica (out);
    out.close();
  }
 
  @SuppressWarnings("unchecked")
  protected void loadLexica(ObjectInputStream in) throws Exception
  {
    m_ambi = (Map<String,String>)in.readObject();
  }
 
  protected void saveLexica(ObjectOutputStream out) throws Exception
  {
    out.writeObject(m_ambi);
  }
 
//  ====================================== GETTERS ======================================

  @Override
  public Set<String> getLabels()
  {
    return Sets.newHashSet(s_model.getLabels());
  }
 
//  ====================================== PROCESS ======================================
 
  public void process(DEPTree tree, byte flag)
  {
    TagState state = initialize(tree, flag);
    List<StringInstance> insts = processAux(state, flag);
    finalize(state, insts, flag);
  }
 
  private List<StringInstance> processAux(TagState state, byte flag)
  {
    List<StringInstance> insts = getEmptyInstanceList(flag);
    String label = null;
   
    while (!state.isTerminate())
    {
      switch (flag)
      {
      case FLAG_COLLECT  : processCollect(state);            break;
      case FLAG_TRAIN    : label = processTrain(state, insts);    break;
      case FLAG_BOOTSTRAP: label = processBootstrap(state, insts)break;
      default            : label = processDecode(state);
      }
     
      setLabel(state.getInput(), label);
      state.moveForward();
    }
   
    return insts;
  }
 
  /** Called by {@link #process(DEPTree)}. */
  private TagState initialize(DEPTree tree, byte flag)
  {
    TagState state = new TagState(tree);
    simplifyForms(tree, flag);
   
    if (flag != FLAG_DECODE)
    {
      state.setGoldLabels(tree.getPOSTags());
     
      if (flag != FLAG_COLLECT)
        tree.clearPOSTags();
    }
   
    return state;
  }
 
  private void simplifyForms(DEPTree tree, byte flag)
  {
    NLPProcess.simplifyForms(tree);
  }
 
  private void finalize(TagState state, List<StringInstance> insts, byte flag)
  {
    if (isTrainOrBootstrap(flag))
    {
      s_model.addInstances(insts);
    }
    else if (isEvaluate(flag))
    {
      if (e_eval == null) e_eval = new POSEval();
      Object[] labels = state.getGoldLabels();
      DEPTree tree = state.getTree();
     
      e_eval.countAccuracy(tree, labels);
      tree.setPOSTags((String[])labels);
    }
  }
 
  /** Called by {@link #process(DEPTree)}. */
  private void processCollect(TagState state)
  {
    DEPNode input = state.getInput();
   
    if (s_lsfs.contains(input.lowerSimplifiedForm))
      p_ambi.add(input.simplifiedForm, input.pos);
  }
 
  /** Called by {@link #process(DEPTree)}. */
  private String processTrain(TagState state, List<StringInstance> insts)
  {
    StringFeatureVector vector = getFeatureVector(f_xml, state);
    String label = getGoldLabel(state);
    addInstance(state, insts, label, vector);
    return label;
  }
 
  /** Called by {@link #process(DEPTree)}. */
  private String processBootstrap(TagState state, List<StringInstance> insts)
  {
    StringFeatureVector vector = getFeatureVector(f_xml, state);
    String label = getAutoLabel(state, vector);
    addInstance(state, insts, getGoldLabel(state), vector);
    return label;
  }
 
  /** Called by {@link #process(DEPTree)}. */
  private String processDecode(TagState state)
  {
    StringFeatureVector vector = getFeatureVector(f_xml, state);
    String label = getAutoLabel(state, vector);
    return label;
  }
 
  private String getGoldLabel(TagState state)
  {
    return state.getGoldLabel();
  }
 
  /** Called by {@link #processBootstrap(TagState, List)} and {@link #processDecode(TagState)}. */
  private String getAutoLabel(TagState state, StringFeatureVector vector)
  {
    Pair<StringPrediction,StringPrediction> ps = s_model.predictTop2(vector);
    StringPrediction fst = ps.o1;
    StringPrediction snd = ps.o2;
   
    if (fst.score - snd.score < 1)
      state.getInput().addFeat(DEPLib.FEAT_POS2, snd.label);
   
    return fst.label;
  }
 
  private void addInstance(TagState state, List<StringInstance> insts, String goldLabel, StringFeatureVector vector)
  {
    if (!vector.isEmpty())
    {
      StringInstance instance = new StringInstance(goldLabel, vector);
      insts.add(instance);
    }
  }
 
//  /** Called by {@link #processBootstrap(TagState, List)} and {@link #processDecode(TagState)}. */
//  private StringPrediction setAutoLabel(StringFeatureVector vector, TagState state)
//  {
//    Pair<StringPrediction,StringPrediction> ps = s_model.predictTop2(vector);
//    DEPNode input = state.getInput();
//    StringPrediction fst = ps.o1;
//    StringPrediction snd = ps.o2;
//   
//    if (fst.isLabel(LABEL_DECAP) && input.simplifiedForm.equals(input.lowerSimplifiedForm))
//      return snd;
//   
//    return fst;
//   
//   
//    if (ps.o1.score - ps.o2.score >= 1) ps.o2 = null;
//   
//    DEPNode input = state.getInput();
//    setLabel(input, ps.o1.label);
//   
//    if (ps.o2 != null)
//      input.addFeat(DEPLib.FEAT_POS2, ps.o2.label);
   
//    DEPNode prev1 = state.getNode(input.id-1);
//    DEPNode prev2 = state.getNode(input.id-2);
//    boolean b = false;
//    String s;
//   
//    if (prev1 != null)
//    {
//      if ((input.isForm("Corp.") || input.isForm("Corp") || input.isForm("Inc.") || input.isForm("Inc")) && prev1.isPos("NNPS"))
//      {
//        prev1.pos = "NNP";
//        b = true;
//      }
//      else if (input.isPos("NNP") && prev1.isPos("NNPS"))
//      {
//        prev1.pos = "NNP";
//        b = true;
//      }
//      else if (input.isPos("NNPS") && prev1.isPos("NNP"))
//      {
//        ps.o1.label = input.pos = "NNP";
//        b = true;
//      }
//    }
//   
//    if (!b && prev2 != null)
//    {
//      if (input.isPos("NNPS") && prev1.isPos("CC") && prev2.isPos("NNP"))
//      {
//        ps.o1.label = input.pos = "NNP";
//        b = true;
//      }
//      else if (input.isPos("NNP") && prev1.isPos("CC") && prev2.isPos("NNPS"))
//      {
//        prev2.pos = "NNP";
//        b = true;
//      }
//    }
//  }
 
  private void setLabel(DEPNode input, String label)
  {
    input.setPOSTag(label);
  }
 
//  ====================================== FEATURE EXTRACTION ======================================

  @Override
  protected String getField(FtrToken token, TagState state)
  {
    DEPNode node = state.getNode(token);
    if (node == null) return null;
   
    switch (token.field)
    {
    case JointFtrXml.F_SIMPLIFIED_FORM:
      return containsLowerSimplifiedForm(node) ? node.simplifiedForm : null;
    case JointFtrXml.F_LOWER_SIMPLIFIED_FORM:
      return containsLowerSimplifiedForm(node) ? node.lowerSimplifiedForm : null;
    case JointFtrXml.F_POS:
      return node.pos;
    case JointFtrXml.F_POS2:
      return node.getFeat(DEPLib.FEAT_POS2);
    case JointFtrXml.F_AMBIGUITY_CLASS:
      return m_ambi.get(node.simplifiedForm);
    }
   
    Matcher m;
   
    if ((m = JointFtrXml.P_BOOLEAN.matcher(token.field)).find())
    {
      int field = Integer.parseInt(m.group(1));
      String value = token.field+token.offset;
     
      switch (field)
      {
      case  0: return UTString.isAllUpperCase(node.simplifiedForm) ? value : null;
      case  1: return UTString.isAllLowerCase(node.simplifiedForm) ? value : null;
      case  2: return UTString.beginsWithUpperCase(node.simplifiedForm) & !state.isInputFirstNode() ? value : null;
      case  3: return UTString.getNumOfCapitalsNotAtBeginning(node.simplifiedForm) == 1 ? value : null;
      case  4: return UTString.getNumOfCapitalsNotAtBeginning(node.simplifiedForm> 1 ? value : null;
      case  5: return node.simplifiedForm.contains(".") ? value : null;
      case  6: return UTString.containsDigit(node.simplifiedForm) ? value : null;
      case  7: return node.simplifiedForm.contains("-") ? value : null;
      case  8: return state.isInputLastNode() ? value : null;
      case  9: return state.isInputFirstNode() ? value : null;
      case 10: return PTPunct.containsOnlyPunctuation(node.lowerSimplifiedForm) ? value : null;
      default: throw new IllegalArgumentException("Unsupported feature: "+token.field);
      }
    }
    else if ((m = JointFtrXml.P_FEAT.matcher(token.field)).find())
      return node.getFeat(m.group(1));
    else if ((m = JointFtrXml.P_PREFIX.matcher(token.field)).find())
    {
      int n = Integer.parseInt(m.group(1)), len = node.lowerSimplifiedForm.length();
      return (n <= len) ? node.lowerSimplifiedForm.substring(0, n) : null;
    }
    else if ((m = JointFtrXml.P_SUFFIX.matcher(token.field)).find())
    {
      int n = Integer.parseInt(m.group(1)), len = node.lowerSimplifiedForm.length();
      return (n <= len) ? node.lowerSimplifiedForm.substring(len-n, len) : null;
    }
    else
      throw new IllegalArgumentException("Unsupported feature: "+token.field);
  }
 
  @Override
  protected String[] getFields(FtrToken token, TagState state)
  {
    DEPNode node = state.getNode(token);
    if (node == null) return null;
    String[] fields = null;
    Matcher m;
   
    if ((m = JointFtrXml.P_PREFIX.matcher(token.field)).find())
    {
      fields = UTString.getPrefixes(node.lowerSimplifiedForm, Integer.parseInt(m.group(1)));
    }
    else if ((m = JointFtrXml.P_SUFFIX.matcher(token.field)).find())
    {
      fields = UTString.getSuffixes(node.lowerSimplifiedForm, Integer.parseInt(m.group(1)));
    }
   
    return (fields == null) || (fields.length == 0) ? null : fields;
  }
 
  private boolean containsLowerSimplifiedForm(DEPNode node)
  {
    return s_lsfs == null || s_lsfs.contains(node.lowerSimplifiedForm);
  }
   
//  private boolean isMeta(String lowerSimplifiedForm)
//  {
//    return lowerSimplifiedForm.equals(MPLib.META_URL) ||
//         PTPunct.containsOnlyPunctuation(lowerSimplifiedForm) ||
//         PTNumber.containsOnlyDigits(lowerSimplifiedForm);
//  }
}
TOP

Related Classes of com.clearnlp.component.online.OnlinePOSTagger

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.