예제 #1
0
// Ensure that the input query is a Select statement that contains no Join,
// GroupBy, OrderBy, Limit or Distinct operations. Also ensure that the
// source table is present in the schema and has at least one primary key.
func (qs *QuerySplitter) validateQuery() error {
	statement, err := sqlparser.Parse(qs.query.Sql)
	if err != nil {
		return err
	}
	var ok bool
	qs.sel, ok = statement.(*sqlparser.Select)
	if !ok {
		return fmt.Errorf("not a select statement")
	}
	if qs.sel.Distinct != "" || qs.sel.GroupBy != nil ||
		qs.sel.Having != nil || len(qs.sel.From) != 1 ||
		qs.sel.OrderBy != nil || qs.sel.Limit != nil ||
		qs.sel.Lock != "" {
		return fmt.Errorf("unsupported query")
	}
	node, ok := qs.sel.From[0].(*sqlparser.AliasedTableExpr)
	if !ok {
		return fmt.Errorf("unsupported query")
	}
	qs.tableName = sqlparser.GetTableName(node.Expr)
	if qs.tableName == "" {
		return fmt.Errorf("not a simple table expression")
	}
	tableInfo, ok := qs.schemaInfo.tables[qs.tableName]
	if !ok {
		return fmt.Errorf("can't find table in schema")
	}
	if len(tableInfo.PKColumns) == 0 {
		return fmt.Errorf("no primary keys")
	}
	qs.pkCol = tableInfo.GetPKColumn(0).Name
	return nil
}
예제 #2
0
파일: plan.go 프로젝트: aaijazi/vitess
// GetStreamExecPlan generates a ExecPlan given a sql query and a TableGetter.
func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, err
	}

	plan = &ExecPlan{
		PlanID:    PlanSelectStream,
		FullQuery: GenerateFullQuery(statement),
	}

	switch stmt := statement.(type) {
	case *sqlparser.Select:
		if stmt.Lock != "" {
			return nil, errors.New("select with lock not allowed for streaming")
		}
		tableName, _ := analyzeFrom(stmt.From)
		// This will block usage of NEXTVAL.
		if tableName == "dual" {
			return nil, errors.New("select from dual not allowed for streaming")
		}
		if tableName != "" {
			plan.setTableInfo(tableName, getTable)
		}

	case *sqlparser.Union:
		// pass
	default:
		return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt))
	}

	return plan, nil
}
예제 #3
0
파일: plan.go 프로젝트: plobsing/vitess
func BuildPlan(query string, schema *Schema) *Plan {
	statement, err := sqlparser.Parse(query)
	if err != nil {
		return &Plan{
			ID:       NoPlan,
			Reason:   err.Error(),
			Original: query,
		}
	}
	noplan := &Plan{
		ID:       NoPlan,
		Reason:   "too complex",
		Original: query,
	}
	var plan *Plan
	switch statement := statement.(type) {
	case *sqlparser.Select:
		plan = buildSelectPlan(statement, schema)
	case *sqlparser.Insert:
		plan = buildInsertPlan(statement, schema)
	case *sqlparser.Update:
		plan = buildUpdatePlan(statement, schema)
	case *sqlparser.Delete:
		plan = buildDeletePlan(statement, schema)
	case *sqlparser.Union, *sqlparser.Set, *sqlparser.DDL, *sqlparser.Other:
		return noplan
	default:
		panic("unexpected")
	}
	plan.Original = query
	return plan
}
예제 #4
0
파일: builder.go 프로젝트: CowLeo/vitess
// Build builds a plan for a query based on the specified vschema.
// It's the main entry point for this package.
func Build(query string, vschema VSchema) (*engine.Plan, error) {
	statement, err := sqlparser.Parse(query)
	if err != nil {
		return nil, err
	}
	plan := &engine.Plan{
		Original: query,
	}
	switch statement := statement.(type) {
	case *sqlparser.Select:
		plan.Instructions, err = buildSelectPlan(statement, vschema)
	case *sqlparser.Insert:
		plan.Instructions, err = buildInsertPlan(statement, vschema)
	case *sqlparser.Update:
		plan.Instructions, err = buildUpdatePlan(statement, vschema)
	case *sqlparser.Delete:
		plan.Instructions, err = buildDeletePlan(statement, vschema)
	case *sqlparser.Union, *sqlparser.Set, *sqlparser.DDL, *sqlparser.Other:
		return nil, errors.New("unsupported construct")
	default:
		panic("unexpected statement type")
	}
	if err != nil {
		return nil, err
	}
	return plan, nil
}
예제 #5
0
func (rci *RowcacheInvalidator) handleErrEvent(event *blproto.StreamEvent) {
	statement, err := sqlparser.Parse(event.Sql)
	if err != nil {
		log.Errorf("Error parsing: %s: %v", event.Sql, err)
		internalErrors.Add("Invalidation", 1)
		return
	}
	var table *sqlparser.TableName
	switch stmt := statement.(type) {
	case *sqlparser.Insert:
		// Inserts don't affect rowcache
		return
	case *sqlparser.Update:
		table = stmt.Table
	case *sqlparser.Delete:
		table = stmt.Table
	default:
		log.Errorf("Unrecognized: %s", event.Sql)
		internalErrors.Add("Invalidation", 1)
		return
	}
	// If it's not a cross-db statement, try treating the statement as a DDL.
	// It will conservatively invalidate all rows of the table.
	if table.Qualifier == nil || string(table.Qualifier) == rci.dbname {
		log.Warningf("Treating %s as DDL for table %s", event.Sql, table.Name)
		rci.qe.InvalidateForDDL(&proto.DDLInvalidate{DDL: fmt.Sprintf("alter table %s alter", table.Name)})
	}
}
예제 #6
0
// newSplitParams validates and initializes all the fields except splitCount and
// numRowsPerQueryPart. It contains the common code for the constructors above.
func newSplitParams(sql string, bindVariables map[string]interface{}, splitColumns []string,
	schema map[string]*schema.Table) (*SplitParams, error) {

	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, fmt.Errorf("splitquery: failed parsing query: '%v', err: '%v'", sql, err)
	}
	selectAST, ok := statement.(*sqlparser.Select)
	if !ok {
		return nil, fmt.Errorf("splitquery: not a select statement")
	}
	if selectAST.Distinct != "" || selectAST.GroupBy != nil ||
		selectAST.Having != nil || len(selectAST.From) != 1 ||
		selectAST.OrderBy != nil || selectAST.Limit != nil ||
		selectAST.Lock != "" {
		return nil, fmt.Errorf("splitquery: unsupported query: %v", sql)
	}
	var aliasedTableExpr *sqlparser.AliasedTableExpr
	aliasedTableExpr, ok = selectAST.From[0].(*sqlparser.AliasedTableExpr)
	if !ok {
		return nil, fmt.Errorf("splitquery: unsupported FROM clause in query: %v", sql)
	}
	tableName := sqlparser.GetTableName(aliasedTableExpr.Expr)
	if tableName == "" {
		return nil, fmt.Errorf("splitquery: unsupported FROM clause in query"+
			" (must be a simple table expression): %v", sql)
	}
	tableSchema, ok := schema[tableName]
	if tableSchema == nil {
		return nil, fmt.Errorf("splitquery: can't find table in schema")
	}
	if len(splitColumns) == 0 {
		splitColumns = getPrimaryKeyColumns(tableSchema)
		if len(splitColumns) == 0 {
			panic(fmt.Sprintf("getPrimaryKeyColumns() returned an empty slice. %+v", tableSchema))
		}
	}
	if !areColumnsAPrefixOfAnIndex(splitColumns, tableSchema) {
		return nil, fmt.Errorf("splitquery: split-columns must be a prefix of the columns composing"+
			" an index. Sql: %v, split-columns: %v", sql, splitColumns)
	}
	// Get the split-columns types.
	splitColumnTypes := make([]querypb.Type, 0, len(splitColumns))
	for _, splitColumn := range splitColumns {
		i := tableSchema.FindColumn(splitColumn)
		if i == -1 {
			return nil, fmt.Errorf("can't find split-column: %v", splitColumn)
		}
		splitColumnTypes = append(splitColumnTypes, tableSchema.Columns[i].Type)
	}

	return &SplitParams{
		sql:              sql,
		bindVariables:    bindVariables,
		splitColumns:     splitColumns,
		splitColumnTypes: splitColumnTypes,
		selectAST:        selectAST,
		splitTableSchema: tableSchema,
	}, nil
}
예제 #7
0
파일: plan.go 프로젝트: pranjal5215/vitess
// GetStreamExecPlan generates a ExecPlan given a sql query and a TableGetter.
func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, err
	}

	plan = &ExecPlan{
		PlanId:    PLAN_SELECT_STREAM,
		FullQuery: GenerateFullQuery(statement),
	}

	switch stmt := statement.(type) {
	case *sqlparser.Select:
		if stmt.Lock != "" {
			return nil, errors.New("select with lock disallowed with streaming")
		}
		tableName, _ := analyzeFrom(stmt.From)
		if tableName != "" {
			plan.setTableInfo(tableName, getTable)
		}

	case *sqlparser.Union:
		// pass
	default:
		return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt))
	}

	return plan, nil
}
예제 #8
0
파일: router.go 프로젝트: chinna1986/vitess
func buildPlan(sql string) (plan *RoutingPlan, err error) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, err
	}
	return getRoutingPlan(statement)
}
예제 #9
0
func TestSelect2(t *testing.T) {
	schema, err := LoadFile(locateFile("schema_test.json"))
	if err != nil {
		t.Fatal(err)
	}
	for tcase := range iterateExecFile("select2_cases.txt") {
		statement, err := sqlparser.Parse(tcase.input)
		if err != nil {
			t.Error(err)
			continue
		}
		sel, ok := statement.(*sqlparser.Select)
		if !ok {
			t.Errorf("unexpected type: %T", statement)
			continue
		}
		plan, _, err := buildSelectPlan2(sel, schema)
		if err != nil {
			t.Error(err)
			continue
		}
		bout, err := json.Marshal(plan)
		if err != nil {
			panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
		}
		out := string(bout)
		if out != tcase.output {
			t.Errorf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out)
		}
		// Comment these line out to see the expected outputs
		// bout, err = json.MarshalIndent(plan, "", "  ")
		// fmt.Printf("%s\n", bout)
	}
}
예제 #10
0
파일: plan.go 프로젝트: pranjal5215/vitess
// GetExecPlan generates a ExecPlan given a sql query and a TableGetter.
func GetExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, err
	}
	plan, err = analyzeSQL(statement, getTable)
	if err != nil {
		return nil, err
	}
	if plan.PlanId == PLAN_PASS_DML {
		log.Warningf("PASS_DML: %s", sql)
	}
	return plan, nil
}
예제 #11
0
func parseDDLs(sqls []string) ([]*sqlparser.DDL, error) {
	parsedDDLs := make([]*sqlparser.DDL, len(sqls))
	for i, sql := range sqls {
		stat, err := sqlparser.Parse(sql)
		if err != nil {
			return nil, fmt.Errorf("failed to parse sql: %s, got error: %v", sql, err)
		}
		ddl, ok := stat.(*sqlparser.DDL)
		if !ok {
			return nil, fmt.Errorf("schema change works for DDLs only, but get non DDL statement: %s", sql)
		}
		parsedDDLs[i] = ddl
	}
	return parsedDDLs, nil
}
예제 #12
0
파일: ddl.go 프로젝트: chinna1986/vitess
func DDLParse(sql string) (plan *DDLPlan) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return &DDLPlan{Action: ""}
	}
	stmt, ok := statement.(*sqlparser.DDL)
	if !ok {
		return &DDLPlan{Action: ""}
	}
	return &DDLPlan{
		Action:    stmt.Action,
		TableName: string(stmt.Table),
		NewName:   string(stmt.NewName),
	}
}
예제 #13
0
func analyze(line []byte) {
	for _, ignore := range ignores {
		if bytes.HasPrefix(line, ignore) {
			return
		}
	}
	dml := string(bytes.TrimRight(line, "\n"))
	ast, err := sqlparser.Parse(dml)
	if err != nil {
		log.Errorf("Error parsing %s", dml)
		return
	}
	bindIndex = 0
	buf := sqlparser.NewTrackedBuffer(FormatWithBind)
	buf.Myprintf("%v", ast)
	addQuery(buf.ParsedQuery().Query)
}
예제 #14
0
func TestSplitParams(t *testing.T) {
	for _, testCase := range splitParamsTestCases {
		var splitParams *SplitParams
		var err error
		if testCase.NumRowsPerQueryPart != 0 {
			splitParams, err = NewSplitParamsGivenNumRowsPerQueryPart(
				testCase.SQL,
				testCase.BindVariables,
				testCase.SplitColumnNames,
				testCase.NumRowsPerQueryPart,
				testSchema)
		} else {
			splitParams, err = NewSplitParamsGivenSplitCount(
				testCase.SQL,
				testCase.BindVariables,
				testCase.SplitColumnNames,
				testCase.SplitCount,
				testSchema)
		}
		if testCase.ExpectedErrorRegex != nil {
			if !testCase.ExpectedErrorRegex.MatchString(err.Error()) {
				t.Errorf("Testcase: %+v, want: %+v, got: %+v", testCase, testCase.ExpectedErrorRegex, err)
			}
			continue
		}
		// Here, we don't expect an error.
		if err != nil {
			t.Errorf("TestCase: %+v, want: %+v, got: %+v", testCase, nil, err)
			continue
		}
		if splitParams == nil {
			t.Errorf("TestCase: %+v, got nil splitParams", testCase)
			continue
		}
		expectedSplitParams := testCase.ExpectedSplitParams
		// We don't require testCaset.ExpectedSplitParams to specify common expected fields like 'sql',
		// so we compute them here and store them in 'expectedSplitParams'.
		expectedSplitParams.sql = testCase.SQL
		expectedSplitParams.bindVariables = testCase.BindVariables
		statement, _ := sqlparser.Parse(testCase.SQL)
		expectedSplitParams.selectAST = statement.(*sqlparser.Select)
		if !reflect.DeepEqual(&expectedSplitParams, splitParams) {
			t.Errorf("TestCase: %+v, want: %+v, got: %+v", testCase, expectedSplitParams, *splitParams)
		}
	}
}
예제 #15
0
// Validate validates a list of sql statements
func (exec *TabletExecutor) Validate(ctx context.Context, sqls []string) error {
	if exec.isClosed {
		return fmt.Errorf("executor is closed")
	}
	parsedDDLs := make([]*sqlparser.DDL, len(sqls))
	for i, sql := range sqls {
		stat, err := sqlparser.Parse(sql)
		if err != nil {
			return fmt.Errorf("failed to parse sql: %s, got error: %v", sql, err)
		}
		ddl, ok := stat.(*sqlparser.DDL)
		if !ok {
			return fmt.Errorf("schema change works for DDLs only, but get non DDL statement: %s", sql)
		}
		parsedDDLs[i] = ddl
	}
	return exec.detectBigSchemaChanges(ctx, parsedDDLs)
}
예제 #16
0
// InvalidateForUnrecognized performs best effort rowcache invalidation
// for unrecognized statements.
func (qe *QueryEngine) InvalidateForUnrecognized(sql string) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		log.Errorf("Error: %v: %s", err, sql)
		internalErrors.Add("Invalidation", 1)
		return
	}
	var table *sqlparser.TableName
	switch stmt := statement.(type) {
	case *sqlparser.Insert:
		// Inserts don't affect rowcache.
		return
	case *sqlparser.Update:
		table = stmt.Table
	case *sqlparser.Delete:
		table = stmt.Table
	default:
		log.Errorf("Unrecognized: %s", sql)
		internalErrors.Add("Invalidation", 1)
		return
	}

	// Ignore cross-db statements.
	if table.Qualifier != nil && string(table.Qualifier) != qe.dbconfig.DbName {
		return
	}

	// Ignore if it's an uncached table.
	tableName := string(table.Name)
	tableInfo := qe.schemaInfo.GetTable(tableName)
	if tableInfo == nil {
		log.Errorf("Table %s not found: %s", tableName, sql)
		internalErrors.Add("Invalidation", 1)
		return
	}
	if tableInfo.CacheType == schema.CACHE_NONE {
		return
	}

	// Treat the statement as a DDL.
	// It will conservatively invalidate all rows of the table.
	log.Warningf("Treating '%s' as DDL for table %s", sql, tableName)
	qe.schemaInfo.CreateOrUpdateTable(tableName)
}
예제 #17
0
// Ensure that the input query is a Select statement that contains no Join,
// GroupBy, OrderBy, Limit or Distinct operations. Also ensure that the
// source table is present in the schema and has at least one primary key.
func (qs *QuerySplitter) validateQuery() error {
	statement, err := sqlparser.Parse(qs.sql)
	if err != nil {
		return err
	}
	var ok bool
	qs.sel, ok = statement.(*sqlparser.Select)
	if !ok {
		return fmt.Errorf("not a select statement")
	}
	if qs.sel.Distinct != "" || qs.sel.GroupBy != nil ||
		qs.sel.Having != nil || len(qs.sel.From) != 1 ||
		qs.sel.OrderBy != nil || qs.sel.Limit != nil ||
		qs.sel.Lock != "" {
		return fmt.Errorf("unsupported query")
	}
	node, ok := qs.sel.From[0].(*sqlparser.AliasedTableExpr)
	if !ok {
		return fmt.Errorf("unsupported query")
	}
	qs.tableName = sqlparser.GetTableName(node.Expr)
	if qs.tableName == "" {
		return fmt.Errorf("not a simple table expression")
	}
	tableInfo, ok := qs.schemaInfo.tables[qs.tableName]
	if !ok {
		return fmt.Errorf("can't find table in schema")
	}
	if len(tableInfo.PKColumns) == 0 {
		return fmt.Errorf("no primary keys")
	}
	if qs.splitColumn != "" {
		for _, index := range tableInfo.Indexes {
			for _, column := range index.Columns {
				if qs.splitColumn == column {
					return nil
				}
			}
		}
		return fmt.Errorf("split column is not indexed or does not exist in table schema, SplitColumn: %s, TableInfo.Table: %v", qs.splitColumn, tableInfo.Table)
	}
	qs.splitColumn = tableInfo.GetPKColumn(0).Name
	return nil
}
예제 #18
0
func (rci *RowcacheInvalidator) handleUnrecognizedEvent(sql string) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		log.Errorf("Error: %v: %s", err, sql)
		rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1)
		return
	}
	var table *sqlparser.TableName
	switch stmt := statement.(type) {
	case *sqlparser.Insert:
		// Inserts don't affect rowcache.
		return
	case *sqlparser.Update:
		table = stmt.Table
	case *sqlparser.Delete:
		table = stmt.Table
	default:
		log.Errorf("Unrecognized: %s", sql)
		rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1)
		return
	}

	// Ignore cross-db statements.
	if table.Qualifier != "" && string(table.Qualifier) != rci.qe.dbconfigs.App.DbName {
		return
	}

	// Ignore if it's an uncached table.
	tableName := string(table.Name)
	tableInfo := rci.qe.schemaInfo.GetTable(tableName)
	if tableInfo == nil {
		log.Errorf("Table %s not found: %s", tableName, sql)
		rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1)
		return
	}
	if tableInfo.CacheType == schema.CacheNone {
		return
	}

	// Treat the statement as a DDL.
	// It will conservatively invalidate all rows of the table.
	log.Warningf("Treating '%s' as DDL for table %s", sql, tableName)
	rci.qe.schemaInfo.CreateOrUpdateTable(context.Background(), tableName)
}
예제 #19
0
// SQLExecute is part of the SQLExecuter interface.
func (se *splitQuerySQLExecuter) SQLExecute(
	sql string, bindVariables map[string]interface{},
) (*sqltypes.Result, error) {
	// We need to parse the query since we're dealing with bind-vars.
	// TODO(erez): Add an SQLExecute() to SQLExecuterInterface that gets a parsed query so that
	// we don't have to parse the query again here.
	ast, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, fmt.Errorf("splitQuerySQLExecuter: parsing sql failed with: %v", err)
	}
	parsedQuery := sqlparser.GenerateParsedQuery(ast)

	// We clone "bindVariables" since fullFetch() changes it.
	return se.queryExecutor.fullFetch(
		se.conn,
		parsedQuery,
		utils.CloneBindVariables(bindVariables),
		nil /* buildStreamComment */)
}
예제 #20
0
// Validate validates a list of sql statements.
func (exec *TabletExecutor) Validate(ctx context.Context, sqls []string) error {
	if exec.isClosed {
		return fmt.Errorf("executor is closed")
	}
	parsedDDLs := make([]*sqlparser.DDL, len(sqls))
	for i, sql := range sqls {
		stat, err := sqlparser.Parse(sql)
		if err != nil {
			return fmt.Errorf("failed to parse sql: %s, got error: %v", sql, err)
		}
		ddl, ok := stat.(*sqlparser.DDL)
		if !ok {
			return fmt.Errorf("schema change works for DDLs only, but get non DDL statement: %s", sql)
		}
		parsedDDLs[i] = ddl
	}
	bigSchemaChange, err := exec.detectBigSchemaChanges(ctx, parsedDDLs)
	if bigSchemaChange && exec.allowBigSchemaChange {
		log.Warning("Processing big schema change. This may cause visible MySQL downtime.")
		return nil
	}
	return err
}
예제 #21
0
func Normalize(sql string) (string, map[string]interface{}, error) {
	if sql == "" || sql[0] == '#' || sql[0] == '/' {
		return "", nil, nil
	} else {
		lstr := strings.ToLower(sql)
		if strings.HasPrefix(lstr, "begin") {
			return "begin", nil, nil
		}
		if strings.HasPrefix(lstr, "commit") {
			return "commit", nil, nil
		}
		if strings.HasPrefix(lstr, "rollback") {
			return "rollback", nil, nil
		}
		if strings.HasPrefix(lstr, "use") {
			return "", nil, nil
		}
	}
	tree, err := sqlparser.Parse(sql)
	if err != nil {
		return "", nil, err
	}
	return NormalizeTree(tree)
}
예제 #22
0
// TODO(erez): Add an SQLExecute() to SQLExecuterInterface that gets a parsed query so that
// we don't have to parse the query again here.
func (se *splitQuerySQLExecuter) SQLExecute(
	sql string, bindVariables map[string]interface{}) (*sqltypes.Result, error) {

	ast, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, fmt.Errorf("splitQuerySQLExecuter: parsing sql failed with: %v", err)
	}
	parsedQuery := sqlparser.GenerateParsedQuery(ast)

	conn, err := se.queryExecutor.getConn(se.queryExecutor.qe.connPool)
	if err != nil {
		return nil, err
	}
	defer conn.Recycle()

	// TODO(erez): Find out what 'buildStreamComment' is, and see if we need to use it or comment why
	// we don't.
	// We clone "bindVariables" since fullFetch() changes it.
	return se.queryExecutor.fullFetch(
		conn,
		parsedQuery,
		utils.CloneBindVariables(bindVariables),
		nil /* buildStreamComment */)
}
예제 #23
0
// newSplitParams validates and initializes all the fields except splitCount and
// numRowsPerQueryPart. It contains the common code for the constructors above.
func newSplitParams(
	sql string,
	bindVariables map[string]interface{},
	splitColumnNames []sqlparser.ColIdent,
	schemaMap map[string]*schema.Table,
) (*SplitParams, error) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, fmt.Errorf("failed parsing query: '%v', err: '%v'", sql, err)
	}
	selectAST, ok := statement.(*sqlparser.Select)
	if !ok {
		return nil, fmt.Errorf("not a select statement")
	}
	if selectAST.Distinct != "" || selectAST.GroupBy != nil ||
		selectAST.Having != nil || len(selectAST.From) != 1 ||
		selectAST.OrderBy != nil || selectAST.Limit != nil ||
		selectAST.Lock != "" {
		return nil, fmt.Errorf("unsupported query: %v", sql)
	}
	var aliasedTableExpr *sqlparser.AliasedTableExpr
	aliasedTableExpr, ok = selectAST.From[0].(*sqlparser.AliasedTableExpr)
	if !ok {
		return nil, fmt.Errorf("unsupported FROM clause in query: %v", sql)
	}
	tableName := sqlparser.GetTableName(aliasedTableExpr.Expr)
	if tableName == "" {
		return nil, fmt.Errorf("unsupported FROM clause in query"+
			" (must be a simple table expression): %v", sql)
	}
	tableSchema, ok := schemaMap[tableName]
	if tableSchema == nil {
		return nil, fmt.Errorf("can't find table in schema")
	}

	// Get the schema.TableColumn representation of each splitColumnName.
	var splitColumns []*schema.TableColumn
	if len(splitColumnNames) == 0 {
		splitColumns = getPrimaryKeyColumns(tableSchema)
	} else {
		splitColumns, err = findSplitColumnsInSchema(splitColumnNames, tableSchema)
		if err != nil {
			return nil, err
		}
		if !areColumnsAPrefixOfAnIndex(splitColumns, tableSchema) {
			return nil, fmt.Errorf("split-columns must be a prefix of the columns composing"+
				" an index. Sql: %v, split-columns: %v", sql, splitColumns)
		}
	}

	if len(splitColumns) == 0 {
		panic(fmt.Sprintf(
			"Empty set of split columns. splitColumns: %+v, tableSchema: %+v",
			splitColumns, tableSchema))
	}

	return &SplitParams{
		sql:              sql,
		bindVariables:    bindVariables,
		splitColumns:     splitColumns,
		selectAST:        selectAST,
		splitTableSchema: tableSchema,
	}, nil
}
예제 #24
0
func TestGetWhereClause(t *testing.T) {
	splitter := &QuerySplitter{}
	sql := "select * from test_table where count > :count"
	statement, _ := sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)
	splitter.splitColumn = "id"

	// no boundary case, start = end = nil, should not change the where clause
	nilValue := sqltypes.Value{}
	clause := splitter.getWhereClause(nilValue, nilValue)
	want := " where count > :count"
	got := sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set lower bound, should add the lower bound condition to where clause
	start, _ := sqltypes.BuildValue(20)
	clause = splitter.getWhereClause(start, nilValue)
	want = " where count > :count and id >= 20"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Set upper bound, should add the upper bound condition to where clause
	end, _ := sqltypes.BuildValue(40)
	clause = splitter.getWhereClause(nilValue, end)
	want = " where count > :count and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(start, end)
	want = " where count > :count and id >= 20 and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Original query with no where clause
	sql = "select * from test_table"
	statement, _ = sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)

	// no boundary case, start = end = nil should return no where clause
	clause = splitter.getWhereClause(nilValue, nilValue)
	want = ""
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(start, end)
	want = " where id >= 20 and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}
}
예제 #25
0
func TestGetWhereClause(t *testing.T) {
	splitter := &QuerySplitter{}
	sql := "select * from test_table where count > :count"
	statement, _ := sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)
	splitter.splitColumn = "id"
	bindVars := make(map[string]interface{})
	// no boundary case, start = end = nil, should not change the where clause
	nilValue := sqltypes.Value{}
	clause := splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue)
	want := " where count > :count"
	got := sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set lower bound, should add the lower bound condition to where clause
	startVal := int64(20)
	start, _ := sqltypes.BuildValue(startVal)
	bindVars = make(map[string]interface{})
	bindVars[":count"] = 300
	clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue)
	want = " where (count > :count) and (id >= :" + startBindVarName + ")"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}
	v, ok := bindVars[startBindVarName]
	if !ok {
		t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal)
	}
	if v != startVal {
		t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal)
	}
	// Set upper bound, should add the upper bound condition to where clause
	endVal := int64(40)
	end, _ := sqltypes.BuildValue(endVal)
	bindVars = make(map[string]interface{})
	clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end)
	want = " where (count > :count) and (id < :" + endBindVarName + ")"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}
	v, ok = bindVars[endBindVarName]
	if !ok {
		t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal)
	}
	if v != endVal {
		t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal)
	}

	// Set both bounds, should add two conditions to where clause
	bindVars = make(map[string]interface{})
	clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
	want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName)
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Original query with no where clause
	sql = "select * from test_table"
	statement, _ = sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)
	bindVars = make(map[string]interface{})
	// no boundary case, start = end = nil should return no where clause
	clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue)
	want = ""
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}
	bindVars = make(map[string]interface{})
	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
	want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName)
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}
	v, ok = bindVars[startBindVarName]
	if !ok {
		t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal)
	}
	if v != startVal {
		t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal)
	}
	v, ok = bindVars[endBindVarName]
	if !ok {
		t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal)
	}
	if v != endVal {
		t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal)
	}
}
예제 #26
0
파일: main.go 프로젝트: roger2000hk/go-fuzz
func Fuzz(data []byte) int {
	stmt, err := sqlparser.Parse(string(data))
	if err != nil {
		if stmt != nil {
			panic("stmt is not nil on error")
		}
		return 0
	}
	if true {
		data1 := sqlparser.String(stmt)
		stmt1, err := sqlparser.Parse(data1)
		if err != nil {
			fmt.Printf("data0: %q\n", data)
			fmt.Printf("data1: %q\n", data1)
			panic(err)
		}
		if !fuzz.DeepEqual(stmt, stmt1) {
			fmt.Printf("data0: %q\n", data)
			fmt.Printf("data1: %q\n", data1)
			panic("not equal")
		}
	} else {
		sqlparser.String(stmt)
	}
	if sel, ok := stmt.(*sqlparser.Select); ok {
		var nodes []sqlparser.SQLNode
		for _, x := range sel.From {
			nodes = append(nodes, x)
		}
		for _, x := range sel.From {
			nodes = append(nodes, x)
		}
		for _, x := range sel.SelectExprs {
			nodes = append(nodes, x)
		}
		for _, x := range sel.GroupBy {
			nodes = append(nodes, x)
		}
		for _, x := range sel.OrderBy {
			nodes = append(nodes, x)
		}
		nodes = append(nodes, sel.Where)
		nodes = append(nodes, sel.Having)
		nodes = append(nodes, sel.Limit)
		for _, n := range nodes {
			if n == nil {
				continue
			}
			if x, ok := n.(sqlparser.SimpleTableExpr); ok {
				sqlparser.GetTableName(x)
			}
			if x, ok := n.(sqlparser.Expr); ok {
				sqlparser.GetColName(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.IsValue(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.IsColName(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.IsSimpleTuple(x)
			}
			if x, ok := n.(sqlparser.ValExpr); ok {
				sqlparser.AsInterface(x)
			}
			if x, ok := n.(sqlparser.BoolExpr); ok {
				sqlparser.HasINClause([]sqlparser.BoolExpr{x})
			}
		}
	}
	buf := sqlparser.NewTrackedBuffer(nil)
	stmt.Format(buf)
	pq := buf.ParsedQuery()
	vars := map[string]interface{}{
		"A": 42,
		"B": 123123123,
		"C": "",
		"D": "a",
		"E": "foobar",
		"F": 1.1,
	}
	pq.GenerateQuery(vars)
	return 1
}