Exemple #1
0
// SETBIT
func (m *Miniredis) cmdSetbit(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) != 3 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	key := r.Args[0]
	bit, err := strconv.Atoi(r.Args[1])
	if err != nil || bit < 0 {
		setDirty(r.Client())
		return redeo.ClientError("bit offset is not an integer or out of range")
	}
	newBit, err := strconv.Atoi(r.Args[2])
	if err != nil || (newBit != 0 && newBit != 1) {
		setDirty(r.Client())
		return redeo.ClientError("bit is not an integer or out of range")
	}

	return withTx(m, out, r, func(out *redeo.Responder, ctx *connCtx) {
		db := m.db(ctx.selectedDB)

		if t, ok := db.keys[key]; ok && t != "string" {
			out.WriteErrorString(msgWrongType)
			return
		}
		value := []byte(db.stringKeys[key])

		ourByteNr := bit / 8
		ourBitNr := bit % 8
		if ourByteNr > len(value)-1 {
			// Too short. Expand.
			newValue := make([]byte, ourByteNr+1)
			copy(newValue, value)
			value = newValue
		}
		old := 0
		if toBits(value[ourByteNr])[ourBitNr] {
			old = 1
		}
		if newBit == 0 {
			value[ourByteNr] &^= 1 << uint8(7-ourBitNr)
		} else {
			value[ourByteNr] |= 1 << uint8(7-ourBitNr)
		}
		db.stringSet(key, string(value))

		out.WriteInt(old)
	})
}
// DISCARD
func (m *Miniredis) cmdDiscard(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) != 0 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	ctx := getCtx(r.Client())
	if !inTx(ctx) {
		return redeo.ClientError("DISCARD without MULTI")
	}

	stopTx(ctx)
	out.WriteOK()
	return nil
}
// MULTI
func (m *Miniredis) cmdMulti(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) != 0 {
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	ctx := getCtx(r.Client())

	if inTx(ctx) {
		return redeo.ClientError("MULTI calls can not be nested")
	}

	startTx(ctx)

	out.WriteOK()
	return nil
}
// EXEC
func (m *Miniredis) cmdExec(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) != 0 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	ctx := getCtx(r.Client())

	if !inTx(ctx) {
		return redeo.ClientError("EXEC without MULTI")
	}

	if dirtyTx(ctx) {
		out.WriteErrorString("EXECABORT Transaction discarded because of previous errors.")
		return nil
	}

	m.Lock()
	defer m.Unlock()

	// Check WATCHed keys.
	for t, version := range ctx.watch {
		if m.db(t.db).keyVersion[t.key] > version {
			// Abort! Abort!
			stopTx(ctx)
			out.WriteBulkLen(0)
			return nil
		}
	}

	out.WriteBulkLen(len(ctx.transaction))
	for _, cb := range ctx.transaction {
		cb(out, ctx)
	}
	// We're done
	stopTx(ctx)
	return nil
}
Exemple #5
0
// SETRANGE
func (m *Miniredis) cmdSetrange(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) != 3 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	key := r.Args[0]
	pos, err := strconv.Atoi(r.Args[1])
	if err != nil {
		setDirty(r.Client())
		out.WriteErrorString(msgInvalidInt)
		return nil
	}
	if pos < 0 {
		setDirty(r.Client())
		return redeo.ClientError("offset is out of range")
	}
	subst := r.Args[2]

	return withTx(m, out, r, func(out *redeo.Responder, ctx *connCtx) {
		db := m.db(ctx.selectedDB)

		if t, ok := db.keys[key]; ok && t != "string" {
			out.WriteErrorString(msgWrongType)
			return
		}

		v := []byte(db.stringKeys[key])
		if len(v) < pos+len(subst) {
			newV := make([]byte, pos+len(subst))
			copy(newV, v)
			v = newV
		}
		copy(v[pos:pos+len(subst)], subst)
		db.stringSet(key, string(v))
		out.WriteInt(len(v))
	})
}
Exemple #6
0
// GETBIT
func (m *Miniredis) cmdGetbit(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) != 2 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	key := r.Args[0]
	bit, err := strconv.Atoi(r.Args[1])
	if err != nil {
		setDirty(r.Client())
		return redeo.ClientError("bit offset is not an integer or out of range")
	}

	return withTx(m, out, r, func(out *redeo.Responder, ctx *connCtx) {
		db := m.db(ctx.selectedDB)

		if t, ok := db.keys[key]; ok && t != "string" {
			out.WriteErrorString(msgWrongType)
			return
		}
		value := db.stringKeys[key]

		ourByteNr := bit / 8
		var ourByte byte
		if ourByteNr > len(value)-1 {
			ourByte = '\x00'
		} else {
			ourByte = value[ourByteNr]
		}
		res := 0
		if toBits(ourByte)[bit%8] {
			res = 1
		}
		out.WriteInt(res)
	})
}
// WATCH
func (m *Miniredis) cmdWatch(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) == 0 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	ctx := getCtx(r.Client())
	if inTx(ctx) {
		return redeo.ClientError("WATCH in MULTI")
	}

	m.Lock()
	defer m.Unlock()
	db := m.db(ctx.selectedDB)

	for _, key := range r.Args {
		watch(db, ctx, key)
	}
	out.WriteOK()
	return nil
}
// ZINTERSTORE
func (m *Miniredis) cmdZinterstore(out *redeo.Responder, r *redeo.Request) error {
	if len(r.Args) < 3 {
		setDirty(r.Client())
		return r.WrongNumberOfArgs()
	}
	if !m.handleAuth(r.Client(), out) {
		return nil
	}

	destination := r.Args[0]
	numKeys, err := strconv.Atoi(r.Args[1])
	if err != nil {
		setDirty(r.Client())
		out.WriteErrorString(msgInvalidInt)
		return nil
	}
	args := r.Args[2:]
	if len(args) < numKeys {
		setDirty(r.Client())
		out.WriteErrorString(msgSyntaxError)
		return nil
	}
	if numKeys <= 0 {
		setDirty(r.Client())
		return redeo.ClientError("at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE")
	}
	keys := args[:numKeys]
	args = args[numKeys:]

	withWeights := false
	weights := []float64{}
	aggregate := "sum"
	for len(args) > 0 {
		if strings.ToLower(args[0]) == "weights" {
			if len(args) < numKeys+1 {
				setDirty(r.Client())
				out.WriteErrorString(msgSyntaxError)
				return nil
			}
			for i := 0; i < numKeys; i++ {
				f, err := strconv.ParseFloat(args[i+1], 64)
				if err != nil {
					setDirty(r.Client())
					return redeo.ClientError("weight value is not a float")
				}
				weights = append(weights, f)
			}
			withWeights = true
			args = args[numKeys+1:]
			continue
		}
		if strings.ToLower(args[0]) == "aggregate" {
			if len(args) < 2 {
				setDirty(r.Client())
				out.WriteErrorString(msgSyntaxError)
				return nil
			}
			aggregate = strings.ToLower(args[1])
			switch aggregate {
			default:
				setDirty(r.Client())
				out.WriteErrorString(msgSyntaxError)
				return nil
			case "sum", "min", "max":
			}
			args = args[2:]
			continue
		}
		setDirty(r.Client())
		out.WriteErrorString(msgSyntaxError)
		return nil
	}

	return withTx(m, out, r, func(out *redeo.Responder, ctx *connCtx) {
		db := m.db(ctx.selectedDB)
		db.del(destination, true)

		// We collect everything and remove all keys which turned out not to be
		// present in every set.
		sset := map[string]float64{}
		counts := map[string]int{}
		for i, key := range keys {
			if !db.exists(key) {
				continue
			}
			if db.t(key) != "zset" {
				out.WriteErrorString(msgWrongType)
				return
			}
			for _, el := range db.ssetElements(key) {
				score := el.score
				if withWeights {
					score *= weights[i]
				}
				counts[el.member]++
				old, ok := sset[el.member]
				if !ok {
					sset[el.member] = score
					continue
				}
				switch aggregate {
				default:
					panic("Invalid aggregate")
				case "sum":
					sset[el.member] += score
				case "min":
					if score < old {
						sset[el.member] = score
					}
				case "max":
					if score > old {
						sset[el.member] = score
					}
				}
			}
		}
		for key, count := range counts {
			if count != numKeys {
				delete(sset, key)
			}
		}
		db.ssetSet(destination, sset)
		out.WriteInt(len(sset))
	})
}