package edu.cmu.graphchi.walks.distributions;
import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.walks.WalkArray;
import edu.cmu.graphchi.walks.LongWalkArray;
import edu.cmu.graphchi.walks.distributions.DiscreteDistribution;
import edu.cmu.graphchi.walks.distributions.RemoteDrunkardCompanion;
import edu.cmu.graphchi.util.IdCount;
import edu.cmu.graphchi.util.IntegerBuffer;
import java.io.*;
import java.rmi.Naming;
import java.rmi.RemoteException;
import java.rmi.registry.LocateRegistry;
import java.rmi.server.UnicastRemoteObject;
import java.text.NumberFormat;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
/**
* A DrunkardCompanion object that has two keys to get to a DiscreteDistribution, instead of one.
* Where DrunkardCompanion represents a matrix of values (one key to get to a DiscreteDistribution
* vector), this represents a rank 3 tensor (two keys to get to a DiscreteDistribution). This is
* suitable for collecting more complicated statistics than DrunkardCompanion, though the current
* implementation is perhaps a little slower than it could be, using nested hash maps instead of a
* more efficient data structure.
*/
public abstract class TwoKeyCompanion extends UnicastRemoteObject
implements RemoteDrunkardCompanion {
protected static class WalkSubmission {
WalkArray walks;
int[] atVertices;
private WalkSubmission(WalkArray walks, int[] atVertices) {
this.walks = walks;
this.atVertices = atVertices;
}
}
protected static final int BUFFER_CAPACITY = 128;
protected static final int BUFFER_MAX = 128;
boolean isLowInMemory = false;
// Using hash maps of hash maps isn't the most efficient thing to do here, but it'll do for
// now.
protected ConcurrentHashMap<Integer,
ConcurrentHashMap<Integer, DiscreteDistribution>> distributions;
protected ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, IntegerBuffer>> buffers;
protected ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, Object>> distrLocks;
protected AtomicInteger outstanding = new AtomicInteger(0);
protected ExecutorService parallelExecutor;
protected long maxMemoryBytes;
protected LinkedBlockingQueue<WalkSubmission> pendingQueue = new LinkedBlockingQueue<WalkSubmission>();
protected static Logger logger = ChiLogger.getLogger("pathcompanion");
protected Timer timer = new Timer(true);
/**
* Prints estimate of memory usage
*/
private long memoryAuditReport() {
long companionOverHeads = 0;
long bufferMem = 0;
long maxMem = 0;
int bufferCount = 0;
for (ConcurrentHashMap<Integer, IntegerBuffer> map : buffers.values()) {
companionOverHeads += 4;
for(IntegerBuffer buf : map.values()) {
bufferCount += 1;
companionOverHeads += 4;
long est = buf.memorySizeEst();
bufferMem += est;
maxMem = Math.max(maxMem, est);
}
}
long distributionMem = 0;
long maxDistMem = 0;
long avoidMem = 0;
int distCount = 0;
for (ConcurrentHashMap<Integer, DiscreteDistribution> map : distributions.values()) {
companionOverHeads += 4;
for(DiscreteDistribution dist : map.values()) {
distCount += 1;
companionOverHeads += 4;
long est = dist.memorySizeEst();
distributionMem += est;
maxDistMem = Math.max(est, maxDistMem);
avoidMem += dist.avoidCount() * 6;
}
}
NumberFormat nf = NumberFormat.getInstance(Locale.US);
logger.info("======= MEMORY REPORT ======");
logger.info("Companion internal: " + nf.format(companionOverHeads / 1024. / 1024.) + " mb");
logger.info("Buffer mem: " + nf.format(bufferMem / 1024. / 1024.) + " mb");
logger.info("Avg bytes per buffer: " +
nf.format(bufferMem * 1.0 / bufferCount / 1024.) + " kb");
logger.info("Max buffer was: " + nf.format(maxMem / 1024.) + "kb");
logger.info("Distribution mem: " + nf.format(distributionMem / 1024. / 1024.) + " mb");
logger.info("- of which avoids: " + nf.format(avoidMem / 1024. / 1024.) + " mb");
logger.info("Avg bytes per distribution: " +
nf.format((distributionMem * 1.0 / distCount / 1024.)) + " kb");
logger.info("Max distribution: " + nf.format(maxDistMem / 1024.) + " kb");
long totalMem = companionOverHeads + bufferMem + distributionMem;
logger.info("** Total: " + nf.format(totalMem / 1024. / 1024. / 1024.) +
" GB (low-mem limit " +
Runtime.getRuntime().maxMemory() * 0.75 / 1024. / 1024. / 1024. + "GB)" );
isLowInMemory = totalMem > maxMemoryBytes;
if (isLowInMemory) {
compactMemoryUsage();
}
return totalMem;
}
/**
* Removes tails from distributions to save memory
*/
private void compactMemoryUsage() {
long before=0;
long after=0;
for (Integer firstKey : distributions.keySet()) {
ConcurrentHashMap<Integer, DiscreteDistribution> map = distributions.get(firstKey);
for (Integer secondKey : map.keySet()) {
DiscreteDistribution prevDist, newDist;
synchronized (distrLocks.get(firstKey).get(secondKey)) {
prevDist = map.get(secondKey);
newDist = prevDist.filteredAndShift(2);
map.put(secondKey, newDist);
}
before += prevDist.memorySizeEst();
after += newDist.memorySizeEst();
}
}
logger.info("** Compacted: " + (before / 1024. / 1024. / 1024.) + " GB --> " +
(after / 1024. / 1024. / 1024.) + " GB");
}
/**
* Creates the TwoKeyCompanion object
* @param numThreads number of worker threads (4 is common)
* @param maxMemoryBytes maximum amount of memory to use for storing the distributions
*/
public TwoKeyCompanion(int numThreads, long maxMemoryBytes) throws RemoteException {
this.maxMemoryBytes = maxMemoryBytes;
parallelExecutor =
Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
buffers = new ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, IntegerBuffer>>();
distrLocks = new ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, Object>>();
distributions = new ConcurrentHashMap<Integer,
ConcurrentHashMap<Integer, DiscreteDistribution>>();
for(int threadId=0; threadId < numThreads; threadId++) {
Thread processingThread = new Thread(new ProcessingThread(threadId, numThreads));
processingThread.setDaemon(true);
processingThread.start();
}
}
private class ProcessingThread implements Runnable {
private int id;
private int numThreads;
public ProcessingThread(int id, int numThreads) {
this.id = id;
this.numThreads = numThreads;
}
@Override
public void run() {
try {
long unpurgedWalks = 0;
while(true) {
WalkSubmission subm = pendingQueue.poll(2000, TimeUnit.MILLISECONDS);
if (subm != null) {
_processWalks(subm.walks, subm.atVertices);
unpurgedWalks += subm.walks.size();
}
if (distributions != null) {
if (unpurgedWalks > distributions.size() * 10 ||
(subm == null && unpurgedWalks > 100000)) {
logger.fine("Purge:" + unpurgedWalks);
unpurgedWalks = 0;
// Loop to see what to drain. Every thread looks for
// different buffers.
for (Integer firstKey : buffers.keySet()) {
ConcurrentHashMap<Integer, IntegerBuffer> map =
buffers.get(firstKey);
for (Integer secondKey : map.keySet()) {
if ((firstKey + secondKey) % numThreads != id) {
continue;
}
// Drain asynchronously
outstanding.incrementAndGet();
final IntegerBuffer toDrain = map.get(secondKey);
final int first = firstKey;
final int second = secondKey;
synchronized (toDrain) {
map.put(secondKey, new IntegerBuffer(BUFFER_CAPACITY));
}
parallelExecutor.submit(new Runnable() { public void run() {
try {
int[] d = toDrain.toIntArray();
Arrays.sort(d);
DiscreteDistribution dist = new DiscreteDistribution(d);
mergeWith(first, second, dist);
} catch (Exception err ) {
err.printStackTrace();
} finally {
outstanding.decrementAndGet();
}
}});
}
}
}
}
}
} catch (Exception err) {
err.printStackTrace();
}
}
}
protected void ensureExists(int firstKey, int secondKey) {
ConcurrentHashMap<Integer, Object> map = distrLocks.get(firstKey);
if (map == null) {
ConcurrentHashMap<Integer, Object> new_map = new ConcurrentHashMap<Integer, Object>();
map = distrLocks.putIfAbsent(firstKey, new_map);
if (map == null) {
map = new_map;
}
}
Object lock = map.get(secondKey);
if (lock == null) {
Object new_lock = new Object();
lock = map.putIfAbsent(secondKey, new_lock);
if (lock == null) {
synchronized(new_lock) {
ConcurrentHashMap<Integer, DiscreteDistribution> dmap =
distributions.get(firstKey);
if (dmap == null) {
dmap = new ConcurrentHashMap<Integer, DiscreteDistribution>();
distributions.put(firstKey, dmap);
}
dmap.put(secondKey, new DiscreteDistribution());
ConcurrentHashMap<Integer, IntegerBuffer> bmap = buffers.get(firstKey);
if (bmap == null) {
bmap = new ConcurrentHashMap<Integer, IntegerBuffer>();
buffers.put(firstKey, bmap);
}
bmap.put(secondKey, new IntegerBuffer(BUFFER_CAPACITY));
}
} else {
synchronized(lock) {
// We're just waiting for the other thread to release the lock, so that we can
// get the buffer without crashing later. Another thread actually added it,
// but we have to wait for them.
}
}
}
}
private void mergeWith(int firstKey, int secondKey, DiscreteDistribution distr) {
ensureExists(firstKey, secondKey);
synchronized (distrLocks.get(firstKey).get(secondKey)) {
DiscreteDistribution mergeInto = distributions.get(firstKey).get(secondKey);
DiscreteDistribution merged = DiscreteDistribution.merge(mergeInto, distr);
distributions.get(firstKey).put(secondKey, merged);
}
}
@Override
public void setAvoidList(int sourceIdx, int[] avoidList) throws RemoteException {
// We don't need this, so this is a no-op
}
@Override
public IdCount[] getTop(int vertexId, int nTop) throws RemoteException {
// Not really useful for us
return null;
}
@Override
public void setSources(int[] sources) throws RemoteException {
// We don't use an array of source indices, so we just take the opportunity to initialize
// our objects.
// Restart timer
timer.cancel();
timer = new Timer(true);
timer.schedule(new TimerTask() {
@Override
public void run() {
memoryAuditReport();
}
}, 5000, 60000);
}
protected void _processWalks(WalkArray walkArray, int[] atVertices) {
long[] walks = ((LongWalkArray)walkArray).getArray();
long t1 = System.currentTimeMillis();
for(int i=0; i < walks.length; i++) {
long w = walks[i];
if (ignoreWalk(w)) {
continue;
}
int atVertex = atVertices[i];
int firstKey = getFirstKey(w, atVertex);
int secondKey = getSecondKey(w, atVertex);
int value = getValue(w, atVertex);
ensureExists(firstKey, secondKey);
IntegerBuffer buffer = buffers.get(firstKey).get(secondKey);
synchronized (buffer) {
buffer.add(value);
}
}
long tt = (System.currentTimeMillis() - t1);
if (tt > 1000) {
logger.info("Processing " + walks.length + " took " + tt + " ms.");
}
}
protected boolean ignoreWalk(long walk) {
if (walk == 0) {
return true;
}
return false;
}
protected abstract int getFirstKey(long walk, int atVertex);
protected abstract int getSecondKey(long walk, int atVertex);
protected abstract int getValue(long walk, int atVertex);
protected void drainBuffer(int firstKey, int secondKey) {
IntegerBuffer buffer = buffers.get(firstKey).get(secondKey);
int[] arr;
synchronized (buffer) {
arr = buffer.toIntArray();
buffers.get(firstKey).put(secondKey, new IntegerBuffer(BUFFER_CAPACITY));
}
Arrays.sort(arr);
DiscreteDistribution dist = new DiscreteDistribution(arr);
mergeWith(firstKey, secondKey, dist);
}
@Override
public void processWalks(final WalkArray walks, final int[] atVertices) throws RemoteException {
try {
pendingQueue.put(new WalkSubmission(walks, atVertices));
int pending = pendingQueue.size();
if (pending > 50 && pending % 20 == 0) {
logger.info("Warning, pending queue size: " + pending);
}
} catch (Exception err) {
err.printStackTrace();
}
}
protected void waitForFinish() {
logger.info("Waiting for processing to finish");
while (pendingQueue.size() > 0) {
logger.info("...");
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
while(outstanding.get() > 0) {
logger.info("...");
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
@Override
public abstract void outputDistributions(String outputFile) throws RemoteException;
@Override
public void outputDistributions(String outputFile, int nTop) throws RemoteException {
outputDistributions(outputFile);
}
public void close() {
parallelExecutor.shutdown();
timer.cancel();
clearMemory();
}
protected void clearMemory() {
distributions.clear();
buffers.clear();
distrLocks.clear();
}
}