// ServeHTTP implements an http.Handler that answers RPC requests. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path != rpc.DefaultRPCPath { if s.handler != nil { s.handler.ServeHTTP(w, r) return } http.NotFound(w, r) return } // Note: this code was adapted from net/rpc.Server.ServeHTTP. if r.Method != "CONNECT" { http.Error(w, "405 must CONNECT", http.StatusMethodNotAllowed) return } // Construct an authentication hook for this security mode and TLS state. authHook, err := security.AuthenticationHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } conn, _, err := w.(http.Hijacker).Hijack() if err != nil { log.Infof("rpc hijacking %s: %s", r.RemoteAddr, err) return } security.LogTLSState("RPC", r.TLS) io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") s.serveConn(conn, authHook) }
// ServeHTTP serves the SQL API by treating the request URL path // as the method, the request body as the arguments, and sets the // response body as the method reply. The request body is unmarshalled // into arguments based on the Content-Type request header. Protobuf // and JSON-encoded requests are supported. The response body is // encoded according to the request's Accept header, or if not // present, in the same format as the request's incoming Content-Type // header. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() method := r.URL.Path if !strings.HasPrefix(method, driver.Endpoint) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Check TLS settings. authenticationHook, err := security.AuthenticationHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } method = strings.TrimPrefix(method, driver.Endpoint) if method != driver.Execute.String() { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Unmarshal the request. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } var args driver.Request if err := util.UnmarshalRequest(r, reqBody, &args, allowedEncodings); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Check request user against client certificate user. if err := authenticationHook(&args); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } // Send the Request for SQL execution and set the application-level error // on the reply. reply, err := s.exec(args) if err != nil { errProto := proto.Error{} errProto.SetResponseGoError(err) reply.Error = &errProto } // Marshal the response. body, contentType, err := util.MarshalResponse(r, &reply, allowedEncodings) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set(util.ContentTypeHeader, contentType) w.Write(body) }
// ServeHTTP implements an http.Handler that answers RPC requests. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path != rpc.DefaultRPCPath { if s.handler != nil { s.handler.ServeHTTP(w, r) return } http.NotFound(w, r) return } // Note: this code was adapted from net/rpc.Server.ServeHTTP. if r.Method != "CONNECT" { http.Error(w, "405 must CONNECT", http.StatusMethodNotAllowed) return } // Construct an authentication hook for this security mode and TLS state. authHook, err := security.AuthenticationHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } conn, _, err := w.(http.Hijacker).Hijack() if err != nil { log.Infof("rpc hijacking %s: %s", r.RemoteAddr, err) return } if log.V(3) { security.LogTLSState("RPC", r.TLS) } if _, err := io.WriteString(conn, "HTTP/1.0 "+codec.Connected+"\n\n"); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } codec := codec.NewServerCodec(conn) responses := make(chan serverResponse) var wg sync.WaitGroup wg.Add(1) go func() { s.sendResponses(codec, responses) wg.Done() }() s.readRequests(codec, authHook, responses) wg.Wait() codec.Close() s.mu.Lock() if s.closeCallbacks != nil { for _, cb := range s.closeCallbacks { cb(conn) } } s.mu.Unlock() conn.Close() }
// ServeHTTP serves the SQL API by treating the request URL path // as the method, the request body as the arguments, and sets the // response body as the method reply. The request body is unmarshalled // into arguments based on the Content-Type request header. Protobuf // and JSON-encoded requests are supported. The response body is // encoded according to the request's Accept header, or if not // present, in the same format as the request's incoming Content-Type // header. func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() method := r.URL.Path if !strings.HasPrefix(method, driver.Endpoint) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Check TLS settings. authenticationHook, err := security.AuthenticationHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } method = strings.TrimPrefix(method, driver.Endpoint) if method != driver.Execute.String() { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Unmarshal the request. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } var args driver.Request if err := util.UnmarshalRequest(r, reqBody, &args, allowedEncodings); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Check request user against client certificate user. if err := authenticationHook(&args, true /*public*/); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } reply, code, err := s.Execute(args) if err != nil { http.Error(w, err.Error(), code) } // Marshal the response. body, contentType, err := util.MarshalResponse(r, &reply, allowedEncodings) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set(util.ContentTypeHeader, contentType) if _, err := w.Write(body); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } }
// ServeHTTP serves the key-value API by treating the request URL path // as the method, the request body as the arguments, and sets the // response body as the method reply. The request body is unmarshalled // into arguments based on the Content-Type request header. Protobuf // and JSON-encoded requests are supported. The response body is // encoded according the the request's Accept header, or if not // present, in the same format as the request's incoming Content-Type // header. func (s *DBServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Check TLS settings before anything else. authenticationHook, err := security.AuthenticationHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } method := r.URL.Path if !strings.HasPrefix(method, DBPrefix) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } method = strings.TrimPrefix(method, DBPrefix) args, reply := createArgsAndReply(method) if args == nil { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Unmarshal the request. reqBody, err := ioutil.ReadAll(r.Body) defer r.Body.Close() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } if err := util.UnmarshalRequest(r, reqBody, args, allowedEncodings); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Verify the request for public API. if err := verifyRequest(args); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Check request user against client certificate user. if err := authenticationHook(args); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } // Create a call and invoke through sender. s.sender.Send(context.TODO(), proto.Call{Args: args, Reply: reply}) // Marshal the response. body, contentType, err := util.MarshalResponse(r, reply, allowedEncodings) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set(util.ContentTypeHeader, contentType) w.Write(body) }
// TestHTTPSenderSend verifies sending posts. func TestHTTPSenderSend(t *testing.T) { defer leaktest.AfterTest(t) server, addr := startTestHTTPServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Make sure SSL certs were properly specified. authenticationHook, err := security.AuthenticationHook(false /* !insecure */, r.TLS) if err != nil { t.Error(err) } if r.Method != "POST" { t.Errorf("expected method POST; got %s", r.Method) } if r.URL.Path != KVDBEndpoint+"Put" { t.Errorf("expected url %s; got %s", KVDBEndpoint+"Put", r.URL.Path) } // Unmarshal the request. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { t.Errorf("unexpected error reading body: %s", err) } args := &proto.PutRequest{} if err := util.UnmarshalRequest(r, reqBody, args, util.AllEncodings); err != nil { t.Errorf("unexpected error unmarshalling request: %s", err) } // Validate request against incoming user. if err := authenticationHook(args, false /*not public*/); err != nil { t.Error(err) } if !args.Key.Equal(testPutReq.Key) || !args.Timestamp.Equal(testPutReq.Timestamp) { t.Errorf("expected parsed %+v to equal %+v", args, testPutReq) } body, contentType, err := util.MarshalResponse(r, testPutResp, util.AllEncodings) if err != nil { t.Errorf("failed to marshal response: %s", err) } w.Header().Set(util.ContentTypeHeader, contentType) w.Write(body) })) defer server.Close() sender, err := newHTTPSender(addr, testutils.NewNodeTestBaseContext(), defaultRetryOptions) if err != nil { t.Fatal(err) } reply := &proto.PutResponse{} sender.Send(context.Background(), proto.Call{Args: testPutReq, Reply: reply}) if reply.GoError() != nil { t.Errorf("expected success; got %s", reply.GoError()) } if !reply.Timestamp.Equal(testPutResp.Timestamp) { t.Errorf("expected received %+v to equal %+v", reply, testPutResp) } }
func TestAuthenticationHook(t *testing.T) { defer leaktest.AfterTest(t) // Proto that does not implement GetUser. badRequest := &cockroach_proto.GetResponse{} getRequest := &cockroach_proto.GetRequest{} testCases := []struct { insecure bool tls *tls.ConnectionState request proto.Message buildHookSuccess bool publicHookSuccess bool privateHookSuccess bool }{ // Insecure mode, nil request. {true, nil, nil, true, false, false}, // Insecure mode, bad request. {true, nil, badRequest, true, false, false}, // Insecure mode, good request. {true, nil, getRequest, true, true, true}, // Secure mode, no TLS state. {false, nil, nil, false, false, false}, // Secure mode, bad user. {false, makeFakeTLSState([]string{"foo"}, []int{1}), getRequest, true, false, false}, // Secure mode, node user. {false, makeFakeTLSState([]string{security.NodeUser}, []int{1}), getRequest, true, true, true}, // Secure mode, root user. {false, makeFakeTLSState([]string{security.RootUser}, []int{1}), getRequest, true, false, false}, } for tcNum, tc := range testCases { hook, err := security.AuthenticationHook(tc.insecure, tc.tls) if (err == nil) != tc.buildHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.buildHookSuccess, err) } if err != nil { continue } err = hook(tc.request, true /*public*/) if (err == nil) != tc.publicHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.publicHookSuccess, err) } err = hook(tc.request, false /*not public*/) if (err == nil) != tc.privateHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.privateHookSuccess, err) } } }
// ServeHTTP serves the SQL API by treating the request URL path // as the method, the request body as the arguments, and sets the // response body as the method reply. The request body is unmarshalled // into arguments based on the Content-Type request header. Protobuf // and JSON-encoded requests are supported. The response body is // encoded according to the request's Accept header, or if not // present, in the same format as the request's incoming Content-Type // header. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() method := r.URL.Path if !strings.HasPrefix(method, driver.Endpoint) { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Check TLS settings. authenticationHook, err := security.AuthenticationHook(s.context.Insecure, r.TLS) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } method = strings.TrimPrefix(method, driver.Endpoint) if method != driver.Execute.String() { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } // Unmarshal the request. reqBody, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } var args driver.Request if err := util.UnmarshalRequest(r, reqBody, &args, allowedEncodings); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Check request user against client certificate user. if err := authenticationHook(&args, true /*public*/); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } // Pick up current session state. planMaker := planner{user: args.GetUser()} if err := gogoproto.Unmarshal(args.Session, &planMaker.session); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Open a pending transaction if needed. if planMaker.session.Txn != nil { txn := client.NewTxn(*s.db) txn.Proto = *planMaker.session.Txn planMaker.txn = txn } // Send the Request for SQL execution and set the application-level error // for each result in the reply. reply := s.exec(args, &planMaker) // Send back the session state even if there were application-level errors. // Add transaction to session state. if planMaker.txn != nil { planMaker.session.Txn = &planMaker.txn.Proto } else { planMaker.session.Txn = nil } bytes, err := gogoproto.Marshal(&planMaker.session) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } reply.Session = bytes // Marshal the response. body, contentType, err := util.MarshalResponse(r, &reply, allowedEncodings) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set(util.ContentTypeHeader, contentType) w.Write(body) }
func TestAuthenticationHook(t *testing.T) { defer leaktest.AfterTest(t) // Proto that does not implement GetUser. badRequest := &cockroach_proto.GetResponse{} testCases := []struct { insecure bool tls *tls.ConnectionState request proto.Message buildHookSuccess bool publicHookSuccess bool privateHookSuccess bool }{ // Insecure mode, nil request. {true, nil, nil, true, false, false}, // Insecure mode, bad request. {true, nil, badRequest, true, false, false}, // Insecure mode, userRequest with empty user. {true, nil, makeUserRequest(""), true, false, false}, // Insecure mode, userRequest with good user. {true, nil, makeUserRequest("foo"), true, true, false}, // Insecure mode, userRequest with root user. {true, nil, makeUserRequest(security.RootUser), true, true, true}, // Insecure mode, userRequest with node user. {true, nil, makeUserRequest(security.NodeUser), true, true, true}, // Secure mode, no TLS state. {false, nil, nil, false, false, false}, // Secure mode, user mismatch. {false, makeFakeTLSState([]string{"foo"}, []int{1}), makeUserRequest("bar"), true, false, false}, // Secure mode, user mismatch, but client certificate is for the node user. {false, makeFakeTLSState([]string{security.NodeUser}, []int{1}), makeUserRequest("bar"), true, true, false}, // Secure mode, user mismatch, and the root user does not get blind permissions. {false, makeFakeTLSState([]string{security.RootUser}, []int{1}), makeUserRequest("bar"), true, false, false}, // Secure mode, matching users. {false, makeFakeTLSState([]string{"foo"}, []int{1}), makeUserRequest("foo"), true, true, false}, // Secure mode, root acting as itself. {false, makeFakeTLSState([]string{security.RootUser}, []int{1}), makeUserRequest(security.RootUser), true, true, true}, // Secure mode, node acting as itself. {false, makeFakeTLSState([]string{security.NodeUser}, []int{1}), makeUserRequest(security.NodeUser), true, true, true}, // Secure mode, node acting as root. {false, makeFakeTLSState([]string{security.NodeUser}, []int{1}), makeUserRequest(security.RootUser), true, true, true}, } for tcNum, tc := range testCases { hook, err := security.AuthenticationHook(tc.insecure, tc.tls) if (err == nil) != tc.buildHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.buildHookSuccess, err) } if err != nil { continue } err = hook(tc.request, true /*public*/) if (err == nil) != tc.publicHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.publicHookSuccess, err) } err = hook(tc.request, false /*not public*/) if (err == nil) != tc.privateHookSuccess { t.Fatalf("#%d: expected success=%t, got err=%v", tcNum, tc.privateHookSuccess, err) } } }