Esempio n. 1
0
File: room.go Progetto: robot0x/heim
func (rb *RoomBinding) GenerateMessageKey(ctx scope.Context, kms security.KMS) (
	proto.RoomMessageKey, error) {

	rmkb, err := rb.Room.generateMessageKey(rb.Backend, kms)
	if err != nil {
		return nil, err
	}

	// Insert key and room association into the DB.
	transaction, err := rb.DbMap.Begin()
	if err != nil {
		return nil, err
	}

	if err := transaction.Insert(&rmkb.MessageKey); err != nil {
		if rerr := transaction.Rollback(); rerr != nil {
			backend.Logger(ctx).Printf("rollback error: %s", rerr)
		}
		return nil, err
	}

	if err := transaction.Insert(&rmkb.RoomMessageKey); err != nil {
		if rerr := transaction.Rollback(); rerr != nil {
			backend.Logger(ctx).Printf("rollback error: %s", rerr)
		}
		return nil, err
	}

	if err := transaction.Commit(); err != nil {
		return nil, err
	}

	return rmkb, nil
}
Esempio n. 2
0
func (m *accountManager) Register(
	ctx scope.Context, kms security.KMS, namespace, id, password string,
	agentID string, agentKey *security.ManagedKey) (
	proto.Account, *security.ManagedKey, error) {

	m.b.Lock()
	defer m.b.Unlock()

	key := fmt.Sprintf("%s:%s", namespace, id)
	if _, ok := m.b.accountIDs[key]; ok {
		return nil, nil, proto.ErrPersonalIdentityInUse
	}

	account, clientKey, err := NewAccount(kms, password)
	if err != nil {
		return nil, nil, err
	}

	if m.b.accounts == nil {
		m.b.accounts = map[snowflake.Snowflake]proto.Account{account.ID(): account}
	} else {
		m.b.accounts[account.ID()] = account
	}

	pi := &personalIdentity{
		accountID: account.ID(),
		namespace: namespace,
		id:        id,
	}
	account.(*memAccount).personalIdentities = []proto.PersonalIdentity{pi}
	if m.b.accountIDs == nil {
		m.b.accountIDs = map[string]*personalIdentity{key: pi}
	} else {
		m.b.accountIDs[key] = pi
	}

	agent, err := m.b.AgentTracker().Get(ctx, agentID)
	if err != nil {
		backend.Logger(ctx).Printf(
			"error locating agent %s for new account %s:%s: %s", agentID, namespace, id, err)
	} else {
		if err := agent.SetClientKey(agentKey, clientKey); err != nil {
			backend.Logger(ctx).Printf(
				"error associating agent %s with new account %s:%s: %s", agentID, namespace, id, err)
		}
		agent.AccountID = account.ID().String()
	}

	return account, clientKey, nil
}
Esempio n. 3
0
func Serve(ctx scope.Context, addr string) {
	http.Handle("/metrics", prometheus.Handler())

	listener, err := net.Listen("tcp", addr)
	if err != nil {
		ctx.Terminate(err)
	}

	closed := false
	m := sync.Mutex{}
	closeListener := func() {
		m.Lock()
		if !closed {
			listener.Close()
			closed = true
		}
		m.Unlock()
	}

	// Spin off goroutine to watch ctx and close listener if shutdown requested.
	go func() {
		<-ctx.Done()
		closeListener()
	}()

	backend.Logger(ctx).Printf("serving /metrics on %s", addr)
	if err := http.Serve(listener, nil); err != nil {
		fmt.Printf("http[%s]: %s\n", addr, err)
		ctx.Terminate(err)
	}

	closeListener()
	ctx.WaitGroup().Done()
}
Esempio n. 4
0
func (b *Backend) part(ctx scope.Context, room *Room, session proto.Session) error {
	t, err := b.DbMap.Begin()
	if err != nil {
		return err
	}

	_, err = t.Exec(
		"DELETE FROM presence WHERE room = $1 AND server_id = $2 AND server_era = $3 AND session_id = $4",
		room.Name, b.desc.ID, b.desc.Era, session.ID())
	if err != nil {
		rollback(ctx, t)
		backend.Logger(ctx).Printf("failed to persist departure: %s", err)
		return err
	}

	// Broadcast a presence event.
	// TODO: make this an explicit action via the Room protocol, to support encryption
	if err := room.broadcast(ctx, t, proto.PartEventType, proto.PresenceEvent(*session.View()), session); err != nil {
		rollback(ctx, t)
		return err
	}

	if err := t.Commit(); err != nil {
		return err
	}

	b.Lock()
	if lm, ok := b.listeners[room.Name]; ok {
		delete(lm, session.ID())
	}
	b.Unlock()

	return nil
}
Esempio n. 5
0
func (b *AccountManagerBinding) ChangeClientKey(
	ctx scope.Context, accountID snowflake.Snowflake, oldKey, newKey *security.ManagedKey) error {

	t, err := b.DbMap.Begin()
	if err != nil {
		return err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	var account Account
	err = t.SelectOne(
		&account,
		"SELECT nonce, mac, encrypted_user_key, encrypted_private_key FROM account WHERE id = $1",
		accountID.String())
	if err != nil {
		rollback()
		if err == sql.ErrNoRows {
			return proto.ErrAccountNotFound
		}
		return err
	}

	sec := account.Bind(b.Backend).accountSecurity()
	if err := sec.ChangeClientKey(oldKey, newKey); err != nil {
		rollback()
		return err
	}

	res, err := t.Exec(
		"UPDATE account SET mac = $2, encrypted_user_key = $3 WHERE id = $1",
		accountID.String(), sec.MAC, sec.UserKey.Ciphertext)
	if err != nil {
		rollback()
		return err
	}

	n, err := res.RowsAffected()
	if err != nil {
		rollback()
		return err
	}
	if n == 0 {
		rollback()
		return proto.ErrAccountNotFound
	}

	if err := t.Commit(); err != nil {
		return err
	}

	return nil
}
Esempio n. 6
0
func (r *memRoom) RenameUser(
	ctx scope.Context, session proto.Session, formerName string) (*proto.NickEvent, error) {

	backend.Logger(ctx).Printf(
		"renaming %s from %s to %s\n", session.ID(), formerName, session.Identity().Name())
	payload := &proto.NickEvent{
		SessionID: session.ID(),
		ID:        session.Identity().ID(),
		From:      formerName,
		To:        session.Identity().Name(),
	}
	return payload, r.broadcast(ctx, proto.NickType, payload, session)
}
Esempio n. 7
0
func (b *AccountManagerBinding) Resolve(ctx scope.Context, namespace, id string) (proto.Account, error) {
	t, err := b.DbMap.Begin()
	account, err := b.resolve(t, namespace, id)
	if err != nil {
		if rerr := t.Rollback(); rerr != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
		return nil, err
	}
	if err := t.Commit(); err != nil {
		return nil, err
	}
	return account, nil
}
Esempio n. 8
0
func (b *AccountManagerBinding) Get(ctx scope.Context, id snowflake.Snowflake) (proto.Account, error) {
	t, err := b.DbMap.Begin()
	account, err := b.get(t, id)
	if err != nil {
		if rerr := t.Rollback(); rerr != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
		return nil, err
	}
	if err := t.Commit(); err != nil {
		return nil, err
	}
	return account, nil
}
Esempio n. 9
0
func (rb *RoomBinding) banAgent(ctx scope.Context, agentID proto.UserID, until time.Time) error {
	ban := &BannedAgent{
		AgentID: agentID.String(),
		Room: sql.NullString{
			String: rb.Name,
			Valid:  true,
		},
		Created: time.Now(),
		Expires: gorp.NullTime{
			Time:  until,
			Valid: !until.IsZero(),
		},
	}

	// Loop within transaction in read committed mode to simulate UPSERT.
	t, err := rb.DbMap.Begin()
	if err != nil {
		return err
	}
	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}
	for {
		// Try to insert; if this fails due to duplicate key value, try to update.
		if err := rb.DbMap.Insert(ban); err != nil {
			if !strings.HasPrefix(err.Error(), "pq: duplicate key value") {
				rollback()
				return err
			}
		} else {
			break
		}
		n, err := rb.DbMap.Update(ban)
		if err != nil {
			rollback()
			return err
		}
		if n > 0 {
			break
		}
	}
	if err := t.Commit(); err != nil {
		return err
	}

	bounceEvent := &proto.BounceEvent{Reason: "banned", AgentID: agentID.String()}
	return rb.broadcast(ctx, rb.Room, proto.BounceEventType, bounceEvent)
}
Esempio n. 10
0
func Run(args []string) {
	out = tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0)

	if len(args) == 0 {
		generalHelp()
		return
	}

	exe := filepath.Base(os.Args[0])
	cmd, ok := subcommands[args[0]]
	if !ok {
		fmt.Fprintf(os.Stderr, "%s: invalid command: %s\n", exe, args[0])
		fmt.Fprintf(os.Stderr, "Run '%s help' for usage.\n", exe)
		os.Exit(2)
	}

	flags := cmd.flags()
	if err := flags.Parse(args[1:]); err != nil {
		fmt.Fprintf(os.Stderr, "%s %s: %s\n", exe, args[0], err)
		os.Exit(2)
	}

	ctx := backend.LoggingContext(scope.New(), fmt.Sprintf("[%s] ", args[0]))
	backend.Logger(ctx).Printf("starting up")
	if err := cmd.run(ctx, flags.Args()); err != nil {
		fmt.Fprintln(os.Stderr, err.Error())
		os.Exit(1)
	}

	timeout := time.After(10 * time.Second)
	completed := make(chan struct{})
	go func() {
		ctx.WaitGroup().Wait()
		close(completed)
	}()

	fmt.Println("waiting for graceful shutdown...")
	select {
	case <-timeout:
		fmt.Println("timed out")
		os.Exit(1)
	case <-completed:
		fmt.Println("ok")
		os.Exit(0)
	}
}
Esempio n. 11
0
// invalidatePeer must be called with lock held
func (b *Backend) invalidatePeer(ctx scope.Context, id, era string) {
	logger := backend.Logger(ctx)
	packet, err := proto.MakeEvent(&proto.NetworkEvent{
		Type:      "partition",
		ServerID:  id,
		ServerEra: era,
	})
	if err != nil {
		logger.Printf("cluster: make network event error: %s", err)
		return
	}
	for _, lm := range b.listeners {
		if err := lm.Broadcast(ctx, packet); err != nil {
			logger.Printf("cluster: network event error: %s", err)
		}
	}
}
Esempio n. 12
0
func (b *AccountManagerBinding) RequestPasswordReset(
	ctx scope.Context, kms security.KMS, namespace, id string) (
	proto.Account, *proto.PasswordResetRequest, error) {

	t, err := b.DbMap.Begin()
	if err != nil {
		return nil, nil, err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	account, err := b.resolve(t, namespace, id)
	if err != nil {
		rollback()
		return nil, nil, err
	}

	req, err := proto.GeneratePasswordResetRequest(kms, account.ID())
	if err != nil {
		rollback()
		return nil, nil, err
	}

	stored := &PasswordResetRequest{
		ID:        req.ID.String(),
		AccountID: req.AccountID.String(),
		Key:       req.Key,
		Requested: req.Requested,
		Expires:   req.Expires,
	}
	if err := t.Insert(stored); err != nil {
		rollback()
		return nil, nil, err
	}

	if err := t.Commit(); err != nil {
		rollback()
		return nil, nil, err
	}

	return account, req, nil
}
Esempio n. 13
0
func (atb *AgentTrackerBinding) Register(ctx scope.Context, agent *proto.Agent) error {
	row := &Agent{
		ID:      agent.IDString(),
		IV:      agent.IV,
		MAC:     agent.MAC,
		Created: agent.Created,
	}
	if agent.EncryptedClientKey != nil {
		row.EncryptedClientKey = agent.EncryptedClientKey.Ciphertext
	}

	if err := atb.Backend.DbMap.Insert(row); err != nil {
		if strings.HasPrefix(err.Error(), "pq: duplicate key value") {
			return proto.ErrAgentAlreadyExists
		}
		return err
	}

	backend.Logger(ctx).Printf("registered agent %s", agent.IDString())
	return nil
}
Esempio n. 14
0
File: room.go Progetto: robot0x/heim
func (rb *RoomBinding) RemoveManager(
	ctx scope.Context, actor proto.Account, actorKey *security.ManagedKey,
	formerManager proto.Account) error {

	t, err := rb.Backend.DbMap.Begin()
	if err != nil {
		return err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	rmkb := NewRoomManagerKeyBinding(rb)
	rmkb.SetExecutor(t)

	if _, _, _, err := rmkb.Authority(ctx, actor, actorKey); err != nil {
		rollback()
		if err == proto.ErrCapabilityNotFound {
			return proto.ErrAccessDenied
		}
		return err
	}

	if err := rmkb.RevokeFromAccount(ctx, formerManager); err != nil {
		rollback()
		if err == proto.ErrCapabilityNotFound || err == proto.ErrAccessDenied {
			return proto.ErrManagerNotFound
		}
		return err
	}

	if err := t.Commit(); err != nil {
		return err
	}

	return nil
}
Esempio n. 15
0
func (atb *AgentTrackerBinding) SetClientKey(
	ctx scope.Context, agentID string, accessKey *security.ManagedKey,
	accountID snowflake.Snowflake, clientKey *security.ManagedKey) error {

	t, err := atb.Backend.DbMap.Begin()
	if err != nil {
		return err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	agent, err := atb.getFromDB(agentID, atb.Backend.DbMap)
	if err != nil {
		rollback()
		return err
	}

	if err := agent.SetClientKey(accessKey, clientKey); err != nil {
		rollback()
		return err
	}

	err = atb.setClientKeyInDB(
		agentID, accountID.String(), agent.EncryptedClientKey.Ciphertext, t)
	if err != nil {
		rollback()
		return err
	}

	if err := t.Commit(); err != nil {
		return err
	}

	return nil
}
Esempio n. 16
0
func ScanLoop(ctx scope.Context, listener *pq.Listener) {
	defer ctx.WaitGroup().Done()

	logger := backend.Logger(ctx)
	for {
		select {
		case <-ctx.Done():
			logger.Printf("received cancellation signal, shutting down")
			return
		case notice := <-listener.Notify:
			if notice == nil {
				logger.Printf("received nil from listener")
				continue
			}

			var msg psql.BroadcastMessage

			if err := json.Unmarshal([]byte(notice.Extra), &msg); err != nil {
				logger.Printf("error: pq listen: invalid broadcast: %s", err)
				logger.Printf("         payload: %#v", notice.Extra)
				continue
			}

			switch msg.Event.Type {
			case proto.BounceEventType:
				bounceActivity.WithLabelValues(msg.Room).Inc()
			case proto.JoinEventType:
				joinActivity.WithLabelValues(msg.Room).Inc()
			case proto.PartEventType:
				partActivity.WithLabelValues(msg.Room).Inc()
			case proto.SendEventType:
				messageActivity.WithLabelValues(msg.Room).Inc()
			}
		}
	}
}
Esempio n. 17
0
func (b *AccountManagerBinding) ConfirmPasswordReset(
	ctx scope.Context, kms security.KMS, confirmation, password string) error {

	id, mac, err := proto.ParsePasswordResetConfirmation(confirmation)
	if err != nil {
		return err
	}

	t, err := b.DbMap.Begin()
	if err != nil {
		return err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	req := &proto.PasswordResetRequest{
		ID: id,
	}

	var (
		stored  PasswordResetRequest
		account *AccountBinding
	)

	cols, err := allColumns(b.DbMap, PasswordResetRequest{}, "")
	if err != nil {
		return err
	}
	err = t.SelectOne(
		&stored,
		fmt.Sprintf(
			"SELECT %s FROM password_reset_request WHERE id = $1 AND expires > NOW() AND invalidated IS NULL AND consumed IS NULL",
			cols),
		id.String())
	if err != nil && err != sql.ErrNoRows {
		rollback()
		return err
	}

	if err == nil {
		req.Key = stored.Key
		if err := req.AccountID.FromString(stored.AccountID); err == nil {
			account, err = b.get(t, req.AccountID)
			if err != nil && err != proto.ErrAccountNotFound {
				rollback()
				return err
			}
		}
	}

	if !req.VerifyMAC(mac) || account == nil {
		rollback()
		fmt.Printf("invalid mac or no account (%#v)\n", account)
		return proto.ErrInvalidConfirmationCode
	}

	sec, err := account.accountSecurity().ResetPassword(kms, password)
	if err != nil {
		rollback()
		fmt.Printf("reset password failed: %s\n", err)
		return err
	}

	_, err = t.Exec(
		"UPDATE account SET mac = $2, encrypted_user_key = $3 WHERE id = $1",
		account.ID().String(), sec.MAC, sec.UserKey.Ciphertext)
	if err != nil {
		rollback()
		fmt.Printf("update 1 failed: %s\n", err)
		return err
	}

	_, err = t.Exec("UPDATE password_reset_request SET consumed = NOW() where id = $1", id.String())
	if err != nil {
		rollback()
		fmt.Printf("update 2 failed: %s\n", err)
		return err
	}

	_, err = t.Exec(
		"UPDATE password_reset_request SET invalidated = NOW() where account_id = $1 AND id != $2",
		account.ID().String(), id)
	if err != nil {
		rollback()
		fmt.Printf("update 3 failed: %s\n", err)
		return err
	}

	if err := t.Commit(); err != nil {
		fmt.Printf("commit failed: %s\n", err)
		return err
	}

	return nil
}
Esempio n. 18
0
func (lm ListenerMap) Broadcast(ctx scope.Context, event *proto.Packet, exclude ...string) error {
	payload, err := event.Payload()
	if err != nil {
		return err
	}

	excludeSet := map[string]struct{}{}
	for _, exc := range exclude {
		excludeSet[exc] = struct{}{}
	}

	// Inspect packet to see if it's a bounce event. If so, we'll deliver it
	// only to the bounced parties.
	bounceAgentID := ""
	bounceIP := ""
	if event.Type == proto.BounceEventType {
		if bounceEvent, ok := payload.(*proto.BounceEvent); ok {
			bounceAgentID = bounceEvent.AgentID
			bounceIP = bounceEvent.IP
		} else {
			backend.Logger(ctx).Printf("wtf? expected *proto.BounceEvent, got %T", payload)
		}
	}

	// Inspect packet to see if it's a join event. If so, we'll enable the excluded
	// listener, and look for aliased sessions to kick into fast-keepalive mode.
	fastKeepaliveAgentID := ""
	if event.Type == proto.JoinEventType {
		if presence, ok := payload.(*proto.PresenceEvent); ok {
			if idx := strings.IndexRune(string(presence.ID), '-'); idx >= 1 {
				fastKeepaliveAgentID = string(presence.ID[:idx])
			}
		}
		for _, sessionID := range exclude {
			listener, ok := lm[sessionID]
			if ok && !listener.enabled {
				listener.enabled = true
				lm[sessionID] = listener
			}
		}
	}

	for sessionID, listener := range lm {
		if _, ok := excludeSet[sessionID]; !ok {
			if bounceAgentID != "" {
				if listener.Session.Identity().ID().String() == bounceAgentID {
					backend.Logger(ctx).Printf("sending disconnect to %s: %#v", listener.ID(), payload)
					discEvent := &proto.DisconnectEvent{Reason: payload.(*proto.BounceEvent).Reason}
					if err := listener.Send(ctx, proto.DisconnectEventType, discEvent); err != nil {
						backend.Logger(ctx).Printf("error sending disconnect event to %s: %s",
							listener.ID(), err)
					}
				}
				continue
			}
			if bounceIP != "" {
				if listener.Client.IP == bounceIP {
					backend.Logger(ctx).Printf("sending disconnect to %s: %#v", listener.ID(), payload)
					discEvent := &proto.DisconnectEvent{Reason: payload.(*proto.BounceEvent).Reason}
					if err := listener.Send(ctx, proto.DisconnectEventType, discEvent); err != nil {
						backend.Logger(ctx).Printf("error sending disconnect event to %s: %s",
							listener.ID(), err)
					}
				}
				continue
			}
			if fastKeepaliveAgentID != "" && strings.HasPrefix(sessionID, fastKeepaliveAgentID) {
				if err := listener.CheckAbandoned(); err != nil {
					fmt.Errorf("fast keepalive to %s: %s", listener.ID(), err)
				}
			}
			if !listener.enabled {
				// The event occurred before the listener joined, so don't deliver it.
				backend.Logger(ctx).Printf("not delivering event %s before %s joined",
					event.Type, listener.ID())
				continue
			}
			if err := listener.Send(ctx, event.Type, payload); err != nil {
				// TODO: accumulate errors
				return fmt.Errorf("send message to %s: %s", listener.ID(), err)
			}
		}
	}

	return nil
}
Esempio n. 19
0
func (b *Backend) join(ctx scope.Context, room *Room, session proto.Session) error {
	client := &proto.Client{}
	if !client.FromContext(ctx) {
		return fmt.Errorf("client data not found in scope")
	}

	bannedAgentCols, err := allColumns(b.DbMap, BannedAgent{}, "")
	if err != nil {
		return err
	}

	bannedIPCols, err := allColumns(b.DbMap, BannedIP{}, "")
	if err != nil {
		return err
	}

	t, err := b.DbMap.Begin()
	if err != nil {
		return err
	}

	rb := func() { rollback(ctx, t) }

	// Check for agent ID bans.
	agentBans, err := t.Select(
		BannedAgent{},
		fmt.Sprintf(
			"SELECT %s FROM banned_agent WHERE agent_id = $1 AND (room IS NULL OR room = $2) AND (expires IS NULL OR expires > NOW())",
			bannedAgentCols),
		session.Identity().ID().String(), room.Name)
	if err != nil {
		rb()
		return err
	}
	if len(agentBans) > 0 {
		backend.Logger(ctx).Printf("access denied to %s: %#v", session.Identity().ID(), agentBans)
		if err := t.Rollback(); err != nil {
			return err
		}
		return proto.ErrAccessDenied
	}

	// Check for IP bans.
	ipBans, err := t.Select(
		BannedIP{},
		fmt.Sprintf(
			"SELECT %s FROM banned_ip WHERE ip = $1 AND (room IS NULL OR room = $2) AND (expires IS NULL OR expires > NOW())",
			bannedIPCols),
		client.IP, room.Name)
	if err != nil {
		rb()
		return err
	}
	if len(ipBans) > 0 {
		backend.Logger(ctx).Printf("access denied to %s: %#v", client.IP, ipBans)
		if err := t.Rollback(); err != nil {
			return err
		}
		return proto.ErrAccessDenied
	}

	// Write to session log.
	// TODO: do proper upsert simulation
	entry := &SessionLog{
		SessionID: session.ID(),
		IP:        client.IP,
		Room:      room.Name,
		UserAgent: client.UserAgent,
		Connected: client.Connected,
	}
	if _, err := t.Delete(entry); err != nil {
		if rerr := t.Rollback(); rerr != nil {
			backend.Logger(ctx).Printf("rollback error: %s", rerr)
		}
		return err
	}
	if err := t.Insert(entry); err != nil {
		if rerr := t.Rollback(); rerr != nil {
			backend.Logger(ctx).Printf("rollback error: %s", rerr)
		}
		return err
	}

	// Broadcast a presence event.
	// TODO: make this an explicit action via the Room protocol, to support encryption

	presence := &Presence{
		Room:      room.Name,
		ServerID:  b.desc.ID,
		ServerEra: b.desc.Era,
		SessionID: session.ID(),
		Updated:   time.Now(),
	}
	err = presence.SetFact(&proto.Presence{
		SessionView:    *session.View(),
		LastInteracted: presence.Updated,
	})
	if err != nil {
		rb()
		return fmt.Errorf("presence marshal error: %s", err)
	}
	if err := t.Insert(presence); err != nil {
		return fmt.Errorf("presence insert error: %s", err)
	}

	b.Lock()
	// Add session to listeners.
	lm, ok := b.listeners[room.Name]
	if !ok {
		lm = ListenerMap{}
		b.listeners[room.Name] = lm
	}
	lm[session.ID()] = Listener{Session: session, Client: client}
	b.Unlock()

	if err := room.broadcast(ctx, t, proto.JoinEventType, proto.PresenceEvent(*session.View()), session); err != nil {
		rb()
		return err
	}

	if err := t.Commit(); err != nil {
		return err
	}

	return nil
}
Esempio n. 20
0
func (b *AccountManagerBinding) GrantStaff(
	ctx scope.Context, accountID snowflake.Snowflake, kmsCred security.KMSCredential) error {

	// Look up the target account's (system) encrypted client key. This is
	// not part of the transaction, because we want to interact with KMS
	// before we proceed. That should be fine, since this is an infrequently
	// used action.
	var row struct {
		EncryptedClientKey []byte `db:"encrypted_system_key"`
		Nonce              []byte `db:"nonce"`
	}
	err := b.DbMap.SelectOne(
		&row, "SELECT encrypted_system_key, nonce FROM account WHERE id = $1", accountID.String())
	if err != nil {
		if err == sql.ErrNoRows {
			return proto.ErrAccountNotFound
		}
		return err
	}

	// Use kmsCred to obtain kms and decrypt the client's key.
	kms := kmsCred.KMS()
	clientKey := &security.ManagedKey{
		KeyType:      proto.ClientKeyType,
		Ciphertext:   row.EncryptedClientKey,
		ContextKey:   "nonce",
		ContextValue: base64.URLEncoding.EncodeToString(row.Nonce),
	}
	if err := kms.DecryptKey(clientKey); err != nil {
		return err
	}

	// Grant staff capability. This involves marshalling kmsCred to JSON and
	// encrypting it with the client key.
	nonce, err := kms.GenerateNonce(clientKey.KeyType.BlockSize())
	if err != nil {
		return err
	}

	capability, err := security.GrantSharedSecretCapability(clientKey, nonce, kmsCred.KMSType(), kmsCred)
	if err != nil {
		return err
	}

	// Store capability and update account table.
	t, err := b.DbMap.Begin()
	if err != nil {
		return err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	dbCap := &Capability{
		ID:                   capability.CapabilityID(),
		NonceBytes:           capability.Nonce(),
		EncryptedPrivateData: capability.EncryptedPayload(),
		PublicData:           capability.PublicPayload(),
	}
	if err := t.Insert(dbCap); err != nil {
		rollback()
		return err
	}

	result, err := t.Exec(
		"UPDATE account SET staff_capability_id = $2 WHERE id = $1",
		accountID.String(), capability.CapabilityID())
	if err != nil {
		rollback()
		return err
	}
	n, err := result.RowsAffected()
	if err != nil {
		rollback()
		return err
	}
	if n != 1 {
		rollback()
		return proto.ErrAccountNotFound
	}

	if err := t.Commit(); err != nil {
		return err
	}

	return nil
}
Esempio n. 21
0
File: room.go Progetto: robot0x/heim
func (rb *RoomBinding) EditMessage(
	ctx scope.Context, session proto.Session, edit proto.EditMessageCommand) (
	proto.EditMessageReply, error) {

	var reply proto.EditMessageReply

	editID, err := snowflake.New()
	if err != nil {
		return reply, err
	}

	t, err := rb.DbMap.Begin()
	if err != nil {
		return reply, err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	var msg Message
	err = t.SelectOne(
		&msg,
		"SELECT room, id, previous_edit_id, parent, posted, edited, deleted,"+
			" session_id, sender_id, sender_name, server_id, server_era, content, encryption_key_id"+
			" FROM message WHERE room = $1 AND id = $2",
		rb.Name, edit.ID.String())
	if err != nil {
		rollback()
		return reply, err
	}

	if msg.PreviousEditID.Valid && msg.PreviousEditID.String != edit.PreviousEditID.String() {
		rollback()
		return reply, proto.ErrEditInconsistent
	}

	entry := &MessageEditLog{
		EditID:          editID.String(),
		Room:            rb.Name,
		MessageID:       edit.ID.String(),
		PreviousEditID:  msg.PreviousEditID,
		PreviousContent: msg.Content,
		PreviousParent: sql.NullString{
			String: msg.Parent,
			Valid:  true,
		},
	}
	// TODO: tests pass in a nil session, until we add support for the edit command
	if session != nil {
		entry.EditorID = sql.NullString{
			String: string(session.Identity().ID()),
			Valid:  true,
		}
	}
	if err := t.Insert(entry); err != nil {
		rollback()
		return reply, err
	}

	now := time.Time(proto.Now())
	sets := []string{"edited = $3", "previous_edit_id = $4"}
	args := []interface{}{rb.Name, edit.ID.String(), now, editID.String()}
	msg.Edited = gorp.NullTime{Valid: true, Time: now}
	if edit.Content != "" {
		args = append(args, edit.Content)
		sets = append(sets, fmt.Sprintf("content = $%d", len(args)))
		msg.Content = edit.Content
	}
	if edit.Parent != 0 {
		args = append(args, edit.Parent.String())
		sets = append(sets, fmt.Sprintf("parent = $%d", len(args)))
		msg.Parent = edit.Parent.String()
	}
	if edit.Delete != msg.Deleted.Valid {
		if edit.Delete {
			args = append(args, now)
			sets = append(sets, fmt.Sprintf("deleted = $%d", len(args)))
			msg.Deleted = gorp.NullTime{Valid: true, Time: now}
		} else {
			sets = append(sets, "deleted = NULL")
			msg.Deleted.Valid = false
		}
	}
	query := fmt.Sprintf("UPDATE message SET %s WHERE room = $1 AND id = $2", strings.Join(sets, ", "))
	if _, err := t.Exec(query, args...); err != nil {
		rollback()
		return reply, err
	}

	if err := t.Commit(); err != nil {
		return reply, err
	}

	if edit.Announce {
		event := &proto.EditMessageEvent{
			EditID:  editID,
			Message: msg.ToBackend(),
		}
		err = rb.Backend.broadcast(ctx, rb.Room, proto.EditMessageEventType, event, session)
		if err != nil {
			return reply, err
		}
	}

	reply.EditID = editID
	reply.Deleted = edit.Delete
	return reply, nil
}
Esempio n. 22
0
func (b *Backend) CreateRoom(
	ctx scope.Context, kms security.KMS, private bool, name string, managers ...proto.Account) (
	proto.Room, error) {

	sec, err := proto.NewRoomSecurity(kms, name)
	if err != nil {
		return nil, err
	}

	backend.Logger(ctx).Printf("creating room: %s", name)
	room := &Room{
		Name:  name,
		IV:    sec.KeyPair.IV,
		MAC:   sec.MAC,
		Nonce: sec.Nonce,
		EncryptedManagementKey: sec.KeyEncryptingKey.Ciphertext,
		EncryptedPrivateKey:    sec.KeyPair.EncryptedPrivateKey,
		PublicKey:              sec.KeyPair.PublicKey,
	}

	var (
		rmkb   *RoomMessageKeyBinding
		msgKey security.ManagedKey
	)
	if private {
		rmkb, err = room.generateMessageKey(b, kms)
		if err != nil {
			return nil, err
		}

		msgKey = rmkb.ManagedKey()
		if err := kms.DecryptKey(&msgKey); err != nil {
			return nil, err
		}
	}

	// Generate manager capabilities.
	managerKey := sec.KeyEncryptingKey.Clone()
	if err := kms.DecryptKey(&managerKey); err != nil {
		return nil, fmt.Errorf("manager key decrypt error: %s", err)
	}
	roomKeyPair, err := sec.Unlock(&managerKey)
	if err != nil {
		return nil, fmt.Errorf("room security unlock error: %s", err)
	}
	managerCaps := make([]*security.PublicKeyCapability, len(managers))
	for i, manager := range managers {
		kp := manager.KeyPair()
		c, err := security.GrantPublicKeyCapability(
			kms, sec.Nonce, roomKeyPair, &kp, nil, managerKey.Plaintext)
		if err != nil {
			return nil, fmt.Errorf("manager grant error: %s", err)
		}
		managerCaps[i] = c
	}

	accessCaps := []*security.PublicKeyCapability{}
	if private {
		accessCaps = make([]*security.PublicKeyCapability, len(managers))
		for i, manager := range managers {
			kp := manager.KeyPair()
			c, err := security.GrantPublicKeyCapability(
				kms, rmkb.Nonce(), roomKeyPair, &kp, nil, msgKey.Plaintext)
			if err != nil {
				return nil, fmt.Errorf("access grant error: %s", err)
			}
			accessCaps[i] = c
		}
	}

	// Insert data.
	t, err := b.DbMap.Begin()
	if err != nil {
		return nil, err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	if err := t.Insert(room); err != nil {
		backend.Logger(ctx).Printf("room creation error on %s: %s", name, err)
		rollback()
		return nil, err
	}

	if rmkb != nil {
		if err := t.Insert(&rmkb.MessageKey, &rmkb.RoomMessageKey); err != nil {
			backend.Logger(ctx).Printf("room creation error on %s (message key): %s", name, err)
			rollback()
			return nil, err
		}
	}

	managerCapTable := RoomManagerCapabilities{
		Room:     room,
		Executor: t,
	}
	for i, capability := range managerCaps {
		if err := managerCapTable.Save(ctx, managers[i], capability); err != nil {
			backend.Logger(ctx).Printf(
				"room creation error on %s (manager %s): %s", name, managers[i].ID().String(), err)
			rollback()
			return nil, err
		}
	}

	messageCapTable := RoomMessageCapabilities{
		Room:     room,
		Executor: t,
	}
	for i, capability := range accessCaps {
		if err := messageCapTable.Save(ctx, managers[i], capability); err != nil {
			backend.Logger(ctx).Printf(
				"room creation error on %s (access capability): %s", name, err)
			rollback()
			return nil, err
		}
	}

	if err := t.Commit(); err != nil {
		backend.Logger(ctx).Printf("room creation error on %s (commit): %s", name, err)
		return nil, err
	}

	return room.Bind(b), nil
}
Esempio n. 23
0
func (b *AccountManagerBinding) Register(
	ctx scope.Context, kms security.KMS, namespace, id, password string,
	agentID string, agentKey *security.ManagedKey) (
	proto.Account, *security.ManagedKey, error) {

	// Generate ID for new account.
	accountID, err := snowflake.New()
	if err != nil {
		return nil, nil, err
	}

	// Generate credentials in advance of working in DB transaction.
	sec, clientKey, err := proto.NewAccountSecurity(kms, password)
	if err != nil {
		return nil, nil, err
	}

	// Begin transaction to check on identity availability and store new account data.
	t, err := b.DbMap.Begin()
	if err != nil {
		return nil, nil, err
	}

	rollback := func() {
		if err := t.Rollback(); err != nil {
			backend.Logger(ctx).Printf("rollback error: %s", err)
		}
	}

	// Insert new rows for account.
	account := &Account{
		ID:                  accountID.String(),
		Nonce:               sec.Nonce,
		MAC:                 sec.MAC,
		EncryptedSystemKey:  sec.SystemKey.Ciphertext,
		EncryptedUserKey:    sec.UserKey.Ciphertext,
		EncryptedPrivateKey: sec.KeyPair.EncryptedPrivateKey,
		PublicKey:           sec.KeyPair.PublicKey,
	}
	personalIdentity := &PersonalIdentity{
		Namespace: namespace,
		ID:        id,
		AccountID: accountID.String(),
	}
	if err := t.Insert(account); err != nil {
		rollback()
		return nil, nil, err
	}
	if err := t.Insert(personalIdentity); err != nil {
		rollback()
		if strings.HasPrefix(err.Error(), "pq: duplicate key value") {
			return nil, nil, proto.ErrPersonalIdentityInUse
		}
		return nil, nil, err
	}

	// Look up the associated agent.
	atb := &AgentTrackerBinding{b.Backend}
	agent, err := atb.getFromDB(agentID, t)
	if err != nil {
		rollback()
		return nil, nil, err
	}
	if err := agent.SetClientKey(agentKey, clientKey); err != nil {
		rollback()
		return nil, nil, err
	}
	err = atb.setClientKeyInDB(agentID, accountID.String(), agent.EncryptedClientKey.Ciphertext, t)
	if err != nil {
		rollback()
		return nil, nil, err
	}

	// Commit the transaction.
	if err := t.Commit(); err != nil {
		return nil, nil, err
	}
	backend.Logger(ctx).Printf("registered new account %s for %s:%s", account.ID, namespace, id)

	ab := account.Bind(b.Backend)
	ab.identities = []proto.PersonalIdentity{&PersonalIdentityBinding{personalIdentity}}
	return ab, clientKey, nil
}