forked from Silverfish/proton-bridge
206 lines
6.5 KiB
Go
206 lines
6.5 KiB
Go
// Copyright (c) 2022 Proton AG
|
|
//
|
|
// This file is part of Proton Mail Bridge.
|
|
//
|
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package bridge_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/ProtonMail/go-proton-api"
|
|
"github.com/ProtonMail/go-proton-api/server"
|
|
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
|
|
"github.com/ProtonMail/proton-bridge/v3/internal/constants"
|
|
"github.com/ProtonMail/proton-bridge/v3/internal/events"
|
|
"github.com/bradenaw/juniper/iterator"
|
|
"github.com/bradenaw/juniper/stream"
|
|
"github.com/emersion/go-imap/client"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestBridge_Sync(t *testing.T) {
|
|
numMsg := 1 << 8
|
|
|
|
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
|
|
userID, addrID, err := s.CreateUser("imap", "imap@pm.me", password)
|
|
require.NoError(t, err)
|
|
|
|
labelID, err := s.CreateLabel(userID, "folder", "", proton.LabelTypeFolder)
|
|
require.NoError(t, err)
|
|
|
|
withClient(ctx, t, s, "imap", password, func(ctx context.Context, c *proton.Client) {
|
|
createMessages(ctx, t, c, addrID, labelID, numMsg)
|
|
})
|
|
|
|
var total uint64
|
|
|
|
// The initial user should be fully synced.
|
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
|
syncCh, done := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
|
|
defer done()
|
|
|
|
// Count how many bytes it takes to fully sync the user.
|
|
total = countBytesRead(netCtl, func() {
|
|
userID, err := bridge.LoginFull(ctx, "imap", password, nil, nil)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, userID, (<-syncCh).UserID)
|
|
})
|
|
})
|
|
|
|
// If we then connect an IMAP client, it should see all the messages.
|
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
|
|
info, err := b.GetUserInfo(userID)
|
|
require.NoError(t, err)
|
|
require.True(t, info.State == bridge.Connected)
|
|
|
|
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
|
require.NoError(t, err)
|
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
|
defer func() { _ = client.Logout() }()
|
|
|
|
status, err := client.Select(`Folders/folder`, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint32(numMsg), status.Messages)
|
|
})
|
|
|
|
// Now let's remove the user and simulate a network error.
|
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
|
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
|
})
|
|
|
|
// Pretend we can only sync 2/3 of the original messages.
|
|
netCtl.SetReadLimit(2 * total / 3)
|
|
|
|
// Login the user; its sync should fail.
|
|
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
|
|
{
|
|
syncCh, done := chToType[events.Event, events.SyncFailed](b.GetEvents(events.SyncFailed{}))
|
|
defer done()
|
|
|
|
userID, err := b.LoginFull(ctx, "imap", password, nil, nil)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, userID, (<-syncCh).UserID)
|
|
|
|
info, err := b.GetUserInfo(userID)
|
|
require.NoError(t, err)
|
|
require.True(t, info.State == bridge.Connected)
|
|
|
|
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
|
require.NoError(t, err)
|
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
|
defer func() { _ = client.Logout() }()
|
|
|
|
status, err := client.Select(`Folders/folder`, false)
|
|
require.NoError(t, err)
|
|
require.Less(t, status.Messages, uint32(numMsg))
|
|
}
|
|
|
|
// Remove the network limit, allowing the sync to finish.
|
|
netCtl.SetReadLimit(0)
|
|
|
|
{
|
|
syncCh, done := chToType[events.Event, events.SyncFinished](b.GetEvents(events.SyncFinished{}))
|
|
defer done()
|
|
|
|
require.Equal(t, userID, (<-syncCh).UserID)
|
|
|
|
info, err := b.GetUserInfo(userID)
|
|
require.NoError(t, err)
|
|
require.True(t, info.State == bridge.Connected)
|
|
|
|
client, err := client.Dial(fmt.Sprintf("%v:%v", constants.Host, b.GetIMAPPort()))
|
|
require.NoError(t, err)
|
|
require.NoError(t, client.Login("imap@pm.me", string(info.BridgePass)))
|
|
defer func() { _ = client.Logout() }()
|
|
|
|
status, err := client.Select(`Folders/folder`, false)
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint32(numMsg), status.Messages)
|
|
}
|
|
})
|
|
}, server.WithTLS(false))
|
|
}
|
|
|
|
func withClient(ctx context.Context, t *testing.T, s *server.Server, username string, password []byte, fn func(context.Context, *proton.Client)) {
|
|
m := proton.New(
|
|
proton.WithHostURL(s.GetHostURL()),
|
|
proton.WithTransport(proton.InsecureTransport()),
|
|
)
|
|
|
|
c, _, err := m.NewClientWithLogin(ctx, username, password)
|
|
require.NoError(t, err)
|
|
defer c.Close()
|
|
|
|
fn(ctx, c)
|
|
}
|
|
|
|
func createMessages(ctx context.Context, t *testing.T, c *proton.Client, addrID, labelID string, count int) {
|
|
literal, err := os.ReadFile(filepath.Join("testdata", "text-plain.eml"))
|
|
require.NoError(t, err)
|
|
|
|
user, err := c.GetUser(ctx)
|
|
require.NoError(t, err)
|
|
|
|
addr, err := c.GetAddresses(ctx)
|
|
require.NoError(t, err)
|
|
|
|
salt, err := c.GetSalts(ctx)
|
|
require.NoError(t, err)
|
|
|
|
keyPass, err := salt.SaltForKey(password, user.Keys.Primary().ID)
|
|
require.NoError(t, err)
|
|
|
|
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, getErr(stream.Collect(ctx, c.ImportMessages(
|
|
ctx,
|
|
addrKRs[addrID],
|
|
runtime.NumCPU(),
|
|
runtime.NumCPU(),
|
|
iterator.Collect(iterator.Map(iterator.Counter(count), func(i int) proton.ImportReq {
|
|
return proton.ImportReq{
|
|
Metadata: proton.ImportMetadata{
|
|
AddressID: addrID,
|
|
LabelIDs: []string{labelID},
|
|
Flags: proton.MessageFlagReceived,
|
|
},
|
|
Message: literal,
|
|
}
|
|
}))...,
|
|
))))
|
|
}
|
|
|
|
func countBytesRead(ctl *proton.NetCtl, fn func()) uint64 {
|
|
var read uint64
|
|
|
|
ctl.OnRead(func(b []byte) {
|
|
atomic.AddUint64(&read, uint64(len(b)))
|
|
})
|
|
|
|
fn()
|
|
|
|
return read
|
|
}
|