Package uk.ac.cam.ha293.tweetlabel.topics

Source Code of uk.ac.cam.ha293.tweetlabel.topics.LDATopicModel

/*
* Based on the algorithm and formulae presented in "Parameter Estimation for Text Analysis, Gregor Heinrich 2008"
*/

package uk.ac.cam.ha293.tweetlabel.topics;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import uk.ac.cam.ha293.tweetlabel.types.CategoryScore;
import uk.ac.cam.ha293.tweetlabel.types.Corpus;
import uk.ac.cam.ha293.tweetlabel.types.Document;
import uk.ac.cam.ha293.tweetlabel.types.WordScore;
import uk.ac.cam.ha293.tweetlabel.util.Tools;

import cc.mallet.topics.ParallelTopicModel;

public class LDATopicModel implements Serializable {

  private static final long serialVersionUID = 7094157759480964151L;
 
  private Corpus corpus;
  private boolean hasRun;
  private int[][] documents;
  private int numTopics;
  private int numIterations;
  private int numSamplingIterations;
  private int numTotalIterations;
  private int samplingLag;
  private Map<String,Integer> wordIDs;
  private ArrayList<String> idLookup;
  private ArrayList<Document> docIDLookup;
  private int numWords;
  private int numDocs;
  private int numStats;
  private double alpha;
  private double beta;
  private int[][] wordTopicAssignments;
  private int[][] docTopicAssignments;
  private int[] numWordsAssignedToTopic;
  private int[] numWordsInDocument;
  private double[][] thetaSum;
  private double[][] phiSum;
  private boolean phiSumNormalised;
  private int[][] topicAssignments; //z
  private boolean printEachIteration;
 
  public LDATopicModel(Corpus corpus, int numTopics, int numIterations, int numSamplingIterations, int samplingLag, double alpha, double beta) {
    this.corpus = corpus;
    this.numTopics = numTopics;
    this.numIterations = numIterations;
    this.numSamplingIterations = numSamplingIterations;
    this.samplingLag = samplingLag;
    this.alpha = alpha;
    this.beta = beta;
    printEachIteration = false;
    numTotalIterations = numIterations;
    if(samplingLag > 0) numTotalIterations += numSamplingIterations * samplingLag;
    else numTotalIterations += numSamplingIterations;
    init();
  }
 
  public void printEachIteration() {
    printEachIteration = true;
  }
 
  //Here, we could remove the least used words instead of giving them an ID - lengths would need consideration
  public void init() {
    hasRun = false;
    Set<Document> documentSet = corpus.getDocuments();
    numDocs = documentSet.size()
    documents = new int[numDocs][];
    numWordsInDocument = new int[numDocs];
    wordIDs = new HashMap<String,Integer>();
    idLookup = new ArrayList<String>();
    docIDLookup = new ArrayList<Document>();
    int docID = 0;
    for(Document document : documentSet) {
      docIDLookup.add(document);
      String[] tokens = document.getDocumentString().split("\\s+");
      documents[docID] = new int[tokens.length];
      numWordsInDocument[docID] = tokens.length;
      for(int i=0; i<documents[docID].length; i++) {
        //Add the token's ID to the documents array
        int wordID;
        if(wordIDs.containsKey(tokens[i])) {
          wordID = wordIDs.get(tokens[i]);
        } else {
          wordID = wordIDs.keySet().size();
          wordIDs.put(tokens[i], wordID);
          idLookup.add(tokens[i]);
        }
        documents[docID][i] = wordID;
      }
      docID++;
    }
    numWords = wordIDs.keySet().size();

    //Now for random initial topic assignment
    wordTopicAssignments = new int[numWords][numTopics];
    docTopicAssignments = new int[numDocs][numTopics];
    numWordsAssignedToTopic = new int[numTopics];
    topicAssignments = new int[numDocs][];
    for(int m=0; m<documents.length; m++) {   
      topicAssignments[m] = new int[documents[m].length];
      for(int n=0; n<documents[m].length; n++) {
        //Generate a random topic and update arrays
        int topicID = (int)(Math.random()*numTopics);
        topicAssignments[m][n] = topicID;
        wordTopicAssignments[documents[m][n]][topicID]++;
        docTopicAssignments[m][topicID]++;
        numWordsAssignedToTopic[topicID]++; 
      }
    }
   
    //Initialise topic and word distributions for later   
    thetaSum = new double[numDocs][numTopics];
    phiSum = new double[numTopics][numWords];
    phiSumNormalised = false;
    numStats = 0;
  }
 
