mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat(GODT-2822): Integrate and activate all service
The bridge now runs on the new architecture.
This commit is contained in:
@ -34,7 +34,7 @@ import (
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/constants"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
imapservice "github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/xmaps"
|
||||
@ -58,75 +58,83 @@ type DiagMailboxMessage struct {
|
||||
Flags imap.FlagSet
|
||||
}
|
||||
|
||||
func (apm DiagnosticMetadata) BuildMailboxToMessageMap(user *User) (map[string]AccountMailboxMap, error) {
|
||||
return safe.RLockRetErr(func() (map[string]AccountMailboxMap, error) {
|
||||
result := make(map[string]AccountMailboxMap)
|
||||
func (apm DiagnosticMetadata) BuildMailboxToMessageMap(ctx context.Context, user *User) (map[string]AccountMailboxMap, error) {
|
||||
apiAddrs, err := user.identityService.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
mode := user.GetAddressMode()
|
||||
primaryAddrID, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
|
||||
apiLabels, err := user.imapService.GetLabels(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get labels: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[string]AccountMailboxMap)
|
||||
|
||||
mode := user.GetAddressMode()
|
||||
primaryAddrID, err := usertypes.GetPrimaryAddr(apiAddrs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary addr for user: %w", err)
|
||||
}
|
||||
|
||||
getAccount := func(addrID string) (AccountMailboxMap, bool) {
|
||||
if mode == vault.CombinedMode {
|
||||
addrID = primaryAddrID.ID
|
||||
}
|
||||
|
||||
getAccount := func(addrID string) (AccountMailboxMap, bool) {
|
||||
if mode == vault.CombinedMode {
|
||||
addrID = primaryAddrID.ID
|
||||
}
|
||||
addr := apiAddrs[addrID]
|
||||
if addr.Status != proton.AddressStatusEnabled {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
addr := user.apiAddrs[addrID]
|
||||
if addr.Status != proton.AddressStatusEnabled {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := result[addr.Email]
|
||||
if !ok {
|
||||
result[addr.Email] = make(AccountMailboxMap)
|
||||
v = result[addr.Email]
|
||||
}
|
||||
|
||||
v, ok := result[addr.Email]
|
||||
return v, true
|
||||
}
|
||||
|
||||
for _, metadata := range apm.Metadata {
|
||||
for _, label := range metadata.LabelIDs {
|
||||
details, ok := apiLabels[label]
|
||||
if !ok {
|
||||
result[addr.Email] = make(AccountMailboxMap)
|
||||
v = result[addr.Email]
|
||||
logrus.Warnf("User %v has message with unknown label '%v'", user.Name(), label)
|
||||
continue
|
||||
}
|
||||
|
||||
return v, true
|
||||
}
|
||||
if !imapservice.WantLabel(details) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, metadata := range apm.Metadata {
|
||||
for _, label := range metadata.LabelIDs {
|
||||
details, ok := user.apiLabels[label]
|
||||
if !ok {
|
||||
logrus.Warnf("User %v has message with unknown label '%v'", user.Name(), label)
|
||||
continue
|
||||
}
|
||||
account, enabled := getAccount(metadata.AddressID)
|
||||
if !enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if !wantLabel(details) {
|
||||
continue
|
||||
}
|
||||
var mboxName string
|
||||
if details.Type == proton.LabelTypeSystem {
|
||||
mboxName = details.Name
|
||||
} else {
|
||||
mboxName = strings.Join(imapservice.GetMailboxName(details), "/")
|
||||
}
|
||||
|
||||
account, enabled := getAccount(metadata.AddressID)
|
||||
if !enabled {
|
||||
continue
|
||||
}
|
||||
mboxMessage := DiagMailboxMessage{
|
||||
UserID: user.ID(),
|
||||
ID: metadata.ID,
|
||||
AddressID: metadata.AddressID,
|
||||
Flags: imapservice.BuildFlagSetFromMessageMetadata(metadata),
|
||||
}
|
||||
|
||||
var mboxName string
|
||||
if details.Type == proton.LabelTypeSystem {
|
||||
mboxName = details.Name
|
||||
} else {
|
||||
mboxName = strings.Join(getMailboxName(details), "/")
|
||||
}
|
||||
|
||||
mboxMessage := DiagMailboxMessage{
|
||||
UserID: user.ID(),
|
||||
ID: metadata.ID,
|
||||
AddressID: metadata.AddressID,
|
||||
Flags: buildFlagSetFromMessageMetadata(metadata),
|
||||
}
|
||||
|
||||
if v, ok := account[mboxName]; ok {
|
||||
account[mboxName] = append(v, mboxMessage)
|
||||
} else {
|
||||
account[mboxName] = []DiagMailboxMessage{mboxMessage}
|
||||
}
|
||||
if v, ok := account[mboxName]; ok {
|
||||
account[mboxName] = append(v, mboxMessage)
|
||||
} else {
|
||||
account[mboxName] = []DiagMailboxMessage{mboxMessage}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}, user.apiAddrsLock, user.apiLabelsLock)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (user *User) GetDiagnosticMetadata(ctx context.Context) (DiagnosticMetadata, error) {
|
||||
@ -161,52 +169,56 @@ func (user *User) DebugDownloadMessages(
|
||||
msgs map[string]DiagMailboxMessage,
|
||||
progressCB func(string, int, int),
|
||||
) error {
|
||||
var err error
|
||||
safe.RLock(func() {
|
||||
err = func() error {
|
||||
total := len(msgs)
|
||||
userID := user.ID()
|
||||
total := len(msgs)
|
||||
userID := user.ID()
|
||||
|
||||
counter := 1
|
||||
for _, msg := range msgs {
|
||||
if progressCB != nil {
|
||||
progressCB(userID, counter, total)
|
||||
counter++
|
||||
}
|
||||
apiUser, err := user.identityService.GetAPIUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get api user: %w", err)
|
||||
}
|
||||
|
||||
msgDir := filepath.Join(path, msg.ID)
|
||||
if err := os.MkdirAll(msgDir, 0o700); err != nil {
|
||||
return fmt.Errorf("failed to create directory '%v':%w", msgDir, err)
|
||||
}
|
||||
apiAddrs, err := user.identityService.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get address: %w", err)
|
||||
}
|
||||
|
||||
message, err := user.client.GetFullMessage(ctx, msg.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download message '%v':%w", msg.ID, err)
|
||||
}
|
||||
counter := 1
|
||||
for _, msg := range msgs {
|
||||
if progressCB != nil {
|
||||
progressCB(userID, counter, total)
|
||||
counter++
|
||||
}
|
||||
|
||||
if err := writeMetadata(msgDir, message.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
msgDir := filepath.Join(path, msg.ID)
|
||||
if err := os.MkdirAll(msgDir, 0o700); err != nil {
|
||||
return fmt.Errorf("failed to create directory '%v':%w", msgDir, err)
|
||||
}
|
||||
|
||||
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
switch {
|
||||
case len(message.Attachments) > 0:
|
||||
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
|
||||
message, err := user.client.GetFullMessage(ctx, msg.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download message '%v':%w", msg.ID, err)
|
||||
}
|
||||
|
||||
case message.MIMEType == "multipart/mixed":
|
||||
return decodePGPMessage(msgDir, addrKR, message.Message)
|
||||
if err := writeMetadata(msgDir, message.Message); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return decodeSimpleMessage(msgDir, addrKR, message.Message)
|
||||
}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := usertypes.WithAddrKR(apiUser, apiAddrs[msg.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
switch {
|
||||
case len(message.Attachments) > 0:
|
||||
return decodeMultipartMessage(msgDir, addrKR, message.Message, message.AttData)
|
||||
|
||||
case message.MIMEType == "multipart/mixed":
|
||||
return decodePGPMessage(msgDir, addrKR, message.Message)
|
||||
|
||||
default:
|
||||
return decodeSimpleMessage(msgDir, addrKR, message.Message)
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
}, user.apiAddrsLock, user.apiUserLock)
|
||||
return err
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBodyName(path string) string {
|
||||
|
||||
@ -1,864 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/ProtonMail/gluon"
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// handleAPIEvent handles the given proton.Event.
|
||||
func (user *User) handleAPIEvent(ctx context.Context, event proton.Event) error {
|
||||
if event.Refresh&proton.RefreshMail != 0 {
|
||||
return user.handleRefreshEvent(ctx, event.Refresh, event.EventID)
|
||||
}
|
||||
|
||||
if event.User != nil {
|
||||
user.handleUserEvent(ctx, *event.User)
|
||||
}
|
||||
|
||||
if len(event.Addresses) > 0 {
|
||||
if err := user.handleAddressEvents(ctx, event.Addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Labels) > 0 {
|
||||
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Messages) > 0 {
|
||||
if err := user.handleMessageEvents(ctx, event.Messages); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if event.UsedSpace != nil {
|
||||
user.handleUsedSpaceChange(*event.UsedSpace)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleRefreshEvent(ctx context.Context, refresh proton.RefreshFlag, eventID string) error {
|
||||
l := user.log.WithFields(logrus.Fields{
|
||||
"eventID": eventID,
|
||||
"refresh": refresh,
|
||||
})
|
||||
|
||||
l.Info("Handling refresh event")
|
||||
|
||||
// Abort the event stream
|
||||
defer user.pollAbort.Abort()
|
||||
|
||||
// Re-sync messages after the user, address and label refresh.
|
||||
defer user.goSync()
|
||||
|
||||
return user.syncUserAddressesLabelsAndClearSync(ctx, false)
|
||||
}
|
||||
|
||||
func (user *User) syncUserAddressesLabelsAndClearSync(ctx context.Context, cancelEventPool bool) error {
|
||||
return safe.LockRet(func() error {
|
||||
// Fetch latest user info.
|
||||
apiUser, err := user.client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Fetch latest address info.
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
// Fetch latest label info.
|
||||
apiLabels, err := user.client.GetLabels(ctx, proton.LabelTypeSystem, proton.LabelTypeFolder, proton.LabelTypeLabel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get labels: %w", err)
|
||||
}
|
||||
|
||||
// Update the API info in the user.
|
||||
user.apiUser = apiUser
|
||||
user.apiAddrs = usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID })
|
||||
user.apiLabels = usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID })
|
||||
|
||||
// Clear sync status; we want to sync everything again.
|
||||
if err := user.clearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
// The user was refreshed.
|
||||
user.eventCh.Enqueue(events.UserRefreshed{
|
||||
UserID: user.apiUser.ID,
|
||||
CancelEventPool: cancelEventPool,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// handleUserEvent handles the given user event.
|
||||
func (user *User) handleUserEvent(_ context.Context, userEvent proton.User) {
|
||||
safe.Lock(func() {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"userID": userEvent.ID,
|
||||
"username": logging.Sensitive(userEvent.Name),
|
||||
}).Info("Handling user event")
|
||||
|
||||
user.apiUser = userEvent
|
||||
|
||||
user.eventCh.Enqueue(events.UserChanged{
|
||||
UserID: user.apiUser.ID,
|
||||
})
|
||||
}, user.apiUserLock)
|
||||
}
|
||||
|
||||
// handleAddressEvents handles the given address events.
|
||||
// GODT-1945: If split address mode, need to signal back to bridge to update the addresses.
|
||||
func (user *User) handleAddressEvents(ctx context.Context, addressEvents []proton.AddressEvent) error {
|
||||
for _, event := range addressEvents {
|
||||
switch event.Action {
|
||||
case proton.EventCreate:
|
||||
if err := user.handleCreateAddressEvent(ctx, event); err != nil {
|
||||
user.reportError("Failed to apply address create event", err)
|
||||
return fmt.Errorf("failed to handle create address event: %w", err)
|
||||
}
|
||||
|
||||
case proton.EventUpdate, proton.EventUpdateFlags:
|
||||
if err := user.handleUpdateAddressEvent(ctx, event); err != nil {
|
||||
if errors.Is(err, ErrAddressDoesNotExist) {
|
||||
logrus.Debugf("Address %v does not exist, will try create instead", event.Address.ID)
|
||||
if createErr := user.handleCreateAddressEvent(ctx, event); createErr != nil {
|
||||
user.reportError("Failed to apply address update event (with create)", createErr)
|
||||
return fmt.Errorf("failed to handle update address event (with create): %w", createErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
user.reportError("Failed to apply address update event", err)
|
||||
return fmt.Errorf("failed to handle update address event: %w", err)
|
||||
}
|
||||
|
||||
case proton.EventDelete:
|
||||
if err := user.handleDeleteAddressEvent(ctx, event); err != nil {
|
||||
user.reportError("Failed to apply address delete event", err)
|
||||
return fmt.Errorf("failed to delete address: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.AddressEvent) error {
|
||||
if err := safe.LockRet(func() error {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"addressID": event.ID,
|
||||
"email": logging.Sensitive(event.Address.Email),
|
||||
}).Info("Handling address created event")
|
||||
|
||||
if _, ok := user.apiAddrs[event.Address.ID]; ok {
|
||||
user.log.Debugf("Address %q already exists", event.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
user.apiAddrs[event.Address.ID] = event.Address
|
||||
|
||||
// If the address is disabled.
|
||||
if event.Address.Status != proton.AddressStatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the address is enabled, we need to hook it up to the update channels.
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = async.NewQueuedChannel[imap.Update](
|
||||
0,
|
||||
0,
|
||||
user.panicHandler,
|
||||
fmt.Sprintf("user-update-split-%v", event.Address.ID),
|
||||
)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, user.apiAddrsLock, user.updateChLock); err != nil {
|
||||
return fmt.Errorf("failed to handle create address event: %w", err)
|
||||
}
|
||||
|
||||
// Perform the sync in an RLock.
|
||||
return safe.RLockRet(func() error {
|
||||
if event.Address.Status != proton.AddressStatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
if err := syncLabels(ctx, user.apiLabels, user.updateCh[event.Address.ID]); err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
var ErrAddressDoesNotExist = errors.New("address does not exist")
|
||||
|
||||
func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.AddressEvent) error { //nolint:unparam
|
||||
return safe.LockRet(func() error {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"addressID": event.ID,
|
||||
"email": logging.Sensitive(event.Address.Email),
|
||||
}).Info("Handling address updated event")
|
||||
|
||||
oldAddr, ok := user.apiAddrs[event.Address.ID]
|
||||
if !ok {
|
||||
return ErrAddressDoesNotExist
|
||||
}
|
||||
|
||||
user.apiAddrs[event.Address.ID] = event.Address
|
||||
|
||||
switch {
|
||||
// If the address was newly enabled:
|
||||
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = async.NewQueuedChannel[imap.Update](
|
||||
0,
|
||||
0,
|
||||
user.panicHandler,
|
||||
fmt.Sprintf("user-update-split-%v", event.Address.ID),
|
||||
)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressEnabled{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
// If the address was newly disabled:
|
||||
case oldAddr.Status == proton.AddressStatusEnabled && event.Address.Status != proton.AddressStatusEnabled:
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.ID].CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
delete(user.updateCh, event.ID)
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDisabled{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
// Otherwise it's just an update:
|
||||
default:
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}, user.apiAddrsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteAddressEvent(_ context.Context, event proton.AddressEvent) error {
|
||||
return safe.LockRet(func() error {
|
||||
user.log.WithField("addressID", event.ID).Info("Handling address deleted event")
|
||||
|
||||
addr, ok := user.apiAddrs[event.ID]
|
||||
if !ok {
|
||||
user.log.Debugf("Address %q does not exist", event.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(user.apiAddrs, event.ID)
|
||||
|
||||
// If the address was disabled to begin with, we don't need to do anything.
|
||||
if addr.Status != proton.AddressStatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, in split mode, drop the update queue.
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.ID].CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
// And in either mode, remove the address from the update channel map.
|
||||
delete(user.updateCh, event.ID)
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.ID,
|
||||
Email: addr.Email,
|
||||
})
|
||||
|
||||
return nil
|
||||
}, user.apiAddrsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// handleLabelEvents handles the given label events.
|
||||
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []proton.LabelEvent) error {
|
||||
for _, event := range labelEvents {
|
||||
switch event.Action {
|
||||
case proton.EventCreate:
|
||||
updates, err := user.handleCreateLabelEvent(ctx, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle create label event: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case proton.EventUpdate, proton.EventUpdateFlags:
|
||||
updates, err := user.handleUpdateLabelEvent(ctx, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle update label event: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case proton.EventDelete:
|
||||
updates, err := user.handleDeleteLabelEvent(ctx, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle delete label event: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return fmt.Errorf("failed to handle delete label event in gluon: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleCreateLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.LockRetErr(func() ([]imap.Update, error) {
|
||||
var updates []imap.Update
|
||||
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"labelID": event.ID,
|
||||
"name": logging.Sensitive(event.Label.Name),
|
||||
}).Info("Handling label created event")
|
||||
|
||||
user.apiLabels[event.Label.ID] = event.Label
|
||||
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := newMailboxCreatedUpdate(imap.MailboxID(event.ID), getMailboxName(event.Label))
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserLabelCreated{
|
||||
UserID: user.apiUser.ID,
|
||||
LabelID: event.Label.ID,
|
||||
Name: event.Label.Name,
|
||||
})
|
||||
|
||||
return updates, nil
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateLabelEvent(ctx context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.LockRetErr(func() ([]imap.Update, error) {
|
||||
var updates []imap.Update
|
||||
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"labelID": event.ID,
|
||||
"name": logging.Sensitive(event.Label.Name),
|
||||
}).Info("Handling label updated event")
|
||||
|
||||
stack := []proton.Label{event.Label}
|
||||
|
||||
for len(stack) > 0 {
|
||||
label := stack[0]
|
||||
stack = stack[1:]
|
||||
|
||||
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
|
||||
if _, ok := user.apiLabels[label.ID]; ok {
|
||||
user.apiLabels[label.ID] = event.Label
|
||||
}
|
||||
|
||||
// API doesn't notify us that the path has changed. We need to fetch it again.
|
||||
apiLabel, err := user.client.GetLabel(ctx, label.ID, label.Type)
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
user.log.WithError(apiErr).Warn("Failed to get label: label does not exist")
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to get label %q: %w", label.ID, err)
|
||||
}
|
||||
|
||||
// Update the label in the map.
|
||||
user.apiLabels[apiLabel.ID] = apiLabel
|
||||
|
||||
// Notify the IMAP clients.
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := imap.NewMailboxUpdated(
|
||||
imap.MailboxID(apiLabel.ID),
|
||||
getMailboxName(apiLabel),
|
||||
)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserLabelUpdated{
|
||||
UserID: user.apiUser.ID,
|
||||
LabelID: apiLabel.ID,
|
||||
Name: apiLabel.Name,
|
||||
})
|
||||
|
||||
children := xslices.Filter(maps.Values(user.apiLabels), func(other proton.Label) bool {
|
||||
return other.ParentID == label.ID
|
||||
})
|
||||
|
||||
stack = append(stack, children...)
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteLabelEvent(_ context.Context, event proton.LabelEvent) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.LockRetErr(func() ([]imap.Update, error) {
|
||||
var updates []imap.Update
|
||||
|
||||
user.log.WithField("labelID", event.ID).Info("Handling label deleted event")
|
||||
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := imap.NewMailboxDeleted(imap.MailboxID(event.ID))
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
delete(user.apiLabels, event.ID)
|
||||
|
||||
user.eventCh.Enqueue(events.UserLabelDeleted{
|
||||
UserID: user.apiUser.ID,
|
||||
LabelID: event.ID,
|
||||
})
|
||||
|
||||
return updates, nil
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// handleMessageEvents handles the given message events.
|
||||
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []proton.MessageEvent) error {
|
||||
for _, event := range messageEvents {
|
||||
ctx = logging.WithLogrusField(ctx, "messageID", event.ID)
|
||||
|
||||
switch event.Action {
|
||||
case proton.EventCreate:
|
||||
updates, err := user.handleCreateMessageEvent(logging.WithLogrusField(ctx, "action", "create message"), event.Message)
|
||||
if err != nil {
|
||||
user.reportError("Failed to apply create message event", err)
|
||||
return fmt.Errorf("failed to handle create message event: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case proton.EventUpdate, proton.EventUpdateFlags:
|
||||
// Draft update means to completely remove old message and upload the new data again, but we should
|
||||
// only do this if the event is of type EventUpdate otherwise label switch operations will not work.
|
||||
if (event.Message.IsDraft() || (event.Message.Flags&proton.MessageFlagSent != 0)) && event.Action == proton.EventUpdate {
|
||||
updates, err := user.handleUpdateDraftOrSentMessage(
|
||||
logging.WithLogrusField(ctx, "action", "update draft or sent message"),
|
||||
event,
|
||||
)
|
||||
if err != nil {
|
||||
user.reportError("Failed to apply update draft message event", err)
|
||||
return fmt.Errorf("failed to handle update draft event: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// GODT-2028 - Use better events here. It should be possible to have 3 separate events that refrain to
|
||||
// whether the flags, labels or read only data (header+body) has been changed. This requires fixing proton
|
||||
// first so that it correctly reports those cases.
|
||||
// Issue regular update to handle mailboxes and flag changes.
|
||||
updates, err := user.handleUpdateMessageEvent(
|
||||
logging.WithLogrusField(ctx, "action", "update message"),
|
||||
event.Message,
|
||||
)
|
||||
if err != nil {
|
||||
user.reportError("Failed to apply update message event", err)
|
||||
return fmt.Errorf("failed to handle update message event: %w", err)
|
||||
}
|
||||
|
||||
// If the update fails on the gluon side because it doesn't exist, we try to create the message instead.
|
||||
if err := waitOnIMAPUpdates(ctx, updates); gluon.IsNoSuchMessage(err) {
|
||||
user.log.WithError(err).Error("Failed to handle update message event in gluon, will try creating it")
|
||||
|
||||
updates, err := user.handleCreateMessageEvent(ctx, event.Message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle update message event as create: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case proton.EventDelete:
|
||||
updates, err := user.handleDeleteMessageEvent(
|
||||
logging.WithLogrusField(ctx, "action", "delete message"),
|
||||
event,
|
||||
)
|
||||
if err != nil {
|
||||
user.reportError("Failed to apply delete message event", err)
|
||||
return fmt.Errorf("failed to handle delete message event: %w", err)
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return fmt.Errorf("failed to handle delete message event in gluon: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.MessageMetadata) ([]imap.Update, error) {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"messageID": message.ID,
|
||||
"subject": logging.Sensitive(message.Subject),
|
||||
}).Info("Handling message created event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
user.log.WithField("messageID", message.ID).Warn("Cannot create new message: full message is missing on API")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get full message: %w", err)
|
||||
}
|
||||
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
var update imap.Update
|
||||
|
||||
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||
|
||||
if res.err != nil {
|
||||
user.log.WithError(err).Error("Failed to build RFC822 message")
|
||||
|
||||
if err := user.vault.AddFailedMessageID(message.ID); err != nil {
|
||||
user.log.WithError(err).Error("Failed to add failed message ID to vault")
|
||||
}
|
||||
|
||||
user.reportErrorAndMessageID("Failed to build message (event create)", res.err, res.messageID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := user.vault.RemFailedMessageID(message.ID); err != nil {
|
||||
user.log.WithError(err).Error("Failed to remove failed message ID from vault")
|
||||
}
|
||||
|
||||
update = imap.NewMessagesCreated(false, res.update)
|
||||
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !didPublish {
|
||||
update = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []imap.Update{update}, nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateMessageEvent(_ context.Context, message proton.MessageMetadata) ([]imap.Update, error) { //nolint:unparam
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"messageID": message.ID,
|
||||
"subject": logging.Sensitive(message.Subject),
|
||||
}).Info("Handling message updated event")
|
||||
|
||||
flags := buildFlagSetFromMessageMetadata(message)
|
||||
|
||||
update := imap.NewMessageMailboxesUpdated(
|
||||
imap.MessageID(message.ID),
|
||||
usertypes.MapTo[string, imap.MailboxID](wantLabels(user.apiLabels, message.LabelIDs)),
|
||||
flags,
|
||||
)
|
||||
|
||||
didPublish, err := safePublishMessageUpdate(user, message.AddressID, update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !didPublish {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []imap.Update{update}, nil
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteMessageEvent(_ context.Context, event proton.MessageEvent) ([]imap.Update, error) {
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
user.log.WithField("messageID", event.ID).Info("Handling message deleted event")
|
||||
|
||||
var updates []imap.Update
|
||||
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := imap.NewMessagesDeleted(imap.MessageID(event.ID))
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateDraftOrSentMessage(ctx context.Context, event proton.MessageEvent) ([]imap.Update, error) {
|
||||
return safe.RLockRetErr(func() ([]imap.Update, error) {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"messageID": event.ID,
|
||||
"subject": logging.Sensitive(event.Message.Subject),
|
||||
"isDraft": event.Message.IsDraft(),
|
||||
}).Info("Handling draft or sent updated event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, usertypes.NewProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
user.log.WithField("messageID", event.Message.ID).Warn("Cannot update message: full message is missing on API")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to get full draft: %w", err)
|
||||
}
|
||||
|
||||
var update imap.Update
|
||||
|
||||
if err := usertypes.WithAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||
|
||||
if res.err != nil {
|
||||
logrus.WithError(err).Error("Failed to build RFC822 message")
|
||||
|
||||
if err := user.vault.AddFailedMessageID(event.ID); err != nil {
|
||||
user.log.WithError(err).Error("Failed to add failed message ID to vault")
|
||||
}
|
||||
|
||||
user.reportErrorAndMessageID("Failed to build draft message (event update)", res.err, res.messageID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := user.vault.RemFailedMessageID(event.ID); err != nil {
|
||||
user.log.WithError(err).Error("Failed to remove failed message ID from vault")
|
||||
}
|
||||
|
||||
update = imap.NewMessageUpdated(
|
||||
res.update.Message,
|
||||
res.update.Literal,
|
||||
res.update.MailboxIDs,
|
||||
res.update.ParsedMessage,
|
||||
true, // Is the message doesn't exist, silently create it.
|
||||
)
|
||||
|
||||
didPublish, err := safePublishMessageUpdate(user, full.AddressID, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !didPublish {
|
||||
update = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if update == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []imap.Update{update}, nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUsedSpaceChange(usedSpace int) {
|
||||
safe.Lock(func() {
|
||||
if user.apiUser.UsedSpace == usedSpace {
|
||||
return
|
||||
}
|
||||
|
||||
user.apiUser.UsedSpace = usedSpace
|
||||
user.eventCh.Enqueue(events.UsedSpaceChanged{
|
||||
UserID: user.apiUser.ID,
|
||||
UsedSpace: usedSpace,
|
||||
})
|
||||
}, user.apiUserLock)
|
||||
}
|
||||
|
||||
func getMailboxName(label proton.Label) []string {
|
||||
var name []string
|
||||
|
||||
switch label.Type {
|
||||
case proton.LabelTypeFolder:
|
||||
name = append([]string{folderPrefix}, label.Path...)
|
||||
|
||||
case proton.LabelTypeLabel:
|
||||
name = append([]string{labelPrefix}, label.Path...)
|
||||
|
||||
case proton.LabelTypeContactGroup:
|
||||
fallthrough
|
||||
case proton.LabelTypeSystem:
|
||||
fallthrough
|
||||
default:
|
||||
name = label.Path
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
func waitOnIMAPUpdates(ctx context.Context, updates []imap.Update) error {
|
||||
for _, update := range updates {
|
||||
if err, ok := update.WaitContext(ctx); ok && err != nil {
|
||||
return fmt.Errorf("failed to apply gluon update %v: %w", update.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) reportError(title string, err error) {
|
||||
user.reportErrorNoContextCancel(title, err, reporter.Context{})
|
||||
}
|
||||
|
||||
func (user *User) reportErrorAndMessageID(title string, err error, messgeID string) {
|
||||
user.reportErrorNoContextCancel(title, err, reporter.Context{"messageID": messgeID})
|
||||
}
|
||||
|
||||
func (user *User) reportErrorNoContextCancel(title string, err error, reportContext reporter.Context) {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
reportContext["error"] = err
|
||||
reportContext["error_type"] = internal.ErrCauseType(err)
|
||||
if rerr := user.reporter.ReportMessageWithContext(title, reportContext); rerr != nil {
|
||||
user.log.WithError(err).WithField("title", title).Error("Failed to report message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safePublishMessageUpdate handles the rare case where the address' update channel may have been deleted in the same
|
||||
// event. This rare case can take place if in the same event fetch request there is an update for delete address and
|
||||
// create/update message.
|
||||
// If the user is in combined mode, we simply push the update to the primary address. If the user is in split mode
|
||||
// we do not publish the update as the address no longer exists.
|
||||
func safePublishMessageUpdate(user *User, addressID string, update imap.Update) (bool, error) {
|
||||
v, ok := user.updateCh[addressID]
|
||||
if !ok {
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
primAddr, err := usertypes.GetPrimaryAddr(user.apiAddrs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
primaryCh, ok := user.updateCh[primAddr.ID]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("primary address channel is not available")
|
||||
}
|
||||
|
||||
primaryCh.Enqueue(update)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logrus.Warnf("Update channel not found for address %v, it may have been already deleted", addressID)
|
||||
_ = user.reporter.ReportMessage("Message Update channel does not exist")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
v.Enqueue(update)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@ -1,750 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/rfc5322"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Verify that *imapConnector implements connector.Connector.
|
||||
var _ connector.Connector = (*imapConnector)(nil)
|
||||
|
||||
var (
|
||||
defaultFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals
|
||||
defaultPermanentFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged, imap.FlagDeleted) // nolint:gochecknoglobals
|
||||
defaultAttributes = imap.NewFlagSet() // nolint:gochecknoglobals
|
||||
)
|
||||
|
||||
const (
|
||||
folderPrefix = "Folders"
|
||||
labelPrefix = "Labels"
|
||||
)
|
||||
|
||||
type imapConnector struct {
|
||||
*User
|
||||
|
||||
addrID string
|
||||
|
||||
flags, permFlags, attrs imap.FlagSet
|
||||
}
|
||||
|
||||
func newIMAPConnector(user *User, addrID string) *imapConnector {
|
||||
return &imapConnector{
|
||||
User: user,
|
||||
|
||||
addrID: addrID,
|
||||
|
||||
flags: defaultFlags,
|
||||
permFlags: defaultPermanentFlags,
|
||||
attrs: defaultAttributes,
|
||||
}
|
||||
}
|
||||
|
||||
// Authorize returns whether the given username/password combination are valid for this connector.
|
||||
func (conn *imapConnector) Authorize(ctx context.Context, username string, password []byte) bool {
|
||||
addrID, err := conn.CheckAuth(username, password)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if conn.vault.AddressMode() == vault.SplitMode && addrID != conn.addrID {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.User.SendConfigStatusSuccess(ctx)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateMailbox creates a label with the given name.
|
||||
func (conn *imapConnector) CreateMailbox(ctx context.Context, name []string) (imap.Mailbox, error) {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if len(name) < 2 {
|
||||
return imap.Mailbox{}, fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
|
||||
}
|
||||
|
||||
switch name[0] {
|
||||
case folderPrefix:
|
||||
return conn.createFolder(ctx, name[1:])
|
||||
|
||||
case labelPrefix:
|
||||
return conn.createLabel(ctx, name[1:])
|
||||
|
||||
default:
|
||||
return imap.Mailbox{}, fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *imapConnector) createLabel(ctx context.Context, name []string) (imap.Mailbox, error) {
|
||||
if len(name) != 1 {
|
||||
return imap.Mailbox{}, fmt.Errorf("a label cannot have children: %w", connector.ErrOperationNotAllowed)
|
||||
}
|
||||
|
||||
return safe.LockRetErr(func() (imap.Mailbox, error) {
|
||||
label, err := conn.client.CreateLabel(ctx, proton.CreateLabelReq{
|
||||
Name: name[0],
|
||||
Color: "#f66",
|
||||
Type: proton.LabelTypeLabel,
|
||||
})
|
||||
if err != nil {
|
||||
return imap.Mailbox{}, err
|
||||
}
|
||||
|
||||
conn.apiLabels[label.ID] = label
|
||||
|
||||
return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs), nil
|
||||
}, conn.apiLabelsLock)
|
||||
}
|
||||
|
||||
func (conn *imapConnector) createFolder(ctx context.Context, name []string) (imap.Mailbox, error) {
|
||||
return safe.LockRetErr(func() (imap.Mailbox, error) {
|
||||
var parentID string
|
||||
|
||||
if len(name) > 1 {
|
||||
for _, label := range conn.apiLabels {
|
||||
if !slices.Equal(label.Path, name[:len(name)-1]) {
|
||||
continue
|
||||
}
|
||||
|
||||
parentID = label.ID
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if parentID == "" {
|
||||
return imap.Mailbox{}, fmt.Errorf("parent folder %q does not exist: %w", name[:len(name)-1], connector.ErrOperationNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
label, err := conn.client.CreateLabel(ctx, proton.CreateLabelReq{
|
||||
Name: name[len(name)-1],
|
||||
Color: "#f66",
|
||||
Type: proton.LabelTypeFolder,
|
||||
ParentID: parentID,
|
||||
})
|
||||
if err != nil {
|
||||
return imap.Mailbox{}, err
|
||||
}
|
||||
|
||||
// Add label to list so subsequent sub folder create requests work correct.
|
||||
conn.apiLabels[label.ID] = label
|
||||
|
||||
return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs), nil
|
||||
}, conn.apiLabelsLock)
|
||||
}
|
||||
|
||||
// UpdateMailboxName sets the name of the label with the given ID.
|
||||
func (conn *imapConnector) UpdateMailboxName(ctx context.Context, labelID imap.MailboxID, name []string) error {
|
||||
return safe.LockRet(func() error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if len(name) < 2 {
|
||||
return fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
|
||||
}
|
||||
|
||||
switch name[0] {
|
||||
case folderPrefix:
|
||||
return conn.updateFolder(ctx, labelID, name[1:])
|
||||
|
||||
case labelPrefix:
|
||||
return conn.updateLabel(ctx, labelID, name[1:])
|
||||
|
||||
default:
|
||||
return fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed)
|
||||
}
|
||||
}, conn.apiLabelsLock)
|
||||
}
|
||||
|
||||
func (conn *imapConnector) updateLabel(ctx context.Context, labelID imap.MailboxID, name []string) error {
|
||||
if len(name) != 1 {
|
||||
return fmt.Errorf("a label cannot have children: %w", connector.ErrOperationNotAllowed)
|
||||
}
|
||||
|
||||
label, err := conn.client.GetLabel(ctx, string(labelID), proton.LabelTypeLabel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
update, err := conn.client.UpdateLabel(ctx, label.ID, proton.UpdateLabelReq{
|
||||
Name: name[0],
|
||||
Color: label.Color,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.apiLabels[label.ID] = update
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *imapConnector) updateFolder(ctx context.Context, labelID imap.MailboxID, name []string) error {
|
||||
var parentID string
|
||||
|
||||
if len(name) > 1 {
|
||||
for _, label := range conn.apiLabels {
|
||||
if !slices.Equal(label.Path, name[:len(name)-1]) {
|
||||
continue
|
||||
}
|
||||
|
||||
parentID = label.ID
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if parentID == "" {
|
||||
return fmt.Errorf("parent folder %q does not exist: %w", name[:len(name)-1], connector.ErrOperationNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
label, err := conn.client.GetLabel(ctx, string(labelID), proton.LabelTypeFolder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
update, err := conn.client.UpdateLabel(ctx, string(labelID), proton.UpdateLabelReq{
|
||||
Name: name[len(name)-1],
|
||||
Color: label.Color,
|
||||
ParentID: parentID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.apiLabels[label.ID] = update
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMailbox deletes the label with the given ID.
|
||||
func (conn *imapConnector) DeleteMailbox(ctx context.Context, labelID imap.MailboxID) error {
|
||||
return safe.LockRet(func() error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if err := conn.client.DeleteLabel(ctx, string(labelID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(conn.apiLabels, string(labelID))
|
||||
|
||||
return nil
|
||||
}, conn.apiLabelsLock)
|
||||
}
|
||||
|
||||
// CreateMessage creates a new message on the remote.
|
||||
func (conn *imapConnector) CreateMessage(
|
||||
ctx context.Context,
|
||||
mailboxID imap.MailboxID,
|
||||
literal []byte,
|
||||
flags imap.FlagSet,
|
||||
_ time.Time,
|
||||
) (imap.Message, []byte, error) {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if mailboxID == proton.AllMailLabel {
|
||||
return imap.Message{}, nil, connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
toList, err := getLiteralToList(literal)
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to retrieve addresses from literal:%w", err)
|
||||
}
|
||||
|
||||
// Compute the hash of the message (to match it against SMTP messages).
|
||||
hash, err := sendrecorder.GetMessageHash(literal)
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, err
|
||||
}
|
||||
|
||||
// Check if we already tried to send this message recently.
|
||||
if messageID, ok, err := conn.sendHash.HasEntryWait(ctx, hash, time.Now().Add(90*time.Second), toList); err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to check send hash: %w", err)
|
||||
} else if ok {
|
||||
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
||||
|
||||
// Query the server-side message.
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
// Build the message as it is on the server.
|
||||
if err := safe.RLockRet(func() error {
|
||||
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[full.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
var err error
|
||||
|
||||
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}, conn.apiUserLock, conn.apiAddrsLock); err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to build message: %w", err)
|
||||
}
|
||||
|
||||
return toIMAPMessage(full.MessageMetadata), literal, nil
|
||||
}
|
||||
|
||||
wantLabelIDs := []string{string(mailboxID)}
|
||||
|
||||
if flags.Contains(imap.FlagFlagged) {
|
||||
wantLabelIDs = append(wantLabelIDs, proton.StarredLabel)
|
||||
}
|
||||
|
||||
var wantFlags proton.MessageFlag
|
||||
|
||||
unread := !flags.Contains(imap.FlagSeen)
|
||||
|
||||
if mailboxID != proton.DraftsLabel {
|
||||
header, err := rfc822.Parse(literal).ParseHeader()
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case mailboxID == proton.InboxLabel:
|
||||
wantFlags = wantFlags.Add(proton.MessageFlagReceived)
|
||||
|
||||
case mailboxID == proton.SentLabel:
|
||||
wantFlags = wantFlags.Add(proton.MessageFlagSent)
|
||||
|
||||
case header.Has("Received"):
|
||||
wantFlags = wantFlags.Add(proton.MessageFlagReceived)
|
||||
|
||||
default:
|
||||
wantFlags = wantFlags.Add(proton.MessageFlagSent)
|
||||
}
|
||||
} else {
|
||||
unread = false
|
||||
}
|
||||
|
||||
if flags.Contains(imap.FlagAnswered) {
|
||||
wantFlags = wantFlags.Add(proton.MessageFlagReplied)
|
||||
}
|
||||
|
||||
msg, literal, err := conn.importMessage(ctx, literal, wantLabelIDs, wantFlags, unread)
|
||||
if err != nil {
|
||||
if errors.Is(err, proton.ErrImportSizeExceeded) {
|
||||
// Remap error so that Gluon does not put this message in the recovery mailbox.
|
||||
err = fmt.Errorf("%v: %w", err, connector.ErrMessageSizeExceedsLimits)
|
||||
}
|
||||
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
|
||||
logrus.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("Failed to import message")
|
||||
} else {
|
||||
logrus.WithError(err).Error("Failed to import message")
|
||||
}
|
||||
}
|
||||
|
||||
return msg, literal, err
|
||||
}
|
||||
|
||||
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return safe.RLockRetErr(func() ([]byte, error) {
|
||||
var literal []byte
|
||||
err := usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[msg.AddressID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
l, buildErr := message.BuildRFC822(addrKR, msg.Message, msg.AttData, defaultJobOpts())
|
||||
if buildErr != nil {
|
||||
return buildErr
|
||||
}
|
||||
|
||||
literal = l
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return literal, err
|
||||
}, conn.apiUserLock, conn.apiAddrsLock)
|
||||
}
|
||||
|
||||
// AddMessagesToMailbox labels the given messages with the given label ID.
|
||||
func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if isAllMailOrScheduled(mailboxID) {
|
||||
return connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mailboxID))
|
||||
}
|
||||
|
||||
// RemoveMessagesFromMailbox unlabels the given messages with the given label ID.
|
||||
func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if isAllMailOrScheduled(mailboxID) {
|
||||
return connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
msgIDs := usertypes.MapTo[imap.MessageID, string](messageIDs)
|
||||
if err := conn.client.UnlabelMessages(ctx, msgIDs, string(mailboxID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mailboxID == proton.TrashLabel || mailboxID == proton.DraftsLabel {
|
||||
if err := conn.client.DeleteMessage(ctx, msgIDs...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MoveMessages removes the given messages from one label and adds them to the other label.
|
||||
func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.MailboxID, labelToID imap.MailboxID) (bool, error) {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if (labelFromID == proton.InboxLabel && labelToID == proton.SentLabel) ||
|
||||
(labelFromID == proton.SentLabel && labelToID == proton.InboxLabel) ||
|
||||
isAllMailOrScheduled(labelFromID) ||
|
||||
isAllMailOrScheduled(labelToID) {
|
||||
return false, connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
shouldExpungeOldLocation := func() bool {
|
||||
conn.apiLabelsLock.RLock()
|
||||
defer conn.apiLabelsLock.RUnlock()
|
||||
|
||||
var result bool
|
||||
|
||||
if v, ok := conn.apiLabels[string(labelFromID)]; ok && v.Type == proton.LabelTypeLabel {
|
||||
result = true
|
||||
}
|
||||
|
||||
if v, ok := conn.apiLabels[string(labelToID)]; ok && (v.Type == proton.LabelTypeFolder || v.Type == proton.LabelTypeSystem) {
|
||||
result = true
|
||||
}
|
||||
|
||||
return result
|
||||
}()
|
||||
|
||||
if err := conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
|
||||
return false, fmt.Errorf("labeling messages: %w", err)
|
||||
}
|
||||
|
||||
if shouldExpungeOldLocation {
|
||||
if err := conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
|
||||
return false, fmt.Errorf("unlabeling messages: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return shouldExpungeOldLocation, nil
|
||||
}
|
||||
|
||||
// MarkMessagesSeen sets the seen value of the given messages.
|
||||
func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if seen {
|
||||
return conn.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
|
||||
}
|
||||
|
||||
return conn.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...)
|
||||
}
|
||||
|
||||
// MarkMessagesFlagged sets the flagged value of the given messages.
|
||||
func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if flagged {
|
||||
return conn.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||
}
|
||||
|
||||
return conn.client.UnlabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel)
|
||||
}
|
||||
|
||||
// GetUpdates returns a stream of updates that the gluon server should apply.
|
||||
// It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount.
|
||||
func (conn *imapConnector) GetUpdates() <-chan imap.Update {
|
||||
return safe.RLockRet(func() <-chan imap.Update {
|
||||
return conn.updateCh[conn.addrID].GetChannel()
|
||||
}, conn.updateChLock)
|
||||
}
|
||||
|
||||
// GetMailboxVisibility returns the visibility of a mailbox over IMAP.
|
||||
func (conn *imapConnector) GetMailboxVisibility(_ context.Context, mailboxID imap.MailboxID) imap.MailboxVisibility {
|
||||
switch mailboxID {
|
||||
case proton.AllMailLabel:
|
||||
if atomic.LoadUint32(&conn.showAllMail) != 0 {
|
||||
return imap.Visible
|
||||
}
|
||||
return imap.Hidden
|
||||
|
||||
case proton.AllScheduledLabel:
|
||||
return imap.HiddenIfEmpty
|
||||
default:
|
||||
return imap.Visible
|
||||
}
|
||||
}
|
||||
|
||||
// Close the connector will no longer be used and all resources should be closed/released.
|
||||
func (conn *imapConnector) Close(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *imapConnector) importMessage(
|
||||
ctx context.Context,
|
||||
literal []byte,
|
||||
labelIDs []string,
|
||||
flags proton.MessageFlag,
|
||||
unread bool,
|
||||
) (imap.Message, []byte, error) {
|
||||
var full proton.FullMessage
|
||||
|
||||
if err := safe.RLockRet(func() error {
|
||||
return usertypes.WithAddrKR(conn.apiUser, conn.apiAddrs[conn.addrID], conn.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
var messageID string
|
||||
|
||||
if slices.Contains(labelIDs, proton.DraftsLabel) {
|
||||
msg, err := conn.createDraft(ctx, literal, addrKR, conn.apiAddrs[conn.addrID])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create draft: %w", err)
|
||||
}
|
||||
|
||||
// apply labels
|
||||
|
||||
messageID = msg.ID
|
||||
} else {
|
||||
str, err := conn.client.ImportMessages(ctx, addrKR, 1, 1, []proton.ImportReq{{
|
||||
Metadata: proton.ImportMetadata{
|
||||
AddressID: conn.addrID,
|
||||
LabelIDs: labelIDs,
|
||||
Unread: proton.Bool(unread),
|
||||
Flags: flags,
|
||||
},
|
||||
Message: literal,
|
||||
}}...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare message for import: %w", err)
|
||||
}
|
||||
|
||||
res, err := stream.Collect(ctx, str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import message: %w", err)
|
||||
}
|
||||
|
||||
messageID = res[0].MessageID
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, usertypes.NewProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
if literal, err = message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); err != nil {
|
||||
return fmt.Errorf("failed to build message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}, conn.apiUserLock, conn.apiAddrsLock); err != nil {
|
||||
return imap.Message{}, nil, err
|
||||
}
|
||||
|
||||
return toIMAPMessage(full.MessageMetadata), literal, nil
|
||||
}
|
||||
|
||||
func toIMAPMessage(message proton.MessageMetadata) imap.Message {
|
||||
flags := buildFlagSetFromMessageMetadata(message)
|
||||
|
||||
var date time.Time
|
||||
|
||||
if message.Time > 0 {
|
||||
date = time.Unix(message.Time, 0)
|
||||
} else {
|
||||
date = time.Now()
|
||||
}
|
||||
|
||||
return imap.Message{
|
||||
ID: imap.MessageID(message.ID),
|
||||
Flags: flags,
|
||||
Date: date,
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *imapConnector) createDraft(ctx context.Context, literal []byte, addrKR *crypto.KeyRing, sender proton.Address) (proton.Message, error) {
|
||||
// Create a new message parser from the reader.
|
||||
parser, err := parser.New(bytes.NewReader(literal))
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create parser: %w", err)
|
||||
}
|
||||
|
||||
message, err := message.ParseWithParser(parser, true)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
decBody := string(message.PlainBody)
|
||||
if message.RichBody != "" {
|
||||
decBody = string(message.RichBody)
|
||||
}
|
||||
|
||||
draft, err := conn.client.CreateDraft(ctx, addrKR, proton.CreateDraftReq{
|
||||
Message: proton.DraftTemplate{
|
||||
Subject: message.Subject,
|
||||
Body: decBody,
|
||||
MIMEType: message.MIMEType,
|
||||
|
||||
Sender: &mail.Address{Name: sender.DisplayName, Address: sender.Email},
|
||||
ToList: message.ToList,
|
||||
CCList: message.CCList,
|
||||
BCCList: message.BCCList,
|
||||
|
||||
ExternalID: message.ExternalID,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create draft: %w", err)
|
||||
}
|
||||
|
||||
for _, att := range message.Attachments {
|
||||
disposition := proton.AttachmentDisposition
|
||||
if att.Disposition == "inline" && att.ContentID != "" {
|
||||
disposition = proton.InlineDisposition
|
||||
}
|
||||
|
||||
if _, err := conn.client.UploadAttachment(ctx, addrKR, proton.CreateAttachmentReq{
|
||||
MessageID: draft.ID,
|
||||
Filename: att.Name,
|
||||
MIMEType: rfc822.MIMEType(att.MIMEType),
|
||||
Disposition: disposition,
|
||||
ContentID: att.ContentID,
|
||||
Body: att.Data,
|
||||
}); err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to add attachment to draft: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return draft, nil
|
||||
}
|
||||
|
||||
func toIMAPMailbox(label proton.Label, flags, permFlags, attrs imap.FlagSet) imap.Mailbox {
|
||||
if label.Type == proton.LabelTypeLabel {
|
||||
label.Path = append([]string{labelPrefix}, label.Path...)
|
||||
} else if label.Type == proton.LabelTypeFolder {
|
||||
label.Path = append([]string{folderPrefix}, label.Path...)
|
||||
}
|
||||
|
||||
return imap.Mailbox{
|
||||
ID: imap.MailboxID(label.ID),
|
||||
Name: label.Path,
|
||||
Flags: flags,
|
||||
PermanentFlags: permFlags,
|
||||
Attributes: attrs,
|
||||
}
|
||||
}
|
||||
|
||||
func isAllMailOrScheduled(mailboxID imap.MailboxID) bool {
|
||||
return (mailboxID == proton.AllMailLabel) || (mailboxID == proton.AllScheduledLabel)
|
||||
}
|
||||
|
||||
func buildFlagSetFromMessageMetadata(message proton.MessageMetadata) imap.FlagSet {
|
||||
flags := imap.NewFlagSet()
|
||||
|
||||
if message.Seen() {
|
||||
flags.AddToSelf(imap.FlagSeen)
|
||||
}
|
||||
|
||||
if message.Starred() {
|
||||
flags.AddToSelf(imap.FlagFlagged)
|
||||
}
|
||||
|
||||
if message.IsDraft() {
|
||||
flags.AddToSelf(imap.FlagDraft)
|
||||
}
|
||||
|
||||
if message.IsRepliedAll == true || message.IsReplied == true { //nolint: gosimple
|
||||
flags.AddToSelf(imap.FlagAnswered)
|
||||
}
|
||||
|
||||
return flags
|
||||
}
|
||||
|
||||
func getLiteralToList(literal []byte) ([]string, error) {
|
||||
headerLiteral, _ := rfc822.Split(literal)
|
||||
|
||||
header, err := rfc822.NewHeader(headerLiteral)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result []string
|
||||
|
||||
parseAddress := func(field string) error {
|
||||
if fieldAddr, ok := header.GetChecked(field); ok {
|
||||
addr, err := rfc5322.ParseAddressList(fieldAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse addresses for '%v': %w", field, err)
|
||||
}
|
||||
|
||||
result = append(result, xslices.Map(addr, func(addr *mail.Address) string {
|
||||
return addr.Address
|
||||
})...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := parseAddress("To"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := parseAddress("Cc"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := parseAddress("Bcc"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@ -36,8 +36,14 @@ func BenchmarkAddrKeyRing(b *testing.B) {
|
||||
withUser(b, ctx, s, m, "username", "password", func(user *User) {
|
||||
b.StartTimer()
|
||||
|
||||
apiUser, err := user.identityService.GetAPIUser(ctx)
|
||||
require.NoError(b, err)
|
||||
|
||||
apiAddrs, err := user.identityService.GetAddresses(ctx)
|
||||
require.NoError(b, err)
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
require.NoError(b, usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
require.NoError(b, usertypes.WithAddrKRs(apiUser, apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
@ -1,918 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/pbnjay/memory"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// syncSystemLabels ensures that system labels are all known to gluon.
|
||||
func (user *User) syncSystemLabels(ctx context.Context) error {
|
||||
return safe.RLockRet(func() error {
|
||||
var updates []imap.Update
|
||||
|
||||
for _, label := range xslices.Filter(maps.Values(user.apiLabels), func(label proton.Label) bool { return label.Type == proton.LabelTypeSystem }) {
|
||||
if !wantLabel(label) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return fmt.Errorf("could not sync system labels: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// doSync begins syncing the user's data.
|
||||
// It first ensures the latest event ID is known; if not, it fetches it.
|
||||
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
|
||||
// depending on whether the sync was successful.
|
||||
func (user *User) doSync(ctx context.Context) error {
|
||||
if user.vault.EventID() == "" {
|
||||
eventID, err := user.client.GetLatestEventID(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get latest event ID: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetEventID(eventID); err != nil {
|
||||
return fmt.Errorf("failed to set latest event ID: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
user.log.WithField("start", start).Info("Beginning user sync")
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
if err := user.sync(ctx); err != nil {
|
||||
user.log.WithError(err).Warn("Failed to sync user")
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFailed{
|
||||
UserID: user.ID(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to sync: %w", err)
|
||||
}
|
||||
|
||||
user.log.WithField("duration", time.Since(start)).Info("Finished user sync")
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
return safe.RLockRet(func() error {
|
||||
return usertypes.WithAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
user.log.Info("Syncing labels")
|
||||
|
||||
if err := syncLabels(ctx, user.apiLabels, xslices.Unique(maps.Values(user.updateCh))...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
|
||||
user.log.Info("Synced labels")
|
||||
} else {
|
||||
user.log.Info("Labels are already synced, skipping")
|
||||
}
|
||||
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
user.log.Info("Syncing messages")
|
||||
|
||||
// Determine which messages to sync.
|
||||
messageIDs, err := user.client.GetMessageIDs(ctx, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get message IDs to sync: %w", err)
|
||||
}
|
||||
|
||||
logrus.Debugf("User has the following failed synced message ids: %v", user.vault.SyncStatus().FailedMessageIDs)
|
||||
|
||||
// Remove any messages that have already failed to sync.
|
||||
messageIDs = xslices.Filter(messageIDs, func(messageID string) bool {
|
||||
return !slices.Contains(user.vault.SyncStatus().FailedMessageIDs, messageID)
|
||||
})
|
||||
|
||||
// Reverse the order of the message IDs so that the newest messages are synced first.
|
||||
xslices.Reverse(messageIDs)
|
||||
|
||||
// If we have a message ID that we've already synced, then we can skip all messages before it.
|
||||
if idx := xslices.Index(messageIDs, user.vault.SyncStatus().LastMessageID); idx >= 0 {
|
||||
messageIDs = messageIDs[idx+1:]
|
||||
}
|
||||
|
||||
// Sync the messages.
|
||||
if err := user.syncMessages(
|
||||
ctx,
|
||||
user.ID(),
|
||||
messageIDs,
|
||||
user.client,
|
||||
user.reporter,
|
||||
user.vault,
|
||||
user.apiLabels,
|
||||
addrKRs,
|
||||
user.updateCh,
|
||||
user.eventCh,
|
||||
user.maxSyncMemory,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
|
||||
user.log.Info("Synced messages")
|
||||
} else {
|
||||
user.log.Info("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// nolint:exhaustive
|
||||
func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh ...*async.QueuedChannel[imap.Update]) error {
|
||||
var updates []imap.Update
|
||||
|
||||
// Create placeholder Folders/Labels mailboxes with the \Noselect attribute.
|
||||
for _, prefix := range []string{folderPrefix, labelPrefix} {
|
||||
for _, updateCh := range updateCh {
|
||||
update := newPlaceHolderMailboxCreatedUpdate(prefix)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
}
|
||||
|
||||
// Sync the user's labels.
|
||||
for labelID, label := range apiLabels {
|
||||
if !wantLabel(label) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch label.Type {
|
||||
case proton.LabelTypeSystem:
|
||||
for _, updateCh := range updateCh {
|
||||
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
case proton.LabelTypeFolder, proton.LabelTypeLabel:
|
||||
for _, updateCh := range updateCh {
|
||||
update := newMailboxCreatedUpdate(imap.MailboxID(labelID), getMailboxName(label))
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown label type: %d", label.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all label updates to be applied.
|
||||
for _, update := range updates {
|
||||
err, ok := update.WaitContext(ctx)
|
||||
if ok && err != nil {
|
||||
return fmt.Errorf("failed to apply label create update in gluon %v: %w", update.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const Kilobyte = uint64(1024)
|
||||
const Megabyte = 1024 * Kilobyte
|
||||
const Gigabyte = 1024 * Megabyte
|
||||
|
||||
func toMB(v uint64) float64 {
|
||||
return float64(v) / float64(Megabyte)
|
||||
}
|
||||
|
||||
type syncLimits struct {
|
||||
MaxDownloadRequestMem uint64
|
||||
MinDownloadRequestMem uint64
|
||||
MaxMessageBuildingMem uint64
|
||||
MinMessageBuildingMem uint64
|
||||
MaxSyncMemory uint64
|
||||
MaxParallelDownloads int
|
||||
}
|
||||
|
||||
func newSyncLimits(maxSyncMemory uint64) syncLimits {
|
||||
limits := syncLimits{
|
||||
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
|
||||
// returns as we can't keep the pipeline fed fast enough.
|
||||
MaxDownloadRequestMem: 128 * Megabyte,
|
||||
|
||||
// Any lower than this and we may fail to download messages.
|
||||
MinDownloadRequestMem: 40 * Megabyte,
|
||||
|
||||
// This value can be increased to your hearts content. The more system memory the user has, the more messages
|
||||
// we can build in parallel.
|
||||
MaxMessageBuildingMem: 128 * Megabyte,
|
||||
MinMessageBuildingMem: 64 * Megabyte,
|
||||
|
||||
// Maximum recommend value for parallel downloads by the API team.
|
||||
MaxParallelDownloads: 20,
|
||||
|
||||
MaxSyncMemory: maxSyncMemory,
|
||||
}
|
||||
|
||||
if _, ok := os.LookupEnv("BRIDGE_SYNC_FORCE_MINIMUM_SPEC"); ok {
|
||||
logrus.Warn("Sync specs forced to minimum")
|
||||
limits.MaxDownloadRequestMem = 50 * Megabyte
|
||||
limits.MaxMessageBuildingMem = 80 * Megabyte
|
||||
limits.MaxParallelDownloads = 2
|
||||
limits.MaxSyncMemory = 800 * Megabyte
|
||||
}
|
||||
|
||||
return limits
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (user *User) syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
messageIDs []string,
|
||||
client *proton.Client,
|
||||
sentry reporter.Reporter,
|
||||
vault *vault.User,
|
||||
apiLabels map[string]proton.Label,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
updateCh map[string]*async.QueuedChannel[imap.Update],
|
||||
eventCh *async.QueuedChannel[events.Event],
|
||||
cfgMaxSyncMemory uint64,
|
||||
) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Track the amount of time to process all the messages.
|
||||
syncStartTime := time.Now()
|
||||
defer func() { logrus.WithField("duration", time.Since(syncStartTime)).Info("Message sync completed") }()
|
||||
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"messages": len(messageIDs),
|
||||
"numCPU": runtime.NumCPU(),
|
||||
}).Info("Starting message sync")
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
|
||||
// Create a reporter to report sync progress updates.
|
||||
syncReporter := newSyncReporter(userID, eventCh, len(messageIDs), time.Second)
|
||||
defer syncReporter.done()
|
||||
|
||||
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
|
||||
// times x due to pipeline and all additional memory used by network requests and compression+io.
|
||||
|
||||
totalMemory := memory.TotalMemory()
|
||||
|
||||
syncLimits := newSyncLimits(cfgMaxSyncMemory)
|
||||
|
||||
if syncLimits.MaxSyncMemory >= totalMemory/2 {
|
||||
logrus.Warnf("Requested max sync memory of %v MB is greater than half of system memory (%v MB), forcing to half of system memory",
|
||||
toMB(syncLimits.MaxSyncMemory), toMB(totalMemory/2))
|
||||
syncLimits.MaxSyncMemory = totalMemory / 2
|
||||
}
|
||||
|
||||
if syncLimits.MaxSyncMemory < 800*Megabyte {
|
||||
logrus.Warnf("Requested max sync memory of %v MB, but minimum recommended is 800 MB, forcing max syncMemory to 800MB", toMB(syncLimits.MaxSyncMemory))
|
||||
syncLimits.MaxSyncMemory = 800 * Megabyte
|
||||
}
|
||||
|
||||
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
|
||||
|
||||
// Linter says it's not used. This is a lie.
|
||||
//nolint: staticcheck
|
||||
syncMaxDownloadRequestMem := syncLimits.MaxDownloadRequestMem
|
||||
|
||||
// Linter says it's not used. This is a lie.
|
||||
//nolint: staticcheck
|
||||
syncMaxMessageBuildingMem := syncLimits.MaxMessageBuildingMem
|
||||
|
||||
// If less than 2GB available try and limit max memory to 512 MB
|
||||
switch {
|
||||
case syncLimits.MaxSyncMemory < 2*Gigabyte:
|
||||
if syncLimits.MaxSyncMemory < 800*Megabyte {
|
||||
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
|
||||
}
|
||||
syncMaxDownloadRequestMem = syncLimits.MinDownloadRequestMem
|
||||
syncMaxMessageBuildingMem = syncLimits.MinMessageBuildingMem
|
||||
case syncLimits.MaxSyncMemory == 2*Gigabyte:
|
||||
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
|
||||
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
|
||||
// updates. Additionally, most of sync time is spent in the message building.
|
||||
syncMaxDownloadRequestMem = syncLimits.MaxDownloadRequestMem
|
||||
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
|
||||
syncMaxMessageBuildingMem = syncLimits.MaxMessageBuildingMem
|
||||
default:
|
||||
// Divide by 8 as download stage and build stage will use aprox. 4x the specified memory.
|
||||
remainingMemory := (syncLimits.MaxSyncMemory - 2*Gigabyte) / 8
|
||||
syncMaxDownloadRequestMem = syncLimits.MaxDownloadRequestMem + remainingMemory
|
||||
syncMaxMessageBuildingMem = syncLimits.MaxMessageBuildingMem + remainingMemory
|
||||
}
|
||||
|
||||
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
|
||||
toMB(syncMaxDownloadRequestMem),
|
||||
toMB(syncMaxMessageBuildingMem),
|
||||
toMB((syncMaxMessageBuildingMem*4)+(syncMaxDownloadRequestMem*4)),
|
||||
)
|
||||
|
||||
type flushUpdate struct {
|
||||
messageID string
|
||||
err error
|
||||
batchLen int
|
||||
}
|
||||
|
||||
type downloadRequest struct {
|
||||
ids []string
|
||||
expectedSize uint64
|
||||
err error
|
||||
}
|
||||
|
||||
type downloadedMessageBatch struct {
|
||||
batch []proton.FullMessage
|
||||
}
|
||||
|
||||
type builtMessageBatch struct {
|
||||
batch []*buildRes
|
||||
}
|
||||
|
||||
downloadCh := make(chan downloadRequest)
|
||||
|
||||
buildCh := make(chan downloadedMessageBatch)
|
||||
|
||||
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
|
||||
// to the update flushing goroutine.
|
||||
flushCh := make(chan builtMessageBatch)
|
||||
|
||||
flushUpdateCh := make(chan flushUpdate)
|
||||
|
||||
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
||||
|
||||
// Go routine in charge of downloading message metadata
|
||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(downloadCh)
|
||||
const MetadataDataPageSize = 150
|
||||
|
||||
var downloadReq downloadRequest
|
||||
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
|
||||
|
||||
metadataChunks := xslices.Chunk(messageIDs, MetadataDataPageSize)
|
||||
for i, metadataChunk := range metadataChunks {
|
||||
logrus.Debugf("Metadata Request (%v of %v), previous: %v", i, len(metadataChunks), len(downloadReq.ids))
|
||||
metadata, err := client.GetMessageMetadataPage(ctx, 0, len(metadataChunk), proton.MessageFilter{ID: metadataChunk})
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to download message metadata for chunk %v", i)
|
||||
downloadReq.err = err
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Build look up table so that messages are processed in the same order.
|
||||
metadataMap := make(map[string]int, len(metadata))
|
||||
for i, v := range metadata {
|
||||
metadataMap[v.ID] = i
|
||||
}
|
||||
|
||||
for i, id := range metadataChunk {
|
||||
m := &metadata[metadataMap[id]]
|
||||
nextSize := downloadReq.expectedSize + uint64(m.Size)
|
||||
if nextSize >= syncMaxDownloadRequestMem || len(downloadReq.ids) >= 256 {
|
||||
logrus.Debugf("Download Request Sent at %v of %v", i, len(metadata))
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
downloadReq.expectedSize = 0
|
||||
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
|
||||
nextSize = uint64(m.Size)
|
||||
}
|
||||
downloadReq.ids = append(downloadReq.ids, id)
|
||||
downloadReq.expectedSize = nextSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(downloadReq.ids) != 0 {
|
||||
logrus.Debugf("Sending remaining download request")
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "meta-data"})
|
||||
|
||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(buildCh)
|
||||
defer close(errorCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync downloader exit")
|
||||
}()
|
||||
|
||||
attachmentDownloader := user.newAttachmentDownloader(ctx, client, syncLimits.MaxParallelDownloads)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
for request := range downloadCh {
|
||||
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
|
||||
if request.err != nil {
|
||||
errorCh <- request.err
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
errorCh <- ctx.Err()
|
||||
return
|
||||
}
|
||||
|
||||
result, err := parallel.MapContext(ctx, syncLimits.MaxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
|
||||
var result proton.FullMessage
|
||||
|
||||
msg, err := client.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message attachments")
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
result.Message = msg
|
||||
result.AttData = attachments
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case buildCh <- downloadedMessageBatch{
|
||||
batch: result,
|
||||
}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "download"})
|
||||
|
||||
// Goroutine which builds messages after they have been downloaded
|
||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(flushCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync builder exit")
|
||||
}()
|
||||
|
||||
maxMessagesInParallel := runtime.NumCPU()
|
||||
|
||||
for buildBatch := range buildCh {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
chunks := chunkSyncBuilderBatch(buildBatch.batch, syncMaxMessageBuildingMem)
|
||||
|
||||
for index, chunk := range chunks {
|
||||
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
|
||||
kr, ok := addrKRs[msg.AddressID]
|
||||
if !ok {
|
||||
logrus.Errorf("Address '%v' on message '%v' does not have an unlocked kerying", msg.AddressID, msg.ID)
|
||||
return &buildRes{
|
||||
messageID: msg.ID,
|
||||
addressID: msg.AddressID,
|
||||
err: fmt.Errorf("address does not have an unlocked keyring"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
res := buildRFC822(apiLabels, msg, kr, new(bytes.Buffer))
|
||||
if res.err != nil {
|
||||
logrus.WithError(res.err).WithField("msgID", msg.ID).Error("Failed to build message (syn)")
|
||||
}
|
||||
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case flushCh <- builtMessageBatch{result}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "builder"})
|
||||
|
||||
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
|
||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(flushUpdateCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync flush exit")
|
||||
}()
|
||||
|
||||
type updateTargetInfo struct {
|
||||
queueIndex int
|
||||
ch *async.QueuedChannel[imap.Update]
|
||||
}
|
||||
|
||||
pendingUpdates := make([][]*imap.MessageCreated, len(updateCh))
|
||||
addressToIndex := make(map[string]updateTargetInfo)
|
||||
|
||||
{
|
||||
i := 0
|
||||
for addrID, updateCh := range updateCh {
|
||||
addressToIndex[addrID] = updateTargetInfo{
|
||||
ch: updateCh,
|
||||
queueIndex: i,
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
for downloadBatch := range flushCh {
|
||||
logrus.Debugf("Flush batch: %v", len(downloadBatch.batch))
|
||||
for _, res := range downloadBatch.batch {
|
||||
if res.err != nil {
|
||||
if err := vault.AddFailedMessageID(res.messageID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to add failed message ID")
|
||||
}
|
||||
|
||||
if err := sentry.ReportMessageWithContext("Failed to build message (sync)", reporter.Context{
|
||||
"messageID": res.messageID,
|
||||
"error": res.err,
|
||||
}); err != nil {
|
||||
logrus.WithError(err).Error("Failed to report message build error")
|
||||
}
|
||||
|
||||
// We could sync a placeholder message here, but for now we skip it entirely.
|
||||
continue
|
||||
}
|
||||
|
||||
if err := vault.RemFailedMessageID(res.messageID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to remove failed message ID")
|
||||
}
|
||||
|
||||
targetInfo := addressToIndex[res.addressID]
|
||||
pendingUpdates[targetInfo.queueIndex] = append(pendingUpdates[targetInfo.queueIndex], res.update)
|
||||
}
|
||||
|
||||
for _, info := range addressToIndex {
|
||||
up := imap.NewMessagesCreated(true, pendingUpdates[info.queueIndex]...)
|
||||
info.ch.Enqueue(up)
|
||||
|
||||
err, ok := up.WaitContext(ctx)
|
||||
if ok && err != nil {
|
||||
flushUpdateCh <- flushUpdate{
|
||||
err: fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
pendingUpdates[info.queueIndex] = pendingUpdates[info.queueIndex][:0]
|
||||
}
|
||||
|
||||
select {
|
||||
case flushUpdateCh <- flushUpdate{
|
||||
messageID: downloadBatch.batch[0].messageID,
|
||||
err: nil,
|
||||
batchLen: len(downloadBatch.batch),
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "flush"})
|
||||
|
||||
for flushUpdate := range flushUpdateCh {
|
||||
if flushUpdate.err != nil {
|
||||
return flushUpdate.err
|
||||
}
|
||||
|
||||
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
|
||||
syncReporter.add(flushUpdate.batchLen)
|
||||
}
|
||||
|
||||
return <-errorCh
|
||||
}
|
||||
|
||||
func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *imap.MailboxCreated {
|
||||
if strings.EqualFold(labelName, imap.Inbox) {
|
||||
labelName = imap.Inbox
|
||||
}
|
||||
|
||||
attrs := imap.NewFlagSet(imap.AttrNoInferiors)
|
||||
permanentFlags := defaultPermanentFlags
|
||||
flags := defaultFlags
|
||||
|
||||
switch labelID {
|
||||
case proton.TrashLabel:
|
||||
attrs = attrs.Add(imap.AttrTrash)
|
||||
|
||||
case proton.SpamLabel:
|
||||
attrs = attrs.Add(imap.AttrJunk)
|
||||
|
||||
case proton.AllMailLabel:
|
||||
attrs = attrs.Add(imap.AttrAll)
|
||||
flags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged)
|
||||
permanentFlags = imap.NewFlagSet(imap.FlagSeen, imap.FlagFlagged)
|
||||
|
||||
case proton.ArchiveLabel:
|
||||
attrs = attrs.Add(imap.AttrArchive)
|
||||
|
||||
case proton.SentLabel:
|
||||
attrs = attrs.Add(imap.AttrSent)
|
||||
|
||||
case proton.DraftsLabel:
|
||||
attrs = attrs.Add(imap.AttrDrafts)
|
||||
|
||||
case proton.StarredLabel:
|
||||
attrs = attrs.Add(imap.AttrFlagged)
|
||||
|
||||
case proton.AllScheduledLabel:
|
||||
labelName = "Scheduled" // API actual name is "All Scheduled"
|
||||
}
|
||||
|
||||
return imap.NewMailboxCreated(imap.Mailbox{
|
||||
ID: labelID,
|
||||
Name: []string{labelName},
|
||||
Flags: flags,
|
||||
PermanentFlags: permanentFlags,
|
||||
Attributes: attrs,
|
||||
})
|
||||
}
|
||||
|
||||
func newPlaceHolderMailboxCreatedUpdate(labelName string) *imap.MailboxCreated {
|
||||
return imap.NewMailboxCreated(imap.Mailbox{
|
||||
ID: imap.MailboxID(labelName),
|
||||
Name: []string{labelName},
|
||||
Flags: defaultFlags,
|
||||
PermanentFlags: defaultPermanentFlags,
|
||||
Attributes: imap.NewFlagSet(imap.AttrNoSelect),
|
||||
})
|
||||
}
|
||||
|
||||
func newMailboxCreatedUpdate(labelID imap.MailboxID, labelName []string) *imap.MailboxCreated {
|
||||
return imap.NewMailboxCreated(imap.Mailbox{
|
||||
ID: labelID,
|
||||
Name: labelName,
|
||||
Flags: defaultFlags,
|
||||
PermanentFlags: defaultPermanentFlags,
|
||||
Attributes: imap.NewFlagSet(),
|
||||
})
|
||||
}
|
||||
|
||||
func wantLabel(label proton.Label) bool {
|
||||
if label.Type != proton.LabelTypeSystem {
|
||||
return true
|
||||
}
|
||||
|
||||
// nolint:exhaustive
|
||||
switch label.ID {
|
||||
case proton.InboxLabel:
|
||||
return true
|
||||
|
||||
case proton.TrashLabel:
|
||||
return true
|
||||
|
||||
case proton.SpamLabel:
|
||||
return true
|
||||
|
||||
case proton.AllMailLabel:
|
||||
return true
|
||||
|
||||
case proton.ArchiveLabel:
|
||||
return true
|
||||
|
||||
case proton.SentLabel:
|
||||
return true
|
||||
|
||||
case proton.DraftsLabel:
|
||||
return true
|
||||
|
||||
case proton.StarredLabel:
|
||||
return true
|
||||
|
||||
case proton.AllScheduledLabel:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string {
|
||||
return xslices.Filter(labelIDs, func(labelID string) bool {
|
||||
apiLabel, ok := apiLabels[labelID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return wantLabel(apiLabel)
|
||||
})
|
||||
}
|
||||
|
||||
type attachmentResult struct {
|
||||
attachment []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type attachmentJob struct {
|
||||
id string
|
||||
size int64
|
||||
result chan attachmentResult
|
||||
}
|
||||
|
||||
type attachmentDownloader struct {
|
||||
workerCh chan attachmentJob
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case job, ok := <-work:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var b bytes.Buffer
|
||||
b.Grow(int(job.size))
|
||||
err := client.GetAttachmentInto(ctx, job.id, &b)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(job.result)
|
||||
return
|
||||
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
|
||||
close(job.result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
workerCh = make(chan attachmentJob)
|
||||
async.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||
})
|
||||
}
|
||||
|
||||
return &attachmentDownloader{
|
||||
workerCh: workerCh,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
|
||||
resultChs := make([]chan attachmentResult, len(attachments))
|
||||
for i, id := range attachments {
|
||||
resultChs[i] = make(chan attachmentResult, 1)
|
||||
select {
|
||||
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
result := make([][]byte, len(attachments))
|
||||
var err error
|
||||
for i := 0; i < len(attachments); i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case r := <-resultChs[i]:
|
||||
if r.err != nil {
|
||||
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
|
||||
}
|
||||
result[i] = r.attachment
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (a *attachmentDownloader) close() {
|
||||
a.cancel()
|
||||
}
|
||||
|
||||
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
|
||||
var expectedMemUsage uint64
|
||||
var chunks [][]proton.FullMessage
|
||||
var lastIndex int
|
||||
var index int
|
||||
|
||||
for _, v := range batch {
|
||||
var dataSize uint64
|
||||
for _, a := range v.Attachments {
|
||||
dataSize += uint64(a.Size)
|
||||
}
|
||||
|
||||
// 2x increase for attachment due to extra memory needed for decrypting and writing
|
||||
// in memory buffer.
|
||||
dataSize *= 2
|
||||
dataSize += uint64(len(v.Body))
|
||||
|
||||
nextMemSize := expectedMemUsage + dataSize
|
||||
if nextMemSize >= maxMemory {
|
||||
chunks = append(chunks, batch[lastIndex:index])
|
||||
lastIndex = index
|
||||
expectedMemUsage = dataSize
|
||||
} else {
|
||||
expectedMemUsage = nextMemSize
|
||||
}
|
||||
|
||||
index++
|
||||
}
|
||||
|
||||
if lastIndex < len(batch) {
|
||||
chunks = append(chunks, batch[lastIndex:])
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
@ -1,174 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"html/template"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
|
||||
"github.com/ProtonMail/proton-bridge/v3/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
)
|
||||
|
||||
type buildRes struct {
|
||||
messageID string
|
||||
addressID string
|
||||
update *imap.MessageCreated
|
||||
err error
|
||||
}
|
||||
|
||||
func defaultJobOpts() message.JobOptions {
|
||||
return message.JobOptions{
|
||||
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
|
||||
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
|
||||
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
|
||||
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
|
||||
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
|
||||
AddMessageIDReference: true, // Whether to include the MessageID in References.
|
||||
}
|
||||
}
|
||||
|
||||
func buildRFC822(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing, buffer *bytes.Buffer) *buildRes {
|
||||
var (
|
||||
update *imap.MessageCreated
|
||||
err error
|
||||
)
|
||||
|
||||
buffer.Grow(full.Size)
|
||||
|
||||
if buildErr := message.BuildRFC822Into(addrKR, full.Message, full.AttData, defaultJobOpts(), buffer); buildErr != nil {
|
||||
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, buildErr)
|
||||
err = buildErr
|
||||
} else if created, parseErr := newMessageCreatedUpdate(apiLabels, full.MessageMetadata, buffer.Bytes()); parseErr != nil {
|
||||
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, parseErr)
|
||||
err = parseErr
|
||||
} else {
|
||||
update = created
|
||||
}
|
||||
|
||||
return &buildRes{
|
||||
messageID: full.ID,
|
||||
addressID: full.AddressID,
|
||||
update: update,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func newMessageCreatedUpdate(
|
||||
apiLabels map[string]proton.Label,
|
||||
message proton.MessageMetadata,
|
||||
literal []byte,
|
||||
) (*imap.MessageCreated, error) {
|
||||
parsedMessage, err := imap.NewParsedMessage(literal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: toIMAPMessage(message),
|
||||
Literal: literal,
|
||||
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||
ParsedMessage: parsedMessage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newMessageCreatedFailedUpdate(
|
||||
apiLabels map[string]proton.Label,
|
||||
message proton.MessageMetadata,
|
||||
err error,
|
||||
) *imap.MessageCreated {
|
||||
literal := newFailedMessageLiteral(message.ID, time.Unix(message.Time, 0), message.Subject, err)
|
||||
|
||||
parsedMessage, err := imap.NewParsedMessage(literal)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: toIMAPMessage(message),
|
||||
MailboxIDs: usertypes.MapTo[string, imap.MailboxID](wantLabels(apiLabels, message.LabelIDs)),
|
||||
Literal: literal,
|
||||
ParsedMessage: parsedMessage,
|
||||
}
|
||||
}
|
||||
|
||||
func newFailedMessageLiteral(
|
||||
messageID string,
|
||||
date time.Time,
|
||||
subject string,
|
||||
syncErr error,
|
||||
) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if tmpl, err := template.New("header").Parse(failedMessageHeaderTemplate); err != nil {
|
||||
panic(err)
|
||||
} else if b, err := tmplExec(tmpl, map[string]any{
|
||||
"Date": date.In(time.UTC).Format(time.RFC822),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
} else if _, err := buf.Write(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if tmpl, err := template.New("body").Parse(failedMessageBodyTemplate); err != nil {
|
||||
panic(err)
|
||||
} else if b, err := tmplExec(tmpl, map[string]any{
|
||||
"MessageID": messageID,
|
||||
"Subject": subject,
|
||||
"Error": syncErr.Error(),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
} else if _, err := buf.Write(lineWrap(algo.B64Encode(b))); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func tmplExec(template *template.Template, data any) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := template.Execute(&buf, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func lineWrap(b []byte) []byte {
|
||||
return bytes.Join(xslices.Chunk(b, 76), []byte{'\r', '\n'})
|
||||
}
|
||||
|
||||
const failedMessageHeaderTemplate = `Date: {{.Date}}
|
||||
Subject: Message failed to build
|
||||
Content-Type: text/plain
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
`
|
||||
|
||||
const failedMessageBodyTemplate = `Failed to build message:
|
||||
Subject: {{.Subject}}
|
||||
Error: {{.Error}}
|
||||
MessageID: {{.MessageID}}
|
||||
`
|
||||
@ -1,80 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewFailedMessageLiteral(t *testing.T) {
|
||||
literal := newFailedMessageLiteral("abcd-efgh", time.Unix(123456789, 0), "subject", errors.New("oops"))
|
||||
|
||||
header, err := rfc822.Parse(literal).ParseHeader()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Message failed to build", header.Get("Subject"))
|
||||
require.Equal(t, "29 Nov 73 21:33 UTC", header.Get("Date"))
|
||||
require.Equal(t, "text/plain", header.Get("Content-Type"))
|
||||
require.Equal(t, "base64", header.Get("Content-Transfer-Encoding"))
|
||||
|
||||
b, err := rfc822.Parse(literal).DecodedBody()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(b), "Failed to build message: \nSubject: subject\nError: oops\nMessageID: abcd-efgh\n")
|
||||
|
||||
parsed, err := imap.NewParsedMessage(literal)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `("29 Nov 73 21:33 UTC" "Message failed to build" NIL NIL NIL NIL NIL NIL NIL NIL)`, parsed.Envelope)
|
||||
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2)`, parsed.Body)
|
||||
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2 NIL NIL NIL NIL)`, parsed.Structure)
|
||||
}
|
||||
|
||||
func TestSyncChunkSyncBuilderBatch(t *testing.T) {
|
||||
// GODT-2424 - Some messages were not fully built due to a bug in the chunking if the total memory used by the
|
||||
// message would be higher than the maximum we allowed.
|
||||
const totalMessageCount = 100
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
Attachments: []proton.Attachment{
|
||||
{
|
||||
Size: int64(8 * Megabyte),
|
||||
},
|
||||
},
|
||||
},
|
||||
AttData: nil,
|
||||
}
|
||||
|
||||
messages := xslices.Repeat(msg, totalMessageCount)
|
||||
|
||||
chunks := chunkSyncBuilderBatch(messages, 16*Megabyte)
|
||||
|
||||
var totalMessagesInChunks int
|
||||
|
||||
for _, v := range chunks {
|
||||
totalMessagesInChunks += len(v)
|
||||
}
|
||||
|
||||
require.Equal(t, totalMessagesInChunks, totalMessageCount)
|
||||
}
|
||||
@ -1,72 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
)
|
||||
|
||||
type syncReporter struct {
|
||||
userID string
|
||||
eventCh *async.QueuedChannel[events.Event]
|
||||
|
||||
start time.Time
|
||||
total int
|
||||
count int
|
||||
|
||||
last time.Time
|
||||
freq time.Duration
|
||||
}
|
||||
|
||||
func newSyncReporter(userID string, eventCh *async.QueuedChannel[events.Event], total int, freq time.Duration) *syncReporter {
|
||||
return &syncReporter{
|
||||
userID: userID,
|
||||
eventCh: eventCh,
|
||||
|
||||
start: time.Now(),
|
||||
total: total,
|
||||
freq: freq,
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *syncReporter) add(delta int) {
|
||||
rep.count += delta
|
||||
|
||||
if time.Since(rep.last) > rep.freq {
|
||||
rep.eventCh.Enqueue(events.SyncProgress{
|
||||
UserID: rep.userID,
|
||||
Progress: float64(rep.count) / float64(rep.total),
|
||||
Elapsed: time.Since(rep.start),
|
||||
Remaining: time.Since(rep.start) * time.Duration(rep.total-(rep.count+1)) / time.Duration(rep.count+1),
|
||||
})
|
||||
|
||||
rep.last = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *syncReporter) done() {
|
||||
rep.eventCh.Enqueue(events.SyncProgress{
|
||||
UserID: rep.userID,
|
||||
Progress: 1,
|
||||
Elapsed: time.Since(rep.start),
|
||||
Remaining: 0,
|
||||
})
|
||||
}
|
||||
@ -19,27 +19,19 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/async"
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/configstatus"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
|
||||
@ -65,6 +57,7 @@ const (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
id string
|
||||
log *logrus.Entry
|
||||
|
||||
vault *vault.User
|
||||
@ -75,27 +68,7 @@ type User struct {
|
||||
eventCh *async.QueuedChannel[events.Event]
|
||||
eventLock safe.RWMutex
|
||||
|
||||
apiUser proton.User
|
||||
apiUserLock safe.RWMutex
|
||||
|
||||
apiAddrs map[string]proton.Address
|
||||
apiAddrsLock safe.RWMutex
|
||||
|
||||
apiLabels map[string]proton.Label
|
||||
apiLabelsLock safe.RWMutex
|
||||
|
||||
updateCh map[string]*async.QueuedChannel[imap.Update]
|
||||
updateChLock safe.RWMutex
|
||||
|
||||
tasks *async.Group
|
||||
syncAbort async.Abortable
|
||||
pollAbort async.Abortable
|
||||
goSync func()
|
||||
|
||||
pollAPIEventsCh chan chan struct{}
|
||||
goPollAPIEvents func(wait bool)
|
||||
|
||||
showAllMail uint32
|
||||
tasks *async.Group
|
||||
|
||||
maxSyncMemory uint64
|
||||
|
||||
@ -108,9 +81,11 @@ type User struct {
|
||||
eventService *userevents.Service
|
||||
identityService *useridentity.Service
|
||||
smtpService *smtp.Service
|
||||
imapService *imapservice.Service
|
||||
|
||||
serviceGroup *orderedtasks.OrderedCancelGroup
|
||||
}
|
||||
|
||||
// New returns a new user.
|
||||
func New(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
@ -122,6 +97,49 @@ func New(
|
||||
maxSyncMemory uint64,
|
||||
statsDir string,
|
||||
telemetryManager telemetry.Availability,
|
||||
serverManager imapservice.IMAPServerManager,
|
||||
eventSubscription events.Subscription,
|
||||
) (*User, error) {
|
||||
user, err := newImpl(
|
||||
ctx,
|
||||
encVault,
|
||||
client,
|
||||
reporter,
|
||||
apiUser,
|
||||
crashHandler,
|
||||
showAllMail,
|
||||
maxSyncMemory,
|
||||
statsDir,
|
||||
telemetryManager,
|
||||
serverManager,
|
||||
eventSubscription,
|
||||
)
|
||||
if err != nil {
|
||||
// Cleanup any pending resources on error
|
||||
if user != nil {
|
||||
user.Close()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// New returns a new user.
|
||||
func newImpl(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
client *proton.Client,
|
||||
reporter reporter.Reporter,
|
||||
apiUser proton.User,
|
||||
crashHandler async.PanicHandler,
|
||||
showAllMail bool,
|
||||
maxSyncMemory uint64,
|
||||
statsDir string,
|
||||
telemetryManager telemetry.Availability,
|
||||
serverManager imapservice.IMAPServerManager,
|
||||
eventSubscription events.Subscription,
|
||||
) (*User, error) {
|
||||
logrus.WithField("userID", apiUser.ID).Info("Creating new user")
|
||||
|
||||
@ -137,7 +155,7 @@ func New(
|
||||
return nil, fmt.Errorf("failed to get labels: %w", err)
|
||||
}
|
||||
|
||||
identityState := useridentity.NewState(apiUser, slices.Clone(apiAddrs), client)
|
||||
identityState := useridentity.NewState(apiUser, apiAddrs, client)
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"userID": apiUser.ID,
|
||||
@ -151,31 +169,12 @@ func New(
|
||||
return nil, fmt.Errorf("failed to init configuration status file: %w", err)
|
||||
}
|
||||
|
||||
// Use null publisher for now to avoid conflicts with original event loop.
|
||||
eventPublisher := &events.NullEventPublisher{}
|
||||
|
||||
// Use in memory store to avoid conflicts with original event loop.
|
||||
idStore := userevents.NewInMemoryEventIDStore()
|
||||
_ = idStore.Store(context.Background(), encVault.EventID())
|
||||
|
||||
eventService := userevents.NewService(
|
||||
apiUser.ID,
|
||||
client,
|
||||
// Use in memory store to avoid conflicts with the original event loop.
|
||||
idStore,
|
||||
eventPublisher,
|
||||
EventPeriod,
|
||||
5*time.Minute,
|
||||
crashHandler,
|
||||
)
|
||||
|
||||
sendRecorder := sendrecorder.NewSendRecorder(sendrecorder.SendEntryExpiry)
|
||||
|
||||
identityService := useridentity.NewService(eventService, eventPublisher, identityState)
|
||||
|
||||
// Create the user object.
|
||||
user := &User{
|
||||
log: logrus.WithField("userID", apiUser.ID),
|
||||
id: apiUser.ID,
|
||||
|
||||
vault: encVault,
|
||||
client: client,
|
||||
@ -185,22 +184,7 @@ func New(
|
||||
eventCh: async.NewQueuedChannel[events.Event](0, 0, crashHandler, fmt.Sprintf("bridge-user-%v", apiUser.ID)),
|
||||
eventLock: safe.NewRWMutex(),
|
||||
|
||||
apiUser: apiUser,
|
||||
apiUserLock: safe.NewRWMutex(),
|
||||
|
||||
apiAddrs: usertypes.GroupBy(apiAddrs, func(addr proton.Address) string { return addr.ID }),
|
||||
apiAddrsLock: safe.NewRWMutex(),
|
||||
|
||||
apiLabels: usertypes.GroupBy(apiLabels, func(label proton.Label) string { return label.ID }),
|
||||
apiLabelsLock: safe.NewRWMutex(),
|
||||
|
||||
updateCh: make(map[string]*async.QueuedChannel[imap.Update]),
|
||||
updateChLock: safe.NewRWMutex(),
|
||||
|
||||
tasks: async.NewGroup(context.Background(), crashHandler),
|
||||
pollAPIEventsCh: make(chan chan struct{}),
|
||||
|
||||
showAllMail: b32(showAllMail),
|
||||
tasks: async.NewGroup(context.Background(), crashHandler),
|
||||
|
||||
maxSyncMemory: maxSyncMemory,
|
||||
|
||||
@ -209,11 +193,24 @@ func New(
|
||||
configStatus: configStatus,
|
||||
telemetryManager: telemetryManager,
|
||||
|
||||
identityService: identityService,
|
||||
smtpService: nil,
|
||||
eventService: eventService,
|
||||
serviceGroup: orderedtasks.NewOrderedCancelGroup(crashHandler),
|
||||
smtpService: nil,
|
||||
}
|
||||
|
||||
user.eventService = userevents.NewService(
|
||||
apiUser.ID,
|
||||
client,
|
||||
userevents.NewVaultEventIDStore(encVault),
|
||||
user,
|
||||
EventPeriod,
|
||||
5*time.Minute,
|
||||
crashHandler,
|
||||
)
|
||||
|
||||
addressMode := usertypes.VaultToAddressMode(encVault.AddressMode())
|
||||
|
||||
user.identityService = useridentity.NewService(user.eventService, user, identityState, encVault, user)
|
||||
|
||||
user.smtpService = smtp.NewService(
|
||||
apiUser.ID,
|
||||
client,
|
||||
@ -223,20 +220,37 @@ func New(
|
||||
encVault,
|
||||
encVault,
|
||||
user,
|
||||
eventService,
|
||||
usertypes.VaultToAddressMode(encVault.AddressMode()),
|
||||
user.eventService,
|
||||
addressMode,
|
||||
identityState.Clone(),
|
||||
)
|
||||
|
||||
user.imapService = imapservice.NewService(
|
||||
client,
|
||||
identityState.Clone(),
|
||||
user,
|
||||
encVault,
|
||||
user.eventService,
|
||||
serverManager,
|
||||
user,
|
||||
encVault,
|
||||
encVault,
|
||||
crashHandler,
|
||||
sendRecorder,
|
||||
user,
|
||||
reporter,
|
||||
addressMode,
|
||||
eventSubscription,
|
||||
user.maxSyncMemory,
|
||||
showAllMail,
|
||||
)
|
||||
|
||||
// Check for status_progress when triggered.
|
||||
user.goStatusProgress = user.tasks.PeriodicOrTrigger(configstatus.ProgressCheckInterval, 0, func(ctx context.Context) {
|
||||
user.SendConfigStatusProgress(ctx)
|
||||
})
|
||||
defer user.goStatusProgress()
|
||||
|
||||
// Initialize the user's update channels for its current address mode.
|
||||
user.initUpdateCh(encVault.AddressMode())
|
||||
|
||||
// When we receive an auth object, we update it in the vault.
|
||||
// This will be used to authorize the user on the next run.
|
||||
user.client.AddAuthHandler(func(auth proton.Auth) {
|
||||
@ -259,130 +273,93 @@ func New(
|
||||
return nil
|
||||
})
|
||||
|
||||
// When triggered, poll the API for events, optionally blocking until the poll is complete.
|
||||
user.goPollAPIEvents = func(wait bool) {
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer async.HandlePanic(user.panicHandler)
|
||||
user.pollAPIEventsCh <- doneCh
|
||||
}()
|
||||
|
||||
if wait {
|
||||
<-doneCh
|
||||
}
|
||||
}
|
||||
|
||||
// When triggered, sync the user and then begin streaming API events.
|
||||
user.goSync = user.tasks.Trigger(func(ctx context.Context) {
|
||||
user.log.Info("Sync triggered")
|
||||
|
||||
// Sync the user.
|
||||
user.syncAbort.Do(ctx, func(ctx context.Context) {
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
user.log.Info("Sync already complete, only system label will be updated")
|
||||
|
||||
if err := user.syncSystemLabels(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to update system labels")
|
||||
return
|
||||
}
|
||||
|
||||
user.log.Info("System label update complete, starting API event stream")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
user.log.WithError(err).Error("Sync aborted")
|
||||
return
|
||||
} else if err := user.doSync(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to sync, will retry later")
|
||||
sleepCtx(ctx, SyncRetryCooldown)
|
||||
} else {
|
||||
user.log.Info("Sync complete, starting API event stream")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Once we know the sync has completed, we can start polling for API events.
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
user.pollAbort.Do(ctx, func(ctx context.Context) {
|
||||
user.startEvents(ctx)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Start Event Service
|
||||
if err := user.eventService.Start(ctx, user.tasks); err != nil {
|
||||
return nil, fmt.Errorf("failed to start event service: %w", err)
|
||||
if err := user.eventService.Start(ctx, user.serviceGroup); err != nil {
|
||||
return user, fmt.Errorf("failed to start event service: %w", err)
|
||||
}
|
||||
|
||||
// Start Identity Service
|
||||
user.identityService.Start(user.tasks)
|
||||
user.identityService.Start(ctx, user.serviceGroup)
|
||||
|
||||
// Start SMTP Service
|
||||
user.smtpService.Start(user.tasks)
|
||||
user.smtpService.Start(ctx, user.serviceGroup)
|
||||
|
||||
if err := user.eventService.Resume(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to resume event service")
|
||||
// Start IMAP Service
|
||||
if err := user.imapService.Start(ctx, user.serviceGroup); err != nil {
|
||||
return user, fmt.Errorf("failed to start imap service: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (user *User) TriggerSync() {
|
||||
user.goSync()
|
||||
}
|
||||
|
||||
// ID returns the user's ID.
|
||||
func (user *User) ID() string {
|
||||
return safe.RLockRet(func() string {
|
||||
return user.apiUser.ID
|
||||
}, user.apiUserLock)
|
||||
return user.id
|
||||
}
|
||||
|
||||
// Name returns the user's username.
|
||||
func (user *User) Name() string {
|
||||
return safe.RLockRet(func() string {
|
||||
return user.apiUser.Name
|
||||
}, user.apiUserLock)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
apiUser, err := user.identityService.GetAPIUser(ctx)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return apiUser.Name
|
||||
}
|
||||
|
||||
// Match matches the given query against the user's username and email addresses.
|
||||
func (user *User) Match(query string) bool {
|
||||
return safe.RLockRet(func() bool {
|
||||
if query == user.apiUser.Name {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
apiUser, err := user.identityService.GetAPIUser(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if query == apiUser.Name {
|
||||
return true
|
||||
}
|
||||
|
||||
apiAddrs, err := user.identityService.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, addr := range apiAddrs {
|
||||
if query == addr.Email {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, addr := range user.apiAddrs {
|
||||
if query == addr.Email {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}, user.apiUserLock, user.apiAddrsLock)
|
||||
return false
|
||||
}
|
||||
|
||||
// Emails returns all the user's active email addresses.
|
||||
// It returns them in sorted order; the user's primary address is first.
|
||||
func (user *User) Emails() []string {
|
||||
return safe.RLockRet(func() []string {
|
||||
addresses := xslices.Filter(maps.Values(user.apiAddrs), func(addr proton.Address) bool {
|
||||
return addr.Status == proton.AddressStatusEnabled && addr.Type != proton.AddressTypeExternal
|
||||
})
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
slices.SortFunc(addresses, func(a, b proton.Address) bool {
|
||||
return a.Order < b.Order
|
||||
})
|
||||
apiAddrs, err := user.identityService.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return xslices.Map(addresses, func(addr proton.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
}, user.apiAddrsLock)
|
||||
addresses := xslices.Filter(maps.Values(apiAddrs), func(addr proton.Address) bool {
|
||||
return addr.Status == proton.AddressStatusEnabled && addr.Type != proton.AddressTypeExternal
|
||||
})
|
||||
|
||||
slices.SortFunc(addresses, func(a, b proton.Address) bool {
|
||||
return a.Order < b.Order
|
||||
})
|
||||
|
||||
return xslices.Map(addresses, func(addr proton.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
}
|
||||
|
||||
// GetAddressMode returns the user's current address mode.
|
||||
@ -394,48 +371,58 @@ func (user *User) GetAddressMode() vault.AddressMode {
|
||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||
user.log.WithField("mode", mode).Info("Setting address mode")
|
||||
|
||||
user.syncAbort.Abort()
|
||||
user.pollAbort.Abort()
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
return safe.LockRet(func() error {
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
if err := user.smtpService.SetAddressMode(ctx, usertypes.VaultToAddressMode(mode)); err != nil {
|
||||
return fmt.Errorf("failed to set smtp address mode: %w", err)
|
||||
}
|
||||
|
||||
if err := user.smtpService.SetAddressMode(ctx, usertypes.VaultToAddressMode(mode)); err != nil {
|
||||
return fmt.Errorf("failed to set smtp address mode: %w", err)
|
||||
}
|
||||
if err := user.imapService.SetAddressMode(ctx, usertypes.VaultToAddressMode(mode)); err != nil {
|
||||
return fmt.Errorf("failed to imap address mode: %w", err)
|
||||
}
|
||||
|
||||
if err := user.clearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, user.eventLock, user.apiAddrsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// CancelSyncAndEventPoll stops the sync or event poll go-routine.
|
||||
func (user *User) CancelSyncAndEventPoll() {
|
||||
user.syncAbort.Abort()
|
||||
user.pollAbort.Abort()
|
||||
return nil
|
||||
}
|
||||
|
||||
// BadEventFeedbackResync sends user feedback whether should do message re-sync.
|
||||
func (user *User) BadEventFeedbackResync(ctx context.Context) {
|
||||
user.CancelSyncAndEventPoll()
|
||||
func (user *User) BadEventFeedbackResync(ctx context.Context) error {
|
||||
if err := user.imapService.OnBadEventResync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to execute bad event request on imap service: %w", err)
|
||||
}
|
||||
|
||||
// We need to cancel the event poll later again as it is not guaranteed, due to timing, that we have a
|
||||
// task to cancel.
|
||||
if err := user.syncUserAddressesLabelsAndClearSync(ctx, true); err != nil {
|
||||
user.log.WithError(err).Error("Bad event resync failed")
|
||||
if err := user.identityService.Resync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to resync identity service: %w", err)
|
||||
}
|
||||
|
||||
if err := user.smtpService.Resync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to resync smtp service: %w", err)
|
||||
}
|
||||
|
||||
if err := user.imapService.Resync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to resync imap service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) OnBadEvent(ctx context.Context) {
|
||||
if err := user.imapService.OnBadEvent(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to notify imap service of bad event")
|
||||
}
|
||||
}
|
||||
|
||||
// SetShowAllMail sets whether to show the All Mail mailbox.
|
||||
func (user *User) SetShowAllMail(show bool) {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
user.log.WithField("show", show).Info("Setting show all mail")
|
||||
|
||||
atomic.StoreUint32(&user.showAllMail, b32(show))
|
||||
if err := user.imapService.ShowAllMail(ctx, show); err != nil {
|
||||
user.log.WithError(err).Error("Failed to set show all mail")
|
||||
}
|
||||
}
|
||||
|
||||
// GetGluonIDs returns the users gluon IDs.
|
||||
@ -498,16 +485,28 @@ func (user *User) BridgePass() []byte {
|
||||
|
||||
// UsedSpace returns the total space used by the user on the API.
|
||||
func (user *User) UsedSpace() int {
|
||||
return safe.RLockRet(func() int {
|
||||
return user.apiUser.UsedSpace
|
||||
}, user.apiUserLock)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
apiUser, err := user.identityService.GetAPIUser(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return apiUser.UsedSpace
|
||||
}
|
||||
|
||||
// MaxSpace returns the amount of space the user can use on the API.
|
||||
func (user *User) MaxSpace() int {
|
||||
return safe.RLockRet(func() int {
|
||||
return user.apiUser.MaxSpace
|
||||
}, user.apiUserLock)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
apiUser, err := user.identityService.GetAPIUser(ctx)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return apiUser.MaxSpace
|
||||
}
|
||||
|
||||
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change).
|
||||
@ -515,118 +514,32 @@ func (user *User) GetEventCh() <-chan events.Event {
|
||||
return user.eventCh.GetChannel()
|
||||
}
|
||||
|
||||
// NewIMAPConnector returns an IMAP connector for the given address.
|
||||
// If not in split mode, this must be the primary address.
|
||||
func (user *User) NewIMAPConnector(addrID string) connector.Connector {
|
||||
return newIMAPConnector(user, addrID)
|
||||
}
|
||||
|
||||
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
||||
// In combined mode, this is just the user's primary address.
|
||||
// In split mode, this is all the user's addresses.
|
||||
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
return safe.RLockRetErr(func() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := usertypes.GetAddrIdx(user.apiAddrs, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
imapConn[primAddr.ID] = newIMAPConnector(user, primAddr.ID)
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
}
|
||||
}
|
||||
|
||||
return imapConn, nil
|
||||
}, user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// CheckAuth returns whether the given email and password can be used to authenticate over IMAP or SMTP with this user.
|
||||
// It returns the address ID of the authenticated address.
|
||||
func (user *User) CheckAuth(email string, password []byte) (string, error) {
|
||||
user.log.WithField("email", logging.Sensitive(email)).Debug("Checking authentication")
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
|
||||
defer cancel()
|
||||
|
||||
if email == "crash@bandicoot" {
|
||||
panic("your wish is my command.. I crash")
|
||||
}
|
||||
|
||||
dec, err := algo.B64RawDecode(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode password: %w", err)
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(user.vault.BridgePass(), dec) != 1 {
|
||||
err := fmt.Errorf("invalid password")
|
||||
user.ReportConfigStatusFailure(err.Error())
|
||||
return "", err
|
||||
}
|
||||
|
||||
return safe.RLockRetErr(func() (string, error) {
|
||||
for _, addr := range user.apiAddrs {
|
||||
if addr.Status != proton.AddressStatusEnabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.EqualFold(addr.Email, email) {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid email")
|
||||
}, user.apiAddrsLock)
|
||||
return user.identityService.CheckAuth(ctx, email, password)
|
||||
}
|
||||
|
||||
// OnStatusUp is called when the connection goes up.
|
||||
func (user *User) OnStatusUp(context.Context) {
|
||||
func (user *User) OnStatusUp(ctx context.Context) {
|
||||
user.log.Info("Connection is up")
|
||||
|
||||
user.goSync()
|
||||
if err := user.imapService.ResumeSync(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to resume sync")
|
||||
}
|
||||
}
|
||||
|
||||
// OnStatusDown is called when the connection goes down.
|
||||
func (user *User) OnStatusDown(context.Context) {
|
||||
func (user *User) OnStatusDown(ctx context.Context) {
|
||||
user.log.Info("Connection is down")
|
||||
|
||||
user.syncAbort.Abort()
|
||||
user.pollAbort.Abort()
|
||||
}
|
||||
|
||||
// GetSyncStatus returns the sync status of the user.
|
||||
func (user *User) GetSyncStatus() vault.SyncStatus {
|
||||
return user.vault.GetSyncStatus()
|
||||
}
|
||||
|
||||
// ClearSyncStatus clears the sync status of the user.
|
||||
// This also drops any updates in the update channel(s).
|
||||
// Warning: the gluon user must be removed and re-added if this happens!
|
||||
func (user *User) ClearSyncStatus() error {
|
||||
user.log.Info("Clearing sync status")
|
||||
|
||||
return safe.LockRet(func() error {
|
||||
return user.clearSyncStatus()
|
||||
}, user.eventLock, user.apiAddrsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// clearSyncStatus clears the sync status of the user.
|
||||
// This also drops any updates in the update channel(s).
|
||||
// Warning: the gluon user must be removed and re-added if this happens!
|
||||
// It is assumed that the eventLock, apiAddrsLock and updateChLock are already locked.
|
||||
func (user *User) clearSyncStatus() error {
|
||||
user.log.Info("Clearing sync status")
|
||||
|
||||
user.initUpdateCh(user.vault.AddressMode())
|
||||
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
user.eventService.Pause()
|
||||
if err := user.imapService.CancelSync(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to cancel sync")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
@ -635,6 +548,19 @@ func (user *User) Logout(ctx context.Context, withAPI bool) error {
|
||||
|
||||
user.log.Debug("Canceling ongoing tasks")
|
||||
|
||||
if err := user.imapService.OnLogout(ctx); err != nil {
|
||||
return fmt.Errorf("failed to remove user from server: %w", err)
|
||||
}
|
||||
|
||||
// Stop Services
|
||||
user.serviceGroup.CancelAndWait()
|
||||
|
||||
// Cleanup Event Service.
|
||||
user.eventService.Close()
|
||||
|
||||
// Close imap service.
|
||||
user.imapService.Close()
|
||||
|
||||
user.tasks.CancelAndWait()
|
||||
|
||||
if withAPI {
|
||||
@ -658,27 +584,24 @@ func (user *User) Logout(ctx context.Context, withAPI bool) error {
|
||||
func (user *User) Close() {
|
||||
user.log.Info("Closing user")
|
||||
|
||||
// Stop Services
|
||||
user.serviceGroup.CancelAndWait()
|
||||
|
||||
// Cleanup Event Service.
|
||||
user.eventService.Close()
|
||||
|
||||
// Close imap service.
|
||||
user.imapService.Close()
|
||||
|
||||
// Stop any ongoing background tasks.
|
||||
user.tasks.CancelAndWait()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
|
||||
// Close the user's update channels.
|
||||
safe.Lock(func() {
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
user.updateCh = make(map[string]*async.QueuedChannel[imap.Update])
|
||||
}, user.updateChLock)
|
||||
|
||||
// Close the user's notify channel.
|
||||
user.eventCh.CloseAndDiscardQueued()
|
||||
|
||||
// Cleanup Event Service.
|
||||
user.eventService.Close()
|
||||
|
||||
// Close the user's vault.
|
||||
if err := user.vault.Close(); err != nil {
|
||||
user.log.WithError(err).Error("Failed to close vault")
|
||||
@ -715,12 +638,6 @@ func (user *User) SendTelemetry(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) WithSMTPData(ctx context.Context, op func(context.Context, map[string]proton.Address, proton.User, *vault.User) error) error {
|
||||
return safe.RLockRet(func() error {
|
||||
return op(ctx, user.apiAddrs, user.apiUser, user.vault)
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.eventLock)
|
||||
}
|
||||
|
||||
func (user *User) ReportSMTPAuthFailed(username string) {
|
||||
emails := user.Emails()
|
||||
for _, mail := range emails {
|
||||
@ -738,182 +655,6 @@ func (user *User) GetSMTPService() *smtp.Service {
|
||||
return user.smtpService
|
||||
}
|
||||
|
||||
// initUpdateCh initializes the user's update channels in the given address mode.
|
||||
// It is assumed that user.apiAddrs and user.updateCh are already locked.
|
||||
func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
updateCh.CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
user.updateCh = make(map[string]*async.QueuedChannel[imap.Update])
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := async.NewQueuedChannel[imap.Update](
|
||||
0,
|
||||
0,
|
||||
user.panicHandler,
|
||||
"user-update-combined",
|
||||
)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = primaryUpdateCh
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = async.NewQueuedChannel[imap.Update](
|
||||
0,
|
||||
0,
|
||||
user.panicHandler,
|
||||
fmt.Sprintf("user-update-split-%v", addrID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startEvents streams events from the API, logging any errors that occur.
|
||||
// This does nothing until the sync has been marked as complete.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) startEvents(ctx context.Context) {
|
||||
ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
var doneCh chan struct{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case doneCh = <-user.pollAPIEventsCh:
|
||||
// ...
|
||||
|
||||
case <-ticker.C:
|
||||
// ...
|
||||
}
|
||||
|
||||
user.log.Debug("Event poll triggered")
|
||||
|
||||
if err := user.doEventPoll(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to poll events")
|
||||
}
|
||||
|
||||
if doneCh != nil {
|
||||
close(doneCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// doEventPoll is called whenever API events should be polled.
|
||||
func (user *User) doEventPoll(ctx context.Context) error {
|
||||
user.eventLock.Lock()
|
||||
defer user.eventLock.Unlock()
|
||||
|
||||
gpaEvents, more, err := user.client.GetEvent(ctx, user.vault.EventID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get event (caused by %T): %w", internal.ErrCause(err), err)
|
||||
}
|
||||
|
||||
// If the event ID hasn't changed, there are no new events.
|
||||
if gpaEvents[len(gpaEvents)-1].EventID == user.vault.EventID() {
|
||||
user.log.Debug("No new API events")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, event := range gpaEvents {
|
||||
user.log.WithFields(logrus.Fields{
|
||||
"old": user.vault.EventID(),
|
||||
"new": event,
|
||||
}).Info("Received new API event")
|
||||
|
||||
// Handle the event.
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
// If the error is a context cancellation, return error to retry later.
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("failed to handle event due to context cancellation: %w", err)
|
||||
}
|
||||
|
||||
// If the error is a network error, return error to retry later.
|
||||
if netErr := new(proton.NetError); errors.As(err, &netErr) {
|
||||
return fmt.Errorf("failed to handle event due to network issue: %w", err)
|
||||
}
|
||||
|
||||
// Catch all for uncategorized net errors that may slip through.
|
||||
if netErr := new(net.OpError); errors.As(err, &netErr) {
|
||||
return fmt.Errorf("failed to handle event due to network issues (uncategorized): %w", err)
|
||||
}
|
||||
|
||||
// In case a json decode error slips through.
|
||||
if jsonErr := new(json.UnmarshalTypeError); errors.As(err, &jsonErr) {
|
||||
user.eventCh.Enqueue(events.UncategorizedEventError{
|
||||
UserID: user.ID(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to handle event due to JSON issue: %w", err)
|
||||
}
|
||||
|
||||
// If the error is an unexpected EOF, return error to retry later.
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return fmt.Errorf("failed to handle event due to EOF: %w", err)
|
||||
}
|
||||
|
||||
// If the error is a server-side issue, return error to retry later.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status >= 500 {
|
||||
return fmt.Errorf("failed to handle event due to server error: %w", err)
|
||||
}
|
||||
|
||||
// Otherwise, the error is a client-side issue; notify bridge to handle it.
|
||||
user.log.WithField("event", event).Warn("Failed to handle API event")
|
||||
|
||||
user.eventCh.Enqueue(events.UserBadEvent{
|
||||
UserID: user.ID(),
|
||||
OldEventID: user.vault.EventID(),
|
||||
NewEventID: event.EventID,
|
||||
EventInfo: event.String(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to handle event due to client error: %w", err)
|
||||
}
|
||||
|
||||
user.log.WithField("event", event).Debug("Handled API event")
|
||||
|
||||
// Update the event ID in the vault. If this fails, notify bridge to handle it.
|
||||
if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
user.eventCh.Enqueue(events.UserBadEvent{
|
||||
UserID: user.ID(),
|
||||
Error: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
|
||||
user.log.WithField("eventID", event.EventID).Debug("Updated event ID in vault")
|
||||
}
|
||||
|
||||
if more {
|
||||
user.goPollAPIEvents(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// b32 returns a uint32 0 or 1 representing b.
|
||||
func b32(b bool) uint32 {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// sleepCtx sleeps for the given duration, or until the context is canceled.
|
||||
func sleepCtx(ctx context.Context, d time.Duration) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(d):
|
||||
}
|
||||
func (user *User) PublishEvent(_ context.Context, event events.Event) {
|
||||
user.eventCh.Enqueue(event)
|
||||
}
|
||||
|
||||
@ -26,6 +26,8 @@ import (
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
"github.com/ProtonMail/go-proton-api/server/backend"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v3/tests"
|
||||
@ -147,9 +149,28 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
|
||||
|
||||
ctl := gomock.NewController(tb)
|
||||
defer ctl.Finish()
|
||||
|
||||
manager := mocks.NewMockHeartbeatManager(ctl)
|
||||
|
||||
manager.EXPECT().IsTelemetryAvailable(context.Background()).AnyTimes()
|
||||
user, err := New(ctx, vaultUser, client, nil, apiUser, nil, true, vault.DefaultMaxSyncMemory, tb.TempDir(), manager)
|
||||
|
||||
nullEventSubscription := events.NewNullSubscription()
|
||||
nullServerManager := imapservice.NewNullIMAPServerManager()
|
||||
|
||||
user, err := New(
|
||||
ctx,
|
||||
vaultUser,
|
||||
client,
|
||||
nil,
|
||||
apiUser,
|
||||
nil,
|
||||
true,
|
||||
vault.DefaultMaxSyncMemory,
|
||||
tb.TempDir(),
|
||||
manager,
|
||||
nullServerManager,
|
||||
nullEventSubscription,
|
||||
)
|
||||
require.NoError(tb, err)
|
||||
defer user.Close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user