Package cc.mallet.topics

Source Code of cc.mallet.topics.WordProb

/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import java.util.Arrays;
import java.io.*;

import cc.mallet.types.*;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.Randoms;

/**
* Latent Dirichlet Allocation.
* @author Andrew McCallum
* @deprecated Use ParallelTopicModel instead.
*/

// Think about support for incrementally adding more documents...
// (I think this means we might want to use FeatureSequence directly).
// We will also need to support a growing vocabulary!

public class LDA implements Serializable {

  int numTopics; // Number of topics to be fit
  double alpha;  // Dirichlet(alpha,alpha,...) is the distribution over topics
  double beta;   // Prior on per-topic multinomial distribution over words
  double tAlpha;
  double vBeta;
  InstanceList ilist;  // the data field of the instances is expected to hold a FeatureSequence
  int[][] topics; // indexed by <document index, sequence index>
  int numTypes;
  int numTokens;
  int[][] docTopicCounts; // indexed by <document index, topic index>
  int[][] typeTopicCounts; // indexed by <feature index, topic index>
  int[] tokensPerTopic; // indexed by <topic index>

  public LDA (int numberOfTopics)
  {
    this (numberOfTopics, 50.0, 0.01);
  }

  public LDA (int numberOfTopics, double alphaSum, double beta)
  {
    this.numTopics = numberOfTopics;
    this.alpha = alphaSum / numTopics;
    this.beta = beta;
  }

  public void estimate (InstanceList documents, int numIterations, int showTopicsInterval,
                        int outputModelInterval, String outputModelFilename,
                        Randoms r)
  {
    ilist = documents.shallowClone();
    numTypes = ilist.getDataAlphabet().size ();
    int numDocs = ilist.size();
    topics = new int[numDocs][];
    docTopicCounts = new int[numDocs][numTopics];
    typeTopicCounts = new int[numTypes][numTopics];
    tokensPerTopic = new int[numTopics];
    tAlpha = alpha * numTopics;
    vBeta = beta * numTypes;

    long startTime = System.currentTimeMillis();

    // Initialize with random assignments of tokens to topics
    // and finish allocating this.topics and this.tokens
    int topic, seqLen;
    FeatureSequence fs;
    for (int di = 0; di < numDocs; di++) {
      try {
        fs = (FeatureSequence) ilist.get(di).getData();
      } catch (ClassCastException e) {
        System.err.println ("LDA and other topic models expect FeatureSequence data, not FeatureVector data.  "
                            +"With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
        throw e;
      }
      seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      // Randomly assign tokens to topics
      for (int si = 0; si < seqLen; si++) {
        topic = r.nextInt(numTopics);
        topics[di][si] = topic;
        docTopicCounts[di][topic]++;
        typeTopicCounts[fs.getIndexAtPosition(si)][topic]++;
        tokensPerTopic[topic]++;
      }
    }
   
    this.estimate(0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r);
    // 124.5 seconds
    // 144.8 seconds after using FeatureSequence instead of tokens[][] array
    // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition()
    // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable

  }
 
  public void addDocuments(InstanceList additionalDocuments,
                           int numIterations, int showTopicsInterval,
                           int outputModelInterval, String outputModelFilename,
                           Randoms r)
  {
    if (ilist == null) throw new IllegalStateException ("Must already have some documents first.");
    for (Instance inst : additionalDocuments)
      ilist.add(inst);
    assert (ilist.getDataAlphabet() == additionalDocuments.getDataAlphabet());
    assert (additionalDocuments.getDataAlphabet().size() >= numTypes);
    numTypes = additionalDocuments.getDataAlphabet().size();
    int numNewDocs = additionalDocuments.size();
    int numOldDocs = topics.length;
    int numDocs = numOldDocs+ numNewDocs;
    // Expand various arrays to make space for the new data.
    int[][] newTopics = new int[numDocs][];
    for (int i = 0; i < topics.length; i++)
      newTopics[i] = topics[i];

    topics = newTopics; // The rest of this array will be initialized below.
    int[][] newDocTopicCounts = new int[numDocs][numTopics];
    for (int i = 0; i < docTopicCounts.length; i++)
      newDocTopicCounts[i] = docTopicCounts[i];
    docTopicCounts = newDocTopicCounts; // The rest of this array will be initialized below.
    int [][] newTypeTopicCounts = new int[numTypes][numTopics];
    for (int i = 0; i < typeTopicCounts.length; i++)
      for (int j = 0; j < numTopics; j++)
        newTypeTopicCounts[i][j] = typeTopicCounts[i][j]; // This array further populated below
   
    FeatureSequence fs;
    for (int di = numOldDocs; di < numDocs; di++) {
      try {
        fs = (FeatureSequence) additionalDocuments.get(di-numOldDocs).getData();
      } catch (ClassCastException e) {
        System.err.println ("LDA and other topic models expect FeatureSequence data, not FeatureVector data.  "
                            +"With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
        throw e;
      }
      int seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      // Randomly assign tokens to topics
      for (int si = 0; si < seqLen; si++) {
        int topic = r.nextInt(numTopics);
        topics[di][si] = topic;
        docTopicCounts[di][topic]++;
        typeTopicCounts[fs.getIndexAtPosition(si)][topic]++;
        tokensPerTopic[topic]++;
      }
    }
  }
 
  /* Perform several rounds of Gibbs sampling on the documents in the given range. */
  public void estimate (int docIndexStart, int docIndexLength,
                        int numIterations, int showTopicsInterval,
                        int outputModelInterval, String outputModelFilename,
                        Randoms r)
  {
    long startTime = System.currentTimeMillis();
    for (int iterations = 0; iterations < numIterations; iterations++) {
      if (iterations % 10 == 0) System.out.print (iterations)else System.out.print (".");
      System.out.flush();
      if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) {
        System.out.println ();
        printTopWords (5, false);
      }
      if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) {
        this.write (new File(outputModelFilename+'.'+iterations));
      }
      sampleTopicsForDocs(docIndexStart, docIndexLength, r);
    }

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
  }

