Package org.fnlp.nlp.parser.dep.train

Source Code of org.fnlp.nlp.parser.dep.train.JointParerTrainer

/**
*  This file is part of FNLP (formerly FudanNLP).
*  FNLP is free software: you can redistribute it and/or modify
*  it under the terms of the GNU Lesser General Public License as published by
*  the Free Software Foundation, either version 3 of the License, or
*  (at your option) any later version.
*  FNLP is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*  You should have received a copy of the GNU General Public License
*  along with FudanNLP.  If not, see <http://www.gnu.org/licenses/>.
*  Copyright 2009-2014 www.fnlp.org. All rights reserved.
*/

package org.fnlp.nlp.parser.dep.train;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.linear.OnlineTrainer;
import org.fnlp.ml.classifier.linear.inf.LinearMax;
import org.fnlp.ml.classifier.linear.update.LinearMaxPAUpdate;
import org.fnlp.ml.classifier.linear.update.Update;
import org.fnlp.ml.feature.SFGenerator;
import org.fnlp.ml.loss.ZeroOneLoss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.alphabet.IFeatureAlphabet;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.nlp.parser.Sentence;
import org.fnlp.nlp.parser.Target;
import org.fnlp.nlp.parser.dep.JointParser;
import org.fnlp.nlp.parser.dep.JointParsingState;
import org.fnlp.nlp.parser.dep.reader.FNLPReader;

/**
* 句法分析器训练类
*
*/
public class JointParerTrainer{
  String modelfile;
  Charset charset;
  File fp;
  AlphabetFactory factory;

  /**
   * 构造函数
   * @param data
   *            训练文件的目录
   * @throws Exception
   */
  public JointParerTrainer(String data) {
    this(data, "UTF-8");
    factory = AlphabetFactory.buildFactory();
  }

  /**
   * 构造函数
   *
   * @param dataPath
   *            训练文件的目录
   * @param charset
   *            文件编码
   * @throws Exception
   */
  public JointParerTrainer(String dataPath, String charset) {
    this.modelfile = dataPath;
    this.charset = Charset.forName(charset);
  }

  /**
   * 生成训练实例
   *
   * 以Yamada分析算法,从单个文件中生成特征以及训练样本
   *
   * @param file
   *            单个训练文件
   * @return
   * @throws Exception
   */
  private InstanceSet buildInstanceList(String file) throws IOException {

    System.out.print("生成训练数据 ...");

    FNLPReader reader = new FNLPReader(file);
    FNLPReader preReader = new FNLPReader(file);
    InstanceSet instset = new InstanceSet();
   
    LabelAlphabet la = factory.DefaultLabelAlphabet();
    IFeatureAlphabet fa = factory.DefaultFeatureAlphabet();
    int count = 0;
   
    //preReader为了把ysize定下来
    la.lookupIndex("S");
    while(preReader.hasNext()){
      Sentence sent = (Sentence) preReader.next();
      Target targets = (Target)sent.getTarget();
      for(int i=0; i<sent.length(); i++){
        String label;
        if(targets.getHead(i) != -1){
          if(targets.getHead(i) < i){
            label = "L" + targets.getDepClass(i);
          }
          //else if(targets.getHead(i) > i){
          else{
            label = "R" + targets.getDepClass(i);
          }
          la.lookupIndex(label);
        }
      }
    }
    int ysize = la.size();
    la.setStopIncrement(true);
       
    while (reader.hasNext()) {
      Sentence sent = (Sentence) reader.next();
      //  int[] heads = (int[]) instance.getTarget();
      String depClass = null;
      Target targets = (Target)sent.getTarget();
      JointParsingState state = new JointParsingState(sent);
     
      while (!state.isFinalState()) {
        // 左右焦点词在句子中的位置
        int[] lr = state.getFocusIndices();

        ArrayList<String> features = state.getFeatures();
        JointParsingState.Action action = getAction(lr[0], lr[1],
            targets);
        switch (action) {
        case LEFT:
          depClass = targets.getDepClass(lr[1]);
          break;
        case RIGHT:
          depClass = targets.getDepClass(lr[0]);
          break;
        default:

        }
        state.next(action,depClass);
        if (action == JointParsingState.Action.LEFT)
          targets.setHeads(lr[1],-1);
        if (action == JointParsingState.Action.RIGHT)
          targets.setHeads(lr[0],-1);
        String label = "";
        switch (action) {
        case LEFT:
          label += "L"+sent.getDepClass(lr[1]);   
          break;
        case RIGHT:
          label+="R"+sent.getDepClass(lr[0]);
          break;
        default:
          label = "S";         
        }
        int id = la.lookupIndex(label);       
        Instance inst = new Instance();
        inst.setTarget(id);
        int[] idx = JointParser.addFeature(fa, features, ysize);
        inst.setData(idx);
        instset.add(inst);
      }
      count++;
//      System.out.println(count);
    }
   
    instset.setAlphabetFactory(factory);
    System.out.printf("共生成实例:%d个\n", count);
    return instset;
  }


