diff --git a/go.mod b/go.mod index 4ee29c3b..9c561565 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 github.com/Masterminds/semver/v3 v3.2.0 - github.com/ProtonMail/gluon v0.17.1-0.20230809071440-23181d8d242e + github.com/ProtonMail/gluon v0.17.1-0.20230814075013-2bbedb1b61ff github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a github.com/ProtonMail/go-proton-api v0.4.1-0.20230724135423-b7d785347afe github.com/ProtonMail/gopenpgp/v2 v2.7.1-proton diff --git a/go.sum b/go.sum index 14148172..859c66d6 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.17.1-0.20230809071440-23181d8d242e h1:aYzqoTzdZ0sU/fzcouEDvJlOBo6JKVd+4II7qxDa3yA= -github.com/ProtonMail/gluon v0.17.1-0.20230809071440-23181d8d242e/go.mod h1:Og5/Dz1MiGpCJn51XujZwxiLG7WzvvjE5PRpZBQmAHo= +github.com/ProtonMail/gluon v0.17.1-0.20230814075013-2bbedb1b61ff h1:EC+mqYMeA869s8SEYpKNEFzaE5hNW17ypffb7cnvEXU= +github.com/ProtonMail/gluon v0.17.1-0.20230814075013-2bbedb1b61ff/go.mod h1:Og5/Dz1MiGpCJn51XujZwxiLG7WzvvjE5PRpZBQmAHo= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4= github.com/ProtonMail/go-crypto v0.0.0-20230321155629-9a39f2531310/go.mod h1:8TI4H3IbrackdNgv+92dI+rhpCaLqM0IfpgCgenFvRE= diff --git a/internal/bridge/settings.go b/internal/bridge/settings.go index bbee9437..5471068f 100644 --- a/internal/bridge/settings.go +++ b/internal/bridge/settings.go @@ -19,9 +19,11 @@ package bridge import ( "context" + "fmt" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v3/internal/safe" + "github.com/ProtonMail/proton-bridge/v3/internal/services/userevents" "github.com/ProtonMail/proton-bridge/v3/internal/updater" "github.com/ProtonMail/proton-bridge/v3/internal/vault" "github.com/sirupsen/logrus" @@ -128,6 +130,38 @@ func (bridge *Bridge) GetGluonDataDir() (string, error) { } func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error { + bridge.usersLock.RLock() + + defer func() { + logrus.Info("Restarting user event loops") + for _, u := range bridge.users { + u.ResumeEventLoop() + } + + bridge.usersLock.RUnlock() + }() + + type waiter struct { + w *userevents.EventPollWaiter + id string + } + + waiters := make([]waiter, 0, len(bridge.users)) + + logrus.Info("Pausing user event loops for gluon dir change") + for id, u := range bridge.users { + waiters = append(waiters, waiter{w: u.PauseEventLoopWithWaiter(), id: id}) + } + + logrus.Info("Waiting on user event loop completion") + for _, waiter := range waiters { + if err := waiter.w.WaitPollFinished(ctx); err != nil { + logrus.WithError(err).Errorf("Failed to wait on event loop pause for user %v", waiter.id) + return fmt.Errorf("failed on event loop pause: %w", err) + } + } + + logrus.Info("Changing gluon directory") return bridge.serverManager.SetGluonDir(ctx, newGluonDir) } diff --git a/internal/bridge/settings_test.go b/internal/bridge/settings_test.go index f10d1932..1655e80c 100644 --- a/internal/bridge/settings_test.go +++ b/internal/bridge/settings_test.go @@ -25,6 +25,7 @@ import ( "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server" "github.com/ProtonMail/proton-bridge/v3/internal/bridge" + "github.com/ProtonMail/proton-bridge/v3/internal/events" "github.com/stretchr/testify/require" ) @@ -51,6 +52,45 @@ func TestBridge_Settings_GluonDir(t *testing.T) { }) } +func TestBridge_Settings_GluonDirWithOnGoingEvents(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { + userID, addrID, err := s.CreateUser("imap", password) + require.NoError(t, err) + + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{})) + defer done() + + _, err := bridge.LoginFull(context.Background(), "imap", password, nil, nil) + require.NoError(t, err) + + <-syncCh + }) + + labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder) + require.NoError(t, err) + + withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) { + createNumMessages(ctx, t, c, addrID, labelID, 200) + }) + + withBridgeWaitForServers(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Create a new location for the Gluon data. + newGluonDir := t.TempDir() + + // Move the gluon dir; it should also move the user's data. + require.NoError(t, bridge.SetGluonDir(context.Background(), newGluonDir)) + + // Check that the new directory is not empty. + entries, err := os.ReadDir(newGluonDir) + require.NoError(t, err) + + // There should be at least one entry. + require.NotEmpty(t, entries) + }) + }) +} + func TestBridge_Settings_IMAPPort(t *testing.T) { withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) { withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { diff --git a/internal/services/imapservice/connector.go b/internal/services/imapservice/connector.go index 2eb102fb..2fcb96a3 100644 --- a/internal/services/imapservice/connector.go +++ b/internal/services/imapservice/connector.go @@ -60,6 +60,8 @@ type Connector struct { labels sharedLabels updateCh *async.QueuedChannel[imap.Update] log *logrus.Entry + + sharedCache *SharedCache } func NewConnector( @@ -101,6 +103,8 @@ func NewConnector( "addr-id": addrID, "user-id": userID, }), + + sharedCache: NewSharedCached(), } } @@ -109,6 +113,11 @@ func (s *Connector) StateClose() { s.updateCh.CloseAndDiscardQueued() } +func (s *Connector) Init(_ context.Context, cache connector.IMAPState) error { + s.sharedCache.Set(cache) + return nil +} + func (s *Connector) Authorize(ctx context.Context, username string, password []byte) bool { addrID, err := s.identityState.CheckAuth(username, password, s.telemetry) if err != nil { @@ -124,7 +133,7 @@ func (s *Connector) Authorize(ctx context.Context, username string, password []b return true } -func (s *Connector) CreateMailbox(ctx context.Context, name []string) (imap.Mailbox, error) { +func (s *Connector) CreateMailbox(ctx context.Context, _ connector.IMAPStateWrite, name []string) (imap.Mailbox, error) { if len(name) < 2 { return imap.Mailbox{}, fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed) } @@ -177,7 +186,7 @@ func (s *Connector) GetMailboxVisibility(_ context.Context, mboxID imap.MailboxI } } -func (s *Connector) UpdateMailboxName(ctx context.Context, mboxID imap.MailboxID, name []string) error { +func (s *Connector) UpdateMailboxName(ctx context.Context, _ connector.IMAPStateWrite, mboxID imap.MailboxID, name []string) error { if len(name) < 2 { return fmt.Errorf("invalid mailbox name %q: %w", name, connector.ErrOperationNotAllowed) } @@ -194,7 +203,7 @@ func (s *Connector) UpdateMailboxName(ctx context.Context, mboxID imap.MailboxID } } -func (s *Connector) DeleteMailbox(ctx context.Context, mboxID imap.MailboxID) error { +func (s *Connector) DeleteMailbox(ctx context.Context, _ connector.IMAPStateWrite, mboxID imap.MailboxID) error { if err := s.client.DeleteLabel(ctx, string(mboxID)); err != nil { return err } @@ -207,7 +216,7 @@ func (s *Connector) DeleteMailbox(ctx context.Context, mboxID imap.MailboxID) er return nil } -func (s *Connector) CreateMessage(ctx context.Context, mailboxID imap.MailboxID, literal []byte, flags imap.FlagSet, _ time.Time) (imap.Message, []byte, error) { +func (s *Connector) CreateMessage(ctx context.Context, _ connector.IMAPStateWrite, mailboxID imap.MailboxID, literal []byte, flags imap.FlagSet, _ time.Time) (imap.Message, []byte, error) { if mailboxID == proton.AllMailLabel { return imap.Message{}, nil, connector.ErrOperationNotAllowed } @@ -305,7 +314,7 @@ func (s *Connector) CreateMessage(ctx context.Context, mailboxID imap.MailboxID, return msg, literal, err } -func (s *Connector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap.MessageID, mboxID imap.MailboxID) error { +func (s *Connector) AddMessagesToMailbox(ctx context.Context, _ connector.IMAPStateWrite, messageIDs []imap.MessageID, mboxID imap.MailboxID) error { if isAllMailOrScheduled(mboxID) { return connector.ErrOperationNotAllowed } @@ -313,7 +322,7 @@ func (s *Connector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap. return s.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), string(mboxID)) } -func (s *Connector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs []imap.MessageID, mboxID imap.MailboxID) error { +func (s *Connector) RemoveMessagesFromMailbox(ctx context.Context, _ connector.IMAPStateWrite, messageIDs []imap.MessageID, mboxID imap.MailboxID) error { if isAllMailOrScheduled(mboxID) { return connector.ErrOperationNotAllowed } @@ -332,7 +341,7 @@ func (s *Connector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs [] return nil } -func (s *Connector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, mboxFromID, mboxToID imap.MailboxID) (bool, error) { +func (s *Connector) MoveMessages(ctx context.Context, _ connector.IMAPStateWrite, messageIDs []imap.MessageID, mboxFromID, mboxToID imap.MailboxID) (bool, error) { if (mboxFromID == proton.InboxLabel && mboxToID == proton.SentLabel) || (mboxFromID == proton.SentLabel && mboxToID == proton.InboxLabel) || isAllMailOrScheduled(mboxFromID) || @@ -370,7 +379,7 @@ func (s *Connector) MoveMessages(ctx context.Context, messageIDs []imap.MessageI return shouldExpungeOldLocation, nil } -func (s *Connector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error { +func (s *Connector) MarkMessagesSeen(ctx context.Context, _ connector.IMAPStateWrite, messageIDs []imap.MessageID, seen bool) error { if seen { return s.client.MarkMessagesRead(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...) } @@ -378,7 +387,7 @@ func (s *Connector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.Mess return s.client.MarkMessagesUnread(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs)...) } -func (s *Connector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error { +func (s *Connector) MarkMessagesFlagged(ctx context.Context, _ connector.IMAPStateWrite, messageIDs []imap.MessageID, flagged bool) error { if flagged { return s.client.LabelMessages(ctx, usertypes.MapTo[imap.MessageID, string](messageIDs), proton.StarredLabel) } @@ -392,6 +401,7 @@ func (s *Connector) GetUpdates() <-chan imap.Update { func (s *Connector) Close(_ context.Context) error { // Nothing to do + s.sharedCache.Close() return nil } diff --git a/internal/services/imapservice/shared_cache.go b/internal/services/imapservice/shared_cache.go new file mode 100644 index 00000000..654af8c1 --- /dev/null +++ b/internal/services/imapservice/shared_cache.go @@ -0,0 +1,87 @@ +// Copyright (c) 2023 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package imapservice + +import ( + "context" + "errors" + "sync" + + "github.com/ProtonMail/gluon/connector" +) + +type CacheAccessor interface { + connector.IMAPState + Close() +} + +// SharedCache is meant to protect access to the database and guarantee it's always valid. There may be some corner +// cases where the Gluon connector can get closed while we are processing events in parallel. If for some reason +// Gluon closes the database, the instance is invalidated and any attempts to access this state will return +// `ErrCacheNotAvailable`. +type SharedCache struct { + cache connector.IMAPState + lock sync.RWMutex +} + +func NewSharedCached() *SharedCache { + return &SharedCache{} +} + +var ErrCacheNotAvailable = errors.New("cache no longer available") + +func (s *SharedCache) Set(cache connector.IMAPState) { + s.lock.Lock() + defer s.lock.Unlock() + + s.cache = cache +} + +func (s *SharedCache) Close() { + s.lock.Lock() + defer s.lock.Unlock() + + s.cache = nil +} + +func (s *SharedCache) Acquire() (CacheAccessor, error) { + s.lock.RLock() + + if s.cache == nil { + s.lock.RUnlock() + return nil, ErrCacheNotAvailable + } + + return &cacheAccessor{sharedCache: s}, nil +} + +type cacheAccessor struct { + sharedCache *SharedCache +} + +func (c cacheAccessor) Read(ctx context.Context, f func(context.Context, connector.IMAPStateRead) error) error { + return c.sharedCache.cache.Read(ctx, f) +} + +func (c cacheAccessor) Write(ctx context.Context, f func(context.Context, connector.IMAPStateWrite) error) error { + return c.sharedCache.cache.Write(ctx, f) +} + +func (c cacheAccessor) Close() { + c.sharedCache.lock.RUnlock() +} diff --git a/internal/services/userevents/service.go b/internal/services/userevents/service.go index bc8eaaab..206c72f2 100644 --- a/internal/services/userevents/service.go +++ b/internal/services/userevents/service.go @@ -116,6 +116,7 @@ func (s *Service) Unsubscribe(subscription EventSubscriber) { // Pause pauses the event polling. func (s *Service) Pause() { + s.log.Info("Pausing") atomic.StoreUint32(&s.paused, 1) } @@ -142,6 +143,7 @@ func (s *Service) PauseWithWaiter() *EventPollWaiter { // Resume resumes the event polling. func (s *Service) Resume() { + s.log.Info("Resuming") atomic.StoreUint32(&s.paused, 0) } diff --git a/internal/user/user.go b/internal/user/user.go index b7694e35..40b6d232 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -675,3 +675,15 @@ func (user *User) GetSMTPService() *smtp.Service { func (user *User) PublishEvent(_ context.Context, event events.Event) { user.eventCh.Enqueue(event) } + +func (user *User) PauseEventLoop() { + user.eventService.Pause() +} + +func (user *User) PauseEventLoopWithWaiter() *userevents.EventPollWaiter { + return user.eventService.PauseWithWaiter() +} + +func (user *User) ResumeEventLoop() { + user.eventService.Resume() +}