/*
* Copyright Myrrix Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.myrrix.online.eval;
import java.io.File;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.recommender.IDRescorer;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import net.myrrix.common.MyrrixRecommender;
import net.myrrix.common.parallel.Paralleler;
import net.myrrix.common.parallel.Processor;
import net.myrrix.online.RescorerProvider;
/**
* <p>A simple evaluation framework for a recommender, which calculates precision, recall, F1,
* mean average precision, and other basic statistics.</p>
*
* <p>This class can be run as a Java program; the single argument is a directory containing test data.
* The {@link EvaluationResult} is printed to standard out.</p>
*
* @author Sean Owen
* @since 1.0
*/
public final class PrecisionRecallEvaluator extends AbstractEvaluator {
private static final Logger log = LoggerFactory.getLogger(PrecisionRecallEvaluator.class);
private static final double LN2 = Math.log(2.0);
@Override
protected boolean isSplitTestByPrefValue() {
return true;
}
@Override
public EvaluationResult evaluate(final MyrrixRecommender recommender,
final RescorerProvider provider,
final Multimap<Long,RecommendedItem> testData) throws TasteException {
final Mean precision = new Mean();
final Mean recall = new Mean();
final Mean ndcg = new Mean();
final Mean meanAveragePrecision = new Mean();
Processor<Long> processor = new Processor<Long>() {
@Override
public void process(Long userID, long count) {
Collection<RecommendedItem> values = testData.get(userID);
int numValues = values.size();
if (numValues == 0) {
return;
}
IDRescorer rescorer =
provider == null ? null : provider.getRecommendRescorer(new long[]{userID}, recommender);
List<RecommendedItem> recs;
try {
recs = recommender.recommend(userID, numValues, rescorer);
} catch (NoSuchUserException nsue) {
// Probably OK, just removed all data for this user from training
log.warn("User only in test data: {}", userID);
return;
} catch (TasteException te) {
log.warn("Unexpected exception", te);
return;
}
int numRecs = recs.size();
Collection<Long> valueIDs = Sets.newHashSet();
for (RecommendedItem rec : values) {
valueIDs.add(rec.getItemID());
}
int intersectionSize = 0;
double score = 0.0;
double maxScore = 0.0;
Mean precisionAtI = new Mean();
double averagePrecision = 0.0;
for (int i = 0; i < numRecs; i++) {
RecommendedItem rec = recs.get(i);
double value = LN2 / Math.log(2.0 + i); // 1 / log_2(1 + (i+1))
if (valueIDs.contains(rec.getItemID())) {
intersectionSize++;
score += value;
precisionAtI.increment(1.0);
averagePrecision += precisionAtI.getResult();
} else {
precisionAtI.increment(0.0);
}
maxScore += value;
}
averagePrecision /= numValues;
synchronized (precision) {
precision.increment(numRecs == 0 ? 0.0 : (double) intersectionSize / numRecs);
recall.increment((double) intersectionSize / numValues);
ndcg.increment(maxScore == 0.0 ? 0.0 : score / maxScore);
meanAveragePrecision.increment(averagePrecision);
if (count % 10000 == 0) {
log.info(new IRStatisticsImpl(precision.getResult(),
recall.getResult(),
ndcg.getResult(),
meanAveragePrecision.getResult()).toString());
}
}
}
};
Paralleler<Long> paralleler = new Paralleler<Long>(testData.keySet().iterator(), processor, "PREval");
try {
if (Boolean.parseBoolean(System.getProperty("eval.parallel", "true"))) {
paralleler.runInParallel();
} else {
paralleler.runInSerial();
}
} catch (InterruptedException ie) {
throw new TasteException(ie);
} catch (ExecutionException e) {
throw new TasteException(e.getCause());
}
EvaluationResult result;
if (precision.getN() > 0) {
result = new IRStatisticsImpl(precision.getResult(),
recall.getResult(),
ndcg.getResult(),
meanAveragePrecision.getResult());
} else {
result = null;
}
log.info(String.valueOf(result));
return result;
}
public static void main(String[] args) throws Exception {
EvaluationResult result = new PrecisionRecallEvaluator().evaluate(new File(args[0]));
log.info(String.valueOf(result));
}
}