func (*trySuite) TestOneSuccess(c *gc.C) { try := parallel.NewTry(0, nil) try.Start(tryFunc(0, result("hello"), nil)) val, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(val, gc.Equals, result("hello")) }
func (*trySuite) TestExtraResultsAreClosed(c *gc.C) { try := parallel.NewTry(0, nil) begin := make([]chan struct{}, 4) results := make([]*closeResult, len(begin)) for i := range begin { begin[i] = make(chan struct{}) results[i] = &closeResult{make(chan struct{})} i := i try.Start(func(<-chan struct{}) (io.Closer, error) { <-begin[i] return results[i], nil }) } begin[0] <- struct{}{} val, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(val, gc.Equals, results[0]) timeout := time.After(shortWait) for i, r := range results[1:] { begin[i+1] <- struct{}{} select { case <-r.closed: case <-timeout: c.Fatalf("timed out waiting for close") } } select { case <-results[0].closed: c.Fatalf("result was inappropriately closed") case <-time.After(shortWait): } }
func (*trySuite) TestMaxParallel(c *gc.C) { try := parallel.NewTry(3, nil) var ( mu sync.Mutex count int max int ) for i := 0; i < 10; i++ { try.Start(func(<-chan struct{}) (io.Closer, error) { mu.Lock() if count++; count > max { max = count } c.Check(count, gc.Not(jc.GreaterThan), 3) mu.Unlock() time.Sleep(20 * time.Millisecond) mu.Lock() count-- mu.Unlock() return result("hello"), nil }) } r, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(r, gc.Equals, result("hello")) mu.Lock() defer mu.Unlock() c.Assert(max, gc.Equals, 3) }
func (*trySuite) TestOutOfOrderResults(c *gc.C) { try := parallel.NewTry(0, nil) try.Start(tryFunc(50*time.Millisecond, result("first"), nil)) try.Start(tryFunc(10*time.Millisecond, result("second"), nil)) r, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(r, gc.Equals, result("second")) }
func (*trySuite) TestCloseTwice(c *gc.C) { try := parallel.NewTry(0, nil) try.Close() try.Close() val, err := try.Result() c.Assert(val, gc.IsNil) c.Assert(err, gc.IsNil) }
// waitSSH waits for the instance to be assigned a routable // address, then waits until we can connect to it via SSH. // // waitSSH attempts on all addresses returned by the instance // in parallel; the first succeeding one wins. We ensure that // private addresses are for the correct machine by checking // the presence of a file on the machine that contains the // machine's nonce. The "checkHostScript" is a bash script // that performs this file check. func waitSSH(ctx environs.BootstrapContext, interrupted <-chan os.Signal, client ssh.Client, checkHostScript string, inst addresser, timeout config.SSHTimeoutOpts) (addr string, err error) { globalTimeout := time.After(timeout.Timeout) pollAddresses := time.NewTimer(0) // checker checks each address in a loop, in parallel, // until one succeeds, the global timeout is reached, // or the tomb is killed. checker := parallelHostChecker{ Try: parallel.NewTry(0, nil), client: client, stderr: ctx.GetStderr(), active: make(map[network.Address]chan struct{}), checkDelay: timeout.RetryDelay, checkHostScript: checkHostScript, } defer checker.wg.Wait() defer checker.Kill() fmt.Fprintln(ctx.GetStderr(), "Waiting for address") for { select { case <-pollAddresses.C: pollAddresses.Reset(timeout.AddressesDelay) if err := inst.Refresh(); err != nil { return "", fmt.Errorf("refreshing addresses: %v", err) } addresses, err := inst.Addresses() if err != nil { return "", fmt.Errorf("getting addresses: %v", err) } checker.UpdateAddresses(addresses) case <-globalTimeout: checker.Close() lastErr := checker.Wait() format := "waited for %v " args := []interface{}{timeout.Timeout} if len(checker.active) == 0 { format += "without getting any addresses" } else { format += "without being able to connect" } if lastErr != nil && lastErr != parallel.ErrStopped { format += ": %v" args = append(args, lastErr) } return "", fmt.Errorf(format, args...) case <-interrupted: return "", fmt.Errorf("interrupted") case <-checker.Dead(): result, err := checker.Result() if err != nil { return "", err } return result.(*hostChecker).addr.Value, nil } } }
func (*trySuite) TestStartReturnsErrorAfterClose(c *gc.C) { try := parallel.NewTry(0, nil) expectErr := errors.New("foo") err := try.Start(tryFunc(0, nil, expectErr)) c.Assert(err, gc.IsNil) try.Close() err = try.Start(tryFunc(0, result("goodbye"), nil)) c.Assert(err, gc.Equals, parallel.ErrClosed) // Wait for the first try to deliver its result time.Sleep(shortWait) try.Kill() err = try.Wait() c.Assert(err, gc.Equals, expectErr) }
func (*trySuite) TestTriesAreStopped(c *gc.C) { try := parallel.NewTry(0, nil) stopped := make(chan struct{}) try.Start(func(stop <-chan struct{}) (io.Closer, error) { <-stop stopped <- struct{}{} return nil, parallel.ErrStopped }) try.Start(tryFunc(0, result("hello"), nil)) val, err := try.Result() c.Assert(err, gc.IsNil) c.Assert(val, gc.Equals, result("hello")) select { case <-stopped: case <-time.After(longWait): c.Fatalf("timed out waiting for stop") } }
// Connect establishes a websocket connection to the API server using // the Info, API path tail and (optional) request headers provided. If // multiple API addresses are provided in Info they will be tried // concurrently - the first successful connection wins. // // The path tail may be blank, in which case the default value will be // used. Otherwise, it must start with a "/". func Connect(info *Info, pathTail string, header http.Header, opts DialOpts) (*websocket.Conn, error) { if len(info.Addrs) == 0 { return nil, errors.New("no API addresses to connect to") } if pathTail != "" && !strings.HasPrefix(pathTail, "/") { return nil, errors.New(`path tail must start with "/"`) } pool := x509.NewCertPool() xcert, err := cert.ParseCert(info.CACert) if err != nil { return nil, errors.Annotate(err, "cert pool creation failed") } pool.AddCert(xcert) path := makeAPIPath(info.EnvironTag.Id(), pathTail) // Dial all addresses at reasonable intervals. try := parallel.NewTry(0, nil) defer try.Kill() for _, addr := range info.Addrs { err := dialWebsocket(addr, path, header, opts, pool, try) if err == parallel.ErrStopped { break } if err != nil { return nil, errors.Trace(err) } select { case <-time.After(opts.DialAddressInterval): case <-try.Dead(): } } try.Close() result, err := try.Result() if err != nil { return nil, errors.Trace(err) } conn := result.(*websocket.Conn) logger.Infof("connection established to %q", conn.RemoteAddr()) return conn, nil }
func (*trySuite) TestOneFailure(c *gc.C) { try := parallel.NewTry(0, nil) expectErr := errors.New("foo") err := try.Start(tryFunc(0, nil, expectErr)) c.Assert(err, gc.IsNil) select { case <-try.Dead(): c.Fatalf("try died before it should") case <-time.After(shortWait): } try.Close() select { case <-try.Dead(): case <-time.After(longWait): c.Fatalf("timed out waiting for Try to complete") } val, err := try.Result() c.Assert(val, gc.IsNil) c.Assert(err, gc.Equals, expectErr) }
func (*trySuite) TestAllConcurrent(c *gc.C) { try := parallel.NewTry(0, nil) started := make(chan chan struct{}) for i := 0; i < 10; i++ { try.Start(func(<-chan struct{}) (io.Closer, error) { reply := make(chan struct{}) started <- reply <-reply return result("hello"), nil }) } timeout := time.After(longWait) for i := 0; i < 10; i++ { select { case reply := <-started: reply <- struct{}{} case <-timeout: c.Fatalf("timed out") } } }
func (*trySuite) TestEverything(c *gc.C) { try := parallel.NewTry(5, gradedErrorCombine) tries := []struct { startAt time.Duration wait time.Duration val result err error }{{ wait: 30 * time.Millisecond, err: gradedError(3), }, { startAt: 10 * time.Millisecond, wait: 20 * time.Millisecond, val: result("result 1"), }, { startAt: 20 * time.Millisecond, wait: 10 * time.Millisecond, val: result("result 2"), }, { startAt: 20 * time.Millisecond, wait: 5 * time.Second, val: "delayed result", }, { startAt: 5 * time.Millisecond, err: gradedError(4), }} for _, t := range tries { t := t go func() { time.Sleep(t.startAt) try.Start(tryFunc(t.wait, t.val, t.err)) }() } val, err := try.Result() if val != result("result 1") && val != result("result 2") { c.Errorf(`expected "result 1" or "result 2" got %#v`, val) } c.Assert(err, gc.IsNil) }
// connectWebsocket establishes a websocket connection to the RPC // API websocket on the API server using Info. If multiple API addresses // are provided in Info they will be tried concurrently - the first successful // connection wins. // // It also returns the TLS configuration that it has derived from the Info. func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) { if len(info.Addrs) == 0 { return nil, nil, errors.New("no API addresses to connect to") } tlsConfig, err := tlsConfigForCACert(info.CACert) if err != nil { return nil, nil, errors.Annotatef(err, "cannot make TLS configuration") } path := "/" if info.EnvironTag.Id() != "" { path = apiPath(info.EnvironTag, "/api") } // Dial all addresses at reasonable intervals. try := parallel.NewTry(0, nil) defer try.Kill() for _, addr := range info.Addrs { err := dialWebsocket(addr, path, opts, tlsConfig, try) if err == parallel.ErrStopped { break } if err != nil { return nil, nil, errors.Trace(err) } select { case <-time.After(opts.DialAddressInterval): case <-try.Dead(): } } try.Close() result, err := try.Result() if err != nil { return nil, nil, errors.Trace(err) } conn := result.(*websocket.Conn) logger.Infof("connection established to %q", conn.RemoteAddr()) return conn, tlsConfig, nil }
func (*trySuite) TestErrorCombine(c *gc.C) { // Use maxParallel=1 to guarantee that all errors are processed sequentially. try := parallel.NewTry(1, func(err0, err1 error) error { if err0 == nil { err0 = &multiError{} } err0.(*multiError).errs = append(err0.(*multiError).errs, int(err1.(gradedError))) return err0 }) errors := []gradedError{3, 2, 4, 0, 5, 5, 3} for _, err := range errors { err := err try.Start(func(<-chan struct{}) (io.Closer, error) { return nil, err }) } try.Close() val, err := try.Result() c.Assert(val, gc.IsNil) grades := err.(*multiError).errs sort.Ints(grades) c.Assert(grades, gc.DeepEquals, []int{0, 2, 3, 3, 4, 5, 5}) }
// dialWebsocketMulti dials a websocket with one of the provided addresses, the // specified URL path, TLS configuration, and dial options. Each of the // specified addresses will be attempted concurrently, and the first // successful connection will be returned. func dialWebsocketMulti(addrs []string, path string, tlsConfig *tls.Config, opts DialOpts) (*websocket.Conn, error) { // Dial all addresses at reasonable intervals. try := parallel.NewTry(0, nil) defer try.Kill() for _, addr := range addrs { err := startDialWebsocket(try, addr, path, opts, tlsConfig) if err == parallel.ErrStopped { break } if err != nil { return nil, errors.Trace(err) } select { case <-time.After(opts.DialAddressInterval): case <-try.Dead(): } } try.Close() result, err := try.Result() if err != nil { return nil, errors.Trace(err) } return result.(*websocket.Conn), nil }
// newAPIFromStore implements the bulk of NewAPIClientFromName // but is separate for testing purposes. func newAPIFromStore(envName string, store configstore.Storage, apiOpen apiOpenFunc) (apiState, error) { // Try to read the default environment configuration file. // If it doesn't exist, we carry on in case // there's some environment info for that environment. // This enables people to copy environment files // into their .juju/environments directory and have // them be directly useful with no further configuration changes. envs, err := environs.ReadEnvirons("") if err == nil { if envName == "" { envName = envs.Default } if envName == "" { return nil, fmt.Errorf("no default environment found") } } else if !environs.IsNoEnv(err) { return nil, err } // Try to connect to the API concurrently using two different // possible sources of truth for the API endpoint. Our // preference is for the API endpoint cached in the API info, // because we know that without needing to access any remote // provider. However, the addresses stored there may no longer // be current (and the network connection may take a very long // time to time out) so we also try to connect using information // found from the provider. We only start to make that // connection after some suitable delay, so that in the // hopefully usual case, we will make the connection to the API // and never hit the provider. By preference we use provider // attributes from the config store, but for backward // compatibility reasons, we fall back to information from // ReadEnvirons if that does not exist. chooseError := func(err0, err1 error) error { if err0 == nil { return err1 } if errorImportance(err0) < errorImportance(err1) { err0, err1 = err1, err0 } logger.Warningf("discarding API open error: %v", err1) return err0 } try := parallel.NewTry(0, chooseError) info, err := store.ReadInfo(envName) if err != nil && !errors.IsNotFound(err) { return nil, err } var delay time.Duration if info != nil && len(info.APIEndpoint().Addresses) > 0 { logger.Debugf( "trying cached API connection settings - endpoints %v", info.APIEndpoint().Addresses, ) try.Start(func(stop <-chan struct{}) (io.Closer, error) { return apiInfoConnect(info, apiOpen, stop) }) // Delay the config connection until we've spent // some time trying to connect to the cached info. delay = providerConnectDelay } else { logger.Debugf("no cached API connection settings found") } try.Start(func(stop <-chan struct{}) (io.Closer, error) { cfg, err := getConfig(info, envs, envName) if err != nil { return nil, err } return apiConfigConnect(cfg, apiOpen, stop, delay, environInfoUserTag(info)) }) try.Close() val0, err := try.Result() if err != nil { if ierr, ok := err.(*infoConnectError); ok { // lose error encapsulation: err = ierr.error } return nil, err } st := val0.(apiState) addrConnectedTo, err := serverAddress(st.Addr()) if err != nil { return nil, err } // Even though we are about to update API addresses based on // APIHostPorts in cacheChangedAPIInfo, we first cache the // addresses based on the provider lookup. This is because older API // servers didn't return their HostPort information on Login, and we // still want to cache our connection information to them. if cachedInfo, ok := st.(apiStateCachedInfo); ok { st = cachedInfo.apiState if cachedInfo.cachedInfo != nil && info != nil { // Cache the connection settings only if we used the // environment config, but any errors are just logged // as warnings, because they're not fatal. err = cacheAPIInfo(st, info, cachedInfo.cachedInfo) if err != nil { logger.Warningf("cannot cache API connection settings: %v", err.Error()) } else { logger.Infof("updated API connection settings cache") } addrConnectedTo, err = serverAddress(st.Addr()) if err != nil { return nil, err } } } // Update API addresses if they've changed. Error is non-fatal. // For older servers, the environ tag or server tag may not be set. // if they are not, we store empty values. var environUUID string var serverUUID string if envTag, err := st.EnvironTag(); err == nil { environUUID = envTag.Id() } if serverTag, err := st.ServerTag(); err == nil { serverUUID = serverTag.Id() } if localerr := cacheChangedAPIInfo(info, st.APIHostPorts(), addrConnectedTo, environUUID, serverUUID); localerr != nil { logger.Warningf("cannot cache API addresses: %v", localerr) } return st, nil }
// WaitSSH waits for the instance to be assigned a routable // address, then waits until we can connect to it via SSH. // // waitSSH attempts on all addresses returned by the instance // in parallel; the first succeeding one wins. We ensure that // private addresses are for the correct machine by checking // the presence of a file on the machine that contains the // machine's nonce. The "checkHostScript" is a bash script // that performs this file check. func WaitSSH( stdErr io.Writer, interrupted <-chan os.Signal, client ssh.Client, checkHostScript string, inst InstanceRefresher, opts environs.BootstrapDialOpts, ) (addr string, err error) { globalTimeout := time.After(opts.Timeout) pollAddresses := time.NewTimer(0) // checker checks each address in a loop, in parallel, // until one succeeds, the global timeout is reached, // or the tomb is killed. checker := parallelHostChecker{ Try: parallel.NewTry(0, nil), client: client, stderr: stdErr, active: make(map[network.Address]chan struct{}), checkDelay: opts.RetryDelay, checkHostScript: checkHostScript, } defer checker.wg.Wait() defer checker.Kill() fmt.Fprintln(stdErr, "Waiting for address") for { select { case <-pollAddresses.C: pollAddresses.Reset(opts.AddressesDelay) if err := inst.Refresh(); err != nil { return "", fmt.Errorf("refreshing addresses: %v", err) } instanceStatus := inst.Status() if instanceStatus.Status == status.ProvisioningError { if instanceStatus.Message != "" { return "", errors.Errorf("instance provisioning failed (%v)", instanceStatus.Message) } return "", errors.Errorf("instance provisioning failed") } addresses, err := inst.Addresses() if err != nil { return "", fmt.Errorf("getting addresses: %v", err) } checker.UpdateAddresses(addresses) case <-globalTimeout: checker.Close() lastErr := checker.Wait() format := "waited for %v " args := []interface{}{opts.Timeout} if len(checker.active) == 0 { format += "without getting any addresses" } else { format += "without being able to connect" } if lastErr != nil && lastErr != parallel.ErrStopped { format += ": %v" args = append(args, lastErr) } return "", fmt.Errorf(format, args...) case <-interrupted: return "", fmt.Errorf("interrupted") case <-checker.Dead(): result, err := checker.Result() if err != nil { return "", err } return result.(*hostChecker).addr.Value, nil } } }
func Open(info *Info, opts DialOpts) (*State, error) { if len(info.Addrs) == 0 { return nil, fmt.Errorf("no API addresses to connect to") } pool := x509.NewCertPool() xcert, err := cert.ParseCert(info.CACert) if err != nil { return nil, err } pool.AddCert(xcert) var environUUID string if info.EnvironTag != nil { environUUID = info.EnvironTag.Id() } // Dial all addresses at reasonable intervals. try := parallel.NewTry(0, nil) defer try.Kill() var addrs []string for _, addr := range info.Addrs { if strings.HasPrefix(addr, "localhost:") { addrs = append(addrs, addr) break } } if len(addrs) == 0 { addrs = info.Addrs } for _, addr := range addrs { err := dialWebsocket(addr, environUUID, opts, pool, try) if err == parallel.ErrStopped { break } if err != nil { return nil, err } select { case <-time.After(opts.DialAddressInterval): case <-try.Dead(): } } try.Close() result, err := try.Result() if err != nil { return nil, err } conn := result.(*websocket.Conn) logger.Infof("connection established to %q", conn.RemoteAddr()) client := rpc.NewConn(jsoncodec.NewWebsocket(conn), nil) client.Start() st := &State{ client: client, conn: conn, addr: conn.Config().Location.Host, serverRoot: "https://" + conn.Config().Location.Host, // why are the contents of the tag (username and password) written into the // state structure BEFORE login ?!? tag: toString(info.Tag), password: info.Password, certPool: pool, } if info.Tag != nil || info.Password != "" { if err := st.Login(info.Tag.String(), info.Password, info.Nonce); err != nil { conn.Close() return nil, err } } st.broken = make(chan struct{}) st.closed = make(chan struct{}) go st.heartbeatMonitor() return st, nil }
func (*trySuite) TestStartBlocksForMaxParallel(c *gc.C) { try := parallel.NewTry(3, nil) started := make(chan struct{}) begin := make(chan struct{}) go func() { for i := 0; i < 6; i++ { err := try.Start(func(<-chan struct{}) (io.Closer, error) { <-begin return nil, fmt.Errorf("an error") }) started <- struct{}{} if i < 5 { c.Check(err, gc.IsNil) } else { c.Check(err, gc.Equals, parallel.ErrClosed) } } close(started) }() // Check we can start the first three. timeout := time.After(longWait) for i := 0; i < 3; i++ { select { case <-started: case <-timeout: c.Fatalf("timed out") } } // Check we block when going above maxParallel. timeout = time.After(shortWait) select { case <-started: c.Fatalf("Start did not block") case <-timeout: } // Unblock two attempts. begin <- struct{}{} begin <- struct{}{} // Check we can start another two. timeout = time.After(longWait) for i := 0; i < 2; i++ { select { case <-started: case <-timeout: c.Fatalf("timed out") } } // Check we block again when going above maxParallel. timeout = time.After(shortWait) select { case <-started: c.Fatalf("Start did not block") case <-timeout: } // Close the Try - the last request should be discarded, // unblocking last remaining Start request. try.Close() timeout = time.After(longWait) select { case <-started: case <-timeout: c.Fatalf("Start did not unblock after Close") } // Ensure all checks are completed select { case _, ok := <-started: c.Assert(ok, gc.Equals, false) case <-timeout: c.Fatalf("Start goroutine did not finish") } }
// newAPIFromStore implements the bulk of NewAPIConnection but is separate for // testing purposes. func newAPIFromStore(args NewAPIConnectionParams, apiOpen api.OpenFunc) (api.Connection, error) { controllerDetails, err := args.Store.ControllerByName(args.ControllerName) if err != nil { return nil, errors.Annotate(err, "getting controller details") } // Try to connect to the API concurrently using two different // possible sources of truth for the API endpoint. Our // preference is for the API endpoint cached in the API info, // because we know that without needing to access any remote // provider. However, the addresses stored there may no longer // be current (and the network connection may take a very long // time to time out) so we also try to connect using information // found from the provider. We only start to make that // connection after some suitable delay, so that in the // hopefully usual case, we will make the connection to the API // and never hit the provider. chooseError := func(err0, err1 error) error { if err0 == nil { return err1 } if errorImportance(err0) < errorImportance(err1) { err0, err1 = err1, err0 } logger.Warningf("discarding API open error: %v", err1) return err0 } try := parallel.NewTry(0, chooseError) var delay time.Duration if len(controllerDetails.APIEndpoints) > 0 { try.Start(func(stop <-chan struct{}) (io.Closer, error) { return apiInfoConnect( controllerDetails, args.AccountDetails, args.ModelUUID, apiOpen, stop, args.DialOpts, ) }) // Delay the config connection until we've spent // some time trying to connect to the cached info. delay = providerConnectDelay } else { logger.Debugf("no cached API connection settings found") } // If the client has bootstrap config for the controller, we'll also // attempt to connect by fetching new addresses from the cloud // directly. This is only attempted after a delay, to give the // faster cached-addresses method a chance to complete. cfg, err := args.BootstrapConfig(args.ControllerName) if err == nil { try.Start(func(stop <-chan struct{}) (io.Closer, error) { cfg, err := apiConfigConnect( cfg, args.AccountDetails, args.ModelUUID, apiOpen, stop, delay, args.DialOpts, ) if err != nil { // Errors are swallowed by parallel.Try, so we // log the failure here to aid in debugging. logger.Debugf("failed to connect via bootstrap config: %v", err) } return cfg, err }) } else if !errors.IsNotFound(err) || len(controllerDetails.APIEndpoints) == 0 { return nil, err } try.Close() val0, err := try.Result() if err != nil { if ierr, ok := err.(*infoConnectError); ok { // lose error encapsulation: err = ierr.error } return nil, err } st := val0.(api.Connection) addrConnectedTo, err := serverAddress(st.Addr()) if err != nil { return nil, err } // Update API addresses if they've changed. Error is non-fatal. hostPorts := st.APIHostPorts() if localerr := updateControllerAddresses( args.Store, args.ControllerName, controllerDetails, hostPorts, addrConnectedTo, ); localerr != nil { logger.Warningf("cannot cache API addresses: %v", localerr) } return st, nil }