func (cc *clientConn) readHandshakeResponse() error { data, err := cc.readPacket() if err != nil { return errors.Trace(err) } pos := 0 // capability cc.capability = binary.LittleEndian.Uint32(data[:4]) pos += 4 // skip max packet size pos += 4 // charset, skip, if you want to use another charset, use set names cc.collation = data[pos] pos++ // skip reserved 23[00] pos += 23 // user name cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) pos += len(cc.user) + 1 // auth length and auth authLen := int(data[pos]) pos++ auth := data[pos : pos+authLen] pos += authLen if cc.capability&mysql.ClientConnectWithDB > 0 { if len(data[pos:]) > 0 { idx := bytes.IndexByte(data[pos:], 0) cc.dbname = string(data[pos : pos+idx]) } } // Open session and do auth cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname) if err != nil { cc.Close() return errors.Trace(err) } if !cc.server.skipAuth() { // Do Auth addr := cc.conn.RemoteAddr().String() host, _, err1 := net.SplitHostPort(addr) if err1 != nil { return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes")) } user := fmt.Sprintf("%s@%s", cc.user, host) if !cc.ctx.Auth(user, auth, cc.salt) { return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes")) } } return nil }
func (s *ShowPlan) getTable(ctx context.Context) (table.Table, error) { is := sessionctx.GetDomain(ctx).InfoSchema() dbName := model.NewCIStr(s.DBName) if !is.SchemaExists(dbName) { // MySQL returns no such table here if database doesn't exist. return nil, errors.Trace(mysql.NewErr(mysql.ErrNoSuchTable, s.DBName, s.TableName)) } tbName := model.NewCIStr(s.TableName) tb, err := is.TableByName(dbName, tbName) if err != nil { return nil, errors.Trace(mysql.NewErr(mysql.ErrNoSuchTable, s.DBName, s.TableName)) } return tb, nil }
// AppendParam implements IStatement AppendParam method. func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error { if paramID >= len(ts.boundParams) { return mysql.NewErr(mysql.ErrWrongArguments, "stmt_send_longdata") } ts.boundParams[paramID] = append(ts.boundParams[paramID], data...) return nil }
func (cc *clientConn) handleStmtReset(data []byte) (err error) { if len(data) < 4 { return mysql.ErrMalformPacket } stmtID := int(binary.LittleEndian.Uint32(data[0:4])) stmt := cc.ctx.GetStatement(stmtID) if stmt == nil { return mysql.NewErr(mysql.ErrUnknownStmtHandler, strconv.Itoa(stmtID), "stmt_reset") } stmt.Reset() return cc.writeOK() }
func (cc *clientConn) readHandshakeResponse() error { data, err := cc.readPacket() if err != nil { return errors.Trace(err) } var p handshakeResponse41 if err = handshakeResponseFromData(&p, data); err != nil { return errors.Trace(err) } cc.capability = p.Capability & defaultCapability cc.user = p.User cc.dbname = p.DBName cc.collation = p.Collation cc.attrs = p.Attrs // Open session and do auth cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, uint8(cc.collation), cc.dbname) if err != nil { cc.Close() return errors.Trace(err) } if !cc.server.skipAuth() { // Do Auth addr := cc.conn.RemoteAddr().String() host, _, err1 := net.SplitHostPort(addr) if err1 != nil { return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes")) } user := fmt.Sprintf("%s@%s", cc.user, host) if !cc.ctx.Auth(user, p.Auth, cc.salt) { return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes")) } } return nil }
func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) { if len(data) < 6 { return mysql.ErrMalformPacket } stmtID := int(binary.LittleEndian.Uint32(data[0:4])) stmt := cc.ctx.GetStatement(stmtID) if stmt == nil { return mysql.NewErr(mysql.ErrUnknownStmtHandler, strconv.Itoa(stmtID), "stmt_send_longdata") } paramID := int(binary.LittleEndian.Uint16(data[4:6])) return stmt.AppendParam(paramID, data[6:]) }
func (cc *clientConn) handleStmtExecute(data []byte) (err error) { if len(data) < 9 { return mysql.ErrMalformPacket } pos := 0 stmtID := binary.LittleEndian.Uint32(data[0:4]) pos += 4 stmt := cc.ctx.GetStatement(int(stmtID)) if stmt == nil { return mysql.NewErr(mysql.ErrUnknownStmtHandler, strconv.FormatUint(uint64(stmtID), 10), "stmt_execute") } flag := data[pos] pos++ //now we only support CURSOR_TYPE_NO_CURSOR flag if flag != 0 { return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag %d", flag) } //skip iteration-count, always 1 pos += 4 var ( nullBitmaps []byte paramTypes []byte paramValues []byte ) numParams := stmt.NumParams() args := make([]interface{}, numParams) if numParams > 0 { nullBitmapLen := (numParams + 7) >> 3 if len(data) < (pos + nullBitmapLen + 1) { return mysql.ErrMalformPacket } nullBitmaps = data[pos : pos+nullBitmapLen] pos += nullBitmapLen //new param bound flag if data[pos] == 1 { pos++ if len(data) < (pos + (numParams << 1)) { return mysql.ErrMalformPacket } paramTypes = data[pos : pos+(numParams<<1)] pos += (numParams << 1) paramValues = data[pos:] } err = parseStmtArgs(args, stmt.BoundParams(), nullBitmaps, paramTypes, paramValues) if err != nil { return errors.Trace(err) } } rs, err := stmt.Execute(args...) if err != nil { return errors.Trace(err) } if rs == nil { return errors.Trace(cc.writeOK()) } return errors.Trace(cc.writeResultset(rs, true)) }