예제 #1
0
// handleRequest handles proxying and forwarding a request to a provider for
// the specified task.
func (s *Server) handleRequest(req *acomm.Request) error {
	providerSockets, err := s.getProviders(req.Task)
	if err != nil {
		return err
	}

	if len(providerSockets) == 0 {
		return errors.New("no providers available for task")
	}

	proxyReq, err := s.proxy.ProxyUnix(req, 0)
	if err != nil {
		return err
	}

	// Cycle through available providers until one accepts the request
	for _, providerSocket := range providerSockets {
		addr, _ := url.ParseRequestURI(fmt.Sprintf("unix://%s", providerSocket))
		err = acomm.Send(addr, proxyReq)
		if err == nil {
			// Successfully sent
			break
		}
	}

	return err
}
예제 #2
0
func (s *ServerSuite) TestStartHandleStop() {
	// Start
	taskHandler := func(a *acomm.Request) (interface{}, *url.URL, error) {
		return nil, nil, nil
	}
	s.server.RegisterTask("foobar", taskHandler)

	if !s.NoError(s.server.Start(), "failed to start server") {
		return
	}
	time.Sleep(time.Second)

	// Stop
	defer s.server.Stop()

	// Handle request
	tracker := s.server.Tracker()
	handled := make(chan struct{})
	respHandler := func(req *acomm.Request, resp *acomm.Response) {
		close(handled)
	}
	req, _ := acomm.NewRequest("foobar", tracker.URL().String(), struct{}{}, respHandler, respHandler)
	providerSocket, _ := url.ParseRequestURI("unix://" + s.server.TaskSocketPath("foobar"))
	if !s.NoError(s.server.Tracker().TrackRequest(req, 5*time.Second)) {
		return
	}
	if !s.NoError(acomm.Send(providerSocket, req)) {
		return
	}
	<-handled
}
예제 #3
0
func makeRequest(coordinator, taskName, responseAddr string, taskArgs map[string]interface{}) error {
	coordinatorURL, err := url.ParseRequestURI(coordinator)
	if err != nil {
		return errors.New("invalid coordinator url")
	}

	responseHook := fmt.Sprintf("http://%s/", responseAddr)
	req, err := acomm.NewRequest(taskName, responseHook, taskArgs, nil, nil)
	if err != nil {
		return err
	}

	return acomm.Send(coordinatorURL, req)
}
예제 #4
0
func (s *TrackerTestSuite) TestProxyUnix() {
	unixReq, err := s.Tracker.ProxyUnix(s.Request, 0)
	s.Error(err, "should fail to proxy when tracker is not listening")
	s.Nil(unixReq, "should not return a request")

	if !s.NoError(s.Tracker.Start(), "listner should start") {
		return
	}

	unixReq, err = s.Tracker.ProxyUnix(s.Request, 0)
	s.NoError(err, "should not fail proxying when tracker is listening")
	s.NotNil(unixReq, "should return a request")
	s.Equal(s.Request.ID, unixReq.ID, "new request should share ID with original")
	s.Equal("unix", unixReq.ResponseHook.Scheme, "new request should have a unix response hook")
	s.Equal(1, s.Tracker.NumRequests(), "should have tracked the new request")
	resp, err := acomm.NewResponse(unixReq, struct{}{}, nil, nil)
	if !s.NoError(err, "new response should not error") {
		return
	}
	if !s.NoError(acomm.Send(unixReq.ResponseHook, resp), "response send should not error") {
		return
	}

	lastResp := s.NextResp()
	if !s.NotNil(lastResp, "response should have been proxied to original http response hook") {
		return
	}
	s.Equal(resp.ID, lastResp.ID, "response should have been proxied to original http response hook")
	s.Equal(0, s.Tracker.NumRequests(), "should have removed the request from tracking")

	// Should not proxy a request already using unix response hook
	origUnixReq, err := acomm.NewRequest("foobar", "unix://foo", struct{}{}, nil, nil)
	if !s.NoError(err, "new request shoudl not error") {
		return
	}
	unixReq, err = s.Tracker.ProxyUnix(origUnixReq, 0)
	s.NoError(err, "should not error with unix response hook")
	s.Equal(origUnixReq, unixReq, "should not proxy unix response hook")
	s.Equal(0, s.Tracker.NumRequests(), "should not response an unproxied request")
}
예제 #5
0
func (s *ResponseTestSuite) TestSend() {
	// Mock HTTP response server
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		resp := &acomm.Response{}
		body, err := ioutil.ReadAll(r.Body)
		s.NoError(err, "should not fail reading body")
		s.NoError(json.Unmarshal(body, resp), "should not fail unmarshalling response")
		ack, _ := json.Marshal(&acomm.Response{})
		_, _ = w.Write(ack)
		s.Responses <- resp
	}))
	defer ts.Close()

	// Mock Unix response listener
	f, err := ioutil.TempFile("", "acommTest-")
	if !s.NoError(err, "failed to create test unix socket") {
		return
	}
	_ = f.Close()
	_ = os.Remove(f.Name())
	socketPath := fmt.Sprintf("%s.sock", f.Name())
	listener, err := net.Listen("unix", socketPath)
	if !s.NoError(err, "failed to listen on unix socket") {
		return
	}
	defer func() { _ = listener.Close() }()
	go func() {
		for {
			conn, err := listener.Accept()
			if err != nil {
				return
			}
			resp := &acomm.Response{}
			s.NoError(acomm.UnmarshalConnData(conn, resp), "should not fail unmarshalling conn data")
			_ = acomm.SendConnData(conn, &acomm.Response{})
			s.Responses <- resp
			_ = conn.Close()
		}
	}()

	resultJ, _ := json.Marshal(map[string]string{"foo": "bar"})
	response := &acomm.Response{
		ID:     uuid.New(),
		Result: (*json.RawMessage)(&resultJ),
	}

	tests := []struct {
		responseHook string
		expectedErr  bool
	}{
		{ts.URL, false},
		{"http://badpath", true},
		{fmt.Sprintf("unix://%s", socketPath), false},
		{fmt.Sprintf("unix://%s", "badpath"), true},
		{"foobar://", true},
	}

	for _, test := range tests {
		msg := testMsgFunc(test.responseHook)
		u, _ := url.ParseRequestURI(test.responseHook)
		err := acomm.Send(u, response)
		resp := s.NextResp()
		if test.expectedErr {
			s.Error(err, msg("send should fail"))
			s.Nil(resp, msg("response hook should not receive a response"))
		} else {
			if !s.NoError(err, msg("send should not fail")) {
				continue
			}
			s.Equal(response.ID, resp.ID, msg("response should be what was sent"))
		}
	}
}
예제 #6
0
// SystemStatus is a task handler to retrieve info look up and return system
// information. It depends on and makes requests for several other tasks.
func (s *Simple) SystemStatus(req *acomm.Request) (interface{}, *url.URL, error) {
	var args SystemStatusArgs
	if err := req.UnmarshalArgs(&args); err != nil {
		return nil, nil, err
	}
	if args.GuestID == "" {
		return nil, nil, errors.New("missing guest_id")
	}

	// Prepare multiple requests
	multiRequest := acomm.NewMultiRequest(s.tracker, 0)

	cpuReq, err := acomm.NewRequest("CPUInfo", s.tracker.URL().String(), &CPUInfoArgs{GuestID: args.GuestID}, nil, nil)
	if err != nil {
		return nil, nil, err
	}
	diskReq, err := acomm.NewRequest("DiskInfo", s.tracker.URL().String(), &DiskInfoArgs{GuestID: args.GuestID}, nil, nil)
	if err != nil {
		return nil, nil, err
	}

	requests := map[string]*acomm.Request{
		"CPUInfo":  cpuReq,
		"DiskInfo": diskReq,
	}

	for name, req := range requests {
		if err := multiRequest.AddRequest(name, req); err != nil {
			continue
		}
		if err := acomm.Send(s.config.CoordinatorURL(), req); err != nil {
			multiRequest.RemoveRequest(req)
			continue
		}
	}

	// Wait for the results
	responses := multiRequest.Responses()
	result := &SystemStatusResult{}

	if resp, ok := responses["CPUInfo"]; ok {
		if err := resp.UnmarshalResult(&(result.CPUs)); err != nil {
			log.WithFields(log.Fields{
				"name":  "CPUInfo",
				"resp":  resp,
				"error": err,
			}).Error("failed to unarshal result")
		}
	}

	if resp, ok := responses["DiskInfo"]; ok {
		if err := resp.UnmarshalResult(&(result.Disks)); err != nil {
			log.WithFields(log.Fields{
				"name":  "DiskInfo",
				"resp":  resp,
				"error": err,
			}).Error("failed to unarshal result")
		}
	}

	return result, nil, nil
}
예제 #7
0
func (s *ServerSuite) TestReqRespHandle() {
	// Start
	if !s.NoError(s.server.Start(), "failed to start server") {
		return
	}
	time.Sleep(time.Second)
	// Stop
	defer s.server.Stop()

	// Set up handlers
	result := make(chan *params, 10)

	// Task handler
	taskName := "foobar"
	taskListener := s.createTaskListener(taskName, result)
	if taskListener == nil {
		return
	}
	defer taskListener.Stop(0)

	// Response handlers
	responseServer, responseListener := s.createResponseHandlers(result)
	if responseServer != nil {
		defer responseServer.Close()
	}
	if responseListener != nil {
		defer responseListener.Stop(0)
	}
	if responseServer == nil || responseListener == nil {
		return
	}

	// Coordinator URLs
	internalURL, _ := url.ParseRequestURI("unix://" + filepath.Join(
		s.config.SocketDir(),
		"coordinator",
		s.config.ServiceName()+".sock"),
	)
	externalURL, _ := url.ParseRequestURI(fmt.Sprintf(
		"http://localhost:%v",
		s.configData.ExternalPort),
	)

	// Test cases
	tests := []struct {
		description  string
		taskName     string
		internal     bool
		params       *params
		expectFailed bool
	}{
		{"valid http", taskName, false, &params{uuid.New()}, false},
		{"valid unix", taskName, true, &params{uuid.New()}, false},
		{"bad task http", "asdf", false, &params{uuid.New()}, true},
		{"bad task unix", "asdf", true, &params{uuid.New()}, true},
	}

	for _, test := range tests {
		msg := testMsgFunc(test.description)
		hookURL := responseServer.URL
		coordinatorURL := externalURL
		if test.internal {
			hookURL = responseListener.URL().String()
			coordinatorURL = internalURL
		}

		req, _ := acomm.NewRequest(test.taskName, hookURL, test.params, nil, nil)
		if err := acomm.Send(coordinatorURL, req); err != nil {
			result <- nil
		}

		respData := <-result
		if test.expectFailed {
			s.Nil(respData, msg("should have failed"))
		} else {
			s.Equal(test.params, respData, msg("should have gotten the correct response data"))
		}

		drainChan(result)
	}
}