Package opennlp.ccg.perceptron

Source Code of opennlp.ccg.perceptron.Model$FeatureFilter

///////////////////////////////////////////////////////////////////////////////
// Copyright (C) 2008 Michael White
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//////////////////////////////////////////////////////////////////////////////

package opennlp.ccg.perceptron;

import java.util.*;
import java.io.*;

import opennlp.ccg.perceptron.Alphabet.Feature;
import opennlp.ccg.util.Pair;

/**
* A model is a vector of weights for an alphabet.
*
* A model can be read from a file, which starts with the number
* of features on one line, followed by one line per feature pairing
* the feature name with its weight.
*
* A new model with all zero weights can also be created from an alphabet.
*
* The main routine tests the model on an event file.
*
* @author Michael White
* @version     $Revision: 1.7 $, $Date: 2009/11/09 18:54:30 $
*/
public class Model {

  /** Filter interface for adjusting feature weights when loading a model. */
  public interface FeatureFilter {
    /** Returns the modified feature weight for the given feature. */
    public double adjustedWeight(String name, double weight);
  }

  /** Flag for whether to print debugging info to System.err. */
  public boolean debug = false;
 
  // weight vector
  private double[] weights;
 
  // alphabet
  private Alphabet alphabet;

 
  /** Constructor with alphabet, for a new model with all zero weights. */
  public Model(Alphabet alphabet) {
    this.alphabet = alphabet;
    weights = new double[alphabet.size()];
    zero();
  }
 
  /** Constructor to load a model from a file. The alphabet is set to closed. */
  public Model(String filename) throws IOException {
      this(filename, null);
  }

  /**
   * Constructor to load a model from a file using a feature filter. The
   * alphabet is set to closed.
   */
  public Model(String filename, FeatureFilter filter) throws IOException {
    Reader reader = EventFile.openReader(new File(filename));
    StreamTokenizer tokenizer = EventFile.initTokenizer(reader);
    tokenizer.nextToken();
    int size = Integer.parseInt(tokenizer.sval);
    alphabet = new Alphabet(size);
    weights = new double[size];
    for (int i = 0; i < size; i++) {
      tokenizer.nextToken();
      String name = tokenizer.sval;
      tokenizer.nextToken();
      double weight = Double.parseDouble(tokenizer.sval);
      if (filter != null)
        weight = filter.adjustedWeight(name, weight);
      alphabet.add(name);
      weights[i] = weight;
    }
    reader.close();
    alphabet.setClosed(true);
  }
 
  /** Returns the size of the model. */
  public int size() { return weights.length; }
 
  /** Returns the alphabet. */
  public Alphabet getAlphabet() { return alphabet; }
 
 
  /** Returns the weight for the given index. */
  public double getWeight(int index) { return weights[index]; }
 
  /** Returns the weight for the given feature. */
  public double getWeight(String feat) { return weights[alphabet.index(feat).getIndex()]; }
 
  /** Returns the weight for the given feature. */
  public double getWeight(Alphabet.Feature f) { return weights[f.getIndex()]; }
 
  /** Sets the weight for the given index. */
  public void setWeight(int index, double weight) { weights[index] = weight; }
 
  /** Sets the weight for the given feature. */
  public void setWeight(String feat, double weight) { weights[alphabet.index(feat).getIndex()] = weight; }

  /** Sets the weight for the given feature. */
  public void setWeight(Alphabet.Feature f, double weight) { weights[f.getIndex()] = weight; }

 
  /** Returns the dot product of the weights and features. */
  public double score(FeatureVector fv) {
    double retval = 0.0;
    for (FeatureVector.Iterator it = fv.iterator(); it.hasNext(); ) {
      Feature feat = it.nextFeature();
      Float value = it.nextValue();
      Integer index = feat.getIndex();
      if (index == null) continue;
      retval += weights[index] * value;
    }
    if (debug) System.err.println("score: " + retval + " " + fv);
    return retval;
  }
 
  /** Adds the feature vector values to the weights. */
  public void add(FeatureVector fv) {
    for (FeatureVector.Iterator it = fv.iterator(); it.hasNext(); ) {
      Feature feat = it.nextFeature();
      Float value = it.nextValue();
      Integer index = feat.getIndex();
      if (index == null) continue;
      weights[index] += value;
    }
  }
 
  /** Subtracts the feature vector values from the weights. */
  public void subtract(FeatureVector fv) {
    for (FeatureVector.Iterator it = fv.iterator(); it.hasNext(); ) {
      Feature feat = it.nextFeature();
      Float value = it.nextValue();
      Integer index = feat.getIndex();
      if (index == null) continue;
      weights[index] -= value;
    }
  }

 
  /** Adds the given model's weights to this model.  The models are assumed to share the same alphabet. */
  public void add(Model model) {
    for (int i=0; i < weights.length; i++) {
      weights[i] += model.weights[i];
    }
  }
 
