Ejemplo n.º 1
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv Server, irh InvalidRequestHandler) {
	LogInfoln("[WECHAT_DEBUG] request uri:", r.RequestURI)
	LogInfoln("[WECHAT_DEBUG] request remote-addr:", r.RemoteAddr)
	LogInfoln("[WECHAT_DEBUG] request user-agent:", r.UserAgent())

	switch r.Method {
	case "POST": // 消息处理
		switch encryptType := queryValues.Get("encrypt_type"); encryptType {
		case "aes": // 安全模式, 兼容模式
			signature := queryValues.Get("signature") // 只讀取, 不驗證了

			msgSignature1 := queryValues.Get("msg_signature")
			if msgSignature1 == "" {
				irh.ServeInvalidRequest(w, r, errors.New("msg_signature is empty"))
				return
			}
			if len(msgSignature1) != 40 { // sha1
				err := fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			timestampStr := queryValues.Get("timestamp")
			if timestampStr == "" {
				irh.ServeInvalidRequest(w, r, errors.New("timestamp is empty"))
				return
			}

			timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
			if err != nil {
				err = errors.New("can not parse timestamp to int64: " + timestampStr)
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			nonce := queryValues.Get("nonce")
			if nonce == "" {
				irh.ServeInvalidRequest(w, r, errors.New("nonce is empty"))
				return
			}

			reqBody, err := ioutil.ReadAll(r.Body)
			if err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}
			LogInfoln("[WECHAT_DEBUG] request msg http body:\r\n", string(reqBody))

			var requestHttpBody RequestHttpBody
			if err := xml.Unmarshal(reqBody, &requestHttpBody); err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			// 安全考虑验证下 ToUserName
			haveToUserName := requestHttpBody.ToUserName
			if wantToUserName := srv.OriId(); wantToUserName != "" {
				if len(haveToUserName) != len(wantToUserName) {
					err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
					irh.ServeInvalidRequest(w, r, err)
					return
				}
				if subtle.ConstantTimeCompare([]byte(haveToUserName), []byte(wantToUserName)) != 1 {
					err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
					irh.ServeInvalidRequest(w, r, err)
					return
				}
			}

			token := srv.Token()

			// 验证签名
			msgSignature2 := util.MsgSign(token, timestampStr, nonce, requestHttpBody.EncryptedMsg)
			if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
				err = fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			// 解密
			encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
			if err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			appId := srv.AppId()
			aesKey := srv.CurrentAESKey()

			random, rawMsgXML, err := util.AESDecryptMsg(encryptedMsgBytes, appId, aesKey)
			if err != nil {
				// 尝试用上一次的 AESKey 来解密
				lastAESKey, isLastAESKeyValid := srv.LastAESKey()
				if !isLastAESKeyValid {
					irh.ServeInvalidRequest(w, r, err)
					return
				}

				aesKey = lastAESKey // NOTE

				random, rawMsgXML, err = util.AESDecryptMsg(encryptedMsgBytes, appId, aesKey)
				if err != nil {
					irh.ServeInvalidRequest(w, r, err)
					return
				}
			}

			LogInfoln("[WECHAT_DEBUG] request msg raw xml:\r\n", string(rawMsgXML))

			// 解密成功, 解析 MixedMessage
			var mixedMsg MixedMessage
			if err := xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			// 安全考虑再次验证 ToUserName
			if haveToUserName != mixedMsg.ToUserName {
				err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's ToUserName(==%s)", haveToUserName, mixedMsg.ToUserName)
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			r := &Request{
				HttpRequest: r,

				QueryValues: queryValues,
				Signature:   signature,
				Timestamp:   timestamp,
				Nonce:       nonce,

				RawMsgXML: rawMsgXML,
				MixedMsg:  &mixedMsg,

				MsgSignature: msgSignature1,
				EncryptType:  encryptType,
				AESKey:       aesKey,
				Random:       random,

				Token: token,
				AppId: appId,
			}
			srv.MessageHandler().ServeMessage(w, r)

		case "", "raw": // 明文模式
			signature1 := queryValues.Get("signature")
			if signature1 == "" {
				irh.ServeInvalidRequest(w, r, errors.New("signature is empty"))
				return
			}
			if len(signature1) != 40 { // sha1
				err := fmt.Errorf("the length of signature mismatch, have: %d, want: 40", len(signature1))
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			timestampStr := queryValues.Get("timestamp")
			if timestampStr == "" {
				irh.ServeInvalidRequest(w, r, errors.New("timestamp is empty"))
				return
			}

			timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
			if err != nil {
				err = errors.New("can not parse timestamp to int64: " + timestampStr)
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			nonce := queryValues.Get("nonce")
			if nonce == "" {
				irh.ServeInvalidRequest(w, r, errors.New("nonce is empty"))
				return
			}

			token := srv.Token()

			signature2 := util.Sign(token, timestampStr, nonce)
			if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
				err = fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			// 验证签名成功, 解析 MixedMessage
			rawMsgXML, err := ioutil.ReadAll(r.Body)
			if err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			LogInfoln("[WECHAT_DEBUG] request msg raw xml:\r\n", string(rawMsgXML))

			var mixedMsg MixedMessage
			if err := xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			// 安全考虑验证 ToUserName
			haveToUserName := mixedMsg.ToUserName
			if wantToUserName := srv.OriId(); wantToUserName != "" {
				if len(haveToUserName) != len(wantToUserName) {
					err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
					irh.ServeInvalidRequest(w, r, err)
					return
				}
				if subtle.ConstantTimeCompare([]byte(haveToUserName), []byte(wantToUserName)) != 1 {
					err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
					irh.ServeInvalidRequest(w, r, err)
					return
				}
			}

			// 成功, 交给 MessageHandler
			r := &Request{
				HttpRequest: r,

				QueryValues: queryValues,
				Signature:   signature1,
				Timestamp:   timestamp,
				Nonce:       nonce,

				RawMsgXML: rawMsgXML,
				MixedMsg:  &mixedMsg,

				EncryptType: encryptType,

				AppId: srv.AppId(),
				Token: token,
			}
			srv.MessageHandler().ServeMessage(w, r)

		default: // 未知的加密类型
			err := errors.New("unknown encrypt_type: " + encryptType)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

	case "GET": // 首次验证
		signature1 := queryValues.Get("signature")
		if signature1 == "" {
			irh.ServeInvalidRequest(w, r, errors.New("signature is empty"))
			return
		}
		if len(signature1) != 40 { // sha1
			err := fmt.Errorf("the length of signature mismatch, have: %d, want: 40", len(signature1))
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			irh.ServeInvalidRequest(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			irh.ServeInvalidRequest(w, r, errors.New("nonce is empty"))
			return
		}

		echostr := queryValues.Get("echostr")
		if echostr == "" {
			irh.ServeInvalidRequest(w, r, errors.New("echostr is empty"))
			return
		}

		signature2 := util.Sign(srv.Token(), timestamp, nonce)
		if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
			err := fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		io.WriteString(w, echostr)
	}
}
Ejemplo n.º 2
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv AgentServer, irh InvalidRequestHandler) {
	LogInfoln("[WECHAT_DEBUG] request uri:", r.RequestURI)
	LogInfoln("[WECHAT_DEBUG] request remote-addr:", r.RemoteAddr)
	LogInfoln("[WECHAT_DEBUG] request user-agent:", r.UserAgent())

	switch r.Method {
	case "POST": // 消息处理
		msgSignature1 := queryValues.Get("msg_signature")
		if msgSignature1 == "" {
			irh.ServeInvalidRequest(w, r, errors.New("msg_signature is empty"))
			return
		}
		if len(msgSignature1) != 40 { // sha1
			err := fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		timestampStr := queryValues.Get("timestamp")
		if timestampStr == "" {
			irh.ServeInvalidRequest(w, r, errors.New("timestamp is empty"))
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			err = errors.New("can not parse timestamp to int64: " + timestampStr)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			irh.ServeInvalidRequest(w, r, errors.New("nonce is empty"))
			return
		}

		reqBody, err := ioutil.ReadAll(r.Body)
		if err != nil {
			irh.ServeInvalidRequest(w, r, err)
			return
		}
		LogInfoln("[WECHAT_DEBUG] request msg http body:\r\n", string(reqBody))

		// 解析 RequestHttpBody
		var requestHttpBody RequestHttpBody
		if err := xml.Unmarshal(reqBody, &requestHttpBody); err != nil {
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		corpId := srv.CorpId()

		haveCorpId := requestHttpBody.CorpId
		if len(haveCorpId) != len(corpId) {
			err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveCorpId, corpId)
			irh.ServeInvalidRequest(w, r, err)
			return
		}
		if subtle.ConstantTimeCompare([]byte(haveCorpId), []byte(corpId)) != 1 {
			err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveCorpId, corpId)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		agentId := srv.AgentId()

		haveAgentId := requestHttpBody.AgentId
		if haveAgentId != agentId && haveAgentId != 0 {
			err = fmt.Errorf("the RequestHttpBody's AgentId mismatch, have: %d, want: %d", haveAgentId, agentId)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		// 此时
		// 要么 haveAgentId == wantAgentId,
		// 要么 haveAgentId == 0

		agentToken := srv.Token()

		// 验证签名
		msgSignature2 := util.MsgSign(agentToken, timestampStr, nonce, requestHttpBody.EncryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err = fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		// 解密
		encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
		if err != nil {
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		aesKey := srv.CurrentAESKey()
		random, rawMsgXML, err := util.AESDecryptMsg(encryptedMsgBytes, corpId, aesKey)
		if err != nil {
			// 尝试用上一次的 AESKey 来解密
			lastAESKey, isLastAESKeyValid := srv.LastAESKey()
			if !isLastAESKeyValid {
				irh.ServeInvalidRequest(w, r, err)
				return
			}

			aesKey = lastAESKey // NOTE

			random, rawMsgXML, err = util.AESDecryptMsg(encryptedMsgBytes, corpId, aesKey)
			if err != nil {
				irh.ServeInvalidRequest(w, r, err)
				return
			}
		}

		LogInfoln("[WECHAT_DEBUG] request msg raw xml:\r\n", string(rawMsgXML))

		// 解密成功, 解析 MixedMessage
		var mixedMsg MixedMessage
		if err = xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		// 安全考虑再次验证
		if haveCorpId != mixedMsg.ToUserName {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's ToUserName(==%s)", haveCorpId, mixedMsg.ToUserName)
			irh.ServeInvalidRequest(w, r, err)
			return
		}
		if haveAgentId != mixedMsg.AgentId {
			err = fmt.Errorf("the RequestHttpBody's AgentId(==%d) mismatch the MixedMessage's AgengId(==%d)", haveAgentId, mixedMsg.AgentId)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		// 如果是订阅/取消订阅 整个企业号, haveAgentId == 0
		if haveAgentId != agentId {
			if mixedMsg.MsgType == "event" &&
				(mixedMsg.Event == "subscribe" || mixedMsg.Event == "unsubscribe") {
				// do nothing
			} else {
				err = fmt.Errorf("the RequestHttpBody's AgentId mismatch, have: %d, want: %d", haveAgentId, agentId)
				irh.ServeInvalidRequest(w, r, err)
				return
			}
		}

		// 成功, 交给 MessageHandler
		r := &Request{
			HttpRequest: r,

			QueryValues:  queryValues,
			MsgSignature: msgSignature1,
			Timestamp:    timestamp,
			Nonce:        nonce,

			RawMsgXML: rawMsgXML,
			MixedMsg:  &mixedMsg,

			AESKey: aesKey,
			Random: random,

			CorpId:     haveCorpId,
			AgentId:    haveAgentId,
			AgentToken: agentToken,
		}
		srv.MessageHandler().ServeMessage(w, r)

	case "GET": // 首次验证
		msgSignature1 := queryValues.Get("msg_signature")
		if msgSignature1 == "" {
			irh.ServeInvalidRequest(w, r, errors.New("msg_signature is empty"))
			return
		}
		if len(msgSignature1) != 40 { // sha1
			err := fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			irh.ServeInvalidRequest(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			irh.ServeInvalidRequest(w, r, errors.New("nonce is empty"))
			return
		}

		encryptedMsg := queryValues.Get("echostr")
		if encryptedMsg == "" {
			irh.ServeInvalidRequest(w, r, errors.New("echostr is empty"))
			return
		}

		msgSignature2 := util.MsgSign(srv.Token(), timestamp, nonce, encryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err := fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		// 解密
		encryptedMsgBytes, err := base64.StdEncoding.DecodeString(encryptedMsg)
		if err != nil {
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		corpId := srv.CorpId()
		aesKey := srv.CurrentAESKey()
		_, echostr, err := util.AESDecryptMsg(encryptedMsgBytes, corpId, aesKey)
		if err != nil {
			irh.ServeInvalidRequest(w, r, err)
			return
		}

		w.Write(echostr)
	}
}
Ejemplo n.º 3
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv Server, errHandler mp.ErrorHandler) {
	switch r.Method {
	case "POST": // 消息处理
		switch encryptType := queryValues.Get("encrypt_type"); encryptType {
		case "aes":
			msgSignature1 := queryValues.Get("msg_signature")
			if msgSignature1 == "" {
				errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
				return
			}
			if len(msgSignature1) != 40 { // sha1
				err := fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
				errHandler.ServeError(w, r, err)
				return
			}

			timestampStr := queryValues.Get("timestamp")
			if timestampStr == "" {
				errHandler.ServeError(w, r, errors.New("timestamp is empty"))
				return
			}

			timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
			if err != nil {
				err = errors.New("can not parse timestamp to int64: " + timestampStr)
				errHandler.ServeError(w, r, err)
				return
			}

			nonce := queryValues.Get("nonce")
			if nonce == "" {
				errHandler.ServeError(w, r, errors.New("nonce is empty"))
				return
			}

			var requestHttpBody RequestHttpBody
			if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			appId := srv.AppId()

			// 安全考虑验证下 AppId
			haveAppId := requestHttpBody.AppId
			if len(haveAppId) != len(appId) {
				err = fmt.Errorf("the RequestHttpBody's AppId mismatch, have: %s, want: %s", haveAppId, appId)
				errHandler.ServeError(w, r, err)
				return
			}
			if subtle.ConstantTimeCompare([]byte(haveAppId), []byte(appId)) != 1 {
				err = fmt.Errorf("the RequestHttpBody's AppId mismatch, have: %s, want: %s", haveAppId, appId)
				errHandler.ServeError(w, r, err)
				return
			}

			token := srv.Token()

			// 验证签名
			msgSignature2 := util.MsgSign(token, timestampStr, nonce, requestHttpBody.EncryptedMsg)
			if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
				err = fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
				errHandler.ServeError(w, r, err)
				return
			}

			// 解密
			encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			aesKey := srv.CurrentAESKey()

			random, rawMsgXML, err := util.AESDecryptMsg(encryptedMsgBytes, appId, aesKey)
			if err != nil {
				// 尝试用上一次的 AESKey 来解密
				lastAESKey, isLastAESKeyValid := srv.LastAESKey()
				if !isLastAESKeyValid {
					errHandler.ServeError(w, r, err)
					return
				}

				aesKey = lastAESKey // NOTE

				random, rawMsgXML, err = util.AESDecryptMsg(encryptedMsgBytes, appId, aesKey)
				if err != nil {
					errHandler.ServeError(w, r, err)
					return
				}
			}

			// 解密成功, 解析 MixedMessage
			var mixedMsg MixedMessage
			if err := xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			// 安全考虑再次验证 AppId
			if haveAppId != mixedMsg.AppId {
				err = fmt.Errorf("the RequestHttpBody's AppId(==%s) mismatch the MixedMessage's AppId(==%s)", haveAppId, mixedMsg.AppId)
				errHandler.ServeError(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			req := &Request{
				HttpRequest: r,

				QueryValues:  queryValues,
				MsgSignature: msgSignature1,
				EncryptType:  encryptType,
				Timestamp:    timestamp,
				Nonce:        nonce,

				RawMsgXML: rawMsgXML,
				MixedMsg:  &mixedMsg,

				AESKey: aesKey,
				Random: random,

				AppId: appId,
				Token: token,
			}
			srv.MessageHandler().ServeMessage(w, req)

		default: // 未知的加密类型
			err := errors.New("unknown encrypt_type: " + encryptType)
			errHandler.ServeError(w, r, err)
			return
		}

	case "GET": // 首次验证
		signature1 := queryValues.Get("signature")
		if signature1 == "" {
			errHandler.ServeError(w, r, errors.New("signature is empty"))
			return
		}
		if len(signature1) != 40 { // sha1
			err := fmt.Errorf("the length of signature mismatch, have: %d, want: 40", len(signature1))
			errHandler.ServeError(w, r, err)
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		echostr := queryValues.Get("echostr")
		if echostr == "" {
			errHandler.ServeError(w, r, errors.New("echostr is empty"))
			return
		}

		signature2 := util.Sign(srv.Token(), timestamp, nonce)
		if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
			err := fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
			errHandler.ServeError(w, r, err)
			return
		}

		io.WriteString(w, echostr)
	}
}
Ejemplo n.º 4
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, urlValues url.Values,
	agentServer AgentServer, invalidRequestHandler InvalidRequestHandler) {

	switch r.Method {
	case "POST": // 消息处理
		msgSignature1, timestampStr, nonce, err := parsePostURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 首先判断签名长度是否合法
		if len(msgSignature1) != 40 {
			err = fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			err = errors.New("can not parse timestamp to int64: " + timestampStr)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 解析 RequestHttpBody
		var requestHttpBody RequestHttpBody
		if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		haveCorpId := requestHttpBody.CorpId
		wantCorpId := agentServer.CorpId()
		if len(haveCorpId) != len(wantCorpId) {
			err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveCorpId, wantCorpId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}
		if subtle.ConstantTimeCompare([]byte(haveCorpId), []byte(wantCorpId)) != 1 {
			err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveCorpId, wantCorpId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		haveAgentId := requestHttpBody.AgentId
		wantAgentId := agentServer.AgentId()
		if haveAgentId != wantAgentId && haveAgentId != 0 {
			err = fmt.Errorf("the RequestHttpBody's AgentId mismatch, have: %d, want: %d", haveAgentId, wantAgentId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 此时
		// 要么 haveAgentId == wantAgentId,
		// 要么 haveAgentId == 0

		agentToken := agentServer.Token()

		// 验证签名
		msgSignature2 := util.MsgSign(agentToken, timestampStr, nonce, requestHttpBody.EncryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err = fmt.Errorf("check signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 解密
		EncryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		AESKey := agentServer.CurrentAESKey()
		Random, RawMsgXML, err := util.AESDecryptMsg(EncryptedMsgBytes, wantCorpId, AESKey)
		if err != nil {
			// 尝试用上一次的 AESKey 来解密
			LastAESKey := agentServer.LastAESKey()
			if bytes.Equal(AESKey[:], LastAESKey[:]) || bytes.Equal(zeroAESKey[:], LastAESKey[:]) {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			AESKey = LastAESKey // NOTE
			Random, RawMsgXML, err = util.AESDecryptMsg(EncryptedMsgBytes, wantCorpId, AESKey)
			if err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
		}

		// 解密成功, 解析 MixedMessage
		var MixedMsg MixedMessage
		if err = xml.Unmarshal(RawMsgXML, &MixedMsg); err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 安全考虑再次验证
		if haveCorpId != MixedMsg.ToUserName {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's ToUserName(==%s)", haveCorpId, MixedMsg.ToUserName)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}
		if haveAgentId != MixedMsg.AgentId {
			err = fmt.Errorf("the RequestHttpBody's AgentId(==%d) mismatch the MixedMessage's AgengId(==%d)", haveAgentId, MixedMsg.AgentId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 如果是订阅/取消订阅 整个企业号, haveAgentId == 0
		if haveAgentId != wantAgentId {
			if MixedMsg.MsgType == "event" &&
				(MixedMsg.Event == "subscribe" || MixedMsg.Event == "unsubscribe") {
				// do nothing
			} else {
				err = fmt.Errorf("the RequestHttpBody's AgentId mismatch, have: %d, want: %d", haveAgentId, wantAgentId)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
		}

		// 成功, 交给 MessageHandler
		r := &Request{
			HttpRequest: r,

			QueryValues:  urlValues,
			MsgSignature: msgSignature1,
			TimeStamp:    timestamp,
			Nonce:        nonce,

			RawMsgXML: RawMsgXML,
			MixedMsg:  &MixedMsg,

			AESKey: AESKey,
			Random: Random,

			CorpId:     haveCorpId,
			AgentId:    haveAgentId,
			AgentToken: agentToken,
		}
		agentServer.MessageHandler().ServeMessage(w, r)

	case "GET": // 首次验证
		msgSignature1, timestamp, nonce, encryptedMsg, err := parseGetURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 验证签名
		if len(msgSignature1) != 40 {
			err = fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		msgSignature2 := util.MsgSign(agentServer.Token(), timestamp, nonce, encryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err = fmt.Errorf("check signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 解密
		EncryptedMsgBytes, err := base64.StdEncoding.DecodeString(encryptedMsg)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		CorpId := agentServer.CorpId()
		AESKey := agentServer.CurrentAESKey()
		_, echostr, err := util.AESDecryptMsg(EncryptedMsgBytes, CorpId, AESKey)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		w.Write(echostr)
	}
}
Ejemplo n.º 5
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, urlValues url.Values,
	wechatServer WechatServer, invalidRequestHandler InvalidRequestHandler) {

	switch r.Method {
	case "POST": // 消息处理
		signature1, timestampStr, nonce, encryptType, msgSignature1, err := parsePostURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			err = errors.New("can not parse timestamp to int64: " + timestampStr)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		switch encryptType {
		case "aes": // 兼容模式, 安全模式
			//if len(signature1) != 40 {
			//	err = fmt.Errorf("the length of signature mismatch, have: %d, want: 40", len(signature1))
			//	invalidRequestHandler.ServeInvalidRequest(w, r, err)
			//	return
			//}

			//WechatToken := wechatServer.Token()
			//signature2 := util.Sign(WechatToken, timestampStr, nonce)
			//if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
			//	err = fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
			//	invalidRequestHandler.ServeInvalidRequest(w, r, err)
			//	return
			//}

			// 首先验证密文签名长度
			if len(msgSignature1) != 40 {
				err = fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			var requestHttpBody RequestHttpBody
			if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 安全考虑验证下 ToUserName
			haveToUserName := requestHttpBody.ToUserName
			wantToUserName := wechatServer.WechatId()
			if len(haveToUserName) != len(wantToUserName) {
				err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
			if subtle.ConstantTimeCompare([]byte(haveToUserName), []byte(wantToUserName)) != 1 {
				err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			wechatToken := wechatServer.Token()

			// 验证签名
			msgSignature2 := util.MsgSign(wechatToken, timestampStr, nonce, requestHttpBody.EncryptedMsg)
			if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
				err = fmt.Errorf("check signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 解密
			EncryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
			if err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			WechatAppId := wechatServer.AppId()
			AESKey := wechatServer.CurrentAESKey()
			Random, RawMsgXML, err := util.AESDecryptMsg(EncryptedMsgBytes, WechatAppId, AESKey)
			if err != nil {
				// 尝试用上一次的 AESKey 来解密
				LastAESKey := wechatServer.LastAESKey()
				if bytes.Equal(zeroAESKey[:], LastAESKey[:]) || bytes.Equal(AESKey[:], LastAESKey[:]) {
					invalidRequestHandler.ServeInvalidRequest(w, r, err)
					return
				}

				AESKey = LastAESKey // NOTE
				Random, RawMsgXML, err = util.AESDecryptMsg(EncryptedMsgBytes, WechatAppId, AESKey)
				if err != nil {
					invalidRequestHandler.ServeInvalidRequest(w, r, err)
					return
				}
			}

			// 解密成功, 解析 MixedMessage
			var MixedMsg MixedMessage
			if err = xml.Unmarshal(RawMsgXML, &MixedMsg); err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 安全考虑再次验证 ToUserName
			if haveToUserName != MixedMsg.ToUserName {
				err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's ToUserName(==%s)", haveToUserName, MixedMsg.ToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			r := &Request{
				HttpRequest: r,

				QueryValues: urlValues,
				Signature:   signature1,
				TimeStamp:   timestamp,
				Nonce:       nonce,

				RawMsgXML: RawMsgXML,
				MixedMsg:  &MixedMsg,

				MsgSignature: msgSignature1,
				EncryptType:  encryptType,
				AESKey:       AESKey,
				Random:       Random,

				WechatId:    haveToUserName,
				WechatToken: wechatToken,
				WechatAppId: WechatAppId,
			}
			wechatServer.MessageHandler().ServeMessage(w, r)

		case "", "raw": // 明文模式
			// 首先验证签名
			if len(signature1) != 40 {
				err = fmt.Errorf("the length of signature mismatch, have: %d, want: 40", len(signature1))
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			WechatToken := wechatServer.Token()
			signature2 := util.Sign(WechatToken, timestampStr, nonce)
			if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
				err = fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 验证签名成功, 解析 MixedMessage
			RawMsgXML, err := ioutil.ReadAll(r.Body)
			if err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			var MixedMsg MixedMessage
			if err := xml.Unmarshal(RawMsgXML, &MixedMsg); err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 安全考虑验证 ToUserName
			haveToUserName := MixedMsg.ToUserName
			wantToUserName := wechatServer.WechatId()
			if len(haveToUserName) != len(wantToUserName) {
				err = fmt.Errorf("the message's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
			if subtle.ConstantTimeCompare([]byte(haveToUserName), []byte(wantToUserName)) != 1 {
				err = fmt.Errorf("the message's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			r := &Request{
				HttpRequest: r,

				QueryValues: urlValues,
				Signature:   signature1,
				TimeStamp:   timestamp,
				Nonce:       nonce,

				RawMsgXML: RawMsgXML,
				MixedMsg:  &MixedMsg,

				WechatId:    haveToUserName,
				WechatToken: WechatToken,
				WechatAppId: wechatServer.AppId(),
			}
			wechatServer.MessageHandler().ServeMessage(w, r)

		default: // 未知的加密类型
			err := errors.New("unknown encrypt_type: " + encryptType)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

	case "GET": // 首次验证
		signature1, timestamp, nonce, echostr, err := parseGetURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		if len(signature1) != 40 {
			err = fmt.Errorf("the length of signature mismatch, have: %d, want: 40", len(signature1))
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		signature2 := util.Sign(wechatServer.Token(), timestamp, nonce)
		if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
			err = fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		io.WriteString(w, echostr)
	}
}
Ejemplo n.º 6
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv AgentServer, errHandler ErrorHandler) {
	switch r.Method {
	case "POST": // 消息处理
		msgSignature1 := queryValues.Get("msg_signature")
		if msgSignature1 == "" {
			errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
			return
		}

		timestampStr := queryValues.Get("timestamp")
		if timestampStr == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			err = errors.New("can not parse timestamp to int64: " + timestampStr)
			errHandler.ServeError(w, r, err)
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		// 解析 RequestHttpBody
		var requestHttpBody RequestHttpBody
		if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		haveCorpId := requestHttpBody.CorpId
		wantCorpId := srv.CorpId()
		if wantCorpId != "" && !security.SecureCompareString(haveCorpId, wantCorpId) {
			err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveCorpId, wantCorpId)
			errHandler.ServeError(w, r, err)
			return
		}

		haveAgentId := requestHttpBody.AgentId
		wantAgentId := srv.AgentId()
		if wantCorpId != "" && wantAgentId != -1 {
			if haveAgentId != wantAgentId && haveAgentId != 0 {
				err = fmt.Errorf("the RequestHttpBody's AgentId mismatch, have: %d, want: %d", haveAgentId, wantAgentId)
				errHandler.ServeError(w, r, err)
				return
			}
			// 此时
			// 要么 haveAgentId == wantAgentId,
			// 要么 haveAgentId == 0
		}

		agentToken := srv.Token()

		// 验证签名
		msgSignature2 := util.MsgSign(agentToken, timestampStr, nonce, requestHttpBody.EncryptedMsg)
		if !security.SecureCompareString(msgSignature1, msgSignature2) {
			err := fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			errHandler.ServeError(w, r, err)
			return
		}

		// 解密
		encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		aesKey := srv.CurrentAESKey()
		random, rawMsgXML, aesAppId, err := util.AESDecryptMsg(encryptedMsgBytes, aesKey)
		if err != nil {
			// 尝试用上一次的 AESKey 来解密
			lastAESKey, isLastAESKeyValid := srv.LastAESKey()
			if !isLastAESKeyValid {
				errHandler.ServeError(w, r, err)
				return
			}

			aesKey = lastAESKey // NOTE

			random, rawMsgXML, aesAppId, err = util.AESDecryptMsg(encryptedMsgBytes, aesKey)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}
		}
		if haveCorpId != string(aesAppId) {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the CorpId with aes encrypt(==%s)", haveCorpId, aesAppId)
			errHandler.ServeError(w, r, err)
			return
		}

		// 解密成功, 解析 MixedMessage
		var mixedMsg MixedMessage
		if err = xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		// 安全考虑再次验证
		if haveCorpId != mixedMsg.ToUserName {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's ToUserName(==%s)", haveCorpId, mixedMsg.ToUserName)
			errHandler.ServeError(w, r, err)
			return
		}
		if haveAgentId != mixedMsg.AgentId {
			err = fmt.Errorf("the RequestHttpBody's AgentId(==%d) mismatch the MixedMessage's AgengId(==%d)", haveAgentId, mixedMsg.AgentId)
			errHandler.ServeError(w, r, err)
			return
		}

		// 如果是订阅/取消订阅 整个企业号, haveAgentId == 0
		if wantCorpId != "" && wantAgentId != -1 && haveAgentId != wantAgentId {
			if mixedMsg.MsgType == "event" &&
				(mixedMsg.Event == "subscribe" || mixedMsg.Event == "unsubscribe") {
				// do nothing
			} else {
				err = fmt.Errorf("the RequestHttpBody's AgentId mismatch, have: %d, want: %d", haveAgentId, wantAgentId)
				errHandler.ServeError(w, r, err)
				return
			}
		}

		// 成功, 交给 MessageHandler
		req := &Request{
			AgentToken: agentToken,

			HttpRequest: r,
			QueryValues: queryValues,

			MsgSignature: msgSignature1,
			Timestamp:    timestamp,
			Nonce:        nonce,

			RawMsgXML: rawMsgXML,
			MixedMsg:  &mixedMsg,

			AESKey:  aesKey,
			Random:  random,
			CorpId:  haveCorpId,
			AgentId: haveAgentId,
		}
		srv.MessageHandler().ServeMessage(w, req)

	case "GET": // 首次验证
		msgSignature1 := queryValues.Get("msg_signature")
		if msgSignature1 == "" {
			errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		encryptedMsg := queryValues.Get("echostr")
		if encryptedMsg == "" {
			errHandler.ServeError(w, r, errors.New("echostr is empty"))
			return
		}

		msgSignature2 := util.MsgSign(srv.Token(), timestamp, nonce, encryptedMsg)
		if !security.SecureCompareString(msgSignature1, msgSignature2) {
			err := fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			errHandler.ServeError(w, r, err)
			return
		}

		// 解密
		encryptedMsgBytes, err := base64.StdEncoding.DecodeString(encryptedMsg)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		aesKey := srv.CurrentAESKey()
		_, echostr, aesAppId, err := util.AESDecryptMsg(encryptedMsgBytes, aesKey)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		wantCorpId := srv.CorpId()
		if wantCorpId != "" && !security.SecureCompare(aesAppId, []byte(wantCorpId)) {
			err = fmt.Errorf("AppId with aes encrypt mismatch, have: %s, want: %s", aesAppId, wantCorpId)
			errHandler.ServeError(w, r, err)
			return
		}

		w.Write(echostr)
	}
}
Ejemplo n.º 7
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 确保所有参数合法, r.Body 能正确读取数据
func ServeHTTP(w http.ResponseWriter, r *http.Request,
	urlValues url.Values, agent Agent, invalidRequestHandler InvalidRequestHandler) {

	switch r.Method {
	case "POST": // 消息处理
		msgSignature1, timestampStr, nonce, err := parsePostURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		const signatureLen = sha1.Size * 2
		if len(msgSignature1) != signatureLen {
			err = fmt.Errorf("the length of msg_signature mismatch, have: %d, want: %d", len(msgSignature1), signatureLen)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, fmt.Errorf("can not parse timestamp(==%q) to int64, error: %s", timestampStr, err.Error()))
			return
		}

		var requestHttpBody request.RequestHttpBody
		if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		agentCorpId := agent.GetCorpId()
		if len(requestHttpBody.CorpId) != len(agentCorpId) {
			err = fmt.Errorf("the message RequestHttpBody's ToUserName mismatch, have: %s, want: %s", requestHttpBody.CorpId, agentCorpId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}
		if subtle.ConstantTimeCompare([]byte(requestHttpBody.CorpId), []byte(agentCorpId)) != 1 {
			err = fmt.Errorf("the message RequestHttpBody's ToUserName mismatch, have: %s, want: %s", requestHttpBody.CorpId, agentCorpId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		agentAgentId := agent.GetAgentId()
		if requestHttpBody.AgentId != agentAgentId && requestHttpBody.AgentId != 0 {
			err = fmt.Errorf("the message RequestHttpBody's AgentId mismatch, have: %d, want: %d", requestHttpBody.AgentId, agentAgentId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 此时
		// 要么 requestHttpBody.AgentId == agent.GetAgentId(),
		// 要么 requestHttpBody.AgentId == 0

		msgSignature2 := util.MsgSignature(agent.GetToken(), timestampStr, nonce, requestHttpBody.EncryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err = fmt.Errorf("check signature failed, have: %s, want: %s", msgSignature1, msgSignature2)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		EncryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		random, rawXMLMsg, err := util.AESDecryptMsg(EncryptedMsgBytes, agentCorpId, agent.GetAESKey())
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		var msgReq request.Request
		if err := xml.Unmarshal(rawXMLMsg, &msgReq); err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		if requestHttpBody.CorpId != msgReq.ToUserName {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the Request's ToUserName(==%s)", requestHttpBody.CorpId, msgReq.ToUserName)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		if requestHttpBody.AgentId != msgReq.AgentId {
			err = fmt.Errorf("the RequestHttpBody's AgentId(==%d) mismatch the Request's AgengId(==%d)", requestHttpBody.AgentId, msgReq.AgentId)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		// 订阅/取消订阅 整个企业号, requestHttpBody.AgentId == 0
		if requestHttpBody.AgentId != agentAgentId {
			if msgReq.MsgType == request.MSG_TYPE_EVENT &&
				(msgReq.Event == request.EVENT_TYPE_SUBSCRIBE || msgReq.Event == request.EVENT_TYPE_UNSUBSCRIBE) {
				// do nothing
			} else {
				err = fmt.Errorf("the message Request's AgentId mismatch, have: %d, want: %d", msgReq.AgentId, agentAgentId)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
		}

		msgDispatch(w, r, &msgReq, rawXMLMsg, timestamp, nonce, random, agent)

	case "GET": // 首次验证
		msgSignature1, timestamp, nonce, encryptedMsg, err := parseGetURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		const signatureLen = sha1.Size * 2
		if len(msgSignature1) != signatureLen {
			err = fmt.Errorf("the length of msg_signature mismatch, have: %d, want: %d", len(msgSignature1), signatureLen)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		msgSignature2 := util.MsgSignature(agent.GetToken(), timestamp, nonce, encryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err = fmt.Errorf("check signature failed, have: %s, want: %s", msgSignature1, msgSignature2)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		EncryptedMsgBytes, err := base64.StdEncoding.DecodeString(encryptedMsg)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		_, echostr, err := util.AESDecryptMsg(EncryptedMsgBytes, agent.GetCorpId(), agent.GetAESKey())
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		w.Write(echostr)
	}
}
Ejemplo n.º 8
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv Server, errHandler mp.ErrorHandler) {
	mp.LogInfoln("[WECHAT_DEBUG] request uri:", r.RequestURI)
	mp.LogInfoln("[WECHAT_DEBUG] request remote-addr:", r.RemoteAddr)
	mp.LogInfoln("[WECHAT_DEBUG] request user-agent:", r.UserAgent())

	switch r.Method {
	case "POST": // 消息处理
		switch encryptType := queryValues.Get("encrypt_type"); encryptType {
		case "aes":
			msgSignature1 := queryValues.Get("msg_signature")
			if msgSignature1 == "" {
				errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
				return
			}

			timestampStr := queryValues.Get("timestamp")
			if timestampStr == "" {
				errHandler.ServeError(w, r, errors.New("timestamp is empty"))
				return
			}

			timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
			if err != nil {
				err = errors.New("can not parse timestamp to int64: " + timestampStr)
				errHandler.ServeError(w, r, err)
				return
			}

			nonce := queryValues.Get("nonce")
			if nonce == "" {
				errHandler.ServeError(w, r, errors.New("nonce is empty"))
				return
			}

			reqBody, err := ioutil.ReadAll(r.Body)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}
			mp.LogInfoln("[WECHAT_DEBUG] request msg http body:\r\n", string(reqBody))

			var requestHttpBody RequestHttpBody
			if err := xml.Unmarshal(reqBody, &requestHttpBody); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			haveAppId := requestHttpBody.AppId
			wantAppId := srv.AppId()
			if wantAppId != "" && !security.SecureCompareString(haveAppId, wantAppId) {
				err = fmt.Errorf("the RequestHttpBody's AppId mismatch, have: %s, want: %s", haveAppId, wantAppId)
				errHandler.ServeError(w, r, err)
				return
			}

			token := srv.Token()

			// 验证签名
			msgSignature2 := util.MsgSign(token, timestampStr, nonce, requestHttpBody.EncryptedMsg)
			if !security.SecureCompareString(msgSignature1, msgSignature2) {
				err = fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
				errHandler.ServeError(w, r, err)
				return
			}

			// 解密
			encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			aesKey := srv.CurrentAESKey()
			random, rawMsgXML, aesAppId, err := util.AESDecryptMsg(encryptedMsgBytes, aesKey)
			if err != nil {
				// 尝试用上一次的 AESKey 来解密
				lastAESKey, isLastAESKeyValid := srv.LastAESKey()
				if !isLastAESKeyValid {
					errHandler.ServeError(w, r, err)
					return
				}

				aesKey = lastAESKey // NOTE

				random, rawMsgXML, aesAppId, err = util.AESDecryptMsg(encryptedMsgBytes, aesKey)
				if err != nil {
					errHandler.ServeError(w, r, err)
					return
				}
			}
			if haveAppId != string(aesAppId) {
				err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the AppId with aes encrypt(==%s)", haveAppId, aesAppId)
				errHandler.ServeError(w, r, err)
				return
			}

			mp.LogInfoln("[WECHAT_DEBUG] request msg raw xml:\r\n", string(rawMsgXML))

			// 解密成功, 解析 MixedMessage
			var mixedMsg MixedMessage
			if err := xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			// 安全考虑再次验证 AppId
			if haveAppId != mixedMsg.AppId {
				err = fmt.Errorf("the RequestHttpBody's AppId(==%s) mismatch the MixedMessage's AppId(==%s)", haveAppId, mixedMsg.AppId)
				errHandler.ServeError(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			req := &Request{
				Token: token,

				HttpRequest: r,
				QueryValues: queryValues,

				MsgSignature: msgSignature1,
				EncryptType:  encryptType,
				Timestamp:    timestamp,
				Nonce:        nonce,

				RawMsgXML: rawMsgXML,
				MixedMsg:  &mixedMsg,

				AESKey: aesKey,
				Random: random,
				AppId:  haveAppId,
			}
			srv.MessageHandler().ServeMessage(w, req)

		default: // 未知的加密类型
			err := errors.New("unknown encrypt_type: " + encryptType)
			errHandler.ServeError(w, r, err)
			return
		}

	case "GET": // 首次验证
		signature1 := queryValues.Get("signature")
		if signature1 == "" {
			errHandler.ServeError(w, r, errors.New("signature is empty"))
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		echostr := queryValues.Get("echostr")
		if echostr == "" {
			errHandler.ServeError(w, r, errors.New("echostr is empty"))
			return
		}

		signature2 := util.Sign(srv.Token(), timestamp, nonce)
		if !security.SecureCompareString(signature1, signature2) {
			err := fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
			errHandler.ServeError(w, r, err)
			return
		}

		io.WriteString(w, echostr)
	}
}
Ejemplo n.º 9
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv Server, errHandler corp.ErrorHandler) {
	corp.LogInfoln("[WECHAT_DEBUG] request uri:", r.RequestURI)
	corp.LogInfoln("[WECHAT_DEBUG] request remote-addr:", r.RemoteAddr)
	corp.LogInfoln("[WECHAT_DEBUG] request user-agent:", r.UserAgent())

	switch r.Method {
	case "POST": // 消息处理
		msgSignature1 := queryValues.Get("msg_signature")
		if msgSignature1 == "" {
			errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
			return
		}
		if len(msgSignature1) != 40 { // sha1
			err := fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
			errHandler.ServeError(w, r, err)
			return
		}

		timestampStr := queryValues.Get("timestamp")
		if timestampStr == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			err = errors.New("can not parse timestamp to int64: " + timestampStr)
			errHandler.ServeError(w, r, err)
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		reqBody, err := ioutil.ReadAll(r.Body)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}
		corp.LogInfoln("[WECHAT_DEBUG] request msg http body:\r\n", string(reqBody))

		// 解析 RequestHttpBody
		var requestHttpBody RequestHttpBody
		if err := xml.Unmarshal(reqBody, &requestHttpBody); err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		haveSuiteId := requestHttpBody.SuiteId
		wantSuiteId := srv.SuiteId()
		if wantSuiteId != "" {
			if len(haveSuiteId) != len(wantSuiteId) {
				err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveSuiteId, wantSuiteId)
				errHandler.ServeError(w, r, err)
				return
			}
			if subtle.ConstantTimeCompare([]byte(haveSuiteId), []byte(wantSuiteId)) != 1 {
				err = fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveSuiteId, wantSuiteId)
				errHandler.ServeError(w, r, err)
				return
			}
		}

		suiteToken := srv.SuiteToken()

		// 验证签名
		msgSignature2 := util.MsgSign(suiteToken, timestampStr, nonce, requestHttpBody.EncryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err = fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			errHandler.ServeError(w, r, err)
			return
		}

		// 解密
		encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		aesKey := srv.CurrentAESKey()
		random, rawMsgXML, aesSuiteId, err := util.AESDecryptMsg(encryptedMsgBytes, aesKey)
		if err != nil {
			// 尝试用上一次的 AESKey 来解密
			lastAESKey, isLastAESKeyValid := srv.LastAESKey()
			if !isLastAESKeyValid {
				errHandler.ServeError(w, r, err)
				return
			}

			aesKey = lastAESKey // NOTE

			random, rawMsgXML, aesSuiteId, err = util.AESDecryptMsg(encryptedMsgBytes, aesKey)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}
		}
		if haveSuiteId != string(aesSuiteId) {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the SuiteId with aes encrypt(==%s)", haveSuiteId, aesSuiteId)
			errHandler.ServeError(w, r, err)
			return
		}

		corp.LogInfoln("[WECHAT_DEBUG] request msg raw xml:\r\n", string(rawMsgXML))

		// 解密成功, 解析 MixedMessage
		var mixedMsg MixedMessage
		if err = xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		// 安全考虑再次验证
		if haveSuiteId != mixedMsg.SuiteId {
			err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's SuiteId(==%s)", haveSuiteId, mixedMsg.SuiteId)
			errHandler.ServeError(w, r, err)
			return
		}

		// 成功, 交给 MessageHandler
		req := &Request{
			SuiteToken: suiteToken,

			HttpRequest: r,
			QueryValues: queryValues,

			MsgSignature: msgSignature1,
			Timestamp:    timestamp,
			Nonce:        nonce,

			RawMsgXML: rawMsgXML,
			MixedMsg:  &mixedMsg,

			AESKey:  aesKey,
			Random:  random,
			SuiteId: haveSuiteId,
		}
		srv.MessageHandler().ServeMessage(w, req)

	case "GET": // 首次验证
		msgSignature1 := queryValues.Get("msg_signature")
		if msgSignature1 == "" {
			errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
			return
		}
		if len(msgSignature1) != 40 { // sha1
			err := fmt.Errorf("the length of msg_signature mismatch, have: %d, want: 40", len(msgSignature1))
			errHandler.ServeError(w, r, err)
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		encryptedMsg := queryValues.Get("echostr")
		if encryptedMsg == "" {
			errHandler.ServeError(w, r, errors.New("echostr is empty"))
			return
		}

		msgSignature2 := util.MsgSign(srv.SuiteToken(), timestamp, nonce, encryptedMsg)
		if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
			err := fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
			errHandler.ServeError(w, r, err)
			return
		}

		// 解密
		encryptedMsgBytes, err := base64.StdEncoding.DecodeString(encryptedMsg)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		aesKey := srv.CurrentAESKey()
		_, echostr, aesAppId, err := util.AESDecryptMsg(encryptedMsgBytes, aesKey)
		if err != nil {
			errHandler.ServeError(w, r, err)
			return
		}

		wantAppId := srv.SuiteId()
		if wantAppId != "" && wantAppId != string(aesAppId) {
			err = fmt.Errorf("AppId with aes encrypt mismatch, have: %s, want: %s", aesAppId, wantAppId)
			errHandler.ServeError(w, r, err)
			return
		}

		w.Write(echostr)
	}
}
Ejemplo n.º 10
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 确保所有参数合法, r.Body 能正确读取数据
func ServeHTTP(w http.ResponseWriter, r *http.Request,
	urlValues url.Values, agent Agent, invalidRequestHandler InvalidRequestHandler) {

	switch r.Method {
	case "POST": // 消息处理
		signature1, timestampStr, nonce, encryptType, msgSignature1, err := parsePostURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		switch encryptType {
		case "", "raw": // 明文模式
			const signatureLen = sha1.Size * 2
			if len(signature1) != signatureLen {
				err = fmt.Errorf("the length of signature mismatch, have: %d, want: %d", len(signature1), signatureLen)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			signature2 := util.Signature(agent.GetToken(), timestampStr, nonce)
			if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
				err = fmt.Errorf("check signature failed, have: %s, want: %s", signature1, signature2)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			rawXMLMsg, err := ioutil.ReadAll(r.Body)
			if err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			var msgReq request.Request
			if err = xml.Unmarshal(rawXMLMsg, &msgReq); err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			wantToUserName := agent.GetId()
			if len(msgReq.ToUserName) != len(wantToUserName) {
				err = fmt.Errorf("the message Request's ToUserName mismatch, have: %s, want: %s", msgReq.ToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
			if subtle.ConstantTimeCompare([]byte(msgReq.ToUserName), []byte(wantToUserName)) != 1 {
				err = fmt.Errorf("the message Request's ToUserName mismatch, have: %s, want: %s", msgReq.ToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			rawMsgDispatch(w, r, &msgReq, rawXMLMsg, timestamp, agent)

		case "aes": // 兼容模式, 安全模式
			const signatureLen = sha1.Size * 2
			//if len(signature1) != signatureLen {
			//	err = fmt.Errorf("the length of signature mismatch, have: %d, want: %d", len(signature1), signatureLen)
			//	invalidRequestHandler.ServeInvalidRequest(w, r, err)
			//	return
			//}
			if len(msgSignature1) != signatureLen {
				err = fmt.Errorf("the length of msg_signature mismatch, have: %d, want: %d", len(msgSignature1), signatureLen)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			//signature2 := signature(agent.GetToken(), timestampStr, nonce)
			//if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
			//	err = fmt.Errorf("check signature failed, have: %s, want: %s", signature1, signature2)
			//	invalidRequestHandler.ServeInvalidRequest(w, r, err)
			//	return
			//}

			var requestHttpBody request.RequestHttpBody
			if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			wantToUserName := agent.GetId()
			if len(requestHttpBody.ToUserName) != len(wantToUserName) {
				err = fmt.Errorf("the message RequestHttpBody's ToUserName mismatch, have: %s, want: %s", requestHttpBody.ToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}
			if subtle.ConstantTimeCompare([]byte(requestHttpBody.ToUserName), []byte(wantToUserName)) != 1 {
				err = fmt.Errorf("the message RequestHttpBody's ToUserName mismatch, have: %s, want: %s", requestHttpBody.ToUserName, wantToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			msgSignature2 := util.MsgSignature(agent.GetToken(), timestampStr, nonce, requestHttpBody.EncryptedMsg)
			if subtle.ConstantTimeCompare([]byte(msgSignature1), []byte(msgSignature2)) != 1 {
				err = fmt.Errorf("check signature failed, have: %s, want: %s", msgSignature1, msgSignature2)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			EncryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
			if err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			AESKey := agent.GetCurrentAESKey()

			random, rawXMLMsg, err := util.AESDecryptMsg(EncryptedMsgBytes, agent.GetAppId(), AESKey)
			if err != nil {
				// 尝试上一个 AESKey
				LastAESKey := agent.GetLastAESKey()
				if bytes.Equal(zeroAESKey[:], LastAESKey[:]) || bytes.Equal(AESKey[:], LastAESKey[:]) {
					invalidRequestHandler.ServeInvalidRequest(w, r, err)
					return
				}

				AESKey = LastAESKey // !!!

				random, rawXMLMsg, err = util.AESDecryptMsg(EncryptedMsgBytes, agent.GetAppId(), AESKey)
				if err != nil {
					invalidRequestHandler.ServeInvalidRequest(w, r, err)
					return
				}
			}

			var msgReq request.Request
			if err = xml.Unmarshal(rawXMLMsg, &msgReq); err != nil {
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			if requestHttpBody.ToUserName != msgReq.ToUserName {
				err = fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the Request's ToUserName(==%s)", requestHttpBody.ToUserName, msgReq.ToUserName)
				invalidRequestHandler.ServeInvalidRequest(w, r, err)
				return
			}

			aesMsgDispatch(w, r, &msgReq, rawXMLMsg, timestamp, nonce, AESKey, random, agent)

		default: // 未知的加密类型
			invalidRequestHandler.ServeInvalidRequest(w, r, errors.New("unknown encrypt_type"))
			return
		}

	case "GET": // 首次验证
		signature1, timestamp, nonce, echostr, err := parseGetURLQuery(urlValues)
		if err != nil {
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		const signatureLen = sha1.Size * 2
		if len(signature1) != signatureLen {
			err = fmt.Errorf("the length of signature mismatch, have: %d, want: %d", len(signature1), signatureLen)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		signature2 := util.Signature(agent.GetToken(), timestamp, nonce)
		if subtle.ConstantTimeCompare([]byte(signature1), []byte(signature2)) != 1 {
			err = fmt.Errorf("check signature failed, have: %s, want: %s", signature1, signature2)
			invalidRequestHandler.ServeInvalidRequest(w, r, err)
			return
		}

		io.WriteString(w, echostr)
	}
}
Ejemplo n.º 11
0
// ServeHTTP 处理 http 消息请求
//  NOTE: 调用者保证所有参数有效
func ServeHTTP(w http.ResponseWriter, r *http.Request, queryValues url.Values, srv Server, errHandler ErrorHandler) {
	switch r.Method {
	case "POST": // 消息处理
		switch encryptType := queryValues.Get("encrypt_type"); encryptType {
		case "aes": // 安全模式, 兼容模式
			signature := queryValues.Get("signature") // 只读取, 不做校验

			msgSignature1 := queryValues.Get("msg_signature")
			if msgSignature1 == "" {
				errHandler.ServeError(w, r, errors.New("msg_signature is empty"))
				return
			}

			timestampStr := queryValues.Get("timestamp")
			if timestampStr == "" {
				errHandler.ServeError(w, r, errors.New("timestamp is empty"))
				return
			}

			timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
			if err != nil {
				err = errors.New("can not parse timestamp to int64: " + timestampStr)
				errHandler.ServeError(w, r, err)
				return
			}

			nonce := queryValues.Get("nonce")
			if nonce == "" {
				errHandler.ServeError(w, r, errors.New("nonce is empty"))
				return
			}

			var requestHttpBody RequestHttpBody
			if err := xml.NewDecoder(r.Body).Decode(&requestHttpBody); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			// 安全考虑验证下 ToUserName
			haveToUserName := requestHttpBody.ToUserName
			wantToUserName := srv.OriId()
			if wantToUserName != "" && !security.SecureCompareString(haveToUserName, wantToUserName) {
				err := fmt.Errorf("the RequestHttpBody's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
				errHandler.ServeError(w, r, err)
				return
			}

			token := srv.Token()

			// 验证签名
			msgSignature2 := util.MsgSign(token, timestampStr, nonce, requestHttpBody.EncryptedMsg)
			if !security.SecureCompareString(msgSignature1, msgSignature2) {
				err := fmt.Errorf("check msg_signature failed, input: %s, local: %s", msgSignature1, msgSignature2)
				errHandler.ServeError(w, r, err)
				return
			}

			// 解密
			encryptedMsgBytes, err := base64.StdEncoding.DecodeString(requestHttpBody.EncryptedMsg)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			aesKey := srv.CurrentAESKey()
			random, rawMsgXML, haveAppIdBytes, err := util.AESDecryptMsg(encryptedMsgBytes, aesKey)
			if err != nil {
				// 尝试用上一次的 AESKey 来解密
				lastAESKey, isLastAESKeyValid := srv.LastAESKey()
				if !isLastAESKeyValid {
					errHandler.ServeError(w, r, err)
					return
				}

				aesKey = lastAESKey // NOTE

				random, rawMsgXML, haveAppIdBytes, err = util.AESDecryptMsg(encryptedMsgBytes, aesKey)
				if err != nil {
					errHandler.ServeError(w, r, err)
					return
				}
			}
			haveAppId := string(haveAppIdBytes)
			wantAppId := srv.AppId()
			if wantAppId != "" && wantAppId != haveAppId {
				err := fmt.Errorf("the message's appid mismatch, have: %s, want: %s", haveAppId, wantAppId)
				errHandler.ServeError(w, r, err)
				return
			}

			// 解密成功, 解析 MixedMessage
			var mixedMsg MixedMessage
			if err := xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			// 安全考虑再次验证 ToUserName
			if haveToUserName != mixedMsg.ToUserName {
				err := fmt.Errorf("the RequestHttpBody's ToUserName(==%s) mismatch the MixedMessage's ToUserName(==%s)", haveToUserName, mixedMsg.ToUserName)
				errHandler.ServeError(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			req := &Request{
				Token: token,

				HttpRequest: r,
				QueryValues: queryValues,

				Signature: signature,
				Timestamp: timestamp,
				Nonce:     nonce,

				EncryptType: encryptType,
				RawMsgXML:   rawMsgXML,
				MixedMsg:    &mixedMsg,

				MsgSignature: msgSignature1,
				AESKey:       aesKey,
				Random:       random,
				AppId:        haveAppId,
			}
			srv.MessageHandler().ServeMessage(w, req)

		case "", "raw": // 明文模式
			signature1 := queryValues.Get("signature")
			if signature1 == "" {
				errHandler.ServeError(w, r, errors.New("signature is empty"))
				return
			}

			timestampStr := queryValues.Get("timestamp")
			if timestampStr == "" {
				errHandler.ServeError(w, r, errors.New("timestamp is empty"))
				return
			}

			timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
			if err != nil {
				err = errors.New("can not parse timestamp to int64: " + timestampStr)
				errHandler.ServeError(w, r, err)
				return
			}

			nonce := queryValues.Get("nonce")
			if nonce == "" {
				errHandler.ServeError(w, r, errors.New("nonce is empty"))
				return
			}

			token := srv.Token()

			signature2 := util.Sign(token, timestampStr, nonce)
			if !security.SecureCompareString(signature1, signature2) {
				err := fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
				errHandler.ServeError(w, r, err)
				return
			}

			// 验证签名成功, 解析 MixedMessage
			rawMsgXML, err := ioutil.ReadAll(r.Body)
			if err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			var mixedMsg MixedMessage
			if err := xml.Unmarshal(rawMsgXML, &mixedMsg); err != nil {
				errHandler.ServeError(w, r, err)
				return
			}

			// 安全考虑验证 ToUserName
			haveToUserName := mixedMsg.ToUserName
			wantToUserName := srv.OriId()
			if wantToUserName != "" && !security.SecureCompareString(haveToUserName, wantToUserName) {
				err := fmt.Errorf("the Message's ToUserName mismatch, have: %s, want: %s", haveToUserName, wantToUserName)
				errHandler.ServeError(w, r, err)
				return
			}

			// 成功, 交给 MessageHandler
			req := &Request{
				Token: token,

				HttpRequest: r,
				QueryValues: queryValues,

				Signature: signature1,
				Timestamp: timestamp,
				Nonce:     nonce,

				EncryptType: encryptType,
				RawMsgXML:   rawMsgXML,
				MixedMsg:    &mixedMsg,
			}
			srv.MessageHandler().ServeMessage(w, req)

		default: // 未知的加密类型
			err := errors.New("unknown encrypt_type: " + encryptType)
			errHandler.ServeError(w, r, err)
			return
		}

	case "GET": // 验证回调 url 有效性
		signature1 := queryValues.Get("signature")
		if signature1 == "" {
			errHandler.ServeError(w, r, errors.New("signature is empty"))
			return
		}

		timestamp := queryValues.Get("timestamp")
		if timestamp == "" {
			errHandler.ServeError(w, r, errors.New("timestamp is empty"))
			return
		}

		nonce := queryValues.Get("nonce")
		if nonce == "" {
			errHandler.ServeError(w, r, errors.New("nonce is empty"))
			return
		}

		echostr := queryValues.Get("echostr")
		if echostr == "" {
			errHandler.ServeError(w, r, errors.New("echostr is empty"))
			return
		}

		signature2 := util.Sign(srv.Token(), timestamp, nonce)
		if !security.SecureCompareString(signature1, signature2) {
			err := fmt.Errorf("check signature failed, input: %s, local: %s", signature1, signature2)
			errHandler.ServeError(w, r, err)
			return
		}

		io.WriteString(w, echostr)
	}
}