boolean isCrossJoin = (node.getType() == JoinNode.Type.CROSS);
// See if we can rewrite outer joins in terms of a plain inner join
node = tryNormalizeToInnerJoin(node, inheritedPredicate);
Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft());
Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight());
Expression joinPredicate = extractJoinPredicate(node);
Expression leftPredicate;
Expression rightPredicate;
Expression postJoinPredicate;
Expression newJoinPredicate;
switch (node.getType()) {
case INNER:
InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputSymbols());
leftPredicate = innerJoinPushDownResult.getLeftPredicate();
rightPredicate = innerJoinPushDownResult.getRightPredicate();
postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
break;
case LEFT:
OuterJoinPushDownResult leftOuterJoinPushDownResult = processOuterJoin(inheritedPredicate,
leftEffectivePredicate,
rightEffectivePredicate,
joinPredicate,
node.getLeft().getOutputSymbols());
leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = joinPredicate; // Use the same as the original
break;
case RIGHT:
OuterJoinPushDownResult rightOuterJoinPushDownResult = processOuterJoin(inheritedPredicate,
rightEffectivePredicate,
leftEffectivePredicate,
joinPredicate,
node.getRight().getOutputSymbols());
leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
newJoinPredicate = joinPredicate; // Use the same as the original
break;
default:
throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
}
PlanNode leftSource = planRewriter.rewrite(node.getLeft(), leftPredicate);
PlanNode rightSource = planRewriter.rewrite(node.getRight(), rightPredicate);
PlanNode output = node;
if (leftSource != node.getLeft() || rightSource != node.getRight() || !newJoinPredicate.equals(joinPredicate)) {
List<JoinNode.EquiJoinClause> criteria = node.getCriteria();
// Rewrite criteria and add projections if there is a new join predicate
if (!newJoinPredicate.equals(joinPredicate) || isCrossJoin) {
// Create identity projections for all existing symbols
ImmutableMap.Builder<Symbol, Expression> leftProjections = ImmutableMap.builder();
leftProjections.putAll(IterableTransformer.<Symbol>on(node.getLeft().getOutputSymbols())
.toMap(symbolToQualifiedNameReference())
.map());
ImmutableMap.Builder<Symbol, Expression> rightProjections = ImmutableMap.builder();
rightProjections.putAll(IterableTransformer.<Symbol>on(node.getRight().getOutputSymbols())
.toMap(symbolToQualifiedNameReference())
.map());
// HACK! we don't support cross joins right now, so put in a simple fake join predicate instead if all of the join clauses got simplified out
// TODO: remove this code when cross join support is added
Iterable<Expression> simplifiedJoinConjuncts = transform(extractConjuncts(newJoinPredicate), simplifyExpressions());
simplifiedJoinConjuncts = filter(simplifiedJoinConjuncts, not(Predicates.<Expression>equalTo(BooleanLiteral.TRUE_LITERAL)));
if (Iterables.isEmpty(simplifiedJoinConjuncts)) {
simplifiedJoinConjuncts = ImmutableList.<Expression>of(new ComparisonExpression(ComparisonExpression.Type.EQUAL, new LongLiteral("0"), new LongLiteral("0")));
}
// Create new projections for the new join clauses
ImmutableList.Builder<JoinNode.EquiJoinClause> builder = ImmutableList.builder();
for (Expression conjunct : simplifiedJoinConjuncts) {
checkState(joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct), "Expected join predicate to be a valid join equality");
ComparisonExpression equality = (ComparisonExpression) conjunct;
boolean alignedComparison = Iterables.all(DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols()));
Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight();
Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft();
Symbol leftSymbol = symbolAllocator.newSymbol(leftExpression, extractType(leftExpression));
leftProjections.put(leftSymbol, leftExpression);
Symbol rightSymbol = symbolAllocator.newSymbol(rightExpression, extractType(rightExpression));
rightProjections.put(rightSymbol, rightExpression);