func (s *Stmt) write(args ...interface{}) error { paramsNum := s.params if len(args) != paramsNum { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } paramTypes := make([]byte, paramsNum<<1) paramValues := make([][]byte, paramsNum) //NULL-bitmap, length: (num-params+7) nullBitmap := make([]byte, (paramsNum+7)>>3) var length int = int(1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1)) var newParamBoundFlag byte = 0 for i := range args { if args[i] == nil { nullBitmap[i/8] |= (1 << (uint(i) % 8)) paramTypes[i<<1] = mysql.MYSQL_TYPE_NULL continue } newParamBoundFlag = 1 switch v := args[i].(type) { case int8: paramTypes[i<<1] = mysql.MYSQL_TYPE_TINY paramValues[i] = []byte{byte(v)} case int16: paramTypes[i<<1] = mysql.MYSQL_TYPE_SHORT paramValues[i] = mysql.Uint16ToBytes(uint16(v)) case int32: paramTypes[i<<1] = mysql.MYSQL_TYPE_LONG paramValues[i] = mysql.Uint32ToBytes(uint32(v)) case int: paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case int64: paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case uint8: paramTypes[i<<1] = mysql.MYSQL_TYPE_TINY paramTypes[(i<<1)+1] = 0x80 paramValues[i] = []byte{v} case uint16: paramTypes[i<<1] = mysql.MYSQL_TYPE_SHORT paramTypes[(i<<1)+1] = 0x80 paramValues[i] = mysql.Uint16ToBytes(uint16(v)) case uint32: paramTypes[i<<1] = mysql.MYSQL_TYPE_LONG paramTypes[(i<<1)+1] = 0x80 paramValues[i] = mysql.Uint32ToBytes(uint32(v)) case uint: paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG paramTypes[(i<<1)+1] = 0x80 paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case uint64: paramTypes[i<<1] = mysql.MYSQL_TYPE_LONGLONG paramTypes[(i<<1)+1] = 0x80 paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case bool: paramTypes[i<<1] = mysql.MYSQL_TYPE_TINY if v { paramValues[i] = []byte{1} } else { paramValues[i] = []byte{0} } case float32: paramTypes[i<<1] = mysql.MYSQL_TYPE_FLOAT paramValues[i] = mysql.Uint32ToBytes(math.Float32bits(v)) case float64: paramTypes[i<<1] = mysql.MYSQL_TYPE_DOUBLE paramValues[i] = mysql.Uint64ToBytes(math.Float64bits(v)) case string: paramTypes[i<<1] = mysql.MYSQL_TYPE_STRING paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) case []byte: paramTypes[i<<1] = mysql.MYSQL_TYPE_STRING paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) default: return fmt.Errorf("invalid argument type %T", args[i]) } length += len(paramValues[i]) } data := make([]byte, 4, 4+length) data = append(data, mysql.COM_STMT_EXECUTE) data = append(data, byte(s.id), byte(s.id>>8), byte(s.id>>16), byte(s.id>>24)) //flag: CURSOR_TYPE_NO_CURSOR data = append(data, 0x00) //iteration-count, always 1 data = append(data, 1, 0, 0, 0) if s.params > 0 { data = append(data, nullBitmap...) //new-params-bound-flag data = append(data, newParamBoundFlag) if newParamBoundFlag == 1 { //type of each parameter, length: num-params * 2 data = append(data, paramTypes...) //value of each parameter for _, v := range paramValues { data = append(data, v...) } } } s.conn.pkg.Sequence = 0 return s.conn.writePacket(data) }
func (c *ClientConn) writePrepare(s *Stmt) error { var err error data := make([]byte, 4, 128) total := make([]byte, 0, 1024) //status ok data = append(data, 0) //stmt id data = append(data, mysql.Uint32ToBytes(s.id)...) //number columns data = append(data, mysql.Uint16ToBytes(uint16(s.columns))...) //number params data = append(data, mysql.Uint16ToBytes(uint16(s.params))...) //filter [00] data = append(data, 0) //warning count data = append(data, 0, 0) total, err = c.writePacketBatch(total, data, false) if err != nil { return err } if s.params > 0 { for i := 0; i < s.params; i++ { data = data[0:4] data = append(data, []byte(paramFieldData)...) total, err = c.writePacketBatch(total, data, false) if err != nil { return err } } total, err = c.writeEOFBatch(total, c.status, false) if err != nil { return err } } if s.columns > 0 { for i := 0; i < s.columns; i++ { data = data[0:4] data = append(data, []byte(columnFieldData)...) total, err = c.writePacketBatch(total, data, false) if err != nil { return err } } total, err = c.writeEOFBatch(total, c.status, false) if err != nil { return err } } total, err = c.writePacketBatch(total, nil, true) total = nil if err != nil { return err } return nil }