/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import java.util.*;

import java.text.NumberFormat;

import cc.mallet.types.*;
import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;

* Latent Dirichlet Allocation for loosely parallel corpora in arbitrary languages
* @author David Mimno, Andrew McCallum

public class PolylingualTopicModel implements Serializable {

  static CommandOption.SpacedStrings languageInputFiles = new CommandOption.SpacedStrings
    (PolylingualTopicModel.class, "language-inputs", "FILENAME [FILENAME ...]", true, null,
      "Filenames for polylingual topic model. Each language should have its own file, " +
      "with the same number of instances in each file. If a document is missing in " +
     "one language, there should be an empty instance.", null);

  static CommandOption.String outputModelFilename = new CommandOption.String
    (PolylingualTopicModel.class, "output-model", "FILENAME", true, null,
      "The filename in which to write the binary topic model at the end of the iterations.  " +
     "By default this is null, indicating that no file will be written.", null);

  static CommandOption.String inputModelFilename = new CommandOption.String
    (PolylingualTopicModel.class, "input-model", "FILENAME", true, null,
      "The filename from which to read the binary topic model to which the --input will be appended, " +
      "allowing incremental training.  " +
     "By default this is null, indicating that no file will be read.", null);

  static CommandOption.String inferencerFilename = new CommandOption.String
    (PolylingualTopicModel.class, "inferencer-filename", "FILENAME", true, null,
      "A topic inferencer applies a previously trained topic model to new documents.  " +
     "By default this is null, indicating that no file will be written.", null);

  static CommandOption.String evaluatorFilename = new CommandOption.String
    (PolylingualTopicModel.class, "evaluator-filename", "FILENAME", true, null,
      "A held-out likelihood evaluator for new documents.  " +
     "By default this is null, indicating that no file will be written.", null);

  static CommandOption.String stateFile = new CommandOption.String
    (PolylingualTopicModel.class, "output-state", "FILENAME", true, null,
      "The filename in which to write the Gibbs sampling state after at the end of the iterations.  " +
     "By default this is null, indicating that no file will be written.", null);

  static CommandOption.String topicKeysFile = new CommandOption.String
    (PolylingualTopicModel.class, "output-topic-keys", "FILENAME", true, null,
      "The filename in which to write the top words for each topic and any Dirichlet parameters.  " +
     "By default this is null, indicating that no file will be written.", null);

  static CommandOption.String docTopicsFile = new CommandOption.String
    (PolylingualTopicModel.class, "output-doc-topics", "FILENAME", true, null,
      "The filename in which to write the topic proportions per document, at the end of the iterations.  " +
     "By default this is null, indicating that no file will be written.", null);

  static CommandOption.Double docTopicsThreshold = new CommandOption.Double
    (PolylingualTopicModel.class, "doc-topics-threshold", "DECIMAL", true, 0.0,
      "When writing topic proportions per document with --output-doc-topics, " +
     "do not print topics with proportions less than this threshold value.", null);

  static CommandOption.Integer docTopicsMax = new CommandOption.Integer
    (PolylingualTopicModel.class, "doc-topics-max", "INTEGER", true, -1,
      "When writing topic proportions per document with --output-doc-topics, " +
      "do not print more than INTEGER number of topics.  "+
     "A negative value indicates that all topics should be printed.", null);

  static CommandOption.Integer outputModelIntervalOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "output-model-interval", "INTEGER", true, 0,
      "The number of iterations between writing the model (and its Gibbs sampling state) to a binary file.  " +
     "You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null);

  static CommandOption.Integer outputStateIntervalOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "output-state-interval", "INTEGER", true, 0,
      "The number of iterations between writing the sampling state to a text file.  " +
     "You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null);

