예제 #1
0
파일: new_builder.go 프로젝트: c4pt0r/tidb
func (b *executorBuilder) buildSemiJoin(v *plan.PhysicalHashSemiJoin) Executor {
	var leftHashKey, rightHashKey []*expression.Column
	var targetTypes []*types.FieldType
	for _, eqCond := range v.EqualConditions {
		ln, _ := eqCond.Args[0].(*expression.Column)
		rn, _ := eqCond.Args[1].(*expression.Column)
		leftHashKey = append(leftHashKey, ln)
		rightHashKey = append(rightHashKey, rn)
		targetTypes = append(targetTypes, types.NewFieldType(types.MergeFieldType(ln.GetType().Tp, rn.GetType().Tp)))
	}
	e := &HashSemiJoinExec{
		schema:       v.GetSchema(),
		otherFilter:  expression.ComposeCNFCondition(v.OtherConditions),
		bigFilter:    expression.ComposeCNFCondition(v.LeftConditions),
		smallFilter:  expression.ComposeCNFCondition(v.RightConditions),
		bigExec:      b.build(v.GetChildByIndex(0)),
		smallExec:    b.build(v.GetChildByIndex(1)),
		prepared:     false,
		ctx:          b.ctx,
		bigHashKey:   leftHashKey,
		smallHashKey: rightHashKey,
		withAux:      v.WithAux,
		anti:         v.Anti,
		targetTypes:  targetTypes,
	}
	return e
}
예제 #2
0
파일: new_builder.go 프로젝트: c4pt0r/tidb
func (b *executorBuilder) buildNewTableScan(v *plan.PhysicalTableScan, s *plan.Selection) Executor {
	txn, err := b.ctx.GetTxn(false)
	if err != nil {
		b.err = errors.Trace(err)
		return nil
	}
	table, _ := b.is.TableByID(v.Table.ID)
	client := txn.GetClient()
	var memDB bool
	switch v.DBName.L {
	case "information_schema", "performance_schema":
		memDB = true
	}
	supportDesc := client.SupportRequestType(kv.ReqTypeSelect, kv.ReqSubTypeDesc)
	if !memDB && client.SupportRequestType(kv.ReqTypeSelect, 0) {
		var ret Executor
		st := &NewXSelectTableExec{
			tableInfo:   v.Table,
			ctx:         b.ctx,
			txn:         txn,
			supportDesc: supportDesc,
			asName:      v.TableAsName,
			table:       table,
			schema:      v.GetSchema(),
			Columns:     v.Columns,
			ranges:      v.Ranges,
			desc:        v.Desc,
			limitCount:  v.LimitCount,
			keepOrder:   v.KeepOrder,
		}
		ret = st
		if !txn.IsReadOnly() {
			if s != nil {
				ret = b.buildNewUnionScanExec(ret,
					expression.ComposeCNFCondition(append(s.Conditions, v.AccessCondition...)))
			} else {
				ret = b.buildNewUnionScanExec(ret, expression.ComposeCNFCondition(v.AccessCondition))
			}
		}
		if s != nil {
			st.where, s.Conditions = b.toPBExpr(s.Conditions, st.tableInfo)
		}
		return ret
	}

	ts := &NewTableScanExec{
		t:          table,
		asName:     v.TableAsName,
		ctx:        b.ctx,
		columns:    v.Columns,
		schema:     v.GetSchema(),
		seekHandle: math.MinInt64,
		ranges:     v.Ranges,
	}
	if v.Desc {
		return &ReverseExec{Src: ts}
	}
	return ts
}
예제 #3
0
파일: builder.go 프로젝트: pingcap/tidb
func (b *executorBuilder) buildJoin(v *plan.PhysicalHashJoin) Executor {
	var leftHashKey, rightHashKey []*expression.Column
	var targetTypes []*types.FieldType
	for _, eqCond := range v.EqualConditions {
		ln, _ := eqCond.Args[0].(*expression.Column)
		rn, _ := eqCond.Args[1].(*expression.Column)
		leftHashKey = append(leftHashKey, ln)
		rightHashKey = append(rightHashKey, rn)
		targetTypes = append(targetTypes, types.NewFieldType(types.MergeFieldType(ln.GetType().Tp, rn.GetType().Tp)))
	}
	e := &HashJoinExec{
		schema:        v.GetSchema(),
		otherFilter:   expression.ComposeCNFCondition(v.OtherConditions),
		prepared:      false,
		ctx:           b.ctx,
		targetTypes:   targetTypes,
		concurrency:   v.Concurrency,
		defaultValues: v.DefaultValues,
	}
	if v.SmallTable == 1 {
		e.smallFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.smallHashKey = rightHashKey
		e.bigHashKey = leftHashKey
		e.leftSmall = false
	} else {
		e.leftSmall = true
		e.smallFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.smallHashKey = leftHashKey
		e.bigHashKey = rightHashKey
	}
	if v.JoinType == plan.LeftOuterJoin || v.JoinType == plan.RightOuterJoin {
		e.outer = true
	}
	if e.leftSmall {
		e.smallExec = b.build(v.GetChildByIndex(0))
		e.bigExec = b.build(v.GetChildByIndex(1))
	} else {
		e.smallExec = b.build(v.GetChildByIndex(1))
		e.bigExec = b.build(v.GetChildByIndex(0))
	}
	for i := 0; i < e.concurrency; i++ {
		ctx := &hashJoinCtx{}
		if e.bigFilter != nil {
			ctx.bigFilter = e.bigFilter.Clone()
		}
		if e.otherFilter != nil {
			ctx.otherFilter = e.otherFilter.Clone()
		}
		ctx.datumBuffer = make([]types.Datum, len(e.bigHashKey))
		ctx.hashKeyBuffer = make([]byte, 0, 10000)
		e.hashJoinContexts = append(e.hashJoinContexts, ctx)
	}
	return e
}
예제 #4
0
파일: new_builder.go 프로젝트: c4pt0r/tidb
func (b *executorBuilder) buildSelection(v *plan.Selection) Executor {
	child := v.GetChildByIndex(0)
	oldConditions := v.Conditions
	var src Executor
	switch x := child.(type) {
	case *plan.PhysicalTableScan:
		if x.LimitCount == nil {
			src = b.buildNewTableScan(x, v)
		} else {
			src = b.buildNewTableScan(x, nil)
		}
	case *plan.PhysicalIndexScan:
		if x.LimitCount == nil {
			src = b.buildNewIndexScan(x, v)
		} else {
			src = b.buildNewIndexScan(x, nil)
		}
	default:
		src = b.build(x)
	}

	if len(v.Conditions) == 0 {
		v.Conditions = oldConditions
		return src
	}

	exec := &SelectionExec{
		Src:       src,
		Condition: expression.ComposeCNFCondition(v.Conditions),
		schema:    v.GetSchema(),
		ctx:       b.ctx,
	}
	copy(v.Conditions, oldConditions)
	return exec
}
예제 #5
0
파일: builder.go 프로젝트: jmptrader/tidb
func (b *executorBuilder) buildSelection(v *plan.Selection) Executor {
	exec := &SelectionExec{
		Src:       b.build(v.GetChildByIndex(0)),
		Condition: expression.ComposeCNFCondition(v.Conditions),
		schema:    v.GetSchema(),
		ctx:       b.ctx,
	}
	return exec
}
예제 #6
0
파일: new_builder.go 프로젝트: c4pt0r/tidb
func (b *executorBuilder) buildNewIndexScan(v *plan.PhysicalIndexScan, s *plan.Selection) Executor {
	txn, err := b.ctx.GetTxn(false)
	if err != nil {
		b.err = errors.Trace(err)
		return nil
	}
	table, _ := b.is.TableByID(v.Table.ID)
	client := txn.GetClient()
	var memDB bool
	switch v.DBName.L {
	case "information_schema", "performance_schema":
		memDB = true
	}
	supportDesc := client.SupportRequestType(kv.ReqTypeIndex, kv.ReqSubTypeDesc)
	if !memDB && client.SupportRequestType(kv.ReqTypeIndex, 0) {
		var ret Executor
		st := &NewXSelectIndexExec{
			tableInfo:   v.Table,
			ctx:         b.ctx,
			supportDesc: supportDesc,
			asName:      v.TableAsName,
			table:       table,
			indexPlan:   v,
			txn:         txn,
		}
		ret = st
		if !txn.IsReadOnly() {
			if s != nil {
				ret = b.buildNewUnionScanExec(ret,
					expression.ComposeCNFCondition(append(s.Conditions, v.AccessCondition...)))
			} else {
				ret = b.buildNewUnionScanExec(ret, expression.ComposeCNFCondition(v.AccessCondition))
			}
		}
		// It will forbid limit and aggregation to push down.
		if s != nil {
			st.where, s.Conditions = b.toPBExpr(s.Conditions, st.tableInfo)
		}
		return ret
	}

	b.err = errors.New("Not implement yet.")
	return nil
}
예제 #7
0
//TODO: select join algorithm during cbo phase.
func (b *executorBuilder) buildJoin(v *plan.PhysicalHashJoin) Executor {
	var leftHashKey, rightHashKey []*expression.Column
	var targetTypes []*types.FieldType
	for _, eqCond := range v.EqualConditions {
		ln, _ := eqCond.Args[0].(*expression.Column)
		rn, _ := eqCond.Args[1].(*expression.Column)
		leftHashKey = append(leftHashKey, ln)
		rightHashKey = append(rightHashKey, rn)
		targetTypes = append(targetTypes, types.NewFieldType(types.MergeFieldType(ln.GetType().Tp, rn.GetType().Tp)))
	}
	e := &HashJoinExec{
		schema:      v.GetSchema(),
		otherFilter: expression.ComposeCNFCondition(v.OtherConditions),
		prepared:    false,
		ctx:         b.ctx,
		targetTypes: targetTypes,
	}
	if v.SmallTable == 1 {
		e.smallFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.smallHashKey = rightHashKey
		e.bigHashKey = leftHashKey
		e.leftSmall = false
	} else {
		e.leftSmall = true
		e.smallFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.smallHashKey = leftHashKey
		e.bigHashKey = rightHashKey
	}
	if v.JoinType == plan.LeftOuterJoin || v.JoinType == plan.RightOuterJoin {
		e.outer = true
	}
	if e.leftSmall {
		e.smallExec = b.build(v.GetChildByIndex(0))
		e.bigExec = b.build(v.GetChildByIndex(1))
	} else {
		e.smallExec = b.build(v.GetChildByIndex(1))
		e.bigExec = b.build(v.GetChildByIndex(0))
	}
	return e
}
예제 #8
0
func (p *physicalTableSource) tryToAddUnionScan(resultPlan PhysicalPlan) PhysicalPlan {
	if p.readOnly {
		return resultPlan
	}
	conditions := append(p.indexFilterConditions, p.tableFilterConditions...)
	us := &PhysicalUnionScan{
		Condition: expression.ComposeCNFCondition(append(conditions, p.AccessCondition...)),
	}
	us.SetChildren(resultPlan)
	us.SetSchema(resultPlan.GetSchema())
	return us
}
예제 #9
0
파일: new_builder.go 프로젝트: xxwwbb3/tidb
//TODO: select join algorithm during cbo phase.
func (b *executorBuilder) buildJoin(v *plan.Join) Executor {
	e := &HashJoinExec{
		schema:      v.GetSchema(),
		otherFilter: expression.ComposeCNFCondition(v.OtherConditions),
		prepared:    false,
		ctx:         b.ctx,
	}
	var leftHashKey, rightHashKey []*expression.Column
	for _, eqCond := range v.EqualConditions {
		ln, _ := eqCond.Args[0].(*expression.Column)
		rn, _ := eqCond.Args[1].(*expression.Column)
		leftHashKey = append(leftHashKey, ln)
		rightHashKey = append(rightHashKey, rn)
	}
	switch v.JoinType {
	case plan.LeftOuterJoin:
		e.outter = true
		e.leftSmall = false
		e.smallFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.smallHashKey = rightHashKey
		e.bigHashKey = leftHashKey
	case plan.RightOuterJoin:
		e.outter = true
		e.leftSmall = true
		e.smallFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.smallHashKey = leftHashKey
		e.bigHashKey = rightHashKey
	case plan.InnerJoin:
		//TODO: assume right table is the small one before cbo is realized.
		e.outter = false
		e.leftSmall = false
		e.smallFilter = expression.ComposeCNFCondition(v.RightConditions)
		e.bigFilter = expression.ComposeCNFCondition(v.LeftConditions)
		e.smallHashKey = rightHashKey
		e.bigHashKey = leftHashKey
	default:
		b.err = ErrUnknownPlan.Gen("Unknown Join Type !!")
		return nil
	}
	if e.leftSmall {
		e.smallExec = b.build(v.GetChildByIndex(0))
		e.bigExec = b.build(v.GetChildByIndex(1))
	} else {
		e.smallExec = b.build(v.GetChildByIndex(1))
		e.bigExec = b.build(v.GetChildByIndex(0))
	}
	return e
}
예제 #10
0
// constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2).
func constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
	lLen, rLen := getRowLen(l), getRowLen(r)
	if lLen == 1 && rLen == 1 {
		return expression.NewFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
	} else if rLen != lLen {
		return nil, errors.Errorf("Operand should contain %d column(s)", lLen)
	}
	funcs := make([]expression.Expression, lLen)
	for i := 0; i < lLen; i++ {
		var err error
		funcs[i], err = constructBinaryOpFunction(getRowArg(l, i), getRowArg(r, i), op)
		if err != nil {
			return nil, err
		}
	}
	return expression.ComposeCNFCondition(funcs), nil
}
예제 #11
0
// constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2).
func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, r expression.Expression, op string) (expression.Expression, error) {
	lLen, rLen := getRowLen(l), getRowLen(r)
	if lLen == 1 && rLen == 1 {
		return expression.NewFunction(op, types.NewFieldType(mysql.TypeTiny), l, r)
	} else if rLen != lLen {
		return nil, ErrOperandColumns.GenByArgs(lLen)
	}
	funcs := make([]expression.Expression, lLen)
	for i := 0; i < lLen; i++ {
		var err error
		funcs[i], err = er.constructBinaryOpFunction(getRowArg(l, i), getRowArg(r, i), op)
		if err != nil {
			return nil, errors.Trace(err)
		}
	}
	return expression.ComposeCNFCondition(funcs), nil
}
예제 #12
0
파일: new_builder.go 프로젝트: xxwwbb3/tidb
func (b *executorBuilder) buildNewTableScan(v *plan.NewTableScan, s *plan.Selection) Executor {
	txn, err := b.ctx.GetTxn(false)
	if err != nil {
		b.err = err
		return nil
	}
	table, _ := b.is.TableByID(v.Table.ID)
	client := txn.GetClient()
	var memDB bool
	switch v.DBName.L {
	case "information_schema", "performance_schema":
		memDB = true
	}
	supportDesc := client.SupportRequestType(kv.ReqTypeSelect, kv.ReqSubTypeDesc)
	if !memDB && client.SupportRequestType(kv.ReqTypeSelect, 0) {
		var ret Executor
		ts := &NewTableScanExec{
			tableInfo:   v.Table,
			ctx:         b.ctx,
			supportDesc: supportDesc,
			asName:      v.TableAsName,
			table:       table,
			schema:      v.GetSchema(),
			Columns:     v.Columns,
			ranges:      v.Ranges,
		}
		ret = ts
		if !txn.IsReadOnly() {
			if s != nil {
				ret = b.buildNewUnionScanExec(ret, expression.ComposeCNFCondition(append(s.Conditions, v.AccessCondition...)))
			} else {
				ret = b.buildNewUnionScanExec(ret, nil)
			}
		}
		if s != nil {
			ts.where, s.Conditions = b.toPBExpr(s.Conditions, ts.tableInfo)
		}
		return ret
	}
	b.err = errors.New("Not implement yet.")
	return nil
}
예제 #13
0
파일: new_builder.go 프로젝트: xxwwbb3/tidb
func (b *executorBuilder) buildSelection(v *plan.Selection) Executor {
	ts, ok := v.GetChildByIndex(0).(*plan.NewTableScan)
	var src Executor
	if ok {
		src = b.buildNewTableScan(ts, v)
	} else {
		src = b.build(v.GetChildByIndex(0))
	}

	if len(v.Conditions) == 0 {
		return src
	}

	return &SelectionExec{
		Src:       src,
		Condition: expression.ComposeCNFCondition(v.Conditions),
		schema:    v.GetSchema(),
		ctx:       b.ctx,
	}
}
예제 #14
0
파일: new_builder.go 프로젝트: yubobo/tidb
func (b *executorBuilder) buildSelection(v *plan.Selection) Executor {
	child := v.GetChildByIndex(0)
	var src Executor
	switch x := child.(type) {
	case *plan.PhysicalTableScan:
		src = b.buildNewTableScan(x, v)
	case *plan.PhysicalIndexScan:
		src = b.buildNewIndexScan(x, v)
	default:
		src = b.build(x)
	}

	if len(v.Conditions) == 0 {
		return src
	}

	return &SelectionExec{
		Src:       src,
		Condition: expression.ComposeCNFCondition(v.Conditions),
		schema:    v.GetSchema(),
		ctx:       b.ctx,
	}
}
예제 #15
0
// TestPushDownExpression tests whether expressions have been pushed down successfully.
func (s *testPlanSuite) TestPushDownExpression(c *C) {
	defer testleak.AfterTest(c)()
	cases := []struct {
		sql  string
		cond string // readable expressions.
	}{
		{
			sql:  "a and b",
			cond: "test.t.b",
		},
		{
			sql:  "a or (b and c)",
			cond: "or(test.t.a, and(test.t.b, test.t.c))",
		},
		{
			sql:  "c=1 and d =1 and e =1 and b=1",
			cond: "eq(test.t.b, 1)",
		},
		{
			sql:  "a or b",
			cond: "or(test.t.a, test.t.b)",
		},
		{
			sql:  "a and (b or c)",
			cond: "or(test.t.b, test.t.c)",
		},
		{
			sql:  "not a",
			cond: "not(test.t.a)",
		},
		{
			sql:  "a xor b",
			cond: "xor(test.t.a, test.t.b)",
		},
		{
			sql:  "a & b",
			cond: "bitand(test.t.a, test.t.b)",
		},
		{
			sql:  "a | b",
			cond: "bitor(test.t.a, test.t.b)",
		},
		{
			sql:  "a ^ b",
			cond: "bitxor(test.t.a, test.t.b)",
		},
		{
			sql:  "~a",
			cond: "bitneg(test.t.a)",
		},
		{
			sql:  "a = case a when b then 1 when a then 0 end",
			cond: "eq(test.t.a, case(eq(test.t.a, test.t.b), 1, eq(test.t.a, test.t.a), 0))",
		},
		// if
		{
			sql:  "a = if(a, 1, 0)",
			cond: "eq(test.t.a, if(test.t.a, 1, 0))",
		},
		// nullif
		{
			sql:  "a = nullif(a, 1)",
			cond: "eq(test.t.a, nullif(test.t.a, 1))",
		},
		// ifnull
		{
			sql:  "a = ifnull(null, a)",
			cond: "eq(test.t.a, ifnull(<nil>, test.t.a))",
		},
		// coalesce
		{
			sql:  "a = coalesce(null, null, a, b)",
			cond: "eq(test.t.a, coalesce(<nil>, <nil>, test.t.a, test.t.b))",
		},
		// isnull
		{
			sql:  "b is null",
			cond: "isnull(test.t.b)",
		},
	}
	for _, ca := range cases {
		sql := "select * from t where " + ca.sql
		comment := Commentf("for %s", sql)
		stmt, err := s.ParseOneStmt(sql, "", "")
		c.Assert(err, IsNil, comment)
		ast.SetFlag(stmt)

		err = mockResolve(stmt)
		c.Assert(err, IsNil)
		builder := &planBuilder{
			allocator: new(idAllocator),
			ctx:       mockContext(),
			colMapper: make(map[*ast.ColumnNameExpr]int),
		}
		p := builder.build(stmt)
		c.Assert(builder.err, IsNil)
		lp := p.(LogicalPlan)

		_, lp, err = lp.PredicatePushDown(nil)
		c.Assert(err, IsNil)
		lp.PruneColumns(lp.GetSchema())
		lp.ResolveIndicesAndCorCols()
		info, err := lp.convert2PhysicalPlan(&requiredProperty{})
		c.Assert(err, IsNil)
		p = info.p
		for {
			var ts *physicalTableSource
			switch x := p.(type) {
			case *PhysicalTableScan:
				ts = &x.physicalTableSource
			case *PhysicalIndexScan:
				ts = &x.physicalTableSource
			}
			if ts != nil {
				conditions := append(ts.indexFilterConditions, ts.tableFilterConditions...)
				c.Assert(fmt.Sprintf("%s", expression.ComposeCNFCondition(conditions).String()), Equals, ca.cond, Commentf("for %s", sql))
				break
			}
			p = p.GetChildByIndex(0)
		}
	}
}
예제 #16
0
// propagateConstant propagate constant values of equality predicates and inequality predicates in a condition.
func propagateConstant(conditions []expression.Expression) []expression.Expression {
	if len(conditions) == 0 {
		return conditions
	}
	// Propagate constants in equality predicates.
	// e.g. for condition: "a = b and b = c and c = a and a = 1";
	// we propagate constant as the following step:
	// first: "1 = b and b = c and c = 1 and a = 1";
	// next:  "1 = b and 1 = c and c = 1 and a = 1";
	// next:  "1 = b and 1 = c and 1 = 1 and a = 1";
	// next:  "1 = b and 1 = c and a = 1";

	// e.g for condition: "a = b and b = c and b = 2 and a = 1";
	// we propagate constant as the following step:
	// first: "a = 2 and 2 = c and b = 2 and a = 1";
	// next:  "a = 2 and 2 = c and b = 2 and 2 = 1";
	// next:  "0"
	isSource := make([]bool, len(conditions))
	type transitiveEqualityPredicate map[string]*expression.Constant // transitive equality predicates between one column and one constant
	for {
		equalities := make(transitiveEqualityPredicate, 0)
		for i, getOneEquality := 0, false; i < len(conditions) && !getOneEquality; i++ {
			if isSource[i] {
				continue
			}
			expr, ok := conditions[i].(*expression.ScalarFunction)
			if !ok {
				continue
			}
			// process the included OR conditions recursively to do the same for CNF item.
			switch expr.FuncName.L {
			case ast.OrOr:
				expressions := expression.SplitDNFItems(conditions[i])
				newExpression := make([]expression.Expression, 0)
				for _, v := range expressions {
					newExpression = append(newExpression, propagateConstant([]expression.Expression{v})...)
				}
				conditions[i] = expression.ComposeDNFCondition(newExpression)
				isSource[i] = true
			case ast.AndAnd:
				newExpression := propagateConstant(expression.SplitCNFItems(conditions[i]))
				conditions[i] = expression.ComposeCNFCondition(newExpression)
				isSource[i] = true
			case ast.EQ:
				var (
					col *expression.Column
					val *expression.Constant
				)
				leftConst, leftIsConst := expr.Args[0].(*expression.Constant)
				rightConst, rightIsConst := expr.Args[1].(*expression.Constant)
				leftCol, leftIsCol := expr.Args[0].(*expression.Column)
				rightCol, rightIsCol := expr.Args[1].(*expression.Column)
				if rightIsConst && leftIsCol {
					col = leftCol
					val = rightConst
				} else if leftIsConst && rightIsCol {
					col = rightCol
					val = leftConst
				} else {
					continue
				}
				equalities[string(col.HashCode())] = val
				isSource[i] = true
				getOneEquality = true
			}
		}
		if len(equalities) == 0 {
			break
		}
		for i := 0; i < len(conditions); i++ {
			if isSource[i] {
				continue
			}
			if len(equalities) != 0 {
				conditions[i] = constantSubstitute(equalities, conditions[i])
			}
		}
	}
	// Propagate transitive inequality predicates.
	// e.g for conditions "a = b and c = d and a = c and g = h and b > 0 and e != 0 and g like 'abc'",
	//     we propagate constant as the following step:
	// 1. build multiple equality predicates(mep):
	//    =(a, b, c, d), =(g, h).
	// 2. extract inequality predicates between one constant and one column,
	//    and rewrite them using the root column of a multiple equality predicate:
	//    b > 0, e != 0, g like 'abc' ==> a > 0, g like 'abc'.
	//    ATTENTION: here column 'e' doesn't belong to any mep, so we skip "e != 0".
	// 3. propagate constants in these inequality predicates, and we finally get:
	//    "a = b and c = d and a = c and e = f and g = h and e != 0 and a > 0 and b > 0 and c > 0 and d > 0 and g like 'abc' and h like 'abc' ".
	multipleEqualities := make(map[*expression.Column]*expression.Column, 0)
	for _, cond := range conditions { // build multiple equality predicates.
		expr, ok := cond.(*expression.ScalarFunction)
		if ok && expr.FuncName.L == ast.EQ {
			left, ok1 := expr.Args[0].(*expression.Column)
			right, ok2 := expr.Args[1].(*expression.Column)
			if ok1 && ok2 {
				UnionColumns(left, right, multipleEqualities)
			}
		}
	}
	if len(multipleEqualities) == 0 {
		return conditions
	}
	inequalityFuncs := map[string]string{
		ast.LT:   ast.LT,
		ast.GT:   ast.GT,
		ast.LE:   ast.LE,
		ast.GE:   ast.GE,
		ast.NE:   ast.NE,
		ast.Like: ast.Like,
	}
	type inequalityFactor struct {
		FuncName string
		Factor   []*expression.Constant
	}
	type transitiveInEqualityPredicate map[string][]inequalityFactor // transitive inequality predicates between one column and one constant.
	inequalities := make(transitiveInEqualityPredicate, 0)
	for i := 0; i < len(conditions); i++ { // extract inequality predicates.
		var (
			column   *expression.Column
			equalCol *expression.Column // the root column corresponding to a column in a multiple equality predicate.
			val      *expression.Constant
			funcName string
		)
		expr, ok := conditions[i].(*expression.ScalarFunction)
		if !ok {
			continue
		}
		funcName, ok = inequalityFuncs[expr.FuncName.L]
		if !ok {
			continue
		}
		leftConst, leftIsConst := expr.Args[0].(*expression.Constant)
		rightConst, rightIsConst := expr.Args[1].(*expression.Constant)
		leftCol, leftIsCol := expr.Args[0].(*expression.Column)
		rightCol, rightIsCol := expr.Args[1].(*expression.Column)
		if rightIsConst && leftIsCol {
			column = leftCol
			val = rightConst
		} else if leftIsConst && rightIsCol {
			column = rightCol
			val = leftConst
		} else {
			continue
		}
		equalCol, ok = multipleEqualities[column]
		if !ok { // no need to propagate inequality predicates whose column is only equal to itself.
			continue
		}
		colHashCode := string(equalCol.HashCode())
		if funcName == ast.Like { // func 'LIKE' need 3 input arguments, so here we handle it alone.
			inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: ast.Like, Factor: []*expression.Constant{val, expr.Args[2].(*expression.Constant)}})
		} else {
			inequalities[colHashCode] = append(inequalities[colHashCode], inequalityFactor{FuncName: funcName, Factor: []*expression.Constant{val}})
		}
		conditions = append(conditions[:i], conditions[i+1:]...)
		i--
	}
	if len(inequalities) == 0 {
		return conditions
	}
	for k, v := range multipleEqualities { // propagate constants in inequality predicates.
		for _, x := range inequalities[string(v.HashCode())] {
			funcName, factors := x.FuncName, x.Factor
			if funcName == ast.Like {
				for i := 0; i < len(factors); i += 2 {
					newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i], factors[i+1])
					conditions = append(conditions, newFunc)
				}
			} else {
				for i := 0; i < len(factors); i++ {
					newFunc, _ := expression.NewFunction(funcName, types.NewFieldType(mysql.TypeTiny), k, factors[i])
					conditions = append(conditions, newFunc)
					i++
				}
			}
		}
	}
	return conditions
}
예제 #17
0
파일: plan_test.go 프로젝트: c4pt0r/tidb
func (s *testPlanSuite) TestConstantFolding(c *C) {
	defer testleak.AfterTest(c)()

	cases := []struct {
		exprStr   string
		resultStr string
	}{
		{
			exprStr:   "a < 1 + 2",
			resultStr: "<(test.t.a,3,)",
		},
		{
			exprStr:   "a < greatest(1, 2)",
			resultStr: "<(test.t.a,2,)",
		},
		{
			exprStr:   "a <  1 + 2 + 3 + b",
			resultStr: "<(test.t.a,+(6,test.t.b,),)",
		},
		{
			exprStr: "a = CASE 1+2 " +
				"WHEN 3 THEN 'a' " +
				"WHEN 1 THEN 'b' " +
				"END;",
			resultStr: "=(test.t.a,a,)",
		},
		{
			exprStr:   "a in (hex(12), 'a', '9')",
			resultStr: "in(test.t.a,C,a,9,)",
		},
		{
			exprStr:   "'string' is not null",
			resultStr: "1",
		},
		{
			exprStr:   "'string' is null",
			resultStr: "0",
		},
		{
			exprStr:   "a = !(1+1)",
			resultStr: "=(test.t.a,0,)",
		},
		{
			exprStr:   "a = rand()",
			resultStr: "=(test.t.a,rand(),)",
		},
		{
			exprStr:   "a = version()",
			resultStr: "=(test.t.a,version(),)",
		},
	}

	for _, ca := range cases {
		sql := "select 1 from t where " + ca.exprStr
		stmts, err := s.Parse(sql, "", "")
		c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, ca.exprStr))
		stmt := stmts[0].(*ast.SelectStmt)

		err = newMockResolve(stmt)
		c.Assert(err, IsNil)

		builder := &planBuilder{allocator: new(idAllocator), ctx: mock.NewContext()}
		p := builder.build(stmt)
		c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, ca.exprStr))

		selection := p.GetChildByIndex(0).(*Selection)
		c.Assert(selection, NotNil, Commentf("expr:%v", ca.exprStr))

		c.Assert(expression.ComposeCNFCondition(selection.Conditions).ToString(), Equals, ca.resultStr, Commentf("different for expr %s", ca.exprStr))
	}
}