mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 16:17:03 +00:00
Other: Fix user logout hangs due to sync
This commit is contained in:
@ -306,6 +306,10 @@ func TestBridge_FailLoginRecover(t *testing.T) {
|
|||||||
// We should now be able to log the user in.
|
// We should now be able to log the user in.
|
||||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||||
|
|
||||||
|
// The user should be there, now connected.
|
||||||
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -400,6 +404,20 @@ func TestBridge_AddressMode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBridge_LoginLogoutRepeated(t *testing.T) {
|
||||||
|
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
// Log the user in.
|
||||||
|
userID := must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||||
|
|
||||||
|
// Log the user out.
|
||||||
|
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// getErr returns the error that was passed to it.
|
// getErr returns the error that was passed to it.
|
||||||
func getErr[T any](val T, err error) error {
|
func getErr[T any](val T, err error) error {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@ -58,6 +58,8 @@ func (user *User) sync(ctx context.Context) error {
|
|||||||
if err := user.vault.SetHasLabels(true); err != nil {
|
if err := user.vault.SetHasLabels(true); err != nil {
|
||||||
return fmt.Errorf("failed to set has labels: %w", err)
|
return fmt.Errorf("failed to set has labels: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logrus.Info("Synced labels")
|
||||||
} else {
|
} else {
|
||||||
logrus.Info("Labels are already synced, skipping")
|
logrus.Info("Labels are already synced, skipping")
|
||||||
}
|
}
|
||||||
@ -74,6 +76,8 @@ func (user *User) sync(ctx context.Context) error {
|
|||||||
if err := user.vault.SetHasMessages(true); err != nil {
|
if err := user.vault.SetHasMessages(true); err != nil {
|
||||||
return fmt.Errorf("failed to set has messages: %w", err)
|
return fmt.Errorf("failed to set has messages: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logrus.Info("Synced messages")
|
||||||
} else {
|
} else {
|
||||||
logrus.Info("Messages are already synced, skipping")
|
logrus.Info("Messages are already synced, skipping")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -104,9 +104,12 @@ func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
func contextWithStopCh(ctx context.Context, stopCh ...<-chan struct{}) (context.Context, context.CancelFunc) {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
for _, stopCh := range stopCh {
|
||||||
|
stopCh := stopCh
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-stopCh:
|
case <-stopCh:
|
||||||
@ -116,6 +119,7 @@ func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Con
|
|||||||
// ...
|
// ...
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return ctx, cancel
|
return ctx, cancel
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,6 +46,7 @@ type User struct {
|
|||||||
vault *vault.User
|
vault *vault.User
|
||||||
client *liteapi.Client
|
client *liteapi.Client
|
||||||
eventCh *queue.QueuedChannel[events.Event]
|
eventCh *queue.QueuedChannel[events.Event]
|
||||||
|
stopCh chan struct{}
|
||||||
|
|
||||||
apiUser *safe.Value[liteapi.User]
|
apiUser *safe.Value[liteapi.User]
|
||||||
apiAddrs *safe.Map[string, liteapi.Address]
|
apiAddrs *safe.Map[string, liteapi.Address]
|
||||||
@ -101,6 +102,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
|||||||
vault: encVault,
|
vault: encVault,
|
||||||
client: client,
|
client: client,
|
||||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
|
||||||
apiUser: safe.NewValue(apiUser),
|
apiUser: safe.NewValue(apiUser),
|
||||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||||
@ -336,8 +338,13 @@ func (user *User) OnStatusDown() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Logout logs the user out from the API.
|
// Logout logs the user out from the API.
|
||||||
// If withVault is true, the user's vault is also cleared.
|
|
||||||
func (user *User) Logout(ctx context.Context) error {
|
func (user *User) Logout(ctx context.Context) error {
|
||||||
|
// Cancel ongoing syncs.
|
||||||
|
user.stopSync()
|
||||||
|
|
||||||
|
// Wait for ongoing syncs to stop.
|
||||||
|
user.waitSync()
|
||||||
|
|
||||||
if err := user.client.AuthDelete(ctx); err != nil {
|
if err := user.client.AuthDelete(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to delete auth: %w", err)
|
return fmt.Errorf("failed to delete auth: %w", err)
|
||||||
}
|
}
|
||||||
@ -351,6 +358,9 @@ func (user *User) Logout(ctx context.Context) error {
|
|||||||
|
|
||||||
// Close closes ongoing connections and cleans up resources.
|
// Close closes ongoing connections and cleans up resources.
|
||||||
func (user *User) Close() error {
|
func (user *User) Close() error {
|
||||||
|
// Close any ongoing operations.
|
||||||
|
close(user.stopCh)
|
||||||
|
|
||||||
// Cancel ongoing syncs.
|
// Cancel ongoing syncs.
|
||||||
user.stopSync()
|
user.stopSync()
|
||||||
|
|
||||||
@ -410,8 +420,11 @@ func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
|
|
||||||
|
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
for event := range eventCh {
|
for event := range eventCh {
|
||||||
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||||
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
||||||
@ -439,7 +452,7 @@ func (user *User) startSync() <-chan error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
user.eventCh.Enqueue(events.SyncStarted{
|
user.eventCh.Enqueue(events.SyncStarted{
|
||||||
|
|||||||
Reference in New Issue
Block a user