func TestBadUpdateExpr(t *testing.T) { var sql string var db string r := newTestDBRule() db = "kingshard" sql = "insert into test1 (id) values (5) on duplicate key update id = 10" stmt, err := sqlparser.Parse(sql) if err != nil { t.Fatal(err.Error()) } if _, err := r.BuildPlan(db, stmt); err == nil { t.Fatal("must err") } sql = "update test1 set id = 10 where id = 5" stmt, err = sqlparser.Parse(sql) if err != nil { t.Fatal(err.Error()) } if _, err := r.BuildPlan(db, stmt); err == nil { t.Fatal("must err") } }
func checkPlan(t *testing.T, sql string, tableIndexs []int, nodeIndexs []int) { r := newTestRouter() db := "kingshard" stmt, err := sqlparser.Parse(sql) if err != nil { t.Fatal(err.Error()) } plan, err := r.BuildPlan(db, stmt) if err != nil { t.Fatal(err.Error()) } if isListEqual(plan.RouteTableIndexs, tableIndexs) == false { err := fmt.Errorf("RouteTableIndexs=%v but tableIndexs=%v", plan.RouteTableIndexs, tableIndexs) t.Fatal(err.Error()) } if isListEqual(plan.RouteNodeIndexs, nodeIndexs) == false { err := fmt.Errorf("RouteNodeIndexs=%v but nodeIndexs=%v", plan.RouteNodeIndexs, nodeIndexs) t.Fatal(err.Error()) } t.Logf("rewritten_sql=%v", plan.RewrittenSqls) }
/*处理query语句*/ func (c *ClientConn) handleQuery(sql string) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("execute %s error %v", sql, e) golog.OutputSql("Error", "%s", sql) return } golog.OutputSql("INFO", "%s", sql) }() sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 hasHandled, err := c.handleUnsupport(sql) if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled) return err } if hasHandled { return nil } var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) //解析sql语句,得到的stmt是一个interface if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled) return err } switch v := stmt.(type) { case *sqlparser.Select: return c.handleSelect(v, sql, nil) case *sqlparser.Insert: return c.handleExec(stmt, sql, nil) case *sqlparser.Update: return c.handleExec(stmt, sql, nil) case *sqlparser.Delete: return c.handleExec(stmt, sql, nil) case *sqlparser.Replace: return c.handleExec(stmt, sql, nil) case *sqlparser.Set: return c.handleSet(v) case *sqlparser.Begin: return c.handleBegin() case *sqlparser.Commit: return c.handleCommit() case *sqlparser.Rollback: return c.handleRollback() case *sqlparser.SimpleSelect: return c.handleSimpleSelect(sql, v) case *sqlparser.Show: return c.handleShow(sql, v) case *sqlparser.Admin: return c.handleAdmin(v) default: return fmt.Errorf("statement %T not support now", stmt) } return nil }
func (c *ClientConn) handleStmtPrepare(sql string) error { if c.schema == nil { return mysql.NewDefaultError(mysql.ER_NO_DB_ERROR) } s := new(Stmt) sql = strings.TrimRight(sql, ";") var err error s.s, err = sqlparser.Parse(sql) if err != nil { return fmt.Errorf(`parse sql "%s" error`, sql) } s.sql = sql defaultRule := c.schema.rule.DefaultRule n := c.proxy.GetNode(defaultRule.Nodes[0]) co, err := n.GetMasterConn() defer c.closeConn(co, false) if err != nil { return fmt.Errorf("prepare error %s", err) } err = co.UseDB(c.db) if err != nil { //reset the database to null c.db = "" return fmt.Errorf("prepare error %s", err) } t, err := co.Prepare(sql) if err != nil { return fmt.Errorf("prepare error %s", err) } s.params = t.ParamNum() s.columns = t.ColumnNum() s.id = c.stmtId c.stmtId++ if err = c.writePrepare(s); err != nil { return err } s.ResetParams() c.stmts[s.id] = s err = co.ClosePrepare(t.GetId()) if err != nil { return err } return nil }
/*由sql语句获得shard node index*/ func (r *Router) GetShardListIndex(sql string, bindVars map[string]interface{}) (nodes []int, err error) { var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) if err != nil { return nil, err } return r.GetStmtShardListIndex(stmt, bindVars) }
func (c *ClientConn) handleStmtPrepare(sql string) error { if c.schema == nil { return NewDefaultError(ER_NO_DB_ERROR) } s := new(Stmt) sql = strings.TrimRight(sql, ";") var err error s.s, err = sqlparser.Parse(sql) if err != nil { return fmt.Errorf(`parse sql "%s" error`, sql) } s.sql = sql defaultRule := c.schema.rule.DefaultRule n := c.proxy.GetNode(defaultRule.Nodes[0]) if co, err := n.GetMasterConn(); err != nil { return fmt.Errorf("prepare error %s", err) } else { defer co.Close() if err = co.UseDB(c.schema.db); err != nil { return fmt.Errorf("parepre error %s", err) } if t, err := co.Prepare(sql); err != nil { return fmt.Errorf("parepre error %s", err) } else { s.params = t.ParamNum() s.columns = t.ColumnNum() } } s.id = c.stmtId c.stmtId++ if err = c.writePrepare(s); err != nil { return err } s.ResetParams() c.stmts[s.id] = s return nil }
func (c *ClientConn) handleStmtPrepare(sql string) error { if c.schema == nil { return NewDefaultError(ER_NO_DB_ERROR) } s := new(Stmt) sql = strings.TrimRight(sql, ";") var err error s.s, err = sqlparser.Parse(sql) if err != nil { return fmt.Errorf(`parse sql "%s" error`, sql) } s.sql = sql var tableName string switch s := s.s.(type) { case *sqlparser.Select: tableName = nstring(s.From) case *sqlparser.Insert: tableName = nstring(s.Table) case *sqlparser.Update: tableName = nstring(s.Table) case *sqlparser.Delete: tableName = nstring(s.Table) case *sqlparser.Replace: tableName = nstring(s.Table) default: return fmt.Errorf(`unsupport prepare sql "%s"`, sql) } r := c.schema.rule.GetRule(tableName) n := c.proxy.GetNode(r.Nodes[0]) if co, err := n.GetMasterConn(); err != nil { return fmt.Errorf("prepare error %s", err) } else { defer co.Close() if err = co.UseDB(c.schema.db); err != nil { return fmt.Errorf("parepre error %s", err) } if t, err := co.Prepare(sql); err != nil { return fmt.Errorf("parepre error %s", err) } else { s.params = t.ParamNum() s.columns = t.ColumnNum() } } s.id = c.stmtId c.stmtId++ if err = c.writePrepare(s); err != nil { return err } s.ResetParams() c.stmts[s.id] = s return nil }
/*处理query语句*/ func (c *ClientConn) handleQuery(sql string) (err error) { defer func() { if e := recover(); e != nil { golog.OutputSql("Error", "%s", sql) if err, ok := e.(error); ok { const size = 4096 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] golog.Error("ClientConn", "handleQuery", err.Error(), 0, "stack", string(buf), "sql", sql) } return } }() sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 hasHandled, err := c.preHandleShard(sql) if err != nil { golog.Error("server", "preHandleShard", err.Error(), 0, "hasHandled", hasHandled) return err } if hasHandled { return nil } var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) //解析sql语句,得到的stmt是一个interface if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled, "sql", sql) return err } switch v := stmt.(type) { case *sqlparser.Select: return c.handleSelect(v, nil) case *sqlparser.Insert: return c.handleExec(stmt, nil) case *sqlparser.Update: return c.handleExec(stmt, nil) case *sqlparser.Delete: return c.handleExec(stmt, nil) case *sqlparser.Replace: return c.handleExec(stmt, nil) case *sqlparser.Set: return c.handleSet(v, sql) case *sqlparser.Begin: return c.handleBegin() case *sqlparser.Commit: return c.handleCommit() case *sqlparser.Rollback: return c.handleRollback() case *sqlparser.Admin: return c.handleAdmin(v) case *sqlparser.UseDB: return c.handleUseDB(v) default: return fmt.Errorf("statement %T not support now", stmt) } return nil }