Package cc.mallet.topics

Source Code of cc.mallet.topics.MultinomialHMM

/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import cc.mallet.types.*;
import cc.mallet.util.Randoms;

import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import java.util.zip.*;

import java.io.*;
import java.text.NumberFormat;

import gnu.trove.*;

/**
* Latent Dirichlet Allocation.
* @author David Mimno, Andrew McCallum
*/

public class MultinomialHMM {

    int numTopics; // Number of topics to be fit
    int numStates; // Number of hidden states
    int numDocs;
    int numSequences;

    // Dirichlet(alpha,alpha,...) is the distribution over topics
    double[] alpha;
    double alphaSum;

    // Prior on per-topic multinomial distribution over words
    double beta;
    double betaSum;

    // Prior on the state-state transition distributions
    double gamma;
    double gammaSum;

    double pi;
    double sumPi;

    TIntObjectHashMap<TIntIntHashMap> documentTopics;
    int[] documentSequenceIDs;   
    int[] documentStates;

    int[][] stateTopicCounts;
    int[] stateTopicTotals;
    int[][] stateStateTransitions;
    int[] stateTransitionTotals;

    int[] initialStateCounts;

    // Keep track of the most times each topic is
    //  used in any document
    int[] maxTokensPerTopic;

    // The size of the largest document
    int maxDocLength;

    // Rather than calculating log gammas for every state and every topic
    //  we cache log predictive distributions for every possible state
    //  and document.
    double[][][] topicLogGammaCache;
    double[][] docLogGammaCache;

    int numIterations = 1000;
    int burninPeriod = 200;
    int saveSampleInterval = 10;
    int optimizeInterval = 0;
    int showTopicsInterval = 50;

    String[] topicKeys;

    Randoms random;

    NumberFormat formatter;
   
    public MultinomialHMM (int numberOfTopics, String topicsFilename, int numStates) throws IOException {
  formatter = NumberFormat.getInstance();
  formatter.setMaximumFractionDigits(5);
 
  System.out.println("LDA HMM: " + numberOfTopics);
 
  documentTopics = new TIntObjectHashMap<TIntIntHashMap>();

  this.numTopics = numberOfTopics;
  this.alphaSum = numberOfTopics;
  this.alpha = new double[numberOfTopics];
  Arrays.fill(alpha, alphaSum / numTopics);

  topicKeys = new String[numTopics];

  // This initializes numDocs as well
  loadTopicsFromFile(topicsFilename);

  documentStates = new int[ numDocs ];
  documentSequenceIDs = new int[ numDocs ];

  maxTokensPerTopic = new int[ numTopics ];
  maxDocLength = 0;
 
  //int[] histogram = new int[380];
  //int totalTokens = 0;

  for (int doc=0; doc < numDocs; doc++) {
      if (! documentTopics.containsKey(doc)) { continue; }
     
      TIntIntHashMap topicCounts = documentTopics.get(doc);
     
      int count = 0;
      for (int topic: topicCounts.keys()) {
    int topicCount = topicCounts.get(topic);
    //histogram[topicCount]++;
    //totalTokens += topicCount;

    if (topicCount > maxTokensPerTopic[topic]) {
        maxTokensPerTopic[topic] = topicCount;
    }
    count += topicCount;
      }
      if (count > maxDocLength) {
    maxDocLength = count;
      }
  }

  /*
  double runningTotal = 0.0;
  for (int i=337; i >= 0; i--) {
      runningTotal += i * histogram[i];
      System.out.format("%d\t%d\t%.3f\n", i, histogram[i],
            runningTotal / totalTokens);
  }
  */

  this.numStates = numStates;
  this.initialStateCounts = new int[numStates];

  topicLogGammaCache = new double[numStates][numTopics][];
  for (int state=0; state < numStates; state++) {
      for (int topic=0; topic < numTopics; topic++) {
    topicLogGammaCache[state][topic] = new double[ maxTokensPerTopic[topic] + 1 ];
    //topicLogGammaCache[state][topic] = new double[21];

      }
  }
  System.out.println( maxDocLength );
  docLogGammaCache = new double[numStates][ maxDocLength + 1 ];

    }

    public void setGamma(double g) {
  this.gamma = g;
    }

