forked from Silverfish/proton-bridge
GODT-1657: Stable sync (still needs more tests)
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
@ -13,114 +15,125 @@ import (
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultEventPeriod = 20 * time.Second
|
||||
DefaultEventJitter = 20 * time.Second
|
||||
EventPeriod = 20 * time.Second
|
||||
EventJitter = 20 * time.Second
|
||||
)
|
||||
|
||||
type User struct {
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
builder *pool.Pool[request, *imap.MessageCreated]
|
||||
attPool *pool.Pool[string, []byte]
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
apiUser liteapi.User
|
||||
apiAddrs *addrList
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
settings liteapi.MailSettings
|
||||
apiUser *safe.Type[liteapi.User]
|
||||
apiAddrs *safe.Slice[liteapi.Address]
|
||||
settings *safe.Type[liteapi.MailSettings]
|
||||
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
||||
syncWG wait.Group
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
||||
syncStopCh chan struct{}
|
||||
syncWG wait.Group
|
||||
}
|
||||
|
||||
func New(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
) (*User, error) {
|
||||
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) {
|
||||
// Get the user's API addresses.
|
||||
apiAddrs, err := client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
// Unlock the user's keyrings.
|
||||
userKR, addrKRs, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unlock user: %w", err)
|
||||
}
|
||||
|
||||
// Get the latest event ID.
|
||||
if encVault.EventID() == "" {
|
||||
eventID, err := client.GetLatestEventID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get latest event ID: %w", err)
|
||||
}
|
||||
|
||||
if err := encVault.SetEventID(eventID); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to set event ID: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the user's mail settings.
|
||||
settings, err := client.GetMailSettings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get mail settings: %w", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
vault: encVault,
|
||||
client: client,
|
||||
builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
// Create update channels for each of the user's addresses (if in combined mode, just the primary).
|
||||
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
apiUser: apiUser,
|
||||
apiAddrs: newAddrList(apiAddrs),
|
||||
for _, addr := range apiAddrs {
|
||||
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
settings: settings,
|
||||
|
||||
updateCh: make(map[string]*queue.QueuedChannel[imap.Update]),
|
||||
}
|
||||
|
||||
// Initialize update channels for each of the user's addresses.
|
||||
for _, addrID := range user.apiAddrs.addrIDs() {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
// If in combined mode, we only need one update channel.
|
||||
if encVault.AddressMode() == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the store.
|
||||
user := &User{
|
||||
vault: encVault,
|
||||
client: client,
|
||||
attPool: pool.New(runtime.NumCPU(), client.GetAttachment),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
apiUser: safe.NewType(apiUser),
|
||||
apiAddrs: safe.NewSlice(apiAddrs),
|
||||
settings: safe.NewType(settings),
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
|
||||
updateCh: updateCh,
|
||||
syncStopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the vault.
|
||||
// This will be used to authorize the user on the next run.
|
||||
client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update auth")
|
||||
logrus.WithError(err).Error("Failed to update auth in vault")
|
||||
}
|
||||
})
|
||||
|
||||
// When we are deauthorized, we send a deauth event to the notify channel.
|
||||
// Bridge will catch this and log the user out.
|
||||
// When we are deauthorized, we send a deauth event to the event channel.
|
||||
// Bridge will react to this event by logging out the user.
|
||||
client.AddDeauthHandler(func() {
|
||||
user.eventCh.Enqueue(events.UserDeauth{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
})
|
||||
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
// If we haven't synced yet, do it first.
|
||||
// If it fails, we don't start the event loop.
|
||||
// Oterwise, begin processing API events, logging any errors that occur.
|
||||
go func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
logrus.WithError(err).Error("Failed to handle event")
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update event ID")
|
||||
if status := user.vault.SyncStatus(); !status.HasMessages {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for err := range user.streamEvents() {
|
||||
logrus.WithError(err).Error("Error while streaming events")
|
||||
}
|
||||
}()
|
||||
|
||||
return user, nil
|
||||
@ -128,30 +141,44 @@ func New(
|
||||
|
||||
// ID returns the user's ID.
|
||||
func (user *User) ID() string {
|
||||
return user.apiUser.ID
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return apiUser.ID
|
||||
})
|
||||
}
|
||||
|
||||
// Name returns the user's username.
|
||||
func (user *User) Name() string {
|
||||
return user.apiUser.Name
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return apiUser.Name
|
||||
})
|
||||
}
|
||||
|
||||
// Match matches the given query against the user's username and email addresses.
|
||||
func (user *User) Match(query string) bool {
|
||||
if query == user.apiUser.Name {
|
||||
return true
|
||||
}
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) bool {
|
||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) bool {
|
||||
if query == apiUser.Name {
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := user.apiAddrs.addrID(query); ok {
|
||||
return true
|
||||
}
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == query {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return false
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Emails returns all the user's email addresses.
|
||||
func (user *User) Emails() []string {
|
||||
return user.apiAddrs.emails()
|
||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// GetAddressMode returns the user's current address mode.
|
||||
@ -167,18 +194,32 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
|
||||
|
||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
for _, addrID := range user.apiAddrs.addrIDs() {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
for _, addr := range apiAddrs {
|
||||
user.updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if mode == vault.CombinedMode {
|
||||
break
|
||||
if mode == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
user.stopSync()
|
||||
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync after setting address mode")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -209,68 +250,27 @@ func (user *User) GluonKey() []byte {
|
||||
|
||||
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
|
||||
func (user *User) BridgePass() []byte {
|
||||
return user.vault.BridgePass()
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
if _, err := hex.NewEncoder(buf).Write(user.vault.BridgePass()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// UsedSpace returns the total space used by the user on the API.
|
||||
func (user *User) UsedSpace() int {
|
||||
return user.apiUser.UsedSpace
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return apiUser.UsedSpace
|
||||
})
|
||||
}
|
||||
|
||||
// MaxSpace returns the amount of space the user can use on the API.
|
||||
func (user *User) MaxSpace() int {
|
||||
return user.apiUser.MaxSpace
|
||||
}
|
||||
|
||||
// HasSync returns whether the user has finished syncing.
|
||||
func (user *User) HasSync() bool {
|
||||
return user.vault.HasSync()
|
||||
}
|
||||
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// TODO: This should abort the sync rather than just waiting.
|
||||
// Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) AbortSync(ctx context.Context) error {
|
||||
user.syncWG.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoSync performs a sync for the user.
|
||||
func (user *User) DoSync(ctx context.Context) <-chan error {
|
||||
errCh := queue.NewQueuedChannel[error](0, 0)
|
||||
|
||||
user.syncWG.Go(func() {
|
||||
defer errCh.Close()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
errCh.Enqueue(func() error {
|
||||
if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
user.syncWait()
|
||||
|
||||
if err := user.vault.SetSync(true); err != nil {
|
||||
return fmt.Errorf("failed to set sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}())
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return apiUser.MaxSpace
|
||||
})
|
||||
|
||||
return errCh.GetChannel()
|
||||
}
|
||||
|
||||
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change)
|
||||
@ -281,31 +281,35 @@ func (user *User) GetEventCh() <-chan events.Event {
|
||||
// NewIMAPConnector returns an IMAP connector for the given address.
|
||||
// If not in split mode, this function returns an error.
|
||||
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
||||
var emails []string
|
||||
return safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (connector.Connector, error) {
|
||||
var emails []string
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
if addrID != user.apiAddrs.primary() {
|
||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
if addrID != apiAddrs[0].ID {
|
||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
||||
}
|
||||
|
||||
emails = xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
email, err := getAddrEmail(apiAddrs, addrID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emails = []string{email}
|
||||
}
|
||||
|
||||
emails = user.apiAddrs.emails()
|
||||
|
||||
case vault.SplitMode:
|
||||
email, ok := user.apiAddrs.email(addrID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("address %s not found", addrID)
|
||||
}
|
||||
|
||||
emails = []string{email}
|
||||
}
|
||||
|
||||
return newIMAPConnector(
|
||||
user.client,
|
||||
user.updateCh[addrID].GetChannel(),
|
||||
user.vault.BridgePass(),
|
||||
emails...,
|
||||
), nil
|
||||
return newIMAPConnector(
|
||||
user.client,
|
||||
user.updateCh[addrID].GetChannel(),
|
||||
user.BridgePass(),
|
||||
emails...,
|
||||
), nil
|
||||
})
|
||||
}
|
||||
|
||||
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
||||
@ -328,22 +332,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
|
||||
// NewSMTPSession returns an SMTP session for the user.
|
||||
func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
|
||||
addrID, ok := user.apiAddrs.addrID(email)
|
||||
if !ok {
|
||||
return nil, ErrNoSuchAddress
|
||||
}
|
||||
|
||||
return newSMTPSession(
|
||||
user.client,
|
||||
user.eventCh,
|
||||
user.apiUser.ID,
|
||||
addrID,
|
||||
user.vault.AddressMode(),
|
||||
user.apiAddrs.addrMap(),
|
||||
user.settings,
|
||||
user.userKR,
|
||||
user.addrKRs,
|
||||
), nil
|
||||
return newSMTPSession(user, email)
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
@ -352,12 +341,12 @@ func (user *User) Logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
func (user *User) Close(ctx context.Context) error {
|
||||
// Wait for ongoing syncs to finish.
|
||||
user.syncWG.Wait()
|
||||
func (user *User) Close() error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
|
||||
// Close the user's message builder.
|
||||
user.builder.Done()
|
||||
// Close the attachment pool.
|
||||
user.attPool.Done()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
@ -372,3 +361,104 @@ func (user *User) Close(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamEvents begins streaming API events for the user.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) streamEvents() <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
for event := range user.client.NewEventStreamer(EventPeriod, EventJitter, user.vault.EventID()).Subscribe() {
|
||||
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// startSync begins a startSync for the user.
|
||||
func (user *User) startSync() <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
user.syncWG.Go(func() {
|
||||
defer close(errCh)
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
||||
defer cancel()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
if err := user.sync(ctx); err != nil {
|
||||
user.eventCh.Enqueue(events.SyncFailed{
|
||||
UserID: user.ID(),
|
||||
Err: err,
|
||||
})
|
||||
|
||||
errCh <- err
|
||||
} else {
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// TODO: Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) stopSync() {
|
||||
select {
|
||||
case user.syncStopCh <- struct{}{}:
|
||||
user.syncWG.Wait()
|
||||
|
||||
default:
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == email {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
}
|
||||
|
||||
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.ID == addrID {
|
||||
return addr.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", addrID)
|
||||
}
|
||||
|
||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user