func (cc *clientConn) dispatch(data []byte) error { cmd := data[0] data = data[1:] cc.lastCmd = hack.String(data) token := cc.server.getToken() defer func() { cc.server.releaseToken(token) }() switch cmd { case mysql.ComQuit: return io.EOF case mysql.ComQuery: return cc.handleQuery(hack.String(data)) case mysql.ComPing: return cc.writeOK() case mysql.ComInitDB: log.Debug("init db", hack.String(data)) if err := cc.useDB(hack.String(data)); err != nil { return errors.Trace(err) } return cc.writeOK() case mysql.ComFieldList: return cc.handleFieldList(hack.String(data)) case mysql.ComStmtPrepare: return cc.handleStmtPrepare(hack.String(data)) case mysql.ComStmtExecute: return cc.handleStmtExecute(data) case mysql.ComStmtClose: return cc.handleStmtClose(data) case mysql.ComStmtSendLongData: return cc.handleStmtSendLongData(data) case mysql.ComStmtReset: return cc.handleStmtReset(data) default: msg := fmt.Sprintf("command %d not supported now", cmd) return mysql.NewError(mysql.ErUnknownError, msg) } }
func (cc *clientConn) writeError(e error) error { var m *mysql.SQLError var ok bool if m, ok = e.(*mysql.SQLError); !ok { m = mysql.NewError(mysql.ErUnknownError, e.Error()) } data := make([]byte, 4, 16+len(m.Message)) data = append(data, mysql.ErrHeader) data = append(data, byte(m.Code), byte(m.Code>>8)) if cc.capability&mysql.ClientProtocol41 > 0 { data = append(data, '#') data = append(data, m.State...) } data = append(data, m.Message...) err := cc.writePacket(data) if err != nil { return errors.Trace(err) } return errors.Trace(cc.flush()) }
func (cc *clientConn) handleStmtExecute(data []byte) (err error) { if len(data) < 9 { return mysql.ErrMalformPacket } pos := 0 stmtID := binary.LittleEndian.Uint32(data[0:4]) pos += 4 stmt := cc.ctx.GetStatement(int(stmtID)) if stmt == nil { return mysql.NewDefaultError(mysql.ErUnknownStmtHandler, strconv.FormatUint(uint64(stmtID), 10), "stmt_execute") } flag := data[pos] pos++ //now we only support CURSOR_TYPE_NO_CURSOR flag if flag != 0 { return mysql.NewError(mysql.ErUnknownError, fmt.Sprintf("unsupported flag %d", flag)) } //skip iteration-count, always 1 pos += 4 var ( nullBitmaps []byte paramTypes []byte paramValues []byte ) numParams := stmt.NumParams() args := make([]interface{}, numParams) if numParams > 0 { nullBitmapLen := (numParams + 7) >> 3 if len(data) < (pos + nullBitmapLen + 1) { return mysql.ErrMalformPacket } nullBitmaps = data[pos : pos+nullBitmapLen] pos += nullBitmapLen //new param bound flag if data[pos] == 1 { pos++ if len(data) < (pos + (numParams << 1)) { return mysql.ErrMalformPacket } paramTypes = data[pos : pos+(numParams<<1)] pos += (numParams << 1) paramValues = data[pos:] } err = parseStmtArgs(args, stmt.BoundParams(), nullBitmaps, paramTypes, paramValues) if err != nil { return errors.Trace(err) } } rs, err := stmt.Execute(args...) if err != nil { return errors.Trace(err) } if rs == nil { return errors.Trace(cc.writeOK()) } return errors.Trace(cc.writeResultset(rs, true)) }