    public void setNumIterations (int numIterations) {
  this.numIterations = numIterations;
    }

    public void setBurninPeriod (int burninPeriod) {
  this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplayInterval(int interval) {
  this.showTopicsInterval = interval;
    }

    public void setRandomSeed(int seed) {
  random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
  this.optimizeInterval = interval;
    }
   
    public void initialize () {

  if (random == null) {
      random = new Randoms();
  }

  gammaSum = gamma * numStates;
 
  stateTopicCounts = new int[numStates][numTopics];
  stateTopicTotals = new int[numStates];
  stateStateTransitions = new int[numStates][numStates];
  stateTransitionTotals = new int[numStates];

  pi = 1000.0;
  sumPi = numStates * pi;

  int maxTokens = 0;
  int totalTokens = 0;

  numSequences = 0;

  int sequenceID;
  int currentSequenceID = -1;

  // The code to cache topic distributions
  //  takes an int-int hashmap as a mask to only update
  //  the distributions for topics that have actually changed.
  // Here we create a dummy count hash that has all the topics.
  TIntIntHashMap allTopicsDummy = new TIntIntHashMap();
  for (int topic = 0; topic < numTopics; topic++) {
      allTopicsDummy.put(topic, 1);
  }

  for (int state=0; state < numStates; state++) {
      recacheStateTopicDistribution(state, allTopicsDummy);
  }

  for (int doc = 0; doc < numDocs; doc++) {
      sampleState(doc, random, true);
  }

    }

    private void recacheStateTopicDistribution(int state, TIntIntHashMap topicCounts) {
  int[] currentStateTopicCounts = stateTopicCounts[state];
  double[][] currentStateCache = topicLogGammaCache[state];
  double[] cache;

  for (int topic: topicCounts.keys()) {
      cache = currentStateCache[topic];
     
      cache[0] = 0.0;
      for (int i=1; i < cache.length; i++) {
                    cache[i] =
                        cache[ i-1 ] +
                        Math.log( alpha[topic] + i - 1 +
          currentStateTopicCounts[topic] );
      }

  }

  docLogGammaCache[state][0] = 0.0;
  for (int i=1; i < docLogGammaCache[state].length; i++) {
                docLogGammaCache[state][i] =
                    docLogGammaCache[state][ i-1 ] +
                    Math.log( alphaSum + i - 1 +
            stateTopicTotals[state] );
  }
    }

    public void sample() throws IOException {

  long startTime = System.currentTimeMillis();
   
  for (int iterations = 1; iterations <= numIterations; iterations++) {
      long iterationStart = System.currentTimeMillis();

      //System.out.println (printStateTransitions());
      for (int doc = 0; doc < numDocs; doc++) {
    sampleState (doc, random, false);
   
    //if (doc % 10000 == 0) { System.out.println (printStateTransitions()); }
      }

      System.out.print((System.currentTimeMillis() - iterationStart) + " ");
     
      if (iterations % 10 == 0) {
    System.out.println ("<" + iterations + "> ");
   
    PrintWriter out =
        new PrintWriter(new BufferedWriter(new FileWriter("state_state_matrix." + iterations)));
    out.print(stateTransitionMatrix());
    out.close();

    out = new PrintWriter(new BufferedWriter(new FileWriter("state_topics." + iterations)));
    out.print(stateTopics());
    out.close();
   
    if (iterations % 10 == 0) {
        out = new PrintWriter(new BufferedWriter(new FileWriter("states." + iterations)));

        for (int doc = 0; doc < documentStates.length; doc++) {
      out.println(documentStates[doc]);
        }

        out.close();
    }
      }
      System.out.flush();
  }
 
  long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
  long minutes = seconds / 60;  seconds %= 60;
  long hours = minutes / 60;  minutes %= 60;
  long days = hours / 24;  hours %= 24;
  System.out.print ("\nTotal time: ");
  if (days != 0) { System.out.print(days); System.out.print(" days "); }
  if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
  if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
  System.out.print(seconds); System.out.println(" seconds");
 
    }
   
    public void loadTopicsFromFile(String stateFilename) throws IOException {
  BufferedReader in;
  if (stateFilename.endsWith(".gz")) {
      in = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFilename))));
  }
  else {
      in = new BufferedReader(new FileReader(new File(stateFilename)));
  }

  numDocs = 0;

  String line = null;
  while ((line = in.readLine()) != null) {
      if (line.startsWith("#")) {
    continue;
      }
     
      String[] fields = line.split(" ");
      int doc = Integer.parseInt(fields[0]);
      int token = Integer.parseInt(fields[1]);
      int type = Integer.parseInt(fields[2]);
      int topic = Integer.parseInt(fields[4]);

      // Now add the new topic

      if (! documentTopics.containsKey(doc)) {
    documentTopics.put(doc, new TIntIntHashMap());
      }

      if (documentTopics.get(doc).containsKey(topic)) {
    documentTopics.get(doc).increment(topic);
      }
      else {
    documentTopics.get(doc).put(topic, 1);
      }

      if (doc >= numDocs) { numDocs = doc + 1; }
  }
  in.close();

  System.out.println("loaded topics, " + numDocs + " documents");
    }

    public void loadAlphaFromFile(String alphaFilename) throws IOException {

  // Now restore the saved alpha parameters
  alphaSum = 0.0;
 
  BufferedReader in = new BufferedReader(new FileReader(new File(alphaFilename)));
  String line = null;
  while ((line = in.readLine()) != null) {
      if (line.equals("")) { continue; }

            String[] fields = line.split("\\s+");

      int topic = Integer.parseInt(fields[0]);
      alpha[topic] = 1.0; // Double.parseDouble(fields[1]);
      alphaSum += alpha[topic];

      StringBuffer topicKey = new StringBuffer();
      for (int i=2; i<fields.length; i++) {
    topicKey.append(fields[i] + " ");
      }
      topicKeys[topic] = topicKey.toString();
  }
  in.close();

  System.out.println("loaded alpha");
    }

    /*
    public void loadStatesFromFile(String stateFilename) throws IOException {

  int doc = 0;

        int state;

        BufferedReader in = new BufferedReader(new FileReader(new File(stateFilename)));
        String line = null;
        while ((line = in.readLine()) != null) {

            // We assume that the sequences are in the instance list
            //  in order.

            state = Integer.parseInt(line);
      documentStates[doc] = state;

      // Additional bookkeeping will be performed when we load sequence IDs,
      // so states MUST be loaded before sequences.

      doc++;
  }
  in.close();

  System.out.println("loaded states");
    }
    */
   

    public void loadSequenceIDsFromFile(String sequenceFilename) throws IOException {

  int doc = 0;

  int sequenceID;
  int currentSequenceID = -1;

        BufferedReader in = new BufferedReader(new FileReader(new File(sequenceFilename)));
        String line = null;
        while ((line = in.readLine()) != null) {

      // We assume that the sequences are in the instance list
      //  in order.

      String[] fields = line.split("\\t");

      sequenceID = Integer.parseInt(fields[0]);

      documentSequenceIDs[doc] = sequenceID;

      if (sequenceID != currentSequenceID) {
    numSequences ++;
      }

      currentSequenceID = sequenceID;

      doc++;
  }
  in.close();

  if (doc != numDocs) { System.out.println("Warning: number of documents with topics (" + numDocs + ") is not equal to number of docs with sequence IDs (" + doc + ")"); }

  System.out.println("loaded sequence");
    }

    private void sampleState (int doc, Randoms r, boolean initializing) {

  /*
  if (doc % 10000 == 0) {
      if (initializing) {
    System.out.println("initializing doc " + doc);
      }
      else {
    System.out.println("sampling doc " + doc);
      }
  }
  */

  long startTime = System.currentTimeMillis();
 
  // It's possible this document contains no words,
  //  in which case it has no topics, and no entry in the
  //  documentTopics hash.
  if (! documentTopics.containsKey(doc)) { return; }

        TIntIntHashMap topicCounts = documentTopics.get(doc);

  // if we are in initializing mode, this is meaningless,
  //  but it won't hurt.
  int oldState = documentStates[doc];
  int[] currentStateTopicCounts = stateTopicCounts[oldState];

  // Look at the document features (topics).
  //  If we're not in initializing mode, reduce the topic counts
  //  of the current (old) state.
 
  int docLength = 0;
 
  for (int topic: topicCounts.keys()) {
      int topicCount = topicCounts.get(topic);
      if (! initializing) {
    currentStateTopicCounts[topic] -= topicCount;
      }
      docLength += topicCount;
  }

  if (! initializing) {
      stateTopicTotals[oldState] -= docLength;
      recacheStateTopicDistribution(oldState, topicCounts);
  }


  int previousSequenceID = -1;
  if (doc > 0) {
      previousSequenceID = documentSequenceIDs[ doc-1 ];
  }

        int sequenceID = documentSequenceIDs[ doc ];

  int nextSequenceID = -1;
  if (! initializing &&
      doc < numDocs - 1) {
      nextSequenceID = documentSequenceIDs[ doc+1 ];
  }

  double[] stateLogLikelihoods = new double[numStates];
  double[] samplingDistribution = new double[numStates];

  int nextState, previousState;

  if (initializing) {
      // Initializing the states is the same as sampling them,
      //  but we only look at the previous state and we don't decrement
      //  any counts.

      if (previousSequenceID != sequenceID) {
    // New sequence, start from scratch

    for (int state = 0; state < numStates; state++) {
                    stateLogLikelihoods[state] = Math.log( (initialStateCounts[state] + pi) /
                                                           (numSequences - 1 + sumPi) );
                }
      }
      else {
    // Continuation
                previousState = documentStates[ doc-1 ];

                for (int state = 0; state < numStates; state++) {
                    stateLogLikelihoods[state] = Math.log( stateStateTransitions[previousState][state] + gamma );

                    if (Double.isInfinite(stateLogLikelihoods[state])) {
                        System.out.println("infinite end");
                    }
                }
      }
  }
  else {

      // There are four cases:

      if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
    // 1. This is a singleton document
   
    initialStateCounts[oldState]--;
   
    for (int state = 0; state < numStates; state++) {
        stateLogLikelihoods[state] = Math.log( (initialStateCounts[state] + pi) /
                 (numSequences - 1 + sumPi) );
    }
      }     
      else if (previousSequenceID != sequenceID) {
    // 2. This is the beginning of a sequence
   
    initialStateCounts[oldState]--;
   
    nextState = documentStates[doc+1];
    stateStateTransitions[oldState][nextState]--;
   
    assert(stateStateTransitions[oldState][nextState] >= 0);
   
    stateTransitionTotals[oldState]--;
   
    for (int state = 0; state < numStates; state++) {
        stateLogLikelihoods[state] = Math.log( (stateStateTransitions[state][nextState] + gamma) *
                 (initialStateCounts[state] + pi) /
                 (numSequences - 1 + sumPi) );
        if (Double.isInfinite(stateLogLikelihoods[state])) {
      System.out.println("infinite beginning");
        }
       
    }
      }
      else if (sequenceID != nextSequenceID) {
    // 3. This is the end of a sequence
   
    previousState = documentStates[doc-1];
    stateStateTransitions[previousState][oldState]--;
   
    assert(stateStateTransitions[previousState][oldState] >= 0);
   
    for (int state = 0; state < numStates; state++) {
        stateLogLikelihoods[state] = Math.log( stateStateTransitions[previousState][state] + gamma );
       
        if (Double.isInfinite(stateLogLikelihoods[state])) {
      System.out.println("infinite end");
        }
    }
      }
      else {
    // 4. This is the middle of a sequence
   
    nextState = documentStates[doc+1];
    stateStateTransitions[oldState][nextState]--;
    if (stateStateTransitions[oldState][nextState] < 0) {
        System.out.println(printStateTransitions());
        System.out.println(oldState + " -> " + nextState);
       
        System.out.println(sequenceID);
    }
    assert (stateStateTransitions[oldState][nextState] >= 0);
    stateTransitionTotals[oldState]--;
   
    previousState = documentStates[doc-1];
    stateStateTransitions[previousState][oldState]--;
    assert(stateStateTransitions[previousState][oldState] >= 0);
   
    for (int state = 0; state < numStates; state++) {
       
        if (previousState == state && state == nextState) {       
      stateLogLikelihoods[state] =
          Math.log( (stateStateTransitions[previousState][state] + gamma) *
              (stateStateTransitions[state][nextState] + 1 + gamma) /
              (stateTransitionTotals[state] + 1 + gammaSum) );
     
        }
        else if (previousState == state) {
      stateLogLikelihoods[state] =
          Math.log( (stateStateTransitions[previousState][state] + gamma) *
              (stateStateTransitions[state][nextState] + gamma) /
              (stateTransitionTotals[state] + 1 + gammaSum) );
        }
        else {
      stateLogLikelihoods[state] =
          Math.log( (stateStateTransitions[previousState][state] + gamma) *
              (stateStateTransitions[state][nextState] + gamma) /
              (stateTransitionTotals[state] + gammaSum) );
        }
       
        if (Double.isInfinite(stateLogLikelihoods[state])) {
      System.out.println("infinite middle: " + doc);
      System.out.println(previousState + " -> " +
             state + " -> " + nextState);
      System.out.println(stateStateTransitions[previousState][state] + " -> " +
             stateStateTransitions[state][nextState] + " / " +
             stateTransitionTotals[state]);
     
        }
    }
   
      }
  }

  double max = Double.NEGATIVE_INFINITY;

  for (int state = 0; state < numStates; state++) {
     
      stateLogLikelihoods[state] -= stateTransitionTotals[state] / 10;
     
      currentStateTopicCounts = stateTopicCounts[state];
      double[][] currentStateLogGammaCache = topicLogGammaCache[state];

      int totalTokens = 0;
      for (int topic: topicCounts.keys()) {
    int count = topicCounts.get(topic);

    // Cached Sampling Distribution
    stateLogLikelihoods[state] += currentStateLogGammaCache[topic][count];

   
    /*
      // Hybrid version

    if (count < currentStateLogGammaCache[topic].length) {
        stateLogLikelihoods[state] += currentStateLogGammaCache[topic][count];
    }
    else {
        int i = currentStateLogGammaCache[topic].length - 1;

        stateLogLikelihoods[state] +=
      currentStateLogGammaCache[topic][ i ];

        for (; i < count; i++) {
      stateLogLikelihoods[state] +=
          Math.log(alpha[topic] + currentStateTopicCounts[topic] + i);
        }
    }
    */

    /*
    for (int j=0; j < count; j++) {
        stateLogLikelihoods[state] +=
      Math.log( (alpha[topic] + currentStateTopicCounts[topic] + j) /
          (alphaSum + stateTopicTotals[state] + totalTokens) );

        if (Double.isNaN(stateLogLikelihoods[state])) {
      System.out.println("NaN: "  + alpha[topic] + " + " +
             currentStateTopicCounts[topic] + " + " +
             j + ") /\n" +
             "(" + alphaSum + " + " +
             stateTopicTotals[state] + " + " + totalTokens);
        }
       
        totalTokens++;
    }
    */
      }
     
      // Cached Sampling Distribution
      stateLogLikelihoods[state] -= docLogGammaCache[state][ docLength ];
   
      /*
      // Hybrid version
      if (docLength < docLogGammaCache[state].length) {
    stateLogLikelihoods[state] -= docLogGammaCache[state][docLength];
      }
      else {
    int i = docLogGammaCache[state].length - 1;
   
    stateLogLikelihoods[state] -=
        docLogGammaCache[state][ i ];
   
    for (; i < docLength; i++) {
        stateLogLikelihoods[state] -=
      Math.log(alphaSum + stateTopicTotals[state] + i);
       
    }
      }
      */

      if (stateLogLikelihoods[state] > max) {
    max = stateLogLikelihoods[state];
      }

  }
 
  double sum = 0.0;
  for (int state = 0; state < numStates; state++) {
      if (Double.isNaN(samplingDistribution[state])) {
    System.out.println(stateLogLikelihoods[state]);
      }

      assert(! Double.isNaN(samplingDistribution[state]));

      samplingDistribution[state] =
    Math.exp(stateLogLikelihoods[state] - max);
      sum += samplingDistribution[state];

      if (Double.isNaN(samplingDistribution[state])) {
    System.out.println(stateLogLikelihoods[state]);
      }

      assert(! Double.isNaN(samplingDistribution[state]));

      if (doc % 100 == 0) {
    //System.out.println(samplingDistribution[state]);
      }
  }

  int newState = r.nextDiscrete(samplingDistribution, sum);

  documentStates[doc] = newState;

  for (int topic = 0; topic < numTopics; topic++) {
      stateTopicCounts[newState][topic] += topicCounts.get(topic);
  }
  stateTopicTotals[newState] += docLength;
  recacheStateTopicDistribution(newState, topicCounts);


  if (initializing) {
      // If we're initializing the states, don't bother
      //  looking at the next state.
     
      if (previousSequenceID != sequenceID) {
    initialStateCounts[newState]++;
      }
      else {
    previousState = documentStates[doc-1];
                stateStateTransitions[previousState][newState]++;
    stateTransitionTotals[newState]++;
      }
  }
  else {
      if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
    // 1. This is a singleton document
   
    initialStateCounts[newState]++;
      }     
      else if (previousSequenceID != sequenceID) {
    // 2. This is the beginning of a sequence
   
    initialStateCounts[newState]++;
   
    nextState = documentStates[doc+1];
    stateStateTransitions[newState][nextState]++;
    stateTransitionTotals[newState]++;
      }
      else if (sequenceID != nextSequenceID) {
    // 3. This is the end of a sequence
   
    previousState = documentStates[doc-1];
    stateStateTransitions[previousState][newState]++;
      }
      else {
    // 4. This is the middle of a sequence
   
    previousState = documentStates[doc-1];
    stateStateTransitions[previousState][newState]++;
   
    nextState = documentStates[doc+1];
    stateStateTransitions[newState][nextState]++;
    stateTransitionTotals[newState]++;
   
      }
  }

    }

    public String printStateTransitions() {
  StringBuffer out = new StringBuffer();

  IDSorter[] sortedTopics = new IDSorter[numTopics];

  for (int s = 0; s < numStates; s++) {
     
      for (int topic=0; topic<numTopics; topic++) {
    sortedTopics[topic] = new IDSorter(topic, (double) stateTopicCounts[s][topic] / stateTopicTotals[s]);
      }
      Arrays.sort(sortedTopics);
     
      out.append("\n" + s + "\n");

      for (int i=0; i<4; i++) {
    int topic = sortedTopics[i].getID();
    out.append(stateTopicCounts[s][topic] + "\t" + topicKeys[topic] + "\n");
      }
     
      out.append("\n");

      out.append("[" + initialStateCounts[s] + "/" + numSequences + "] ");

      out.append("[" + stateTransitionTotals[s] + "]");
      for (int t = 0; t < numStates; t++) {
    out.append("\t");
    if (s == t) {
        out.append("[" + stateStateTransitions[s][t] + "]");
    }
    else {
        out.append(stateStateTransitions[s][t]);
    }
      }
      out.append("\n");
  }

  return out.toString();
    }

    public String stateTransitionMatrix() {
  StringBuffer out = new StringBuffer();

  for (int s = 0; s < numStates; s++) {
      for (int t = 0; t < numStates; t++) {
    out.append(stateStateTransitions[s][t]);
    out.append("\t");
      }
      out.append("\n");
  }

  return out.toString();
    }

    public String stateTopics() {
  StringBuffer out = new StringBuffer();

  for (int s = 0; s < numStates; s++) {
      for (int topic=0; topic<numTopics; topic++) {
    out.append(stateTopicCounts[s][topic] + "\t");
      }
      out.append("\n");
  }

  return out.toString();
    }

    public static void main (String[] args) throws IOException {

  if (args.length != 4) {
      System.err.println("Usage: MultinomialHMM [num topics] [lda state file] [lda keys file] [sequence metadata file]");
      System.exit(0);
  }

  int numTopics = Integer.parseInt(args[0]);

  MultinomialHMM hmm =
      new MultinomialHMM (numTopics, args[1], 150);

  hmm.setGamma(1.0);
  hmm.setRandomSeed(1);

  hmm.loadAlphaFromFile(args[2]);
  hmm.loadSequenceIDsFromFile(args[3]);

  hmm.initialize();

  hmm.sample();
    }
   
}
TOP

Related Classes of cc.mallet.topics.MultinomialHMM

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.