Package com.facebook.presto.operator.aggregation

Source Code of com.facebook.presto.operator.aggregation.AggregationCompiler

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.byteCode.DynamicClassLoader;
import com.facebook.presto.operator.aggregation.state.AccumulatorState;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateFactory;
import com.facebook.presto.operator.aggregation.state.AccumulatorStateSerializer;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.type.SqlType;
import com.google.common.base.Predicate;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import javax.annotation.Nullable;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.fromAnnotation;
import static com.facebook.presto.operator.aggregation.AggregationUtils.getTypeInstance;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

public class AggregationCompiler
{
    private static List<Method> findPublicStaticMethodsWithAnnotation(Class<?> clazz, Class<?> annotationClass)
    {
        ImmutableList.Builder<Method> methods = ImmutableList.builder();
        for (Method method : clazz.getMethods()) {
            for (Annotation annotation : method.getAnnotations()) {
                if (annotationClass.isInstance(annotation)) {
                    checkArgument(Modifier.isStatic(method.getModifiers()) && Modifier.isPublic(method.getModifiers()), "%s annotated with %s must be static and public", method.getName(), annotationClass.getSimpleName());
                    methods.add(method);
                }
            }
        }
        return methods.build();
    }

    public InternalAggregationFunction generateAggregationFunction(Class<?> clazz)
    {
        List<InternalAggregationFunction> aggregations = generateAggregationFunctions(clazz);
        checkArgument(aggregations.size() == 1, "More than one aggregation function found");
        return aggregations.get(0);
    }

    public InternalAggregationFunction generateAggregationFunction(Class<?> clazz, Type returnType, List<Type> argumentTypes)
    {
        checkNotNull(returnType, "returnType is null");
        checkNotNull(argumentTypes, "argumentTypes is null");
        for (InternalAggregationFunction aggregation : generateAggregationFunctions(clazz)) {
            if (aggregation.getFinalType() == returnType && aggregation.getParameterTypes().equals(argumentTypes)) {
                return aggregation;
            }
        }
        throw new IllegalArgumentException(String.format("No method with return type %s and arguments %s", returnType, argumentTypes));
    }

    public List<InternalAggregationFunction> generateAggregationFunctions(Class<?> clazz)
    {
        AggregationFunction aggregationAnnotation = clazz.getAnnotation(AggregationFunction.class);
        checkNotNull(aggregationAnnotation, "aggregationAnnotation is null");

        DynamicClassLoader classLoader = new DynamicClassLoader(clazz.getClassLoader());

        ImmutableList.Builder<InternalAggregationFunction> builder = ImmutableList.builder();
        for (Class<?> stateClass : getStateClasses(clazz)) {
            AccumulatorStateSerializer<?> stateSerializer = new StateCompiler().generateStateSerializer(stateClass, classLoader);
            Type intermediateType = stateSerializer.getSerializedType();
            Method intermediateInputFunction = getIntermediateInputFunction(clazz, stateClass);
            Method combineFunction = getCombineFunction(clazz, stateClass);
            AccumulatorStateFactory<?> stateFactory = new StateCompiler().generateStateFactory(stateClass, classLoader);

            for (Method outputFunction : getOutputFunctions(clazz, stateClass)) {
                for (Method inputFunction : getInputFunctions(clazz, stateClass)) {
                    for (String name : getNames(outputFunction, aggregationAnnotation)) {
                        List<Type> inputTypes = getInputTypes(inputFunction);
                        Type outputType = AggregationUtils.getOutputType(outputFunction, stateSerializer);

                        AggregationMetadata metadata = new AggregationMetadata(
                                generateAggregationName(name, outputType, inputTypes),
                                getParameterMetadata(inputFunction, aggregationAnnotation.approximate()),
                                inputFunction,
                                getParameterMetadata(intermediateInputFunction, false),
                                intermediateInputFunction,
                                combineFunction,
                                outputFunction,
                                stateClass,
                                stateSerializer,
                                stateFactory,
                                outputType,
                                aggregationAnnotation.approximate(),
                                false);

                        AccumulatorFactory factory = new AccumulatorCompiler().generateAccumulatorFactory(metadata, classLoader);
                        // TODO: support un-decomposable aggregations
                        builder.add(new GenericAggregationFunction(name, inputTypes, intermediateType, outputType, true, aggregationAnnotation.approximate(), factory));
                    }
                }
            }
        }

        return builder.build();
    }

