func (cc *clientConn) readHandshakeResponse() error { data, err := cc.readPacket() if err != nil { return errors.Trace(err) } pos := 0 // capability cc.capability = binary.LittleEndian.Uint32(data[:4]) pos += 4 // skip max packet size pos += 4 // charset, skip, if you want to use another charset, use set names cc.collation = data[pos] pos++ // skip reserved 23[00] pos += 23 // user name cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) pos += len(cc.user) + 1 // auth length and auth authLen := int(data[pos]) pos++ auth := data[pos : pos+authLen] pos += authLen if cc.capability|mysql.ClientConnectWithDB > 0 { if len(data[pos:]) > 0 { idx := bytes.IndexByte(data[pos:], 0) cc.dbname = string(data[pos : pos+idx]) } } // Open session and do auth cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname) if err != nil { cc.Close() return errors.Trace(err) } if !cc.server.skipAuth() { // Do Auth addr := cc.conn.RemoteAddr().String() host, _, err1 := net.SplitHostPort(addr) if err1 != nil { return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, addr, "Yes")) } user := fmt.Sprintf("%s@%s", cc.user, host) if !cc.ctx.Auth(user, auth, cc.salt) { return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, host, "Yes")) } } return nil }
// AppendParam implements IStatement AppendParam method. func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error { if paramID >= len(ts.boundParams) { return mysql.NewDefaultError(mysql.ErWrongArguments, "stmt_send_longdata") } ts.boundParams[paramID] = append(ts.boundParams[paramID], data...) return nil }
func (cc *clientConn) handleStmtReset(data []byte) (err error) { if len(data) < 4 { return mysql.ErrMalformPacket } stmtID := int(binary.LittleEndian.Uint32(data[0:4])) stmt := cc.ctx.GetStatement(stmtID) if stmt == nil { return mysql.NewDefaultError(mysql.ErUnknownStmtHandler, strconv.Itoa(stmtID), "stmt_reset") } stmt.Reset() return cc.writeOK() }
func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) { if len(data) < 6 { return mysql.ErrMalformPacket } stmtID := int(binary.LittleEndian.Uint32(data[0:4])) stmt := cc.ctx.GetStatement(stmtID) if stmt == nil { return mysql.NewDefaultError(mysql.ErUnknownStmtHandler, strconv.Itoa(stmtID), "stmt_send_longdata") } paramID := int(binary.LittleEndian.Uint16(data[4:6])) return stmt.AppendParam(paramID, data[6:]) }
func (cc *clientConn) readHandshakeResponse() error { data, err := cc.readPacket() if err != nil { return errors.Trace(err) } pos := 0 // capability cc.capability = binary.LittleEndian.Uint32(data[:4]) pos += 4 // skip max packet size pos += 4 // charset, skip, if you want to use another charset, use set names cc.collation = data[pos] pos++ // skip reserved 23[00] pos += 23 // user name cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) pos += len(cc.user) + 1 // auth length and auth authLen := int(data[pos]) pos++ auth := data[pos : pos+authLen] checkAuth := calcPassword(cc.salt, []byte(cc.server.cfgGetPwd(cc.user))) if !bytes.Equal(auth, checkAuth) && !cc.server.skipAuth() { return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.conn.RemoteAddr().String(), cc.user, "Yes")) } pos += authLen if cc.capability|mysql.ClientConnectWithDB > 0 { if len(data[pos:]) == 0 { return nil } idx := bytes.IndexByte(data[pos:], 0) cc.dbname = string(data[pos : pos+idx]) } return nil }
func (cc *clientConn) handleStmtExecute(data []byte) (err error) { if len(data) < 9 { return mysql.ErrMalformPacket } pos := 0 stmtID := binary.LittleEndian.Uint32(data[0:4]) pos += 4 stmt := cc.ctx.GetStatement(int(stmtID)) if stmt == nil { return mysql.NewDefaultError(mysql.ErUnknownStmtHandler, strconv.FormatUint(uint64(stmtID), 10), "stmt_execute") } flag := data[pos] pos++ //now we only support CURSOR_TYPE_NO_CURSOR flag if flag != 0 { return mysql.NewError(mysql.ErUnknownError, fmt.Sprintf("unsupported flag %d", flag)) } //skip iteration-count, always 1 pos += 4 var ( nullBitmaps []byte paramTypes []byte paramValues []byte ) numParams := stmt.NumParams() args := make([]interface{}, numParams) if numParams > 0 { nullBitmapLen := (numParams + 7) >> 3 if len(data) < (pos + nullBitmapLen + 1) { return mysql.ErrMalformPacket } nullBitmaps = data[pos : pos+nullBitmapLen] pos += nullBitmapLen //new param bound flag if data[pos] == 1 { pos++ if len(data) < (pos + (numParams << 1)) { return mysql.ErrMalformPacket } paramTypes = data[pos : pos+(numParams<<1)] pos += (numParams << 1) paramValues = data[pos:] } err = parseStmtArgs(args, stmt.BoundParams(), nullBitmaps, paramTypes, paramValues) if err != nil { return errors.Trace(err) } } rs, err := stmt.Execute(args...) if err != nil { return errors.Trace(err) } if rs == nil { return errors.Trace(cc.writeOK()) } return errors.Trace(cc.writeResultset(rs, true)) }