  public void runGibbsSampling() { 
    hasRun = true;
    System.out.println("Starting Gibbs sampling");
    System.out.println("Documents: "+numDocs+" docs");
    System.out.println("Words: "+numWords+" unique words");
    System.out.println("Topics: "+numTopics+" topics");
    System.out.println("Iterations: "+numTotalIterations);
   
    for(int i=0; i<numTotalIterations; i++) {
      //if(i % 50 == 0)
      System.out.println("Starting iteration "+i);
      for(int m=0; m<topicAssignments.length; m++) { //m is document index
        for(int n=0; n<topicAssignments[m].length; n++) { //n is word index
          //Get an updated topic sample for this word
          topicAssignments[m][n] = sampleTopic(m, n);
        }
      }
         
      if(i >= numIterations && (samplingLag == 0 || i % samplingLag == 0)) {
        updatePhiThetaSums();
       
        if(printEachIteration) {
          if(i % 20 == 0iterationPrintTopicsNew(5);
        }
       
        //try{System.in.read();}catch(IOException e){}
      }
    }
  }
 
  //Sample a new topic from the multinomial topic distribution
  private int sampleTopic(int m, int n) {
    //Removes this iteration's topicAssignment from the counting variables
    int topicID = topicAssignments[m][n]; //Get the currently assigned topic
    wordTopicAssignments[documents[m][n]][topicID]--; //Decrement the topic count for the current word
    docTopicAssignments[m][topicID]--; //Decrement the topic count for the current document
    numWordsAssignedToTopic[topicID]--; //Decrement the number of words the current topic has
    numWordsInDocument[m]--;
   
    //Cumulative multinomial sampling
    double[] p = new double[numTopics];
    for(int k=0; k<numTopics; k++) {
      p[k] = (wordTopicAssignments[documents[m][n]][k] + beta) / (numWordsAssignedToTopic[k] + numWords * beta) * (docTopicAssignments[m][k] + alpha) / (numWordsInDocument[m] + numTopics * alpha);
    }
   
    //Cumulative part - could be combined into above part cleverly
    for(int k=1; k<numTopics; k++) {
      p[k] += p[k-1];
    }
   
    //Sampling part - scaled because we haven't normalised
    double topicThreshold = Math.random() * p[numTopics-1];
    int sampledTopicID = 0;
    for(sampledTopicID=0; sampledTopicID<numTopics; sampledTopicID++) {
      if(topicThreshold < p[sampledTopicID]) {
        break;
      }
    }
   
    //Maybe a fix needed by faulty double arithmetic
    if(sampledTopicID >= numTopics) {
      sampledTopicID = numTopics-1;
    }
   
    //Finally, increment the relevant count variables
    wordTopicAssignments[documents[m][n]][sampledTopicID]++;
    docTopicAssignments[m][sampledTopicID]++;
    numWordsAssignedToTopic[sampledTopicID]++;
    numWordsInDocument[m]++;
   
    return sampledTopicID;
  }
 
  //By eqs 82 and 83 of paper mentioned in topmost comment
  private void updatePhiThetaSums() {
   
    for(int docID=0; docID<numDocs; docID++) {
      for(int topicID=0; topicID<numTopics; topicID++) {
        thetaSum[docID][topicID] += (docTopicAssignments[docID][topicID] + alpha) / (numWordsInDocument[docID] + numTopics * alpha);
      }
    }
   
    for(int topicID=0; topicID<numTopics; topicID++) {
      for(int wordID=0; wordID<numWords; wordID++) {
        phiSum[topicID][wordID] += (wordTopicAssignments[wordID][topicID] + beta) / (numWordsAssignedToTopic[topicID] + numWords * beta) ;
      }
    }
    numStats++; //Used if we want to take the mean of many stats
  }
 
  private double[][] getTheta() {
    double[][] theta = new double[numDocs][numTopics];
    for(int docID=0; docID<numDocs; docID++) {
      for(int topicID=0; topicID<numTopics; topicID++) {
        if(numStats > 0) theta[docID][topicID] = thetaSum[docID][topicID] / numStats;
        else theta[docID][topicID] = (docTopicAssignments[docID][topicID] + alpha) / (numWordsInDocument[docID] + numTopics * alpha);
      }
    }
    return theta;
  }
 
  private List<Map<Integer,Double>> getThetaMap() {
    List<Map<Integer,Double>> docList = new LinkedList<Map<Integer,Double>>();
    for(int docID=0; docID<numDocs; docID++) {
      Map<Integer,Double> topicMap = new HashMap<Integer,Double>();
      for(int topicID=0; topicID<numTopics; topicID++) {
        topicMap.put(topicID, thetaSum[docID][topicID] / numStats)
      }
      docList.add(Tools.sortMapByValue(topicMap));
    }
    return docList;
  }
 
  private double[][] getPhi() {
    double[][] phi = new double[numTopics][numWords];
    for(int topicID=0; topicID<numTopics; topicID++) {
      for(int wordID=0; wordID<numWords; wordID++) {
        if(numStats > 0) phi[topicID][wordID] = phiSum[topicID][wordID] / numStats;
        else phi[topicID][wordID] += (wordTopicAssignments[wordID][topicID] + beta) / (numWordsAssignedToTopic[topicID] + numWords * beta) ;
      }
    }
    return phi;
  }
 
  private List<Map<Integer,Double>> getPhiMap() {
    List<Map<Integer,Double>> topicList = new LinkedList<Map<Integer,Double>>();
    for(int topicID=0; topicID<numTopics; topicID++) {
      Map<Integer,Double> wordMap = new HashMap<Integer,Double>();
      for(int wordID=0; wordID<numWords; wordID++) {
        wordMap.put(wordID, phiSum[topicID][wordID] / numStats)
      }
      topicList.add(Tools.sortMapByValue(wordMap));
    }
    return topicList;
  }
 
  private double[][] normalisePhiSum() {
    if(phiSumNormalised) return phiSum;
    phiSumNormalised = true;
    for(int topicID=0; topicID<numTopics; topicID++) {
      for(int wordID=0; wordID<numWords; wordID++) {
        phiSum[topicID][wordID] /= numStats;
      }
    }
    return phiSum;
  }
 
  public List<List<WordScore>> getTopics() {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return null;
    }
   
    List<List<WordScore>> topics = new LinkedList<List<WordScore>>();
    double[][] phi = getPhi();
    for(int topicID=0; topicID<numTopics; topicID++) {
      double[] wordProbs = phi[topicID];
      List<WordScore> wordScores = new LinkedList<WordScore>();
      for(int wordID=0; wordID<numWords; wordID++) {
        wordScores.add(new WordScore(idLookup.get(wordID), wordProbs[wordID]));
      }
      Collections.sort(wordScores);
      Collections.reverse(wordScores);
      topics.add(wordScores);
    }
   
    return topics;
  }
   
