/*******************************************************************************
* Copyright (c) 2013, Salesforce.com, Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* Neither the name of Salesforce.com nor the names of its contributors may
* be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************************************************************/
package com.salesforce.phoenix.parse;
import java.lang.reflect.Constructor;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hbase.filter.CompareFilter.CompareOp;
import org.apache.hadoop.hbase.util.Pair;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Maps;
import com.salesforce.phoenix.exception.UnknownFunctionException;
import com.salesforce.phoenix.expression.Expression;
import com.salesforce.phoenix.expression.ExpressionType;
import com.salesforce.phoenix.expression.function.AvgAggregateFunction;
import com.salesforce.phoenix.expression.function.CountAggregateFunction;
import com.salesforce.phoenix.expression.function.CurrentDateFunction;
import com.salesforce.phoenix.expression.function.CurrentTimeFunction;
import com.salesforce.phoenix.expression.function.DistinctCountAggregateFunction;
import com.salesforce.phoenix.expression.function.FunctionExpression;
import com.salesforce.phoenix.parse.FunctionParseNode.BuiltInFunction;
import com.salesforce.phoenix.parse.FunctionParseNode.BuiltInFunctionInfo;
import com.salesforce.phoenix.parse.JoinTableNode.JoinType;
import com.salesforce.phoenix.schema.ColumnModifier;
import com.salesforce.phoenix.schema.PDataType;
import com.salesforce.phoenix.schema.PIndexState;
import com.salesforce.phoenix.schema.PTableType;
import com.salesforce.phoenix.schema.TypeMismatchException;
import com.salesforce.phoenix.util.SchemaUtil;
/**
*
* Factory used by parser to construct object model while parsing a SQL statement
*
* @author jtaylor
* @since 0.1
*/
public class ParseNodeFactory {
private static final String ARRAY_ELEM = "ARRAY_ELEM";
// TODO: Use Google's Reflection library instead to find aggregate functions
@SuppressWarnings("unchecked")
private static final List<Class<? extends FunctionExpression>> CLIENT_SIDE_BUILT_IN_FUNCTIONS = Arrays.<Class<? extends FunctionExpression>>asList(
CurrentDateFunction.class,
CurrentTimeFunction.class,
AvgAggregateFunction.class
);
private static final Map<BuiltInFunctionKey, BuiltInFunctionInfo> BUILT_IN_FUNCTION_MAP = Maps.newHashMap();
/**
*
* Key used to look up a built-in function using the combination of
* the lowercase name and the number of arguments. This disambiguates
* the aggregate MAX(<col>) from the non aggregate MAX(<col1>,<col2>).
*
* @author jtaylor
* @since 0.1
*/
private static class BuiltInFunctionKey {
private final String upperName;
private final int argCount;
private BuiltInFunctionKey(String lowerName, int argCount) {
this.upperName = lowerName;
this.argCount = argCount;
}
@Override
public String toString() {
return upperName;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + argCount;
result = prime * result + ((upperName == null) ? 0 : upperName.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
BuiltInFunctionKey other = (BuiltInFunctionKey)obj;
if (argCount != other.argCount) return false;
if (!upperName.equals(other.upperName)) return false;
return true;
}
}
private static void addBuiltInFunction(Class<? extends FunctionExpression> f) throws Exception {
BuiltInFunction d = f.getAnnotation(BuiltInFunction.class);
if (d == null) {
return;
}
int nArgs = d.args().length;
BuiltInFunctionInfo value = new BuiltInFunctionInfo(f, d);
do {
// Add function to function map, throwing if conflicts found
// Add entry for each possible version of function based on arguments that are not required to be present (i.e. arg with default value)
BuiltInFunctionKey key = new BuiltInFunctionKey(value.getName(), nArgs);
if (BUILT_IN_FUNCTION_MAP.put(key, value) != null) {
throw new IllegalStateException("Multiple " + value.getName() + " functions with " + nArgs + " arguments");
}
} while (--nArgs >= 0 && d.args()[nArgs].defaultValue().length() > 0);
// Look for default values that aren't at the end and throw
while (--nArgs >= 0) {
if (d.args()[nArgs].defaultValue().length() > 0) {
throw new IllegalStateException("Function " + value.getName() + " has non trailing default value of '" + d.args()[nArgs].defaultValue() + "'. Only trailing arguments may have default values");
}
}
}
/**
* Reflect this class and populate static structures from it.
* Don't initialize in static block because we have a circular dependency
*/
private synchronized static void initBuiltInFunctionMap() {
if (!BUILT_IN_FUNCTION_MAP.isEmpty()) {
return;
}
Class<? extends FunctionExpression> f = null;
try {
// Reflection based parsing which yields direct explicit function evaluation at runtime
for (int i = 0; i < CLIENT_SIDE_BUILT_IN_FUNCTIONS.size(); i++) {
f = CLIENT_SIDE_BUILT_IN_FUNCTIONS.get(i);
addBuiltInFunction(f);
}
for (ExpressionType et : ExpressionType.values()) {
Class<? extends Expression> ec = et.getExpressionClass();
if (FunctionExpression.class.isAssignableFrom(ec)) {
@SuppressWarnings("unchecked")
Class<? extends FunctionExpression> c = (Class<? extends FunctionExpression>)ec;
addBuiltInFunction(f = c);
}
}
} catch (Exception e) {
throw new RuntimeException("Failed initialization of built-in functions at class '" + f + "'", e);
}
}
private static BuiltInFunctionInfo getInfo(String name, List<ParseNode> children) {
return get(SchemaUtil.normalizeIdentifier(name), children);
}
public static BuiltInFunctionInfo get(String normalizedName, List<ParseNode> children) {
initBuiltInFunctionMap();
BuiltInFunctionInfo info = BUILT_IN_FUNCTION_MAP.get(new BuiltInFunctionKey(normalizedName,children.size()));
if (info == null) {
throw new UnknownFunctionException(normalizedName);
}
return info;
}
public ParseNodeFactory() {
}
public ExplainStatement explain(BindableStatement statement) {
return new ExplainStatement(statement);
}
public AliasedNode aliasedNode(String alias, ParseNode expression) {
return new AliasedNode(alias, expression);
}
public AddParseNode add(List<ParseNode> children) {
return new AddParseNode(children);
}
public SubtractParseNode subtract(List<ParseNode> children) {
return new SubtractParseNode(children);
}
public MultiplyParseNode multiply(List<ParseNode> children) {
return new MultiplyParseNode(children);
}
public AndParseNode and(List<ParseNode> children) {
return new AndParseNode(children);
}
public FamilyWildcardParseNode family(String familyName){
return new FamilyWildcardParseNode(familyName, false);
}
public WildcardParseNode wildcard() {
return WildcardParseNode.INSTANCE;
}
public BetweenParseNode between(ParseNode l, ParseNode r1, ParseNode r2, boolean negate) {
return new BetweenParseNode(l, r1, r2, negate);
}
public BindParseNode bind(String bind) {
return new BindParseNode(bind);
}
public StringConcatParseNode concat(List<ParseNode> children) {
return new StringConcatParseNode(children);
}
public ColumnParseNode column(TableName tableName, String name, String alias) {
return new ColumnParseNode(tableName,name,alias);
}
public ColumnName columnName(String columnName) {
return new ColumnName(columnName);
}
public ColumnName columnName(String familyName, String columnName) {
return new ColumnName(familyName, columnName);
}
public PropertyName propertyName(String propertyName) {
return new PropertyName(propertyName);
}
public PropertyName propertyName(String familyName, String propertyName) {
return new PropertyName(familyName, propertyName);
}
public ColumnDef columnDef(ColumnName columnDefName, String sqlTypeName, boolean isNull, Integer maxLength, Integer scale, boolean isPK, ColumnModifier columnModifier) {
return new ColumnDef(columnDefName, sqlTypeName, isNull, maxLength, scale, isPK, columnModifier);
}
public ColumnDef columnDef(ColumnName columnDefName, String sqlTypeName, boolean isArray, Integer arrSize, boolean isNull, Integer maxLength, Integer scale, boolean isPK,
ColumnModifier columnModifier) {
return new ColumnDef(columnDefName, sqlTypeName, isArray, arrSize, isNull, maxLength, scale, isPK, columnModifier);
}
public PrimaryKeyConstraint primaryKey(String name, List<Pair<ColumnName, ColumnModifier>> columnNameAndModifier) {
return new PrimaryKeyConstraint(name, columnNameAndModifier);
}
public CreateTableStatement createTable(TableName tableName, ListMultimap<String,Pair<String,Object>> props, List<ColumnDef> columns, PrimaryKeyConstraint pkConstraint, List<ParseNode> splits, PTableType tableType, boolean ifNotExists, TableName baseTableName, ParseNode tableTypeIdNode, int bindCount) {
return new CreateTableStatement(tableName, props, columns, pkConstraint, splits, tableType, ifNotExists, baseTableName, tableTypeIdNode, bindCount);
}
public CreateIndexStatement createIndex(NamedNode indexName, NamedTableNode dataTable, PrimaryKeyConstraint pkConstraint, List<ColumnName> includeColumns, List<ParseNode> splits, ListMultimap<String,Pair<String,Object>> props, boolean ifNotExists, int bindCount) {
return new CreateIndexStatement(indexName, dataTable, pkConstraint, includeColumns, splits, props, ifNotExists, bindCount);
}
public CreateSequenceStatement createSequence(TableName tableName, ParseNode startsWith, ParseNode incrementBy, ParseNode cacheSize, boolean ifNotExits, int bindCount){
return new CreateSequenceStatement(tableName, startsWith, incrementBy, cacheSize, ifNotExits, bindCount);
}
public DropSequenceStatement dropSequence(TableName tableName, boolean ifExits, int bindCount){
return new DropSequenceStatement(tableName, ifExits, bindCount);
}
public SequenceValueParseNode currentValueFor(TableName tableName) {
return new SequenceValueParseNode(tableName, SequenceValueParseNode.Op.CURRENT_VALUE);
}
public SequenceValueParseNode nextValueFor(TableName tableName) {
return new SequenceValueParseNode(tableName, SequenceValueParseNode.Op.NEXT_VALUE);
}
public AddColumnStatement addColumn(NamedTableNode table, PTableType tableType, List<ColumnDef> columnDefs, boolean ifNotExists, Map<String,Object> props) {
return new AddColumnStatement(table, tableType, columnDefs, ifNotExists, props);
}
public DropColumnStatement dropColumn(NamedTableNode table, PTableType tableType, List<ColumnName> columnNodes, boolean ifExists) {
return new DropColumnStatement(table, tableType, columnNodes, ifExists);
}
public DropTableStatement dropTable(TableName tableName, PTableType tableType, boolean ifExists) {
return new DropTableStatement(tableName, tableType, ifExists);
}
public DropIndexStatement dropIndex(NamedNode indexName, TableName tableName, boolean ifExists) {
return new DropIndexStatement(indexName, tableName, ifExists);
}
public AlterIndexStatement alterIndex(NamedTableNode indexTableNode, String dataTableName, boolean ifExists, PIndexState state) {
return new AlterIndexStatement(indexTableNode, dataTableName, ifExists, state);
}
public TableName table(String schemaName, String tableName) {
return TableName.createNormalized(schemaName,tableName);
}
public NamedNode indexName(String name) {
return new NamedNode(name);
}
public NamedTableNode namedTable(String alias, TableName name) {
return new NamedTableNode(alias, name);
}
public NamedTableNode namedTable(String alias, TableName name ,List<ColumnDef> dyn_columns) {
return new NamedTableNode(alias, name,dyn_columns);
}
public BindTableNode bindTable(String alias, TableName name) {
return new BindTableNode(alias, name);
}
public CaseParseNode caseWhen(List<ParseNode> children) {
return new CaseParseNode(children);
}
public DivideParseNode divide(List<ParseNode> children) {
return new DivideParseNode(children);
}
public FunctionParseNode functionDistinct(String name, List<ParseNode> args) {
if (CountAggregateFunction.NAME.equals(SchemaUtil.normalizeIdentifier(name))) {
BuiltInFunctionInfo info = getInfo(
SchemaUtil.normalizeIdentifier(DistinctCountAggregateFunction.NAME), args);
return new DistinctCountParseNode(name, args, info);
} else {
throw new UnsupportedOperationException("DISTINCT not supported with " + name);
}
}
public FunctionParseNode arrayElemRef(List<ParseNode> args) {
return function(ARRAY_ELEM, args);
}
public FunctionParseNode function(String name, List<ParseNode> args) {
BuiltInFunctionInfo info = getInfo(name, args);
Constructor<? extends FunctionParseNode> ctor = info.getNodeCtor();
if (ctor == null) {
return info.isAggregate()
? new AggregateFunctionParseNode(name, args, info)
: new FunctionParseNode(name, args, info);
} else {
try {
return ctor.newInstance(name, args, info);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
public FunctionParseNode function(String name, List<ParseNode> valueNodes,
List<ParseNode> columnNodes, boolean isAscending) {
// Right now we support PERCENT functions on only one column
if (valueNodes.size() != 1 || columnNodes.size() != 1) {
throw new UnsupportedOperationException(name + " not supported on multiple columns");
}
List<ParseNode> children = new ArrayList<ParseNode>(3);
children.add(columnNodes.get(0));
children.add(new LiteralParseNode(Boolean.valueOf(isAscending)));
children.add(valueNodes.get(0));
return function(name, children);
}
public HintNode hint(String hint) {
return new HintNode(hint);
}
public InListParseNode inList(List<ParseNode> children, boolean negate) {
return new InListParseNode(children, negate);
}
public ExistsParseNode exists(ParseNode l, ParseNode r, boolean negate) {
return new ExistsParseNode(l, r, negate);
}
public InParseNode in(ParseNode l, ParseNode r, boolean negate) {
return new InParseNode(l, r, negate);
}
public IsNullParseNode isNull(ParseNode child, boolean negate) {
return new IsNullParseNode(child, negate);
}
public JoinTableNode join (JoinType type, ParseNode on, TableNode table) {
return new JoinTableNode(type, on, table);
}
public DerivedTableNode derivedTable (String alias, SelectStatement select) {
return new DerivedTableNode(alias, select);
}
public LikeParseNode like(ParseNode lhs, ParseNode rhs, boolean negate) {
return new LikeParseNode(lhs, rhs, negate);
}
public LiteralParseNode literal(Object value) {
return new LiteralParseNode(value);
}
public CastParseNode cast(ParseNode expression, String dataType) {
return new CastParseNode(expression, dataType);
}
public CastParseNode cast(ParseNode expression, PDataType dataType) {
return new CastParseNode(expression, dataType);
}
public ParseNode rowValueConstructor(List<ParseNode> l) {
return new RowValueConstructorParseNode(l);
}
private void checkTypeMatch (PDataType expectedType, PDataType actualType) throws SQLException {
if (!expectedType.isCoercibleTo(actualType)) {
throw TypeMismatchException.newException(expectedType, actualType);
}
}
public LiteralParseNode literal(Object value, PDataType expectedType) throws SQLException {
PDataType actualType = PDataType.fromLiteral(value);
if (actualType != null && actualType != expectedType) {
checkTypeMatch(expectedType, actualType);
value = expectedType.toObject(value, actualType);
}
return new LiteralParseNode(value);
}
public LiteralParseNode coerce(LiteralParseNode literalNode, PDataType expectedType) throws SQLException {
PDataType actualType = literalNode.getType();
if (actualType != null) {
Object before = literalNode.getValue();
checkTypeMatch(expectedType, actualType);
Object after = expectedType.toObject(before, actualType);
if (before != after) {
literalNode = literal(after);
}
}
return literalNode;
}
public ComparisonParseNode comparison(CompareOp op, ParseNode lhs, ParseNode rhs) {
switch (op){
case LESS:
return lt(lhs,rhs);
case LESS_OR_EQUAL:
return lte(lhs,rhs);
case EQUAL:
return equal(lhs,rhs);
case NOT_EQUAL:
return notEqual(lhs,rhs);
case GREATER_OR_EQUAL:
return gte(lhs,rhs);
case GREATER:
return gt(lhs,rhs);
default:
throw new IllegalArgumentException("Unexpcted CompareOp of " + op);
}
}
public GreaterThanParseNode gt(ParseNode lhs, ParseNode rhs) {
return new GreaterThanParseNode(lhs, rhs);
}
public GreaterThanOrEqualParseNode gte(ParseNode lhs, ParseNode rhs) {
return new GreaterThanOrEqualParseNode(lhs, rhs);
}
public LessThanParseNode lt(ParseNode lhs, ParseNode rhs) {
return new LessThanParseNode(lhs, rhs);
}
public LessThanOrEqualParseNode lte(ParseNode lhs, ParseNode rhs) {
return new LessThanOrEqualParseNode(lhs, rhs);
}
public EqualParseNode equal(ParseNode lhs, ParseNode rhs) {
return new EqualParseNode(lhs, rhs);
}
public ArrayConstructorNode upsertStmtArrayNode(List<ParseNode> upsertStmtArray) {
return new ArrayConstructorNode(upsertStmtArray);
}
public MultiplyParseNode negate(ParseNode child) {
return new MultiplyParseNode(Arrays.asList(child,this.literal(-1)));
}
public NotEqualParseNode notEqual(ParseNode lhs, ParseNode rhs) {
return new NotEqualParseNode(lhs, rhs);
}
public NotParseNode not(ParseNode child) {
return new NotParseNode(child);
}
public OrParseNode or(List<ParseNode> children) {
return new OrParseNode(children);
}
public OrderByNode orderBy(ParseNode expression, boolean nullsLast, boolean orderAscending) {
return new OrderByNode(expression, nullsLast, orderAscending);
}
public OuterJoinParseNode outer(ParseNode node) {
return new OuterJoinParseNode(node);
}
public SelectStatement select(List<? extends TableNode> from, HintNode hint, boolean isDistinct, List<AliasedNode> select, ParseNode where,
List<ParseNode> groupBy, ParseNode having, List<OrderByNode> orderBy, LimitNode limit, int bindCount, boolean isAggregate) {
return new SelectStatement(from, hint, isDistinct, select, where, groupBy == null ? Collections.<ParseNode>emptyList() : groupBy, having, orderBy == null ? Collections.<OrderByNode>emptyList() : orderBy, limit, bindCount, isAggregate);
}
public UpsertStatement upsert(NamedTableNode table, HintNode hint, List<ColumnName> columns, List<ParseNode> values, SelectStatement select, int bindCount) {
return new UpsertStatement(table, hint, columns, values, select, bindCount);
}
public DeleteStatement delete(NamedTableNode table, HintNode hint, ParseNode node, List<OrderByNode> orderBy, LimitNode limit, int bindCount) {
return new DeleteStatement(table, hint, node, orderBy, limit, bindCount);
}
public SelectStatement select(SelectStatement statement, ParseNode where, ParseNode having) {
return select(statement.getFrom(), statement.getHint(), statement.isDistinct(), statement.getSelect(), where, statement.getGroupBy(), having, statement.getOrderBy(), statement.getLimit(), statement.getBindCount(), statement.isAggregate());
}
public SelectStatement select(SelectStatement statement, List<? extends TableNode> tables) {
return select(tables, statement.getHint(), statement.isDistinct(), statement.getSelect(), statement.getWhere(), statement.getGroupBy(), statement.getHaving(), statement.getOrderBy(), statement.getLimit(), statement.getBindCount(), statement.isAggregate());
}
public SelectStatement select(SelectStatement statement, HintNode hint) {
return hint == null || hint.isEmpty() ? statement : select(statement.getFrom(), hint, statement.isDistinct(), statement.getSelect(), statement.getWhere(), statement.getGroupBy(), statement.getHaving(), statement.getOrderBy(), statement.getLimit(), statement.getBindCount(), statement.isAggregate());
}
public SubqueryParseNode subquery(SelectStatement select) {
return new SubqueryParseNode(select);
}
public LimitNode limit(BindParseNode b) {
return new LimitNode(b);
}
public LimitNode limit(LiteralParseNode l) {
return new LimitNode(l);
}
}