// Find returns the route for the symbol referenced by col. // If a reference is found, the column's Metadata is set to point // it. Subsequent searches will reuse this meatadata. // If autoResolve is true, and there is only one table in the symbol table, // then an unqualified reference is assumed to be implicitly against // that table. The table info doesn't contain the full list of columns. // So, any column reference is presumed valid. If a Colsyms scope is // present, then the table scope is not searched. If a symbol is found // in the current symtab, then isLocal is set to true. Otherwise, the // search is continued in the outer symtab. If so, isLocal will be set // to false. If the symbol was not found, an error is returned. // isLocal must be checked before you can push-down (or pull-out) // a construct. // If a symbol was found in an outer scope, then the column reference // is added to the Externs field. func (st *symtab) Find(col *sqlparser.ColName, autoResolve bool) (rb *route, isLocal bool, err error) { if m, ok := col.Metadata.(sym); ok { return m.Route(), m.Symtab() == st, nil } if len(st.Colsyms) != 0 { name := sqlparser.String(col) starname := sqlparser.String(&sqlparser.ColName{ Name: sqlparser.NewColIdent("*"), Qualifier: col.Qualifier, }) for _, colsym := range st.Colsyms { if colsym.Alias.EqualString(name) || colsym.Alias.EqualString(starname) || colsym.Alias.EqualString("*") { col.Metadata = colsym return colsym.Route(), true, nil } } if st.Outer != nil { // autoResolve only allowed for innermost scope. rb, _, err = st.Outer.Find(col, false) if err == nil { st.Externs = append(st.Externs, col) } return rb, false, err } return nil, false, fmt.Errorf("symbol %s not found", sqlparser.String(col)) } qualifier := sqlparser.TableIdent(sqlparser.String(col.Qualifier)) if qualifier == "" && autoResolve && len(st.tables) == 1 { for _, t := range st.tables { qualifier = t.Alias break } } alias := st.findTable(qualifier) if alias == nil { if st.Outer != nil { // autoResolve only allowed for innermost scope. rb, _, err = st.Outer.Find(col, false) if err == nil { st.Externs = append(st.Externs, col) } return rb, false, err } return nil, false, fmt.Errorf("symbol %s not found", sqlparser.String(col)) } col.Metadata = alias return alias.Route(), true, nil }
// PushStar pushes the '*' expression into the route. func (rb *route) PushStar(expr *sqlparser.StarExpr) *colsym { colsym := newColsym(rb, rb.Symtab()) colsym.Alias = sqlparser.NewColIdent(sqlparser.String(expr)) rb.Select.SelectExprs = append(rb.Select.SelectExprs, expr) rb.Colsyms = append(rb.Colsyms, colsym) return colsym }
// split splits the query into multiple queries. validateQuery() must return // nil error before split() is called. func (qs *QuerySplitter) split(columnType querypb.Type, pkMinMax *sqltypes.Result) ([]querytypes.QuerySplit, error) { boundaries, err := qs.splitBoundaries(columnType, pkMinMax) if err != nil { return nil, err } splits := []querytypes.QuerySplit{} // No splits, return the original query as a single split if len(boundaries) == 0 { splits = append(splits, querytypes.QuerySplit{ Sql: qs.sql, BindVariables: qs.bindVariables, }) } else { boundaries = append(boundaries, sqltypes.Value{}) whereClause := qs.sel.Where // Loop through the boundaries and generated modified where clauses start := sqltypes.Value{} for _, end := range boundaries { bindVars := make(map[string]interface{}, len(qs.bindVariables)) for k, v := range qs.bindVariables { bindVars[k] = v } qs.sel.Where = qs.getWhereClause(whereClause, bindVars, start, end) split := &querytypes.QuerySplit{ Sql: sqlparser.String(qs.sel), BindVariables: bindVars, RowCount: qs.rowCount, } splits = append(splits, *split) start = end } qs.sel.Where = whereClause // reset where clause } return splits, err }
// GetStreamExecPlan generates a ExecPlan given a sql query and a TableGetter. func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } plan = &ExecPlan{ PlanID: PlanSelectStream, FullQuery: GenerateFullQuery(statement), } switch stmt := statement.(type) { case *sqlparser.Select: if stmt.Lock != "" { return nil, errors.New("select with lock not allowed for streaming") } tableName, _ := analyzeFrom(stmt.From) // This will block usage of NEXTVAL. if tableName == "dual" { return nil, errors.New("select from dual not allowed for streaming") } if tableName != "" { plan.setTableInfo(tableName, getTable) } case *sqlparser.Union: // pass default: return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt)) } return plan, nil }
func analyzeSelectExprs(exprs sqlparser.SelectExprs, table *schema.Table) (selects []int, err error) { selects = make([]int, 0, len(exprs)) for _, expr := range exprs { switch expr := expr.(type) { case *sqlparser.StarExpr: // Append all columns. for colIndex := range table.Columns { selects = append(selects, colIndex) } case *sqlparser.NonStarExpr: name := sqlparser.GetColName(expr.Expr) if name == "" { // Not a simple column name. return nil, nil } colIndex := table.FindColumn(name) if colIndex == -1 { return nil, fmt.Errorf("column %s not found in table %s", name, table.Name) } selects = append(selects, colIndex) default: return nil, fmt.Errorf("unsupported construct: %s", sqlparser.String(expr)) } } return selects, nil }
// asInterface is similar to sqlparser.AsInterface, but it converts // numeric and string types to native go types. func asInterface(node sqlparser.ValExpr) (interface{}, error) { switch node := node.(type) { case sqlparser.ValTuple: vals := make([]interface{}, 0, len(node)) for _, val := range node { v, err := asInterface(val) if err != nil { return nil, err } vals = append(vals, v) } return vals, nil case sqlparser.ValArg: return string(node), nil case sqlparser.ListArg: return string(node), nil case sqlparser.StrVal: return []byte(node), nil case sqlparser.NumVal: val := string(node) signed, err := strconv.ParseInt(val, 0, 64) if err == nil { return signed, nil } unsigned, err := strconv.ParseUint(val, 0, 64) if err == nil { return unsigned, nil } return nil, err case *sqlparser.NullVal: return nil, nil } return nil, fmt.Errorf("%v is not a value", sqlparser.String(node)) }
// split splits the query into multiple queries. validateQuery() must return // nil error before split() is called. func (qs *QuerySplitter) split(pkMinMax *mproto.QueryResult) []proto.QuerySplit { boundaries := qs.getSplitBoundaries(pkMinMax) splits := []proto.QuerySplit{} // No splits, return the original query as a single split if len(boundaries) == 0 { split := &proto.QuerySplit{ Query: *qs.query, } splits = append(splits, *split) } else { // Loop through the boundaries and generated modified where clauses start := sqltypes.Value{} clauses := []*sqlparser.Where{} for _, end := range boundaries { clauses = append(clauses, qs.getWhereClause(start, end)) start.Inner = end.Inner } clauses = append(clauses, qs.getWhereClause(start, sqltypes.Value{})) // Generate one split per clause for _, clause := range clauses { sel := qs.sel sel.Where = clause q := &proto.BoundQuery{ Sql: sqlparser.String(sel), BindVariables: qs.query.BindVariables, } split := &proto.QuerySplit{ Query: *q, RowCount: qs.rowCount, } splits = append(splits, *split) } } return splits }
// GetStreamExecPlan generates a ExecPlan given a sql query and a TableGetter. func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } plan = &ExecPlan{ PlanId: PLAN_SELECT_STREAM, FullQuery: GenerateFullQuery(statement), } switch stmt := statement.(type) { case *sqlparser.Select: if stmt.Lock != "" { return nil, errors.New("select with lock disallowed with streaming") } tableName, _ := analyzeFrom(stmt.From) if tableName != "" { plan.setTableInfo(tableName, getTable) } case *sqlparser.Union: // pass default: return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt)) } return plan, nil }
// buildInitialQuery returns the initial query to execute to get the // initial boundary tuple. // If the query to split (given in splitParams.sql) is // "SELECT <select exprs> FROM <table> WHERE <where>", // the Sql field of the result will be: // "SELECT sc_1,sc_2,...,sc_n FROM <table> // WHERE <where> // LIMIT splitParams.numRowsPerQueryPart, 1", // The BindVariables field of the result will contain a deep-copy of splitParams.BindVariables. func buildInitialQuery(splitParams *SplitParams) *querytypes.BoundQuery { resultSelectAST := buildInitialQueryAST(splitParams) return &querytypes.BoundQuery{ Sql: sqlparser.String(resultSelectAST), BindVariables: cloneBindVariables(splitParams.bindVariables), } }
// PushSelect pushes the select expression into the route. func (rb *route) PushSelect(expr *sqlparser.NonStarExpr, _ *route) (colsym *colsym, colnum int, err error) { colsym = newColsym(rb, rb.Symtab()) colsym.Alias = expr.As if col, ok := expr.Expr.(*sqlparser.ColName); ok { // If no alias was specified, then the base name // of the column becomes the alias. if colsym.Alias.Original() == "" { colsym.Alias = col.Name } // We should always allow other parts of the query to reference // the fully qualified name of the column. if tab, ok := col.Metadata.(*tabsym); ok { colsym.QualifiedName = sqlparser.NewColIdent(sqlparser.String(tab.Alias) + "." + col.Name.Original()) } colsym.Vindex = rb.Symtab().Vindex(col, rb, true) colsym.Underlying = newColref(col) } else { if rb.IsRHS { return nil, 0, errors.New("unsupported: complex left join and column expressions") } // We should ideally generate an alias based on the // expression, but we currently don't have the ability // to reference such expressions. So, we leave the // alias blank. } rb.Select.SelectExprs = append(rb.Select.SelectExprs, expr) rb.Colsyms = append(rb.Colsyms, colsym) return colsym, len(rb.Colsyms) - 1, nil }
// processAliasedTable produces a builder subtree for the given AliasedTableExpr. // If the expression is a subquery, then the the route built for it will contain // the entire subquery tree in the from clause, as if it was a table. // The symtab entry for the query will be a tabsym where the columns // will be built from the select expressions of the subquery. // Since the table aliases only contain vindex columns, we'll follow // the same rule: only columns from the subquery that are identified as // vindex columns will be added to the tabsym. // A symtab symbol can only point to a route. This means that we canoot // support complex joins in subqueries yet. func processAliasedTable(tableExpr *sqlparser.AliasedTableExpr, vschema VSchema) (builder, error) { switch expr := tableExpr.Expr.(type) { case *sqlparser.TableName: eroute, table, err := getTablePlan(expr, vschema) if err != nil { return nil, err } alias := sqlparser.SQLName(sqlparser.String(expr)) astName := expr.Name if tableExpr.As != "" { alias = tableExpr.As astName = alias } return newRoute( sqlparser.TableExprs([]sqlparser.TableExpr{tableExpr}), eroute, table, vschema, alias, astName, ), nil case *sqlparser.Subquery: sel, ok := expr.Select.(*sqlparser.Select) if !ok { return nil, errors.New("unsupported: union operator in subqueries") } subplan, err := processSelect(sel, vschema, nil) if err != nil { return nil, err } subroute, ok := subplan.(*route) if !ok { return nil, errors.New("unsupported: complex join in subqueries") } table := &vindexes.Table{ Keyspace: subroute.ERoute.Keyspace, } for _, colsyms := range subroute.Colsyms { if colsyms.Vindex == nil { continue } table.ColVindexes = append(table.ColVindexes, &vindexes.ColVindex{ Col: string(colsyms.Alias), Vindex: colsyms.Vindex, }) } rtb := newRoute( sqlparser.TableExprs([]sqlparser.TableExpr{tableExpr}), subroute.ERoute, table, vschema, tableExpr.As, tableExpr.As, ) subroute.Redirect = rtb return rtb, nil } panic("unreachable") }
// buildIndexPlan adds the insert value to the Values field for the specified ColumnVindex. // This value will be used at the time of insert to validate the vindex value. func buildIndexPlan(colVindex *vindexes.ColumnVindex, rowNum int, row sqlparser.ValTuple, pos int) (interface{}, error) { val, err := valConvert(row[pos]) if err != nil { return val, fmt.Errorf("could not convert val: %s, pos: %d: %v", sqlparser.String(row[pos]), pos, err) } row[pos] = sqlparser.ValArg([]byte(":_" + colVindex.Column.Original() + strconv.Itoa(rowNum))) return val, nil }
// PushStar pushes the '*' expression into the route. func (rb *route) PushStar(expr *sqlparser.StarExpr) *colsym { colsym := newColsym(rb, rb.Symtab()) // This is not perfect, but it should be good enough. // We'll match unqualified column names against Alias // and qualified column names against QualifiedName. // If someone uses 'select *' and then uses table.col // in the HAVING clause, then things won't match. But // such cases are easy to correct in the application. if expr.TableName == "" { colsym.Alias = sqlparser.NewColIdent(sqlparser.String(expr)) } else { colsym.QualifiedName = sqlparser.NewColIdent(sqlparser.String(expr)) } rb.Select.SelectExprs = append(rb.Select.SelectExprs, expr) rb.Colsyms = append(rb.Colsyms, colsym) return colsym }
// buildIndexPlan adds the insert value to the Values field for the specified ColVindex. // This value will be used at the time of insert to validate the vindex value. func buildIndexPlan(ins *sqlparser.Insert, colVindex *vindexes.ColVindex, route *engine.Route) error { row, pos := findOrInsertPos(ins, colVindex.Col) val, err := valConvert(row[pos]) if err != nil { return fmt.Errorf("could not convert val: %s, pos: %d: %v", sqlparser.String(row[pos]), pos, err) } route.Values = append(route.Values.([]interface{}), val) row[pos] = sqlparser.ValArg([]byte(":_" + colVindex.Col)) return nil }
// MarshalJSON marshals RouteBuilder into a readable form. // It's used for testing and diagnostics. The representation // cannot be used to reconstruct a RouteBuilder. func (rtb *RouteBuilder) MarshalJSON() ([]byte, error) { marshalRoute := struct { From string `json:",omitempty"` Order int Route *Route }{ From: sqlparser.String(rtb.From), Order: rtb.order, Route: rtb.Route, } return json.Marshal(marshalRoute) }
// buildNonInitialQuery returns the non-initial query to execute to get the // noninitial boundary tuples. // If the query to split (given in splitParams.sql) is // "SELECT <select exprs> FROM <table> WHERE <where>", // the Sql field of the result will be: // "SELECT sc_1,sc_2,...,sc_n FROM <table> // WHERE (<where>) AND (:prev_sc_1,...,:prev_sc_n) <= (sc_1,...,sc_n) // LIMIT splitParams.numRowsPerQueryPart, 1", // where sc_1,...,sc_n are the split columns, // and :prev_sc_1,...,:_prev_sc_n are the bind variable names for the previous tuple. // The BindVariables field of the result will contain a deep-copy of splitParams.BindVariables. // The new "prev_<sc>" bind variables are not populated yet. func buildNoninitialQuery( splitParams *SplitParams, prevBindVariableNames []string) *querytypes.BoundQuery { resultSelectAST := buildInitialQueryAST(splitParams) addAndTermToWhereClause( resultSelectAST, constructTupleInequality( convertBindVariableNamesToValExpr(prevBindVariableNames), convertColumnsToValExpr(splitParams.splitColumns), false /* strict */)) return &querytypes.BoundQuery{ Sql: sqlparser.String(resultSelectAST), BindVariables: cloneBindVariables(splitParams.bindVariables), } }
// initQueryPartSQLs initializes the firstQueryPartSQL, middleQueryPartSQL and lastQueryPartSQL // fields. func (splitter *Splitter) initQueryPartSQLs() { splitColumns := convertColumnsToValExpr(splitter.algorithm.getSplitColumns()) startBindVariables := convertBindVariableNamesToValExpr(splitter.startBindVariableNames) endBindVariables := convertBindVariableNamesToValExpr(splitter.endBindVariableNames) splitColsLessThanEnd := constructTupleInequality( splitColumns, endBindVariables, true /* strict */) splitColsGreaterThanOrEqualToStart := constructTupleInequality( startBindVariables, splitColumns, false /* not strict */) splitter.firstQueryPartSQL = sqlparser.String( queryWithAdditionalWhere(splitter.splitParams.selectAST, splitColsLessThanEnd)) splitter.middleQueryPartSQL = sqlparser.String( queryWithAdditionalWhere(splitter.splitParams.selectAST, &sqlparser.AndExpr{ Left: &sqlparser.ParenBoolExpr{Expr: splitColsGreaterThanOrEqualToStart}, Right: &sqlparser.ParenBoolExpr{Expr: splitColsLessThanEnd}, })) splitter.lastQueryPartSQL = sqlparser.String( queryWithAdditionalWhere(splitter.splitParams.selectAST, splitColsGreaterThanOrEqualToStart)) }
// generateDeleteSubquery generates the query to fetch the rows // that will be deleted. This allows VTGate to clean up any // owned vindexes as needed. func generateDeleteSubquery(del *sqlparser.Delete, table *vindexes.Table) string { if len(table.Owned) == 0 { return "" } buf := bytes.NewBuffer(nil) buf.WriteString("select ") prefix := "" for _, cv := range table.Owned { buf.WriteString(prefix) buf.WriteString(cv.Column.Original()) prefix = ", " } fmt.Fprintf(buf, " from %s", table.Name) buf.WriteString(sqlparser.String(del.Where)) buf.WriteString(" for update") return buf.String() }
func buildAutoIncrementPlan(ins *sqlparser.Insert, autoinc *vindexes.AutoIncrement, route *engine.Route, rowValue []interface{}, rowNum int) (interface{}, []interface{}, error) { var autoIncVal interface{} // If it's also a colvindex, we have to add a redirect from route.Values. // Otherwise, we have to redirect from row[pos]. if autoinc.ColumnVindexNum >= 0 { autoIncVal = rowValue[autoinc.ColumnVindexNum] rowValue[autoinc.ColumnVindexNum] = ":" + engine.SeqVarName + strconv.Itoa(rowNum) return autoIncVal, rowValue, nil } row, pos := findOrInsertPos(ins, autoinc.Column, rowNum) val, err := valConvert(row[pos]) if err != nil { return autoIncVal, rowValue, fmt.Errorf("could not convert val: %s, pos: %d: %v", sqlparser.String(row[pos]), pos, err) } autoIncVal = val row[pos] = sqlparser.ValArg([]byte(":" + engine.SeqVarName + strconv.Itoa(rowNum))) return autoIncVal, rowValue, nil }
// PushSelect pushes the select expression into the route. func (rb *route) PushSelect(expr *sqlparser.NonStarExpr, _ *route) (colsym *colsym, colnum int, err error) { colsym = newColsym(rb, rb.Symtab()) if expr.As.Original() != "" { colsym.Alias = expr.As } if col, ok := expr.Expr.(*sqlparser.ColName); ok { if colsym.Alias.Original() == "" { colsym.Alias = sqlparser.NewColIdent(sqlparser.String(col)) } colsym.Vindex = rb.Symtab().Vindex(col, rb, true) colsym.Underlying = newColref(col) } else { if rb.IsRHS { return nil, 0, errors.New("unsupported: complex left join and column expressions") } } rb.Select.SelectExprs = append(rb.Select.SelectExprs, expr) rb.Colsyms = append(rb.Colsyms, colsym) return colsym, len(rb.Colsyms) - 1, nil }
func buildAutoincPlan(ins *sqlparser.Insert, autoinc *vindexes.Autoinc, route *engine.Route) error { route.Generate = &engine.Generate{ Opcode: engine.SelectUnsharded, Keyspace: autoinc.Sequence.Keyspace, Query: fmt.Sprintf("select next value from `%s`", autoinc.Sequence.Name), } // If it's also a colvindex, we have to add a redirect from route.Values. // Otherwise, we have to redirect from row[pos]. if autoinc.ColVindexNum >= 0 { route.Generate.Value = route.Values.([]interface{})[autoinc.ColVindexNum] route.Values.([]interface{})[autoinc.ColVindexNum] = ":" + engine.SeqVarName return nil } row, pos := findOrInsertPos(ins, autoinc.Col) val, err := valConvert(row[pos]) if err != nil { return fmt.Errorf("could not convert val: %s, pos: %d: %v", sqlparser.String(row[pos]), pos, err) } route.Generate.Value = val row[pos] = sqlparser.ValArg([]byte(":" + engine.SeqVarName)) return nil }
func buildIndexPlan(ins *sqlparser.Insert, tablename string, colVindex *ColVindex, plan *Plan) error { pos := -1 for i, column := range ins.Columns { if colVindex.Col == sqlparser.GetColName(column.(*sqlparser.NonStarExpr).Expr) { pos = i break } } if pos == -1 { pos = len(ins.Columns) ins.Columns = append(ins.Columns, &sqlparser.NonStarExpr{Expr: &sqlparser.ColName{Name: sqlparser.SQLName(colVindex.Col)}}) ins.Rows.(sqlparser.Values)[0] = append(ins.Rows.(sqlparser.Values)[0].(sqlparser.ValTuple), &sqlparser.NullVal{}) } row := ins.Rows.(sqlparser.Values)[0].(sqlparser.ValTuple) val, err := asInterface(row[pos]) if err != nil { return fmt.Errorf("could not convert val: %s, pos: %d: %v", sqlparser.String(row[pos]), pos, err) } plan.Values = append(plan.Values.([]interface{}), val) row[pos] = sqlparser.ValArg([]byte(fmt.Sprintf(":_%s", colVindex.Col))) return nil }
// valConvert converts an AST value to the Value field in the route. func valConvert(node sqlparser.ValExpr) (interface{}, error) { switch node := node.(type) { case sqlparser.ValArg: return string(node), nil case sqlparser.StrVal: return []byte(node), nil case sqlparser.NumVal: val := string(node) signed, err := strconv.ParseInt(val, 0, 64) if err == nil { return signed, nil } unsigned, err := strconv.ParseUint(val, 0, 64) if err == nil { return unsigned, nil } return nil, err case *sqlparser.NullVal: return nil, nil } return nil, fmt.Errorf("%v is not a value", sqlparser.String(node)) }
// Interprets the parsed node and correctly encodes the primary key values. func encodePKValues(tuple sqlparser.ValTuple, insertid int64) (rowPk []interface{}, newinsertid int64, err error) { for _, pkVal := range tuple { switch pkVal := pkVal.(type) { case sqlparser.StrVal: rowPk = append(rowPk, []byte(pkVal)) case sqlparser.NumVal: valstr := string(pkVal) if ival, err := strconv.ParseInt(valstr, 0, 64); err == nil { rowPk = append(rowPk, ival) } else if uval, err := strconv.ParseUint(valstr, 0, 64); err == nil { rowPk = append(rowPk, uval) } else { return nil, insertid, err } case *sqlparser.NullVal: rowPk = append(rowPk, insertid) insertid++ default: return nil, insertid, fmt.Errorf("unexpected token: '%v'", sqlparser.String(pkVal)) } } return rowPk, insertid, nil }
func TestGetWhereClause(t *testing.T) { splitter := &QuerySplitter{} sql := "select * from test_table where count > :count" statement, _ := sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) splitter.splitColumn = "id" // no boundary case, start = end = nil, should not change the where clause nilValue := sqltypes.Value{} clause := splitter.getWhereClause(nilValue, nilValue) want := " where count > :count" got := sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set lower bound, should add the lower bound condition to where clause start, _ := sqltypes.BuildValue(20) clause = splitter.getWhereClause(start, nilValue) want = " where count > :count and id >= 20" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Set upper bound, should add the upper bound condition to where clause end, _ := sqltypes.BuildValue(40) clause = splitter.getWhereClause(nilValue, end) want = " where count > :count and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(start, end) want = " where count > :count and id >= 20 and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Original query with no where clause sql = "select * from test_table" statement, _ = sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) // no boundary case, start = end = nil should return no where clause clause = splitter.getWhereClause(nilValue, nilValue) want = "" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(start, end) want = " where id >= 20 and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } }
func TestGetWhereClause(t *testing.T) { splitter := &QuerySplitter{} sql := "select * from test_table where count > :count" statement, _ := sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) splitter.splitColumn = "id" bindVars := make(map[string]interface{}) // no boundary case, start = end = nil, should not change the where clause nilValue := sqltypes.Value{} clause := splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue) want := " where count > :count" got := sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set lower bound, should add the lower bound condition to where clause startVal := int64(20) start, _ := sqltypes.BuildValue(startVal) bindVars = make(map[string]interface{}) bindVars[":count"] = 300 clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue) want = " where (count > :count) and (id >= :" + startBindVarName + ")" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } v, ok := bindVars[startBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal) } if v != startVal { t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal) } // Set upper bound, should add the upper bound condition to where clause endVal := int64(40) end, _ := sqltypes.BuildValue(endVal) bindVars = make(map[string]interface{}) clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end) want = " where (count > :count) and (id < :" + endBindVarName + ")" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } v, ok = bindVars[endBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal) } if v != endVal { t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal) } // Set both bounds, should add two conditions to where clause bindVars = make(map[string]interface{}) clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end) want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName) got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Original query with no where clause sql = "select * from test_table" statement, _ = sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) bindVars = make(map[string]interface{}) // no boundary case, start = end = nil should return no where clause clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue) want = "" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } bindVars = make(map[string]interface{}) // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end) want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName) got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } v, ok = bindVars[startBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal) } if v != startVal { t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal) } v, ok = bindVars[endBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal) } if v != endVal { t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal) } }
// processAliasedTable produces a builder subtree for the given AliasedTableExpr. // If the expression is a subquery, then the the route built for it will contain // the entire subquery tree in the from clause, as if it was a table. // The symtab entry for the query will be a tabsym where the columns // will be built from the select expressions of the subquery. // Since the table aliases only contain vindex columns, we'll follow // the same rule: only columns from the subquery that are identified as // vindex columns will be added to the tabsym. // A symtab symbol can only point to a route. This means that we canoot // support complex joins in subqueries yet. func processAliasedTable(tableExpr *sqlparser.AliasedTableExpr, vschema VSchema) (builder, error) { switch expr := tableExpr.Expr.(type) { case *sqlparser.TableName: eroute, table, err := getTablePlan(expr, vschema) if err != nil { return nil, err } alias := sqlparser.TableIdent(sqlparser.String(expr)) astName := expr.Name if tableExpr.As != "" { alias = tableExpr.As astName = alias } return newRoute( sqlparser.TableExprs([]sqlparser.TableExpr{tableExpr}), eroute, table, vschema, alias, astName, ), nil case *sqlparser.Subquery: sel, ok := expr.Select.(*sqlparser.Select) if !ok { return nil, errors.New("unsupported: union operator in subqueries") } subplan, err := processSelect(sel, vschema, nil) if err != nil { return nil, err } subroute, ok := subplan.(*route) if !ok { return nil, errors.New("unsupported: complex join in subqueries") } table := &vindexes.Table{ Keyspace: subroute.ERoute.Keyspace, } for _, colsyms := range subroute.Colsyms { if colsyms.Vindex == nil { continue } // Check if a colvindex of the same name already exists. // Dups are not allowed in subqueries in this situation. for _, colVindex := range table.ColumnVindexes { if colVindex.Column.Equal(cistring.CIString(colsyms.Alias)) { return nil, fmt.Errorf("duplicate column aliases: %v", colsyms.Alias) } } table.ColumnVindexes = append(table.ColumnVindexes, &vindexes.ColumnVindex{ Column: cistring.CIString(colsyms.Alias), Vindex: colsyms.Vindex, }) } rtb := newRoute( sqlparser.TableExprs([]sqlparser.TableExpr{tableExpr}), subroute.ERoute, table, vschema, tableExpr.As, tableExpr.As, ) subroute.Redirect = rtb return rtb, nil } panic("unreachable") }
func Fuzz(data []byte) int { stmt, err := sqlparser.Parse(string(data)) if err != nil { if stmt != nil { panic("stmt is not nil on error") } return 0 } if true { data1 := sqlparser.String(stmt) stmt1, err := sqlparser.Parse(data1) if err != nil { fmt.Printf("data0: %q\n", data) fmt.Printf("data1: %q\n", data1) panic(err) } if !fuzz.DeepEqual(stmt, stmt1) { fmt.Printf("data0: %q\n", data) fmt.Printf("data1: %q\n", data1) panic("not equal") } } else { sqlparser.String(stmt) } if sel, ok := stmt.(*sqlparser.Select); ok { var nodes []sqlparser.SQLNode for _, x := range sel.From { nodes = append(nodes, x) } for _, x := range sel.From { nodes = append(nodes, x) } for _, x := range sel.SelectExprs { nodes = append(nodes, x) } for _, x := range sel.GroupBy { nodes = append(nodes, x) } for _, x := range sel.OrderBy { nodes = append(nodes, x) } nodes = append(nodes, sel.Where) nodes = append(nodes, sel.Having) nodes = append(nodes, sel.Limit) for _, n := range nodes { if n == nil { continue } if x, ok := n.(sqlparser.SimpleTableExpr); ok { sqlparser.GetTableName(x) } if x, ok := n.(sqlparser.Expr); ok { sqlparser.GetColName(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsValue(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsColName(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsSimpleTuple(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.AsInterface(x) } if x, ok := n.(sqlparser.BoolExpr); ok { sqlparser.HasINClause([]sqlparser.BoolExpr{x}) } } } buf := sqlparser.NewTrackedBuffer(nil) stmt.Format(buf) pq := buf.ParsedQuery() vars := map[string]interface{}{ "A": 42, "B": 123123123, "C": "", "D": "a", "E": "foobar", "F": 1.1, } pq.GenerateQuery(vars) return 1 }