Package cc.mallet.topics.tui

Source Code of cc.mallet.topics.tui.TopicTrain

package cc.mallet.topics.tui;

import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;

import cc.mallet.topics.LDAStream;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;

//This TUI first train a topic model, and save the model to modelFile
public class TopicTrain {

  /**
   * @param args
   */
  //inputFile for training sequence, testFile for testing sequence, the two should share the same Alphabets
  static CommandOption.String inputFile = new CommandOption.String
  (TopicTrain.class, "input", "FILENAME", true, null,
   "The filename from which to read the list of training instances.  Use - for stdin.  " +
   "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);

  static CommandOption.String testFile = new CommandOption.String
  (TopicTrain.class, "test", "FILENAME", true, null,
   "The filename from which to read the list of training instances.  Use - for stdin.  " +
   "The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);

    static CommandOption.String outputModelFilename = new CommandOption.String
  (TopicTrain.class, "output-model", "FILENAME", true, null,
   "The filename in which to write the binary topic model at the end of the iterations.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String inputModelFilename = new CommandOption.String
  (TopicTrain.class, "input-model", "FILENAME", true, null,
   "The filename from which to read the binary topic model to which the --input will be appended, " +
   "allowing incremental training.  " +
   "By default this is null, indicating that no file will be read.", null);

    static CommandOption.String stateFile = new CommandOption.String
  (TopicTrain.class, "output-state", "FILENAME", true, null,
   "The filename in which to write the Gibbs sampling state after at the end of the iterations.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String stateTestFile = new CommandOption.String
  (TopicTrain.class, "output-state-test", "FILENAME", true, null,
   "The filename in which to write the Gibbs sampling state for test after at the end of the iterations.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String topicKeysFile = new CommandOption.String
  (TopicTrain.class, "output-topic-keys", "FILENAME", true, null,
     "The filename in which to write the top words for each topic and any Dirichlet parameters.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String topicTypesFile = new CommandOption.String
  (TopicTrain.class, "output-type-topics", "FILENAME", true, null,
     "The filename in which to write the matrix of phi, the probability of each type/word belonging to each topic " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String docTopicsFile = new CommandOption.String
  (TopicTrain.class, "output-doc-topics", "FILENAME", true, null,
   "The filename in which to write the topic proportions per document, at the end of the iterations.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String testDocTopicsFile = new CommandOption.String
  (TopicTrain.class, "output-testdoc-topics", "inf.theta", true, null,
   "The filename in which to write the topic proportions per test document, at the end of the iterations.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.String serialTestDocTopicsFile = new CommandOption.String
  (TopicTrain.class, "output-testdoc-serial", "inf.theta", true, null,
   "The filename in which to write the topic proportions per test document, at the end of the iterations.  " +
   "By default this is null, indicating that no file will be written.", null);

    static CommandOption.Double docTopicsThreshold = new CommandOption.Double
  (TopicTrain.class, "doc-topics-threshold", "DECIMAL", true, 0.0,
   "When writing topic proportions per document with --output-doc-topics, " +
   "do not print topics with proportions less than this threshold value.", null);

    static CommandOption.Integer docTopicsMax = new CommandOption.Integer
  (TopicTrain.class, "doc-topics-max", "INTEGER", true, -1,
   "When writing topic proportions per document with --output-doc-topics, " +
   "do not print more than INTEGER number of topics.  "+
   "A negative value indicates that all topics should be printed.", null);

    static CommandOption.Integer numTopics = new CommandOption.Integer
  (TopicTrain.class, "num-topics", "INTEGER", true, 10,
   "The number of topics to fit.", null);

    static CommandOption.Integer numIterations = new CommandOption.Integer
  (TopicTrain.class, "num-iterations", "INTEGER", true, 1000,
   "The number of iterations of Gibbs sampling.", null);

    static CommandOption.Integer infIterations = new CommandOption.Integer
  (TopicTrain.class, "inf-iterations", "INTEGER", true, 1000,
   "The number of iterations of Gibbs sampling.", null);

    //you should design this argument to call inferenceOneByOne
    static CommandOption.Integer infOneByOne = new CommandOption.Integer
  (TopicTrain.class, "inf-onebyone", "INTEGER", true, 0,
   "The number of iterations of Gibbs sampling.", null);

    //you must design this argument to do inferenceAll, usually it is not needed
    static CommandOption.Integer infAllIterations = new CommandOption.Integer
  (TopicTrain.class, "inf-all-iterations", "INTEGER", true, 0,
   "The number of inference iterations of Gibbs sampling.", null);


    static CommandOption.Integer randomSeed = new CommandOption.Integer
  (TopicTrain.class, "random-seed", "INTEGER", true, 0,
   "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);

    static CommandOption.Integer topWords = new CommandOption.Integer
  (TopicTrain.class, "num-top-words", "INTEGER", true, 20,
   "The number of most probable words to print for each topic after model estimation.", null);

    static CommandOption.Integer showTopicsInterval = new CommandOption.Integer
  (TopicTrain.class, "show-topics-interval", "INTEGER", true, 50,
   "The number of iterations between printing a brief summary of the topics so far.", null);

    static CommandOption.Integer outputModelInterval = new CommandOption.Integer
  (TopicTrain.class, "output-model-interval", "INTEGER", true, 0,
   "The number of iterations between writing the model (and its Gibbs sampling state) to a binary file.  " +
   "You must also set the --output-model to use this option, whose argument will be the prefix of the filenames.", null);

    static CommandOption.Integer outputStateInterval = new CommandOption.Integer
    (TopicTrain.class, "output-state-interval", "INTEGER", true, 0,
     "The number of iterations between writing the sampling state to a text file.  " +
     "You must also set the --output-state to use this option, whose argument will be the prefix of the filenames.", null);

    static CommandOption.Double alpha = new CommandOption.Double
  (TopicTrain.class, "alpha", "DECIMAL", true, 50.0,
   "Alpha parameter: smoothing over topic distribution.",null);

    static CommandOption.Double beta = new CommandOption.Double
  (TopicTrain.class, "beta", "DECIMAL", true, 0.01,
   "Beta parameter: smoothing over unigram distribution.",null);


  public static void main(String[] args) throws java.io.IOException
  {
    // TODO Auto-generated method stub
    // Process the command-line options
    CommandOption.setSummary (TopicTrain.class,
                  "A tool for training and test streamline topic model.");
    CommandOption.process (TopicTrain.class, args);

    LDAStream lda = null;

    if (inputFile.value != null) {
      InstanceList instances = InstanceList.load (new File(inputFile.value));
      System.out.println ("Training Data loaded.");
      lda=new LDAStream(numTopics.value, alpha.value, beta.value);
      lda.addInstances(instances);
    }
    if(testFile.value != null) {
      InstanceList testing = InstanceList.load(new File(testFile.value));
      lda.setTestingInstances(testing);
   
   
     lda.setTopicDisplay(showTopicsInterval.value, topWords.value);

    if (outputModelInterval.value != 0) {
      lda.setModelOutput(outputModelInterval.value, outputModelFilename.value);
    }

    lda.setNumIterations(numIterations.value);

        if (randomSeed.value != 0) {
            lda.setRandomSeed(randomSeed.value);
        }

        if (outputStateInterval.value != 0) {
            lda.setSaveState(outputStateInterval.value, stateFile.value);
        }

    lda.estimate();
    //save the model, we need typeTopicCounts and tokensPerTopic for empirical likelihood
    lda.write(new File (inputFile.value + ".model"));

    if (topicKeysFile.value != null) {
      lda.printTopWords(new File(topicKeysFile.value), topWords.value, false);
    }
    if (topicKeysFile.value != null) {
      lda.printTopWords(new File(topicKeysFile.value), topWords.value, false);
    }
    if (topicTypesFile.value != null) {
      lda.printPhi(new File(topicTypesFile.value), 1e-4);
    }
    if (stateFile.value != null) {
      lda.printState (lda.getData(), new File(stateFile.value));
    }
    if (docTopicsFile.value != null) {
      lda.printDocumentTopics(lda.getData(), new PrintWriter (new FileWriter ((new File(docTopicsFile.value)))),
                  docTopicsThreshold.value, docTopicsMax.value);

    }

 
   }

}
TOP

Related Classes of cc.mallet.topics.tui.TopicTrain

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.