Package net.myrrix.online.eval

Source Code of net.myrrix.online.eval.AUCEvaluator

/*
* Copyright Myrrix Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.myrrix.online.eval;

import java.io.File;
import java.util.Collection;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.collect.Multimap;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.mahout.cf.taste.common.NoSuchItemException;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.myrrix.common.MyrrixRecommender;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
import net.myrrix.common.parallel.Paralleler;
import net.myrrix.common.parallel.Processor;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;
import net.myrrix.online.RescorerProvider;

/**
* <p>This implementation calculates Area under curve (AUC), which may be understood as the probability
* that a random "good" recommendation is ranked higher than a random "bad" recommendation.</p>
*
* <p>This class can be run as a Java program; the single argument is a directory containing test data.
* The {@link EvaluationResult} is printed to standard out.</p>
*
* @author Sean Owen
* @since 1.0
*/
public final class AUCEvaluator extends AbstractEvaluator {

  private static final Logger log = LoggerFactory.getLogger(AUCEvaluator.class);

  @Override
  protected boolean isSplitTestByPrefValue() {
    return true;
  }

  @Override
  public EvaluationResult evaluate(MyrrixRecommender recommender,
                                   RescorerProvider provider, // ignored
                                   Multimap<Long,RecommendedItem> testData) throws TasteException {
    FastByIDMap<FastIDSet> converted = new FastByIDMap<FastIDSet>(testData.size());
    for (long userID : testData.keySet()) {
      Collection<RecommendedItem> userTestData = testData.get(userID);
      FastIDSet itemIDs = new FastIDSet(userTestData.size());
      converted.put(userID, itemIDs);
      for (RecommendedItem datum : userTestData) {
        itemIDs.add(datum.getItemID());
      }
    }
    return evaluate(recommender, converted);
  }

  public EvaluationResult evaluate(final MyrrixRecommender recommender,
                                   final FastByIDMap<FastIDSet> testData) throws TasteException {

    final AtomicInteger underCurve = new AtomicInteger(0);
    final AtomicInteger total = new AtomicInteger(0);
   
    final long[] allItemIDs = recommender.getAllItemIDs().toArray();
   
    Processor<Long> processor = new Processor<Long>() {
      private final RandomGenerator random = RandomManager.getRandom();     
      @Override
      public void process(Long userID, long count) throws ExecutionException {
        FastIDSet testItemIDs = testData.get(userID);
        int numTest = testItemIDs.size()
        for (int i = 0; i < numTest; i++) {
 
          long randomTestItemID;
          long randomTrainingItemID;
          synchronized (random) {
            randomTestItemID = RandomUtils.randomFrom(testItemIDs, random);
            do {
              randomTrainingItemID = allItemIDs[random.nextInt(allItemIDs.length)];
            } while (testItemIDs.contains(randomTrainingItemID));
          }
 
          float relevantEstimate;
          float nonRelevantEstimate;
          try {
            relevantEstimate = recommender.estimatePreference(userID, randomTestItemID);
            nonRelevantEstimate = recommender.estimatePreference(userID, randomTrainingItemID);           
          } catch (NoSuchItemException nsie) {
            // OK; it's possible item only showed up in test split
            continue;
          } catch (NoSuchUserException nsie) {
            // OK; it's possible user only showed up in test split
            continue;
          } catch (TasteException te) {
            throw new ExecutionException(te);
          }


          if (relevantEstimate > nonRelevantEstimate) {
            underCurve.incrementAndGet();
          }
          total.incrementAndGet();
         
          if (count % 100000 == 0) {
            log.info("AUC: {}", (double) underCurve.get() / total.get());
          }
        }
      }
    };
   
    try {
      new Paralleler<Long>(testData.keySetIterator(), processor, "AUCEval").runInParallel();
    } catch (InterruptedException ie) {
      throw new TasteException(ie);
    } catch (ExecutionException e) {
      throw new TasteException(e.getCause());
    }

    double score = (double) underCurve.get() / total.get();
    log.info("AUC: {}", score);
    return new EvaluationResultImpl(score);
  }

  public static void main(String[] args) throws Exception {
    EvaluationResult result = new AUCEvaluator().evaluate(new File(args[0]));
    log.info(String.valueOf(result));
  }

}
TOP

Related Classes of net.myrrix.online.eval.AUCEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.