  /* One iteration of Gibbs sampling, across all documents. */
  public void sampleTopicsForAllDocs (Randoms r)
  {
    double[] topicWeights = new double[numTopics];
    // Loop over every word in the corpus
    for (int di = 0; di < topics.length; di++) {
      sampleTopicsForOneDoc ((FeatureSequence)ilist.get(di).getData(),
                             topics[di], docTopicCounts[di], topicWeights, r);
    }
  }

  /* One iteration of Gibbs sampling, across all documents. */
  public void sampleTopicsForDocs (int start, int length, Randoms r)
  {
    assert (start+length <= docTopicCounts.length);
    double[] topicWeights = new double[numTopics];
    // Loop over every word in the corpus
    for (int di = start; di < start+length; di++) {
      sampleTopicsForOneDoc ((FeatureSequence)ilist.get(di).getData(),
                             topics[di], docTopicCounts[di], topicWeights, r);
    }
  }

  /*
  public double[] assignTopics (int[] testTokens, Random r)
  {
    int[] testTopics = new int[testTokens.length];
    int[] testTopicCounts = new int[numTopics];
    int numTokens = MatrixOps.sum(testTokens);
    double[] topicWeights = new double[numTopics];
    // Randomly assign topics to the words and
    // incorporate this document in the global counts
    int topic;
    for (int si = 0; si < testTokens.length; si++) {
      topic = r.nextInt (numTopics);
      testTopics[si] = topic; // analogous to this.topics
      testTopicCounts[topic]++; // analogous to this.docTopicCounts
      typeTopicCounts[testTokens[si]][topic]++;
      tokensPerTopic[topic]++;
    }
    // Repeatedly sample topic assignments for the words in this document
    for (int iterations = 0; iterations < numTokens*2; iterations++)
      sampleTopicsForOneDoc (testTokens, testTopics, testTopicCounts, topicWeights, r);
    // Remove this document from the global counts
    // and also fill topicWeights with an unnormalized distribution over topics for whole doc
    Arrays.fill (topicWeights, 0.0);
    for (int si = 0; si < testTokens.length; si++) {
      topic = testTopics[si];
      typeTopicCounts[testTokens[si]][topic]--;
      tokensPerTopic[topic]--;
      topicWeights[topic]++;
    }
    // Normalize the distribution over topics for whole doc
    for (int ti = 0; ti < numTopics; ti++)
      topicWeights[ti] /= testTokens.length;
    return topicWeights;
  }
*/

