Package edu.stanford.nlp.international.arabic.process

Source Code of edu.stanford.nlp.international.arabic.process.ArabicSegmenter

package edu.stanford.nlp.international.arabic.process;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Serializable;
import java.io.StringReader;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import edu.stanford.nlp.ie.crf.CRFClassifier;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.process.TokenizerFactory;
import edu.stanford.nlp.process.WordSegmenter;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;

/**
* Arabic word segmentation model based on conditional random fields (CRF).
* This is a re-implementation (with extensions) of the model described in
* (Green and DeNero, 2012).
* <p>
* This package includes a JFlex-based orthographic normalization package
* that runs on the input prior to processing by the CRF-based segmentation
* model. The normalization options are configurable, but must be consistent for
* both training and test data.
*
* @author Spence Green
*/
public class ArabicSegmenter implements WordSegmenter, Serializable, ThreadsafeProcessor<String,String> {

  private static final long serialVersionUID = -4791848633597417788L;

  // SEGMENTER OPTIONS (can be set in the Properties object
  // passed to the constructor).

  // The input already been tokenized. Do not run the Arabic tokenizer.
  private static final String optTokenized = "tokenized";

  // Tokenizer options
  private static final String optTokenizer = "orthoOptions";

  // Mark segmented prefixes with this String
  private static final String optPrefix = "prefixMarker";

  // Mark segmented suffixes with this String
  private static final String optSuffix = "suffixMarker";

  // Number of decoding threads
  private static final String optThreads = "nthreads";

  // Write TedEval files
  private static final String optTedEval = "tedEval";
 
  // Use a custom feature factory
  private static final String optFeatureFactory = "featureFactory";
  private static final String defaultFeatureFactory =
      "edu.stanford.nlp.international.arabic.process.StartAndEndArabicSegmenterFeatureFactory";
  private static final String localOnlyFeatureFactory =
      "edu.stanford.nlp.international.arabic.process.ArabicSegmenterFeatureFactory";

  // Training and evaluation files have domain labels
  private static final String optWithDomains = "withDomains";
 
  // Training and evaluation text are all in the same domain (default:atb)
  private static final String optDomain = "domain";
 
  // Ignore rewrites (training only, produces a model that then can be used to do
  // no-rewrite segmentation)
  private static final String optNoRewrites = "noRewrites";
 
  // Use the original feature set which doesn't contain start-and-end "wrapper" features
  private static final String optLocalFeaturesOnly = "localFeaturesOnly";

  private transient CRFClassifier<CoreLabel> classifier;
  private final SeqClassifierFlags flags;
  private final TokenizerFactory<CoreLabel> tf;
  private final String prefixMarker;
  private final String suffixMarker;
  private final boolean isTokenized;
  private final String tokenizerOptions;
  private final String tedEvalPrefix;
  private final boolean hasDomainLabels;
  private final String domain;
  private final boolean noRewrites;

  /**
   * Make an Arabic Segmenter.
   *
   *  @param props Options for how to tokenize. See the main method of {@see ArabicTokenizer} for details
   */
  public ArabicSegmenter(Properties props) {
    isTokenized = props.containsKey(optTokenized);
    tokenizerOptions = props.getProperty(optTokenizer, null);
    tedEvalPrefix = props.getProperty(optTedEval, null);
    hasDomainLabels = props.containsKey(optWithDomains);
    domain = props.getProperty(optDomain, "atb");
    noRewrites = props.containsKey(optNoRewrites);
    tf = getTokenizerFactory();

    prefixMarker = props.getProperty(optPrefix, "");
    suffixMarker = props.getProperty(optSuffix, "");

    if (props.containsKey(optLocalFeaturesOnly)) {
      if (props.containsKey(optFeatureFactory))
        throw new RuntimeException("Cannot use custom feature factory with localFeaturesOnly flag--" +
            "have your custom feature factory extend ArabicSegmenterFeatureFactory instead of " +
            "StartAndEndArabicSegmenterFeatureFactory and remove the localFeaturesOnly flag.");
     
      props.put(optFeatureFactory, localOnlyFeatureFactory);
    }
    if (!props.containsKey(optFeatureFactory))
      props.put(optFeatureFactory, defaultFeatureFactory);
   
    // Remove all command-line properties that are specific to ArabicSegmenter
    props.remove(optTokenizer);
    props.remove(optTokenized);
    props.remove(optPrefix);
    props.remove(optSuffix);
    props.remove(optThreads);
    props.remove(optTedEval);
    props.remove(optWithDomains);
    props.remove(optDomain);
    props.remove(optNoRewrites);
    props.remove(optLocalFeaturesOnly);

    flags = new SeqClassifierFlags(props);
    classifier = new CRFClassifier<CoreLabel>(flags);
  }

