diff --git a/go.mod b/go.mod index dc027458..26c6b0be 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 + gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 diff --git a/go.sum b/go.sum index 2ba9856e..0862a588 100644 --- a/go.sum +++ b/go.sum @@ -399,8 +399,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 h1:TWNT/rPSUGjYsNTwWx5Fd029LipSv+h1XuBwFSd5cAo= -gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f h1:5gPPdQS+dm0A2GAE0IaGLZwgTKn2Q2dCQeMxgJUD+Nk= +gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/bridge/bug_report.go b/internal/bridge/bug_report.go index 55180ad4..50aeb304 100644 --- a/internal/bridge/bug_report.go +++ b/internal/bridge/bug_report.go @@ -27,6 +27,7 @@ import ( "sort" "github.com/ProtonMail/proton-bridge/v2/internal/logging" + "github.com/ProtonMail/proton-bridge/v2/internal/vault" "gitlab.protontech.ch/go/liteapi" ) @@ -41,12 +42,11 @@ func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, descript if info, err := bridge.QueryUserInfo(username); err == nil { account = info.Username } else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 { - user, err := bridge.vault.GetUser(userIDs[0]) - if err != nil { + if err := bridge.vault.GetUser(userIDs[0], func(user *vault.User) { + account = user.Username() + }); err != nil { return err } - - account = user.Username() } var atts []liteapi.ReportBugAttachment diff --git a/internal/bridge/user.go b/internal/bridge/user.go index d42525b3..86d6374d 100644 --- a/internal/bridge/user.go +++ b/internal/bridge/user.go @@ -13,7 +13,6 @@ import ( "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" "gitlab.protontech.ch/go/liteapi" - "golang.org/x/exp/slices" ) type UserInfo struct { @@ -49,18 +48,21 @@ func (bridge *Bridge) GetUserIDs() []string { // GetUserInfo returns info about the given user. func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) { - vaultUser, err := bridge.vault.GetUser(userID) - if err != nil { - return UserInfo{}, err - } - if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo { return getConnUserInfo(user) }); ok { return info, nil } - return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil + var info UserInfo + + if err := bridge.vault.GetUser(userID, func(user *vault.User) { + info = getUserInfo(user.UserID(), user.Username(), user.AddressMode()) + }); err != nil { + return UserInfo{}, fmt.Errorf("failed to get user info: %w", err) + } + + return info, nil } // QueryUserInfo queries the user info by username or address. @@ -108,10 +110,6 @@ func (bridge *Bridge) LoginUser( func() error { return client.AuthDelete(ctx) }, - func() error { - bridge.deleteUser(ctx, auth.UserID) - return nil - }, ) if err != nil { return "", fmt.Errorf("failed to login user: %w", err) @@ -126,6 +124,7 @@ func (bridge *Bridge) LoginUser( // LoginUser authorizes a new bridge user with the given username and password. // If necessary, a TOTP and mailbox password are requested via the callbacks. +// This is equivalent to doing LoginAuth and LoginUser separately. func (bridge *Bridge) LoginFull( ctx context.Context, username string, @@ -252,10 +251,12 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut func (bridge *Bridge) loadLoop() { for { bridge.loadWG.GoTry(func(ok bool) { - if ok { - if err := bridge.loadUsers(); err != nil { - logrus.WithError(err).Error("Failed to load users") - } + if !ok { + return + } + + if err := bridge.loadUsers(); err != nil { + logrus.WithError(err).Error("Failed to load users") } }) @@ -269,6 +270,7 @@ func (bridge *Bridge) loadLoop() { } } +// loadUsers tries to load each user in the vault that isn't already loaded. func (bridge *Bridge) loadUsers() error { if err := bridge.vault.ForUser(func(user *vault.User) error { if bridge.users.Has(user.UserID()) { @@ -317,7 +319,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error { return fmt.Errorf("failed to create API client: %w", err) } - if err := try.Catch( + return try.Catch( func() error { apiUser, err := client.GetUser(ctx) if err != nil { @@ -329,14 +331,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error { func() error { return client.AuthDelete(ctx) }, - func() error { - return bridge.logoutUser(ctx, user.UserID()) - }, - ); err != nil { - return fmt.Errorf("failed to load user: %w", err) - } - - return nil + ) } // addUser adds a new user with an already salted mailbox password. @@ -347,24 +342,46 @@ func (bridge *Bridge) addUser( authUID, authRef string, saltedKeyPass []byte, ) error { - var user *user.User - - if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) { - existingUser, err := bridge.addExistingUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass) - if err != nil { - return fmt.Errorf("failed to add existing user: %w", err) - } - - user = existingUser - } else { - newUser, err := bridge.addNewUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass) - if err != nil { - return fmt.Errorf("failed to add new user: %w", err) - } - - user = newUser + vaultUser, isNew, err := bridge.newVaultUser(client, apiUser, authUID, authRef, saltedKeyPass) + if err != nil { + return fmt.Errorf("failed to add vault user: %w", err) } + if err := bridge.addUserWithVault(ctx, client, apiUser, vaultUser); err != nil { + if err := vaultUser.Clear(); err != nil { + logrus.WithError(err).Error("Failed to clear vault user") + } + + if err := vaultUser.Close(); err != nil { + logrus.WithError(err).Error("Failed to close vault user") + } + + if isNew { + if err := bridge.vault.DeleteUser(apiUser.ID); err != nil { + logrus.WithError(err).Error("Failed to delete vault user") + } + } + + return fmt.Errorf("failed to add user with vault: %w", err) + } + + return nil +} + +// addUserWithVault adds a new user to bridge with the given vault. +func (bridge *Bridge) addUserWithVault( + ctx context.Context, + client *liteapi.Client, + apiUser liteapi.User, + vaultUser *vault.User, +) error { + user, err := user.New(ctx, vaultUser, client, apiUser) + if err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + + bridge.users.Set(apiUser.ID, user) + // Connect the user's address(es) to gluon. if err := bridge.addIMAPUser(ctx, user); err != nil { return fmt.Errorf("failed to add IMAP user: %w", err) @@ -403,56 +420,37 @@ func (bridge *Bridge) addUser( return nil } -func (bridge *Bridge) addNewUser( - ctx context.Context, +// newVaultUser creates a new vault user from the given auth information. +// If one already exists in the vault, its data will be updated. +func (bridge *Bridge) newVaultUser( client *liteapi.Client, apiUser liteapi.User, authUID, authRef string, saltedKeyPass []byte, -) (*user.User, error) { - vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass) +) (*vault.User, bool, error) { + if !bridge.vault.HasUser(apiUser.ID) { + user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass) + if err != nil { + return nil, false, fmt.Errorf("failed to add user to vault: %w", err) + } + + return user, true, nil + } + + user, err := bridge.vault.NewUser(apiUser.ID) if err != nil { - return nil, err + return nil, false, err } - user, err := user.New(ctx, vaultUser, client, apiUser) - if err != nil { - return nil, err + if err := user.SetAuth(authUID, authRef); err != nil { + return nil, false, err } - bridge.users.Set(apiUser.ID, user) - - return user, nil -} - -func (bridge *Bridge) addExistingUser( - ctx context.Context, - client *liteapi.Client, - apiUser liteapi.User, - authUID, authRef string, - saltedKeyPass []byte, -) (*user.User, error) { - vaultUser, err := bridge.vault.GetUser(apiUser.ID) - if err != nil { - return nil, err + if err := user.SetKeyPass(saltedKeyPass); err != nil { + return nil, false, err } - if err := vaultUser.SetAuth(authUID, authRef); err != nil { - return nil, err - } - - if err := vaultUser.SetKeyPass(saltedKeyPass); err != nil { - return nil, err - } - - user, err := user.New(ctx, vaultUser, client, apiUser) - if err != nil { - return nil, err - } - - bridge.users.Set(apiUser.ID, user) - - return user, nil + return user, false, nil } // addIMAPUser connects the given user to gluon. diff --git a/internal/bridge/user_test.go b/internal/bridge/user_test.go index 46e8709c..3b0bbb8e 100644 --- a/internal/bridge/user_test.go +++ b/internal/bridge/user_test.go @@ -262,21 +262,33 @@ func TestBridge_FailLoginRecover(t *testing.T) { read += uint64(len(b)) }) + var userID string + // Log the user in and record how much data was read. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { - userID := must(bridge.LoginFull(ctx, username, password, nil, nil)) + userID = must(bridge.LoginFull(ctx, username, password, nil, nil)) require.NoError(t, bridge.LogoutUser(ctx, userID)) }) - // Simulate a partial read. - netCtl.SetReadLimit(read / 2) - - // We should fail to log the user in because we can't fully read its data. + // Now simulate failing to login. withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + // Simulate a partial read. + netCtl.SetReadLimit(3 * read / 4) + + // We should fail to log the user in because we can't fully read its data. require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) - // There should be no users. - require.Empty(t, bridge.GetUserIDs()) + // The user should still be there (but disconnected). + require.Equal(t, []string{userID}, bridge.GetUserIDs()) + require.Empty(t, getConnectedUserIDs(t, bridge)) + }) + + // Simulate the network recovering. + netCtl.SetReadLimit(0) + + // We should now be able to log the user in. + withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) { + require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil))) }) }) } diff --git a/internal/focus/service.go b/internal/focus/service.go index c61e25b5..6afa5364 100644 --- a/internal/focus/service.go +++ b/internal/focus/service.go @@ -8,6 +8,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/proton-bridge/v2/internal/focus/proto" + "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" ) @@ -22,34 +23,31 @@ var Port = 1042 // nolint:gochecknoglobals type Service struct { proto.UnimplementedFocusServer - server *grpc.Server - listener net.Listener - raiseCh chan struct{} - version *semver.Version + server *grpc.Server + raiseCh chan struct{} + version *semver.Version } // NewService creates a new focus service. // It listens on the local host and port 1042 (by default). func NewService(version *semver.Version) (*Service, error) { - listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port))) - if err != nil { - return nil, fmt.Errorf("failed to listen: %w", err) - } - service := &Service{ - server: grpc.NewServer(), - listener: listener, - raiseCh: make(chan struct{}, 1), - version: version, + server: grpc.NewServer(), + raiseCh: make(chan struct{}, 1), + version: version, } proto.RegisterFocusServer(service.server, service) - go func() { - if err := service.server.Serve(listener); err != nil { - fmt.Printf("failed to serve: %v", err) - } - }() + if listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port))); err != nil { + logrus.WithError(err).Warn("Failed to start focus service") + } else { + go func() { + if err := service.server.Serve(listener); err != nil { + fmt.Printf("failed to serve: %v", err) + } + }() + } return service, nil } diff --git a/internal/try/try.go b/internal/try/try.go index 4d97d081..490cc5ee 100644 --- a/internal/try/try.go +++ b/internal/try/try.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/bradenaw/juniper/xerrors" "github.com/sirupsen/logrus" ) @@ -23,7 +24,7 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er defer func() { if r := recover(); r != nil { catch(handlers...) - err = fmt.Errorf("panic: %v", r) + err = xerrors.WithStack(fmt.Errorf("panic: %v", r)) } }() @@ -38,13 +39,13 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er func catch(handlers ...func() error) { defer func() { if r := recover(); r != nil { - logrus.WithField("panic", r).Error("Panic in catch") + logrus.WithError(xerrors.WithStack(fmt.Errorf("panic: %v", r))).Error("Catch handler panicked") } }() for _, handler := range handlers { if err := handler(); err != nil { - logrus.WithError(err).Error("Failed to handle error") + logrus.WithError(xerrors.WithStack(err)).Error("Catch handler failed") } } } diff --git a/internal/user/user.go b/internal/user/user.go index 29de575a..83dc1ecc 100644 --- a/internal/user/user.go +++ b/internal/user/user.go @@ -341,7 +341,9 @@ func (user *User) Close() error { user.waitSync() // Close the user's API client. - user.client.Close() + if err := user.client.Close(); err != nil { + logrus.WithError(err).Error("Failed to close API client") + } // Close the user's update channels. user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) { @@ -353,6 +355,11 @@ func (user *User) Close() error { // Close the user's notify channel. user.eventCh.Close() + // Close the user's vault. + if err := user.vault.Close(); err != nil { + logrus.WithError(err).Error("Failed to close vault") + } + return nil } diff --git a/internal/vault/user.go b/internal/vault/user.go index 4727f599..def8271b 100644 --- a/internal/vault/user.go +++ b/internal/vault/user.go @@ -147,3 +147,8 @@ func (user *User) Clear() error { data.KeyPass = nil }) } + +// Close closes the user. This allows it to be removed from the vault. +func (user *User) Close() error { + return user.vault.detachUser(user.userID) +} diff --git a/internal/vault/user_test.go b/internal/vault/user_test.go index 4e4c322e..9a1f77f5 100644 --- a/internal/vault/user_test.go +++ b/internal/vault/user_test.go @@ -80,7 +80,11 @@ func TestUser_Delete(t *testing.T) { // The user should be listed in the store. require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs()) - // Clear the user's auth information. + // Try to delete the user; it should fail because it is still in use. + require.Error(t, s.DeleteUser("userID")) + + // Close the user; it should now be deletable. + require.NoError(t, user.Close()) require.NoError(t, s.DeleteUser("userID")) // The store should have no users again. @@ -122,6 +126,56 @@ func TestUser_SyncStatus(t *testing.T) { require.Empty(t, user.SyncStatus().LastMessageID) } +func TestUser_ForEach(t *testing.T) { + // Create a new test vault. + s := newVault(t) + + // Create some new users. + user1, err := s.AddUser("userID1", "username1", "authUID1", "authRef1", []byte("keyPass1")) + require.NoError(t, err) + user2, err := s.AddUser("userID2", "username2", "authUID2", "authRef2", []byte("keyPass2")) + require.NoError(t, err) + + // Iterate through the users. + s.ForUser(func(user *vault.User) error { + switch user.UserID() { + case "userID1": + require.Equal(t, "username1", user.Username()) + require.Equal(t, "authUID1", user.AuthUID()) + require.Equal(t, "authRef1", user.AuthRef()) + require.Equal(t, "keyPass1", string(user.KeyPass())) + + case "userID2": + require.Equal(t, "username2", user.Username()) + require.Equal(t, "authUID2", user.AuthUID()) + require.Equal(t, "authRef2", user.AuthRef()) + require.Equal(t, "keyPass2", string(user.KeyPass())) + + default: + t.Fatalf("unexpected user %q", user.UserID()) + } + + return nil + }) + + // Try to delete the first user; it should fail because it is still in use. + require.Error(t, s.DeleteUser("userID1")) + + // Close the first user; it should now be deletable. + require.NoError(t, user1.Close()) + require.NoError(t, s.DeleteUser("userID1")) + + // Try to delete the second user; it should fail because it is still in use. + require.Error(t, s.DeleteUser("userID2")) + + // Close the second user; it should now be deletable. + require.NoError(t, user2.Close()) + require.NoError(t, s.DeleteUser("userID2")) + + // The store should have no users again. + require.Empty(t, s.GetUserIDs()) +} + /* func TestUser(t *testing.T) { // Replace the token generator with a dummy one. diff --git a/internal/vault/vault.go b/internal/vault/vault.go index 78dad6f8..7d2db7ce 100644 --- a/internal/vault/vault.go +++ b/internal/vault/vault.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/json" "errors" + "fmt" "io/fs" "math/rand" "os" @@ -19,9 +20,13 @@ import ( // Vault is an encrypted data vault that stores bridge and user data. type Vault struct { path string - enc []byte gcm cipher.AEAD - lock sync.RWMutex + + enc []byte + encLock sync.RWMutex + + ref map[string]int + refLock sync.Mutex } // New constructs a new encrypted data vault at the given filepath using the given encryption key. @@ -57,27 +62,45 @@ func (vault *Vault) GetUserIDs() []string { }) } -// GetUserIDs returns the user IDs and usernames of all users in the vault. -func (vault *Vault) GetUser(userID string) (*User, error) { +// HasUser returns true if the vault contains a user with the given ID. +func (vault *Vault) HasUser(userID string) bool { + return xslices.IndexFunc(vault.get().Users, func(user UserData) bool { + return user.UserID == userID + }) >= 0 +} + +// GetUser provides access to a vault user. It returns an error if the user does not exist. +func (vault *Vault) GetUser(userID string, fn func(*User)) error { + user, err := vault.NewUser(userID) + if err != nil { + return err + } + defer func() { _ = user.Close() }() + + fn(user) + + return nil +} + +// NewUser returns a new vault user. It must be closed before it can be deleted. +func (vault *Vault) NewUser(userID string) (*User, error) { if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool { return user.UserID == userID }); idx < 0 { return nil, errors.New("no such user") } - return &User{ - vault: vault, - userID: userID, - }, nil + return vault.attachUser(userID), nil } // ForUser executes a callback for each user in the vault. func (vault *Vault) ForUser(fn func(*User) error) error { for _, userID := range vault.GetUserIDs() { - user, err := vault.GetUser(userID) + user, err := vault.NewUser(userID) if err != nil { return err } + defer func() { _ = user.Close() }() if err := fn(user); err != nil { return err @@ -102,11 +125,18 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [ return nil, err } - return vault.GetUser(userID) + return vault.NewUser(userID) } // DeleteUser removes the given user from the vault. func (vault *Vault) DeleteUser(userID string) error { + vault.refLock.Lock() + defer vault.refLock.Unlock() + + if _, ok := vault.ref[userID]; ok { + return fmt.Errorf("user %s is currently in use", userID) + } + return vault.mod(func(data *Data) { idx := xslices.IndexFunc(data.Users, func(user UserData) bool { return user.UserID == userID @@ -124,6 +154,35 @@ func (vault *Vault) Close() error { return nil } +func (vault *Vault) attachUser(userID string) *User { + vault.refLock.Lock() + defer vault.refLock.Unlock() + + vault.ref[userID] += 1 + + return &User{ + vault: vault, + userID: userID, + } +} + +func (vault *Vault) detachUser(userID string) error { + vault.refLock.Lock() + defer vault.refLock.Unlock() + + if _, ok := vault.ref[userID]; !ok { + return fmt.Errorf("user %s is not attached", userID) + } + + vault.ref[userID] -= 1 + + if vault.ref[userID] == 0 { + delete(vault.ref, userID) + } + + return nil +} + func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) { if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) { if _, err := initVault(path, gluonDir, gcm); err != nil { @@ -149,12 +208,17 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) { enc = newEnc } - return &Vault{path: path, enc: enc, gcm: gcm}, corrupt, nil + return &Vault{ + path: path, + enc: enc, + gcm: gcm, + ref: make(map[string]int), + }, corrupt, nil } func (vault *Vault) get() Data { - vault.lock.RLock() - defer vault.lock.RUnlock() + vault.encLock.RLock() + defer vault.encLock.RUnlock() dec, err := decrypt(vault.gcm, vault.enc) if err != nil { @@ -171,8 +235,8 @@ func (vault *Vault) get() Data { } func (vault *Vault) mod(fn func(data *Data)) error { - vault.lock.Lock() - defer vault.lock.Unlock() + vault.encLock.Lock() + defer vault.encLock.Unlock() dec, err := decrypt(vault.gcm, vault.enc) if err != nil {