}
Dictionary newsGroups = new Dictionary();
encoder.setProbes(2);
AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
learningAlgorithm.setInterval(800);
learningAlgorithm.setAveragingWindow(500);
List<File> files = Lists.newArrayList();
File[] directories = base.listFiles();
Arrays.sort(directories, Ordering.usingToString());
for (File newsgroup : directories) {
if (newsgroup.isDirectory()) {
newsGroups.intern(newsgroup.getName());
files.addAll(Arrays.asList(newsgroup.listFiles()));
}
}
Collections.shuffle(files);
System.out.printf("%d training files\n", files.size());
System.out.printf("%s\n", Arrays.asList(directories));
double averageLL = 0;
double averageCorrect = 0;
int k = 0;
double step = 0;
int[] bumps = {1, 2, 5};
for (File file : files) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
Vector v = encodeFeatureVector(file);
learningAlgorithm.train(actual, v);
k++;
int bump = bumps[(int) Math.floor(step) % bumps.length];
int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
double maxBeta;
double nonZeros;
double positive;
double norm;
double lambda = 0;
double mu = 0;
if (best != null) {
CrossFoldLearner state = best.getPayload().getLearner();
averageCorrect = state.percentCorrect();
averageLL = state.logLikelihood();
OnlineLogisticRegression model = state.getModels().get(0);
// finish off pending regularization
model.close();
Matrix beta = model.getBeta();
maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
@Override
public double apply(double v) {
return Math.abs(v) > 1.0e-6 ? 1 : 0;
}
});
positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
@Override
public double apply(double v) {
return v > 0 ? 1 : 0;
}
});
norm = beta.aggregate(Functions.PLUS, Functions.ABS);
lambda = learningAlgorithm.getBest().getMappedParams()[0];
mu = learningAlgorithm.getBest().getMappedParams()[1];
} else {
maxBeta = 0;
nonZeros = 0;
positive = 0;
norm = 0;
}
if (k % (bump * scale) == 0) {
if (learningAlgorithm.getBest() != null) {
ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}
step += 0.25;
System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
System.out.printf("%d\t%.3f\t%.2f\t%s\n",
k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
}
}
learningAlgorithm.close();
dissect(newsGroups, learningAlgorithm, files);
System.out.println("exiting main");
ModelSerializer.writeBinary("/tmp/news-group.model",
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
}