Пример #1
0
// ServeHTTP serves the SQL API by treating the request URL path
// as the method, the request body as the arguments, and sets the
// response body as the method reply. The request body is unmarshalled
// into arguments based on the Content-Type request header. Protobuf
// and JSON-encoded requests are supported. The response body is
// encoded according to the request's Accept header, or if not
// present, in the same format as the request's incoming Content-Type
// header.
func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	defer r.Body.Close()
	method := r.URL.Path
	if !strings.HasPrefix(method, driver.Endpoint) {
		http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
		return
	}

	// Check TLS settings.
	authenticationHook, err := security.ProtoAuthHook(s.context.Insecure, r.TLS)
	if err != nil {
		http.Error(w, err.Error(), http.StatusUnauthorized)
		return
	}

	method = strings.TrimPrefix(method, driver.Endpoint)
	if method != driver.Execute.String() {
		http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
		return
	}

	// Unmarshal the request.
	reqBody, err := ioutil.ReadAll(r.Body)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	var args driver.Request
	if err := util.UnmarshalRequest(r, reqBody, &args, allowedEncodings); err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	// Check request user against client certificate user.
	if err := authenticationHook(&args, true /* public */); err != nil {
		http.Error(w, err.Error(), http.StatusUnauthorized)
		return
	}

	reply, code, err := s.Execute(args)
	if err != nil {
		http.Error(w, err.Error(), code)
	}

	// Marshal the response.
	body, contentType, err := util.MarshalResponse(r, &reply, allowedEncodings)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	w.Header().Set(util.ContentTypeHeader, contentType)
	if _, err := w.Write(body); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}
Пример #2
0
func TestAuthenticationHook(t *testing.T) {
	defer leaktest.AfterTest(t)
	// Proto that does not implement GetUser.
	badRequest := &roachpb.GetResponse{}
	goodRequest := &roachpb.BatchRequest{}

	testCases := []struct {
		insecure           bool
		tls                *tls.ConnectionState
		request            proto.Message
		buildHookSuccess   bool
		publicHookSuccess  bool
		privateHookSuccess bool
	}{
		// Insecure mode, nil request.
		{true, nil, nil, true, false, false},
		// Insecure mode, bad request.
		{true, nil, badRequest, true, false, false},
		// Insecure mode, good request.
		{true, nil, goodRequest, true, true, true},
		// Secure mode, no TLS state.
		{false, nil, nil, false, false, false},
		// Secure mode, bad user.
		{false, makeFakeTLSState([]string{"foo"}, []int{1}), goodRequest, true, false, false},
		// Secure mode, node user.
		{false, makeFakeTLSState([]string{security.NodeUser}, []int{1}), goodRequest, true, true, true},
		// Secure mode, root user.
		{false, makeFakeTLSState([]string{security.RootUser}, []int{1}), goodRequest, true, false, false},
	}

	for tcNum, tc := range testCases {
		hook, err := security.ProtoAuthHook(tc.insecure, tc.tls)
		if (err == nil) != tc.buildHookSuccess {
			t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.buildHookSuccess, err)
		}
		if err != nil {
			continue
		}
		err = hook(tc.request, true /*public*/)
		if (err == nil) != tc.publicHookSuccess {
			t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.publicHookSuccess, err)
		}
		err = hook(tc.request, false /*not public*/)
		if (err == nil) != tc.privateHookSuccess {
			t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.privateHookSuccess, err)
		}
	}
}
Пример #3
0
// ServeHTTP implements an http.Handler that answers RPC requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	// Note: this code was adapted from net/rpc.Server.ServeHTTP.
	if req.Method != "CONNECT" {
		http.Error(w, "405 must CONNECT", http.StatusMethodNotAllowed)
		return
	}

	// Construct an authentication hook for this security mode and TLS state.
	authHook, err := security.ProtoAuthHook(s.insecure, req.TLS)
	if err != nil {
		http.Error(w, err.Error(), http.StatusUnauthorized)
		return
	}

	conn, _, err := w.(http.Hijacker).Hijack()
	if err != nil {
		log.Infof("rpc hijacking ", req.RemoteAddr, ": ", err)
		return
	}

	if log.V(3) {
		security.LogTLSState("RPC", req.TLS)
	}

	if _, err := io.WriteString(conn, "HTTP/1.0 "+codec.Connected+"\n\n"); err != nil {
		http.Error(w, err.Error(), http.StatusUnauthorized)
		return
	}

	// Run open callbacks.
	s.runOpenCallbacks(conn)

	codec := codec.NewServerCodec(conn)
	responses := make(chan serverResponse)
	go func() {
		s.sendResponses(codec, responses)
	}()
	go func() {
		s.readRequests(conn, codec, authHook, responses)
		codec.Close()
		conn.Close()
	}()
}