feat(GODT-2500): Add panic handlers everywhere.

This commit is contained in:
Jakub
2023-03-22 17:18:17 +01:00
parent 9f59e61b14
commit ec92c918cd
42 changed files with 283 additions and 130 deletions

View File

@ -225,7 +225,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode:
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
}
user.eventCh.Enqueue(events.UserAddressCreated{
@ -284,7 +284,7 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode:
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
}
user.eventCh.Enqueue(events.UserAddressEnabled{
@ -594,7 +594,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
"subject": logging.Sensitive(message.Subject),
}).Info("Handling message created event")
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
@ -686,7 +686,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
"subject": logging.Sensitive(event.Message.Subject),
}).Info("Handling draft updated event")
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(user.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {

View File

@ -290,7 +290,7 @@ func (conn *imapConnector) CreateMessage(
conn.log.WithField("messageID", messageID).Warn("Message already sent")
// Query the server-side message.
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
}
@ -354,7 +354,7 @@ func (conn *imapConnector) CreateMessage(
}
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator())
if err != nil {
return nil, err
}
@ -572,7 +572,7 @@ func (conn *imapConnector) importMessage(
var err error
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil {
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(conn.panicHandler), proton.NewDefaultAttachmentAllocator()); err != nil {
return fmt.Errorf("failed to fetch message: %w", err)
}

View File

@ -48,6 +48,8 @@ import (
// sendMail sends an email from the given address to the given recipients.
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
defer user.handlePanic()
return safe.RLockRet(func() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -143,7 +145,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// Send the message using the correct key.
sent, err := sendWithKey(
sent, err := user.sendWithKey(
ctx,
user.client,
user.reporter,
@ -167,7 +169,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// sendWithKey sends the message with the given address key.
func sendWithKey(
func (user *User) sendWithKey(
ctx context.Context,
client *proton.Client,
sentry reporter.Reporter,
@ -226,12 +228,12 @@ func sendWithKey(
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
}
attKeys, err := createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
attKeys, err := user.createAttachments(ctx, client, addrKR, draft.ID, message.Attachments)
if err != nil {
return proton.Message{}, fmt.Errorf("failed to create attachments: %w", err)
}
recipients, err := getRecipients(ctx, client, userKR, settings, draft)
recipients, err := user.getRecipients(ctx, client, userKR, settings, draft)
if err != nil {
return proton.Message{}, fmt.Errorf("failed to get recipients: %w", err)
}
@ -377,7 +379,7 @@ func createDraft(
})
}
func createAttachments(
func (user *User) createAttachments(
ctx context.Context,
client *proton.Client,
addrKR *crypto.KeyRing,
@ -390,6 +392,8 @@ func createAttachments(
}
keys, err := parallel.MapContext(ctx, runtime.NumCPU(), attachments, func(ctx context.Context, att message.Attachment) (attKey, error) {
defer user.handlePanic()
logrus.WithFields(logrus.Fields{
"name": logging.Sensitive(att.Name),
"contentID": att.ContentID,
@ -455,7 +459,7 @@ func createAttachments(
return attKeys, nil
}
func getRecipients(
func (user *User) getRecipients(
ctx context.Context,
client *proton.Client,
userKR *crypto.KeyRing,
@ -467,6 +471,8 @@ func getRecipients(
})
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
defer user.handlePanic()
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
if err != nil {
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)

View File

@ -153,7 +153,7 @@ func (user *User) sync(ctx context.Context) error {
}
// Sync the messages.
if err := syncMessages(
if err := user.syncMessages(
ctx,
user.ID(),
messageIDs,
@ -242,7 +242,7 @@ func toMB(v uint64) float64 {
}
// nolint:gocyclo
func syncMessages(
func (user *User) syncMessages(
ctx context.Context,
userID string,
messageIDs []string,
@ -370,7 +370,7 @@ func syncMessages(
errorCh := make(chan error, maxParallelDownloads*4)
// Go routine in charge of downloading message metadata
logging.GoAnnotated(ctx, func(ctx context.Context) {
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(downloadCh)
const MetadataDataPageSize = 150
@ -433,14 +433,14 @@ func syncMessages(
}, logging.Labels{"sync-stage": "meta-data"})
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
logging.GoAnnotated(ctx, func(ctx context.Context) {
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(buildCh)
defer close(errorCh)
defer func() {
logrus.Debugf("sync downloader exit")
}()
attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads)
attachmentDownloader := user.newAttachmentDownloader(ctx, client, maxParallelDownloads)
defer attachmentDownloader.close()
for request := range downloadCh {
@ -456,6 +456,8 @@ func syncMessages(
}
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
defer user.handlePanic()
var result proton.FullMessage
msg, err := client.GetMessage(ctx, id)
@ -490,7 +492,7 @@ func syncMessages(
}, logging.Labels{"sync-stage": "download"})
// Goroutine which builds messages after they have been downloaded
logging.GoAnnotated(ctx, func(ctx context.Context) {
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(flushCh)
defer func() {
logrus.Debugf("sync builder exit")
@ -509,6 +511,8 @@ func syncMessages(
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
defer user.handlePanic()
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
})
if err != nil {
@ -526,7 +530,7 @@ func syncMessages(
}, logging.Labels{"sync-stage": "builder"})
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
logging.GoAnnotated(ctx, func(ctx context.Context) {
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) {
defer close(flushUpdateCh)
defer func() {
logrus.Debugf("sync flush exit")
@ -771,12 +775,12 @@ func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan at
}
}
func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
func (user *User) newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
ctx, cancel := context.WithCancel(ctx)
for i := 0; i < workerCount; i++ {
workerCh = make(chan attachmentJob)
logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
logging.GoAnnotated(ctx, user.panicHandler, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
"sync": fmt.Sprintf("att-downloader %v", i),
})
}

View File

@ -24,6 +24,7 @@ import (
"strings"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/async"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
@ -93,6 +94,6 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
return sorted
}
func newProtonAPIScheduler() proton.Scheduler {
return proton.NewParallelScheduler(runtime.NumCPU() / 2)
func newProtonAPIScheduler(panicHandler async.PanicHandler) proton.Scheduler {
return proton.NewParallelScheduler(runtime.NumCPU()/2, panicHandler)
}

View File

@ -91,6 +91,8 @@ type User struct {
showAllMail uint32
maxSyncMemory uint64
panicHandler async.PanicHandler
}
// New returns a new user.
@ -127,7 +129,7 @@ func New(
reporter: reporter,
sendHash: newSendRecorder(sendEntryExpiry),
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
eventCh: queue.NewQueuedChannel[events.Event](0, 0, crashHandler),
eventLock: safe.NewRWMutex(),
apiUser: apiUser,
@ -148,6 +150,8 @@ func New(
showAllMail: b32(showAllMail),
maxSyncMemory: maxSyncMemory,
panicHandler: crashHandler,
}
// Initialize the user's update channels for its current address mode.
@ -179,7 +183,10 @@ func New(
user.goPollAPIEvents = func(wait bool) {
doneCh := make(chan struct{})
go func() { user.pollAPIEventsCh <- doneCh }()
go func() {
defer user.handlePanic()
user.pollAPIEventsCh <- doneCh
}()
if wait {
<-doneCh
@ -230,6 +237,12 @@ func New(
return user, nil
}
func (user *User) handlePanic() {
if user.panicHandler != nil {
user.panicHandler.HandlePanic()
}
}
func (user *User) TriggerSync() {
user.goSync()
}
@ -596,7 +609,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
switch mode {
case vault.CombinedMode:
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
for addrID := range user.apiAddrs {
user.updateCh[addrID] = primaryUpdateCh
@ -604,7 +617,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
case vault.SplitMode:
for addrID := range user.apiAddrs {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0, user.panicHandler)
}
}
}
@ -614,7 +627,7 @@ func (user *User) initUpdateCh(mode vault.AddressMode) {
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
func (user *User) startEvents(ctx context.Context) {
ticker := proton.NewTicker(EventPeriod, EventJitter)
ticker := proton.NewTicker(EventPeriod, EventJitter, user.panicHandler)
defer ticker.Stop()
for {

View File

@ -119,7 +119,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
require.NoError(tb, err)
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"), nil)
require.NoError(tb, err)
require.False(tb, corrupt)