}
@Override
public void onMatch(RelOptRuleCall call) {
final DrillAggregateRel agg = (DrillAggregateRel) call.rel(0);
final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length -1);
final DrillProjectRel proj = call.rels.length == 3 ? (DrillProjectRel) call.rel(1) : null;
GroupScan oldGrpScan = scan.getGroupScan();
// Only apply the rule when :
// 1) scan knows the exact row count in getSize() call,
// 2) No GroupBY key,
// 3) only one agg function (Check if it's count(*) below).
// 4) No distinct agg call.
if (! (oldGrpScan.getScanStats().getGroupScanProperty().hasExactRowCount()
&& agg.getGroupCount() == 0
&& agg.getAggCallList().size() == 1
&& !agg.containsDistinctCall())) {
return;
}
AggregateCall aggCall = agg.getAggCallList().get(0);
if (aggCall.getAggregation().getName().equals("COUNT") ) {
long cnt = 0;
// count(*) == > empty arg ==> rowCount
// count(Not-null-input) ==> rowCount
if (aggCall.getArgList().isEmpty() ||
(aggCall.getArgList().size() == 1 &&
! agg.getChild().getRowType().getFieldList().get(aggCall.getArgList().get(0).intValue()).getType().isNullable())) {
cnt = (long) oldGrpScan.getScanStats().getRecordCount();
} else if (aggCall.getArgList().size() == 1) {
// count(columnName) ==> Agg ( Scan )) ==> columnValueCount
int index = aggCall.getArgList().get(0);
String columnName = scan.getRowType().getFieldNames().get(index).toLowerCase();
cnt = oldGrpScan.getColumnValueCount(SchemaPath.getSimplePath(columnName));
} else {
return; // do nothing.
}
RelDataType scanRowType = getCountDirectScanRowType(agg.getCluster().getTypeFactory());
final ScanPrel newScan = ScanPrel.create(scan,
scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), getCountDirectScan(cnt),
scanRowType);
List<RexNode> exprs = Lists.newArrayList();
exprs.add(RexInputRef.of(0, scanRowType));