예제 #1
0
func (c *ClientConn) writePrepare(s *Stmt) error {
	data := make([]byte, 4, 128)

	//status ok
	data = append(data, 0)
	//stmt id
	data = append(data, mysql.Uint32ToBytes(s.id)...)
	//number columns
	data = append(data, mysql.Uint16ToBytes(uint16(s.columns))...)
	//number params
	data = append(data, mysql.Uint16ToBytes(uint16(s.params))...)
	//filter [00]
	data = append(data, 0)
	//warning count
	data = append(data, 0, 0)

	if err := c.writePacket(data); err != nil {
		return err
	}

	if s.params > 0 {
		for i := 0; i < s.params; i++ {
			data = data[0:4]
			data = append(data, []byte(paramFieldData)...)

			if err := c.writePacket(data); err != nil {
				return err
			}
		}

		if err := c.writeEOF(c.status); err != nil {
			return err
		}
	}

	if s.columns > 0 {
		for i := 0; i < s.columns; i++ {
			data = data[0:4]
			data = append(data, []byte(columnFieldData)...)

			if err := c.writePacket(data); err != nil {
				return err
			}
		}

		if err := c.writeEOF(c.status); err != nil {
			return err
		}

	}
	return nil
}
예제 #2
0
func (s *Stmt) write(args ...interface{}) error {
	paramsNum := s.params

	if len(args) != paramsNum {
		return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
	}

	paramTypes := make([]byte, paramsNum<<1)
	paramValues := make([][]byte, paramsNum)

	//NULL-bitmap, length: (num-params+7)
	nullBitmap := make([]byte, (paramsNum+7)>>3)

	var length int = int(1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1))

	var newParamBoundFlag byte = 0

	for i := range args {
		if args[i] == nil {
			nullBitmap[i/8] |= (1 << (uint(i) % 8))
			paramTypes[i<<1] = mysql.MYSQL_TYPE_NULL
			continue
		}

		newParamBoundFlag = 1

		switch v := args[i].(type) {
		case int8:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_TINY
			paramValues[i] = []byte{byte(v)}
		case int16:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_SHORT
			paramValues[i] = mysql.Uint16ToBytes(uint16(v))
		case int32:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_LONG
			paramValues[i] = mysql.Uint32ToBytes(uint32(v))
		case int:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG
			paramValues[i] = mysql.Uint64ToBytes(uint64(v))
		case int64:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG
			paramValues[i] = mysql.Uint64ToBytes(uint64(v))
		case uint8:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_TINY
			paramTypes[(i<<1)+1] = 0x80
			paramValues[i] = []byte{v}
		case uint16:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_SHORT
			paramTypes[(i<<1)+1] = 0x80
			paramValues[i] = mysql.Uint16ToBytes(uint16(v))
		case uint32:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_LONG
			paramTypes[(i<<1)+1] = 0x80
			paramValues[i] = mysql.Uint32ToBytes(uint32(v))
		case uint:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG
			paramTypes[(i<<1)+1] = 0x80
			paramValues[i] = mysql.Uint64ToBytes(uint64(v))
		case uint64:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG
			paramTypes[(i<<1)+1] = 0x80
			paramValues[i] = mysql.Uint64ToBytes(uint64(v))
		case bool:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_TINY
			if v {
				paramValues[i] = []byte{1}
			} else {
				paramValues[i] = []byte{0}

			}
		case float32:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_FLOAT
			paramValues[i] = mysql.Uint32ToBytes(math.Float32bits(v))
		case float64:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_DOUBLE
			paramValues[i] = mysql.Uint64ToBytes(math.Float64bits(v))
		case string:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_STRING
			paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...)
		case []byte:
			paramTypes[i<<1] = mysql.MYSQL_TYPE_STRING
			paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...)
		default:
			return fmt.Errorf("invalid argument type %T", args[i])
		}

		length += len(paramValues[i])
	}

	data := make([]byte, 4, 4+length)

	data = append(data, mysql.COM_STMT_EXECUTE)
	data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24))

	//flag: CURSOR_TYPE_NO_CURSOR
	data = append(data, 0x00)

	//iteration-count, always 1
	data = append(data, 1, 0, 0, 0)

	if s.params > 0 {
		data = append(data, nullBitmap...)

		//new-params-bound-flag
		data = append(data, newParamBoundFlag)

		if newParamBoundFlag == 1 {
			//type of each parameter, length: num-params * 2
			data = append(data, paramTypes...)

			//value of each parameter
			for _, v := range paramValues {
				data = append(data, v...)
			}
		}
	}

	s.conn.pkg.Sequence = 0

	return s.conn.writePacket(data)
}