/*处理query语句*/ func (c *ClientConn) handleQuery(sql string) (err error) { defer func() { if e := recover(); e != nil { err = fmt.Errorf("execute %s error %v", sql, e) golog.OutputSql("Error", "%s", sql) return } golog.OutputSql("INFO", "%s", sql) }() sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 hasHandled, err := c.handleUnsupport(sql) if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled) return err } if hasHandled { return nil } var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) //解析sql语句,得到的stmt是一个interface if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled) return err } switch v := stmt.(type) { case *sqlparser.Select: return c.handleSelect(v, sql, nil) case *sqlparser.Insert: return c.handleExec(stmt, sql, nil) case *sqlparser.Update: return c.handleExec(stmt, sql, nil) case *sqlparser.Delete: return c.handleExec(stmt, sql, nil) case *sqlparser.Replace: return c.handleExec(stmt, sql, nil) case *sqlparser.Set: return c.handleSet(v) case *sqlparser.Begin: return c.handleBegin() case *sqlparser.Commit: return c.handleCommit() case *sqlparser.Rollback: return c.handleRollback() case *sqlparser.SimpleSelect: return c.handleSimpleSelect(sql, v) case *sqlparser.Show: return c.handleShow(sql, v) case *sqlparser.Admin: return c.handleAdmin(v) default: return fmt.Errorf("statement %T not support now", stmt) } return nil }
func (c *ClientConn) executeInNode(conn *backend.BackendConn, sql string, args []interface{}) ([]*mysql.Result, error) { var state string startTime := time.Now().UnixNano() r, err := conn.Execute(sql, args...) if err != nil { state = "ERROR" } else { state = "OK" } execTime := float64(time.Now().UnixNano()-startTime) / float64(time.Millisecond) if strings.ToLower(c.proxy.logSql[c.proxy.logSqlIndex]) != golog.LogSqlOff && execTime > float64(c.proxy.slowLogTime[c.proxy.slowLogTimeIndex]) { c.proxy.counter.IncrSlowLogTotal() golog.OutputSql(state, "%.1fms - %s->%s:%s", execTime, c.c.RemoteAddr(), conn.GetAddr(), sql, ) } if err != nil { return nil, err } return []*mysql.Result{r}, err }
func (c *ClientConn) handleSet(stmt *sqlparser.Set, sql string) (err error) { if len(stmt.Exprs) != 1 && len(stmt.Exprs) != 2 { return fmt.Errorf("must set one item once, not %s", nstring(stmt)) } //log the SQL startTime := time.Now().UnixNano() defer func() { var state string if err != nil { state = "ERROR" } else { state = "OK" } execTime := float64(time.Now().UnixNano()-startTime) / float64(time.Millisecond) if c.proxy.logSql[c.proxy.logSqlIndex] != golog.LogSqlOff && execTime > float64(c.proxy.slowLogTime[c.proxy.slowLogTimeIndex]) { c.proxy.counter.IncrSlowLogTotal() golog.OutputSql(state, "%.1fms - %s->%s:%s", execTime, c.c.RemoteAddr(), c.proxy.addr, sql, ) } }() k := string(stmt.Exprs[0].Name.Name) switch strings.ToUpper(k) { case `AUTOCOMMIT`, `@@AUTOCOMMIT`, `@@SESSION.AUTOCOMMIT`: return c.handleSetAutoCommit(stmt.Exprs[0].Expr) case `NAMES`, `CHARACTER_SET_RESULTS`, `@@CHARACTER_SET_RESULTS`, `@@SESSION.CHARACTER_SET_RESULTS`, `CHARACTER_SET_CLIENT`, `@@CHARACTER_SET_CLIENT`, `@@SESSION.CHARACTER_SET_CLIENT`, `CHARACTER_SET_CONNECTION`, `@@CHARACTER_SET_CONNECTION`, `@@SESSION.CHARACTER_SET_CONNECTION`: if len(stmt.Exprs) == 2 { //SET NAMES 'charset_name' COLLATE 'collation_name' return c.handleSetNames(stmt.Exprs[0].Expr, stmt.Exprs[1].Expr) } return c.handleSetNames(stmt.Exprs[0].Expr, nil) default: golog.Error("ClientConn", "handleSet", "command not supported", c.connectionId, "sql", sql) return c.writeOK(nil) } }
func (c *ClientConn) executeInNode(conn *backend.BackendConn, sql string, args []interface{}) ([]*mysql.Result, error) { var state string r, err := conn.Execute(sql, args...) if err != nil { state = "ERROR" } else { state = "INFO" } if strings.ToLower(c.proxy.cfg.LogSql) != golog.LogSqlOff { golog.OutputSql(state, "%s->%s:%s", c.c.RemoteAddr(), conn.GetAddr(), sql, ) } if err != nil { return nil, err } return []*mysql.Result{r}, err }
func (c *ClientConn) executeInNode(conn *backend.BackendConn, sql string, args []interface{}) ([]*Result, error) { var wg sync.WaitGroup wg.Add(1) rs := make([]interface{}, 1) f := func(rs []interface{}, i int, co *backend.BackendConn) { var state string r, err := co.Execute(sql, args...) if err != nil { state = "ERROR" rs[i] = err } else { state = "INFO" rs[i] = r } golog.OutputSql(state, "%s->%s:%s", c.c.RemoteAddr(), co.GetAddr(), sql, ) wg.Done() } go f(rs, 0, conn) wg.Wait() var err error r := make([]*Result, 1) for i, v := range rs { if e, ok := v.(error); ok { err = e break } r[i] = rs[i].(*Result) } return r, err }
//preprocessing sql before parse sql func (c *ClientConn) preHandleShard(sql string) (bool, error) { var rs []*mysql.Result var err error var executeDB *ExecuteDB if len(sql) == 0 { return false, errors.ErrCmdUnsupport } //filter the blacklist sql if c.proxy.blacklistSqls[c.proxy.blacklistSqlsIndex].sqlsLen != 0 { if c.isBlacklistSql(sql) { golog.OutputSql("Forbidden", "%s->%s:%s", c.c.RemoteAddr(), c.proxy.addr, sql, ) err := mysql.NewError(mysql.ER_UNKNOWN_ERROR, "sql in blacklist.") return false, err } } tokens := strings.FieldsFunc(sql, hack.IsSqlSep) if len(tokens) == 0 { return false, errors.ErrCmdUnsupport } if c.isInTransaction() { executeDB, err = c.GetTransExecDB(tokens, sql) } else { executeDB, err = c.GetExecDB(tokens, sql) } if err != nil { //this SQL doesn't need execute in the backend. if err == errors.ErrIgnoreSQL { err = c.writeOK(nil) if err != nil { return false, err } return true, nil } return false, err } //need shard sql if executeDB == nil { return false, nil } //get connection in DB conn, err := c.getBackendConn(executeDB.ExecNode, executeDB.IsSlave) defer c.closeConn(conn, false) if err != nil { return false, err } rs, err = c.executeInNode(conn, sql, nil) if err != nil { return false, err } if len(rs) == 0 { msg := fmt.Sprintf("result is empty") golog.Error("ClientConn", "handleUnsupport", msg, 0, "sql", sql) return false, mysql.NewError(mysql.ER_UNKNOWN_ERROR, msg) } c.lastInsertId = int64(rs[0].InsertId) c.affectedRows = int64(rs[0].AffectedRows) if rs[0].Resultset != nil { err = c.writeResultset(c.status, rs[0].Resultset) } else { err = c.writeOK(rs[0]) } if err != nil { return false, err } return true, nil }
func (c *ClientConn) executeInMultiNodes(conns map[string]*backend.BackendConn, sqls map[string][]string, args []interface{}) ([]*mysql.Result, error) { if len(conns) != len(sqls) { golog.Error("ClientConn", "executeInMultiNodes", errors.ErrConnNotEqual.Error(), c.connectionId, "conns", conns, "sqls", sqls, ) return nil, errors.ErrConnNotEqual } var wg sync.WaitGroup if len(conns) == 0 { return nil, errors.ErrNoPlan } wg.Add(len(conns)) resultCount := 0 for _, sqlSlice := range sqls { resultCount += len(sqlSlice) } rs := make([]interface{}, resultCount) f := func(rs []interface{}, i int, execSqls []string, co *backend.BackendConn) { var state string for _, v := range execSqls { r, err := co.Execute(v, args...) if err != nil { state = "ERROR" rs[i] = err } else { state = "INFO" rs[i] = r } if c.proxy.cfg.LogSql != golog.LogSqlOff { golog.OutputSql(state, "%s->%s:%s", c.c.RemoteAddr(), co.GetAddr(), v, ) } i++ } wg.Done() } offsert := 0 for nodeName, co := range conns { s := sqls[nodeName] //[]string go f(rs, offsert, s, co) offsert += len(s) } wg.Wait() var err error r := make([]*mysql.Result, resultCount) for i, v := range rs { if e, ok := v.(error); ok { err = e break } r[i] = rs[i].(*mysql.Result) } return r, err }
/*处理query语句*/ func (c *ClientConn) handleQuery(sql string) (err error) { defer func() { if e := recover(); e != nil { golog.OutputSql("Error", "%s", sql) if err, ok := e.(error); ok { const size = 4096 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] golog.Error("ClientConn", "handleQuery", err.Error(), 0, "stack", string(buf), "sql", sql) } return } }() sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 hasHandled, err := c.preHandleShard(sql) if err != nil { golog.Error("server", "preHandleShard", err.Error(), 0, "hasHandled", hasHandled) return err } if hasHandled { return nil } var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) //解析sql语句,得到的stmt是一个interface if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled, "sql", sql) return err } switch v := stmt.(type) { case *sqlparser.Select: return c.handleSelect(v, nil) case *sqlparser.Insert: return c.handleExec(stmt, nil) case *sqlparser.Update: return c.handleExec(stmt, nil) case *sqlparser.Delete: return c.handleExec(stmt, nil) case *sqlparser.Replace: return c.handleExec(stmt, nil) case *sqlparser.Set: return c.handleSet(v, sql) case *sqlparser.Begin: return c.handleBegin() case *sqlparser.Commit: return c.handleCommit() case *sqlparser.Rollback: return c.handleRollback() case *sqlparser.Admin: return c.handleAdmin(v) case *sqlparser.UseDB: return c.handleUseDB(v) default: return fmt.Errorf("statement %T not support now", stmt) } return nil }