this.D = D;
}
protected void execute() throws Exception {
/* Step 1. Compute ALS */
ALSMatrixFactorization als = ALSMatrixFactorization.computeALS(baseFilename, nShards, D, 5);
logger.info("Computed ALS, now random walks");
/* Initialize drunkardmob */
DrunkardMobEngine<Integer, Float> drunkardMobEngine = new DrunkardMobEngine<Integer, Float>(baseFilename, nShards, new IntDrunkardFactory());
DrunkardJob positiveJob = drunkardMobEngine.addJob("positive", EdgeDirection.IN_AND_OUT_EDGES,
new PositiveWalkUpdate(), new IntDrunkardCompanion(2, Runtime.getRuntime().maxMemory() / 8));
DrunkardJob negativeJob = drunkardMobEngine.addJob("negative", EdgeDirection.IN_AND_OUT_EDGES,
new NegativeWalkUpdate(), new IntDrunkardCompanion(2, Runtime.getRuntime().maxMemory() / 8));
drunkardMobEngine.setEdataConverter(new FloatConverter());
/* Create list of user vertices (i.e vertices on left). But we need to find their internal ids. */
ALSMatrixFactorization.BipartiteGraphInfo graphInfo = als.getGraphInfo();
VertexIdTranslate vertexIdTranslate = drunkardMobEngine.getVertexIdTranslate();
ArrayList<Integer> userVertices = new ArrayList<Integer>(graphInfo.getNumLeft());
int numUsers = 50000; // NOTE: hard-coded
int walksPerSource = 1000;
if (numUsers > graphInfo.getNumLeft()) graphInfo.getNumLeft();
logger.info("Compute predictions for first " + numUsers + " users");
for(int i=0; i< numUsers; i++) {
userVertices.add(vertexIdTranslate.forward(i));
}
/* Configure */
positiveJob.configureWalkSources(userVertices, walksPerSource);
negativeJob.configureWalkSources(userVertices, walksPerSource);
/* Run */
drunkardMobEngine.run(6);
/* TODO: handle results */
for(int i=0; i< 500; i++) {
int userId = vertexIdTranslate.forward(i);
IdCount[] posTop = positiveJob.getCompanion().getTop(userId, 20);
IdCount[] negTop = negativeJob.getCompanion().getTop(userId, 20);
double sumEstimatePos = 0.0;
double sumEstimateNeg = 0.0;
int n = Math.min(posTop.length, negTop.length);
for(int j=0; j<n; j++) {
sumEstimatePos += als.predict(userId, posTop[j].id);
sumEstimateNeg += als.predict(userId, negTop[j].id);
}
long t = System.currentTimeMillis();
// Compute all
double allSum = 0;
int numMovies = graphInfo.getNumRight();
for(int m=0; m < numMovies; m++) {
int movieId = vertexIdTranslate.forward(graphInfo.getNumLeft() + m);
allSum += als.predict(userId, movieId);
}
System.out.println(i + " avg pos: " + sumEstimatePos / n + "; avg neg: " + sumEstimateNeg / n + "; all="
+ allSum / graphInfo.getNumRight() + " (" + (System.currentTimeMillis() - t) + " ms for " + numMovies + " movies");
}