package edu.brown.optimizer.optimizations;
import java.util.Collection;
import java.util.List;
import org.apache.log4j.Logger;
import org.voltdb.VoltType;
import org.voltdb.expressions.AbstractExpression;
import org.voltdb.expressions.TupleValueExpression;
import org.voltdb.planner.PlanAssembler;
import org.voltdb.planner.PlanColumn;
import org.voltdb.plannodes.AbstractPlanNode;
import org.voltdb.plannodes.AbstractScanPlanNode;
import org.voltdb.plannodes.DistinctPlanNode;
import org.voltdb.plannodes.HashAggregatePlanNode;
import org.voltdb.plannodes.ReceivePlanNode;
import org.voltdb.plannodes.SendPlanNode;
import org.voltdb.types.ExpressionType;
import org.voltdb.utils.Pair;
import edu.brown.expressions.ExpressionUtil;
import edu.brown.logging.LoggerUtil.LoggerBoolean;
import edu.brown.optimizer.PlanOptimizerState;
import edu.brown.plannodes.PlanNodeUtil;
import edu.brown.utils.CollectionUtil;
public class AggregatePushdownOptimization extends AbstractOptimization {
private static final Logger LOG = Logger.getLogger(AggregatePushdownOptimization.class);
private static final LoggerBoolean debug = new LoggerBoolean();
private static final LoggerBoolean trace = new LoggerBoolean();
public AggregatePushdownOptimization(PlanOptimizerState state) {
super(state);
}
@Override
public Pair<Boolean, AbstractPlanNode> optimize(AbstractPlanNode rootNode) {
Collection<HashAggregatePlanNode> nodes = PlanNodeUtil.getPlanNodes(rootNode, HashAggregatePlanNode.class);
if (nodes.size() != 1) {
if (debug.val) LOG.debug("SKIP - Not an aggregate query plan");
return Pair.of(false, rootNode);
}
final HashAggregatePlanNode node = CollectionUtil.first(nodes);
// Skip single-partition query plans
if (PlanNodeUtil.isDistributedQuery(rootNode) == false) {
if (debug.val) LOG.debug("SKIP - Not a distributed query plan");
return (Pair.of(false, rootNode));
}
// // Right now, Can't do averages
// for (ExpressionType et: node.getAggregateTypes()) {
// if (et.equals(ExpressionType.AGGREGATE_AVG)) {
// if (debug.val) LOG.debug("SKIP - Right now can't optimize AVG()");
// return (Pair.of(false, rootNode));
// }
// }
// Get the AbstractScanPlanNode that is directly below us
Collection<AbstractScanPlanNode> scans = PlanNodeUtil.getPlanNodes(node, AbstractScanPlanNode.class);
if (debug.val) LOG.debug("<ScanPlanNodes>: "+ scans);
if (scans.size() != 1) {
if (debug.val)
LOG.debug("SKIP - Multiple scans!");
return (Pair.of(false, rootNode));
}
if (debug.val) LOG.debug("Trying to apply Aggregate pushdown optimization!");
AbstractScanPlanNode scan_node = CollectionUtil.first(scans);
assert (scan_node != null);
// // For some reason we have to do this??
// for (int col = 0, cnt = scan_node.getOutputColumnGUIDs().size(); col < cnt; col++) {
// int col_guid = scan_node.getOutputColumnGUIDs().get(col);
// assert (state.plannerContext.get(col_guid) != null) : "Failed [" + col_guid + "]";
// // PlanColumn retval = new PlanColumn(guid, expression, columnName,
// // sortOrder, storage);
// } // FOR
// Skip if we're already directly after the scan (meaning no network traffic)
if (scan_node.getParent(0).equals(node)) {
if (debug.val)
LOG.debug("SKIP - Aggregate does not need to be distributed");
return (Pair.of(false, rootNode));
}
// Check if this is COUNT(DISTINCT) query
// If it is then we can only pushdown the DISTINCT
AbstractPlanNode clone_node = null;
if (node.getAggregateTypes().contains(ExpressionType.AGGREGATE_COUNT)) {
for (AbstractPlanNode child : node.getChildren()) {
if (child.getClass().equals(DistinctPlanNode.class)) {
try {
clone_node = (AbstractPlanNode) child.clone(false, true);
} catch (CloneNotSupportedException ex) {
throw new RuntimeException(ex);
}
state.markDirty(clone_node);
break;
}
} // FOR
}
// Note that we don't want actually move the existing aggregate. We just
// want to clone it and then attach it down below the SEND/RECIEVE so
// that we calculate the aggregates in parallel
if (clone_node == null) {
clone_node = this.cloneAggregatePlanNode(node);
}
assert (clone_node != null);
// But this means we have to also update the RECEIVE to only expect the
// columns that the AggregateNode will be sending along
ReceivePlanNode recv_node = null;
if (clone_node instanceof DistinctPlanNode) {
recv_node = (ReceivePlanNode) node.getChild(0).getChild(0);
} else {
recv_node = (ReceivePlanNode) node.getChild(0);
}
recv_node.getOutputColumnGUIDs().clear();
recv_node.getOutputColumnGUIDs().addAll(clone_node.getOutputColumnGUIDs());
state.markDirty(recv_node);
assert (recv_node.getChild(0) instanceof SendPlanNode);
SendPlanNode send_node = (SendPlanNode) recv_node.getChild(0);
send_node.getOutputColumnGUIDs().clear();
send_node.getOutputColumnGUIDs().addAll(clone_node.getOutputColumnGUIDs());
send_node.addIntermediary(clone_node);
state.markDirty(send_node);
// 2011-12-08: We now need to correct the aggregate columns for the
// original plan node
if ((clone_node instanceof DistinctPlanNode) == false) {
// If we have a AGGREGATE_WEIGHTED_AVG in our node, then we know that
// we can skip the last column because that's the COUNT from the remote partition
boolean has_weightedAvg = node.getAggregateTypes().contains(ExpressionType.AGGREGATE_WEIGHTED_AVG);
node.getAggregateColumnGuids().clear();
int num_cols = clone_node.getOutputColumnGUIDCount() - (has_weightedAvg ? 1 : 0);
for (int i = 0; i < num_cols; i++) {
Integer aggOutput = clone_node.getOutputColumnGUID(i);
PlanColumn planCol = state.plannerContext.get(aggOutput);
assert (planCol != null);
AbstractExpression exp = planCol.getExpression();
assert (exp != null);
Collection<String> refTables = ExpressionUtil.getReferencedTableNames(exp);
assert (refTables != null);
if (refTables.size() == 1 && refTables.contains(PlanAssembler.AGGREGATE_TEMP_TABLE)) {
node.getAggregateColumnGuids().add(planCol.guid());
}
} // FOR
}
if (debug.val) {
LOG.debug("Successfully applied optimization! Eat that John Hugg!");
if (trace.val)
LOG.trace("\n" + PlanNodeUtil.debug(rootNode));
}
return Pair.of(true, rootNode);
}
/**
*
* @param node
* @return
*/
protected HashAggregatePlanNode cloneAggregatePlanNode(final HashAggregatePlanNode node) {
HashAggregatePlanNode clone_agg = null;
try {
clone_agg = (HashAggregatePlanNode) node.clone(false, true);
} catch (CloneNotSupportedException ex) {
throw new RuntimeException(ex);
}
state.markDirty(clone_agg);
// Update the cloned AggregateNode to handle distributed averages
List<ExpressionType> clone_types = clone_agg.getAggregateTypes();
// For now we'll always put a COUNT at the end of the AggregatePlanNode
// This makes it easier for us to find it in the EE
boolean has_count = false;
// boolean has_count = (clone_types.contains(ExpressionType.AGGREGATE_COUNT) ||
// clone_types.contains(ExpressionType.AGGREGATE_COUNT_STAR));
int orig_cnt = clone_types.size();
for (int i = 0; i < orig_cnt; i++) {
ExpressionType cloneType = clone_types.get(i);
// Ok, strap on your helmets boys, here's what we got going on here...
// In order to do a distributed average, we need to send the average
// AND the count (in order to compute the weight average at the base partition).
// We need check whether we already have a count already in our list
// If not, then we'll want to insert it here.
if (cloneType == ExpressionType.AGGREGATE_AVG) {
if (has_count == false) {
// But now because we add a new output column that we're going to use internally,
// we need to make sure that our output columns reflect this.
clone_types.add(ExpressionType.AGGREGATE_COUNT_STAR);
has_count = true;
// Aggregate Input Column
// We just need to do it against the first column in the child's output
// Picking the column that we want to use doesn't matter even if there is a GROUP BY
clone_agg.getAggregateColumnGuids().add(node.getChild(0).getOutputColumnGUID(0));
// Aggregate Output Column
TupleValueExpression exp = new TupleValueExpression();
exp.setValueType(VoltType.BIGINT);
exp.setValueSize(VoltType.BIGINT.getLengthInBytesForFixedTypes());
exp.setTableName(PlanAssembler.AGGREGATE_TEMP_TABLE);
exp.setColumnName("");
exp.setColumnAlias("_DTXN_COUNT");
exp.setColumnIndex(clone_agg.getOutputColumnGUIDCount());
PlanColumn new_pc = state.plannerContext.getPlanColumn(exp, exp.getColumnAlias());
clone_agg.getAggregateOutputColumns().add(clone_agg.getOutputColumnGUIDCount());
clone_agg.getAggregateColumnNames().add(new_pc.getDisplayName());
clone_agg.getOutputColumnGUIDs().add(new_pc.guid());
}
}
} // FOR
// Now go through the original AggregateNode (the one at the top of tree)
// and change the ExpressiontTypes for the aggregates to handle ahat we're
// doing down below in the distributed query
List<ExpressionType> exp_types = node.getAggregateTypes();
exp_types.clear();
for (int i = 0; i < orig_cnt; i++) {
ExpressionType cloneType = clone_types.get(i);
switch (cloneType) {
case AGGREGATE_COUNT:
case AGGREGATE_COUNT_STAR:
case AGGREGATE_SUM:
exp_types.add(ExpressionType.AGGREGATE_SUM);
break;
case AGGREGATE_MAX:
case AGGREGATE_MIN:
exp_types.add(cloneType);
break;
case AGGREGATE_AVG:
// This is a special internal marker that allows us to compute
// a weighted average from the count
exp_types.add(ExpressionType.AGGREGATE_WEIGHTED_AVG);
break;
default:
throw new RuntimeException("Unexpected ExpressionType " + cloneType);
} // SWITCH
} // FOR
// IMPORTANT: If we have GROUP BY columns, then we need to make sure
// that those columns are always passed up the query tree at the pushed
// down node, even if the final answer doesn't need it
if (node.getGroupByColumnGuids().isEmpty() == false) {
for (Integer guid : clone_agg.getGroupByColumnGuids()) {
if (clone_agg.getOutputColumnGUIDs().contains(guid) == false) {
clone_agg.getOutputColumnGUIDs().add(guid);
}
} // FOR
}
assert(clone_agg.getGroupByColumnOffsets().size() == node.getGroupByColumnOffsets().size());
assert(clone_agg.getGroupByColumnNames().size() == node.getGroupByColumnNames().size());
assert(clone_agg.getGroupByColumnGuids().size() == node.getGroupByColumnGuids().size()) : clone_agg.getGroupByColumnGuids().size() + " not equal " + node.getGroupByColumnGuids().size();
return (clone_agg);
}
}