  static CommandOption.Integer numTopicsOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "num-topics", "INTEGER", true, 10,
     "The number of topics to fit.", null);

  static CommandOption.Integer numIterationsOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "num-iterations", "INTEGER", true, 1000,
     "The number of iterations of Gibbs sampling.", null);

  static CommandOption.Integer randomSeedOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "random-seed", "INTEGER", true, 0,
     "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);

  static CommandOption.Integer topWordsOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "num-top-words", "INTEGER", true, 20,
     "The number of most probable words to print for each topic after model estimation.", null);

  static CommandOption.Integer showTopicsIntervalOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "show-topics-interval", "INTEGER", true, 50,
     "The number of iterations between printing a brief summary of the topics so far.", null);

  static CommandOption.Integer optimizeIntervalOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "optimize-interval", "INTEGER", true, 0,
     "The number of iterations between reestimating dirichlet hyperparameters.", null);

  static CommandOption.Integer optimizeBurnInOption = new CommandOption.Integer
    (PolylingualTopicModel.class, "optimize-burn-in", "INTEGER", true, 200,
     "The number of iterations to run before first estimating dirichlet hyperparameters.", null);

  static CommandOption.Double alphaOption = new CommandOption.Double
    (PolylingualTopicModel.class, "alpha", "DECIMAL", true, 50.0,
     "Alpha parameter: smoothing over topic distribution.",null);

  static CommandOption.Double betaOption = new CommandOption.Double
    (PolylingualTopicModel.class, "beta", "DECIMAL", true, 0.01,
     "Beta parameter: smoothing over unigram distribution.",null);
  public class TopicAssignment implements Serializable {
    public Instance[] instances;
    public LabelSequence[] topicSequences;
    public Labeling topicDistribution;
    public TopicAssignment (Instance[] instances, LabelSequence[] topicSequences) {
      this.instances = instances;
      this.topicSequences = topicSequences;

  int numLanguages = 1;

  protected ArrayList<TopicAssignment> data;  // the training instances and their topic assignments
  protected LabelAlphabet topicAlphabet;  // the alphabet for the topics

  protected int numStopwords = 0;
  protected int numTopics; // Number of topics to be fit

  HashSet<String> testingIDs = null;

  // These values are used to encode type/topic counts as
  //  count/topic pairs in a single int.
  protected int topicMask;
  protected int topicBits;

  protected Alphabet[] alphabets;
  protected int[] vocabularySizes;

  protected double[] alpha;   // Dirichlet(alpha,alpha,...) is the distribution over topics
  protected double alphaSum;
  protected double[] betas;   // Prior on per-topic multinomial distribution over words
  protected double[] betaSums;

  protected int[] languageMaxTypeCounts;

  public static final double DEFAULT_BETA = 0.01;
  protected double[] languageSmoothingOnlyMasses;
  protected double[][] languageCachedCoefficients;
  int topicTermCount = 0;
  int betaTopicCount = 0;
  int smoothingOnlyCount = 0;

  // An array to put the topic counts for the current document.
  // Initialized locally below.  Defined here to avoid
  // garbage collection overhead.
  protected int[] oneDocTopicCounts; // indexed by <document index, topic index>

  protected int[][][] languageTypeTopicCounts; // indexed by <feature index, topic index>
  protected int[][] languageTokensPerTopic; // indexed by <topic index>

  // for dirichlet estimation
  protected int[] docLengthCounts; // histogram of document sizes, summed over languages
  protected int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>

  protected int iterationsSoFar = 1;
  public int numIterations = 1000;
  public int burninPeriod = 5;
  public int saveSampleInterval = 5; // was 10; 
  public int optimizeInterval = 10;
  public int showTopicsInterval = 10; // was 50;
  public int wordsPerTopic = 7;

  protected int saveModelInterval = 0;
  protected String modelFilename;

  protected int saveStateInterval = 0;
  protected String stateFilename = null;
  protected Randoms random;
  protected NumberFormat formatter;
  protected boolean printLogLikelihood = false;
  public PolylingualTopicModel (int numberOfTopics) {
    this (numberOfTopics, numberOfTopics);
  public PolylingualTopicModel (int numberOfTopics, double alphaSum) {
    this (numberOfTopics, alphaSum, new Randoms());
  private static LabelAlphabet newLabelAlphabet (int numTopics) {
    LabelAlphabet ret = new LabelAlphabet();
    for (int i = 0; i < numTopics; i++)
    return ret;
  public PolylingualTopicModel (int numberOfTopics, double alphaSum, Randoms random) {
    this (newLabelAlphabet (numberOfTopics), alphaSum, random);
  public PolylingualTopicModel (LabelAlphabet topicAlphabet, double alphaSum, Randoms random)
  { = new ArrayList<TopicAssignment>();
    this.topicAlphabet = topicAlphabet;
    this.numTopics = topicAlphabet.size();

    if (Integer.bitCount(numTopics) == 1) {
      // exact power of 2
      topicMask = numTopics - 1;
      topicBits = Integer.bitCount(topicMask);
    else {
      // otherwise add an extra bit
      topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
      topicBits = Integer.bitCount(topicMask);

    this.alphaSum = alphaSum;
    this.alpha = new double[numTopics];
    Arrays.fill(alpha, alphaSum / numTopics);
    this.random = random;
    formatter = NumberFormat.getInstance();

    System.err.println("Polylingual LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
               Integer.toBinaryString(topicMask) + " topic mask");

    public void loadTestingIDs(File testingIDFile) throws IOException {
        testingIDs = new HashSet();

        BufferedReader in = new BufferedReader(new FileReader(testingIDFile));
        String id = null;
        while ((id = in.readLine()) != null) {
  public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
  public int getNumTopics() { return numTopics; }
  public ArrayList<TopicAssignment> getData() { return data; }
  public void setNumIterations (int numIterations) {
    this.numIterations = numIterations;

  public void setBurninPeriod (int burninPeriod) {
    this.burninPeriod = burninPeriod;

  public void setTopicDisplay(int interval, int n) {
    this.showTopicsInterval = interval;
    this.wordsPerTopic = n;

  public void setRandomSeed(int seed) {
    random = new Randoms(seed);

  public void setOptimizeInterval(int interval) {
    this.optimizeInterval = interval;

  public void setModelOutput(int interval, String filename) {
    this.saveModelInterval = interval;
    this.modelFilename = filename;
  /** Define how often and where to save the state
   * @param interval Save a copy of the state every <code>interval</code> iterations.
   * @param filename Save the state to this file, with the iteration number as a suffix
  public void setSaveState(int interval, String filename) {
    this.saveStateInterval = interval;
    this.stateFilename = filename;
  public void addInstances (InstanceList[] training) {

    numLanguages = training.length;

    languageTokensPerTopic = new int[numLanguages][numTopics];
    alphabets = new Alphabet[ numLanguages ];
    vocabularySizes = new int[ numLanguages ];
    betas = new double[ numLanguages ];
    betaSums = new double[ numLanguages ];
    languageMaxTypeCounts = new int[ numLanguages ];
    languageTypeTopicCounts = new int[ numLanguages ][][];
    int numInstances = training[0].size();

    HashSet[] stoplists = new HashSet[ numLanguages ];

    for (int language = 0; language < numLanguages; language++) {

      if (training[language].size() != numInstances) {
        System.err.println("Warning: language " + language + " has " +
                   training[language].size() + " instances, lang 0 has " +

      alphabets[ language ] = training[ language ].getDataAlphabet();
      vocabularySizes[ language ] = alphabets[ language ].size();
      betas[language] = DEFAULT_BETA;
      betaSums[language] = betas[language] * vocabularySizes[ language ];
      languageTypeTopicCounts[language] = new int[ vocabularySizes[language] ][];

      int[][] typeTopicCounts = languageTypeTopicCounts[language];

      // Get the total number of occurrences of each word type
      int[] typeTotals = new int[ vocabularySizes[language] ];
      for (Instance instance : training[language]) {
        if (testingIDs != null &&
          testingIDs.contains(instance.getName())) {

        FeatureSequence tokens = (FeatureSequence) instance.getData();
        for (int position = 0; position < tokens.getLength(); position++) {
          int type = tokens.getIndexAtPosition(position);
          typeTotals[ type ]++;

      /* Automatic stoplist creation, currently disabled
      TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
      for (int type = 0; type < vocabularySizes[language]; type++) {
        sortedWords.add(new IDSorter(type, typeTotals[type]));

      stoplists[language] = new HashSet<Integer>();
      Iterator<IDSorter> typeIterator = sortedWords.iterator();
      int totalStopwords = 0;

      while (typeIterator.hasNext() && totalStopwords < numStopwords) {
      // Allocate enough space so that we never have to worry about
      //  overflows: either the number of topics or the number of times
      //  the type occurs.
      for (int type = 0; type < vocabularySizes[language]; type++) {
        if (typeTotals[type] > languageMaxTypeCounts[language]) {
          languageMaxTypeCounts[language] = typeTotals[type];
        typeTopicCounts[type] = new int[ Math.min(numTopics, typeTotals[type]) ];
    for (int doc = 0; doc < numInstances; doc++) {

      if (testingIDs != null &&
        testingIDs.contains(training[0].get(doc).getName())) {

      Instance[] instances = new Instance[ numLanguages ];
      LabelSequence[] topicSequences = new LabelSequence[ numLanguages ];

      for (int language = 0; language < numLanguages; language++) {
        int[][] typeTopicCounts = languageTypeTopicCounts[language];
        int[] tokensPerTopic = languageTokensPerTopic[language];

        instances[language] = training[language].get(doc);
        FeatureSequence tokens = (FeatureSequence) instances[language].getData();
        topicSequences[language] =
          new LabelSequence(topicAlphabet, new int[ tokens.size() ]);
        int[] topics = topicSequences[language].getFeatures();
        for (int position = 0; position < tokens.size(); position++) {
          int type = tokens.getIndexAtPosition(position);
          int[] currentTypeTopicCounts = typeTopicCounts[ type ];
          int topic = random.nextInt(numTopics);

          // If the word is one of the [numStopwords] most
          //  frequent words, put it in a non-sampled topic.
          //if (stoplists[language].contains(type)) {
          //  topic = -1;

          topics[position] = topic;
          // The format for these arrays is
          //  the topic in the rightmost bits
          //  the count in the remaining (left) bits.
          // Since the count is in the high bits, sorting (desc)
          //  by the numeric value of the int guarantees that
          //  higher counts will be before the lower counts.
          // Start by assuming that the array is either empty
          //  or is in sorted (descending) order.
          // Here we are only adding counts, so if we find
          //  an existing location with the topic, we only need
          //  to ensure that it is not larger than its left neighbor.
          int index = 0;
          int currentTopic = currentTypeTopicCounts[index] & topicMask;
          int currentValue;
          while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
              // Debugging output...
             if (index >= currentTypeTopicCounts.length) {
              for (int i=0; i < currentTypeTopicCounts.length; i++) {
                System.out.println((currentTypeTopicCounts[i] & topicMask) + ":" +
                           (currentTypeTopicCounts[i] >> topicBits) + " ");
              System.out.println(type + " " + typeTotals[type]);
            currentTopic = currentTypeTopicCounts[index] & topicMask;
          currentValue = currentTypeTopicCounts[index] >> topicBits;
          if (currentValue == 0) {
            // new value is 1, so we don't have to worry about sorting
            //  (except by topic suffix, which doesn't matter)
            currentTypeTopicCounts[index] =
              (1 << topicBits) + topic;
          else {
            currentTypeTopicCounts[index] =
              ((currentValue + 1) << topicBits) + topic;
            // Now ensure that the array is still sorted by
            //  bubbling this value up.
            while (index > 0 &&
                 currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
              int temp = currentTypeTopicCounts[index];
              currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
              currentTypeTopicCounts[index - 1] = temp;

      TopicAssignment t = new TopicAssignment (instances, topicSequences);
      data.add (t);


    languageSmoothingOnlyMasses = new double[ numLanguages ];
    languageCachedCoefficients = new double[ numLanguages ][ numTopics ];


   *  Gather statistics on the size of documents
   *  and create histograms for use in Dirichlet hyperparameter
   *  optimization.
  private void initializeHistograms() {

    int maxTokens = 0;
    int totalTokens = 0;

    for (int doc = 0; doc < data.size(); doc++) {
      int length = 0;
      for (LabelSequence sequence : data.get(doc).topicSequences) {
        length += sequence.getLength();

      if (length > maxTokens) {
        maxTokens = length;

      totalTokens += length;

    System.err.println("max tokens: " + maxTokens);
    System.err.println("total tokens: " + totalTokens);

    docLengthCounts = new int[maxTokens + 1];
    topicDocCounts = new int[numTopics][maxTokens + 1];

  private void cacheValues() {

    for (int language = 0; language < numLanguages; language++) {
      languageSmoothingOnlyMasses[language] = 0.0;
      for (int topic=0; topic < numTopics; topic++) {
        languageSmoothingOnlyMasses[language] +=
          alpha[topic] * betas[language] /
          (languageTokensPerTopic[language][topic] + betaSums[language]);
        languageCachedCoefficients[language][topic] =
          alpha[topic] / (languageTokensPerTopic[language][topic] + betaSums[language]);
  private void clearHistograms() {
    Arrays.fill(docLengthCounts, 0);
    for (int topic = 0; topic < topicDocCounts.length; topic++)
      Arrays.fill(topicDocCounts[topic], 0);

  public void estimate () throws IOException {
    estimate (numIterations);
  public void estimate (int iterationsThisRound) throws IOException {

    long startTime = System.currentTimeMillis();
    int maxIteration = iterationsSoFar + iterationsThisRound;

    long totalTime = 0;
    for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) {
      long iterationStart = System.currentTimeMillis();
      if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) {
        printTopWords (System.out, wordsPerTopic, false);


      if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) {
        this.printState(new File(stateFilename + '.' + iterationsSoFar));

        if (saveModelInterval != 0 && iterations % saveModelInterval == 0) {
        this.write (new File(modelFilename+'.'+iterations));

      // TODO this condition should also check that we have more than one sample to work with here
      // (The number of samples actually obtained is not yet tracked.)
      if (iterationsSoFar > burninPeriod && optimizeInterval != 0 &&
        iterationsSoFar % optimizeInterval == 0) {

        alphaSum = Dirichlet.learnParameters(alpha, topicDocCounts, docLengthCounts);

      // Loop over every document in the corpus
      topicTermCount = betaTopicCount = smoothingOnlyCount = 0;

      for (int doc = 0; doc < data.size(); doc++) {

        sampleTopicsForOneDoc (data.get(doc),
                     (iterationsSoFar >= burninPeriod &&
                    iterationsSoFar % saveSampleInterval == 0));
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            totalTime += elapsedMillis;

      if ((iterationsSoFar + 1) % 10 == 0) {
        double ll = modelLogLikelihood();
        System.out.println(elapsedMillis + "\t" + totalTime + "\t" +
      else {
        System.out.print(elapsedMillis + " ");

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
  public void optimizeBetas() {
    for (int language = 0; language < numLanguages; language++) {
      // The histogram starts at count 0, so if all of the
      //  tokens of the most frequent type were assigned to one topic,
      //  we would need to store a maxTypeCount + 1 count.
      int[] countHistogram = new int[languageMaxTypeCounts[language] + 1];
      // Now count the number of type/topic pairs that have
      //  each number of tokens.
      int[][] typeTopicCounts = languageTypeTopicCounts[language];
      int[] tokensPerTopic = languageTokensPerTopic[language];

      int index;
      for (int type = 0; type < vocabularySizes[language]; type++) {
        int[] counts = typeTopicCounts[type];
        index = 0;
        while (index < counts.length &&
             counts[index] > 0) {
          int count = counts[index] >> topicBits;
      // Figure out how large we need to make the "observation lengths"
      //  histogram.
      int maxTopicSize = 0;
      for (int topic = 0; topic < numTopics; topic++) {
        if (tokensPerTopic[topic] > maxTopicSize) {
          maxTopicSize = tokensPerTopic[topic];
      // Now allocate it and populate it.
      int[] topicSizeHistogram = new int[maxTopicSize + 1];
      for (int topic = 0; topic < numTopics; topic++) {
        topicSizeHistogram[ tokensPerTopic[topic] ]++;
      betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram,
                                     vocabularySizes[ language ],
      betas[language] = betaSums[language] / vocabularySizes[ language ];

  protected void sampleTopicsForOneDoc (TopicAssignment topicAssignment,
                      boolean shouldSaveState) {

    int[] currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double topicWeightsSum;

    int[] localTopicCounts = new int[numTopics];
    int[] localTopicIndex = new int[numTopics];

    for (int language = 0; language < numLanguages; language++) {

      int[] oneDocTopics =
      int docLength =
      //    populate topic counts
      for (int position = 0; position < docLength; position++) {

    // Build an array that densely lists the topics that
    //  have non-zero counts.
    int denseIndex = 0;
    for (int topic = 0; topic < numTopics; topic++) {
      if (localTopicCounts[topic] != 0) {
        localTopicIndex[denseIndex] = topic;

    // Record the total number of non-zero topics
    int nonZeroTopics = denseIndex;

    for (int language = 0; language < numLanguages; language++) {

            int[] oneDocTopics =
            int docLength =
      FeatureSequence tokenSequence =
        (FeatureSequence) topicAssignment.instances[language].getData();

      int[][] typeTopicCounts = languageTypeTopicCounts[language];
      int[] tokensPerTopic = languageTokensPerTopic[language];
      double beta = betas[language];
      double betaSum = betaSums[language];

      // Initialize the smoothing-only sampling bucket
      double smoothingOnlyMass = languageSmoothingOnlyMasses[language];
      //for (int topic = 0; topic < numTopics; topic++)
      //smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
      // Initialize the cached coefficients, using only smoothing.
      //cachedCoefficients = new double[ numTopics ];
      //for (int topic=0; topic < numTopics; topic++)
      //  cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
      double[] cachedCoefficients =

      //    Initialize the topic count/beta sampling bucket
      double topicBetaMass = 0.0;
      // Initialize cached coefficients and the topic/beta
      //  normalizing constant.
      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];
        int n = localTopicCounts[topic];
        //  initialize the normalization constant for the (B * n_{t|d}) term
        topicBetaMass += beta * n /  (tokensPerTopic[topic] + betaSum)
        //  update the coefficients for the non-zero topics
        cachedCoefficients[topic] (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);

      double topicTermMass = 0.0;

      double[] topicTermScores = new double[numTopics];
      int[] topicTermIndices;
      int[] topicTermValues;
      int i;
      double score;

      //  Iterate over the positions (words) in the document
      for (int position = 0; position < docLength; position++) {
        type = tokenSequence.getIndexAtPosition(position);
        oldTopic = oneDocTopics[position];
        if (oldTopic == -1) { continue; }

        currentTypeTopicCounts = typeTopicCounts[type];
        //  Remove this token from all counts.
        // Remove this topic's contribution to the
        //  normalizing constants
        smoothingOnlyMass -= alpha[oldTopic] * beta /
          (tokensPerTopic[oldTopic] + betaSum);
        topicBetaMass -= beta * localTopicCounts[oldTopic] /
          (tokensPerTopic[oldTopic] + betaSum);
        // Decrement the local doc/topic counts
        // Maintain the dense index, if we are deleting
        //  the old topic
        if (localTopicCounts[oldTopic] == 0) {
          // First get to the dense location associated with
          //  the old topic.
          denseIndex = 0;
          // We know it's in there somewhere, so we don't
          //  need bounds checking.
          while (localTopicIndex[denseIndex] != oldTopic) {
          // shift all remaining dense indices to the left.
          while (denseIndex < nonZeroTopics) {
            if (denseIndex < localTopicIndex.length - 1) {
              localTopicIndex[denseIndex] =
                localTopicIndex[denseIndex + 1];
          nonZeroTopics --;
        // Decrement the global topic count totals
        //assert(tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
        // Add the old topic's contribution back into the
        //  normalizing constants.
        smoothingOnlyMass += alpha[oldTopic] * beta /
          (tokensPerTopic[oldTopic] + betaSum);
        topicBetaMass += beta * localTopicCounts[oldTopic] /
          (tokensPerTopic[oldTopic] + betaSum);
        // Reset the cached coefficient for this topic
        cachedCoefficients[oldTopic] =
          (alpha[oldTopic] + localTopicCounts[oldTopic]) /
          (tokensPerTopic[oldTopic] + betaSum);
        // Now go over the type/topic counts, decrementing
        //  where appropriate, and calculating the score
        //  for each topic at the same time.
        int index = 0;
        int currentTopic, currentValue;
        boolean alreadyDecremented = false;
        topicTermMass = 0.0;
        while (index < currentTypeTopicCounts.length &&
             currentTypeTopicCounts[index] > 0) {
          currentTopic = currentTypeTopicCounts[index] & topicMask;
          currentValue = currentTypeTopicCounts[index] >> topicBits;
          if (! alreadyDecremented &&
            currentTopic == oldTopic) {
            // We're decrementing and adding up the
            //  sampling weights at the same time, but
            //  decrementing may require us to reorder
            //  the topics, so after we're done here,
            //  look at this cell in the array again.
            currentValue --;
            if (currentValue == 0) {
              currentTypeTopicCounts[index] = 0;
            else {
              currentTypeTopicCounts[index] =
                (currentValue << topicBits) + oldTopic;
            // Shift the reduced value to the right, if necessary.
            int subIndex = index;
            while (subIndex < currentTypeTopicCounts.length - 1 &&
                 currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) {
              int temp = currentTypeTopicCounts[subIndex];
              currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
              currentTypeTopicCounts[subIndex + 1] = temp;
            alreadyDecremented = true;
          else {
            score =
              cachedCoefficients[currentTopic] * currentValue;
            topicTermMass += score;
            topicTermScores[index] = score;
        double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
        double origSample = sample;
        //  Make sure it actually gets set
        newTopic = -1;
        if (sample < topicTermMass) {
          i = -1;
          while (sample > 0) {
            sample -= topicTermScores[i];
          newTopic = currentTypeTopicCounts[i] & topicMask;
          currentValue = currentTypeTopicCounts[i] >> topicBits;
          currentTypeTopicCounts[i] = ((currentValue + 1) << topicBits) + newTopic;
          // Bubble the new value up, if necessary
          while (i > 0 &&
               currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
            int temp = currentTypeTopicCounts[i];
            currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
            currentTypeTopicCounts[i - 1] = temp;
        else {
          sample -= topicTermMass;
          if (sample < topicBetaMass) {
            sample /= beta;
            for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
              int topic = localTopicIndex[denseIndex];
              sample -= localTopicCounts[topic] /
                (tokensPerTopic[topic] + betaSum);

              if (sample <= 0.0) {
                newTopic = topic;
          else {
            sample -= topicBetaMass;
            sample /= beta;
            newTopic = 0;
            sample -= alpha[newTopic] /
              (tokensPerTopic[newTopic] + betaSum);
            while (sample > 0.0) {
              sample -= alpha[newTopic] /
                (tokensPerTopic[newTopic] + betaSum);
          // Move to the position for the new topic,
          //  which may be the first empty position if this
          //  is a new topic for this word.
          index = 0;
          while (currentTypeTopicCounts[index] > 0 &&
               (currentTypeTopicCounts[index] & topicMask) != newTopic) {
          // index should now be set to the position of the new topic,
          //  which may be an empty cell at the end of the list.
          if (currentTypeTopicCounts[index] == 0) {
            // inserting a new topic, guaranteed to be in
            //  order w.r.t. count, if not topic.
            currentTypeTopicCounts[index] = (1 << topicBits) + newTopic;
          else {
            currentValue = currentTypeTopicCounts[index] >> topicBits;
            currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + newTopic;
            // Bubble the increased value left, if necessary
            while (index > 0 &&
                 currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
              int temp = currentTypeTopicCounts[index];
              currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
              currentTypeTopicCounts[index - 1] = temp;
        if (newTopic == -1) {
          System.err.println("PolylingualTopicModel sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
                     topicBetaMass + " " + topicTermMass);
          newTopic = numTopics-1; // TODO is this appropriate
          //throw new IllegalStateException ("PolylingualTopicModel: New topic not sampled.");
        //assert(newTopic != -1);
        //      Put that new topic into the counts
        oneDocTopics[position] = newTopic;
        smoothingOnlyMass -= alpha[newTopic] * beta /
          (tokensPerTopic[newTopic] + betaSum);
        topicBetaMass -= beta * localTopicCounts[newTopic] /
          (tokensPerTopic[newTopic] + betaSum);
        // If this is a new topic for this document,
        //  add the topic to the dense index.
        if (localTopicCounts[newTopic] == 1) {
          // First find the point where we
          //  should insert the new topic by going to
          //  the end (which is the only reason we're keeping
          //  track of the number of non-zero
          //  topics) and working backwards
          denseIndex = nonZeroTopics;
          while (denseIndex > 0 &&
               localTopicIndex[denseIndex - 1] > newTopic) {
            localTopicIndex[denseIndex] =
              localTopicIndex[denseIndex - 1];
          localTopicIndex[denseIndex] = newTopic;
        //  update the coefficients for the non-zero topics
        cachedCoefficients[newTopic] =
          (alpha[newTopic] + localTopicCounts[newTopic]) /
          (tokensPerTopic[newTopic] + betaSum);
        smoothingOnlyMass += alpha[newTopic] * beta /
          (tokensPerTopic[newTopic] + betaSum);
        topicBetaMass += beta * localTopicCounts[newTopic] /
          (tokensPerTopic[newTopic] + betaSum);
        // Save the smoothing-only mass to the global cache
        languageSmoothingOnlyMasses[language] = smoothingOnlyMass;


    if (shouldSaveState) {
      // Update the document-topic count histogram,
      //  for dirichlet estimation

      int totalLength = 0;

      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];
        topicDocCounts[topic][ localTopicCounts[topic] ]++;
        totalLength += localTopicCounts[topic];

      docLengthCounts[ totalLength ]++;



  public void printTopWords (File file, int numWords, boolean useNewLines) throws IOException {
    PrintStream out = new PrintStream (file);
    printTopWords(out, numWords, useNewLines);
    public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {

    TreeSet[][] languageTopicSortedWords = new TreeSet[numLanguages][numTopics];

    for (int language = 0; language < numLanguages; language++) {
      TreeSet[] topicSortedWords = languageTopicSortedWords[language];
      int[][] typeTopicCounts = languageTypeTopicCounts[language];

      for (int topic = 0; topic < numTopics; topic++) {
        topicSortedWords[topic] = new TreeSet<IDSorter>();

      for (int type = 0; type < vocabularySizes[language]; type++) {
        int[] topicCounts = typeTopicCounts[type];
        int index = 0;
        while (index < topicCounts.length &&
             topicCounts[index] > 0) {
          int topic = topicCounts[index] & topicMask;
          int count = topicCounts[index] >> topicBits;
          topicSortedWords[topic].add(new IDSorter(type, count));


        for (int topic = 0; topic < numTopics; topic++) {

      out.println (topic + "\t" + formatter.format(alpha[topic]));
      for (int language = 0; language < numLanguages; language++) {
        out.print(" " + language + "\t" + languageTokensPerTopic[language][topic] + "\t" + betas[language] + "\t");

        TreeSet<IDSorter> sortedWords = languageTopicSortedWords[language][topic];
        Alphabet alphabet = alphabets[language];

        int word = 1;
        Iterator<IDSorter> iterator = sortedWords.iterator();
        while (iterator.hasNext() && word < numWords) {
          IDSorter info =;
          out.print(alphabet.lookupObject(info.getID()) + " ");

  public void printDocumentTopics (File f) throws IOException {
    printDocumentTopics (new PrintWriter (f, "UTF-8") );

  public void printDocumentTopics (PrintWriter pw) {
    printDocumentTopics (pw, 0.0, -1);

   *  @param pw          A print writer
   *  @param threshold   Only print topics with proportion greater than this number
   *  @param max         Print no more than this many topics
  public void printDocumentTopics (PrintWriter pw, double threshold, int max)  {
    pw.print ("#doc source topic proportion ...\n");
    int docLength;
    int[] topicCounts = new int[ numTopics ];

    IDSorter[] sortedTopics = new IDSorter[ numTopics ];
    for (int topic = 0; topic < numTopics; topic++) {
      // Initialize the sorters with dummy values
      sortedTopics[topic] = new IDSorter(topic, topic);

    if (max < 0 || max > numTopics) {
      max = numTopics;

    for (int di = 0; di < data.size(); di++) {

      pw.print (di); pw.print (' ');

      int totalLength = 0;

      for (int language = 0; language < numLanguages; language++) {
        LabelSequence topicSequence = (LabelSequence) data.get(di).topicSequences[language];
        int[] currentDocTopics = topicSequence.getFeatures();
        docLength = topicSequence.getLength();
        totalLength += docLength;
        // Count up the tokens
        for (int token=0; token < docLength; token++) {
          topicCounts[ currentDocTopics[token] ]++;
      // And normalize
      for (int topic = 0; topic < numTopics; topic++) {
        sortedTopics[topic].set(topic, (float) topicCounts[topic] / totalLength);

      for (int i = 0; i < max; i++) {
        if (sortedTopics[i].getWeight() < threshold) { break; }
        pw.print (sortedTopics[i].getID() + " " +
              sortedTopics[i].getWeight() + " ");
      pw.print (" \n");

      Arrays.fill(topicCounts, 0);
  public void printState (File f) throws IOException {
    PrintStream out =
      new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))),
              false, "UTF-8");
  public void printState (PrintStream out) {

    out.println ("#doc lang pos typeindex type topic");

    for (int doc = 0; doc < data.size(); doc++) {
      for (int language =0; language < numLanguages; language++) {
        FeatureSequence tokenSequence =  (FeatureSequence) data.get(doc).instances[language].getData();
        LabelSequence topicSequence =  (LabelSequence) data.get(doc).topicSequences[language];
        for (int pi = 0; pi < topicSequence.getLength(); pi++) {
          int type = tokenSequence.getIndexAtPosition(pi);
          int topic = topicSequence.getIndexAtPosition(pi);
          out.print(doc); out.print(' ');
          out.print(language); out.print(' ');
          out.print(pi); out.print(' ');
          out.print(type); out.print(' ');
          out.print(alphabets[language].lookupObject(type)); out.print(' ');
          out.print(topic); out.println();

  public double modelLogLikelihood() {
    double logLikelihood = 0.0;
    int nonZeroTopics;

    // The likelihood of the model is a combination of a
    // Dirichlet-multinomial for the words in each topic
    // and a Dirichlet-multinomial for the topics in each
    // document.

    // The likelihood function of a dirichlet multinomial is
    //   Gamma( sum_i alpha_i )   prod_i Gamma( alpha_i + N_i )
    //  prod_i Gamma( alpha_i )    Gamma( sum_i (alpha_i + N_i) )

    // So the log likelihood is
    //  logGamma ( sum_i alpha_i ) - logGamma ( sum_i (alpha_i + N_i) ) +
    //   sum_i [ logGamma( alpha_i + N_i) - logGamma( alpha_i ) ]

    // Do the documents first

    int[] topicCounts = new int[numTopics];
    double[] topicLogGammas = new double[numTopics];
    int[] docTopics;

    for (int topic=0; topic < numTopics; topic++) {
      topicLogGammas[ topic ] = Dirichlet.logGammaStirling( alpha[topic] );
    for (int doc=0; doc < data.size(); doc++) {

      int totalLength = 0;

            for (int language = 0; language < numLanguages; language++) {

                LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();

        totalLength += topicSequence.getLength();

                // Count up the tokens
                for (int token=0; token < topicSequence.getLength(); token++) {
                    topicCounts[ currentDocTopics[token] ]++;

      for (int topic=0; topic < numTopics; topic++) {
        if (topicCounts[topic] > 0) {
          logLikelihood += (Dirichlet.logGammaStirling(alpha[topic] + topicCounts[topic]) -
                    topicLogGammas[ topic ]);

      // subtract the (count + parameter) sum term
      logLikelihood -= Dirichlet.logGammaStirling(alphaSum + totalLength);

      Arrays.fill(topicCounts, 0);
    // add the parameter sum term
    logLikelihood += data.size() * Dirichlet.logGammaStirling(alphaSum);

    // And the topics

    for (int language = 0; language < numLanguages; language++) {
      int[][] typeTopicCounts = languageTypeTopicCounts[language];
      int[] tokensPerTopic = languageTokensPerTopic[language];
      double beta = betas[language];

      // Count the number of type-topic pairs
      int nonZeroTypeTopics = 0;
      for (int type=0; type < vocabularySizes[language]; type++) {
        // reuse this array as a pointer
        topicCounts = typeTopicCounts[type];
        int index = 0;
        while (index < topicCounts.length &&
             topicCounts[index] > 0) {
          int topic = topicCounts[index] & topicMask;
          int count = topicCounts[index] >> topicBits;
          logLikelihood += Dirichlet.logGammaStirling(beta + count);
          if (Double.isNaN(logLikelihood)) {
      for (int topic=0; topic < numTopics; topic++) {
        logLikelihood -=
          Dirichlet.logGammaStirling( (beta * numTopics) +
                        tokensPerTopic[ topic ] );
        if (Double.isNaN(logLikelihood)) {
          System.out.println("after topic " + topic + " " + tokensPerTopic[ topic ]);
      logLikelihood +=
        (Dirichlet.logGammaStirling(beta * numTopics)) -
        (Dirichlet.logGammaStirling(beta) * nonZeroTypeTopics);

    if (Double.isNaN(logLikelihood)) {
      System.out.println("at the end");

    return logLikelihood;
    /** Return a tool for estimating topic distributions for new documents */
    public TopicInferencer getInferencer(int language) {
    return new TopicInferencer(languageTypeTopicCounts[language], languageTokensPerTopic[language],
                   alpha, betas[language], betaSums[language]);

  // Serialization

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 0;
  private static final int NULL_INTEGER = -1;

  private void writeObject (ObjectOutputStream out) throws IOException {
    out.writeInt (CURRENT_SERIAL_VERSION);















  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
    int version = in.readInt();

    numLanguages = in.readInt();
    data = (ArrayList<TopicAssignment>) in.readObject ();
    topicAlphabet = (LabelAlphabet) in.readObject();
    numTopics = in.readInt();
    testingIDs = (HashSet<String>) in.readObject();

    topicMask = in.readInt();
    topicBits = in.readInt();
    alphabets = (Alphabet[]) in.readObject();
    vocabularySizes = (int[]) in.readObject();
    alpha = (double[]) in.readObject();
    alphaSum = in.readDouble();
    betas = (double[]) in.readObject();
    betaSums = (double[]) in.readObject();

    languageMaxTypeCounts = (int[]) in.readObject();
    languageTypeTopicCounts = (int[][][]) in.readObject();
    languageTokensPerTopic = (int[][]) in.readObject();
    languageSmoothingOnlyMasses = (double[]) in.readObject();
    languageCachedCoefficients = (double[][]) in.readObject();

    docLengthCounts = (int[]) in.readObject();
    topicDocCounts = (int[][]) in.readObject();
    numIterations = in.readInt();
    burninPeriod = in.readInt();
    saveSampleInterval = in.readInt();
    optimizeInterval = in.readInt();
    showTopicsInterval = in.readInt();
    wordsPerTopic = in.readInt();

    saveStateInterval = in.readInt();
    stateFilename = (String) in.readObject();
    saveModelInterval = in.readInt();
    modelFilename = (String) in.readObject();
    random = (Randoms) in.readObject();
    formatter = (NumberFormat) in.readObject();
    printLogLikelihood = in.readBoolean();


  public void write (File serializedModelFile) {
    try {
      ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream(serializedModelFile));
    } catch (IOException e) {
      System.err.println("Problem serializing PolylingualTopicModel to file " +
                 serializedModelFile + ": " + e);

  public static PolylingualTopicModel read (File f) throws Exception {

    PolylingualTopicModel topicModel = null;

    ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
    topicModel = (PolylingualTopicModel) ois.readObject();


    return topicModel;

  public static void main (String[] args) throws IOException {

    CommandOption.setSummary (PolylingualTopicModel.class,
                  "A tool for estimating, saving and printing diagnostics for topic models over comparable corpora.");
    CommandOption.process (PolylingualTopicModel.class, args);

    PolylingualTopicModel topicModel = null;

    if (inputModelFilename.value != null) {

      try {
        topicModel = File(inputModelFilename.value));
      } catch (Exception e) {
        System.err.println("Unable to restore saved topic model " +
                   inputModelFilename.value + ": " + e);
    else {

      int numLanguages = languageInputFiles.value.length;
      InstanceList[] training = new InstanceList[ numLanguages ];
      for (int i=0; i < training.length; i++) {
        training[i] = InstanceList.load(new File(languageInputFiles.value[i]));
        if (training[i] != null) { System.out.println(i + " is not null"); }
        else { System.out.println(i + " is null"); }

      System.out.println ("Data loaded.");
      // For historical reasons we currently only support FeatureSequence data,
      //  not the FeatureVector, which is the default for the input functions.
      //  Provide a warning to avoid ClassCastExceptions.
      if (training[0].size() > 0 &&
        training[0].get(0) != null) {
        Object data = training[0].get(0).getData();
        if (! (data instanceof FeatureSequence)) {
          System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
      topicModel = new PolylingualTopicModel (numTopicsOption.value, alphaOption.value);
      if (randomSeedOption.value != 0) {

    topicModel.setTopicDisplay(showTopicsIntervalOption.value, topWordsOption.value);


    if (outputStateIntervalOption.value != 0) {
      topicModel.setSaveState(outputStateIntervalOption.value, stateFile.value);

    if (outputModelIntervalOption.value != 0) {
      topicModel.setModelOutput(outputModelIntervalOption.value, outputModelFilename.value);


    if (topicKeysFile.value != null) {
      topicModel.printTopWords(new File(topicKeysFile.value), topWordsOption.value, false);

    if (stateFile.value != null) {
      topicModel.printState (new File(stateFile.value));

    if (docTopicsFile.value != null) {
      PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
      topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value);

    if (inferencerFilename.value != null) {
      try {
        for (int language = 0; language < topicModel.numLanguages; language++) {

          ObjectOutputStream oos =
            new ObjectOutputStream(new FileOutputStream(inferencerFilename.value + "." + language));

      } catch (Exception e) {



    if (outputModelFilename.value != null) {
      assert (topicModel != null);
      topicModel.write(new File(outputModelFilename.value));


