model.vectorToParams(theta);
}
public static void train(SentimentModel model, String modelPath, List<Tree> trainingTrees, List<Tree> devTrees) {
Timing timing = new Timing();
long maxTrainTimeMillis = model.op.trainOptions.maxTrainTimeSeconds * 1000;
int debugCycle = 0;
double bestAccuracy = 0.0;
// train using AdaGrad (seemed to work best during the dvparser project)
double[] sumGradSquare = new double[model.totalParamSize()];
Arrays.fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight);
int numBatches = trainingTrees.size() / model.op.trainOptions.batchSize + 1;
System.err.println("Training on " + trainingTrees.size() + " trees in " + numBatches + " batches");
System.err.println("Times through each training batch: " + model.op.trainOptions.epochs);
for (int epoch = 0; epoch < model.op.trainOptions.epochs; ++epoch) {
System.err.println("======================================");
System.err.println("Starting epoch " + epoch);
if (epoch > 0 && model.op.trainOptions.adagradResetFrequency > 0 &&
(epoch % model.op.trainOptions.adagradResetFrequency == 0)) {
System.err.println("Resetting adagrad weights to " + model.op.trainOptions.initialAdagradWeight);
Arrays.fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight);
}
List<Tree> shuffledSentences = Generics.newArrayList(trainingTrees);
Collections.shuffle(shuffledSentences, model.rand);
for (int batch = 0; batch < numBatches; ++batch) {
System.err.println("======================================");
System.err.println("Epoch " + epoch + " batch " + batch);
// Each batch will be of the specified batch size, except the
// last batch will include any leftover trees at the end of
// the list
int startTree = batch * model.op.trainOptions.batchSize;
int endTree = (batch + 1) * model.op.trainOptions.batchSize;
if (endTree + model.op.trainOptions.batchSize > shuffledSentences.size()) {
endTree = shuffledSentences.size();
}
executeOneTrainingBatch(model, shuffledSentences.subList(startTree, endTree), sumGradSquare);
long totalElapsed = timing.report();
System.err.println("Finished epoch " + epoch + " batch " + batch + "; total training time " + totalElapsed + " ms");
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
break;
}
if (batch == 0 && epoch > 0 && epoch % model.op.trainOptions.debugOutputEpochs == 0) {
double score = 0.0;
if (devTrees != null) {
Evaluate eval = new Evaluate(model);
eval.eval(devTrees);
eval.printSummary();
score = eval.exactNodeAccuracy() * 100.0;
}
// output an intermediate model
if (modelPath != null) {
String tempPath = modelPath;
if (modelPath.endsWith(".ser.gz")) {
tempPath = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score) + ".ser.gz";
} else if (modelPath.endsWith(".gz")) {
tempPath = modelPath.substring(0, modelPath.length() - 3) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score) + ".gz";
} else {
tempPath = modelPath.substring(0, modelPath.length() - 3) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(score);
}
model.saveSerialized(tempPath);
}
++debugCycle;
}
}
long totalElapsed = timing.report();
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
System.err.println("Max training time exceeded, exiting");
break;
}