// Ensure that the input query is a Select statement that contains no Join, // GroupBy, OrderBy, Limit or Distinct operations. Also ensure that the // source table is present in the schema and has at least one primary key. func (qs *QuerySplitter) validateQuery() error { statement, err := sqlparser.Parse(qs.query.Sql) if err != nil { return err } var ok bool qs.sel, ok = statement.(*sqlparser.Select) if !ok { return fmt.Errorf("not a select statement") } if qs.sel.Distinct != "" || qs.sel.GroupBy != nil || qs.sel.Having != nil || len(qs.sel.From) != 1 || qs.sel.OrderBy != nil || qs.sel.Limit != nil || qs.sel.Lock != "" { return fmt.Errorf("unsupported query") } node, ok := qs.sel.From[0].(*sqlparser.AliasedTableExpr) if !ok { return fmt.Errorf("unsupported query") } qs.tableName = sqlparser.GetTableName(node.Expr) if qs.tableName == "" { return fmt.Errorf("not a simple table expression") } tableInfo, ok := qs.schemaInfo.tables[qs.tableName] if !ok { return fmt.Errorf("can't find table in schema") } if len(tableInfo.PKColumns) == 0 { return fmt.Errorf("no primary keys") } qs.pkCol = tableInfo.GetPKColumn(0).Name return nil }
// GetStreamExecPlan generates a ExecPlan given a sql query and a TableGetter. func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } plan = &ExecPlan{ PlanID: PlanSelectStream, FullQuery: GenerateFullQuery(statement), } switch stmt := statement.(type) { case *sqlparser.Select: if stmt.Lock != "" { return nil, errors.New("select with lock not allowed for streaming") } tableName, _ := analyzeFrom(stmt.From) // This will block usage of NEXTVAL. if tableName == "dual" { return nil, errors.New("select from dual not allowed for streaming") } if tableName != "" { plan.setTableInfo(tableName, getTable) } case *sqlparser.Union: // pass default: return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt)) } return plan, nil }
func BuildPlan(query string, schema *Schema) *Plan { statement, err := sqlparser.Parse(query) if err != nil { return &Plan{ ID: NoPlan, Reason: err.Error(), Original: query, } } noplan := &Plan{ ID: NoPlan, Reason: "too complex", Original: query, } var plan *Plan switch statement := statement.(type) { case *sqlparser.Select: plan = buildSelectPlan(statement, schema) case *sqlparser.Insert: plan = buildInsertPlan(statement, schema) case *sqlparser.Update: plan = buildUpdatePlan(statement, schema) case *sqlparser.Delete: plan = buildDeletePlan(statement, schema) case *sqlparser.Union, *sqlparser.Set, *sqlparser.DDL, *sqlparser.Other: return noplan default: panic("unexpected") } plan.Original = query return plan }
// Build builds a plan for a query based on the specified vschema. // It's the main entry point for this package. func Build(query string, vschema VSchema) (*engine.Plan, error) { statement, err := sqlparser.Parse(query) if err != nil { return nil, err } plan := &engine.Plan{ Original: query, } switch statement := statement.(type) { case *sqlparser.Select: plan.Instructions, err = buildSelectPlan(statement, vschema) case *sqlparser.Insert: plan.Instructions, err = buildInsertPlan(statement, vschema) case *sqlparser.Update: plan.Instructions, err = buildUpdatePlan(statement, vschema) case *sqlparser.Delete: plan.Instructions, err = buildDeletePlan(statement, vschema) case *sqlparser.Union, *sqlparser.Set, *sqlparser.DDL, *sqlparser.Other: return nil, errors.New("unsupported construct") default: panic("unexpected statement type") } if err != nil { return nil, err } return plan, nil }
func (rci *RowcacheInvalidator) handleErrEvent(event *blproto.StreamEvent) { statement, err := sqlparser.Parse(event.Sql) if err != nil { log.Errorf("Error parsing: %s: %v", event.Sql, err) internalErrors.Add("Invalidation", 1) return } var table *sqlparser.TableName switch stmt := statement.(type) { case *sqlparser.Insert: // Inserts don't affect rowcache return case *sqlparser.Update: table = stmt.Table case *sqlparser.Delete: table = stmt.Table default: log.Errorf("Unrecognized: %s", event.Sql) internalErrors.Add("Invalidation", 1) return } // If it's not a cross-db statement, try treating the statement as a DDL. // It will conservatively invalidate all rows of the table. if table.Qualifier == nil || string(table.Qualifier) == rci.dbname { log.Warningf("Treating %s as DDL for table %s", event.Sql, table.Name) rci.qe.InvalidateForDDL(&proto.DDLInvalidate{DDL: fmt.Sprintf("alter table %s alter", table.Name)}) } }
// newSplitParams validates and initializes all the fields except splitCount and // numRowsPerQueryPart. It contains the common code for the constructors above. func newSplitParams(sql string, bindVariables map[string]interface{}, splitColumns []string, schema map[string]*schema.Table) (*SplitParams, error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, fmt.Errorf("splitquery: failed parsing query: '%v', err: '%v'", sql, err) } selectAST, ok := statement.(*sqlparser.Select) if !ok { return nil, fmt.Errorf("splitquery: not a select statement") } if selectAST.Distinct != "" || selectAST.GroupBy != nil || selectAST.Having != nil || len(selectAST.From) != 1 || selectAST.OrderBy != nil || selectAST.Limit != nil || selectAST.Lock != "" { return nil, fmt.Errorf("splitquery: unsupported query: %v", sql) } var aliasedTableExpr *sqlparser.AliasedTableExpr aliasedTableExpr, ok = selectAST.From[0].(*sqlparser.AliasedTableExpr) if !ok { return nil, fmt.Errorf("splitquery: unsupported FROM clause in query: %v", sql) } tableName := sqlparser.GetTableName(aliasedTableExpr.Expr) if tableName == "" { return nil, fmt.Errorf("splitquery: unsupported FROM clause in query"+ " (must be a simple table expression): %v", sql) } tableSchema, ok := schema[tableName] if tableSchema == nil { return nil, fmt.Errorf("splitquery: can't find table in schema") } if len(splitColumns) == 0 { splitColumns = getPrimaryKeyColumns(tableSchema) if len(splitColumns) == 0 { panic(fmt.Sprintf("getPrimaryKeyColumns() returned an empty slice. %+v", tableSchema)) } } if !areColumnsAPrefixOfAnIndex(splitColumns, tableSchema) { return nil, fmt.Errorf("splitquery: split-columns must be a prefix of the columns composing"+ " an index. Sql: %v, split-columns: %v", sql, splitColumns) } // Get the split-columns types. splitColumnTypes := make([]querypb.Type, 0, len(splitColumns)) for _, splitColumn := range splitColumns { i := tableSchema.FindColumn(splitColumn) if i == -1 { return nil, fmt.Errorf("can't find split-column: %v", splitColumn) } splitColumnTypes = append(splitColumnTypes, tableSchema.Columns[i].Type) } return &SplitParams{ sql: sql, bindVariables: bindVariables, splitColumns: splitColumns, splitColumnTypes: splitColumnTypes, selectAST: selectAST, splitTableSchema: tableSchema, }, nil }
// GetStreamExecPlan generates a ExecPlan given a sql query and a TableGetter. func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } plan = &ExecPlan{ PlanId: PLAN_SELECT_STREAM, FullQuery: GenerateFullQuery(statement), } switch stmt := statement.(type) { case *sqlparser.Select: if stmt.Lock != "" { return nil, errors.New("select with lock disallowed with streaming") } tableName, _ := analyzeFrom(stmt.From) if tableName != "" { plan.setTableInfo(tableName, getTable) } case *sqlparser.Union: // pass default: return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt)) } return plan, nil }
func buildPlan(sql string) (plan *RoutingPlan, err error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } return getRoutingPlan(statement) }
func TestSelect2(t *testing.T) { schema, err := LoadFile(locateFile("schema_test.json")) if err != nil { t.Fatal(err) } for tcase := range iterateExecFile("select2_cases.txt") { statement, err := sqlparser.Parse(tcase.input) if err != nil { t.Error(err) continue } sel, ok := statement.(*sqlparser.Select) if !ok { t.Errorf("unexpected type: %T", statement) continue } plan, _, err := buildSelectPlan2(sel, schema) if err != nil { t.Error(err) continue } bout, err := json.Marshal(plan) if err != nil { panic(fmt.Sprintf("Error marshalling %v: %v", plan, err)) } out := string(bout) if out != tcase.output { t.Errorf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out) } // Comment these line out to see the expected outputs // bout, err = json.MarshalIndent(plan, "", " ") // fmt.Printf("%s\n", bout) } }
// GetExecPlan generates a ExecPlan given a sql query and a TableGetter. func GetExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, err } plan, err = analyzeSQL(statement, getTable) if err != nil { return nil, err } if plan.PlanId == PLAN_PASS_DML { log.Warningf("PASS_DML: %s", sql) } return plan, nil }
func parseDDLs(sqls []string) ([]*sqlparser.DDL, error) { parsedDDLs := make([]*sqlparser.DDL, len(sqls)) for i, sql := range sqls { stat, err := sqlparser.Parse(sql) if err != nil { return nil, fmt.Errorf("failed to parse sql: %s, got error: %v", sql, err) } ddl, ok := stat.(*sqlparser.DDL) if !ok { return nil, fmt.Errorf("schema change works for DDLs only, but get non DDL statement: %s", sql) } parsedDDLs[i] = ddl } return parsedDDLs, nil }
func DDLParse(sql string) (plan *DDLPlan) { statement, err := sqlparser.Parse(sql) if err != nil { return &DDLPlan{Action: ""} } stmt, ok := statement.(*sqlparser.DDL) if !ok { return &DDLPlan{Action: ""} } return &DDLPlan{ Action: stmt.Action, TableName: string(stmt.Table), NewName: string(stmt.NewName), } }
func analyze(line []byte) { for _, ignore := range ignores { if bytes.HasPrefix(line, ignore) { return } } dml := string(bytes.TrimRight(line, "\n")) ast, err := sqlparser.Parse(dml) if err != nil { log.Errorf("Error parsing %s", dml) return } bindIndex = 0 buf := sqlparser.NewTrackedBuffer(FormatWithBind) buf.Myprintf("%v", ast) addQuery(buf.ParsedQuery().Query) }
func TestSplitParams(t *testing.T) { for _, testCase := range splitParamsTestCases { var splitParams *SplitParams var err error if testCase.NumRowsPerQueryPart != 0 { splitParams, err = NewSplitParamsGivenNumRowsPerQueryPart( testCase.SQL, testCase.BindVariables, testCase.SplitColumnNames, testCase.NumRowsPerQueryPart, testSchema) } else { splitParams, err = NewSplitParamsGivenSplitCount( testCase.SQL, testCase.BindVariables, testCase.SplitColumnNames, testCase.SplitCount, testSchema) } if testCase.ExpectedErrorRegex != nil { if !testCase.ExpectedErrorRegex.MatchString(err.Error()) { t.Errorf("Testcase: %+v, want: %+v, got: %+v", testCase, testCase.ExpectedErrorRegex, err) } continue } // Here, we don't expect an error. if err != nil { t.Errorf("TestCase: %+v, want: %+v, got: %+v", testCase, nil, err) continue } if splitParams == nil { t.Errorf("TestCase: %+v, got nil splitParams", testCase) continue } expectedSplitParams := testCase.ExpectedSplitParams // We don't require testCaset.ExpectedSplitParams to specify common expected fields like 'sql', // so we compute them here and store them in 'expectedSplitParams'. expectedSplitParams.sql = testCase.SQL expectedSplitParams.bindVariables = testCase.BindVariables statement, _ := sqlparser.Parse(testCase.SQL) expectedSplitParams.selectAST = statement.(*sqlparser.Select) if !reflect.DeepEqual(&expectedSplitParams, splitParams) { t.Errorf("TestCase: %+v, want: %+v, got: %+v", testCase, expectedSplitParams, *splitParams) } } }
// Validate validates a list of sql statements func (exec *TabletExecutor) Validate(ctx context.Context, sqls []string) error { if exec.isClosed { return fmt.Errorf("executor is closed") } parsedDDLs := make([]*sqlparser.DDL, len(sqls)) for i, sql := range sqls { stat, err := sqlparser.Parse(sql) if err != nil { return fmt.Errorf("failed to parse sql: %s, got error: %v", sql, err) } ddl, ok := stat.(*sqlparser.DDL) if !ok { return fmt.Errorf("schema change works for DDLs only, but get non DDL statement: %s", sql) } parsedDDLs[i] = ddl } return exec.detectBigSchemaChanges(ctx, parsedDDLs) }
// InvalidateForUnrecognized performs best effort rowcache invalidation // for unrecognized statements. func (qe *QueryEngine) InvalidateForUnrecognized(sql string) { statement, err := sqlparser.Parse(sql) if err != nil { log.Errorf("Error: %v: %s", err, sql) internalErrors.Add("Invalidation", 1) return } var table *sqlparser.TableName switch stmt := statement.(type) { case *sqlparser.Insert: // Inserts don't affect rowcache. return case *sqlparser.Update: table = stmt.Table case *sqlparser.Delete: table = stmt.Table default: log.Errorf("Unrecognized: %s", sql) internalErrors.Add("Invalidation", 1) return } // Ignore cross-db statements. if table.Qualifier != nil && string(table.Qualifier) != qe.dbconfig.DbName { return } // Ignore if it's an uncached table. tableName := string(table.Name) tableInfo := qe.schemaInfo.GetTable(tableName) if tableInfo == nil { log.Errorf("Table %s not found: %s", tableName, sql) internalErrors.Add("Invalidation", 1) return } if tableInfo.CacheType == schema.CACHE_NONE { return } // Treat the statement as a DDL. // It will conservatively invalidate all rows of the table. log.Warningf("Treating '%s' as DDL for table %s", sql, tableName) qe.schemaInfo.CreateOrUpdateTable(tableName) }
// Ensure that the input query is a Select statement that contains no Join, // GroupBy, OrderBy, Limit or Distinct operations. Also ensure that the // source table is present in the schema and has at least one primary key. func (qs *QuerySplitter) validateQuery() error { statement, err := sqlparser.Parse(qs.sql) if err != nil { return err } var ok bool qs.sel, ok = statement.(*sqlparser.Select) if !ok { return fmt.Errorf("not a select statement") } if qs.sel.Distinct != "" || qs.sel.GroupBy != nil || qs.sel.Having != nil || len(qs.sel.From) != 1 || qs.sel.OrderBy != nil || qs.sel.Limit != nil || qs.sel.Lock != "" { return fmt.Errorf("unsupported query") } node, ok := qs.sel.From[0].(*sqlparser.AliasedTableExpr) if !ok { return fmt.Errorf("unsupported query") } qs.tableName = sqlparser.GetTableName(node.Expr) if qs.tableName == "" { return fmt.Errorf("not a simple table expression") } tableInfo, ok := qs.schemaInfo.tables[qs.tableName] if !ok { return fmt.Errorf("can't find table in schema") } if len(tableInfo.PKColumns) == 0 { return fmt.Errorf("no primary keys") } if qs.splitColumn != "" { for _, index := range tableInfo.Indexes { for _, column := range index.Columns { if qs.splitColumn == column { return nil } } } return fmt.Errorf("split column is not indexed or does not exist in table schema, SplitColumn: %s, TableInfo.Table: %v", qs.splitColumn, tableInfo.Table) } qs.splitColumn = tableInfo.GetPKColumn(0).Name return nil }
func (rci *RowcacheInvalidator) handleUnrecognizedEvent(sql string) { statement, err := sqlparser.Parse(sql) if err != nil { log.Errorf("Error: %v: %s", err, sql) rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1) return } var table *sqlparser.TableName switch stmt := statement.(type) { case *sqlparser.Insert: // Inserts don't affect rowcache. return case *sqlparser.Update: table = stmt.Table case *sqlparser.Delete: table = stmt.Table default: log.Errorf("Unrecognized: %s", sql) rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1) return } // Ignore cross-db statements. if table.Qualifier != "" && string(table.Qualifier) != rci.qe.dbconfigs.App.DbName { return } // Ignore if it's an uncached table. tableName := string(table.Name) tableInfo := rci.qe.schemaInfo.GetTable(tableName) if tableInfo == nil { log.Errorf("Table %s not found: %s", tableName, sql) rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1) return } if tableInfo.CacheType == schema.CacheNone { return } // Treat the statement as a DDL. // It will conservatively invalidate all rows of the table. log.Warningf("Treating '%s' as DDL for table %s", sql, tableName) rci.qe.schemaInfo.CreateOrUpdateTable(context.Background(), tableName) }
// SQLExecute is part of the SQLExecuter interface. func (se *splitQuerySQLExecuter) SQLExecute( sql string, bindVariables map[string]interface{}, ) (*sqltypes.Result, error) { // We need to parse the query since we're dealing with bind-vars. // TODO(erez): Add an SQLExecute() to SQLExecuterInterface that gets a parsed query so that // we don't have to parse the query again here. ast, err := sqlparser.Parse(sql) if err != nil { return nil, fmt.Errorf("splitQuerySQLExecuter: parsing sql failed with: %v", err) } parsedQuery := sqlparser.GenerateParsedQuery(ast) // We clone "bindVariables" since fullFetch() changes it. return se.queryExecutor.fullFetch( se.conn, parsedQuery, utils.CloneBindVariables(bindVariables), nil /* buildStreamComment */) }
// Validate validates a list of sql statements. func (exec *TabletExecutor) Validate(ctx context.Context, sqls []string) error { if exec.isClosed { return fmt.Errorf("executor is closed") } parsedDDLs := make([]*sqlparser.DDL, len(sqls)) for i, sql := range sqls { stat, err := sqlparser.Parse(sql) if err != nil { return fmt.Errorf("failed to parse sql: %s, got error: %v", sql, err) } ddl, ok := stat.(*sqlparser.DDL) if !ok { return fmt.Errorf("schema change works for DDLs only, but get non DDL statement: %s", sql) } parsedDDLs[i] = ddl } bigSchemaChange, err := exec.detectBigSchemaChanges(ctx, parsedDDLs) if bigSchemaChange && exec.allowBigSchemaChange { log.Warning("Processing big schema change. This may cause visible MySQL downtime.") return nil } return err }
func Normalize(sql string) (string, map[string]interface{}, error) { if sql == "" || sql[0] == '#' || sql[0] == '/' { return "", nil, nil } else { lstr := strings.ToLower(sql) if strings.HasPrefix(lstr, "begin") { return "begin", nil, nil } if strings.HasPrefix(lstr, "commit") { return "commit", nil, nil } if strings.HasPrefix(lstr, "rollback") { return "rollback", nil, nil } if strings.HasPrefix(lstr, "use") { return "", nil, nil } } tree, err := sqlparser.Parse(sql) if err != nil { return "", nil, err } return NormalizeTree(tree) }
// TODO(erez): Add an SQLExecute() to SQLExecuterInterface that gets a parsed query so that // we don't have to parse the query again here. func (se *splitQuerySQLExecuter) SQLExecute( sql string, bindVariables map[string]interface{}) (*sqltypes.Result, error) { ast, err := sqlparser.Parse(sql) if err != nil { return nil, fmt.Errorf("splitQuerySQLExecuter: parsing sql failed with: %v", err) } parsedQuery := sqlparser.GenerateParsedQuery(ast) conn, err := se.queryExecutor.getConn(se.queryExecutor.qe.connPool) if err != nil { return nil, err } defer conn.Recycle() // TODO(erez): Find out what 'buildStreamComment' is, and see if we need to use it or comment why // we don't. // We clone "bindVariables" since fullFetch() changes it. return se.queryExecutor.fullFetch( conn, parsedQuery, utils.CloneBindVariables(bindVariables), nil /* buildStreamComment */) }
// newSplitParams validates and initializes all the fields except splitCount and // numRowsPerQueryPart. It contains the common code for the constructors above. func newSplitParams( sql string, bindVariables map[string]interface{}, splitColumnNames []sqlparser.ColIdent, schemaMap map[string]*schema.Table, ) (*SplitParams, error) { statement, err := sqlparser.Parse(sql) if err != nil { return nil, fmt.Errorf("failed parsing query: '%v', err: '%v'", sql, err) } selectAST, ok := statement.(*sqlparser.Select) if !ok { return nil, fmt.Errorf("not a select statement") } if selectAST.Distinct != "" || selectAST.GroupBy != nil || selectAST.Having != nil || len(selectAST.From) != 1 || selectAST.OrderBy != nil || selectAST.Limit != nil || selectAST.Lock != "" { return nil, fmt.Errorf("unsupported query: %v", sql) } var aliasedTableExpr *sqlparser.AliasedTableExpr aliasedTableExpr, ok = selectAST.From[0].(*sqlparser.AliasedTableExpr) if !ok { return nil, fmt.Errorf("unsupported FROM clause in query: %v", sql) } tableName := sqlparser.GetTableName(aliasedTableExpr.Expr) if tableName == "" { return nil, fmt.Errorf("unsupported FROM clause in query"+ " (must be a simple table expression): %v", sql) } tableSchema, ok := schemaMap[tableName] if tableSchema == nil { return nil, fmt.Errorf("can't find table in schema") } // Get the schema.TableColumn representation of each splitColumnName. var splitColumns []*schema.TableColumn if len(splitColumnNames) == 0 { splitColumns = getPrimaryKeyColumns(tableSchema) } else { splitColumns, err = findSplitColumnsInSchema(splitColumnNames, tableSchema) if err != nil { return nil, err } if !areColumnsAPrefixOfAnIndex(splitColumns, tableSchema) { return nil, fmt.Errorf("split-columns must be a prefix of the columns composing"+ " an index. Sql: %v, split-columns: %v", sql, splitColumns) } } if len(splitColumns) == 0 { panic(fmt.Sprintf( "Empty set of split columns. splitColumns: %+v, tableSchema: %+v", splitColumns, tableSchema)) } return &SplitParams{ sql: sql, bindVariables: bindVariables, splitColumns: splitColumns, selectAST: selectAST, splitTableSchema: tableSchema, }, nil }
func TestGetWhereClause(t *testing.T) { splitter := &QuerySplitter{} sql := "select * from test_table where count > :count" statement, _ := sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) splitter.splitColumn = "id" // no boundary case, start = end = nil, should not change the where clause nilValue := sqltypes.Value{} clause := splitter.getWhereClause(nilValue, nilValue) want := " where count > :count" got := sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set lower bound, should add the lower bound condition to where clause start, _ := sqltypes.BuildValue(20) clause = splitter.getWhereClause(start, nilValue) want = " where count > :count and id >= 20" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Set upper bound, should add the upper bound condition to where clause end, _ := sqltypes.BuildValue(40) clause = splitter.getWhereClause(nilValue, end) want = " where count > :count and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(start, end) want = " where count > :count and id >= 20 and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Original query with no where clause sql = "select * from test_table" statement, _ = sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) // no boundary case, start = end = nil should return no where clause clause = splitter.getWhereClause(nilValue, nilValue) want = "" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(start, end) want = " where id >= 20 and id < 40" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } }
func TestGetWhereClause(t *testing.T) { splitter := &QuerySplitter{} sql := "select * from test_table where count > :count" statement, _ := sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) splitter.splitColumn = "id" bindVars := make(map[string]interface{}) // no boundary case, start = end = nil, should not change the where clause nilValue := sqltypes.Value{} clause := splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue) want := " where count > :count" got := sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } // Set lower bound, should add the lower bound condition to where clause startVal := int64(20) start, _ := sqltypes.BuildValue(startVal) bindVars = make(map[string]interface{}) bindVars[":count"] = 300 clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue) want = " where (count > :count) and (id >= :" + startBindVarName + ")" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } v, ok := bindVars[startBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal) } if v != startVal { t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal) } // Set upper bound, should add the upper bound condition to where clause endVal := int64(40) end, _ := sqltypes.BuildValue(endVal) bindVars = make(map[string]interface{}) clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end) want = " where (count > :count) and (id < :" + endBindVarName + ")" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } v, ok = bindVars[endBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal) } if v != endVal { t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal) } // Set both bounds, should add two conditions to where clause bindVars = make(map[string]interface{}) clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end) want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName) got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } // Original query with no where clause sql = "select * from test_table" statement, _ = sqlparser.Parse(sql) splitter.sel, _ = statement.(*sqlparser.Select) bindVars = make(map[string]interface{}) // no boundary case, start = end = nil should return no where clause clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue) want = "" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want) } bindVars = make(map[string]interface{}) // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end) want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName) got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) } v, ok = bindVars[startBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal) } if v != startVal { t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal) } v, ok = bindVars[endBindVarName] if !ok { t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal) } if v != endVal { t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal) } }
func Fuzz(data []byte) int { stmt, err := sqlparser.Parse(string(data)) if err != nil { if stmt != nil { panic("stmt is not nil on error") } return 0 } if true { data1 := sqlparser.String(stmt) stmt1, err := sqlparser.Parse(data1) if err != nil { fmt.Printf("data0: %q\n", data) fmt.Printf("data1: %q\n", data1) panic(err) } if !fuzz.DeepEqual(stmt, stmt1) { fmt.Printf("data0: %q\n", data) fmt.Printf("data1: %q\n", data1) panic("not equal") } } else { sqlparser.String(stmt) } if sel, ok := stmt.(*sqlparser.Select); ok { var nodes []sqlparser.SQLNode for _, x := range sel.From { nodes = append(nodes, x) } for _, x := range sel.From { nodes = append(nodes, x) } for _, x := range sel.SelectExprs { nodes = append(nodes, x) } for _, x := range sel.GroupBy { nodes = append(nodes, x) } for _, x := range sel.OrderBy { nodes = append(nodes, x) } nodes = append(nodes, sel.Where) nodes = append(nodes, sel.Having) nodes = append(nodes, sel.Limit) for _, n := range nodes { if n == nil { continue } if x, ok := n.(sqlparser.SimpleTableExpr); ok { sqlparser.GetTableName(x) } if x, ok := n.(sqlparser.Expr); ok { sqlparser.GetColName(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsValue(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsColName(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.IsSimpleTuple(x) } if x, ok := n.(sqlparser.ValExpr); ok { sqlparser.AsInterface(x) } if x, ok := n.(sqlparser.BoolExpr); ok { sqlparser.HasINClause([]sqlparser.BoolExpr{x}) } } } buf := sqlparser.NewTrackedBuffer(nil) stmt.Format(buf) pq := buf.ParsedQuery() vars := map[string]interface{}{ "A": 42, "B": 123123123, "C": "", "D": "a", "E": "foobar", "F": 1.1, } pq.GenerateQuery(vars) return 1 }