Exemple #1
0
func checkOrigin(config *websocket.Config, req *http.Request) (err error) {
	config.Origin, err = websocket.Origin(config, req)
	if err == nil && config.Origin == nil {
		return fmt.Errorf("null origin")
	}
	return err
}
Exemple #2
0
func (a *Application) checkOrigin(conf *websocket.Config,
	req *http.Request) (err error) {

	if len(a.origins) == 0 {
		return nil
	}
	if conf.Origin, err = websocket.Origin(conf, req); err != nil {
		if a.log.ShouldLog(WARNING) {
			a.log.Warn("http", "Error parsing WebSocket origin",
				LogFields{"rid": req.Header.Get(HeaderID), "error": err.Error()})
		}
		return err
	}
	if conf.Origin == nil {
		return ErrMissingOrigin
	}
	for _, origin := range a.origins {
		if isSameOrigin(conf.Origin, origin) {
			return nil
		}
	}
	if a.log.ShouldLog(WARNING) {
		a.log.Warn("http", "Rejected WebSocket connection from unknown origin",
			LogFields{"rid": req.Header.Get(HeaderID), "origin": conf.Origin.String()})
	}
	return ErrInvalidOrigin
}
Exemple #3
0
func (h *SocketHandler) checkOrigin(conf *websocket.Config, req *http.Request) (err error) {
	if conf.Origin, err = websocket.Origin(conf, req); err != nil {
		if h.logger.ShouldLog(NOTICE) {
			h.logger.Notice("handlers_socket", "Error parsing WebSocket origin",
				LogFields{"rid": req.Header.Get(HeaderID), "error": err.Error()})
		}
	}
	if len(h.origins) == 0 {
		return nil
	}
	if conf.Origin == nil {
		return ErrMissingOrigin
	}
	for _, origin := range h.origins {
		if isSameOrigin(conf.Origin, origin) {
			return nil
		}
	}
	if h.logger.ShouldLog(WARNING) {
		h.logger.Warn("handlers_socket",
			"Rejected WebSocket connection from unknown origin", LogFields{
				"rid": req.Header.Get(HeaderID), "origin": conf.Origin.String()})
	}
	return ErrInvalidOrigin
}
Exemple #4
0
// handshake checks the origin of a request during the websocket handshake.
func handshake(c *websocket.Config, req *http.Request) error {
	o, err := websocket.Origin(c, req)
	if err != nil {
		log.Println("bad websocket origin:", err)
		return websocket.ErrBadWebSocketOrigin
	}
	_, port, err := net.SplitHostPort(c.Origin.Host)
	if err != nil {
		log.Println("bad websocket origin:", err)
		return websocket.ErrBadWebSocketOrigin
	}
	ok := c.Origin.Scheme == o.Scheme && (c.Origin.Host == o.Host || c.Origin.Host == net.JoinHostPort(o.Host, port))
	if !ok {
		log.Println("bad websocket origin:", o)
		return websocket.ErrBadWebSocketOrigin
	}
	log.Println("accepting connection from:", req.RemoteAddr)
	return nil
}
Exemple #5
0
func (self *MyHttpServer) apiRouter(w http.ResponseWriter, req *http.Request) error {
	switch req.URL.Path {
	case "/":
		fmt.Fprintf(w, "HELO MOMONGA WORLD")
	case "/pub":
		reqParams, err := url.ParseQuery(req.URL.RawQuery)
		if err != nil {
			return nil
		}

		var topic string
		var qos string
		if topics, ok := reqParams["topic"]; ok {
			topic = topics[0]
		}
		if qoss, ok := reqParams["qos"]; ok {
			qos = qoss[0]
		}

		if qos == "" {
			qos = "0"
		}

		readMax := int64(8192)
		body, _ := ioutil.ReadAll(io.LimitReader(req.Body, readMax))
		if len(body) < 1 {
			return fmt.Errorf("body required")
		}

		rqos, _ := strconv.ParseInt(qos, 10, 32)
		self.Engine.SendMessage(topic, []byte(body), int(rqos))
		w.Write([]byte(fmt.Sprintf("OK")))
		return nil
	case "/stats":
		return nil
	case self.WebSocketMount:
		s := websocket.Server{
			Handler: websocket.Handler(func(ws *websocket.Conn) {
				// need for binary frame
				ws.PayloadType = 0x02

				myconf := GetDefaultMyConfig()
				myconf.MaxMessageSize = self.Engine.Config().Server.MessageSizeLimit
				conn := NewMyConnection(myconf)
				conn.SetMyConnection(ws)
				conn.SetId(ws.RemoteAddr().String())
				self.Engine.HandleConnection(conn)
			}),
			Handshake: func(config *websocket.Config, req *http.Request) (err error) {
				config.Origin, err = websocket.Origin(config, req)
				if err == nil && config.Origin == nil {
					return fmt.Errorf("null origin")
				}
				if config.Origin == nil {
					config.Origin, err = url.Parse("http://localhost")
				}

				if len(config.Protocol) > 1 {
					config.Protocol = []string{"mqttv3.1"}
				}

				// これどっしよっかなー。もうちょっと楽に選択させたい
				v := 0
				for i := 0; i < len(config.Protocol); i++ {
					switch config.Protocol[i] {
					case "mqtt":
						if v == 0 {
							v = 1
						}
					case "mqttv3.1":
						v = 2
					default:
						return fmt.Errorf("unsupported protocol")
					}
				}

				switch v {
				case 1:
					config.Protocol = []string{"mqtt"}
				case 2:
					config.Protocol = []string{"mqttv3.1"}
				}

				return err
			},
		}
		s.ServeHTTP(w, req)
	default:
		return fmt.Errorf("404 %s", req.URL.Path)
	}
	return nil
}
Exemple #6
0
func checkOrigin(wsconf *websocket.Config, req *http.Request, config *Config, log *LogScope) (err error) {
	// check for origin to be correct in future
	// handshaker triggers answering with 403 if error was returned
	// We keep behavior of original handshaker that populates this field
	origin := req.Header.Get("Origin")
	if origin == "" || (origin == "null" && config.AllowOrigins == nil) {
		// we don't want to trust string "null" if there is any
		// enforcements are active
		req.Header.Set("Origin", "file:")
	}

	wsconf.Origin, err = websocket.Origin(wsconf, req)
	if err == nil && wsconf.Origin == nil {
		log.Access("session", "rejected null origin")
		return fmt.Errorf("null origin not allowed")
	}
	if err != nil {
		log.Access("session", "Origin parsing error: %s", err)
		return err
	}
	log.Associate("origin", wsconf.Origin.String())

	// If some origin restrictions are present:
	if config.SameOrigin || config.AllowOrigins != nil {
		originServer, originPort, err := tellHostPort(wsconf.Origin.Host, wsconf.Origin.Scheme == "https")
		if err != nil {
			log.Access("session", "Origin hostname parsing error: %s", err)
			return err
		}
		if config.SameOrigin {
			localServer, localPort, err := tellHostPort(req.Host, req.TLS != nil)
			if err != nil {
				log.Access("session", "Request hostname parsing error: %s", err)
				return err
			}
			if originServer != localServer || originPort != localPort {
				log.Access("session", "Same origin policy mismatch")
				return fmt.Errorf("same origin policy violated")
			}
		}
		if config.AllowOrigins != nil {
			matchFound := false
			for _, allowed := range config.AllowOrigins {
				if pos := strings.Index(allowed, "://"); pos > 0 {
					// allowed schema has to match
					allowedURL, err := url.Parse(allowed)
					if err != nil {
						continue // pass bad URLs in origin list
					}
					if allowedURL.Scheme != wsconf.Origin.Scheme {
						continue // mismatch
					}
					allowed = allowed[pos+3:]
				}
				allowServer, allowPort, err := tellHostPort(allowed, false)
				if err != nil {
					continue // unparseable
				}
				if allowPort == "80" && allowed[len(allowed)-3:] != ":80" {
					// any port is allowed, host names need to match
					matchFound = allowServer == originServer
				} else {
					// exact match of host names and ports
					matchFound = allowServer == originServer && allowPort == originPort
				}
				if matchFound {
					break
				}
			}
			if !matchFound {
				log.Access("session", "Origin is not listed in allowed list")
				return fmt.Errorf("origin list matches were not found")
			}
		}
	}
	return nil
}