예제 #1
0
func (s *Session) DEL(req *redis.Request) {
	var result int64
	// 串行会很慢,可以考滤开goroutine并行执行
	// 但是这个goroutine量一定要控制,不能有多少key就多少goroutine
	p := make(chan int, s.MulOpParallel)
	for i := 0; i < s.MulOpParallel; i++ {
		p <- 1
	}

	defer func() {
		close(p)
	}()

	keys := req.Args()
	wg := sync.WaitGroup{}
	wg.Add(len(keys))

	for _, key := range keys {
		go func(key string) {
			<-p
			// log.Info("In DEL goroutine ", key)
			cmdslice := []string{"DEL", key}
			r := redis.NewRequest(cmdslice)
			resp := s.Proxy.Backend.OnDEL(r)
			result += resp.Val()
			p <- 1
			wg.Done()
		}(key)
	}

	wg.Wait()
	mergeResp := redis.FormatInt(result)
	// log.Info("DEL merger resp ", mergeResp, result)
	s.write2client(mergeResp)
}
예제 #2
0
func verifyCommand(req *redis.Request) error {
	if req == nil || req.Len() == 0 {
		return BadCommandError
	}

	name := req.Name()

	if _, ok := blackList[name]; ok {
		return CommandForbidden
	}

	rule, exist := reqrules[name]
	if !exist {
		// may return an error ?
		return BadCommandError
	}

	for i, count := 0, len(rule); i < count; i++ {
		switch i {
		case RI_MinCount:
			if val := rule[i].(int); val != -1 && req.Len() < val {
				return WrongArgumentCount
			}
		case RI_MaxCount:
			if val := rule[i].(int); val != -1 && req.Len() > val {
				return WrongArgumentCount
			}
		}
	}

	return nil
}
예제 #3
0
func (s *Session) MSET(req *redis.Request) {
	pair := req.Args()
	if len(pair)%2 != 0 {
		err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
		s.write2client([]byte(err))
		return
	}

	p := make(chan int, s.MulOpParallel)
	for i := 0; i < s.MulOpParallel; i++ {
		p <- 1
	}

	defer func() {
		close(p)
	}()
	wg := sync.WaitGroup{}
	wg.Add(len(pair) / 2)
	partialErr := 0
	// we just ignore return code, MSET reuturn OK unless anyone set error
	for i := 0; i < len(pair); i += 2 {
		go func(k string, v string) {
			<-p
			// log.Info("In MSET goroutine ", k, v)
			cmdslice := []string{"SET", k, v}
			r := redis.NewRequest(cmdslice)
			resp := s.Proxy.Backend.OnSET(r)
			if resp.Err() != nil && resp.Err() != redis.Nil {
				// log.Warning("MSET error ", cmdslice, resp.Err())
				partialErr += 1
			}
			p <- 1
			wg.Done()
		}(pair[i], pair[i+1])
	}
	wg.Wait()

	if partialErr == 0 {
		s.write2client(OK_BYTES)
	} else {
		d := fmt.Sprintf("- %d MSET failed, partial key/value %d set\r\n", partialErr, len(pair)/2-partialErr)
		s.write2client([]byte(d))
	}
}
예제 #4
0
//loglevel  idletime  mulparallel  statsd  slaveok
func (s *Session) proxyConf(req *redis.Request) {
	// proxy config set loglevel info
	// proxy config set idletime 200
	// proxy config set slaveok 1|0
	// proxy config set mulparallel 30
	// proxy config get statsd
	args := req.Args()
	// config get|set
	switch strings.ToLower(args[1]) {
	case "get":
		if len(req.Args()) != 3 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		cfgname := strings.ToLower(args[2])
		reply := s.proxyConfigGetByName(cfgname)
		s.write2client(reply)
		return
	case "set":
		if len(req.Args()) != 4 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		cfgname := strings.ToLower(args[2])
		value := strings.ToLower(args[3])
		reply := s.proxyConfigSetByName(cfgname, value)
		s.write2client(reply)
		return
	default:
		s.write2client([]byte("-wrong proxy config op type\r\n"))
		return
	}
}
예제 #5
0
func (ps *ProxyServer) Dispatch(req *redis.Request) redis.Cmder {

	name := req.Name()

	method, ok := ps.RedisMethod[name]

	if !ok {
		method = reflect.ValueOf(ps.Backend).MethodByName("On" + name)
		ps.RedisMethod[name] = method
	}

	if method.IsValid() {
		in := []reflect.Value{reflect.ValueOf(req)}
		callResult := method.Call(in)
		if callResult[0].Interface() != nil {
			return callResult[0].Interface().(redis.Cmder)
		}
	} else {
		return ps.Backend.OnReflectUnvalid(req)
	}
	return ps.Backend.OnUnDenfined(req)
}
예제 #6
0
func (s *Session) MGET(req *redis.Request) {
	p := make(chan int, s.MulOpParallel)
	for i := 0; i < s.MulOpParallel; i++ {
		p <- 1
	}

	defer func() {
		close(p)
	}()

	keys := req.Args()
	wg := sync.WaitGroup{}
	wg.Add(len(keys))

	// we should ensure the KEY's order
	result := make([][]byte, len(keys))

	for idx, key := range keys {
		go func(key string, idx int) {
			<-p
			// log.Info("In MGET goroutine ", key)
			cmdslice := []string{"GET", key}
			r := redis.NewRequest(cmdslice)
			resp := s.Proxy.Backend.OnGET(r)
			result[idx] = resp.Reply()
			p <- 1
			wg.Done()
		}(key, idx)
	}

	wg.Wait()
	mergeResp := []byte(fmt.Sprintf("*%d\r\n", len(keys)))
	for _, res := range result {
		mergeResp = append(mergeResp, res...)
	}
	// log.Info("MGET merger resp ", string(mergeResp))
	s.write2client(mergeResp)
}
예제 #7
0
func (s *Session) SpecCommandProcess(req *redis.Request) {
	// log.Info("Spec command Process ", req)

	switch req.Name() {
	case "SINTERSTORE":
		s.SINTERSTORE(req)
	case "SMOVE":
		s.SMOVE(req)
	case "DEL":
		s.DEL(req)
	case "RPOPLPUSH":
		s.RPOPLPUSH(req)
	case "SDIFFSTORE":
		s.SDIFFSTORE(req)
	case "SINTER":
		s.SINTER(req)
	case "SDIFF":
		s.SDIFF(req)
	case "MGET":
		s.MGET(req)
	case "ZINTERSTORE":
		s.ZINTERSTORE(req)
	case "ZUNIONSTORE":
		s.ZUNIONSTORE(req)
	case "RENAME":
		s.RENAME(req)
	case "RENAMENX":
		s.RENAMENX(req)
	case "MSET":
		s.MSET(req)
	case "MSETNX":
		s.MSETNX(req)
	case "PROXY":
		s.PROXY(req)
	default:
		log.Fatalf("Unknown Spec Command: %s, we won't expect this happen ", req.Name())
	}
}
예제 #8
0
func (s *Session) Write2client(req *redis.Request) error {
	return s.write2client(req.Result())
}
예제 #9
0
func (s *Session) forward(req *redis.Request) {
	resp := s.Proxy.Dispatch(req)
	// log.Info("session forward got response: ", resp)
	req.SetResp(resp)
}
예제 #10
0
// buf, shouldClose, handled, err
func preCheckCommand(req *redis.Request) ([]byte, bool, bool, error) {
	var reply []byte
	shouldClose := false

	if req.Len() == 0 {
		return reply, false, true, BadCommandError
	}
	cmd := req.Name()
	switch cmd {
	case "PING":
		reply = []byte("+PONG\r\n")
	case "QUIT":
		reply = OK_BYTES
		shouldClose = true
	case "SELECT":
		//支持 select,但是到后台全部都用的 db 0
		//hia hia hia hia 没办法。。。
		reply = OK_BYTES
	case "AUTH":
		reply = OK_BYTES
	case "ECHO":
		if len(req.Args()) == 1 {
			echo := fmt.Sprintf("+%s\r\n", req.Args()[0])
			return []byte(echo), false, true, nil
		} else {
			return nil, false, true, WrongArgumentCount
		}
	}

	if len(reply) > 0 {
		return reply, shouldClose, true, nil
	}

	if err := verifyCommand(req); err != nil {
		return nil, shouldClose, true, err
	}

	if len(req.Args()) >= 1 {
		if _, ok := BlackKeyLists[req.Args()[0]]; ok {
			// key blacked
			reply = []byte("-key already be blacked \r\n")
			return reply, shouldClose, true, nil
		}
	}

	return reply, shouldClose, false, nil
}
예제 #11
0
func (s *Session) proxyBlack(req *redis.Request) {
	args := strings.ToLower(req.Args()[1])
	// log.Warning(req.Args())
	switch args {
	// proxy black remove keyname
	case "remove":
		if len(req.Args()) != 3 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		// delete(BlackKeyLists, req.Args()[-1])
		key := req.Args()[2]
		if _, exists := BlackKeyLists[key]; exists {
			log.Warning("remove black key ", key)
			delete(BlackKeyLists, key)
			s.write2client(OK_BYTES)
		} else {
			s.write2client([]byte("-remove key not exists\r\n"))
		}

	case "get":
		if len(req.Args()) != 2 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		ks := make([]string, 0)
		for k, _ := range BlackKeyLists {
			ks = append(ks, k)
		}
		d := redis.FormatStringSlice(ks)
		s.write2client(d)
	case "set":
		//proxy black set 3600 keyname1
		if len(req.Args()) != 4 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		t, err := strconv.Atoi(req.Args()[3])
		if err != nil {
			log.Warningf("black key: %s time unavailable %s", req.Args()[2], req.Args()[3])
			err := fmt.Sprintf("-%s\r\n", BlackTimeUnavaliable)
			s.write2client([]byte(err))
			return
		}
		if t > 86400 || t < 0 {
			log.Warningf("black key: %s time unavailable %s", req.Args()[2], req.Args()[3])
			s.write2client([]byte("-black time must between 0 ~ 86400\r\n"))
			return
		}
		BlackKeyLists[req.Args()[2]] = &BlackKey{
			Name:     req.Args()[2],
			Startup:  time.Now(),
			Deadline: time.Now().Add(time.Duration(t) * time.Second),
		}
		s.write2client(OK_BYTES)
		return
	default:
		err := fmt.Sprintf("-%s\r\n", UnknowProxyOpType)
		s.write2client([]byte(err))
		return
	}
	return
}
예제 #12
0
func (s *Session) PROXY(req *redis.Request) {
	op := strings.ToLower(req.Args()[0])
	// log.Warning("PROXY ", req.Args())
	switch op {
	case "info":
		if len(req.Args()) != 1 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		s.proxyInfo(req)
	case "black":
		if len(req.Args()) < 2 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		s.proxyBlack(req)
	case "config":
		// proxy config set name value
		if len(req.Args()) < 3 || len(req.Args()) > 4 {
			err := fmt.Sprintf("-%s\r\n", WrongArgumentCount)
			s.write2client([]byte(err))
			return
		}
		s.proxyConf(req)
	default:
		log.Warning("Unknow proxy op type: ", req.Args())
		err := fmt.Sprintf("-%s\r\n", UnknowProxyOpType)
		s.write2client([]byte(err))
		return
	}

}