// ServeHTTP serves the SQL API by treating the request URL path // as the method, the request body as the arguments, and sets the // response body as the method reply. The request body is unmarshalled // into arguments based on the Content-Type request header. Protobuf // and JSON-encoded requests are supported. The response body is // encoded according to the request's Accept header, or if not // present, in the same format as the request's incoming Content-Type // header. func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() method := r.URL.Path if !strings.HasPrefix(method, driver.Endpoint) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Check TLS settings. authenticationHook, err := security.ProtoAuthHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } method = strings.TrimPrefix(method, driver.Endpoint) if method != driver.Execute.String() { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Unmarshal the request. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } var args driver.Request if err := util.UnmarshalRequest(r, reqBody, &args, allowedEncodings); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Check request user against client certificate user. if err := authenticationHook(&args, true /* public */); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } reply, code, err := s.Execute(args) if err != nil { http.Error(w, err.Error(), code) } // Marshal the response. body, contentType, err := util.MarshalResponse(r, &reply, allowedEncodings) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set(util.ContentTypeHeader, contentType) if _, err := w.Write(body); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } }
func TestAuthenticationHook(t *testing.T) { defer leaktest.AfterTest(t) // Proto that does not implement GetUser. badRequest := &roachpb.GetResponse{} goodRequest := &roachpb.BatchRequest{} testCases := []struct { insecure bool tls *tls.ConnectionState request proto.Message buildHookSuccess bool publicHookSuccess bool privateHookSuccess bool }{ // Insecure mode, nil request. {true, nil, nil, true, false, false}, // Insecure mode, bad request. {true, nil, badRequest, true, false, false}, // Insecure mode, good request. {true, nil, goodRequest, true, true, true}, // Secure mode, no TLS state. {false, nil, nil, false, false, false}, // Secure mode, bad user. {false, makeFakeTLSState([]string{"foo"}, []int{1}), goodRequest, true, false, false}, // Secure mode, node user. {false, makeFakeTLSState([]string{security.NodeUser}, []int{1}), goodRequest, true, true, true}, // Secure mode, root user. {false, makeFakeTLSState([]string{security.RootUser}, []int{1}), goodRequest, true, false, false}, } for tcNum, tc := range testCases { hook, err := security.ProtoAuthHook(tc.insecure, tc.tls) if (err == nil) != tc.buildHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.buildHookSuccess, err) } if err != nil { continue } err = hook(tc.request, true /*public*/) if (err == nil) != tc.publicHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.publicHookSuccess, err) } err = hook(tc.request, false /*not public*/) if (err == nil) != tc.privateHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.privateHookSuccess, err) } } }
// ServeHTTP implements an http.Handler that answers RPC requests. func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Note: this code was adapted from net/rpc.Server.ServeHTTP. if req.Method != "CONNECT" { http.Error(w, "405 must CONNECT", http.StatusMethodNotAllowed) return } // Construct an authentication hook for this security mode and TLS state. authHook, err := security.ProtoAuthHook(s.insecure, req.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } conn, _, err := w.(http.Hijacker).Hijack() if err != nil { log.Infof("rpc hijacking ", req.RemoteAddr, ": ", err) return } if log.V(3) { security.LogTLSState("RPC", req.TLS) } if _, err := io.WriteString(conn, "HTTP/1.0 "+codec.Connected+"\n\n"); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } // Run open callbacks. s.runOpenCallbacks(conn) codec := codec.NewServerCodec(conn) responses := make(chan serverResponse) go func() { s.sendResponses(codec, responses) }() go func() { s.readRequests(conn, codec, authHook, responses) codec.Close() conn.Close() }() }