Package JKernelMachines.fr.lip6.classifier

Source Code of JKernelMachines.fr.lip6.classifier.SOGradPKL

package JKernelMachines.fr.lip6.classifier;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.List;

import JKernelMachines.fr.lip6.kernel.Kernel;
import JKernelMachines.fr.lip6.kernel.SimpleCacheKernel;
import JKernelMachines.fr.lip6.kernel.adaptative.ThreadedProductKernel;
import JKernelMachines.fr.lip6.threading.ThreadedMatrixOperator;
import JKernelMachines.fr.lip6.type.TrainingSample;

public class SOGradPKL<T> implements Classifier<T> {

  List<TrainingSample<T>> listOfExamples;
  List<Double> listOfExampleWeights;
  List<Kernel<T>> listOfKernels;
  List<Double> listOfKernelWeights;
 
  SMOSVM<T> svm;
 
  double stopGap = 1e-5;
  double num_cleaning = 1e-8;
  double p_norm = 1;
 
  double C = 1e0;
  boolean traceNorm = false;
 
  double d_lambda;
  double [][] lambda_matrix = null;
  double oldObjective;
 
  boolean cache = true;

  public SOGradPKL()
  {
    listOfKernels = new ArrayList<Kernel<T>>();
    listOfKernelWeights = new ArrayList<Double>();
    listOfExamples = new ArrayList<TrainingSample<T>>();
    listOfExampleWeights = new ArrayList<Double>();
  }
 
  public void addKernel(Kernel<T> k)
  {
    listOfKernels.add(k);
    listOfKernelWeights.add(1.0);
  }
 

  public void train(TrainingSample<T> t) {
    if(listOfExamples == null)
      listOfExamples = new ArrayList<TrainingSample<T>>();
    if(!listOfExamples.contains(t))
      listOfExamples.add(t);
    train(listOfExamples);
  }


  public void train(List<TrainingSample<T>> l) {

    long tim = System.currentTimeMillis();
    eprintln(2, "training on "+listOfKernels.size()+" kernels and "+l.size()+" examples");
   
    //1. init kernels
    ArrayList<Kernel<T>> kernels = new ArrayList<Kernel<T>>();
    ArrayList<Double> weights = new ArrayList<Double>();
   
    //normalize to cst trace and init weights to 1/N
    for(int i = 0 ; i < listOfKernels.size(); i++)
    {
      if(cache)
      {
        eprintln(3, "+ cache is set, computing cache");
        SimpleCacheKernel<T> sck = new SimpleCacheKernel<T>(listOfKernels.get(i), l);
        sck.setName(listOfKernels.get(i).toString());
        double[][] matrix = sck.getKernelMatrix(l);
        if(traceNorm)
        {
          //compute trace
          double trace = 0.;
          for(int x = 0 ; x < matrix.length; x++)
          {
            trace += matrix[x][x];
          }
          //divide by trace
          for(int x = 0 ; x < matrix.length; x++)
            for(int y = x ; y < matrix.length; y++)
            {
              matrix[x][y] *= matrix.length/(double)trace;
              matrix[y][x] = matrix[x][y];
            }
        }
        kernels.add(sck);
      }
      else
      {
        eprintln(3, "+ cache is not set, skipping cache");
        kernels.add(listOfKernels.get(i));
      }
      weights.add(Math.pow(1/(double)listOfKernels.size(), 1/(double)p_norm));
      eprintln(3, "+ kernel : "+kernels.get(i)+" weight : "+weights.get(i));
    }
   
   
    //1 train first svm
    ThreadedProductKernel<T> tpk = new ThreadedProductKernel<T>();
    for(int i = 0 ; i < kernels.size(); i++)
      tpk.addKernel(kernels.get(i), weights.get(i));
    svm = new SMOSVM<T>(tpk);
    svm.setC(C);
    svm.setVerbosityLevel(VERBOSITY_LEVEL-1);
    eprintln(3, "+ training svm");
    svm.train(l);
    double[] a = svm.getAlphas();
    //update lambda matrix
    updateLambdaMatrix(a, tpk, l);
    //compute old value of objective function
    oldObjective = computeObj(a, tpk, l);
   
    eprintln(3, "+ initial weights : "+weights);
   
    //2. big loop
    double gap = 0;
    do
    {           
      //perform one step
      double objEvol = performPKLStep(kernels, weights, l);
     
      if(objEvol < 0)
      {
        eprintln(1, "Error, performPKLStep return wrong value");
        System.exit(0);;
      }
      gap = 1 - objEvol;
     
      eprintln(1, "+ objective_gap : "+(float)gap);
      eprintln(1, "+");
     
    }
    while(gap >= stopGap);
   
    //3. save weights
    listOfKernelWeights.clear();
    listOfKernelWeights.addAll(weights);
   
    //4. retrain svm
    //compute sum kernel
    tpk = new ThreadedProductKernel<T>();
    for(int i = 0 ; i < kernels.size(); i++)
      tpk.addKernel(listOfKernels.get(i), listOfKernelWeights.get(i));
    //train svm
    svm.setKernel(tpk);
    eprintln(3, "+ retraining svm");
    svm.retrain();
   
    //5. save examples weights
    listOfExamples.addAll(l);
    listOfExampleWeights.clear();
    for(double d : svm.getAlphas())
      listOfExampleWeights.add(d);

    eprintln(1, "PKL trained in "+(System.currentTimeMillis()-tim)+" milis.");
  }
 
