Ejemplo n.º 1
0
func (er *expressionRewriter) betweenToScalarFunc(v *ast.BetweenExpr) {
	stkLen := len(er.ctxStack)
	var op string
	var l, r *expression.ScalarFunction
	if v.Not {
		l, er.err = expression.NewFunction(ast.LT, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2])
		if er.err == nil {
			r, er.err = expression.NewFunction(ast.GT, v.Type, er.ctxStack[stkLen-3].DeepCopy(), er.ctxStack[stkLen-1])
		}
		op = ast.OrOr
	} else {
		l, er.err = expression.NewFunction(ast.GE, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2])
		if er.err == nil {
			r, er.err = expression.NewFunction(ast.LE, v.Type, er.ctxStack[stkLen-3].DeepCopy(), er.ctxStack[stkLen-1])
		}
		op = ast.AndAnd
	}
	if er.err != nil {
		er.err = errors.Trace(er.err)
		return
	}
	function, err := expression.NewFunction(op, v.Type, l, r)
	if err != nil {
		er.err = errors.Trace(err)
		return
	}
	er.ctxStack = er.ctxStack[:stkLen-3]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 2
0
func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) {
	stkLen := len(er.ctxStack)
	er.checkArgsOneColumn(er.ctxStack[stkLen-3:]...)
	if er.err != nil {
		return
	}
	var op string
	var l, r expression.Expression
	l, er.err = expression.NewFunction(ast.GE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2])
	if er.err == nil {
		r, er.err = expression.NewFunction(ast.LE, &v.Type, er.ctxStack[stkLen-3].Clone(), er.ctxStack[stkLen-1])
	}
	op = ast.AndAnd
	if er.err != nil {
		er.err = errors.Trace(er.err)
		return
	}
	function, err := expression.NewFunction(op, &v.Type, l, r)
	if err != nil {
		er.err = errors.Trace(err)
		return
	}
	if v.Not {
		function, err = expression.NewFunction(ast.UnaryNot, &v.Type, function)
		if err != nil {
			er.err = errors.Trace(err)
			return
		}
	}
	er.ctxStack = er.ctxStack[:stkLen-3]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 3
0
func (er *expressionRewriter) notToScalarFunc(b bool, op string, tp *types.FieldType,
	args ...expression.Expression) *expression.ScalarFunction {
	opFunc := expression.NewFunction(op, tp, args...)
	if !b {
		return opFunc
	}
	return expression.NewFunction(ast.UnaryNot, tp, opFunc)
}
Ejemplo n.º 4
0
func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Node, bool) {
	np := er.buildSubquery(v)
	if er.err != nil {
		return v, true
	}
	np = er.b.buildMaxOneRow(np)
	if np.IsCorrelated() {
		er.p = er.b.buildApply(er.p, np, nil)
		if len(np.GetSchema()) > 1 {
			newCols := make([]expression.Expression, 0, len(np.GetSchema()))
			for _, col := range np.GetSchema() {
				newCols = append(newCols, col.Clone())
			}
			expr, err := expression.NewFunction(ast.RowFunc, nil, newCols...)
			if err != nil {
				er.err = errors.Trace(err)
				return v, true
			}
			er.ctxStack = append(er.ctxStack, expr)
		} else {
			er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1])
		}
		return v, true
	}
	physicalPlan, err := doOptimize(np, er.b.ctx, er.b.allocator)
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	d, err := EvalSubquery(physicalPlan, er.b.is, er.b.ctx)
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	if len(np.GetSchema()) > 1 {
		newCols := make([]expression.Expression, 0, len(np.GetSchema()))
		for i, data := range d {
			newCols = append(newCols, &expression.Constant{
				Value:   data,
				RetType: np.GetSchema()[i].GetType()})
		}
		expr, err1 := expression.NewFunction(ast.RowFunc, nil, newCols...)
		if err1 != nil {
			er.err = errors.Trace(err1)
			return v, true
		}
		er.ctxStack = append(er.ctxStack, expr)
	} else {
		er.ctxStack = append(er.ctxStack, &expression.Constant{
			Value:   d[0],
			RetType: np.GetSchema()[0].GetType(),
		})
	}
	return v, true
}
Ejemplo n.º 5
0
func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) {
	stkLen := len(er.ctxStack)
	argsLen := 2 * len(v.WhenClauses)
	if v.ElseClause != nil {
		argsLen++
	}
	er.checkArgsOneColumn(er.ctxStack[stkLen-argsLen:]...)
	if er.err != nil {
		return
	}

	// value                          -> ctxStack[stkLen-argsLen-1]
	// when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1];
	// else clause                    -> ctxStack[stkLen-1]
	var args []expression.Expression
	if v.Value != nil {
		// args:  eq scalar func(args: value, condition1), result1,
		//        eq scalar func(args: value, condition2), result2,
		//        ...
		//        else clause
		value := er.ctxStack[stkLen-argsLen-1]
		args = make([]expression.Expression, 0, argsLen)
		for i := stkLen - argsLen; i < stkLen-1; i += 2 {
			arg, err := expression.NewFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), value.Clone(), er.ctxStack[i])
			if err != nil {
				er.err = errors.Trace(err)
				return
			}
			args = append(args, arg)
			args = append(args, er.ctxStack[i+1])
		}
		if v.ElseClause != nil {
			args = append(args, er.ctxStack[stkLen-1])
		}
		argsLen++ // for trimming the value element later
	} else {
		// args:  condition1, result1,
		//        condition2, result2,
		//        ...
		//        else clause
		args = er.ctxStack[stkLen-argsLen:]
	}
	function, err := expression.NewFunction(ast.Case, &v.Type, args...)
	if err != nil {
		er.err = errors.Trace(err)
		return
	}
	er.ctxStack = er.ctxStack[:stkLen-argsLen]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 6
