Example #1
0
func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error {
	if c.auth == nil {
		return fmt.Errorf("authentication required (using %q)", authFrame.class)
	}

	resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
	if err != nil {
		return err
	}

	req := &writeAuthResponseFrame{data: resp}

	for {
		select {
		case frameTicker <- struct{}{}:
		case <-ctx.Done():
			return ctx.Err()
		}

		framer, err := c.exec(ctx, req, nil)
		if err != nil {
			return err
		}

		frame, err := framer.parseFrame()
		if err != nil {
			return err
		}

		switch v := frame.(type) {
		case error:
			return v
		case *authSuccessFrame:
			if challenger != nil {
				return challenger.Success(v.data)
			}
			return nil
		case *authChallengeFrame:
			resp, challenger, err = challenger.Challenge(v.data)
			if err != nil {
				return err
			}

			req = &writeAuthResponseFrame{
				data: resp,
			}
		default:
			return fmt.Errorf("unknown frame response during authentication: %v", v)
		}

		framerPool.Put(framer)
	}
}
Example #2
0
func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error {
	m := map[string]string{
		"CQL_VERSION": c.cfg.CQLVersion,
	}

	if c.compressor != nil {
		m["COMPRESSION"] = c.compressor.Name()
	}

	select {
	case frameTicker <- struct{}{}:
	case <-ctx.Done():
		return ctx.Err()
	}

	framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil)
	if err != nil {
		return err
	}

	frame, err := framer.parseFrame()
	if err != nil {
		return err
	}

	switch v := frame.(type) {
	case error:
		return v
	case *readyFrame:
		return nil
	case *authenticateFrame:
		return c.authenticateHandshake(ctx, v, frameTicker)
	default:
		return NewErrProtocol("Unknown type of response to startup frame: %s", v)
	}
}
Example #3
0
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) {
	// TODO: move tracer onto conn
	stream, ok := c.streams.GetStream()
	if !ok {
		return nil, ErrNoStreams
	}

	// resp is basically a waiting semaphore protecting the framer
	framer := newFramer(c, c, c.compressor, c.version)

	c.mu.Lock()
	call := c.calls[stream]
	if call != nil {
		c.mu.Unlock()
		return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, call.streamID)
	} else {
		call = streamPool.Get().(*callReq)
	}
	c.calls[stream] = call
	c.mu.Unlock()

	call.framer = framer
	call.timeout = make(chan struct{})
	call.streamID = stream

	if tracer != nil {
		framer.trace()
	}

	err := req.writeFrame(framer, stream)
	if err != nil {
		// closeWithError will block waiting for this stream to either receive a response
		// or for us to timeout, close the timeout chan here. Im not entirely sure
		// but we should not get a response after an error on the write side.
		close(call.timeout)
		// I think this is the correct thing to do, im not entirely sure. It is not
		// ideal as readers might still get some data, but they probably wont.
		// Here we need to be careful as the stream is not available and if all
		// writes just timeout or fail then the pool might use this connection to
		// send a frame on, with all the streams used up and not returned.
		c.closeWithError(err)
		return nil, err
	}

	var timeoutCh <-chan time.Time
	if c.timeout > 0 {
		if call.timer == nil {
			call.timer = time.NewTimer(0)
			<-call.timer.C
		} else {
			if !call.timer.Stop() {
				select {
				case <-call.timer.C:
				default:
				}
			}
		}

		call.timer.Reset(c.timeout)
		timeoutCh = call.timer.C
	}

	var ctxDone <-chan struct{}
	if ctx != nil {
		ctxDone = ctx.Done()
	}

	select {
	case err := <-call.resp:
		close(call.timeout)
		if err != nil {
			if !c.Closed() {
				// if the connection is closed then we cant release the stream,
				// this is because the request is still outstanding and we have
				// been handed another error from another stream which caused the
				// connection to close.
				c.releaseStream(stream)
			}
			return nil, err
		}
	case <-timeoutCh:
		close(call.timeout)
		c.handleTimeout()
		return nil, ErrTimeoutNoResponse
	case <-ctxDone:
		close(call.timeout)
		return nil, ctx.Err()
	case <-c.quit:
		return nil, ErrConnectionClosed
	}

	// dont release the stream if detect a timeout as another request can reuse
	// that stream and get a response for the old request, which we have no
	// easy way of detecting.
	//
	// Ensure that the stream is not released if there are potentially outstanding
	// requests on the stream to prevent nil pointer dereferences in recv().
	defer c.releaseStream(stream)

	if v := framer.header.version.version(); v != c.version {
		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
	}

	return framer, nil
}
Example #4
0
// Connect establishes a connection to a Cassandra node.
func Connect(host *HostInfo, addr string, cfg *ConnConfig,
	errorHandler ConnErrorHandler, session *Session) (*Conn, error) {

	var (
		err  error
		conn net.Conn
	)

	dialer := &net.Dialer{
		Timeout: cfg.Timeout,
	}

	if cfg.tlsConfig != nil {
		// the TLS config is safe to be reused by connections but it must not
		// be modified after being used.
		conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
	} else {
		conn, err = dialer.Dial("tcp", addr)
	}

	if err != nil {
		return nil, err
	}

	c := &Conn{
		conn:         conn,
		r:            bufio.NewReader(conn),
		cfg:          cfg,
		calls:        make(map[int]*callReq),
		timeout:      cfg.Timeout,
		version:      uint8(cfg.ProtoVersion),
		addr:         conn.RemoteAddr().String(),
		errorHandler: errorHandler,
		compressor:   cfg.Compressor,
		auth:         cfg.Authenticator,
		quit:         make(chan struct{}),
		session:      session,
		streams:      streams.New(cfg.ProtoVersion),
		host:         host,
	}

	if cfg.Keepalive > 0 {
		c.setKeepalive(cfg.Keepalive)
	}

	var (
		ctx    context.Context
		cancel func()
	)
	if c.timeout > 0 {
		ctx, cancel = context.WithTimeout(context.Background(), c.timeout)
	} else {
		ctx, cancel = context.WithCancel(context.Background())
	}
	defer cancel()

	frameTicker := make(chan struct{}, 1)
	startupErr := make(chan error)
	go func() {
		for range frameTicker {
			err := c.recv()
			if err != nil {
				select {
				case startupErr <- err:
				case <-ctx.Done():
				}

				return
			}
		}
	}()

	go func() {
		defer close(frameTicker)
		err := c.startup(ctx, frameTicker)
		select {
		case startupErr <- err:
		case <-ctx.Done():
		}
	}()

	select {
	case err := <-startupErr:
		if err != nil {
			c.Close()
			return nil, err
		}
	case <-ctx.Done():
		c.Close()
		return nil, errors.New("gocql: no response to connection startup within timeout")
	}

	go c.serve()

	return c, nil
}