/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tajo.engine.eval;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.tajo.catalog.Column;
import org.apache.tajo.catalog.Schema;
import org.apache.tajo.common.TajoDataTypes.DataType;
import org.apache.tajo.engine.planner.Target;
import org.apache.tajo.exception.InternalException;
import java.util.*;
public class EvalTreeUtil {
public static void changeColumnRef(EvalNode node, String oldName,
String newName) {
node.postOrder(new ChangeColumnRefVisitor(oldName, newName));
}
public static void replace(EvalNode expr, EvalNode targetExpr, EvalNode tobeReplaced) {
EvalReplaceVisitor replacer = new EvalReplaceVisitor(targetExpr, tobeReplaced);
replacer.visitChild(null, expr, new Stack<EvalNode>());
}
public static class EvalReplaceVisitor extends BasicEvalNodeVisitor<EvalNode, EvalNode> {
private EvalNode target;
private EvalNode tobeReplaced;
public EvalReplaceVisitor(EvalNode target, EvalNode tobeReplaced) {
this.target = target;
this.tobeReplaced = tobeReplaced;
}
@Override
public EvalNode visitChild(EvalNode context, EvalNode evalNode, Stack<EvalNode> stack) {
super.visitChild(context, evalNode, stack);
if (evalNode.equals(target)) {
EvalNode parent = stack.peek();
if (parent.getLeftExpr().equals(evalNode)) {
parent.setLeftExpr(tobeReplaced);
}
if (parent.getRightExpr().equals(evalNode)) {
parent.setRightExpr(tobeReplaced);
}
}
return evalNode;
}
}
public static Set<Column> findDistinctRefColumns(EvalNode node) {
DistinctColumnRefFinder finder = new DistinctColumnRefFinder();
node.postOrder(finder);
return finder.getColumnRefs();
}
public static List<Column> findAllColumnRefs(EvalNode node) {
AllColumnRefFinder finder = new AllColumnRefFinder();
node.postOrder(finder);
return finder.getColumnRefs();
}
/**
* Convert a list of conjunctive normal forms into a singleton expression.
*
* @param cnfExprs
* @return The EvalNode object that merges all CNF-formed expressions.
*/
public static EvalNode transformCNF2Singleton(EvalNode...cnfExprs) {
if (cnfExprs.length == 1) {
return cnfExprs[0];
}
return transformCNF2Singleton_(cnfExprs, 0);
}
private static EvalNode transformCNF2Singleton_(EvalNode [] evalNode, int idx) {
if (idx == evalNode.length - 2) {
return new BinaryEval(EvalType.AND, evalNode[idx], evalNode[idx + 1]);
} else {
return new BinaryEval(EvalType.AND, evalNode[idx],
transformCNF2Singleton_(evalNode, idx + 1));
}
}
/**
* Get a list of exprs similar to CNF
*
* @param expr The expression to be transformed to an array of CNF-formed expressions.
* @return An array of CNF-formed expressions
*/
public static EvalNode [] getConjNormalForm(EvalNode expr) {
List<EvalNode> list = new ArrayList<EvalNode>();
getConjNormalForm(expr, list);
return list.toArray(new EvalNode[list.size()]);
}
private static void getConjNormalForm(EvalNode node, List<EvalNode> found) {
if (node.getType() == EvalType.AND) {
getConjNormalForm(node.getLeftExpr(), found);
getConjNormalForm(node.getRightExpr(), found);
} else {
found.add(node);
}
}
public static Schema getSchemaByTargets(Schema inputSchema, Target [] targets)
throws InternalException {
Schema schema = new Schema();
for (Target target : targets) {
schema.addColumn(
target.hasAlias() ? target.getAlias() : target.getEvalTree().getName(),
getDomainByExpr(inputSchema, target.getEvalTree()));
}
return schema;
}
public static DataType getDomainByExpr(Schema inputSchema, EvalNode expr)
throws InternalException {
switch (expr.getType()) {
case AND:
case OR:
case EQUAL:
case NOT_EQUAL:
case LTH:
case LEQ:
case GTH:
case GEQ:
case PLUS:
case MINUS:
case MULTIPLY:
case DIVIDE:
case CONST:
case FUNCTION:
return expr.getValueType();
case FIELD:
FieldEval fieldEval = (FieldEval) expr;
return inputSchema.getColumnByFQN(fieldEval.getName()).getDataType();
default:
throw new InternalException("Unknown expr type: "
+ expr.getType().toString());
}
}
/**
* Return all exprs to refer columns corresponding to the target.
*
* @param expr
* @param target to be found
* @return a list of exprs
*/
public static Collection<EvalNode> getContainExpr(EvalNode expr, Column target) {
Set<EvalNode> exprSet = Sets.newHashSet();
getContainExpr(expr, target, exprSet);
return exprSet;
}
/**
* Return the counter to count the number of expression types individually.
*
* @param expr
* @return
*/
public static Map<EvalType, Integer> getExprCounters(EvalNode expr) {
VariableCounter counter = new VariableCounter();
expr.postOrder(counter);
return counter.getCounter();
}
private static void getContainExpr(EvalNode expr, Column target, Set<EvalNode> exprSet) {
switch (expr.getType()) {
case EQUAL:
case LTH:
case LEQ:
case GTH:
case GEQ:
case NOT_EQUAL:
if (containColumnRef(expr, target)) {
exprSet.add(expr);
}
}
}
/**
* Examine if the expr contains the column reference corresponding
* to the target column
*/
public static boolean containColumnRef(EvalNode expr, Column target) {
Set<EvalNode> exprSet = Sets.newHashSet();
_containColumnRef(expr, target, exprSet);
return exprSet.size() > 0;
}
private static void _containColumnRef(EvalNode expr, Column target,
Set<EvalNode> exprSet) {
switch (expr.getType()) {
case FIELD:
FieldEval field = (FieldEval) expr;
if (field.getColumnName().equals(target.getColumnName())) {
exprSet.add(field);
}
break;
case CONST:
return;
default:
_containColumnRef(expr.getLeftExpr(), target, exprSet);
_containColumnRef(expr.getRightExpr(), target, exprSet);
}
}
public static boolean isComparisonOperator(EvalNode expr) {
return expr.getType() == EvalType.EQUAL ||
expr.getType() == EvalType.LEQ ||
expr.getType() == EvalType.LTH ||
expr.getType() == EvalType.GEQ ||
expr.getType() == EvalType.GTH;
}
public static boolean isJoinQual(EvalNode expr) {
return isComparisonOperator(expr) &&
expr.getLeftExpr().getType() == EvalType.FIELD &&
expr.getRightExpr().getType() == EvalType.FIELD;
}
public static class ChangeColumnRefVisitor implements EvalNodeVisitor {
private final String findColumn;
private final String toBeChanged;
public ChangeColumnRefVisitor(String oldName, String newName) {
this.findColumn = oldName;
this.toBeChanged = newName;
}
@Override
public void visit(EvalNode node) {
if (node.type == EvalType.FIELD) {
FieldEval field = (FieldEval) node;
if (field.getColumnName().equals(findColumn)
|| field.getName().equals(findColumn)) {
field.replaceColumnRef(toBeChanged);
}
}
}
}
public static class AllColumnRefFinder implements EvalNodeVisitor {
private List<Column> colList = new ArrayList<Column>();
private FieldEval field = null;
@Override
public void visit(EvalNode node) {
if (node.getType() == EvalType.FIELD) {
field = (FieldEval) node;
colList.add(field.getColumnRef());
}
}
public List<Column> getColumnRefs() {
return this.colList;
}
}
public static class DistinctColumnRefFinder implements EvalNodeVisitor {
private Set<Column> colList = new HashSet<Column>();
private FieldEval field = null;
@Override
public void visit(EvalNode node) {
if (node.getType() == EvalType.FIELD) {
field = (FieldEval) node;
colList.add(field.getColumnRef());
}
}
public Set<Column> getColumnRefs() {
return this.colList;
}
}
public static class VariableCounter implements EvalNodeVisitor {
private final Map<EvalType, Integer> counter;
public VariableCounter() {
counter = Maps.newHashMap();
counter.put(EvalType.FUNCTION, 0);
counter.put(EvalType.FIELD, 0);
}
@Override
public void visit(EvalNode node) {
if (counter.containsKey(node.getType())) {
int val = counter.get(node.getType());
val++;
counter.put(node.getType(), val);
}
}
public Map<EvalType, Integer> getCounter() {
return counter;
}
}
public static List<AggregationFunctionCallEval> findDistinctAggFunction(EvalNode expr) {
AllAggFunctionFinder finder = new AllAggFunctionFinder();
expr.postOrder(finder);
return Lists.newArrayList(finder.getAggregationFunction());
}
public static class AllAggFunctionFinder implements EvalNodeVisitor {
private Set<AggregationFunctionCallEval> aggFucntions = Sets.newHashSet();
private AggregationFunctionCallEval field = null;
@Override
public void visit(EvalNode node) {
if (node.getType() == EvalType.AGG_FUNCTION) {
field = (AggregationFunctionCallEval) node;
aggFucntions.add(field);
}
}
public Set<AggregationFunctionCallEval> getAggregationFunction() {
return this.aggFucntions;
}
}
}