Package com.facebook.presto.operator.aggregation

Source Code of com.facebook.presto.operator.aggregation.AccumulatorCompiler

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.byteCode.Block;
import com.facebook.presto.byteCode.ClassDefinition;
import com.facebook.presto.byteCode.CompilerContext;
import com.facebook.presto.byteCode.DynamicClassLoader;
import com.facebook.presto.byteCode.FieldDefinition;
import com.facebook.presto.byteCode.MethodDefinition;
import com.facebook.presto.byteCode.NamedParameterDefinition;
import com.facebook.presto.byteCode.Variable;
import com.facebook.presto.byteCode.control.ForLoop;
import com.facebook.presto.operator.GroupByIdBlock;
import com.facebook.presto.spi.Page;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateFactory;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateSerializer;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.CallSiteBinder;
import com.facebook.presto.sql.gen.CompilerOperations;
import com.facebook.presto.sql.gen.SqlTypeByteCodeExpression;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import javax.annotation.Nullable;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

import static com.facebook.presto.byteCode.Access.FINAL;
import static com.facebook.presto.byteCode.Access.PRIVATE;
import static com.facebook.presto.byteCode.Access.PUBLIC;
import static com.facebook.presto.byteCode.Access.a;
import static com.facebook.presto.byteCode.NamedParameterDefinition.arg;
import static com.facebook.presto.byteCode.OpCode.NOP;
import static com.facebook.presto.byteCode.ParameterizedType.type;
import static com.facebook.presto.byteCode.control.IfStatement.IfStatementBuilder;
import static com.facebook.presto.byteCode.control.IfStatement.ifStatementBuilder;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_INPUT_CHANNEL;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.countInputChannels;
import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
import static com.facebook.presto.sql.gen.CompilerUtils.defineClass;
import static com.facebook.presto.sql.gen.CompilerUtils.makeClassName;
import static com.facebook.presto.sql.gen.SqlTypeByteCodeExpression.constantType;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

public class AccumulatorCompiler
{
    public GenericAccumulatorFactoryBinder generateAccumulatorFactoryBinder(AggregationMetadata metadata, DynamicClassLoader classLoader)
    {
        Class<? extends Accumulator> accumulatorClass = generateAccumulatorClass(
                Accumulator.class,
                metadata,
                classLoader);

        Class<? extends GroupedAccumulator> groupedAccumulatorClass = generateAccumulatorClass(
                GroupedAccumulator.class,
                metadata,
                classLoader);

        return new GenericAccumulatorFactoryBinder(
                metadata.getStateSerializer(),
                metadata.getStateFactory(),
                accumulatorClass,
                groupedAccumulatorClass,
                metadata.isApproximate());
    }

