diff --git a/go.mod b/go.mod index 445c59fe..11cdba69 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.18 require ( github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 github.com/Masterminds/semver/v3 v3.1.1 - github.com/ProtonMail/gluon v0.13.1-0.20221021093632-0b277a6d0226 + github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a github.com/ProtonMail/go-rfc5322 v0.11.0 github.com/ProtonMail/gopenpgp/v2 v2.4.10 @@ -37,14 +37,14 @@ require ( github.com/pkg/profile v1.6.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 - github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.35.0 + github.com/urfave/cli/v2 v2.20.3 + gitlab.protontech.ch/go/liteapi v0.35.1-0.20221024102125-605f4712c351 go.uber.org/goleak v1.2.0 golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 golang.org/x/text v0.4.0 - google.golang.org/grpc v1.49.0 + google.golang.org/grpc v1.50.1 google.golang.org/protobuf v1.28.1 howett.net/plist v1.0.0 ) diff --git a/go.sum b/go.sum index 44de99b9..19b37729 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.13.1-0.20221021093632-0b277a6d0226 h1:PWPtXMDKHGY3tG6SEeKE5B7BgfHqTDui4DREb0GoTfU= -github.com/ProtonMail/gluon v0.13.1-0.20221021093632-0b277a6d0226/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= +github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb h1:TGRkkuOdF3mIxbu5QMp62dJ2WvfQGgZ4MUVsNuRD7sc= +github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4= github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo= @@ -391,16 +391,16 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1 github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= -github.com/urfave/cli/v2 v2.16.3 h1:gHoFIwpPjoyIMbJp/VFd+/vuD0dAgFK4B6DpEMFJfQk= -github.com/urfave/cli/v2 v2.16.3/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= +github.com/urfave/cli/v2 v2.20.3 h1:lOgGidH/N5loaigd9HjFsOIhXSTrzl7tBpHswZ428w4= +github.com/urfave/cli/v2 v2.20.3/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.35.0 h1:26Yx8+pp14MPxV1DL8gU+cahPigrEATzm3WpoNQs7kU= -gitlab.protontech.ch/go/liteapi v0.35.0/go.mod h1:NuLWhVn8c0bR9qUaJER7VbPFu7oEowWWcP5ANgQHwRo= +gitlab.protontech.ch/go/liteapi v0.35.1-0.20221024102125-605f4712c351 h1:0qom9c2XYMdW1J8JAMBgflL0a8SYwloi9tbE8FV2O2A= +gitlab.protontech.ch/go/liteapi v0.35.1-0.20221024102125-605f4712c351/go.mod h1:VMv325o/Z3h9c6JL7loDVjXF0mqs4/7JbzhGmABxDj8= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= @@ -575,8 +575,8 @@ google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737/go.mod h1:2r/26NEF google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.49.0 h1:WTLtQzmQori5FUH25Pq4WT22oCsv8USpQ+F6rqtsmxw= -google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= +google.golang.org/grpc v1.50.1 h1:DS/BukOZWp8s6p4Dt/tOaJaTQyPyOoCcrjroHuCeLzY= +google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= diff --git a/internal/async/context.go b/internal/async/context.go new file mode 100644 index 00000000..6faabd34 --- /dev/null +++ b/internal/async/context.go @@ -0,0 +1,53 @@ +// Copyright (c) 2022 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package async + +import ( + "context" + "sync" +) + +// Abortable collects groups of functions that can be aborted by calling Abort. +type Abortable struct { + abortFunc []context.CancelFunc + abortLock sync.RWMutex +} + +func (a *Abortable) Do(ctx context.Context, fn func(context.Context)) { + fn(a.newCancelCtx(ctx)) +} + +func (a *Abortable) Abort() { + a.abortLock.RLock() + defer a.abortLock.RUnlock() + + for _, fn := range a.abortFunc { + fn() + } +} + +func (a *Abortable) newCancelCtx(ctx context.Context) context.Context { + a.abortLock.Lock() + defer a.abortLock.Unlock() + + ctx, cancel := context.WithCancel(ctx) + + a.abortFunc = append(a.abortFunc, cancel) + + return ctx +} diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index beebf594..d0baa2fd 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -332,11 +332,9 @@ func (bridge *Bridge) Close(ctx context.Context) error { } // Close all users. - if err := bridge.users.IterValuesErr(func(user *user.User) error { - return user.Close() - }); err != nil { - logrus.WithError(err).Error("Failed to close users") - } + bridge.users.IterValues(func(user *user.User) { + user.Close() + }) // Close the focus service. bridge.focusService.Close() diff --git a/internal/bridge/user.go b/internal/bridge/user.go index 745de222..2eceaefe 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -505,9 +505,7 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error { logrus.WithError(err).Error("Failed to logout user") } - if err := user.Close(); err != nil { - logrus.WithError(err).Error("Failed to close user") - } + user.Close() return nil }); !ok { @@ -532,9 +530,7 @@ func (bridge *Bridge) deleteUser(ctx context.Context, userID string) { logrus.WithError(err).Error("Failed to logout user") } - if err := user.Close(); err != nil { - logrus.WithError(err).Error("Failed to close user") - } + user.Close() }); !ok { logrus.Debug("The bridge user was not connected") } diff --git a/internal/safe/map.go b/internal/safe/map.go index 1d9301f0..61e12dce 100644 --- a/internal/safe/map.go +++ b/internal/safe/map.go @@ -136,34 +136,58 @@ func (m *Map[Key, Val]) GetDeleteErr(key Key, fn func(Val) error) (bool, error) return ok, err } -func (m *Map[Key, Val]) Set(key Key, val Val) { +func (m *Map[Key, Val]) Set(key Key, val Val) bool { m.lock.Lock() defer m.lock.Unlock() + var had bool + + if _, ok := m.data[key]; ok { + had = true + } + m.data[key] = val - m.order = append(m.order, key) + if idx := xslices.Index(m.order, key); idx >= 0 { + m.order[idx] = key + } else { + m.order = append(m.order, key) + } if m.sort != nil { slices.SortFunc(m.order, func(a, b Key) bool { return m.sort(a, b, m.data) }) } + + return had } -func (m *Map[Key, Val]) SetFrom(key Key, other Key) { +func (m *Map[Key, Val]) SetFrom(key Key, other Key) bool { m.lock.Lock() defer m.lock.Unlock() + var had bool + + if _, ok := m.data[key]; ok { + had = true + } + m.data[key] = m.data[other] - m.order = append(m.order, key) + if idx := xslices.Index(m.order, key); idx >= 0 { + m.order[idx] = key + } else { + m.order = append(m.order, key) + } if m.sort != nil { slices.SortFunc(m.order, func(a, b Key) bool { return m.sort(a, b, m.data) }) } + + return had } func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) { diff --git a/internal/user/events.go b/internal/user/events.go index e19dd9e0..e50d6eea 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -27,7 +27,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" - "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" ) @@ -92,7 +91,7 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea } case liteapi.EventUpdateFlags: - logrus.Warn("Not implemented yet.") + user.log.Warn("Not implemented yet.") } } @@ -100,7 +99,9 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea } func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error { - user.apiAddrs.Set(event.Address.ID, event.Address) + if had := user.apiAddrs.Set(event.Address.ID, event.Address); had { + return fmt.Errorf("address %q already exists", event.Address.ID) + } switch user.vault.AddressMode() { case vault.CombinedMode: @@ -132,7 +133,9 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad } func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam - user.apiAddrs.Set(event.Address.ID, event.Address) + if had := user.apiAddrs.Set(event.Address.ID, event.Address); !had { + return fmt.Errorf("address %q does not exist", event.Address.ID) + } user.eventCh.Enqueue(events.UserAddressUpdated{ UserID: user.ID(), @@ -219,9 +222,6 @@ func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelE // handleMessageEvents handles the given message events. func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - for _, event := range messageEvents { switch event.Action { case liteapi.EventCreate: @@ -249,7 +249,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me } return user.withAddrKR(event.Message.AddressID, func(_, addrKR *crypto.KeyRing) error { - buildRes, err := buildRFC822(ctx, full, addrKR) + buildRes, err := buildRFC822(full, addrKR) if err != nil { return fmt.Errorf("failed to build RFC822 message: %w", err) } diff --git a/internal/user/imap.go b/internal/user/imap.go index 13d37b17..ab825547 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -20,6 +20,7 @@ package user import ( "context" "fmt" + "sync/atomic" "time" "github.com/ProtonMail/gluon/imap" @@ -79,9 +80,9 @@ func (conn *imapConnector) Authorize(username string, password []byte) bool { return true } -// GetMailbox returns information about the label with the given ID. -func (conn *imapConnector) GetMailbox(ctx context.Context, labelID imap.MailboxID) (imap.Mailbox, error) { - label, err := conn.client.GetLabel(ctx, string(labelID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder) +// GetMailbox returns information about the mailbox with the given ID. +func (conn *imapConnector) GetMailbox(ctx context.Context, mailboxID imap.MailboxID) (imap.Mailbox, error) { + label, err := conn.client.GetLabel(ctx, string(mailboxID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder) if err != nil { return imap.Mailbox{}, err } @@ -362,10 +363,7 @@ func (conn *imapConnector) Close(ctx context.Context) error { return nil } -func (conn *imapConnector) IsMailboxVisible(_ context.Context, id imap.MailboxID) bool { - if !conn.GetShowAllMail() && id == liteapi.AllMailLabel { - return false - } - - return true +// IsMailboxVisible returns whether this mailbox should be visible over IMAP. +func (conn *imapConnector) IsMailboxVisible(_ context.Context, mailboxID imap.MailboxID) bool { + return atomic.LoadUint32(&conn.showAllMail) != 0 || mailboxID != liteapi.AllMailLabel } diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 5b7414dc..27fce27d 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -35,7 +35,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/pkg/message/parser" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/xslices" - "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" "golang.org/x/exp/slices" ) @@ -104,7 +103,7 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str return fmt.Errorf("failed to send message: %w", err) } - logrus.WithField("messageID", sent.ID).Info("Message sent") + user.log.WithField("messageID", sent.ID).Info("Message sent") return nil }) diff --git a/internal/user/sync.go b/internal/user/sync.go index 821f30d3..9a8584c4 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -33,7 +33,6 @@ import ( "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" - "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" ) @@ -42,12 +41,40 @@ const ( maxBatchSize = 1 << 8 ) +// doSync begins syncing the users data. +// It sends a SyncStarted event and then either SyncFinished or SyncFailed +// depending on whether the sync was successful. +func (user *User) doSync(ctx context.Context) error { + user.log.Debug("Beginning user sync") + + user.eventCh.Enqueue(events.SyncStarted{ + UserID: user.ID(), + }) + + if err := user.sync(ctx); err != nil { + user.log.WithError(err).Debug("Failed to sync user") + + user.eventCh.Enqueue(events.SyncFailed{ + UserID: user.ID(), + Err: err, + }) + + return fmt.Errorf("failed to sync: %w", err) + } + + user.log.Debug("Finished user sync") + + user.eventCh.Enqueue(events.SyncFinished{ + UserID: user.ID(), + }) + + return nil +} + func (user *User) sync(ctx context.Context) error { return user.withAddrKRs(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { - logrus.Info("Beginning sync") - if !user.vault.SyncStatus().HasLabels { - logrus.Info("Syncing labels") + user.log.Debug("Syncing labels") if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error { return syncLabels(ctx, user.client, xslices.Unique(updateCh)...) @@ -59,13 +86,13 @@ func (user *User) sync(ctx context.Context) error { return fmt.Errorf("failed to set has labels: %w", err) } - logrus.Info("Synced labels") + user.log.Debug("Synced labels") } else { - logrus.Info("Labels are already synced, skipping") + user.log.Debug("Labels are already synced, skipping") } if !user.vault.SyncStatus().HasMessages { - logrus.Info("Syncing messages") + user.log.Debug("Syncing messages") if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error { return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh) @@ -77,9 +104,9 @@ func (user *User) sync(ctx context.Context) error { return fmt.Errorf("failed to set has messages: %w", err) } - logrus.Info("Synced messages") + user.log.Debug("Synced messages") } else { - logrus.Info("Messages are already synced, skipping") + user.log.Debug("Messages are already synced, skipping") } return nil @@ -169,11 +196,10 @@ func syncMessages( //nolint:funlen // Fetch and build each message. buildCh := stream.Map( client.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...), - func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) { - return buildRFC822(ctx, full, addrKRs[full.AddressID]) + func(_ context.Context, full liteapi.FullMessage) (*buildRes, error) { + return buildRFC822(full, addrKRs[full.AddressID]) }, ) - defer buildCh.Close() // Create the flushers, one per update channel. flushers := make(map[string]*flusher) @@ -254,6 +280,8 @@ func wantLabelID(labelID string) bool { } func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error { + defer streamer.Close() + for { res, err := streamer.Next(ctx) if errors.Is(err, stream.End) { diff --git a/internal/user/sync_build.go b/internal/user/sync_build.go index b4da342c..8ccccad2 100644 --- a/internal/user/sync_build.go +++ b/internal/user/sync_build.go @@ -18,7 +18,6 @@ package user import ( - "context" "fmt" "time" @@ -47,7 +46,7 @@ func defaultJobOpts() message.JobOptions { } } -func buildRFC822(_ context.Context, full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) { +func buildRFC822(full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) { literal, err := message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()) if err != nil { return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err) diff --git a/internal/user/types.go b/internal/user/types.go index 585fa094..ebe4833b 100644 --- a/internal/user/types.go +++ b/internal/user/types.go @@ -18,7 +18,6 @@ package user import ( - "context" "encoding/hex" "fmt" "reflect" @@ -92,22 +91,3 @@ func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) { return "", fmt.Errorf("address %s not found", email) } - -// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it. -func contextWithStopCh(ctx context.Context, channels ...<-chan struct{}) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(ctx) - - for _, stopCh := range channels { - go func(ch <-chan struct{}) { - select { - case <-ch: - cancel() - - case <-ctx.Done(): - // ... - } - }(stopCh) - } - - return ctx, cancel -} diff --git a/internal/user/user.go b/internal/user/user.go index 174c1b66..1a86a907 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -30,11 +30,12 @@ import ( "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/queue" + "github.com/ProtonMail/proton-bridge/v2/internal/async" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/safe" - "github.com/ProtonMail/proton-bridge/v2/internal/try" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/bradenaw/juniper/xslices" + "github.com/bradenaw/juniper/xsync" "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" ) @@ -45,23 +46,35 @@ var ( ) type User struct { + log *logrus.Entry + vault *vault.User client *liteapi.Client eventCh *queue.QueuedChannel[events.Event] - stopCh chan struct{} apiUser *safe.Value[liteapi.User] apiAddrs *safe.Map[string, liteapi.Address] updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]] - syncStopCh chan struct{} - syncLock try.Group - syncWG sync.WaitGroup + tasks *xsync.Group + abortable async.Abortable + goSync func() - showAllMail int32 + showAllMail uint32 } -func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User, showAllMail bool) (*User, error) { //nolint:funlen +// New returns a new user. +// +// nolint:funlen +func New( + ctx context.Context, + encVault *vault.User, + client *liteapi.Client, + apiUser liteapi.User, + showAllMail bool, +) (*User, error) { //nolint:funlen + logrus.WithField("userID", apiUser.ID).Debug("Creating new user") + // Get the user's API addresses. apiAddrs, err := client.GetAddresses(ctx) if err != nil { @@ -104,25 +117,26 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU } user := &User{ + log: logrus.WithField("userID", apiUser.ID), + vault: encVault, client: client, eventCh: queue.NewQueuedChannel[events.Event](0, 0), - stopCh: make(chan struct{}), apiUser: safe.NewValue(apiUser), apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), updateCh: safe.NewMapFrom(updateCh, nil), - syncStopCh: make(chan struct{}), - } + tasks: xsync.NewGroup(context.Background()), - user.SetShowAllMail(showAllMail) + showAllMail: b32(showAllMail), + } // When we receive an auth object, we update it in the vault. // This will be used to authorize the user on the next run. user.client.AddAuthHandler(func(auth liteapi.Auth) { if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil { - logrus.WithError(err).Error("Failed to update auth in vault") + user.log.WithError(err).Error("Failed to update auth in vault") } }) @@ -134,24 +148,38 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU }) }) - // GODT-1946 - Don't start the event loop until the initial sync has finished. - eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID()) - - user.syncWG.Add(1) - // If we haven't synced yet, do it first. - // If it fails, we don't start the event loop. - // Otherwise, begin processing API events, logging any errors that occur. - go func() { - defer user.syncWG.Done() - - if err := <-user.startSync(); err != nil { - return + // Stream events from the API, logging any errors that occur. + // When we receive an API event, we attempt to handle it. + // If successful, we update the event ID in the vault. + goStream := user.tasks.Trigger(func(ctx context.Context) { + for event := range user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()) { + if err := user.handleAPIEvent(ctx, event); err != nil { + user.log.WithError(err).Error("Failed to handle API event") + } else if err := user.vault.SetEventID(event.EventID); err != nil { + user.log.WithError(err).Error("Failed to update event ID in vault") + } } + }) - for err := range user.streamEvents(eventCh) { - logrus.WithError(err).Error("Error while streaming events") - } - }() + // We only ever want to start one event streamer. + var once sync.Once + + // When triggered, attempt to sync the user. + // If successful, we start the event streamer if we haven't already. + user.goSync = user.tasks.Trigger(func(ctx context.Context) { + user.abortable.Do(ctx, func(ctx context.Context) { + if !user.vault.SyncStatus().IsComplete() { + if err := user.doSync(ctx); err != nil { + return + } + } + + once.Do(goStream) + }) + }) + + // Trigger an initial sync (if necessary) and start the event stream. + user.goSync() return user, nil } @@ -199,9 +227,8 @@ func (user *User) GetAddressMode() vault.AddressMode { // SetAddressMode sets the user's address mode. func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error { - user.stopSync() - user.lockSync() - defer user.unlockSync() + user.abortable.Abort() + defer user.goSync() user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { for _, updateCh := range xslices.Unique(updateCh) { @@ -235,12 +262,6 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er return fmt.Errorf("failed to clear sync status: %w", err) } - go func() { - if err := <-user.startSync(); err != nil { - logrus.WithError(err).Error("Failed to sync after setting address mode") - } - }() - return nil } @@ -364,26 +385,17 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) { // OnStatusUp is called when the connection goes up. func (user *User) OnStatusUp() { - go func() { - logrus.Info("Connection up, checking if sync is needed") - - if err := <-user.startSync(); err != nil { - logrus.WithError(err).Error("Failed to sync on status up") - } - }() + user.goSync() } // OnStatusDown is called when the connection goes down. func (user *User) OnStatusDown() { - logrus.Info("Connection down, aborting any ongoing syncs") - - user.stopSync() + user.abortable.Abort() } // Logout logs the user out from the API. func (user *User) Logout(ctx context.Context) error { - // Cancel ongoing syncs. - user.stopSync() + user.tasks.Wait() if err := user.client.AuthDelete(ctx); err != nil { return fmt.Errorf("failed to delete auth: %w", err) @@ -397,14 +409,9 @@ func (user *User) Logout(ctx context.Context) error { } // Close closes ongoing connections and cleans up resources. -func (user *User) Close() error { - defer user.syncWG.Wait() - - // Close any ongoing operations. - close(user.stopCh) - - // Cancel ongoing syncs. - user.stopSync() +func (user *User) Close() { + // Stop any ongoing background tasks. + user.tasks.Wait() // Close the user's API client. user.client.Close() @@ -421,113 +428,20 @@ func (user *User) Close() error { // Close the user's vault. if err := user.vault.Close(); err != nil { - logrus.WithError(err).Error("Failed to close vault") + user.log.WithError(err).Error("Failed to close vault") } - - return nil } +// SetShowAllMail sets whether to show the All Mail mailbox. func (user *User) SetShowAllMail(show bool) { - var value int32 + atomic.StoreUint32(&user.showAllMail, b32(show)) +} - if show { - value = 1 - } else { - value = 0 +// b32 returns a uint32 0 or 1 representing b. +func b32(b bool) uint32 { + if b { + return 1 } - atomic.StoreInt32(&user.showAllMail, value) -} - -func (user *User) GetShowAllMail() bool { - return atomic.LoadInt32(&user.showAllMail) == 1 -} - -// streamEvents begins streaming API events for the user. -// When we receive an API event, we attempt to handle it. -// If successful, we update the event ID in the vault. -func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error { - errCh := make(chan error) - - go func() { - defer close(errCh) - - ctx, cancel := contextWithStopCh(context.Background(), user.stopCh) - defer cancel() - - for event := range eventCh { - if err := user.handleAPIEvent(ctx, event); err != nil { - errCh <- fmt.Errorf("failed to handle API event: %w", err) - } else if err := user.vault.SetEventID(event.EventID); err != nil { - errCh <- fmt.Errorf("failed to update event ID: %w", err) - } - } - }() - - return errCh -} - -// startSync begins a startSync for the user. -func (user *User) startSync() <-chan error { - errCh := make(chan error) - - user.syncLock.GoTry(func(ok bool) { - defer close(errCh) - - if user.vault.SyncStatus().IsComplete() { - logrus.Debug("Already synced, skipping") - return - } - - if !ok { - logrus.Debug("Sync already in progress, skipping") - return - } - - ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh) - defer cancel() - - user.eventCh.Enqueue(events.SyncStarted{ - UserID: user.ID(), - }) - - if err := user.sync(ctx); err != nil { - user.eventCh.Enqueue(events.SyncFailed{ - UserID: user.ID(), - Err: err, - }) - - errCh <- err - } else { - user.eventCh.Enqueue(events.SyncFinished{ - UserID: user.ID(), - }) - } - }) - - return errCh -} - -// AbortSync aborts any ongoing sync. -// GODT-1947: Should probably be done automatically when one of the user's IMAP connectors is closed. -func (user *User) stopSync() { - defer user.syncLock.Wait() - - select { - case user.syncStopCh <- struct{}{}: - logrus.Debug("Sent sync abort signal") - - default: - logrus.Debug("No sync to abort") - } -} - -// lockSync prevents a new sync from starting. -func (user *User) lockSync() { - user.syncLock.Lock() -} - -// unlockSync allows a new sync to start. -func (user *User) unlockSync() { - user.syncLock.Unlock() + return 0 } diff --git a/internal/user/user_test.go b/internal/user/user_test.go index 9ec9c548..f83c5248 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/ProtonMail/gluon/connector" "github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/events" "github.com/ProtonMail/proton-bridge/v2/internal/vault" @@ -30,6 +31,7 @@ import ( "gitlab.protontech.ch/go/liteapi" "gitlab.protontech.ch/go/liteapi/server" "gitlab.protontech.ch/go/liteapi/server/backend" + "go.uber.org/goleak" ) func init() { @@ -39,7 +41,11 @@ func init() { certs.GenerateCert = tests.FastGenerateCert } -func TestUser_Data(t *testing.T) { +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, goleak.IgnoreCurrent()) +} + +func TestUser_Info(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { @@ -66,6 +72,9 @@ func TestUser_Sync(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { withUser(t, ctx, s, m, "username", "password", func(user *User) { + // Process the IMAP updates as if we were gluon. + handleUpdates(t, user) + // User starts a sync at startup. require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) @@ -79,6 +88,36 @@ func TestUser_Sync(t *testing.T) { }) } +func TestUser_AddressMode(t *testing.T) { + withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { + withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) { + withUser(t, ctx, s, m, "username", "password", func(user *User) { + // Process the IMAP updates as if we were gluon. + handleUpdates(t, user) + + // User finishes syncing at startup. + require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) + require.IsType(t, events.SyncProgress{}, <-user.GetEventCh()) + require.IsType(t, events.SyncFinished{}, <-user.GetEventCh()) + + // By default, user should be in combined mode. + require.Equal(t, vault.CombinedMode, user.GetAddressMode()) + + // User should be able to switch to split mode. + require.NoError(t, user.SetAddressMode(ctx, vault.SplitMode)) + + // Process the IMAP updates as if we were gluon. + handleUpdates(t, user) + + // User finishes syncing after switching to split mode. + require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) + require.IsType(t, events.SyncProgress{}, <-user.GetEventCh()) + require.IsType(t, events.SyncFinished{}, <-user.GetEventCh()) + }) + }) + }) +} + func TestUser_Deauth(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { @@ -126,7 +165,6 @@ func withAccount(tb testing.TB, s *server.Server, username, password string, ema func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.Manager, username, password string, fn func(*User)) { //nolint:unparam,revive client, apiAuth, err := m.NewClientWithLogin(ctx, username, []byte(password)) require.NoError(tb, err) - defer client.Close() apiUser, err := client.GetUser(ctx) require.NoError(tb, err) @@ -146,18 +184,20 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.M user, err := New(ctx, vaultUser, client, apiUser, true) require.NoError(tb, err) - defer func() { require.NoError(tb, user.Close()) }() - - imapConn, err := user.NewIMAPConnectors() - require.NoError(tb, err) - - go func() { - for _, imapConn := range imapConn { - for update := range imapConn.GetUpdates() { - update.Done() - } - } - }() + defer user.Close() fn(user) } + +func handleUpdates(t *testing.T, user *User) { + imapConn, err := user.NewIMAPConnectors() + require.NoError(t, err) + + for _, imapConn := range imapConn { + go func(imapConn connector.Connector) { + for update := range imapConn.GetUpdates() { + update.Done() + } + }(imapConn) + } +}