diff --git a/internal/safe/map.go b/internal/safe/map.go index 461a9db5..0898a695 100644 --- a/internal/safe/map.go +++ b/internal/safe/map.go @@ -136,6 +136,24 @@ func (m *Map[Key, Val]) GetDeleteErr(key Key, fn func(Val) error) (bool, error) return ok, err } +func (m *Map[Key, Val]) GetFunc(where func(Val) bool, fn func(Val)) bool { + m.lock.RLock() + defer m.lock.RUnlock() + + for _, key := range m.order { + if where(m.data[key]) { + fn(m.data[key]) + return true + } + } + + return false +} + +func (m *Map[Key, Val]) Delete(key Key) bool { + return m.GetDelete(key, func(val Val) {}) +} + func (m *Map[Key, Val]) Set(key Key, val Val) bool { m.lock.Lock() defer m.lock.Unlock() @@ -287,6 +305,20 @@ func MapGetRet[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func( return ret, ok } +func MapGetRetErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(Val) (Ret, error)) (Ret, bool, error) { + var ret Ret + + ok, err := m.GetErr(key, func(val Val) error { + var err error + + ret, err = fn(val) + + return err + }) + + return ret, ok, err +} + func MapValuesRet[Key comparable, Val, Ret any](m *Map[Key, Val], fn func([]Val) Ret) Ret { var ret Ret diff --git a/internal/user/events.go b/internal/user/events.go index d4723602..fab08722 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -194,6 +194,8 @@ func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.L } func (user *User) handleCreateLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam + user.apiLabels.Set(event.Label.ID, event.Label) + user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { updateCh.Enqueue(newMailboxCreatedUpdate(imap.MailboxID(event.ID), getMailboxName(event.Label))) }) @@ -202,6 +204,8 @@ func (user *User) handleCreateLabelEvent(_ context.Context, event liteapi.LabelE } func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam + user.apiLabels.Set(event.Label.ID, event.Label) + user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { updateCh.Enqueue(imap.NewMailboxUpdated(imap.MailboxID(event.ID), getMailboxName(event.Label))) }) @@ -210,6 +214,8 @@ func (user *User) handleUpdateLabelEvent(_ context.Context, event liteapi.LabelE } func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelEvent) error { //nolint:unparam + user.apiLabels.Delete(event.Label.ID) + user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) { updateCh.Enqueue(imap.NewMailboxDeleted(imap.MailboxID(event.ID))) }) diff --git a/internal/user/imap.go b/internal/user/imap.go index fa900edc..b1d403e1 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -31,7 +31,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/pkg/message" "github.com/bradenaw/juniper/stream" - "github.com/bradenaw/juniper/xslices" "github.com/google/go-cmp/cmp" "gitlab.protontech.ch/go/liteapi" "golang.org/x/exp/slices" @@ -84,12 +83,14 @@ func (conn *imapConnector) Authorize(username string, password []byte) bool { // GetMailbox returns information about the mailbox with the given ID. func (conn *imapConnector) GetMailbox(ctx context.Context, mailboxID imap.MailboxID) (imap.Mailbox, error) { - label, err := conn.client.GetLabel(ctx, string(mailboxID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder, liteapi.LabelTypeSystem) - if err != nil { - return imap.Mailbox{}, err + mailbox, ok := safe.MapGetRet(conn.apiLabels, string(mailboxID), func(label liteapi.Label) imap.Mailbox { + return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs) + }) + if !ok { + return imap.Mailbox{}, fmt.Errorf("no such mailbox: %s", mailboxID) } - return toIMAPMailbox(label, conn.flags, conn.permFlags, conn.attrs), nil + return mailbox, nil } // CreateMailbox creates a label with the given name. @@ -131,20 +132,13 @@ func (conn *imapConnector) createFolder(ctx context.Context, name []string) (ima var parentID string if len(name) > 1 { - folders, err := conn.client.GetLabels(ctx, liteapi.LabelTypeFolder) - if err != nil { - return imap.Mailbox{}, err - } - - idx := xslices.IndexFunc(folders, func(folder liteapi.Label) bool { - return cmp.Equal(folder.Path, name[:len(name)-1]) - }) - - if idx < 0 { + if ok := conn.apiLabels.GetFunc(func(label liteapi.Label) bool { + return cmp.Equal(label.Path, name[:len(name)-1]) + }, func(label liteapi.Label) { + parentID = label.ID + }); !ok { return imap.Mailbox{}, fmt.Errorf("parent folder %q does not exist", name[:len(name)-1]) } - - parentID = folders[idx].ID } label, err := conn.client.CreateLabel(ctx, liteapi.CreateLabelReq{ @@ -202,20 +196,13 @@ func (conn *imapConnector) updateFolder(ctx context.Context, labelID imap.Mailbo var parentID string if len(name) > 1 { - folders, err := conn.client.GetLabels(ctx, liteapi.LabelTypeFolder) - if err != nil { - return err - } - - idx := xslices.IndexFunc(folders, func(folder liteapi.Label) bool { - return cmp.Equal(folder.Path, name[:len(name)-1]) - }) - - if idx < 0 { + if ok := conn.apiLabels.GetFunc(func(label liteapi.Label) bool { + return cmp.Equal(label.Path, name[:len(name)-1]) + }, func(label liteapi.Label) { + parentID = label.ID + }); !ok { return fmt.Errorf("parent folder %q does not exist", name[:len(name)-1]) } - - parentID = folders[idx].ID } label, err := conn.client.GetLabel(ctx, string(labelID), liteapi.LabelTypeFolder) diff --git a/internal/user/user.go b/internal/user/user.go index eaa4796f..e0345962 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -52,10 +52,11 @@ type User struct { client *liteapi.Client eventCh *queue.QueuedChannel[events.Event] - apiUser *safe.Value[liteapi.User] - apiAddrs *safe.Map[string, liteapi.Address] - updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] - sendHash *sendRecorder + apiUser *safe.Value[liteapi.User] + apiAddrs *safe.Map[string, liteapi.Address] + apiLabels *safe.Map[string, liteapi.Label] + updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] + sendHash *sendRecorder tasks *xsync.Group abortable async.Abortable @@ -87,6 +88,12 @@ func New( return nil, fmt.Errorf("failed to unlock user: %w", err) } + // Get the user's API labels. + apiLabels, err := client.GetLabels(ctx, liteapi.LabelTypeSystem, liteapi.LabelTypeFolder, liteapi.LabelTypeLabel) + if err != nil { + return nil, fmt.Errorf("failed to get labels: %w", err) + } + // Get the latest event ID. if encVault.EventID() == "" { eventID, err := client.GetLatestEventID(ctx) @@ -124,10 +131,11 @@ func New( client: client, eventCh: queue.NewQueuedChannel[events.Event](0, 0), - apiUser: safe.NewValue(apiUser), - apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), - updateCh: safe.NewMapFrom(updateCh, nil), - sendHash: newSendRecorder(sendEntryExpiry), + apiUser: safe.NewValue(apiUser), + apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), + apiLabels: safe.NewMapFrom(groupBy(apiLabels, func(label liteapi.Label) string { return label.ID }), nil), + updateCh: safe.NewMapFrom(updateCh, nil), + sendHash: newSendRecorder(sendEntryExpiry), tasks: xsync.NewGroup(context.Background()), @@ -154,13 +162,13 @@ func New( // When we receive an API event, we attempt to handle it. // If successful, we update the event ID in the vault. goStream := user.tasks.Trigger(func(ctx context.Context) { - for event := range user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()) { + async.RangeContext(ctx, user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()), func(event liteapi.Event) { if err := user.handleAPIEvent(ctx, event); err != nil { user.log.WithError(err).Error("Failed to handle API event") } else if err := user.vault.SetEventID(event.EventID); err != nil { user.log.WithError(err).Error("Failed to update event ID in vault") } - } + }) }) // We only ever want to start one event streamer.