  public List<Map<Integer,Double>> getTopicsNew() {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return null;
    }
   
    List<Map<Integer,Double>> topics = new ArrayList<Map<Integer,Double>>();
    double[][] phi = normalisePhiSum();
    for(int topicID=0; topicID<numTopics; topicID++) {
      double[] wordProbs = phi[topicID];
      Map<Integer,Double> wordScores = new HashMap<Integer,Double>();
      for(int wordID=0; wordID<numWords; wordID++) {
        wordScores.put(wordID, wordProbs[wordID]);
      }
      Tools.sortMapByValueDesc(wordScores);
      topics.add(wordScores);
    }
    return topics;
  }
 
  public double[][] getTopicsUnsorted() {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return null;
    }
   
    double[][] phi = normalisePhiSum();
    return phi; //Hmmm...
  }
 
  public List<List<WordScore>> getDocuments() {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return null;
    }
   
    List<List<WordScore>> topics = new LinkedList<List<WordScore>>();
    double[][] theta = getTheta();
    for(int docID=0; docID<numDocs; docID++) {
      double[] topicProbs = theta[docID];
      List<WordScore> wordScores = new LinkedList<WordScore>();
      for(int topicID=0; topicID<numTopics; topicID++) {
        wordScores.add(new WordScore("Topic"+topicID, topicProbs[topicID]));
      }
      Collections.sort(wordScores);
      Collections.reverse(wordScores);
      topics.add(wordScores);
    }
   
    return topics;
  }
 
  public void printTopics(int topWords) {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return;
    }
   
    List<List<WordScore>> topics = getTopics();
    for(int k=0; k<numTopics; k++) {
      System.out.print("Topic "+k+":");
      List<WordScore> words = topics.get(k);
      for(int n=0; n<topWords; n++) {
        if(n == numWords) break; //incase topWords is huge or Corpus is tiny...
        //System.out.println(words.get(n).getWord()+" = "+words.get(n).getScore());
        System.out.print(words.get(n).getWord()+" ");
      }
      System.out.println();
    }
  }
 
  //Note - this is destructive as it calls normalisePhiSum
  public void printTopicsNew(int topWords) {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return;
    }
   
    double[][] phi;
    if(phiSumNormalised) phi = phiSum;
    else phi = normalisePhiSum();
    for(int topicID=0; topicID<numTopics; topicID++) {
      double[] wordProbs = phi[topicID];
     
      List<WordScore> wordScores = new LinkedList<WordScore>();
      for(int wordID=0; wordID<numWords; wordID++) {
        wordScores.add(new WordScore(idLookup.get(wordID), wordProbs[wordID]));
      }
      Collections.sort(wordScores);
      Collections.reverse(wordScores);
     
      System.out.print("Topic "+topicID+": ");
      int count=0;
      for(WordScore score : wordScores) {
        if(count==topWords) break;
        System.out.print(score.getWord()+" ");
        count++;
      }
      System.out.println();
    }
  }
 
  public boolean isPhiSumNormalised() {
    return phiSumNormalised;
  }
 
  private void iterationPrintTopics(int topWords) {
    List<List<WordScore>> topics = getTopics();
    for(int k=0; k<numTopics; k++) {
      System.out.print("Topic "+k+": ");
      List<WordScore> words = topics.get(k);
      for(int n=0; n<topWords; n++) {
        if(n == numWords) break; //incase topWords is huge or Corpus is tiny...
        System.out.print(words.get(n).getWord()+" ");
      }
      System.out.println();
    }
  }
 
  private void iterationPrintTopicsNew(int topWords) {
    List<Map<Integer,Double>> phiMap = getPhiMap();
    int topicID = 0;
    for(Map<Integer,Double> words : phiMap) {
      System.out.print("Topic "+topicID+": ");
      int count = 0;
      for(int wordID : words.keySet()) {
        if(count == topWords) break;
        System.out.print(idLookup.get(wordID)+" ");
        count++;
      }
      System.out.println();
      topicID++;
    }
  }
 
  public void printDocuments(int topTopics) {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return;
    }
   
    List<List<WordScore>> docs = getDocuments();
    for(int d=0; d<numDocs; d++) {
      System.out.println("Document "+d+":");
      List<WordScore> topics = docs.get(d);
      for(int n=0; n<topTopics; n++) {
        if(n == numWords) break;
        System.out.println(topics.get(n).getWord()+" = "+topics.get(n).getScore());
      }
      try {System.in.read();} catch (IOException e) {}
    }
  }
 
  public void printDocumentsVerbose(int topTopics) {
    if(!hasRun) {
      System.err.println("Gibbs sampler has not yet run");
      return;
    }
   
    List<List<WordScore>> docs = getDocuments();
    for(int d=0; d<numDocs; d++) {
      List<WordScore> topics = docs.get(d);
      System.out.print("Document "+d+": ");
      for(int n=0; n<documents[d].length; n++) {
        System.out.print(idLookup.get(documents[d][n])+"["+topicAssignments[d][n]+"] ");
      }
      System.out.println();
      for(int n=0; n<topTopics; n++) {
        if(n == numWords) break;
        System.out.println(topics.get(n).getWord()+" = "+topics.get(n).getScore());
      }
      try {System.in.read();} catch (IOException e) {}
    }
  }
 
  public void print() {
    //TODO
  }

    public void save(String name) {
    try {
      String filename = "models/lda/"+name+".model";
      FileOutputStream fileOut = new FileOutputStream(filename);
      ObjectOutputStream objectOut = new ObjectOutputStream(fileOut);
      objectOut.writeObject(this);
      objectOut.close();
      System.out.println("Saved LDA topic model "+name);
    } catch (FileNotFoundException e) {
      System.out.println("Couldn't save LDA topic model "+name);
      e.printStackTrace();
    } catch (IOException e) {
      System.out.println("Couldn't save LDA topic model "+name);
      e.printStackTrace();     
    }     
    }
   
    public static LDATopicModel load(String name) {
    try {
      String filename = "models/lda/"+name+".model";
      FileInputStream fileIn = new FileInputStream(filename);
      ObjectInputStream objectIn = new ObjectInputStream(fileIn);
      LDATopicModel model = (LDATopicModel)objectIn.readObject();
      objectIn.close();
      System.out.println("Loaded LDA topic model "+name);
      return model;
    } catch (FileNotFoundException e) {
      System.out.println("Couldn't load LDA topic model "+name);
      e.printStackTrace();
    } catch (IOException e) {
      System.out.println("Couldn't load LDA topic model "+name);
      e.printStackTrace();     
    } catch (ClassNotFoundException e) {
      System.out.println("Couldn't load LDA topic model "+name);
      e.printStackTrace();     
    }
    return null;
    }
   
    public long getDocIDFromIndex(int m) {
      return docIDLookup.get(m).getId();
    }
   
    public Map<String,Integer> getVocab() {
      return wordIDs;
    }
 
}
TOP

Related Classes of uk.ac.cam.ha293.tweetlabel.topics.LDATopicModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.