package com.twitter.pers.bipartite;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.ChiVertex;
import edu.cmu.graphchi.GraphChiContext;
import edu.cmu.graphchi.GraphChiProgram;
import edu.cmu.graphchi.vertexdata.ForeachCallback;
import edu.cmu.graphchi.vertexdata.VertexAggregator;
import edu.cmu.graphchi.datablocks.FloatConverter;
import edu.cmu.graphchi.datablocks.FloatPair;
import edu.cmu.graphchi.datablocks.FloatPairConverter;
import edu.cmu.graphchi.engine.GraphChiEngine;
import edu.cmu.graphchi.engine.VertexInterval;
import edu.cmu.graphchi.hadoop.PigGraphChiBase;
import edu.cmu.graphchi.preprocessing.EdgeProcessor;
import edu.cmu.graphchi.preprocessing.FastSharder;
import edu.cmu.graphchi.preprocessing.VertexProcessor;
import edu.cmu.graphchi.util.IdFloat;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.Tuple;
import org.apache.pig.data.TupleFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.logging.Logger;
/**
* Version of SALSA that uses just a little memory (values propagated
* via edges), and can be run under Pig.
*
* On each iteration either left or right side is computed. Each vertex
* can represent both sides. Left side has out-edges, right side in-edges.
* Left side = authorities (users)
* Right side = hubs
*
* The algorithm starts with the right side, and the edges have initial
* values for the left side vertices (authorities).
*
* @author Aapo Kyrola, akyrola@cs.cmu.edu
* @copyright Twitter (done during internship, Fall 2012)
*/
public class SALSASmallMem extends PigGraphChiBase implements GraphChiProgram<FloatPair, Float> {
private final static int RIGHTSIDE = 0; // Start with right side
private final static int LEFTSIDE = 1;
private String graphName;
private final static Logger logger = ChiLogger.getLogger("salsa-smallmem");
int numShards = 20;
GraphChiEngine<FloatPair, Float> engine;
public SALSASmallMem() {
super();
}
@Override
public void update(ChiVertex<FloatPair, Float> vertex, GraphChiContext context) {
int side = context.getIteration() % 2;
if (vertex.numEdges() > 0) {
float nbrSum = 0.0f;
if (side == LEFTSIDE) {
for(int i=0; i < vertex.numOutEdges(); i++) {
nbrSum += vertex.outEdge(i).getValue();
}
} else {
for(int i=0; i < vertex.numInEdges(); i++) {
nbrSum += vertex.inEdge(i).getValue();
}
}
float newValue = nbrSum;
FloatPair curValue = vertex.getValue();
if (side == LEFTSIDE && vertex.numOutEdges() > 0) {
curValue = new FloatPair(newValue, curValue.second);
// Write value to outedges
float broadcastValue = newValue / vertex.numOutEdges();
for(int i=0; i < vertex.numOutEdges(); i++) {
vertex.outEdge(i).setValue(broadcastValue);
}
}
else if (side == RIGHTSIDE && vertex.numInEdges() > 0) {
// Renormalization
int numRelevantEdges = vertex.numInEdges();
int totalEdges = (int) curValue.second;
if (totalEdges == 0) {
logger.warning("Normalization factor cannot be zero! Id:" + context.getVertexIdTranslate().backward(vertex.getId()));
totalEdges = numRelevantEdges;
}
newValue *= numRelevantEdges * 1.0f / (float)totalEdges;
// Write value to in-edges
float broadcastValue = newValue / vertex.numInEdges();
for(int i=0; i < vertex.numInEdges(); i++) {
vertex.inEdge(i).setValue(broadcastValue);
}
}
vertex.setValue(curValue);
}
}
@Override
public void beginIteration(GraphChiContext ctx) {
}
@Override
public void beginInterval(GraphChiContext ctx, VertexInterval interval) {
}
@Override
public void endInterval(GraphChiContext ctx, VertexInterval interval) {
}
public void endIteration(GraphChiContext ctx) {
}
@Override
public void beginSubInterval(GraphChiContext ctx, VertexInterval interval) {
}
@Override
public void endSubInterval(GraphChiContext ctx, VertexInterval interval) {
}
public void run(String graphName, int numShards) throws Exception {
this.graphName = graphName;
engine = new GraphChiEngine<FloatPair, Float>(graphName, numShards);
engine.setEnableScheduler(false);
engine.setSkipZeroDegreeVertices(true);
engine.setEdataConverter(new FloatConverter());
engine.setVertexDataConverter(new FloatPairConverter());
engine.setMaxWindow(20000000);
engine.run(this, 8);
}
private void outputResults(String graphName) throws IOException {
VertexAggregator.foreach(engine.numVertices(), graphName, new FloatPairConverter(), new ForeachCallback<FloatPair>() {
@Override
public void callback(int vertexId, FloatPair vertexValue) {
if (vertexValue.first > 0) {
System.out.println(engine.getVertexIdTranslate().backward(vertexId) + "\t" + vertexValue.first);
}
}
});
}
/**
] * @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
int k = 0;
String graphName = null;
if (args.length == 2) graphName = args[k++];
int nShards = Integer.parseInt(args[k++]);
SALSASmallMem hits = new SALSASmallMem();
if (graphName == null) {
graphName = "pipein";
FastSharder sharder = hits.createSharder(graphName, nShards);
sharder.shard(System.in);
}
hits.run(graphName, nShards);
hits.outputResults(graphName);
}
// PIG support
@Override
protected String getSchemaString() {
return "(weight:float, vertex:int)";
}
@Override
protected int getNumShards() {
return numShards;
}
private ArrayList<IdFloat> results;
private Iterator<IdFloat> resultIter;
@Override
protected void runGraphChi() throws Exception {
run(getGraphName(), getNumShards());
results = new ArrayList<IdFloat>(100000);
// Collect results - into memory ... This may consume a lot of memory.
// It would be better to have an iterator for the vertex data.
VertexAggregator.foreach(engine.numVertices(), graphName, new FloatPairConverter(), new ForeachCallback<FloatPair>() {
@Override
public void callback(int vertexId, FloatPair vertexValue) {
if (vertexValue.first > 0) {
results.add(new IdFloat(engine.getVertexIdTranslate().backward(vertexId), vertexValue.first));
}
}
});
engine = null;
resultIter = results.iterator();
}
@Override
protected FastSharder createSharder(String graphName, int numShards) throws IOException {
this.numShards = numShards;
return new FastSharder<FloatPair, Float>(graphName, numShards, new VertexProcessor<FloatPair>() {
@Override
/* For lists (hubs), the vertex value will encode the total number of edges */
public FloatPair receiveVertexValue(int vertexId, String token) {
return new FloatPair(0.0f, Float.parseFloat(token));
}
}, new EdgeProcessor<Float>() {
@Override
public Float receiveEdge(int from, int to, String token) {
return Float.parseFloat(token);
}
}, new FloatPairConverter(), new FloatConverter());
}
@Override
protected Tuple getNextResult(TupleFactory tupleFactory) throws ExecException {
if (resultIter.hasNext()) {
IdFloat res = resultIter.next();
Tuple t = tupleFactory.newTuple(2);
t.set(0, res.getValue());
t.set(1, res.getVertexId());
return t;
} else {
return null;
}
}
@Override
protected String getStatusString() {
if (engine != null) {
GraphChiContext ctx = engine.getContext();
if (ctx != null) {
return ctx.getCurInterval() + " iteration: " + ctx.getIteration() + "/" + ctx.getNumIterations();
}
}
return "Initializing";
}
}