// Drop removes a channel ID associated with the given device ID from // memcached. Deregistration calls should call s.Unregister() instead. // Implements Store.Drop(). func (s *EmceeStore) Drop(uaid, chid string) (err error) { if len(uaid) == 0 { return ErrNoID } if len(chid) == 0 { return ErrNoChannel } if !id.Valid(uaid) { return ErrInvalidID } if !id.Valid(chid) { return ErrInvalidChannel } client, err := s.getClient() if err != nil { return err } defer s.releaseWithout(client, &err) key, ok := s.IDsToKey(uaid, chid) if !ok { return ErrInvalidKey } if err = client.Delete(key, 0); err == nil || isMissing(err) { return nil } return err }
// Register creates and stores a channel record for the given device ID and // channel ID. If version > 0, the record will be marked as active. Implements // Store.Register(). func (s *GomemcStore) Register(uaid, chid string, version int64) (err error) { if len(uaid) == 0 { return ErrNoID } if len(chid) == 0 { return ErrNoChannel } if !id.Valid(uaid) { return ErrInvalidID } if !id.Valid(chid) { return ErrInvalidChannel } return s.storeRegister(uaid, chid, version) }
// Unregister marks the channel ID associated with the given device ID // as inactive. Implements Store.Unregister(). func (s *GomemcStore) Unregister(uaid, chid string) (err error) { if len(uaid) == 0 { return ErrNoID } if len(chid) == 0 { return ErrNoChannel } if !id.Valid(uaid) { return ErrInvalidID } if !id.Valid(chid) { return ErrInvalidChannel } return s.storeUnregister(uaid, chid) }
// Update updates the version for the given device ID and channel ID. // Implements Store.Update(). func (s *GomemcStore) Update(uaid, chid string, version int64) (err error) { if len(uaid) == 0 { return ErrNoID } if len(chid) == 0 { return ErrNoChannel } // Normalize the device and channel IDs. if !id.Valid(uaid) { return ErrInvalidID } if !id.Valid(chid) { return ErrInvalidChannel } return s.storeUpdate(uaid, chid, version) }
// DropAll removes all channel records for the given device ID. Implements // Store.DropAll(). func (s *EmceeStore) DropAll(uaid string) error { if !id.Valid(uaid) { return ErrInvalidID } chids, err := s.fetchAppIDArray(uaid) if err != nil && !isMissing(err) { return err } client, err := s.getClient() if err != nil { return err } defer s.releaseWithout(client, &err) for _, chid := range chids { key, ok := s.IDsToKey(uaid, chid) if !ok { return ErrInvalidKey } client.Delete(key, 0) } if err = client.Delete(uaid, 0); err != nil && !isMissing(err) { return err } return nil }
func TestNilDeviceId(t *testing.T) { origin, err := testServer.Origin() if err != nil { t.Fatalf("Error initializing test server: %#v", err) } conn, err := client.DialOrigin(origin) if err != nil { t.Fatalf("Error dialing origin: %#v", err) } defer conn.Close() defer conn.Purge() request := CustomHelo{ MessageType: "hello", DeviceId: NilId, ChannelIds: []interface{}{}, Extra: "extra field", replies: make(chan client.Reply), errors: make(chan error), } reply, err := conn.WriteRequest(request) if err != nil { t.Fatalf("Error writing handshake request: %#v", err) } helo, ok := reply.(client.ServerHelo) if !ok { t.Errorf("Type assertion failed for handshake reply: %#v", reply) } if !id.Valid(helo.DeviceId) { t.Errorf("Got invalid device ID: %#v", helo.DeviceId) } }
// PutPing stores the proprietary ping info blob for the given device ID in // memcached. Implements Store.PutPing(). func (s *GomemcStore) PutPing(uaid string, pingData []byte) error { if !id.Valid(uaid) { return ErrInvalidID } return s.client.Set(&mc.Item{ Key: s.PingPrefix + uaid, Value: pingData, Expiration: 0}) }
// DropPing removes all proprietary ping info for the given device ID. // Implements Store.DropPing(). func (s *GomemcStore) DropPing(uaid string) error { if len(uaid) == 0 { return ErrNoID } if !id.Valid(uaid) { return ErrInvalidID } return s.client.Delete(s.PingPrefix + uaid) }
// Drop removes a channel ID associated with the given device ID from // memcached. Deregistration calls should call s.Unregister() instead. // Implements Store.Drop(). func (s *GomemcStore) Drop(uaid, chid string) (err error) { if len(uaid) == 0 { return ErrNoID } if len(chid) == 0 { return ErrNoChannel } if !id.Valid(uaid) { return ErrInvalidID } if !id.Valid(chid) { return ErrInvalidChannel } key := joinIDs(uaid, chid) if err = s.client.Delete(key); err != nil && err != mc.ErrCacheMiss { return err } return nil }
// PutPing stores the proprietary ping info blob for the given device ID in // memcached. Implements Store.PutPing(). func (s *EmceeStore) PutPing(uaid string, pingData []byte) (err error) { if !id.Valid(uaid) { return ErrInvalidID } client, err := s.getClient() defer s.releaseWithout(client, &err) if err != nil { return err } return client.Set(s.PingPrefix+uaid, pingData, 0) }
// Update updates the version for the given device ID and channel ID. // Implements Store.Update(). func (s *EmceeStore) Update(key string, version int64) (err error) { uaid, chid, ok := s.KeyToIDs(key) if !ok { return ErrInvalidKey } if len(uaid) == 0 { return ErrNoID } if len(chid) == 0 { return ErrNoChannel } // Normalize the device and channel IDs. if !id.Valid(uaid) { return ErrInvalidID } if !id.Valid(chid) { return ErrInvalidChannel } return s.storeUpdate(uaid, chid, version) }
func (t typeTest) Run() error { origin, err := testServer.Origin() if err != nil { return fmt.Errorf("On test %v, error initializing test server: %#v", t.name, err) } conn, err := client.DialOrigin(origin) if err != nil { return fmt.Errorf("On test %v, error dialing origin: %#v", t.name, err) } defer conn.Close() defer conn.Purge() request := CustomHelo{ MessageType: t.messageType, DeviceId: t.deviceId, ChannelIds: []interface{}{"1", "2"}, Extra: "custom value", replies: make(chan client.Reply), errors: make(chan error), } reply, err := conn.WriteRequest(request) if t.statusCode >= 200 && t.statusCode < 300 { if err != nil { return fmt.Errorf("On test %v, error writing handshake request: %#v", t.name, err) } helo, ok := reply.(client.ServerHelo) if !ok { return fmt.Errorf("On test %v, type assertion failed for handshake reply: %#v", t.name, reply) } if helo.StatusCode != t.statusCode { return fmt.Errorf("On test %v, unexpected reply status: got %#v; want %#v", t.name, helo.StatusCode, t.statusCode) } deviceId, _ := t.deviceId.(string) if len(deviceId) == 0 && !id.Valid(helo.DeviceId) { return fmt.Errorf("On test %v, got invalid device ID: %#v", t.name, helo.DeviceId) } else if !t.shouldReset && deviceId != helo.DeviceId { return fmt.Errorf("On test %v, mismatched device ID: got %#v; want %#v", t.name, helo.DeviceId, deviceId) } else if t.shouldReset && deviceId == helo.DeviceId { return fmt.Errorf("On test %v, want new device ID; got %#v", t.name, deviceId) } return nil } if err != io.EOF { return fmt.Errorf("On test %v, error writing handshake: got %#v; want io.EOF", t.name, err) } err = conn.Close() clientErr, ok := err.(client.Error) if !ok { return fmt.Errorf("On test %v, type assertion failed for close error: %#v", t.name, err) } if clientErr.Status() != t.statusCode { return fmt.Errorf("On test %v, unexpected close error status: got %#v; want %#v", t.name, clientErr.Status(), t.statusCode) } return nil }
// FetchPing retrieves proprietary ping information for the given device ID // from memcached. Implements Store.FetchPing(). func (s *GomemcStore) FetchPing(uaid string) (pingData []byte, err error) { if len(uaid) == 0 { return nil, ErrNoID } if !id.Valid(uaid) { return nil, ErrInvalidID } raw, err := s.client.Get(s.PingPrefix + uaid) if err != nil { return nil, err } return raw.Value, nil }
// DropPing removes all proprietary ping info for the given device ID. // Implements Store.DropPing(). func (s *EmceeStore) DropPing(uaid string) (err error) { if len(uaid) == 0 { return ErrNoID } if !id.Valid(uaid) { return ErrInvalidID } client, err := s.getClient() defer s.releaseWithout(client, &err) if err != nil { return err } return client.Delete(s.PingPrefix+uaid, 0) }
// FetchPing retrieves proprietary ping information for the given device ID // from memcached. Implements Store.FetchPing(). func (s *EmceeStore) FetchPing(uaid string) (pingData []byte, err error) { if len(uaid) == 0 { return nil, ErrNoID } if !id.Valid(uaid) { return nil, ErrInvalidID } client, err := s.getClient() defer s.releaseWithout(client, &err) if err != nil { return } err = client.Get(s.PingPrefix+uaid, &pingData) return }
// Exists returns a Boolean indicating whether a device has previously // registered with the Simple Push server. Implements Store.Exists(). func (s *EmceeStore) Exists(uaid string) bool { if ok, hasID := hasExistsHook(uaid); hasID { return ok } var err error if !id.Valid(uaid) { return false } if _, err = s.fetchAppIDArray(uaid); err != nil && !isMissing(err) { if s.logger.ShouldLog(ERROR) { s.logger.Error("emcee", "Exists encountered unknown error", LogFields{"error": err.Error()}) } } return err == nil }
// Exists returns a Boolean indicating whether a device has previously // registered with the Simple Push server. Implements Store.Exists(). func (s *GomemcStore) Exists(uaid string) bool { if ok, hasID := hasExistsHook(uaid); hasID { return ok } var err error if !id.Valid(uaid) { return false } if _, err = s.client.Get(uaid); err != nil && err != mc.ErrCacheMiss { if s.logger.ShouldLog(ERROR) { s.logger.Error("gomemc", "Exists encountered unknown error", LogFields{"uaid": uaid, "error": err.Error()}) } } return err == nil }
// ServeHTTP implements http.Handler.ServeHTTP. func (h *LogHandler) ServeHTTP(res http.ResponseWriter, req *http.Request) { receivedAt := time.Now() // The `X-Request-Id` header is used by Heroku, restify, etc. to correlate // logs for the same request. requestID := req.Header.Get(HeaderID) if !id.Valid(requestID) { requestID, _ = id.Generate() req.Header.Set(HeaderID, requestID) } writer := &logResponseWriter{ResponseWriter: res, StatusCode: http.StatusOK} defer h.logResponse(writer, req, requestID, receivedAt) h.Handler.ServeHTTP(writer, req) }
// DropAll removes all channel records for the given device ID. Implements // Store.DropAll(). func (s *GomemcStore) DropAll(uaid string) error { if !id.Valid(uaid) { return ErrInvalidID } chids, err := s.fetchAppIDArray(uaid) if err != nil && err != mc.ErrCacheMiss { return err } for _, chid := range chids { key := joinIDs(uaid, chid) s.client.Delete(key) } if err = s.client.Delete(uaid); err != nil && err != mc.ErrCacheMiss { return err } return nil }
// Register a new ChannelID. Optionally, encrypt the endpoint. func (w *WorkerWS) Register(header *RequestHeader, message []byte) (err error) { defer func() { if r := recover(); r != nil { if err, _ := r.(error); err != nil && w.logger.ShouldLog(ERROR) { stack := make([]byte, 1<<16) n := runtime.Stack(stack, false) w.logger.Error("worker", "Unhandled error", LogFields{"rid": w.logID, "cmd": "register", "error": ErrStr(err), "stack": string(stack[:n])}) } err = ErrInvalidParams } }() uaid := w.UAID() if uaid == "" { return ErrNoHandshake } request := new(RegisterRequest) if err = json.Unmarshal(message, request); err != nil || !id.Valid(request.ChannelID) { return ErrInvalidParams } if err = w.store.Register(uaid, request.ChannelID, 0); err != nil { if w.logger.ShouldLog(WARNING) { w.logger.Warn("worker", "Register failed, error updating backing store", LogFields{"rid": w.logID, "cmd": "register", "error": ErrStr(err)}) } return err } key, err := w.store.IDsToKey(uaid, request.ChannelID) if err != nil { if w.logger.ShouldLog(WARNING) { w.logger.Warn("worker", "Error generating primary key", LogFields{"rid": w.logID, "cmd": "register", "error": ErrStr(err)}) } return err } endpoint, err := w.app.CreateEndpoint(key) if err != nil { if w.logger.ShouldLog(WARNING) { w.logger.Warn("worker", "Error registering endpoint", LogFields{ "rid": w.logID, "uaid": uaid, "chid": request.ChannelID, "error": ErrStr(err)}) } return err } status, _ := ErrToStatus(err) if w.logger.ShouldLog(DEBUG) { w.logger.Debug("worker", "Server returned", LogFields{ "rid": w.logID, "cmd": "register", "code": strconv.FormatInt(int64(status), 10), "chid": request.ChannelID, "uaid": uaid}) } // return the info back to the socket if w.logger.ShouldLog(DEBUG) { w.logger.Debug("worker", "sending response", LogFields{ "rid": w.logID, "cmd": "register", "uaid": uaid, "code": strconv.FormatInt(int64(status), 10), "channelID": request.ChannelID, "pushEndpoint": endpoint}) } w.WriteJSON(RegisterReply{header.Type, uaid, status, request.ChannelID, endpoint}) w.metrics.Increment("updates.client.register") return nil }
// FetchAll returns all channel updates and expired channels for a device ID // since the specified cutoff time. Implements Store.FetchAll(). func (s *EmceeStore) FetchAll(uaid string, since time.Time) ([]Update, []string, error) { var err error if len(uaid) == 0 { return nil, nil, ErrNoID } if !id.Valid(uaid) { return nil, nil, err } chids, err := s.fetchAppIDArray(uaid) if err != nil && !isMissing(err) { return nil, nil, err } updates := make([]Update, 0, 20) expired := make([]string, 0, 20) keys := make([]string, len(chids)) for i, chid := range chids { keys[i] = joinIDs(uaid, chid) } if s.logger.ShouldLog(INFO) { s.logger.Info("emcee", "Fetching items", LogFields{ "uaid": uaid, "items": fmt.Sprintf("[%s]", strings.Join(keys, ", ")), }) } client, err := s.getClient() defer s.releaseWithout(client, &err) if err != nil { return nil, nil, err } sinceUnix := since.Unix() for index, key := range keys { channel := new(ChannelRecord) if err := client.Get(key, channel); err != nil { continue } chid := chids[index] channelString := chid if s.logger.ShouldLog(DEBUG) { s.logger.Debug("emcee", "FetchAll Fetched record ", LogFields{ "uaid": uaid, "chid": channelString, "value": fmt.Sprintf("%d,%s,%d", channel.LastTouched, channel.State, channel.Version), }) } if channel.LastTouched < sinceUnix { if s.logger.ShouldLog(DEBUG) { s.logger.Debug("emcee", "Skipping record...", LogFields{ "uaid": uaid, "chid": channelString, }) } continue } // Yay! Go translates numeric interface values as float64s // Apparently float64(1) != int(1). switch channel.State { case StateLive: version := channel.Version if version == 0 { version = uint64(time.Now().UTC().Unix()) if s.logger.ShouldLog(DEBUG) { s.logger.Debug("emcee", "FetchAll Using Timestamp", LogFields{ "uaid": uaid, "chid": channelString, }) } } update := Update{ ChannelID: channelString, Version: version, } updates = append(updates, update) case StateDeleted: if s.logger.ShouldLog(DEBUG) { s.logger.Debug("emcee", "FetchAll Deleting record", LogFields{ "uaid": uaid, "chid": channelString, }) } expired = append(expired, chid) case StateRegistered: // Item registered, but not yet active. Ignore it. default: if s.logger.ShouldLog(WARNING) { s.logger.Warn("emcee", "Unknown state", LogFields{ "uaid": uaid, "chid": channelString, }) } } } return updates, expired, nil }
// handshake performs the opening handshake. func (w *WorkerWS) handshake(request *HelloRequest) ( deviceID string, allowRedirect bool, err error) { logWarning := w.logger.ShouldLog(WARNING) currentID := w.UAID() if request.ChannelIDs == nil { // Must include "channelIDs" (even if empty) if logWarning { w.logger.Warn("worker", "Missing ChannelIDs", LogFields{"rid": w.logID}) } return "", false, ErrNoParams } if len(currentID) > 0 { if len(request.DeviceID) == 0 || currentID == request.DeviceID { // Duplicate handshake with omitted or identical device ID. Allow the // caller to flush pending notifications, but avoid querying the balancer. if w.logger.ShouldLog(DEBUG) { w.logger.Debug("worker", "Duplicate client handshake", LogFields{"rid": w.logID}) } return currentID, false, nil } // if there's already a Uaid for this device, don't accept a new one if logWarning { w.logger.Warn("worker", "Conflicting UAIDs", LogFields{"rid": w.logID}) } return "", false, ErrExistingID } var ( prevWorker Worker workerConnected bool ) if len(request.DeviceID) == 0 { if w.logger.ShouldLog(DEBUG) { w.logger.Debug("worker", "Generating new UAID for device", LogFields{"rid": w.logID}) } goto forceReset } if !id.Valid(request.DeviceID) { if logWarning { w.logger.Warn("worker", "Invalid character in UAID", LogFields{"rid": w.logID}) } goto forceReset } if !w.store.CanStore(len(request.ChannelIDs)) { // are there a suspicious number of channels? if logWarning { w.logger.Warn("worker", "Too many channel IDs in handshake; resetting UAID", LogFields{ "rid": w.logID, "uaid": request.DeviceID, "channels": strconv.Itoa(len(request.ChannelIDs))}) } w.store.DropAll(request.DeviceID) goto forceReset } prevWorker, workerConnected = w.app.GetWorker(request.DeviceID) if workerConnected { if w.logger.ShouldLog(INFO) { w.logger.Info("worker", "UAID collision; disconnecting previous client", LogFields{"rid": w.logID, "uaid": request.DeviceID}) } prevWorker.Close() } if len(request.ChannelIDs) > 0 && !w.store.Exists(request.DeviceID) { if logWarning { w.logger.Warn("worker", "Channel IDs specified in handshake for nonexistent UAID", LogFields{"rid": w.logID, "uaid": request.DeviceID}) } goto forceReset } return request.DeviceID, true, nil forceReset: if deviceID, err = idGenerate(); err != nil { return "", false, err } return deviceID, true, nil }