func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema expression.Schema, unionChild LogicalPlan) LogicalPlan {
	newAgg := &Aggregation{
		AggFuncs:        make([]expression.AggregationFunction, 0, len(agg.AggFuncs)),
		GroupByItems:    make([]expression.Expression, 0, len(agg.GroupByItems)),
		baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc),
	}
	newAgg.SetSchema(agg.schema.Clone())
	newAgg.correlated = agg.correlated
	newAgg.initIDAndContext(a.ctx)
	for _, aggFunc := range agg.AggFuncs {
		newAggFunc := aggFunc.Clone()
		newArgs := make([]expression.Expression, 0, len(newAggFunc.GetArgs()))
		for _, arg := range newAggFunc.GetArgs() {
			newArgs = append(newArgs, expression.ColumnSubstitute(arg, unionSchema, expression.Schema2Exprs(unionChild.GetSchema())))
		}
		newAggFunc.SetArgs(newArgs)
		newAgg.AggFuncs = append(newAgg.AggFuncs, newAggFunc)
	}
	for _, gbyExpr := range agg.GroupByItems {
		newExpr := expression.ColumnSubstitute(gbyExpr, unionSchema, expression.Schema2Exprs(unionChild.GetSchema()))
		newAgg.GroupByItems = append(newAgg.GroupByItems, newExpr)
	}
	newAgg.collectGroupByColumns()
	newAgg.SetChildren(unionChild)
	unionChild.SetParents(newAgg)
	return newAgg
}
func (a *aggPushDownSolver) makeNewAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column) *Aggregation {
	agg := &Aggregation{
		GroupByItems:    expression.Schema2Exprs(gbyCols),
		baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc),
		groupByCols:     gbyCols,
	}
	agg.initIDAndContext(a.ctx)
	var newAggFuncs []expression.AggregationFunction
	schema := make(expression.Schema, 0, len(aggFuncs))
	for _, aggFunc := range aggFuncs {
		var newFuncs []expression.AggregationFunction
		newFuncs, schema = a.decompose(aggFunc, schema, agg.GetID())
		newAggFuncs = append(newAggFuncs, newFuncs...)
		for _, arg := range aggFunc.GetArgs() {
			agg.correlated = agg.correlated || arg.IsCorrelated()
		}
	}
	for _, gbyCol := range gbyCols {
		firstRow := expression.NewAggFunction(ast.AggFuncFirstRow, []expression.Expression{gbyCol.Clone()}, false)
		newAggFuncs = append(newAggFuncs, firstRow)
		schema = append(schema, gbyCol.Clone().(*expression.Column))
	}
	agg.AggFuncs = newAggFuncs
	agg.SetSchema(schema)
	return agg
}
// decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently
// there are no differences between partial mode and complete mode, so we can confuse them.
func (a *aggPushDownSolver) decompose(aggFunc expression.AggregationFunction, schema expression.Schema, id string) ([]expression.AggregationFunction, expression.Schema) {
	// Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case.
	result := []expression.AggregationFunction{aggFunc.Clone()}
	for _, aggFunc := range result {
		schema = append(schema, &expression.Column{
			ColName:  model.NewCIStr(fmt.Sprintf("join_agg_%d", len(schema))), // useless but for debug
			FromID:   id,
			Position: len(schema),
			RetType:  aggFunc.GetType(),
		})
	}
	aggFunc.SetArgs(expression.Schema2Exprs(schema[len(schema)-len(result):]))
	aggFunc.SetMode(expression.FinalMode)
	return result, schema
}
// PredicatePushDown implements LogicalPlan PredicatePushDown interface.
func (p *NewUnion) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, err error) {
	for _, proj := range p.Selects {
		newExprs := make([]expression.Expression, 0, len(predicates))
		for _, cond := range predicates {
			newCond := columnSubstitute(cond.DeepCopy(), p.GetSchema(), expression.Schema2Exprs(proj.GetSchema()))
			newExprs = append(newExprs, newCond)
		}
		retCond, err := proj.PredicatePushDown(newExprs)
		if err != nil {
			return nil, errors.Trace(err)
		}
		if len(retCond) != 0 {
			addSelection(p, proj, retCond, p.allocator)
		}
	}
	return
}
Exemple #5
0
// PredicatePushDown implements LogicalPlan PredicatePushDown interface.
func (p *Union) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) {
	retPlan = p
	for _, proj := range p.children {
		newExprs := make([]expression.Expression, 0, len(predicates))
		for _, cond := range predicates {
			newCond := expression.ColumnSubstitute(cond, p.GetSchema(), expression.Schema2Exprs(proj.GetSchema()))
			newExprs = append(newExprs, newCond)
		}
		retCond, _, err := proj.(LogicalPlan).PredicatePushDown(newExprs)
		if err != nil {
			return nil, nil, errors.Trace(err)
		}
		if len(retCond) != 0 {
			addSelection(p, proj.(LogicalPlan), retCond, p.allocator)
		}
	}
	return
}
Exemple #6
0
// predicatePushDown applies predicate push down to all kinds of plans, except aggregation and union.
func (b *planBuilder) predicatePushDown(p Plan, predicates []expression.Expression) (ret []expression.Expression, err error) {
	switch v := p.(type) {
	case *NewTableScan:
		return predicates, nil
	case *Selection:
		conditions := v.Conditions
		retConditions, err1 := b.predicatePushDown(p.GetChildByIndex(0), append(conditions, predicates...))
		if err1 != nil {
			return nil, errors.Trace(err1)
		}
		if len(retConditions) > 0 {
			v.Conditions = retConditions
		} else {
			if len(p.GetParents()) == 0 {
				return ret, nil
			}
			err1 = RemovePlan(p)
			if err1 != nil {
				return nil, errors.Trace(err1)
			}
		}
		return
	case *Join:
		//TODO: add null rejecter.
		var leftCond, rightCond []expression.Expression
		leftPlan := v.GetChildByIndex(0)
		rightPlan := v.GetChildByIndex(1)
		equalCond, leftPushCond, rightPushCond, otherCond := extractOnCondition(predicates, leftPlan, rightPlan)
		if v.JoinType == LeftOuterJoin {
			rightCond = v.RightConditions
			leftCond = leftPushCond
			ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...)
			ret = append(ret, rightPushCond...)
		} else if v.JoinType == RightOuterJoin {
			leftCond = v.LeftConditions
			rightCond = rightPushCond
			ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...)
			ret = append(ret, leftPushCond...)
		} else {
			leftCond = append(v.LeftConditions, leftPushCond...)
			rightCond = append(v.RightConditions, rightPushCond...)
		}
		leftRet, err1 := b.predicatePushDown(leftPlan, leftCond)
		if err1 != nil {
			return nil, errors.Trace(err1)
		}
		rightRet, err2 := b.predicatePushDown(rightPlan, rightCond)
		if err2 != nil {
			return nil, errors.Trace(err2)
		}
		if len(leftRet) > 0 {
			err2 = b.addSelection(p, leftPlan, leftRet)
			if err2 != nil {
				return nil, errors.Trace(err2)
			}
		}
		if len(rightRet) > 0 {
			err2 = b.addSelection(p, rightPlan, rightRet)
			if err2 != nil {
				return nil, errors.Trace(err2)
			}
		}
		if v.JoinType == InnerJoin {
			v.EqualConditions = append(v.EqualConditions, equalCond...)
			v.OtherConditions = append(v.OtherConditions, otherCond...)
		}
		return
	case *Projection:
		if len(v.GetChildren()) == 0 {
			return predicates, nil
		}
		var push []expression.Expression
		for _, cond := range predicates {
			canSubstitute := true
			extractedCols, _ := extractColumn(cond, nil, nil)
			for _, col := range extractedCols {
				id := v.GetSchema().GetIndex(col)
				if _, ok := v.Exprs[id].(*expression.ScalarFunction); ok {
					canSubstitute = false
					break
				}
			}
			if canSubstitute {
				push = append(push, columnSubstitute(cond, v.GetSchema(), v.Exprs))
			} else {
				ret = append(ret, cond)
			}
		}
		restConds, err1 := b.predicatePushDown(v.GetChildByIndex(0), push)
		if err1 != nil {
			return nil, errors.Trace(err1)
		}
		if len(restConds) > 0 {
			err1 = b.addSelection(v, v.GetChildByIndex(0), restConds)
			if err1 != nil {
				return nil, errors.Trace(err1)
			}
		}
		return
	case *NewSort, *Limit, *Distinct:
		rest, err1 := b.predicatePushDown(p.GetChildByIndex(0), predicates)
		if err1 != nil {
			return nil, errors.Trace(err1)
		}
		if len(rest) > 0 {
			err1 = b.addSelection(p, p.GetChildByIndex(0), rest)
			if err1 != nil {
				return nil, errors.Trace(err1)
			}
		}
		return
	case *Union:
		for _, proj := range v.Selects {
			newExprs := make([]expression.Expression, 0, len(predicates))
			for _, cond := range predicates {
				newCond := columnSubstitute(cond.DeepCopy(), v.GetSchema(), expression.Schema2Exprs(proj.GetSchema()))
				newExprs = append(newExprs, newCond)
			}
			retCond, err := b.predicatePushDown(proj, newExprs)
			if err != nil {
				return nil, errors.Trace(err)
			}
			if len(retCond) != 0 {
				b.addSelection(v, proj, retCond)
			}
		}
		return
	//TODO: support aggregation, apply.
	case *Aggregation, *Simple, *Apply:
		return predicates, nil
	default:
		log.Warnf("Unknown Type %T in Predicate Pushdown", v)
		return predicates, nil
	}
}