Package cc.mallet.topics

Source Code of cc.mallet.topics.DMRTopicModel

package cc.mallet.topics;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.OptimizationException;

import cc.mallet.types.*;
import cc.mallet.classify.MaxEnt;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.Noop;

import gnu.trove.TIntIntHashMap;

import java.io.IOException;
import java.io.PrintStream;
import java.io.File;

public class DMRTopicModel extends LDAHyper {

  MaxEnt dmrParameters = null;
    int numFeatures;
    int defaultFeatureIndex;

    Pipe parameterPipe = null;
 
  double[][] alphaCache;
    double[] alphaSumCache;

  public DMRTopicModel(int numberOfTopics) {
    super(numberOfTopics);
  }

  public void estimate (int iterationsThisRound) throws IOException {

        numFeatures = data.get(0).instance.getTargetAlphabet().size() + 1;
        defaultFeatureIndex = numFeatures - 1;

    int numDocs = data.size(); // TODO consider beginning by sub-sampling?

    alphaCache = new double[numDocs][numTopics];
        alphaSumCache = new double[numDocs];

    long startTime = System.currentTimeMillis();
        int maxIteration = iterationsSoFar + iterationsThisRound;

        for ( ; iterationsSoFar <= maxIteration; iterationsSoFar++) {
            long iterationStart = System.currentTimeMillis();

            if (showTopicsInterval != 0 && iterationsSoFar != 0 && iterationsSoFar % showTopicsInterval == 0) {
                System.out.println();
                printTopWords (System.out, wordsPerTopic, false);
      }

      if (saveStateInterval != 0 && iterationsSoFar % saveStateInterval == 0) {
                this.printState(new File(stateFilename + '.' + iterationsSoFar + ".gz"));
            }

      if (iterationsSoFar > burninPeriod && optimizeInterval != 0 &&
        iterationsSoFar % optimizeInterval == 0) {

        // Train regression parameters
        learnParameters();
      }

      // Loop over every document in the corpus

            for (int doc = 0; doc < numDocs; doc++) {
                FeatureSequence tokenSequence = (FeatureSequence) data.get(doc).instance.getData();
                LabelSequence topicSequence = (LabelSequence) data.get(doc).topicSequence;

        if (dmrParameters != null) {
          // set appropriate Alpha parameters
          setAlphas(data.get(doc).instance);
        }

                sampleTopicsForOneDoc (tokenSequence, topicSequence,
                     false, false);
            }

      long ms = System.currentTimeMillis() - iterationStart;
      if (ms > 1000) {
        System.out.print(Math.round(ms / 1000) + "s ");
      }
      else {
        System.out.print(ms + "ms ");
      }

            if (iterationsSoFar % 10 == 0) {
                System.out.println ("<" + iterationsSoFar + "> ");
                if (printLogLikelihood) System.out.println (modelLogLikelihood());
            }
            System.out.flush();
    }

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
        long minutes = seconds / 60;    seconds %= 60;
        long hours = minutes / 60;  minutes %= 60;
        long days = hours / 24; hours %= 24;
        System.out.print ("\nTotal time: ");
        if (days != 0) { System.out.print(days); System.out.print(" days "); }
        if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
        if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
        System.out.print(seconds); System.out.println(" seconds");
  }

  /**
   *  Use only the default features to set the topic prior (use no document features)
    */
  public void setAlphas() {

        double[] parameters = dmrParameters.getParameters();
       
        alphaSum = 0.0;
    smoothingOnlyMass = 0.0;

        // Use only the default features to set the topic prior (use no document features)
        for (int topic=0; topic < numTopics; topic++) {
            alpha[topic] = Math.exp( parameters[ (topic * numFeatures) + defaultFeatureIndex ] );
            alphaSum += alpha[topic];
     
      smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
      cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
        }

    }

    /** This method sets the alphas for a hypothetical "document" that contains
     *   a single non-default feature.
     */
    public void setAlphas(int featureIndex) {

        double[] parameters = dmrParameters.getParameters();
       
        alphaSum = 0.0;
    smoothingOnlyMass = 0.0;

        // Use only the default features to set the topic prior (use no document features)
        for (int topic=0; topic < numTopics; topic++) {
            alpha[topic] = Math.exp(parameters[ (topic * numFeatures) + featureIndex ]
                                    parameters[ (topic * numFeatures) + defaultFeatureIndex ] );
            alphaSum += alpha[topic];

      smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
      cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
        }

    }

