func (b *planBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) LogicalPlan { conditions := splitWhere(where) expressions := make([]expression.Expression, 0, len(conditions)) selection := &Selection{baseLogicalPlan: newBaseLogicalPlan(Sel, b.allocator)} selection.initID() selection.correlated = p.IsCorrelated() for _, cond := range conditions { expr, np, correlated, err := b.rewrite(cond, p, AggMapper, false) if err != nil { b.err = err return nil } p = np selection.correlated = selection.correlated || correlated if expr != nil { expressions = append(expressions, expression.SplitCNFItems(expr)...) } } if len(expressions) == 0 { return p } selection.Conditions = expressions selection.SetSchema(p.GetSchema().DeepCopy()) addChild(selection, p) return selection }
func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan { if join.Right == nil { return b.buildResultSetNode(join.Left) } leftPlan := b.buildResultSetNode(join.Left) rightPlan := b.buildResultSetNode(join.Right) newSchema := append(leftPlan.GetSchema().Clone(), rightPlan.GetSchema().Clone()...) joinPlan := &Join{baseLogicalPlan: newBaseLogicalPlan(Jn, b.allocator)} joinPlan.self = joinPlan joinPlan.initIDAndContext(b.ctx) joinPlan.SetSchema(newSchema) joinPlan.correlated = leftPlan.IsCorrelated() || rightPlan.IsCorrelated() if join.On != nil { onExpr, _, err := b.rewrite(join.On.Expr, joinPlan, nil, false) if err != nil { b.err = err return nil } if onExpr.IsCorrelated() { b.err = errors.New("ON condition doesn't support subqueries yet") } onCondition := expression.SplitCNFItems(onExpr) eqCond, leftCond, rightCond, otherCond := extractOnCondition(onCondition, leftPlan, rightPlan) joinPlan.EqualConditions = eqCond joinPlan.LeftConditions = leftCond joinPlan.RightConditions = rightCond joinPlan.OtherConditions = otherCond } else if joinPlan.JoinType == InnerJoin { joinPlan.cartesianJoin = true } if join.Tp == ast.LeftJoin { joinPlan.JoinType = LeftOuterJoin joinPlan.DefaultValues = make([]types.Datum, len(rightPlan.GetSchema())) } else if join.Tp == ast.RightJoin { joinPlan.JoinType = RightOuterJoin joinPlan.DefaultValues = make([]types.Datum, len(leftPlan.GetSchema())) } else { joinPlan.JoinType = InnerJoin } addChild(joinPlan, leftPlan) addChild(joinPlan, rightPlan) return joinPlan }
func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) { asScalar := er.asScalar er.asScalar = true v.Expr.Accept(er) if er.err != nil { return v, true } lexpr := er.ctxStack[len(er.ctxStack)-1] subq, ok := v.Sel.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown compare type %T.", v.Sel) return v, true } np, outerSchema := er.buildSubquery(subq) if er.err != nil { return v, true } if getRowLen(lexpr) != len(np.GetSchema()) { er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr)) return v, true } var rexpr expression.Expression if len(np.GetSchema()) == 1 { rexpr = np.GetSchema()[0].DeepCopy() } else { args := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { args = append(args, col.DeepCopy()) } rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...) if er.err != nil { er.err = errors.Trace(er.err) return v, true } } // a in (subq) will be rewrited as a = any(subq). // a not in (subq) will be rewrited as a != all(subq). checkCondition, err := constructBinaryOpFunction(lexpr, rexpr, ast.EQ) if !np.IsCorrelated() { er.p = er.b.buildSemiJoin(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not) if asScalar { col := er.p.GetSchema()[len(er.p.GetSchema())-1] er.ctxStack[len(er.ctxStack)-1] = col } else { er.ctxStack = er.ctxStack[:len(er.ctxStack)-1] } return v, true } if v.Not { checkCondition, _ = expression.NewFunction(ast.UnaryNot, &v.Type, checkCondition) } if err != nil { er.err = errors.Trace(err) return v, true } er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: v.Not}) if er.p.IsCorrelated() { er.correlated = true } // The parent expression only use the last column in schema, which represents whether the condition is matched. er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1] return v, true }
// propagateConstant propagate constant values of equality predicates and inequality predicates in a condition. func propagateConstant(conditions []expression.Expression) []expression.Expression { if len(conditions) == 0 { return conditions } // Propagate constants in equality predicates. // e.g. for condition: "a = b and b = c and c = a and a = 1"; // we propagate constant as the following step: // first: "1 = b and b = c and c = 1 and a = 1"; // next: "1 = b and 1 = c and c = 1 and a = 1"; // next: "1 = b and 1 = c and 1 = 1 and a = 1"; // next: "1 = b and 1 = c and a = 1"; // e.g for condition: "a = b and b = c and b = 2 and a = 1"; // we propagate constant as the following step: // first: "a = 2 and 2 = c and b = 2 and a = 1"; // next: "a = 2 and 2 = c and b = 2 and 2 = 1"; // next: "0" isSource := make([]bool, len(conditions)) type transitiveEqualityPredicate map[string]*expression.Constant // transitive equality predicates between one column and one constant for { equalities := make(transitiveEqualityPredicate, 0) for i, getOneEquality := 0, false; i < len(conditions) && !getOneEquality; i++ { if isSource[i] { continue } expr, ok := conditions[i].(*expression.ScalarFunction) if !ok { continue } // process the included OR conditions recursively to do the same for CNF item. switch expr.FuncName.L { case ast.OrOr: expressions := expression.SplitDNFItems(conditions[i]) newExpression := make([]expression.Expression, 0) for _, v := range expressions { newExpression = append(newExpression, propagateConstant([]expression.Expression{v})...) } conditions[i] = expression.ComposeDNFCondition(newExpression) isSource[i] = true case ast.AndAnd: newExpression := propagateConstant(expression.SplitCNFItems(conditions[i])) conditions[i] = expression.ComposeCNFCondition(newExpression) isSource[i] = true case ast.EQ: var ( col *expression.Column val *expression.Constant ) leftConst, leftIsConst := expr.Args[0].(*expression.Constant) rightConst, rightIsConst := expr.Args[1].(*expression.Constant) leftCol, leftIsCol := expr.Args[0].(*expression.Column) rightCol, rightIsCol := expr.Args[1].(*expression.Column) if rightIsConst && leftIsCol { col = leftCol val = rightConst } else if leftIsConst && rightIsCol { col = rightCol val = leftConst } else { continue } equalities[string(col.HashCode())] = val isSource[i] = true getOneEquality = true } } if len(equalities) == 0 { break } for i := 0; i < len(conditions); i++ { if isSource[i] { continue } if len(equalities) != 0 { conditions[i] = constantSubstitute(equalities, conditions[i]) } } } // Propagate transitive inequality predicates. // e.g for conditions "a = b and c = d and a = c and g = h and b > 0 and e != 0 and g like 'abc'", // we propagate constant as the following step: // 1. build multiple equality predicates(mep): // =(a, b, c, d), =(g, h). // 2. extract inequality predicates between one constant and one column, // and rewrite them using the root column of a multiple equality predicate: // b > 0, e != 0, g like 'abc' ==> a > 0, g like 'abc'. // ATTENTION: here column 'e' doesn't belong to any mep, so we skip "e != 0". // 3. propagate constants in these inequality predicates, and we finally get: // "a = b and c = d and a = c and e = f and g = h and e != 0 and a > 0 and b > 0 and c > 0 and d > 0 and g like 'abc' and h like 'abc' ". multipleEqualities := make(map[*expression.Column]*expression.Column, 0) for _, cond := range conditions { // build multiple equality predicates. expr, ok := cond.(*expression.ScalarFunction) if ok && expr.FuncName.L == ast.EQ { left, ok1 := expr.Args[0].(*expression.Column) right, ok2 := expr.Args[1].(*expression.Column) if ok1 && ok2 { UnionColumns(left, right, multipleEqualities) } } } if len(multipleEqualities) == 0 { return conditions } inequalityFuncs := map[string]string{ ast.LT: ast.LT, ast.GT: ast.GT, ast.LE: ast.LE, ast.GE: ast.GE, ast.NE: ast.NE, ast.Like: ast.Like, } type inequalityFactor struct { FuncName string Factor []*expression.Constant } type transitiveInEqualityPredicate map[string][]inequalityFactor // transitive inequality predicates between one column and one constant. inequalities := make(transitiveInEqualityPredicate, 0) for i := 0; i < len(conditions); i++ { // extract inequality predicates. var ( column *expression.Column equalCol *expression.Column // the root column corresponding to a column in a multiple equality predicate. val *expression.Constant funcName string ) expr, ok := conditions[i].(*expression.ScalarFunction) if !ok { continue } funcName, ok = inequalityFuncs[expr.FuncName.L] if !ok { continue } leftConst, leftIsConst := expr.Args[0].(*expression.Constant) rightConst, rightIsConst := expr.Args[1].(*expression.Constant) leftCol, leftIsCol := expr.Args[0].(*expression.Column) rightCol, rightIsCol := expr.Args[1].(*expression.Column) if rightIsConst && leftIsCol { column = leftCol val = rightConst } else if leftIsConst && rightIsCol { column = rightCol val = leftConst } else { continue } equalCol, ok = multipleEqualities[column] if !ok { // no need to propagate inequality predicates whose column is only equal to itself. continue } colHashCode := string(equalCol.HashCode()) if funcName == ast.Like { // func 'LIKE' need 3 input arguments, so here we handle it alone. inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: ast.Like, Factor: []*expression.Constant{val, expr.Args[2].(*expression.Constant)}}) } else { inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: funcName, Factor: []*expression.Constant{val}}) } conditions = append(conditions[:i], conditions[i+1:]...) i-- } if len(inequalities) == 0 { return conditions } for k, v := range multipleEqualities { // propagate constants in inequality predicates. for _, x := range inequalities[string(v.HashCode())] { funcName, factors := x.FuncName, x.Factor if funcName == ast.Like { for i := 0; i < len(factors); i += 2 { newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i], factors[i+1]) conditions = append(conditions, newFunc) } } else { for i := 0; i < len(factors); i++ { newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i]) conditions = append(conditions, newFunc) i++ } } } } return conditions }