    private static <T> Class<? extends T> generateAccumulatorClass(
            Class<T> accumulatorInterface,
            AggregationMetadata metadata,
            DynamicClassLoader classLoader)
    {
        boolean grouped = accumulatorInterface == GroupedAccumulator.class;
        boolean approximate = metadata.isApproximate();

        ClassDefinition definition = new ClassDefinition(
                a(PUBLIC, FINAL),
                makeClassName(metadata.getName() + accumulatorInterface.getSimpleName()),
                type(Object.class),
                type(accumulatorInterface));

        CallSiteBinder callSiteBinder = new CallSiteBinder();

        AccumulatorStateSerializer<?> stateSerializer = metadata.getStateSerializer();
        AccumulatorStateFactory<?> stateFactory = metadata.getStateFactory();

        FieldDefinition stateSerializerField = definition.declareField(a(PRIVATE, FINAL), "stateSerializer", AccumulatorStateSerializer.class);
        FieldDefinition stateFactoryField = definition.declareField(a(PRIVATE, FINAL), "stateFactory", AccumulatorStateFactory.class);
        FieldDefinition inputChannelsField = definition.declareField(a(PRIVATE, FINAL), "inputChannels", type(List.class, Integer.class));
        FieldDefinition maskChannelField = definition.declareField(a(PRIVATE, FINAL), "maskChannel", type(Optional.class, Integer.class));
        FieldDefinition sampleWeightChannelField = null;
        FieldDefinition confidenceField = null;
        if (approximate) {
            sampleWeightChannelField = definition.declareField(a(PRIVATE, FINAL), "sampleWeightChannel", type(Optional.class, Integer.class));
            confidenceField = definition.declareField(a(PRIVATE, FINAL), "confidence", double.class);
        }
        FieldDefinition stateField = definition.declareField(a(PRIVATE, FINAL), "state", grouped ? stateFactory.getGroupedStateClass() : stateFactory.getSingleStateClass());

        // Generate constructor
        generateConstructor(
                definition,
                stateSerializerField,
                stateFactoryField,
                inputChannelsField,
                maskChannelField,
                sampleWeightChannelField,
                confidenceField,
                stateField,
                grouped);

        // Generate methods
        generateAddInput(definition, stateField, inputChannelsField, maskChannelField, sampleWeightChannelField, metadata.getInputMetadata(), metadata.getInputFunction(), callSiteBinder, grouped);
        generateGetEstimatedSize(definition, stateField);
        MethodDefinition getIntermediateType = generateGetIntermediateType(definition, callSiteBinder, stateSerializer.getSerializedType());
        MethodDefinition getFinalType = generateGetFinalType(definition, callSiteBinder, metadata.getOutputType());

        if (metadata.getIntermediateInputFunction() == null) {
            generateAddIntermediateAsCombine(definition, stateField, stateSerializerField, stateFactoryField, metadata.getCombineFunction(), stateFactory.getSingleStateClass(), grouped);
        }
        else {
            generateAddIntermediateAsIntermediateInput(definition, stateField, metadata.getIntermediateInputMetadata(), metadata.getIntermediateInputFunction(), callSiteBinder, grouped);
        }

        if (grouped) {
            generateGroupedEvaluateIntermediate(definition, stateSerializerField, stateField);
        }
        else {
            generateEvaluateIntermediate(definition, getIntermediateType, stateSerializerField, stateField);
        }

        if (grouped) {
            generateGroupedEvaluateFinal(definition, confidenceField, stateSerializerField, stateField, metadata.getOutputFunction(), metadata.isApproximate());
        }
        else {
            generateEvaluateFinal(definition, getFinalType, confidenceField, stateSerializerField, stateField, metadata.getOutputFunction(), metadata.isApproximate());
        }

        return defineClass(definition, accumulatorInterface, callSiteBinder.getBindings(), classLoader);
    }

    private static MethodDefinition generateGetIntermediateType(ClassDefinition definition, CallSiteBinder callSiteBinder, Type type)
    {
        MethodDefinition methodDefinition = definition.declareMethod(a(PUBLIC), "getIntermediateType", type(Type.class));

        methodDefinition.getBody()
                .append(constantType(new CompilerContext(BOOTSTRAP_METHOD), callSiteBinder, type))
                .retObject();

        return methodDefinition;
    }

    private static MethodDefinition generateGetFinalType(ClassDefinition definition, CallSiteBinder callSiteBinder, Type type)
    {
        MethodDefinition methodDefinition = definition.declareMethod(a(PUBLIC), "getFinalType", type(Type.class));

        methodDefinition.getBody()
                .append(constantType(new CompilerContext(BOOTSTRAP_METHOD), callSiteBinder, type))
                .retObject();

        return methodDefinition;
    }

