func (b *planBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) LogicalPlan {
	conditions := splitWhere(where)
	expressions := make([]expression.Expression, 0, len(conditions))
	selection := &Selection{baseLogicalPlan: newBaseLogicalPlan(Sel, b.allocator)}
	selection.initID()
	selection.correlated = p.IsCorrelated()
	for _, cond := range conditions {
		expr, np, correlated, err := b.rewrite(cond, p, AggMapper, false)
		if err != nil {
			b.err = err
			return nil
		}
		p = np
		selection.correlated = selection.correlated || correlated
		if expr != nil {
			expressions = append(expressions, expression.SplitCNFItems(expr)...)
		}
	}
	if len(expressions) == 0 {
		return p
	}
	selection.Conditions = expressions
	selection.SetSchema(p.GetSchema().DeepCopy())
	addChild(selection, p)
	return selection
}
Exemple #2
0
func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {
	if join.Right == nil {
		return b.buildResultSetNode(join.Left)
	}
	leftPlan := b.buildResultSetNode(join.Left)
	rightPlan := b.buildResultSetNode(join.Right)
	newSchema := append(leftPlan.GetSchema().Clone(), rightPlan.GetSchema().Clone()...)
	joinPlan := &Join{baseLogicalPlan: newBaseLogicalPlan(Jn, b.allocator)}
	joinPlan.self = joinPlan
	joinPlan.initIDAndContext(b.ctx)
	joinPlan.SetSchema(newSchema)
	joinPlan.correlated = leftPlan.IsCorrelated() || rightPlan.IsCorrelated()
	if join.On != nil {
		onExpr, _, err := b.rewrite(join.On.Expr, joinPlan, nil, false)
		if err != nil {
			b.err = err
			return nil
		}
		if onExpr.IsCorrelated() {
			b.err = errors.New("ON condition doesn't support subqueries yet")
		}
		onCondition := expression.SplitCNFItems(onExpr)
		eqCond, leftCond, rightCond, otherCond := extractOnCondition(onCondition, leftPlan, rightPlan)
		joinPlan.EqualConditions = eqCond
		joinPlan.LeftConditions = leftCond
		joinPlan.RightConditions = rightCond
		joinPlan.OtherConditions = otherCond
	} else if joinPlan.JoinType == InnerJoin {
		joinPlan.cartesianJoin = true
	}
	if join.Tp == ast.LeftJoin {
		joinPlan.JoinType = LeftOuterJoin
		joinPlan.DefaultValues = make([]types.Datum, len(rightPlan.GetSchema()))
	} else if join.Tp == ast.RightJoin {
		joinPlan.JoinType = RightOuterJoin
		joinPlan.DefaultValues = make([]types.Datum, len(leftPlan.GetSchema()))
	} else {
		joinPlan.JoinType = InnerJoin
	}
	addChild(joinPlan, leftPlan)
	addChild(joinPlan, rightPlan)
	return joinPlan
}
func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) {
	asScalar := er.asScalar
	er.asScalar = true
	v.Expr.Accept(er)
	if er.err != nil {
		return v, true
	}
	lexpr := er.ctxStack[len(er.ctxStack)-1]
	subq, ok := v.Sel.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown compare type %T.", v.Sel)
		return v, true
	}
	np, outerSchema := er.buildSubquery(subq)
	if er.err != nil {
		return v, true
	}
	if getRowLen(lexpr) != len(np.GetSchema()) {
		er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr))
		return v, true
	}
	var rexpr expression.Expression
	if len(np.GetSchema()) == 1 {
		rexpr = np.GetSchema()[0].DeepCopy()
	} else {
		args := make([]expression.Expression, 0, len(np.GetSchema()))
		for _, col := range np.GetSchema() {
			args = append(args, col.DeepCopy())
		}
		rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...)
		if er.err != nil {
			er.err = errors.Trace(er.err)
			return v, true
		}
	}
	// a in (subq) will be rewrited as a = any(subq).
	// a not in (subq) will be rewrited as a != all(subq).
	checkCondition, err := constructBinaryOpFunction(lexpr, rexpr, ast.EQ)
	if !np.IsCorrelated() {
		er.p = er.b.buildSemiJoin(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not)
		if asScalar {
			col := er.p.GetSchema()[len(er.p.GetSchema())-1]
			er.ctxStack[len(er.ctxStack)-1] = col
		} else {
			er.ctxStack = er.ctxStack[:len(er.ctxStack)-1]
		}
		return v, true
	}
	if v.Not {
		checkCondition, _ = expression.NewFunction(ast.UnaryNot, &v.Type, checkCondition)
	}
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: v.Not})
	if er.p.IsCorrelated() {
		er.correlated = true
	}
	// The parent expression only use the last column in schema, which represents whether the condition is matched.
	er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1]
	return v, true

}
// propagateConstant propagate constant values of equality predicates and inequality predicates in a condition.
func propagateConstant(conditions []expression.Expression) []expression.Expression {
	if len(conditions) == 0 {
		return conditions
	}
	// Propagate constants in equality predicates.
	// e.g. for condition: "a = b and b = c and c = a and a = 1";
	// we propagate constant as the following step:
	// first: "1 = b and b = c and c = 1 and a = 1";
	// next:  "1 = b and 1 = c and c = 1 and a = 1";
	// next:  "1 = b and 1 = c and 1 = 1 and a = 1";
	// next:  "1 = b and 1 = c and a = 1";

	// e.g for condition: "a = b and b = c and b = 2 and a = 1";
	// we propagate constant as the following step:
	// first: "a = 2 and 2 = c and b = 2 and a = 1";
	// next:  "a = 2 and 2 = c and b = 2 and 2 = 1";
	// next:  "0"
	isSource := make([]bool, len(conditions))
	type transitiveEqualityPredicate map[string]*expression.Constant // transitive equality predicates between one column and one constant
	for {
		equalities := make(transitiveEqualityPredicate, 0)
		for i, getOneEquality := 0, false; i < len(conditions) && !getOneEquality; i++ {
			if isSource[i] {
				continue
			}
			expr, ok := conditions[i].(*expression.ScalarFunction)
			if !ok {
				continue
			}
			// process the included OR conditions recursively to do the same for CNF item.
			switch expr.FuncName.L {
			case ast.OrOr:
				expressions := expression.SplitDNFItems(conditions[i])
				newExpression := make([]expression.Expression, 0)
				for _, v := range expressions {
					newExpression = append(newExpression, propagateConstant([]expression.Expression{v})...)
				}
				conditions[i] = expression.ComposeDNFCondition(newExpression)
				isSource[i] = true
			case ast.AndAnd:
				newExpression := propagateConstant(expression.SplitCNFItems(conditions[i]))
				conditions[i] = expression.ComposeCNFCondition(newExpression)
				isSource[i] = true
			case ast.EQ:
				var (
					col *expression.Column
					val *expression.Constant
				)
				leftConst, leftIsConst := expr.Args[0].(*expression.Constant)
				rightConst, rightIsConst := expr.Args[1].(*expression.Constant)
				leftCol, leftIsCol := expr.Args[0].(*expression.Column)
				rightCol, rightIsCol := expr.Args[1].(*expression.Column)
				if rightIsConst && leftIsCol {
					col = leftCol
					val = rightConst
				} else if leftIsConst && rightIsCol {
					col = rightCol
					val = leftConst
				} else {
					continue
				}
				equalities[string(col.HashCode())] = val
				isSource[i] = true
				getOneEquality = true
			}
		}
		if len(equalities) == 0 {
			break
		}
		for i := 0; i < len(conditions); i++ {
			if isSource[i] {
				continue
			}
			if len(equalities) != 0 {
				conditions[i] = constantSubstitute(equalities, conditions[i])
			}
		}
	}
	// Propagate transitive inequality predicates.
	// e.g for conditions "a = b and c = d and a = c and g = h and b > 0 and e != 0 and g like 'abc'",
	//     we propagate constant as the following step:
	// 1. build multiple equality predicates(mep):
	//    =(a, b, c, d), =(g, h).
	// 2. extract inequality predicates between one constant and one column,
	//    and rewrite them using the root column of a multiple equality predicate:
	//    b > 0, e != 0, g like 'abc' ==> a > 0, g like 'abc'.
	//    ATTENTION: here column 'e' doesn't belong to any mep, so we skip "e != 0".
	// 3. propagate constants in these inequality predicates, and we finally get:
	//    "a = b and c = d and a = c and e = f and g = h and e != 0 and a > 0 and b > 0 and c > 0 and d > 0 and g like 'abc' and h like 'abc' ".
	multipleEqualities := make(map[*expression.Column]*expression.Column, 0)
	for _, cond := range conditions { // build multiple equality predicates.
		expr, ok := cond.(*expression.ScalarFunction)
		if ok && expr.FuncName.L == ast.EQ {
			left, ok1 := expr.Args[0].(*expression.Column)
			right, ok2 := expr.Args[1].(*expression.Column)
			if ok1 && ok2 {
				UnionColumns(left, right, multipleEqualities)
			}
		}
	}
	if len(multipleEqualities) == 0 {
		return conditions
	}
	inequalityFuncs := map[string]string{
		ast.LT:   ast.LT,
		ast.GT:   ast.GT,
		ast.LE:   ast.LE,
		ast.GE:   ast.GE,
		ast.NE:   ast.NE,
		ast.Like: ast.Like,
	}
	type inequalityFactor struct {
		FuncName string
		Factor   []*expression.Constant
	}
	type transitiveInEqualityPredicate map[string][]inequalityFactor // transitive inequality predicates between one column and one constant.
	inequalities := make(transitiveInEqualityPredicate, 0)
	for i := 0; i < len(conditions); i++ { // extract inequality predicates.
		var (
			column   *expression.Column
			equalCol *expression.Column // the root column corresponding to a column in a multiple equality predicate.
			val      *expression.Constant
			funcName string
		)
		expr, ok := conditions[i].(*expression.ScalarFunction)
		if !ok {
			continue
		}
		funcName, ok = inequalityFuncs[expr.FuncName.L]
		if !ok {
			continue
		}
		leftConst, leftIsConst := expr.Args[0].(*expression.Constant)
		rightConst, rightIsConst := expr.Args[1].(*expression.Constant)
		leftCol, leftIsCol := expr.Args[0].(*expression.Column)
		rightCol, rightIsCol := expr.Args[1].(*expression.Column)
		if rightIsConst && leftIsCol {
			column = leftCol
			val = rightConst
		} else if leftIsConst && rightIsCol {
			column = rightCol
			val = leftConst
		} else {
			continue
		}
		equalCol, ok = multipleEqualities[column]
		if !ok { // no need to propagate inequality predicates whose column is only equal to itself.
			continue
		}
		colHashCode := string(equalCol.HashCode())
		if funcName == ast.Like { // func 'LIKE' need 3 input arguments, so here we handle it alone.
			inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: ast.Like, Factor: []*expression.Constant{val, expr.Args[2].(*expression.Constant)}})
		} else {
			inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: funcName, Factor: []*expression.Constant{val}})
		}
		conditions = append(conditions[:i], conditions[i+1:]...)
		i--
	}
	if len(inequalities) == 0 {
		return conditions
	}
	for k, v := range multipleEqualities { // propagate constants in inequality predicates.
		for _, x := range inequalities[string(v.HashCode())] {
			funcName, factors := x.FuncName, x.Factor
			if funcName == ast.Like {
				for i := 0; i < len(factors); i += 2 {
					newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i], factors[i+1])
					conditions = append(conditions, newFunc)
				}
			} else {
				for i := 0; i < len(factors); i++ {
					newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i])
					conditions = append(conditions, newFunc)
					i++
				}
			}
		}
	}
	return conditions
}