func (er *expressionRewriter) betweenToScalarFunc(v *ast.BetweenExpr) { stkLen := len(er.ctxStack) var op string var l, r *expression.ScalarFunction if v.Not { l, er.err = expression.NewFunction(ast.LT, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) if er.err == nil { r, er.err = expression.NewFunction(ast.GT, v.Type, er.ctxStack[stkLen-3].DeepCopy(), er.ctxStack[stkLen-1]) } op = ast.OrOr } else { l, er.err = expression.NewFunction(ast.GE, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) if er.err == nil { r, er.err = expression.NewFunction(ast.LE, v.Type, er.ctxStack[stkLen-3].DeepCopy(), er.ctxStack[stkLen-1]) } op = ast.AndAnd } if er.err != nil { er.err = errors.Trace(er.err) return } function, err := expression.NewFunction(op, v.Type, l, r) if err != nil { er.err = errors.Trace(err) return } er.ctxStack = er.ctxStack[:stkLen-3] er.ctxStack = append(er.ctxStack, function) }
func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) { stkLen := len(er.ctxStack) er.checkArgsOneColumn(er.ctxStack[stkLen-3:]...) if er.err != nil { return } var op string var l, r expression.Expression l, er.err = expression.NewFunction(ast.GE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) if er.err == nil { r, er.err = expression.NewFunction(ast.LE, &v.Type, er.ctxStack[stkLen-3].Clone(), er.ctxStack[stkLen-1]) } op = ast.AndAnd if er.err != nil { er.err = errors.Trace(er.err) return } function, err := expression.NewFunction(op, &v.Type, l, r) if err != nil { er.err = errors.Trace(err) return } if v.Not { function, err = expression.NewFunction(ast.UnaryNot, &v.Type, function) if err != nil { er.err = errors.Trace(err) return } } er.ctxStack = er.ctxStack[:stkLen-3] er.ctxStack = append(er.ctxStack, function) }
func (er *expressionRewriter) notToScalarFunc(b bool, op string, tp *types.FieldType, args ...expression.Expression) *expression.ScalarFunction { opFunc := expression.NewFunction(op, tp, args...) if !b { return opFunc } return expression.NewFunction(ast.UnaryNot, tp, opFunc) }
func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Node, bool) { np := er.buildSubquery(v) if er.err != nil { return v, true } np = er.b.buildMaxOneRow(np) if np.IsCorrelated() { er.p = er.b.buildApply(er.p, np, nil) if len(np.GetSchema()) > 1 { newCols := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { newCols = append(newCols, col.Clone()) } expr, err := expression.NewFunction(ast.RowFunc, nil, newCols...) if err != nil { er.err = errors.Trace(err) return v, true } er.ctxStack = append(er.ctxStack, expr) } else { er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1]) } return v, true } physicalPlan, err := doOptimize(np, er.b.ctx, er.b.allocator) if err != nil { er.err = errors.Trace(err) return v, true } d, err := EvalSubquery(physicalPlan, er.b.is, er.b.ctx) if err != nil { er.err = errors.Trace(err) return v, true } if len(np.GetSchema()) > 1 { newCols := make([]expression.Expression, 0, len(np.GetSchema())) for i, data := range d { newCols = append(newCols, &expression.Constant{ Value: data, RetType: np.GetSchema()[i].GetType()}) } expr, err1 := expression.NewFunction(ast.RowFunc, nil, newCols...) if err1 != nil { er.err = errors.Trace(err1) return v, true } er.ctxStack = append(er.ctxStack, expr) } else { er.ctxStack = append(er.ctxStack, &expression.Constant{ Value: d[0], RetType: np.GetSchema()[0].GetType(), }) } return v, true }
func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) { stkLen := len(er.ctxStack) argsLen := 2 * len(v.WhenClauses) if v.ElseClause != nil { argsLen++ } er.checkArgsOneColumn(er.ctxStack[stkLen-argsLen:]...) if er.err != nil { return } // value -> ctxStack[stkLen-argsLen-1] // when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1]; // else clause -> ctxStack[stkLen-1] var args []expression.Expression if v.Value != nil { // args: eq scalar func(args: value, condition1), result1, // eq scalar func(args: value, condition2), result2, // ... // else clause value := er.ctxStack[stkLen-argsLen-1] args = make([]expression.Expression, 0, argsLen) for i := stkLen - argsLen; i < stkLen-1; i += 2 { arg, err := expression.NewFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), value.Clone(), er.ctxStack[i]) if err != nil { er.err = errors.Trace(err) return } args = append(args, arg) args = append(args, er.ctxStack[i+1]) } if v.ElseClause != nil { args = append(args, er.ctxStack[stkLen-1]) } argsLen++ // for trimming the value element later } else { // args: condition1, result1, // condition2, result2, // ... // else clause args = er.ctxStack[stkLen-argsLen:] } function, err := expression.NewFunction(ast.Case, &v.Type, args...) if err != nil { er.err = errors.Trace(err) return } er.ctxStack = er.ctxStack[:stkLen-argsLen] er.ctxStack = append(er.ctxStack, function) }
func (er *expressionRewriter) funcCallToScalarFunc(v *ast.FuncCallExpr) { stackLen := len(er.ctxStack) var function *expression.ScalarFunction function, er.err = expression.NewFunction(v.FnName.L, v.Type, er.ctxStack[stackLen-len(v.Args):]...) er.ctxStack = er.ctxStack[:stackLen-len(v.Args)] er.ctxStack = append(er.ctxStack, function) }
func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { stkLen := len(er.ctxStack) var function expression.Expression switch v.Op { case opcode.EQ, opcode.NE, opcode.NullEQ: function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1], opcode.Ops[v.Op]) default: lLen := getRowLen(er.ctxStack[stkLen-2]) rLen := getRowLen(er.ctxStack[stkLen-1]) switch v.Op { case opcode.GT, opcode.GE, opcode.LT, opcode.LE: if lLen != rLen { er.err = ErrOperandColumns.GenByArgs(lLen) } default: if lLen != 1 || rLen != 1 { er.err = ErrOperandColumns.GenByArgs(1) } } if er.err != nil { return } function, er.err = expression.NewFunction(opcode.Ops[v.Op], &v.Type, er.ctxStack[stkLen-2:]...) } if er.err != nil { er.err = errors.Trace(er.err) return } er.ctxStack = er.ctxStack[:stkLen-2] er.ctxStack = append(er.ctxStack, function) }
func pushDownNot(expr expression.Expression, not bool) expression.Expression { if f, ok := expr.(*expression.ScalarFunction); ok { switch f.FuncName.L { case ast.UnaryNot: return pushDownNot(f.Args[0], !not) case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: if not { nf, _ := expression.NewFunction(oppositeOp[f.FuncName.L], f.GetType(), f.Args...) return nf } for i, arg := range f.Args { f.Args[i] = pushDownNot(arg, false) } return f case ast.AndAnd: if not { args := f.Args for i, a := range args { args[i] = pushDownNot(a, true) } nf, _ := expression.NewFunction(ast.OrOr, f.GetType(), args...) return nf } for i, arg := range f.Args { f.Args[i] = pushDownNot(arg, false) } return f case ast.OrOr: if not { args := f.Args for i, a := range args { args[i] = pushDownNot(a, true) } nf, _ := expression.NewFunction(ast.AndAnd, f.GetType(), args...) return nf } for i, arg := range f.Args { f.Args[i] = pushDownNot(arg, false) } return f } } if not { expr, _ = expression.NewFunction(ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr) } return expr }
func (er *expressionRewriter) betweenToScalarFunc(v *ast.BetweenExpr) { stkLen := len(er.ctxStack) var op string var l, r *expression.ScalarFunction if v.Not { l = expression.NewFunction(ast.LT, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) r = expression.NewFunction(ast.GT, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1]) op = ast.OrOr } else { l = expression.NewFunction(ast.GE, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2]) r = expression.NewFunction(ast.LE, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1]) op = ast.AndAnd } function := expression.NewFunction(op, v.Type, l, r) er.ctxStack = er.ctxStack[:stkLen-3] er.ctxStack = append(er.ctxStack, function) }
func (er *expressionRewriter) notToScalarFunc(hasNot bool, op string, tp *types.FieldType, args ...expression.Expression) *expression.ScalarFunction { opFunc, err := expression.NewFunction(op, tp, args...) if err != nil { er.err = errors.Trace(err) return nil } if !hasNot { return opFunc } opFunc, err = expression.NewFunction(ast.UnaryNot, tp, opFunc) if err != nil { er.err = errors.Trace(err) return nil } return opFunc }
func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) { stackLen := len(er.ctxStack) args := er.ctxStack[stackLen-len(v.Args):] er.checkArgsOneColumn(args...) if er.err != nil { return } var function expression.Expression function, er.err = expression.NewFunction(v.FnName.L, &v.Type, args...) er.ctxStack = er.ctxStack[:stackLen-len(v.Args)] er.ctxStack = append(er.ctxStack, function) }
// compose CNF items into a balance deep CNF tree, which benefits a lot for pb decoder/encoder. func composeCondition(conditions []expression.Expression) expression.Expression { length := len(conditions) if length == 0 { return nil } if length == 1 { return conditions[0] } return expression.NewFunction(model.NewCIStr(ast.AndAnd), []expression.Expression{composeCondition(conditions[0 : length/2]), composeCondition(conditions[length/2:])}) }
func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) { v.Expr.Accept(er) if er.err != nil { return v, true } lexpr := er.ctxStack[len(er.ctxStack)-1] subq, ok := v.Sel.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown compare type %T.", v.Sel) return v, true } np, outerSchema := er.buildSubquery(subq) if er.err != nil { return v, true } if getRowLen(lexpr) != len(np.GetSchema()) { er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr)) return v, true } var rexpr expression.Expression if len(np.GetSchema()) == 1 { rexpr = np.GetSchema()[0].DeepCopy() } else { args := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { args = append(args, col.DeepCopy()) } rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...) if er.err != nil { er.err = errors.Trace(er.err) return v, true } } // a in (subq) will be rewrited as a = any(subq). // a not in (subq) will be rewrited as a != all(subq). op, all := ast.EQ, false if v.Not { op, all = ast.NE, true } checkCondition, err := constructBinaryOpFunction(lexpr, rexpr, op) if err != nil { er.err = errors.Trace(err) return v, true } er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: all}) if er.p.IsCorrelated() { er.correlated = true } // The parent expression only use the last column in schema, which represents whether the condition is matched. er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1] return v, true }
func (er *expressionRewriter) rowToScalarFunc(v *ast.RowExpr) { stkLen := len(er.ctxStack) length := len(v.Values) rows := make([]expression.Expression, 0, length) for i := stkLen - length; i < stkLen; i++ { rows = append(rows, er.ctxStack[i]) } er.ctxStack = er.ctxStack[:stkLen-length] function, err := expression.NewFunction(ast.RowFunc, nil, rows...) if err != nil { er.err = errors.Trace(err) return } er.ctxStack = append(er.ctxStack, function) }
// constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2). func constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) { lLen, rLen := getRowLen(l), getRowLen(r) if lLen == 1 && rLen == 1 { return expression.NewFunction(op, types.NewFieldType(mysql.TypeTiny), l, r) } else if rLen != lLen { return nil, errors.Errorf("Operand should contain %d column(s)", lLen) } funcs := make([]expression.Expression, lLen) for i := 0; i < lLen; i++ { var err error funcs[i], err = constructBinaryOpFunction(getRowArg(l, i), getRowArg(r, i), op) if err != nil { return nil, err } } return expression.ComposeCNFCondition(funcs), nil }
// constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2). func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) { lLen, rLen := getRowLen(l), getRowLen(r) if lLen == 1 && rLen == 1 { return expression.NewFunction(op, types.NewFieldType(mysql.TypeTiny), l, r) } else if rLen != lLen { return nil, ErrOperandColumns.GenByArgs(lLen) } funcs := make([]expression.Expression, lLen) for i := 0; i < lLen; i++ { var err error funcs[i], err = er.constructBinaryOpFunction(getRowArg(l, i), getRowArg(r, i), op) if err != nil { return nil, errors.Trace(err) } } return expression.ComposeCNFCondition(funcs), nil }
// constantSubstitute substitute column expression in a condition by an equivalent constant. func constantSubstitute(equalities map[string]*expression.Constant, condition expression.Expression) expression.Expression { switch expr := condition.(type) { case *expression.Column: if v, ok := equalities[string(expr.HashCode())]; ok { return v } case *expression.ScalarFunction: for i, arg := range expr.Args { expr.Args[i] = constantSubstitute(equalities, arg) } if _, ok := evaluator.Funcs[expr.FuncName.L]; ok { condition, _ = expression.NewFunction(expr.FuncName.L, expr.RetType, expr.Args...) } return condition } return condition }
func (er *expressionRewriter) binaryOpToScalarFunc(v *ast.BinaryOperationExpr) { stkLen := len(er.ctxStack) var function expression.Expression switch v.Op { case opcode.EQ, opcode.NE, opcode.NullEQ: function, er.err = constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1], opcode.Ops[v.Op]) default: function, er.err = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...) } if er.err != nil { er.err = errors.Trace(er.err) return } er.ctxStack = er.ctxStack[:stkLen-2] er.ctxStack = append(er.ctxStack, function) }
func (er *expressionRewriter) unaryOpToScalarFunc(v *ast.UnaryOperationExpr) { stkLen := len(er.ctxStack) if getRowLen(er.ctxStack[stkLen-1]) != 1 { er.err = errors.New("Operand should contain 1 column(s)") } var op string switch v.Op { case opcode.Plus: op = ast.UnaryPlus case opcode.Minus: op = ast.UnaryMinus case opcode.BitNeg: op = ast.BitNeg case opcode.Not: op = ast.UnaryNot default: er.err = errors.Errorf("Unknown Unary Op %T", v.Op) return } er.ctxStack[stkLen-1], er.err = expression.NewFunction(op, v.Type, er.ctxStack[stkLen-1]) }
func extractOnCondition(conditions []expression.Expression, left LogicalPlan, right LogicalPlan) ( eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { for _, expr := range conditions { binop, ok := expr.(*expression.ScalarFunction) if ok && binop.FuncName.L == ast.EQ { ln, lOK := binop.Args[0].(*expression.Column) rn, rOK := binop.Args[1].(*expression.Column) if lOK && rOK { if left.GetSchema().GetIndex(ln) != -1 && right.GetSchema().GetIndex(rn) != -1 { eqCond = append(eqCond, binop) continue } if left.GetSchema().GetIndex(rn) != -1 && right.GetSchema().GetIndex(ln) != -1 { cond, _ := expression.NewFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), rn, ln) eqCond = append(eqCond, cond.(*expression.ScalarFunction)) continue } } } columns, _ := extractColumn(expr, nil, nil) allFromLeft, allFromRight := true, true for _, col := range columns { if left.GetSchema().GetIndex(col) == -1 { allFromLeft = false } if right.GetSchema().GetIndex(col) == -1 { allFromRight = false } } if allFromRight { rightCond = append(rightCond, expr) } else if allFromLeft { leftCond = append(leftCond, expr) } else { otherCond = append(otherCond, expr) } } return }
func extractOnCondition(conditions []expression.Expression, left Plan, right Plan) ( eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { for _, expr := range conditions { binop, ok := expr.(*expression.ScalarFunction) if ok && binop.FuncName.L == ast.EQ { ln, lOK := binop.Args[0].(*expression.Column) rn, rOK := binop.Args[1].(*expression.Column) if lOK && rOK { if left.GetSchema().GetIndex(ln) != -1 && right.GetSchema().GetIndex(rn) != -1 { eqCond = append(eqCond, binop) continue } if left.GetSchema().GetIndex(rn) != -1 && right.GetSchema().GetIndex(ln) != -1 { newEq := expression.NewFunction(model.NewCIStr(ast.EQ), []expression.Expression{rn, ln}) eqCond = append(eqCond, newEq) continue } } } columns, _ := extractColumn(expr, nil, nil) allFromLeft, allFromRight := true, true for _, col := range columns { if left.GetSchema().GetIndex(col) != -1 { allFromRight = false } else { allFromLeft = false } } if allFromRight { rightCond = append(rightCond, expr) } else if allFromLeft { leftCond = append(leftCond, expr) } else { otherCond = append(otherCond, expr) } } return }
// calculateResultOfExpression set inner table columns in a expression as null and calculate the finally result of the scalar function. func calculateResultOfExpression(schema expression.Schema, expr expression.Expression) (expression.Expression, error) { switch x := expr.(type) { case *expression.ScalarFunction: var err error args := make([]expression.Expression, len(x.Args)) for i, arg := range x.Args { args[i], err = calculateResultOfExpression(schema, arg) } if err != nil { return nil, errors.Trace(err) } return expression.NewFunction(x.FuncName.L, types.NewFieldType(mysql.TypeTiny), args...) case *expression.Column: if schema.GetIndex(x) == -1 { return x, nil } constant := &expression.Constant{Value: types.Datum{}} constant.Value.SetNull() return constant, nil default: return x.DeepCopy(), nil } }
func (er *expressionRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) { stkLen := len(er.ctxStack) var op string switch v.Op { case opcode.Plus: // expression (+ a) is equal to a return case opcode.Minus: op = ast.UnaryMinus case opcode.BitNeg: op = ast.BitNeg case opcode.Not: op = ast.UnaryNot default: er.err = errors.Errorf("Unknown Unary Op %T", v.Op) return } if getRowLen(er.ctxStack[stkLen-1]) != 1 { er.err = ErrOperandColumns.GenByArgs(1) return } er.ctxStack[stkLen-1], er.err = expression.NewFunction(op, &v.Type, er.ctxStack[stkLen-1]) }
// propagateConstant propagate constant values of equality predicates and inequality predicates in a condition. func propagateConstant(conditions []expression.Expression) []expression.Expression { if len(conditions) == 0 { return conditions } // Propagate constants in equality predicates. // e.g. for condition: "a = b and b = c and c = a and a = 1"; // we propagate constant as the following step: // first: "1 = b and b = c and c = 1 and a = 1"; // next: "1 = b and 1 = c and c = 1 and a = 1"; // next: "1 = b and 1 = c and 1 = 1 and a = 1"; // next: "1 = b and 1 = c and a = 1"; // e.g for condition: "a = b and b = c and b = 2 and a = 1"; // we propagate constant as the following step: // first: "a = 2 and 2 = c and b = 2 and a = 1"; // next: "a = 2 and 2 = c and b = 2 and 2 = 1"; // next: "0" isSource := make([]bool, len(conditions)) type transitiveEqualityPredicate map[string]*expression.Constant // transitive equality predicates between one column and one constant for { equalities := make(transitiveEqualityPredicate, 0) for i, getOneEquality := 0, false; i < len(conditions) && !getOneEquality; i++ { if isSource[i] { continue } expr, ok := conditions[i].(*expression.ScalarFunction) if !ok { continue } // process the included OR conditions recursively to do the same for CNF item. switch expr.FuncName.L { case ast.OrOr: expressions := expression.SplitDNFItems(conditions[i]) newExpression := make([]expression.Expression, 0) for _, v := range expressions { newExpression = append(newExpression, propagateConstant([]expression.Expression{v})...) } conditions[i] = expression.ComposeDNFCondition(newExpression) isSource[i] = true case ast.AndAnd: newExpression := propagateConstant(expression.SplitCNFItems(conditions[i])) conditions[i] = expression.ComposeCNFCondition(newExpression) isSource[i] = true case ast.EQ: var ( col *expression.Column val *expression.Constant ) leftConst, leftIsConst := expr.Args[0].(*expression.Constant) rightConst, rightIsConst := expr.Args[1].(*expression.Constant) leftCol, leftIsCol := expr.Args[0].(*expression.Column) rightCol, rightIsCol := expr.Args[1].(*expression.Column) if rightIsConst && leftIsCol { col = leftCol val = rightConst } else if leftIsConst && rightIsCol { col = rightCol val = leftConst } else { continue } equalities[string(col.HashCode())] = val isSource[i] = true getOneEquality = true } } if len(equalities) == 0 { break } for i := 0; i < len(conditions); i++ { if isSource[i] { continue } if len(equalities) != 0 { conditions[i] = constantSubstitute(equalities, conditions[i]) } } } // Propagate transitive inequality predicates. // e.g for conditions "a = b and c = d and a = c and g = h and b > 0 and e != 0 and g like 'abc'", // we propagate constant as the following step: // 1. build multiple equality predicates(mep): // =(a, b, c, d), =(g, h). // 2. extract inequality predicates between one constant and one column, // and rewrite them using the root column of a multiple equality predicate: // b > 0, e != 0, g like 'abc' ==> a > 0, g like 'abc'. // ATTENTION: here column 'e' doesn't belong to any mep, so we skip "e != 0". // 3. propagate constants in these inequality predicates, and we finally get: // "a = b and c = d and a = c and e = f and g = h and e != 0 and a > 0 and b > 0 and c > 0 and d > 0 and g like 'abc' and h like 'abc' ". multipleEqualities := make(map[*expression.Column]*expression.Column, 0) for _, cond := range conditions { // build multiple equality predicates. expr, ok := cond.(*expression.ScalarFunction) if ok && expr.FuncName.L == ast.EQ { left, ok1 := expr.Args[0].(*expression.Column) right, ok2 := expr.Args[1].(*expression.Column) if ok1 && ok2 { UnionColumns(left, right, multipleEqualities) } } } if len(multipleEqualities) == 0 { return conditions } inequalityFuncs := map[string]string{ ast.LT: ast.LT, ast.GT: ast.GT, ast.LE: ast.LE, ast.GE: ast.GE, ast.NE: ast.NE, ast.Like: ast.Like, } type inequalityFactor struct { FuncName string Factor []*expression.Constant } type transitiveInEqualityPredicate map[string][]inequalityFactor // transitive inequality predicates between one column and one constant. inequalities := make(transitiveInEqualityPredicate, 0) for i := 0; i < len(conditions); i++ { // extract inequality predicates. var ( column *expression.Column equalCol *expression.Column // the root column corresponding to a column in a multiple equality predicate. val *expression.Constant funcName string ) expr, ok := conditions[i].(*expression.ScalarFunction) if !ok { continue } funcName, ok = inequalityFuncs[expr.FuncName.L] if !ok { continue } leftConst, leftIsConst := expr.Args[0].(*expression.Constant) rightConst, rightIsConst := expr.Args[1].(*expression.Constant) leftCol, leftIsCol := expr.Args[0].(*expression.Column) rightCol, rightIsCol := expr.Args[1].(*expression.Column) if rightIsConst && leftIsCol { column = leftCol val = rightConst } else if leftIsConst && rightIsCol { column = rightCol val = leftConst } else { continue } equalCol, ok = multipleEqualities[column] if !ok { // no need to propagate inequality predicates whose column is only equal to itself. continue } colHashCode := string(equalCol.HashCode()) if funcName == ast.Like { // func 'LIKE' need 3 input arguments, so here we handle it alone. inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: ast.Like, Factor: []*expression.Constant{val, expr.Args[2].(*expression.Constant)}}) } else { inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: funcName, Factor: []*expression.Constant{val}}) } conditions = append(conditions[:i], conditions[i+1:]...) i-- } if len(inequalities) == 0 { return conditions } for k, v := range multipleEqualities { // propagate constants in inequality predicates. for _, x := range inequalities[string(v.HashCode())] { funcName, factors := x.FuncName, x.Factor if funcName == ast.Like { for i := 0; i < len(factors); i += 2 { newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i], factors[i+1]) conditions = append(conditions, newFunc) } } else { for i := 0; i < len(factors); i++ { newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i]) conditions = append(conditions, newFunc) i++ } } } } return conditions }
// Enter implements Visitor interface. func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { switch v := inNode.(type) { case *ast.AggregateFuncExpr: index, ok := -1, false if er.aggrMap != nil { index, ok = er.aggrMap[v] } if !ok { er.err = errors.New("Can't appear aggrFunctions") return inNode, true } er.ctxStack = append(er.ctxStack, er.schema[index]) return inNode, true case *ast.ColumnNameExpr: if index, ok := er.b.colMapper[v]; ok { er.ctxStack = append(er.ctxStack, er.schema[index]) return inNode, true } case *ast.CompareSubqueryExpr: v.L.Accept(er) if er.err != nil { return inNode, true } lexpr := er.ctxStack[len(er.ctxStack)-1] subq, ok := v.R.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown compare type %T.", v.R) return inNode, true } np, outerSchema := er.buildSubquery(subq) if er.err != nil { return inNode, true } // Only (a,b,c) = all (...) and (a,b,c) != any () can use row expression. canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE) if !canMultiCol && (getRowLen(lexpr) != 1 || len(np.GetSchema()) != 1) { er.err = errors.New("Operand should contain 1 column(s)") return inNode, true } if getRowLen(lexpr) != len(np.GetSchema()) { er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr)) return inNode, true } var checkCondition expression.Expression var rexpr expression.Expression if len(np.GetSchema()) == 1 { rexpr = np.GetSchema()[0].DeepCopy() } else { args := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { args = append(args, col.DeepCopy()) } rexpr, er.err = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...) if er.err != nil { er.err = errors.Trace(er.err) return inNode, true } } switch v.Op { // Only EQ, NE and NullEQ can be composed with and. case opcode.EQ, opcode.NE, opcode.NullEQ: checkCondition, er.err = constructBinaryOpFunction(lexpr, rexpr, opcode.Ops[v.Op]) if er.err != nil { er.err = errors.Trace(er.err) return inNode, true } // If op is not EQ, NE, NullEQ, say LT, it will remain as row(a,b) < row(c,d), and be compared as row datum. default: checkCondition, er.err = expression.NewFunction(opcode.Ops[v.Op], types.NewFieldType(mysql.TypeTiny), lexpr, rexpr) if er.err != nil { er.err = errors.Trace(er.err) return inNode, true } } er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: v.All}) // The parent expression only use the last column in schema, which represents whether the condition is matched. er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1] return inNode, true case *ast.ExistsSubqueryExpr: subq, ok := v.Sel.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown exists type %T.", v.Sel) return inNode, true } np, outerSchema := er.buildSubquery(subq) if er.err != nil { return inNode, true } np = er.b.buildExists(np) if np.IsCorrelated() { er.p = er.b.buildApply(er.p, np, outerSchema, nil) er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1]) } else { _, err := np.PruneColumnsAndResolveIndices(np.GetSchema()) if err != nil { er.err = errors.Trace(err) return inNode, true } d, err := EvalSubquery(np, er.b.is, er.b.ctx) if err != nil { er.err = errors.Trace(err) return inNode, true } er.ctxStack = append(er.ctxStack, &expression.Constant{ Value: d[0], RetType: np.GetSchema()[0].GetType()}) } return inNode, true case *ast.PatternInExpr: if v.Sel != nil { v.Expr.Accept(er) if er.err != nil { return inNode, true } lexpr := er.ctxStack[len(er.ctxStack)-1] subq, ok := v.Sel.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown compare type %T.", v.Sel) return inNode, true } np, outerSchema := er.buildSubquery(subq) if er.err != nil { return inNode, true } if getRowLen(lexpr) != len(np.GetSchema()) { er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr)) return inNode, true } var rexpr expression.Expression if len(np.GetSchema()) == 1 { rexpr = np.GetSchema()[0].DeepCopy() } else { args := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { args = append(args, col.DeepCopy()) } rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...) if er.err != nil { er.err = errors.Trace(er.err) return inNode, true } } // a in (subq) will be rewrited as a = any(subq). // a not in (subq) will be rewrited as a != all(subq). op, all := ast.EQ, false if v.Not { op, all = ast.NE, true } checkCondition, err := constructBinaryOpFunction(lexpr, rexpr, op) if err != nil { er.err = errors.Trace(err) return inNode, true } er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: all}) // The parent expression only use the last column in schema, which represents whether the condition is matched. er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1] return inNode, true } case *ast.SubqueryExpr: np, outerSchema := er.buildSubquery(v) if er.err != nil { return inNode, true } np = er.b.buildMaxOneRow(np) if np.IsCorrelated() { er.p = er.b.buildApply(er.p, np, outerSchema, nil) if len(np.GetSchema()) > 1 { newCols := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { newCols = append(newCols, col.DeepCopy()) } expr, err := expression.NewFunction(ast.RowFunc, nil, newCols...) if err != nil { er.err = errors.Trace(err) return inNode, true } er.ctxStack = append(er.ctxStack, expr) } else { er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1]) } return inNode, true } _, err := np.PruneColumnsAndResolveIndices(np.GetSchema()) if err != nil { er.err = errors.Trace(err) return inNode, true } d, err := EvalSubquery(np, er.b.is, er.b.ctx) if err != nil { er.err = errors.Trace(err) return inNode, true } if len(np.GetSchema()) > 1 { newCols := make([]expression.Expression, 0, len(np.GetSchema())) for i, data := range d { newCols = append(newCols, &expression.Constant{ Value: data, RetType: np.GetSchema()[i].GetType()}) } expr, err1 := expression.NewFunction(ast.RowFunc, nil, newCols...) if err1 != nil { er.err = errors.Trace(err1) return inNode, true } er.ctxStack = append(er.ctxStack, expr) } else { er.ctxStack = append(er.ctxStack, &expression.Constant{ Value: d[0], RetType: np.GetSchema()[0].GetType(), }) } return inNode, true } return inNode, false }
func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { stkLen := len(er.ctxStack) name := strings.ToLower(v.Name) sessionVars := variable.GetSessionVars(er.b.ctx) globalVars := variable.GetGlobalVarAccessor(er.b.ctx) if !v.IsSystem { if v.Value != nil { er.ctxStack[stkLen-1], er.err = expression.NewFunction(ast.SetVar, er.ctxStack[stkLen-1].GetType(), datumToConstant(types.NewDatum(name), mysql.TypeString), er.ctxStack[stkLen-1]) return } if _, ok := sessionVars.Users[name]; ok { f, err := expression.NewFunction(ast.GetVar, // TODO: Here is wrong, the sessionVars should store a name -> Datum map. Will fix it later. types.NewFieldType(mysql.TypeString), datumToConstant(types.NewStringDatum(name), mysql.TypeString)) if err != nil { er.err = errors.Trace(err) return } er.ctxStack = append(er.ctxStack, f) } else { // select null user vars is permitted. er.ctxStack = append(er.ctxStack, &expression.Constant{RetType: types.NewFieldType(mysql.TypeNull)}) } return } sysVar, ok := variable.SysVars[name] if !ok { // select null sys vars is not permitted er.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) return } if sysVar.Scope == variable.ScopeNone { er.ctxStack = append(er.ctxStack, datumToConstant(types.NewDatum(sysVar.Value), mysql.TypeString)) return } if v.IsGlobal { value, err := globalVars.GetGlobalSysVar(er.b.ctx, name) if err != nil { er.err = errors.Trace(err) return } er.ctxStack = append(er.ctxStack, datumToConstant(types.NewDatum(value), mysql.TypeString)) return } d := sessionVars.GetSystemVar(name) if d.IsNull() { if sysVar.Scope&variable.ScopeGlobal == 0 { d.SetString(sysVar.Value) } else { // Get global system variable and fill it in session. globalVal, err := globalVars.GetGlobalSysVar(er.b.ctx, name) if err != nil { er.err = errors.Trace(err) return } d.SetString(globalVal) err = sessionVars.SetSystemVar(name, d) if err != nil { er.err = errors.Trace(err) return } } } er.ctxStack = append(er.ctxStack, datumToConstant(d, mysql.TypeString)) return }
func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Node, bool) { np, outerSchema := er.buildSubquery(v) if er.err != nil { return v, true } np = er.b.buildMaxOneRow(np) if np.IsCorrelated() { er.p = er.b.buildApply(er.p, np, outerSchema, nil) if er.p.IsCorrelated() { er.correlated = true } if len(np.GetSchema()) > 1 { newCols := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { newCols = append(newCols, col.DeepCopy()) } expr, err := expression.NewFunction(ast.RowFunc, nil, newCols...) if err != nil { er.err = errors.Trace(err) return v, true } er.ctxStack = append(er.ctxStack, expr) } else { er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1]) } return v, true } _, np, er.err = np.PredicatePushDown(nil) if er.err != nil { return v, true } _, err := np.PruneColumnsAndResolveIndices(np.GetSchema()) if err != nil { er.err = errors.Trace(err) return v, true } _, res, _, err := np.convert2PhysicalPlan(nil) if err != nil { er.err = errors.Trace(err) return v, true } phyPlan := res.p.PushLimit(nil) d, err := EvalSubquery(phyPlan, er.b.is, er.b.ctx) if err != nil { er.err = errors.Trace(err) return v, true } if len(np.GetSchema()) > 1 { newCols := make([]expression.Expression, 0, len(np.GetSchema())) for i, data := range d { newCols = append(newCols, &expression.Constant{ Value: data, RetType: np.GetSchema()[i].GetType()}) } expr, err1 := expression.NewFunction(ast.RowFunc, nil, newCols...) if err1 != nil { er.err = errors.Trace(err1) return v, true } er.ctxStack = append(er.ctxStack, expr) } else { er.ctxStack = append(er.ctxStack, &expression.Constant{ Value: d[0], RetType: np.GetSchema()[0].GetType(), }) } return v, true }
func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) { asScalar := er.asScalar er.asScalar = true v.Expr.Accept(er) if er.err != nil { return v, true } lexpr := er.ctxStack[len(er.ctxStack)-1] subq, ok := v.Sel.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown compare type %T.", v.Sel) return v, true } np := er.buildSubquery(subq) if er.err != nil { return v, true } lLen := getRowLen(lexpr) if lLen != len(np.GetSchema()) { er.err = ErrOperandColumns.GenByArgs(lLen) return v, true } var rexpr expression.Expression if len(np.GetSchema()) == 1 { rexpr = np.GetSchema()[0].Clone() } else { args := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { args = append(args, col.Clone()) } rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...) if er.err != nil { er.err = errors.Trace(er.err) return v, true } } // a in (subq) will be rewrited as a = any(subq). // a not in (subq) will be rewrited as a != all(subq). checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ) if err != nil { er.err = errors.Trace(err) return v, true } if !np.IsCorrelated() { er.p = er.b.buildSemiJoin(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not) if asScalar { col := er.p.GetSchema()[len(er.p.GetSchema())-1] er.ctxStack[len(er.ctxStack)-1] = col } else { er.ctxStack = er.ctxStack[:len(er.ctxStack)-1] } return v, true } if v.Not { checkCondition, _ = expression.NewFunction(ast.UnaryNot, &v.Type, checkCondition) } er.p = er.b.buildApply(er.p, np, &ApplyConditionChecker{Condition: checkCondition, All: v.Not}) // The parent expression only use the last column in schema, which represents whether the condition is matched. er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1] return v, true }
func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) { v.L.Accept(er) if er.err != nil { return v, true } lexpr := er.ctxStack[len(er.ctxStack)-1] subq, ok := v.R.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown compare type %T.", v.R) return v, true } np, outerSchema := er.buildSubquery(subq) if er.err != nil { return v, true } // Only (a,b,c) = all (...) and (a,b,c) != any () can use row expression. canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE) if !canMultiCol && (getRowLen(lexpr) != 1 || len(np.GetSchema()) != 1) { er.err = errors.New("Operand should contain 1 column(s)") return v, true } if getRowLen(lexpr) != len(np.GetSchema()) { er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr)) return v, true } var checkCondition expression.Expression var rexpr expression.Expression if len(np.GetSchema()) == 1 { rexpr = np.GetSchema()[0].DeepCopy() } else { args := make([]expression.Expression, 0, len(np.GetSchema())) for _, col := range np.GetSchema() { args = append(args, col.DeepCopy()) } rexpr, er.err = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...) if er.err != nil { er.err = errors.Trace(er.err) return v, true } } switch v.Op { // Only EQ, NE and NullEQ can be composed with and. case opcode.EQ, opcode.NE, opcode.NullEQ: checkCondition, er.err = constructBinaryOpFunction(lexpr, rexpr, opcode.Ops[v.Op]) if er.err != nil { er.err = errors.Trace(er.err) return v, true } // If op is not EQ, NE, NullEQ, say LT, it will remain as row(a,b) < row(c,d), and be compared as row datum. default: checkCondition, er.err = expression.NewFunction(opcode.Ops[v.Op], types.NewFieldType(mysql.TypeTiny), lexpr, rexpr) if er.err != nil { er.err = errors.Trace(er.err) return v, true } } er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: v.All}) if er.p.IsCorrelated() { er.correlated = true } // The parent expression only use the last column in schema, which represents whether the condition is matched. er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1] return v, true }
// Leave implements Visitor interface. func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) { if er.err != nil { return retNode, false } stkLen := len(er.ctxStack) switch v := inNode.(type) { case *ast.AggregateFuncExpr: case *ast.RowExpr: length := len(v.Values) rows := make([]expression.Expression, 0, length) for i := stkLen - length; i < stkLen; i++ { rows = append(rows, er.ctxStack[i]) } er.ctxStack = er.ctxStack[:stkLen-length] er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, nil, rows...)) case *ast.VariableExpr: return inNode, er.rewriteVariable(v) case *ast.FuncCallExpr: er.funcCallToScalarFunc(v) case *ast.PositionExpr: if v.N > 0 && v.N <= len(er.schema) { er.ctxStack = append(er.ctxStack, er.schema[v.N-1]) } else { er.err = errors.Errorf("Position %d is out of range", v.N) } case *ast.ColumnName: er.toColumn(v) case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr: case *ast.ValueExpr: value := &expression.Constant{Value: v.Datum, RetType: v.Type} er.ctxStack = append(er.ctxStack, value) case *ast.ParamMarkerExpr: value := &expression.Constant{Value: v.Datum, RetType: v.Type} er.ctxStack = append(er.ctxStack, value) case *ast.IsNullExpr: if getRowLen(er.ctxStack[stkLen-1]) != 1 { er.err = errors.New("Operand should contain 1 column(s)") return retNode, false } function := er.notToScalarFunc(v.Not, ast.IsNull, v.Type, er.ctxStack[stkLen-1]) er.ctxStack = er.ctxStack[:stkLen-1] er.ctxStack = append(er.ctxStack, function) case *ast.IsTruthExpr: op := ast.IsTruth if v.True == 0 { op = ast.IsFalsity } function := er.notToScalarFunc(v.Not, op, v.Type, er.ctxStack[stkLen-1]) er.ctxStack = er.ctxStack[:stkLen-1] er.ctxStack = append(er.ctxStack, function) case *ast.BinaryOperationExpr: var function expression.Expression switch v.Op { case opcode.EQ, opcode.NE, opcode.NullEQ: var err error function, err = constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1], opcode.Ops[v.Op]) if err != nil { er.err = errors.Trace(err) return retNode, false } default: function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...) } er.ctxStack = er.ctxStack[:stkLen-2] er.ctxStack = append(er.ctxStack, function) case *ast.BetweenExpr: er.betweenToScalarFunc(v) case *ast.PatternLikeExpr: er.likeToScalarFunc(v) case *ast.PatternInExpr: if v.Sel == nil { er.inToScalarFunc(v) } case *ast.UnaryOperationExpr: if getRowLen(er.ctxStack[stkLen-1]) != 1 { er.err = errors.New("Operand should contain 1 column(s)") return retNode, false } function := expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-1]) er.ctxStack = er.ctxStack[:stkLen-1] er.ctxStack = append(er.ctxStack, function) default: er.err = errors.Errorf("UnknownType: %T", v) return retNode, false } if er.err != nil { return retNode, false } return inNode, true }