Package jclassifier

Source Code of jclassifier.Classify$FeatureVectorNormalizedEuclideanDistance

package jclassifier;

import static java.lang.Double.parseDouble;
import static java.lang.Math.sqrt;
import static java.util.Arrays.sort;
import static java.util.Collections.shuffle;
import static net.sourceforge.aprog.tools.MathTools.Statistics.square;
import static net.sourceforge.aprog.tools.Tools.getOrCreate;
import static net.sourceforge.aprog.tools.Tools.unchecked;

import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Scanner;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import net.sourceforge.aprog.tools.CommandLineArgumentsParser;
import net.sourceforge.aprog.tools.Factory;
import net.sourceforge.aprog.tools.Factory.DefaultFactory;
import net.sourceforge.aprog.tools.IllegalInstantiationException;
import net.sourceforge.aprog.tools.MathTools.Statistics;
import net.sourceforge.aprog.tools.SystemProperties;

/**
* @author codistmonk (creation 2013-07-01)
*/
public final class Classify {
 
  private Classify() {
    throw new IllegalInstantiationException();
  }
 
  public static final Factory<Statistics> STATISTICS_FACTORY = (Factory<Statistics>) DefaultFactory.forClass(Statistics.class);
 
  public static final Factory<ArrayList> ARRAY_LIST_FACTORY = (Factory<ArrayList>) DefaultFactory.forClass(ArrayList.class);
 
  /**
   * @param commandLineArguments
   * <br>Must not be null
   */
  public static final void main(final String[] commandLineArguments) throws Exception {
    final CommandLineArgumentsParser arguments = new CommandLineArgumentsParser(commandLineArguments);
    final String trainingFilePath = arguments.get("train", "");
    final String testingFilePath = arguments.get("test", "");
    System.out.println("Creating training base...");
    final Map<double[], Integer> trainingBase = newTrainingBase(trainingFilePath);
    final Statistics statistics = new Statistics();
    final String distanceName = arguments.get("distance", "manhattan").toUpperCase(Locale.ENGLISH);
    final Distance<double[]> distance;
    final int randomReduce = arguments.get("randomReduce", 0)[0];
    final int quantization = arguments.get("q", 0)[0];
   
    if (0 < quantization) {
      System.out.println("Creating quantizers...");
     
      final Factory<Quantizer> quantizerFactory = new Factory<Quantizer>() {
       
        @Override
        public final Quantizer newInstance() {
          return new HierarchicalAdaptiveQuantizer(quantization);
        }
       
      };
      final Map<Double, Quantizer> quantizers = new TreeMap<Double, Quantizer>();
      int progress = 0;
     
      for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
        System.out.print((progress++) + "/" + trainingBase.size() + "\r");
       
        final double[] featureVector = entry.getKey();
        final int n = featureVector.length;
       
        for (int i = 0; i < n; i += 2) {
          final Double featureIndex = featureVector[i + 0];
          final Double featureValue = featureVector[i + 1];
         
          getOrCreate(quantizers, featureIndex, quantizerFactory).add(featureValue, 1.0);
        }
      }
     
      System.out.println("Quantizing...");
     
      final Map<double[], Integer> newTrainingBase = new TreeMap<double[], Integer>(FeatureVectorComparator.INSTANCE);
      progress = 0;
     
      for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
        System.out.print((progress++) + "/" + trainingBase.size() + "\r");
        newTrainingBase.put(quantize(entry.getKey(), quantizers), entry.getValue());
      }
     
