Package de.jungblut.ner

Source Code of de.jungblut.ner.MaxEntMarkovModel

package de.jungblut.ner;

import java.util.Collections;

import com.google.common.base.Preconditions;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.ViterbiUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.math.minimize.Minimizer;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;

/**
* Maximum entropy markov model for named entity recognition (classifying labels
* in sequence learning).
*
* @author thomas.jungblut
*
*/
public final class MaxEntMarkovModel extends AbstractClassifier {

  private final Minimizer minimizer;
  private final boolean verbose;
  private final int numIterations;
  private DoubleMatrix theta;
  private int classes;

  public MaxEntMarkovModel(Minimizer minimizer, int numIterations,
      boolean verbose) {
    this.minimizer = minimizer;
    this.numIterations = numIterations;
    this.verbose = verbose;
  }

  public MaxEntMarkovModel(DenseDoubleMatrix theta, int classes) {
    this(null, -1, false);
    this.theta = theta;
    this.classes = classes;
  }

  @Override
  public void train(DoubleVector[] features, DoubleVector[] outcome) {
    Preconditions
        .checkArgument(
            features.length == outcome.length && features.length > 0,
            "There wasn't at least a single featurevector, or the two array didn't match in size.");
    this.classes = outcome[0].getDimension() == 1 ? 2 : outcome[0]
        .getDimension();
    DoubleMatrix mat = null;
    if (features[0].isSparse()) {
      mat = new SparseDoubleRowMatrix(features);
    } else {
      mat = new DenseDoubleMatrix(features);
    }
    ConditionalLikelihoodCostFunction func = new ConditionalLikelihoodCostFunction(
        mat, new DenseDoubleMatrix(outcome));
    DenseDoubleVector vx = new DenseDoubleVector(mat.getColumnCount() * classes);
    DoubleVector input = minimizer.minimize(func, vx, numIterations, verbose);
    theta = DenseMatrixFolder.unfoldMatrix(input, classes,
        (int) (input.getLength() / (double) classes));
  }

  @Override
  public DoubleVector predict(DoubleVector features) {
    Preconditions.checkArgument(
        features.getClass().equals(UnrollableDoubleVector.class),
        "Features must be an instance of the class UnrollableDoubleVector.");
    UnrollableDoubleVector unrollable = (UnrollableDoubleVector) features;
    return predict(unrollable.getMainVector(), unrollable.getSideVectors());
  }

  /**
   * @return the learned parameters.
   */
  public DoubleMatrix getTheta() {
    return this.theta;
  }

  public DoubleVector predict(DoubleVector feature,
      DoubleVector[] featuresPerState) {
    return ViterbiUtils.decode(theta,
        new SparseDoubleRowMatrix(Collections.singletonList(feature)),
        new SparseDoubleRowMatrix(featuresPerState), classes).getRowVector(0);
  }

  // matrix prediction
  public DoubleMatrix predict(DoubleMatrix features,
      DoubleMatrix featuresPerState) {
    return ViterbiUtils.decode(theta, features, featuresPerState, classes);
  }

}
TOP

Related Classes of de.jungblut.ner.MaxEntMarkovModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.