func (c *Conn) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame, frameTicker chan struct{}) error { if c.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } resp, challenger, err := c.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} for { select { case frameTicker <- struct{}{}: case <-ctx.Done(): return ctx.Err() } framer, err := c.exec(ctx, req, nil) if err != nil { return err } frame, err := framer.parseFrame() if err != nil { return err } switch v := frame.(type) { case error: return v case *authSuccessFrame: if challenger != nil { return challenger.Success(v.data) } return nil case *authChallengeFrame: resp, challenger, err = challenger.Challenge(v.data) if err != nil { return err } req = &writeAuthResponseFrame{ data: resp, } default: return fmt.Errorf("unknown frame response during authentication: %v", v) } framerPool.Put(framer) } }
func (c *Conn) startup(ctx context.Context, frameTicker chan struct{}) error { m := map[string]string{ "CQL_VERSION": c.cfg.CQLVersion, } if c.compressor != nil { m["COMPRESSION"] = c.compressor.Name() } select { case frameTicker <- struct{}{}: case <-ctx.Done(): return ctx.Err() } framer, err := c.exec(ctx, &writeStartupFrame{opts: m}, nil) if err != nil { return err } frame, err := framer.parseFrame() if err != nil { return err } switch v := frame.(type) { case error: return v case *readyFrame: return nil case *authenticateFrame: return c.authenticateHandshake(ctx, v, frameTicker) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } }
func (c *Conn) exec(ctx context.Context, req frameWriter, tracer Tracer) (*framer, error) { // TODO: move tracer onto conn stream, ok := c.streams.GetStream() if !ok { return nil, ErrNoStreams } // resp is basically a waiting semaphore protecting the framer framer := newFramer(c, c, c.compressor, c.version) c.mu.Lock() call := c.calls[stream] if call != nil { c.mu.Unlock() return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, call.streamID) } else { call = streamPool.Get().(*callReq) } c.calls[stream] = call c.mu.Unlock() call.framer = framer call.timeout = make(chan struct{}) call.streamID = stream if tracer != nil { framer.trace() } err := req.writeFrame(framer, stream) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure // but we should not get a response after an error on the write side. close(call.timeout) // I think this is the correct thing to do, im not entirely sure. It is not // ideal as readers might still get some data, but they probably wont. // Here we need to be careful as the stream is not available and if all // writes just timeout or fail then the pool might use this connection to // send a frame on, with all the streams used up and not returned. c.closeWithError(err) return nil, err } var timeoutCh <-chan time.Time if c.timeout > 0 { if call.timer == nil { call.timer = time.NewTimer(0) <-call.timer.C } else { if !call.timer.Stop() { select { case <-call.timer.C: default: } } } call.timer.Reset(c.timeout) timeoutCh = call.timer.C } var ctxDone <-chan struct{} if ctx != nil { ctxDone = ctx.Done() } select { case err := <-call.resp: close(call.timeout) if err != nil { if !c.Closed() { // if the connection is closed then we cant release the stream, // this is because the request is still outstanding and we have // been handed another error from another stream which caused the // connection to close. c.releaseStream(stream) } return nil, err } case <-timeoutCh: close(call.timeout) c.handleTimeout() return nil, ErrTimeoutNoResponse case <-ctxDone: close(call.timeout) return nil, ctx.Err() case <-c.quit: return nil, ErrConnectionClosed } // dont release the stream if detect a timeout as another request can reuse // that stream and get a response for the old request, which we have no // easy way of detecting. // // Ensure that the stream is not released if there are potentially outstanding // requests on the stream to prevent nil pointer dereferences in recv(). defer c.releaseStream(stream) if v := framer.header.version.version(); v != c.version { return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) } return framer, nil }
// Connect establishes a connection to a Cassandra node. func Connect(host *HostInfo, addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) { var ( err error conn net.Conn ) dialer := &net.Dialer{ Timeout: cfg.Timeout, } if cfg.tlsConfig != nil { // the TLS config is safe to be reused by connections but it must not // be modified after being used. conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig) } else { conn, err = dialer.Dial("tcp", addr) } if err != nil { return nil, err } c := &Conn{ conn: conn, r: bufio.NewReader(conn), cfg: cfg, calls: make(map[int]*callReq), timeout: cfg.Timeout, version: uint8(cfg.ProtoVersion), addr: conn.RemoteAddr().String(), errorHandler: errorHandler, compressor: cfg.Compressor, auth: cfg.Authenticator, quit: make(chan struct{}), session: session, streams: streams.New(cfg.ProtoVersion), host: host, } if cfg.Keepalive > 0 { c.setKeepalive(cfg.Keepalive) } var ( ctx context.Context cancel func() ) if c.timeout > 0 { ctx, cancel = context.WithTimeout(context.Background(), c.timeout) } else { ctx, cancel = context.WithCancel(context.Background()) } defer cancel() frameTicker := make(chan struct{}, 1) startupErr := make(chan error) go func() { for range frameTicker { err := c.recv() if err != nil { select { case startupErr <- err: case <-ctx.Done(): } return } } }() go func() { defer close(frameTicker) err := c.startup(ctx, frameTicker) select { case startupErr <- err: case <-ctx.Done(): } }() select { case err := <-startupErr: if err != nil { c.Close() return nil, err } case <-ctx.Done(): c.Close() return nil, errors.New("gocql: no response to connection startup within timeout") } go c.serve() return c, nil }