forked from Silverfish/proton-bridge
Other: Fix user logout hangs due to sync
This commit is contained in:
@ -306,6 +306,10 @@ func TestBridge_FailLoginRecover(t *testing.T) {
|
||||
// We should now be able to log the user in.
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||
|
||||
// The user should be there, now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -400,6 +404,20 @@ func TestBridge_AddressMode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutRepeated(t *testing.T) {
|
||||
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
for i := 0; i < 10; i++ {
|
||||
// Log the user in.
|
||||
userID := must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||
|
||||
// Log the user out.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// getErr returns the error that was passed to it.
|
||||
func getErr[T any](val T, err error) error {
|
||||
return err
|
||||
|
||||
@ -58,6 +58,8 @@ func (user *User) sync(ctx context.Context) error {
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("Synced labels")
|
||||
} else {
|
||||
logrus.Info("Labels are already synced, skipping")
|
||||
}
|
||||
@ -74,6 +76,8 @@ func (user *User) sync(ctx context.Context) error {
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("Synced messages")
|
||||
} else {
|
||||
logrus.Info("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
@ -104,18 +104,22 @@ func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
||||
}
|
||||
|
||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
func contextWithStopCh(ctx context.Context, stopCh ...<-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
for _, stopCh := range stopCh {
|
||||
stopCh := stopCh
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ type User struct {
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
stopCh chan struct{}
|
||||
|
||||
apiUser *safe.Value[liteapi.User]
|
||||
apiAddrs *safe.Map[string, liteapi.Address]
|
||||
@ -101,6 +102,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
vault: encVault,
|
||||
client: client,
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
stopCh: make(chan struct{}),
|
||||
|
||||
apiUser: safe.NewValue(apiUser),
|
||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||
@ -336,8 +338,13 @@ func (user *User) OnStatusDown() {
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
// If withVault is true, the user's vault is also cleared.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
|
||||
// Wait for ongoing syncs to stop.
|
||||
user.waitSync()
|
||||
|
||||
if err := user.client.AuthDelete(ctx); err != nil {
|
||||
return fmt.Errorf("failed to delete auth: %w", err)
|
||||
}
|
||||
@ -351,6 +358,9 @@ func (user *User) Logout(ctx context.Context) error {
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
func (user *User) Close() error {
|
||||
// Close any ongoing operations.
|
||||
close(user.stopCh)
|
||||
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
|
||||
@ -410,8 +420,11 @@ func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh)
|
||||
defer cancel()
|
||||
|
||||
for event := range eventCh {
|
||||
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
||||
@ -439,7 +452,7 @@ func (user *User) startSync() <-chan error {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh)
|
||||
defer cancel()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
|
||||
Reference in New Issue
Block a user