// we use QN to minimize the cost function for the model
// to do this minimization, we turn all of the matrices in the
// DVModel into one big Theta, which is the set of variables to
// be optimized by the QN.
Timing timing = new Timing();
long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
int batchCount = 0;
int debugCycle = 0;
double bestLabelF1 = 0.0;
if (op.trainOptions.useContextWords) {
for (Tree tree : sentences) {
Trees.convertToCoreLabels(tree);
tree.setSpans();
}
}
// for AdaGrad
double[] sumGradSquare = new double[dvModel.totalParamSize()];
Arrays.fill(sumGradSquare, 1.0);
int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
System.err.println("Training on " + sentences.size() + " trees in " + numBatches + " batches");
System.err.println("Times through each training batch: " + op.trainOptions.trainingIterations);
System.err.println("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
List<Tree> shuffledSentences = new ArrayList<Tree>(sentences);
Collections.shuffle(shuffledSentences, dvModel.rand);
for (int batch = 0; batch < numBatches; ++batch) {
++batchCount;
// This did not help performance
//System.err.println("Setting AdaGrad's sum of squares to 1...");
//Arrays.fill(sumGradSquare, 1.0);
System.err.println("======================================");
System.err.println("Iteration " + iter + " batch " + batch);
// Each batch will be of the specified batch size, except the
// last batch will include any leftover trees at the end of
// the list
int startTree = batch * op.trainOptions.batchSize;
int endTree = (batch + 1) * op.trainOptions.batchSize;
if (endTree + op.trainOptions.batchSize > shuffledSentences.size()) {
endTree = shuffledSentences.size();
}
executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);
long totalElapsed = timing.report();
System.err.println("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
break;
}
if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
System.err.println("Finished " + batchCount + " total batches, running evaluation cycle");
// Time for debugging output!
double tagF1 = 0.0;
double labelF1 = 0.0;
if (testTreebank != null) {
EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
evaluator.testOnTreebank(testTreebank);
labelF1 = evaluator.getLBScore();
tagF1 = evaluator.getTagScore();
if (labelF1 > bestLabelF1) {
bestLabelF1 = labelF1;
}
System.err.println("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
}
String tempName = null;
if (modelPath != null) {
tempName = modelPath;
if (modelPath.endsWith(".ser.gz")) {
tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
}
saveModel(tempName);
}
String statusLine = ("CHECKPOINT:" +
" iteration " + iter +
" batch " + batch +
" labelF1 " + NF.format(labelF1) +
" tagF1 " + NF.format(tagF1) +
" bestLabelF1 " + NF.format(bestLabelF1) +
" model " + tempName +
op.trainOptions +
" word vectors: " + op.lexOptions.wordVectorFile +
" numHid: " + op.lexOptions.numHid);
System.err.println(statusLine);
if (resultsRecordPath != null) {
FileWriter fout = new FileWriter(resultsRecordPath, true); // append
fout.write(statusLine);
fout.write("\n");
fout.close();
}
++debugCycle;
}
}
long totalElapsed = timing.report();
if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
// no need to debug output, we're done now
System.err.println("Max training time exceeded, exiting");
break;