  /**
   * perform one approximate second order gradient descent step
   * @param kernels
   * @param weights
   * @param l
   * @return
   */
  private double performPKLStep(ArrayList<Kernel<T>> kernels, ArrayList<Double> weights, List<TrainingSample<T>> l)
  {

    //store new as old for the loop
    double objective = oldObjective;
   
    eprintln(3, "+++ old weights : "+weights);
    eprintln(3, "+++ oldObjective : "+oldObjective+" sumAlpha : "+computeSumAlpha());
   
   
    //compute grad
    double [] grad = gradBeta(kernels, weights, l);
    double [] sgrad = secondGradBeta(kernels, weights, l);
   
   
    double newBeta[] = new double[grad.length];

    //update weights in this direction until objective decreases
    d_lambda = 1.0; //reset learning rate
    do
    {
      for(int i = 0 ; i < grad.length; i++)
      {
        //second order update
        if(sgrad[i] != 0.)
          newBeta[i] = weights.get(i) * (1 - d_lambda*grad[i]/sgrad[i]);
        if(newBeta[i] < num_cleaning)
          newBeta[i] = 0;

      }

      //normalize
      double norm = 0;
      if(p_norm == 1) // L1-norm
      {
        for(int i = 0 ; i < newBeta.length; i++)
          norm += Math.abs(newBeta[i]);
      }
      else //Lp-norm
      {
        for(int i = 0 ; i < newBeta.length; i++)
          norm += Math.pow(newBeta[i], p_norm);
        norm = Math.pow(norm, -1/(double)p_norm);
      }
      if(norm < 0)
      {
        eprintln(1, "Error normalization, norm < 0");
        return -1;
      }
      eprintln(3, "+++ norm : "+norm);
      //normalize
      for(int i = 0 ; i < newBeta.length; i++)
        newBeta[i] /= (double)(norm);
     


      //compute new objective function
      ThreadedProductKernel<T> pk = new ThreadedProductKernel<T>();
      for( int i = 0 ; i < kernels.size(); i++)
        pk.addKernel(kernels.get(i), newBeta[i]);
      //train svm
      svm.setKernel(pk);
      eprintln(3, "+ retraining svm");
      svm.retrain();
      double[] a = svm.getAlphas();
      //update lambda
      updateLambdaMatrix(a, pk, l);
      //new objective
      objective = computeObj(a, pk, l);

      if(objective < oldObjective + num_cleaning) // did the objective at least stay similar
      {
        //store new weights
        for(int i = 0 ; i < weights.size(); i++)
          weights.set(i, newBeta[i]);
        eprintln(3, "+++ new weights : "+weights);
      }
      else //if not, reduce learning rate exponentially
      {
        if(d_lambda > num_cleaning)
          d_lambda /= 8.;
        else
        {
          d_lambda = 0.;
          eprint(3, "+++ d_lambda is zero, stopping.");
          eprintln(2, "");
          break;
        }
        eprint(2, "+");
        eprintln(3, "++ new objective ("+(float)objective+") did not decrease ("+(float)oldObjective+"), reducing step : "+d_lambda);
      }
     
    }
    while(oldObjective + num_cleaning < objective);
   
    eprintln(2, "+ objective : "+(float)objective+"\t+\t sumAlpha : "+(float)computeSumAlpha());

    double gap = objective/oldObjective;
   
    //store objective as oldObjective
    oldObjective = objective;

    //return objective evolution
    return gap;
  }
     
 
  /** calcul du gradient en chaque beta */
  private double [] gradBeta(ArrayList<Kernel<T>> kernels, ArrayList<Double> weights, List<TrainingSample<T>> l)
  {
    double grad[] = new double[kernels.size()];

    //doing <L, Dn>
    for(int i = 0 ; i < kernels.size(); i++)
    {
      double matrix[][] = kernels.get(i).getKernelMatrix(l);
      for(int x = 0 ; x < matrix.length; x++)
      {
        for(int y = x ; y < matrix.length; y++)
        {
          if(matrix[x][y] == 0)
            continue;
          grad[i] += - Math.log(matrix[x][y]) * lambda_matrix[x][y];
        }
      }
    }
   
    eprintln(4, "++++++ gradDir : "+Arrays.toString(grad));
   
    return grad;
  }
 
   
 
