Package edu.cmu.graphchi.walks.distributions

Source Code of edu.cmu.graphchi.walks.distributions.DrunkardCompanion$WalkSubmission

package edu.cmu.graphchi.walks.distributions;

import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.walks.WalkArray;
import edu.cmu.graphchi.util.IdCount;
import edu.cmu.graphchi.util.IntegerBuffer;

import java.io.*;
import java.rmi.Naming;
import java.rmi.RemoteException;
import java.rmi.registry.LocateRegistry;
import java.rmi.server.UnicastRemoteObject;
import java.text.NumberFormat;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;

/**
* DrunkardCompanion is a remote (or local) service that receives walks from the DrunkardEngine
* and maintains a distribution of visits from each source.
* Done partially during internship at Twitter, Fall 2012
* @author Aapo Kyrola, akyrola@cs.cmu.edu
*/
public abstract class DrunkardCompanion extends UnicastRemoteObject implements RemoteDrunkardCompanion {

    protected static class WalkSubmission {
        WalkArray walks;
        int[] atVertices;

        private WalkSubmission(WalkArray walks, int[] atVertices) {
            this.walks = walks;
            this.atVertices = atVertices;
        }
    }

    protected static final int BUFFER_CAPACITY = 128;
    protected static final int BUFFER_MAX = 128;

    protected int[] sourceVertexIds;
    protected Object[] distrLocks;
    boolean isLowInMemory = false;

    protected DiscreteDistribution[] distributions;
    protected IntegerBuffer[] buffers;
    protected AtomicInteger outstanding = new AtomicInteger(0);

    protected ExecutorService parallelExecutor;
    protected long maxMemoryBytes;

    protected LinkedBlockingQueue<WalkSubmission> pendingQueue = new LinkedBlockingQueue<WalkSubmission>();

    protected static Logger logger = ChiLogger.getLogger("drunkardcompanion");
    protected Timer timer  = new Timer(true);

    private boolean closed = false;

    /**
     * Prints estimate of memory usage
     */
    private long memoryAuditReport() {
        long companionOverHeads = 0;
        companionOverHeads += sourceVertexIds.length * 4;
        companionOverHeads += distrLocks.length * 4;


        long bufferMem = 0;
        long maxMem = 0;
        for(IntegerBuffer buf : buffers) {
            long est = buf.memorySizeEst();
            bufferMem += est;
            maxMem = Math.max(maxMem, est);
        }

        long distributionMem = 0;
        long maxDistMem = 0;
        long avoidMem = 0;
        for(DiscreteDistribution dist : distributions) {
            long est = dist.memorySizeEst();
            distributionMem += est;
            maxDistMem = Math.max(est, maxDistMem);
            avoidMem += dist.avoidCount() * 6;
        }

        NumberFormat nf = NumberFormat.getInstance(Locale.US);

        logger.info("======= MEMORY REPORT ======");
        logger.info("Companion internal: " + nf.format(companionOverHeads / 1024. / 1024.) + " mb");

        logger.info("Buffer mem: " + nf.format(bufferMem / 1024. / 1024.) + " mb");
        logger.info("Avg bytes per buffer: " + nf.format(bufferMem * 1.0 / buffers.length / 1024.) + " kb");
        logger.info("Max buffer was: " + nf.format(maxMem / 1024.) + "kb");

        logger.info("Distribution mem: " + nf.format(distributionMem / 1024. / 1024.) + " mb");
        logger.info("- of which avoids: " + nf.format(avoidMem / 1024. / 1024.) + " mb");

        logger.info("Avg bytes per distribution: " + nf.format((distributionMem * 1.0 / distributions.length / 1024.)) + " kb");
        logger.info("Max distribution: " + nf.format(maxDistMem / 1024.) + " kb");

        long totalMem = companionOverHeads + bufferMem + distributionMem;
        logger.info("** Total:  " + nf.format(totalMem / 1024. / 1024. / 1024.) + " GB (low-mem limit " + Runtime.getRuntime().maxMemory() * 0.25 / 1024. / 1024. / 1024. + "GB)" );
        isLowInMemory = totalMem > maxMemoryBytes;

        if (isLowInMemory) {
            compactMemoryUsage();
        }

        return totalMem;
    }

