func (m *accountManager) GrantStaff( ctx scope.Context, accountID snowflake.Snowflake, kmsCred security.KMSCredential) error { m.b.Lock() defer m.b.Unlock() account, ok := m.b.accounts[accountID] if !ok { return proto.ErrAccountNotFound } memAcc := account.(*memAccount) kms := kmsCred.KMS() key := memAcc.sec.SystemKey.Clone() if err := kms.DecryptKey(&key); err != nil { return err } nonce, err := kms.GenerateNonce(key.KeyType.BlockSize()) if err != nil { return err } capability, err := security.GrantSharedSecretCapability(&key, nonce, kmsCred.KMSType(), kmsCred) if err != nil { return err } memAcc.staffCapability = capability return nil }
func (b *AccountManagerBinding) GrantStaff( ctx scope.Context, accountID snowflake.Snowflake, kmsCred security.KMSCredential) error { // Look up the target account's (system) encrypted client key. This is // not part of the transaction, because we want to interact with KMS // before we proceed. That should be fine, since this is an infrequently // used action. var row struct { EncryptedClientKey []byte `db:"encrypted_system_key"` Nonce []byte `db:"nonce"` } err := b.DbMap.SelectOne( &row, "SELECT encrypted_system_key, nonce FROM account WHERE id = $1", accountID.String()) if err != nil { if err == sql.ErrNoRows { return proto.ErrAccountNotFound } return err } // Use kmsCred to obtain kms and decrypt the client's key. kms := kmsCred.KMS() clientKey := &security.ManagedKey{ KeyType: proto.ClientKeyType, Ciphertext: row.EncryptedClientKey, ContextKey: "nonce", ContextValue: base64.URLEncoding.EncodeToString(row.Nonce), } if err := kms.DecryptKey(clientKey); err != nil { return err } // Grant staff capability. This involves marshalling kmsCred to JSON and // encrypting it with the client key. nonce, err := kms.GenerateNonce(clientKey.KeyType.BlockSize()) if err != nil { return err } capability, err := security.GrantSharedSecretCapability(clientKey, nonce, kmsCred.KMSType(), kmsCred) if err != nil { return err } // Store capability and update account table. t, err := b.DbMap.Begin() if err != nil { return err } rollback := func() { if err := t.Rollback(); err != nil { backend.Logger(ctx).Printf("rollback error: %s", err) } } dbCap := &Capability{ ID: capability.CapabilityID(), NonceBytes: capability.Nonce(), EncryptedPrivateData: capability.EncryptedPayload(), PublicData: capability.PublicPayload(), } if err := t.Insert(dbCap); err != nil { rollback() return err } result, err := t.Exec( "UPDATE account SET staff_capability_id = $2 WHERE id = $1", accountID.String(), capability.CapabilityID()) if err != nil { rollback() return err } n, err := result.RowsAffected() if err != nil { rollback() return err } if n != 1 { rollback() return proto.ErrAccountNotFound } if err := t.Commit(); err != nil { return err } return nil }