Example #1
1
// refreshRequestForwardingConnection ensures that the client/transport are
// alive and that the current active address value matches the most
// recently-known address.
func (c *Core) refreshRequestForwardingConnection(clusterAddr string) error {
	c.requestForwardingConnectionLock.Lock()
	defer c.requestForwardingConnectionLock.Unlock()

	// It's nil but we don't have an address anyways, so exit
	if c.requestForwardingConnection == nil && clusterAddr == "" {
		return nil
	}

	// NOTE: We don't fast path the case where we have a connection because the
	// address is the same, because the cert/key could have changed if the
	// active node ended up being the same node. Before we hit this function in
	// Leader() we'll have done a hash on the advertised info to ensure that we
	// won't hit this function unnecessarily anyways.

	// Disabled, potentially, so clean up anything that might be around.
	if clusterAddr == "" {
		c.clearForwardingClients()
		return nil
	}

	clusterURL, err := url.Parse(clusterAddr)
	if err != nil {
		c.logger.Error("core/refreshRequestForwardingConnection: error parsing cluster address", "error", err)
		return err
	}

	switch os.Getenv("VAULT_USE_GRPC_REQUEST_FORWARDING") {
	case "":
		// Set up normal HTTP forwarding handling
		tlsConfig, err := c.ClusterTLSConfig()
		if err != nil {
			c.logger.Error("core/refreshRequestForwardingConnection: error fetching cluster tls configuration", "error", err)
			return err
		}
		tp := &http2.Transport{
			TLSClientConfig: tlsConfig,
		}
		c.requestForwardingConnection = &activeConnection{
			transport:   tp,
			clusterAddr: clusterAddr,
		}

	default:
		// Set up grpc forwarding handling
		// It's not really insecure, but we have to dial manually to get the
		// ALPN header right. It's just "insecure" because GRPC isn't managing
		// the TLS state.
		ctx, cancelFunc := context.WithCancel(context.Background())
		c.rpcClientConnCancelFunc = cancelFunc
		c.rpcClientConn, err = grpc.DialContext(ctx, clusterURL.Host, grpc.WithDialer(c.getGRPCDialer()), grpc.WithInsecure())
		if err != nil {
			c.logger.Error("core/refreshRequestForwardingConnection: err setting up rpc client", "error", err)
			return err
		}
		c.rpcForwardingClient = NewRequestForwardingClient(c.rpcClientConn)
	}

	return nil
}
Example #2
0
func TestDropRequestFailedNonFailFast(t *testing.T) {
	// Start a backend.
	beLis, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to listen %v", err)
	}
	beAddr := strings.Split(beLis.Addr().String(), ":")
	bePort, err := strconv.Atoi(beAddr[1])
	backends := startBackends(t, besn, beLis)
	defer stopBackends(backends)

	// Start a load balancer.
	lbLis, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to create the listener for the load balancer %v", err)
	}
	lbCreds := &serverNameCheckCreds{
		sn: lbsn,
	}
	lb := grpc.NewServer(grpc.Creds(lbCreds))
	if err != nil {
		t.Fatalf("Failed to generate the port number %v", err)
	}
	be := &lbpb.Server{
		IpAddress:        []byte(beAddr[0]),
		Port:             int32(bePort),
		LoadBalanceToken: lbToken,
		DropRequest:      true,
	}
	var bes []*lbpb.Server
	bes = append(bes, be)
	sl := &lbpb.ServerList{
		Servers: bes,
	}
	ls := newRemoteBalancer(sl)
	lbpb.RegisterLoadBalancerServer(lb, ls)
	go func() {
		lb.Serve(lbLis)
	}()
	defer func() {
		ls.stop()
		lb.Stop()
	}()
	creds := serverNameCheckCreds{
		expected: besn,
	}
	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
		addr: lbLis.Addr().String(),
	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
	if err != nil {
		t.Fatalf("Failed to dial to the backend %v", err)
	}
	helloC := hwpb.NewGreeterClient(cc)
	ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond)
	if _, err := helloC.SayHello(ctx, &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.DeadlineExceeded)
	}
	cc.Close()
}
Example #3
0
// GRPCDial calls grpc.Dial with the options appropriate for the context.
func (ctx *Context) GRPCDial(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
	ctx.conns.Lock()
	meta, ok := ctx.conns.cache[target]
	if !ok {
		meta = &connMeta{}
		ctx.conns.cache[target] = meta
	}
	ctx.conns.Unlock()

	meta.Do(func() {
		var dialOpt grpc.DialOption
		if ctx.Insecure {
			dialOpt = grpc.WithInsecure()
		} else {
			tlsConfig, err := ctx.GetClientTLSConfig()
			if err != nil {
				meta.err = err
				return
			}
			dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))
		}

		dialOpts := make([]grpc.DialOption, 0, 2+len(opts))
		dialOpts = append(dialOpts, dialOpt)
		dialOpts = append(dialOpts, grpc.WithBackoffMaxDelay(maxBackoff))
		dialOpts = append(dialOpts, opts...)

		if log.V(1) {
			log.Infof(ctx.masterCtx, "dialing %s", target)
		}
		meta.conn, meta.err = grpc.DialContext(ctx.masterCtx, target, dialOpts...)
		if meta.err == nil {
			if err := ctx.Stopper.RunTask(func() {
				ctx.Stopper.RunWorker(func() {
					err := ctx.runHeartbeat(meta.conn, target)
					if err != nil && !grpcutil.IsClosedConnection(err) {
						log.Error(ctx.masterCtx, err)
					}
					ctx.removeConn(target, meta)
				})
			}); err != nil {
				meta.err = err
				// removeConn and ctx's cleanup worker both lock ctx.conns. However,
				// to avoid racing with meta's initialization, the cleanup worker
				// blocks on meta.Do while holding ctx.conns. Invoke removeConn
				// asynchronously to avoid deadlock.
				go ctx.removeConn(target, meta)
			}
		}
	})

	return meta.conn, meta.err
}
Example #4
0
File: dial.go Project: naunga/vault
// DialGRPC returns a GRPC connection for use communicating with a Google cloud
// service, configured with the given ClientOptions.
func DialGRPC(ctx context.Context, opts ...option.ClientOption) (*grpc.ClientConn, error) {
	var o internal.DialSettings
	for _, opt := range opts {
		opt.Apply(&o)
	}
	if o.HTTPClient != nil {
		return nil, errors.New("unsupported HTTP client specified")
	}
	if o.GRPCConn != nil {
		return o.GRPCConn, nil
	}
	if o.ServiceAccountJSONFilename != "" {
		ts, err := serviceAcctTokenSource(ctx, o.ServiceAccountJSONFilename, o.Scopes...)
		if err != nil {
			return nil, err
		}
		o.TokenSource = ts
	}
	if o.TokenSource == nil {
		var err error
		o.TokenSource, err = google.DefaultTokenSource(ctx, o.Scopes...)
		if err != nil {
			return nil, fmt.Errorf("google.DefaultTokenSource: %v", err)
		}
	}
	grpcOpts := []grpc.DialOption{
		grpc.WithPerRPCCredentials(oauth.TokenSource{o.TokenSource}),
		grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")),
	}
	if appengineDialerHook != nil {
		// Use the Socket API on App Engine.
		grpcOpts = append(grpcOpts, appengineDialerHook(ctx))
	}
	grpcOpts = append(grpcOpts, o.GRPCDialOpts...)
	if o.UserAgent != "" {
		grpcOpts = append(grpcOpts, grpc.WithUserAgent(o.UserAgent))
	}
	return grpc.DialContext(ctx, o.Endpoint, grpcOpts...)
}
Example #5
0
func TestDropRequest(t *testing.T) {
	// Start 2 backends.
	beLis1, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to listen %v", err)
	}
	beAddr1 := strings.Split(beLis1.Addr().String(), ":")
	bePort1, err := strconv.Atoi(beAddr1[1])

	beLis2, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to listen %v", err)
	}
	beAddr2 := strings.Split(beLis2.Addr().String(), ":")
	bePort2, err := strconv.Atoi(beAddr2[1])

	backends := startBackends(t, besn, beLis1, beLis2)
	defer stopBackends(backends)

	// Start a load balancer.
	lbLis, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to create the listener for the load balancer %v", err)
	}
	lbCreds := &serverNameCheckCreds{
		sn: lbsn,
	}
	lb := grpc.NewServer(grpc.Creds(lbCreds))
	if err != nil {
		t.Fatalf("Failed to generate the port number %v", err)
	}
	var bes []*lbpb.Server
	be := &lbpb.Server{
		IpAddress:        []byte(beAddr1[0]),
		Port:             int32(bePort1),
		LoadBalanceToken: lbToken,
		DropRequest:      true,
	}
	bes = append(bes, be)
	be = &lbpb.Server{
		IpAddress:        []byte(beAddr2[0]),
		Port:             int32(bePort2),
		LoadBalanceToken: lbToken,
		DropRequest:      false,
	}
	bes = append(bes, be)
	sl := &lbpb.ServerList{
		Servers: bes,
	}
	ls := newRemoteBalancer(sl)
	lbpb.RegisterLoadBalancerServer(lb, ls)
	go func() {
		lb.Serve(lbLis)
	}()
	defer func() {
		ls.stop()
		lb.Stop()
	}()
	creds := serverNameCheckCreds{
		expected: besn,
	}
	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
		addr: lbLis.Addr().String(),
	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
	if err != nil {
		t.Fatalf("Failed to dial to the backend %v", err)
	}
	// The 1st fail-fast RPC should fail because the 1st backend has DropRequest set to true.
	helloC := hwpb.NewGreeterClient(cc)
	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
	}
	// The 2nd fail-fast RPC should succeed since it chooses the non-drop-request backend according
	// to the round robin policy.
	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
	}
	// The 3nd non-fail-fast RPC should succeed.
	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
	}
	cc.Close()
}
Example #6
0
func TestServerExpiration(t *testing.T) {
	// Start a backend.
	beLis, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to listen %v", err)
	}
	beAddr := strings.Split(beLis.Addr().String(), ":")
	bePort, err := strconv.Atoi(beAddr[1])
	backends := startBackends(t, besn, beLis)
	defer stopBackends(backends)

	// Start a load balancer.
	lbLis, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		t.Fatalf("Failed to create the listener for the load balancer %v", err)
	}
	lbCreds := &serverNameCheckCreds{
		sn: lbsn,
	}
	lb := grpc.NewServer(grpc.Creds(lbCreds))
	if err != nil {
		t.Fatalf("Failed to generate the port number %v", err)
	}
	be := &lbpb.Server{
		IpAddress:        []byte(beAddr[0]),
		Port:             int32(bePort),
		LoadBalanceToken: lbToken,
	}
	var bes []*lbpb.Server
	bes = append(bes, be)
	exp := &lbpb.Duration{
		Seconds: 0,
		Nanos:   100000000, // 100ms
	}
	var sls []*lbpb.ServerList
	sl := &lbpb.ServerList{
		Servers:            bes,
		ExpirationInterval: exp,
	}
	sls = append(sls, sl)
	sl = &lbpb.ServerList{
		Servers: bes,
	}
	sls = append(sls, sl)
	var intervals []time.Duration
	intervals = append(intervals, 0)
	intervals = append(intervals, 500*time.Millisecond)
	ls := newRemoteBalancer(sls, intervals)
	lbpb.RegisterLoadBalancerServer(lb, ls)
	go func() {
		lb.Serve(lbLis)
	}()
	defer func() {
		ls.stop()
		lb.Stop()
	}()
	creds := serverNameCheckCreds{
		expected: besn,
	}
	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
		addr: lbLis.Addr().String(),
	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
	if err != nil {
		t.Fatalf("Failed to dial to the backend %v", err)
	}
	helloC := hwpb.NewGreeterClient(cc)
	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
	}
	// Sleep and wake up when the first server list gets expired.
	time.Sleep(150 * time.Millisecond)
	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable)
	}
	// A non-failfast rpc should be succeeded after the second server list is received from
	// the remote load balancer.
	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil {
		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
	}
	cc.Close()
}