forked from Silverfish/proton-bridge
Other: Fix user sync leaks/race conditions
This fixes various race conditions and leaks related to the user's sync and API event stream. It was possible for a sync/stream to begin after a user was already closed; this change prevents that by managing the goroutines related to sync/stream within cancellable groups.
This commit is contained in:
8
go.mod
8
go.mod
@ -5,7 +5,7 @@ go 1.18
|
||||
require (
|
||||
github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557
|
||||
github.com/Masterminds/semver/v3 v3.1.1
|
||||
github.com/ProtonMail/gluon v0.13.1-0.20221021093632-0b277a6d0226
|
||||
github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb
|
||||
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a
|
||||
github.com/ProtonMail/go-rfc5322 v0.11.0
|
||||
github.com/ProtonMail/gopenpgp/v2 v2.4.10
|
||||
@ -37,14 +37,14 @@ require (
|
||||
github.com/pkg/profile v1.6.0
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/urfave/cli/v2 v2.16.3
|
||||
gitlab.protontech.ch/go/liteapi v0.35.0
|
||||
github.com/urfave/cli/v2 v2.20.3
|
||||
gitlab.protontech.ch/go/liteapi v0.35.1-0.20221024102125-605f4712c351
|
||||
go.uber.org/goleak v1.2.0
|
||||
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
||||
golang.org/x/net v0.1.0
|
||||
golang.org/x/sys v0.1.0
|
||||
golang.org/x/text v0.4.0
|
||||
google.golang.org/grpc v1.49.0
|
||||
google.golang.org/grpc v1.50.1
|
||||
google.golang.org/protobuf v1.28.1
|
||||
howett.net/plist v1.0.0
|
||||
)
|
||||
|
||||
16
go.sum
16
go.sum
@ -28,8 +28,8 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
|
||||
github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk=
|
||||
github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g=
|
||||
github.com/ProtonMail/gluon v0.13.1-0.20221021093632-0b277a6d0226 h1:PWPtXMDKHGY3tG6SEeKE5B7BgfHqTDui4DREb0GoTfU=
|
||||
github.com/ProtonMail/gluon v0.13.1-0.20221021093632-0b277a6d0226/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4=
|
||||
github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb h1:TGRkkuOdF3mIxbu5QMp62dJ2WvfQGgZ4MUVsNuRD7sc=
|
||||
github.com/ProtonMail/gluon v0.13.1-0.20221023130957-9bcdfe15b0fb/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4=
|
||||
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4=
|
||||
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4=
|
||||
github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo=
|
||||
@ -391,16 +391,16 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1
|
||||
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
|
||||
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/urfave/cli/v2 v2.16.3 h1:gHoFIwpPjoyIMbJp/VFd+/vuD0dAgFK4B6DpEMFJfQk=
|
||||
github.com/urfave/cli/v2 v2.16.3/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI=
|
||||
github.com/urfave/cli/v2 v2.20.3 h1:lOgGidH/N5loaigd9HjFsOIhXSTrzl7tBpHswZ428w4=
|
||||
github.com/urfave/cli/v2 v2.20.3/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI=
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
||||
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
||||
gitlab.protontech.ch/go/liteapi v0.35.0 h1:26Yx8+pp14MPxV1DL8gU+cahPigrEATzm3WpoNQs7kU=
|
||||
gitlab.protontech.ch/go/liteapi v0.35.0/go.mod h1:NuLWhVn8c0bR9qUaJER7VbPFu7oEowWWcP5ANgQHwRo=
|
||||
gitlab.protontech.ch/go/liteapi v0.35.1-0.20221024102125-605f4712c351 h1:0qom9c2XYMdW1J8JAMBgflL0a8SYwloi9tbE8FV2O2A=
|
||||
gitlab.protontech.ch/go/liteapi v0.35.1-0.20221024102125-605f4712c351/go.mod h1:VMv325o/Z3h9c6JL7loDVjXF0mqs4/7JbzhGmABxDj8=
|
||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
@ -575,8 +575,8 @@ google.golang.org/genproto v0.0.0-20220921223823-23cae91e6737/go.mod h1:2r/26NEF
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
google.golang.org/grpc v1.49.0 h1:WTLtQzmQori5FUH25Pq4WT22oCsv8USpQ+F6rqtsmxw=
|
||||
google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
|
||||
google.golang.org/grpc v1.50.1 h1:DS/BukOZWp8s6p4Dt/tOaJaTQyPyOoCcrjroHuCeLzY=
|
||||
google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
|
||||
|
||||
53
internal/async/context.go
Normal file
53
internal/async/context.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package async
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Abortable collects groups of functions that can be aborted by calling Abort.
|
||||
type Abortable struct {
|
||||
abortFunc []context.CancelFunc
|
||||
abortLock sync.RWMutex
|
||||
}
|
||||
|
||||
func (a *Abortable) Do(ctx context.Context, fn func(context.Context)) {
|
||||
fn(a.newCancelCtx(ctx))
|
||||
}
|
||||
|
||||
func (a *Abortable) Abort() {
|
||||
a.abortLock.RLock()
|
||||
defer a.abortLock.RUnlock()
|
||||
|
||||
for _, fn := range a.abortFunc {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Abortable) newCancelCtx(ctx context.Context) context.Context {
|
||||
a.abortLock.Lock()
|
||||
defer a.abortLock.Unlock()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
a.abortFunc = append(a.abortFunc, cancel)
|
||||
|
||||
return ctx
|
||||
}
|
||||
@ -332,11 +332,9 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Close all users.
|
||||
if err := bridge.users.IterValuesErr(func(user *user.User) error {
|
||||
return user.Close()
|
||||
}); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close users")
|
||||
}
|
||||
bridge.users.IterValues(func(user *user.User) {
|
||||
user.Close()
|
||||
})
|
||||
|
||||
// Close the focus service.
|
||||
bridge.focusService.Close()
|
||||
|
||||
@ -505,9 +505,7 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string) error {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
user.Close()
|
||||
|
||||
return nil
|
||||
}); !ok {
|
||||
@ -532,9 +530,7 @@ func (bridge *Bridge) deleteUser(ctx context.Context, userID string) {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
user.Close()
|
||||
}); !ok {
|
||||
logrus.Debug("The bridge user was not connected")
|
||||
}
|
||||
|
||||
@ -136,34 +136,58 @@ func (m *Map[Key, Val]) GetDeleteErr(key Key, fn func(Val) error) (bool, error)
|
||||
return ok, err
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Set(key Key, val Val) {
|
||||
func (m *Map[Key, Val]) Set(key Key, val Val) bool {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
var had bool
|
||||
|
||||
if _, ok := m.data[key]; ok {
|
||||
had = true
|
||||
}
|
||||
|
||||
m.data[key] = val
|
||||
|
||||
if idx := xslices.Index(m.order, key); idx >= 0 {
|
||||
m.order[idx] = key
|
||||
} else {
|
||||
m.order = append(m.order, key)
|
||||
}
|
||||
|
||||
if m.sort != nil {
|
||||
slices.SortFunc(m.order, func(a, b Key) bool {
|
||||
return m.sort(a, b, m.data)
|
||||
})
|
||||
}
|
||||
|
||||
return had
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) SetFrom(key Key, other Key) {
|
||||
func (m *Map[Key, Val]) SetFrom(key Key, other Key) bool {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
var had bool
|
||||
|
||||
if _, ok := m.data[key]; ok {
|
||||
had = true
|
||||
}
|
||||
|
||||
m.data[key] = m.data[other]
|
||||
|
||||
if idx := xslices.Index(m.order, key); idx >= 0 {
|
||||
m.order[idx] = key
|
||||
} else {
|
||||
m.order = append(m.order, key)
|
||||
}
|
||||
|
||||
if m.sort != nil {
|
||||
slices.SortFunc(m.order, func(a, b Key) bool {
|
||||
return m.sort(a, b, m.data)
|
||||
})
|
||||
}
|
||||
|
||||
return had
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Iter(fn func(key Key, val Val)) {
|
||||
|
||||
@ -27,7 +27,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
@ -92,7 +91,7 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
|
||||
}
|
||||
|
||||
case liteapi.EventUpdateFlags:
|
||||
logrus.Warn("Not implemented yet.")
|
||||
user.log.Warn("Not implemented yet.")
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,7 +99,9 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
|
||||
}
|
||||
|
||||
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
user.apiAddrs.Set(event.Address.ID, event.Address)
|
||||
if had := user.apiAddrs.Set(event.Address.ID, event.Address); had {
|
||||
return fmt.Errorf("address %q already exists", event.Address.ID)
|
||||
}
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
@ -132,7 +133,9 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateAddressEvent(_ context.Context, event liteapi.AddressEvent) error { //nolint:unparam
|
||||
user.apiAddrs.Set(event.Address.ID, event.Address)
|
||||
if had := user.apiAddrs.Set(event.Address.ID, event.Address); !had {
|
||||
return fmt.Errorf("address %q does not exist", event.Address.ID)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.ID(),
|
||||
@ -219,9 +222,6 @@ func (user *User) handleDeleteLabelEvent(_ context.Context, event liteapi.LabelE
|
||||
|
||||
// handleMessageEvents handles the given message events.
|
||||
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
for _, event := range messageEvents {
|
||||
switch event.Action {
|
||||
case liteapi.EventCreate:
|
||||
@ -249,7 +249,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me
|
||||
}
|
||||
|
||||
return user.withAddrKR(event.Message.AddressID, func(_, addrKR *crypto.KeyRing) error {
|
||||
buildRes, err := buildRFC822(ctx, full, addrKR)
|
||||
buildRes, err := buildRFC822(full, addrKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build RFC822 message: %w", err)
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package user
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
@ -79,9 +80,9 @@ func (conn *imapConnector) Authorize(username string, password []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetMailbox returns information about the label with the given ID.
|
||||
func (conn *imapConnector) GetMailbox(ctx context.Context, labelID imap.MailboxID) (imap.Mailbox, error) {
|
||||
label, err := conn.client.GetLabel(ctx, string(labelID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder)
|
||||
// GetMailbox returns information about the mailbox with the given ID.
|
||||
func (conn *imapConnector) GetMailbox(ctx context.Context, mailboxID imap.MailboxID) (imap.Mailbox, error) {
|
||||
label, err := conn.client.GetLabel(ctx, string(mailboxID), liteapi.LabelTypeLabel, liteapi.LabelTypeFolder)
|
||||
if err != nil {
|
||||
return imap.Mailbox{}, err
|
||||
}
|
||||
@ -362,10 +363,7 @@ func (conn *imapConnector) Close(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *imapConnector) IsMailboxVisible(_ context.Context, id imap.MailboxID) bool {
|
||||
if !conn.GetShowAllMail() && id == liteapi.AllMailLabel {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
// IsMailboxVisible returns whether this mailbox should be visible over IMAP.
|
||||
func (conn *imapConnector) IsMailboxVisible(_ context.Context, mailboxID imap.MailboxID) bool {
|
||||
return atomic.LoadUint32(&conn.showAllMail) != 0 || mailboxID != liteapi.AllMailLabel
|
||||
}
|
||||
|
||||
@ -35,7 +35,6 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message/parser"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@ -104,7 +103,7 @@ func (user *User) sendMail(authID string, emails []string, from string, to []str
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
logrus.WithField("messageID", sent.ID).Info("Message sent")
|
||||
user.log.WithField("messageID", sent.ID).Info("Message sent")
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@ -33,7 +33,6 @@ import (
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
@ -42,12 +41,40 @@ const (
|
||||
maxBatchSize = 1 << 8
|
||||
)
|
||||
|
||||
// doSync begins syncing the users data.
|
||||
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
|
||||
// depending on whether the sync was successful.
|
||||
func (user *User) doSync(ctx context.Context) error {
|
||||
user.log.Debug("Beginning user sync")
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
if err := user.sync(ctx); err != nil {
|
||||
user.log.WithError(err).Debug("Failed to sync user")
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFailed{
|
||||
UserID: user.ID(),
|
||||
Err: err,
|
||||
})
|
||||
|
||||
return fmt.Errorf("failed to sync: %w", err)
|
||||
}
|
||||
|
||||
user.log.Debug("Finished user sync")
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
return user.withAddrKRs(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
logrus.Info("Beginning sync")
|
||||
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
logrus.Info("Syncing labels")
|
||||
user.log.Debug("Syncing labels")
|
||||
|
||||
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
|
||||
@ -59,13 +86,13 @@ func (user *User) sync(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("Synced labels")
|
||||
user.log.Debug("Synced labels")
|
||||
} else {
|
||||
logrus.Info("Labels are already synced, skipping")
|
||||
user.log.Debug("Labels are already synced, skipping")
|
||||
}
|
||||
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
logrus.Info("Syncing messages")
|
||||
user.log.Debug("Syncing messages")
|
||||
|
||||
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
|
||||
@ -77,9 +104,9 @@ func (user *User) sync(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("Synced messages")
|
||||
user.log.Debug("Synced messages")
|
||||
} else {
|
||||
logrus.Info("Messages are already synced, skipping")
|
||||
user.log.Debug("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -169,11 +196,10 @@ func syncMessages( //nolint:funlen
|
||||
// Fetch and build each message.
|
||||
buildCh := stream.Map(
|
||||
client.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...),
|
||||
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||
return buildRFC822(ctx, full, addrKRs[full.AddressID])
|
||||
func(_ context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||
return buildRFC822(full, addrKRs[full.AddressID])
|
||||
},
|
||||
)
|
||||
defer buildCh.Close()
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher)
|
||||
@ -254,6 +280,8 @@ func wantLabelID(labelID string) bool {
|
||||
}
|
||||
|
||||
func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error {
|
||||
defer streamer.Close()
|
||||
|
||||
for {
|
||||
res, err := streamer.Next(ctx)
|
||||
if errors.Is(err, stream.End) {
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -47,7 +46,7 @@ func defaultJobOpts() message.JobOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func buildRFC822(_ context.Context, full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) {
|
||||
func buildRFC822(full liteapi.FullMessage, addrKR *crypto.KeyRing) (*buildRes, error) {
|
||||
literal, err := message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build message %s: %w", full.ID, err)
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
@ -92,22 +91,3 @@ func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
}
|
||||
|
||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||
func contextWithStopCh(ctx context.Context, channels ...<-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
for _, stopCh := range channels {
|
||||
go func(ch <-chan struct{}) {
|
||||
select {
|
||||
case <-ch:
|
||||
cancel()
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}(stopCh)
|
||||
}
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
@ -30,11 +30,12 @@ import (
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/async"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/bradenaw/juniper/xsync"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
@ -45,23 +46,35 @@ var (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
log *logrus.Entry
|
||||
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
stopCh chan struct{}
|
||||
|
||||
apiUser *safe.Value[liteapi.User]
|
||||
apiAddrs *safe.Map[string, liteapi.Address]
|
||||
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
|
||||
|
||||
syncStopCh chan struct{}
|
||||
syncLock try.Group
|
||||
syncWG sync.WaitGroup
|
||||
tasks *xsync.Group
|
||||
abortable async.Abortable
|
||||
goSync func()
|
||||
|
||||
showAllMail int32
|
||||
showAllMail uint32
|
||||
}
|
||||
|
||||
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User, showAllMail bool) (*User, error) { //nolint:funlen
|
||||
// New returns a new user.
|
||||
//
|
||||
// nolint:funlen
|
||||
func New(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
showAllMail bool,
|
||||
) (*User, error) { //nolint:funlen
|
||||
logrus.WithField("userID", apiUser.ID).Debug("Creating new user")
|
||||
|
||||
// Get the user's API addresses.
|
||||
apiAddrs, err := client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
@ -104,25 +117,26 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
}
|
||||
|
||||
user := &User{
|
||||
log: logrus.WithField("userID", apiUser.ID),
|
||||
|
||||
vault: encVault,
|
||||
client: client,
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
stopCh: make(chan struct{}),
|
||||
|
||||
apiUser: safe.NewValue(apiUser),
|
||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||
updateCh: safe.NewMapFrom(updateCh, nil),
|
||||
|
||||
syncStopCh: make(chan struct{}),
|
||||
}
|
||||
tasks: xsync.NewGroup(context.Background()),
|
||||
|
||||
user.SetShowAllMail(showAllMail)
|
||||
showAllMail: b32(showAllMail),
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the vault.
|
||||
// This will be used to authorize the user on the next run.
|
||||
user.client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update auth in vault")
|
||||
user.log.WithError(err).Error("Failed to update auth in vault")
|
||||
}
|
||||
})
|
||||
|
||||
@ -134,24 +148,38 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
})
|
||||
})
|
||||
|
||||
// GODT-1946 - Don't start the event loop until the initial sync has finished.
|
||||
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
|
||||
// Stream events from the API, logging any errors that occur.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
goStream := user.tasks.Trigger(func(ctx context.Context) {
|
||||
for event := range user.client.NewEventStream(ctx, EventPeriod, EventJitter, user.vault.EventID()) {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
user.log.WithError(err).Error("Failed to handle API event")
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
user.log.WithError(err).Error("Failed to update event ID in vault")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
user.syncWG.Add(1)
|
||||
// If we haven't synced yet, do it first.
|
||||
// If it fails, we don't start the event loop.
|
||||
// Otherwise, begin processing API events, logging any errors that occur.
|
||||
go func() {
|
||||
defer user.syncWG.Done()
|
||||
// We only ever want to start one event streamer.
|
||||
var once sync.Once
|
||||
|
||||
if err := <-user.startSync(); err != nil {
|
||||
// When triggered, attempt to sync the user.
|
||||
// If successful, we start the event streamer if we haven't already.
|
||||
user.goSync = user.tasks.Trigger(func(ctx context.Context) {
|
||||
user.abortable.Do(ctx, func(ctx context.Context) {
|
||||
if !user.vault.SyncStatus().IsComplete() {
|
||||
if err := user.doSync(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for err := range user.streamEvents(eventCh) {
|
||||
logrus.WithError(err).Error("Error while streaming events")
|
||||
}
|
||||
}()
|
||||
|
||||
once.Do(goStream)
|
||||
})
|
||||
})
|
||||
|
||||
// Trigger an initial sync (if necessary) and start the event stream.
|
||||
user.goSync()
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@ -199,9 +227,8 @@ func (user *User) GetAddressMode() vault.AddressMode {
|
||||
|
||||
// SetAddressMode sets the user's address mode.
|
||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||
user.stopSync()
|
||||
user.lockSync()
|
||||
defer user.unlockSync()
|
||||
user.abortable.Abort()
|
||||
defer user.goSync()
|
||||
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
@ -235,12 +262,6 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync after setting address mode")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -364,26 +385,17 @@ func (user *User) CheckAuth(email string, password []byte) (string, error) {
|
||||
|
||||
// OnStatusUp is called when the connection goes up.
|
||||
func (user *User) OnStatusUp() {
|
||||
go func() {
|
||||
logrus.Info("Connection up, checking if sync is needed")
|
||||
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync on status up")
|
||||
}
|
||||
}()
|
||||
user.goSync()
|
||||
}
|
||||
|
||||
// OnStatusDown is called when the connection goes down.
|
||||
func (user *User) OnStatusDown() {
|
||||
logrus.Info("Connection down, aborting any ongoing syncs")
|
||||
|
||||
user.stopSync()
|
||||
user.abortable.Abort()
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
user.tasks.Wait()
|
||||
|
||||
if err := user.client.AuthDelete(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete auth: %w", err)
|
||||
@ -397,14 +409,9 @@ func (user *User) Logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
func (user *User) Close() error {
|
||||
defer user.syncWG.Wait()
|
||||
|
||||
// Close any ongoing operations.
|
||||
close(user.stopCh)
|
||||
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
func (user *User) Close() {
|
||||
// Stop any ongoing background tasks.
|
||||
user.tasks.Wait()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
@ -421,113 +428,20 @@ func (user *User) Close() error {
|
||||
|
||||
// Close the user's vault.
|
||||
if err := user.vault.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close vault")
|
||||
user.log.WithError(err).Error("Failed to close vault")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetShowAllMail sets whether to show the All Mail mailbox.
|
||||
func (user *User) SetShowAllMail(show bool) {
|
||||
var value int32
|
||||
|
||||
if show {
|
||||
value = 1
|
||||
} else {
|
||||
value = 0
|
||||
}
|
||||
|
||||
atomic.StoreInt32(&user.showAllMail, value)
|
||||
atomic.StoreUint32(&user.showAllMail, b32(show))
|
||||
}
|
||||
|
||||
func (user *User) GetShowAllMail() bool {
|
||||
return atomic.LoadInt32(&user.showAllMail) == 1
|
||||
}
|
||||
|
||||
// streamEvents begins streaming API events for the user.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh)
|
||||
defer cancel()
|
||||
|
||||
for event := range eventCh {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// startSync begins a startSync for the user.
|
||||
func (user *User) startSync() <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
user.syncLock.GoTry(func(ok bool) {
|
||||
defer close(errCh)
|
||||
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
logrus.Debug("Already synced, skipping")
|
||||
return
|
||||
// b32 returns a uint32 0 or 1 representing b.
|
||||
func b32(b bool) uint32 {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
if !ok {
|
||||
logrus.Debug("Sync already in progress, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh)
|
||||
defer cancel()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
if err := user.sync(ctx); err != nil {
|
||||
user.eventCh.Enqueue(events.SyncFailed{
|
||||
UserID: user.ID(),
|
||||
Err: err,
|
||||
})
|
||||
|
||||
errCh <- err
|
||||
} else {
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// GODT-1947: Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) stopSync() {
|
||||
defer user.syncLock.Wait()
|
||||
|
||||
select {
|
||||
case user.syncStopCh <- struct{}{}:
|
||||
logrus.Debug("Sent sync abort signal")
|
||||
|
||||
default:
|
||||
logrus.Debug("No sync to abort")
|
||||
}
|
||||
}
|
||||
|
||||
// lockSync prevents a new sync from starting.
|
||||
func (user *User) lockSync() {
|
||||
user.syncLock.Lock()
|
||||
}
|
||||
|
||||
// unlockSync allows a new sync to start.
|
||||
func (user *User) unlockSync() {
|
||||
user.syncLock.Unlock()
|
||||
return 0
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
@ -30,6 +31,7 @@ import (
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
"gitlab.protontech.ch/go/liteapi/server/backend"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -39,7 +41,11 @@ func init() {
|
||||
certs.GenerateCert = tests.FastGenerateCert
|
||||
}
|
||||
|
||||
func TestUser_Data(t *testing.T) {
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m, goleak.IgnoreCurrent())
|
||||
}
|
||||
|
||||
func TestUser_Info(t *testing.T) {
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *User) {
|
||||
@ -66,6 +72,9 @@ func TestUser_Sync(t *testing.T) {
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *User) {
|
||||
// Process the IMAP updates as if we were gluon.
|
||||
handleUpdates(t, user)
|
||||
|
||||
// User starts a sync at startup.
|
||||
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
|
||||
|
||||
@ -79,6 +88,36 @@ func TestUser_Sync(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_AddressMode(t *testing.T) {
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *User) {
|
||||
// Process the IMAP updates as if we were gluon.
|
||||
handleUpdates(t, user)
|
||||
|
||||
// User finishes syncing at startup.
|
||||
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
|
||||
require.IsType(t, events.SyncProgress{}, <-user.GetEventCh())
|
||||
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
|
||||
|
||||
// By default, user should be in combined mode.
|
||||
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
|
||||
|
||||
// User should be able to switch to split mode.
|
||||
require.NoError(t, user.SetAddressMode(ctx, vault.SplitMode))
|
||||
|
||||
// Process the IMAP updates as if we were gluon.
|
||||
handleUpdates(t, user)
|
||||
|
||||
// User finishes syncing after switching to split mode.
|
||||
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
|
||||
require.IsType(t, events.SyncProgress{}, <-user.GetEventCh())
|
||||
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_Deauth(t *testing.T) {
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
|
||||
@ -126,7 +165,6 @@ func withAccount(tb testing.TB, s *server.Server, username, password string, ema
|
||||
func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.Manager, username, password string, fn func(*User)) { //nolint:unparam,revive
|
||||
client, apiAuth, err := m.NewClientWithLogin(ctx, username, []byte(password))
|
||||
require.NoError(tb, err)
|
||||
defer client.Close()
|
||||
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
require.NoError(tb, err)
|
||||
@ -146,18 +184,20 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.M
|
||||
|
||||
user, err := New(ctx, vaultUser, client, apiUser, true)
|
||||
require.NoError(tb, err)
|
||||
defer func() { require.NoError(tb, user.Close()) }()
|
||||
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
require.NoError(tb, err)
|
||||
|
||||
go func() {
|
||||
for _, imapConn := range imapConn {
|
||||
for update := range imapConn.GetUpdates() {
|
||||
update.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer user.Close()
|
||||
|
||||
fn(user)
|
||||
}
|
||||
|
||||
func handleUpdates(t *testing.T, user *User) {
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, imapConn := range imapConn {
|
||||
go func(imapConn connector.Connector) {
|
||||
for update := range imapConn.GetUpdates() {
|
||||
update.Done()
|
||||
}
|
||||
}(imapConn)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user