Package org.fnlp.ml.classifier.hier

Source Code of org.fnlp.ml.classifier.hier.ModelAnalysis

/**
*  This file is part of FNLP (formerly FudanNLP).
*  FNLP is free software: you can redistribute it and/or modify
*  it under the terms of the GNU Lesser General Public License as published by
*  the Free Software Foundation, either version 3 of the License, or
*  (at your option) any later version.
*  FNLP is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*  You should have received a copy of the GNU General Public License
*  along with FudanNLP.  If not, see <http://www.gnu.org/licenses/>.
*  Copyright 2009-2014 www.fnlp.org. All rights reserved.
*/

package org.fnlp.ml.classifier.hier;

import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;

import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.alphabet.IFeatureAlphabet;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.ml.types.alphabet.StringFeatureAlphabet;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.util.MyHashSparseArrays;

import gnu.trove.iterator.TIntFloatIterator;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.hash.TIntObjectHashMap;

/**
* 优化模型文件,去掉无用的特征
* 权重向量为HashSparseVector[]
* @since FudanNLP 1.0
* @author xpqiu
*
*/
public class ModelAnalysis {

  private Linear cl;
  public AlphabetFactory factory;
  private float thresh = 0;
  HashSparseVector[] weights;
  private IFeatureAlphabet feature;
  private LabelAlphabet label;

  public ModelAnalysis(Linear cl) {
    this.cl = cl;
    this.factory = cl.factory;
    feature = factory.DefaultFeatureAlphabet();
    label = factory.DefaultLabelAlphabet();
    this.weights = cl.weights;
  }

  /**
   * 统计信息,计算删除非0特征后,权重的长度
   *
   * @throws IOException
   */
  public void removeZero() {
    boolean freeze = false;
    if (feature.isStopIncrement()) {
      feature.setStopIncrement(false);
      freeze = true;
    }
    TIntObjectHashMap<String> index = (TIntObjectHashMap<String>) feature.toInverseIndexMap();
   
    System.out.println("原字典大小"+index.size());
    System.out.println("原字典大小"+feature.size());
    StringFeatureAlphabet newfeat = new StringFeatureAlphabet();
    cl.factory.setDefaultFeatureAlphabet(newfeat);
    for(int i=0;i<weights.length;i++){
        TIntFloatIterator itt = weights[i].data.iterator();
        HashSparseVector ww = new HashSparseVector();
        while(itt.hasNext()){
          itt.advance();
          float v = itt.value();
          if(Math.abs(v)<1e-3f)
            continue;
          String fea = index.get(itt.key());
          int newidx = newfeat.lookupIndex(fea);
          ww.put(newidx, v);       
      }
      weights[i] = ww; 
    }
   
    newfeat.setStopIncrement(freeze);
    System.out.println("新字典大小"+newfeat.size());   
    System.out.println("新字典大小"+feature.size());   
    index.clear();   
  }
 
  /**
   * @param args
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
   
    String file = "./tmp/model/tree_model.gz";
    Linear cl = Linear.loadFrom(file);

   
    ModelAnalysis ma = new ModelAnalysis(cl);
    ma.getSalientFeatures("./tmp/model/tree_model",100);

//    ma.removeZero();
//    cl.saveTo(file+1);
    System.out.print("Done");
  }

  private void getSalientFeatures(String string, int topn) throws IOException {
    PrintWriter pw = new PrintWriter(new OutputStreamWriter(
        new FileOutputStream(string), "UTF-8"));
    TIntObjectHashMap<String> index = (TIntObjectHashMap<String>) feature.toInverseIndexMap();
    for(int i=0;i<weights.length;i++){
      int[] idx = MyHashSparseArrays.sort(weights[i].data);
      pw.println(label.lookupString(i));
      for(int j=0;j<topn;j++){
        pw.print(index.get(idx[j]));
        pw.print("\t");
        pw.println(weights[i].get(idx[j]));
      }
      pw.println();
    }
    pw.close();
  }

}
TOP

Related Classes of org.fnlp.ml.classifier.hier.ModelAnalysis

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.