diff --git a/go.mod b/go.mod index aae7f642..532ed649 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.0 github.com/urfave/cli/v2 v2.16.3 - gitlab.protontech.ch/go/liteapi v0.34.4-0.20221020085018-bbcf03261c80 + gitlab.protontech.ch/go/liteapi v0.34.4-0.20221020141054-833d961bc9b5 golang.org/x/exp v0.0.0-20220921164117-439092de6870 golang.org/x/net v0.1.0 golang.org/x/sys v0.1.0 diff --git a/go.sum b/go.sum index 11ff616a..c0af5524 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,6 @@ github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo= github.com/ProtonMail/docker-credential-helpers v1.1.0 h1:+kvUIpwWcbtP3WFv5sSvkFn/XLzSqPOB5AAthuk9xPk= github.com/ProtonMail/docker-credential-helpers v1.1.0/go.mod h1:mK0aBveCxhnQ756AmaTfXMZDeULvheYVhF/MWMErN5g= -github.com/ProtonMail/gluon v0.12.0 h1:90kirLwZNh91B+Mhc8fLGt/HMnhSJD2XUedwKVCs5bQ= -github.com/ProtonMail/gluon v0.12.0/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= github.com/ProtonMail/gluon v0.13.0 h1:WgL32KvMcanomDP3Z0mSs61QYmNHAtSEbVlimD5seiU= github.com/ProtonMail/gluon v0.13.0/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4= github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4= @@ -401,8 +399,8 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsr github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.11.0 h1:726SxLdi2SDnjY+BStqB9J1hNp4+2WlzyXLuimibIe0= github.com/zclconf/go-cty v1.11.0/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -gitlab.protontech.ch/go/liteapi v0.34.4-0.20221020085018-bbcf03261c80 h1:fxmLKxf1xsNAu6EBj+BPU1ChyZagXYhLHiqb7Jy1yOA= -gitlab.protontech.ch/go/liteapi v0.34.4-0.20221020085018-bbcf03261c80/go.mod h1:VCEA83UCi9f3XCP9W/XUIFnJKwokGB46lKUHBNzPWsQ= +gitlab.protontech.ch/go/liteapi v0.34.4-0.20221020141054-833d961bc9b5 h1:wUHwZXLiVdLjJqj6iPYXd2axeG8ROy4uHJ+da/GQ+0A= +gitlab.protontech.ch/go/liteapi v0.34.4-0.20221020141054-833d961bc9b5/go.mod h1:VCEA83UCi9f3XCP9W/XUIFnJKwokGB46lKUHBNzPWsQ= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/internal/user/events.go b/internal/user/events.go index 173326ff..db0b284e 100644 --- a/internal/user/events.go +++ b/internal/user/events.go @@ -248,7 +248,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me return fmt.Errorf("failed to get full message: %w", err) } - return user.withAddrKR(event.Message.AddressID, func(addrKR *crypto.KeyRing) error { + return user.withAddrKR(event.Message.AddressID, func(_, addrKR *crypto.KeyRing) error { buildRes, err := buildRFC822(ctx, full, addrKR) if err != nil { return fmt.Errorf("failed to build RFC822 message: %w", err) diff --git a/internal/user/imap.go b/internal/user/imap.go index 4bedb069..b6349968 100644 --- a/internal/user/imap.go +++ b/internal/user/imap.go @@ -254,7 +254,7 @@ func (conn *imapConnector) CreateMessage( imported []byte ) - if err := conn.withAddrKR(conn.addrID, func(addrKR *crypto.KeyRing) error { + if err := conn.withAddrKR(conn.addrID, func(_, addrKR *crypto.KeyRing) error { res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []liteapi.ImportReq{{ Metadata: liteapi.ImportMetadata{ AddressID: conn.addrID, diff --git a/internal/user/keys.go b/internal/user/keys.go index 695dcaa4..950e1fc2 100644 --- a/internal/user/keys.go +++ b/internal/user/keys.go @@ -36,7 +36,7 @@ func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error { }) } -func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing) error) error { +func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing, *crypto.KeyRing) error) error { return user.withUserKR(func(userKR *crypto.KeyRing) error { if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error { addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR) @@ -45,7 +45,7 @@ func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing) error) erro } defer userKR.ClearPrivateParams() - return fn(addrKR) + return fn(userKR, addrKR) }); !ok { return fmt.Errorf("no such address %q", addrID) } else if err != nil { @@ -56,7 +56,7 @@ func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing) error) erro }) } -func (user *User) withAddrKRs(fn func(map[string]*crypto.KeyRing) error) error { +func (user *User) withAddrKRs(fn func(*crypto.KeyRing, map[string]*crypto.KeyRing) error) error { return user.withUserKR(func(userKR *crypto.KeyRing) error { return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { addrKRs := make(map[string]*crypto.KeyRing) @@ -71,7 +71,7 @@ func (user *User) withAddrKRs(fn func(map[string]*crypto.KeyRing) error) error { addrKRs[apiAddr.ID] = addrKR } - return fn(addrKRs) + return fn(userKR, addrKRs) }) }) } diff --git a/internal/user/keys_test.go b/internal/user/keys_test.go new file mode 100644 index 00000000..da678b97 --- /dev/null +++ b/internal/user/keys_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2022 Proton AG +// +// This file is part of Proton Mail Bridge. +// +// Proton Mail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// Proton Mail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with Proton Mail Bridge. If not, see . + +package user + +import ( + "context" + "testing" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/stretchr/testify/require" + "gitlab.protontech.ch/go/liteapi" + "gitlab.protontech.ch/go/liteapi/server" +) + +func BenchmarkUserKeyRing(b *testing.B) { + b.StopTimer() + + withAPI(b, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { + withAccount(b, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { + withUser(b, ctx, s, m, "username", "password", func(user *User) { + b.StartTimer() + + for i := 0; i < b.N; i++ { + require.NoError(b, user.withUserKR(func(userKR *crypto.KeyRing) error { + return nil + })) + } + }) + }) + }) +} + +func BenchmarkAddrKeyRing(b *testing.B) { + b.StopTimer() + + withAPI(b, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { + withAccount(b, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { + withUser(b, ctx, s, m, "username", "password", func(user *User) { + b.StartTimer() + + for i := 0; i < b.N; i++ { + require.NoError(b, user.withAddrKR(addrIDs[0], func(userKR, addrKR *crypto.KeyRing) error { + return nil + })) + } + }) + }) + }) +} diff --git a/internal/user/smtp.go b/internal/user/smtp.go index 24f21227..664034a8 100644 --- a/internal/user/smtp.go +++ b/internal/user/smtp.go @@ -163,72 +163,70 @@ func (session *smtpSession) Data(r io.Reader) error { //nolint:funlen } return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error { - return session.withAddrKR(session.fromAddrID, func(addrKR *crypto.KeyRing) error { - return session.withUserKR(func(userKR *crypto.KeyRing) error { - // Use the first key for encrypting the message. - addrKR, err := addrKR.FirstKey() + return session.withAddrKR(session.fromAddrID, func(userKR, addrKR *crypto.KeyRing) error { + // Use the first key for encrypting the message. + addrKR, err := addrKR.FirstKey() + if err != nil { + return fmt.Errorf("failed to get first key: %w", err) + } + + // If the message contains a sender, use it instead of the one from the return path. + if sender, ok := getMessageSender(parser); ok { + session.from = sender + } + + // Load the user's mail settings. + settings, err := session.client.GetMailSettings(ctx) + if err != nil { + return fmt.Errorf("failed to get mail settings: %w", err) + } + + // If we have to attach the public key, do it now. + if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { + key, err := addrKR.GetKey(0) if err != nil { - return fmt.Errorf("failed to get first key: %w", err) + return fmt.Errorf("failed to get sending key: %w", err) } - // If the message contains a sender, use it instead of the one from the return path. - if sender, ok := getMessageSender(parser); ok { - session.from = sender - } - - // Load the user's mail settings. - settings, err := session.client.GetMailSettings(ctx) + pubKey, err := key.GetArmoredPublicKey() if err != nil { - return fmt.Errorf("failed to get mail settings: %w", err) + return fmt.Errorf("failed to get public key: %w", err) } - // If we have to attach the public key, do it now. - if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled { - key, err := addrKR.GetKey(0) - if err != nil { - return fmt.Errorf("failed to get sending key: %w", err) - } + parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) + } - pubKey, err := key.GetArmoredPublicKey() - if err != nil { - return fmt.Errorf("failed to get public key: %w", err) - } + // Parse the message we want to send (after we have attached the public key). + message, err := message.ParseWithParser(parser) + if err != nil { + return fmt.Errorf("failed to parse message: %w", err) + } - parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8])) - } - - // Parse the message we want to send (after we have attached the public key). - message, err := message.ParseWithParser(parser) - if err != nil { - return fmt.Errorf("failed to parse message: %w", err) - } - - // Collect all the user's emails so we can match them to the outgoing message. - emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string { - return addr.Email - }) - - sent, err := sendWithKey( - ctx, - session.client, - session.authID, - session.vault.AddressMode(), - settings, - userKR, - addrKR, - emails, - session.from, - session.to, - message, - ) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } - - logrus.WithField("messageID", sent.ID).Info("Message sent") - - return nil + // Collect all the user's emails so we can match them to the outgoing message. + emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string { + return addr.Email }) + + sent, err := sendWithKey( + ctx, + session.client, + session.authID, + session.vault.AddressMode(), + settings, + userKR, + addrKR, + emails, + session.from, + session.to, + message, + ) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + logrus.WithField("messageID", sent.ID).Info("Message sent") + + return nil }) }) } diff --git a/internal/user/sync.go b/internal/user/sync.go index 1439e737..821f30d3 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -43,7 +43,7 @@ const ( ) func (user *User) sync(ctx context.Context) error { - return user.withAddrKRs(func(addrKRs map[string]*crypto.KeyRing) error { + return user.withAddrKRs(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { logrus.Info("Beginning sync") if !user.vault.SyncStatus().HasLabels { diff --git a/internal/user/user_test.go b/internal/user/user_test.go index d6c35a48..0fbdf67c 100644 --- a/internal/user/user_test.go +++ b/internal/user/user_test.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with Proton Mail Bridge. If not, see . -package user_test +package user import ( "context" @@ -24,7 +24,6 @@ import ( "github.com/ProtonMail/proton-bridge/v2/internal/certs" "github.com/ProtonMail/proton-bridge/v2/internal/events" - "github.com/ProtonMail/proton-bridge/v2/internal/user" "github.com/ProtonMail/proton-bridge/v2/internal/vault" "github.com/ProtonMail/proton-bridge/v2/tests" "github.com/stretchr/testify/require" @@ -34,8 +33,8 @@ import ( ) func init() { - user.EventPeriod = 100 * time.Millisecond - user.EventJitter = 0 + EventPeriod = 100 * time.Millisecond + EventJitter = 0 backend.GenerateKey = tests.FastGenerateKey certs.GenerateCert = tests.FastGenerateCert } @@ -43,7 +42,7 @@ func init() { func TestUser_Data(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) { - withUser(t, ctx, s, m, "username", "password", func(user *user.User) { + withUser(t, ctx, s, m, "username", "password", func(user *User) { // User's ID should be correct. require.Equal(t, userID, user.ID()) @@ -66,7 +65,7 @@ func TestUser_Data(t *testing.T) { func TestUser_Sync(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { - withUser(t, ctx, s, m, "username", "password", func(user *user.User) { + withUser(t, ctx, s, m, "username", "password", func(user *User) { // User starts a sync at startup. require.IsType(t, events.SyncStarted{}, <-user.GetEventCh()) @@ -83,7 +82,7 @@ func TestUser_Sync(t *testing.T) { func TestUser_Deauth(t *testing.T) { withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) { withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) { - withUser(t, ctx, s, m, "username", "password", func(user *user.User) { + withUser(t, ctx, s, m, "username", "password", func(user *User) { eventCh := user.GetEventCh() // Revoke the user's auth token. @@ -96,7 +95,7 @@ func TestUser_Deauth(t *testing.T) { }) } -func withAPI(_ *testing.T, ctx context.Context, fn func(context.Context, *server.Server, *liteapi.Manager)) { //nolint:revive +func withAPI(_ testing.TB, ctx context.Context, fn func(context.Context, *server.Server, *liteapi.Manager)) { //nolint:revive server := server.New() defer server.Close() @@ -106,9 +105,9 @@ func withAPI(_ *testing.T, ctx context.Context, fn func(context.Context, *server )) } -func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) { +func withAccount(tb testing.TB, s *server.Server, username, password string, emails []string, fn func(string, []string)) { //nolint:unparam userID, addrID, err := s.CreateUser(username, emails[0], []byte(password)) - require.NoError(t, err) + require.NoError(tb, err) addrIDs := make([]string, 0, len(emails)) @@ -116,7 +115,7 @@ func withAccount(t *testing.T, s *server.Server, username, password string, emai for _, email := range emails[1:] { addrID, err := s.CreateAddress(userID, email, []byte(password)) - require.NoError(t, err) + require.NoError(tb, err) addrIDs = append(addrIDs, addrID) } @@ -124,33 +123,33 @@ func withAccount(t *testing.T, s *server.Server, username, password string, emai fn(userID, addrIDs) } -func withUser(t *testing.T, ctx context.Context, _ *server.Server, m *liteapi.Manager, username, password string, fn func(*user.User)) { //nolint:revive +func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *liteapi.Manager, username, password string, fn func(*User)) { //nolint:unparam,revive client, apiAuth, err := m.NewClientWithLogin(ctx, username, []byte(password)) - require.NoError(t, err) - defer func() { require.NoError(t, client.Close()) }() + require.NoError(tb, err) + defer func() { require.NoError(tb, client.Close()) }() apiUser, err := client.GetUser(ctx) - require.NoError(t, err) + require.NoError(tb, err) salts, err := client.GetSalts(ctx) - require.NoError(t, err) + require.NoError(tb, err) saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID) - require.NoError(t, err) + require.NoError(tb, err) - vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key")) - require.NoError(t, err) - require.False(t, corrupt) + vault, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key")) + require.NoError(tb, err) + require.False(tb, corrupt) vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass) - require.NoError(t, err) + require.NoError(tb, err) - user, err := user.New(ctx, vaultUser, client, apiUser, true) - require.NoError(t, err) - defer func() { require.NoError(t, user.Close()) }() + user, err := New(ctx, vaultUser, client, apiUser, true) + require.NoError(tb, err) + defer func() { require.NoError(tb, user.Close()) }() imapConn, err := user.NewIMAPConnectors() - require.NoError(t, err) + require.NoError(tb, err) go func() { for _, imapConn := range imapConn {