package jclassifier;
import static java.lang.Double.parseDouble;
import static java.lang.Math.sqrt;
import static java.util.Arrays.sort;
import static java.util.Collections.shuffle;
import static net.sourceforge.aprog.tools.MathTools.Statistics.square;
import static net.sourceforge.aprog.tools.Tools.getOrCreate;
import static net.sourceforge.aprog.tools.Tools.unchecked;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Scanner;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import net.sourceforge.aprog.tools.CommandLineArgumentsParser;
import net.sourceforge.aprog.tools.Factory;
import net.sourceforge.aprog.tools.Factory.DefaultFactory;
import net.sourceforge.aprog.tools.IllegalInstantiationException;
import net.sourceforge.aprog.tools.MathTools.Statistics;
import net.sourceforge.aprog.tools.SystemProperties;
/**
* @author codistmonk (creation 2013-07-01)
*/
public final class Classify {
private Classify() {
throw new IllegalInstantiationException();
}
public static final Factory<Statistics> STATISTICS_FACTORY = (Factory<Statistics>) DefaultFactory.forClass(Statistics.class);
public static final Factory<ArrayList> ARRAY_LIST_FACTORY = (Factory<ArrayList>) DefaultFactory.forClass(ArrayList.class);
/**
* @param commandLineArguments
* <br>Must not be null
*/
public static final void main(final String[] commandLineArguments) throws Exception {
final CommandLineArgumentsParser arguments = new CommandLineArgumentsParser(commandLineArguments);
final String trainingFilePath = arguments.get("train", "");
final String testingFilePath = arguments.get("test", "");
System.out.println("Creating training base...");
final Map<double[], Integer> trainingBase = newTrainingBase(trainingFilePath);
final Statistics statistics = new Statistics();
final String distanceName = arguments.get("distance", "manhattan").toUpperCase(Locale.ENGLISH);
final Distance<double[]> distance;
final int randomReduce = arguments.get("randomReduce", 0)[0];
final int quantization = arguments.get("q", 0)[0];
if (0 < quantization) {
System.out.println("Creating quantizers...");
final Factory<Quantizer> quantizerFactory = new Factory<Quantizer>() {
@Override
public final Quantizer newInstance() {
return new HierarchicalAdaptiveQuantizer(quantization);
}
};
final Map<Double, Quantizer> quantizers = new TreeMap<Double, Quantizer>();
int progress = 0;
for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
System.out.print((progress++) + "/" + trainingBase.size() + "\r");
final double[] featureVector = entry.getKey();
final int n = featureVector.length;
for (int i = 0; i < n; i += 2) {
final Double featureIndex = featureVector[i + 0];
final Double featureValue = featureVector[i + 1];
getOrCreate(quantizers, featureIndex, quantizerFactory).add(featureValue, 1.0);
}
}
System.out.println("Quantizing...");
final Map<double[], Integer> newTrainingBase = new TreeMap<double[], Integer>(FeatureVectorComparator.INSTANCE);
progress = 0;
for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
System.out.print((progress++) + "/" + trainingBase.size() + "\r");
newTrainingBase.put(quantize(entry.getKey(), quantizers), entry.getValue());
}
trainingBase.clear();
trainingBase.putAll(newTrainingBase);
}
if (0 < randomReduce) {
System.out.println("Reducing...");
final Map<Integer, List<Map.Entry<double[], Integer>>> data = new HashMap<Integer, List<Map.Entry<double[], Integer>>>();
for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
getOrCreate(data, entry.getValue(), (Factory) ARRAY_LIST_FACTORY).add(entry);
}
for (final Map.Entry<Integer, List<Map.Entry<double[], Integer>>> entry : data.entrySet()) {
final List<Entry<double[], Integer>> list = entry.getValue();
final int n = list.size();
if (randomReduce < n) {
shuffle(list, new Random(n - randomReduce));
for (int i = randomReduce; i < n; ++i) {
trainingBase.remove(list.get(i).getKey());
}
}
}
}
final Map.Entry<double[], Integer>[] entries = trainingBase.entrySet().toArray(new Map.Entry[trainingBase.size()]);
if ("NORMALIZED_EUCLIDEAN".equals(distanceName)) {
System.out.println("Initializing distance...");
distance = new FeatureVectorNormalizedEuclideanDistance();
for (final Map.Entry<double[], Integer> entry : entries) {
((FeatureVectorNormalizedEuclideanDistance) distance).add(entry.getKey());
}
} else {
distance = PredefinedFeatureVectorDistance.valueOf(distanceName);
}
System.out.println("trainingEntryCount: " + entries.length);
System.out.println("Testing...");
final int threadCount = arguments.get("threads", SystemProperties.getAvailableProcessorCount())[0];
final ExecutorService executor = Executors.newFixedThreadPool(threadCount);
try {
final Collection<Future<Double>> tasks = new ArrayList<Future<Double>>();
readLibSVMFile(testingFilePath, new FeatureVectorProcessor() {
@Override
public final void process(final int label, final double[] featureVector) {
assert 0.0 == distance.compute(featureVector, featureVector);
tasks.add(executor.submit(new Callable<Double>() {
@Override
public final Double call() {
return label == nearestNeighbor(featureVector, distance, entries).getValue() ? 1.0 : 0.0;
}
}));
}
});
int progress = 0;
for (final Future<Double> task : tasks) {
System.out.print((progress++) + "/" + tasks.size() + " \r");
statistics.addValue(task.get());
}
System.out.println("testingEntryCount: " + statistics.getCount());
System.out.println("accuracy:" + statistics.getMean());
} finally {
executor.shutdown();
}
}
public static final double[] quantize(final double[] featureVector, final Map<Double, Quantizer> quantizers) {
final int n = featureVector.length;
for (int i = 0; i < n; i += 2) {
final Quantizer quantizer = quantizers.get(featureVector[i + 0]);
featureVector[i + 1] = quantizer.quantize(featureVector[i + 1]);
}
return featureVector;
}
public static final Map.Entry<double[], Integer> nearestNeighbor(final double[] featureVector,
final Distance<double[]> distance, final Map.Entry<double[], Integer>[] entries) {
Map.Entry<double[], Integer> result = null;
double nearestDistance = Double.POSITIVE_INFINITY;
for (final Map.Entry<double[], Integer> entry : entries) {
final double d = distance.compute(featureVector, entry.getKey());
if (result == null || d < nearestDistance) {
result = entry;
nearestDistance = d;
}
}
return result;
}
public static final void readLibSVMFile(final String filePath, final FeatureVectorProcessor processor) {
try {
final File file = new File(filePath);
final Scanner scanner = new Scanner(file);
final long total = file.length();
long done = 0L;
try {
while (scanner.hasNext()) {
System.out.print(done + "/" + total + " \r");
final String line = scanner.nextLine();
final String[] labelAndFeatures = line.split("\\s+|:");
final int n = labelAndFeatures.length;
final double[] featureVector = new double[n - 1];
done += line.length();
for (int i = 1; i < n; ++i) {
featureVector[i - 1] = parseDouble(labelAndFeatures[i]);
}
processor.process((int) parseDouble(labelAndFeatures[0]), orderFeatures(featureVector));
}
} finally {
scanner.close();
}
} catch (final Exception exception) {
throw unchecked(exception);
}
}
public static final Map<double[], Integer> newTrainingBase(final String trainingFilePath) {
final Map<double[], Integer> result = new TreeMap<double[], Integer>(FeatureVectorComparator.INSTANCE);
readLibSVMFile(trainingFilePath, new FeatureVectorProcessor() {
@Override
public final void process(final int label, final double[] featureVector) {
final Integer oldLabel = result.put(featureVector, label);
if (oldLabel != null && oldLabel.intValue() != label) {
System.err.println("Warning: noninjective labeling");
}
}
});
return result;
}
public static final double[] orderFeatures(final double[] featureVector) {
final int n = featureVector.length;
final Integer[] indices = integerRange(0, 2, n - 1);
sort(indices, new FeatureIndexComparator(featureVector));
final double[] result = new double[n];
for (int outIndex = 0, indexIndex = 0; outIndex < n; outIndex += 2, ++indexIndex) {
final int inIndex = indices[indexIndex];
result[outIndex] = featureVector[inIndex];
result[outIndex + 1] = featureVector[inIndex + 1];
}
return result;
}
public static final Integer[] integerRange(final int first, final int step, final int last) {
final int n = (step + last - first) / step;
final Integer[] result = new Integer[n];
for (int i = 0; i < n; ++i) {
result[i] = first + step * i;
}
return result;
}
/**
* @author codistmonk (creation 2013-07-02)
*/
public static final class FeatureVectorNormalizedEuclideanDistance extends FeatureVectorDistance {
private final Map<Double, Statistics> statistics = new HashMap<Double, Statistics>();
public final void add(final double[] featureVector) {
final int n = featureVector.length;
for (int i = 0; i < n; i += 2) {
getOrCreate(this.statistics, featureVector[i], STATISTICS_FACTORY).addValue(featureVector[i + 1]);
}
}
@Override
protected final double updateResult(final double oldResult,
final double featureIndex, final double featureValue1, final double featureValue2) {
final Statistics statistics = this.statistics.get(featureIndex);
final double variance = statistics == null ? 0.0 : statistics.getVariance();
return variance == 0.0 ? oldResult : oldResult + square(featureValue2 - featureValue1) / variance;
}
@Override
protected final double updateResult(final double oldResult) {
return sqrt(oldResult);
}
}
}