0
func (er *expressionRewriter) funcCallToScalarFunc(v *ast.FuncCallExpr) {
	stackLen := len(er.ctxStack)
	var function *expression.ScalarFunction
	function, er.err = expression.NewFunction(v.FnName.L, v.Type, er.ctxStack[stackLen-len(v.Args):]...)
	er.ctxStack = er.ctxStack[:stackLen-len(v.Args)]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 7
0
func (er *expressionRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) {
	stkLen := len(er.ctxStack)
	var function expression.Expression
	switch v.Op {
	case opcode.EQ, opcode.NE, opcode.NullEQ:
		function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
			opcode.Ops[v.Op])
	default:
		lLen := getRowLen(er.ctxStack[stkLen-2])
		rLen := getRowLen(er.ctxStack[stkLen-1])
		switch v.Op {
		case opcode.GT, opcode.GE, opcode.LT, opcode.LE:
			if lLen != rLen {
				er.err = ErrOperandColumns.GenByArgs(lLen)
			}
		default:
			if lLen != 1 || rLen != 1 {
				er.err = ErrOperandColumns.GenByArgs(1)
			}
		}
		if er.err != nil {
			return
		}
		function, er.err = expression.NewFunction(opcode.Ops[v.Op], &v.Type, er.ctxStack[stkLen-2:]...)
	}
	if er.err != nil {
		er.err = errors.Trace(er.err)
		return
	}
	er.ctxStack = er.ctxStack[:stkLen-2]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 8
0
func pushDownNot(expr expression.Expression, not bool) expression.Expression {
	if f, ok := expr.(*expression.ScalarFunction); ok {
		switch f.FuncName.L {
		case ast.UnaryNot:
			return pushDownNot(f.Args[0], !not)
		case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE:
			if not {
				nf, _ := expression.NewFunction(oppositeOp[f.FuncName.L], f.GetType(), f.Args...)
				return nf
			}
			for i, arg := range f.Args {
				f.Args[i] = pushDownNot(arg, false)
			}
			return f
		case ast.AndAnd:
			if not {
				args := f.Args
				for i, a := range args {
					args[i] = pushDownNot(a, true)
				}
				nf, _ := expression.NewFunction(ast.OrOr, f.GetType(), args...)
				return nf
			}
			for i, arg := range f.Args {
				f.Args[i] = pushDownNot(arg, false)
			}
			return f
		case ast.OrOr:
			if not {
				args := f.Args
				for i, a := range args {
					args[i] = pushDownNot(a, true)
				}
				nf, _ := expression.NewFunction(ast.AndAnd, f.GetType(), args...)
				return nf
			}
			for i, arg := range f.Args {
				f.Args[i] = pushDownNot(arg, false)
			}
			return f
		}
	}
	if not {
		expr, _ = expression.NewFunction(ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr)
	}
	return expr
}
Ejemplo n.º 9
0
func (er *expressionRewriter) betweenToScalarFunc(v *ast.BetweenExpr) {
	stkLen := len(er.ctxStack)
	var op string
	var l, r *expression.ScalarFunction
	if v.Not {
		l = expression.NewFunction(ast.LT, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2])
		r = expression.NewFunction(ast.GT, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1])
		op = ast.OrOr
	} else {
		l = expression.NewFunction(ast.GE, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2])
		r = expression.NewFunction(ast.LE, v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1])
		op = ast.AndAnd
	}
	function := expression.NewFunction(op, v.Type, l, r)
	er.ctxStack = er.ctxStack[:stkLen-3]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 10
0
func (er *expressionRewriter) notToScalarFunc(hasNot bool, op string, tp *types.FieldType,
	args ...expression.Expression) *expression.ScalarFunction {
	opFunc, err := expression.NewFunction(op, tp, args...)
	if err != nil {
		er.err = errors.Trace(err)
		return nil
	}
	if !hasNot {
		return opFunc
	}

	opFunc, err = expression.NewFunction(ast.UnaryNot, tp, opFunc)
	if err != nil {
		er.err = errors.Trace(err)
		return nil
	}
	return opFunc
}
Ejemplo n.º 11
0
func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
	stackLen := len(er.ctxStack)
	args := er.ctxStack[stackLen-len(v.Args):]
	er.checkArgsOneColumn(args...)
	if er.err != nil {
		return
	}
	var function expression.Expression
	function, er.err = expression.NewFunction(v.FnName.L, &v.Type, args...)
	er.ctxStack = er.ctxStack[:stackLen-len(v.Args)]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 12