    private static List<ParameterMetadata> getParameterMetadata(@Nullable Method method, boolean sampleWeightAllowed)
    {
        if (method == null) {
            return null;
        }

        ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
        builder.add(new ParameterMetadata(STATE));

        Annotation[][] annotations = method.getParameterAnnotations();
        // Start at 1 because 0 is the STATE
        for (int i = 1; i < annotations.length; i++) {
            Annotation foundAnnotation = null;
            for (Annotation annotation : annotations[i]) {
                if (annotation instanceof SqlType || annotation instanceof BlockIndex || (sampleWeightAllowed && annotation instanceof SampleWeight)) {
                    checkState(foundAnnotation == null, "%s may only have one of @SqlType, @BlockIndex, @SampleWeight per parameter", method.getName());
                    foundAnnotation = annotation;
                }
            }
            checkNotNull(foundAnnotation, "Parameter %d of %s must have @SqlType, @BlockIndex, or @SampleWeight", i, method.getName());
            builder.add(fromAnnotation(foundAnnotation));
        }
        return builder.build();
    }

    private static List<String> getNames(@Nullable Method outputFunction, AggregationFunction aggregationAnnotation)
    {
        List<String> defaultNames = ImmutableList.<String>builder().add(aggregationAnnotation.value()).addAll(Arrays.asList(aggregationAnnotation.alias())).build();

        if (outputFunction == null) {
            return defaultNames;
        }

        AggregationFunction annotation = outputFunction.getAnnotation(AggregationFunction.class);
        if (annotation == null) {
            return defaultNames;
        }
        else {
            return ImmutableList.<String>builder().add(annotation.value()).addAll(Arrays.asList(annotation.alias())).build();
        }
    }

    private static Method getIntermediateInputFunction(Class<?> clazz, Class<?> stateClass)
    {
        for (Method method : findPublicStaticMethodsWithAnnotation(clazz, IntermediateInputFunction.class)) {
            if (method.getParameterTypes()[0] == stateClass) {
                return method;
            }
        }
        return null;
    }

    private static Method getCombineFunction(Class<?> clazz, Class<?> stateClass)
    {
        for (Method method : findPublicStaticMethodsWithAnnotation(clazz, CombineFunction.class)) {
            if (method.getParameterTypes()[0] == stateClass) {
                return method;
            }
        }
        return null;
    }

    private static List<Method> getOutputFunctions(Class<?> clazz, final Class<?> stateClass)
    {
        ImmutableList<Method> methods = FluentIterable.from(findPublicStaticMethodsWithAnnotation(clazz, OutputFunction.class))
                .filter(new Predicate<Method>()
                {
                    @Override
                    public boolean apply(Method method)
                    {
                        // Filter out methods that don't match this state class
                        return method.getParameterTypes()[0] == stateClass;
                    }
                }).toList();
        if (methods.isEmpty()) {
            List<Method> noOutputFunction = new ArrayList<>();
            noOutputFunction.add(null);
            return noOutputFunction;
        }
        return methods;
    }

    private static List<Method> getInputFunctions(Class<?> clazz, final Class<?> stateClass)
    {
        List<Method> inputFunctions = FluentIterable.from(findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class))
                .filter(new Predicate<Method>()
                {
                    @Override
                    public boolean apply(Method method)
                    {
                        // Filter out methods that don't match this state class
                        return method.getParameterTypes()[0] == stateClass;
                    }
                }).toList();
        checkArgument(!inputFunctions.isEmpty(), "Aggregation has no input functions");
        return inputFunctions;
    }

    private static List<Type> getInputTypes(Method inputFunction)
    {
        ImmutableList.Builder<Type> builder = ImmutableList.builder();
        Annotation[][] parameterAnnotations = inputFunction.getParameterAnnotations();
        for (Annotation[] annotations : parameterAnnotations) {
            for (Annotation annotation : annotations) {
                if (annotation instanceof SqlType) {
                    builder.add(getTypeInstance(((SqlType) annotation).value()));
                }
            }
        }

        ImmutableList<Type> types = builder.build();
        checkArgument(!types.isEmpty(), "Input function has no parameters");
        return types;
    }

    private static Set<Class<?>> getStateClasses(Class<?> clazz)
    {
        ImmutableSet.Builder<Class<?>> builder = ImmutableSet.builder();
        for (Method inputFunction : findPublicStaticMethodsWithAnnotation(clazz, InputFunction.class)) {
            checkArgument(inputFunction.getParameterTypes().length > 0, "Input function has no parameters");
            Class<?> stateClass = inputFunction.getParameterTypes()[0];
            checkArgument(AccumulatorState.class.isAssignableFrom(stateClass), "stateClass is not a subclass of AccumulatorState");
            builder.add(stateClass);
        }
        ImmutableSet<Class<?>> stateClasses = builder.build();
        checkArgument(!stateClasses.isEmpty(), "No input functions found");

        return stateClasses;
    }

    private static DynamicClassLoader createClassLoader()
    {
        return new DynamicClassLoader(AggregationCompiler.class.getClassLoader());
    }
}
TOP

Related Classes of com.facebook.presto.operator.aggregation.AggregationCompiler

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.