Example #1
0
func (s CommandSuite) TestCommands(c *C) {
	for _, t := range tests {
		cmd := commands[t.command]
		var wb *levigo.WriteBatch
		if cmd.writes {
			wb = levigo.NewWriteBatch()
		}
		var args [][]byte
		if t.args != "" {
			if cmd.arity > 0 {
				args = bytes.SplitN([]byte(t.args), []byte(" "), cmd.arity)
			} else {
				args = bytes.Split([]byte(t.args), []byte(" "))
			}
		}
		cmd.lockKeys(args)
		res := cmd.function(args, wb)
		if cmd.writes {
			err := DB.Write(DefaultWriteOptions, wb)
			c.Assert(err, IsNil)
			wb.Close()
		}
		cmd.unlockKeys(args)
		if stream, ok := res.(*cmdReplyStream); ok {
			items := make([]interface{}, 0, int(stream.size))
			for item := range stream.items {
				items = append(items, item)
			}
			res = items
		}
		if reply, ok := res.(rawReply); ok {
			if reply[0] == '+' {
				res = string(reply[1 : len(reply)-2])
			}
		}
		c.Assert(res, DeepEquals, t.response, Commentf("%s %s, obtained=%s expected=%s", t.command, t.args, res, t.response))
	}
}
Example #2
0
func protocolHandler(c *client) {
	// Read a length (looks like "$3\r\n")
	readLength := func(prefix byte) (length int, err error) {
		b, err := c.r.ReadByte()
		if err != nil {
			return
		}
		if b != prefix {
			writeProtocolError(c.w, "invalid length")
			return
		}
		l, overflowed, err := c.r.ReadLine() // Read bytes will look like "123"
		if err != nil {
			return
		}
		if overflowed {
			writeProtocolError(c.w, "length line too long")
			return
		}
		if len(l) == 0 {
			writeProtocolError(c.w, "missing length")
			return
		}
		length, err = bconv.Atoi(l)
		if err != nil {
			writeProtocolError(c.w, "length is not a valid integer")
			return
		}
		return
	}

	runCommand := func(args [][]byte) (err error) {
		if len(args) == 0 {
			writeProtocolError(c.w, "missing command")
			return
		}

		// lookup the command
		command, ok := commands[UnsafeBytesToString(bytes.ToLower(args[0]))]
		if !ok {
			writeError(c.w, "unknown command '"+string(args[0])+"'")
			return
		}

		// check command arity, negative arity means >= n
		if (command.arity < 0 && len(args)-1 < -command.arity) || (command.arity >= 0 && len(args)-1 > command.arity) {
			writeError(c.w, "wrong number of arguments for '"+string(args[0])+"' command")
			return
		}

		// call the command and respond
		var wb *levigo.WriteBatch
		if command.writes {
			wb = levigo.NewWriteBatch()
			defer wb.Close()
		}
		command.lockKeys(args[1:])
		res := command.function(args[1:], wb)
		if command.writes {
			if _, ok := res.(error); !ok { // only write the batch if the return value is not an error
				err = DB.Write(DefaultWriteOptions, wb)
			}
			if err != nil {
				writeError(c.w, "data write error: "+err.Error())
				return
			}
		}
		command.unlockKeys(args[1:])
		writeReply(c.w, res)

		return
	}

	processInline := func() error {
		line, err := c.r.ReadBytes('\n')
		if err != nil {
			return err
		}
		return runCommand(bytes.Split(line[:len(line)-2], []byte(" ")))
	}

	scratch := make([]byte, 2)
	args := [][]byte{}
	// Client event loop, each iteration handles a command
	for {
		// check if we're using the old inline protocol
		b, err := c.r.Peek(1)
		if err != nil {
			return
		}
		if b[0] != '*' {
			err = processInline()
			if err != nil {
				return
			}
			continue
		}

		// Step 1: get the number of arguments
		argCount, err := readLength('*')
		if err != nil {
			return
		}

		// read the arguments
		for i := 0; i < argCount; i++ {
			length, err := readLength('$')
			if err != nil {
				return
			}

			// Read the argument bytes
			args = append(args, make([]byte, length))
			_, err = io.ReadFull(c.r, args[i])
			if err != nil {
				return
			}

			// The argument has a trailing \r\n that we need to discard
			c.r.Read(scratch) // TODO: make sure these bytes are read
		}

		err = runCommand(args)
		if err != nil {
			return
		}

		// Truncate arguments for the next run
		args = args[:0]
	}
}