func (rb *ManagedRoomBinding) GenerateMessageKey(ctx scope.Context, kms security.KMS) ( proto.RoomMessageKey, error) { rmkb, err := rb.Room.generateMessageKey(rb.Backend, kms) if err != nil { return nil, err } // Insert key and room association into the DB. transaction, err := rb.DbMap.Begin() if err != nil { return nil, err } if err := transaction.Insert(&rmkb.MessageKey); err != nil { if rerr := transaction.Rollback(); rerr != nil { logging.Logger(ctx).Printf("rollback error: %s", rerr) } return nil, err } if err := transaction.Insert(&rmkb.RoomMessageKey); err != nil { if rerr := transaction.Rollback(); rerr != nil { logging.Logger(ctx).Printf("rollback error: %s", rerr) } return nil, err } if err := transaction.Commit(); err != nil { return nil, err } return rmkb, nil }
func (s *session) join() error { msgs, err := s.room.Latest(s.ctx, 100, 0) if err != nil { return err } listing, err := s.room.Listing(s.ctx) if err != nil { return err } if err := s.room.Join(s.ctx, s); err != nil { logging.Logger(s.ctx).Printf("join failed: %s", err) return err } s.onClose = func() { if err := s.room.Part(s.ctx, s); err != nil { // TODO: error handling return } } if err := s.sendSnapshot(msgs, listing); err != nil { logging.Logger(s.ctx).Printf("snapshot failed: %s", err) return err } s.joined = true return nil }
func (c *Controller) background(ctx scope.Context) { defer ctx.WaitGroup().Done() var lastStatCheck time.Time for { logging.Logger(ctx).Printf("[%s] background loop", c.w.QueueName()) if time.Now().Sub(lastStatCheck) > StatsInterval { logging.Logger(ctx).Printf("[%s] collecting stats", c.w.QueueName()) stats, err := c.jq.Stats(ctx) if err != nil { logging.Logger(ctx).Printf("error: %s stats: %s", c.w.QueueName(), err) return } lastStatCheck = time.Now() labels := map[string]string{"queue": c.w.QueueName()} claimedGauge.With(labels).Set(float64(stats.Claimed)) dueGauge.With(labels).Set(float64(stats.Due)) waitingGauge.With(labels).Set(float64(stats.Waiting)) } if err := c.processOne(ctx); err != nil { // TODO: retry a couple times before giving up logging.Logger(ctx).Printf("error: %s: %s", c.w.QueueName(), err) return } } }
func (s *session) join() error { nick, ok, err := s.room.ResolveNick(s.ctx, s.Identity().ID()) if err != nil { return err } if ok { s.identity.name = nick } addr, err := s.room.Join(s.ctx, s) if err != nil { logging.Logger(s.ctx).Printf("join failed: %s", err) return err } s.vClientAddr = addr s.onClose = func() { // Use a fork of the server's root context, because the session's context // might be closed. ctx := s.server.rootCtx.Fork() if err := s.room.Part(ctx, s); err != nil { logging.Logger(ctx).Printf("room part failed: %s", err) return } } if err := s.sendSnapshot(); err != nil { logging.Logger(s.ctx).Printf("snapshot failed: %s", err) return err } s.joined = true return nil }
func (m *accountManager) Register( ctx scope.Context, kms security.KMS, namespace, id, password string, agentID string, agentKey *security.ManagedKey) ( proto.Account, *security.ManagedKey, error) { m.b.Lock() defer m.b.Unlock() key := fmt.Sprintf("%s:%s", namespace, id) if _, ok := m.b.accountIDs[key]; ok { return nil, nil, proto.ErrPersonalIdentityInUse } account, clientKey, err := NewAccount(kms, password) if err != nil { return nil, nil, err } if namespace == "email" { account.(*memAccount).email = id } if m.b.accounts == nil { m.b.accounts = map[snowflake.Snowflake]proto.Account{account.ID(): account} } else { m.b.accounts[account.ID()] = account } pid := &personalIdentity{ accountID: account.ID(), namespace: namespace, id: id, } account.(*memAccount).personalIdentities = []proto.PersonalIdentity{pid} if m.b.accountIDs == nil { m.b.accountIDs = map[string]*personalIdentity{key: pid} } else { m.b.accountIDs[key] = pid } agent, err := m.b.AgentTracker().Get(ctx, agentID) if err != nil { logging.Logger(ctx).Printf( "error locating agent %s for new account %s:%s: %s", agentID, namespace, id, err) } else { if err := agent.SetClientKey(agentKey, clientKey); err != nil { logging.Logger(ctx).Printf( "error associating agent %s with new account %s:%s: %s", agentID, namespace, id, err) } agent.AccountID = account.ID().String() } return account, clientKey, nil }
func (b *Backend) part(ctx scope.Context, rb *RoomBinding, session proto.Session) error { t, err := b.DbMap.Begin() if err != nil { return err } _, err = t.Exec( "DELETE FROM presence WHERE room = $1 AND server_id = $2 AND server_era = $3 AND session_id = $4", rb.RoomName, b.desc.ID, b.desc.Era, session.ID()) if err != nil { rollback(ctx, t) logging.Logger(ctx).Printf("failed to persist departure: %s", err) return err } // Broadcast a presence event. // TODO: make this an explicit action via the Room protocol, to support encryption event := proto.PresenceEvent(session.View(proto.Staff)) if err := rb.broadcast(ctx, t, proto.PartEventType, event, session); err != nil { rollback(ctx, t) return err } if err := t.Commit(); err != nil { return err } b.Lock() if lm, ok := b.listeners[rb.RoomName]; ok { delete(lm, session.ID()) } b.Unlock() return nil }
func (s *session) CheckAbandoned() error { s.m.Lock() defer s.m.Unlock() logger := logging.Logger(s.ctx) if s.maybeAbandoned { // already in fast-keepalive state return nil } s.maybeAbandoned = true child := s.ctx.Fork() s.fastKeepAliveCancel = child.Cancel go func() { logger.Printf("starting fast-keepalive timer") timer := time.After(FastKeepAlive) select { case <-child.Done(): logger.Printf("aliased session still alive") case <-timer: logger.Printf("connection replaced") s.ctx.Terminate(ErrReplaced) } }() return s.sendPing() }
func (s *session) sendPing() error { logger := logging.Logger(s.ctx) now := time.Now() cmd, err := proto.MakeEvent(&proto.PingEvent{ UnixTime: proto.Time(now), NextUnixTime: proto.Time(now.Add(3 * KeepAlive / 2)), }) if err != nil { logger.Printf("error: ping event: %s", err) return err } data, err := cmd.Encode() if err != nil { logger.Printf("error: ping event encode: %s", err) return err } if err := s.conn.WriteMessage(websocket.TextMessage, data); err != nil { logger.Printf("error: write ping event: %s", err) return err } s.expectedPingReply = now.Unix() s.outstandingPings++ return nil }
func (s *session) sendHello(roomIsPrivate, accountHasAccess bool) error { logger := logging.Logger(s.ctx) event := &proto.HelloEvent{ SessionView: s.View(), AccountHasAccess: accountHasAccess, RoomIsPrivate: roomIsPrivate, Version: s.room.Version(), } if s.client.Account != nil { event.AccountView = s.client.Account.View(s.roomName) } event.ID = event.SessionView.ID cmd, err := proto.MakeEvent(event) if err != nil { logger.Printf("error: hello event: %s", err) return err } data, err := cmd.Encode() if err != nil { logger.Printf("error: hello event encode: %s", err) return err } if err := s.conn.WriteMessage(websocket.TextMessage, data); err != nil { logger.Printf("error: write hello event: %s", err) return err } return nil }
func (s *session) readMessages() { logger := logging.Logger(s.ctx) defer s.Close() for s.ctx.Err() == nil { messageType, data, err := s.conn.ReadMessage() if err != nil { if err == io.EOF { logger.Printf("client disconnected") return } logger.Printf("error: read message: %s", err) return } switch messageType { case websocket.TextMessage: cmd, err := proto.ParseRequest(data) if err != nil { logger.Printf("error: ParseRequest: %s", err) return } s.incoming <- cmd default: logger.Printf("error: unsupported message type: %v", messageType) return } } }
func Serve(ctx scope.Context, addr string) { http.Handle("/metrics", prometheus.Handler()) listener, err := net.Listen("tcp", addr) if err != nil { ctx.Terminate(err) } closed := false m := sync.Mutex{} closeListener := func() { m.Lock() if !closed { listener.Close() closed = true } m.Unlock() } // Spin off goroutine to watch ctx and close listener if shutdown requested. go func() { <-ctx.Done() closeListener() }() logging.Logger(ctx).Printf("serving /metrics on %s", addr) if err := http.Serve(listener, nil); err != nil { fmt.Printf("http[%s]: %s\n", addr, err) ctx.Terminate(err) } closeListener() ctx.WaitGroup().Done() }
func (et *EmailTracker) Send( ctx scope.Context, js jobs.JobService, templater *templates.Templater, deliverer emails.Deliverer, account proto.Account, to, templateName string, data interface{}) ( *emails.EmailRef, error) { if to == "" { to, _ = account.Email() } sf, err := snowflake.New() if err != nil { return nil, err } msgID := fmt.Sprintf("<%s@%s>", sf, deliverer.LocalName()) ref, err := emails.NewEmail(templater, msgID, to, templateName, data) if err != nil { return nil, err } ref.AccountID = account.ID() jq, err := js.GetQueue(ctx, jobs.EmailQueue) if err != nil { return nil, err } payload := &jobs.EmailJob{ AccountID: account.ID(), EmailID: ref.ID, } job, err := jq.AddAndClaim(ctx, jobs.EmailJobType, payload, "immediate", jobs.EmailJobOptions...) if err != nil { return nil, err } ref.JobID = job.ID et.m.Lock() if et.emailsByAccount == nil { et.emailsByAccount = map[snowflake.Snowflake][]*emails.EmailRef{} } et.emailsByAccount[account.ID()] = append(et.emailsByAccount[account.ID()], ref) et.m.Unlock() child := ctx.Fork() child.WaitGroup().Add(1) go job.Exec(child, func(ctx scope.Context) error { defer ctx.WaitGroup().Done() logging.Logger(ctx).Printf("delivering to %s\n", to) if err := deliverer.Deliver(ctx, ref); err != nil { return err } return nil }) return ref, nil }
func (s *session) handleAuthCommand(msg *proto.AuthCommand) *response { if s.joined { return &response{packet: &proto.AuthReply{Success: true}} } if s.authFailCount > 0 { buf := []byte{0} if _, err := rand.Read(buf); err != nil { return &response{err: err} } jitter := 4 * time.Duration(int(buf[0])-128) * time.Millisecond delay := authDelay + jitter if security.TestMode { delay = 0 } time.Sleep(delay) } authAttempts.WithLabelValues(s.roomName).Inc() var ( failureReason string err error ) switch msg.Type { case proto.AuthPasscode: if s.managedRoom == nil { failureReason = fmt.Sprintf("auth type not supported: %s", msg.Type) } else { failureReason, err = s.client.AuthenticateWithPasscode(s.ctx, s.managedRoom, msg.Passcode) } default: failureReason = fmt.Sprintf("auth type not supported: %s", msg.Type) } if err != nil { return &response{err: err} } if failureReason != "" { authFailures.WithLabelValues(s.roomName).Inc() s.authFailCount++ if s.authFailCount >= MaxAuthFailures { logging.Logger(s.ctx).Printf( "max authentication failures on room %s by %s", s.roomName, s.Identity().ID()) authTerminations.WithLabelValues(s.roomName).Inc() s.state = s.ignoreState } return &response{packet: &proto.AuthReply{Reason: failureReason}} } s.keyID = s.client.Authorization.CurrentMessageKeyID s.state = s.joinedState if err := s.join(); err != nil { s.keyID = "" s.state = s.unauthedState return &response{err: err} } return &response{packet: &proto.AuthReply{Success: true}} }
func (s *session) Send(ctx scope.Context, cmdType proto.PacketType, payload interface{}) error { // Special case: certain events have privileged info that may need to be stripped from them switch event := payload.(type) { case *proto.PresenceEvent: switch s.privilegeLevel() { case proto.Staff: case proto.Host: event.RealClientAddress = "" default: event.RealClientAddress = "" event.ClientAddress = "" } case *proto.Message: if s.privilegeLevel() == proto.General { event.Sender.ClientAddress = "" } case *proto.EditMessageEvent: if s.privilegeLevel() == proto.General { event.Sender.ClientAddress = "" } } var err error payload, err = proto.DecryptPayload(payload, &s.client.Authorization, s.privilegeLevel()) if err != nil { return err } encoded, err := json.Marshal(payload) if err != nil { return err } cmd := &proto.Packet{ Type: cmdType, Data: encoded, } // Add to outgoing channel. If channel is full, defer to goroutine so as not to block // the caller (this may result in deliveries coming out of order). select { case <-ctx.Done(): // Session is closed, return error. return ctx.Err() case s.outgoing <- cmd: // Packet delivered to queue. default: // Queue is full. logging.Logger(s.ctx).Printf("outgoing channel full, ordering cannot be guaranteed") go func() { s.outgoing <- cmd }() } return nil }
func (b *AccountManagerBinding) ChangeClientKey( ctx scope.Context, accountID snowflake.Snowflake, oldKey, newKey *security.ManagedKey) error { t, err := b.DbMap.Begin() if err != nil { return err } rollback := func() { if err := t.Rollback(); err != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } } row, err := t.Get(Account{}, accountID.String()) if err != nil { rollback() if err == sql.ErrNoRows { return proto.ErrAccountNotFound } return err } account := row.(*Account) sec := account.Bind(b.Backend).accountSecurity() if err := sec.ChangeClientKey(oldKey, newKey); err != nil { rollback() return err } res, err := t.Exec( "UPDATE account SET mac = $2, encrypted_user_key = $3 WHERE id = $1", accountID.String(), sec.MAC, sec.UserKey.Ciphertext) if err != nil { rollback() return err } n, err := res.RowsAffected() if err != nil { rollback() return err } if n == 0 { rollback() return proto.ErrAccountNotFound } if err := t.Commit(); err != nil { return err } return nil }
func (s *session) handleRegisterAccountCommand(cmd *proto.RegisterAccountCommand) *response { // Session must not be logged in. if s.client.Account != nil { return &response{packet: &proto.RegisterAccountReply{Reason: "already logged in"}} } // Agent must be of sufficient age. if time.Now().Sub(s.client.Agent.Created) < s.server.newAccountMinAgentAge { return &response{packet: &proto.RegisterAccountReply{Reason: "not familiar yet, try again later"}} } // Validate givens. if ok, reason := proto.ValidatePersonalIdentity(cmd.Namespace, cmd.ID); !ok { return &response{packet: &proto.RegisterAccountReply{Reason: reason}} } if ok, reason := proto.ValidateAccountPassword(cmd.Password); !ok { return &response{packet: &proto.RegisterAccountReply{Reason: reason}} } // Register the account. account, clientKey, err := s.backend.AccountManager().Register( s.ctx, s.kms, cmd.Namespace, cmd.ID, cmd.Password, s.client.Agent.IDString(), s.agentKey) if err != nil { switch err { case proto.ErrPersonalIdentityInUse: return &response{packet: &proto.RegisterAccountReply{Reason: err.Error()}} default: return &response{err: err} } } // Kick off on-registration tasks. if err := s.heim.OnAccountRegistration(s.ctx, s.backend, account, clientKey); err != nil { // Log this error only. logging.Logger(s.ctx).Printf("error on account registration: %s", err) } // Authorize session's agent to unlock account. err = s.backend.AgentTracker().SetClientKey( s.ctx, s.client.Agent.IDString(), s.agentKey, account.ID(), clientKey) if err != nil { return &response{err: err} } // Return successful response. reply := &proto.RegisterAccountReply{ Success: true, AccountID: account.ID(), } return &response{packet: reply} }
func (s *Server) serveRoomWebsocket( ctx scope.Context, room proto.Room, cookie *http.Cookie, client *proto.Client, agentKey *security.ManagedKey, w http.ResponseWriter, r *http.Request) { // Upgrade to a websocket and set cookie. headers := http.Header{} if cookie != nil { headers.Add("Set-Cookie", cookie.String()) } conn, err := upgrader.Upgrade(w, r, headers) if err != nil { logging.Logger(ctx).Printf("upgrade error: %s", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } defer conn.Close() // Determine client address. clientAddress := r.Header.Get("X-Forwarded-For") if clientAddress == "" { addr := conn.RemoteAddr() switch a := addr.(type) { case *net.TCPAddr: clientAddress = a.IP.String() default: clientAddress = addr.String() } } // Serve the session. session := newSession(ctx, s, conn, clientAddress, room, client, agentKey) if err = session.serve(); err != nil { // TODO: error handling logging.Logger(ctx).Printf("session serve error: %s", err) return } }
func (b *AccountManagerBinding) Resolve(ctx scope.Context, namespace, id string) (proto.Account, error) { t, err := b.DbMap.Begin() account, err := b.resolve(t, namespace, id) if err != nil { if rerr := t.Rollback(); rerr != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } return nil, err } if err := t.Commit(); err != nil { return nil, err } return account, nil }
func (b *AccountManagerBinding) Get(ctx scope.Context, id snowflake.Snowflake) (proto.Account, error) { t, err := b.DbMap.Begin() account, err := b.get(t, id) if err != nil { if rerr := t.Rollback(); rerr != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } return nil, err } if err := t.Commit(); err != nil { return nil, err } return account, nil }
func (r *memRoom) RenameUser( ctx scope.Context, session proto.Session, formerName string) (*proto.NickEvent, error) { r.m.Lock() defer r.m.Unlock() logging.Logger(ctx).Printf( "renaming %s from %s to %s\n", session.ID(), formerName, session.Identity().Name()) payload := &proto.NickEvent{ SessionID: session.ID(), ID: session.Identity().ID(), From: formerName, To: session.Identity().Name(), } return payload, r.broadcast(ctx, proto.NickType, payload, session) }
func Run(args []string) { out = tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0) if len(args) == 0 { generalHelp() return } exe := filepath.Base(os.Args[0]) cmd, ok := subcommands[args[0]] if !ok { fmt.Fprintf(os.Stderr, "%s: invalid command: %s\n", exe, args[0]) fmt.Fprintf(os.Stderr, "Run '%s help' for usage.\n", exe) os.Exit(2) } flags := cmd.flags() if err := flags.Parse(args[1:]); err != nil { fmt.Fprintf(os.Stderr, "%s %s: %s\n", exe, args[0], err) os.Exit(2) } ctx := logging.LoggingContext(scope.New(), os.Stdout, fmt.Sprintf("[%s] ", args[0])) logging.Logger(ctx).Printf("starting up") if err := cmd.run(ctx, flags.Args()); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) } timeout := time.After(10 * time.Second) completed := make(chan struct{}) go func() { ctx.WaitGroup().Wait() close(completed) }() fmt.Println("waiting for graceful shutdown...") select { case <-timeout: fmt.Println("timed out") os.Exit(1) case <-completed: fmt.Println("ok") os.Exit(0) } }
func (b *AccountManagerBinding) RequestPasswordReset( ctx scope.Context, kms security.KMS, namespace, id string) ( proto.Account, *proto.PasswordResetRequest, error) { t, err := b.DbMap.Begin() if err != nil { return nil, nil, err } rollback := func() { if err := t.Rollback(); err != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } } account, err := b.resolve(t, namespace, id) if err != nil { rollback() return nil, nil, err } req, err := proto.GeneratePasswordResetRequest(kms, account.ID()) if err != nil { rollback() return nil, nil, err } stored := &PasswordResetRequest{ ID: req.ID.String(), AccountID: req.AccountID.String(), Key: req.Key, Requested: req.Requested, Expires: req.Expires, } if err := t.Insert(stored); err != nil { rollback() return nil, nil, err } if err := t.Commit(); err != nil { rollback() return nil, nil, err } return account, req, nil }
// invalidatePeer must be called with lock held func (b *Backend) invalidatePeer(ctx scope.Context, id, era string) { logger := logging.Logger(ctx) packet, err := proto.MakeEvent(&proto.NetworkEvent{ Type: "partition", ServerID: id, ServerEra: era, }) if err != nil { logger.Printf("cluster: make network event error: %s", err) return } for _, lm := range b.listeners { if err := lm.Broadcast(ctx, packet); err != nil { logger.Printf("cluster: network event error: %s", err) } } }
func (w *EmailWorker) send(ctx scope.Context, accountID snowflake.Snowflake, msgID string) error { ref, err := w.et.Get(ctx, accountID, msgID) if err != nil { return err } if err := w.d.Deliver(ctx, ref); err != nil { return err } if err := w.et.MarkDelivered(ctx, accountID, msgID); err != nil { // We failed to mark the email as delivered, which is unfortunate, // but not quite as unfortunate as delivering it twice would be. // So we swallow the error here but log it noisily. logging.Logger(ctx).Printf("failed to mark email %s/%s as delivered: %s", accountID, msgID, err) } return nil }
func (rb *ManagedRoomBinding) RemoveManager( ctx scope.Context, actor proto.Account, actorKey *security.ManagedKey, formerManager proto.Account) error { t, err := rb.Backend.DbMap.Begin() if err != nil { return err } rollback := func() { if err := t.Rollback(); err != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } } rmkb := NewRoomManagerKeyBinding(rb) rmkb.SetExecutor(t) if _, _, _, err := rmkb.Authority(ctx, actor, actorKey); err != nil { rollback() if err == proto.ErrCapabilityNotFound { return proto.ErrAccessDenied } return err } if err := rmkb.RevokeFromAccount(ctx, formerManager); err != nil { rollback() if err == proto.ErrCapabilityNotFound || err == proto.ErrAccessDenied { return proto.ErrManagerNotFound } return err } if err := t.Commit(); err != nil { return err } return nil }
func (j *Job) Exec(ctx scope.Context, f func(scope.Context) error) error { if j.JobClaim == nil { return ErrJobNotClaimed } w := io.MultiWriter(os.Stdout, j) prefix := fmt.Sprintf("[%s-%s] ", j.Queue.Name(), j.HandlerID) deadline := time.Now().Add(j.MaxWorkDuration) child := logging.LoggingContext(ctx.ForkWithTimeout(j.MaxWorkDuration), w, prefix) if err := f(child); err != nil { logging.Logger(child).Printf("error: %s", err) if err != scope.TimedOut { delay := time.Duration(j.AttemptsMade+1) * BackoffDuration if time.Now().Add(delay).After(deadline) { delay = deadline.Sub(time.Now()) } time.Sleep(delay) } return j.Fail(ctx, err.Error()) } return j.Complete(ctx) }
func (atb *AgentTrackerBinding) Register(ctx scope.Context, agent *proto.Agent) error { row := &Agent{ ID: agent.IDString(), IV: agent.IV, MAC: agent.MAC, Created: agent.Created, Bot: agent.Bot, } if agent.EncryptedClientKey != nil { row.EncryptedClientKey = agent.EncryptedClientKey.Ciphertext } if err := atb.Backend.DbMap.Insert(row); err != nil { if strings.HasPrefix(err.Error(), "pq: duplicate key value") { return proto.ErrAgentAlreadyExists } return err } logging.Logger(ctx).Printf("registered agent %s", agent.IDString()) return nil }
func (atb *AgentTrackerBinding) SetClientKey( ctx scope.Context, agentID string, accessKey *security.ManagedKey, accountID snowflake.Snowflake, clientKey *security.ManagedKey) error { t, err := atb.Backend.DbMap.Begin() if err != nil { return err } rollback := func() { if err := t.Rollback(); err != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } } agent, err := atb.getFromDB(agentID, atb.Backend.DbMap) if err != nil { rollback() return err } if err := agent.SetClientKey(accessKey, clientKey); err != nil { rollback() return err } err = atb.setClientKeyInDB( agentID, accountID.String(), agent.EncryptedClientKey.Ciphertext, t) if err != nil { rollback() return err } if err := t.Commit(); err != nil { return err } return nil }
func ScanLoop(ctx scope.Context, listener *pq.Listener) { defer ctx.WaitGroup().Done() logger := logging.Logger(ctx) for { select { case <-ctx.Done(): logger.Printf("received cancellation signal, shutting down") return case notice := <-listener.Notify: if notice == nil { logger.Printf("received nil from listener") continue } var msg psql.BroadcastMessage if err := json.Unmarshal([]byte(notice.Extra), &msg); err != nil { logger.Printf("error: pq listen: invalid broadcast: %s", err) logger.Printf(" payload: %#v", notice.Extra) continue } switch msg.Event.Type { case proto.BounceEventType: bounceActivity.WithLabelValues(msg.Room).Inc() case proto.JoinEventType: joinActivity.WithLabelValues(msg.Room).Inc() case proto.PartEventType: partActivity.WithLabelValues(msg.Room).Inc() case proto.SendEventType: messageActivity.WithLabelValues(msg.Room).Inc() } } } }
func (rb *RoomBinding) EditMessage( ctx scope.Context, session proto.Session, edit proto.EditMessageCommand) ( proto.EditMessageReply, error) { var reply proto.EditMessageReply editID, err := snowflake.New() if err != nil { return reply, err } cols, err := allColumns(rb.DbMap, Message{}, "") if err != nil { return reply, err } t, err := rb.DbMap.Begin() if err != nil { return reply, err } rollback := func() { if err := t.Rollback(); err != nil { logging.Logger(ctx).Printf("rollback error: %s", err) } } var msg Message err = t.SelectOne(&msg, fmt.Sprintf("SELECT %s FROM message WHERE room = $1 AND id = $2", cols), rb.RoomName, edit.ID.String()) if err != nil { rollback() return reply, err } if msg.PreviousEditID.Valid && msg.PreviousEditID.String != edit.PreviousEditID.String() { rollback() return reply, proto.ErrEditInconsistent } entry := &MessageEditLog{ EditID: editID.String(), Room: rb.RoomName, MessageID: edit.ID.String(), PreviousEditID: msg.PreviousEditID, PreviousContent: msg.Content, PreviousParent: sql.NullString{ String: msg.Parent, Valid: true, }, } // TODO: tests pass in a nil session, until we add support for the edit command if session != nil { entry.EditorID = sql.NullString{ String: string(session.Identity().ID()), Valid: true, } } if err := t.Insert(entry); err != nil { rollback() return reply, err } now := time.Time(proto.Now()) sets := []string{"edited = $3", "previous_edit_id = $4"} args := []interface{}{rb.RoomName, edit.ID.String(), now, editID.String()} msg.Edited = gorp.NullTime{Valid: true, Time: now} if edit.Content != "" { args = append(args, edit.Content) sets = append(sets, fmt.Sprintf("content = $%d", len(args))) msg.Content = edit.Content } if edit.Parent != 0 { args = append(args, edit.Parent.String()) sets = append(sets, fmt.Sprintf("parent = $%d", len(args))) msg.Parent = edit.Parent.String() } if edit.Delete != msg.Deleted.Valid { if edit.Delete { args = append(args, now) sets = append(sets, fmt.Sprintf("deleted = $%d", len(args))) msg.Deleted = gorp.NullTime{Valid: true, Time: now} } else { sets = append(sets, "deleted = NULL") msg.Deleted.Valid = false } } query := fmt.Sprintf("UPDATE message SET %s WHERE room = $1 AND id = $2", strings.Join(sets, ", ")) if _, err := t.Exec(query, args...); err != nil { rollback() return reply, err } if edit.Announce { event := &proto.EditMessageEvent{ EditID: editID, Message: msg.ToTransmission(), } err = rb.broadcast(ctx, t, proto.EditMessageEventType, event, session) if err != nil { rollback() return reply, err } } if err := t.Commit(); err != nil { return reply, err } reply.EditID = editID reply.Message = msg.ToTransmission() return reply, nil }