Examples of ALSMatrixFactorization


Examples of edu.cmu.graphchi.apps.ALSMatrixFactorization

        this.D = D;
    }

    protected void execute() throws Exception {
        /* Step 1. Compute ALS */
        ALSMatrixFactorization als = ALSMatrixFactorization.computeALS(baseFilename, nShards, D, 5);

        logger.info("Computed ALS, now random walks");

        /* Initialize drunkardmob */
        DrunkardMobEngine<Integer, Float> drunkardMobEngine = new DrunkardMobEngine<Integer, Float>(baseFilename, nShards, new IntDrunkardFactory());
        DrunkardJob positiveJob = drunkardMobEngine.addJob("positive", EdgeDirection.IN_AND_OUT_EDGES,
                new PositiveWalkUpdate(), new IntDrunkardCompanion(2, Runtime.getRuntime().maxMemory() / 8));
        DrunkardJob negativeJob = drunkardMobEngine.addJob("negative", EdgeDirection.IN_AND_OUT_EDGES,
                new NegativeWalkUpdate(), new IntDrunkardCompanion(2, Runtime.getRuntime().maxMemory() / 8));

        drunkardMobEngine.setEdataConverter(new FloatConverter());

        /* Create list of user vertices (i.e vertices on left). But we need to find their internal ids. */
        ALSMatrixFactorization.BipartiteGraphInfo graphInfo = als.getGraphInfo();
        VertexIdTranslate vertexIdTranslate = drunkardMobEngine.getVertexIdTranslate();
        ArrayList<Integer> userVertices = new ArrayList<Integer>(graphInfo.getNumLeft());

        int numUsers = 50000; // NOTE: hard-coded
        int walksPerSource = 1000;

        if (numUsers > graphInfo.getNumLeft())  graphInfo.getNumLeft();
        logger.info("Compute predictions for first " + numUsers + " users");

        for(int i=0; i< numUsers; i++) {
            userVertices.add(vertexIdTranslate.forward(i));
        }

        /* Configure */
        positiveJob.configureWalkSources(userVertices, walksPerSource);
        negativeJob.configureWalkSources(userVertices, walksPerSource);

       /* Run */
        drunkardMobEngine.run(6);


        /* TODO: handle results */
        for(int i=0; i< 500; i++) {
            int userId = vertexIdTranslate.forward(i);
            IdCount[] posTop = positiveJob.getCompanion().getTop(userId, 20);
            IdCount[] negTop =  negativeJob.getCompanion().getTop(userId, 20);

            double sumEstimatePos = 0.0;
            double sumEstimateNeg = 0.0;

            int n = Math.min(posTop.length, negTop.length);
            for(int j=0; j<n; j++) {
                sumEstimatePos += als.predict(userId, posTop[j].id);
                sumEstimateNeg += als.predict(userId, negTop[j].id);
            }


            long t = System.currentTimeMillis();
            // Compute all
            double allSum = 0;
            int numMovies = graphInfo.getNumRight();
            for(int m=0; m < numMovies; m++) {
                int movieId = vertexIdTranslate.forward(graphInfo.getNumLeft() + m);
                allSum += als.predict(userId, movieId);
            }

            System.out.println(i + " avg pos: " + sumEstimatePos / n + "; avg neg: " + sumEstimateNeg / n + "; all="
                    + allSum / graphInfo.getNumRight() + " (" + (System.currentTimeMillis() - t) + " ms for " + numMovies + " movies");
        }
View Full Code Here
TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.