  /** calcul du gradient second en chaque beta */
  private double [] secondGradBeta(ArrayList<Kernel<T>> kernels, ArrayList<Double> weights, List<TrainingSample<T>> l)
  {
    double grad[] = new double[kernels.size()];
   
    //doing <L, Dn.^2>
    for(int i = 0 ; i < kernels.size(); i++)
    {
      double matrix[][] = kernels.get(i).getKernelMatrix(l);
      for(int x = 0 ; x < matrix.length; x++)
      {
        for(int y = x ; y < matrix.length; y++)
        {
          if(matrix[x][y] == 0)
            continue;
          double d = Math.log(matrix[x][y]);
          grad[i] += d * d * lambda_matrix[x][y];
        }
      }
    }
   
    for(int i = 0 ; i < grad.length; i++)
      if(grad[i] < num_cleaning)
        grad[i] = 0.0;
   
    eprintln(4, "++++++ secondGradDir : "+Arrays.toString(grad));
   
    return grad;
  }
 
  /** compute the sum of examples weights */
  private double computeSumAlpha()
  {
    double sum = 0;
    double[] a = svm.getAlphas();
    for(double d : a)
      sum += Math.abs(d);
    return sum;
  }
 
  /** compute obj */
  private double computeObj(double[] a, Kernel<T> kernel, List<TrainingSample<T>> l)
  {
    double obj = 0;
   
    //sum of alpha
    for(double aa : a)
      obj += aa;
   
   
    for(int x = 0 ; x < lambda_matrix.length; x++)
    {
      for(int y = x ; y < lambda_matrix.length; y++)
      {
        if(lambda_matrix[x][y] == 0)
          continue;
        if(x != y)
          obj += 2*lambda_matrix[x][y];
        else
          obj += lambda_matrix[x][y];
      }
    }
    return obj;
  }
 
  /** compute the lambda matrix */
  private void updateLambdaMatrix(final double[] a, Kernel<T> kernel, final List<TrainingSample<T>> l)
  {
    final double [][] matrix = kernel.getKernelMatrix(l);
    lambda_matrix = new double[matrix.length][matrix.length];
   
//    for(int x = 0 ; x < matrix.length; x++)
//    {
//      int l1 = l.get(x).label;
//      for(int y = x ; y < matrix.length; y++)
//      {
//        if(matrix[x][y] == 0)
//          continue;
//        int l2 = l.get(y).label;
//        lambda_matrix[x][y] = -0.5 * l1 * l2 * a[x] * a[y] * matrix[x][y];
//        lambda_matrix[y][x] = lambda_matrix[x][y];
//      }
//    }
//   
    eprintln(3, "+ update lambda");
    ThreadedMatrixOperator factory = new ThreadedMatrixOperator()
    {
      @Override
      public void doLine(int index, double[] line) {
        int l1 = l.get(index).label;
        double al1 = -0.5 * a[index]*l1;
        for(int j = line.length-1 ; j != 0 ; j--)
        {
          int l2 = l.get(j).label;
          line[j] = al1 * l2 * a[j] * matrix[index][j];
        }
      }

     
    };
   
    lambda_matrix = factory.getMatrix(lambda_matrix);
  }


  public double valueOf(T e) {
   
    return svm.valueOf(e);
  }
 

  private int VERBOSITY_LEVEL = 0;
 
  /**
   * set how verbose SimpleMKL shall be. <br />
   * Everything is printed to stderr. <br />
   * none : 0 (default), few  : 1, more : 2, all : 3
   * @param l
   */
  public void setVerbosityLevel(int l)
  {
    VERBOSITY_LEVEL = l;
  }
 
  public boolean isCache() {
    return cache;
  }

  public void setCache(boolean cache) {
    this.cache = cache;
  }

  public void eprint(int level, String s)
  {
    if(VERBOSITY_LEVEL >= level)
      System.err.print(s);
  }
 
  public void eprintln(int level, String s)
  {
    if(VERBOSITY_LEVEL >= level)
      System.err.println(s);
  }

  public double getC() {
    return C;
  }

  public void setC(double c) {
    C = c;
  }

  public void setMKLNorm(double p)
  {
    p_norm = p;
  }
 
  public void setStopGap(double w)
  {
    stopGap = w;
  }
 
  public void setTraceNorm(boolean traceNorm) {
    this.traceNorm = traceNorm;
  }

  public double getNum_cleaning() {
    return num_cleaning;
  }

  public void setNum_cleaning(double num_cleaning) {
    this.num_cleaning = num_cleaning;
  }

  public List<Double> getExampleWeights() {
    return listOfExampleWeights;
  }
 
  public List<Double> getKernelWeights()
  {
    return listOfKernelWeights;
  }
 
  public Hashtable<Kernel<T>, Double> getWeights()
  {
    Hashtable<Kernel<T>, Double> map = new Hashtable<Kernel<T>, Double>();
    for(int i = 0 ; i < listOfKernels.size(); i++)
      map.put(listOfKernels.get(i), listOfKernelWeights.get(i));
    return map;
  }

}
TOP

Related Classes of JKernelMachines.fr.lip6.classifier.SOGradPKL

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.