Package LM

Source Code of LM.LM

// levenberg-marquardt in java
//
// To use this, implement the functions in the LMfunc interface.
//
// This library uses simple matrix routines from the JAMA java matrix package,
// which is in the public domain.  Reference:
//    http://math.nist.gov/javanumerics/jama/
// (JAMA has a matrix object class.  An earlier library JNL, which is no longer
// available, represented matrices as low-level arrays.  Several years
// ago the performance of JNL matrix code was better than that of JAMA,
// though improvements in java compilers may have fixed this by now.)
//
// One further recommendation would be to use an inverse based
// on Choleski decomposition, which is easy to implement and
// suitable for the symmetric inverse required here.  There is a choleski
// routine at idiom.com/~zilla.
//
// If you make an improved version, please consider adding your
// name to it ("modified by ...") and send it back to me
// (and put it on the web).
//
// ----------------------------------------------------------------
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Library General Public License for more details.
//
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 59 Temple Place - Suite 330,
// Boston, MA  02111-1307, USA.
//
// initial author contact info: 
// jplewis  www.idiom.com/~zilla  zilla # computer.org,   #=at
//
// Improvements by:
// dscherba  www.ncsa.uiuc.edu/~dscherba 
// Jonathan Jackson   j.jackson # ucl.ac.uk


package LM;

// see comment above
import Jama.Matrix;

/**
* Levenberg-Marquardt, implemented from the general description
* in Numerical Recipes (NR), then tweaked slightly to mostly
* match the results of their code.
* Use for nonlinear least squares assuming Gaussian errors.
*
* TODO this holds some parameters fixed by simply not updating them.
* this may be ok if the number if fixed parameters is small,
* but if the number of varying parameters is larger it would
* be more efficient to make a smaller hessian involving only
* the variables.
*
* The NR code assumes a statistical context, e.g. returns
* covariance of parameter errors; we do not do this.
*/
public final class LM
{

  /**
   * calculate the current sum-squared-error
   * (Chi-squared is the distribution of squared Gaussian errors,
   * thus the name)
   */
  static double chiSquared(double[][] x, double[] a, double[] y, double[] s,
         LMfunc f)
  {
    int npts = y.length;
    double sum = 0.;

    for( int i = 0; i < npts; i++ ) {
      double d = y[i] - f.val(x[i], a);
      d = d / s[i];
      sum = sum + (d*d);
    }

    return sum;
  } //chiSquared


