func (p *pipeTransport) SendMessage(ctx context.Context, msg rpccapnp.Message) error { if !p.startSend() { return errClosed } defer p.finishSend() buf, err := msg.Segment().Message().Marshal() if err != nil { return err } mm, err := capnp.Unmarshal(buf) if err != nil { return err } msg, err = rpccapnp.ReadRootMessage(mm) if err != nil { return err } select { case p.w <- msg: return nil case <-ctx.Done(): return ctx.Err() case <-p.finish: return errClosed case <-p.otherFin: return errBrokenPipe } }
// copyRPCMessage clones an RPC packet. func copyRPCMessage(m rpccapnp.Message) rpccapnp.Message { mm := copyMessage(m.Segment().Message()) rpcMsg, err := rpccapnp.ReadRootMessage(mm) if err != nil { panic(err) } return rpcMsg }
// handleCallMessage is run in the coordinate goroutine to handle a // received call message. It mutates the capability table of its // parameter. func (c *Conn) handleCallMessage(m rpccapnp.Message) error { mcall, err := m.Call() if err != nil { return err } mt, err := mcall.Target() if err != nil { return err } if mt.Which() != rpccapnp.MessageTarget_Which_importedCap && mt.Which() != rpccapnp.MessageTarget_Which_promisedAnswer { um := newUnimplementedMessage(nil, m) return c.sendMessage(um) } mparams, err := mcall.Params() if err != nil { return err } if err := c.populateMessageCapTable(mparams); err == errUnimplemented { um := newUnimplementedMessage(nil, m) return c.sendMessage(um) } else if err != nil { c.abort(err) return err } ctx, cancel := c.newContext() id := answerID(mcall.QuestionId()) a := c.answers.insert(id, cancel) if a == nil { // Question ID reused, error out. c.abort(errQuestionReused) return errQuestionReused } meth := capnp.Method{ InterfaceID: mcall.InterfaceId(), MethodID: mcall.MethodId(), } paramContent, err := mparams.Content() if err != nil { return err } cl := &capnp.Call{ Ctx: ctx, Method: meth, Params: capnp.ToStruct(paramContent), } if err := c.routeCallMessage(a, mt, cl); err != nil { msgs := a.reject(nil, err) for _, m := range msgs { if err := c.sendMessage(m); err != nil { return err } } return nil } return nil }
func (c *Conn) handleDisembargoMessage(msg rpccapnp.Message) error { d, err := msg.Disembargo() if err != nil { return err } dtarget, err := d.Target() if err != nil { return err } switch d.Context().Which() { case rpccapnp.Disembargo_context_Which_senderLoopback: id := embargoID(d.Context().SenderLoopback()) if dtarget.Which() != rpccapnp.MessageTarget_Which_promisedAnswer { return errDisembargoNonImport } dpa, err := dtarget.PromisedAnswer() if err != nil { return err } aid := answerID(dpa.QuestionId()) a := c.answers.get(aid) if a == nil { return errDisembargoMissingAnswer } dtrans, err := dpa.Transform() if err != nil { return err } transform := promisedAnswerOpsToTransform(dtrans) queued, err := a.queueDisembargo(transform, id, dtarget) if err != nil { return err } if !queued { // There's nothing to embargo; everything's been delivered. resp := newDisembargoMessage(nil, rpccapnp.Disembargo_context_Which_receiverLoopback, id) rd, _ := resp.Disembargo() if err := rd.SetTarget(dtarget); err != nil { return err } c.sendMessage(resp) } case rpccapnp.Disembargo_context_Which_receiverLoopback: id := embargoID(d.Context().ReceiverLoopback()) c.embargoes.disembargo(id) default: um := newUnimplementedMessage(nil, msg) c.sendMessage(um) } return nil }
func (s *streamTransport) SendMessage(ctx context.Context, msg rpccapnp.Message) error { s.wbuf.Reset() if err := s.enc.Encode(msg.Segment().Message()); err != nil { return err } if s.deadline != nil { // TODO(light): log errors if d, ok := ctx.Deadline(); ok { s.deadline.SetWriteDeadline(d) } else { s.deadline.SetWriteDeadline(time.Time{}) } } _, err := s.rwc.Write(s.wbuf.Bytes()) return err }
func formatMsg(w io.Writer, m rpccapnp.Message) { switch m.Which() { case rpccapnp.Message_Which_unimplemented: fmt.Fprint(w, "unimplemented") case rpccapnp.Message_Which_abort: mabort, _ := m.Abort() reason, _ := mabort.Reason() fmt.Fprintf(w, "abort type=%v: %s", mabort.Type(), reason) case rpccapnp.Message_Which_bootstrap: mboot, _ := m.Bootstrap() fmt.Fprintf(w, "bootstrap id=%d", mboot.QuestionId()) case rpccapnp.Message_Which_call: c, _ := m.Call() fmt.Fprintf(w, "call id=%d target=<", c.QuestionId()) tgt, _ := c.Target() formatMessageTarget(w, tgt) fmt.Fprintf(w, "> @%#x/@%d", c.InterfaceId(), c.MethodId()) case rpccapnp.Message_Which_return: r, _ := m.Return() fmt.Fprintf(w, "return id=%d", r.AnswerId()) if r.ReleaseParamCaps() { fmt.Fprint(w, " releaseParamCaps") } switch r.Which() { case rpccapnp.Return_Which_results: case rpccapnp.Return_Which_exception: exc, _ := r.Exception() reason, _ := exc.Reason() fmt.Fprintf(w, ", exception type=%v: %s", exc.Type(), reason) case rpccapnp.Return_Which_canceled: fmt.Fprint(w, ", canceled") case rpccapnp.Return_Which_resultsSentElsewhere: fmt.Fprint(w, ", results sent elsewhere") case rpccapnp.Return_Which_takeFromOtherQuestion: fmt.Fprint(w, ", results sent elsewhere") case rpccapnp.Return_Which_acceptFromThirdParty: fmt.Fprint(w, ", accept from third party") default: fmt.Fprintf(w, ", UNKNOWN RESULT which=%v", r.Which()) } case rpccapnp.Message_Which_finish: fin, _ := m.Finish() fmt.Fprintf(w, "finish id=%d", fin.QuestionId()) if fin.ReleaseResultCaps() { fmt.Fprint(w, " releaseResultCaps") } case rpccapnp.Message_Which_resolve: r, _ := m.Resolve() fmt.Fprintf(w, "resolve id=%d ", r.PromiseId()) switch r.Which() { case rpccapnp.Resolve_Which_cap: fmt.Fprint(w, "capability=") c, _ := r.Cap() formatCapDescriptor(w, c) case rpccapnp.Resolve_Which_exception: exc, _ := r.Exception() reason, _ := exc.Reason() fmt.Fprintf(w, "exception type=%v: %s", exc.Type(), reason) default: fmt.Fprintf(w, "UNKNOWN RESOLUTION which=%v", r.Which()) } case rpccapnp.Message_Which_release: rel, _ := m.Release() fmt.Fprintf(w, "release id=%d by %d", rel.Id(), rel.ReferenceCount()) case rpccapnp.Message_Which_disembargo: de, _ := m.Disembargo() tgt, _ := de.Target() fmt.Fprint(w, "disembargo <") formatMessageTarget(w, tgt) fmt.Fprint(w, "> ") dc := de.Context() switch dc.Which() { case rpccapnp.Disembargo_context_Which_senderLoopback: fmt.Fprintf(w, "sender loopback id=%d", dc.SenderLoopback()) case rpccapnp.Disembargo_context_Which_receiverLoopback: fmt.Fprintf(w, "receiver loopback id=%d", dc.ReceiverLoopback()) case rpccapnp.Disembargo_context_Which_accept: fmt.Fprint(w, "accept") case rpccapnp.Disembargo_context_Which_provide: fmt.Fprintf(w, "provide id=%d", dc.Provide()) default: fmt.Fprintf(w, "UNKNOWN CONTEXT which=%v", dc.Which()) } case rpccapnp.Message_Which_obsoleteSave: fmt.Fprint(w, "save") case rpccapnp.Message_Which_obsoleteDelete: fmt.Fprint(w, "delete") case rpccapnp.Message_Which_provide: prov, _ := m.Provide() tgt, _ := prov.Target() fmt.Fprintf(w, "provide id=%d <", prov.QuestionId()) formatMessageTarget(w, tgt) fmt.Fprint(w, ">") case rpccapnp.Message_Which_accept: acc, _ := m.Accept() fmt.Fprintf(w, "accept id=%d", acc.QuestionId()) if acc.Embargo() { fmt.Fprint(w, " with embargo") } case rpccapnp.Message_Which_join: join, _ := m.Join() tgt, _ := join.Target() fmt.Fprintf(w, "join id=%d <", join.QuestionId()) formatMessageTarget(w, tgt) fmt.Fprint(w, ">") default: fmt.Fprintf(w, "UNKNOWN MESSAGE which=%v", m.Which()) } }
// handleReturnMessage is run in the coordinate goroutine. func (c *Conn) handleReturnMessage(m rpccapnp.Message) error { ret, err := m.Return() if err != nil { return err } id := questionID(ret.AnswerId()) q := c.questions.pop(id) if q == nil { return fmt.Errorf("received return for unknown question id=%d", id) } if ret.ReleaseParamCaps() { c.exports.releaseList(q.paramCaps) } if _, _, _, resolved := q.peek(); resolved { // If the question was already resolved, that means it was canceled, // in which case we already sent the finish message. return nil } releaseResultCaps := true switch ret.Which() { case rpccapnp.Return_Which_results: releaseResultCaps = false results, err := ret.Results() if err != nil { return err } if err := c.populateMessageCapTable(results); err == errUnimplemented { um := newUnimplementedMessage(nil, m) c.sendMessage(um) return errUnimplemented } else if err != nil { c.abort(err) return err } content, err := results.Content() if err != nil { return err } disembargoes := q.fulfill(content, c.embargoes.new) for _, d := range disembargoes { if err := c.sendMessage(d); err != nil { // shutdown return nil } } case rpccapnp.Return_Which_exception: exc, err := ret.Exception() if err != nil { return err } e := error(Exception{exc}) if q.method != nil { e = &capnp.MethodError{ Method: q.method, Err: e, } } else { e = bootstrapError{e} } q.reject(questionResolved, e) case rpccapnp.Return_Which_canceled: err := &questionError{ id: id, method: q.method, err: fmt.Errorf("receiver reported canceled"), } log.Println(err) q.reject(questionResolved, err) return nil default: um := newUnimplementedMessage(nil, m) c.sendMessage(um) return errUnimplemented } fin := newFinishMessage(nil, id, releaseResultCaps) c.sendMessage(fin) return nil }
// handleMessage is run in the coordinate goroutine. func (c *Conn) handleMessage(m rpccapnp.Message) { switch m.Which() { case rpccapnp.Message_Which_unimplemented: // no-op for now to avoid feedback loop case rpccapnp.Message_Which_abort: ma, err := m.Abort() if err != nil { log.Println("rpc: decode abort:", err) // Keep going, since we're trying to abort anyway. } a := Abort{ma} log.Print(a) c.manager.shutdown(a) case rpccapnp.Message_Which_return: if err := c.handleReturnMessage(m); err != nil { log.Println("rpc: handle return:", err) } case rpccapnp.Message_Which_finish: // TODO(light): what if answers never had this ID? // TODO(light): return if cancelled mfin, err := m.Finish() if err != nil { log.Println("rpc: decode finish:", err) return } id := answerID(mfin.QuestionId()) a := c.answers.pop(id) a.cancel() if mfin.ReleaseResultCaps() { c.exports.releaseList(a.resultCaps) } case rpccapnp.Message_Which_bootstrap: boot, err := m.Bootstrap() if err != nil { log.Println("rpc: decode bootstrap:", err) return } id := answerID(boot.QuestionId()) if err := c.handleBootstrapMessage(id); err != nil { log.Println("rpc: handle bootstrap:", err) } case rpccapnp.Message_Which_call: if err := c.handleCallMessage(m); err != nil { log.Println("rpc: handle call:", err) } case rpccapnp.Message_Which_release: rel, err := m.Release() if err != nil { log.Println("rpc: decode release:", err) return } id := exportID(rel.Id()) refs := int(rel.ReferenceCount()) c.exports.release(id, refs) case rpccapnp.Message_Which_disembargo: if err := c.handleDisembargoMessage(m); err != nil { // Any failure in a disembargo is a protocol violation. c.abort(err) } default: log.Printf("rpc: received unimplemented message, which = %v", m.Which()) um := newUnimplementedMessage(nil, m) c.sendMessage(um) } }