Package org.apache.mahout.cf.taste.impl.recommender.svd

Source Code of org.apache.mahout.cf.taste.impl.recommender.svd.ImplicitLinearRegressionFactorizer

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import com.google.common.collect.Lists;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DiagonalMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.SparseMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public final class ImplicitLinearRegressionFactorizer extends AbstractFactorizer {

  private static final Logger log = LoggerFactory.getLogger(ImplicitLinearRegressionFactorizer.class);
  private final double preventOverfitting;
  /** number of features used to compute this factorization */
  private final int numFeatures;
  /** number of iterations */
  private final int numIterations;
  private final DataModel dataModel;
  /** User singular vector. */
  private double[][] userMatrix;
  /** Item singular vector. */
  private double[][] itemMatrix;
  private Matrix userTransUser;
  private Matrix itemTransItem;
  private Collection<Callable<Void>> fVectorCallables;
  private boolean recomputeUserFeatures;
  private RunningAverage avrChange;

  public ImplicitLinearRegressionFactorizer(DataModel dataModel) throws TasteException {
    this(dataModel, 64, 10, 0.1);
  }

  public ImplicitLinearRegressionFactorizer(DataModel dataModel, int numFeatures, int numIterations,
                                            double preventOverfitting) throws TasteException {

    super(dataModel);
    this.dataModel = dataModel;
    this.numFeatures = numFeatures;
    this.numIterations = numIterations;
    this.preventOverfitting = preventOverfitting;
    fVectorCallables = Lists.newArrayList();
    avrChange = new FullRunningAverage();
  }

  @Override
  public Factorization factorize() throws TasteException {
    Random random = RandomUtils.getRandom();
    userMatrix = new double[dataModel.getNumUsers()][numFeatures];
    itemMatrix = new double[dataModel.getNumItems()][numFeatures];

    /* start with the user side */
    recomputeUserFeatures = true;

    double average = getAveragePreference();

    double prefInterval = dataModel.getMaxPreference() - dataModel.getMinPreference();
    double defaultValue = Math.sqrt((average - prefInterval * 0.1) / numFeatures);
    double interval = prefInterval * 0.1 / numFeatures;

    for (int feature = 0; feature < numFeatures; feature++) {
      for (int userIndex = 0; userIndex < dataModel.getNumUsers(); userIndex++) {
        userMatrix[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * random.nextDouble();
      }
      for (int itemIndex = 0; itemIndex < dataModel.getNumItems(); itemIndex++) {
        itemMatrix[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * random.nextDouble();
      }
    }
    train();
    return createFactorization(userMatrix, itemMatrix);
  }

  public void train() throws TasteException {
    for (int i = 0; i < numIterations; i++) {
      if (recomputeUserFeatures) {
        LongPrimitiveIterator userIds = dataModel.getUserIDs();
        /* start with calculating X^TX or Y^TX */
        log.info("Calculating Y^TY");
        reCalculateTrans(recomputeUserFeatures);
        log.info("Building callables for users.");
        while (userIds.hasNext()) {
          long userId = userIds.nextLong();
          int useridx = userIndex(userId);
          buildCallables(buildConfidenceMatrixForUser(userId), buildPreferenceVectorForUser(userId), useridx);
        }
        finishProcessing();
      } else {
        LongPrimitiveIterator itemIds = dataModel.getItemIDs();
        /* start with calculating X^TX or Y^TX */
        log.info("Calculating X^TX");
        reCalculateTrans(recomputeUserFeatures);
        log.info("Building callables for items.");
        while (itemIds.hasNext()) {
          long itemId = itemIds.nextLong();
          int itemidx = itemIndex(itemId);
          buildCallables(buildConfidenceMatrixForItem(itemId), buildPreferenceVectorForItem(itemId), itemidx);
        }
        finishProcessing();
      }
    }
  }

  public Matrix buildPreferenceVectorForUser(long realId) throws TasteException {
    Matrix ids = new SparseMatrix(1, dataModel.getNumItems());
    for (Preference pref : dataModel.getPreferencesFromUser(realId)) {
      int itemidx = itemIndex(pref.getItemID());
      ids.setQuick(0, itemidx, pref.getValue());
    }
    return ids;
  }

  private Matrix buildConfidenceMatrixForItem(long itemId) throws TasteException {
    PreferenceArray prefs = dataModel.getPreferencesForItem(itemId);
    Matrix confidenceMatrix = new SparseMatrix(dataModel.getNumUsers(), dataModel.getNumUsers());
    for (Preference pref : prefs) {
      long userId = pref.getUserID();
      int userIdx = userIndex(userId);
      confidenceMatrix.setQuick(userIdx, userIdx, 1);
    }
    return new DiagonalMatrix(confidenceMatrix);
  }

  private Matrix buildConfidenceMatrixForUser(long userId) throws TasteException {
    PreferenceArray prefs = dataModel.getPreferencesFromUser(userId);
    Matrix confidenceMatrix = new SparseMatrix(dataModel.getNumItems(), dataModel.getNumItems());
    for (Preference pref : prefs) {
      long itemId = pref.getItemID();
      int itemIdx = itemIndex(itemId);
      confidenceMatrix.setQuick(itemIdx, itemIdx, 1);
    }
    return new DiagonalMatrix(confidenceMatrix);
  }

  private Matrix buildPreferenceVectorForItem(long realId) throws TasteException {
    Matrix ids = new SparseMatrix(1, dataModel.getNumUsers());
    for (Preference pref : dataModel.getPreferencesForItem(realId)) {
      int useridx = userIndex(pref.getUserID());
      ids.setQuick(0, useridx, pref.getValue());
    }
    return ids;
  }

  private static Matrix ones(int size) {
    double[] vector = new double[size];
    for (int i = 0; i < size; i++) {
      vector[i] = 1;
    }
    return new DiagonalMatrix(vector);
  }

  private double getAveragePreference() throws TasteException {
    RunningAverage average = new FullRunningAverage();
    LongPrimitiveIterator it = dataModel.getUserIDs();
    while (it.hasNext()) {
      int count = 0;
      try {
        PreferenceArray prefs = dataModel.getPreferencesFromUser(it.nextLong());
        for (Preference pref : prefs) {
          average.addDatum(pref.getValue());
          count++;
        }
      } catch (NoSuchUserException ex) {
        continue;
      }
      /* add the remaining zeros */
      for (int i = 0; i < (dataModel.getNumItems() - count); i++) {
        average.addDatum(0);
      }
    }
    return average.getAverage();
  }

  /**
   * Recalculating Y^TY or X^TX which is needed for further calculations
   * @param recomputeUserFeatures
   */
  public void reCalculateTrans(boolean recomputeUserFeatures) {
    if (recomputeUserFeatures) {
      Matrix iMatrix = new DenseMatrix(itemMatrix);
      itemTransItem = iMatrix.transpose().times(iMatrix);
    } else {
      Matrix uMatrix = new DenseMatrix(userMatrix);
      userTransUser = uMatrix.transpose().times(uMatrix);
    }
  }

  private synchronized void updateMatrix(int id, Matrix m) {
    double normA = 0;
    double normB = 0;
    double aTb = 0;
    for (int feature = 0; feature < numFeatures; feature++) {
      if (recomputeUserFeatures) {
        normA += userMatrix[id][feature] * userMatrix[id][feature];
        normB += m.get(feature, 0) * m.get(feature, 0);
        aTb += userMatrix[id][feature] * m.get(feature, 0);
        userMatrix[id][feature] = m.get(feature, 0);
      } else {
        normA += itemMatrix[id][feature] * itemMatrix[id][feature];
        normB += m.get(feature, 0) * m.get(feature, 0);
        aTb += itemMatrix[id][feature] * m.get(feature, 0);
        itemMatrix[id][feature] = m.get(feature, 0);
      }
    }
    /* calculating cosine similarity to determine when to stop the algorithm,
    this could be used to detect convergence */
    double cosine = aTb / (Math.sqrt(normA) * Math.sqrt(normB));
    if (Double.isNaN(cosine)) {
      log.info("Cosine similarity is NaN, recomputeUserFeatures={} id={}", recomputeUserFeatures, id);
    } else {
      avrChange.addDatum(cosine);
    }
  }

  public void resetCallables() {
    fVectorCallables = Lists.newArrayList();
  }

  private void resetAvrChange() {
    log.info("Avr Change: {}", avrChange.getAverage());
    avrChange = new FullRunningAverage();
  }

  public void buildCallables(Matrix C, Matrix prefVector, int id) {
    fVectorCallables.add(new FeatureVectorCallable(C, prefVector, id));
    if (fVectorCallables.size() % (200 * Runtime.getRuntime().availableProcessors()) == 0) {
      execute(fVectorCallables);
      resetCallables();
    }
  }

  public void finishProcessing() {
    /* run the remaining part */
    if (fVectorCallables != null) {
      execute(fVectorCallables);
    }
    resetCallables();
    if ((recomputeUserFeatures && avrChange.getCount() != userMatrix.length)
        || (!recomputeUserFeatures && avrChange.getCount() != itemMatrix.length)) {
      log.info("Matrix length is not equal to count");
    }
    resetAvrChange();
    recomputeUserFeatures = !recomputeUserFeatures;
  }

  public static Matrix identityV(int size) {
    return ones(size);
  }

  static void execute(Collection<Callable<Void>> callables) {
    callables = wrapWithStatsCallables(callables);
    int numProcessors = Runtime.getRuntime().availableProcessors();
    ExecutorService executor = Executors.newFixedThreadPool(numProcessors);
    log.info("Starting timing of {} tasks in {} threads", callables.size(), numProcessors);
    try {
      List<Future<Void>> futures = executor.invokeAll(callables);
      //TODO go look for exceptions here, really
      for (Future<Void> future : futures) {
        future.get();
      }
    } catch (InterruptedException ie) {
      log.warn("error in factorization", ie);
    } catch (ExecutionException ee) {
      log.warn("error in factorization", ee);
    }
    executor.shutdown();
  }

  private static Collection<Callable<Void>> wrapWithStatsCallables(Collection<Callable<Void>> callables) {
    int size = callables.size();
    Collection<Callable<Void>> wrapped = Lists.newArrayListWithExpectedSize(size);
    int count = 1;
    RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
    for (Callable<Void> callable : callables) {
      boolean logStats = count++ % 1000 == 0;
      wrapped.add(new StatsCallable(callable, logStats, timing));
    }
    return wrapped;
  }

  private final class FeatureVectorCallable implements Callable<Void> {

    private final Matrix C;
    private final Matrix prefVector;
    private final int id;

    private FeatureVectorCallable(Matrix C, Matrix prefVector, int id) {
      this.C = C;
      this.prefVector = prefVector;
      this.id = id;
    }

    @Override
    public Void call() throws Exception {
      Matrix XTCX;
      if (recomputeUserFeatures) {
        Matrix I = identityV(dataModel.getNumItems());
        Matrix I2 = identityV(numFeatures);
        Matrix iTi = itemTransItem.clone();
        Matrix itemM = new DenseMatrix(itemMatrix);
        XTCX = iTi.plus(itemM.transpose().times(C.minus(I)).times(itemM));

        Matrix diag = solve(XTCX.plus(I2.times(preventOverfitting)), I2);
        Matrix results = diag.times(itemM.transpose().times(C)).times(prefVector.transpose());
        updateMatrix(id, results);
      } else {
        Matrix I = identityV(dataModel.getNumUsers());
        Matrix I2 = identityV(numFeatures);
        Matrix uTu = userTransUser.clone();
        Matrix userM = new DenseMatrix(userMatrix);
        XTCX = uTu.plus(userM.transpose().times(C.minus(I)).times(userM));

        Matrix diag = solve(XTCX.plus(I2.times(preventOverfitting)), I2);
        Matrix results = diag.times(userM.transpose().times(C)).times(prefVector.transpose());
        updateMatrix(id, results);
      }
      return null;
    }
  }

  private static Matrix solve(Matrix A, Matrix y) {
    return new QRDecomposition(A).solve(y);
  }

  private static final class StatsCallable implements Callable<Void> {

    private final Callable<Void> delegate;
    private final boolean logStats;
    private final RunningAverageAndStdDev timing;

    private StatsCallable(Callable<Void> delegate, boolean logStats, RunningAverageAndStdDev timing) {
      this.delegate = delegate;
      this.logStats = logStats;
      this.timing = timing;
    }

    @Override
    public Void call() throws Exception {
      long start = System.currentTimeMillis();
      delegate.call();
      long end = System.currentTimeMillis();
      timing.addDatum(end - start);
      if (logStats) {
        Runtime runtime = Runtime.getRuntime();
        int average = (int) timing.getAverage();
        log.info("Average time per task: {}ms", average);
        long totalMemory = runtime.totalMemory();
        long memory = totalMemory - runtime.freeMemory();
        log.info("Approximate memory used: {}MB / {}MB", memory / 1000000L, totalMemory / 1000000L);
      }
      return null;
    }
  }
}
TOP

Related Classes of org.apache.mahout.cf.taste.impl.recommender.svd.ImplicitLinearRegressionFactorizer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.