mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-11 19:48:32 +00:00
Other(refactor): Use normal value + mutex for user.apiAddrs
This commit is contained in:
@ -38,6 +38,7 @@ import (
|
||||
"github.com/bradenaw/juniper/xsync"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -55,7 +56,9 @@ type User struct {
|
||||
apiUser liteapi.User
|
||||
apiUserLock sync.RWMutex
|
||||
|
||||
apiAddrs *safe.Map[string, liteapi.Address]
|
||||
apiAddrs map[string]liteapi.Address
|
||||
apiAddrsLock sync.RWMutex
|
||||
|
||||
apiLabels *safe.Map[string, liteapi.Label]
|
||||
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
|
||||
sendHash *sendRecorder
|
||||
@ -134,7 +137,7 @@ func New(
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
apiUser: apiUser,
|
||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||
apiAddrs: groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }),
|
||||
apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil),
|
||||
updateCh: safe.NewMapFrom(updateCh, nil),
|
||||
sendHash: newSendRecorder(sendEntryExpiry),
|
||||
@ -217,19 +220,23 @@ func (user *User) Match(query string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool {
|
||||
return addr.Email == query
|
||||
})
|
||||
}, &user.apiUserLock)
|
||||
for _, addr := range user.apiAddrs {
|
||||
if query == addr.Email {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}, &user.apiUserLock, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// Emails returns all the user's email addresses via the callback.
|
||||
func (user *User) Emails() []string {
|
||||
return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return safe.RLockRet(func() []string {
|
||||
return xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
})
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// GetAddressMode returns the user's current address mode.
|
||||
@ -242,37 +249,39 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
|
||||
user.abortable.Abort()
|
||||
defer user.goSync()
|
||||
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
return safe.RLockRet(func() error {
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
})
|
||||
|
||||
user.updateCh.Clear()
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh.Set(addrID, primaryUpdateCh)
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
user.updateCh.Clear()
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
user.updateCh.Set(addrID, primaryUpdateCh)
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
user.updateCh.Set(addrID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
})
|
||||
}
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// GetGluonIDs returns the users gluon IDs.
|
||||
@ -334,21 +343,26 @@ func (user *User) NewIMAPConnector(addrID string) connector.Connector {
|
||||
// In combined mode, this is just the user's primary address.
|
||||
// In split mode, this is all the user's addresses.
|
||||
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
return safe.RLockRetErr(func() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
})
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
})
|
||||
}
|
||||
imapConn[primAddr.ID] = newIMAPConnector(user, primAddr.ID)
|
||||
|
||||
return imapConn, nil
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
}
|
||||
}
|
||||
|
||||
return imapConn, nil
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// SendMail sends an email from the given address to the given recipients.
|
||||
@ -357,17 +371,17 @@ func (user *User) SendMail(authID string, from string, to []string, r io.Reader)
|
||||
return ErrInvalidRecipient
|
||||
}
|
||||
|
||||
return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
if _, err := getAddrID(apiAddrs, from); err != nil {
|
||||
return safe.RLockRet(func() error {
|
||||
if _, err := getAddrID(user.apiAddrs, from); err != nil {
|
||||
return ErrInvalidReturnPath
|
||||
}
|
||||
|
||||
emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
emails := xslices.Map(maps.Values(user.apiAddrs), func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
return user.sendMail(authID, emails, from, to, r)
|
||||
})
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user.
|
||||
@ -382,15 +396,15 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) {
|
||||
return "", fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
return safe.RLockRetErr(func() (string, error) {
|
||||
for _, addr := range user.apiAddrs {
|
||||
if strings.EqualFold(addr.Email, email) {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid email")
|
||||
})
|
||||
}, &user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// OnStatusUp is called when the connection goes up.
|
||||
|
||||
Reference in New Issue
Block a user