    /**
     * Removes tails from distributions to save memory
     */
    private void compactMemoryUsage() {
        long before=0;
        long after=0;

        for(int i=0; i < distributions.length; i++) {
            DiscreteDistribution prevDist, newDist;
            synchronized (distrLocks[i]) {

                prevDist = distributions[i];
                newDist =  prevDist.filteredAndShift(2);
                distributions[i] = newDist;
            }
            before += prevDist.memorySizeEst();
            after += newDist.memorySizeEst();
        }

        logger.info("** Compacted: " + (before / 1024. / 1024. / 1024.) + " GB --> " + (after / 1024. / 1024. / 1024.) + " GB");
    }


    /**
     * Creates the DrunkardCompanion object
     * @param numThreads number of worker threads (4 is common)
     * @param maxMemoryBytes maximum amount of memory to use for storing the distributions
     * @throws RemoteException
     */
    public DrunkardCompanion( final int numThreads, final long maxMemoryBytes) throws RemoteException {
        this.maxMemoryBytes = maxMemoryBytes;
        parallelExecutor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());


        for(int threadId=0; threadId < numThreads; threadId++) {
            final int _threadId = threadId;
            Thread processingThread = new Thread(new Runnable() {
                @Override
                public void run() {
                    try {

                        long unpurgedWalks = 0;
                        while(!closed) {

                            WalkSubmission subm = pendingQueue.poll(2000, TimeUnit.MILLISECONDS);
                            if (subm != null) {
                                _processWalks(subm.walks, subm.atVertices);
                                unpurgedWalks += subm.walks.size();
                            }
                            if (sourceVertexIds != null) {
                                if (unpurgedWalks > sourceVertexIds.length * 10 || (subm == null && unpurgedWalks > 100000)) {
                                    logger.fine("Purge:" + unpurgedWalks);
                                    unpurgedWalks = 0;

                                    // Loop to see what to drain. Every thread looks for
                                    // different buffers.
                                    for(int i=_threadId; i < sourceVertexIds.length; i+=numThreads) {
                                        if (buffers[i].size() >= BUFFER_MAX || closed) {
                                            // Drain asynchronously
                                            outstanding.incrementAndGet();
                                            final IntegerBuffer toDrain = buffers[i];
                                            final int drainIdx = i;

                                            synchronized (buffers[i]) {
                                                buffers[i] = new IntegerBuffer(BUFFER_CAPACITY);
                                            }
                                            parallelExecutor.submit(new Runnable() { public void run() {
                                                try {
                                                    int[] d = toDrain.toIntArray();
                                                    Arrays.sort(d);
                                                    DiscreteDistribution dist = new DiscreteDistribution(d);
                                                    mergeWith(drainIdx, dist);
                                                } catch (Exception err ) {
                                                    err.printStackTrace();
                                                } finally {
                                                    outstanding.decrementAndGet();
                                                }
                                            }});
                                        }
                                    }
                                }
                            }

                        }
                    } catch (Exception err) {
                        err.printStackTrace();
                    }
                }
            });
            processingThread.setDaemon(true);
            processingThread.start();
        }
    }

    private void mergeWith(int sourceIdx, DiscreteDistribution distr) {
        synchronized (distrLocks[sourceIdx]) {
            distributions[sourceIdx] = DiscreteDistribution.merge(distributions[sourceIdx], distr);

    /*        if (pruneFraction > 0.0 && isLowInMemory) {
                int sz = distributions[sourceIdx].sizeExcludingAvoids();
                if (sz > 200) {
                    int mx = distributions[sourceIdx].max();
                    int pruneLimit = 2 + (int) (mx * pruneFraction);
                    DiscreteDistribution filtered =  distributions[sourceIdx].filteredAndShift((short)pruneLimit);
                    if (filtered.sizeExcludingAvoids() > 25) { // ad-hoc...
                        distributions[sourceIdx] = filtered;
                        int prunedSize = distributions[sourceIdx].sizeExcludingAvoids();
                        if (sourceIdx % 10000 == 0) {
                            logger.info("Pruned: " + sz + " => " + prunedSize + " max: " + mx + ", limit=" + pruneLimit);
                        }
                    } else {
                        //  logger.info("Filtering would have deleted almost everything...");
                        // Try pruning ones
                        filtered = distributions[sourceIdx].filteredAndShift((short)2);
                        if (filtered.sizeExcludingAvoids() > 25) {
                            distributions[sourceIdx] = filtered;
                        }  else {
                            distributions[sourceIdx] = distributions[sourceIdx].filteredAndShift((short)1);
                        }
                    }
                }
            }               */
        }
    }

    @Override
    public void setAvoidList(int sourceIdx, int[] avoidList) throws RemoteException {
        Arrays.sort(avoidList);
        DiscreteDistribution avoidDistr = DiscreteDistribution.createAvoidanceDistribution(avoidList);
        mergeWith(sourceIdx, avoidDistr);
    }

    @Override
    public void setSources(int[] sources) throws RemoteException {
        // Restart timer
        timer.cancel();
        timer = new Timer(true);

        logger.info("Initializing sources...");
        buffers = new IntegerBuffer[sources.length];
        sourceVertexIds = new int[sources.length];
        distrLocks = new Object[sources.length];
        distributions = new DiscreteDistribution[sources.length];
        for(int i=0; i < sources.length; i++) {
            distrLocks[i] = new Object();
            sourceVertexIds[i] = sources[i];
            buffers[i] = new IntegerBuffer(BUFFER_CAPACITY);
            distributions[i] = DiscreteDistribution.createAvoidanceDistribution(new int[]{sources[i]}); // Add the vertex itself to avoids
        }
        logger.info("Done...");

        timer.schedule(new TimerTask() {
            @Override
            public void run() {
                memoryAuditReport();
            }
        }, 5000, 60000);
    }

    protected abstract void _processWalks(WalkArray walkArray, int[] atVertices);

    @Override
    public IdCount[] getTop(int vertexId, int nTop) throws RemoteException {
        int sourceIdx = (sourceVertexIds == null ? -1 : Arrays.binarySearch(sourceVertexIds, vertexId));
        if (sourceIdx >= 0) {
            int[] arr = buffers[sourceIdx].toIntArray();
            drainBuffer(sourceIdx);
            return distributions[sourceIdx].getTop(nTop);
        } else {
           throw new IllegalArgumentException("Vertex not found from memory. ");
        }
    }


    protected void drainBuffer(int sourceIdx) {
        synchronized (buffers[sourceIdx]) {
            int[] arr = buffers[sourceIdx].toIntArray();
            buffers[sourceIdx] = new IntegerBuffer(BUFFER_CAPACITY);
            Arrays.sort(arr);
            DiscreteDistribution dist = new DiscreteDistribution(arr);
            mergeWith(sourceIdx, dist);
        }
    }

    @Override
    public void processWalks(final WalkArray walks, final int[] atVertices) throws RemoteException {
        try {
            pendingQueue.put(new WalkSubmission(walks, atVertices));
            int pending = pendingQueue.size();
            if (pending > 50 && pending % 20 == 0) {
                logger.info("Warning, pending queue size: " + pending);
            }
        } catch (Exception err) {
            err.printStackTrace();
        }
    }

    public void outputDistributions(String outputFile) throws RemoteException {
        outputDistributions(outputFile, 10);
    }

    /*
      Writes the top visit counts to a binary file.
     */
    public void outputDistributions(String outputFile, int nTop) throws RemoteException {
        logger.info("Waiting for processing to finish");
        while(outstanding.get() > 0) {
            logger.info("...");
            try {
                Thread.sleep(500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        logger.info("Write output...");
        try {
            DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(
                    new File(outputFile))));

            for(int i=0; i<sourceVertexIds.length; i++) {
                int sourceVertex = sourceVertexIds[i];
                drainBuffer(i);
                DiscreteDistribution distr = distributions[i];
                IdCount[] topVertices = distr.getTop(nTop);
                dos.writeInt(sourceVertex);
                int written = 0;
                for(IdCount vc : topVertices) {
                    dos.writeInt(vc.id);
                    dos.writeInt(vc.count);
                    written++;
                }
                while(written < nTop) {
                    written++;
                    dos.writeInt(-1);
                    dos.writeInt(-1);
                }
            }
            dos.close();

        } catch (Exception err) {
            err.printStackTrace();
        }
    }

    public void close() {
        closed = true;
        timer.cancel();
    }

    public static void main(String[] args) throws Exception {
        Double pruneFraction = Double.parseDouble(args[0]);
        String bindAddress = args[1];
        try {
            LocateRegistry.createRegistry(1099);
        } catch (Exception err) {
            logger.info("Registry already created?");
        }
        // TODO? Not sure what the main class is used for; just for testing?  This may need to be
        // put into the subclass.
        Naming.rebind(bindAddress, new IntDrunkardCompanion(4, (long) (Runtime.getRuntime().maxMemory() * 0.75)));
        logger.info("Prune fraction: " + pruneFraction);
    }

}
TOP

Related Classes of edu.cmu.graphchi.walks.distributions.DrunkardCompanion$WalkSubmission

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.