func (r *Router) buildDeletePlan(statement sqlparser.Statement) (*Plan, error) { plan := &Plan{} var where *sqlparser.Where var err error stmt := statement.(*sqlparser.Delete) plan.Rule = r.GetRule(sqlparser.String(stmt.Table)) where = stmt.Where if where != nil { plan.Criteria = where.Expr //路由条件 err = plan.calRouteIndexs() if err != nil { golog.Error("Route", "BuildUpdatePlan", err.Error(), 0) return nil, err } } else { //if shard delete without where,send to all nodes and all tables plan.RouteTableIndexs = plan.Rule.SubTableIndexs plan.RouteNodeIndexs = makeList(0, len(plan.Rule.Nodes)) } if plan.Rule.Type != DefaultRuleType && len(plan.RouteTableIndexs) == 0 { golog.Error("Route", "BuildDeletePlan", errors.ErrNoCriteria.Error(), 0) return nil, errors.ErrNoCriteria } //generate sql,如果routeTableindexs为空则表示不分表,不分表则发default node err = r.generateDeleteSql(plan, stmt) if err != nil { return nil, err } return plan, nil }
func (s *Server) onConn(c net.Conn) { s.counter.IncrClientConns() conn := s.newClientConn(c) //新建一个conn defer func() { err := recover() if err != nil { const size = 4096 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] //获得当前goroutine的stacktrace golog.Error("server", "onConn", "error", 0, "remoteAddr", c.RemoteAddr().String(), "stack", string(buf), ) } conn.Close() s.counter.DecrClientConns() }() if allowConnect := conn.IsAllowConnect(); allowConnect == false { err := mysql.NewError(mysql.ER_ACCESS_DENIED_ERROR, "ip address access denied by kingshard.") conn.writeError(err) conn.Close() return } if err := conn.Handshake(); err != nil { golog.Error("server", "onConn", err.Error(), 0) c.Close() return } conn.Run() }
func (c *ClientConn) Handshake() error { if err := c.writeInitialHandshake(); err != nil { golog.Error("server", "Handshake", err.Error(), c.connectionId, "msg", "send initial handshake error") return err } if err := c.readHandshakeResponse(); err != nil { golog.Error("server", "readHandshakeResponse", err.Error(), c.connectionId, "msg", "read Handshake Response error") c.writeError(err) return err } if err := c.writeOK(nil); err != nil { golog.Error("server", "readHandshakeResponse", "write ok fail", c.connectionId, "error", err.Error()) return err } c.pkg.Sequence = 0 return nil }
func (r *Router) buildReplacePlan(statement sqlparser.Statement) (*Plan, error) { plan := &Plan{} plan.Rows = make(map[int]sqlparser.Values) stmt := statement.(*sqlparser.Replace) if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok { panic(sqlparser.NewParserError("select in replace not allowed")) } if stmt.Columns == nil { return nil, errors.ErrIRNoColumns } plan.Rule = r.GetRule(sqlparser.String(stmt.Table)) err := plan.GetIRKeyIndex(stmt.Columns) if err != nil { return nil, err } plan.Criteria = plan.checkValuesType(stmt.Rows.(sqlparser.Values)) err = plan.calRouteIndexs() if err != nil { golog.Error("Route", "BuildReplacePlan", err.Error(), 0) return nil, err } err = r.generateReplaceSql(plan, stmt) if err != nil { return nil, err } return plan, nil }
func (r *Router) buildSelectPlan(statement sqlparser.Statement) (*Plan, error) { plan := &Plan{} var where *sqlparser.Where var err error var tableName string stmt := statement.(*sqlparser.Select) switch v := (stmt.From[0]).(type) { case *sqlparser.AliasedTableExpr: tableName = sqlparser.String(v.Expr) case *sqlparser.JoinTableExpr: if ate, ok := (v.LeftExpr).(*sqlparser.AliasedTableExpr); ok { tableName = sqlparser.String(ate.Expr) } else { tableName = sqlparser.String(v) } default: tableName = sqlparser.String(v) } plan.Rule = r.GetRule(tableName) //根据表名获得分表规则 where = stmt.Where if where != nil { plan.Criteria = where.Expr //路由条件 err = plan.calRouteIndexs() if err != nil { golog.Error("Route", "BuildSelectPlan", err.Error(), 0) return nil, err } } else { //if shard select without where,send to all nodes and all tables plan.RouteTableIndexs = plan.Rule.SubTableIndexs plan.RouteNodeIndexs = makeList(0, len(plan.Rule.Nodes)) } if plan.Rule.Type != DefaultRuleType && len(plan.RouteTableIndexs) == 0 { golog.Error("Route", "BuildSelectPlan", errors.ErrNoCriteria.Error(), 0) return nil, errors.ErrNoCriteria } //generate sql,如果routeTableindexs为空则表示不分表,不分表则发default node err = r.generateSelectSql(plan, stmt) if err != nil { return nil, err } return plan, nil }
//处理select语句 func (c *ClientConn) handleSelect(stmt *sqlparser.Select, args []interface{}) error { var fromSlave bool = true plan, err := c.schema.rule.BuildPlan(stmt) if err != nil { return err } if 0 < len(stmt.Comments) { comment := string(stmt.Comments[0]) if 0 < len(comment) && strings.ToLower(comment) == MasterComment { fromSlave = false } } conns, err := c.getShardConns(fromSlave, plan) if err != nil { golog.Error("ClientConn", "handleSelect", err.Error(), c.connectionId) return err } if conns == nil { r := c.newEmptyResultset(stmt) return c.writeResultset(c.status, r) } var rs []*mysql.Result rs, err = c.executeInMultiNodes(conns, plan.RewrittenSqls, args) c.closeShardConns(conns, false) if err != nil { golog.Error("ClientConn", "handleSelect", err.Error(), c.connectionId) return err } err = c.mergeSelectResult(rs, stmt) if err != nil { golog.Error("ClientConn", "handleSelect", err.Error(), c.connectionId) } return err }
func (c *ClientConn) Run() { defer func() { r := recover() if err, ok := r.(error); ok { const size = 4096 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] golog.Error("ClientConn", "Run", err.Error(), 0, "stack", string(buf)) } c.Close() }() for { data, err := c.readPacket() if err != nil { return } if err := c.dispatch(data); err != nil { c.proxy.counter.IncrErrLogTotal() golog.Error("server", "Run", err.Error(), c.connectionId, ) c.writeError(err) } if c.closed { return } c.pkg.Sequence = 0 } }
func (c *ClientConn) getBackendConn(n *backend.Node, fromSlave bool) (co *backend.BackendConn, err error) { if !c.isInTransaction() { if fromSlave { co, err = n.GetSlaveConn() if err != nil { co, err = n.GetMasterConn() } } else { co, err = n.GetMasterConn() } if err != nil { golog.Error("server", "getBackendConn", err.Error(), 0) return } } else { var ok bool co, ok = c.txConns[n] if !ok { if co, err = n.GetMasterConn(); err != nil { return } if !c.isAutoCommit() { if err = co.SetAutoCommit(0); err != nil { return } } else { if err = co.Begin(); err != nil { return } } c.txConns[n] = co } } if err = co.UseDB(c.db); err != nil { return } if err = co.SetCharset(c.charset); err != nil { return } return }
func (s *Server) Run() error { s.running = true // flush counter go s.flushCounter() for s.running { conn, err := s.listener.Accept() if err != nil { golog.Error("server", "Run", err.Error(), 0) continue } go s.onConn(conn) } return nil }
func (c *ClientConn) handleSet(stmt *sqlparser.Set, sql string) (err error) { if len(stmt.Exprs) != 1 && len(stmt.Exprs) != 2 { return fmt.Errorf("must set one item once, not %s", nstring(stmt)) } //log the SQL startTime := time.Now().UnixNano() defer func() { var state string if err != nil { state = "ERROR" } else { state = "OK" } execTime := float64(time.Now().UnixNano()-startTime) / float64(time.Millisecond) if c.proxy.logSql[c.proxy.logSqlIndex] != golog.LogSqlOff && execTime > float64(c.proxy.slowLogTime[c.proxy.slowLogTimeIndex]) { c.proxy.counter.IncrSlowLogTotal() golog.OutputSql(state, "%.1fms - %s->%s:%s", execTime, c.c.RemoteAddr(), c.proxy.addr, sql, ) } }() k := string(stmt.Exprs[0].Name.Name) switch strings.ToUpper(k) { case `AUTOCOMMIT`: return c.handleSetAutoCommit(stmt.Exprs[0].Expr) case `NAMES`, `CHARACTER_SET_RESULTS`, `CHARACTER_SET_CLIENT`, `CHARACTER_SET_CONNECTION`: if len(stmt.Exprs) == 2 { //SET NAMES 'charset_name' COLLATE 'collation_name' return c.handleSetNames(stmt.Exprs[0].Expr, stmt.Exprs[1].Expr) } return c.handleSetNames(stmt.Exprs[0].Expr, nil) default: golog.Error("ClientConn", "handleSet", "command not supported", c.connectionId, "sql", sql) return c.writeOK(nil) } }
func (r *Router) buildInsertPlan(statement sqlparser.Statement) (*Plan, error) { plan := &Plan{} plan.Rows = make(map[int]sqlparser.Values) stmt := statement.(*sqlparser.Insert) if _, ok := stmt.Rows.(sqlparser.SelectStatement); ok { return nil, errors.ErrSelectInInsert } if stmt.Columns == nil { return nil, errors.ErrIRNoColumns } //根据sql语句的表,获得对应的分片规则 plan.Rule = r.GetRule(sqlparser.String(stmt.Table)) err := plan.GetIRKeyIndex(stmt.Columns) if err != nil { return nil, err } if stmt.OnDup != nil { err := plan.Rule.checkUpdateExprs(sqlparser.UpdateExprs(stmt.OnDup)) if err != nil { return nil, err } } plan.Criteria = plan.checkValuesType(stmt.Rows.(sqlparser.Values)) err = plan.calRouteIndexs() if err != nil { golog.Error("Route", "BuildInsertPlan", err.Error(), 0) return nil, err } err = r.generateInsertSql(plan, stmt) if err != nil { return nil, err } return plan, nil }
func (c *ClientConn) dispatch(data []byte) error { c.proxy.counter.IncrClientQPS() cmd := data[0] data = data[1:] switch cmd { case mysql.COM_QUIT: c.Close() return nil case mysql.COM_QUERY: return c.handleQuery(hack.String(data)) case mysql.COM_PING: return c.writeOK(nil) case mysql.COM_INIT_DB: if err := c.useDB(hack.String(data)); err != nil { return err } else { return c.writeOK(nil) } case mysql.COM_FIELD_LIST: return c.handleFieldList(data) case mysql.COM_STMT_PREPARE: return c.handleStmtPrepare(hack.String(data)) case mysql.COM_STMT_EXECUTE: return c.handleStmtExecute(data) case mysql.COM_STMT_CLOSE: return c.handleStmtClose(data) case mysql.COM_STMT_SEND_LONG_DATA: return c.handleStmtSendLongData(data) case mysql.COM_STMT_RESET: return c.handleStmtReset(data) case mysql.COM_SET_OPTION: return c.writeEOF(0) default: msg := fmt.Sprintf("command %d not supported now", cmd) golog.Error("ClientConn", "dispatch", msg, 0) return mysql.NewError(mysql.ER_UNKNOWN_ERROR, msg) } return nil }
func (c *ClientConn) IsAllowConnect() bool { clientHost, _, err := net.SplitHostPort(c.c.RemoteAddr().String()) if err != nil { fmt.Println(err) } clientIP := net.ParseIP(clientHost) ipVec := c.proxy.allowips[c.proxy.allowipsIndex] if ipVecLen := len(ipVec); ipVecLen == 0 { return true } for _, ip := range ipVec { if ip.Equal(clientIP) { return true } } golog.Error("server", "IsAllowConnect", "error", mysql.ER_ACCESS_DENIED_ERROR, "ip address", c.c.RemoteAddr().String(), " access denied by kindshard.") return false }
func (s *Server) saveBlackSql() error { if len(s.cfg.BlsFile) == 0 { return nil } f, err := os.Create(s.cfg.BlsFile) if err != nil { golog.Error("Server", "saveBlackSql", "create file error", 0, "err", err.Error(), "blacklist_sql_file", s.cfg.BlsFile, ) return err } for _, v := range s.blacklistSqls[s.blacklistSqlsIndex].sqls { v = v + "\n" _, err = f.WriteString(v) if err != nil { return err } } return nil }
//计算表下标和node下标 func (plan *Plan) calRouteIndexs() error { var err error nodesCount := len(plan.Rule.Nodes) if plan.Rule.Type == DefaultRuleType { plan.RouteNodeIndexs = []int{0} return nil } if plan.Criteria == nil { //如果没有分表条件,则是全子表扫描 if plan.Rule.Type != DefaultRuleType { golog.Error("Plan", "calRouteIndexs", "plan have no criteria", 0, "type", plan.Rule.Type) return errors.ErrNoCriteria } } switch criteria := plan.Criteria.(type) { case sqlparser.Values: //代表insert中values plan.RouteTableIndexs, err = plan.getInsertTableIndex(criteria) if err != nil { return err } plan.RouteNodeIndexs = plan.TindexsToNindexs(plan.RouteTableIndexs) return nil case sqlparser.BoolExpr: plan.RouteTableIndexs, err = plan.getTableIndexByBoolExpr(criteria) if err != nil { return err } plan.RouteNodeIndexs = plan.TindexsToNindexs(plan.RouteTableIndexs) return nil default: plan.RouteTableIndexs = plan.Rule.SubTableIndexs plan.RouteNodeIndexs = makeList(0, nodesCount) return nil } }
func (c *ClientConn) handlePrepareSelect(stmt *sqlparser.Select, sql string, args []interface{}) error { defaultRule := c.schema.rule.DefaultRule if len(defaultRule.Nodes) == 0 { return errors.ErrNoDefaultNode } defaultNode := c.proxy.GetNode(defaultRule.Nodes[0]) //choose connection in slave DB first conn, err := c.getBackendConn(defaultNode, true) defer c.closeConn(conn, false) if err != nil { return err } if conn == nil { r := c.newEmptyResultset(stmt) return c.writeResultset(c.status, r) } var rs []*mysql.Result rs, err = c.executeInNode(conn, sql, args) if err != nil { golog.Error("ClientConn", "handlePrepareSelect", err.Error(), c.connectionId) return err } status := c.status | rs[0].Status if rs[0].Resultset != nil { err = c.writeResultset(status, rs[0].Resultset) } else { r := c.newEmptyResultset(stmt) err = c.writeResultset(status, r) } return err }
func (c *ClientConn) handlePrepareExec(stmt sqlparser.Statement, sql string, args []interface{}) error { defaultRule := c.schema.rule.DefaultRule if len(defaultRule.Nodes) == 0 { return errors.ErrNoDefaultNode } defaultNode := c.proxy.GetNode(defaultRule.Nodes[0]) //execute in Master DB conn, err := c.getBackendConn(defaultNode, false) defer c.closeConn(conn, false) if err != nil { return err } if conn == nil { return c.writeOK(nil) } var rs []*mysql.Result rs, err = c.executeInNode(conn, sql, args) c.closeConn(conn, false) if err != nil { golog.Error("ClientConn", "handlePrepareExec", err.Error(), c.connectionId) return err } status := c.status | rs[0].Status if rs[0].Resultset != nil { err = c.writeResultset(status, rs[0].Resultset) } else { err = c.writeOK(rs[0]) } return err }
func (c *ClientConn) handleExec(stmt sqlparser.Statement, args []interface{}) error { plan, err := c.schema.rule.BuildPlan(stmt) if err != nil { return err } conns, err := c.getShardConns(false, plan) defer c.closeShardConns(conns, err != nil) if err != nil { golog.Error("ClientConn", "handleExec", err.Error(), c.connectionId) return err } if conns == nil { return c.writeOK(nil) } var rs []*mysql.Result rs, err = c.executeInMultiNodes(conns, plan.RewrittenSqls, args) if err == nil { err = c.mergeExecResult(rs) } return err }
func (s *Server) parseSchema() error { schemaCfg := s.cfg.Schema if len(schemaCfg.Nodes) == 0 { //fmt.Errorf("schema [%s] must have a node.", schemaCfg.DB) golog.Error("server", "parser schema", "no schema configured...", 0, s.cfg) return nil } nodes := make(map[string]*backend.Node) for _, n := range schemaCfg.Nodes { if s.GetNode(n) == nil { fmt.Errorf("schema [%s] node [%s] config is not exists.", schemaCfg.DB, n) return nil } if _, ok := nodes[n]; ok { return fmt.Errorf("schema [%s] node [%s] duplicate.", schemaCfg.DB, n) } nodes[n] = s.GetNode(n) } rule, err := router.NewRouter(&schemaCfg) if err != nil { return err } s.schema = &Schema{ db: schemaCfg.DB, nodes: nodes, rule: rule, } s.db = schemaCfg.DB return nil }
func (c *ClientConn) executeInMultiNodes(conns map[string]*backend.BackendConn, sqls map[string][]string, args []interface{}) ([]*mysql.Result, error) { if len(conns) != len(sqls) { golog.Error("ClientConn", "executeInMultiNodes", errors.ErrConnNotEqual.Error(), c.connectionId, "conns", conns, "sqls", sqls, ) return nil, errors.ErrConnNotEqual } var wg sync.WaitGroup if len(conns) == 0 { return nil, errors.ErrNoPlan } wg.Add(len(conns)) resultCount := 0 for _, sqlSlice := range sqls { resultCount += len(sqlSlice) } rs := make([]interface{}, resultCount) f := func(rs []interface{}, i int, execSqls []string, co *backend.BackendConn) { var state string for _, v := range execSqls { startTime := time.Now().UnixNano() r, err := co.Execute(v, args...) if err != nil { state = "ERROR" rs[i] = err } else { state = "OK" rs[i] = r } execTime := float64(time.Now().UnixNano()-startTime) / float64(time.Millisecond) if c.proxy.logSql[c.proxy.logSqlIndex] != golog.LogSqlOff && execTime > float64(c.proxy.slowLogTime[c.proxy.slowLogTimeIndex]) { c.proxy.counter.IncrSlowLogTotal() golog.OutputSql(state, "%.1fms - %s->%s:%s", execTime, c.c.RemoteAddr(), co.GetAddr(), v, ) } i++ } wg.Done() } offsert := 0 for nodeName, co := range conns { s := sqls[nodeName] //[]string go f(rs, offsert, s, co) offsert += len(s) } wg.Wait() var err error r := make([]*mysql.Result, resultCount) for i, v := range rs { if e, ok := v.(error); ok { err = e break } r[i] = rs[i].(*mysql.Result) } return r, err }
func (c *ClientConn) readHandshakeResponse() error { data, err := c.readPacket() if err != nil { return err } pos := 0 //capability c.capability = binary.LittleEndian.Uint32(data[:4]) pos += 4 //skip max packet size pos += 4 //charset, skip, if you want to use another charset, use set names //c.collation = CollationId(data[pos]) pos++ //skip reserved 23[00] pos += 23 //user name c.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) pos += len(c.user) + 1 //auth length and auth authLen := int(data[pos]) pos++ auth := data[pos : pos+authLen] checkAuth := mysql.CalcPassword(c.salt, []byte(c.proxy.cfg.Password)) if c.user != c.proxy.cfg.User || !bytes.Equal(auth, checkAuth) { golog.Error("ClientConn", "readHandshakeResponse", "error", 0, "auth", auth, "checkAuth", checkAuth, "client_user", c.user, "config_set_user", c.proxy.cfg.User, "passworld", c.proxy.cfg.Password) return mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR, c.user, c.c.RemoteAddr().String(), "Yes") } pos += authLen var db string if c.capability&mysql.CLIENT_CONNECT_WITH_DB > 0 { if len(data[pos:]) == 0 { return nil } db = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) pos += len(c.db) + 1 } golog.Error("handshake ", "response", "db ", 0, db) if err := c.useDB(db); err != nil { return err } return nil }
func main() { fmt.Print(banner) runtime.GOMAXPROCS(runtime.NumCPU()) flag.Parse() fmt.Printf("Git commit:%s\n", hack.Version) fmt.Printf("Build time:%s\n", hack.Compile) if *version { return } if len(*configFile) == 0 { fmt.Println("must use a config file") return } cfg, err := config.ParseConfigFile(*configFile) if err != nil { fmt.Printf("parse config file error:%v\n", err.Error()) return } //when the log file size greater than 1GB, kingshard will generate a new file if len(cfg.LogPath) != 0 { sysFilePath := path.Join(cfg.LogPath, sysLogName) sysFile, err := golog.NewRotatingFileHandler(sysFilePath, MaxLogSize, 1) if err != nil { fmt.Printf("new log file error:%v\n", err.Error()) return } golog.GlobalSysLogger = golog.New(sysFile, golog.Lfile|golog.Ltime|golog.Llevel) sqlFilePath := path.Join(cfg.LogPath, sqlLogName) sqlFile, err := golog.NewRotatingFileHandler(sqlFilePath, MaxLogSize, 1) if err != nil { fmt.Printf("new log file error:%v\n", err.Error()) return } golog.GlobalSqlLogger = golog.New(sqlFile, golog.Lfile|golog.Ltime|golog.Llevel) } if *logLevel != "" { setLogLevel(*logLevel) } else { setLogLevel(cfg.LogLevel) } var svr *server.Server svr, err = server.NewServer(cfg) if err != nil { golog.Error("main", "main", err.Error(), 0) golog.GlobalSysLogger.Close() golog.GlobalSqlLogger.Close() return } sc := make(chan os.Signal, 1) signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) go func() { sig := <-sc golog.Info("main", "main", "Got signal", 0, "signal", sig) golog.GlobalSysLogger.Close() golog.GlobalSqlLogger.Close() svr.Close() }() svr.Run() }
/*处理query语句*/ func (c *ClientConn) handleQuery(sql string) (err error) { defer func() { if e := recover(); e != nil { golog.OutputSql("Error", "err:%v,sql:%s", e, sql) if err, ok := e.(error); ok { const size = 4096 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] golog.Error("ClientConn", "handleQuery", err.Error(), 0, "stack", string(buf), "sql", sql) } return } }() sql = strings.TrimRight(sql, ";") //删除sql语句最后的分号 hasHandled, err := c.preHandleShard(sql) if err != nil { golog.Error("server", "preHandleShard", err.Error(), 0, "hasHandled", hasHandled) return err } if hasHandled { return nil } var stmt sqlparser.Statement stmt, err = sqlparser.Parse(sql) //解析sql语句,得到的stmt是一个interface if err != nil { golog.Error("server", "parse", err.Error(), 0, "hasHandled", hasHandled, "sql", sql) return err } switch v := stmt.(type) { case *sqlparser.Select: return c.handleSelect(v, nil) case *sqlparser.Insert: return c.handleExec(stmt, nil) case *sqlparser.Update: return c.handleExec(stmt, nil) case *sqlparser.Delete: return c.handleExec(stmt, nil) case *sqlparser.Replace: return c.handleExec(stmt, nil) case *sqlparser.Set: return c.handleSet(v, sql) case *sqlparser.Begin: return c.handleBegin() case *sqlparser.Commit: return c.handleCommit() case *sqlparser.Rollback: return c.handleRollback() case *sqlparser.Admin: return c.handleAdmin(v) case *sqlparser.AdminHelp: return c.handleAdminHelp(v) case *sqlparser.UseDB: return c.handleUseDB(v) default: return fmt.Errorf("statement %T not support now", stmt) } return nil }