diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index c45b0e49..e7e80670 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -306,6 +306,10 @@ func TestBridge_FailLoginRecover(t *testing.T) { // We should now be able to log the user in. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) + + // The user should be there, now connected. + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge)) }) }) } @@ -400,6 +404,20 @@ func TestBridge_AddressMode(t *testing.T) { }) } +func TestBridge_LoginLogoutRepeated(t *testing.T) { + withEnv(t, func(ctx context.Context, s *server.Server, netCtl *liteapi.NetCtl, locator bridge.Locator, storeKey []byte) { + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + for i := 0; i < 10; i++ { + // Log the user in. + userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) + + // Log the user out. + require.NoError(t, bridge.LogoutUser(ctx, userID)) + } + }) + }) +} + // getErr returns the error that was passed to it. func getErr[T any](val T, err error) error { return err diff --git a/internal/user/sync.go b/internal/user/sync.go index ae155327..1439e737 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -58,6 +58,8 @@ func (user *User) sync(ctx context.Context) error { if err := user.vault.SetHasLabels(true); err != nil { return fmt.Errorf("failed to set has labels: %w", err) } + + logrus.Info("Synced labels") } else { logrus.Info("Labels are already synced, skipping") } @@ -74,6 +76,8 @@ func (user *User) sync(ctx context.Context) error { if err := user.vault.SetHasMessages(true); err != nil { return fmt.Errorf("failed to set has messages: %w", err) } + + logrus.Info("Synced messages") } else { logrus.Info("Messages are already synced, skipping") } diff --git a/internal/user/types.go b/internal/user/types.go index f6627de9..c31ecdcb 100644 --- a/internal/user/types.go +++ b/internal/user/types.go @@ -104,18 +104,22 @@ func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) { } // contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it. -func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) { +func contextWithStopCh(ctx context.Context, stopCh ...<-chan struct{}) (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(ctx) - go func() { - select { - case <-stopCh: - cancel() + for _, stopCh := range stopCh { + stopCh := stopCh - case <-ctx.Done(): - // ... - } - }() + go func() { + select { + case <-stopCh: + cancel() + + case <-ctx.Done(): + // ... + } + }() + } return ctx, cancel } diff --git a/internal/user/user.go b/internal/user/user.go index e4522a37..6344dd1b 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -46,6 +46,7 @@ type User struct { vault *vault.User client *liteapi.Client eventCh *queue.QueuedChannel[events.Event] + stopCh chan struct{} apiUser *safe.Value[liteapi.User] apiAddrs *safe.Map[string, liteapi.Address] @@ -101,6 +102,7 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU vault: encVault, client: client, eventCh: queue.NewQueuedChannel[events.Event](0, 0), + stopCh: make(chan struct{}), apiUser: safe.NewValue(apiUser), apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr), @@ -336,8 +338,13 @@ func (user *User) OnStatusDown() { } // Logout logs the user out from the API. -// If withVault is true, the user's vault is also cleared. func (user *User) Logout(ctx context.Context) error { + // Cancel ongoing syncs. + user.stopSync() + + // Wait for ongoing syncs to stop. + user.waitSync() + if err := user.client.AuthDelete(ctx); err != nil { return fmt.Errorf("failed to delete auth: %w", err) } @@ -351,6 +358,9 @@ func (user *User) Logout(ctx context.Context) error { // Close closes ongoing connections and cleans up resources. func (user *User) Close() error { + // Close any ongoing operations. + close(user.stopCh) + // Cancel ongoing syncs. user.stopSync() @@ -410,8 +420,11 @@ func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error { go func() { defer close(errCh) + ctx, cancel := contextWithStopCh(context.Background(), user.stopCh) + defer cancel() + for event := range eventCh { - if err := user.handleAPIEvent(context.Background(), event); err != nil { + if err := user.handleAPIEvent(ctx, event); err != nil { errCh <- fmt.Errorf("failed to handle API event: %w", err) } else if err := user.vault.SetEventID(event.EventID); err != nil { errCh <- fmt.Errorf("failed to update event ID: %w", err) @@ -439,7 +452,7 @@ func (user *User) startSync() <-chan error { return } - ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh) + ctx, cancel := contextWithStopCh(context.Background(), user.stopCh, user.syncStopCh) defer cancel() user.eventCh.Enqueue(events.SyncStarted{