func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema expression.Schema, unionChild LogicalPlan) LogicalPlan { newAgg := &Aggregation{ AggFuncs: make([]expression.AggregationFunction, 0, len(agg.AggFuncs)), GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)), baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc), } newAgg.SetSchema(agg.schema.Clone()) newAgg.correlated = agg.correlated newAgg.initIDAndContext(a.ctx) for _, aggFunc := range agg.AggFuncs { newAggFunc := aggFunc.Clone() newArgs := make([]expression.Expression, 0, len(newAggFunc.GetArgs())) for _, arg := range newAggFunc.GetArgs() { newArgs = append(newArgs, expression.ColumnSubstitute(arg, unionSchema, expression.Schema2Exprs(unionChild.GetSchema()))) } newAggFunc.SetArgs(newArgs) newAgg.AggFuncs = append(newAgg.AggFuncs, newAggFunc) } for _, gbyExpr := range agg.GroupByItems { newExpr := expression.ColumnSubstitute(gbyExpr, unionSchema, expression.Schema2Exprs(unionChild.GetSchema())) newAgg.GroupByItems = append(newAgg.GroupByItems, newExpr) } newAgg.collectGroupByColumns() newAgg.SetChildren(unionChild) unionChild.SetParents(newAgg) return newAgg }
func (a *aggPushDownSolver) makeNewAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column) *Aggregation { agg := &Aggregation{ GroupByItems: expression.Schema2Exprs(gbyCols), baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc), groupByCols: gbyCols, } agg.initIDAndContext(a.ctx) var newAggFuncs []expression.AggregationFunction schema := make(expression.Schema, 0, len(aggFuncs)) for _, aggFunc := range aggFuncs { var newFuncs []expression.AggregationFunction newFuncs, schema = a.decompose(aggFunc, schema, agg.GetID()) newAggFuncs = append(newAggFuncs, newFuncs...) for _, arg := range aggFunc.GetArgs() { agg.correlated = agg.correlated || arg.IsCorrelated() } } for _, gbyCol := range gbyCols { firstRow := expression.NewAggFunction(ast.AggFuncFirstRow, []expression.Expression{gbyCol.Clone()}, false) newAggFuncs = append(newAggFuncs, firstRow) schema = append(schema, gbyCol.Clone().(*expression.Column)) } agg.AggFuncs = newAggFuncs agg.SetSchema(schema) return agg }
// decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently // there are no differences between partial mode and complete mode, so we can confuse them. func (a *aggPushDownSolver) decompose(aggFunc expression.AggregationFunction, schema expression.Schema, id string) ([]expression.AggregationFunction, expression.Schema) { // Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case. result := []expression.AggregationFunction{aggFunc.Clone()} for _, aggFunc := range result { schema = append(schema, &expression.Column{ ColName: model.NewCIStr(fmt.Sprintf("join_agg_%d", len(schema))), // useless but for debug FromID: id, Position: len(schema), RetType: aggFunc.GetType(), }) } aggFunc.SetArgs(expression.Schema2Exprs(schema[len(schema)-len(result):])) aggFunc.SetMode(expression.FinalMode) return result, schema }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *NewUnion) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, err error) { for _, proj := range p.Selects { newExprs := make([]expression.Expression, 0, len(predicates)) for _, cond := range predicates { newCond := columnSubstitute(cond.DeepCopy(), p.GetSchema(), expression.Schema2Exprs(proj.GetSchema())) newExprs = append(newExprs, newCond) } retCond, err := proj.PredicatePushDown(newExprs) if err != nil { return nil, errors.Trace(err) } if len(retCond) != 0 { addSelection(p, proj, retCond, p.allocator) } } return }
// PredicatePushDown implements LogicalPlan PredicatePushDown interface. func (p *Union) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { retPlan = p for _, proj := range p.children { newExprs := make([]expression.Expression, 0, len(predicates)) for _, cond := range predicates { newCond := expression.ColumnSubstitute(cond, p.GetSchema(), expression.Schema2Exprs(proj.GetSchema())) newExprs = append(newExprs, newCond) } retCond, _, err := proj.(LogicalPlan).PredicatePushDown(newExprs) if err != nil { return nil, nil, errors.Trace(err) } if len(retCond) != 0 { addSelection(p, proj.(LogicalPlan), retCond, p.allocator) } } return }
// predicatePushDown applies predicate push down to all kinds of plans, except aggregation and union. func (b *planBuilder) predicatePushDown(p Plan, predicates []expression.Expression) (ret []expression.Expression, err error) { switch v := p.(type) { case *NewTableScan: return predicates, nil case *Selection: conditions := v.Conditions retConditions, err1 := b.predicatePushDown(p.GetChildByIndex(0), append(conditions, predicates...)) if err1 != nil { return nil, errors.Trace(err1) } if len(retConditions) > 0 { v.Conditions = retConditions } else { if len(p.GetParents()) == 0 { return ret, nil } err1 = RemovePlan(p) if err1 != nil { return nil, errors.Trace(err1) } } return case *Join: //TODO: add null rejecter. var leftCond, rightCond []expression.Expression leftPlan := v.GetChildByIndex(0) rightPlan := v.GetChildByIndex(1) equalCond, leftPushCond, rightPushCond, otherCond := extractOnCondition(predicates, leftPlan, rightPlan) if v.JoinType == LeftOuterJoin { rightCond = v.RightConditions leftCond = leftPushCond ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, rightPushCond...) } else if v.JoinType == RightOuterJoin { leftCond = v.LeftConditions rightCond = rightPushCond ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, leftPushCond...) } else { leftCond = append(v.LeftConditions, leftPushCond...) rightCond = append(v.RightConditions, rightPushCond...) } leftRet, err1 := b.predicatePushDown(leftPlan, leftCond) if err1 != nil { return nil, errors.Trace(err1) } rightRet, err2 := b.predicatePushDown(rightPlan, rightCond) if err2 != nil { return nil, errors.Trace(err2) } if len(leftRet) > 0 { err2 = b.addSelection(p, leftPlan, leftRet) if err2 != nil { return nil, errors.Trace(err2) } } if len(rightRet) > 0 { err2 = b.addSelection(p, rightPlan, rightRet) if err2 != nil { return nil, errors.Trace(err2) } } if v.JoinType == InnerJoin { v.EqualConditions = append(v.EqualConditions, equalCond...) v.OtherConditions = append(v.OtherConditions, otherCond...) } return case *Projection: if len(v.GetChildren()) == 0 { return predicates, nil } var push []expression.Expression for _, cond := range predicates { canSubstitute := true extractedCols, _ := extractColumn(cond, nil, nil) for _, col := range extractedCols { id := v.GetSchema().GetIndex(col) if _, ok := v.Exprs[id].(*expression.ScalarFunction); ok { canSubstitute = false break } } if canSubstitute { push = append(push, columnSubstitute(cond, v.GetSchema(), v.Exprs)) } else { ret = append(ret, cond) } } restConds, err1 := b.predicatePushDown(v.GetChildByIndex(0), push) if err1 != nil { return nil, errors.Trace(err1) } if len(restConds) > 0 { err1 = b.addSelection(v, v.GetChildByIndex(0), restConds) if err1 != nil { return nil, errors.Trace(err1) } } return case *NewSort, *Limit, *Distinct: rest, err1 := b.predicatePushDown(p.GetChildByIndex(0), predicates) if err1 != nil { return nil, errors.Trace(err1) } if len(rest) > 0 { err1 = b.addSelection(p, p.GetChildByIndex(0), rest) if err1 != nil { return nil, errors.Trace(err1) } } return case *Union: for _, proj := range v.Selects { newExprs := make([]expression.Expression, 0, len(predicates)) for _, cond := range predicates { newCond := columnSubstitute(cond.DeepCopy(), v.GetSchema(), expression.Schema2Exprs(proj.GetSchema())) newExprs = append(newExprs, newCond) } retCond, err := b.predicatePushDown(proj, newExprs) if err != nil { return nil, errors.Trace(err) } if len(retCond) != 0 { b.addSelection(v, proj, retCond) } } return //TODO: support aggregation, apply. case *Aggregation, *Simple, *Apply: return predicates, nil default: log.Warnf("Unknown Type %T in Predicate Pushdown", v) return predicates, nil } }