Package com.facebook.presto.operator

Source Code of com.facebook.presto.operator.TopNOperator$TopNMemoryManager

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;

import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.SortOrder;
import com.facebook.presto.spi.type.Type;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;

import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

/**
* Returns the top N rows from the source sorted according to the specified ordering in the keyChannelIndex channel.
*/
public class TopNOperator
        implements Operator
{
    public static class TopNOperatorFactory
            implements OperatorFactory
    {
        private final int operatorId;
        private final List<Type> sourceTypes;
        private final int n;
        private final List<Integer> sortChannels;
        private final List<SortOrder> sortOrders;
        private final Optional<Integer> sampleWeight;
        private final boolean partial;
        private boolean closed;

        public TopNOperatorFactory(
                int operatorId,
                List<? extends Type> types,
                int n,
                List<Integer> sortChannels,
                List<SortOrder> sortOrders,
                Optional<Integer> sampleWeight,
                boolean partial)
        {
            this.operatorId = operatorId;
            this.sourceTypes = ImmutableList.copyOf(checkNotNull(types, "types is null"));
            this.n = n;
            this.sortChannels = ImmutableList.copyOf(checkNotNull(sortChannels, "sortChannels is null"));
            this.sortOrders = ImmutableList.copyOf(checkNotNull(sortOrders, "sortOrders is null"));
            this.partial = partial;
            this.sampleWeight = sampleWeight;
        }

        @Override
        public List<Type> getTypes()
        {
            return sourceTypes;
        }

        @Override
        public Operator createOperator(DriverContext driverContext)
        {
            checkState(!closed, "Factory is already closed");
            OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, TopNOperator.class.getSimpleName());
            return new TopNOperator(
                    operatorContext,
                    sourceTypes,
                    n,
                    sortChannels,
                    sortOrders,
                    sampleWeight,
                    partial);
        }

        @Override
        public void close()
        {
            closed = true;
        }
    }

    private static final int MAX_INITIAL_PRIORITY_QUEUE_SIZE = 10000;
    private static final DataSize OVERHEAD_PER_VALUE = new DataSize(100, DataSize.Unit.BYTE); // for estimating in-memory size. This is a completely arbitrary number

    private final OperatorContext operatorContext;
    private final List<Type> types;
    private final int n;
    private final List<Integer> sortChannels;
    private final List<SortOrder> sortOrders;
    private final TopNMemoryManager memoryManager;
    private final boolean partial;
    private final Optional<Integer> sampleWeight;

    private final PageBuilder pageBuilder;

    private TopNBuilder topNBuilder;
    private boolean finishing;

    private Iterator<Block[]> outputIterator;

    public TopNOperator(
            OperatorContext operatorContext,
            List<Type> types,
            int n,
            List<Integer> sortChannels,
            List<SortOrder> sortOrders,
            Optional<Integer> sampleWeight,
            boolean partial)
    {
        this.operatorContext = checkNotNull(operatorContext, "operatorContext is null");
        this.types = checkNotNull(types, "types is null");

        checkArgument(n > 0, "n must be greater than zero");
        this.n = n;

        this.sortChannels = checkNotNull(sortChannels, "sortChannels is null");
        this.sortOrders = checkNotNull(sortOrders, "sortOrders is null");

        this.partial = partial;

        this.memoryManager = new TopNMemoryManager(checkNotNull(operatorContext, "operatorContext is null"));

        this.pageBuilder = new PageBuilder(types);

        this.sampleWeight = sampleWeight;
    }

    @Override
    public OperatorContext getOperatorContext()
    {
        return operatorContext;
    }

    @Override
    public List<Type> getTypes()
    {
        return types;
    }

    @Override
    public void finish()
    {
        finishing = true;
    }

    @Override
    public boolean isFinished()
    {
        return finishing && topNBuilder == null && (outputIterator == null || !outputIterator.hasNext());
    }

    @Override
    public ListenableFuture<?> isBlocked()
    {
        return NOT_BLOCKED;
    }

    @Override
    public boolean needsInput()
    {
        return !finishing && outputIterator == null && (topNBuilder == null || !topNBuilder.isFull());
    }

    @Override
    public void addInput(Page page)
    {
        checkState(!finishing, "Operator is already finishing");
        checkNotNull(page, "page is null");
        if (topNBuilder == null) {
            topNBuilder = new TopNBuilder(
                    n,
                    sortChannels,
                    sortOrders,
                    sampleWeight,
                    memoryManager);
        }

        checkState(!topNBuilder.isFull(), "Aggregation buffer is full");
        topNBuilder.processPage(page);
    }

    @Override
    public Page getOutput()
    {
        if (outputIterator == null || !outputIterator.hasNext()) {
            // no data
            if (topNBuilder == null) {
                return null;
            }

            // only flush if we are finishing or the aggregation builder is full
            if (!finishing && !topNBuilder.isFull()) {
                return null;
            }

            // Only partial aggregation can flush early. Also, check that we are not flushing tiny bits at a time
            checkState(finishing || partial, "Task exceeded max memory size of %s", memoryManager.getMaxMemorySize());

            outputIterator = topNBuilder.build();
            topNBuilder = null;
        }

        pageBuilder.reset();
        while (!pageBuilder.isFull() && outputIterator.hasNext()) {
            Block[] next = outputIterator.next();
            for (int i = 0; i < next.length; i++) {
                next[i].appendTo(0, pageBuilder.getBlockBuilder(i));
            }
        }

        Page page = pageBuilder.build();
        return page;
    }

    private static class TopNBuilder
    {
        private final int n;
        private final List<Integer> sortChannels;
        private final List<SortOrder> sortOrders;
        private final TopNMemoryManager memoryManager;
        private final PriorityQueue<Block[]> globalCandidates;
        private final Optional<Integer> sampleWeightChannel;

        private long memorySize;

        private TopNBuilder(int n, List<Integer> sortChannels, List<SortOrder> sortOrders, Optional<Integer> sampleWeightChannel, TopNMemoryManager memoryManager)
        {
            this.n = n;

            this.sortChannels = sortChannels;
            this.sortOrders = sortOrders;

            this.memoryManager = memoryManager;
            this.sampleWeightChannel = sampleWeightChannel;

            Ordering<Block[]> comparator = Ordering.from(new RowComparator(sortChannels, sortOrders)).reverse();
            this.globalCandidates = new PriorityQueue<>(Math.min(n, MAX_INITIAL_PRIORITY_QUEUE_SIZE), comparator);
        }

        public void processPage(Page page)
        {
            long sizeDelta = mergeWithGlobalCandidates(page);
            memorySize += sizeDelta;
        }

        private long mergeWithGlobalCandidates(Page page)
        {
            long sizeDelta = 0;

            Block[] blocks = page.getBlocks();
            for (int position = 0; position < page.getPositionCount(); position++) {
                if (globalCandidates.size() < n) {
                    sizeDelta += addRow(position, blocks);
                }
                else if (compare(position, blocks, globalCandidates.peek()) < 0) {
                    sizeDelta += addRow(position, blocks);
                }
            }

            return sizeDelta;
        }

        private int compare(int position, Block[] blocks, Block[] currentMax)
        {
            for (int i = 0; i < sortChannels.size(); i++) {
                int sortChannel = sortChannels.get(i);
                SortOrder sortOrder = sortOrders.get(i);

                Block block = blocks[sortChannel];
                Block currentMaxValue = currentMax[sortChannel];

                // compare the right value to the left block but negate the result since we are evaluating in the opposite order
                int compare = -currentMaxValue.compareTo(sortOrder, 0, block, position);
                if (compare != 0) {
                    return compare;
                }
            }
            return 0;
        }

        private long addRow(int position, Block[] blocks)
        {
            long sizeDelta = 0;
            Block[] row = getValues(position, blocks);
            long sampleWeight = 1;
            if (sampleWeightChannel.isPresent()) {
                sampleWeight = row[sampleWeightChannel.get()].getLong(0);
                // Set the weight to one, since we're going to insert it multiple times in the priority queue
                row[sampleWeightChannel.get()] = createBigintBlock(1);
            }

            // Count the column sizes only once, because we insert the same object reference multiple times for sampled rows
            sizeDelta += sizeOfRow(row);
            globalCandidates.add(row);
            sizeDelta += (sampleWeight - 1) * OVERHEAD_PER_VALUE.toBytes();
            for (int i = 1; i < sampleWeight; i++) {
                globalCandidates.add(row);
            }

            while (globalCandidates.size() > n) {
                Block[] previous = globalCandidates.remove();
                // We insert sampled rows multiple times, so use reference equality when checking if this row is still in the queue
                if (previous != globalCandidates.peek()) {
                    sizeDelta -= sizeOfRow(previous);
                }
                else {
                    sizeDelta -= OVERHEAD_PER_VALUE.toBytes();
                }
            }
            return sizeDelta;
        }

        private static long sizeOfRow(Block[] row)
        {
            long size = OVERHEAD_PER_VALUE.toBytes();
            for (Block value : row) {
                size += value.getSizeInBytes();
            }
            return size;
        }

        private static Block[] getValues(int position, Block[] blocks)
        {
            Block[] row = new Block[blocks.length];
            for (int i = 0; i < blocks.length; i++) {
                row[i] = blocks[i].getSingleValueBlock(position);
            }
            return row;
        }

        private boolean isFull()
        {
            return memoryManager.canUse(memorySize);
        }

        public Iterator<Block[]> build()
        {
            ImmutableList.Builder<Block[]> minSortedGlobalCandidates = ImmutableList.builder();
            long sampleWeight = 1;
            while (!globalCandidates.isEmpty()) {
                Block[] row = globalCandidates.remove();
                if (sampleWeightChannel.isPresent()) {
                    // sampled rows are inserted multiple times (we can use identity comparison here)
                    // we could also test for equality to "pack" results further, but that would require another equality function
                    if (globalCandidates.peek() != null && row == globalCandidates.peek()) {
                        sampleWeight++;
                    }
                    else {
                        row[sampleWeightChannel.get()] = createBigintBlock(sampleWeight);
                        minSortedGlobalCandidates.add(row);
                        sampleWeight = 1;
                    }
                }
                else {
                    minSortedGlobalCandidates.add(row);
                }
            }
            return minSortedGlobalCandidates.build().reverse().iterator();
        }

        private static Block createBigintBlock(long value)
        {
            return BIGINT.createBlockBuilder(new BlockBuilderStatus())
                    .appendLong(value)
                    .build();
        }
    }

    public static class TopNMemoryManager
    {
        private final OperatorContext operatorContext;
        private long currentMemoryReservation;

        public TopNMemoryManager(OperatorContext operatorContext)
        {
            this.operatorContext = operatorContext;
        }

        public boolean canUse(long memorySize)
        {
            // remove the pre-allocated memory from this size
            memorySize -= operatorContext.getOperatorPreAllocatedMemory().toBytes();

            long delta = memorySize - currentMemoryReservation;
            if (delta <= 0) {
                return false;
            }

            if (!operatorContext.reserveMemory(delta)) {
                return true;
            }

            // reservation worked, record the reservation
            currentMemoryReservation = Math.max(currentMemoryReservation, memorySize);
            return false;
        }

        public DataSize getMaxMemorySize()
        {
            return operatorContext.getMaxMemorySize();
        }
    }

    private static class RowComparator
            implements Comparator<Block[]>
    {
        private final List<Integer> sortChannels;
        private final List<SortOrder> sortOrders;

        public RowComparator(List<Integer> sortChannels, List<SortOrder> sortOrders)
        {
            checkNotNull(sortChannels, "sortChannels is null");
            checkNotNull(sortOrders, "sortOrders is null");
            checkArgument(sortChannels.size() == sortOrders.size(), "sortFields size (%s) doesn't match sortOrders size (%s)", sortChannels.size(), sortOrders.size());

            this.sortChannels = ImmutableList.copyOf(sortChannels);
            this.sortOrders = ImmutableList.copyOf(sortOrders);
        }

        @Override
        public int compare(Block[] leftRow, Block[] rightRow)
        {
            for (int index = 0; index < sortChannels.size(); index++) {
                int channel = sortChannels.get(index);
                SortOrder sortOrder = sortOrders.get(index);

                Block left = leftRow[channel];
                Block right = rightRow[channel];

                int comparison = left.compareTo(sortOrder, 0, right, 0);
                if (comparison != 0) {
                    return comparison;
                }
            }
            return 0;
        }
    }
}
TOP

Related Classes of com.facebook.presto.operator.TopNOperator$TopNMemoryManager

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.