0
// compose CNF items into a balance deep CNF tree, which benefits a lot for pb decoder/encoder.
func composeCondition(conditions []expression.Expression) expression.Expression {
	length := len(conditions)
	if length == 0 {
		return nil
	}
	if length == 1 {
		return conditions[0]
	}
	return expression.NewFunction(model.NewCIStr(ast.AndAnd),
		[]expression.Expression{composeCondition(conditions[0 : length/2]),
			composeCondition(conditions[length/2:])})
}
Ejemplo n.º 13
0
func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) {
	v.Expr.Accept(er)
	if er.err != nil {
		return v, true
	}
	lexpr := er.ctxStack[len(er.ctxStack)-1]
	subq, ok := v.Sel.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown compare type %T.", v.Sel)
		return v, true
	}
	np, outerSchema := er.buildSubquery(subq)
	if er.err != nil {
		return v, true
	}
	if getRowLen(lexpr) != len(np.GetSchema()) {
		er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr))
		return v, true
	}
	var rexpr expression.Expression
	if len(np.GetSchema()) == 1 {
		rexpr = np.GetSchema()[0].DeepCopy()
	} else {
		args := make([]expression.Expression, 0, len(np.GetSchema()))
		for _, col := range np.GetSchema() {
			args = append(args, col.DeepCopy())
		}
		rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...)
		if er.err != nil {
			er.err = errors.Trace(er.err)
			return v, true
		}
	}
	// a in (subq) will be rewrited as a = any(subq).
	// a not in (subq) will be rewrited as a != all(subq).
	op, all := ast.EQ, false
	if v.Not {
		op, all = ast.NE, true
	}
	checkCondition, err := constructBinaryOpFunction(lexpr, rexpr, op)
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: all})
	if er.p.IsCorrelated() {
		er.correlated = true
	}
	// The parent expression only use the last column in schema, which represents whether the condition is matched.
	er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1]
	return v, true

}
Ejemplo n.º 14
0
func (er *expressionRewriter) rowToScalarFunc(v *ast.RowExpr) {
	stkLen := len(er.ctxStack)
	length := len(v.Values)
	rows := make([]expression.Expression, 0, length)
	for i := stkLen - length; i < stkLen; i++ {
		rows = append(rows, er.ctxStack[i])
	}
	er.ctxStack = er.ctxStack[:stkLen-length]
	function, err := expression.NewFunction(ast.RowFunc, nil, rows...)
	if err != nil {
		er.err = errors.Trace(err)
		return
	}
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 15
0
// constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2).
func constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
	lLen, rLen := getRowLen(l), getRowLen(r)
	if lLen == 1 && rLen == 1 {
		return expression.NewFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
	} else if rLen != lLen {
		return nil, errors.Errorf("Operand should contain %d column(s)", lLen)
	}
	funcs := make([]expression.Expression, lLen)
	for i := 0; i < lLen; i++ {
		var err error
		funcs[i], err = constructBinaryOpFunction(getRowArg(l, i), getRowArg(r, i), op)
		if err != nil {
			return nil, err
		}
	}
	return expression.ComposeCNFCondition(funcs), nil
}
Ejemplo n.º 16
0
// constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2).
func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
	lLen, rLen := getRowLen(l), getRowLen(r)
	if lLen == 1 && rLen == 1 {
		return expression.NewFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
	} else if rLen != lLen {
		return nil, ErrOperandColumns.GenByArgs(lLen)
	}
	funcs := make([]expression.Expression, lLen)
	for i := 0; i < lLen; i++ {
		var err error
		funcs[i], err = er.constructBinaryOpFunction(getRowArg(l, i), getRowArg(r, i), op)
		if err != nil {
			return nil, errors.Trace(err)
		}
	}
	return expression.ComposeCNFCondition(funcs), nil
}
Ejemplo n.º 17
0
// constantSubstitute substitute column expression in a condition by an equivalent constant.
func constantSubstitute(equalities map[string]*expression.Constant, condition expression.Expression) expression.Expression {
	switch expr := condition.(type) {
	case *expression.Column:
		if v, ok := equalities[string(expr.HashCode())]; ok {
			return v
		}
	case *expression.ScalarFunction:
		for i, arg := range expr.Args {
			expr.Args[i] = constantSubstitute(equalities, arg)
		}
		if _, ok := evaluator.Funcs[expr.FuncName.L]; ok {
			condition, _ = expression.NewFunction(expr.FuncName.L, expr.RetType, expr.Args...)
		}
		return condition
	}
	return condition
}
Ejemplo n.º 18
0
func (er *expressionRewriter) binaryOpToScalarFunc(v *ast.BinaryOperationExpr) {
	stkLen := len(er.ctxStack)
	var function expression.Expression
	switch v.Op {
	case opcode.EQ, opcode.NE, opcode.NullEQ:
		function, er.err = constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
			opcode.Ops[v.Op])
	default:
		function, er.err = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...)
	}
	if er.err != nil {
		er.err = errors.Trace(er.err)
		return
	}
	er.ctxStack = er.ctxStack[:stkLen-2]
	er.ctxStack = append(er.ctxStack, function)
}
Ejemplo n.º 19
0
func (er *expressionRewriter) unaryOpToScalarFunc(v *ast.UnaryOperationExpr) {
	stkLen := len(er.ctxStack)
	if getRowLen(er.ctxStack[stkLen-1]) != 1 {
		er.err = errors.New("Operand should contain 1 column(s)")
	}
	var op string
	switch v.Op {
	case opcode.Plus:
		op = ast.UnaryPlus
	case opcode.Minus:
		op = ast.UnaryMinus
	case opcode.BitNeg:
		op = ast.BitNeg
	case opcode.Not:
		op = ast.UnaryNot
	default:
		er.err = errors.Errorf("Unknown Unary Op %T", v.Op)
		return
	}
	er.ctxStack[stkLen-1], er.err = expression.NewFunction(op, v.Type, er.ctxStack[stkLen-1])
}
Ejemplo n.º 20
0
func extractOnCondition(conditions []expression.Expression, left LogicalPlan, right LogicalPlan) (
	eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression,
	otherCond []expression.Expression) {
	for _, expr := range conditions {
		binop, ok := expr.(*expression.ScalarFunction)
		if ok && binop.FuncName.L == ast.EQ {
			ln, lOK := binop.Args[0].(*expression.Column)
			rn, rOK := binop.Args[1].(*expression.Column)
			if lOK && rOK {
				if left.GetSchema().GetIndex(ln) != -1 && right.GetSchema().GetIndex(rn) != -1 {
					eqCond = append(eqCond, binop)
					continue
				}
				if left.GetSchema().GetIndex(rn) != -1 && right.GetSchema().GetIndex(ln) != -1 {
					cond, _ := expression.NewFunction(ast.EQ, types.NewFieldType(mysql.TypeTiny), rn, ln)
					eqCond = append(eqCond, cond.(*expression.ScalarFunction))
					continue
				}
			}
		}
		columns, _ := extractColumn(expr, nil, nil)
		allFromLeft, allFromRight := true, true
		for _, col := range columns {
			if left.GetSchema().GetIndex(col) == -1 {
				allFromLeft = false
			}
			if right.GetSchema().GetIndex(col) == -1 {
				allFromRight = false
			}
		}
		if allFromRight {
			rightCond = append(rightCond, expr)
		} else if allFromLeft {
			leftCond = append(leftCond, expr)
		} else {
			otherCond = append(otherCond, expr)
		}
	}
	return
}
Ejemplo n.º 21
0
func extractOnCondition(conditions []expression.Expression, left Plan, right Plan) (
	eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression,
	otherCond []expression.Expression) {
	for _, expr := range conditions {
		binop, ok := expr.(*expression.ScalarFunction)
		if ok && binop.FuncName.L == ast.EQ {
			ln, lOK := binop.Args[0].(*expression.Column)
			rn, rOK := binop.Args[1].(*expression.Column)
			if lOK && rOK {
				if left.GetSchema().GetIndex(ln) != -1 && right.GetSchema().GetIndex(rn) != -1 {
					eqCond = append(eqCond, binop)
					continue
				}
				if left.GetSchema().GetIndex(rn) != -1 && right.GetSchema().GetIndex(ln) != -1 {
					newEq := expression.NewFunction(model.NewCIStr(ast.EQ), []expression.Expression{rn, ln})
					eqCond = append(eqCond, newEq)
					continue
				}
			}
		}
		columns, _ := extractColumn(expr, nil, nil)
		allFromLeft, allFromRight := true, true
		for _, col := range columns {
			if left.GetSchema().GetIndex(col) != -1 {
				allFromRight = false
			} else {
				allFromLeft = false
			}
		}
		if allFromRight {
			rightCond = append(rightCond, expr)
		} else if allFromLeft {
			leftCond = append(leftCond, expr)
		} else {
			otherCond = append(otherCond, expr)
		}
	}
	return
}
Ejemplo n.º 22
0
// calculateResultOfExpression set inner table columns in a expression as null and calculate the finally result of the scalar function.
func calculateResultOfExpression(schema expression.Schema, expr expression.Expression) (expression.Expression, error) {
	switch x := expr.(type) {
	case *expression.ScalarFunction:
		var err error
		args := make([]expression.Expression, len(x.Args))
		for i, arg := range x.Args {
			args[i], err = calculateResultOfExpression(schema, arg)
		}
		if err != nil {
			return nil, errors.Trace(err)
		}
		return expression.NewFunction(x.FuncName.L, types.NewFieldType(mysql.TypeTiny), args...)
	case *expression.Column:
		if schema.GetIndex(x) == -1 {
			return x, nil
		}
		constant := &expression.Constant{Value: types.Datum{}}
		constant.Value.SetNull()
		return constant, nil
	default:
		return x.DeepCopy(), nil
	}
}
Ejemplo n.º 23
0
func (er *expressionRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) {
	stkLen := len(er.ctxStack)
	var op string
	switch v.Op {
	case opcode.Plus:
		// expression (+ a) is equal to a
		return
	case opcode.Minus:
		op = ast.UnaryMinus
	case opcode.BitNeg:
		op = ast.BitNeg
	case opcode.Not:
		op = ast.UnaryNot
	default:
		er.err = errors.Errorf("Unknown Unary Op %T", v.Op)
		return
	}
	if getRowLen(er.ctxStack[stkLen-1]) != 1 {
		er.err = ErrOperandColumns.GenByArgs(1)
		return
	}
	er.ctxStack[stkLen-1], er.err = expression.NewFunction(op, &v.Type, er.ctxStack[stkLen-1])
}
Ejemplo n.º 24
0
// propagateConstant propagate constant values of equality predicates and inequality predicates in a condition.
func propagateConstant(conditions []expression.Expression) []expression.Expression {
	if len(conditions) == 0 {
		return conditions
	}
	// Propagate constants in equality predicates.
	// e.g. for condition: "a = b and b = c and c = a and a = 1";
	// we propagate constant as the following step:
	// first: "1 = b and b = c and c = 1 and a = 1";
	// next:  "1 = b and 1 = c and c = 1 and a = 1";
	// next:  "1 = b and 1 = c and 1 = 1 and a = 1";
	// next:  "1 = b and 1 = c and a = 1";

	// e.g for condition: "a = b and b = c and b = 2 and a = 1";
	// we propagate constant as the following step:
	// first: "a = 2 and 2 = c and b = 2 and a = 1";
	// next:  "a = 2 and 2 = c and b = 2 and 2 = 1";
	// next:  "0"
	isSource := make([]bool, len(conditions))
	type transitiveEqualityPredicate map[string]*expression.Constant // transitive equality predicates between one column and one constant
	for {
		equalities := make(transitiveEqualityPredicate, 0)
		for i, getOneEquality := 0, false; i < len(conditions) && !getOneEquality; i++ {
			if isSource[i] {
				continue
			}
			expr, ok := conditions[i].(*expression.ScalarFunction)
			if !ok {
				continue
			}
			// process the included OR conditions recursively to do the same for CNF item.
			switch expr.FuncName.L {
			case ast.OrOr:
				expressions := expression.SplitDNFItems(conditions[i])
				newExpression := make([]expression.Expression, 0)
				for _, v := range expressions {
					newExpression = append(newExpression, propagateConstant([]expression.Expression{v})...)
				}
				conditions[i] = expression.ComposeDNFCondition(newExpression)
				isSource[i] = true
			case ast.AndAnd:
				newExpression := propagateConstant(expression.SplitCNFItems(conditions[i]))
				conditions[i] = expression.ComposeCNFCondition(newExpression)
				isSource[i] = true
			case ast.EQ:
				var (
					col *expression.Column
					val *expression.Constant
				)
				leftConst, leftIsConst := expr.Args[0].(*expression.Constant)
				rightConst, rightIsConst := expr.Args[1].(*expression.Constant)
				leftCol, leftIsCol := expr.Args[0].(*expression.Column)
				rightCol, rightIsCol := expr.Args[1].(*expression.Column)
				if rightIsConst && leftIsCol {
					col = leftCol
					val = rightConst
				} else if leftIsConst && rightIsCol {
					col = rightCol
					val = leftConst
				} else {
					continue
				}
				equalities[string(col.HashCode())] = val
				isSource[i] = true
				getOneEquality = true
			}
		}
		if len(equalities) == 0 {
			break
		}
		for i := 0; i < len(conditions); i++ {
			if isSource[i] {
				continue
			}
			if len(equalities) != 0 {
				conditions[i] = constantSubstitute(equalities, conditions[i])
			}
		}
	}
	// Propagate transitive inequality predicates.
	// e.g for conditions "a = b and c = d and a = c and g = h and b > 0 and e != 0 and g like 'abc'",
	//     we propagate constant as the following step:
	// 1. build multiple equality predicates(mep):
	//    =(a, b, c, d), =(g, h).
	// 2. extract inequality predicates between one constant and one column,
	//    and rewrite them using the root column of a multiple equality predicate:
	//    b > 0, e != 0, g like 'abc' ==> a > 0, g like 'abc'.
	//    ATTENTION: here column 'e' doesn't belong to any mep, so we skip "e != 0".
	// 3. propagate constants in these inequality predicates, and we finally get:
	//    "a = b and c = d and a = c and e = f and g = h and e != 0 and a > 0 and b > 0 and c > 0 and d > 0 and g like 'abc' and h like 'abc' ".
	multipleEqualities := make(map[*expression.Column]*expression.Column, 0)
	for _, cond := range conditions { // build multiple equality predicates.
		expr, ok := cond.(*expression.ScalarFunction)
		if ok && expr.FuncName.L == ast.EQ {
			left, ok1 := expr.Args[0].(*expression.Column)
			right, ok2 := expr.Args[1].(*expression.Column)
			if ok1 && ok2 {
				UnionColumns(left, right, multipleEqualities)
			}
		}
	}
	if len(multipleEqualities) == 0 {
		return conditions
	}
	inequalityFuncs := map[string]string{
		ast.LT:   ast.LT,
		ast.GT:   ast.GT,
		ast.LE:   ast.LE,
		ast.GE:   ast.GE,
		ast.NE:   ast.NE,
		ast.Like: ast.Like,
	}
	type inequalityFactor struct {
		FuncName string
		Factor   []*expression.Constant
	}
	type transitiveInEqualityPredicate map[string][]inequalityFactor // transitive inequality predicates between one column and one constant.
	inequalities := make(transitiveInEqualityPredicate, 0)
	for i := 0; i < len(conditions); i++ { // extract inequality predicates.
		var (
			column   *expression.Column
			equalCol *expression.Column // the root column corresponding to a column in a multiple equality predicate.
			val      *expression.Constant
			funcName string
		)
		expr, ok := conditions[i].(*expression.ScalarFunction)
		if !ok {
			continue
		}
		funcName, ok = inequalityFuncs[expr.FuncName.L]
		if !ok {
			continue
		}
		leftConst, leftIsConst := expr.Args[0].(*expression.Constant)
		rightConst, rightIsConst := expr.Args[1].(*expression.Constant)
		leftCol, leftIsCol := expr.Args[0].(*expression.Column)
		rightCol, rightIsCol := expr.Args[1].(*expression.Column)
		if rightIsConst && leftIsCol {
			column = leftCol
			val = rightConst
		} else if leftIsConst && rightIsCol {
			column = rightCol
			val = leftConst
		} else {
			continue
		}
		equalCol, ok = multipleEqualities[column]
		if !ok { // no need to propagate inequality predicates whose column is only equal to itself.
			continue
		}
		colHashCode := string(equalCol.HashCode())
		if funcName == ast.Like { // func 'LIKE' need 3 input arguments, so here we handle it alone.
			inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: ast.Like, Factor: []*expression.Constant{val, expr.Args[2].(*expression.Constant)}})
		} else {
			inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: funcName, Factor: []*expression.Constant{val}})
		}
		conditions = append(conditions[:i], conditions[i+1:]...)
		i--
	}
	if len(inequalities) == 0 {
		return conditions
	}
	for k, v := range multipleEqualities { // propagate constants in inequality predicates.
		for _, x := range inequalities[string(v.HashCode())] {
			funcName, factors := x.FuncName, x.Factor
			if funcName == ast.Like {
				for i := 0; i < len(factors); i += 2 {
					newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i], factors[i+1])
					conditions = append(conditions, newFunc)
				}
			} else {
				for i := 0; i < len(factors); i++ {
					newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i])
					conditions = append(conditions, newFunc)
					i++
				}
			}
		}
	}
	return conditions
}
Ejemplo n.º 25
0
// Enter implements Visitor interface.
func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
	switch v := inNode.(type) {
	case *ast.AggregateFuncExpr:
		index, ok := -1, false
		if er.aggrMap != nil {
			index, ok = er.aggrMap[v]
		}
		if !ok {
			er.err = errors.New("Can't appear aggrFunctions")
			return inNode, true
		}
		er.ctxStack = append(er.ctxStack, er.schema[index])
		return inNode, true
	case *ast.ColumnNameExpr:
		if index, ok := er.b.colMapper[v]; ok {
			er.ctxStack = append(er.ctxStack, er.schema[index])
			return inNode, true
		}
	case *ast.CompareSubqueryExpr:
		v.L.Accept(er)
		if er.err != nil {
			return inNode, true
		}
		lexpr := er.ctxStack[len(er.ctxStack)-1]
		subq, ok := v.R.(*ast.SubqueryExpr)
		if !ok {
			er.err = errors.Errorf("Unknown compare type %T.", v.R)
			return inNode, true
		}
		np, outerSchema := er.buildSubquery(subq)
		if er.err != nil {
			return inNode, true
		}
		// Only (a,b,c) = all (...) and (a,b,c) != any () can use row expression.
		canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE)
		if !canMultiCol && (getRowLen(lexpr) != 1 || len(np.GetSchema()) != 1) {
			er.err = errors.New("Operand should contain 1 column(s)")
			return inNode, true
		}
		if getRowLen(lexpr) != len(np.GetSchema()) {
			er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr))
			return inNode, true
		}
		var checkCondition expression.Expression
		var rexpr expression.Expression
		if len(np.GetSchema()) == 1 {
			rexpr = np.GetSchema()[0].DeepCopy()
		} else {
			args := make([]expression.Expression, 0, len(np.GetSchema()))
			for _, col := range np.GetSchema() {
				args = append(args, col.DeepCopy())
			}
			rexpr, er.err = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...)
			if er.err != nil {
				er.err = errors.Trace(er.err)
				return inNode, true
			}
		}
		switch v.Op {
		// Only EQ, NE and NullEQ can be composed with and.
		case opcode.EQ, opcode.NE, opcode.NullEQ:
			checkCondition, er.err = constructBinaryOpFunction(lexpr, rexpr, opcode.Ops[v.Op])
			if er.err != nil {
				er.err = errors.Trace(er.err)
				return inNode, true
			}
		// If op is not EQ, NE, NullEQ, say LT, it will remain as row(a,b) < row(c,d), and be compared as row datum.
		default:
			checkCondition, er.err = expression.NewFunction(opcode.Ops[v.Op],
				types.NewFieldType(mysql.TypeTiny), lexpr, rexpr)
			if er.err != nil {
				er.err = errors.Trace(er.err)
				return inNode, true
			}
		}
		er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: v.All})
		// The parent expression only use the last column in schema, which represents whether the condition is matched.
		er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1]
		return inNode, true
	case *ast.ExistsSubqueryExpr:
		subq, ok := v.Sel.(*ast.SubqueryExpr)
		if !ok {
			er.err = errors.Errorf("Unknown exists type %T.", v.Sel)
			return inNode, true
		}
		np, outerSchema := er.buildSubquery(subq)
		if er.err != nil {
			return inNode, true
		}
		np = er.b.buildExists(np)
		if np.IsCorrelated() {
			er.p = er.b.buildApply(er.p, np, outerSchema, nil)
			er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1])
		} else {
			_, err := np.PruneColumnsAndResolveIndices(np.GetSchema())
			if err != nil {
				er.err = errors.Trace(err)
				return inNode, true
			}
			d, err := EvalSubquery(np, er.b.is, er.b.ctx)
			if err != nil {
				er.err = errors.Trace(err)
				return inNode, true
			}
			er.ctxStack = append(er.ctxStack, &expression.Constant{
				Value:   d[0],
				RetType: np.GetSchema()[0].GetType()})
		}
		return inNode, true
	case *ast.PatternInExpr:
		if v.Sel != nil {
			v.Expr.Accept(er)
			if er.err != nil {
				return inNode, true
			}
			lexpr := er.ctxStack[len(er.ctxStack)-1]
			subq, ok := v.Sel.(*ast.SubqueryExpr)
			if !ok {
				er.err = errors.Errorf("Unknown compare type %T.", v.Sel)
				return inNode, true
			}
			np, outerSchema := er.buildSubquery(subq)
			if er.err != nil {
				return inNode, true
			}
			if getRowLen(lexpr) != len(np.GetSchema()) {
				er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr))
				return inNode, true
			}
			var rexpr expression.Expression
			if len(np.GetSchema()) == 1 {
				rexpr = np.GetSchema()[0].DeepCopy()
			} else {
				args := make([]expression.Expression, 0, len(np.GetSchema()))
				for _, col := range np.GetSchema() {
					args = append(args, col.DeepCopy())
				}
				rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...)
				if er.err != nil {
					er.err = errors.Trace(er.err)
					return inNode, true
				}
			}
			// a in (subq) will be rewrited as a = any(subq).
			// a not in (subq) will be rewrited as a != all(subq).
			op, all := ast.EQ, false
			if v.Not {
				op, all = ast.NE, true
			}
			checkCondition, err := constructBinaryOpFunction(lexpr, rexpr, op)
			if err != nil {
				er.err = errors.Trace(err)
				return inNode, true
			}
			er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: all})
			// The parent expression only use the last column in schema, which represents whether the condition is matched.
			er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1]
			return inNode, true
		}
	case *ast.SubqueryExpr:
		np, outerSchema := er.buildSubquery(v)
		if er.err != nil {
			return inNode, true
		}
		np = er.b.buildMaxOneRow(np)
		if np.IsCorrelated() {
			er.p = er.b.buildApply(er.p, np, outerSchema, nil)
			if len(np.GetSchema()) > 1 {
				newCols := make([]expression.Expression, 0, len(np.GetSchema()))
				for _, col := range np.GetSchema() {
					newCols = append(newCols, col.DeepCopy())
				}
				expr, err := expression.NewFunction(ast.RowFunc, nil, newCols...)
				if err != nil {
					er.err = errors.Trace(err)
					return inNode, true
				}
				er.ctxStack = append(er.ctxStack, expr)

			} else {
				er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1])
			}
			return inNode, true
		}
		_, err := np.PruneColumnsAndResolveIndices(np.GetSchema())
		if err != nil {
			er.err = errors.Trace(err)
			return inNode, true
		}
		d, err := EvalSubquery(np, er.b.is, er.b.ctx)
		if err != nil {
			er.err = errors.Trace(err)
			return inNode, true
		}
		if len(np.GetSchema()) > 1 {
			newCols := make([]expression.Expression, 0, len(np.GetSchema()))
			for i, data := range d {
				newCols = append(newCols, &expression.Constant{
					Value:   data,
					RetType: np.GetSchema()[i].GetType()})
			}
			expr, err1 := expression.NewFunction(ast.RowFunc, nil, newCols...)
			if err1 != nil {
				er.err = errors.Trace(err1)
				return inNode, true
			}
			er.ctxStack = append(er.ctxStack, expr)
		} else {
			er.ctxStack = append(er.ctxStack, &expression.Constant{
				Value:   d[0],
				RetType: np.GetSchema()[0].GetType(),
			})
		}
		return inNode, true
	}
	return inNode, false
}
Ejemplo n.º 26
0
func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) {
	stkLen := len(er.ctxStack)
	name := strings.ToLower(v.Name)
	sessionVars := variable.GetSessionVars(er.b.ctx)
	globalVars := variable.GetGlobalVarAccessor(er.b.ctx)
	if !v.IsSystem {
		if v.Value != nil {
			er.ctxStack[stkLen-1], er.err = expression.NewFunction(ast.SetVar,
				er.ctxStack[stkLen-1].GetType(),
				datumToConstant(types.NewDatum(name), mysql.TypeString),
				er.ctxStack[stkLen-1])
			return
		}
		if _, ok := sessionVars.Users[name]; ok {
			f, err := expression.NewFunction(ast.GetVar,
				// TODO: Here is wrong, the sessionVars should store a name -> Datum map. Will fix it later.
				types.NewFieldType(mysql.TypeString),
				datumToConstant(types.NewStringDatum(name), mysql.TypeString))
			if err != nil {
				er.err = errors.Trace(err)
				return
			}
			er.ctxStack = append(er.ctxStack, f)
		} else {
			// select null user vars is permitted.
			er.ctxStack = append(er.ctxStack, &expression.Constant{RetType: types.NewFieldType(mysql.TypeNull)})
		}
		return
	}

	sysVar, ok := variable.SysVars[name]
	if !ok {
		// select null sys vars is not permitted
		er.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
		return
	}
	if sysVar.Scope == variable.ScopeNone {
		er.ctxStack = append(er.ctxStack, datumToConstant(types.NewDatum(sysVar.Value), mysql.TypeString))
		return
	}

	if v.IsGlobal {
		value, err := globalVars.GetGlobalSysVar(er.b.ctx, name)
		if err != nil {
			er.err = errors.Trace(err)
			return
		}
		er.ctxStack = append(er.ctxStack, datumToConstant(types.NewDatum(value), mysql.TypeString))
		return
	}
	d := sessionVars.GetSystemVar(name)
	if d.IsNull() {
		if sysVar.Scope&variable.ScopeGlobal == 0 {
			d.SetString(sysVar.Value)
		} else {
			// Get global system variable and fill it in session.
			globalVal, err := globalVars.GetGlobalSysVar(er.b.ctx, name)
			if err != nil {
				er.err = errors.Trace(err)
				return
			}
			d.SetString(globalVal)
			err = sessionVars.SetSystemVar(name, d)
			if err != nil {
				er.err = errors.Trace(err)
				return
			}
		}
	}
	er.ctxStack = append(er.ctxStack, datumToConstant(d, mysql.TypeString))
	return
}
Ejemplo n.º 27
0
func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Node, bool) {
	np, outerSchema := er.buildSubquery(v)
	if er.err != nil {
		return v, true
	}
	np = er.b.buildMaxOneRow(np)
	if np.IsCorrelated() {
		er.p = er.b.buildApply(er.p, np, outerSchema, nil)
		if er.p.IsCorrelated() {
			er.correlated = true
		}
		if len(np.GetSchema()) > 1 {
			newCols := make([]expression.Expression, 0, len(np.GetSchema()))
			for _, col := range np.GetSchema() {
				newCols = append(newCols, col.DeepCopy())
			}
			expr, err := expression.NewFunction(ast.RowFunc, nil, newCols...)
			if err != nil {
				er.err = errors.Trace(err)
				return v, true
			}
			er.ctxStack = append(er.ctxStack, expr)
		} else {
			er.ctxStack = append(er.ctxStack, er.p.GetSchema()[len(er.p.GetSchema())-1])
		}
		return v, true
	}
	_, np, er.err = np.PredicatePushDown(nil)
	if er.err != nil {
		return v, true
	}
	_, err := np.PruneColumnsAndResolveIndices(np.GetSchema())
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	_, res, _, err := np.convert2PhysicalPlan(nil)
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	phyPlan := res.p.PushLimit(nil)
	d, err := EvalSubquery(phyPlan, er.b.is, er.b.ctx)
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	if len(np.GetSchema()) > 1 {
		newCols := make([]expression.Expression, 0, len(np.GetSchema()))
		for i, data := range d {
			newCols = append(newCols, &expression.Constant{
				Value:   data,
				RetType: np.GetSchema()[i].GetType()})
		}
		expr, err1 := expression.NewFunction(ast.RowFunc, nil, newCols...)
		if err1 != nil {
			er.err = errors.Trace(err1)
			return v, true
		}
		er.ctxStack = append(er.ctxStack, expr)
	} else {
		er.ctxStack = append(er.ctxStack, &expression.Constant{
			Value:   d[0],
			RetType: np.GetSchema()[0].GetType(),
		})
	}
	return v, true
}
Ejemplo n.º 28
0
func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) {
	asScalar := er.asScalar
	er.asScalar = true
	v.Expr.Accept(er)
	if er.err != nil {
		return v, true
	}
	lexpr := er.ctxStack[len(er.ctxStack)-1]
	subq, ok := v.Sel.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown compare type %T.", v.Sel)
		return v, true
	}
	np := er.buildSubquery(subq)
	if er.err != nil {
		return v, true
	}
	lLen := getRowLen(lexpr)
	if lLen != len(np.GetSchema()) {
		er.err = ErrOperandColumns.GenByArgs(lLen)
		return v, true
	}
	var rexpr expression.Expression
	if len(np.GetSchema()) == 1 {
		rexpr = np.GetSchema()[0].Clone()
	} else {
		args := make([]expression.Expression, 0, len(np.GetSchema()))
		for _, col := range np.GetSchema() {
			args = append(args, col.Clone())
		}
		rexpr, er.err = expression.NewFunction(ast.RowFunc, nil, args...)
		if er.err != nil {
			er.err = errors.Trace(er.err)
			return v, true
		}
	}
	// a in (subq) will be rewrited as a = any(subq).
	// a not in (subq) will be rewrited as a != all(subq).
	checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ)
	if err != nil {
		er.err = errors.Trace(err)
		return v, true
	}
	if !np.IsCorrelated() {
		er.p = er.b.buildSemiJoin(er.p, np, expression.SplitCNFItems(checkCondition), asScalar, v.Not)
		if asScalar {
			col := er.p.GetSchema()[len(er.p.GetSchema())-1]
			er.ctxStack[len(er.ctxStack)-1] = col
		} else {
			er.ctxStack = er.ctxStack[:len(er.ctxStack)-1]
		}
		return v, true
	}
	if v.Not {
		checkCondition, _ = expression.NewFunction(ast.UnaryNot, &v.Type, checkCondition)
	}
	er.p = er.b.buildApply(er.p, np, &ApplyConditionChecker{Condition: checkCondition, All: v.Not})
	// The parent expression only use the last column in schema, which represents whether the condition is matched.
	er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1]
	return v, true

}
Ejemplo n.º 29
0
func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) {
	v.L.Accept(er)
	if er.err != nil {
		return v, true
	}
	lexpr := er.ctxStack[len(er.ctxStack)-1]
	subq, ok := v.R.(*ast.SubqueryExpr)
	if !ok {
		er.err = errors.Errorf("Unknown compare type %T.", v.R)
		return v, true
	}
	np, outerSchema := er.buildSubquery(subq)
	if er.err != nil {
		return v, true
	}
	// Only (a,b,c) = all (...) and (a,b,c) != any () can use row expression.
	canMultiCol := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE)
	if !canMultiCol && (getRowLen(lexpr) != 1 || len(np.GetSchema()) != 1) {
		er.err = errors.New("Operand should contain 1 column(s)")
		return v, true
	}
	if getRowLen(lexpr) != len(np.GetSchema()) {
		er.err = errors.Errorf("Operand should contain %d column(s)", getRowLen(lexpr))
		return v, true
	}
	var checkCondition expression.Expression
	var rexpr expression.Expression
	if len(np.GetSchema()) == 1 {
		rexpr = np.GetSchema()[0].DeepCopy()
	} else {
		args := make([]expression.Expression, 0, len(np.GetSchema()))
		for _, col := range np.GetSchema() {
			args = append(args, col.DeepCopy())
		}
		rexpr, er.err = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...)
		if er.err != nil {
			er.err = errors.Trace(er.err)
			return v, true
		}
	}
	switch v.Op {
	// Only EQ, NE and NullEQ can be composed with and.
	case opcode.EQ, opcode.NE, opcode.NullEQ:
		checkCondition, er.err = constructBinaryOpFunction(lexpr, rexpr, opcode.Ops[v.Op])
		if er.err != nil {
			er.err = errors.Trace(er.err)
			return v, true
		}
	// If op is not EQ, NE, NullEQ, say LT, it will remain as row(a,b) < row(c,d), and be compared as row datum.
	default:
		checkCondition, er.err = expression.NewFunction(opcode.Ops[v.Op],
			types.NewFieldType(mysql.TypeTiny), lexpr, rexpr)
		if er.err != nil {
			er.err = errors.Trace(er.err)
			return v, true
		}
	}
	er.p = er.b.buildApply(er.p, np, outerSchema, &ApplyConditionChecker{Condition: checkCondition, All: v.All})
	if er.p.IsCorrelated() {
		er.correlated = true
	}
	// The parent expression only use the last column in schema, which represents whether the condition is matched.
	er.ctxStack[len(er.ctxStack)-1] = er.p.GetSchema()[len(er.p.GetSchema())-1]
	return v, true
}
Ejemplo n.º 30
0
// Leave implements Visitor interface.
func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) {
	if er.err != nil {
		return retNode, false
	}

	stkLen := len(er.ctxStack)
	switch v := inNode.(type) {
	case *ast.AggregateFuncExpr:
	case *ast.RowExpr:
		length := len(v.Values)
		rows := make([]expression.Expression, 0, length)
		for i := stkLen - length; i < stkLen; i++ {
			rows = append(rows, er.ctxStack[i])
		}
		er.ctxStack = er.ctxStack[:stkLen-length]
		er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, nil, rows...))
	case *ast.VariableExpr:
		return inNode, er.rewriteVariable(v)
	case *ast.FuncCallExpr:
		er.funcCallToScalarFunc(v)
	case *ast.PositionExpr:
		if v.N > 0 && v.N <= len(er.schema) {
			er.ctxStack = append(er.ctxStack, er.schema[v.N-1])
		} else {
			er.err = errors.Errorf("Position %d is out of range", v.N)
		}
	case *ast.ColumnName:
		er.toColumn(v)
	case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr:
	case *ast.ValueExpr:
		value := &expression.Constant{Value: v.Datum, RetType: v.Type}
		er.ctxStack = append(er.ctxStack, value)
	case *ast.ParamMarkerExpr:
		value := &expression.Constant{Value: v.Datum, RetType: v.Type}
		er.ctxStack = append(er.ctxStack, value)
	case *ast.IsNullExpr:
		if getRowLen(er.ctxStack[stkLen-1]) != 1 {
			er.err = errors.New("Operand should contain 1 column(s)")
			return retNode, false
		}
		function := er.notToScalarFunc(v.Not, ast.IsNull, v.Type, er.ctxStack[stkLen-1])
		er.ctxStack = er.ctxStack[:stkLen-1]
		er.ctxStack = append(er.ctxStack, function)
	case *ast.IsTruthExpr:
		op := ast.IsTruth
		if v.True == 0 {
			op = ast.IsFalsity
		}
		function := er.notToScalarFunc(v.Not, op, v.Type, er.ctxStack[stkLen-1])
		er.ctxStack = er.ctxStack[:stkLen-1]
		er.ctxStack = append(er.ctxStack, function)
	case *ast.BinaryOperationExpr:
		var function expression.Expression
		switch v.Op {
		case opcode.EQ, opcode.NE, opcode.NullEQ:
			var err error
			function, err = constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
				opcode.Ops[v.Op])
			if err != nil {
				er.err = errors.Trace(err)
				return retNode, false
			}
		default:
			function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...)
		}
		er.ctxStack = er.ctxStack[:stkLen-2]
		er.ctxStack = append(er.ctxStack, function)
	case *ast.BetweenExpr:
		er.betweenToScalarFunc(v)
	case *ast.PatternLikeExpr:
		er.likeToScalarFunc(v)
	case *ast.PatternInExpr:
		if v.Sel == nil {
			er.inToScalarFunc(v)
		}
	case *ast.UnaryOperationExpr:
		if getRowLen(er.ctxStack[stkLen-1]) != 1 {
			er.err = errors.New("Operand should contain 1 column(s)")
			return retNode, false
		}
		function := expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-1])
		er.ctxStack = er.ctxStack[:stkLen-1]
		er.ctxStack = append(er.ctxStack, function)
	default:
		er.err = errors.Errorf("UnknownType: %T", v)
		return retNode, false
	}

	if er.err != nil {
		return retNode, false
	}
	return inNode, true
}