mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 07:36:44 +00:00
Other: Safer vault
This commit is contained in:
2
go.mod
2
go.mod
@ -39,7 +39,7 @@ require (
|
|||||||
github.com/sirupsen/logrus v1.9.0
|
github.com/sirupsen/logrus v1.9.0
|
||||||
github.com/stretchr/testify v1.8.0
|
github.com/stretchr/testify v1.8.0
|
||||||
github.com/urfave/cli/v2 v2.16.3
|
github.com/urfave/cli/v2 v2.16.3
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f
|
||||||
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
golang.org/x/exp v0.0.0-20220921164117-439092de6870
|
||||||
golang.org/x/net v0.1.0
|
golang.org/x/net v0.1.0
|
||||||
golang.org/x/sys v0.1.0
|
golang.org/x/sys v0.1.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@ -399,8 +399,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr
|
|||||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0=
|
||||||
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA=
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455 h1:TWNT/rPSUGjYsNTwWx5Fd029LipSv+h1XuBwFSd5cAo=
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f h1:5gPPdQS+dm0A2GAE0IaGLZwgTKn2Q2dCQeMxgJUD+Nk=
|
||||||
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012150646-afdb630a0455/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
|
gitlab.protontech.ch/go/liteapi v0.33.2-0.20221012223142-6daad3a5912f/go.mod h1:NfsxXn1T81sz0gHnxuAfyCI4Agzm5UWVRyEtdQSch/4=
|
||||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||||
|
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||||
"gitlab.protontech.ch/go/liteapi"
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -41,12 +42,11 @@ func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, descript
|
|||||||
if info, err := bridge.QueryUserInfo(username); err == nil {
|
if info, err := bridge.QueryUserInfo(username); err == nil {
|
||||||
account = info.Username
|
account = info.Username
|
||||||
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
|
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
|
||||||
user, err := bridge.vault.GetUser(userIDs[0])
|
if err := bridge.vault.GetUser(userIDs[0], func(user *vault.User) {
|
||||||
if err != nil {
|
account = user.Username()
|
||||||
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account = user.Username()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var atts []liteapi.ReportBugAttachment
|
var atts []liteapi.ReportBugAttachment
|
||||||
|
|||||||
@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"gitlab.protontech.ch/go/liteapi"
|
"gitlab.protontech.ch/go/liteapi"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserInfo struct {
|
type UserInfo struct {
|
||||||
@ -49,18 +48,21 @@ func (bridge *Bridge) GetUserIDs() []string {
|
|||||||
|
|
||||||
// GetUserInfo returns info about the given user.
|
// GetUserInfo returns info about the given user.
|
||||||
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
||||||
vaultUser, err := bridge.vault.GetUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return UserInfo{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo {
|
if info, ok := safe.MapGetRet(bridge.users, userID, func(user *user.User) UserInfo {
|
||||||
return getConnUserInfo(user)
|
return getConnUserInfo(user)
|
||||||
}); ok {
|
}); ok {
|
||||||
return info, nil
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return getUserInfo(vaultUser.UserID(), vaultUser.Username(), vaultUser.AddressMode()), nil
|
var info UserInfo
|
||||||
|
|
||||||
|
if err := bridge.vault.GetUser(userID, func(user *vault.User) {
|
||||||
|
info = getUserInfo(user.UserID(), user.Username(), user.AddressMode())
|
||||||
|
}); err != nil {
|
||||||
|
return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryUserInfo queries the user info by username or address.
|
// QueryUserInfo queries the user info by username or address.
|
||||||
@ -108,10 +110,6 @@ func (bridge *Bridge) LoginUser(
|
|||||||
func() error {
|
func() error {
|
||||||
return client.AuthDelete(ctx)
|
return client.AuthDelete(ctx)
|
||||||
},
|
},
|
||||||
func() error {
|
|
||||||
bridge.deleteUser(ctx, auth.UserID)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to login user: %w", err)
|
return "", fmt.Errorf("failed to login user: %w", err)
|
||||||
@ -126,6 +124,7 @@ func (bridge *Bridge) LoginUser(
|
|||||||
|
|
||||||
// LoginUser authorizes a new bridge user with the given username and password.
|
// LoginUser authorizes a new bridge user with the given username and password.
|
||||||
// If necessary, a TOTP and mailbox password are requested via the callbacks.
|
// If necessary, a TOTP and mailbox password are requested via the callbacks.
|
||||||
|
// This is equivalent to doing LoginAuth and LoginUser separately.
|
||||||
func (bridge *Bridge) LoginFull(
|
func (bridge *Bridge) LoginFull(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
username string,
|
username string,
|
||||||
@ -252,10 +251,12 @@ func (bridge *Bridge) loginUser(ctx context.Context, client *liteapi.Client, aut
|
|||||||
func (bridge *Bridge) loadLoop() {
|
func (bridge *Bridge) loadLoop() {
|
||||||
for {
|
for {
|
||||||
bridge.loadWG.GoTry(func(ok bool) {
|
bridge.loadWG.GoTry(func(ok bool) {
|
||||||
if ok {
|
if !ok {
|
||||||
if err := bridge.loadUsers(); err != nil {
|
return
|
||||||
logrus.WithError(err).Error("Failed to load users")
|
}
|
||||||
}
|
|
||||||
|
if err := bridge.loadUsers(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to load users")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -269,6 +270,7 @@ func (bridge *Bridge) loadLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadUsers tries to load each user in the vault that isn't already loaded.
|
||||||
func (bridge *Bridge) loadUsers() error {
|
func (bridge *Bridge) loadUsers() error {
|
||||||
if err := bridge.vault.ForUser(func(user *vault.User) error {
|
if err := bridge.vault.ForUser(func(user *vault.User) error {
|
||||||
if bridge.users.Has(user.UserID()) {
|
if bridge.users.Has(user.UserID()) {
|
||||||
@ -317,7 +319,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error {
|
|||||||
return fmt.Errorf("failed to create API client: %w", err)
|
return fmt.Errorf("failed to create API client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := try.Catch(
|
return try.Catch(
|
||||||
func() error {
|
func() error {
|
||||||
apiUser, err := client.GetUser(ctx)
|
apiUser, err := client.GetUser(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -329,14 +331,7 @@ func (bridge *Bridge) loadUser(user *vault.User) error {
|
|||||||
func() error {
|
func() error {
|
||||||
return client.AuthDelete(ctx)
|
return client.AuthDelete(ctx)
|
||||||
},
|
},
|
||||||
func() error {
|
)
|
||||||
return bridge.logoutUser(ctx, user.UserID())
|
|
||||||
},
|
|
||||||
); err != nil {
|
|
||||||
return fmt.Errorf("failed to load user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addUser adds a new user with an already salted mailbox password.
|
// addUser adds a new user with an already salted mailbox password.
|
||||||
@ -347,24 +342,46 @@ func (bridge *Bridge) addUser(
|
|||||||
authUID, authRef string,
|
authUID, authRef string,
|
||||||
saltedKeyPass []byte,
|
saltedKeyPass []byte,
|
||||||
) error {
|
) error {
|
||||||
var user *user.User
|
vaultUser, isNew, err := bridge.newVaultUser(client, apiUser, authUID, authRef, saltedKeyPass)
|
||||||
|
if err != nil {
|
||||||
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
|
return fmt.Errorf("failed to add vault user: %w", err)
|
||||||
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add existing user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user = existingUser
|
|
||||||
} else {
|
|
||||||
newUser, err := bridge.addNewUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add new user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user = newUser
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := bridge.addUserWithVault(ctx, client, apiUser, vaultUser); err != nil {
|
||||||
|
if err := vaultUser.Clear(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to clear vault user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := vaultUser.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close vault user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNew {
|
||||||
|
if err := bridge.vault.DeleteUser(apiUser.ID); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to delete vault user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("failed to add user with vault: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addUserWithVault adds a new user to bridge with the given vault.
|
||||||
|
func (bridge *Bridge) addUserWithVault(
|
||||||
|
ctx context.Context,
|
||||||
|
client *liteapi.Client,
|
||||||
|
apiUser liteapi.User,
|
||||||
|
vaultUser *vault.User,
|
||||||
|
) error {
|
||||||
|
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bridge.users.Set(apiUser.ID, user)
|
||||||
|
|
||||||
// Connect the user's address(es) to gluon.
|
// Connect the user's address(es) to gluon.
|
||||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||||
@ -403,56 +420,37 @@ func (bridge *Bridge) addUser(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bridge *Bridge) addNewUser(
|
// newVaultUser creates a new vault user from the given auth information.
|
||||||
ctx context.Context,
|
// If one already exists in the vault, its data will be updated.
|
||||||
|
func (bridge *Bridge) newVaultUser(
|
||||||
client *liteapi.Client,
|
client *liteapi.Client,
|
||||||
apiUser liteapi.User,
|
apiUser liteapi.User,
|
||||||
authUID, authRef string,
|
authUID, authRef string,
|
||||||
saltedKeyPass []byte,
|
saltedKeyPass []byte,
|
||||||
) (*user.User, error) {
|
) (*vault.User, bool, error) {
|
||||||
vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
|
if !bridge.vault.HasUser(apiUser.ID) {
|
||||||
|
user, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("failed to add user to vault: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := bridge.vault.NewUser(apiUser.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
if err := user.SetAuth(authUID, authRef); err != nil {
|
||||||
if err != nil {
|
return nil, false, err
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bridge.users.Set(apiUser.ID, user)
|
if err := user.SetKeyPass(saltedKeyPass); err != nil {
|
||||||
|
return nil, false, err
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bridge *Bridge) addExistingUser(
|
|
||||||
ctx context.Context,
|
|
||||||
client *liteapi.Client,
|
|
||||||
apiUser liteapi.User,
|
|
||||||
authUID, authRef string,
|
|
||||||
saltedKeyPass []byte,
|
|
||||||
) (*user.User, error) {
|
|
||||||
vaultUser, err := bridge.vault.GetUser(apiUser.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := vaultUser.SetAuth(authUID, authRef); err != nil {
|
return user, false, nil
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := vaultUser.SetKeyPass(saltedKeyPass); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bridge.users.Set(apiUser.ID, user)
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// addIMAPUser connects the given user to gluon.
|
// addIMAPUser connects the given user to gluon.
|
||||||
|
|||||||
@ -262,21 +262,33 @@ func TestBridge_FailLoginRecover(t *testing.T) {
|
|||||||
read += uint64(len(b))
|
read += uint64(len(b))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var userID string
|
||||||
|
|
||||||
// Log the user in and record how much data was read.
|
// Log the user in and record how much data was read.
|
||||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
userID := must(bridge.LoginFull(ctx, username, password, nil, nil))
|
userID = must(bridge.LoginFull(ctx, username, password, nil, nil))
|
||||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Simulate a partial read.
|
// Now simulate failing to login.
|
||||||
netCtl.SetReadLimit(read / 2)
|
|
||||||
|
|
||||||
// We should fail to log the user in because we can't fully read its data.
|
|
||||||
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
// Simulate a partial read.
|
||||||
|
netCtl.SetReadLimit(3 * read / 4)
|
||||||
|
|
||||||
|
// We should fail to log the user in because we can't fully read its data.
|
||||||
require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
require.Error(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||||
|
|
||||||
// There should be no users.
|
// The user should still be there (but disconnected).
|
||||||
require.Empty(t, bridge.GetUserIDs())
|
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||||
|
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate the network recovering.
|
||||||
|
netCtl.SetReadLimit(0)
|
||||||
|
|
||||||
|
// We should now be able to log the user in.
|
||||||
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||||
|
require.NoError(t, getErr(bridge.LoginFull(ctx, username, password, nil, nil)))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Masterminds/semver/v3"
|
"github.com/Masterminds/semver/v3"
|
||||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
|
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
)
|
)
|
||||||
@ -22,34 +23,31 @@ var Port = 1042 // nolint:gochecknoglobals
|
|||||||
type Service struct {
|
type Service struct {
|
||||||
proto.UnimplementedFocusServer
|
proto.UnimplementedFocusServer
|
||||||
|
|
||||||
server *grpc.Server
|
server *grpc.Server
|
||||||
listener net.Listener
|
raiseCh chan struct{}
|
||||||
raiseCh chan struct{}
|
version *semver.Version
|
||||||
version *semver.Version
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewService creates a new focus service.
|
// NewService creates a new focus service.
|
||||||
// It listens on the local host and port 1042 (by default).
|
// It listens on the local host and port 1042 (by default).
|
||||||
func NewService(version *semver.Version) (*Service, error) {
|
func NewService(version *semver.Version) (*Service, error) {
|
||||||
listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to listen: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
service := &Service{
|
service := &Service{
|
||||||
server: grpc.NewServer(),
|
server: grpc.NewServer(),
|
||||||
listener: listener,
|
raiseCh: make(chan struct{}, 1),
|
||||||
raiseCh: make(chan struct{}, 1),
|
version: version,
|
||||||
version: version,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
proto.RegisterFocusServer(service.server, service)
|
proto.RegisterFocusServer(service.server, service)
|
||||||
|
|
||||||
go func() {
|
if listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port))); err != nil {
|
||||||
if err := service.server.Serve(listener); err != nil {
|
logrus.WithError(err).Warn("Failed to start focus service")
|
||||||
fmt.Printf("failed to serve: %v", err)
|
} else {
|
||||||
}
|
go func() {
|
||||||
}()
|
if err := service.server.Serve(listener); err != nil {
|
||||||
|
fmt.Printf("failed to serve: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/bradenaw/juniper/xerrors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,7 +24,7 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er
|
|||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
catch(handlers...)
|
catch(handlers...)
|
||||||
err = fmt.Errorf("panic: %v", r)
|
err = xerrors.WithStack(fmt.Errorf("panic: %v", r))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -38,13 +39,13 @@ func CatchVal[T any](try func() (T, error), handlers ...func() error) (res T, er
|
|||||||
func catch(handlers ...func() error) {
|
func catch(handlers ...func() error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
logrus.WithField("panic", r).Error("Panic in catch")
|
logrus.WithError(xerrors.WithStack(fmt.Errorf("panic: %v", r))).Error("Catch handler panicked")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
if err := handler(); err != nil {
|
if err := handler(); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to handle error")
|
logrus.WithError(xerrors.WithStack(err)).Error("Catch handler failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -341,7 +341,9 @@ func (user *User) Close() error {
|
|||||||
user.waitSync()
|
user.waitSync()
|
||||||
|
|
||||||
// Close the user's API client.
|
// Close the user's API client.
|
||||||
user.client.Close()
|
if err := user.client.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close API client")
|
||||||
|
}
|
||||||
|
|
||||||
// Close the user's update channels.
|
// Close the user's update channels.
|
||||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||||
@ -353,6 +355,11 @@ func (user *User) Close() error {
|
|||||||
// Close the user's notify channel.
|
// Close the user's notify channel.
|
||||||
user.eventCh.Close()
|
user.eventCh.Close()
|
||||||
|
|
||||||
|
// Close the user's vault.
|
||||||
|
if err := user.vault.Close(); err != nil {
|
||||||
|
logrus.WithError(err).Error("Failed to close vault")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -147,3 +147,8 @@ func (user *User) Clear() error {
|
|||||||
data.KeyPass = nil
|
data.KeyPass = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the user. This allows it to be removed from the vault.
|
||||||
|
func (user *User) Close() error {
|
||||||
|
return user.vault.detachUser(user.userID)
|
||||||
|
}
|
||||||
|
|||||||
@ -80,7 +80,11 @@ func TestUser_Delete(t *testing.T) {
|
|||||||
// The user should be listed in the store.
|
// The user should be listed in the store.
|
||||||
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
|
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
|
||||||
|
|
||||||
// Clear the user's auth information.
|
// Try to delete the user; it should fail because it is still in use.
|
||||||
|
require.Error(t, s.DeleteUser("userID"))
|
||||||
|
|
||||||
|
// Close the user; it should now be deletable.
|
||||||
|
require.NoError(t, user.Close())
|
||||||
require.NoError(t, s.DeleteUser("userID"))
|
require.NoError(t, s.DeleteUser("userID"))
|
||||||
|
|
||||||
// The store should have no users again.
|
// The store should have no users again.
|
||||||
@ -122,6 +126,56 @@ func TestUser_SyncStatus(t *testing.T) {
|
|||||||
require.Empty(t, user.SyncStatus().LastMessageID)
|
require.Empty(t, user.SyncStatus().LastMessageID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUser_ForEach(t *testing.T) {
|
||||||
|
// Create a new test vault.
|
||||||
|
s := newVault(t)
|
||||||
|
|
||||||
|
// Create some new users.
|
||||||
|
user1, err := s.AddUser("userID1", "username1", "authUID1", "authRef1", []byte("keyPass1"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
user2, err := s.AddUser("userID2", "username2", "authUID2", "authRef2", []byte("keyPass2"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Iterate through the users.
|
||||||
|
s.ForUser(func(user *vault.User) error {
|
||||||
|
switch user.UserID() {
|
||||||
|
case "userID1":
|
||||||
|
require.Equal(t, "username1", user.Username())
|
||||||
|
require.Equal(t, "authUID1", user.AuthUID())
|
||||||
|
require.Equal(t, "authRef1", user.AuthRef())
|
||||||
|
require.Equal(t, "keyPass1", string(user.KeyPass()))
|
||||||
|
|
||||||
|
case "userID2":
|
||||||
|
require.Equal(t, "username2", user.Username())
|
||||||
|
require.Equal(t, "authUID2", user.AuthUID())
|
||||||
|
require.Equal(t, "authRef2", user.AuthRef())
|
||||||
|
require.Equal(t, "keyPass2", string(user.KeyPass()))
|
||||||
|
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected user %q", user.UserID())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Try to delete the first user; it should fail because it is still in use.
|
||||||
|
require.Error(t, s.DeleteUser("userID1"))
|
||||||
|
|
||||||
|
// Close the first user; it should now be deletable.
|
||||||
|
require.NoError(t, user1.Close())
|
||||||
|
require.NoError(t, s.DeleteUser("userID1"))
|
||||||
|
|
||||||
|
// Try to delete the second user; it should fail because it is still in use.
|
||||||
|
require.Error(t, s.DeleteUser("userID2"))
|
||||||
|
|
||||||
|
// Close the second user; it should now be deletable.
|
||||||
|
require.NoError(t, user2.Close())
|
||||||
|
require.NoError(t, s.DeleteUser("userID2"))
|
||||||
|
|
||||||
|
// The store should have no users again.
|
||||||
|
require.Empty(t, s.GetUserIDs())
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
func TestUser(t *testing.T) {
|
func TestUser(t *testing.T) {
|
||||||
// Replace the token generator with a dummy one.
|
// Replace the token generator with a dummy one.
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
@ -19,9 +20,13 @@ import (
|
|||||||
// Vault is an encrypted data vault that stores bridge and user data.
|
// Vault is an encrypted data vault that stores bridge and user data.
|
||||||
type Vault struct {
|
type Vault struct {
|
||||||
path string
|
path string
|
||||||
enc []byte
|
|
||||||
gcm cipher.AEAD
|
gcm cipher.AEAD
|
||||||
lock sync.RWMutex
|
|
||||||
|
enc []byte
|
||||||
|
encLock sync.RWMutex
|
||||||
|
|
||||||
|
ref map[string]int
|
||||||
|
refLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
// New constructs a new encrypted data vault at the given filepath using the given encryption key.
|
||||||
@ -57,27 +62,45 @@ func (vault *Vault) GetUserIDs() []string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserIDs returns the user IDs and usernames of all users in the vault.
|
// HasUser returns true if the vault contains a user with the given ID.
|
||||||
func (vault *Vault) GetUser(userID string) (*User, error) {
|
func (vault *Vault) HasUser(userID string) bool {
|
||||||
|
return xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||||
|
return user.UserID == userID
|
||||||
|
}) >= 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUser provides access to a vault user. It returns an error if the user does not exist.
|
||||||
|
func (vault *Vault) GetUser(userID string, fn func(*User)) error {
|
||||||
|
user, err := vault.NewUser(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = user.Close() }()
|
||||||
|
|
||||||
|
fn(user)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUser returns a new vault user. It must be closed before it can be deleted.
|
||||||
|
func (vault *Vault) NewUser(userID string) (*User, error) {
|
||||||
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
if idx := xslices.IndexFunc(vault.get().Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
}); idx < 0 {
|
}); idx < 0 {
|
||||||
return nil, errors.New("no such user")
|
return nil, errors.New("no such user")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &User{
|
return vault.attachUser(userID), nil
|
||||||
vault: vault,
|
|
||||||
userID: userID,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForUser executes a callback for each user in the vault.
|
// ForUser executes a callback for each user in the vault.
|
||||||
func (vault *Vault) ForUser(fn func(*User) error) error {
|
func (vault *Vault) ForUser(fn func(*User) error) error {
|
||||||
for _, userID := range vault.GetUserIDs() {
|
for _, userID := range vault.GetUserIDs() {
|
||||||
user, err := vault.GetUser(userID)
|
user, err := vault.NewUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer func() { _ = user.Close() }()
|
||||||
|
|
||||||
if err := fn(user); err != nil {
|
if err := fn(user); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -102,11 +125,18 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return vault.GetUser(userID)
|
return vault.NewUser(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser removes the given user from the vault.
|
// DeleteUser removes the given user from the vault.
|
||||||
func (vault *Vault) DeleteUser(userID string) error {
|
func (vault *Vault) DeleteUser(userID string) error {
|
||||||
|
vault.refLock.Lock()
|
||||||
|
defer vault.refLock.Unlock()
|
||||||
|
|
||||||
|
if _, ok := vault.ref[userID]; ok {
|
||||||
|
return fmt.Errorf("user %s is currently in use", userID)
|
||||||
|
}
|
||||||
|
|
||||||
return vault.mod(func(data *Data) {
|
return vault.mod(func(data *Data) {
|
||||||
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
idx := xslices.IndexFunc(data.Users, func(user UserData) bool {
|
||||||
return user.UserID == userID
|
return user.UserID == userID
|
||||||
@ -124,6 +154,35 @@ func (vault *Vault) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) attachUser(userID string) *User {
|
||||||
|
vault.refLock.Lock()
|
||||||
|
defer vault.refLock.Unlock()
|
||||||
|
|
||||||
|
vault.ref[userID] += 1
|
||||||
|
|
||||||
|
return &User{
|
||||||
|
vault: vault,
|
||||||
|
userID: userID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vault *Vault) detachUser(userID string) error {
|
||||||
|
vault.refLock.Lock()
|
||||||
|
defer vault.refLock.Unlock()
|
||||||
|
|
||||||
|
if _, ok := vault.ref[userID]; !ok {
|
||||||
|
return fmt.Errorf("user %s is not attached", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
vault.ref[userID] -= 1
|
||||||
|
|
||||||
|
if vault.ref[userID] == 0 {
|
||||||
|
delete(vault.ref, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
||||||
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
|
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
|
||||||
if _, err := initVault(path, gluonDir, gcm); err != nil {
|
if _, err := initVault(path, gluonDir, gcm); err != nil {
|
||||||
@ -149,12 +208,17 @@ func newVault(path, gluonDir string, gcm cipher.AEAD) (*Vault, bool, error) {
|
|||||||
enc = newEnc
|
enc = newEnc
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Vault{path: path, enc: enc, gcm: gcm}, corrupt, nil
|
return &Vault{
|
||||||
|
path: path,
|
||||||
|
enc: enc,
|
||||||
|
gcm: gcm,
|
||||||
|
ref: make(map[string]int),
|
||||||
|
}, corrupt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) get() Data {
|
func (vault *Vault) get() Data {
|
||||||
vault.lock.RLock()
|
vault.encLock.RLock()
|
||||||
defer vault.lock.RUnlock()
|
defer vault.encLock.RUnlock()
|
||||||
|
|
||||||
dec, err := decrypt(vault.gcm, vault.enc)
|
dec, err := decrypt(vault.gcm, vault.enc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -171,8 +235,8 @@ func (vault *Vault) get() Data {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (vault *Vault) mod(fn func(data *Data)) error {
|
func (vault *Vault) mod(fn func(data *Data)) error {
|
||||||
vault.lock.Lock()
|
vault.encLock.Lock()
|
||||||
defer vault.lock.Unlock()
|
defer vault.encLock.Unlock()
|
||||||
|
|
||||||
dec, err := decrypt(vault.gcm, vault.enc)
|
dec, err := decrypt(vault.gcm, vault.enc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user