func getMatch(node sqlparser.BoolExpr, index *Index) (planID PlanID, values interface{}) { switch node := node.(type) { case *sqlparser.AndExpr: if planID, values = getMatch(node.Left, index); planID != SelectScatter { return planID, values } if planID, values = getMatch(node.Right, index); planID != SelectScatter { return planID, values } case *sqlparser.ParenBoolExpr: return getMatch(node.Expr, index) case *sqlparser.ComparisonExpr: switch node.Operator { case "=": if !nameMatch(node.Left, index.Column) { return SelectScatter, nil } if !sqlparser.IsValue(node.Right) { return SelectScatter, nil } val, err := sqlparser.AsInterface(node.Right) if err != nil { return SelectScatter, nil } if index.Type == ShardKey { planID = SelectSingleShardKey } else { planID = SelectSingleLookup } return planID, val case "in": if !nameMatch(node.Left, index.Column) { return SelectScatter, nil } if !sqlparser.IsSimpleTuple(node.Right) { return SelectScatter, nil } val, err := sqlparser.AsInterface(node.Right) if err != nil { return SelectScatter, nil } node.Right = sqlparser.ListArg("::_vals") if index.Type == ShardKey { planID = SelectMultiShardKey } else { planID = SelectMultiLookup } return planID, val } } return SelectScatter, nil }
func buildIndexPlan(ins *sqlparser.Insert, tablename string, index *Index, plan *Plan) error { pos := -1 for i, column := range ins.Columns { if index.Column == sqlparser.GetColName(column.(*sqlparser.NonStarExpr).Expr) { pos = i break } } if pos == -1 && index.Owner == tablename && index.IsAutoInc { pos = len(ins.Columns) ins.Columns = append(ins.Columns, &sqlparser.NonStarExpr{Expr: &sqlparser.ColName{Name: []byte(index.Column)}}) ins.Rows.(sqlparser.Values)[0] = append(ins.Rows.(sqlparser.Values)[0].(sqlparser.ValTuple), &sqlparser.NullVal{}) } if pos == -1 { return fmt.Errorf("must supply value for indexed column: %s", index.Column) } row := ins.Rows.(sqlparser.Values)[0].(sqlparser.ValTuple) val, err := sqlparser.AsInterface(row[pos]) if err != nil { return fmt.Errorf("could not convert val: %s, pos: %d", row[pos], pos) } plan.Values = append(plan.Values.([]interface{}), val) if index.Owner == tablename && index.IsAutoInc { row[pos] = sqlparser.ValArg([]byte(fmt.Sprintf(":_%s", index.Column))) } return nil }
func getInsertPKValues(pkColumnNumbers []int, rowList sqlparser.Values, tableInfo *schema.Table) (pkValues []interface{}, err error) { pkValues = make([]interface{}, len(pkColumnNumbers)) for index, columnNumber := range pkColumnNumbers { if columnNumber == -1 { pkValues[index] = tableInfo.GetPKColumn(index).Default continue } values := make([]interface{}, len(rowList)) for j := 0; j < len(rowList); j++ { if _, ok := rowList[j].(*sqlparser.Subquery); ok { return nil, errors.New("row subquery not supported for inserts") } row := rowList[j].(sqlparser.ValTuple) if columnNumber >= len(row) { return nil, errors.New("column count doesn't match value count") } node := row[columnNumber] if !sqlparser.IsNull(node) && !sqlparser.IsValue(node) { return nil, nil } var err error values[j], err = sqlparser.AsInterface(node) if err != nil { return nil, err } } if len(values) == 1 { pkValues[index] = values[0] } else { pkValues[index] = values } } return pkValues, nil }
func getPKValues(conditions []*sqlparser.ComparisonExpr, pkIndex *schema.Index) []interface{} { pkValues := make([]interface{}, len(pkIndex.Columns)) inClauseSeen := false for _, condition := range conditions { if condition.Operator == sqlparser.InStr { if inClauseSeen { return nil } inClauseSeen = true } index := pkIndex.FindColumn(condition.Left.(*sqlparser.ColName).Name.Original()) if index == -1 { return nil } if pkValues[index] != nil { return nil } var err error pkValues[index], err = sqlparser.AsInterface(condition.Right) if err != nil { return nil } } for _, v := range pkValues { if v == nil { return nil } } return pkValues }
func getPKValues(conditions []sqlparser.BoolExpr, pkIndex *schema.Index) (pkValues []interface{}, err error) { pkIndexScore := NewIndexScore(pkIndex) pkValues = make([]interface{}, len(pkIndexScore.ColumnMatch)) for _, condition := range conditions { condition, ok := condition.(*sqlparser.ComparisonExpr) if !ok { return nil, nil } if !sqlparser.StringIn(condition.Operator, sqlparser.AST_EQ, sqlparser.AST_IN) { return nil, nil } index := pkIndexScore.FindMatch(string(condition.Left.(*sqlparser.ColName).Name)) if index == -1 { return nil, nil } switch condition.Operator { case sqlparser.AST_EQ, sqlparser.AST_IN: var err error pkValues[index], err = sqlparser.AsInterface(condition.Right) if err != nil { return nil, err } default: panic("unreachable") } } if pkIndexScore.GetScore() == PERFECT_SCORE { return pkValues, nil } return nil, nil }
func analyzeUpdateExpressions(exprs sqlparser.UpdateExprs, pkIndex *schema.Index) (pkValues []interface{}, err error) { for _, expr := range exprs { index := pkIndex.FindColumn(expr.Name.Original()) if index == -1 { continue } if !sqlparser.IsValue(expr.Expr) { return nil, ErrTooComplex } if pkValues == nil { pkValues = make([]interface{}, len(pkIndex.Columns)) } var err error pkValues[index], err = sqlparser.AsInterface(expr.Expr) if err != nil { return nil, err } } return pkValues, nil }
func analyzeUpdateExpressions(exprs sqlparser.UpdateExprs, pkIndex *schema.Index) (pkValues []interface{}, err error) { for _, expr := range exprs { index := pkIndex.FindColumn(sqlparser.GetColName(expr.Name)) if index == -1 { continue } if !sqlparser.IsValue(expr.Expr) { log.Warningf("expression is too complex %v", expr) return nil, TooComplex } if pkValues == nil { pkValues = make([]interface{}, len(pkIndex.Columns)) } var err error pkValues[index], err = sqlparser.AsInterface(expr.Expr) if err != nil { return nil, err } } return pkValues, nil }
func Fuzz(data []byte) int { stmt, err := sqlparser.Parse(string(data)) if err != nil { if stmt != nil { panic("stmt is not nil on error") } return 0 } if true { data1 := sqlparser.String(stmt) stmt1, err := sqlparser.Parse(data1) if err != nil { fmt.Printf("data0: %q\n", data) fmt.Printf("data1: %q\n", data1) panic(err) } if !fuzz.DeepEqual(stmt, stmt1) { fmt.Printf("data0: %q\n", data) fmt.Printf("data1: %q\n", data1) panic("not equal") } } else { sqlparser.String(stmt) } if sel, ok := stmt.(*sqlparser.Select); ok { var nodes []sqlparser.SQLNode for _, x := range sel.From { nodes = append(nodes, x) } for _, x := range sel.From { nodes = append(nodes, x) } for _, x := range sel.SelectExprs { nodes = append(nodes, x) } for _, x := range sel.GroupBy { nodes = append(nodes, x) } for _, x := range sel.OrderBy { nodes = append(nodes, x) } nodes = append(nodes, sel.Where) nodes = append(nodes, sel.Having) nodes = append(nodes, sel.Limit) for _, n := range nodes { if n == nil { continue } if x, ok := n.(sqlparser.SimpleTableExpr); ok { sqlparser.GetTableName(x) } if x, ok := n.(sqlparser.Expr); ok { sqlparser.GetColName(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsValue(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsColName(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsSimpleTuple(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.AsInterface(x) } if x, ok := n.(sqlparser.BoolExpr); ok { sqlparser.HasINClause([]sqlparser.BoolExpr{x}) } } } buf := sqlparser.NewTrackedBuffer(nil) stmt.Format(buf) pq := buf.ParsedQuery() vars := map[string]interface{}{ "A": 42, "B": 123123123, "C": "", "D": "a", "E": "foobar", "F": 1.1, } pq.GenerateQuery(vars) return 1 }