Package org.fnlp.nlp.parser.dep.train

Source Code of org.fnlp.nlp.parser.dep.train.ParserTrainer

/**
*  This file is part of FNLP (formerly FudanNLP).
*  FNLP is free software: you can redistribute it and/or modify
*  it under the terms of the GNU Lesser General Public License as published by
*  the Free Software Foundation, either version 3 of the License, or
*  (at your option) any later version.
*  FNLP is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*  You should have received a copy of the GNU General Public License
*  along with FudanNLP.  If not, see <http://www.gnu.org/licenses/>.
*  Copyright 2009-2014 www.fnlp.org. All rights reserved.
*/

package org.fnlp.nlp.parser.dep.train;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPOutputStream;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.linear.OnlineTrainer;
import org.fnlp.ml.classifier.linear.inf.LinearMax;
import org.fnlp.ml.classifier.linear.update.LinearMaxPAUpdate;
import org.fnlp.ml.classifier.linear.update.Update;
import org.fnlp.ml.feature.SFGenerator;
import org.fnlp.ml.loss.ZeroOneLoss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.alphabet.IFeatureAlphabet;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.ml.types.alphabet.StringFeatureAlphabet;
import org.fnlp.ml.types.alphabet.AlphabetFactory.Type;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.nlp.parser.Sentence;
import org.fnlp.nlp.parser.dep.ParsingState;
import org.fnlp.nlp.parser.dep.ParsingState.Action;
import org.fnlp.nlp.parser.dep.reader.CoNLLReader;

import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.map.hash.TIntObjectHashMap;

/**
* 句法分析器训练类
*
* @version Feb 16, 2009
*/
public class ParserTrainer {

  String modelfile;
  Charset charset;
  File fp;
  AlphabetFactory factory;

  /**
   * 构造函数
   * @param data
   *            训练文件的目录
   * @throws Exception
   */
  public ParserTrainer(String data) {
    this(data, "UTF-8");
    factory = AlphabetFactory.buildFactory();
  }

  /**
   * 构造函数
   *
   * @param dataPath
   *            训练文件的目录
   * @param charset
   *            文件编码
   * @throws Exception
   */
  public ParserTrainer(String dataPath, String charset) {
    this.modelfile = dataPath;
    this.charset = Charset.forName(charset);
  }

  /**
   * 生成训练实例
   *
   * 以Yamada分析算法,从单个文件中生成特征以及训练样本
   *
   * @param file
   *            单个训练文件
   * @throws Exception
   */
  private void buildInstanceList(String file) throws IOException {

    System.out.print("generating training instances ...");

    CoNLLReader reader = new CoNLLReader(file);

    BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(
        new FileOutputStream(fp), charset));
   
    LabelAlphabet postagAlphabet = factory.buildLabelAlphabet("postag");

    int count = 0;
    while (reader.hasNext()) {

      Sentence instance = (Sentence) reader.next();
      int[] heads = (int[]) instance.getTarget();
      ParsingState state = new ParsingState(instance,factory);
      while (!state.isFinalState()) {
        // 左右焦点词在句子中的位置
        int[] lr = state.getFocusIndices();

        HashSparseVector features = state.getFeatures();
        ParsingState.Action action = getAction(lr[0], lr[1],
            heads);
        state.next(action);
        if (action == ParsingState.Action.LEFT)
          heads[lr[1]] = -1;
        if (action == ParsingState.Action.RIGHT)
          heads[lr[0]] = -1;

        // writer.write(String.valueOf(instance.postags[lr[0]]));
        String pos = instance.getTagAt(lr[0]);
        postagAlphabet.lookupIndex(pos);
        writer.write(pos);
        writer.write(" ");
        switch (action) {
        case LEFT:
          writer.write("L");
          break;
        case RIGHT:
          writer.write("R");
          break;
        default:
          writer.write("S");
        }
        writer.write(" ");
        int[] idx = features.indices();
        Arrays.sort(idx);
        for (int i = 0; i < idx.length; i++) {
          writer.write(String.valueOf(idx[i]));
          writer.write(" ");
        }
        writer.newLine();

      }
      writer.write('\n');
      writer.flush();
      count++;
    }
    writer.close();

