mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-17 23:56:56 +00:00
GODT-1657: Stable sync (still needs more tests)
This commit is contained in:
@ -84,13 +84,13 @@ func run(c *cli.Context) error {
|
||||
|
||||
// Start CPU profile if requested.
|
||||
if c.Bool(flagCPUProfile) {
|
||||
p := profile.Start(profile.CPUProfile, profile.ProfilePath("cpu.pprof"))
|
||||
p := profile.Start(profile.CPUProfile, profile.ProfilePath("."))
|
||||
defer p.Stop()
|
||||
}
|
||||
|
||||
// Start memory profile if requested.
|
||||
if c.Bool(flagMemProfile) {
|
||||
p := profile.Start(profile.MemProfile, profile.MemProfileAllocs, profile.ProfilePath("mem.pprof"))
|
||||
p := profile.Start(profile.MemProfile, profile.MemProfileAllocs, profile.ProfilePath("."))
|
||||
defer p.Stop()
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +69,9 @@ type Bridge struct {
|
||||
|
||||
// errors contains errors encountered during startup.
|
||||
errors []error
|
||||
|
||||
// stopCh is used to stop ongoing goroutines when the bridge is closed.
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new bridge.
|
||||
@ -153,6 +156,8 @@ func New(
|
||||
focusService: focusService,
|
||||
autostarter: autostarter,
|
||||
locator: locator,
|
||||
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
api.AddStatusObserver(func(status liteapi.Status) {
|
||||
@ -232,12 +237,8 @@ func (bridge *Bridge) GetErrors() []error {
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Close(ctx context.Context) error {
|
||||
// Abort any ongoing syncs.
|
||||
for _, user := range bridge.users {
|
||||
if err := user.AbortSync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to abort sync: %w", err)
|
||||
}
|
||||
}
|
||||
// Stop ongoing operations such as connectivity checks.
|
||||
close(bridge.stopCh)
|
||||
|
||||
// Close the IMAP server.
|
||||
if err := bridge.closeIMAP(ctx); err != nil {
|
||||
@ -251,7 +252,7 @@ func (bridge *Bridge) Close(ctx context.Context) error {
|
||||
|
||||
// Close all users.
|
||||
for _, user := range bridge.users {
|
||||
if err := user.Close(ctx); err != nil {
|
||||
if err := user.Close(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
}
|
||||
@ -335,6 +336,9 @@ func (bridge *Bridge) onStatusDown() {
|
||||
case <-upCh:
|
||||
return
|
||||
|
||||
case <-bridge.stopCh:
|
||||
return
|
||||
|
||||
case <-time.After(backoff):
|
||||
if err := bridge.api.Ping(ctx); err != nil {
|
||||
logrus.WithError(err).Debug("Failed to ping API")
|
||||
|
||||
@ -25,20 +25,17 @@ import (
|
||||
"gitlab.protontech.ch/go/liteapi/server/backend"
|
||||
)
|
||||
|
||||
const (
|
||||
username = "username"
|
||||
)
|
||||
|
||||
var password = []byte("password")
|
||||
|
||||
var (
|
||||
username = "username"
|
||||
password = []byte("password")
|
||||
|
||||
v2_3_0 = semver.MustParse("2.3.0")
|
||||
v2_4_0 = semver.MustParse("2.4.0")
|
||||
)
|
||||
|
||||
func init() {
|
||||
user.DefaultEventPeriod = 100 * time.Millisecond
|
||||
user.DefaultEventJitter = 0
|
||||
user.EventPeriod = 100 * time.Millisecond
|
||||
user.EventJitter = 0
|
||||
backend.GenerateKey = tests.FastGenerateKey
|
||||
certs.GenerateCert = tests.FastGenerateCert
|
||||
}
|
||||
|
||||
@ -74,7 +74,7 @@ func getGluonDir(encVault *vault.Vault) (string, error) {
|
||||
|
||||
if empty {
|
||||
if err := encVault.ForUser(func(user *vault.User) error {
|
||||
return user.SetSync(false)
|
||||
return user.ClearSyncStatus()
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("failed to reset user sync status: %w", err)
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
@ -86,6 +85,10 @@ func (bridge *Bridge) LoginUser(
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, ok := bridge.users[auth.UserID]; ok {
|
||||
return "", ErrUserAlreadyLoggedIn
|
||||
}
|
||||
|
||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||
totp, err := getTOTP()
|
||||
if err != nil {
|
||||
@ -110,16 +113,26 @@ func (bridge *Bridge) LoginUser(
|
||||
keyPass = password
|
||||
}
|
||||
|
||||
apiUser, apiAddrs, userKR, addrKRs, saltedKeyPass, err := client.Unlock(ctx, keyPass)
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
|
||||
salts, err := client.GetSalts(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return apiUser.ID, nil
|
||||
saltedKeyPass, err := salts.SaltForKey(keyPass, apiUser.Keys.Primary().ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return auth.UserID, nil
|
||||
}
|
||||
|
||||
// LogoutUser logs out the given user.
|
||||
@ -158,10 +171,6 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
||||
return fmt.Errorf("address mode is already %q", mode)
|
||||
}
|
||||
|
||||
if err := user.AbortSync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to abort sync: %w", err)
|
||||
}
|
||||
|
||||
for _, gluonID := range user.GetGluonIDs() {
|
||||
if err := bridge.imapServer.RemoveUser(ctx, gluonID, true); err != nil {
|
||||
return fmt.Errorf("failed to remove user from IMAP server: %w", err)
|
||||
@ -181,8 +190,6 @@ func (bridge *Bridge) SetAddressMode(ctx context.Context, userID string, mode va
|
||||
AddressMode: mode,
|
||||
})
|
||||
|
||||
user.DoSync(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -220,12 +227,12 @@ func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
||||
return fmt.Errorf("failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
apiUser, apiAddrs, userKR, addrKRs, err := client.UnlockSalted(ctx, user.KeyPass())
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user: %w", err)
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
|
||||
if err := bridge.addUser(ctx, client, apiUser, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
|
||||
@ -241,27 +248,20 @@ func (bridge *Bridge) addUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) error {
|
||||
if _, ok := bridge.users[apiUser.ID]; ok {
|
||||
return ErrUserAlreadyLoggedIn
|
||||
}
|
||||
|
||||
var user *user.User
|
||||
|
||||
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
|
||||
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
|
||||
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add existing user: %w", err)
|
||||
}
|
||||
|
||||
user = existingUser
|
||||
} else {
|
||||
newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
|
||||
newUser, err := bridge.addNewUser(ctx, client, apiUser, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add new user: %w", err)
|
||||
}
|
||||
@ -269,11 +269,16 @@ func (bridge *Bridge) addUser(
|
||||
user = newUser
|
||||
}
|
||||
|
||||
// Connects the user's address(es) to gluon.
|
||||
// Connect the user's address(es) to gluon.
|
||||
if err := bridge.addIMAPUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to add IMAP user: %w", err)
|
||||
}
|
||||
|
||||
// Connect the user's address(es) to the SMTP server.
|
||||
if err := bridge.smtpBackend.addUser(user); err != nil {
|
||||
return fmt.Errorf("failed to add user to SMTP backend: %w", err)
|
||||
}
|
||||
|
||||
// Handle events coming from the user before forwarding them to the bridge.
|
||||
// For example, if the user's addresses change, we need to update them in gluon.
|
||||
go func() {
|
||||
@ -299,11 +304,6 @@ func (bridge *Bridge) addUser(
|
||||
return nil
|
||||
})
|
||||
|
||||
// TODO: Replace this with proper sync manager.
|
||||
if !user.HasSync() {
|
||||
user.DoSync(ctx)
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
@ -315,9 +315,6 @@ func (bridge *Bridge) addNewUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*user.User, error) {
|
||||
@ -326,15 +323,11 @@ func (bridge *Bridge) addNewUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.addUser(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bridge.users[apiUser.ID] = user
|
||||
|
||||
return user, nil
|
||||
@ -344,9 +337,6 @@ func (bridge *Bridge) addExistingUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*user.User, error) {
|
||||
@ -363,15 +353,11 @@ func (bridge *Bridge) addExistingUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.addUser(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bridge.users[apiUser.ID] = user
|
||||
|
||||
return user, nil
|
||||
@ -386,11 +372,6 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
// TODO: The sync should be canceled by the sync manager.
|
||||
if err := user.AbortSync(ctx); err != nil {
|
||||
return fmt.Errorf("failed to abort user sync: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
return fmt.Errorf("failed to remove SMTP user: %w", err)
|
||||
}
|
||||
@ -407,7 +388,7 @@ func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, wi
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Close(ctx); err != nil {
|
||||
if err := user.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close user: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
package events
|
||||
|
||||
type MessageSent struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
AddressID string
|
||||
MessageID string
|
||||
}
|
||||
@ -22,3 +22,10 @@ type SyncFinished struct {
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type SyncFailed struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
Err error
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// ErrJobCancelled indicates the job was cancelled.
|
||||
var ErrJobCancelled = errors.New("Job cancelled by surrounding context")
|
||||
var ErrJobCancelled = errors.New("job cancelled by surrounding context")
|
||||
|
||||
// Pool is a worker pool that handles input of type In and returns results of type Out.
|
||||
type Pool[In comparable, Out any] struct {
|
||||
|
||||
97
internal/safe/map.go
Normal file
97
internal/safe/map.go
Normal file
@ -0,0 +1,97 @@
|
||||
package safe
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type Map[Key comparable, Val any] struct {
|
||||
data map[Key]Val
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMap[Key comparable, Val any](from map[Key]Val) *Map[Key, Val] {
|
||||
m := &Map[Key, Val]{
|
||||
data: make(map[Key]Val),
|
||||
}
|
||||
|
||||
for key, val := range from {
|
||||
m.Set(key, val)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Get(key Key, fn func(val Val)) bool {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
val, ok := m.data[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
fn(val)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) GetErr(key Key, fn func(val Val) error) (bool, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
val, ok := m.data[key]
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, fn(val)
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Set(key Key, val Val) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
m.data[key] = val
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Keys(fn func(keys []Key)) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
fn(maps.Keys(m.data))
|
||||
}
|
||||
|
||||
func (m *Map[Key, Val]) Values(fn func(vals []Val)) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
fn(maps.Values(m.data))
|
||||
}
|
||||
|
||||
func GetMap[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) Ret) (Ret, bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
val, ok := m.data[key]
|
||||
if !ok {
|
||||
return *new(Ret), false
|
||||
}
|
||||
|
||||
return fn(val), true
|
||||
}
|
||||
|
||||
func GetMapErr[Key comparable, Val, Ret any](m *Map[Key, Val], key Key, fn func(val Val) (Ret, error)) (Ret, bool, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
val, ok := m.data[key]
|
||||
if !ok {
|
||||
return *new(Ret), false, nil
|
||||
}
|
||||
|
||||
ret, err := fn(val)
|
||||
|
||||
return ret, true, err
|
||||
}
|
||||
53
internal/safe/slice.go
Normal file
53
internal/safe/slice.go
Normal file
@ -0,0 +1,53 @@
|
||||
package safe
|
||||
|
||||
import "sync"
|
||||
|
||||
type Slice[Val any] struct {
|
||||
data []Val
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSlice[Val any](from []Val) *Slice[Val] {
|
||||
s := &Slice[Val]{
|
||||
data: make([]Val, len(from)),
|
||||
}
|
||||
|
||||
copy(s.data, from)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Slice[Val]) Get(fn func(data []Val)) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
fn(s.data)
|
||||
}
|
||||
|
||||
func (s *Slice[Val]) GetErr(fn func(data []Val) error) error {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return fn(s.data)
|
||||
}
|
||||
|
||||
func (s *Slice[Val]) Set(data []Val) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.data = data
|
||||
}
|
||||
|
||||
func GetSlice[Val, Ret any](s *Slice[Val], fn func(data []Val) Ret) Ret {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return fn(s.data)
|
||||
}
|
||||
|
||||
func GetSliceErr[Val, Ret any](s *Slice[Val], fn func(data []Val) (Ret, error)) (Ret, error) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return fn(s.data)
|
||||
}
|
||||
49
internal/safe/type.go
Normal file
49
internal/safe/type.go
Normal file
@ -0,0 +1,49 @@
|
||||
package safe
|
||||
|
||||
import "sync"
|
||||
|
||||
type Type[T any] struct {
|
||||
data T
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewType[T any](data T) *Type[T] {
|
||||
return &Type[T]{
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Type[T]) Get(fn func(data T)) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
fn(s.data)
|
||||
}
|
||||
|
||||
func (s *Type[T]) GetErr(fn func(data T) error) error {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return fn(s.data)
|
||||
}
|
||||
|
||||
func (s *Type[T]) Set(data T) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.data = data
|
||||
}
|
||||
|
||||
func GetType[T, Ret any](s *Type[T], fn func(data T) Ret) Ret {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return fn(s.data)
|
||||
}
|
||||
|
||||
func GetTypeErr[T, Ret any](s *Type[T], fn func(data T) (Ret, error)) (Ret, error) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
return fn(s.data)
|
||||
}
|
||||
@ -1,50 +0,0 @@
|
||||
package user
|
||||
|
||||
import "gitlab.protontech.ch/go/liteapi"
|
||||
|
||||
type addrList struct {
|
||||
apiAddrs ordMap[string, string, liteapi.Address]
|
||||
}
|
||||
|
||||
func newAddrList(apiAddrs []liteapi.Address) *addrList {
|
||||
return &addrList{
|
||||
apiAddrs: newOrdMap(
|
||||
func(addr liteapi.Address) string { return addr.ID },
|
||||
func(addr liteapi.Address) string { return addr.Email },
|
||||
func(a, b liteapi.Address) bool { return a.Order < b.Order },
|
||||
apiAddrs...,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (list *addrList) insert(address liteapi.Address) {
|
||||
list.apiAddrs.insert(address)
|
||||
}
|
||||
|
||||
func (list *addrList) delete(addrID string) string {
|
||||
return list.apiAddrs.delete(addrID)
|
||||
}
|
||||
|
||||
func (list *addrList) primary() string {
|
||||
return list.apiAddrs.keys()[0]
|
||||
}
|
||||
|
||||
func (list *addrList) addrIDs() []string {
|
||||
return list.apiAddrs.keys()
|
||||
}
|
||||
|
||||
func (list *addrList) addrID(email string) (string, bool) {
|
||||
return list.apiAddrs.getKey(email)
|
||||
}
|
||||
|
||||
func (list *addrList) emails() []string {
|
||||
return list.apiAddrs.values()
|
||||
}
|
||||
|
||||
func (list *addrList) email(addrID string) (string, bool) {
|
||||
return list.apiAddrs.getVal(addrID)
|
||||
}
|
||||
|
||||
func (list *addrList) addrMap() map[string]string {
|
||||
return list.apiAddrs.toMap()
|
||||
}
|
||||
@ -1,95 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type request struct {
|
||||
messageID string
|
||||
addressID string
|
||||
addrKR *crypto.KeyRing
|
||||
}
|
||||
|
||||
type fetcher interface {
|
||||
GetMessage(context.Context, string) (liteapi.Message, error)
|
||||
GetAttachment(context.Context, string) ([]byte, error)
|
||||
}
|
||||
|
||||
func newBuilder(f fetcher, msgWorkers, attWorkers int) *pool.Pool[request, *imap.MessageCreated] {
|
||||
attPool := pool.New(attWorkers, func(ctx context.Context, attID string) ([]byte, error) {
|
||||
return f.GetAttachment(ctx, attID)
|
||||
})
|
||||
|
||||
msgPool := pool.New(msgWorkers, func(ctx context.Context, req request) (*imap.MessageCreated, error) {
|
||||
msg, err := f.GetMessage(ctx, req.messageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var attIDs []string
|
||||
|
||||
for _, att := range msg.Attachments {
|
||||
attIDs = append(attIDs, att.ID)
|
||||
}
|
||||
|
||||
attData, err := attPool.ProcessAll(ctx, attIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
literal, err := message.BuildRFC822(req.addrKR, msg, attData, message.JobOptions{
|
||||
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
|
||||
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
|
||||
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
|
||||
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
|
||||
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
|
||||
AddMessageIDReference: true, // Whether to include the MessageID in References.
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newMessageCreatedUpdate(msg, literal)
|
||||
})
|
||||
|
||||
return msgPool
|
||||
}
|
||||
|
||||
func newMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
|
||||
parsedMessage, err := imap.NewParsedMessage(literal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flags := imap.NewFlagSet()
|
||||
|
||||
if !message.Unread {
|
||||
flags = flags.Add(imap.FlagSeen)
|
||||
}
|
||||
|
||||
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
|
||||
flags = flags.Add(imap.FlagFlagged)
|
||||
}
|
||||
|
||||
imapMessage := imap.Message{
|
||||
ID: imap.MessageID(message.ID),
|
||||
Flags: flags,
|
||||
Date: time.Unix(message.Time, 0),
|
||||
}
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: imapMessage,
|
||||
Literal: literal,
|
||||
LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)),
|
||||
ParsedMessage: parsedMessage,
|
||||
}, nil
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
@ -54,7 +55,7 @@ func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) e
|
||||
return err
|
||||
}
|
||||
|
||||
user.apiUser = userEvent
|
||||
user.apiUser.Set(userEvent)
|
||||
|
||||
user.userKR = userKR
|
||||
|
||||
@ -96,24 +97,29 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.insert(event.Address)
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.Set(apiAddrs)
|
||||
|
||||
user.addrKRs[event.Address.ID] = addrKR
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if err := user.syncLabels(ctx, event.Address.ID); err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if err := syncLabels(ctx, user.client, user.updateCh[event.Address.ID]); err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -123,7 +129,12 @@ func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.insert(event.Address)
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.Set(apiAddrs)
|
||||
|
||||
user.addrKRs[event.Address.ID] = addrKR
|
||||
|
||||
@ -137,9 +148,23 @@ func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
email := user.apiAddrs.delete(event.ID)
|
||||
email, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
return getAddrEmail(apiAddrs, event.ID)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get address email: %w", err)
|
||||
}
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
apiAddrs, err := user.client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
user.apiAddrs.Set(apiAddrs)
|
||||
|
||||
delete(user.addrKRs, event.ID)
|
||||
|
||||
if len(user.updateCh) > 1 {
|
||||
user.updateCh[event.ID].Close()
|
||||
delete(user.updateCh, event.ID)
|
||||
}
|
||||
@ -155,7 +180,7 @@ func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.Ad
|
||||
|
||||
// handleMailSettingsEvent handles the given mail settings event.
|
||||
func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
|
||||
user.settings = mailSettingsEvent
|
||||
user.settings.Set(mailSettingsEvent)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -234,24 +259,18 @@ func (user *User) handleMessageEvents(ctx context.Context, messageEvents []litea
|
||||
}
|
||||
|
||||
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||
var addressID string
|
||||
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
addressID = user.apiAddrs.primary()
|
||||
} else {
|
||||
addressID = event.Message.AddressID
|
||||
}
|
||||
|
||||
message, err := user.builder.ProcessOne(ctx, request{
|
||||
messageID: event.ID,
|
||||
addressID: addressID,
|
||||
addrKR: user.addrKRs[event.Message.AddressID],
|
||||
})
|
||||
buildRes, err := user.buildRFC822(ctx, event.Message)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to build RFC822: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh[addressID].Enqueue(imap.NewMessagesCreated(message))
|
||||
if len(user.updateCh) > 1 {
|
||||
user.updateCh[buildRes.addressID].Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
} else {
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
user.updateCh[apiAddrs[0].ID].Enqueue(imap.NewMessagesCreated(buildRes.update))
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -264,10 +283,12 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.Me
|
||||
event.Message.Starred(),
|
||||
)
|
||||
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
user.updateCh[user.apiAddrs.primary()].Enqueue(update)
|
||||
} else {
|
||||
if len(user.updateCh) > 1 {
|
||||
user.updateCh[event.Message.AddressID].Enqueue(update)
|
||||
} else {
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
user.updateCh[apiAddrs[0].ID].Enqueue(update)
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -1,76 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
)
|
||||
|
||||
type flusher struct {
|
||||
userID string
|
||||
updateCh *queue.QueuedChannel[imap.Update]
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
updates []*imap.MessageCreated
|
||||
maxChunkSize int
|
||||
curChunkSize int
|
||||
|
||||
count int
|
||||
total int
|
||||
start time.Time
|
||||
|
||||
pushLock sync.Mutex
|
||||
}
|
||||
|
||||
func newFlusher(
|
||||
userID string,
|
||||
updateCh *queue.QueuedChannel[imap.Update],
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
total, maxChunkSize int,
|
||||
) *flusher {
|
||||
return &flusher{
|
||||
userID: userID,
|
||||
updateCh: updateCh,
|
||||
eventCh: eventCh,
|
||||
|
||||
maxChunkSize: maxChunkSize,
|
||||
|
||||
total: total,
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) push(update *imap.MessageCreated) {
|
||||
f.pushLock.Lock()
|
||||
defer f.pushLock.Unlock()
|
||||
|
||||
f.updates = append(f.updates, update)
|
||||
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
|
||||
f.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) flush() {
|
||||
if len(f.updates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
f.count += len(f.updates)
|
||||
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
||||
f.eventCh.Enqueue(newSyncProgress(f.userID, f.count, f.total, f.start))
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
}
|
||||
|
||||
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
|
||||
return events.SyncProgress{
|
||||
UserID: userID,
|
||||
Progress: float64(count) / float64(total),
|
||||
Elapsed: time.Since(start),
|
||||
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
|
||||
}
|
||||
}
|
||||
@ -1,104 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type ordMap[Key, Val comparable, Data any] struct {
|
||||
data map[Key]Data
|
||||
order []Key
|
||||
|
||||
toKey func(Data) Key
|
||||
toVal func(Data) Val
|
||||
isLess func(Data, Data) bool
|
||||
}
|
||||
|
||||
func newOrdMap[Key, Val comparable, Data any](
|
||||
key func(Data) Key,
|
||||
value func(Data) Val,
|
||||
less func(Data, Data) bool,
|
||||
data ...Data,
|
||||
) ordMap[Key, Val, Data] {
|
||||
m := ordMap[Key, Val, Data]{
|
||||
data: make(map[Key]Data),
|
||||
|
||||
toKey: key,
|
||||
toVal: value,
|
||||
isLess: less,
|
||||
}
|
||||
|
||||
for _, d := range data {
|
||||
m.insert(d)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) insert(data Data) {
|
||||
if _, ok := set.data[set.toKey(data)]; ok {
|
||||
set.delete(set.toKey(data))
|
||||
}
|
||||
|
||||
set.data[set.toKey(data)] = data
|
||||
|
||||
set.order = append(set.order, set.toKey(data))
|
||||
|
||||
slices.SortFunc(set.order, func(a, b Key) bool {
|
||||
return set.isLess(set.data[a], set.data[b])
|
||||
})
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) delete(key Key) Val {
|
||||
data, ok := set.data[key]
|
||||
if !ok {
|
||||
return *new(Val)
|
||||
}
|
||||
|
||||
delete(set.data, key)
|
||||
|
||||
set.order = xslices.Filter(set.order, func(otherKey Key) bool {
|
||||
return otherKey != key
|
||||
})
|
||||
|
||||
return set.toVal(data)
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) getVal(key Key) (Val, bool) {
|
||||
data, ok := set.data[key]
|
||||
if !ok {
|
||||
return *new(Val), false
|
||||
}
|
||||
|
||||
return set.toVal(data), true
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) getKey(wantVal Val) (Key, bool) {
|
||||
for key, data := range set.data {
|
||||
if set.toVal(data) == wantVal {
|
||||
return key, true
|
||||
}
|
||||
}
|
||||
|
||||
return *new(Key), false
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) keys() []Key {
|
||||
return set.order
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) values() []Val {
|
||||
return xslices.Map(set.order, func(key Key) Val {
|
||||
return set.toVal(set.data[key])
|
||||
})
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) toMap() map[Key]Val {
|
||||
m := make(map[Key]Val)
|
||||
|
||||
for _, key := range set.order {
|
||||
m[key] = set.toVal(set.data[key])
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
@ -1,48 +0,0 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
type Key int
|
||||
|
||||
type Value string
|
||||
|
||||
type Data struct {
|
||||
key Key
|
||||
value Value
|
||||
}
|
||||
|
||||
m := newOrdMap(
|
||||
func(d Data) Key { return d.key },
|
||||
func(d Data) Value { return d.value },
|
||||
func(a, b Data) bool { return a.key < b.key },
|
||||
Data{key: 1, value: "a"},
|
||||
Data{key: 2, value: "b"},
|
||||
Data{key: 3, value: "c"},
|
||||
)
|
||||
|
||||
// Insert some new data.
|
||||
m.insert(Data{key: 4, value: "d"})
|
||||
m.insert(Data{key: 5, value: "e"})
|
||||
|
||||
// Delete some data.
|
||||
require.Equal(t, Value("c"), m.delete(3))
|
||||
require.Equal(t, Value("a"), m.delete(1))
|
||||
require.Equal(t, Value("e"), m.delete(5))
|
||||
|
||||
// Check the remaining keys and values are correct.
|
||||
require.Equal(t, []Key{2, 4}, m.keys())
|
||||
require.Equal(t, []Value{"b", "d"}, m.values())
|
||||
|
||||
// Overwrite some data.
|
||||
m.insert(Data{key: 2, value: "two"})
|
||||
m.insert(Data{key: 4, value: "four"})
|
||||
|
||||
// Check the remaining keys and values are correct.
|
||||
require.Equal(t, []Key{2, 4}, m.keys())
|
||||
require.Equal(t, []Value{"two", "four"}, m.values())
|
||||
}
|
||||
@ -6,14 +6,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-rfc5322"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message/parser"
|
||||
@ -22,37 +22,14 @@ import (
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type smtpSession struct {
|
||||
// client is the user's API client.
|
||||
client *liteapi.Client
|
||||
*User
|
||||
|
||||
// eventCh allows the session to publish events.
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
// userID is the user's ID.
|
||||
userID string
|
||||
|
||||
// addrID holds the ID of the address that is currently being used.
|
||||
addrID string
|
||||
|
||||
// addrMode holds the address mode that is currently being used.
|
||||
addrMode vault.AddressMode
|
||||
|
||||
// emails holds all email addresses associated with the user, by address ID.
|
||||
emails map[string]string
|
||||
|
||||
// settings holds the mail settings for the user.
|
||||
settings liteapi.MailSettings
|
||||
|
||||
// userKR holds the user's keyring.
|
||||
userKR *crypto.KeyRing
|
||||
|
||||
// addrKRs holds the keyrings for each address.
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
// authID holds the ID of the address that the SMTP client authenticated with to send the message.
|
||||
authID string
|
||||
|
||||
// fromAddrID is the ID of the current sending address (taken from the return path).
|
||||
fromAddrID string
|
||||
@ -61,30 +38,18 @@ type smtpSession struct {
|
||||
to []string
|
||||
}
|
||||
|
||||
func newSMTPSession(
|
||||
client *liteapi.Client,
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
userID, addrID string,
|
||||
addrMode vault.AddressMode,
|
||||
emails map[string]string,
|
||||
settings liteapi.MailSettings,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
) *smtpSession {
|
||||
return &smtpSession{
|
||||
client: client,
|
||||
eventCh: eventCh,
|
||||
|
||||
userID: userID,
|
||||
addrID: addrID,
|
||||
addrMode: addrMode,
|
||||
|
||||
emails: emails,
|
||||
settings: settings,
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
func newSMTPSession(user *User, email string) (*smtpSession, error) {
|
||||
authID, err := safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
return getAddrID(apiAddrs, email)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get address ID: %w", err)
|
||||
}
|
||||
|
||||
return &smtpSession{
|
||||
User: user,
|
||||
authID: authID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Discard currently processed message.
|
||||
@ -109,30 +74,35 @@ func (session *smtpSession) Logout() error {
|
||||
func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||
logrus.Info("SMTP session mail")
|
||||
|
||||
switch {
|
||||
case opts.RequireTLS:
|
||||
return ErrNotImplemented
|
||||
return session.apiAddrs.GetErr(func(apiAddrs []liteapi.Address) error {
|
||||
|
||||
case opts.UTF8:
|
||||
return ErrNotImplemented
|
||||
|
||||
case opts.Auth != nil:
|
||||
if *opts.Auth != "" && *opts.Auth != session.emails[session.addrID] {
|
||||
switch {
|
||||
case opts.RequireTLS:
|
||||
return ErrNotImplemented
|
||||
|
||||
case opts.UTF8:
|
||||
return ErrNotImplemented
|
||||
|
||||
case opts.Auth != nil:
|
||||
email, err := getAddrEmail(apiAddrs, session.authID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid auth address: %w", err)
|
||||
}
|
||||
|
||||
if *opts.Auth != "" && *opts.Auth != email {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for addrID, email := range session.emails {
|
||||
if strings.EqualFold(from, email) {
|
||||
session.fromAddrID = addrID
|
||||
fromAddrID, err := getAddrID(apiAddrs, from)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid return path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if session.fromAddrID == "" {
|
||||
return ErrInvalidReturnPath
|
||||
}
|
||||
session.fromAddrID = fromAddrID
|
||||
|
||||
return nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Add recipient for currently processed message.
|
||||
@ -168,13 +138,15 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
}
|
||||
|
||||
// If the message contains a sender, use it instead of the one from the return path.
|
||||
if sender, ok := getMessageSender(parser); ok {
|
||||
for addrID, email := range session.emails {
|
||||
if strings.EqualFold(email, sanitizeEmail(sender)) {
|
||||
session.fromAddrID = addrID
|
||||
session.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
if sender, ok := getMessageSender(parser); ok {
|
||||
for _, addr := range apiAddrs {
|
||||
if strings.EqualFold(addr.Email, sanitizeEmail(sender)) {
|
||||
session.fromAddrID = addr.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
addrKR, ok := session.addrKRs[session.fromAddrID]
|
||||
if !ok {
|
||||
@ -186,28 +158,38 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
return fmt.Errorf("failed to get first key: %w", err)
|
||||
}
|
||||
|
||||
message, err := sendWithKey(
|
||||
session.client,
|
||||
session.addrID,
|
||||
session.addrMode,
|
||||
session.userKR,
|
||||
firstAddrKR,
|
||||
session.settings,
|
||||
sanitizeEmail(session.emails[session.fromAddrID]),
|
||||
session.to,
|
||||
maps.Values(session.emails),
|
||||
parser,
|
||||
)
|
||||
from, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (string, error) {
|
||||
email, err := getAddrEmail(apiAddrs, session.fromAddrID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get address email: %w", err)
|
||||
}
|
||||
|
||||
return sanitizeEmail(email), nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get address email: %w", err)
|
||||
}
|
||||
|
||||
message, err := safe.GetSliceErr(session.apiAddrs, func(apiAddrs []liteapi.Address) (liteapi.Message, error) {
|
||||
return safe.GetTypeErr(session.settings, func(settings liteapi.MailSettings) (liteapi.Message, error) {
|
||||
return sendWithKey(
|
||||
session.client,
|
||||
session.authID,
|
||||
session.vault.AddressMode(),
|
||||
apiAddrs,
|
||||
settings,
|
||||
session.userKR,
|
||||
firstAddrKR,
|
||||
parser,
|
||||
from,
|
||||
session.to,
|
||||
)
|
||||
})
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
|
||||
session.eventCh.Enqueue(events.MessageSent{
|
||||
UserID: session.userID,
|
||||
AddressID: session.addrID,
|
||||
MessageID: message.ID,
|
||||
})
|
||||
|
||||
logrus.WithField("messageID", message.ID).Info("Message sent")
|
||||
|
||||
return nil
|
||||
@ -216,13 +198,14 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
// sendWithKey sends the message with the given address key.
|
||||
func sendWithKey(
|
||||
client *liteapi.Client,
|
||||
addrID string,
|
||||
authAddrID string,
|
||||
addrMode vault.AddressMode,
|
||||
userKR, addrKR *crypto.KeyRing,
|
||||
apiAddrs []liteapi.Address,
|
||||
settings liteapi.MailSettings,
|
||||
from string,
|
||||
to, emails []string,
|
||||
userKR, addrKR *crypto.KeyRing,
|
||||
parser *parser.Parser,
|
||||
from string,
|
||||
to []string,
|
||||
) (liteapi.Message, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@ -246,11 +229,11 @@ func sendWithKey(
|
||||
return liteapi.Message{}, fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
if err := sanitizeParsedMessage(&message, from, to, emails); err != nil {
|
||||
if err := sanitizeParsedMessage(&message, apiAddrs, from, to); err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to sanitize message: %w", err)
|
||||
}
|
||||
|
||||
parentID, err := getParentID(ctx, client, addrID, addrMode, message.References)
|
||||
parentID, err := getParentID(ctx, client, authAddrID, addrMode, message.References)
|
||||
if err != nil {
|
||||
return liteapi.Message{}, fmt.Errorf("failed to get parent ID: %w", err)
|
||||
}
|
||||
@ -278,7 +261,7 @@ func sendWithKey(
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func sanitizeParsedMessage(message *message.Message, from string, to, emails []string) error {
|
||||
func sanitizeParsedMessage(message *message.Message, apiAddrs []liteapi.Address, from string, to []string) error {
|
||||
// Check sender: set the sender in the parsed message if it's missing.
|
||||
if message.Sender == nil {
|
||||
message.Sender = &mail.Address{Address: from}
|
||||
@ -287,12 +270,12 @@ func sanitizeParsedMessage(message *message.Message, from string, to, emails []s
|
||||
}
|
||||
|
||||
// Check that the sending address is owned by the user, and if so, properly capitalize it.
|
||||
if idx := xslices.IndexFunc(emails, func(email string) bool {
|
||||
return strings.EqualFold(email, sanitizeEmail(message.Sender.Address))
|
||||
if idx := xslices.IndexFunc(apiAddrs, func(addr liteapi.Address) bool {
|
||||
return strings.EqualFold(addr.Email, sanitizeEmail(message.Sender.Address))
|
||||
}); idx < 0 {
|
||||
return fmt.Errorf("address %q is not owned by user", message.Sender.Address)
|
||||
} else {
|
||||
message.Sender.Address = constructEmail(message.Sender.Address, emails[idx])
|
||||
message.Sender.Address = constructEmail(message.Sender.Address, apiAddrs[idx].Email)
|
||||
}
|
||||
|
||||
// Check ToList: ensure that ToList only contains addresses we actually plan to send to.
|
||||
@ -313,7 +296,7 @@ func sanitizeParsedMessage(message *message.Message, from string, to, emails []s
|
||||
func getParentID(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
addrID string,
|
||||
authAddrID string,
|
||||
addrMode vault.AddressMode,
|
||||
references []string,
|
||||
) (string, error) {
|
||||
@ -334,12 +317,12 @@ func getParentID(
|
||||
|
||||
// Try to find a parent ID in the internal references.
|
||||
for _, internal := range internal {
|
||||
filter := map[string][]string{
|
||||
filter := url.Values{
|
||||
"ID": {internal},
|
||||
}
|
||||
|
||||
if addrMode == vault.SplitMode {
|
||||
filter["AddressID"] = []string{addrID}
|
||||
filter["AddressID"] = []string{authAddrID}
|
||||
}
|
||||
|
||||
metadata, err := client.GetAllMessageMetadata(ctx, filter)
|
||||
@ -359,12 +342,12 @@ func getParentID(
|
||||
// If no parent was found, try to find it in the last external reference.
|
||||
// There can be multiple messages with the same external ID; in this case, we don't pick any parent.
|
||||
if parentID == "" && len(external) > 0 {
|
||||
filter := map[string][]string{
|
||||
filter := url.Values{
|
||||
"ExternalID": {external[len(external)-1]},
|
||||
}
|
||||
|
||||
if addrMode == vault.SplitMode {
|
||||
filter["AddressID"] = []string{addrID}
|
||||
filter["AddressID"] = []string{authAddrID}
|
||||
}
|
||||
|
||||
metadata, err := client.GetAllMessageMetadata(ctx, filter)
|
||||
|
||||
@ -2,131 +2,174 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/bradenaw/juniper/iterator"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const chunkSize = 1 << 20
|
||||
const (
|
||||
maxUpdateSize = 1 << 25
|
||||
maxBatchSize = 1 << 8
|
||||
)
|
||||
|
||||
func (user *User) syncLabels(ctx context.Context, addrIDs ...string) error {
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
if !user.vault.SyncStatus().HasLabels {
|
||||
if err := syncLabels(ctx, user.client, maps.Values(user.updateCh)...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasLabels(true); err != nil {
|
||||
return fmt.Errorf("failed to set has labels: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !user.vault.SyncStatus().HasMessages {
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
if err := user.vault.SetHasMessages(true); err != nil {
|
||||
return fmt.Errorf("failed to set has messages: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncLabels(ctx context.Context, client *liteapi.Client, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
||||
// Sync the system folders.
|
||||
system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem)
|
||||
system, err := client.GetLabels(ctx, liteapi.LabelTypeSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get system labels: %w", err)
|
||||
}
|
||||
|
||||
for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) {
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
|
||||
for _, updateCh := range updateCh {
|
||||
updateCh.Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
|
||||
for _, prefix := range []string{folderPrefix, labelPrefix} {
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
|
||||
for _, updateCh := range updateCh {
|
||||
updateCh.Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
|
||||
}
|
||||
}
|
||||
|
||||
// Sync the API folders.
|
||||
folders, err := user.client.GetLabels(ctx, liteapi.LabelTypeFolder)
|
||||
folders, err := client.GetLabels(ctx, liteapi.LabelTypeFolder)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get folders: %w", err)
|
||||
}
|
||||
|
||||
for _, folder := range folders {
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}))
|
||||
for _, updateCh := range updateCh {
|
||||
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}))
|
||||
}
|
||||
}
|
||||
|
||||
// Sync the API labels.
|
||||
labels, err := user.client.GetLabels(ctx, liteapi.LabelTypeLabel)
|
||||
labels, err := client.GetLabels(ctx, liteapi.LabelTypeLabel)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get labels: %w", err)
|
||||
}
|
||||
|
||||
for _, label := range labels {
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}))
|
||||
for _, updateCh := range updateCh {
|
||||
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}))
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all label updates to be applied.
|
||||
for _, updateCh := range updateCh {
|
||||
update := imap.NewNoop()
|
||||
defer update.WaitContext(ctx)
|
||||
|
||||
updateCh.Enqueue(update)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) syncMessages(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Determine which messages to sync.
|
||||
// TODO: This needs to be done better using the new API route to retrieve just the message IDs.
|
||||
metadata, err := user.client.GetAllMessageMetadata(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("get all message metadata: %w", err)
|
||||
}
|
||||
|
||||
// If in split mode, we need to send each message to a different IMAP connector.
|
||||
isSplitMode := user.vault.AddressMode() == vault.SplitMode
|
||||
|
||||
// Collect the build requests -- we need:
|
||||
// - the message ID to build,
|
||||
// - the keyring to decrypt the message,
|
||||
// - and the address to send the message to (for split mode).
|
||||
requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request {
|
||||
var addressID string
|
||||
|
||||
if isSplitMode {
|
||||
addressID = metadata.AddressID
|
||||
} else {
|
||||
addressID = user.apiAddrs.primary()
|
||||
// If possible, begin syncing from the last synced message.
|
||||
if beginID := user.vault.SyncStatus().LastMessageID; beginID != "" {
|
||||
if idx := xslices.IndexFunc(metadata, func(metadata liteapi.MessageMetadata) bool {
|
||||
return metadata.ID == beginID
|
||||
}); idx >= 0 {
|
||||
metadata = metadata[idx:]
|
||||
}
|
||||
}
|
||||
|
||||
return request{
|
||||
messageID: metadata.ID,
|
||||
addressID: addressID,
|
||||
addrKR: user.addrKRs[metadata.AddressID],
|
||||
}
|
||||
})
|
||||
// Process the metadata, building the messages.
|
||||
buildCh := stream.Chunk(parallel.MapStream(
|
||||
ctx,
|
||||
stream.FromIterator(iterator.Slice(metadata)),
|
||||
runtime.NumCPU()*runtime.NumCPU()/2,
|
||||
runtime.NumCPU()*runtime.NumCPU()/2,
|
||||
user.buildRFC822,
|
||||
), maxBatchSize)
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher)
|
||||
|
||||
for addrID, updateCh := range user.updateCh {
|
||||
flusher := newFlusher(user.ID(), updateCh, user.eventCh, len(requests), chunkSize)
|
||||
defer flusher.flush()
|
||||
flusher := newFlusher(user.ID(), updateCh, maxUpdateSize)
|
||||
defer flusher.flush(ctx, true)
|
||||
|
||||
flushers[addrID] = flusher
|
||||
}
|
||||
|
||||
// Build the messages and send them to the correct flusher.
|
||||
if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build message %s: %w", req.messageID, err)
|
||||
// Create a reporter to report sync progress updates.
|
||||
reporter := newReporter(user.ID(), user.eventCh, len(metadata), time.Second)
|
||||
defer reporter.done()
|
||||
|
||||
// Send each update to the appropriate flusher.
|
||||
for {
|
||||
batch, err := buildCh.Next(ctx)
|
||||
if errors.Is(err, stream.End) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to get next sync batch: %w", err)
|
||||
}
|
||||
|
||||
flushers[req.addressID].push(res)
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
for _, res := range batch {
|
||||
if len(flushers) > 1 {
|
||||
flushers[res.addressID].push(ctx, res.update)
|
||||
} else {
|
||||
flushers[apiAddrs[0].ID].push(ctx, res.update)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to build messages: %w", err)
|
||||
}
|
||||
for _, flusher := range flushers {
|
||||
flusher.flush(ctx, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
if err := user.vault.SetLastMessageID(batch[len(batch)-1].messageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
|
||||
func (user *User) syncWait() {
|
||||
for _, updateCh := range user.updateCh {
|
||||
waiter := imap.NewNoop()
|
||||
defer waiter.Wait()
|
||||
|
||||
updateCh.Enqueue(waiter)
|
||||
reporter.add(len(batch))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
88
internal/user/sync_build.go
Normal file
88
internal/user/sync_build.go
Normal file
@ -0,0 +1,88 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type buildRes struct {
|
||||
messageID string
|
||||
addressID string
|
||||
update *imap.MessageCreated
|
||||
}
|
||||
|
||||
func defaultJobOpts() message.JobOptions {
|
||||
return message.JobOptions{
|
||||
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
|
||||
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
|
||||
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
|
||||
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
|
||||
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
|
||||
AddMessageIDReference: true, // Whether to include the MessageID in References.
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) buildRFC822(ctx context.Context, metadata liteapi.MessageMetadata) (*buildRes, error) {
|
||||
msg, err := user.client.GetMessage(ctx, metadata.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get message %s: %w", metadata.ID, err)
|
||||
}
|
||||
|
||||
attData, err := user.attPool.ProcessAll(ctx, xslices.Map(msg.Attachments, func(att liteapi.Attachment) string { return att.ID }))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attachments for message %s: %w", metadata.ID, err)
|
||||
}
|
||||
|
||||
literal, err := message.BuildRFC822(user.addrKRs[msg.AddressID], msg, attData, defaultJobOpts())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build message %s: %w", metadata.ID, err)
|
||||
}
|
||||
|
||||
update, err := newMessageCreatedUpdate(metadata, literal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP update for message %s: %w", metadata.ID, err)
|
||||
}
|
||||
|
||||
return &buildRes{
|
||||
messageID: metadata.ID,
|
||||
addressID: metadata.AddressID,
|
||||
update: update,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newMessageCreatedUpdate(message liteapi.MessageMetadata, literal []byte) (*imap.MessageCreated, error) {
|
||||
parsedMessage, err := imap.NewParsedMessage(literal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flags := imap.NewFlagSet()
|
||||
|
||||
if !message.Unread {
|
||||
flags = flags.Add(imap.FlagSeen)
|
||||
}
|
||||
|
||||
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
|
||||
flags = flags.Add(imap.FlagFlagged)
|
||||
}
|
||||
|
||||
imapMessage := imap.Message{
|
||||
ID: imap.MessageID(message.ID),
|
||||
Flags: flags,
|
||||
Date: time.Unix(message.Time, 0),
|
||||
}
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: imapMessage,
|
||||
Literal: literal,
|
||||
LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)),
|
||||
ParsedMessage: parsedMessage,
|
||||
}, nil
|
||||
}
|
||||
56
internal/user/sync_flusher.go
Normal file
56
internal/user/sync_flusher.go
Normal file
@ -0,0 +1,56 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
)
|
||||
|
||||
type flusher struct {
|
||||
userID string
|
||||
updateCh *queue.QueuedChannel[imap.Update]
|
||||
|
||||
updates []*imap.MessageCreated
|
||||
maxChunkSize int
|
||||
curChunkSize int
|
||||
|
||||
pushLock sync.Mutex
|
||||
}
|
||||
|
||||
func newFlusher(userID string, updateCh *queue.QueuedChannel[imap.Update], maxChunkSize int) *flusher {
|
||||
return &flusher{
|
||||
userID: userID,
|
||||
updateCh: updateCh,
|
||||
maxChunkSize: maxChunkSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) push(ctx context.Context, update *imap.MessageCreated) {
|
||||
f.pushLock.Lock()
|
||||
defer f.pushLock.Unlock()
|
||||
|
||||
f.updates = append(f.updates, update)
|
||||
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
|
||||
f.flush(ctx, false)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) flush(ctx context.Context, wait bool) {
|
||||
if len(f.updates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
|
||||
if wait {
|
||||
update := imap.NewNoop()
|
||||
defer update.WaitContext(ctx)
|
||||
|
||||
f.updateCh.Enqueue(update)
|
||||
}
|
||||
}
|
||||
55
internal/user/sync_reporter.go
Normal file
55
internal/user/sync_reporter.go
Normal file
@ -0,0 +1,55 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
)
|
||||
|
||||
type reporter struct {
|
||||
userID string
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
start time.Time
|
||||
total int
|
||||
count int
|
||||
|
||||
last time.Time
|
||||
freq time.Duration
|
||||
}
|
||||
|
||||
func newReporter(userID string, eventCh *queue.QueuedChannel[events.Event], total int, freq time.Duration) *reporter {
|
||||
return &reporter{
|
||||
userID: userID,
|
||||
eventCh: eventCh,
|
||||
|
||||
start: time.Now(),
|
||||
total: total,
|
||||
freq: freq,
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *reporter) add(delta int) {
|
||||
rep.count += delta
|
||||
|
||||
if time.Since(rep.last) > rep.freq {
|
||||
rep.eventCh.Enqueue(events.SyncProgress{
|
||||
UserID: rep.userID,
|
||||
Progress: float64(rep.count) / float64(rep.total),
|
||||
Elapsed: time.Since(rep.start),
|
||||
Remaining: time.Since(rep.start) * time.Duration(rep.total-(rep.count+1)) / time.Duration(rep.count+1),
|
||||
})
|
||||
|
||||
rep.last = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *reporter) done() {
|
||||
rep.eventCh.Enqueue(events.SyncProgress{
|
||||
UserID: rep.userID,
|
||||
Progress: 1,
|
||||
Elapsed: time.Since(rep.start),
|
||||
Remaining: 0,
|
||||
})
|
||||
}
|
||||
@ -1,7 +1,9 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
@ -13,114 +15,125 @@ import (
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/safe"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultEventPeriod = 20 * time.Second
|
||||
DefaultEventJitter = 20 * time.Second
|
||||
EventPeriod = 20 * time.Second
|
||||
EventJitter = 20 * time.Second
|
||||
)
|
||||
|
||||
type User struct {
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
builder *pool.Pool[request, *imap.MessageCreated]
|
||||
attPool *pool.Pool[string, []byte]
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
apiUser liteapi.User
|
||||
apiAddrs *addrList
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
settings liteapi.MailSettings
|
||||
apiUser *safe.Type[liteapi.User]
|
||||
apiAddrs *safe.Slice[liteapi.Address]
|
||||
settings *safe.Type[liteapi.MailSettings]
|
||||
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
||||
syncWG wait.Group
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
||||
syncStopCh chan struct{}
|
||||
syncWG wait.Group
|
||||
}
|
||||
|
||||
func New(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
) (*User, error) {
|
||||
func New(ctx context.Context, encVault *vault.User, client *liteapi.Client, apiUser liteapi.User) (*User, error) {
|
||||
// Get the user's API addresses.
|
||||
apiAddrs, err := client.GetAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses: %w", err)
|
||||
}
|
||||
|
||||
// Unlock the user's keyrings.
|
||||
userKR, addrKRs, err := liteapi.Unlock(apiUser, apiAddrs, encVault.KeyPass())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unlock user: %w", err)
|
||||
}
|
||||
|
||||
// Get the latest event ID.
|
||||
if encVault.EventID() == "" {
|
||||
eventID, err := client.GetLatestEventID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get latest event ID: %w", err)
|
||||
}
|
||||
|
||||
if err := encVault.SetEventID(eventID); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to set event ID: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the user's mail settings.
|
||||
settings, err := client.GetMailSettings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to get mail settings: %w", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
vault: encVault,
|
||||
client: client,
|
||||
builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
// Create update channels for each of the user's addresses (if in combined mode, just the primary).
|
||||
updateCh := make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
apiUser: apiUser,
|
||||
apiAddrs: newAddrList(apiAddrs),
|
||||
for _, addr := range apiAddrs {
|
||||
updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
settings: settings,
|
||||
|
||||
updateCh: make(map[string]*queue.QueuedChannel[imap.Update]),
|
||||
}
|
||||
|
||||
// Initialize update channels for each of the user's addresses.
|
||||
for _, addrID := range user.apiAddrs.addrIDs() {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
// If in combined mode, we only need one update channel.
|
||||
if encVault.AddressMode() == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the store.
|
||||
user := &User{
|
||||
vault: encVault,
|
||||
client: client,
|
||||
attPool: pool.New(runtime.NumCPU(), client.GetAttachment),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
apiUser: safe.NewType(apiUser),
|
||||
apiAddrs: safe.NewSlice(apiAddrs),
|
||||
settings: safe.NewType(settings),
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
|
||||
updateCh: updateCh,
|
||||
syncStopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the vault.
|
||||
// This will be used to authorize the user on the next run.
|
||||
client.AddAuthHandler(func(auth liteapi.Auth) {
|
||||
if err := user.vault.SetAuth(auth.UID, auth.RefreshToken); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update auth")
|
||||
logrus.WithError(err).Error("Failed to update auth in vault")
|
||||
}
|
||||
})
|
||||
|
||||
// When we are deauthorized, we send a deauth event to the notify channel.
|
||||
// Bridge will catch this and log the user out.
|
||||
// When we are deauthorized, we send a deauth event to the event channel.
|
||||
// Bridge will react to this event by logging out the user.
|
||||
client.AddDeauthHandler(func() {
|
||||
user.eventCh.Enqueue(events.UserDeauth{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
})
|
||||
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
// If we haven't synced yet, do it first.
|
||||
// If it fails, we don't start the event loop.
|
||||
// Oterwise, begin processing API events, logging any errors that occur.
|
||||
go func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
logrus.WithError(err).Error("Failed to handle event")
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update event ID")
|
||||
if status := user.vault.SyncStatus(); !status.HasMessages {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for err := range user.streamEvents() {
|
||||
logrus.WithError(err).Error("Error while streaming events")
|
||||
}
|
||||
}()
|
||||
|
||||
return user, nil
|
||||
@ -128,30 +141,44 @@ func New(
|
||||
|
||||
// ID returns the user's ID.
|
||||
func (user *User) ID() string {
|
||||
return user.apiUser.ID
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return apiUser.ID
|
||||
})
|
||||
}
|
||||
|
||||
// Name returns the user's username.
|
||||
func (user *User) Name() string {
|
||||
return user.apiUser.Name
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) string {
|
||||
return apiUser.Name
|
||||
})
|
||||
}
|
||||
|
||||
// Match matches the given query against the user's username and email addresses.
|
||||
func (user *User) Match(query string) bool {
|
||||
if query == user.apiUser.Name {
|
||||
return true
|
||||
}
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) bool {
|
||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) bool {
|
||||
if query == apiUser.Name {
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := user.apiAddrs.addrID(query); ok {
|
||||
return true
|
||||
}
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == query {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return false
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Emails returns all the user's email addresses.
|
||||
func (user *User) Emails() []string {
|
||||
return user.apiAddrs.emails()
|
||||
return safe.GetSlice(user.apiAddrs, func(apiAddrs []liteapi.Address) []string {
|
||||
return xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// GetAddressMode returns the user's current address mode.
|
||||
@ -167,18 +194,32 @@ func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) er
|
||||
|
||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
for _, addrID := range user.apiAddrs.addrIDs() {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
user.apiAddrs.Get(func(apiAddrs []liteapi.Address) {
|
||||
for _, addr := range apiAddrs {
|
||||
user.updateCh[addr.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if mode == vault.CombinedMode {
|
||||
break
|
||||
if mode == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
user.stopSync()
|
||||
|
||||
if err := user.vault.ClearSyncStatus(); err != nil {
|
||||
return fmt.Errorf("failed to clear sync status: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := <-user.startSync(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to sync after setting address mode")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -209,68 +250,27 @@ func (user *User) GluonKey() []byte {
|
||||
|
||||
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
|
||||
func (user *User) BridgePass() []byte {
|
||||
return user.vault.BridgePass()
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
if _, err := hex.NewEncoder(buf).Write(user.vault.BridgePass()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// UsedSpace returns the total space used by the user on the API.
|
||||
func (user *User) UsedSpace() int {
|
||||
return user.apiUser.UsedSpace
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return apiUser.UsedSpace
|
||||
})
|
||||
}
|
||||
|
||||
// MaxSpace returns the amount of space the user can use on the API.
|
||||
func (user *User) MaxSpace() int {
|
||||
return user.apiUser.MaxSpace
|
||||
}
|
||||
|
||||
// HasSync returns whether the user has finished syncing.
|
||||
func (user *User) HasSync() bool {
|
||||
return user.vault.HasSync()
|
||||
}
|
||||
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// TODO: This should abort the sync rather than just waiting.
|
||||
// Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) AbortSync(ctx context.Context) error {
|
||||
user.syncWG.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoSync performs a sync for the user.
|
||||
func (user *User) DoSync(ctx context.Context) <-chan error {
|
||||
errCh := queue.NewQueuedChannel[error](0, 0)
|
||||
|
||||
user.syncWG.Go(func() {
|
||||
defer errCh.Close()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
errCh.Enqueue(func() error {
|
||||
if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
user.syncWait()
|
||||
|
||||
if err := user.vault.SetSync(true); err != nil {
|
||||
return fmt.Errorf("failed to set sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}())
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
return safe.GetType(user.apiUser, func(apiUser liteapi.User) int {
|
||||
return apiUser.MaxSpace
|
||||
})
|
||||
|
||||
return errCh.GetChannel()
|
||||
}
|
||||
|
||||
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change)
|
||||
@ -281,31 +281,35 @@ func (user *User) GetEventCh() <-chan events.Event {
|
||||
// NewIMAPConnector returns an IMAP connector for the given address.
|
||||
// If not in split mode, this function returns an error.
|
||||
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
||||
var emails []string
|
||||
return safe.GetSliceErr(user.apiAddrs, func(apiAddrs []liteapi.Address) (connector.Connector, error) {
|
||||
var emails []string
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
if addrID != user.apiAddrs.primary() {
|
||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
if addrID != apiAddrs[0].ID {
|
||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
||||
}
|
||||
|
||||
emails = xslices.Map(apiAddrs, func(addr liteapi.Address) string {
|
||||
return addr.Email
|
||||
})
|
||||
|
||||
case vault.SplitMode:
|
||||
email, err := getAddrEmail(apiAddrs, addrID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emails = []string{email}
|
||||
}
|
||||
|
||||
emails = user.apiAddrs.emails()
|
||||
|
||||
case vault.SplitMode:
|
||||
email, ok := user.apiAddrs.email(addrID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("address %s not found", addrID)
|
||||
}
|
||||
|
||||
emails = []string{email}
|
||||
}
|
||||
|
||||
return newIMAPConnector(
|
||||
user.client,
|
||||
user.updateCh[addrID].GetChannel(),
|
||||
user.vault.BridgePass(),
|
||||
emails...,
|
||||
), nil
|
||||
return newIMAPConnector(
|
||||
user.client,
|
||||
user.updateCh[addrID].GetChannel(),
|
||||
user.BridgePass(),
|
||||
emails...,
|
||||
), nil
|
||||
})
|
||||
}
|
||||
|
||||
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
||||
@ -328,22 +332,7 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
|
||||
// NewSMTPSession returns an SMTP session for the user.
|
||||
func (user *User) NewSMTPSession(email string) (smtp.Session, error) {
|
||||
addrID, ok := user.apiAddrs.addrID(email)
|
||||
if !ok {
|
||||
return nil, ErrNoSuchAddress
|
||||
}
|
||||
|
||||
return newSMTPSession(
|
||||
user.client,
|
||||
user.eventCh,
|
||||
user.apiUser.ID,
|
||||
addrID,
|
||||
user.vault.AddressMode(),
|
||||
user.apiAddrs.addrMap(),
|
||||
user.settings,
|
||||
user.userKR,
|
||||
user.addrKRs,
|
||||
), nil
|
||||
return newSMTPSession(user, email)
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
@ -352,12 +341,12 @@ func (user *User) Logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
func (user *User) Close(ctx context.Context) error {
|
||||
// Wait for ongoing syncs to finish.
|
||||
user.syncWG.Wait()
|
||||
func (user *User) Close() error {
|
||||
// Cancel ongoing syncs.
|
||||
user.stopSync()
|
||||
|
||||
// Close the user's message builder.
|
||||
user.builder.Done()
|
||||
// Close the attachment pool.
|
||||
user.attPool.Done()
|
||||
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
@ -372,3 +361,104 @@ func (user *User) Close(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamEvents begins streaming API events for the user.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
func (user *User) streamEvents() <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
for event := range user.client.NewEventStreamer(EventPeriod, EventJitter, user.vault.EventID()).Subscribe() {
|
||||
if err := user.handleAPIEvent(context.Background(), event); err != nil {
|
||||
errCh <- fmt.Errorf("failed to handle API event: %w", err)
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
errCh <- fmt.Errorf("failed to update event ID: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// startSync begins a startSync for the user.
|
||||
func (user *User) startSync() <-chan error {
|
||||
errCh := make(chan error)
|
||||
|
||||
user.syncWG.Go(func() {
|
||||
defer close(errCh)
|
||||
|
||||
ctx, cancel := contextWithStopCh(context.Background(), user.syncStopCh)
|
||||
defer cancel()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
if err := user.sync(ctx); err != nil {
|
||||
user.eventCh.Enqueue(events.SyncFailed{
|
||||
UserID: user.ID(),
|
||||
Err: err,
|
||||
})
|
||||
|
||||
errCh <- err
|
||||
} else {
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// TODO: Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) stopSync() {
|
||||
select {
|
||||
case user.syncStopCh <- struct{}{}:
|
||||
user.syncWG.Wait()
|
||||
|
||||
default:
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func getAddrID(apiAddrs []liteapi.Address, email string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.Email == email {
|
||||
return addr.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", email)
|
||||
}
|
||||
|
||||
func getAddrEmail(apiAddrs []liteapi.Address, addrID string) (string, error) {
|
||||
for _, addr := range apiAddrs {
|
||||
if addr.ID == addrID {
|
||||
return addr.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("address %s not found", addrID)
|
||||
}
|
||||
|
||||
// contextWithStopCh returns a new context that is cancelled when the stop channel is closed or a value is sent to it.
|
||||
func contextWithStopCh(ctx context.Context, stopCh <-chan struct{}) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-stopCh:
|
||||
cancel()
|
||||
|
||||
case <-ctx.Done():
|
||||
// ...
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
@ -17,117 +17,128 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
user.DefaultEventPeriod = 100 * time.Millisecond
|
||||
user.DefaultEventJitter = 0
|
||||
user.EventPeriod = 100 * time.Millisecond
|
||||
user.EventJitter = 0
|
||||
backend.GenerateKey = tests.FastGenerateKey
|
||||
certs.GenerateCert = tests.FastGenerateCert
|
||||
}
|
||||
|
||||
func TestUser_Data(t *testing.T) {
|
||||
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
|
||||
// User's ID should be correct.
|
||||
require.Equal(t, userID, user.ID())
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *user.User) {
|
||||
// User's ID should be correct.
|
||||
require.Equal(t, userID, user.ID())
|
||||
|
||||
// User's name should be correct.
|
||||
require.Equal(t, "username", user.Name())
|
||||
// User's name should be correct.
|
||||
require.Equal(t, "username", user.Name())
|
||||
|
||||
// User's email should be correct.
|
||||
require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails())
|
||||
// User's email should be correct.
|
||||
require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails())
|
||||
|
||||
// By default, user should be in combined mode.
|
||||
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
|
||||
// By default, user should be in combined mode.
|
||||
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
|
||||
|
||||
// By default, user should have a non-empty bridge password.
|
||||
require.NotEmpty(t, user.BridgePass())
|
||||
// By default, user should have a non-empty bridge password.
|
||||
require.NotEmpty(t, user.BridgePass())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_Sync(t *testing.T) {
|
||||
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
|
||||
// Get the user's IMAP connectors.
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
require.NoError(t, err)
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *user.User) {
|
||||
// User starts a sync at startup.
|
||||
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
|
||||
|
||||
// Pretend to be gluon applying all the updates.
|
||||
go func() {
|
||||
for _, imapConn := range imapConn {
|
||||
for update := range imapConn.GetUpdates() {
|
||||
update.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
// User sends sync progress.
|
||||
require.IsType(t, events.SyncProgress{}, <-user.GetEventCh())
|
||||
|
||||
// Trigger a user sync.
|
||||
errCh := user.DoSync(ctx)
|
||||
|
||||
// User starts a sync at startup.
|
||||
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
|
||||
|
||||
// User finishes a sync at startup.
|
||||
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
|
||||
|
||||
// The sync completes without error.
|
||||
require.NoError(t, <-errCh)
|
||||
// User finishes a sync at startup.
|
||||
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_Deauth(t *testing.T) {
|
||||
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
|
||||
eventCh := user.GetEventCh()
|
||||
withAPI(t, context.Background(), func(ctx context.Context, s *server.Server, m *liteapi.Manager) {
|
||||
withAccount(t, s, "username", "password", []string{"email@pm.me"}, func(userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s, m, "username", "password", func(user *user.User) {
|
||||
eventCh := user.GetEventCh()
|
||||
|
||||
// Revoke the user's auth token.
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
// Revoke the user's auth token.
|
||||
require.NoError(t, s.RevokeUser(user.ID()))
|
||||
|
||||
// The user should eventually be logged out.
|
||||
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond)
|
||||
// The user should eventually be logged out.
|
||||
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func withAPI(t *testing.T, ctx context.Context, username, password string, emails []string, fn func(context.Context, *server.Server, string, []string)) {
|
||||
func withAPI(t *testing.T, ctx context.Context, fn func(context.Context, *server.Server, *liteapi.Manager)) {
|
||||
server := server.New()
|
||||
defer server.Close()
|
||||
|
||||
fn(ctx, server, liteapi.New(liteapi.WithHostURL(server.GetHostURL())))
|
||||
}
|
||||
|
||||
func withAccount(t *testing.T, s *server.Server, username, password string, emails []string, fn func(string, []string)) {
|
||||
var addrIDs []string
|
||||
|
||||
userID, addrID, err := server.CreateUser(username, password, emails[0])
|
||||
userID, addrID, err := s.CreateUser(username, password, emails[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
addrIDs = append(addrIDs, addrID)
|
||||
|
||||
for _, email := range emails[1:] {
|
||||
addrID, err := server.CreateAddress(userID, email, password)
|
||||
addrID, err := s.CreateAddress(userID, email, password)
|
||||
require.NoError(t, err)
|
||||
|
||||
addrIDs = append(addrIDs, addrID)
|
||||
}
|
||||
|
||||
fn(ctx, server, userID, addrIDs)
|
||||
fn(userID, addrIDs)
|
||||
}
|
||||
|
||||
func withUser(t *testing.T, ctx context.Context, apiURL, username, password string, fn func(*user.User)) {
|
||||
c, apiAuth, err := liteapi.New(liteapi.WithHostURL(apiURL)).NewClientWithLogin(ctx, username, []byte(password))
|
||||
func withUser(t *testing.T, ctx context.Context, s *server.Server, m *liteapi.Manager, username, password string, fn func(*user.User)) {
|
||||
client, apiAuth, err := m.NewClientWithLogin(ctx, username, []byte(password))
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, c.Close()) }()
|
||||
defer func() { require.NoError(t, client.Close()) }()
|
||||
|
||||
apiUser, apiAddrs, userKR, addrKRs, passphrase, err := c.Unlock(ctx, []byte(password))
|
||||
apiUser, err := client.GetUser(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
salts, err := client.GetSalts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, passphrase)
|
||||
vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := user.New(ctx, vaultUser, c, apiUser, apiAddrs, userKR, addrKRs)
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, user.Close(ctx)) }()
|
||||
defer func() { require.NoError(t, user.Close()) }()
|
||||
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
for _, imapConn := range imapConn {
|
||||
for update := range imapConn.GetUpdates() {
|
||||
update.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fn(user)
|
||||
}
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
)
|
||||
|
||||
@ -18,12 +16,3 @@ func newRandomToken(size int) []byte {
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func newRandomString(size int) []byte {
|
||||
token, err := RandomToken(size)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return []byte(hex.EncodeToString(token))
|
||||
}
|
||||
|
||||
@ -46,33 +46,6 @@ type Settings struct {
|
||||
FirstStartGUI bool
|
||||
}
|
||||
|
||||
type AddressMode int
|
||||
|
||||
const (
|
||||
CombinedMode AddressMode = iota
|
||||
SplitMode
|
||||
)
|
||||
|
||||
// UserData holds information about a single bridge user.
|
||||
// The user may or may not be logged in.
|
||||
type UserData struct {
|
||||
UserID string
|
||||
Username string
|
||||
|
||||
GluonKey []byte
|
||||
GluonIDs map[string]string
|
||||
UIDValidity map[string]imap.UID
|
||||
BridgePass []byte
|
||||
AddressMode AddressMode
|
||||
|
||||
AuthUID string
|
||||
AuthRef string
|
||||
KeyPass []byte
|
||||
|
||||
EventID string
|
||||
HasSync bool
|
||||
}
|
||||
|
||||
func newDefaultSettings(gluonDir string) Settings {
|
||||
return Settings{
|
||||
GluonDir: gluonDir,
|
||||
@ -96,3 +69,53 @@ func newDefaultSettings(gluonDir string) Settings {
|
||||
FirstStartGUI: true,
|
||||
}
|
||||
}
|
||||
|
||||
// UserData holds information about a single bridge user.
|
||||
// The user may or may not be logged in.
|
||||
type UserData struct {
|
||||
UserID string
|
||||
Username string
|
||||
|
||||
GluonKey []byte
|
||||
GluonIDs map[string]string
|
||||
UIDValidity map[string]imap.UID
|
||||
BridgePass []byte
|
||||
AddressMode AddressMode
|
||||
|
||||
AuthUID string
|
||||
AuthRef string
|
||||
KeyPass []byte
|
||||
|
||||
SyncStatus SyncStatus
|
||||
EventID string
|
||||
}
|
||||
|
||||
type AddressMode int
|
||||
|
||||
const (
|
||||
CombinedMode AddressMode = iota
|
||||
SplitMode
|
||||
)
|
||||
|
||||
type SyncStatus struct {
|
||||
HasLabels bool
|
||||
HasMessages bool
|
||||
LastMessageID string
|
||||
}
|
||||
|
||||
func newDefaultUser(userID, username, authUID, authRef string, keyPass []byte) UserData {
|
||||
return UserData{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
|
||||
GluonKey: newRandomToken(32),
|
||||
GluonIDs: make(map[string]string),
|
||||
UIDValidity: make(map[string]imap.UID),
|
||||
BridgePass: newRandomToken(16),
|
||||
AddressMode: CombinedMode,
|
||||
|
||||
AuthUID: authUID,
|
||||
AuthRef: authRef,
|
||||
KeyPass: keyPass,
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,6 +17,11 @@ func (user *User) Username() string {
|
||||
return user.vault.getUser(user.userID).Username
|
||||
}
|
||||
|
||||
// GluonKey returns the key needed to decrypt the user's gluon database.
|
||||
func (user *User) GluonKey() []byte {
|
||||
return user.vault.getUser(user.userID).GluonKey
|
||||
}
|
||||
|
||||
func (user *User) GetGluonIDs() map[string]string {
|
||||
return user.vault.getUser(user.userID).GluonIDs
|
||||
}
|
||||
@ -42,44 +47,33 @@ func (user *User) SetUIDValidity(addrID string, validity imap.UID) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (user *User) GluonKey() []byte {
|
||||
return user.vault.getUser(user.userID).GluonKey
|
||||
}
|
||||
|
||||
// AddressMode returns the user's address mode.
|
||||
func (user *User) AddressMode() AddressMode {
|
||||
return user.vault.getUser(user.userID).AddressMode
|
||||
}
|
||||
|
||||
// SetAddressMode sets the address mode for the given user.
|
||||
func (user *User) SetAddressMode(mode AddressMode) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.AddressMode = mode
|
||||
})
|
||||
}
|
||||
|
||||
// BridgePass returns the user's bridge password (unencoded).
|
||||
func (user *User) BridgePass() []byte {
|
||||
return user.vault.getUser(user.userID).BridgePass
|
||||
}
|
||||
|
||||
// AuthUID returns the user's auth UID.
|
||||
func (user *User) AuthUID() string {
|
||||
return user.vault.getUser(user.userID).AuthUID
|
||||
}
|
||||
|
||||
// AuthRef returns the user's auth refresh token.
|
||||
func (user *User) AuthRef() string {
|
||||
return user.vault.getUser(user.userID).AuthRef
|
||||
}
|
||||
|
||||
func (user *User) KeyPass() []byte {
|
||||
return user.vault.getUser(user.userID).KeyPass
|
||||
}
|
||||
|
||||
func (user *User) EventID() string {
|
||||
return user.vault.getUser(user.userID).EventID
|
||||
}
|
||||
|
||||
func (user *User) HasSync() bool {
|
||||
return user.vault.getUser(user.userID).HasSync
|
||||
}
|
||||
|
||||
func (user *User) SetKeyPass(keyPass []byte) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.KeyPass = keyPass
|
||||
})
|
||||
}
|
||||
|
||||
// SetAuth sets the auth secrets for the given user.
|
||||
func (user *User) SetAuth(authUID, authRef string) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
@ -88,23 +82,59 @@ func (user *User) SetAuth(authUID, authRef string) error {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAddressMode sets the address mode for the given user.
|
||||
func (user *User) SetAddressMode(mode AddressMode) error {
|
||||
// KeyPass returns the user's (salted) key password.
|
||||
func (user *User) KeyPass() []byte {
|
||||
return user.vault.getUser(user.userID).KeyPass
|
||||
}
|
||||
|
||||
// SetKeyPass sets the user's (salted) key password.
|
||||
func (user *User) SetKeyPass(keyPass []byte) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.AddressMode = mode
|
||||
data.KeyPass = keyPass
|
||||
})
|
||||
}
|
||||
|
||||
// SyncStatus return's the user's sync status.
|
||||
func (user *User) SyncStatus() SyncStatus {
|
||||
return user.vault.getUser(user.userID).SyncStatus
|
||||
}
|
||||
|
||||
// SetHasLabels sets whether the user's labels have been synced.
|
||||
func (user *User) SetHasLabels(hasLabels bool) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.SyncStatus.HasLabels = hasLabels
|
||||
})
|
||||
}
|
||||
|
||||
// SetHasMessages sets whether the user's messages have been synced.
|
||||
func (user *User) SetHasMessages(hasMessages bool) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.SyncStatus.HasMessages = hasMessages
|
||||
})
|
||||
}
|
||||
|
||||
// SetLastMessageID sets the last synced message ID for the given user.
|
||||
func (user *User) SetLastMessageID(messageID string) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.SyncStatus.LastMessageID = messageID
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSyncStatus clears the user's sync status.
|
||||
func (user *User) ClearSyncStatus() error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.SyncStatus = SyncStatus{}
|
||||
})
|
||||
}
|
||||
|
||||
// EventID returns the last processed event ID of the user.
|
||||
func (user *User) EventID() string {
|
||||
return user.vault.getUser(user.userID).EventID
|
||||
}
|
||||
|
||||
// SetEventID sets the event ID for the given user.
|
||||
func (user *User) SetEventID(eventID string) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.EventID = eventID
|
||||
})
|
||||
}
|
||||
|
||||
// SetSync sets the sync state for the given user.
|
||||
func (user *User) SetSync(hasSync bool) error {
|
||||
return user.vault.modUser(user.userID, func(data *UserData) {
|
||||
data.HasSync = hasSync
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,14 +1,128 @@
|
||||
package vault_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUser_New(t *testing.T) {
|
||||
// Replace the token generator with a dummy one.
|
||||
vault.RandomToken = func(size int) ([]byte, error) {
|
||||
return []byte("token"), nil
|
||||
}
|
||||
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// There should be no users in the store.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
|
||||
// Create a new user.
|
||||
user, err := s.AddUser("userID", "username", "authUID", "authRef", []byte("keyPass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// The user should be listed in the store.
|
||||
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
|
||||
|
||||
// Check the user's default user information.
|
||||
require.Equal(t, "userID", user.UserID())
|
||||
require.Equal(t, "username", user.Username())
|
||||
|
||||
// Check the user's default auth information.
|
||||
require.Equal(t, "authUID", user.AuthUID())
|
||||
require.Equal(t, "authRef", user.AuthRef())
|
||||
require.Equal(t, "keyPass", string(user.KeyPass()))
|
||||
|
||||
// Check the user has a random bridge password and gluon key.
|
||||
require.Equal(t, "token", string(user.BridgePass()))
|
||||
require.Equal(t, "token", string(user.GluonKey()))
|
||||
|
||||
// Check the user's initial sync status.
|
||||
require.False(t, user.SyncStatus().HasLabels)
|
||||
require.False(t, user.SyncStatus().HasMessages)
|
||||
}
|
||||
|
||||
func TestUser_Clear(t *testing.T) {
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// Create a new user.
|
||||
user, err := s.AddUser("userID", "username", "authUID", "authRef", []byte("keyPass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the user's default auth information.
|
||||
require.Equal(t, "authUID", user.AuthUID())
|
||||
require.Equal(t, "authRef", user.AuthRef())
|
||||
require.Equal(t, "keyPass", string(user.KeyPass()))
|
||||
|
||||
// Clear the user's auth information.
|
||||
require.NoError(t, s.ClearUser("userID"))
|
||||
|
||||
// Check the user's cleared auth information.
|
||||
require.Empty(t, user.AuthUID())
|
||||
require.Empty(t, user.AuthRef())
|
||||
require.Empty(t, user.KeyPass())
|
||||
}
|
||||
|
||||
func TestUser_Delete(t *testing.T) {
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// The store should have no users.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
|
||||
// Create a new user.
|
||||
user, err := s.AddUser("userID", "username", "authUID", "authRef", []byte("keyPass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// The user should be listed in the store.
|
||||
require.ElementsMatch(t, []string{"userID"}, s.GetUserIDs())
|
||||
|
||||
// Clear the user's auth information.
|
||||
require.NoError(t, s.DeleteUser("userID"))
|
||||
|
||||
// The store should have no users again.
|
||||
require.Empty(t, s.GetUserIDs())
|
||||
|
||||
// Attempting to use the user should return an error.
|
||||
require.Panics(t, func() { _ = user.AddressMode() })
|
||||
}
|
||||
|
||||
func TestUser_SyncStatus(t *testing.T) {
|
||||
// Create a new test vault.
|
||||
s := newVault(t)
|
||||
|
||||
// Create a new user.
|
||||
user, err := s.AddUser("userID", "username", "authUID", "authRef", []byte("keyPass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the user's initial sync status.
|
||||
require.False(t, user.SyncStatus().HasLabels)
|
||||
require.False(t, user.SyncStatus().HasMessages)
|
||||
require.Empty(t, user.SyncStatus().LastMessageID)
|
||||
|
||||
// Simulate having synced a message.
|
||||
require.NoError(t, user.SetLastMessageID("test"))
|
||||
require.Equal(t, "test", user.SyncStatus().LastMessageID)
|
||||
|
||||
// Simulate finishing the sync.
|
||||
require.NoError(t, user.SetHasLabels(true))
|
||||
require.NoError(t, user.SetHasMessages(true))
|
||||
require.True(t, user.SyncStatus().HasLabels)
|
||||
require.True(t, user.SyncStatus().HasMessages)
|
||||
|
||||
// Clear the sync status.
|
||||
require.NoError(t, user.ClearSyncStatus())
|
||||
|
||||
// Check the user's cleared sync status.
|
||||
require.False(t, user.SyncStatus().HasLabels)
|
||||
require.False(t, user.SyncStatus().HasMessages)
|
||||
require.Empty(t, user.SyncStatus().LastMessageID)
|
||||
}
|
||||
|
||||
/*
|
||||
func TestUser(t *testing.T) {
|
||||
// Replace the token generator with a dummy one.
|
||||
vault.RandomToken = func(size int) ([]byte, error) {
|
||||
@ -101,3 +215,5 @@ func TestUser(t *testing.T) {
|
||||
// List available userIDs. User 1 should be gone.
|
||||
require.ElementsMatch(t, []string{"userID2"}, s.GetUserIDs())
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
)
|
||||
@ -100,20 +99,7 @@ func (vault *Vault) AddUser(userID, username, authUID, authRef string, keyPass [
|
||||
}
|
||||
|
||||
if err := vault.mod(func(data *Data) {
|
||||
data.Users = append(data.Users, UserData{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
|
||||
GluonKey: newRandomToken(32),
|
||||
GluonIDs: make(map[string]string),
|
||||
UIDValidity: make(map[string]imap.UID),
|
||||
BridgePass: newRandomString(16),
|
||||
AddressMode: CombinedMode,
|
||||
|
||||
AuthUID: authUID,
|
||||
AuthRef: authRef,
|
||||
KeyPass: keyPass,
|
||||
})
|
||||
data.Users = append(data.Users, newDefaultUser(userID, username, authUID, authRef, keyPass))
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user