  /**
   * Minimize E = sum {(y[k] - f(x[k],a)) / s[k]}^2
   * The individual errors are optionally scaled by s[k].
   * Note that LMfunc implements the value and gradient of f(x,a),
   * NOT the value and gradient of E with respect to a!
   *
   * @param x array of domain points, each may be multidimensional
   * @param y corresponding array of values
   * @param a the parameters/state of the model
   * @param vary false to indicate the corresponding a[k] is to be held fixed
   * @param s2 sigma^2 for point i
   * @param lambda blend between steepest descent (lambda high) and
   *  jump to bottom of quadratic (lambda zero).
   *   Start with 0.001.
   * @param termepsilon termination accuracy (0.01)
   * @param maxiter  stop and return after this many iterations if not done
   * @param verbose  set to zero (no prints), 1, 2
   *
   * @return the new lambda for future iterations.
   *  Can use this and maxiter to interleave the LM descent with some other
   *  task, setting maxiter to something small.
   */
  public static double solve(double[][] x, double[] a, double[] y, double[] s,
           boolean[] vary, LMfunc f,
           double lambda, double termepsilon, int maxiter,
           int verbose)
    throws Exception
  {
    int npts = y.length;
    int nparm = a.length;
    assert s.length == npts;
    assert x.length == npts;
    if (verbose > 0) {
      System.out.print("solve x["+x.length+"]["+x[0].length+"]" );
      System.out.print(" a["+a.length+"]");
      System.out.println(" y["+y.length+"]");
    }

    double e0 = chiSquared(x, a, y, s, f);
    //double lambda = 0.001;
    boolean done = false;

    // g = gradient, H = hessian, d = step to minimum
    // H d = -g, solve for d
    double[][] H = new double[nparm][nparm];
    double[] g = new double[nparm];
    //double[] d = new double[nparm];

    double[] oos2 = new double[s.length];
    for( int i = 0; i < npts; i++ oos2[i] = 1./(s[i]*s[i]);

    int iter = 0;
    int term = 0// termination count test

    do {
      ++iter;

      // hessian approximation
      for( int r = 0; r < nparm; r++ ) {
  for( int c = 0; c < nparm; c++ ) {
    for( int i = 0; i < npts; i++ ) {
      if (i == 0) H[r][c] = 0.;
      double[] xi = x[i];
      H[r][c] += (oos2[i] * f.grad(xi, a, r) * f.grad(xi, a, c));
    //npts
  } //c
      } //r

      // boost diagonal towards gradient descent
      for( int r = 0; r < nparm; r++ )
  H[r][r] *= (1. + lambda);

      // gradient
      for( int r = 0; r < nparm; r++ ) {
  for( int i = 0; i < npts; i++ ) {
    if (i == 0) g[r] = 0.;
    double[] xi = x[i];
    g[r] += (oos2[i] * (y[i]-f.val(xi,a)) * f.grad(xi, a, r));
  }
      } //npts

      // scale (for consistency with NR, not necessary)
      if (false) {
  for( int r = 0; r < nparm; r++ ) {
    g[r] = -0.5 * g[r];
    for( int c = 0; c < nparm; c++ ) {
      H[r][c] *= 0.5;
    }
  }
      }

      // solve H d = -g, evaluate error at new location
      //double[] d = DoubleMatrix.solve(H, g);
      double[] d = (new Matrix(H)).lu().solve(new Matrix(g, nparm)).getRowPackedCopy();
      //double[] na = DoubleVector.add(a, d);
      double[] na = (new Matrix(a, nparm)).plus(new Matrix(d, nparm)).getRowPackedCopy();
      double e1 = chiSquared(x, na, y, s, f);

      if (verbose > 0) {
  System.out.println("\n\niteration "+iter+" lambda = "+lambda);
  System.out.print("a = ");
        (new Matrix(a, nparm)).print(10, 2);
  if (verbose > 1) {
          System.out.print("H = ");
          (new Matrix(H)).print(10, 2);
          System.out.print("g = ");
          (new Matrix(g, nparm)).print(10, 2);
          System.out.print("d = ");
          (new Matrix(d, nparm)).print(10, 2);
  }
  System.out.print("e0 = " + e0 + ": ");
  System.out.print("moved from ");
        (new Matrix(a, nparm)).print(10, 2);
  System.out.print("e1 = " + e1 + ": ");
  if (e1 < e0) {
    System.out.print("to ");
          (new Matrix(na, nparm)).print(10, 2);
  }
  else {
    System.out.println("move rejected");
  }
      }

      // termination test (slightly different than NR)
      if (Math.abs(e1-e0) > termepsilon) {
  term = 0;
      }
      else {
  term++;
  if (term == 4) {
    System.out.println("terminating after " + iter + " iterations");
    done = true;
  }
      }
      if (iter >= maxiter) done = true;

      // in the C++ version, found that changing this to e1 >= e0
      // was not a good idea.  See comment there.
      //
      if (e1 > e0 || Double.isNaN(e1)) { // new location worse than before
  lambda *= 10.;
      }
      else {    // new location better, accept new parameters
  lambda *= 0.1;
  e0 = e1;
  // simply assigning a = na will not get results copied back to caller
  for( int i = 0; i < nparm; i++ ) {
    if (vary[i]) a[i] = na[i];
  }
      }

    } while(!done);

    return lambda;
  } //solve

  //----------------------------------------------------------------

  /**
   * solve for phase, amplitude and frequency of a sinusoid
   */
  static class LMSineTest implements LMfunc
  {
    static final int  PHASE = 0;
    static final int  AMP = 1;
    static final int  FREQ = 2;

    @Override
    public double[] initial()
    {
      double[] a = new double[3];
      a[PHASE] = 0.;
      a[AMP] = 1.;
      a[FREQ] = 1.;
      return a;
    } //initial

    @Override
    public double val(double[] x, double[] a)
    {
      return a[AMP] * Math.sin(a[FREQ]*x[0] + a[PHASE]);
    } //val

    @Override
    public double grad(double[] x, double[] a, int a_k)
    {
      if (a_k == AMP)
  return Math.sin(a[FREQ]*x[0] + a[PHASE]);

      else if (a_k == FREQ)
  return a[AMP] * Math.cos(a[FREQ]*x[0] + a[PHASE]) * x[0];

      else if (a_k == PHASE)
  return a[AMP] * Math.cos(a[FREQ]*x[0] + a[PHASE]);

      else {
  assert false;
  return 0.;
      }
    } //grad


    @Override
    public Object[] testdata() {
      double[] a = new double[3];
      a[PHASE] = 0.111;
      a[AMP] = 1.222;
      a[FREQ] = 1.333;

      int npts = 10;
      double[][] x = new double[npts][1];
      double[] y = new double[npts];
      double[] s = new double[npts];
      for( int i = 0; i < npts; i++ ) {
  x[i][0] = (double)i / npts;
  y[i] = val(x[i], a);
  s[i] = 1.;
      }

      Object[] o = new Object[4];
      o[0] = x;
      o[1] = a;
      o[2] = y;
      o[3] = s;

      return o;
    } //test

  } //SineTest

  //----------------------------------------------------------------

  /**
   * quadratic (p-o)'S'S(p-o)
   * solve for o, S
   * S is a single scale factor
   */
  static class LMQuadTest implements LMfunc
  {

    @Override
    public double val(double[] x, double[] a)
    {
      assert a.length == 3;
      assert x.length == 2;

      double ox = a[0];
      double oy = a[1];
      double s  = a[2];

      double sdx = s*(x[0] - ox);
      double sdy = s*(x[1] - oy);

      return sdx*sdx + sdy*sdy;
    } //val


    /**
     * z = (p-o)'S'S(p-o)
     * dz/dp = 2S'S(p-o)
     *
     * z = (s*(px-ox))^2 + (s*(py-oy))^2
     * dz/dox = -2(s*(px-ox))*s
     * dz/ds = 2*s*[(px-ox)^2 + (py-oy)^2]

     * z = (s*dx)^2 + (s*dy)^2
     * dz/ds = 2(s*dx)*dx + 2(s*dy)*dy
     */
    @Override
    public double grad(double[] x, double[] a, int a_k)
    {
      assert a.length == 3;
      assert x.length == 2;
      assert a_k < 3: "a_k="+a_k;

      double ox = a[0];
      double oy = a[1];
      double s  = a[2];

      double dx = (x[0] - ox);
      double dy = (x[1] - oy);

      if (a_k == 0
  return -2.*s*s*dx;

      else if (a_k == 1)
  return -2.*s*s*dy;

      else
  return 2.*s*(dx*dx + dy*dy);
    } //grad


    @Override
    public double[] initial()
    {
      double[] a = new double[3];
      a[0] = 0.05;
      a[1] = 0.1;
      a[2] = 1.0;
      return a;
    } //initial


    @Override
    public Object[] testdata()
    {
      Object[] o = new Object[4];
      int npts = 25;
      double[][] x = new double[npts][2];
      double[] y = new double[npts];
      double[] s = new double[npts];
      double[] a = new double[3];

      a[0] = 0.;
      a[1] = 0.;
      a[2] = 0.9;

      int i = 0;
      for( int r = -2; r <= 2; r++ ) {
  for( int c = -2; c <= 2; c++ ) {
    x[i][0] = c;
    x[i][1] = r;
    y[i] = val(x[i], a);
    System.out.println("Quad "+c+","+r+" -> "+y[i]);
    s[i] = 1.;
    i++;
  }
      }
      System.out.print("quad x= ");
      (new Matrix(x)).print(10, 2);

      System.out.print("quad y= ");
      (new Matrix(y,npts)).print(10, 2);


      o[0] = x;
      o[1] = a;
      o[2] = y;
      o[3] = s;
      return o;
    } //testdata

  } //LMQuadTest

  //----------------------------------------------------------------

  /**
   * Replicate the example in NR, fit a sum of Gaussians to data.
   * y(x) = \sum B_k exp(-((x - E_k) / G_k)^2)
   * minimize chisq = \sum { y[j] - \sum B_k exp(-((x_j - E_k) / G_k)^2) }^2
   *
   * B_k, E_k, G_k are stored in that order
   *
   * Works, results are close to those from the NR example code.
   */
  static class LMGaussTest implements LMfunc
  {
    static double SPREAD = 0.001;   // noise variance

    @Override
    public double val(double[] x, double[] a)
    {
      assert x.length == 1;
      assert (a.length%3) == 0;

      int K = a.length / 3;
      int i = 0;

      double y = 0.;
      for( int j = 0; j < K; j++ ) {
  double arg = (x[0] - a[i+1]) / a[i+2];
  double ex = Math.exp(- arg*arg);
  y += (a[i] * ex);
  i += 3;
      }

      return y;
    } //val


    /**
     * <pre>
     * y(x) = \sum B_k exp(-((x - E_k) / G_k)^2)
     * arg  =  (x-E_k)/G_k
     * ex   =  exp(-arg*arg)
     * fac =   B_k * ex * 2 * arg
     *
     * d/dB_k = exp(-((x - E_k) / G_k)^2)
     *
     * d/dE_k = B_k exp(-((x - E_k) / G_k)^2) . -2((x - E_k) / G_k) . -1/G_k
     *        = 2 * B_k * ex * arg / G_k
     *   d/E_k[-((x - E_k) / G_k)^2] = -2((x - E_k) / G_k) d/dE_k[(x-E_k)/G_k]
     *   d/dE_k[(x-E_k)/G_k] = -1/G_k
     *
     * d/G_k = B_k exp(-((x - E_k) / G_k)^2) . -2((x - E_k) / G_k) . -(x-E_k)/G_k^2
     *       = B_k ex -2 arg -arg / G_k
     *       = fac arg / G_k
     *   d/dx[1/x] = d/dx[x^-1] = -x[x^-2]
     */
    @Override
    public double grad(double[] x, double[] a, int a_k)
    {
      assert x.length == 1;

      // i - index one of the K Gaussians
      int i = 3*(a_k / 3);

      double arg = (x[0] - a[i+1]) / a[i+2];
      double ex = Math.exp(- arg*arg);
      double fac = a[i] * ex * 2. * arg;

      if (a_k == i)
  return ex;

      else if (a_k == (i+1)) {
  return fac / a[i+2];
      }

      else if (a_k == (i+2)) {
  return fac * arg / a[i+2];
      }

      else {
  System.err.println("bad a_k");
  return 1.;
      }

    } //grad


    @Override
    public double[] initial()
    {
      double[] a = new double[6];
      a[0] = 4.5;
      a[1] = 2.2;
      a[2] = 2.8;

      a[3] = 2.5;
      a[4] = 4.9;
      a[5] = 2.8;
      return a;
    } //initial


    @Override
    public Object[] testdata()
    {
      Object[] o = new Object[4];
      int npts = 100;
      double[][] x = new double[npts][1];
      double[] y = new double[npts];
      double[] s = new double[npts];
      double[] a = new double[6];

      a[0] = 5.0// values returned by initial
      a[1] = 2.0// should be fairly close to these
      a[2] = 3.0;
      a[3] = 2.0;
      a[4] = 5.0;
      a[5] = 3.0;

      for( int i = 0; i < npts; i++ ) {
  x[i][0] = 0.1*(i+1)// NR always counts from 1
  y[i] = val(x[i], a);
  s[i] = SPREAD * y[i];
  System.out.println(i+": x,y= "+x[i][0]+", "+y[i]);
      }

      o[0] = x;
      o[1] = a;
      o[2] = y;
      o[3] = s;

      return o;
    } //testdata

  } //LMGaussTest

  //----------------------------------------------------------------

  // test program
  public static void main(String[] cmdline)
  {

    LMfunc f = new LMQuadTest();
    //LMfunc f = new LMSineTest();  // works
    //LMfunc f = new LMGaussTest();  // works

    double[] aguess = f.initial();
    Object[] test = f.testdata();
    double[][] x = (double[][])test[0];
    double[] areal = (double[])test[1];
    double[] y = (double[])test[2];
    double[] s = (double[])test[3];
    boolean[] vary = new boolean[aguess.length];
    for( int i = 0; i < aguess.length; i++ ) vary[i] = true;
    assert aguess.length == areal.length;

    try {
      solve( x, aguess, y, s, vary, f, 0.001, 0.01, 100, 2);
    }
    catch(Exception ex) {
      System.err.println("Exception caught: " + ex.getMessage());
      System.exit(1);
    }

    System.out.print("desired solution ");
    (new Matrix(areal, areal.length)).print(10, 2);

    System.exit(0);
  } //main

} //LM
TOP

Related Classes of LM.LM

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.