    System.out.println(" ... finished");
    System.out.printf("%d instances have benn loaded.\n\n", count);
  }

  /**
   * 模型训练函数
   *
   * @param dataFile
   *            训练文件
   * @param maxite
   *            最大迭代次数
   * @throws IOException
   * @throws Exception
   */
  public void train(String dataFile, int maxite, float c) throws IOException {

    fp = File.createTempFile("train-features", null, new File("./tmp/"));

    buildInstanceList(dataFile);

    LabelAlphabet postagAlphabet = factory.buildLabelAlphabet("postag");

    IFeatureAlphabet features = factory.DefaultFeatureAlphabet();

    SFGenerator generator = new SFGenerator();
    Linear[] models = new Linear[postagAlphabet.size()];
    int fsize = features.size();

    for (int i = 0; i < postagAlphabet.size(); i++) {
      String pos = postagAlphabet.lookupString(i);
      InstanceSet instset = readInstanceSet(pos);
      LabelAlphabet alphabet = factory.buildLabelAlphabet(pos);
      int ysize = alphabet.size();
      System.out.printf("Training with data: %s\n", pos);
      System.out.printf("Number of labels: %d\n", ysize);
      LinearMax solver = new LinearMax(generator, ysize);
      ZeroOneLoss loss = new ZeroOneLoss();
      Update update = new LinearMaxPAUpdate(loss);
      OnlineTrainer trainer = new OnlineTrainer(solver, update, loss,
          fsize, maxite, c);
      models[i] = trainer.train(instset, null);
      instset = null;
      solver = null;
      loss = null;
      trainer = null;
      System.out.println();
    }
    factory.setStopIncrement(true);
    saveModels(modelfile, models,factory);

    fp.delete();
    fp = null;
  }

 

  /**
   * 读取样本
   *
   * 根据词性读取样本文件中的样本
   *
   * @param pos
   *            词性
   * @return 样本集
   * @throws Exception
   */
  private InstanceSet readInstanceSet(String pos) throws IOException {

    InstanceSet instset = new InstanceSet();

    LabelAlphabet labelAlphabet = factory.buildLabelAlphabet(pos);

    BufferedReader in = new BufferedReader(new InputStreamReader(
        new FileInputStream(fp), charset));

    String line = null;
    while ((line = in.readLine()) != null) {
      line = line.trim();
      if (line.matches("^$"))
        continue;
      if (line.startsWith(pos + " ")) {
        List<String> tokens = Arrays.asList(line.split("\\s+"));

        int[] data = new int[tokens.size() - 2];
        for (int i = 0; i < data.length; i++) {
          data[i] = Integer.parseInt(tokens.get(i + 2));
        }
        Instance inst = new Instance(data);
        inst.setTarget(labelAlphabet.lookupIndex(tokens.get(1)));

        instset.add(inst);
      }
    }

    in.close();

    labelAlphabet.setStopIncrement(true);
    instset.setAlphabetFactory(factory);

    return instset;
  }

  /**
   * 保存模型
   *
   * 以序列化的方式保存模型
   *
   * @param models
   *            模型参数
   * @param factory
   * @param pos
   *            词性
   * @throws Exception
   */
  public static void saveModels(String modelfile, Linear[] models, AlphabetFactory factory) throws IOException {

    ObjectOutputStream outstream = new ObjectOutputStream(
        new GZIPOutputStream(new FileOutputStream(modelfile)));
    outstream.writeObject(factory);
    outstream.writeObject(models);
    outstream.close();
  }

  /**
   * 根据已有的依赖关系,得到焦点词之间的应采取的动作
   *
   *
   * @param l
   *            左焦点词在句子中是第l个词
   * @param r
   *            右焦点词在句子中是第r个词
   * @param heads
   *            中心评词
   * @return 动作
   */
  private ParsingState.Action getAction(int l, int r, int[] heads) {
    if (heads[l] == r && modifierNumOf(l, heads) == 0)
      return ParsingState.Action.RIGHT;
    if (heads[r] == l && modifierNumOf(r, heads) == 0)
      return ParsingState.Action.LEFT;
    return ParsingState.Action.SHIFT;
  }

  private int modifierNumOf(int h, int[] heads) {
    int n = 0;
    for (int i = 0; i < heads.length; i++)
      if (heads[i] == h)
        n++;
    return n;
  }

  /**
   * 主文件
   *
   * @param args
   *            : 训练文件;模型文件;循环次数
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    args = new String[2];
    args[0] = "./tmp/CoNLL2009-ST-Chinese-train.txt";
    args[1] = "./tmp/modelConll.gz";
    Options opt = new Options();

    opt.addOption("h", false, "Print help for this application");
    opt.addOption("iter", true, "iterative num, default 50");
    opt.addOption("c", true, "parameters 1, default 1");

    BasicParser parser = new BasicParser();
    CommandLine cl;
    try {
      cl = parser.parse(opt, args);
    } catch (Exception e) {
      System.err.println("Parameters format error");
      return;
    }
   
    if (args.length == 0 || cl.hasOption('h')) {
      HelpFormatter f = new HelpFormatter();
      f.printHelp(
          "Tagger:\n"
              + "ParserTrainer [option] train_file model_file;\n",
          opt);
      return;
    }
    args = cl.getArgs();
    String datafile = args[0];
    String modelfile = args[1];
    int maxite = Integer.parseInt(cl.getOptionValue("iter", "50"));
    float c = Float.parseFloat(cl.getOptionValue("c", "1"));
   
    ParserTrainer trainer = new ParserTrainer(modelfile);
    trainer.train(datafile, maxite, c);
  }

}
TOP

Related Classes of org.fnlp.nlp.parser.dep.train.ParserTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.