  /**
   * Copy constructor.
   *
   * @param other
   */
  public ArabicSegmenter(ArabicSegmenter other) {
    isTokenized = other.isTokenized;
    tokenizerOptions = other.tokenizerOptions;
    prefixMarker = other.prefixMarker;
    suffixMarker = other.suffixMarker;
    tedEvalPrefix = other.tedEvalPrefix;
    hasDomainLabels = other.hasDomainLabels;
    domain = other.domain;
    noRewrites = other.noRewrites;
    flags = other.flags;

    // ArabicTokenizerFactory is *not* threadsafe. Make a new copy.
    tf = getTokenizerFactory();

    // CRFClassifier is threadsafe, so return a reference.
    classifier = other.classifier;
  }

  /**
   * Creates an ArabicTokenizer. The default tokenizer
   * is ArabicTokenizer.atbFactory(), which produces the
   * same orthographic normalization as Green and Manning (2010).
   *
   * @return A TokenizerFactory that produces each Arabic token as a CoreLabel
   */
  private TokenizerFactory<CoreLabel> getTokenizerFactory() {
    TokenizerFactory<CoreLabel> tokFactory = null;
    if ( ! isTokenized) {
      if (tokenizerOptions == null) {
        tokFactory = ArabicTokenizer.atbFactory();
        String atbVocOptions = "removeProMarker,removeMorphMarker,removeLengthening";
        tokFactory.setOptions(atbVocOptions);
      } else {
        if (tokenizerOptions.contains("removeSegMarker")) {
          throw new RuntimeException("Option 'removeSegMarker' cannot be used with ArabicSegmenter");
        }
        tokFactory = ArabicTokenizer.factory();
        tokFactory.setOptions(tokenizerOptions);
      }
      System.err.println("Loaded ArabicTokenizer with options: " + tokenizerOptions);
    }
    return tokFactory;
  }

  @Override
  public void initializeTraining(double numTrees) {
    throw new UnsupportedOperationException("Training is not supported!");
  }

  @Override
  public void train(Collection<Tree> trees) {
    throw new UnsupportedOperationException("Training is not supported!");
  }

  @Override
  public void train(Tree tree) {
    throw new UnsupportedOperationException("Training is not supported!");
  }

  @Override
  public void train(List<TaggedWord> sentence) {
    throw new UnsupportedOperationException("Training is not supported!");
  }

  @Override
  public void finishTraining() {
    throw new UnsupportedOperationException("Training is not supported!");
  }

  @Override
  public String process(String nextInput) {
    return segmentString(nextInput);
  }

  @Override
  public ThreadsafeProcessor<String, String> newInstance() {
    return new ArabicSegmenter(this);
  }

  @Override
  public List<HasWord> segment(String line) {
    String segmentedString = segmentString(line);
    return Sentence.toWordList(segmentedString.split("\\s+"));
  }

