예제 #1
0
파일: where.go 프로젝트: plobsing/vitess
func getMatch(node sqlparser.BoolExpr, index *Index) (planID PlanID, values interface{}) {
	switch node := node.(type) {
	case *sqlparser.AndExpr:
		if planID, values = getMatch(node.Left, index); planID != SelectScatter {
			return planID, values
		}
		if planID, values = getMatch(node.Right, index); planID != SelectScatter {
			return planID, values
		}
	case *sqlparser.ParenBoolExpr:
		return getMatch(node.Expr, index)
	case *sqlparser.ComparisonExpr:
		switch node.Operator {
		case "=":
			if !nameMatch(node.Left, index.Column) {
				return SelectScatter, nil
			}
			if !sqlparser.IsValue(node.Right) {
				return SelectScatter, nil
			}
			val, err := sqlparser.AsInterface(node.Right)
			if err != nil {
				return SelectScatter, nil
			}
			if index.Type == ShardKey {
				planID = SelectSingleShardKey
			} else {
				planID = SelectSingleLookup
			}
			return planID, val
		case "in":
			if !nameMatch(node.Left, index.Column) {
				return SelectScatter, nil
			}
			if !sqlparser.IsSimpleTuple(node.Right) {
				return SelectScatter, nil
			}
			val, err := sqlparser.AsInterface(node.Right)
			if err != nil {
				return SelectScatter, nil
			}
			node.Right = sqlparser.ListArg("::_vals")
			if index.Type == ShardKey {
				planID = SelectMultiShardKey
			} else {
				planID = SelectMultiLookup
			}
			return planID, val
		}
	}
	return SelectScatter, nil
}
예제 #2
0
파일: insert.go 프로젝트: plobsing/vitess
func buildIndexPlan(ins *sqlparser.Insert, tablename string, index *Index, plan *Plan) error {
	pos := -1
	for i, column := range ins.Columns {
		if index.Column == sqlparser.GetColName(column.(*sqlparser.NonStarExpr).Expr) {
			pos = i
			break
		}
	}
	if pos == -1 && index.Owner == tablename && index.IsAutoInc {
		pos = len(ins.Columns)
		ins.Columns = append(ins.Columns, &sqlparser.NonStarExpr{Expr: &sqlparser.ColName{Name: []byte(index.Column)}})
		ins.Rows.(sqlparser.Values)[0] = append(ins.Rows.(sqlparser.Values)[0].(sqlparser.ValTuple), &sqlparser.NullVal{})
	}
	if pos == -1 {
		return fmt.Errorf("must supply value for indexed column: %s", index.Column)
	}
	row := ins.Rows.(sqlparser.Values)[0].(sqlparser.ValTuple)
	val, err := sqlparser.AsInterface(row[pos])
	if err != nil {
		return fmt.Errorf("could not convert val: %s, pos: %d", row[pos], pos)
	}
	plan.Values = append(plan.Values.([]interface{}), val)
	if index.Owner == tablename && index.IsAutoInc {
		row[pos] = sqlparser.ValArg([]byte(fmt.Sprintf(":_%s", index.Column)))
	}
	return nil
}
예제 #3
0
파일: dml.go 프로젝트: jmptrader/vitess
func getInsertPKValues(pkColumnNumbers []int, rowList sqlparser.Values, tableInfo *schema.Table) (pkValues []interface{}, err error) {
	pkValues = make([]interface{}, len(pkColumnNumbers))
	for index, columnNumber := range pkColumnNumbers {
		if columnNumber == -1 {
			pkValues[index] = tableInfo.GetPKColumn(index).Default
			continue
		}
		values := make([]interface{}, len(rowList))
		for j := 0; j < len(rowList); j++ {
			if _, ok := rowList[j].(*sqlparser.Subquery); ok {
				return nil, errors.New("row subquery not supported for inserts")
			}
			row := rowList[j].(sqlparser.ValTuple)
			if columnNumber >= len(row) {
				return nil, errors.New("column count doesn't match value count")
			}
			node := row[columnNumber]
			if !sqlparser.IsNull(node) && !sqlparser.IsValue(node) {
				return nil, nil
			}
			var err error
			values[j], err = sqlparser.AsInterface(node)
			if err != nil {
				return nil, err
			}
		}
		if len(values) == 1 {
			pkValues[index] = values[0]
		} else {
			pkValues[index] = values
		}
	}
	return pkValues, nil
}
예제 #4
0
파일: dml.go 프로젝트: jmptrader/vitess
func getPKValues(conditions []*sqlparser.ComparisonExpr, pkIndex *schema.Index) []interface{} {
	pkValues := make([]interface{}, len(pkIndex.Columns))
	inClauseSeen := false
	for _, condition := range conditions {
		if condition.Operator == sqlparser.InStr {
			if inClauseSeen {
				return nil
			}
			inClauseSeen = true
		}
		index := pkIndex.FindColumn(condition.Left.(*sqlparser.ColName).Name.Original())
		if index == -1 {
			return nil
		}
		if pkValues[index] != nil {
			return nil
		}
		var err error
		pkValues[index], err = sqlparser.AsInterface(condition.Right)
		if err != nil {
			return nil
		}
	}
	for _, v := range pkValues {
		if v == nil {
			return nil
		}
	}
	return pkValues
}
예제 #5
0
func getPKValues(conditions []sqlparser.BoolExpr, pkIndex *schema.Index) (pkValues []interface{}, err error) {
	pkIndexScore := NewIndexScore(pkIndex)
	pkValues = make([]interface{}, len(pkIndexScore.ColumnMatch))
	for _, condition := range conditions {
		condition, ok := condition.(*sqlparser.ComparisonExpr)
		if !ok {
			return nil, nil
		}
		if !sqlparser.StringIn(condition.Operator, sqlparser.AST_EQ, sqlparser.AST_IN) {
			return nil, nil
		}
		index := pkIndexScore.FindMatch(string(condition.Left.(*sqlparser.ColName).Name))
		if index == -1 {
			return nil, nil
		}
		switch condition.Operator {
		case sqlparser.AST_EQ, sqlparser.AST_IN:
			var err error
			pkValues[index], err = sqlparser.AsInterface(condition.Right)
			if err != nil {
				return nil, err
			}
		default:
			panic("unreachable")
		}
	}
	if pkIndexScore.GetScore() == PERFECT_SCORE {
		return pkValues, nil
	}
	return nil, nil
}
예제 #6
0
파일: dml.go 프로젝트: jmptrader/vitess
func analyzeUpdateExpressions(exprs sqlparser.UpdateExprs, pkIndex *schema.Index) (pkValues []interface{}, err error) {
	for _, expr := range exprs {
		index := pkIndex.FindColumn(expr.Name.Original())
		if index == -1 {
			continue
		}
		if !sqlparser.IsValue(expr.Expr) {
			return nil, ErrTooComplex
		}
		if pkValues == nil {
			pkValues = make([]interface{}, len(pkIndex.Columns))
		}
		var err error
		pkValues[index], err = sqlparser.AsInterface(expr.Expr)
		if err != nil {
			return nil, err
		}
	}
	return pkValues, nil
}
예제 #7
0
파일: dml.go 프로젝트: miffa/vitess
func analyzeUpdateExpressions(exprs sqlparser.UpdateExprs, pkIndex *schema.Index) (pkValues []interface{}, err error) {
	for _, expr := range exprs {
		index := pkIndex.FindColumn(sqlparser.GetColName(expr.Name))
		if index == -1 {
			continue
		}
		if !sqlparser.IsValue(expr.Expr) {
			log.Warningf("expression is too complex %v", expr)
			return nil, TooComplex
		}
		if pkValues == nil {
			pkValues = make([]interface{}, len(pkIndex.Columns))
		}
		var err error
		pkValues[index], err = sqlparser.AsInterface(expr.Expr)
		if err != nil {
			return nil, err
		}
	}
	return pkValues, nil
}
예제 #8
0
파일: main.go 프로젝트: roger2000hk/go-fuzz
func Fuzz(data []byte) int {
	stmt, err := sqlparser.Parse(string(data))
	if err != nil {
		if stmt != nil {
			panic("stmt is not nil on error")
		}
		return 0
	}
	if true {
		data1 := sqlparser.String(stmt)
		stmt1, err := sqlparser.Parse(data1)
		if err != nil {
			fmt.Printf("data0: %q\n", data)
			fmt.Printf("data1: %q\n", data1)
			panic(err)
		}
		if !fuzz.DeepEqual(stmt, stmt1) {
			fmt.Printf("data0: %q\n", data)
			fmt.Printf("data1: %q\n", data1)
			panic("not equal")
		}
	} else {
		sqlparser.String(stmt)
	}
	if sel, ok := stmt.(*sqlparser.Select); ok {
		var nodes []sqlparser.SQLNode
		for _, x := range sel.From {
			nodes = append(nodes, x)
		}
		for _, x := range sel.From {
			nodes = append(nodes, x)
		}
		for _, x := range sel.SelectExprs {
			nodes = append(nodes, x)
		}
		for _, x := range sel.GroupBy {
			nodes = append(nodes, x)
		}
		for _, x := range sel.OrderBy {
			nodes = append(nodes, x)
		}
		nodes = append(nodes, sel.Where)
		nodes = append(nodes, sel.Having)
		nodes = append(nodes, sel.Limit)
		for _, n := range nodes {
			if n == nil {
				continue
			}
			if x, ok := n.(sqlparser.SimpleTableExpr); ok {
				sqlparser.GetTableName(x)
			}
			if x, ok := n.(sqlparser.Expr); ok {
				sqlparser.GetColName(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.IsValue(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.IsColName(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.IsSimpleTuple(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.AsInterface(x)
			}
			if x, ok := n.(sqlparser.BoolExpr); ok {
				sqlparser.HasINClause([]sqlparser.BoolExpr{x})
			}
		}
	}
	buf := sqlparser.NewTrackedBuffer(nil)
	stmt.Format(buf)
	pq := buf.ParsedQuery()
	vars := map[string]interface{}{
		"A": 42,
		"B": 123123123,
		"C": "",
		"D": "a",
		"E": "foobar",
		"F": 1.1,
	}
	pq.GenerateQuery(vars)
	return 1
}