package edu.stanford.nlp.optimization;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Random;
/**
* In place Stochastic Gradient Descent Minimizer.
* <ul>
* <li> Follows weight decay and tuning of learning parameter of crfsgd of
* Leon Bottou: http://leon.bottou.org/projects/sgd
* <li> Only supports L2 regularization (QUADRATIC)
* <li> Requires objective function to be an AbstractStochasticCachingDiffUpdateFunction.
* </ul>
* NOTE: unlike other minimizers, regularization is done in the minimizer, not the objective function.
*
* This class was previously called StochasticInPlaceMinimizer. This is now SGDMinimizer, and the old SGDMinimizer is now InefficientSGDMinimizer.
*
* @author Angel Chang
*/
public class SGDMinimizer<T extends Function> implements Minimizer<T>, HasEvaluators {
protected double xscale, xnorm;
protected double[] x;
protected int t0; // Initial stochastic iteration count
protected final double sigma;
protected double lambda;
protected boolean quiet = false;
private static final int DEFAULT_NUM_PASSES = 50;
protected final int numPasses; //-1;
protected int bSize = 1; // NOTE: If bSize does not divide evenly into total number of samples,
// some samples may get accounted for twice in one pass
private static final int DEFAULT_TUNING_SAMPLES = 1000;
protected final int tuningSamples;
protected Random gen = new Random(1);
protected long maxTime = Long.MAX_VALUE;
private int evaluateIters = 0; // Evaluate every x iterations (0 = no evaluation)
private Evaluator[] evaluators; // separate set of evaluators to check how optimization is going
public SGDMinimizer(double sigma, int numPasses)
{
this(sigma, numPasses, -1, 1);
}
public SGDMinimizer(double sigma, int numPasses, int tuningSamples) {
this(sigma, numPasses, tuningSamples, 1);
}
public SGDMinimizer(double sigma, int numPasses, int tuningSamples, int batchSize)
{
this.bSize = batchSize;
this.sigma = sigma;
if (numPasses >= 0) {
this.numPasses = numPasses;
} else {
this.numPasses = DEFAULT_NUM_PASSES;
sayln(" SGDMinimizer: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
}
if (tuningSamples > 0) {
this.tuningSamples = tuningSamples;
} else {
this.tuningSamples = DEFAULT_TUNING_SAMPLES;
sayln(" SGDMinimizer: tuneSampleSize=" + tuningSamples + ", defaulting to " + this.tuningSamples);
}
}
public SGDMinimizer(LogPrior prior, int numPasses, int batchSize, int tuningSamples)
{
if (LogPrior.LogPriorType.QUADRATIC == prior.getType()) {
sigma = prior.getSigma();
} else {
throw new RuntimeException("Unsupported prior type " + prior.getType());
}
if (numPasses >= 0) {
this.numPasses = numPasses;
} else {
this.numPasses = DEFAULT_NUM_PASSES;
sayln(" SGDMinimizer: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
}
this.bSize = batchSize;
if (tuningSamples > 0) {
this.tuningSamples = tuningSamples;
} else {
this.tuningSamples = DEFAULT_TUNING_SAMPLES;
sayln(" SGDMinimizer: tuneSampleSize=" + tuningSamples + ", defaulting to " + this.tuningSamples);
}
}
public void shutUp() {
this.quiet = true;
}
private static final NumberFormat nf = new DecimalFormat("0.000E0");
protected String getName() {
return "SGD_InPlace_b" + bSize + "_lambda" + nf.format(lambda);
}
@Override
public void setEvaluators(int iters, Evaluator[] evaluators)
{
this.evaluateIters = iters;
this.evaluators = evaluators;
}
//This can be filled if an extending class needs to initialize things.
protected void init(AbstractStochasticCachingDiffUpdateFunction func) {
}
public double getObjective(AbstractStochasticCachingDiffUpdateFunction function, double[] w, double wscale, int[] sample)
{
double wnorm = getNorm(w) * wscale*wscale;
double obj = function.valueAt(w,wscale,sample);
// Calculate objective with L2 regularization
return obj + 0.5*sample.length*lambda*wnorm;
}
public double tryEta(AbstractStochasticCachingDiffUpdateFunction function, double[] initial, int[] sample, double eta)
{
int numBatches = sample.length / bSize;
double[] w = new double[initial.length];
double wscale = 1;
System.arraycopy(initial, 0, w, 0, w.length);
int[] sampleBatch = new int[bSize];
int sampleIndex = 0;
for (int batch = 0; batch < numBatches; batch++) {
for (int i = 0; i < bSize; i++) {
sampleBatch[i] = sample[(sampleIndex + i) % sample.length];
}
sampleIndex += bSize;
double gain = eta/wscale;
function.calculateStochasticUpdate(w, wscale, sampleBatch, gain);
wscale *= (1 - eta * lambda*bSize);
}
double obj = getObjective(function, w, wscale, sample);
return obj;
}
/**
* Finds a good learning rate to start with.
* eta = 1/(lambda*(t0+t)) - we find good t0
* @param function
* @param initial
* @param sampleSize
* @param seta
*/
public double tune(AbstractStochasticCachingDiffUpdateFunction function, double[] initial, int sampleSize, double seta)
{
Timing timer = new Timing();
int[] sample = function.getSample(sampleSize);
double sobj = getObjective(function, initial, 1, sample);
double besteta = 1;
double bestobj = sobj;
double eta = seta;
int totest = 10;
double factor = 2;
boolean phase2 = false;
while (totest > 0 || !phase2)
{
double obj = tryEta(function, initial, sample, eta);
boolean okay = (obj < sobj);
sayln(" Trying eta=" + eta + " obj=" + obj + ((okay)? "(possible)":"(too large)"));
if (okay)
{
totest -= 1;
if (obj < bestobj) {
bestobj = obj;
besteta = eta;
}
}
if (! phase2)
{
if (okay) {
eta = eta * factor;
} else {
phase2 = true;
eta = seta;
}
}
if (phase2) {
eta = eta / factor;
}
}
// take it on the safe side (implicit regularization)
besteta /= factor;
// determine t
t0 = (int) (1 / (besteta * lambda));
sayln(" Taking eta=" + besteta + " t0=" + t0);
sayln(" Tuning completed in: " + Timing.toSecondsString(timer.report()) + " s");
return besteta;
}
// really this is the square of the L2 norm....
private static double getNorm(double[] w)
{
double norm = 0;
for (int i = 0; i < w.length; i++) {
norm += w[i]*w[i];
}
return norm;
}
private void rescale()
{
if (xscale == 1) return;
for (int i = 0; i < x.length; i++) {
x[i] *= xscale;
}
xscale = 1;
}
private void doEvaluation(double[] x) {
// Evaluate solution
if (evaluators == null) return;
for (Evaluator eval:evaluators) {
sayln(" Evaluating: " + eval.toString());
eval.evaluate(x);
}
}
@Override
public double[] minimize(Function function, double functionTolerance, double[] initial) {
return minimize(function, functionTolerance, initial, -1);
}
@Override
public double[] minimize(Function f, double functionTolerance, double[] initial, int maxIterations) {
if (!(f instanceof AbstractStochasticCachingDiffUpdateFunction)) {
throw new UnsupportedOperationException();
}
AbstractStochasticCachingDiffUpdateFunction function = (AbstractStochasticCachingDiffUpdateFunction) f;
int totalSamples = function.dataDimension();
int tuneSampleSize = Math.min(totalSamples, tuningSamples);
if (tuneSampleSize < tuningSamples) {
System.err.println("WARNING: Total number of samples=" + totalSamples +
" is smaller than requested tuning sample size=" + tuningSamples + "!!!");
}
lambda = 1.0/(sigma*totalSamples);
sayln("Using sigma=" + sigma + " lambda=" + lambda + " tuning sample size " + tuneSampleSize);
// tune(function, initial, tuneSampleSize, 0.1);
t0 = (int) (1 / (0.1 * lambda));
x = new double[initial.length];
System.arraycopy(initial, 0, x, 0, x.length);
xscale = 1;
xnorm = getNorm(x);
int numBatches = totalSamples/ bSize;
init(function);
boolean have_max = (maxIterations > 0 || numPasses > 0);
if (!have_max){
throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
} else{
maxIterations = Math.max(maxIterations, numPasses)*numBatches;
}
sayln(" Batch size of: " + bSize);
sayln(" Data dimension of: " + totalSamples );
sayln(" Batches per pass through data: " + numBatches );
sayln(" Number of passes is = " + numPasses);
sayln(" Max iterations is = " + maxIterations);
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Loop
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Timing total = new Timing();
Timing current = new Timing();
total.start();
current.start();
int t=t0;
int iters = 0;
for (int pass = 0; pass < numPasses; pass++) {
boolean doEval = (pass > 0 && evaluateIters > 0 && pass % evaluateIters == 0);
if (doEval) {
rescale();
doEvaluation(x);
}
double totalValue = 0;
double lastValue = 0;
say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
for (int batch = 0; batch < numBatches; batch++) {
iters++;
//Get the next X
double eta = 1/(lambda*t);
double gain = eta/xscale;
lastValue = function.calculateStochasticUpdate(x, xscale, bSize, gain);
totalValue += lastValue;
// weight decay (for L2 regularization)
xscale *= (1 - eta * lambda*bSize);
t+=bSize;
}
if (xscale < 1e-6) {
rescale();
}
try {
ArrayMath.assertFinite(x,"x");
} catch (ArrayMath.InvalidElementException e) {
System.err.println(e.toString());
for(int i=0;i<x.length;i++){ x[i]=Double.NaN; }
break;
}
xnorm = getNorm(x)*xscale*xscale;
// Calculate loss based on L2 regularization
double loss = totalValue + 0.5 * xnorm * lambda * totalSamples;
say(String.valueOf(numBatches));
say("[" + ( total.report() )/1000.0 + " s " );
say("{" + (current.restart()/1000.0) + " s}] ");
sayln(" "+lastValue + " " + totalValue + " " + loss);
if (iters >= maxIterations) {
sayln("Stochastic Optimization complete. Stopped after max iterations");
break;
}
if (total.report() >= maxTime){
sayln("Stochastic Optimization complete. Stopped after max time");
break;
}
}
rescale();
if (evaluateIters > 0) {
// do final evaluation
doEvaluation(x);
}
sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");
return x;
}
protected void sayln(String s) {
if (!quiet) {
System.err.println(s);
}
}
protected void say(String s) {
if (!quiet) {
System.err.print(s);
}
}
}