mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 23:56:56 +00:00
feat(GODT-2500): Add panic handlers everywhere.
This commit is contained in:
@ -225,7 +225,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
@ -284,7 +284,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressEnabled{
|
||||
@ -594,7 +594,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
"subject": logging.Sensitive(message.Subject),
|
||||
}).Info("Handling message created event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
@ -686,7 +686,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
|
||||
"subject": logging.Sensitive(event.Message.Subject),
|
||||
}).Info("Handling draft updated event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
|
||||
@ -290,7 +290,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
||||
|
||||
// Query the server-side message.
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
@ -354,7 +354,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
}
|
||||
|
||||
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -572,7 +572,7 @@ func (conn *imapConnector) importMessage(
|
||||
|
||||
var err error
|
||||
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@ -48,6 +48,8 @@ import (
|
||||
|
||||
// sendMail sends an email from the given address to the given recipients.
|
||||
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
|
||||
defer user.handlePanic()
|
||||
|
||||
return safe.RLockRet(func() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@ -143,7 +145,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// Send the message using the correct key.
|
||||
sent, err := sendWithKey(
|
||||
sent, err := user.sendWithKey(
|
||||
ctx,
|
||||
user.client,
|
||||
user.reporter,
|
||||
@ -167,7 +169,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// sendWithKey sends the message with the given address key.
|
||||
func sendWithKey(
|
||||
func (user *User) sendWithKey(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
sentry reporter.Reporter,
|
||||
@ -226,12 +228,12 @@ func sendWithKey(
|
||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
|
||||
attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
recipients, err := getRecipients(ctx, client, userKR, settings, draft)
|
||||
recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
|
||||
if err != nil {
|
||||
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
||||
}
|
||||
@ -377,7 +379,7 @@ func createDraft(
|
||||
})
|
||||
}
|
||||
|
||||
func createAttachments(
|
||||
func (user *User) createAttachments(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
addrKR *crypto.KeyRing,
|
||||
@ -390,6 +392,8 @@ func createAttachments(
|
||||
}
|
||||
|
||||
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"name": logging.Sensitive(att.Name),
|
||||
"contentID": att.ContentID,
|
||||
@ -455,7 +459,7 @@ func createAttachments(
|
||||
return attKeys, nil
|
||||
}
|
||||
|
||||
func getRecipients(
|
||||
func (user *User) getRecipients(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
userKR *crypto.KeyRing,
|
||||
@ -467,6 +471,8 @@ func getRecipients(
|
||||
})
|
||||
|
||||
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
|
||||
if err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)
|
||||
|
||||
@ -153,7 +153,7 @@ func (user *User) sync(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Sync the messages.
|
||||
if err := syncMessages(
|
||||
if err := user.syncMessages(
|
||||
ctx,
|
||||
user.ID(),
|
||||
messageIDs,
|
||||
@ -242,7 +242,7 @@ func toMB(v uint64) float64 {
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func syncMessages(
|
||||
func (user *User) syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
messageIDs []string,
|
||||
@ -370,7 +370,7 @@ func syncMessages(
|
||||
errorCh := make(chan error, maxParallelDownloads*4)
|
||||
|
||||
// Go routine in charge of downloading message metadata
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(downloadCh)
|
||||
const MetadataDataPageSize = 150
|
||||
|
||||
@ -433,14 +433,14 @@ func syncMessages(
|
||||
}, logging.Labels{"sync-stage": "meta-data"})
|
||||
|
||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(buildCh)
|
||||
defer close(errorCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync downloader exit")
|
||||
}()
|
||||
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads)
|
||||
attachmentDownloader := user.newAttachmentDownloader(ctx, client, maxParallelDownloads)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
for request := range downloadCh {
|
||||
@ -456,6 +456,8 @@ func syncMessages(
|
||||
}
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
var result proton.FullMessage
|
||||
|
||||
msg, err := client.GetMessage(ctx, id)
|
||||
@ -490,7 +492,7 @@ func syncMessages(
|
||||
}, logging.Labels{"sync-stage": "download"})
|
||||
|
||||
// Goroutine which builds messages after they have been downloaded
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(flushCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync builder exit")
|
||||
@ -509,6 +511,8 @@ func syncMessages(
|
||||
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
|
||||
defer user.handlePanic()
|
||||
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
|
||||
})
|
||||
if err != nil {
|
||||
@ -526,7 +530,7 @@ func syncMessages(
|
||||
}, logging.Labels{"sync-stage": "builder"})
|
||||
|
||||
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
|
||||
defer close(flushUpdateCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync flush exit")
|
||||
@ -771,12 +775,12 @@ func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan at
|
||||
}
|
||||
}
|
||||
|
||||
func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
||||
func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
workerCh = make(chan attachmentJob)
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
||||
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||
})
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/async"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@ -93,6 +94,6 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
||||
return sorted
|
||||
}
|
||||
|
||||
func newProtonAPIScheduler() proton.Scheduler {
|
||||
return proton.NewParallelScheduler(runtime.NumCPU() / 2)
|
||||
func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
|
||||
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
|
||||
}
|
||||
|
||||
@ -91,6 +91,8 @@ type User struct {
|
||||
showAllMail uint32
|
||||
|
||||
maxSyncMemory uint64
|
||||
|
||||
panicHandler async.PanicHandler
|
||||
}
|
||||
|
||||
// New returns a new user.
|
||||
@ -127,7 +129,7 @@ func New(
|
||||
reporter: reporter,
|
||||
sendHash: newSendRecorder(sendEntryExpiry),
|
||||
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0, crashHandler),
|
||||
eventLock: safe.NewRWMutex(),
|
||||
|
||||
apiUser: apiUser,
|
||||
@ -148,6 +150,8 @@ func New(
|
||||
showAllMail: b32(showAllMail),
|
||||
|
||||
maxSyncMemory: maxSyncMemory,
|
||||
|
||||
panicHandler: crashHandler,
|
||||
}
|
||||
|
||||
// Initialize the user's update channels for its current address mode.
|
||||
@ -179,7 +183,10 @@ func New(
|
||||
user.goPollAPIEvents = func(wait bool) {
|
||||
doneCh := make(chan struct{})
|
||||
|
||||
go func() { user.pollAPIEventsCh <- doneCh }()
|
||||
go func() {
|
||||
defer user.handlePanic()
|
||||
user.pollAPIEventsCh <- doneCh
|
||||
}()
|
||||
|
||||
if wait {
|
||||
<-doneCh
|
||||
@ -230,6 +237,12 @@ func New(
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (user *User) handlePanic() {
|
||||
if user.panicHandler != nil {
|
||||
user.panicHandler.HandlePanic()
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) TriggerSync() {
|
||||
user.goSync()
|
||||
}
|
||||
@ -596,7 +609,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = primaryUpdateCh
|
||||
@ -604,7 +617,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
|
||||
case vault.SplitMode:
|
||||
for addrID := range user.apiAddrs {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -614,7 +627,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) startEvents(ctx context.Context) {
|
||||
ticker := proton.NewTicker(EventPeriod, EventJitter)
|
||||
ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@ -119,7 +119,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
|
||||
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
|
||||
require.NoError(tb, err)
|
||||
|
||||
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
|
||||
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil)
|
||||
require.NoError(tb, err)
|
||||
require.False(tb, corrupt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user