func reconnect(origin, deviceId, channelId, endpoint string) (err error) { socket, err := ws.Dial(origin, "", origin) if err != nil { return fmt.Errorf("Error dialing origin: %s", err) } connId, err := id.Generate() if err != nil { return fmt.Errorf("Error generating connection ID: %#v", err) } conn := client.NewConn(socket, connId, true) defer conn.Close() defer conn.Purge() actualId, err := conn.WriteHelo(deviceId, channelId) if err != nil { return fmt.Errorf("Error writing handshake request: %s", err) } if actualId != deviceId { return fmt.Errorf("Mismatched device IDs: got %q; want %q", actualId, deviceId) } if err = roundTrip(conn, deviceId, channelId, endpoint, 2); err != nil { return fmt.Errorf("Error sending notification after reconnect: %s", err) } return nil }
func TestUnregisterRace(t *testing.T) { origin, err := testServer.Origin() if err != nil { t.Fatalf("Error initializing test server: %#v", err) } socket, err := ws.Dial(origin, "", origin) if err != nil { t.Fatalf("Error dialing origin: %#v", err) } connId, err := id.Generate() if err != nil { t.Fatalf("Error generating connection ID: %#v", err) } // Spool all notifications, including those received on dregistered channels. conn := client.NewConn(socket, connId, true) defer conn.Close() if _, err = conn.WriteHelo(""); err != nil { t.Fatalf("Error writing handshake request: %#v", err) } defer conn.Purge() channelId, endpoint, err := conn.Subscribe() if err != nil { t.Fatalf("Error subscribing to channel: %#v", err) } if !isValidEndpoint(endpoint) { t.Fatalf("Invalid push endpoint for channel %#v: %#v", channelId, endpoint) } version := time.Now().UTC().Unix() var notifyWait sync.WaitGroup signal, errors := make(chan bool), make(chan error) notifyWait.Add(2) go func() { defer notifyWait.Done() timeout := time.After(1 * time.Minute) var ( isRemoved bool pendingTimer <-chan time.Time ) for ok := true; ok; { var packet client.Packet select { case ok = <-signal: case <-timeout: ok = false errors <- client.ErrTimedOut case <-pendingTimer: ok = false // Read the update, but don't call AcceptUpdate(). case packet, ok = <-conn.Packets: if !ok { err = client.ErrChanClosed break } var ( updates client.ServerUpdates hasUpdates bool ) if updates, hasUpdates = packet.(client.ServerUpdates); !hasUpdates { break } var ( update client.Update hasUpdate bool ) for _, update = range updates { if hasUpdate = update.ChannelId == channelId; hasUpdate { break } } if !hasUpdate { break } var err error if update.Version != version { err = fmt.Errorf("Expected update %#v, not %#v", version, update.Version) } else if isRemoved { err = fmt.Errorf("Update %#v resent on deregistered channel %#v", update.Version, update.ChannelId) } else { err = conn.Unregister(channelId) } if err != nil { ok = false errors <- err break } isRemoved = true timeout = nil // Queued updates should be sent immediately. pendingTimer = time.After(1 * time.Second) } } }() go func() { defer notifyWait.Done() select { case <-signal: case errors <- client.Notify(endpoint, version): } }() go func() { notifyWait.Wait() close(errors) }() for err = range errors { if err != nil { close(signal) t.Fatal(err) } } }