  public String segmentString(String line) {
    List<CoreLabel> tokenList;
    if (tf == null) {
      // Whitespace tokenization.
      tokenList = IOBUtils.StringToIOB(line);
    } else {
      List<CoreLabel> tokens = tf.getTokenizer(new StringReader(line)).tokenize();
      tokenList = IOBUtils.StringToIOB(tokens, null, false);
    }
    IOBUtils.labelDomain(tokenList, domain);
    tokenList = classifier.classify(tokenList);
    String segmentedString = IOBUtils.IOBToString(tokenList, prefixMarker, suffixMarker);
    return segmentedString;
  }

  /**
   * Segment all strings from an input.
   *
   * @param br -- input stream to segment
   * @param pwOut -- output stream to write the segmenter text
   * @return number of input characters segmented
   */
  public long segment(BufferedReader br, PrintWriter pwOut) {
    long nSegmented = 0;
    try {
      for (String line; (line = br.readLine()) != null;) {
        nSegmented += line.length(); // Measure this quantity since it is quick to compute
        String segmentedLine = segmentString(line);
        pwOut.println(segmentedLine);
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
    return nSegmented;
  }

  /**
   * Train a segmenter from raw text. Gold segmentation markers are required.
   */
  public void train() {
    boolean hasSegmentationMarkers = true;
    boolean hasTags = true;
    DocumentReaderAndWriter<CoreLabel> docReader = new ArabicDocumentReaderAndWriter(hasSegmentationMarkers,
                                                                                     hasTags,
                                                                                     hasDomainLabels,
                                                                                     domain,
                                                                                     noRewrites,
                                                                                     tf);
    ObjectBank<List<CoreLabel>> lines =
      classifier.makeObjectBankFromFile(flags.trainFile, docReader);

    classifier.train(lines, docReader);
    System.err.println("Finished training.");
  }

  /**
   * Evaluate accuracy when the input is gold segmented text *with* segmentation
   * markers and morphological analyses. In other words, the evaluation file has the
   * same format as the training data.
   *
   * @param pwOut
   */
  private void evaluate(PrintWriter pwOut) {
    System.err.println("Starting evaluation...");
    boolean hasSegmentationMarkers = true;
    boolean hasTags = true;
    DocumentReaderAndWriter<CoreLabel> docReader = new ArabicDocumentReaderAndWriter(hasSegmentationMarkers,
                                                                                     hasTags,
                                                                                     hasDomainLabels,
                                                                                     domain,
                                                                                     tf);
    ObjectBank<List<CoreLabel>> lines =
      classifier.makeObjectBankFromFile(flags.testFile, docReader);
   
    PrintWriter tedEvalGoldTree = null, tedEvalParseTree = null;
    PrintWriter tedEvalGoldSeg = null, tedEvalParseSeg = null;
    if (tedEvalPrefix != null) {
      try {
        tedEvalGoldTree = new PrintWriter(tedEvalPrefix + "_gold.ftree");
        tedEvalGoldSeg = new PrintWriter(tedEvalPrefix + "_gold.segmentation");
        tedEvalParseTree = new PrintWriter(tedEvalPrefix + "_parse.ftree");
        tedEvalParseSeg = new PrintWriter(tedEvalPrefix + "_parse.segmentation");
      } catch (FileNotFoundException e) {
        System.err.printf("%s: %s%n", ArabicSegmenter.class.getName(), e.getMessage());
      }
    }

    Counter<String> labelTotal = new ClassicCounter<String>();
    Counter<String> labelCorrect = new ClassicCounter<String>();
    int total = 0;
    int correct = 0;
    for (List<CoreLabel> line : lines) {
      final String[] inputTokens = tedEvalSanitize(IOBUtils.IOBToString(line).replaceAll(":", "#pm#")).split(" ");
      final String[] goldTokens = tedEvalSanitize(IOBUtils.IOBToString(line, ":")).split(" ");
      line = classifier.classify(line);
      final String[] parseTokens = tedEvalSanitize(IOBUtils.IOBToString(line, ":")).split(" ");
      for (CoreLabel label : line) {
        // Do not evaluate labeling of whitespace
        String observation = label.get(CoreAnnotations.CharAnnotation.class);
        if ( ! observation.equals(IOBUtils.getBoundaryCharacter())) {
          total++;
          String hypothesis = label.get(CoreAnnotations.AnswerAnnotation.class);
          String reference = label.get(CoreAnnotations.GoldAnswerAnnotation.class);
          labelTotal.incrementCount(reference);
          if (hypothesis.equals(reference)) {
            correct++;
            labelCorrect.incrementCount(reference);
          }
        }
      }
      if (tedEvalParseSeg != null) {
        tedEvalGoldTree.printf("(root");
        tedEvalParseTree.printf("(root");
        int safeLength = inputTokens.length;
        if (inputTokens.length != goldTokens.length) {
          System.err.println("In generating TEDEval files: Input and gold do not have the same number of tokens");
          System.err.println("    (ignoring any extras)");
          System.err.println("  input: " + Arrays.toString(inputTokens));
          System.err.println("  gold: " + Arrays.toString(goldTokens));
          safeLength = Math.min(inputTokens.length, goldTokens.length);
        }
        if (inputTokens.length != parseTokens.length) {
          System.err.println("In generating TEDEval files: Input and parse do not have the same number of tokens");
          System.err.println("    (ignoring any extras)");
          System.err.println("  input: " + Arrays.toString(inputTokens));
          System.err.println("  parse: " + Arrays.toString(parseTokens));
          safeLength = Math.min(inputTokens.length, parseTokens.length);
        }
        for (int i = 0; i < safeLength; i++) {
          for (String segment : goldTokens[i].split(":"))
            tedEvalGoldTree.printf(" (seg %s)", segment);
          tedEvalGoldSeg.printf("%s\t%s%n", inputTokens[i], goldTokens[i]);
          for (String segment : parseTokens[i].split(":"))
            tedEvalParseTree.printf(" (seg %s)", segment);
          tedEvalParseSeg.printf("%s\t%s%n", inputTokens[i], parseTokens[i]);
        }
        tedEvalGoldTree.printf(")%n");
        tedEvalGoldSeg.println();
        tedEvalParseTree.printf(")%n");
        tedEvalParseSeg.println();
      }
    }

    double accuracy = ((double) correct) / ((double) total);
    accuracy *= 100.0;

    pwOut.println("EVALUATION RESULTS");
    pwOut.printf("#datums:\t%d%n", total);
    pwOut.printf("#correct:\t%d%n", correct);
    pwOut.printf("accuracy:\t%.2f%n", accuracy);
    pwOut.println("==================");

    // Output the per label accuracies
    pwOut.println("PER LABEL ACCURACIES");
    for (String refLabel : labelTotal.keySet()) {
      double nTotal = labelTotal.getCount(refLabel);
      double nCorrect = labelCorrect.getCount(refLabel);
      double acc = (nCorrect / nTotal) * 100.0;
      pwOut.printf(" %s\t%.2f%n", refLabel, acc);
    }
   
    if (tedEvalParseSeg != null) {
      tedEvalGoldTree.close();
      tedEvalGoldSeg.close();
      tedEvalParseTree.close();
      tedEvalParseSeg.close();
    }
  }

  private String tedEvalSanitize(String str) {
    return str.replaceAll("\\(", "#lp#").replaceAll("\\)", "#rp#");
  }

  /**
   * Evaluate P/R/F1 when the input is raw text
   */
  private void evaluateRawText(PrintWriter pwOut) {
    // TODO(spenceg): Evaluate raw input w.r.t. a reference that might have different numbers
    // of characters per sentence. Need to implement a monotonic sequence alignment algorithm
    // to align the two character strings.
    //    String gold = flags.answerFile;
    //    String rawFile = flags.testFile;
    throw new RuntimeException("Not yet implemented!");
  }

  public void serializeSegmenter(String filename) {
    classifier.serializeClassifier(filename);
  }

  public void loadSegmenter(String filename, Properties p) {
    classifier = new CRFClassifier<CoreLabel>(p);
    try {
      classifier.loadClassifier(new File(filename), p);
    } catch (ClassCastException e) {
      e.printStackTrace();
    } catch (IOException e) {
      e.printStackTrace();
    } catch (ClassNotFoundException e) {
      e.printStackTrace();
    }
  }

  @Override
  public void loadSegmenter(String filename) {
    loadSegmenter(filename, new Properties());
  }


  private static String usage() {
    String nl = System.getProperty("line.separator");
    StringBuilder sb = new StringBuilder();
    sb.append("Usage: java ").append(ArabicSegmenter.class.getName()).append(" OPTS < file_to_segment").append(nl);
    sb.append(nl).append(" Options:").append(nl);
    sb.append("  -help                : Print this message.").append(nl);
    sb.append("  -orthoOptions str    : Comma-separated list of orthographic normalization options to pass to ArabicTokenizer.").append(nl);
    sb.append("  -tokenized           : Text is already tokenized. Do not run internal tokenizer.").append(nl);
    sb.append("  -trainFile file      : Gold segmented IOB training file.").append(nl);
    sb.append("  -testFile  file      : Gold segmented IOB evaluation file.").append(nl);
    sb.append("  -textFile  file      : Raw input file to be segmented.").append(nl);
    sb.append("  -loadClassifier file : Load serialized classifier from file.").append(nl);
    sb.append("  -prefixMarker char   : Mark segmented prefixes with specified character.").append(nl);
    sb.append("  -suffixMarker char   : Mark segmented suffixes with specified character.").append(nl);
    sb.append("  -nthreads num        : Number of threads  (default: 1)").append(nl);
    sb.append("  -tedEval prefix      : Output TedEval-compliant gold and parse files.").append(nl);
    sb.append("  -featureFactory cls  : Name of feature factory class  (default: ").append(defaultFeatureFactory);
    sb.append(")").append(nl);
    sb.append("  -withDomains         : Train file (if given) and eval file have domain labels.").append(nl);
    sb.append("  -domain dom          : Assume one domain for all data (default: 123)").append(nl);
    sb.append(nl).append(" Otherwise, all flags correspond to those present in SeqClassifierFlags.java.").append(nl);
    return sb.toString();
  }

  private static Map<String,Integer> optionArgDefs() {
    Map<String,Integer> optionArgDefs = Generics.newHashMap();
    optionArgDefs.put("help", 0);
    optionArgDefs.put("orthoOptions", 1);
    optionArgDefs.put("tokenized", 0);
    optionArgDefs.put("trainFile", 1);
    optionArgDefs.put("testFile", 1);
    optionArgDefs.put("textFile", 1);
    optionArgDefs.put("loadClassifier", 1);
    optionArgDefs.put("prefixMarker", 1);
    optionArgDefs.put("suffixMarker", 1);
    optionArgDefs.put("nthreads", 1);
    optionArgDefs.put("tedEval", 1);
    optionArgDefs.put("featureFactory", 1);
    optionArgDefs.put("withDomains", 0);
    optionArgDefs.put("domain", 1);
    return optionArgDefs;
  }

  /**
   *
   * @param args
   */
  public static void main(String[] args) {
    // Strips off hyphens
    Properties options = StringUtils.argsToProperties(args, optionArgDefs());
    if (options.containsKey("help") || args.length == 0) {
      System.err.println(usage());
      System.exit(-1);
    }

    int nThreads = PropertiesUtils.getInt(options, "nthreads", 1);
    ArabicSegmenter segmenter = getSegmenter(options);

    // Decode either an evaluation file or raw text
    try {
      PrintWriter pwOut;
      if (segmenter.flags.outputEncoding != null) {
        OutputStreamWriter out = new OutputStreamWriter(System.out, segmenter.flags.outputEncoding);
        pwOut = new PrintWriter(out, true);
      } else if (segmenter.flags.inputEncoding != null) {
        OutputStreamWriter out = new OutputStreamWriter(System.out, segmenter.flags.inputEncoding);
        pwOut = new PrintWriter(out, true);
      } else {
        pwOut = new PrintWriter(System.out, true);
      }
      if (segmenter.flags.testFile != null) {
        if (segmenter.flags.answerFile == null) {
          segmenter.evaluate(pwOut);
        } else {
          segmenter.evaluateRawText(pwOut);
        }

      } else {
        BufferedReader br = (segmenter.flags.textFile == null) ?
            new BufferedReader(new InputStreamReader(System.in)) :
              new BufferedReader(new InputStreamReader(new FileInputStream(segmenter.flags.textFile),
                  segmenter.flags.inputEncoding));

        double charsPerSec = decode(segmenter, br, pwOut, nThreads);
        IOUtils.closeIgnoringExceptions(br);
        System.err.printf("Done! Processed input text at %.2f input characters/second%n", charsPerSec);
      }

    } catch (UnsupportedEncodingException e) {
      e.printStackTrace();
    } catch (FileNotFoundException e) {
      System.err.printf("%s: Could not open %s%n", ArabicSegmenter.class.getName(), segmenter.flags.textFile);
    }
  }

  /**
   * Segment input and write to output stream.
   *
   * @param segmenter
   * @param br
   * @param pwOut
   * @param nThreads
   * @return input characters processed per second
   */
  private static double decode(ArabicSegmenter segmenter, BufferedReader br,
                               PrintWriter pwOut, int nThreads) {
    assert nThreads > 0;
    long nChars = 0;
    final long startTime = System.nanoTime();
    if (nThreads > 1) {
      MulticoreWrapper<String,String> wrapper = new MulticoreWrapper<String,String>(nThreads, segmenter);
      try {
        for (String line; (line = br.readLine()) != null;) {
          nChars += line.length();
          wrapper.put(line);
          while (wrapper.peek()) {
            pwOut.println(wrapper.poll());
          }
        }

        wrapper.join();
        while (wrapper.peek()) {
          pwOut.println(wrapper.poll());
        }

      } catch (IOException e) {
        e.printStackTrace();
      }

    } else {
      nChars = segmenter.segment(br, pwOut);
    }
    long duration = System.nanoTime() - startTime;
    double charsPerSec = (double) nChars / (duration / 1000000000.0);
    return charsPerSec;
  }

  /**
   * Train a new segmenter or load an trained model from file.  First
   * checks to see if there is a "model" or "loadClassifier" flag to
   * load from, and if not tries to run training using the given
   * options.
   *
   * @param options
   * @return the trained or loaded model
   */
  private static ArabicSegmenter getSegmenter(Properties options) {
    ArabicSegmenter segmenter = new ArabicSegmenter(options);
    if (segmenter.flags.inputEncoding == null) {
      segmenter.flags.inputEncoding = System.getProperty("file.encoding");
    }

    // Load or train the classifier
    if (segmenter.flags.loadClassifier != null) {
      segmenter.loadSegmenter(segmenter.flags.loadClassifier, options);
    } else if (segmenter.flags.trainFile != null){
      segmenter.train();

      if(segmenter.flags.serializeTo != null) {
        segmenter.serializeSegmenter(segmenter.flags.serializeTo);
        System.err.println("Serialized segmenter to: " + segmenter.flags.serializeTo);
      }
    } else {
      System.err.println("No training file or trained model specified!");
      System.err.println(usage());
      System.exit(-1);
    }
    return segmenter;
  }

}
TOP

Related Classes of edu.stanford.nlp.international.arabic.process.ArabicSegmenter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.