func (c *Session) handleComStmtPrepare(sqlstmt string) error { stmt, err := parser.Parse(sqlstmt) if err != nil { log.Warningf(`parse sql "%s" error "%s"`, sqlstmt, err.Error()) return c.handleMySQLError( mysql.NewDefaultError(mysql.ER_SYNTAX_ERROR, err.Error())) } // Only a few statements supported by prepare statements // http://dev.mysql.com/worklog/task/?id=2871 switch v := stmt.(type) { case parser.ISelect, *parser.Insert, *parser.Update, *parser.Delete, *parser.Replace, parser.IDDLStatement, *parser.ShowTables, *parser.ShowColumns, *parser.ShowVariables, *parser.ShowIndex, *parser.Set, *parser.DescribeTable, *parser.Do: return c.prepare(v, sqlstmt) default: log.Warnf("statement %T[%s] not support prepare ops", stmt, sqlstmt) return c.handleMySQLError( mysql.NewDefaultError(mysql.ER_UNSUPPORTED_PS)) } }
func (session *Session) handleComStmtExecute(data []byte) error { if len(data) < 9 { return session.handleMySQLError(mysql.ErrMalformPkt) } pos := 0 id := binary.LittleEndian.Uint32(data[0:4]) pos += 4 stmt, ok := session.bc.stmts[id] if !ok { return mysql.NewDefaultError(mysql.ER_UNKNOWN_STMT_HANDLER, strconv.FormatUint(uint64(id), 10), "stmt_execute") } flag := data[pos] pos++ //now we only support CURSOR_TYPE_NO_CURSOR flag if flag != 0 { return mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flag %d", flag)) } //skip iteration-count, always 1 pos += 4 var err error switch stmt.SQL.(type) { case parser.ISelect, *parser.ShowTables, *parser.ShowVariables, *parser.ShowColumns, *parser.ShowIndex, *parser.DescribeTable: err = session.handleStmtQuery(stmt, data[pos:]) default: err = session.handleStmtExec(stmt, data[pos:]) } return err }
func (session *Session) handleComStmtSendLongData(data []byte) error { if len(data) < 6 { return session.handleMySQLError(mysql.ErrMalformPkt) } id := binary.LittleEndian.Uint32(data[0:4]) stmt, ok := session.bc.stmts[id] if !ok { return mysql.NewDefaultError(mysql.ER_UNKNOWN_STMT_HANDLER, strconv.FormatUint(uint64(id), 10), "stmt_send_longdata") } paramId := binary.LittleEndian.Uint16(data[4:6]) if paramId >= uint16(len(stmt.Params)) { return mysql.NewDefaultError(mysql.ER_WRONG_ARGUMENTS, "stmt_send_longdata") } stmt.SendLongData(int(paramId), data[6:]) return nil }
func (session *Session) useDB(db string) error { if session.cluster != nil { if session.cluster.DBName != db { return mysql.NewDefaultError(mysql.ER_BAD_DB_ERROR, db) } return nil } if _, err := session.config.GetClusterByDBName(db); err != nil { return mysql.NewDefaultError(mysql.ER_BAD_DB_ERROR, db) } else if session.cluster, err = cluster.New(session.user.ClusterName); err != nil { return err } if session.bc == nil { master, err := session.cluster.Master() if err != nil { return mysql.NewDefaultError(mysql.ER_BAD_DB_ERROR, db) } slave, err := session.cluster.Slave() if err != nil { slave = master } session.bc = &SqlConn{ master: master, slave: slave, stmts: make(map[uint32]*mysql.Stmt), tx: nil, session: session, } } return nil }
func (session *Session) handleComStmtReset(data []byte) error { if len(data) < 4 { return session.handleMySQLError(mysql.ErrMalformPkt) } id := binary.LittleEndian.Uint32(data[0:4]) stmt, ok := session.bc.stmts[id] if !ok { return mysql.NewDefaultError(mysql.ER_UNKNOWN_STMT_HANDLER, strconv.FormatUint(uint64(id), 10), "stmt_reset") } if rs, err := stmt.Reset(); err != nil { return session.handleMySQLError(err) } else { return session.fc.WriteOK(rs) } }
func (session *Session) dispatch(data []byte) (err error) { cmd := data[0] data = data[1:] defer func() { flush_error := session.fc.Flush() if err == nil { err = flush_error } }() switch cmd { case mysql.ComQuery: err = session.comQuery(hack.String(data)) case mysql.ComPing: err = session.fc.WriteOK(nil) case mysql.ComInitDB: if err := session.useDB(hack.String(data)); err != nil { err = session.handleMySQLError(err) } else { err = session.fc.WriteOK(nil) } case mysql.ComFieldList: err = session.handleFieldList(data) case mysql.ComStmtPrepare: err = session.handleComStmtPrepare(hack.String(data)) case mysql.ComStmtExecute: err = session.handleComStmtExecute(data) case mysql.ComStmtClose: err = session.handleComStmtClose(data) case mysql.ComStmtSendLongData: err = session.handleComStmtSendLongData(data) case mysql.ComStmtReset: err = session.handleComStmtReset(data) default: msg := fmt.Sprintf("command %d not supported now", cmd) log.Warnf(msg) err = mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR, msg) } return }