  /**
   * 模型训练函数
   *
   * @param dataFile
   *            训练文件
   * @param maxite
   *            最大迭代次数
   * @throws IOException
   * @throws Exception
   */
  public void train(String dataFile, int maxite, float c) throws IOException {
   
    InstanceSet instset =  buildInstanceList(dataFile);
    IFeatureAlphabet features = factory.DefaultFeatureAlphabet();

    SFGenerator generator = new SFGenerator();
    int fsize = features.size();
   
    LabelAlphabet la = factory.DefaultLabelAlphabet();
    int ysize = la.size();
    System.out.printf("开始训练");
    LinearMax solver = new LinearMax(generator, ysize);
    ZeroOneLoss loss = new ZeroOneLoss();
    Update update = new LinearMaxPAUpdate(loss);
    OnlineTrainer trainer = new OnlineTrainer(solver, update, loss,
        fsize, maxite, c);
    Linear models = trainer.train(instset, null);
    instset = null;
    solver = null;
    loss = null;
    trainer = null;
    System.out.println();
    factory.setStopIncrement(true);
    models.saveTo(modelfile);

  }


  /**
   * 根据已有的依赖关系,得到焦点词之间的应采取的动作
   *
   *
   * @param l
   *            左焦点词在句子中是第l个词
   * @param r
   *            右焦点词在句子中是第r个词
   * @param heads
   *            中心评词
   * @return 动作
   */
  private JointParsingState.Action getAction(int l, int r, Target targets) {
    if (targets.getHead(l) == r && modifierNumOf(l, targets) == 0)
      return JointParsingState.Action.RIGHT;
    if (targets.getHead(r) == l && modifierNumOf(r, targets) == 0)
      return JointParsingState.Action.LEFT;
    return JointParsingState.Action.SHIFT;
  }

  private int modifierNumOf(int h, Target target) {
    int n = 0;
    for (int i = 0; i < target.size(); i++)
      if (target.getHead(i) == h)
        n++;
    return n;
  }

  /**
   * 主文件
   *
   * @param args
   *            : 训练文件;模型文件;循环次数
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    //    args = new String[2];
    //    args[0] = "./tmp/malt.train";
    //    args[1] = "./tmp/Malt2Model.gz";
    Options opt = new Options();

    opt.addOption("h", false, "Print help for this application");
    opt.addOption("iter", true, "iterative num, default 50");
    opt.addOption("c", true, "parameters 1, default 1");

    BasicParser parser = new BasicParser();
    CommandLine cl;
    try {
      cl = parser.parse(opt, args);
    } catch (Exception e) {
      System.err.println("Parameters format error");
      return;
    }

    if (args.length == 0 || cl.hasOption('h')) {
      HelpFormatter f = new HelpFormatter();
      f.printHelp(
          "Tagger:\n"
              + "ParserTrainer [option] train_file model_file;\n",
              opt);
      return;
    }
    args = cl.getArgs();
    String datafile = args[0];
    String modelfile = args[1];
    int maxite = Integer.parseInt(cl.getOptionValue("iter", "50"));
    float c = Float.parseFloat(cl.getOptionValue("c", "1"));

    JointParerTrainer trainer = new JointParerTrainer(modelfile);
    trainer.train(datafile, maxite, c);
  }

}
TOP

Related Classes of org.fnlp.nlp.parser.dep.train.JointParerTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.