    private static void generateGetEstimatedSize(ClassDefinition definition, FieldDefinition stateField)
    {
        definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class))
                .getBody()
                .pushThis()
                .getField(stateField)
                .invokeVirtual(stateField.getType(), "getEstimatedSize", type(long.class))
                .retLong();
    }

    private static void generateAddInput(
            ClassDefinition definition,
            FieldDefinition stateField,
            FieldDefinition inputChannelsField,
            FieldDefinition maskChannelField,
            @Nullable FieldDefinition sampleWeightChannelField,
            List<ParameterMetadata> parameterMetadatas,
            Method inputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        CompilerContext context = new CompilerContext();

        ImmutableList.Builder<NamedParameterDefinition> parameters = ImmutableList.builder();
        if (grouped) {
            parameters.add(arg("groupIdsBlock", GroupByIdBlock.class));
        }
        parameters.add(arg("page", Page.class));

        Block body = definition.declareMethod(context, a(PUBLIC), "addInput", type(void.class), parameters.build())
                .getBody();

        if (grouped) {
            generateEnsureCapacity(stateField, body);
        }

        List<Variable> parameterVariables = new ArrayList<>();
        for (int i = 0; i < countInputChannels(parameterMetadatas); i++) {
            parameterVariables.add(context.declareVariable(com.facebook.presto.spi.block.Block.class, "block" + i));
        }
        Variable masksBlock = context.declareVariable(com.facebook.presto.spi.block.Block.class, "masksBlock");
        Variable sampleWeightsBlock = null;
        if (sampleWeightChannelField != null) {
            sampleWeightsBlock = context.declareVariable(com.facebook.presto.spi.block.Block.class, "sampleWeightsBlock");
        }

        body.comment("masksBlock = maskChannel.transform(page.blockGetter()).orNull();")
                .pushThis()
                .getField(maskChannelField)
                .getVariable("page")
                .invokeStatic(type(AggregationUtils.class), "pageBlockGetter", type(Function.class, Integer.class, com.facebook.presto.spi.block.Block.class), type(Page.class))
                .invokeVirtual(Optional.class, "transform", Optional.class, Function.class)
                .invokeVirtual(Optional.class, "orNull", Object.class)
                .checkCast(com.facebook.presto.spi.block.Block.class)
                .putVariable(masksBlock);

        if (sampleWeightChannelField != null) {
            body.comment("sampleWeightsBlock = sampleWeightChannel.transform(page.blockGetter()).get();")
                    .pushThis()
                    .getField(sampleWeightChannelField)
                    .getVariable("page")
                    .invokeStatic(type(AggregationUtils.class), "pageBlockGetter", type(Function.class, Integer.class, com.facebook.presto.spi.block.Block.class), type(Page.class))
                    .invokeVirtual(Optional.class, "transform", Optional.class, Function.class)
                    .invokeVirtual(Optional.class, "get", Object.class)
                    .checkCast(com.facebook.presto.spi.block.Block.class)
                    .putVariable(sampleWeightsBlock);
        }

        // Get all parameter blocks
        for (int i = 0; i < countInputChannels(parameterMetadatas); i++) {
            body.comment("%s = page.getBlock(inputChannels.get(%d));", parameterVariables.get(i).getName(), i)
                    .getVariable("page")
                    .pushThis()
                    .getField(inputChannelsField)
                    .push(i)
                    .invokeInterface(List.class, "get", Object.class, int.class)
                    .checkCast(Integer.class)
                    .invokeVirtual(Integer.class, "intValue", int.class)
                    .invokeVirtual(Page.class, "getBlock", com.facebook.presto.spi.block.Block.class, int.class)
                    .putVariable(parameterVariables.get(i));
        }
        Block block = generateInputForLoop(stateField, parameterMetadatas, inputFunction, context, parameterVariables, masksBlock, sampleWeightsBlock, callSiteBinder, grouped);

        body.append(block)
                .ret();
    }

    private static Block generateInputForLoop(
            FieldDefinition stateField,
            List<ParameterMetadata> parameterMetadatas,
            Method inputFunction,
            CompilerContext context,
            List<Variable> parameterVariables,
            Variable masksBlock,
            @Nullable Variable sampleWeightsBlock,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        // For-loop over rows
        Variable positionVariable = context.declareVariable(int.class, "position");
        Variable sampleWeightVariable = null;
        if (sampleWeightsBlock != null) {
            sampleWeightVariable = context.declareVariable(long.class, "sampleWeight");
        }
        Variable rowsVariable = context.declareVariable(int.class, "rows");

        Block block = new Block(context)
                .getVariable("page")
                .invokeVirtual(Page.class, "getPositionCount", int.class)
                .putVariable(rowsVariable)
                .initializeVariable(positionVariable);
        if (sampleWeightVariable != null) {
            block.initializeVariable(sampleWeightVariable);
        }

        Block loopBody = generateInvokeInputFunction(context, stateField, positionVariable, sampleWeightVariable, parameterVariables, parameterMetadatas, inputFunction, callSiteBinder, grouped);

        //  Wrap with null checks
        List<Boolean> nullable = new ArrayList<>();
        for (ParameterMetadata metadata : parameterMetadatas) {
            if (metadata.getParameterType() == INPUT_CHANNEL) {
                nullable.add(false);
            }
            else if (metadata.getParameterType() == NULLABLE_INPUT_CHANNEL) {
                nullable.add(true);
            }
        }
        checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
        for (int i = 0; i < parameterVariables.size(); i++) {
            if (!nullable.get(i)) {
                IfStatementBuilder builder = ifStatementBuilder(context);
                Variable variableDefinition = parameterVariables.get(i);
                builder.comment("if(!%s.isNull(position))", variableDefinition.getName())
                        .condition(new Block(context)
                                .getVariable(variableDefinition)
                                .getVariable(positionVariable)
                                .invokeInterface(com.facebook.presto.spi.block.Block.class, "isNull", boolean.class, int.class))
                        .ifTrue(NOP)
                        .ifFalse(loopBody);
                loopBody = new Block(context).append(builder.build());
            }
        }

        // Check that sample weight is > 0 (also checks the mask)
        if (sampleWeightVariable != null) {
            loopBody = generateComputeSampleWeightAndCheckGreaterThanZero(context, loopBody, sampleWeightVariable, masksBlock, sampleWeightsBlock, positionVariable);
        }
        // Otherwise just check the mask
        else {
            IfStatementBuilder builder = ifStatementBuilder(context);
            builder.comment("if(testMask(%s, position))", masksBlock.getName())
                    .condition(new Block(context)
                            .getVariable(masksBlock)
                            .getVariable(positionVariable)
                            .invokeStatic(CompilerOperations.class, "testMask", boolean.class, com.facebook.presto.spi.block.Block.class, int.class))
                    .ifTrue(loopBody)
                    .ifFalse(NOP);
            loopBody = new Block(context).append(builder.build());
        }

        block.append(new ForLoop.ForLoopBuilder(context)
                .initialize(new Block(context).putVariable(positionVariable, 0))
                .condition(new Block(context)
                        .getVariable(positionVariable)
                        .getVariable(rowsVariable)
                        .invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class))
                .update(new Block(context).incrementVariable(positionVariable, (byte) 1))
                .body(loopBody)
                .build());

        return block;
    }

    private static Block generateComputeSampleWeightAndCheckGreaterThanZero(CompilerContext context, Block body, Variable sampleWeight, Variable masks, Variable sampleWeights, Variable position)
    {
        Block block = new Block(context)
                .comment("sampleWeight = computeSampleWeight(masks, sampleWeights, position);")
                .getVariable(masks)
                .getVariable(sampleWeights)
                .getVariable(position)
                .invokeStatic(ApproximateUtils.class, "computeSampleWeight", long.class, com.facebook.presto.spi.block.Block.class, com.facebook.presto.spi.block.Block.class, int.class)
                .putVariable(sampleWeight);

        IfStatementBuilder builder = ifStatementBuilder(context);
        builder.comment("if(sampleWeight > 0)")
                .condition(new Block(context)
                        .getVariable(sampleWeight)
                        .invokeStatic(CompilerOperations.class, "longGreaterThanZero", boolean.class, long.class))
                .ifTrue(body)
                .ifFalse(NOP);

        return block.append(builder.build());
    }

    private static Block generateInvokeInputFunction(
            CompilerContext context,
            FieldDefinition stateField,
            Variable position,
            @Nullable Variable sampleWeight,
            List<Variable> parameterVariables,
            List<ParameterMetadata> parameterMetadatas,
            Method inputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        Block block = new Block(context);

        if (grouped) {
            generateSetGroupIdFromGroupIdsBlock(stateField, position, block);
        }

        block.comment("Call input function with unpacked Block arguments");

        Class<?>[] parameters = inputFunction.getParameterTypes();
        int inputChannel = 0;
        for (int i = 0; i < parameters.length; i++) {
            ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
            switch (parameterMetadata.getParameterType()) {
                case STATE:
                    block.pushThis().getField(stateField);
                    break;
                case BLOCK_INDEX:
                    block.getVariable(position);
                    break;
                case SAMPLE_WEIGHT:
                    checkNotNull(sampleWeight, "sampleWeight is null");
                    block.getVariable(sampleWeight);
                    break;
                case NULLABLE_INPUT_CHANNEL:
                    block.getVariable(parameterVariables.get(inputChannel));
                    inputChannel++;
                    break;
                case INPUT_CHANNEL:
                    Block getBlockByteCode = new Block(context)
                            .getVariable(parameterVariables.get(inputChannel));
                    pushStackType(block, parameterMetadata.getSqlType(), getBlockByteCode, parameters[i], callSiteBinder);
                    inputChannel++;
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
            }
        }

        block.invokeStatic(inputFunction);
        return block;
    }

    // Assumes that there is a variable named 'position' in the block, which is the current index
    private static void pushStackType(Block block, Type sqlType, Block getBlockByteCode, Class<?> parameter, CallSiteBinder callSiteBinder)
    {
        if (parameter == com.facebook.presto.spi.block.Block.class) {
            block.append(getBlockByteCode);
        }
        else if (parameter == long.class) {
            block.comment("%s.getLong(block, position)", sqlType.getName())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(BOOTSTRAP_METHOD), callSiteBinder, sqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class, "getLong", long.class, com.facebook.presto.spi.block.Block.class, int.class);
        }
        else if (parameter == double.class) {
            block.comment("%s.getDouble(block, position)", sqlType.getName())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(BOOTSTRAP_METHOD), callSiteBinder, sqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class, "getDouble", double.class, com.facebook.presto.spi.block.Block.class, int.class);
        }
        else if (parameter == boolean.class) {
            block.comment("%s.getBoolean(block, position)", sqlType.getName())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(BOOTSTRAP_METHOD), callSiteBinder, sqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class, "getBoolean", boolean.class, com.facebook.presto.spi.block.Block.class, int.class);
        }
        else if (parameter == Slice.class) {
            block.comment("%s.getBoolean(block, position)", sqlType.getName())
                    .append(SqlTypeByteCodeExpression.constantType(new CompilerContext(BOOTSTRAP_METHOD), callSiteBinder, sqlType))
                    .append(getBlockByteCode)
                    .getVariable("position")
                    .invokeInterface(Type.class, "getSlice", Slice.class, com.facebook.presto.spi.block.Block.class, int.class);
        }
        else {
            throw new IllegalArgumentException("Unsupported parameter type: " + parameter.getSimpleName());
        }
    }

    private static void generateAddIntermediateAsCombine(
            ClassDefinition definition,
            FieldDefinition stateField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateFactoryField,
            Method combineFunction,
            Class<?> singleStateClass,
            boolean grouped)
    {
        CompilerContext context = new CompilerContext();

        Block body = declareAddIntermediate(definition, grouped, context);

        Variable scratchStateVariable = context.declareVariable(singleStateClass, "scratchState");
        Variable positionVariable = context.declareVariable(int.class, "position");

        body.comment("scratchState = stateFactory.createSingleState();")
                .pushThis()
                .getField(stateFactoryField)
                .invokeInterface(AccumulatorStateFactory.class, "createSingleState", Object.class)
                .checkCast(scratchStateVariable.getType())
                .putVariable(scratchStateVariable);

        if (grouped) {
            generateEnsureCapacity(stateField, body);
        }

        Block loopBody = new Block(context);

        if (grouped) {
            generateSetGroupIdFromGroupIdsBlock(stateField, positionVariable, loopBody);
        }

        loopBody.comment("stateSerializer.deserialize(block, position, scratchState)")
                .pushThis()
                .getField(stateSerializerField)
                .getVariable("block")
                .getVariable(positionVariable)
                .getVariable(scratchStateVariable)
                .invokeInterface(AccumulatorStateSerializer.class, "deserialize", void.class, com.facebook.presto.spi.block.Block.class, int.class, Object.class);

        loopBody.comment("combine(state, scratchState)")
                .pushThis()
                .getField(stateField)
                .getVariable("scratchState")
                .invokeStatic(combineFunction);

        body.append(generateBlockNonNullPositionForLoop(context, positionVariable, loopBody))
                .ret();
    }

    private static void generateSetGroupIdFromGroupIdsBlock(FieldDefinition stateField, Variable positionVariable, Block block)
    {
        block.comment("state.setGroupId(groupIdsBlock.getGroupId(position))")
                .pushThis()
                .getField(stateField)
                .getVariable("groupIdsBlock")
                .getVariable(positionVariable)
                .invokeVirtual(GroupByIdBlock.class, "getGroupId", long.class, int.class)
                .invokeVirtual(stateField.getType(), "setGroupId", type(void.class), type(long.class));
    }

    private static void generateEnsureCapacity(FieldDefinition stateField, Block block)
    {
        block.comment("state.ensureCapacity(groupIdsBlock.getGroupCount())")
                .pushThis()
                .getField(stateField)
                .getVariable("groupIdsBlock")
                .invokeVirtual(GroupByIdBlock.class, "getGroupCount", long.class)
                .invokeVirtual(stateField.getType(), "ensureCapacity", type(void.class), type(long.class));
    }

    private static Block declareAddIntermediate(ClassDefinition definition, boolean grouped, CompilerContext context)
    {
        ImmutableList.Builder<NamedParameterDefinition> parameters = ImmutableList.builder();
        if (grouped) {
            parameters.add(arg("groupIdsBlock", GroupByIdBlock.class));
        }
        parameters.add(arg("block", com.facebook.presto.spi.block.Block.class));

        return definition.declareMethod(
                context,
                a(PUBLIC),
                "addIntermediate",
                type(void.class),
                parameters.build())
                .getBody();
    }

    private static void generateAddIntermediateAsIntermediateInput(
            ClassDefinition definition,
            FieldDefinition stateField,
            List<ParameterMetadata> parameterMetadatas,
            Method intermediateInputFunction,
            CallSiteBinder callSiteBinder,
            boolean grouped)
    {
        CompilerContext context = new CompilerContext();

        Block body = declareAddIntermediate(definition, grouped, context);

        if (grouped) {
            generateEnsureCapacity(stateField, body);
        }

        Variable positionVariable = context.declareVariable(int.class, "position");

        Block loopBody = generateInvokeInputFunction(context, stateField, positionVariable, null, ImmutableList.of(context.getVariable("block")), parameterMetadatas, intermediateInputFunction, callSiteBinder, grouped);

        body.append(generateBlockNonNullPositionForLoop(context, positionVariable, loopBody))
                .ret();
    }

    // Generates a for-loop with a local variable named "position" defined, with the current position in the block,
    // loopBody will only be executed for non-null positions in the Block
    private static Block generateBlockNonNullPositionForLoop(CompilerContext context, Variable positionVariable, Block loopBody)
    {
        Variable rowsVariable = context.declareVariable(int.class, "rows");

        Block block = new Block(context)
                .getVariable("block")
                .invokeInterface(com.facebook.presto.spi.block.Block.class, "getPositionCount", int.class)
                .putVariable(rowsVariable);

        IfStatementBuilder builder = ifStatementBuilder(context);
        builder.comment("if(!block.isNull(position))")
                .condition(new Block(context)
                        .getVariable("block")
                        .getVariable(positionVariable)
                        .invokeInterface(com.facebook.presto.spi.block.Block.class, "isNull", boolean.class, int.class))
                .ifTrue(NOP)
                .ifFalse(loopBody);

        block.append(new ForLoop.ForLoopBuilder(context)
                .initialize(new Block(context).putVariable(positionVariable, 0))
                .condition(new Block(context)
                        .getVariable(positionVariable)
                        .getVariable(rowsVariable)
                        .invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class))
                .update(new Block(context).incrementVariable(positionVariable, (byte) 1))
                .body(builder.build())
                .build());

        return block;
    }

    private static void generateGroupedEvaluateIntermediate(ClassDefinition definition, FieldDefinition stateSerializerField, FieldDefinition stateField)
    {
        definition.declareMethod(
                a(PUBLIC),
                "evaluateIntermediate",
                type(void.class),
                arg("groupId", int.class),
                arg("out", BlockBuilder.class))
                .getBody()
                .comment("state.setGroupId(groupId)")
                .pushThis()
                .getField(stateField)
                .getVariable("groupId")
                .intToLong()
                .invokeVirtual(stateField.getType(), "setGroupId", type(void.class), type(long.class))

                .comment("stateSerializer.serialize(state, out)")
                .pushThis()
                .getField(stateSerializerField)
                .pushThis()
                .getField(stateField)
                .getVariable("out")
                .invokeInterface(AccumulatorStateSerializer.class, "serialize", void.class, Object.class, BlockBuilder.class)
                .ret();
    }

    private static void generateEvaluateIntermediate(ClassDefinition definition, MethodDefinition getIntermediateType, FieldDefinition stateSerializerField, FieldDefinition stateField)
    {
        CompilerContext context = new CompilerContext();
        Block body = definition.declareMethod(
                context,
                a(PUBLIC),
                "evaluateIntermediate",
                type(com.facebook.presto.spi.block.Block.class))
                .getBody();

        context.declareVariable(BlockBuilder.class, "out");

        body.comment("out = getIntermediateType().createBlockBuilder(new BlockBuilderStatus());")
                .pushThis()
                .invokeVirtual(getIntermediateType)
                .newObject(BlockBuilderStatus.class)
                .dup()
                .invokeConstructor(BlockBuilderStatus.class)
                .invokeInterface(Type.class, "createBlockBuilder", BlockBuilder.class, BlockBuilderStatus.class)
                .putVariable("out");

        body.comment("stateSerializer.serialize(state, out)")
                .pushThis()
                .getField(stateSerializerField)
                .pushThis()
                .getField(stateField)
                .getVariable("out")
                .invokeInterface(AccumulatorStateSerializer.class, "serialize", void.class, Object.class, BlockBuilder.class);

        body.comment("return out.build();")
                .getVariable("out")
                .invokeInterface(BlockBuilder.class, "build", com.facebook.presto.spi.block.Block.class)
                .retObject();
    }

    private static void generateGroupedEvaluateFinal(
            ClassDefinition definition,
            FieldDefinition confidenceField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateField,
            @Nullable Method outputFunction,
            boolean approximate)
    {
        Block body = definition.declareMethod(
                a(PUBLIC),
                "evaluateFinal",
                type(void.class),
                arg("groupId", int.class),
                arg("out", BlockBuilder.class))
                .getBody()
                .comment("state.setGroupId(groupId)")
                .pushThis()
                .getField(stateField)
                .getVariable("groupId")
                .intToLong()
                .invokeVirtual(stateField.getType(), "setGroupId", type(void.class), type(long.class));

        if (outputFunction != null) {
            body.comment("output(state, out)")
                    .pushThis()
                    .getField(stateField);
            if (approximate) {
                checkNotNull(confidenceField, "confidenceField is null");
                body.pushThis().getField(confidenceField);
            }
            body.getVariable("out")
                    .invokeStatic(outputFunction);
        }
        else {
            checkArgument(!approximate, "Approximate aggregations must specify an output function");
            body.comment("stateSerializer.serialize(state, out)")
                    .pushThis()
                    .getField(stateSerializerField)
                    .pushThis()
                    .getField(stateField)
                    .getVariable("out")
                    .invokeInterface(AccumulatorStateSerializer.class, "serialize", void.class, Object.class, BlockBuilder.class);
        }
        body.ret();
    }

    private static void generateEvaluateFinal(
            ClassDefinition definition,
            MethodDefinition getFinalType,
            FieldDefinition confidenceField,
            FieldDefinition stateSerializerField,
            FieldDefinition stateField,
            @Nullable
            Method outputFunction,
            boolean approximate)
    {
        CompilerContext context = new CompilerContext();
        Block body = definition.declareMethod(
                context,
                a(PUBLIC),
                "evaluateFinal",
                type(com.facebook.presto.spi.block.Block.class))
                .getBody();

        context.declareVariable(BlockBuilder.class, "out");

        body.pushThis()
                .invokeVirtual(getFinalType)
                .newObject(BlockBuilderStatus.class)
                .dup()
                .invokeConstructor(BlockBuilderStatus.class)
                .invokeInterface(Type.class, "createBlockBuilder", BlockBuilder.class, BlockBuilderStatus.class)
                .putVariable("out");

        if (outputFunction != null) {
            body.comment("output(state, out)")
                    .pushThis()
                    .getField(stateField);
            if (approximate) {
                checkNotNull(confidenceField, "confidenceField is null");
                body.pushThis().getField(confidenceField);
            }
            body.getVariable("out")
                    .invokeStatic(outputFunction);
        }
        else {
            checkArgument(!approximate, "Approximate aggregations must specify an output function");
            body.comment("stateSerializer.serialize(state, out)")
                    .pushThis()
                    .getField(stateSerializerField)
                    .pushThis()
                    .getField(stateField)
                    .getVariable("out")
                    .invokeInterface(AccumulatorStateSerializer.class, "serialize", void.class, Object.class, BlockBuilder.class);
        }

        body.getVariable("out")
                .invokeInterface(BlockBuilder.class, "build", com.facebook.presto.spi.block.Block.class)
                .retObject();
    }

    private static void generateConstructor(
            ClassDefinition definition,
            FieldDefinition stateSerializerField,
            FieldDefinition stateFactoryField,
            FieldDefinition inputChannelsField,
            FieldDefinition maskChannelField,
            @Nullable FieldDefinition sampleWeightChannelField,
            @Nullable FieldDefinition confidenceField,
            FieldDefinition stateField,
            boolean grouped)
    {
        Block body = definition.declareConstructor(
                a(PUBLIC),
                arg("stateSerializer", AccumulatorStateSerializer.class),
                arg("stateFactory", AccumulatorStateFactory.class),
                arg("inputChannels", type(List.class, Integer.class)),
                arg("maskChannel", type(Optional.class, Integer.class)),
                arg("sampleWeightChannel", type(Optional.class, Integer.class)),
                arg("confidence", double.class))
                .getBody()
                .comment("super();")
                .pushThis()
                .invokeConstructor(Object.class);

        generateCastCheckNotNullAndAssign(body, stateSerializerField, "stateSerializer");
        generateCastCheckNotNullAndAssign(body, stateFactoryField, "stateFactory");
        generateCastCheckNotNullAndAssign(body, inputChannelsField, "inputChannels");
        generateCastCheckNotNullAndAssign(body, maskChannelField, "maskChannel");
        if (sampleWeightChannelField != null) {
            generateCastCheckNotNullAndAssign(body, sampleWeightChannelField, "sampleWeightChannel");
        }

        String createState;
        if (grouped) {
            createState = "createGroupedState";
        }
        else {
            createState = "createSingleState";
        }

        if (confidenceField != null) {
            body.comment("this.confidence = confidence")
                    .pushThis()
                    .getVariable("confidence")
                    .putField(confidenceField);
        }

        body.comment("this.state = stateFactory.%s()", createState)
                .pushThis()
                .getVariable("stateFactory")
                .invokeInterface(AccumulatorStateFactory.class, createState, Object.class)
                .checkCast(stateField.getType())
                .putField(stateField)
                .ret();
    }

    private static void generateCastCheckNotNullAndAssign(Block block, FieldDefinition field, String variableName)
    {
        block.comment("this.%s = checkNotNull(%s, \"%s is null\"", field.getName(), variableName, variableName)
                .pushThis()
                .getVariable(variableName)
                .checkCast(field.getType())
                .push(variableName + " is null")
                .invokeStatic(Preconditions.class, "checkNotNull", Object.class, Object.class, Object.class)
                .checkCast(field.getType())
                .putField(field);
    }
}
TOP

Related Classes of com.facebook.presto.operator.aggregation.AccumulatorCompiler

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.