mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 23:56:56 +00:00
Other: Safer user types
This commit is contained in:
@ -10,12 +10,13 @@ import (
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -24,27 +25,43 @@ const (
|
||||
)
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
if err := syncLabels(ctx, user.client, maps.Values(user.updateCh)...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
return user.withAddrKRs(func(addrKRs map[string]*crypto.KeyRing) error {
|
||||
logrus.Info("Beginning sync")
|
||||
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
logrus.Info("Syncing labels")
|
||||
|
||||
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
} else {
|
||||
logrus.Info("Labels are already synced, skipping")
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
}
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
logrus.Info("Syncing labels")
|
||||
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
} else {
|
||||
logrus.Info("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
||||
@ -102,48 +119,44 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) syncMessages(ctx context.Context) error {
|
||||
func syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
client *liteapi.Client,
|
||||
vault *vault.User,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update],
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
) error {
|
||||
// Determine which messages to sync.
|
||||
allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
||||
metadata, err := client.GetAllMessageMetadata(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get all message metadata: %w", err)
|
||||
}
|
||||
|
||||
metadata := allMetadata
|
||||
// Get the message IDs to sync.
|
||||
messageIDs := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||
return metadata.ID
|
||||
})
|
||||
|
||||
// If possible, begin syncing from one beyond the last synced message.
|
||||
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
||||
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
||||
return metadata.ID == beginID
|
||||
}); idx >= 0 {
|
||||
metadata = metadata[idx+1:]
|
||||
}
|
||||
if idx := xslices.Index(messageIDs, vault.SyncStatus().LastMessageID); idx >= 0 {
|
||||
messageIDs = messageIDs[idx+1:]
|
||||
}
|
||||
|
||||
// Process the metadata, building the messages.
|
||||
buildCh := stream.Chunk(stream.Map(
|
||||
user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||
return metadata.ID
|
||||
})...),
|
||||
// Fetch and build each message.
|
||||
buildCh := stream.Map(
|
||||
client.GetFullMessages(ctx, messageIDs...),
|
||||
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||
return safe.GetMapErr(
|
||||
user.addrKRs,
|
||||
full.AddressID,
|
||||
func(addrKR *crypto.KeyRing) (*buildRes, error) {
|
||||
return buildRFC822(ctx, full, addrKR)
|
||||
},
|
||||
func() (*buildRes, error) {
|
||||
return nil, fmt.Errorf("address keyring not found")
|
||||
},
|
||||
)
|
||||
return buildRFC822(ctx, full, addrKRs[full.AddressID])
|
||||
},
|
||||
), maxBatchSize)
|
||||
)
|
||||
defer buildCh.Close()
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher)
|
||||
|
||||
for addrID, updateCh := range user.updateCh {
|
||||
for addrID, updateCh := range updateCh {
|
||||
flusher := newFlusher(updateCh, maxUpdateSize)
|
||||
defer flusher.flush(ctx, true)
|
||||
|
||||
@ -151,42 +164,27 @@ func (user *User) syncMessages(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create a reporter to report sync progress updates.
|
||||
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
||||
reporter := newReporter(userID, eventCh, len(messageIDs), time.Second)
|
||||
defer reporter.done()
|
||||
|
||||
var count int
|
||||
|
||||
// Send each update to the appropriate flusher.
|
||||
for {
|
||||
batch, err := buildCh.Next(ctx)
|
||||
if errors.Is(err, stream.End) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to get next sync batch: %w", err)
|
||||
return forEach(ctx, stream.Chunk(buildCh, maxBatchSize), func(batch []*buildRes) error {
|
||||
for _, res := range batch {
|
||||
flushers[res.addressID].push(ctx, res.update)
|
||||
}
|
||||
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
for _, res := range batch {
|
||||
if len(flushers) > 1 {
|
||||
flushers[res.addressID].push(ctx, res.update)
|
||||
} else {
|
||||
flushers[apiAddrs[0].ID].push(ctx, res.update)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for _, flusher := range flushers {
|
||||
flusher.flush(ctx, true)
|
||||
}
|
||||
|
||||
if err := user.vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
||||
if err := vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
|
||||
reporter.add(len(batch))
|
||||
|
||||
count += len(batch)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
|
||||
@ -232,3 +230,18 @@ func wantLabelID(labelID string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error {
|
||||
for {
|
||||
res, err := streamer.Next(ctx)
|
||||
if errors.Is(err, stream.End) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to get next stream item: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(res); err != nil {
|
||||
return fmt.Errorf("failed to process stream item: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user