Package edu.cmu.graphchi.walks.distributions

Source Code of edu.cmu.graphchi.walks.distributions.TwoKeyCompanion$ProcessingThread

package edu.cmu.graphchi.walks.distributions;

import edu.cmu.graphchi.ChiLogger;
import edu.cmu.graphchi.walks.WalkArray;
import edu.cmu.graphchi.walks.LongWalkArray;
import edu.cmu.graphchi.walks.distributions.DiscreteDistribution;
import edu.cmu.graphchi.walks.distributions.RemoteDrunkardCompanion;
import edu.cmu.graphchi.util.IdCount;
import edu.cmu.graphchi.util.IntegerBuffer;

import java.io.*;
import java.rmi.Naming;
import java.rmi.RemoteException;
import java.rmi.registry.LocateRegistry;
import java.rmi.server.UnicastRemoteObject;
import java.text.NumberFormat;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;


/**
* A DrunkardCompanion object that has two keys to get to a DiscreteDistribution, instead of one.
* Where DrunkardCompanion represents a matrix of values (one key to get to a DiscreteDistribution
* vector), this represents a rank 3 tensor (two keys to get to a DiscreteDistribution).  This is
* suitable for collecting more complicated statistics than DrunkardCompanion, though the current
* implementation is perhaps a little slower than it could be, using nested hash maps instead of a
* more efficient data structure.
*/
public abstract class TwoKeyCompanion extends UnicastRemoteObject
        implements RemoteDrunkardCompanion {

    protected static class WalkSubmission {
        WalkArray walks;
        int[] atVertices;

        private WalkSubmission(WalkArray walks, int[] atVertices) {
            this.walks = walks;
            this.atVertices = atVertices;
        }
    }

    protected static final int BUFFER_CAPACITY = 128;
    protected static final int BUFFER_MAX = 128;

    boolean isLowInMemory = false;

    // Using hash maps of hash maps isn't the most efficient thing to do here, but it'll do for
    // now.
    protected ConcurrentHashMap<Integer,
              ConcurrentHashMap<Integer, DiscreteDistribution>> distributions;
    protected ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, IntegerBuffer>> buffers;
    protected ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, Object>> distrLocks;
    protected AtomicInteger outstanding = new AtomicInteger(0);

    protected ExecutorService parallelExecutor;
    protected long maxMemoryBytes;

    protected LinkedBlockingQueue<WalkSubmission> pendingQueue = new LinkedBlockingQueue<WalkSubmission>();

    protected static Logger logger = ChiLogger.getLogger("pathcompanion");
    protected Timer timer  = new Timer(true);

    /**
     * Prints estimate of memory usage
     */
    private long memoryAuditReport() {
        long companionOverHeads = 0;

        long bufferMem = 0;
        long maxMem = 0;
        int bufferCount = 0;
        for (ConcurrentHashMap<Integer, IntegerBuffer> map : buffers.values()) {
            companionOverHeads += 4;
            for(IntegerBuffer buf : map.values()) {
                bufferCount += 1;
                companionOverHeads += 4;
                long est = buf.memorySizeEst();
                bufferMem += est;
                maxMem = Math.max(maxMem, est);
            }
        }

        long distributionMem = 0;
        long maxDistMem = 0;
        long avoidMem = 0;
        int distCount = 0;
        for (ConcurrentHashMap<Integer, DiscreteDistribution> map : distributions.values()) {
            companionOverHeads += 4;
            for(DiscreteDistribution dist : map.values()) {
                distCount += 1;
                companionOverHeads += 4;
                long est = dist.memorySizeEst();
                distributionMem += est;
                maxDistMem = Math.max(est, maxDistMem);
                avoidMem += dist.avoidCount() * 6;
            }
        }

        NumberFormat nf = NumberFormat.getInstance(Locale.US);

        logger.info("======= MEMORY REPORT ======");
        logger.info("Companion internal: " + nf.format(companionOverHeads / 1024. / 1024.) + " mb");

        logger.info("Buffer mem: " + nf.format(bufferMem / 1024. / 1024.) + " mb");
        logger.info("Avg bytes per buffer: " +
                nf.format(bufferMem * 1.0 / bufferCount / 1024.) + " kb");
        logger.info("Max buffer was: " + nf.format(maxMem / 1024.) + "kb");

        logger.info("Distribution mem: " + nf.format(distributionMem / 1024. / 1024.) + " mb");
        logger.info("- of which avoids: " + nf.format(avoidMem / 1024. / 1024.) + " mb");

        logger.info("Avg bytes per distribution: " +
                nf.format((distributionMem * 1.0 / distCount / 1024.)) + " kb");
        logger.info("Max distribution: " + nf.format(maxDistMem / 1024.) + " kb");

        long totalMem = companionOverHeads + bufferMem + distributionMem;
        logger.info("** Total:  " + nf.format(totalMem / 1024. / 1024. / 1024.) +
                " GB (low-mem limit " +
                Runtime.getRuntime().maxMemory() * 0.75 / 1024. / 1024. / 1024. + "GB)" );
        isLowInMemory = totalMem > maxMemoryBytes;

        if (isLowInMemory) {
            compactMemoryUsage();
        }

        return totalMem;
    }

    /**
     * Removes tails from distributions to save memory
     */
    private void compactMemoryUsage() {
        long before=0;
        long after=0;

        for (Integer firstKey : distributions.keySet()) {
            ConcurrentHashMap<Integer, DiscreteDistribution> map = distributions.get(firstKey);
            for (Integer secondKey : map.keySet()) {
                DiscreteDistribution prevDist, newDist;
                synchronized (distrLocks.get(firstKey).get(secondKey)) {
                    prevDist = map.get(secondKey);
                    newDist =  prevDist.filteredAndShift(2);
                    map.put(secondKey, newDist);
                }
                before += prevDist.memorySizeEst();
                after += newDist.memorySizeEst();
            }
        }

        logger.info("** Compacted: " + (before / 1024. / 1024. / 1024.) + " GB --> " +
                (after / 1024. / 1024. / 1024.) + " GB");
    }


    /**
     * Creates the TwoKeyCompanion object
     * @param numThreads number of worker threads (4 is common)
     * @param maxMemoryBytes maximum amount of memory to use for storing the distributions
     */
    public TwoKeyCompanion(int numThreads, long maxMemoryBytes) throws RemoteException {
        this.maxMemoryBytes = maxMemoryBytes;
        parallelExecutor =
            Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());

        buffers = new ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, IntegerBuffer>>();
        distrLocks = new ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, Object>>();
        distributions = new ConcurrentHashMap<Integer,
                      ConcurrentHashMap<Integer, DiscreteDistribution>>();


        for(int threadId=0; threadId < numThreads; threadId++) {
            Thread processingThread = new Thread(new ProcessingThread(threadId, numThreads));
            processingThread.setDaemon(true);
            processingThread.start();
        }
    }

    private class ProcessingThread implements Runnable {
        private int id;
        private int numThreads;

        public ProcessingThread(int id, int numThreads) {
            this.id = id;
            this.numThreads = numThreads;
        }
        @Override
        public void run() {
            try {
                long unpurgedWalks = 0;
                while(true) {
                    WalkSubmission subm = pendingQueue.poll(2000, TimeUnit.MILLISECONDS);
                    if (subm != null) {
                        _processWalks(subm.walks, subm.atVertices);
                        unpurgedWalks += subm.walks.size();
                    }
                    if (distributions != null) {
                        if (unpurgedWalks > distributions.size() * 10 ||
                                (subm == null && unpurgedWalks > 100000)) {
                            logger.fine("Purge:" + unpurgedWalks);
                            unpurgedWalks = 0;

                            // Loop to see what to drain. Every thread looks for
                            // different buffers.
                            for (Integer firstKey : buffers.keySet()) {
                                ConcurrentHashMap<Integer, IntegerBuffer> map =
                                    buffers.get(firstKey);
                                for (Integer secondKey : map.keySet()) {
                                    if ((firstKey + secondKey) % numThreads != id) {
                                        continue;
                                    }
                                    // Drain asynchronously
                                    outstanding.incrementAndGet();
                                    final IntegerBuffer toDrain = map.get(secondKey);
                                    final int first = firstKey;
                                    final int second = secondKey;

                                    synchronized (toDrain) {
                                        map.put(secondKey, new IntegerBuffer(BUFFER_CAPACITY));
                                    }
                                    parallelExecutor.submit(new Runnable() { public void run() {
                                        try {
                                            int[] d = toDrain.toIntArray();
                                            Arrays.sort(d);
                                            DiscreteDistribution dist = new DiscreteDistribution(d);
                                            mergeWith(first, second, dist);
                                        } catch (Exception err ) {
                                            err.printStackTrace();
                                        } finally {
                                            outstanding.decrementAndGet();
                                        }
                                    }});
                                }
                            }
                        }
                    }
                }
            } catch (Exception err) {
                err.printStackTrace();
            }
        }
    }

    protected void ensureExists(int firstKey, int secondKey) {
        ConcurrentHashMap<Integer, Object> map = distrLocks.get(firstKey);
        if (map == null) {
            ConcurrentHashMap<Integer, Object> new_map = new ConcurrentHashMap<Integer, Object>();
            map = distrLocks.putIfAbsent(firstKey, new_map);
            if (map == null) {
                map = new_map;
            }
        }
        Object lock = map.get(secondKey);
        if (lock == null) {
            Object new_lock = new Object();
            lock = map.putIfAbsent(secondKey, new_lock);
            if (lock == null) {
                synchronized(new_lock) {
                    ConcurrentHashMap<Integer, DiscreteDistribution> dmap =
                        distributions.get(firstKey);
                    if (dmap == null) {
                        dmap = new ConcurrentHashMap<Integer, DiscreteDistribution>();
                        distributions.put(firstKey, dmap);
                    }
                    dmap.put(secondKey, new DiscreteDistribution());
                    ConcurrentHashMap<Integer, IntegerBuffer> bmap = buffers.get(firstKey);
                    if (bmap == null) {
                        bmap = new ConcurrentHashMap<Integer, IntegerBuffer>();
                        buffers.put(firstKey, bmap);
                    }
                    bmap.put(secondKey, new IntegerBuffer(BUFFER_CAPACITY));
                }
            } else {
                synchronized(lock) {
                    // We're just waiting for the other thread to release the lock, so that we can
                    // get the buffer without crashing later.  Another thread actually added it,
                    // but we have to wait for them.
                }
            }
        }
    }

    private void mergeWith(int firstKey, int secondKey, DiscreteDistribution distr) {
        ensureExists(firstKey, secondKey);
        synchronized (distrLocks.get(firstKey).get(secondKey)) {
            DiscreteDistribution mergeInto = distributions.get(firstKey).get(secondKey);
            DiscreteDistribution merged = DiscreteDistribution.merge(mergeInto, distr);
            distributions.get(firstKey).put(secondKey, merged);
        }
    }

    @Override
    public void setAvoidList(int sourceIdx, int[] avoidList) throws RemoteException {
        // We don't need this, so this is a no-op
    }

    @Override
    public IdCount[] getTop(int vertexId, int nTop) throws RemoteException {
        // Not really useful for us
        return null;
    }

    @Override
    public void setSources(int[] sources) throws RemoteException {
        // We don't use an array of source indices, so we just take the opportunity to initialize
        // our objects.

        // Restart timer
        timer.cancel();
        timer = new Timer(true);

        timer.schedule(new TimerTask() {
            @Override
            public void run() {
                memoryAuditReport();
            }
        }, 5000, 60000);
    }

    protected void _processWalks(WalkArray walkArray, int[] atVertices) {
        long[] walks = ((LongWalkArray)walkArray).getArray();
        long t1 = System.currentTimeMillis();
        for(int i=0; i < walks.length; i++) {
            long w = walks[i];
            if (ignoreWalk(w)) {
                continue;
            }
            int atVertex = atVertices[i];
            int firstKey = getFirstKey(w, atVertex);
            int secondKey = getSecondKey(w, atVertex);
            int value = getValue(w, atVertex);

            ensureExists(firstKey, secondKey);
            IntegerBuffer buffer = buffers.get(firstKey).get(secondKey);
            synchronized (buffer) {
                buffer.add(value);
            }
        }

        long tt = (System.currentTimeMillis() - t1);
        if (tt > 1000) {
            logger.info("Processing " + walks.length + " took " + tt + " ms.");
        }
    }

    protected boolean ignoreWalk(long walk) {
        if (walk == 0) {
            return true;
        }
        return false;
    }

    protected abstract int getFirstKey(long walk, int atVertex);

    protected abstract int getSecondKey(long walk, int atVertex);

    protected abstract int getValue(long walk, int atVertex);

    protected void drainBuffer(int firstKey, int secondKey) {
        IntegerBuffer buffer = buffers.get(firstKey).get(secondKey);
        int[] arr;
        synchronized (buffer) {
            arr = buffer.toIntArray();
            buffers.get(firstKey).put(secondKey, new IntegerBuffer(BUFFER_CAPACITY));
        }
        Arrays.sort(arr);
        DiscreteDistribution dist = new DiscreteDistribution(arr);
        mergeWith(firstKey, secondKey, dist);
    }

    @Override
    public void processWalks(final WalkArray walks, final int[] atVertices) throws RemoteException {
        try {
            pendingQueue.put(new WalkSubmission(walks, atVertices));
            int pending = pendingQueue.size();
            if (pending > 50 && pending % 20 == 0) {
                logger.info("Warning, pending queue size: " + pending);
            }
        } catch (Exception err) {
            err.printStackTrace();
        }
    }

    protected void waitForFinish() {
        logger.info("Waiting for processing to finish");
        while (pendingQueue.size() > 0) {
            logger.info("...");
            try {
                Thread.sleep(500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        while(outstanding.get() > 0) {
            logger.info("...");
            try {
                Thread.sleep(500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public abstract void outputDistributions(String outputFile) throws RemoteException;

    @Override
    public void outputDistributions(String outputFile, int nTop) throws RemoteException {
        outputDistributions(outputFile);
    }

    public void close() {
        parallelExecutor.shutdown();
        timer.cancel();
        clearMemory();
    }

    protected void clearMemory() {
        distributions.clear();
        buffers.clear();
        distrLocks.clear();
    }
}
TOP

Related Classes of edu.cmu.graphchi.walks.distributions.TwoKeyCompanion$ProcessingThread

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.