func (c *ClientConn) executeInNode(conn *backend.BackendConn, sql string, args []interface{}) ([]*mysql.Result, error) { var state string r, err := conn.Execute(sql, args...) if err != nil { state = "ERROR" } else { state = "INFO" } if strings.ToLower(c.proxy.cfg.LogSql) != golog.LogSqlOff { golog.OutputSql(state, "%s->%s:%s", c.c.RemoteAddr(), conn.GetAddr(), sql, ) } if err != nil { return nil, err } return []*mysql.Result{r}, err }
func (c *ClientConn) executeInMultiNodes(conns map[string]*backend.BackendConn, sqls map[string][]string, args []interface{}) ([]*mysql.Result, error) { if len(conns) != len(sqls) { golog.Error("ClientConn", "executeInMultiNodes", errors.ErrConnNotEqual.Error(), c.connectionId, "conns", conns, "sqls", sqls, ) return nil, errors.ErrConnNotEqual } var wg sync.WaitGroup if len(conns) == 0 { return nil, errors.ErrNoPlan } wg.Add(len(conns)) resultCount := 0 for _, sqlSlice := range sqls { resultCount += len(sqlSlice) } rs := make([]interface{}, resultCount) f := func(rs []interface{}, i int, execSqls []string, co *backend.BackendConn) { var state string for _, v := range execSqls { r, err := co.Execute(v, args...) if err != nil { state = "ERROR" rs[i] = err } else { state = "INFO" rs[i] = r } if c.proxy.cfg.LogSql != golog.LogSqlOff { golog.OutputSql(state, "%s->%s:%s", c.c.RemoteAddr(), co.GetAddr(), v, ) } i++ } wg.Done() } offsert := 0 for nodeName, co := range conns { s := sqls[nodeName] //[]string go f(rs, offsert, s, co) offsert += len(s) } wg.Wait() var err error r := make([]*mysql.Result, resultCount) for i, v := range rs { if e, ok := v.(error); ok { err = e break } r[i] = rs[i].(*mysql.Result) } return r, err }
/*处理query语句*/ func (c *ClientConn) handleQuery(sql string) (err error) { defer func() { if e := recover(); e != nil { golog.OutputSql("Error", "err:%v,sql:%s", e, sql) if err, ok := e.(error); ok { const size = 4096 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] golog.Error("ClientConn", "handleQuery", err.Error(), 0, "stack", string(buf), "sql", sql) } return } }() sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 hasHandled, err := c.preHandleShard(sql) if err != nil { golog.Error("server", "preHandleShard", err.Error(), 0, "hasHandled", hasHandled) return err } if hasHandled { return nil } var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) //解析sql语句,得到的stmt是一个interface if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled, "sql", sql) return err } switch v := stmt.(type) { case *sqlparser.Select: return c.handleSelect(v, nil) case *sqlparser.Insert: return c.handleExec(stmt, nil) case *sqlparser.Update: return c.handleExec(stmt, nil) case *sqlparser.Delete: return c.handleExec(stmt, nil) case *sqlparser.Replace: return c.handleExec(stmt, nil) case *sqlparser.Set: return c.handleSet(v, sql) case *sqlparser.Begin: return c.handleBegin() case *sqlparser.Commit: return c.handleCommit() case *sqlparser.Rollback: return c.handleRollback() case *sqlparser.Admin: return c.handleAdmin(v) case *sqlparser.UseDB: return c.handleUseDB(v) default: return fmt.Errorf("statement %T not support now", stmt) } return nil }