  /**
   *  Set alpha based on features in an instance
   */
    public void setAlphas(Instance instance) {
   
        // we can't use the standard score functions from MaxEnt,
        //  since our features are currently in the Target.
        FeatureVector features = (FeatureVector) instance.getTarget();
        if (features == null) { setAlphas(); return; }
       
        double[] parameters = dmrParameters.getParameters();
       
        alphaSum = 0.0;
    smoothingOnlyMass = 0.0;
       
        for (int topic = 0; topic < numTopics; topic++) {
            alpha[topic] = parameters[topic*numFeatures + defaultFeatureIndex]
                + MatrixOps.rowDotProduct (parameters,
                                           numFeatures,
                                           topic, features,
                                           defaultFeatureIndex,
                                           null);
           
            alpha[topic] = Math.exp(alpha[topic]);
            alphaSum += alpha[topic];

      smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
      cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
        }
    }

  public void learnParameters() {

        // Create a "fake" pipe with the features in the data and
        //  a trove int-int hashmap of topic counts in the target.
       
        if (parameterPipe == null) {
            parameterPipe = new Noop();

            parameterPipe.setDataAlphabet(data.get(0).instance.getTargetAlphabet());
            parameterPipe.setTargetAlphabet(topicAlphabet);
        }

        InstanceList parameterInstances = new InstanceList(parameterPipe);

        if (dmrParameters == null) {
            dmrParameters = new MaxEnt(parameterPipe, new double[numFeatures * numTopics]);
        }
       
        for (int doc=0; doc < data.size(); doc++) {
           
            if (data.get(doc).instance.getTarget() == null) {
                continue;
            }

      FeatureCounter counter = new FeatureCounter(topicAlphabet);

      for (int topic : data.get(doc).topicSequence.getFeatures()) {
        counter.increment(topic);
            }

            // Put the real target in the data field, and the
            //  topic counts in the target field
            parameterInstances.add( new Instance(data.get(doc).instance.getTarget(), counter.toFeatureVector(), null, null) );

        }

        DMROptimizable optimizable = new DMROptimizable(parameterInstances, dmrParameters);
        optimizable.setRegularGaussianPriorVariance(0.5);
        optimizable.setInterceptGaussianPriorVariance(100.0);

    LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(optimizable);

    // Optimize once
    try {
      optimizer.optimize();
    } catch (OptimizationException e) {
      // step size too small
    }

    // Restart with a fresh initialization to improve likelihood
    try {
      optimizer.optimize();
    } catch (OptimizationException e) {
      // step size too small
    }
        dmrParameters = optimizable.getClassifier();

        for (int doc=0; doc < data.size(); doc++) {
            Instance instance = data.get(doc).instance;
            FeatureSequence tokens = (FeatureSequence) instance.getData();
            if (instance.getTarget() == null) { continue; }
            int numTokens = tokens.getLength();

            // This sets alpha[] and alphaSum
            setAlphas(instance);

            // Now cache alpha values
            for (int topic=0; topic < numTopics; topic++) {
                alphaCache[doc][topic] = alpha[topic];
            }
            alphaSumCache[doc] = alphaSum;
        }
    }

  public void printTopWords (PrintStream out, int numWords, boolean usingNewLines) {
    if (dmrParameters != null) { setAlphas(); }
    super.printTopWords(out, numWords, usingNewLines);
  }

  public void writeParameters(File parameterFile) throws IOException {
    if (dmrParameters != null) {
      PrintStream out = new PrintStream(parameterFile);
      dmrParameters.print(out);
      out.close();
    }
  }

    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

  public static void main (String[] args) throws IOException {

        InstanceList training = InstanceList.load (new File(args[0]));

        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;

        InstanceList testing =
            args.length > 2 ? InstanceList.load (new File(args[2])) : null;

        DMRTopicModel lda = new DMRTopicModel (numTopics);
    lda.setOptimizeInterval(100);
    lda.setTopicDisplay(100, 10);
    lda.addInstances(training);
    lda.estimate();

    lda.writeParameters(new File("dmr.parameters"));
    lda.printState(new File("dmr.state.gz"));
    }
}
TOP

Related Classes of cc.mallet.topics.DMRTopicModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.