  private void sampleTopicsForOneDoc (FeatureSequence oneDocTokens, int[] oneDocTopics, // indexed by seq position
                                      int[] oneDocTopicCounts, // indexed by topic index
                                      double[] topicWeights, Randoms r)
  {
    int[] currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double topicWeightsSum;
    int docLen = oneDocTokens.getLength();
    double tw;
    // Iterate over the positions (words) in the document
    for (int si = 0; si < docLen; si++) {
      type = oneDocTokens.getIndexAtPosition(si);
      oldTopic = oneDocTopics[si];
      // Remove this token from all counts
      oneDocTopicCounts[oldTopic]--;
      typeTopicCounts[type][oldTopic]--;
      tokensPerTopic[oldTopic]--;
      // Build a distribution over topics for this token
      Arrays.fill (topicWeights, 0.0);
      topicWeightsSum = 0;
      currentTypeTopicCounts = typeTopicCounts[type];
      for (int ti = 0; ti < numTopics; ti++) {
        tw = ((currentTypeTopicCounts[ti] + beta) / (tokensPerTopic[ti] + vBeta))
              * ((oneDocTopicCounts[ti] + alpha)); // (/docLen-1+tAlpha); is constant across all topics
        topicWeightsSum += tw;
        topicWeights[ti] = tw;
      }
      // Sample a topic assignment from this distribution
      newTopic = r.nextDiscrete (topicWeights, topicWeightsSum);

      // Put that new topic into the counts
      oneDocTopics[si] = newTopic;
      oneDocTopicCounts[newTopic]++;
      typeTopicCounts[type][newTopic]++;
      tokensPerTopic[newTopic]++;
    }
  }
 
  public int[][] getDocTopicCounts(){
    return docTopicCounts;
  }
 
  public int[][] getTypeTopicCounts(){
    return typeTopicCounts;
  }

  public int[] getTokensPerTopic(){
    return tokensPerTopic;
  }

  public void printTopWords (int numWords, boolean useNewLines)
  {
    class WordProb implements Comparable {
      int wi;
      double p;
      public WordProb (int wi, double p) { this.wi = wi; this.p = p; }
      public final int compareTo (Object o2) {
        if (p > ((WordProb)o2).p)
          return -1;
        else if (p == ((WordProb)o2).p)
          return 0;
        else return 1;
      }
    }

    WordProb[] wp = new WordProb[numTypes];
    for (int ti = 0; ti < numTopics; ti++) {
      for (int wi = 0; wi < numTypes; wi++)
        wp[wi] = new WordProb (wi, ((double)typeTopicCounts[wi][ti]) / tokensPerTopic[ti]);
      Arrays.sort (wp);
      if (useNewLines) {
        System.out.println ("\nTopic "+ti);
        for (int i = 0; i < numWords; i++)
          System.out.println (ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p);
      } else {
        System.out.print ("Topic "+ti+": ");
        for (int i = 0; i < numWords; i++)
          System.out.print (ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " ");
        System.out.println();
      }
    }
  }

  public void printDocumentTopics (File f) throws IOException
  {
    printDocumentTopics (new PrintWriter (new FileWriter (f)));
  }

  public void printDocumentTopics (PrintWriter pw) {
    printDocumentTopics (pw, 0.0, -1);
  }

  public void printDocumentTopics (PrintWriter pw, double threshold, int max)
  {
    pw.println ("#doc source topic proportion ...");
    int docLen;
    double topicDist[] = new double[topics.length];
    for (int di = 0; di < topics.length; di++) {
      pw.print (di); pw.print (' ');
      if (ilist.get(di).getSource() != null){
        pw.print (ilist.get(di).getSource().toString());
      }
      else {
        pw.print("null-source");
      }
      pw.print (' ');
      docLen = topics[di].length;
      for (int ti = 0; ti < numTopics; ti++)
        topicDist[ti] = (((float)docTopicCounts[di][ti])/docLen);
      if (max < 0) max = numTopics;
      for (int tp = 0; tp < max; tp++) {
        double maxvalue = 0;
        int maxindex = -1;
        for (int ti = 0; ti < numTopics; ti++)
          if (topicDist[ti] > maxvalue) {
            maxvalue = topicDist[ti];
            maxindex = ti;
          }
        if (maxindex == -1 || topicDist[maxindex] < threshold)
          break;
        pw.print (maxindex+" "+topicDist[maxindex]+" ");
        topicDist[maxindex] = 0;
      }
      pw.println (' ');
    }
  }