      trainingBase.clear();
      trainingBase.putAll(newTrainingBase);
    }
   
    if (0 < randomReduce) {
      System.out.println("Reducing...");
     
      final Map<Integer, List<Map.Entry<double[], Integer>>> data = new HashMap<Integer, List<Map.Entry<double[], Integer>>>();
     
      for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
        getOrCreate(data, entry.getValue(), (Factory) ARRAY_LIST_FACTORY).add(entry);
      }
     
      for (final Map.Entry<Integer, List<Map.Entry<double[], Integer>>> entry : data.entrySet()) {
        final List<Entry<double[], Integer>> list = entry.getValue();
        final int n = list.size();
       
        if (randomReduce < n) {
          shuffle(list, new Random(n - randomReduce));
         
          for (int i = randomReduce; i < n; ++i) {
            trainingBase.remove(list.get(i).getKey());
          }
        }
      }
    }
   
    final Map.Entry<double[], Integer>[] entries = trainingBase.entrySet().toArray(new Map.Entry[trainingBase.size()]);
   
    if ("NORMALIZED_EUCLIDEAN".equals(distanceName)) {
      System.out.println("Initializing distance...");
      distance = new FeatureVectorNormalizedEuclideanDistance();
     
      for (final Map.Entry<double[], Integer> entry : entries) {
        ((FeatureVectorNormalizedEuclideanDistance) distance).add(entry.getKey());
      }
    } else {
      distance = PredefinedFeatureVectorDistance.valueOf(distanceName);
    }
   
    System.out.println("trainingEntryCount: " + entries.length);
   
    System.out.println("Testing...");
   
    final int threadCount = arguments.get("threads", SystemProperties.getAvailableProcessorCount())[0];
    final ExecutorService executor = Executors.newFixedThreadPool(threadCount);
   
    try {
      final Collection<Future<Double>> tasks = new ArrayList<Future<Double>>();
     
      readLibSVMFile(testingFilePath, new FeatureVectorProcessor() {
       
        @Override
        public final void process(final int label, final double[] featureVector) {
          assert 0.0 == distance.compute(featureVector, featureVector);
         
          tasks.add(executor.submit(new Callable<Double>() {
           
            @Override
            public final Double call() {
              return label == nearestNeighbor(featureVector, distance, entries).getValue() ? 1.0 : 0.0;
            }
           
          }));
        }
       
      });
     
      int progress = 0;
     
      for (final Future<Double> task : tasks) {
        System.out.print((progress++) + "/" + tasks.size() + " \r");
        statistics.addValue(task.get());
      }
     
      System.out.println("testingEntryCount: " + statistics.getCount());
      System.out.println("accuracy:" + statistics.getMean());
    } finally {
      executor.shutdown();
    }
  }
 
  public static final double[] quantize(final double[] featureVector, final Map<Double, Quantizer> quantizers) {
    final int n = featureVector.length;
   
    for (int i = 0; i < n; i += 2) {
      final Quantizer quantizer = quantizers.get(featureVector[i + 0]);
      featureVector[i + 1] = quantizer.quantize(featureVector[i + 1]);
    }
   
    return featureVector;
  }
 
  public static final Map.Entry<double[], Integer> nearestNeighbor(final double[] featureVector,
      final Distance<double[]> distance, final Map.Entry<double[], Integer>[] entries) {
    Map.Entry<double[], Integer> result = null;
    double nearestDistance = Double.POSITIVE_INFINITY;
   
    for (final Map.Entry<double[], Integer> entry : entries) {
      final double d = distance.compute(featureVector, entry.getKey());
     
      if (result == null || d < nearestDistance) {
        result = entry;
        nearestDistance = d;
      }
    }
   
    return result;
  }
 
  public static final void readLibSVMFile(final String filePath, final FeatureVectorProcessor processor) {
    try {
      final File file = new File(filePath);
      final Scanner scanner = new Scanner(file);
      final long total = file.length();
      long done = 0L;
     
      try {
        while (scanner.hasNext()) {
          System.out.print(done + "/" + total + " \r");
          final String line = scanner.nextLine();
          final String[] labelAndFeatures = line.split("\\s+|:");
          final int n = labelAndFeatures.length;
          final double[] featureVector = new double[n - 1];
          done += line.length();
         
          for (int i = 1; i < n; ++i) {
            featureVector[i - 1] = parseDouble(labelAndFeatures[i]);
          }
         
          processor.process((int) parseDouble(labelAndFeatures[0]), orderFeatures(featureVector));
        }
      } finally {
        scanner.close();
      }
    } catch (final Exception exception) {
      throw unchecked(exception);
    }
  }
 
  public static final Map<double[], Integer> newTrainingBase(final String trainingFilePath) {
    final Map<double[], Integer> result = new TreeMap<double[], Integer>(FeatureVectorComparator.INSTANCE);
   
    readLibSVMFile(trainingFilePath, new FeatureVectorProcessor() {
     
      @Override
      public final void process(final int label, final double[] featureVector) {
        final Integer oldLabel = result.put(featureVector, label);
       
        if (oldLabel != null && oldLabel.intValue() != label) {
          System.err.println("Warning: noninjective labeling");
        }
      }
     
    });
   
    return result;
  }
 
  public static final double[] orderFeatures(final double[] featureVector) {
    final int n = featureVector.length;
    final Integer[] indices = integerRange(0, 2, n - 1);
   
    sort(indices, new FeatureIndexComparator(featureVector));
   
    final double[] result = new double[n];
   
    for (int outIndex = 0, indexIndex = 0; outIndex < n; outIndex += 2, ++indexIndex) {
      final int inIndex = indices[indexIndex];
      result[outIndex] = featureVector[inIndex];
      result[outIndex + 1] = featureVector[inIndex + 1];
    }
   
    return result;
  }
 
  public static final Integer[] integerRange(final int first, final int step, final int last) {
    final int n = (step + last - first) / step;
    final Integer[] result = new Integer[n];
   
    for (int i = 0; i < n; ++i) {
      result[i] = first + step * i;
    }
   
    return result;
  }
 
  /**
   * @author codistmonk (creation 2013-07-02)
   */
  public static final class FeatureVectorNormalizedEuclideanDistance extends FeatureVectorDistance {
   
    private final Map<Double, Statistics> statistics = new HashMap<Double, Statistics>();
   
    public final void add(final double[] featureVector) {
      final int n = featureVector.length;
     
      for (int i = 0; i < n; i += 2) {
        getOrCreate(this.statistics, featureVector[i], STATISTICS_FACTORY).addValue(featureVector[i + 1]);
      }
    }

    @Override
    protected final double updateResult(final double oldResult,
        final double featureIndex, final double featureValue1, final double featureValue2) {
      final Statistics statistics = this.statistics.get(featureIndex);
      final double variance = statistics == null ? 0.0 : statistics.getVariance();
     
      return variance == 0.0 ? oldResult : oldResult + square(featureValue2 - featureValue1) / variance;
    }
   
    @Override
    protected final double updateResult(final double oldResult) {
      return sqrt(oldResult);
    }
   
  }
 
}
TOP

Related Classes of jclassifier.Classify$FeatureVectorNormalizedEuclideanDistance

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.