Package io.prediction.examples.java.recommendations.tutorial3

Source Code of io.prediction.examples.java.recommendations.tutorial3.DataSource

package io.prediction.examples.java.recommendations.tutorial3;

import io.prediction.examples.java.recommendations.tutorial1.TrainingData;
import io.prediction.examples.java.recommendations.tutorial1.Query;
import io.prediction.examples.java.recommendations.tutorial1.DataSourceParams;

import io.prediction.controller.java.LJavaDataSource;
import scala.Tuple2;
import scala.Tuple3;
import java.io.File;
import java.io.FileNotFoundException;
import java.lang.Iterable;
import java.util.List;
import java.util.ArrayList;
import java.util.Scanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Random;
import java.util.Collections;

public class DataSource extends LJavaDataSource<
  DataSourceParams, Object, TrainingData, Query, Float> {

  final static Logger logger = LoggerFactory.getLogger(DataSource.class);

  DataSourceParams params;

  public DataSource(DataSourceParams params) {
    this.params = params;
  }

  @Override
  public Iterable<Tuple3<Object, TrainingData, Iterable<Tuple2<Query, Float>>>> read() {

    File ratingFile = new File(params.filePath);
    Scanner sc = null;

    try {
      sc = new Scanner(ratingFile);
    } catch (FileNotFoundException e) {
      logger.error("Caught FileNotFoundException " + e.getMessage());
      System.exit(1);
    }

    List<TrainingData.Rating> ratings = new ArrayList<TrainingData.Rating>();

    while (sc.hasNext()) {
      String line = sc.nextLine();
      String[] tokens = line.split("[\t,]");
      try {
        TrainingData.Rating rating = new TrainingData.Rating(
          Integer.parseInt(tokens[0]),
          Integer.parseInt(tokens[1]),
          Float.parseFloat(tokens[2]));
        ratings.add(rating);
      } catch (Exception e) {
        logger.error("Can't parse rating file. Caught Exception: " + e.getMessage());
        System.exit(1);
      }
    }

    int size = ratings.size();
    float trainingPercentage = 0.8f;
    float testPercentage = 1 - trainingPercentage;
    int iterations = 3;

    // cap by original size
    int trainingEndIndex = Math.min(size,
      (int) (ratings.size() * trainingPercentage));
    int testEndIndex = Math.min(size,
      trainingEndIndex + (int) (ratings.size() * testPercentage));
      // trainingEndIndex + 10);

    Random rand = new Random(0); // seed

    List<Tuple3<Object, TrainingData, Iterable<Tuple2<Query, Float>>>> data = new
      ArrayList<Tuple3<Object, TrainingData, Iterable<Tuple2<Query, Float>>>>();

    for (int i = 0; i < iterations; i++) {
      Collections.shuffle(ratings, new Random(rand.nextInt()));

      // create a new ArrayList because subList() returns view and not serialzable
      List<TrainingData.Rating> trainingRatings =
        new ArrayList<TrainingData.Rating>(ratings.subList(0, trainingEndIndex));
      List<TrainingData.Rating> testRatings = ratings.subList(trainingEndIndex, testEndIndex);
      TrainingData td = new TrainingData(trainingRatings);
      List<Tuple2<Query, Float>> qaList = prepareValidation(testRatings);

      data.add(new Tuple3<Object, TrainingData, Iterable<Tuple2<Query, Float>>>(
        null, td, qaList));
    }

    return data;
  }

  private List<Tuple2<Query, Float>> prepareValidation(List<TrainingData.Rating> testRatings) {
    List<Tuple2<Query, Float>> validationList = new ArrayList<Tuple2<Query, Float>>();

    for (TrainingData.Rating r : testRatings) {
      validationList.add(new Tuple2<Query, Float>(
        new Query(r.uid, r.iid),
        r.rating));
    }

    return validationList;
  }

}
TOP

Related Classes of io.prediction.examples.java.recommendations.tutorial3.DataSource

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.