forked from Silverfish/proton-bridge
Other: Safer user types
This commit is contained in:
@ -8,7 +8,6 @@ import (
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
@ -28,12 +27,6 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error
|
||||
}
|
||||
}
|
||||
|
||||
if event.MailSettings != nil {
|
||||
if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Labels) > 0 {
|
||||
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
|
||||
return err
|
||||
@ -51,14 +44,7 @@ func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error
|
||||
|
||||
// handleUserEvent handles the given user event.
|
||||
func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
|
||||
userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.apiUser.Set(userEvent)
|
||||
|
||||
user.userKR.Set(userKR)
|
||||
user.apiUser.Save(userEvent)
|
||||
|
||||
user.eventCh.Enqueue(events.UserChanged{
|
||||
UserID: user.ID(),
|
||||
@ -93,22 +79,18 @@ func (user *User) handleAddressEvents(ctx context.Context, addressEvents []litea
|
||||
}
|
||||
|
||||
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) {
|
||||
return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
user.apiAddrs.Set(event.Address.ID, event.Address)
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||
user.updateCh.SetFrom(event.Address.ID, addrID)
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh.Set(event.Address.ID, queue.NewQueuedChannel[imap.Update](0, 0))
|
||||
}
|
||||
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.Set(apiAddrs)
|
||||
|
||||
user.addrKRs.Set(event.Address.ID, addrKR)
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
@ -116,9 +98,11 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
})
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if err := syncLabels(ctx, user.client, user.updateCh[event.Address.ID]); err != nil {
|
||||
if ok, err := user.updateCh.GetErr(event.Address.ID, func(updateCh *queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, updateCh)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.Address.ID)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
}
|
||||
}
|
||||
@ -127,21 +111,7 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
addrKR, err := safe.GetTypeErr(user.userKR, func(userKR *crypto.KeyRing) (*crypto.KeyRing, error) {
|
||||
return event.Address.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.Set(apiAddrs)
|
||||
|
||||
user.addrKRs.Set(event.Address.ID, addrKR)
|
||||
user.apiAddrs.Set(event.Address.ID, event.Address)
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.ID(),
|
||||
@ -153,25 +123,20 @@ func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
email, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
return getAddrEmail(apiAddrs, event.ID)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get address email: %w", err)
|
||||
var email string
|
||||
|
||||
if ok := user.apiAddrs.GetDelete(event.ID, func(apiAddr liteapi.Address) {
|
||||
email = apiAddr.Email
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.ID)
|
||||
}
|
||||
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.Set(apiAddrs)
|
||||
|
||||
user.addrKRs.Delete(event.ID)
|
||||
|
||||
if len(user.updateCh) > 1 {
|
||||
user.updateCh[event.ID].Close()
|
||||
delete(user.updateCh, event.ID)
|
||||
if ok := user.updateCh.GetDelete(event.ID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
updateCh.Close()
|
||||
}
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", event.ID)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||
@ -183,13 +148,6 @@ func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMailSettingsEvent handles the given mail settings event.
|
||||
func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
|
||||
user.settings.Set(mailSettingsEvent)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLabelEvents handles the given label events.
|
||||
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
|
||||
for _, event := range labelEvents {
|
||||
@ -215,25 +173,25 @@ func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.L
|
||||
}
|
||||
|
||||
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
user.updateCh.IterValues(func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -269,29 +227,18 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.Me
|
||||
return fmt.Errorf("failed to get full message: %w", err)
|
||||
}
|
||||
|
||||
buildRes, err := safe.GetMapErr(
|
||||
user.addrKRs,
|
||||
full.AddressID,
|
||||
func(addrKR *crypto.KeyRing) (*buildRes, error) {
|
||||
return buildRFC822(ctx, full, addrKR)
|
||||
},
|
||||
func() (*buildRes, error) {
|
||||
return nil, fmt.Errorf("address keyring not found")
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build RFC822: %w", err)
|
||||
}
|
||||
return user.withAddrKR(event.Message.AddressID, func(addrKR *crypto.KeyRing) error {
|
||||
buildRes, err := buildRFC822(ctx, full, addrKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build RFC822 message: %w", err)
|
||||
}
|
||||
|
||||
if len(user.updateCh) > 1 {
|
||||
user.updateCh[buildRes.addressID].Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
} else {
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
user.updateCh[apiAddrs[0].ID].Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
user.updateCh.Get(full.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||
@ -302,13 +249,9 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.Me
|
||||
event.Message.Starred(),
|
||||
)
|
||||
|
||||
if len(user.updateCh) > 1 {
|
||||
user.updateCh[event.Message.AddressID].Enqueue(update)
|
||||
} else {
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
user.updateCh[apiAddrs[0].ID].Enqueue(update)
|
||||
})
|
||||
}
|
||||
user.updateCh.Get(event.Message.AddressID, func(updateCh *queue.QueuedChannel[imap.Update]) {
|
||||
updateCh.Enqueue(update)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2,13 +2,13 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@ -25,27 +25,18 @@ const (
|
||||
)
|
||||
|
||||
type imapConnector struct {
|
||||
client *liteapi.Client
|
||||
updateCh <-chan imap.Update
|
||||
*User
|
||||
|
||||
emails []string
|
||||
password []byte
|
||||
addrID string
|
||||
|
||||
flags, permFlags, attrs imap.FlagSet
|
||||
}
|
||||
|
||||
func newIMAPConnector(
|
||||
client *liteapi.Client,
|
||||
updateCh <-chan imap.Update,
|
||||
password []byte,
|
||||
emails ...string,
|
||||
) *imapConnector {
|
||||
func newIMAPConnector(user *User, addrID string) *imapConnector {
|
||||
return &imapConnector{
|
||||
client: client,
|
||||
updateCh: updateCh,
|
||||
User: user,
|
||||
|
||||
emails: emails,
|
||||
password: password,
|
||||
addrID: addrID,
|
||||
|
||||
flags: defaultFlags,
|
||||
permFlags: defaultPermanentFlags,
|
||||
@ -55,13 +46,16 @@ func newIMAPConnector(
|
||||
|
||||
// Authorize returns whether the given username/password combination are valid for this connector.
|
||||
func (conn *imapConnector) Authorize(username string, password []byte) bool {
|
||||
if subtle.ConstantTimeCompare(conn.password, password) != 1 {
|
||||
addrID, err := conn.checkAuth(username, password)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return xslices.IndexFunc(conn.emails, func(address string) bool {
|
||||
return strings.EqualFold(address, username)
|
||||
}) >= 0
|
||||
if conn.vault.AddressMode() == vault.SplitMode && addrID != conn.addrID {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetLabel returns information about the label with the given ID.
|
||||
@ -246,7 +240,14 @@ func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs [
|
||||
// GetUpdates returns a stream of updates that the gluon server should apply.
|
||||
// It is recommended that the returned channel is buffered with at least constants.ChannelBufferCount.
|
||||
func (conn *imapConnector) GetUpdates() <-chan imap.Update {
|
||||
return conn.updateCh
|
||||
updateCh, ok := safe.MapGetRet(conn.updateCh, conn.addrID, func(updateCh *queue.QueuedChannel[imap.Update]) <-chan imap.Update {
|
||||
return updateCh.GetChannel()
|
||||
})
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("update channel for %q not found", conn.addrID))
|
||||
}
|
||||
|
||||
return updateCh
|
||||
}
|
||||
|
||||
// GetUIDValidity returns the default UID validity for this user.
|
||||
|
||||
60
internal/user/keys.go
Normal file
60
internal/user/keys.go
Normal file
@ -0,0 +1,60 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
func (user *User) withUserKR(fn func(*crypto.KeyRing) error) error {
|
||||
return user.apiUser.LoadErr(func(apiUser liteapi.User) error {
|
||||
userKR, err := apiUser.Keys.Unlock(user.vault.KeyPass(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
return fn(userKR)
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) withAddrKR(addrID string, fn func(*crypto.KeyRing) error) error {
|
||||
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
if ok, err := user.apiAddrs.GetErr(addrID, func(apiAddr liteapi.Address) error {
|
||||
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
return fn(addrKR)
|
||||
}); !ok {
|
||||
return fmt.Errorf("no such address %q", addrID)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) withAddrKRs(fn func(map[string]*crypto.KeyRing) error) error {
|
||||
return user.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
return user.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
addrKRs := make(map[string]*crypto.KeyRing)
|
||||
|
||||
for _, apiAddr := range apiAddrs {
|
||||
addrKR, err := apiAddr.Keys.Unlock(user.vault.KeyPass(), userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
defer userKR.ClearPrivateParams()
|
||||
|
||||
addrKRs[apiAddr.ID] = addrKR
|
||||
}
|
||||
|
||||
return fn(addrKRs)
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -34,22 +34,25 @@ type smtpSession struct {
|
||||
// from is the current sending address (taken from the return path).
|
||||
from string
|
||||
|
||||
// fromAddrID is the ID of the curent sending address (taken from the return path).
|
||||
fromAddrID string
|
||||
|
||||
// to holds all to for the current message.
|
||||
to []string
|
||||
}
|
||||
|
||||
func newSMTPSession(user *User, email string) (*smtpSession, error) {
|
||||
authID, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
return getAddrID(apiAddrs, email)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get address ID: %w", err)
|
||||
}
|
||||
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (*smtpSession, error) {
|
||||
authID, err := getAddrID(apiAddrs, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get address ID: %w", err)
|
||||
}
|
||||
|
||||
return &smtpSession{
|
||||
User: user,
|
||||
authID: authID,
|
||||
}, nil
|
||||
return &smtpSession{
|
||||
User: user,
|
||||
authID: authID,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Discard currently processed message.
|
||||
@ -58,6 +61,7 @@ func (session *smtpSession) Reset() {
|
||||
|
||||
// Clear the from and to fields.
|
||||
session.from = ""
|
||||
session.fromAddrID = ""
|
||||
session.to = nil
|
||||
}
|
||||
|
||||
@ -74,7 +78,7 @@ func (session *smtpSession) Logout() error {
|
||||
func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||
logrus.Info("SMTP session mail")
|
||||
|
||||
return session.apiAddrs.GetErr(func(apiAddrs []liteapi.Address) error {
|
||||
return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
switch {
|
||||
case opts.RequireTLS:
|
||||
return ErrNotImplemented
|
||||
@ -93,12 +97,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := getAddrID(apiAddrs, sanitizeEmail(from)); err != nil {
|
||||
addrID, err := getAddrID(apiAddrs, sanitizeEmail(from))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid return path: %w", err)
|
||||
}
|
||||
|
||||
session.from = from
|
||||
|
||||
session.fromAddrID = addrID
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@ -138,18 +145,13 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
return fmt.Errorf("failed to create parser: %w", err)
|
||||
}
|
||||
|
||||
message, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (liteapi.Message, error) {
|
||||
addrID, err := getAddrID(apiAddrs, session.from)
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("invalid return path: %w", err)
|
||||
}
|
||||
|
||||
return safe.GetMapErr(session.addrKRs, addrID, func(addrKR *crypto.KeyRing) (liteapi.Message, error) {
|
||||
return safe.GetTypeErr(session.settings, func(settings liteapi.MailSettings) (liteapi.Message, error) {
|
||||
return session.apiAddrs.ValuesErr(func(apiAddrs []liteapi.Address) error {
|
||||
return session.withAddrKR(session.fromAddrID, func(addrKR *crypto.KeyRing) error {
|
||||
return session.withUserKR(func(userKR *crypto.KeyRing) error {
|
||||
// Use the first key for encrypting the message.
|
||||
addrKR, err := addrKR.FirstKey()
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to get first key: %w", err)
|
||||
return fmt.Errorf("failed to get first key: %w", err)
|
||||
}
|
||||
|
||||
// If the message contains a sender, use it instead of the one from the return path.
|
||||
@ -157,51 +159,61 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
session.from = sender
|
||||
}
|
||||
|
||||
// Load the user's mail settings.
|
||||
settings, err := session.client.GetMailSettings(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get mail settings: %w", err)
|
||||
}
|
||||
|
||||
// If we have to attach the public key, do it now.
|
||||
if settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
|
||||
key, err := addrKR.GetKey(0)
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to get sending key: %w", err)
|
||||
return fmt.Errorf("failed to get sending key: %w", err)
|
||||
}
|
||||
|
||||
pubKey, err := key.GetArmoredPublicKey()
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to get public key: %w", err)
|
||||
return fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||
}
|
||||
|
||||
// Parse the message we want to send (after we have attached the public key).
|
||||
message, err := message.ParseWithParser(parser)
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to parse message: %w", err)
|
||||
return fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
return sendWithKey(
|
||||
// Collect all the user's emails so we can match them to the outgoing message.
|
||||
emails := xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
sent, err := sendWithKey(
|
||||
ctx,
|
||||
session.client,
|
||||
session.authID,
|
||||
session.vault.AddressMode(),
|
||||
apiAddrs,
|
||||
settings,
|
||||
session.userKR,
|
||||
userKR,
|
||||
addrKR,
|
||||
emails,
|
||||
session.from,
|
||||
session.to,
|
||||
message,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
logrus.WithField("messageID", sent.ID).Info("Message sent")
|
||||
|
||||
return nil
|
||||
})
|
||||
}, func() (liteapi.Message, error) {
|
||||
return liteapi.Message{}, ErrMissingAddrKey
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
logrus.WithField("messageID", message.ID).Info("Message sent")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendWithKey sends the message with the given address key.
|
||||
@ -210,10 +222,10 @@ func sendWithKey(
|
||||
client *liteapi.Client,
|
||||
authAddrID string,
|
||||
addrMode vault.AddressMode,
|
||||
apiAddrs []liteapi.Address,
|
||||
settings liteapi.MailSettings,
|
||||
userKR *safe.Value[*crypto.KeyRing],
|
||||
userKR *crypto.KeyRing,
|
||||
addrKR *crypto.KeyRing,
|
||||
emails []string,
|
||||
from string,
|
||||
to []string,
|
||||
message message.Message,
|
||||
@ -243,7 +255,7 @@ func sendWithKey(
|
||||
return liteapi.Message{}, fmt.Errorf("failed to get armored message body: %w", err)
|
||||
}
|
||||
|
||||
draft, err := createDraft(ctx, client, apiAddrs, from, to, parentID, liteapi.DraftTemplate{
|
||||
draft, err := createDraft(ctx, client, emails, from, to, parentID, liteapi.DraftTemplate{
|
||||
Subject: message.Subject,
|
||||
Body: armBody,
|
||||
MIMEType: message.MIMEType,
|
||||
@ -264,9 +276,7 @@ func sendWithKey(
|
||||
return liteapi.Message{}, fmt.Errorf("failed to create attachments: %w", err)
|
||||
}
|
||||
|
||||
recipients, err := safe.GetTypeErr(userKR, func(userKR *crypto.KeyRing) (recipients, error) {
|
||||
return getRecipients(ctx, client, userKR, settings, draft)
|
||||
})
|
||||
recipients, err := getRecipients(ctx, client, userKR, settings, draft)
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to get recipients: %w", err)
|
||||
}
|
||||
@ -357,7 +367,7 @@ func getParentID(
|
||||
func createDraft(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiAddrs []liteapi.Address,
|
||||
emails []string,
|
||||
from string,
|
||||
to []string,
|
||||
parentID string,
|
||||
@ -371,12 +381,12 @@ func createDraft(
|
||||
}
|
||||
|
||||
// Check that the sending address is owned by the user, and if so, sanitize it.
|
||||
if idx := xslices.IndexFunc(apiAddrs, func(addr liteapi.Address) bool {
|
||||
return strings.EqualFold(addr.Email, sanitizeEmail(template.Sender.Address))
|
||||
if idx := xslices.IndexFunc(emails, func(email string) bool {
|
||||
return strings.EqualFold(email, sanitizeEmail(template.Sender.Address))
|
||||
}); idx < 0 {
|
||||
return liteapi.Message{}, fmt.Errorf("address %q is not owned by user", template.Sender.Address)
|
||||
} else {
|
||||
template.Sender.Address = constructEmail(template.Sender.Address, apiAddrs[idx].Email)
|
||||
template.Sender.Address = constructEmail(template.Sender.Address, emails[idx])
|
||||
}
|
||||
|
||||
// Check ToList: ensure that ToList only contains addresses we actually plan to send to.
|
||||
|
||||
@ -10,12 +10,13 @@ import (
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -24,27 +25,43 @@ const (
|
||||
)
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
if err := syncLabels(ctx, user.client, maps.Values(user.updateCh)...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
return user.withAddrKRs(func(addrKRs map[string]*crypto.KeyRing) error {
|
||||
logrus.Info("Beginning sync")
|
||||
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
logrus.Info("Syncing labels")
|
||||
|
||||
if err := user.updateCh.ValuesErr(func(updateCh []*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncLabels(ctx, user.client, xslices.Unique(updateCh)...)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
} else {
|
||||
logrus.Info("Labels are already synced, skipping")
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
}
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
logrus.Info("Syncing labels")
|
||||
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
if err := user.updateCh.MapErr(func(updateCh map[string]*queue.QueuedChannel[imap.Update]) error {
|
||||
return syncMessages(ctx, user.ID(), user.client, user.vault, addrKRs, updateCh, user.eventCh)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
} else {
|
||||
logrus.Info("Messages are already synced, skipping")
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
||||
@ -102,48 +119,44 @@ func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) syncMessages(ctx context.Context) error {
|
||||
func syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
client *liteapi.Client,
|
||||
vault *vault.User,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update],
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
) error {
|
||||
// Determine which messages to sync.
|
||||
allMetadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
||||
metadata, err := client.GetAllMessageMetadata(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get all message metadata: %w", err)
|
||||
}
|
||||
|
||||
metadata := allMetadata
|
||||
// Get the message IDs to sync.
|
||||
messageIDs := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||
return metadata.ID
|
||||
})
|
||||
|
||||
// If possible, begin syncing from one beyond the last synced message.
|
||||
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
||||
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
||||
return metadata.ID == beginID
|
||||
}); idx >= 0 {
|
||||
metadata = metadata[idx+1:]
|
||||
}
|
||||
if idx := xslices.Index(messageIDs, vault.SyncStatus().LastMessageID); idx >= 0 {
|
||||
messageIDs = messageIDs[idx+1:]
|
||||
}
|
||||
|
||||
// Process the metadata, building the messages.
|
||||
buildCh := stream.Chunk(stream.Map(
|
||||
user.client.GetFullMessages(ctx, xslices.Map(metadata, func(metadata liteapi.MessageMetadata) string {
|
||||
return metadata.ID
|
||||
})...),
|
||||
// Fetch and build each message.
|
||||
buildCh := stream.Map(
|
||||
client.GetFullMessages(ctx, messageIDs...),
|
||||
func(ctx context.Context, full liteapi.FullMessage) (*buildRes, error) {
|
||||
return safe.GetMapErr(
|
||||
user.addrKRs,
|
||||
full.AddressID,
|
||||
func(addrKR *crypto.KeyRing) (*buildRes, error) {
|
||||
return buildRFC822(ctx, full, addrKR)
|
||||
},
|
||||
func() (*buildRes, error) {
|
||||
return nil, fmt.Errorf("address keyring not found")
|
||||
},
|
||||
)
|
||||
return buildRFC822(ctx, full, addrKRs[full.AddressID])
|
||||
},
|
||||
), maxBatchSize)
|
||||
)
|
||||
defer buildCh.Close()
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher)
|
||||
|
||||
for addrID, updateCh := range user.updateCh {
|
||||
for addrID, updateCh := range updateCh {
|
||||
flusher := newFlusher(updateCh, maxUpdateSize)
|
||||
defer flusher.flush(ctx, true)
|
||||
|
||||
@ -151,42 +164,27 @@ func (user *User) syncMessages(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create a reporter to report sync progress updates.
|
||||
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
||||
reporter := newReporter(userID, eventCh, len(messageIDs), time.Second)
|
||||
defer reporter.done()
|
||||
|
||||
var count int
|
||||
|
||||
// Send each update to the appropriate flusher.
|
||||
for {
|
||||
batch, err := buildCh.Next(ctx)
|
||||
if errors.Is(err, stream.End) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to get next sync batch: %w", err)
|
||||
return forEach(ctx, stream.Chunk(buildCh, maxBatchSize), func(batch []*buildRes) error {
|
||||
for _, res := range batch {
|
||||
flushers[res.addressID].push(ctx, res.update)
|
||||
}
|
||||
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
for _, res := range batch {
|
||||
if len(flushers) > 1 {
|
||||
flushers[res.addressID].push(ctx, res.update)
|
||||
} else {
|
||||
flushers[apiAddrs[0].ID].push(ctx, res.update)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for _, flusher := range flushers {
|
||||
flusher.flush(ctx, true)
|
||||
}
|
||||
|
||||
if err := user.vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
||||
if err := vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
|
||||
reporter.add(len(batch))
|
||||
|
||||
count += len(batch)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
|
||||
@ -232,3 +230,18 @@ func wantLabelID(labelID string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func forEach[T any](ctx context.Context, streamer stream.Stream[T], fn func(T) error) error {
|
||||
for {
|
||||
res, err := streamer.Next(ctx)
|
||||
if errors.Is(err, stream.End) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to get next stream item: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(res); err != nil {
|
||||
return fmt.Errorf("failed to process stream item: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
// mapTo converts the slice to the given type.
|
||||
// This is not runtime safe, so make sure the slice is of the correct type!
|
||||
// (This is a workaround for the fact that slices cannot be converted to other types generically).
|
||||
func mapTo[From, To any](from []From) []To {
|
||||
to := make([]To, 0, len(from))
|
||||
|
||||
@ -19,3 +26,79 @@ func mapTo[From, To any](from []From) []To {
|
||||
|
||||
return to
|
||||
}
|
||||
|
||||
// groupBy returns a map of the given slice grouped by the given key.
|
||||
// Duplicate keys are overwritten.
|
||||
func groupBy[Key comparable, Value any](items []Value, key func(Value) Key) map[Key]Value {
|
||||
groups := make(map[Key]Value)
|
||||
|
||||
for _, item := range items {
|
||||
groups[key(item)] = item
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// sortAddr returns whether the first address should be sorted before the second.
|
||||
func sortAddr(addrIDA, addrIDB string, apiAddrs map[string]liteapi.Address) bool {
|
||||
return apiAddrs[addrIDA].Order < apiAddrs[addrIDB].Order
|
||||
}
|
||||
|
||||
// hexEncode returns the hexadecimal encoding of the given byte slice.
|
||||
func hexEncode(b []byte) []byte {
|
||||
enc := make([]byte, hex.EncodedLen(len(b)))
|
||||
|
||||
hex.Encode(enc, b)
|
||||
|
||||
return enc
|
||||
}
|
||||
|
||||
// hexDecode returns the bytes represented by the hexadecimal encoding of the given byte slice.
|
||||
func hexDecode(b []byte) ([]byte, error) {
|
||||
dec := make([]byte, hex.DecodedLen(len(b)))
|
||||
|
||||
if _, err := hex.Decode(dec, b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dec, nil
|
||||
}
|
||||
|
||||
// getAddrID returns the address ID for the given email address.
|
||||
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == email {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
}
|
||||
|
||||
// getAddrEmail returns the email address of the given address ID.
|
||||
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.ID == addrID {
|
||||
return addr.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", addrID)
|
||||
}
|
||||
|
||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
@ -1,19 +1,18 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gluon/wait"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/try"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
@ -32,15 +31,11 @@ type User struct {
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
apiUser *safe.Value[liteapi.User]
|
||||
apiAddrs *safe.Slice[liteapi.Address]
|
||||
settings *safe.Value[liteapi.MailSettings]
|
||||
apiAddrs *safe.Map[string, liteapi.Address]
|
||||
updateCh *safe.Map[string, *queue.QueuedChannel[imap.Update]]
|
||||
|
||||
userKR *safe.Value[*crypto.KeyRing]
|
||||
addrKRs *safe.Map[string, *crypto.KeyRing]
|
||||
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
||||
syncStopCh chan struct{}
|
||||
syncWG wait.Group
|
||||
syncLock try.Group
|
||||
}
|
||||
|
||||
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) {
|
||||
@ -50,9 +45,8 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
// Unlock the user's keyrings.
|
||||
userKR, addrKRs, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass())
|
||||
if err != nil {
|
||||
// Check we can unlock the keyrings.
|
||||
if _, _, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass()); err != nil {
|
||||
return nil, fmt.Errorf("failed to unlock user: %w", err)
|
||||
}
|
||||
|
||||
@ -68,20 +62,21 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
}
|
||||
}
|
||||
|
||||
// Get the user's mail settings.
|
||||
settings, err := client.GetMailSettings(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get mail settings: %w", err)
|
||||
}
|
||||
|
||||
// Create update channels for each of the user's addresses (if in combined mode, just the primary).
|
||||
// Create update channels for each of the user's addresses.
|
||||
// In combined mode, the addresses all share the same update channel.
|
||||
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
for _, addr := range apiAddrs {
|
||||
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
switch encVault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if encVault.AddressMode() == vault.CombinedMode {
|
||||
break
|
||||
for _, addr := range apiAddrs {
|
||||
updateCh[addr.ID] = primaryUpdateCh
|
||||
}
|
||||
|
||||
case vault.SplitMode:
|
||||
for _, addr := range apiAddrs {
|
||||
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,19 +86,15 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
apiUser: safe.NewValue(apiUser),
|
||||
apiAddrs: safe.NewSlice(apiAddrs),
|
||||
settings: safe.NewValue(settings),
|
||||
apiAddrs: safe.NewMapFrom(groupBy(apiAddrs, func(addr liteapi.Address) string { return addr.ID }), sortAddr),
|
||||
updateCh: safe.NewMapFrom(updateCh, nil),
|
||||
|
||||
userKR: safe.NewValue(userKR),
|
||||
addrKRs: safe.NewMap(addrKRs),
|
||||
|
||||
updateCh: updateCh,
|
||||
syncStopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the vault.
|
||||
// This will be used to authorize the user on the next run.
|
||||
client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||
user.client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update auth in vault")
|
||||
}
|
||||
@ -111,23 +102,24 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
|
||||
// When we are deauthorized, we send a deauth event to the event channel.
|
||||
// Bridge will react to this event by logging out the user.
|
||||
client.AddDeauthHandler(func() {
|
||||
user.client.AddDeauthHandler(func() {
|
||||
user.eventCh.Enqueue(events.UserDeauth{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
})
|
||||
|
||||
// TODO: Don't start the event loop until the initial sync has finished!
|
||||
eventCh := user.client.NewEventStream(EventPeriod, EventJitter, user.vault.EventID())
|
||||
|
||||
// If we haven't synced yet, do it first.
|
||||
// If it fails, we don't start the event loop.
|
||||
// Otherwise, begin processing API events, logging any errors that occur.
|
||||
go func() {
|
||||
if status := user.vault.SyncStatus(); !status.HasMessages {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
return
|
||||
}
|
||||
if err := <-user.startSync(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for err := range user.streamEvents() {
|
||||
for err := range user.streamEvents(eventCh) {
|
||||
logrus.WithError(err).Error("Error while streaming events")
|
||||
}
|
||||
}()
|
||||
@ -137,40 +129,34 @@ func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiU
|
||||
|
||||
// ID returns the user's ID.
|
||||
func (user *User) ID() string {
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return apiUser.ID
|
||||
})
|
||||
}
|
||||
|
||||
// Name returns the user's username.
|
||||
func (user *User) Name() string {
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return apiUser.Name
|
||||
})
|
||||
}
|
||||
|
||||
// Match matches the given query against the user's username and email addresses.
|
||||
func (user *User) Match(query string) bool {
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) bool {
|
||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) bool {
|
||||
if query == apiUser.Name {
|
||||
return true
|
||||
}
|
||||
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) bool {
|
||||
if query == apiUser.Name {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == query {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return user.apiAddrs.HasFunc(func(_ string, addr liteapi.Address) bool {
|
||||
return addr.Email == query
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Emails returns all the user's email addresses.
|
||||
// Emails returns all the user's email addresses via the callback.
|
||||
func (user *User) Emails() []string {
|
||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||
return safe.MapValuesRet(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
@ -184,28 +170,38 @@ func (user *User) GetAddressMode() vault.AddressMode {
|
||||
|
||||
// SetAddressMode sets the user's address mode.
|
||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Close()
|
||||
}
|
||||
user.stopSync()
|
||||
user.lockSync()
|
||||
defer user.unlockSync()
|
||||
|
||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
for _, addr := range apiAddrs {
|
||||
user.updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if mode == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
updateCh.Close()
|
||||
}
|
||||
})
|
||||
|
||||
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
switch mode {
|
||||
case vault.CombinedMode:
|
||||
primaryUpdateCh := queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
updateCh[addrID] = primaryUpdateCh
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
})
|
||||
}
|
||||
|
||||
user.updateCh = safe.NewMapFrom(updateCh, nil)
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
user.stopSync()
|
||||
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
@ -246,25 +242,19 @@ func (user *User) GluonKey() []byte {
|
||||
|
||||
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
|
||||
func (user *User) BridgePass() []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
if _, err := hex.NewEncoder(buf).Write(user.vault.BridgePass()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
return hexEncode(user.vault.BridgePass())
|
||||
}
|
||||
|
||||
// UsedSpace returns the total space used by the user on the API.
|
||||
func (user *User) UsedSpace() int {
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return apiUser.UsedSpace
|
||||
})
|
||||
}
|
||||
|
||||
// MaxSpace returns the amount of space the user can use on the API.
|
||||
func (user *User) MaxSpace() int {
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return safe.LoadRet(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return apiUser.MaxSpace
|
||||
})
|
||||
}
|
||||
@ -275,37 +265,9 @@ func (user *User) GetEventCh() <-chan events.Event {
|
||||
}
|
||||
|
||||
// NewIMAPConnector returns an IMAP connector for the given address.
|
||||
// If not in split mode, this function returns an error.
|
||||
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
||||
return safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (connector.Connector, error) {
|
||||
var emails []string
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
if addrID != apiAddrs[0].ID {
|
||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
||||
}
|
||||
|
||||
emails = xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
email, err := getAddrEmail(apiAddrs, addrID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emails = []string{email}
|
||||
}
|
||||
|
||||
return newIMAPConnector(
|
||||
user.client,
|
||||
user.updateCh[addrID].GetChannel(),
|
||||
user.BridgePass(),
|
||||
emails...,
|
||||
), nil
|
||||
})
|
||||
// If not in split mode, this must be the primary address.
|
||||
func (user *User) NewIMAPConnector(addrID string) connector.Connector {
|
||||
return newIMAPConnector(user, addrID)
|
||||
}
|
||||
|
||||
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
||||
@ -314,23 +276,48 @@ func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
||||
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
|
||||
for addrID := range user.updateCh {
|
||||
conn, err := user.NewIMAPConnector(addrID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
|
||||
}
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
user.apiAddrs.Index(0, func(addrID string, _ liteapi.Address) {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
})
|
||||
|
||||
imapConn[addrID] = conn
|
||||
case vault.SplitMode:
|
||||
user.apiAddrs.IterKeys(func(addrID string) {
|
||||
imapConn[addrID] = newIMAPConnector(user, addrID)
|
||||
})
|
||||
}
|
||||
|
||||
return imapConn, nil
|
||||
}
|
||||
|
||||
// NewSMTPSession returns an SMTP session for the user.
|
||||
func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
|
||||
func (user *User) NewSMTPSession(email string, password []byte) (smtp.Session, error) {
|
||||
if _, err := user.checkAuth(email, password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newSMTPSession(user, email)
|
||||
}
|
||||
|
||||
// OnStatusUp is called when the connection goes up.
|
||||
func (user *User) OnStatusUp() {
|
||||
go func() {
|
||||
logrus.Info("Connection up, checking if sync is needed")
|
||||
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync on status up")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// OnStatusDown is called when the connection goes down.
|
||||
func (user *User) OnStatusDown() {
|
||||
logrus.Info("Connection down, aborting any ongoing syncs")
|
||||
|
||||
user.stopSync()
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
// If withVault is true, the user's vault is also cleared.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
@ -350,13 +337,18 @@ func (user *User) Close() error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
|
||||
// Wait for ongoing syncs to stop.
|
||||
user.waitSync()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
|
||||
// Close the user's update channels.
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Close()
|
||||
}
|
||||
user.updateCh.Values(func(updateCh []*queue.QueuedChannel[imap.Update]) {
|
||||
for _, updateCh := range xslices.Unique(updateCh) {
|
||||
updateCh.Close()
|
||||
}
|
||||
})
|
||||
|
||||
// Close the user's notify channel.
|
||||
user.eventCh.Close()
|
||||
@ -364,16 +356,37 @@ func (user *User) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) checkAuth(email string, password []byte) (string, error) {
|
||||
dec, err := hexDecode(password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode password: %w", err)
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(user.vault.BridgePass(), dec) != 1 {
|
||||
return "", fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
return safe.MapValuesRetErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == strings.ToLower(email) {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid email")
|
||||
})
|
||||
}
|
||||
|
||||
// streamEvents begins streaming API events for the user.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) streamEvents() <-chan error {
|
||||
func (user *User) streamEvents(eventCh <-chan liteapi.Event) <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
for event := range user.client.NewEventStreamer(EventPeriod, EventJitter, user.vault.EventID()).Subscribe() {
|
||||
for event := range eventCh {
|
||||
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
@ -387,11 +400,21 @@ func (user *User) streamEvents() <-chan error {
|
||||
|
||||
// startSync begins a startSync for the user.
|
||||
func (user *User) startSync() <-chan error {
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
logrus.Debug("Already synced, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
errCh := make(chan error)
|
||||
|
||||
user.syncWG.Go(func() {
|
||||
user.syncLock.GoTry(func(ok bool) {
|
||||
defer close(errCh)
|
||||
|
||||
if !ok {
|
||||
logrus.Debug("Sync already in progress, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
||||
defer cancel()
|
||||
|
||||
@ -421,46 +444,24 @@ func (user *User) startSync() <-chan error {
|
||||
func (user *User) stopSync() {
|
||||
select {
|
||||
case user.syncStopCh <- struct{}{}:
|
||||
user.syncWG.Wait()
|
||||
logrus.Debug("Sent sync abort signal")
|
||||
|
||||
default:
|
||||
// ...
|
||||
logrus.Debug("No sync to abort")
|
||||
}
|
||||
}
|
||||
|
||||
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == email {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
// lockSync prevents a new sync from starting.
|
||||
func (user *User) lockSync() {
|
||||
user.syncLock.Lock()
|
||||
}
|
||||
|
||||
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.ID == addrID {
|
||||
return addr.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", addrID)
|
||||
// unlockSync allows a new sync to start.
|
||||
func (user *User) unlockSync() {
|
||||
user.syncLock.Unlock()
|
||||
}
|
||||
|
||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
// waitSync waits for any ongoing sync to finish.
|
||||
func (user *User) waitSync() {
|
||||
user.syncLock.Wait()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user