CommandOption.setSummary (Vectors2Topics.class,
"A tool for estimating, saving and printing diagnostics for topic models, such as LDA.");
CommandOption.process (Vectors2Topics.class, args);
if (usePAM.value) {
InstanceList ilist = InstanceList.load (new File(inputFile.value));
System.out.println ("Data loaded.");
if (inputModelFilename.value != null)
throw new IllegalArgumentException ("--input-model not supported with --use-pam.");
PAM4L pam = new PAM4L(pamNumSupertopics.value, pamNumSubtopics.value);
pam.estimate (ilist, numIterations.value, /*optimizeModelInterval*/50,
showTopicsInterval.value,
outputModelInterval.value, outputModelFilename.value,
randomSeed.value == 0 ? new Randoms() : new Randoms(randomSeed.value));
pam.printTopWords(topWords.value, true);
if (stateFile.value != null)
pam.printState (new File(stateFile.value));
if (docTopicsFile.value != null) {
PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
pam.printDocumentTopics (out, docTopicsThreshold.value, docTopicsMax.value);
out.close();
}
if (outputModelFilename.value != null) {
assert (pam != null);
try {
ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
oos.writeObject (pam);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write topic model to filename "+outputModelFilename.value);
}
}
}
else if (useNgrams.value) {
InstanceList ilist = InstanceList.load (new File(inputFile.value));
System.out.println ("Data loaded.");
if (inputModelFilename.value != null)
throw new IllegalArgumentException ("--input-model not supported with --use-ngrams.");
TopicalNGrams tng = new TopicalNGrams(numTopics.value,
alpha.value,
beta.value,
gamma.value,
delta.value,
delta1.value,
delta2.value);
tng.estimate (ilist, numIterations.value, showTopicsInterval.value,
outputModelInterval.value, outputModelFilename.value,
randomSeed.value == 0 ? new Randoms() : new Randoms(randomSeed.value));
tng.printTopWords(topWords.value, true);
if (stateFile.value != null)
tng.printState (new File(stateFile.value));
if (docTopicsFile.value != null) {
PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
tng.printDocumentTopics (out, docTopicsThreshold.value, docTopicsMax.value);
out.close();
}
if (outputModelFilename.value != null) {
assert (tng != null);
try {
ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
oos.writeObject (tng);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write topic model to filename "+outputModelFilename.value);
}
}
}
else if (languageInputFiles.value != null) {
// Start a new polylingual topic model
PolylingualTopicModel topicModel = null;
int numLanguages = languageInputFiles.value.length;
InstanceList[] training = new InstanceList[ languageInputFiles.value.length ];
for (int i=0; i < training.length; i++) {
training[i] = InstanceList.load(new File(languageInputFiles.value[i]));
if (training[i] != null) { System.out.println(i + " is not null"); }
else { System.out.println(i + " is null"); }
}
System.out.println ("Data loaded.");
// For historical reasons we currently only support FeatureSequence data,
// not the FeatureVector, which is the default for the input functions.
// Provide a warning to avoid ClassCastExceptions.
if (training[0].size() > 0 &&
training[0].get(0) != null) {
Object data = training[0].get(0).getData();
if (! (data instanceof FeatureSequence)) {
System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
System.exit(1);
}
}
topicModel = new PolylingualTopicModel (numTopics.value, alpha.value);
if (randomSeed.value != 0) {
topicModel.setRandomSeed(randomSeed.value);
}
topicModel.addInstances(training);
topicModel.setTopicDisplay(showTopicsInterval.value, topWords.value);
topicModel.setNumIterations(numIterations.value);
topicModel.setOptimizeInterval(optimizeInterval.value);
topicModel.setBurninPeriod(optimizeBurnIn.value);
if (outputStateInterval.value != 0) {
topicModel.setSaveState(outputStateInterval.value, stateFile.value);
}
if (outputModelInterval.value != 0) {
topicModel.setModelOutput(outputModelInterval.value, outputModelFilename.value);
}
topicModel.estimate();
if (topicKeysFile.value != null) {
topicModel.printTopWords(new File(topicKeysFile.value), topWords.value, false);
}
if (stateFile.value != null) {
topicModel.printState (new File(stateFile.value));
}
if (docTopicsFile.value != null) {
PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value);
out.close();
}
if (inferencerFilename.value != null) {
try {
for (int language = 0; language < numLanguages; language++) {
ObjectOutputStream oos =
new ObjectOutputStream(new FileOutputStream(inferencerFilename.value + "." + language));
oos.writeObject(topicModel.getInferencer(language));
oos.close();
}
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
if (outputModelFilename.value != null) {
assert (topicModel != null);
try {
ObjectOutputStream oos =
new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
oos.writeObject (topicModel);
oos.close();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException ("Couldn't write topic model to filename "+outputModelFilename.value);
}
}
}
else {
// Start a new LDA topic model
ParallelTopicModel topicModel = null;
if (inputModelFilename.value != null) {
try {
topicModel = ParallelTopicModel.read(new File(inputModelFilename.value));
} catch (Exception e) {
System.err.println("Unable to restore saved topic model " +
inputModelFilename.value + ": " + e);
System.exit(1);
}
/*
// Loading new data is optional if we are restoring a saved state.
if (inputFile.value != null) {
InstanceList instances = InstanceList.load (new File(inputFile.value));
System.out.println ("Data loaded.");
lda.addInstances(instances);
}
*/
}
else {
InstanceList training = null;
try {
if (inputFile.value.startsWith("db:")) {
training = DBInstanceIterator.getInstances(inputFile.value.substring(3));
}
else {
training = InstanceList.load (new File(inputFile.value));
}
} catch (Exception e) {
System.err.println("Unable to restore instance list " +
inputFile.value + ": " + e);
System.exit(1);
}
System.out.println ("Data loaded.");
if (training.size() > 0 &&
training.get(0) != null) {
Object data = training.get(0).getData();
if (! (data instanceof FeatureSequence)) {
System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
System.exit(1);
}
}