/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.metadata;
import com.facebook.presto.operator.Description;
import com.facebook.presto.operator.aggregation.GenericAggregationFunctionFactory;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import com.facebook.presto.operator.scalar.JsonPath;
import com.facebook.presto.operator.scalar.ScalarFunction;
import com.facebook.presto.operator.scalar.ScalarOperator;
import com.facebook.presto.operator.window.ReflectionWindowFunctionSupplier;
import com.facebook.presto.operator.window.WindowFunction;
import com.facebook.presto.operator.window.WindowFunctionSupplier;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.type.SqlType;
import com.google.common.base.Throwables;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.primitives.Primitives;
import io.airlift.slice.Slice;
import org.joni.Regex;
import javax.annotation.Nullable;
import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import static com.facebook.presto.metadata.FunctionRegistry.operatorInfo;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.type.TypeUtils.nameGetter;
import static com.facebook.presto.type.TypeUtils.resolveTypes;
import static com.google.common.base.CaseFormat.LOWER_CAMEL;
import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static java.lang.invoke.MethodHandles.lookup;
public class FunctionListBuilder
{
private static final Set<Class<?>> NULLABLE_ARGUMENT_TYPES = ImmutableSet.<Class<?>>of(Boolean.class, Long.class, Double.class, Slice.class);
private static final Set<Class<?>> SUPPORTED_TYPES = ImmutableSet.of(
long.class,
Long.class,
double.class,
Double.class,
Slice.class,
boolean.class,
Boolean.class,
Pattern.class,
Regex.class,
JsonPath.class);
private static final Set<Class<?>> SUPPORTED_RETURN_TYPES = ImmutableSet.of(
long.class,
double.class,
Slice.class,
boolean.class,
int.class,
Pattern.class,
Regex.class,
JsonPath.class);
private final List<ParametricFunction> functions = new ArrayList<>();
private final TypeManager typeManager;
public FunctionListBuilder(TypeManager typeManager)
{
this.typeManager = checkNotNull(typeManager, "typeManager is null");
}
public FunctionListBuilder window(String name, Type returnType, List<? extends Type> argumentTypes, Class<? extends WindowFunction> functionClass)
{
WindowFunctionSupplier windowFunctionSupplier = new ReflectionWindowFunctionSupplier<>(
new Signature(name, returnType.getName(), Lists.transform(ImmutableList.copyOf(argumentTypes), nameGetter())),
functionClass);
functions.add(new FunctionInfo(windowFunctionSupplier.getSignature(), windowFunctionSupplier.getDescription(), windowFunctionSupplier));
return this;
}
public FunctionListBuilder aggregate(List<InternalAggregationFunction> functions)
{
for (InternalAggregationFunction function : functions) {
aggregate(function);
}
return this;
}
public FunctionListBuilder aggregate(InternalAggregationFunction function)
{
String name = function.name();
name = name.toLowerCase();
String description = getDescription(function.getClass());
Signature signature = new Signature(name, function.getFinalType().getName(), Lists.transform(ImmutableList.copyOf(function.getParameterTypes()), nameGetter()));
functions.add(new FunctionInfo(signature, description, function.getIntermediateType().getName(), function, function.isApproximate()));
return this;
}
public FunctionListBuilder aggregate(Class<?> aggregationDefinition)
{
functions.addAll(GenericAggregationFunctionFactory.fromAggregationDefinition(aggregationDefinition, typeManager).listFunctions());
return this;
}
public FunctionListBuilder scalar(Signature signature, MethodHandle function, boolean deterministic, String description, boolean hidden, boolean nullable, List<Boolean> nullableArguments)
{
functions.add(new FunctionInfo(signature, description, hidden, function, deterministic, nullable, nullableArguments));
return this;
}
private FunctionListBuilder operator(OperatorType operatorType, Type returnType, List<Type> parameterTypes, MethodHandle function, boolean nullable, List<Boolean> nullableArguments)
{
FunctionInfo operatorInfo = operatorInfo(operatorType, returnType.getName(), Lists.transform(parameterTypes, nameGetter()), function, nullable, nullableArguments);
functions.add(operatorInfo);
return this;
}
public FunctionListBuilder scalar(Class<?> clazz)
{
try {
boolean foundOne = false;
for (Method method : clazz.getMethods()) {
foundOne = processScalarFunction(method) || foundOne;
foundOne = processScalarOperator(method) || foundOne;
}
checkArgument(foundOne, "Expected class %s to contain at least one method annotated with @%s", clazz.getName(), ScalarFunction.class.getSimpleName());
}
catch (IllegalAccessException e) {
throw Throwables.propagate(e);
}
return this;
}
public FunctionListBuilder parametricScalar(ParametricFunction parametricFunction)
{
checkNotNull(parametricFunction, "parametricFunction is null");
functions.add(parametricFunction);
return this;
}
private boolean processScalarFunction(Method method)
throws IllegalAccessException
{
ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class);
if (scalarFunction == null) {
return false;
}
checkValidMethod(method);
MethodHandle methodHandle = lookup().unreflect(method);
String name = scalarFunction.value();
if (name.isEmpty()) {
name = camelToSnake(method.getName());
}
SqlType returnTypeAnnotation = method.getAnnotation(SqlType.class);
checkArgument(returnTypeAnnotation != null, "Method %s return type does not have a @SqlType annotation", method);
Type returnType = type(typeManager, returnTypeAnnotation);
Signature signature = new Signature(name.toLowerCase(), returnType.getName(), Lists.transform(parameterTypes(typeManager, method), nameGetter()));
verifyMethodSignature(method, signature.getReturnType(), signature.getArgumentTypes(), typeManager);
List<Boolean> nullableArguments = getNullableArguments(method);
scalar(signature, methodHandle, scalarFunction.deterministic(), getDescription(method), scalarFunction.hidden(), method.isAnnotationPresent(Nullable.class), nullableArguments);
for (String alias : scalarFunction.alias()) {
scalar(signature.withAlias(alias.toLowerCase()), methodHandle, scalarFunction.deterministic(), getDescription(method), scalarFunction.hidden(), method.isAnnotationPresent(Nullable.class), nullableArguments);
}
return true;
}
private static Type type(TypeManager typeManager, SqlType explicitType)
{
Type type = typeManager.getType(explicitType.value());
checkNotNull(type, "No type found for '%s'", explicitType.value());
return type;
}
private static List<Type> parameterTypes(TypeManager typeManager, Method method)
{
Annotation[][] parameterAnnotations = method.getParameterAnnotations();
ImmutableList.Builder<Type> types = ImmutableList.builder();
for (int i = 0; i < method.getParameterTypes().length; i++) {
Class<?> clazz = method.getParameterTypes()[i];
// skip session parameters
if (clazz == ConnectorSession.class) {
continue;
}
// find the explicit type annotation if present
SqlType explicitType = null;
for (Annotation annotation : parameterAnnotations[i]) {
if (annotation instanceof SqlType) {
explicitType = (SqlType) annotation;
break;
}
}
checkArgument(explicitType != null, "Method %s argument %s does not have a @SqlType annotation", method, i);
types.add(type(typeManager, explicitType));
}
return types.build();
}
private static void verifyMethodSignature(Method method, String returnTypeName, List<String> argumentTypeNames, TypeManager typeManager)
{
Type returnType = typeManager.getType(returnTypeName);
checkNotNull(returnType, "returnType is null");
List<Type> argumentTypes = resolveTypes(argumentTypeNames, typeManager);
checkArgument(Primitives.unwrap(method.getReturnType()) == returnType.getJavaType(),
"Expected method %s return type to be %s (%s)",
method,
returnType.getJavaType().getName(),
returnType);
// skip Session argument
Class<?>[] parameterTypes = method.getParameterTypes();
Annotation[][] annotations = method.getParameterAnnotations();
if (parameterTypes.length > 0 && parameterTypes[0] == ConnectorSession.class) {
parameterTypes = Arrays.copyOfRange(parameterTypes, 1, parameterTypes.length);
annotations = Arrays.copyOfRange(annotations, 1, annotations.length);
}
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> actualType = parameterTypes[i];
Type expectedType = argumentTypes.get(i);
boolean nullable = !FluentIterable.from(Arrays.asList(annotations[i])).filter(Nullable.class).isEmpty();
// Only allow boxing for functions that need to see nulls
if (Primitives.isWrapperType(actualType)) {
checkArgument(nullable, "Method %s has parameter with type %s that is missing @Nullable", method, actualType);
}
if (nullable) {
checkArgument(NULLABLE_ARGUMENT_TYPES.contains(actualType), "Method %s has parameter type %s, but @Nullable is not supported on this type", method, actualType);
}
checkArgument(Primitives.unwrap(actualType) == expectedType.getJavaType(),
"Expected method %s parameter %s type to be %s (%s)",
method,
i,
expectedType.getJavaType().getName(),
expectedType);
}
}
private static List<Boolean> getNullableArguments(Method method)
{
List<Boolean> nullableArguments = new ArrayList<>();
for (Annotation[] annotations : method.getParameterAnnotations()) {
boolean nullable = false;
boolean foundSqlType = false;
for (Annotation annotation : annotations) {
if (annotation instanceof Nullable) {
nullable = true;
}
if (annotation instanceof SqlType) {
foundSqlType = true;
}
}
// Check that this is a real argument. For example, some functions take ConnectorSession which isn't a SqlType
if (foundSqlType) {
nullableArguments.add(nullable);
}
}
return nullableArguments;
}
private boolean processScalarOperator(Method method)
throws IllegalAccessException
{
ScalarOperator scalarOperator = method.getAnnotation(ScalarOperator.class);
if (scalarOperator == null) {
return false;
}
checkValidMethod(method);
MethodHandle methodHandle = lookup().unreflect(method);
OperatorType operatorType = scalarOperator.value();
List<Type> parameterTypes = parameterTypes(typeManager, method);
Type returnType;
if (operatorType == OperatorType.HASH_CODE) {
// todo hack for hashCode... should be int
returnType = BIGINT;
}
else {
SqlType explicitType = method.getAnnotation(SqlType.class);
checkArgument(explicitType != null, "Method %s return type does not have a @SqlType annotation", method);
returnType = type(typeManager, explicitType);
verifyMethodSignature(method, returnType.getName(), Lists.transform(parameterTypes, nameGetter()), typeManager);
}
List<Boolean> nullableArguments = getNullableArguments(method);
operator(operatorType, returnType, parameterTypes, methodHandle, method.isAnnotationPresent(Nullable.class), nullableArguments);
return true;
}
private static String getDescription(AnnotatedElement annotatedElement)
{
Description description = annotatedElement.getAnnotation(Description.class);
return (description == null) ? null : description.value();
}
private static String camelToSnake(String name)
{
return LOWER_CAMEL.to(LOWER_UNDERSCORE, name);
}
private static void checkValidMethod(Method method)
{
String message = "@ScalarFunction method %s is not valid: ";
checkArgument(Modifier.isStatic(method.getModifiers()), message + "must be static", method);
checkArgument(SUPPORTED_RETURN_TYPES.contains(Primitives.unwrap(method.getReturnType())), message + "return type not supported", method);
if (method.getAnnotation(Nullable.class) != null) {
checkArgument(!method.getReturnType().isPrimitive(), message + "annotated with @Nullable but has primitive return type", method);
}
else {
checkArgument(!Primitives.isWrapperType(method.getReturnType()), "not annotated with @Nullable but has boxed primitive return type", method);
}
for (Class<?> type : getParameterTypes(method.getParameterTypes())) {
checkArgument(SUPPORTED_TYPES.contains(type), message + "parameter type [%s] not supported", method, type.getName());
}
}
private static List<Class<?>> getParameterTypes(Class<?>... types)
{
ImmutableList<Class<?>> parameterTypes = ImmutableList.copyOf(types);
if (!parameterTypes.isEmpty() && parameterTypes.get(0) == ConnectorSession.class) {
parameterTypes = parameterTypes.subList(1, parameterTypes.size());
}
return parameterTypes;
}
public List<ParametricFunction> getFunctions()
{
return ImmutableList.copyOf(functions);
}
}