func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema expression.Schema, unionChild LogicalPlan) LogicalPlan { newAgg := &Aggregation{ AggFuncs: make([]expression.AggregationFunction, 0, len(agg.AggFuncs)), GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)), baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc), } newAgg.SetSchema(agg.schema.Clone()) newAgg.correlated = agg.correlated newAgg.initIDAndContext(a.ctx) for _, aggFunc := range agg.AggFuncs { newAggFunc := aggFunc.Clone() newArgs := make([]expression.Expression, 0, len(newAggFunc.GetArgs())) for _, arg := range newAggFunc.GetArgs() { newArgs = append(newArgs, expression.ColumnSubstitute(arg, unionSchema, expression.Schema2Exprs(unionChild.GetSchema()))) } newAggFunc.SetArgs(newArgs) newAgg.AggFuncs = append(newAgg.AggFuncs, newAggFunc) } for _, gbyExpr := range agg.GroupByItems { newExpr := expression.ColumnSubstitute(gbyExpr, unionSchema, expression.Schema2Exprs(unionChild.GetSchema())) newAgg.GroupByItems = append(newAgg.GroupByItems, newExpr) } newAgg.collectGroupByColumns() newAgg.SetChildren(unionChild) unionChild.SetParents(newAgg) return newAgg }
// aggPushDown tries to push down aggregate functions to join paths. func (a *aggPushDownSolver) aggPushDown(p LogicalPlan) { if agg, ok := p.(*Aggregation); ok { child := agg.GetChildByIndex(0) if join, ok1 := child.(*Join); ok1 && a.checkValidJoin(join) { if valid, leftAggFuncs, rightAggFuncs, leftGbyCols, rightGbyCols := a.splitAggFuncsAndGbyCols(agg, join); valid { var lChild, rChild LogicalPlan // If there exist count or sum functions in left join path, we can't push any // aggregate function into right join path. rightInvalid := a.checkAnyCountAndSum(leftAggFuncs) leftInvalid := a.checkAnyCountAndSum(rightAggFuncs) if rightInvalid { rChild = join.GetChildByIndex(1).(LogicalPlan) } else { rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) } if leftInvalid { lChild = join.GetChildByIndex(0).(LogicalPlan) } else { lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) } join.SetChildren(lChild, rChild) lChild.SetParents(join) rChild.SetParents(join) join.SetSchema(append(lChild.GetSchema().Clone(), rChild.GetSchema().Clone()...)) } } else if proj, ok1 := child.(*Projection); ok1 { // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, // so we must do this optimization. for i, gbyItem := range agg.GroupByItems { agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) } agg.collectGroupByColumns() for _, aggFunc := range agg.AggFuncs { newArgs := make([]expression.Expression, 0, len(aggFunc.GetArgs())) for _, arg := range aggFunc.GetArgs() { newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) } aggFunc.SetArgs(newArgs) } projChild := proj.children[0] agg.SetChildren(projChild) projChild.SetParents(agg) } else if union, ok1 := child.(*Union); ok1 { pushedAgg := a.makeNewAgg(agg.AggFuncs, agg.groupByCols) newChildren := make([]Plan, 0, len(union.children)) for _, child := range union.children { newChild := a.pushAggCrossUnion(pushedAgg, union.schema, child.(LogicalPlan)) newChildren = append(newChildren, newChild) newChild.SetParents(union) } union.SetChildren(newChildren...) union.SetSchema(pushedAgg.schema) } } for _, child := range p.GetChildren() { a.aggPushDown(child.(LogicalPlan)) } }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *Projection) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { retPlan = p var push []expression.Expression for _, cond := range predicates { canSubstitute := true extractedCols := expression.ExtractColumns(cond) for _, col := range extractedCols { id := p.GetSchema().GetIndex(col) if _, ok := p.Exprs[id].(*expression.ScalarFunction); ok { canSubstitute = false break } } if canSubstitute { push = append(push, expression.ColumnSubstitute(cond, p.GetSchema(), p.Exprs)) } else { ret = append(ret, cond) } } child := p.GetChildByIndex(0).(LogicalPlan) restConds, _, err1 := child.PredicatePushDown(push) if err1 != nil { return nil, nil, errors.Trace(err1) } if len(restConds) > 0 { err1 = addSelection(p, child, restConds, p.allocator) if err1 != nil { return nil, nil, errors.Trace(err1) } } return }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *Union) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { retPlan = p for _, proj := range p.children { newExprs := make([]expression.Expression, 0, len(predicates)) for _, cond := range predicates { newCond := expression.ColumnSubstitute(cond, p.GetSchema(), expression.Schema2Exprs(proj.GetSchema())) newExprs = append(newExprs, newCond) } retCond, _, err := proj.(LogicalPlan).PredicatePushDown(newExprs) if err != nil { return nil, nil, errors.Trace(err) } if len(retCond) != 0 { addSelection(p, proj.(LogicalPlan), retCond, p.allocator) } } return }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *Aggregation) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { retPlan = p var exprsOriginal []expression.Expression var condsToPush []expression.Expression for _, fun := range p.AggFuncs { exprsOriginal = append(exprsOriginal, fun.GetArgs()[0]) } for _, cond := range predicates { switch cond.(type) { case *expression.Constant: condsToPush = append(condsToPush, cond) // Consider SQL list "select sum(b) from t group by a having 1=0". "1=0" is a constant predicate which should be // retained and pushed down at the same time. Because we will get a wrong query result that contains one column // with value 0 rather than an empty query result. ret = append(ret, cond) case *expression.ScalarFunction: extractedCols := expression.ExtractColumns(cond) ok := true for _, col := range extractedCols { if p.getGbyColIndex(col) == -1 { ok = false break } } if ok { newFunc := expression.ColumnSubstitute(cond.Clone(), p.GetSchema(), exprsOriginal) condsToPush = append(condsToPush, newFunc) } else { ret = append(ret, cond) } default: ret = append(ret, cond) } } p.baseLogicalPlan.PredicatePushDown(condsToPush) return }