package edu.cmu.graphchi.walks.distributions;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.walks.WalkArray;
import edu.cmu.graphchi.util.IdCount;
import edu.cmu.graphchi.util.IntegerBuffer;
import java.io.*;
import java.rmi.Naming;
import java.rmi.RemoteException;
import java.rmi.registry.LocateRegistry;
import java.rmi.server.UnicastRemoteObject;
import java.text.NumberFormat;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
/**
* DrunkardCompanion is a remote (or local) service that receives walks from the DrunkardEngine
* and maintains a distribution of visits from each source.
* Done partially during internship at Twitter, Fall 2012
* @author Aapo Kyrola, akyrola@cs.cmu.edu
*/
public abstract class DrunkardCompanion extends UnicastRemoteObject implements RemoteDrunkardCompanion {
protected static class WalkSubmission {
WalkArray walks;
int[] atVertices;
private WalkSubmission(WalkArray walks, int[] atVertices) {
this.walks = walks;
this.atVertices = atVertices;
}
}
protected static final int BUFFER_CAPACITY = 128;
protected static final int BUFFER_MAX = 128;
protected int[] sourceVertexIds;
protected Object[] distrLocks;
boolean isLowInMemory = false;
protected DiscreteDistribution[] distributions;
protected IntegerBuffer[] buffers;
protected AtomicInteger outstanding = new AtomicInteger(0);
protected ExecutorService parallelExecutor;
protected long maxMemoryBytes;
protected LinkedBlockingQueue<WalkSubmission> pendingQueue = new LinkedBlockingQueue<WalkSubmission>();
protected static Logger logger = ChiLogger.getLogger("drunkardcompanion");
protected Timer timer = new Timer(true);
private boolean closed = false;
/**
* Prints estimate of memory usage
*/
private long memoryAuditReport() {
long companionOverHeads = 0;
companionOverHeads += sourceVertexIds.length * 4;
companionOverHeads += distrLocks.length * 4;
long bufferMem = 0;
long maxMem = 0;
for(IntegerBuffer buf : buffers) {
long est = buf.memorySizeEst();
bufferMem += est;
maxMem = Math.max(maxMem, est);
}
long distributionMem = 0;
long maxDistMem = 0;
long avoidMem = 0;
for(DiscreteDistribution dist : distributions) {
long est = dist.memorySizeEst();
distributionMem += est;
maxDistMem = Math.max(est, maxDistMem);
avoidMem += dist.avoidCount() * 6;
}
NumberFormat nf = NumberFormat.getInstance(Locale.US);
logger.info("======= MEMORY REPORT ======");
logger.info("Companion internal: " + nf.format(companionOverHeads / 1024. / 1024.) + " mb");
logger.info("Buffer mem: " + nf.format(bufferMem / 1024. / 1024.) + " mb");
logger.info("Avg bytes per buffer: " + nf.format(bufferMem * 1.0 / buffers.length / 1024.) + " kb");
logger.info("Max buffer was: " + nf.format(maxMem / 1024.) + "kb");
logger.info("Distribution mem: " + nf.format(distributionMem / 1024. / 1024.) + " mb");
logger.info("- of which avoids: " + nf.format(avoidMem / 1024. / 1024.) + " mb");
logger.info("Avg bytes per distribution: " + nf.format((distributionMem * 1.0 / distributions.length / 1024.)) + " kb");
logger.info("Max distribution: " + nf.format(maxDistMem / 1024.) + " kb");
long totalMem = companionOverHeads + bufferMem + distributionMem;
logger.info("** Total: " + nf.format(totalMem / 1024. / 1024. / 1024.) + " GB (low-mem limit " + Runtime.getRuntime().maxMemory() * 0.25 / 1024. / 1024. / 1024. + "GB)" );
isLowInMemory = totalMem > maxMemoryBytes;
if (isLowInMemory) {
compactMemoryUsage();
}
return totalMem;
}
/**
* Removes tails from distributions to save memory
*/
private void compactMemoryUsage() {
long before=0;
long after=0;
for(int i=0; i < distributions.length; i++) {
DiscreteDistribution prevDist, newDist;
synchronized (distrLocks[i]) {
prevDist = distributions[i];
newDist = prevDist.filteredAndShift(2);
distributions[i] = newDist;
}
before += prevDist.memorySizeEst();
after += newDist.memorySizeEst();
}
logger.info("** Compacted: " + (before / 1024. / 1024. / 1024.) + " GB --> " + (after / 1024. / 1024. / 1024.) + " GB");
}
/**
* Creates the DrunkardCompanion object
* @param numThreads number of worker threads (4 is common)
* @param maxMemoryBytes maximum amount of memory to use for storing the distributions
* @throws RemoteException
*/
public DrunkardCompanion( final int numThreads, final long maxMemoryBytes) throws RemoteException {
this.maxMemoryBytes = maxMemoryBytes;
parallelExecutor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
for(int threadId=0; threadId < numThreads; threadId++) {
final int _threadId = threadId;
Thread processingThread = new Thread(new Runnable() {
@Override
public void run() {
try {
long unpurgedWalks = 0;
while(!closed) {
WalkSubmission subm = pendingQueue.poll(2000, TimeUnit.MILLISECONDS);
if (subm != null) {
_processWalks(subm.walks, subm.atVertices);
unpurgedWalks += subm.walks.size();
}
if (sourceVertexIds != null) {
if (unpurgedWalks > sourceVertexIds.length * 10 || (subm == null && unpurgedWalks > 100000)) {
logger.fine("Purge:" + unpurgedWalks);
unpurgedWalks = 0;
// Loop to see what to drain. Every thread looks for
// different buffers.
for(int i=_threadId; i < sourceVertexIds.length; i+=numThreads) {
if (buffers[i].size() >= BUFFER_MAX || closed) {
// Drain asynchronously
outstanding.incrementAndGet();
final IntegerBuffer toDrain = buffers[i];
final int drainIdx = i;
synchronized (buffers[i]) {
buffers[i] = new IntegerBuffer(BUFFER_CAPACITY);
}
parallelExecutor.submit(new Runnable() { public void run() {
try {
int[] d = toDrain.toIntArray();
Arrays.sort(d);
DiscreteDistribution dist = new DiscreteDistribution(d);
mergeWith(drainIdx, dist);
} catch (Exception err ) {
err.printStackTrace();
} finally {
outstanding.decrementAndGet();
}
}});
}
}
}
}
}
} catch (Exception err) {
err.printStackTrace();
}
}
});
processingThread.setDaemon(true);
processingThread.start();
}
}
private void mergeWith(int sourceIdx, DiscreteDistribution distr) {
synchronized (distrLocks[sourceIdx]) {
distributions[sourceIdx] = DiscreteDistribution.merge(distributions[sourceIdx], distr);
/* if (pruneFraction > 0.0 && isLowInMemory) {
int sz = distributions[sourceIdx].sizeExcludingAvoids();
if (sz > 200) {
int mx = distributions[sourceIdx].max();
int pruneLimit = 2 + (int) (mx * pruneFraction);
DiscreteDistribution filtered = distributions[sourceIdx].filteredAndShift((short)pruneLimit);
if (filtered.sizeExcludingAvoids() > 25) { // ad-hoc...
distributions[sourceIdx] = filtered;
int prunedSize = distributions[sourceIdx].sizeExcludingAvoids();
if (sourceIdx % 10000 == 0) {
logger.info("Pruned: " + sz + " => " + prunedSize + " max: " + mx + ", limit=" + pruneLimit);
}
} else {
// logger.info("Filtering would have deleted almost everything...");
// Try pruning ones
filtered = distributions[sourceIdx].filteredAndShift((short)2);
if (filtered.sizeExcludingAvoids() > 25) {
distributions[sourceIdx] = filtered;
} else {
distributions[sourceIdx] = distributions[sourceIdx].filteredAndShift((short)1);
}
}
}
} */
}
}
@Override
public void setAvoidList(int sourceIdx, int[] avoidList) throws RemoteException {
Arrays.sort(avoidList);
DiscreteDistribution avoidDistr = DiscreteDistribution.createAvoidanceDistribution(avoidList);
mergeWith(sourceIdx, avoidDistr);
}
@Override
public void setSources(int[] sources) throws RemoteException {
// Restart timer
timer.cancel();
timer = new Timer(true);
logger.info("Initializing sources...");
buffers = new IntegerBuffer[sources.length];
sourceVertexIds = new int[sources.length];
distrLocks = new Object[sources.length];
distributions = new DiscreteDistribution[sources.length];
for(int i=0; i < sources.length; i++) {
distrLocks[i] = new Object();
sourceVertexIds[i] = sources[i];
buffers[i] = new IntegerBuffer(BUFFER_CAPACITY);
distributions[i] = DiscreteDistribution.createAvoidanceDistribution(new int[]{sources[i]}); // Add the vertex itself to avoids
}
logger.info("Done...");
timer.schedule(new TimerTask() {
@Override
public void run() {
memoryAuditReport();
}
}, 5000, 60000);
}
protected abstract void _processWalks(WalkArray walkArray, int[] atVertices);
@Override
public IdCount[] getTop(int vertexId, int nTop) throws RemoteException {
int sourceIdx = (sourceVertexIds == null ? -1 : Arrays.binarySearch(sourceVertexIds, vertexId));
if (sourceIdx >= 0) {
int[] arr = buffers[sourceIdx].toIntArray();
drainBuffer(sourceIdx);
return distributions[sourceIdx].getTop(nTop);
} else {
throw new IllegalArgumentException("Vertex not found from memory. ");
}
}
protected void drainBuffer(int sourceIdx) {
synchronized (buffers[sourceIdx]) {
int[] arr = buffers[sourceIdx].toIntArray();
buffers[sourceIdx] = new IntegerBuffer(BUFFER_CAPACITY);
Arrays.sort(arr);
DiscreteDistribution dist = new DiscreteDistribution(arr);
mergeWith(sourceIdx, dist);
}
}
@Override
public void processWalks(final WalkArray walks, final int[] atVertices) throws RemoteException {
try {
pendingQueue.put(new WalkSubmission(walks, atVertices));
int pending = pendingQueue.size();
if (pending > 50 && pending % 20 == 0) {
logger.info("Warning, pending queue size: " + pending);
}
} catch (Exception err) {
err.printStackTrace();
}
}
public void outputDistributions(String outputFile) throws RemoteException {
outputDistributions(outputFile, 10);
}
/*
Writes the top visit counts to a binary file.
*/
public void outputDistributions(String outputFile, int nTop) throws RemoteException {
logger.info("Waiting for processing to finish");
while(outstanding.get() > 0) {
logger.info("...");
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
logger.info("Write output...");
try {
DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(
new File(outputFile))));
for(int i=0; i<sourceVertexIds.length; i++) {
int sourceVertex = sourceVertexIds[i];
drainBuffer(i);
DiscreteDistribution distr = distributions[i];
IdCount[] topVertices = distr.getTop(nTop);
dos.writeInt(sourceVertex);
int written = 0;
for(IdCount vc : topVertices) {
dos.writeInt(vc.id);
dos.writeInt(vc.count);
written++;
}
while(written < nTop) {
written++;
dos.writeInt(-1);
dos.writeInt(-1);
}
}
dos.close();
} catch (Exception err) {
err.printStackTrace();
}
}
public void close() {
closed = true;
timer.cancel();
}
public static void main(String[] args) throws Exception {
Double pruneFraction = Double.parseDouble(args[0]);
String bindAddress = args[1];
try {
LocateRegistry.createRegistry(1099);
} catch (Exception err) {
logger.info("Registry already created?");
}
// TODO? Not sure what the main class is used for; just for testing? This may need to be
// put into the subclass.
Naming.rebind(bindAddress, new IntDrunkardCompanion(4, (long) (Runtime.getRuntime().maxMemory() * 0.75)));
logger.info("Prune fraction: " + pruneFraction);
}
}