forked from Silverfish/proton-bridge
Other(refactor): Use normal value + mutex for user.apiAddrs
This commit is contained in:
@ -99,77 +99,87 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
|
||||
}
|
||||
|
||||
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
if had := user.apiAddrs.Set(event.Address.ID, event.Address); had {
|
||||
return fmt.Errorf("address %q already exists", event.Address.ID)
|
||||
}
|
||||
return safe.LockRet(func() error {
|
||||
if _, ok := user.apiAddrs[event.Address.ID]; ok {
|
||||
return fmt.Errorf("address %q already exists", event.ID)
|
||||
}
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||
user.updateCh.SetFrom(event.Address.ID, addrID)
|
||||
user.apiAddrs[event.Address.ID] = event.Address
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh.SetFrom(event.Address.ID, primAddr.ID)
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, updateCh)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.Address.ID)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, updateCh)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.Address.ID)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam
|
||||
if had := user.apiAddrs.Set(event.Address.ID, event.Address); !had {
|
||||
return fmt.Errorf("address %q does not exist", event.Address.ID)
|
||||
}
|
||||
return safe.LockRet(func() error {
|
||||
if _, ok := user.apiAddrs[event.Address.ID]; ok {
|
||||
return fmt.Errorf("address %q already exists", event.ID)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
user.apiAddrs[event.Address.ID] = event.Address
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteAddressEvent(_ context.Context, event liteapi.AddressEvent) error {
|
||||
var email string
|
||||
|
||||
if ok := user.apiAddrs.GetDelete(event.ID, func(apiAddr liteapi.Address) {
|
||||
email = apiAddr.Email
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.ID)
|
||||
}
|
||||
|
||||
if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
return safe.LockRet(func() error {
|
||||
addr, ok := user.apiAddrs[event.ID]
|
||||
if !ok {
|
||||
return fmt.Errorf("address %q does not exist", event.ID)
|
||||
}
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.ID)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.ID,
|
||||
Email: email,
|
||||
if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.ID)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.ID,
|
||||
Email: addr.Email,
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLabelEvents handles the given label events.
|
||||
@ -254,18 +264,20 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me
|
||||
return fmt.Errorf("failed to get full message: %w", err)
|
||||
}
|
||||
|
||||
return user.withAddrKR(event.Message.AddressID, func(_, addrKR *crypto.KeyRing) error {
|
||||
buildRes, err := buildRFC822(full, addrKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build RFC822 message: %w", err)
|
||||
}
|
||||
return safe.RLockRet(func() error {
|
||||
return withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
buildRes, err := buildRFC822(full, addrKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build RFC822 message: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
}, &user.apiUserLock, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateMessageEvent(_ context.Context, event liteapi.MessageEvent) error { //nolint:unparam
|
||||
|
||||
@ -373,29 +373,31 @@ func (conn *imapConnector) importMessage(
|
||||
) (imap.Message, []byte, error) {
|
||||
var full liteapi.FullMessage
|
||||
|
||||
if err := conn.withAddrKR(conn.addrID, func(_, addrKR *crypto.KeyRing) error {
|
||||
res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []liteapi.ImportReq{{
|
||||
Metadata: liteapi.ImportMetadata{
|
||||
AddressID: conn.addrID,
|
||||
LabelIDs: labelIDs,
|
||||
Unread: liteapi.Bool(unread),
|
||||
Flags: flags,
|
||||
},
|
||||
Message: literal,
|
||||
}}...))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import message: %w", err)
|
||||
}
|
||||
if err := safe.RLockRet(func() error {
|
||||
return withAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []liteapi.ImportReq{{
|
||||
Metadata: liteapi.ImportMetadata{
|
||||
AddressID: conn.addrID,
|
||||
LabelIDs: labelIDs,
|
||||
Unread: liteapi.Bool(unread),
|
||||
Flags: flags,
|
||||
},
|
||||
Message: literal,
|
||||
}}...))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import message: %w", err)
|
||||
}
|
||||
|
||||
if full, err = conn.client.GetFullMessage(ctx, res[0].MessageID); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
if full, err = conn.client.GetFullMessage(ctx, res[0].MessageID); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
||||
return fmt.Errorf("failed to build message: %w", err)
|
||||
}
|
||||
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
||||
return fmt.Errorf("failed to build message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}); err != nil {
|
||||
return imap.Message{}, nil, err
|
||||
}
|
||||
|
||||
@ -21,85 +21,43 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error {
|
||||
return safe.RLockRet(func() error {
|
||||
userKR, err := user.apiUser.Keys.Unlock(user.vault.KeyPass(), nil)
|
||||
func withAddrKR(apiUser liteapi.User, apiAddr liteapi.Address, keyPass []byte, fn func(userKR, addrKR *crypto.KeyRing) error) error {
|
||||
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
addrKR, err := apiAddr.Keys.Unlock(keyPass, userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer addrKR.ClearPrivateParams()
|
||||
|
||||
return fn(userKR, addrKR)
|
||||
}
|
||||
|
||||
func withAddrKRs(apiUser liteapi.User, apiAddr map[string]liteapi.Address, keyPass []byte, fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
userKR, err := apiUser.Keys.Unlock(keyPass, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
addrKRs := make(map[string]*crypto.KeyRing, len(apiAddr))
|
||||
|
||||
for addrID, apiAddr := range apiAddr {
|
||||
addrKR, err := apiAddr.Keys.Unlock(keyPass, userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
defer addrKR.ClearPrivateParams()
|
||||
|
||||
return fn(userKR)
|
||||
}, &user.apiUserLock)
|
||||
}
|
||||
|
||||
func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error {
|
||||
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error {
|
||||
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
return fn(userKR, addrKR)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", addrID)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) withAddrKRByEmail(email string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error {
|
||||
return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
addrID, err := getAddrID(apiAddrs, email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get address ID: %w", err)
|
||||
}
|
||||
|
||||
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error {
|
||||
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
return fn(userKR, addrKR)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", addrID)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) withAddrKRs(fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error {
|
||||
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
addrKRs := make(map[string]*crypto.KeyRing)
|
||||
|
||||
for _, apiAddr := range apiAddrs {
|
||||
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
addrKRs[apiAddr.ID] = addrKR
|
||||
}
|
||||
|
||||
return fn(userKR, addrKRs)
|
||||
})
|
||||
})
|
||||
addrKRs[addrID] = addrKR
|
||||
}
|
||||
|
||||
return fn(userKR, addrKRs)
|
||||
}
|
||||
|
||||
@ -27,24 +27,6 @@ import (
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func BenchmarkUserKeyRing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
withAPI(b, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(b, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(b, ctx, s, m, "username", "password", func(user *User) {
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
require.NoError(b, user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkAddrKeyRing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
@ -54,7 +36,7 @@ func BenchmarkAddrKeyRing(b *testing.B) {
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
require.NoError(b, user.withAddrKR(addrIDs[0], func(userKR, addrKR *crypto.KeyRing) error {
|
||||
require.NoError(b, withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
@ -32,6 +32,7 @@ import (
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-rfc5322"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message/parser"
|
||||
@ -85,54 +86,61 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str
|
||||
return fmt.Errorf("failed to get mail settings: %w", err)
|
||||
}
|
||||
|
||||
return user.withAddrKRByEmail(from, func(userKR, addrKR *crypto.KeyRing) error {
|
||||
// Use the first key for encrypting the message.
|
||||
addrKR, err := addrKR.FirstKey()
|
||||
return safe.LockRet(func() error {
|
||||
addrID, err := getAddrID(user.apiAddrs, from)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get first key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// If we have to attach the public key, do it now.
|
||||
if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
|
||||
key, err := addrKR.GetKey(0)
|
||||
return withAddrKR(user.apiUser, user.apiAddrs[addrID], user.vault.KeyPass(), func(userKR, addrKR *crypto.KeyRing) error {
|
||||
// Use the first key for encrypting the message.
|
||||
addrKR, err := addrKR.FirstKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sending key: %w", err)
|
||||
return fmt.Errorf("failed to get first key: %w", err)
|
||||
}
|
||||
|
||||
pubKey, err := key.GetArmoredPublicKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get public key: %w", err)
|
||||
// If we have to attach the public key, do it now.
|
||||
if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
|
||||
key, err := addrKR.GetKey(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get sending key: %w", err)
|
||||
}
|
||||
|
||||
pubKey, err := key.GetArmoredPublicKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||
}
|
||||
|
||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||
}
|
||||
// Parse the message we want to send (after we have attached the public key).
|
||||
message, err := message.ParseWithParser(parser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
// Parse the message we want to send (after we have attached the public key).
|
||||
message, err := message.ParseWithParser(parser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
// Send the message using the correct key.
|
||||
sent, err := sendWithKey(
|
||||
ctx,
|
||||
user.client,
|
||||
authID,
|
||||
user.vault.AddressMode(),
|
||||
settings,
|
||||
userKR, addrKR,
|
||||
emails, from, to,
|
||||
message,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
// Send the message using the correct key.
|
||||
sent, err := sendWithKey(
|
||||
ctx,
|
||||
user.client,
|
||||
authID,
|
||||
user.vault.AddressMode(),
|
||||
settings,
|
||||
userKR, addrKR,
|
||||
emails, from, to,
|
||||
message,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
// If the message was successfully sent, we can update the message ID in the record.
|
||||
user.sendHash.addMessageID(hash, sent.ID)
|
||||
|
||||
// If the message was successfully sent, we can update the message ID in the record.
|
||||
user.sendHash.addMessageID(hash, sent.ID)
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}, &user.apiUserLock, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// sendWithKey sends the message with the given address key.
|
||||
|
||||
@ -29,6 +29,7 @@ import (
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
@ -72,45 +73,47 @@ func (user *User) doSync(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
return user.withAddrKRs(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
user.log.Debug("Syncing labels")
|
||||
return safe.RLockRet(func() error {
|
||||
return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
user.log.Debug("Syncing labels")
|
||||
|
||||
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
|
||||
user.log.Debug("Synced labels")
|
||||
} else {
|
||||
user.log.Debug("Labels are already synced, skipping")
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
user.log.Debug("Syncing messages")
|
||||
|
||||
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
|
||||
user.log.Debug("Synced messages")
|
||||
} else {
|
||||
user.log.Debug("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
user.log.Debug("Synced labels")
|
||||
} else {
|
||||
user.log.Debug("Labels are already synced, skipping")
|
||||
}
|
||||
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
user.log.Debug("Syncing messages")
|
||||
|
||||
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
|
||||
user.log.Debug("Synced messages")
|
||||
} else {
|
||||
user.log.Debug("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}, &user.apiUserLock, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
||||
|
||||
@ -24,6 +24,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// mapTo converts the slice to the given type.
|
||||
@ -82,7 +84,7 @@ func hexDecode(b []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// getAddrID returns the address ID for the given email address.
|
||||
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
func getAddrID(apiAddrs map[string]liteapi.Address, email string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if strings.EqualFold(addr.Email, sanitizeEmail(email)) {
|
||||
return addr.ID, nil
|
||||
@ -91,3 +93,27 @@ func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
}
|
||||
|
||||
// getAddrIdx returns the address with the given index.
|
||||
func getAddrIdx(apiAddrs map[string]liteapi.Address, idx int) (liteapi.Address, error) {
|
||||
sorted := sortSlice(maps.Values(apiAddrs), func(a, b liteapi.Address) bool {
|
||||
return a.Order < b.Order
|
||||
})
|
||||
|
||||
if idx < 0 || idx >= len(sorted) {
|
||||
return liteapi.Address{}, fmt.Errorf("address index %d out of range", idx)
|
||||
}
|
||||
|
||||
return sorted[idx], nil
|
||||
}
|
||||
|
||||
// sortSlice returns the given slice sorted by the given comparator.
|
||||
func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
||||
sorted := make([]Item, len(items))
|
||||
|
||||
copy(sorted, items)
|
||||
|
||||
slices.SortFunc(sorted, less)
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ import (
|
||||
"github.com/bradenaw/juniper/xsync"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -55,7 +56,9 @@ type User struct {
|
||||
apiUser liteapi.User
|
||||
apiUserLock sync.RWMutex
|
||||
|
||||
apiAddrs *safe.Map[string, liteapi.Address]
|
||||
apiAddrs map[string]liteapi.Address
|
||||
apiAddrsLock sync.RWMutex
|
||||
|
||||
apiLabels *safe.Map[string, liteapi.Label]
|
||||
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
|
||||
sendHash *sendRecorder
|
||||
@ -134,7 +137,7 @@ func New(
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
apiUser: apiUser,
|
||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||
apiAddrs: groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }),
|
||||
apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil),
|
||||
updateCh: safe.NewMapFrom(updateCh, nil),
|
||||
sendHash: newSendRecorder(sendEntryExpiry),
|
||||
@ -217,19 +220,23 @@ func (user *User) Match(query string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool {
|
||||
return addr.Email == query
|
||||
})
|
||||
}, &user.apiUserLock)
|
||||
for _, addr := range user.apiAddrs {
|
||||
if query == addr.Email {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}, &user.apiUserLock, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// Emails returns all the user's email addresses via the callback.
|
||||
func (user *User) Emails() []string {
|
||||
return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return safe.RLockRet(func() []string {
|
||||
return xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
})
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// GetAddressMode returns the user's current address mode.
|
||||
@ -242,37 +249,39 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
|
||||
user.abortable.Abort()
|
||||
defer user.goSync()
|
||||
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
return safe.RLockRet(func() error {
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
})
|
||||
|
||||
user.updateCh.Clear()
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh.Set(addrID, primaryUpdateCh)
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
user.updateCh.Clear()
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
user.updateCh.Set(addrID, primaryUpdateCh)
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
})
|
||||
}
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// GetGluonIDs returns the users gluon IDs.
|
||||
@ -334,21 +343,26 @@ func (user *User) NewIMAPConnector(addrID string) connector.Connector {
|
||||
// In combined mode, this is just the user's primary address.
|
||||
// In split mode, this is all the user's addresses.
|
||||
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
return safe.RLockRetErr(func() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
})
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
})
|
||||
}
|
||||
imapConn[primAddr.ID] = newIMAPConnector(user, primAddr.ID)
|
||||
|
||||
return imapConn, nil
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
}
|
||||
}
|
||||
|
||||
return imapConn, nil
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// SendMail sends an email from the given address to the given recipients.
|
||||
@ -357,17 +371,17 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader)
|
||||
return ErrInvalidRecipient
|
||||
}
|
||||
|
||||
return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
if _, err := getAddrID(apiAddrs, from); err != nil {
|
||||
return safe.RLockRet(func() error {
|
||||
if _, err := getAddrID(user.apiAddrs, from); err != nil {
|
||||
return ErrInvalidReturnPath
|
||||
}
|
||||
|
||||
emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
emails := xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
return user.sendMail(authID, emails, from, to, r)
|
||||
})
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user.
|
||||
@ -382,15 +396,15 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) {
|
||||
return "", fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
return safe.RLockRetErr(func() (string, error) {
|
||||
for _, addr := range user.apiAddrs {
|
||||
if strings.EqualFold(addr.Email, email) {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid email")
|
||||
})
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// OnStatusUp is called when the connection goes up.
|
||||
|
||||
Reference in New Issue
Block a user