  public void printState (File f) throws IOException
  {
    PrintWriter writer = new PrintWriter (new FileWriter(f));
    printState (writer);
    writer.close();
  }


  public void printState (PrintWriter pw)
  {
    Alphabet a = ilist.getDataAlphabet();
    pw.println ("#doc pos typeindex type topic");
    for (int di = 0; di < topics.length; di++) {
      FeatureSequence fs = (FeatureSequence) ilist.get(di).getData();
      for (int si = 0; si < topics[di].length; si++) {
        int type = fs.getIndexAtPosition(si);
        pw.print(di); pw.print(' ');
        pw.print(si); pw.print(' ');
        pw.print(type); pw.print(' ');
        pw.print(a.lookupObject(type)); pw.print(' ');
        pw.print(topics[di][si]); pw.println();
      }
    }
  }

  public void write (File f) {
    try {
      ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(f));
      oos.writeObject(this);
      oos.close();
    }
    catch (IOException e) {
      System.err.println("Exception writing file " + f + ": " + e);
    }
  }


  // Serialization

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 0;
  private static final int NULL_INTEGER = -1;

  private void writeObject (ObjectOutputStream out) throws IOException {
    out.writeInt (CURRENT_SERIAL_VERSION);
    out.writeObject (ilist);
    out.writeInt (numTopics);
    out.writeDouble (alpha);
    out.writeDouble (beta);
    out.writeDouble (tAlpha);
    out.writeDouble (vBeta);
    for (int di = 0; di < topics.length; di ++)
      for (int si = 0; si < topics[di].length; si++)
        out.writeInt (topics[di][si]);
    for (int di = 0; di < topics.length; di ++)
      for (int ti = 0; ti < numTopics; ti++)
        out.writeInt (docTopicCounts[di][ti]);
    for (int fi = 0; fi < numTypes; fi++)
      for (int ti = 0; ti < numTopics; ti++)
        out.writeInt (typeTopicCounts[fi][ti]);
    for (int ti = 0; ti < numTopics; ti++)
      out.writeInt (tokensPerTopic[ti]);
  }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
    int featuresLength;
    int version = in.readInt ();
    ilist = (InstanceList) in.readObject ();
    numTopics = in.readInt();
    alpha = in.readDouble();
    beta = in.readDouble();
    tAlpha = in.readDouble();
    vBeta = in.readDouble();
    int numDocs = ilist.size();
    topics = new int[numDocs][];
    for (int di = 0; di < ilist.size(); di++) {
      int docLen = ((FeatureSequence)ilist.get(di).getData()).getLength();
      topics[di] = new int[docLen];
      for (int si = 0; si < docLen; si++)
        topics[di][si] = in.readInt();
    }
    docTopicCounts = new int[numDocs][numTopics];
    for (int di = 0; di < ilist.size(); di++)
      for (int ti = 0; ti < numTopics; ti++)
        docTopicCounts[di][ti] = in.readInt();
    int numTypes = ilist.getDataAlphabet().size();
    typeTopicCounts = new int[numTypes][numTopics];
    for (int fi = 0; fi < numTypes; fi++)
      for (int ti = 0; ti < numTopics; ti++)
        typeTopicCounts[fi][ti] = in.readInt();
    tokensPerTopic = new int[numTopics];
    for (int ti = 0; ti < numTopics; ti++)
      tokensPerTopic[ti] = in.readInt();
  }

  public InstanceList getInstanceList ()
  {
    return ilist;
  }

  // Recommended to use mallet/bin/vectors2topics instead.
  public static void main (String[] args) throws IOException
  {
    InstanceList ilist = InstanceList.load (new File(args[0]));
    int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000;
    int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20;
    System.out.println ("Data loaded.");
    LDA lda = new LDA (10);
    lda.estimate (ilist, numIterations, 50, 0, null, new Randoms())// should be 1100
    lda.printTopWords (numTopWords, true);
    lda.printDocumentTopics (new File(args[0]+".lda"));
  }

}
TOP

Related Classes of cc.mallet.topics.WordProb

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.