コード例 #1
0
ファイル: conn.go プロジェクト: remotesyssupport/tidb
func (cc *clientConn) readHandshakeResponse() error {
	data, err := cc.readPacket()
	if err != nil {
		return errors.Trace(err)
	}

	pos := 0
	// capability
	cc.capability = binary.LittleEndian.Uint32(data[:4])
	pos += 4
	// skip max packet size
	pos += 4
	// charset, skip, if you want to use another charset, use set names
	cc.collation = data[pos]
	pos++
	// skip reserved 23[00]
	pos += 23
	// user name
	cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
	pos += len(cc.user) + 1
	// auth length and auth
	authLen := int(data[pos])
	pos++
	auth := data[pos : pos+authLen]
	pos += authLen
	if cc.capability|mysql.ClientConnectWithDB > 0 {
		if len(data[pos:]) > 0 {
			idx := bytes.IndexByte(data[pos:], 0)
			cc.dbname = string(data[pos : pos+idx])
		}
	}
	// Open session and do auth
	cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname)
	if err != nil {
		cc.Close()
		return errors.Trace(err)
	}
	if !cc.server.skipAuth() {
		// Do Auth
		addr := cc.conn.RemoteAddr().String()
		host, _, err1 := net.SplitHostPort(addr)
		if err1 != nil {
			return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, addr, "Yes"))
		}
		user := fmt.Sprintf("%s@%s", cc.user, host)
		if !cc.ctx.Auth(user, auth, cc.salt) {
			return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, host, "Yes"))
		}
	}
	return nil
}
コード例 #2
0
ファイル: driver_tidb.go プロジェクト: awesomeleo/tidb
// AppendParam implements IStatement AppendParam method.
func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error {
	if paramID >= len(ts.boundParams) {
		return mysql.NewDefaultError(mysql.ErWrongArguments, "stmt_send_longdata")
	}
	ts.boundParams[paramID] = append(ts.boundParams[paramID], data...)
	return nil
}
コード例 #3
0
ファイル: conn_stmt.go プロジェクト: nengwang/tidb
func (cc *clientConn) handleStmtReset(data []byte) (err error) {
	if len(data) < 4 {
		return mysql.ErrMalformPacket
	}

	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))
	stmt := cc.ctx.GetStatement(stmtID)
	if stmt == nil {
		return mysql.NewDefaultError(mysql.ErUnknownStmtHandler,
			strconv.Itoa(stmtID), "stmt_reset")
	}
	stmt.Reset()
	return cc.writeOK()
}
コード例 #4
0
ファイル: conn_stmt.go プロジェクト: nengwang/tidb
func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) {
	if len(data) < 6 {
		return mysql.ErrMalformPacket
	}

	stmtID := int(binary.LittleEndian.Uint32(data[0:4]))

	stmt := cc.ctx.GetStatement(stmtID)
	if stmt == nil {
		return mysql.NewDefaultError(mysql.ErUnknownStmtHandler,
			strconv.Itoa(stmtID), "stmt_send_longdata")
	}

	paramID := int(binary.LittleEndian.Uint16(data[4:6]))
	return stmt.AppendParam(paramID, data[6:])
}
コード例 #5
0
ファイル: conn.go プロジェクト: hxiaodon/tidb
func (cc *clientConn) readHandshakeResponse() error {
	data, err := cc.readPacket()
	if err != nil {
		return errors.Trace(err)
	}

	pos := 0
	// capability
	cc.capability = binary.LittleEndian.Uint32(data[:4])
	pos += 4
	// skip max packet size
	pos += 4
	// charset, skip, if you want to use another charset, use set names
	cc.collation = data[pos]
	pos++
	// skip reserved 23[00]
	pos += 23
	// user name
	cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
	pos += len(cc.user) + 1
	// auth length and auth
	authLen := int(data[pos])
	pos++
	auth := data[pos : pos+authLen]
	checkAuth := calcPassword(cc.salt, []byte(cc.server.cfgGetPwd(cc.user)))
	if !bytes.Equal(auth, checkAuth) && !cc.server.skipAuth() {
		return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.conn.RemoteAddr().String(), cc.user, "Yes"))
	}

	pos += authLen
	if cc.capability|mysql.ClientConnectWithDB > 0 {
		if len(data[pos:]) == 0 {
			return nil
		}
		idx := bytes.IndexByte(data[pos:], 0)
		cc.dbname = string(data[pos : pos+idx])
	}

	return nil
}
コード例 #6
0
ファイル: conn_stmt.go プロジェクト: nengwang/tidb
func (cc *clientConn) handleStmtExecute(data []byte) (err error) {
	if len(data) < 9 {
		return mysql.ErrMalformPacket
	}

	pos := 0
	stmtID := binary.LittleEndian.Uint32(data[0:4])
	pos += 4

	stmt := cc.ctx.GetStatement(int(stmtID))
	if stmt == nil {
		return mysql.NewDefaultError(mysql.ErUnknownStmtHandler,
			strconv.FormatUint(uint64(stmtID), 10), "stmt_execute")
	}

	flag := data[pos]
	pos++
	//now we only support CURSOR_TYPE_NO_CURSOR flag
	if flag != 0 {
		return mysql.NewError(mysql.ErUnknownError, fmt.Sprintf("unsupported flag %d", flag))
	}

	//skip iteration-count, always 1
	pos += 4

	var (
		nullBitmaps []byte
		paramTypes  []byte
		paramValues []byte
	)
	numParams := stmt.NumParams()
	args := make([]interface{}, numParams)
	if numParams > 0 {
		nullBitmapLen := (numParams + 7) >> 3
		if len(data) < (pos + nullBitmapLen + 1) {
			return mysql.ErrMalformPacket
		}
		nullBitmaps = data[pos : pos+nullBitmapLen]
		pos += nullBitmapLen

		//new param bound flag
		if data[pos] == 1 {
			pos++
			if len(data) < (pos + (numParams << 1)) {
				return mysql.ErrMalformPacket
			}

			paramTypes = data[pos : pos+(numParams<<1)]
			pos += (numParams << 1)
			paramValues = data[pos:]
		}

		err = parseStmtArgs(args, stmt.BoundParams(), nullBitmaps, paramTypes, paramValues)
		if err != nil {
			return errors.Trace(err)
		}
	}
	rs, err := stmt.Execute(args...)
	if err != nil {
		return errors.Trace(err)
	}
	if rs == nil {
		return errors.Trace(cc.writeOK())
	}

	return errors.Trace(cc.writeResultset(rs, true))
}