  /** Multiplies the weights by the given number. */
  public void multiply(double num) {
    for (int i=0; i < weights.length; i++) {
      weights[i] *= num;
    }
  }

  /** Resets the weights to zero. */
  public void zero() {
    for (int i=0; i < weights.length; i++) {
      weights[i] = 0.0;
    }
  }

 
  /** Sets this model's weights to the given model's ones, where the alphabets intersect. */
  public void set(Model model) {
    zero();
    for (int i=0; i < model.weights.length; i++) {
      Alphabet.Feature f = model.alphabet.feature(i);
      Alphabet.Feature f0 = alphabet.index(f);
      if (f0 == null || f0.getIndex() == null) continue;
      weights[f0.getIndex()] = model.weights[i];
    }
  }
 
 
  /** Returns the best event (first tied if ties). */
  public FeatureVector best(List<FeatureVector> fvs) {
    FeatureVector retval = null; double max = Double.NEGATIVE_INFINITY;
    for (FeatureVector fv : fvs) {
      double score = score(fv);
      if (score > max) { retval = fv; max = score; }
    }
    return retval;
  }
 
  /** Returns the best event (first tied if ties). */
  public EventFile.Event best(EventFile.Block block) {
    EventFile.Event retval = null; double max = Double.NEGATIVE_INFINITY;
    for (EventFile.Event event : block.events) {
      double score = score(event.features);
      if (score > max) { retval = event; max = score; }
    }
    return retval;
  }
 
  /** Returns the accuracy on the event file. */
  public double accuracy(EventFile eventFile) throws IOException {
    if (alphabet != eventFile.getAlphabet()) {
      throw new RuntimeException("Model and EventFile must share the same alphabet!");
    }
    int correct = 0; int total = 0;
    EventFile.Block block;
    while ( (block = eventFile.nextBlock()) != null ) {
      total++;
      if (best(block) == block.best()) {
        correct++;
        if (debug) System.err.println("CORRECT");
      }
      else {
        if (debug) System.err.println("WRONG; best: " + block.best());
      }
    }
    if (debug) System.err.println("correct: " + correct + " total: " + total);
    return 1.0 * correct / total;
  }
 
 
  /** Saves the model to a file, filtering out zero weights. */
  public void save(String filename) throws IOException { save(filename, 0.0); }
 
  /** Saves the model to a file, filtering out weights whose absolute value does not exceed the pruning value. */
  public void save(String filename, double minPrune) throws IOException {
    File file = new File(filename);
    PrintWriter out = EventFile.openWriter(file);
    // calc pruned size
    int size = size();
    int pruned = 0;
    for (int i=0; i < size; i++) if (Math.abs(weights[i]) <= minPrune) pruned++;
    int prunedSize = size - pruned;
    // write pruned size
    out.println(Integer.toString(prunedSize));
    // collect unpruned weights
    List<Pair<Feature,Double>> featWeights = new ArrayList<Pair<Feature,Double>>(prunedSize);
    for (int i=0; i < size; i++) {
      if (Math.abs(weights[i]) <= minPrune) continue;
      featWeights.add(new Pair<Feature,Double>(alphabet.feature(i), weights[i]));
    }
    // sort weights by descending absolute value
    // (further sorting alphabetically may take too long)
    Collections.sort(
        featWeights,
        new Comparator<Pair<Feature,Double>>() {
          public int compare(Pair<Feature,Double> entry1, Pair<Feature,Double> entry2) {
            double val1 = Math.abs(entry1.b); double val2 = Math.abs(entry2.b);
            if (val1 > val2) return -1;
            if (val1 < val2) return 1;
            return 0;
            //return entry1.a.name().compareTo(entry2.a.name());
          }
        }
    );
    // write sorted weights
    for (Pair<Feature,Double> fw : featWeights) {
      out.println(fw.a.name() + " " + fw.b);
    }
    out.close();
  }
 
 
  /**
   * Loads a model from a file and tests it on the given event file.
   */
  public static void main(String[] args) throws IOException {
    // help
    if (args.length < 2) {
      System.out.println("Usage: java perceptron.Model <modelfile> <eventfile> (-debug)");
      System.exit(0);
    }
    // args
    String modelfile = args[0];
    String eventfile = args[1];
    boolean debug = Arrays.asList(args).contains("-debug");
    // load model
    System.out.println("Loading model from: " + modelfile);
    Model model = new Model(modelfile);
    model.debug = debug;
    System.out.println("model size: " + model.size());
    System.out.println("debug: " + debug);
    // compute accuracy
    System.out.println("Scoring events in: " + eventfile);
    EventFile eventFile = new EventFile(eventfile, model.alphabet);
    double accuracy = model.accuracy(eventFile);
    System.out.println("accuracy: " + accuracy);
  }
}
TOP

Related Classes of opennlp.ccg.perceptron.Model$FeatureFilter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.