final CommandLineArgumentsParser arguments = new CommandLineArgumentsParser(commandLineArguments);
final String trainingFilePath = arguments.get("train", "");
final String testingFilePath = arguments.get("test", "");
System.out.println("Creating training base...");
final Map<double[], Integer> trainingBase = newTrainingBase(trainingFilePath);
final Statistics statistics = new Statistics();
final String distanceName = arguments.get("distance", "manhattan").toUpperCase(Locale.ENGLISH);
final Distance<double[]> distance;
final int randomReduce = arguments.get("randomReduce", 0)[0];
final int quantization = arguments.get("q", 0)[0];
if (0 < quantization) {
System.out.println("Creating quantizers...");
final Factory<Quantizer> quantizerFactory = new Factory<Quantizer>() {
@Override
public final Quantizer newInstance() {
return new HierarchicalAdaptiveQuantizer(quantization);
}
};
final Map<Double, Quantizer> quantizers = new TreeMap<Double, Quantizer>();
int progress = 0;
for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
System.out.print((progress++) + "/" + trainingBase.size() + "\r");
final double[] featureVector = entry.getKey();
final int n = featureVector.length;
for (int i = 0; i < n; i += 2) {
final Double featureIndex = featureVector[i + 0];
final Double featureValue = featureVector[i + 1];
getOrCreate(quantizers, featureIndex, quantizerFactory).add(featureValue, 1.0);
}
}
System.out.println("Quantizing...");
final Map<double[], Integer> newTrainingBase = new TreeMap<double[], Integer>(FeatureVectorComparator.INSTANCE);
progress = 0;
for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
System.out.print((progress++) + "/" + trainingBase.size() + "\r");
newTrainingBase.put(quantize(entry.getKey(), quantizers), entry.getValue());
}
trainingBase.clear();
trainingBase.putAll(newTrainingBase);
}
if (0 < randomReduce) {
System.out.println("Reducing...");
final Map<Integer, List<Map.Entry<double[], Integer>>> data = new HashMap<Integer, List<Map.Entry<double[], Integer>>>();
for (final Map.Entry<double[], Integer> entry : trainingBase.entrySet()) {
getOrCreate(data, entry.getValue(), (Factory) ARRAY_LIST_FACTORY).add(entry);
}
for (final Map.Entry<Integer, List<Map.Entry<double[], Integer>>> entry : data.entrySet()) {
final List<Entry<double[], Integer>> list = entry.getValue();
final int n = list.size();
if (randomReduce < n) {
shuffle(list, new Random(n - randomReduce));
for (int i = randomReduce; i < n; ++i) {
trainingBase.remove(list.get(i).getKey());
}
}
}
}
final Map.Entry<double[], Integer>[] entries = trainingBase.entrySet().toArray(new Map.Entry[trainingBase.size()]);
if ("NORMALIZED_EUCLIDEAN".equals(distanceName)) {
System.out.println("Initializing distance...");
distance = new FeatureVectorNormalizedEuclideanDistance();
for (final Map.Entry<double[], Integer> entry : entries) {
((FeatureVectorNormalizedEuclideanDistance) distance).add(entry.getKey());
}
} else {
distance = PredefinedFeatureVectorDistance.valueOf(distanceName);
}
System.out.println("trainingEntryCount: " + entries.length);
System.out.println("Testing...");
final int threadCount = arguments.get("threads", SystemProperties.getAvailableProcessorCount())[0];
final ExecutorService executor = Executors.newFixedThreadPool(threadCount);
try {
final Collection<Future<Double>> tasks = new ArrayList<Future<Double>>();
readLibSVMFile(testingFilePath, new FeatureVectorProcessor() {
@Override
public final void process(final int label, final double[] featureVector) {
assert 0.0 == distance.compute(featureVector, featureVector);
tasks.add(executor.submit(new Callable<Double>() {
@Override
public final Double call() {
return label == nearestNeighbor(featureVector, distance, entries).getValue() ? 1.0 : 0.0;
}
}));
}
});
int progress = 0;
for (final Future<Double> task : tasks) {
System.out.print((progress++) + "/" + tasks.size() + " \r");
statistics.addValue(task.get());
}
System.out.println("testingEntryCount: " + statistics.getCount());
System.out.println("accuracy:" + statistics.getMean());
} finally {
executor.shutdown();
}
}