Exemplo n.º 1
0
// Receive decodes each response result into result and calls f with the result's
// id. If a response error was received, result will not be touched and err will
// be populated for f. Likewise, on the end of a streaming call, result will not
// be touched and success will be true for the call to f.
func (r *PipeReader) Receive(result interface{}, f func(id []byte, err error, success bool) bool) error {
	for {
		var resp protocol.Response
		if err := r.dec.Decode(&resp); err == io.EOF {
			break
		} else if err != nil {
			return err
		} else if err := protocol.CheckID(resp.ID); err != nil {
			return err
		}

		if resp.Error == nil && !resp.Success {
			if err := json.Unmarshal(resp.Result, result); err != nil {
				return fmt.Errorf("error unmarshalling result: %v", err)
			}
		}

		var respErr error
		if resp.Error != nil {
			// this is due to Go's handling of nil values with interface types
			respErr = resp.Error
		}

		if !f(resp.ID, respErr, resp.Success) {
			break
		}
	}

	return nil
}
Exemplo n.º 2
0
func (e Endpoint) handleRequest(ctx Context, de *json.Decoder, en *json.Encoder) (err error) {
	var req protocol.Request

	wr := newResponseWriter(&req, en)
	defer func() {
		// On a panic, write error response
		if e := recover(); e != nil {
			switch e := e.(type) {
			case *protocol.Error:
				err = wr.pError(e)
			case error:
				err = wr.Error(protocol.ErrorInternal, e.Error())
			default:
				err = wr.Error(protocol.ErrorInternal, fmt.Sprintf("%+v", e))
			}
		}
	}()

	// Decode request
	if err := de.Decode(&req); err == io.EOF {
		return err
	} else if err != nil {
		if e := wr.Error(protocol.ErrorParsing, fmt.Sprintf("error decoding JSON: %v", err)); e != nil {
			return fmt.Errorf("failed to write error response for %v: %v", err, e)
		}
		return err
	}

	// Validate request
	if err := protocol.CheckID(req.ID); err != nil {
		return wr.Error(protocol.ErrorInvalidRequest, err.Error())
	} else if req.Version != protocol.Version2 && req.Version != protocol.Version2Streaming {
		return wr.Error(protocol.ErrorInvalidRequest, fmt.Sprintf("invalid protocol version: %q", req.Version))
	}

	serviceName, methodName, err := parseServiceMethod(req.Method)
	if err != nil {
		return wr.Error(protocol.ErrorMethodNotFound,
			fmt.Sprintf("Unknown method name format: %q", req.Method))
	}

	// Handle builtin /ServerInfo service
	if serviceName == serverInfoService {
		req.Version = protocol.Version2 // downgrade version for ServerInfo methods
		if methodName != listMethod {
			return wr.Error(protocol.ErrorMethodNotFound,
				fmt.Sprintf("%q method not found in builtin service %q", methodName, serverInfoService))
		}

		spec, err := e.serviceList()
		if err != nil {
			panic(err) // shouldn't happen, internal server error
		}

		return wr.Result(spec)
	}

	// Resolve method
	method, err := e.Resolve(serviceName, methodName)
	if err != nil {
		return wr.Error(protocol.ErrorMethodNotFound,
			fmt.Sprintf("Method not found: /%s/%s", serviceName, methodName))
	}

	if !method.Stream {
		// Ensure downgraded version if method returns a single result
		req.Version = protocol.Version2
	}

	// Construct method output handler
	var (
		results [][]byte // For Version2 array result
		out     func(result []byte)
		outErr  error
	)
	switch req.Version {
	case protocol.Version2:
		out = func(result []byte) {
			results = append(results, result)
		}
	case protocol.Version2Streaming:
		out = func(result []byte) {
			if outErr != nil {
				return
			}
			outErr = wr.Result(result)
		}
	default:
		panic(fmt.Errorf("Unhandled version outputs: %q", req.Version))
	}

	// Invoke method with params
	err = method.Invoke(ctx, req.Params, out)
	if outErr != nil {
		panic(outErr)
	}
	if err != nil {
		switch err := err.(type) {
		case *protocol.Error:
			return wr.pError(err)
		default:
			return wr.Error(protocol.ErrorApplication, err.Error())
		}
	}

	switch req.Version {
	case protocol.Version2:
		if method.Stream {
			// If the client requested V2, but the method is streaming, we must
			// collect the results into a JSON array.
			arry := append(append([]byte("["), bytes.Join(results, []byte(","))...), []byte("]")...)
			return wr.Result(arry)
		} else if len(results) == 1 {
			// If the client requested V2 and the method is non-streaming, we just
			// return the result.
			return wr.Result(results[0])
		} else {
			// This shouldn't happen due to the way method registration works; a
			// non-streaming method MUST return a single result.
			log.Fatalf("Found non-singleton result for non-stream method: %v", results)
		}
	case protocol.Version2Streaming:
		// We already wrote the results, just end the stream.
		return wr.Success()
	default:
		// This case should never happen if we've correctly handled all protocol
		// versions.
		panic(fmt.Errorf("Unexpected protocol version: %q", req.Version))
	}

	panic("unexpected end of request handler")
}