Package de.jungblut.nlp

Source Code of de.jungblut.nlp.MarkovChain

package de.jungblut.nlp;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import org.apache.commons.math3.util.FastMath;

import com.google.common.base.Optional;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;

/**
* Markov chain, that can "learn" the state transition probabilities by a given
* input and returns the probability for a given sequence of states.
*
* @author thomas.jungblut
*
*/
public final class MarkovChain {

  /*
   * log normalized, so we can sum the probabilities instead of multiplying and
   * running into numerical problems. the probability from state i to j can be
   * retrieved by matrix.get(i,j).
   */
  private final DoubleMatrix transitionProbabilities;
  private final int numStates;

  private MarkovChain(int numStates) {
    this(numStates, new SparseDoubleRowMatrix(numStates, numStates));
  }

  private MarkovChain(int numStates, DoubleMatrix mat) {
    this.numStates = numStates;
    this.transitionProbabilities = mat;
  }

  /**
   * Trains the transition probabilities of the markov chain. <br/>
   * Each list element contains a set of states. The values of the
   * element-states are nominal and should be lower than the number of provided
   * states (each nominal will be a index, so it's from 0 to numStates-1). So
   * each element can be arbitrary sized, because in markov chains we are
   * considering the transition between two states, thus it will measure the
   * occurrence of each following two state pairs. e.G. [ 1, 2, 3, 4 ] will
   * measure the probabilities of [1,2],[2,3],[3,4].
   */
  public void train(List<int[]> states) {
    // loop over all state sets and set the count of the co-occurrence in the
    // transition probability matrix
    for (int[] array : states) {
      for (int i = 0; i < array.length - 1; i++) {
        int count = (int) transitionProbabilities.get(array[i], array[i + 1]);
        transitionProbabilities.set(array[i], array[i + 1], ++count);
      }
    }

    final int[] rowEntries = transitionProbabilities.rowIndices();
    for (int rowIndex : rowEntries) {
      DoubleVector rowVector = transitionProbabilities.getRowVector(rowIndex);
      double sum = rowVector.sum();
      Iterator<DoubleVectorElement> iterateNonZero = rowVector.iterateNonZero();
      // loop over all counts and take the log of the probability
      while (iterateNonZero.hasNext()) {
        DoubleVectorElement columnElement = iterateNonZero.next();
        int columnIndex = columnElement.getIndex();
        double probability = FastMath.log(columnElement.getValue())
            - FastMath.log(sum);
        transitionProbabilities.set(rowIndex, columnIndex, probability);
      }
    }

  }

  /**
   * Calculates the probability that the given sequence occurs.
   *
   * @return value between 0d and 1d, where 1d is very likely that the sequence
   *         is happening.
   */
  public double getProbabilityForSequence(int[] stateSequence) {
    DoubleVector distribution = getTransitionProbabilities(stateSequence);
    // normalize it by the maximum of the log probabilities
    double max = distribution.max();
    double probabilitySum = 0.0d;
    for (int i = 0; i < distribution.getDimension(); i++) {
      double probability = distribution.get(i);
      double normalizedProbability = probability - max;
      // add up the log probabilities
      probabilitySum += normalizedProbability;
    }
    // no we can exp them to get the real probability
    return FastMath.exp(probabilitySum);
  }

  /**
   * @return the average transition probability of the given sequence.
   */
  public double averageTransitionProbability(int[] sequence) {
    DoubleVector distribution = getTransitionProbabilities(sequence);
    return FastMath.exp(distribution.sum()
        / Math.max(1d, distribution.getLength()));
  }

  /**
   * @return the transition probabilities for the states.
   */
  public DoubleVector getTransitionProbabilities(int[] stateSequence) {
    DenseDoubleVector distribution = new DenseDoubleVector(
        stateSequence.length - 1);
    for (int i = 0; i < distribution.getDimension(); i++) {
      distribution.set(i,
          transitionProbabilities.get(stateSequence[i], stateSequence[i + 1]));
    }
    return distribution;
  }

  /**
   * Completes the given state sequence by picking the best next state on the
   * transition probabilities (so a transition with a high probability is picked
   * more often). If the optional random is not provided, it picks the next
   * state by the highest transition probability between the states (it predicts
   * the next state based on the previous).
   */
  public int[] completeStateSequence(Optional<Random> optionalRandom,
      int[] stateSequence, int... unsuppliedStateIndices) {
    // sort them first, then work through the array
    Arrays.sort(unsuppliedStateIndices);
    for (int index : unsuppliedStateIndices) {
      if (index == 0) {
        // special case because there is no previous state, pick the next
        if (index + 1 < stateSequence.length) {
          if (optionalRandom.isPresent()) {
            stateSequence[index] = chooseState(optionalRandom.get(),
                transitionProbabilities
                    .getColumnVector(stateSequence[index + 1]));
          } else {
            stateSequence[index] = transitionProbabilities.getColumnVector(
                stateSequence[index + 1]).maxIndex();
          }
        } else {
          throw new IllegalArgumentException("Can't guess state " + index
              + " in " + Arrays.toString(stateSequence));
        }
      } else {
        if (optionalRandom.isPresent()) {
          stateSequence[index] = chooseState(optionalRandom.get(),
              transitionProbabilities.getRowVector(stateSequence[index - 1]));
        } else {
          stateSequence[index] = transitionProbabilities.getRowVector(
              stateSequence[index - 1]).maxIndex();
        }
      }
    }
    return stateSequence;
  }

  /**
   * Chooses the state by a uniformly distributed random number, so higher
   * probable states are more likely to happen.
   *
   * @return the index of the next state.
   */
  private static int chooseState(Random random, DoubleVector probabilities) {
    final double r = random.nextDouble();
    Iterator<DoubleVectorElement> iterateNonZero = probabilities
        .iterateNonZero();
    DoubleVectorElement next = null;
    while (iterateNonZero.hasNext()) {
      next = iterateNonZero.next();
      if (r <= Math.exp(next.getValue())) {
        return next.getIndex();
      }
    }
    // return the max if we haven't escaped yet
    return probabilities.maxIndex();
  }

  /**
   * @return the state transition probability matrix to export/serialize.
   */
  public DoubleMatrix getTransitionProbabilities() {
    return transitionProbabilities;
  }

  /**
   * @return how many states were defined?
   */
  public int getNumStates() {
    return numStates;
  }

  /**
   * Creates a new markov chain with the supplied number of states.
   */
  public static MarkovChain create(int numStates) {
    return new MarkovChain(numStates);
  }

  /**
   * Creates a new markov chain with the supplied number of states and its
   * predefined transition matrix.
   */
  public static MarkovChain create(int numStates, DoubleMatrix mat) {
    return new MarkovChain(numStates, mat);
  }

}
TOP

Related Classes of de.jungblut.nlp.MarkovChain

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.