GODT-1779: Remove go-imap

This commit is contained in:
James Houlahan
2022-08-26 17:00:21 +02:00
parent 3b0bc1ca15
commit 39433fe707
593 changed files with 12725 additions and 91626 deletions

View File

@ -1,38 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package bridge provides core functionality of Bridge app.
package bridge
import "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
// IsAutostartEnabled checks if link file exits.
func (b *Bridge) IsAutostartEnabled() bool {
return b.autostart.IsEnabled()
}
// EnableAutostart creates link and sets the preferences.
func (b *Bridge) EnableAutostart() error {
b.settings.SetBool(settings.AutostartKey, true)
return b.autostart.Enable()
}
// DisableAutostart removes link and sets the preferences.
func (b *Bridge) DisableAutostart() error {
b.settings.SetBool(settings.AutostartKey, false)
return b.autostart.Disable()
}

View File

@ -1,325 +1,318 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package bridge provides core functionality of Bridge app.
package bridge
import (
"errors"
"context"
"crypto/tls"
"fmt"
"strconv"
"time"
"net"
"net/http"
"sync"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-autostart"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/watcher"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/metrics"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
logrus "github.com/sirupsen/logrus"
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
var log = logrus.WithField("pkg", "bridge") //nolint:gochecknoglobals
var ErrLocalCacheUnavailable = errors.New("local cache is unavailable")
type Bridge struct {
*users.Users
// vault holds bridge-specific data, such as preferences and known users (authorized or not).
vault *vault.Vault
locations Locator
settings SettingsProvider
clientManager pmapi.Manager
// users holds authorized users.
users map[string]*user.User
// api manages user API clients.
api *liteapi.Manager
cookieJar *cookies.Jar
proxyDialer ProxyDialer
identifier Identifier
// watchers holds all registered event watchers.
watchers []*watcher.Watcher[events.Event]
watchersLock sync.RWMutex
// tlsConfig holds the bridge TLS config used by the IMAP and SMTP servers.
tlsConfig *tls.Config
// imapServer is the bridge's IMAP server.
imapServer *gluon.Server
imapListener net.Listener
// smtpServer is the bridge's SMTP server.
smtpServer *smtp.Server
smtpBackend *smtpBackend
smtpListener net.Listener
// updater is the bridge's updater.
updater Updater
versioner Versioner
tls *tls.TLS
userAgent *useragent.UserAgent
cacheProvider CacheProvider
autostart *autostart.App
// Bridge's global errors list.
errors []error
curVersion *semver.Version
updateCheckCh chan struct{}
isAllMailVisible bool
isFirstStart bool
lastVersion string
// focusService is used to raise the bridge window when needed.
focusService *focus.FocusService
// autostarter is the bridge's autostarter.
autostarter Autostarter
// locator is the bridge's locator.
locator Locator
// errors contains errors encountered during startup.
errors []error
}
func New( //nolint:funlen
locations Locator,
cacheProvider CacheProvider,
setting SettingsProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
tls *tls.TLS,
userAgent *useragent.UserAgent,
cache cache.Cache,
builder *message.Builder,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
updater Updater,
versioner Versioner,
autostart *autostart.App,
) *Bridge {
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if setting.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy()
// New creates a new bridge.
func New(
apiURL string, // the URL of the API to use
locator Locator, // the locator to provide paths to store data
vault *vault.Vault, // the bridge's encrypted data store
identifier Identifier, // the identifier to keep track of the user agent
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
proxyDialer ProxyDialer, // the DoH dialer
autostarter Autostarter, // the autostarter to manage autostart settings
updater Updater, // the updater to fetch and install updates
curVersion *semver.Version, // the current version of the bridge
) (*Bridge, error) {
if vault.GetProxyAllowed() {
proxyDialer.AllowProxy()
} else {
proxyDialer.DisallowProxy()
}
u := users.New(
locations,
panicHandler,
eventListener,
clientManager,
credStorer,
newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
cookieJar, err := cookies.NewCookieJar(vault)
if err != nil {
return nil, fmt.Errorf("failed to create cookie jar: %w", err)
}
api := liteapi.New(
liteapi.WithHostURL(apiURL),
liteapi.WithAppVersion(constants.AppVersion),
liteapi.WithCookieJar(cookieJar),
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
)
b := &Bridge{
Users: u,
locations: locations,
settings: setting,
clientManager: clientManager,
updater: updater,
versioner: versioner,
tls: tls,
userAgent: userAgent,
cacheProvider: cacheProvider,
autostart: autostart,
isFirstStart: false,
isAllMailVisible: setting.GetBool(settings.IsAllMailVisible),
}
if setting.GetBool(settings.FirstStartKey) {
b.isFirstStart = true
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
logrus.WithError(err).Error("Failed to send metric")
}
setting.SetBool(settings.FirstStartKey, false)
}
// Keep in bridge and update in settings the last used version.
b.lastVersion = b.settings.Get(settings.LastVersionKey)
b.settings.Set(settings.LastVersionKey, constants.Version)
go b.heartbeat()
return b
}
// heartbeat sends a heartbeat signal once a day.
func (b *Bridge) heartbeat() {
for range time.Tick(time.Minute) {
lastHeartbeatDay, err := strconv.ParseInt(b.settings.Get(settings.LastHeartbeatKey), 10, 64)
if err != nil {
continue
}
// If we're still on the same day, don't send a heartbeat.
if time.Now().YearDay() == int(lastHeartbeatDay) {
continue
}
// We're on the next (or a different) day, so send a heartbeat.
if err := b.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel)); err != nil {
logrus.WithError(err).Error("Failed to send heartbeat")
continue
}
// Heartbeat was sent successfully so update the last heartbeat day.
b.settings.Set(settings.LastHeartbeatKey, fmt.Sprintf("%v", time.Now().YearDay()))
}
}
// GetUpdateChannel returns currently set update channel.
func (b *Bridge) GetUpdateChannel() updater.UpdateChannel {
return updater.UpdateChannel(b.settings.Get(settings.UpdateChannelKey))
}
// SetUpdateChannel switches update channel.
func (b *Bridge) SetUpdateChannel(channel updater.UpdateChannel) {
b.settings.Set(settings.UpdateChannelKey, string(channel))
}
func (b *Bridge) resetToLatestStable() error {
version, err := b.updater.Check()
tlsConfig, err := loadTLSConfig(vault)
if err != nil {
// If we can not check for updates - just remove all local updates and reset to base installer version.
// Not using `b.locations.ClearUpdates()` because `versioner.RemoveOtherVersions` can also handle
// case when it is needed to remove currently running verion.
if err := b.versioner.RemoveOtherVersions(semver.MustParse("0.0.0")); err != nil {
log.WithError(err).Error("Failed to clear updates while downgrading channel")
}
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
imapServer, err := newIMAPServer(vault.GetGluonDir(), curVersion, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
}
smtpBackend, err := newSMTPBackend()
if err != nil {
return nil, fmt.Errorf("failed to create SMTP backend: %w", err)
}
smtpServer, err := newSMTPServer(smtpBackend, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create SMTP server: %w", err)
}
focusService, err := focus.NewService()
if err != nil {
return nil, fmt.Errorf("failed to create focus service: %w", err)
}
bridge := &Bridge{
vault: vault,
users: make(map[string]*user.User),
api: api,
cookieJar: cookieJar,
proxyDialer: proxyDialer,
identifier: identifier,
tlsConfig: tlsConfig,
imapServer: imapServer,
smtpServer: smtpServer,
smtpBackend: smtpBackend,
updater: updater,
curVersion: curVersion,
updateCheckCh: make(chan struct{}, 1),
focusService: focusService,
autostarter: autostarter,
locator: locator,
}
api.AddStatusObserver(func(status liteapi.Status) {
bridge.publish(events.ConnStatus{
Status: status,
})
})
api.AddErrorHandler(liteapi.AppVersionBadCode, func() {
bridge.publish(events.UpdateForced{})
})
api.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error {
req.SetHeader("User-Agent", bridge.identifier.GetUserAgent())
return nil
})
go func() {
for range tlsReporter.GetTLSIssueCh() {
bridge.publish(events.TLSIssue{})
}
}()
go func() {
for range focusService.GetRaiseCh() {
bridge.publish(events.Raise{})
}
}()
go func() {
for event := range imapServer.AddWatcher() {
bridge.handleIMAPEvent(event)
}
}()
if err := bridge.loadUsers(context.Background()); err != nil {
return nil, fmt.Errorf("failed to load connected users: %w", err)
}
// If current version is same as upstream stable version - do nothing.
if version.Version.Equal(semver.MustParse(constants.Version)) {
return nil
if err := bridge.serveIMAP(); err != nil {
bridge.PushError(ErrServeIMAP)
}
if err := b.updater.InstallUpdate(version); err != nil {
return err
if err := bridge.serveSMTP(); err != nil {
bridge.PushError(ErrServeSMTP)
}
return b.versioner.RemoveOtherVersions(version.Version)
if err := bridge.watchForUpdates(); err != nil {
bridge.PushError(ErrWatchUpdates)
}
return bridge, nil
}
// FactoryReset will remove all local cache and settings.
// It will also downgrade to latest stable version if user is on early version.
func (b *Bridge) FactoryReset() {
wasEarly := b.GetUpdateChannel() == updater.EarlyChannel
// GetEvents returns a channel of events of the given type.
// If no types are supplied, all events are returned.
func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, func()) {
newWatcher := bridge.addWatcher(ofType...)
b.settings.Set(settings.UpdateChannelKey, string(updater.StableChannel))
return newWatcher.GetChannel(), func() { bridge.remWatcher(newWatcher) }
}
if wasEarly {
if err := b.resetToLatestStable(); err != nil {
log.WithError(err).Error("Failed to reset to latest stable version")
func (bridge *Bridge) FactoryReset(ctx context.Context) error {
panic("TODO")
}
func (bridge *Bridge) PushError(err error) {
bridge.errors = append(bridge.errors, err)
}
func (bridge *Bridge) GetErrors() []error {
return bridge.errors
}
func (bridge *Bridge) Close(ctx context.Context) error {
// Close the IMAP server.
if err := bridge.closeIMAP(ctx); err != nil {
logrus.WithError(err).Error("Failed to close IMAP server")
}
// Close the SMTP server.
if err := bridge.closeSMTP(); err != nil {
logrus.WithError(err).Error("Failed to close SMTP server")
}
// Close all users.
for _, user := range bridge.users {
if err := user.Close(ctx); err != nil {
logrus.WithError(err).Error("Failed to close user")
}
}
if err := b.Users.ClearData(); err != nil {
log.WithError(err).Error("Failed to remove bridge data")
// Persist the cookies.
if err := bridge.cookieJar.PersistCookies(); err != nil {
logrus.WithError(err).Error("Failed to persist cookies")
}
if err := b.Users.ClearUsers(); err != nil {
log.WithError(err).Error("Failed to remove bridge users")
// Close the focus service.
bridge.focusService.Close()
// Save the last version of bridge that was run.
if err := bridge.vault.SetLastVersion(bridge.curVersion); err != nil {
logrus.WithError(err).Error("Failed to save last version")
}
}
// GetKeychainApp returns current keychain helper.
func (b *Bridge) GetKeychainApp() string {
return b.settings.Get(settings.PreferredKeychainKey)
}
// SetKeychainApp sets current keychain helper.
func (b *Bridge) SetKeychainApp(helper string) {
b.settings.Set(settings.PreferredKeychainKey, helper)
}
func (b *Bridge) EnableCache() error {
if err := b.Users.EnableCache(); err != nil {
return err
}
b.settings.SetBool(settings.CacheEnabledKey, true)
return nil
}
func (b *Bridge) DisableCache() error {
if err := b.Users.DisableCache(); err != nil {
return err
}
func (bridge *Bridge) publish(event events.Event) {
bridge.watchersLock.RLock()
defer bridge.watchersLock.RUnlock()
b.settings.SetBool(settings.CacheEnabledKey, false)
// Reset back to the default location when disabling.
b.settings.Set(settings.CacheLocationKey, b.cacheProvider.GetDefaultMessageCacheDir())
return nil
}
func (b *Bridge) MigrateCache(from, to string) error {
if err := b.Users.MigrateCache(from, to); err != nil {
return err
}
b.settings.Set(settings.CacheLocationKey, to)
return nil
}
// SetProxyAllowed instructs the app whether to use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (b *Bridge) SetProxyAllowed(proxyAllowed bool) {
b.settings.SetBool(settings.AllowProxyKey, proxyAllowed)
if proxyAllowed {
b.clientManager.AllowProxy()
} else {
b.clientManager.DisallowProxy()
}
}
// GetProxyAllowed returns whether use of DoH is enabled to access an API proxy if necessary.
func (b *Bridge) GetProxyAllowed() bool {
return b.settings.GetBool(settings.AllowProxyKey)
}
// AddError add an error to a global error list if it does not contain it yet. Adding nil is noop.
func (b *Bridge) AddError(err error) {
if err == nil {
return
}
if b.HasError(err) {
return
}
b.errors = append(b.errors, err)
}
// DelError removes an error from global error list.
func (b *Bridge) DelError(err error) {
for idx, val := range b.errors {
if val == err {
b.errors = append(b.errors[:idx], b.errors[idx+1:]...)
return
for _, watcher := range bridge.watchers {
if watcher.IsWatching(event) {
if ok := watcher.Send(event); !ok {
logrus.WithField("event", event).Warn("Failed to send event to watcher")
}
}
}
}
// HasError returnes true if global error list contains an err.
func (b *Bridge) HasError(err error) bool {
for _, val := range b.errors {
if val == err {
return true
}
func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] {
bridge.watchersLock.Lock()
defer bridge.watchersLock.Unlock()
newWatcher := watcher.New(ofType...)
bridge.watchers = append(bridge.watchers, newWatcher)
return newWatcher
}
func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
bridge.watchersLock.Lock()
defer bridge.watchersLock.Unlock()
bridge.watchers = xslices.Filter(bridge.watchers, func(other *watcher.Watcher[events.Event]) bool {
return other != oldWatcher
})
}
func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) {
cert, err := tls.X509KeyPair(vault.GetBridgeTLSCert(), vault.GetBridgeTLSKey())
if err != nil {
return nil, err
}
return false
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}
// GetLastVersion returns the version which was used in previous execution of
// Bridge.
func (b *Bridge) GetLastVersion() string {
return b.lastVersion
}
func newListener(port int, useTLS bool, tlsConfig *tls.Config) (net.Listener, error) {
if useTLS {
tlsListener, err := tls.Listen("tcp", fmt.Sprintf(":%v", port), tlsConfig)
if err != nil {
return nil, err
}
// IsFirstStart returns true when Bridge is running for first time or after
// factory reset.
func (b *Bridge) IsFirstStart() bool {
return b.isFirstStart
}
return tlsListener, nil
}
// IsAllMailVisible can be called extensively by IMAP. Therefore, it is better
// to cache the value instead of reading from settings file.
func (b *Bridge) IsAllMailVisible() bool {
return b.isAllMailVisible
}
netListener, err := net.Listen("tcp", fmt.Sprintf(":%v", port))
if err != nil {
return nil, err
}
func (b *Bridge) SetIsAllMailVisible(isVisible bool) {
b.settings.SetBool(settings.IsAllMailVisible, isVisible)
b.isAllMailVisible = isVisible
return netListener, nil
}

View File

@ -0,0 +1,362 @@
package bridge_test
import (
"context"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
const (
username = "username"
password = "password"
)
var (
v2_3_0 = semver.MustParse("2.3.0")
v2_4_0 = semver.MustParse("2.4.0")
)
func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of connection status events.
eventCh, done := bridge.GetEvents(events.ConnStatus{})
defer done()
// Simulate network disconnect.
mocks.TLSDialer.SetCanDial(false)
// Trigger some operation that will fail due to the network disconnect.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.Error(t, err)
// Wait for the event.
require.Equal(t, events.ConnStatus{Status: liteapi.StatusDown}, <-eventCh)
// Simulate network reconnect.
mocks.TLSDialer.SetCanDial(true)
// Trigger some operation that will succeed due to the network reconnect.
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, userID)
// Wait for the event.
require.Equal(t, events.ConnStatus{Status: liteapi.StatusUp}, <-eventCh)
})
})
}
func TestBridge_TLSIssue(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
defer done()
// Simulate a TLS issue.
go func() {
mocks.TLSIssueCh <- struct{}{}
}()
// Wait for the event.
require.IsType(t, events.TLSIssue{}, <-tlsEventCh)
})
})
}
func TestBridge_Focus(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
raiseCh, done := bridge.GetEvents(events.Raise{})
defer done()
// Simulate a focus event.
focus.TryRaise()
// Wait for the event.
require.IsType(t, events.Raise{}, <-raiseCh)
})
})
}
func TestBridge_UserAgent(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Set the platform to something other than the default.
bridge.SetCurrentPlatform("platform")
// Assert that the user agent then contains the platform.
require.Contains(t, bridge.GetCurrentUserAgent(), "platform")
// Login the user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
// Assert that the user agent was sent to the API.
require.Contains(t, calls[len(calls)-1].Request.Header.Get("User-Agent"), bridge.GetCurrentUserAgent())
})
})
}
func TestBridge_Cookies(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
var sessionID string
// Start bridge and add a user so that API assigns us a session ID via cookie.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
require.NoError(t, err)
sessionID = cookie.Value
})
// Start bridge again and check that it uses the same session ID.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
require.NoError(t, err)
require.Equal(t, sessionID, cookie.Value)
})
})
}
func TestBridge_CheckUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
defer done()
// We are currently on the latest version.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
},
CanInstall: true,
}, <-updateCh)
})
})
}
func TestBridge_AutoUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Enable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(true))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{})
defer done()
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateInstalled{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
},
}, <-updateCh)
})
})
}
func TestBridge_ManualUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
defer done()
// Simulate a new version being available, but it's too new for us.
mocks.Updater.SetLatestVersion(v2_4_0, v2_4_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_4_0,
RolloutProportion: 1.0,
},
CanInstall: false,
}, <-updateCh)
})
})
}
func TestBridge_ForceUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateForced{})
defer done()
// Set the minimum accepted app version to something newer than the current version.
s.SetMinAppVersion(v2_4_0)
// Try to login the user. It will fail because the bridge is too old.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.Error(t, err)
// We should get an update required event.
require.Equal(t, events.UpdateForced{}, <-updateCh)
})
})
}
func TestBridge_BadVaultKey(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
var userID string
// Login a user.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
userID = newUserID
})
// Start bridge with the correct vault key -- it should load the users correctly.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
})
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
withBridge(t, s.GetHostURL(), locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
withBridge(t, s.GetHostURL(), locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
})
}
// withEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, tests func(server *server.Server, locator bridge.Locator, vaultKey []byte)) {
// Create test API.
server := server.NewTLS()
defer server.Close()
// Add test user.
_, _, err := server.AddUser(username, password, username+"@pm.me")
require.NoError(t, err)
// Generate a random vault key.
vaultKey, err := crypto.RandomToken(32)
require.NoError(t, err)
// Run the tests.
tests(server, locations.New(bridge.NewTestLocationsProvider(t), "config-name"), vaultKey)
}
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
func withBridge(t *testing.T, apiURL string, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create the mock objects used in the tests.
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
// Bridge will enable the proxy by default at startup.
mocks.ProxyDialer.EXPECT().AllowProxy()
// Get the path to the vault.
vaultDir, err := locator.ProvideSettingsPath()
require.NoError(t, err)
// Create the vault.
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey)
require.NoError(t, err)
// Create a new bridge.
bridge, err := bridge.New(
apiURL,
locator,
vault,
useragent.New(),
mocks.TLSReporter,
mocks.ProxyDialer,
mocks.Autostarter,
mocks.Updater,
v2_3_0,
)
require.NoError(t, err)
// Use the bridge.
tests(bridge, mocks)
// Close the bridge.
require.NoError(t, bridge.Close(ctx))
}
// must is a helper function that panics on error.
func must[T any](val T, err error) T {
if err != nil {
panic(err)
}
return val
}
func getConnectedUserIDs(t *testing.T, bridge *bridge.Bridge) []string {
t.Helper()
return xslices.Filter(bridge.GetUserIDs(), func(userID string) bool {
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
return info.Connected
})
}

View File

@ -21,67 +21,51 @@ import (
"archive/zip"
"bytes"
"context"
"errors"
"io"
"os"
"path/filepath"
"sort"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"gitlab.protontech.ch/go/liteapi"
)
const (
MaxAttachmentSize = 7 * 1024 * 1024 // MaxAttachmentSize 7 MB total limit
MaxAttachmentSize = 7 * (1 << 20) // MaxAttachmentSize 7 MB total size of all attachments.
MaxCompressedFilesCount = 6
)
var ErrSizeTooLarge = errors.New("file is too big")
func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, description, username, email, client string, attachLogs bool) error {
var account string
// ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string, attachLogs bool) error { //nolint:funlen
if user, err := b.GetUser(address); err == nil {
accountName = user.Username()
} else if users := b.GetUsers(); len(users) > 0 {
accountName = users[0].Username()
if info, err := bridge.QueryUserInfo(username); err == nil {
account = info.Username
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
account = bridge.users[userIDs[0]].Name()
}
report := pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Browser: emailClient,
Title: "[Bridge] Bug",
Description: description,
Username: accountName,
Email: address,
}
var atts []liteapi.ReportBugAttachment
if attachLogs {
logs, err := b.getMatchingLogs(
func(filename string) bool {
return logging.MatchLogName(filename) && !logging.MatchStackTraceName(filename)
},
)
logs, err := getMatchingLogs(bridge.locator, func(filename string) bool {
return logging.MatchLogName(filename) && !logging.MatchStackTraceName(filename)
})
if err != nil {
log.WithError(err).Error("Can't get log files list")
return err
}
guiLogs, err := b.getMatchingLogs(
func(filename string) bool {
return logging.MatchGUILogName(filename) && !logging.MatchStackTraceName(filename)
},
)
guiLogs, err := getMatchingLogs(bridge.locator, func(filename string) bool {
return logging.MatchGUILogName(filename) && !logging.MatchStackTraceName(filename)
})
if err != nil {
log.WithError(err).Error("Can't get GUI log files list")
return err
}
crashes, err := b.getMatchingLogs(
func(filename string) bool {
return logging.MatchLogName(filename) && logging.MatchStackTraceName(filename)
},
)
crashes, err := getMatchingLogs(bridge.locator, func(filename string) bool {
return logging.MatchLogName(filename) && logging.MatchStackTraceName(filename)
})
if err != nil {
log.WithError(err).Error("Can't get crash files list")
return err
}
var matchFiles []string
@ -95,26 +79,42 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
archive, err := zipFiles(matchFiles)
if err != nil {
log.WithError(err).Error("Can't zip logs and crashes")
return err
}
if archive != nil {
report.AddAttachment("logs.zip", "application/zip", archive)
body, err := io.ReadAll(archive)
if err != nil {
return err
}
atts = append(atts, liteapi.ReportBugAttachment{
Name: "logs.zip",
Filename: "logs.zip",
MIMEType: "application/zip",
Body: body,
})
}
return b.clientManager.ReportBug(context.Background(), report)
return bridge.api.ReportBug(ctx, liteapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Description: description,
Client: client,
Username: account,
Email: email,
}, atts...)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func (b *Bridge) getMatchingLogs(filenameMatchFunc func(string) bool) (filenames []string, err error) {
logsPath, err := b.locations.ProvideLogsPath()
func getMatchingLogs(locator Locator, filenameMatchFunc func(string) bool) (filenames []string, err error) {
logsPath, err := locator.ProvideLogsPath()
if err != nil {
return nil, err
}
@ -131,24 +131,25 @@ func (b *Bridge) getMatchingLogs(filenameMatchFunc func(string) bool) (filenames
matchFiles = append(matchFiles, filepath.Join(logsPath, file.Name()))
}
}
sort.Strings(matchFiles) // Sorted by timestamp: oldest first.
return matchFiles, nil
}
type LimitedBuffer struct {
type limitedBuffer struct {
capacity int
buf *bytes.Buffer
}
func NewLimitedBuffer(capacity int) *LimitedBuffer {
return &LimitedBuffer{
func newLimitedBuffer(capacity int) *limitedBuffer {
return &limitedBuffer{
capacity: capacity,
buf: bytes.NewBuffer(make([]byte, 0, capacity)),
}
}
func (b *LimitedBuffer) Write(p []byte) (n int, err error) {
func (b *limitedBuffer) Write(p []byte) (n int, err error) {
if len(p)+b.buf.Len() > b.capacity {
return 0, ErrSizeTooLarge
}
@ -156,7 +157,7 @@ func (b *LimitedBuffer) Write(p []byte) (n int, err error) {
return b.buf.Write(p)
}
func (b *LimitedBuffer) Read(p []byte) (n int, err error) {
func (b *limitedBuffer) Read(p []byte) (n int, err error) {
return b.buf.Read(p)
}
@ -165,14 +166,13 @@ func zipFiles(filenames []string) (io.Reader, error) {
return nil, nil
}
buf := NewLimitedBuffer(MaxAttachmentSize)
buf := newLimitedBuffer(MaxAttachmentSize)
w := zip.NewWriter(buf)
defer w.Close() //nolint:errcheck
for _, file := range filenames {
err := addFileToZip(file, w)
if err != nil {
if err := addFileToZip(file, w); err != nil {
return nil, err
}
}
@ -209,12 +209,9 @@ func addFileToZip(filename string, writer *zip.Writer) error {
return err
}
_, err = io.Copy(fileWriter, fileReader)
if err != nil {
if _, err := io.Copy(fileWriter, fileReader); err != nil {
return err
}
err = fileReader.Close()
return err
return fileReader.Close()
}

View File

@ -1,70 +1,38 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
)
func (b *Bridge) ConfigureAppleMail(userID, address string) (bool, error) {
user, err := b.GetUser(userID)
if err != nil {
return false, err
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if address == "" {
address = user.GetPrimaryAddress()
address = user.Addresses()[0]
}
username := address
addresses := address
if user.IsCombinedAddressMode() {
username = user.GetPrimaryAddress()
addresses = strings.Join(user.GetAddresses(), ",")
}
var (
restart = false
smtpSSL = b.settings.GetBool(settings.SMTPSSLKey)
)
// If configuring apple mail for Catalina or newer, users should use SSL.
if useragent.IsCatalinaOrNewer() && !smtpSSL {
smtpSSL = true
restart = true
b.settings.SetBool(settings.SMTPSSLKey, true)
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
if err := bridge.SetSMTPSSL(true); err != nil {
return err
}
}
if err := (&clientconfig.AppleMail{}).Configure(
Host,
b.settings.GetInt(settings.IMAPPortKey),
b.settings.GetInt(settings.SMTPPortKey),
false, smtpSSL,
username, addresses,
user.GetBridgePassword(),
); err != nil {
return false, err
}
return restart, nil
return (&clientconfig.AppleMail{}).Configure(
constants.Host,
bridge.vault.GetIMAPPort(),
bridge.vault.GetSMTPPort(),
bridge.vault.GetIMAPSSL(),
bridge.vault.GetSMTPSSL(),
address,
strings.Join(user.Addresses(), ","),
user.BridgePass(),
)
}

View File

@ -1,23 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
// Host settings.
const (
Host = "127.0.0.1"
)

16
internal/bridge/errors.go Normal file
View File

@ -0,0 +1,16 @@
package bridge
import "errors"
var (
ErrServeIMAP = errors.New("failed to serve IMAP")
ErrServeSMTP = errors.New("failed to serve SMTP")
ErrWatchUpdates = errors.New("failed to watch for updates")
ErrNoSuchUser = errors.New("no such user")
ErrUserAlreadyExists = errors.New("user already exists")
ErrUserAlreadyLoggedIn = errors.New("user already logged in")
ErrNotImplemented = errors.New("not implemented")
ErrSizeTooLarge = errors.New("file is too big")
)

67
internal/bridge/files.go Normal file
View File

@ -0,0 +1,67 @@
package bridge
import (
"os"
"path/filepath"
)
func moveDir(from, to string) error {
entries, err := os.ReadDir(from)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
if err := os.Mkdir(filepath.Join(to, entry.Name()), 0700); err != nil {
return err
}
if err := moveDir(filepath.Join(from, entry.Name()), filepath.Join(to, entry.Name())); err != nil {
return err
}
if err := os.RemoveAll(filepath.Join(from, entry.Name())); err != nil {
return err
}
} else {
if err := move(filepath.Join(from, entry.Name()), filepath.Join(to, entry.Name())); err != nil {
return err
}
}
}
return os.Remove(from)
}
func move(from, to string) error {
if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
return err
}
f, err := os.Open(from)
if err != nil {
return err
}
defer f.Close()
c, err := os.Create(to)
if err != nil {
return err
}
defer c.Close()
if err := os.Chmod(to, 0600); err != nil {
return err
}
if _, err := c.ReadFrom(f); err != nil {
return err
}
if err := os.Remove(from); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,56 @@
package bridge
import (
"os"
"path/filepath"
"testing"
)
func TestMoveDir(t *testing.T) {
from, to := t.TempDir(), t.TempDir()
// Create some files in from.
if err := os.WriteFile(filepath.Join(from, "a"), []byte("a"), 0600); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(from, "b"), []byte("b"), 0600); err != nil {
t.Fatal(err)
}
if err := os.Mkdir(filepath.Join(from, "c"), 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(from, "c", "d"), []byte("d"), 0600); err != nil {
t.Fatal(err)
}
// Move the files.
if err := moveDir(from, to); err != nil {
t.Fatal(err)
}
// Check that the files were moved.
if _, err := os.Stat(filepath.Join(from, "a")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "a")); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(from, "b")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "b")); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(from, "c")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "c")); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(from, "c", "d")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "c", "d")); err != nil {
t.Fatal(err)
}
}

117
internal/bridge/imap.go Normal file
View File

@ -0,0 +1,117 @@
package bridge
import (
"context"
"crypto/tls"
"fmt"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon"
imapEvents "github.com/ProtonMail/gluon/events"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/sirupsen/logrus"
)
const (
defaultClientName = "UnknownClient"
defaultClientVersion = "0.0.1"
)
func (bridge *Bridge) GetIMAPPort() int {
return bridge.vault.GetIMAPPort()
}
func (bridge *Bridge) SetIMAPPort(newPort int) error {
if newPort == bridge.vault.GetIMAPPort() {
return nil
}
if err := bridge.vault.SetIMAPPort(newPort); err != nil {
return err
}
return bridge.restartIMAP(context.Background())
}
func (bridge *Bridge) GetIMAPSSL() bool {
return bridge.vault.GetIMAPSSL()
}
func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
if newSSL == bridge.vault.GetIMAPSSL() {
return nil
}
if err := bridge.vault.SetIMAPSSL(newSSL); err != nil {
return err
}
return bridge.restartIMAP(context.Background())
}
func (bridge *Bridge) serveIMAP() error {
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
if err != nil {
return fmt.Errorf("failed to create IMAP listener: %w", err)
}
bridge.imapListener = imapListener
return bridge.imapServer.Serve(context.Background(), bridge.imapListener)
}
func (bridge *Bridge) restartIMAP(ctx context.Context) error {
if err := bridge.imapListener.Close(); err != nil {
logrus.WithError(err).Warn("Failed to close IMAP listener")
}
return bridge.serveIMAP()
}
func (bridge *Bridge) closeIMAP(ctx context.Context) error {
if err := bridge.imapServer.Close(ctx); err != nil {
logrus.WithError(err).Warn("Failed to close IMAP server")
}
if err := bridge.imapListener.Close(); err != nil {
logrus.WithError(err).Warn("Failed to close IMAP listener")
}
return nil
}
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
switch event := event.(type) {
case imapEvents.SessionAdded:
if !bridge.identifier.HasClient() {
bridge.identifier.SetClient(defaultClientName, defaultClientVersion)
}
case imapEvents.IMAPID:
bridge.identifier.SetClient(event.IMAPID.Name, event.IMAPID.Version)
}
}
func newIMAPServer(gluonDir string, version *semver.Version, tlsConfig *tls.Config) (*gluon.Server, error) {
imapServer, err := gluon.New(
gluon.WithTLS(tlsConfig),
gluon.WithDataDir(gluonDir),
gluon.WithVersionInfo(
int(version.Major()),
int(version.Minor()),
int(version.Patch()),
constants.FullAppName,
"TODO",
"TODO",
),
gluon.WithLogger(
logrus.StandardLogger().WriterLevel(logrus.InfoLevel),
logrus.StandardLogger().WriterLevel(logrus.InfoLevel),
),
)
if err != nil {
return nil, err
}
return imapServer, nil
}

View File

@ -1,30 +1,13 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
func (b *Bridge) ProvideLogsPath() (string, error) {
return b.locations.ProvideLogsPath()
func (bridge *Bridge) GetLogsPath() (string, error) {
return bridge.locator.ProvideLogsPath()
}
func (b *Bridge) GetLicenseFilePath() string {
return b.locations.GetLicenseFilePath()
func (bridge *Bridge) GetLicenseFilePath() string {
return bridge.locator.GetLicenseFilePath()
}
func (b *Bridge) GetDependencyLicensesLink() string {
return b.locations.GetDependencyLicensesLink()
func (bridge *Bridge) GetDependencyLicensesLink() string {
return bridge.locator.GetDependencyLicensesLink()
}

127
internal/bridge/mocks.go Normal file
View File

@ -0,0 +1,127 @@
package bridge
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge/mocks"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/golang/mock/gomock"
)
type Mocks struct {
TLSDialer *TestDialer
ProxyDialer *mocks.MockProxyDialer
TLSReporter *mocks.MockTLSReporter
TLSIssueCh chan struct{}
Updater *TestUpdater
Autostarter *mocks.MockAutostarter
}
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
ctl := gomock.NewController(tb)
mocks := &Mocks{
TLSDialer: NewTestDialer(),
ProxyDialer: mocks.NewMockProxyDialer(ctl),
TLSReporter: mocks.NewMockTLSReporter(ctl),
TLSIssueCh: make(chan struct{}),
Updater: NewTestUpdater(version, minAuto),
Autostarter: mocks.NewMockAutostarter(ctl),
}
// When using the proxy dialer, we want to use the test dialer.
mocks.ProxyDialer.EXPECT().DialTLSContext(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
return mocks.TLSDialer.DialTLSContext(ctx, network, address)
}).AnyTimes()
// When getting the TLS issue channel, we want to return the test channel.
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
return mocks
}
type TestDialer struct {
canDial bool
}
func NewTestDialer() *TestDialer {
return &TestDialer{
canDial: true,
}
}
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
if !d.canDial {
return nil, errors.New("cannot dial")
}
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
}
func (d *TestDialer) SetCanDial(canDial bool) {
d.canDial = canDial
}
type TestLocationsProvider struct {
config, cache string
}
func NewTestLocationsProvider(tb testing.TB) *TestLocationsProvider {
return &TestLocationsProvider{
config: tb.TempDir(),
cache: tb.TempDir(),
}
}
func (provider *TestLocationsProvider) UserConfig() string {
return provider.config
}
func (provider *TestLocationsProvider) UserCache() string {
return provider.cache
}
type TestUpdater struct {
latest updater.VersionInfo
}
func NewTestUpdater(version, minAuto *semver.Version) *TestUpdater {
return &TestUpdater{
latest: updater.VersionInfo{
Version: version,
MinAuto: minAuto,
RolloutProportion: 1.0,
},
}
}
func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Version) {
testUpdater.latest = updater.VersionInfo{
Version: version,
MinAuto: minAuto,
RolloutProportion: 1.0,
}
}
func (updater *TestUpdater) GetVersionInfo(downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error) {
return updater.latest, nil
}
func (updater *TestUpdater) InstallUpdate(downloader updater.Downloader, update updater.VersionInfo) error {
return nil
}

View File

@ -0,0 +1,163 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockTLSReporter is a mock of TLSReporter interface.
type MockTLSReporter struct {
ctrl *gomock.Controller
recorder *MockTLSReporterMockRecorder
}
// MockTLSReporterMockRecorder is the mock recorder for MockTLSReporter.
type MockTLSReporterMockRecorder struct {
mock *MockTLSReporter
}
// NewMockTLSReporter creates a new mock instance.
func NewMockTLSReporter(ctrl *gomock.Controller) *MockTLSReporter {
mock := &MockTLSReporter{ctrl: ctrl}
mock.recorder = &MockTLSReporterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTLSReporter) EXPECT() *MockTLSReporterMockRecorder {
return m.recorder
}
// GetTLSIssueCh mocks base method.
func (m *MockTLSReporter) GetTLSIssueCh() <-chan struct{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTLSIssueCh")
ret0, _ := ret[0].(<-chan struct{})
return ret0
}
// GetTLSIssueCh indicates an expected call of GetTLSIssueCh.
func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
}
// MockProxyDialer is a mock of ProxyDialer interface.
type MockProxyDialer struct {
ctrl *gomock.Controller
recorder *MockProxyDialerMockRecorder
}
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
type MockProxyDialerMockRecorder struct {
mock *MockProxyDialer
}
// NewMockProxyDialer creates a new mock instance.
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
mock := &MockProxyDialer{ctrl: ctrl}
mock.recorder = &MockProxyDialerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method.
func (m *MockProxyDialer) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy.
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
}
// DialTLSContext mocks base method.
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialTLSContext indicates an expected call of DialTLSContext.
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
}
// DisallowProxy mocks base method.
func (m *MockProxyDialer) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy.
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
}
// MockAutostarter is a mock of Autostarter interface.
type MockAutostarter struct {
ctrl *gomock.Controller
recorder *MockAutostarterMockRecorder
}
// MockAutostarterMockRecorder is the mock recorder for MockAutostarter.
type MockAutostarterMockRecorder struct {
mock *MockAutostarter
}
// NewMockAutostarter creates a new mock instance.
func NewMockAutostarter(ctrl *gomock.Controller) *MockAutostarter {
mock := &MockAutostarter{ctrl: ctrl}
mock.recorder = &MockAutostarterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAutostarter) EXPECT() *MockAutostarterMockRecorder {
return m.recorder
}
// Disable mocks base method.
func (m *MockAutostarter) Disable() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disable")
ret0, _ := ret[0].(error)
return ret0
}
// Disable indicates an expected call of Disable.
func (mr *MockAutostarterMockRecorder) Disable() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disable", reflect.TypeOf((*MockAutostarter)(nil).Disable))
}
// Enable mocks base method.
func (m *MockAutostarter) Enable() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Enable")
ret0, _ := ret[0].(error)
return ret0
}
// Enable indicates an expected call of Enable.
func (mr *MockAutostarterMockRecorder) Enable() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Enable", reflect.TypeOf((*MockAutostarter)(nil).Enable))
}

View File

@ -1,26 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Code generated by ./release-notes.sh at 'Fri Jan 22 11:01:06 AM CET 2021'. DO NOT EDIT.
package bridge
const ReleaseNotes = `
`
const ReleaseFixedBugs = `• Fixed sending error caused by inconsistent use of upper and lower case in senders email address
`

View File

@ -1,44 +1,175 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
import (
"context"
func (b *Bridge) Get(key settings.Key) string {
return b.settings.Get(key)
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
)
func (bridge *Bridge) GetKeychainApp() (string, error) {
vaultDir, err := bridge.locator.ProvideSettingsPath()
if err != nil {
return "", err
}
return vault.GetHelper(vaultDir)
}
func (b *Bridge) Set(key settings.Key, value string) {
b.settings.Set(key, value)
func (bridge *Bridge) SetKeychainApp(helper string) error {
vaultDir, err := bridge.locator.ProvideSettingsPath()
if err != nil {
return err
}
return vault.SetHelper(vaultDir, helper)
}
func (b *Bridge) GetBool(key settings.Key) bool {
return b.settings.GetBool(key)
func (bridge *Bridge) GetGluonDir() string {
return bridge.vault.GetGluonDir()
}
func (b *Bridge) SetBool(key settings.Key, value bool) {
b.settings.SetBool(key, value)
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
if newGluonDir == bridge.GetGluonDir() {
return nil
}
if err := bridge.closeIMAP(context.Background()); err != nil {
return err
}
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return err
}
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return err
}
imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig)
if err != nil {
return err
}
for _, user := range bridge.users {
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return err
}
if err := imapServer.LoadUser(context.Background(), imapConn, user.GluonID(), user.GluonKey()); err != nil {
return err
}
}
bridge.imapServer = imapServer
return bridge.serveIMAP()
}
func (b *Bridge) GetInt(key settings.Key) int {
return b.settings.GetInt(key)
func (bridge *Bridge) GetProxyAllowed() bool {
return bridge.vault.GetProxyAllowed()
}
func (b *Bridge) SetInt(key settings.Key, value int) {
b.settings.SetInt(key, value)
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
if allowed {
bridge.proxyDialer.AllowProxy()
} else {
bridge.proxyDialer.DisallowProxy()
}
return bridge.vault.SetProxyAllowed(allowed)
}
func (bridge *Bridge) GetShowAllMail() bool {
return bridge.vault.GetShowAllMail()
}
func (bridge *Bridge) SetShowAllMail(show bool) error {
panic("TODO")
}
func (bridge *Bridge) GetAutostart() bool {
return bridge.vault.GetAutostart()
}
func (bridge *Bridge) SetAutostart(autostart bool) error {
if err := bridge.vault.SetAutostart(autostart); err != nil {
return err
}
var err error
if autostart {
err = bridge.autostarter.Enable()
} else {
err = bridge.autostarter.Disable()
}
return err
}
func (bridge *Bridge) GetAutoUpdate() bool {
return bridge.vault.GetAutoUpdate()
}
func (bridge *Bridge) SetAutoUpdate(autoUpdate bool) error {
if bridge.vault.GetAutoUpdate() == autoUpdate {
return nil
}
if err := bridge.vault.SetAutoUpdate(autoUpdate); err != nil {
return err
}
bridge.updateCheckCh <- struct{}{}
return nil
}
func (bridge *Bridge) GetUpdateChannel() updater.Channel {
return updater.Channel(bridge.vault.GetUpdateChannel())
}
func (bridge *Bridge) SetUpdateChannel(channel updater.Channel) error {
if bridge.vault.GetUpdateChannel() == channel {
return nil
}
if err := bridge.vault.SetUpdateChannel(channel); err != nil {
return err
}
bridge.updateCheckCh <- struct{}{}
return nil
}
func (bridge *Bridge) GetLastVersion() *semver.Version {
return bridge.vault.GetLastVersion()
}
func (bridge *Bridge) GetFirstStart() bool {
return bridge.vault.GetFirstStart()
}
func (bridge *Bridge) SetFirstStart(firstStart bool) error {
return bridge.vault.SetFirstStart(firstStart)
}
func (bridge *Bridge) GetFirstStartGUI() bool {
return bridge.vault.GetFirstStartGUI()
}
func (bridge *Bridge) SetFirstStartGUI(firstStart bool) error {
return bridge.vault.SetFirstStartGUI(firstStart)
}
func (bridge *Bridge) GetColorScheme() string {
return bridge.vault.GetColorScheme()
}
func (bridge *Bridge) SetColorScheme(colorScheme string) error {
return bridge.vault.SetColorScheme(colorScheme)
}

View File

@ -0,0 +1,156 @@
package bridge_test
import (
"context"
"os"
"testing"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_Settings_GluonDir(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Create a user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
// Create a new location for the Gluon data.
newGluonDir := t.TempDir()
// Move the gluon dir; it should also move the user's data.
require.NoError(t, bridge.SetGluonDir(context.Background(), newGluonDir))
// Check that the new directory is not empty.
entries, err := os.ReadDir(newGluonDir)
require.NoError(t, err)
// There should be at least one entry.
require.NotEmpty(t, entries)
})
})
}
func TestBridge_Settings_IMAPPort(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, the port is 1143.
require.Equal(t, 1143, bridge.GetIMAPPort())
// Set the port to 1144.
require.NoError(t, bridge.SetIMAPPort(1144))
// Get the new setting.
require.Equal(t, 1144, bridge.GetIMAPPort())
})
})
}
func TestBridge_Settings_IMAPSSL(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, IMAP SSL is disabled.
require.False(t, bridge.GetIMAPSSL())
// Enable IMAP SSL.
require.NoError(t, bridge.SetIMAPSSL(true))
// Get the new setting.
require.True(t, bridge.GetIMAPSSL())
})
})
}
func TestBridge_Settings_SMTPPort(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, the port is 1025.
require.Equal(t, 1025, bridge.GetSMTPPort())
// Set the port to 1024.
require.NoError(t, bridge.SetSMTPPort(1024))
// Get the new setting.
require.Equal(t, 1024, bridge.GetSMTPPort())
})
})
}
func TestBridge_Settings_SMTPSSL(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, SMTP SSL is disabled.
require.False(t, bridge.GetSMTPSSL())
// Enable SMTP SSL.
require.NoError(t, bridge.SetSMTPSSL(true))
// Get the new setting.
require.True(t, bridge.GetSMTPSSL())
})
})
}
func TestBridge_Settings_Proxy(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, proxy is allowed.
require.True(t, bridge.GetProxyAllowed())
// Disallow proxy.
mocks.ProxyDialer.EXPECT().DisallowProxy()
require.NoError(t, bridge.SetProxyAllowed(false))
// Get the new setting.
require.False(t, bridge.GetProxyAllowed())
})
})
}
func TestBridge_Settings_Autostart(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, autostart is disabled.
require.False(t, bridge.GetAutostart())
// Enable autostart.
mocks.Autostarter.EXPECT().Enable().Return(nil)
require.NoError(t, bridge.SetAutostart(true))
// Get the new setting.
require.True(t, bridge.GetAutostart())
})
})
}
func TestBridge_Settings_FirstStart(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true.
require.True(t, bridge.GetFirstStart())
// Set first start to false.
require.NoError(t, bridge.SetFirstStart(false))
// Get the new setting.
require.False(t, bridge.GetFirstStart())
})
})
}
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true.
require.True(t, bridge.GetFirstStartGUI())
// Set first start to false.
require.NoError(t, bridge.SetFirstStartGUI(false))
// Get the new setting.
require.False(t, bridge.GetFirstStartGUI())
})
})
}

109
internal/bridge/smtp.go Normal file
View File

@ -0,0 +1,109 @@
package bridge
import (
"crypto/tls"
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
)
func (bridge *Bridge) GetSMTPPort() int {
return bridge.vault.GetSMTPPort()
}
func (bridge *Bridge) SetSMTPPort(newPort int) error {
if newPort == bridge.vault.GetSMTPPort() {
return nil
}
if err := bridge.vault.SetSMTPPort(newPort); err != nil {
return err
}
return bridge.restartSMTP()
}
func (bridge *Bridge) GetSMTPSSL() bool {
return bridge.vault.GetSMTPSSL()
}
func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
if newSSL == bridge.vault.GetSMTPSSL() {
return nil
}
if err := bridge.vault.SetSMTPSSL(newSSL); err != nil {
return err
}
return bridge.restartSMTP()
}
func (bridge *Bridge) serveSMTP() error {
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
if err != nil {
return fmt.Errorf("failed to create SMTP listener: %w", err)
}
bridge.smtpListener = smtpListener
go func() {
if err := bridge.smtpServer.Serve(bridge.smtpListener); err != nil {
logrus.WithError(err).Error("SMTP server stopped")
}
}()
return nil
}
func (bridge *Bridge) restartSMTP() error {
if err := bridge.closeSMTP(); err != nil {
return err
}
smtpServer, err := newSMTPServer(bridge.smtpBackend, bridge.tlsConfig)
if err != nil {
return err
}
bridge.smtpServer = smtpServer
return bridge.serveSMTP()
}
func (bridge *Bridge) closeSMTP() error {
if err := bridge.smtpServer.Close(); err != nil {
logrus.WithError(err).Warn("Failed to close SMTP server")
}
// Don't close the SMTP listener -- it's closed by the server.
return nil
}
func newSMTPServer(smtpBackend *smtpBackend, tlsConfig *tls.Config) (*smtp.Server, error) {
smtpServer := smtp.NewServer(smtpBackend)
smtpServer.TLSConfig = tlsConfig
smtpServer.Domain = constants.Host
smtpServer.AllowInsecureAuth = true
smtpServer.MaxLineLength = 1 << 16
smtpServer.EnableAuth(sasl.Login, func(conn *smtp.Conn) sasl.Server {
return sasl.NewLoginServer(func(address, password string) error {
user, err := conn.Server().Backend.Login(nil, address, password)
if err != nil {
return err
}
conn.SetSession(user)
return nil
})
})
return smtpServer, nil
}

View File

@ -0,0 +1,70 @@
package bridge
import (
"sync"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"golang.org/x/exp/slices"
)
type smtpBackend struct {
users []*user.User
usersLock sync.RWMutex
}
func newSMTPBackend() (*smtpBackend, error) {
return &smtpBackend{}, nil
}
func (backend *smtpBackend) Login(state *smtp.ConnectionState, username string, password string) (smtp.Session, error) {
backend.usersLock.RLock()
defer backend.usersLock.RUnlock()
for _, user := range backend.users {
if slices.Contains(user.Addresses(), username) && user.BridgePass() == password {
return user.NewSMTPSession(username)
}
}
return nil, ErrNoSuchUser
}
func (backend *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
return nil, ErrNotImplemented
}
// addUser adds the given user to the backend.
// It returns an error if a user with the same ID already exists.
func (backend *smtpBackend) addUser(user *user.User) error {
backend.usersLock.Lock()
defer backend.usersLock.Unlock()
for _, u := range backend.users {
if u.ID() == user.ID() {
return ErrUserAlreadyExists
}
}
backend.users = append(backend.users, user)
return nil
}
// removeUser removes the given user from the backend.
// It returns an error if the user doesn't exist.
func (backend *smtpBackend) removeUser(user *user.User) error {
backend.usersLock.Lock()
defer backend.usersLock.Unlock()
idx := xslices.Index(backend.users, user)
if idx < 0 {
return ErrNoSuchUser
}
backend.users = append(backend.users[:idx], backend.users[idx+1:]...)
return nil
}

View File

@ -1,87 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"fmt"
"path/filepath"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
"github.com/ProtonMail/proton-bridge/v2/internal/store"
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
)
type storeFactory struct {
cacheProvider CacheProvider
sentryReporter *sentry.Reporter
panicHandler users.PanicHandler
eventListener listener.Listener
events *store.Events
cache cache.Cache
builder *message.Builder
}
func newStoreFactory(
cacheProvider CacheProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
) *storeFactory {
return &storeFactory{
cacheProvider: cacheProvider,
sentryReporter: sentryReporter,
panicHandler: panicHandler,
eventListener: eventListener,
events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
cache: cache,
builder: builder,
}
}
// New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
return store.New(
f.sentryReporter,
f.panicHandler,
user,
f.eventListener,
f.cache,
f.builder,
getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
f.events,
)
}
// Remove removes all store files for given user.
func (f *storeFactory) Remove(userID string) error {
return store.RemoveStore(
f.events,
getUserStorePath(f.cacheProvider.GetDBDir(), userID),
userID,
)
}
// getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) {
return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
}

View File

@ -1,64 +1,5 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"crypto/tls"
pkgTLS "github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/pkg/errors"
logrus "github.com/sirupsen/logrus"
)
func (b *Bridge) GetTLSConfig() (*tls.Config, error) {
if !b.tls.HasCerts() {
if err := b.generateTLSCerts(); err != nil {
return nil, err
}
}
tlsConfig, err := b.tls.GetConfig()
if err == nil {
return tlsConfig, nil
}
logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates")
if err := b.generateTLSCerts(); err != nil {
return nil, err
}
return b.tls.GetConfig()
}
func (b *Bridge) generateTLSCerts() error {
template, err := pkgTLS.NewTLSTemplate()
if err != nil {
return errors.Wrap(err, "failed to generate TLS template")
}
if err := b.tls.GenerateCerts(template); err != nil {
return errors.Wrap(err, "failed to generate TLS certs")
}
if err := b.tls.InstallCerts(); err != nil {
return errors.Wrap(err, "failed to install TLS certs")
}
return nil
func (bridge *Bridge) GetBridgeTLSCert() ([]byte, []byte) {
return bridge.vault.GetBridgeTLSCert(), bridge.vault.GetBridgeTLSKey()
}

View File

@ -1,62 +1,43 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"github.com/Masterminds/semver/v3"
"context"
"net"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
)
type Locator interface {
ProvideSettingsPath() (string, error)
ProvideLogsPath() (string, error)
GetLicenseFilePath() string
GetDependencyLicensesLink() string
Clear() error
ClearUpdates() error
}
type CacheProvider interface {
GetIMAPCachePath() string
GetDBDir() string
GetDefaultMessageCacheDir() string
type Identifier interface {
GetUserAgent() string
HasClient() bool
SetClient(name, version string)
SetPlatform(platform string)
}
type SettingsProvider interface {
Get(key settings.Key) string
Set(key settings.Key, value string)
type TLSReporter interface {
GetTLSIssueCh() <-chan struct{}
}
GetBool(key settings.Key) bool
SetBool(key settings.Key, val bool)
type ProxyDialer interface {
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
GetInt(key settings.Key) int
SetInt(key settings.Key, val int)
AllowProxy()
DisallowProxy()
}
type Autostarter interface {
Enable() error
Disable() error
}
type Updater interface {
Check() (updater.VersionInfo, error)
IsDowngrade(updater.VersionInfo) bool
InstallUpdate(updater.VersionInfo) error
}
type Versioner interface {
RemoveOtherVersions(*semver.Version) error
GetVersionInfo(downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error)
InstallUpdate(downloader updater.Downloader, update updater.VersionInfo) error
}

View File

@ -0,0 +1,72 @@
package bridge
import (
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
)
func (bridge *Bridge) CheckForUpdates() {
bridge.updateCheckCh <- struct{}{}
}
func (bridge *Bridge) watchForUpdates() error {
ticker := time.NewTicker(constants.UpdateCheckInterval)
go func() {
for {
select {
case <-bridge.updateCheckCh:
case <-ticker.C:
}
version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel())
if err != nil {
continue
}
if err := bridge.handleUpdate(version); err != nil {
continue
}
}
}()
bridge.updateCheckCh <- struct{}{}
return nil
}
func (bridge *Bridge) handleUpdate(version updater.VersionInfo) error {
switch {
case !version.Version.GreaterThan(bridge.curVersion):
bridge.publish(events.UpdateNotAvailable{})
case version.RolloutProportion < bridge.vault.GetUpdateRollout():
bridge.publish(events.UpdateNotAvailable{})
case bridge.curVersion.LessThan(version.MinAuto):
bridge.publish(events.UpdateAvailable{
Version: version,
CanInstall: false,
})
case !bridge.vault.GetAutoUpdate():
bridge.publish(events.UpdateAvailable{
Version: version,
CanInstall: true,
})
default:
if err := bridge.updater.InstallUpdate(bridge.api, version); err != nil {
return err
}
bridge.publish(events.UpdateInstalled{
Version: version,
})
}
return nil
}

View File

@ -1,26 +1,9 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
func (b *Bridge) GetCurrentUserAgent() string {
return b.userAgent.String()
func (bridge *Bridge) GetCurrentUserAgent() string {
return bridge.identifier.GetUserAgent()
}
func (b *Bridge) SetCurrentPlatform(platform string) {
b.userAgent.SetPlatform(platform)
func (bridge *Bridge) SetCurrentPlatform(platform string) {
bridge.identifier.SetPlatform(platform)
}

434
internal/bridge/users.go Normal file
View File

@ -0,0 +1,434 @@
package bridge
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
type UserInfo struct {
// UserID is the user's API ID.
UserID string
// Username is the user's API username.
Username string
// Connected is true if the user is logged in (has API auth).
Connected bool
// Addresses holds the user's email addresses. The first address is the primary address.
Addresses []string
// AddressMode is the user's address mode.
AddressMode AddressMode
// BridgePass is the user's bridge password.
BridgePass string
// UsedSpace is the amount of space used by the user.
UsedSpace int
// MaxSpace is the total amount of space available to the user.
MaxSpace int
}
type AddressMode int
const (
SplitMode AddressMode = iota
CombinedMode
)
// GetUserIDs returns the IDs of all known users (authorized or not).
func (bridge *Bridge) GetUserIDs() []string {
return bridge.vault.GetUserIDs()
}
// GetUserInfo returns info about the given user.
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
vaultUser, err := bridge.vault.GetUser(userID)
if err != nil {
return UserInfo{}, err
}
user, ok := bridge.users[userID]
if !ok {
return getUserInfo(vaultUser.UserID(), vaultUser.Username()), nil
}
return getConnUserInfo(user), nil
}
// QueryUserInfo queries the user info by username or address.
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
for userID, user := range bridge.users {
if user.Match(query) {
return bridge.GetUserInfo(userID)
}
}
return UserInfo{}, ErrNoSuchUser
}
// LoginUser authorizes a new bridge user with the given username and password.
// If necessary, a TOTP and mailbox password are requested via the callbacks.
func (bridge *Bridge) LoginUser(
ctx context.Context,
username, password string,
getTOTP func() (string, error),
getKeyPass func() ([]byte, error),
) (string, error) {
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
if err != nil {
return "", err
}
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
totp, err := getTOTP()
if err != nil {
return "", err
}
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
return "", err
}
}
var keyPass []byte
if auth.PasswordMode == liteapi.TwoPasswordMode {
pass, err := getKeyPass()
if err != nil {
return "", err
}
keyPass = pass
} else {
keyPass = []byte(password)
}
apiUser, apiAddrs, userKR, addrKRs, saltedKeyPass, err := client.Unlock(ctx, keyPass)
if err != nil {
return "", err
}
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
return "", err
}
return apiUser.ID, nil
}
// LogoutUser logs out the given user.
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
return bridge.logoutUser(ctx, userID, true, false)
}
// DeleteUser deletes the given user.
// If it is authorized, it is logged out first.
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
if bridge.users[userID] != nil {
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
return err
}
}
if err := bridge.vault.DeleteUser(userID); err != nil {
return err
}
bridge.publish(events.UserDeleted{
UserID: userID,
})
return nil
}
func (bridge *Bridge) GetAddressMode(userID string) (AddressMode, error) {
panic("TODO")
}
func (bridge *Bridge) SetAddressMode(userID string, mode AddressMode) error {
panic("TODO")
}
// loadUsers loads authorized users from the vault.
func (bridge *Bridge) loadUsers(ctx context.Context) error {
for _, userID := range bridge.vault.GetUserIDs() {
user, err := bridge.vault.GetUser(userID)
if err != nil {
return err
}
if user.AuthUID() == "" {
continue
}
if err := bridge.loadUser(ctx, user); err != nil {
logrus.WithError(err).Error("Failed to load connected user")
if err := user.Clear(); err != nil {
logrus.WithError(err).Error("Failed to clear user")
}
continue
}
}
return nil
}
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
if err != nil {
return fmt.Errorf("failed to create API client: %w", err)
}
apiUser, apiAddrs, userKR, addrKRs, err := client.UnlockSalted(ctx, user.KeyPass())
if err != nil {
return fmt.Errorf("failed to unlock user: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
return fmt.Errorf("failed to add user: %w", err)
}
bridge.publish(events.UserLoggedIn{
UserID: user.UserID(),
})
return nil
}
// addUser adds a new user with an already salted mailbox password.
func (bridge *Bridge) addUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
authUID, authRef string,
saltedKeyPass []byte,
) error {
if _, ok := bridge.users[apiUser.ID]; ok {
return ErrUserAlreadyLoggedIn
}
var user *user.User
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil {
return err
}
user = existingUser
} else {
newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil {
return err
}
user = newUser
}
go func() {
for event := range user.GetNotifyCh() {
switch event := event.(type) {
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
}
bridge.publish(event)
}
}()
// Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user.
client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error {
if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok {
bridge.identifier.SetClient(imapID.Name, imapID.Version)
}
return nil
})
bridge.publish(events.UserLoggedIn{
UserID: user.ID(),
})
return nil
}
func (bridge *Bridge) addNewUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
authUID, authRef string,
saltedKeyPass []byte,
) (*user.User, error) {
vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
if err != nil {
return nil, err
}
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
if err != nil {
return nil, err
}
gluonKey, err := crypto.RandomToken(32)
if err != nil {
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, gluonKey)
if err != nil {
return nil, err
}
if err := vaultUser.UpdateGluonData(gluonID, gluonKey); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err
}
bridge.users[apiUser.ID] = user
return user, nil
}
func (bridge *Bridge) addExistingUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
authUID, authRef string,
saltedKeyPass []byte,
) (*user.User, error) {
vaultUser, err := bridge.vault.GetUser(apiUser.ID)
if err != nil {
return nil, err
}
if err := vaultUser.UpdateAuth(authUID, authRef); err != nil {
return nil, err
}
if err := vaultUser.UpdateKeyPass(saltedKeyPass); err != nil {
return nil, err
}
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
if err != nil {
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, user.GluonID(), user.GluonKey()); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err
}
bridge.users[apiUser.ID] = user
return user, nil
}
// logoutUser closes and removes the user with the given ID.
// If withAPI is true, the user will additionally be logged out from API.
// If withFiles is true, the user's files will be deleted.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
vaultUser, err := bridge.vault.GetUser(userID)
if err != nil {
return err
}
if err := bridge.imapServer.RemoveUser(ctx, vaultUser.GluonID(), withFiles); err != nil {
return err
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
return err
}
if withAPI {
if err := user.Logout(ctx); err != nil {
return err
}
}
if err := user.Close(ctx); err != nil {
return err
}
if err := vaultUser.Clear(); err != nil {
return err
}
delete(bridge.users, userID)
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}
// getUserInfo returns information about a disconnected user.
func getUserInfo(userID, username string) UserInfo {
return UserInfo{
UserID: userID,
Username: username,
AddressMode: CombinedMode,
}
}
// getConnUserInfo returns information about a connected user.
func getConnUserInfo(user *user.User) UserInfo {
return UserInfo{
Connected: true,
UserID: user.ID(),
Username: user.Name(),
Addresses: user.Addresses(),
AddressMode: CombinedMode,
BridgePass: user.BridgePass(),
UsedSpace: user.UsedSpace(),
MaxSpace: user.MaxSpace(),
}
}

View File

@ -0,0 +1,286 @@
package bridge_test
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_WithoutUsers(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_Login(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginLogoutLogin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Logout the user.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// The user is now disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Login the user again.
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginDeleteLogin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Delete the user.
require.NoError(t, bridge.DeleteUser(ctx, userID))
// The user is now gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Login the user again.
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginDeauthLogin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// Get a channel to receive the deauth event.
eventCh, done := bridge.GetEvents(events.UserDeauth{})
defer done()
// Deauth the user.
require.NoError(t, s.RevokeUser(userID))
// The user is eventually disconnected.
require.Eventually(t, func() bool {
return len(getConnectedUserIDs(t, bridge)) == 0
}, 10*time.Second, time.Second)
// We should get a deauth event.
require.IsType(t, events.UserDeauth{}, <-eventCh)
// Login the user after the disconnection.
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginExpireLogin(t *testing.T) {
const authLife = 2 * time.Second
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
s.SetAuthLife(authLife)
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. Its auth will only be valid for a short time.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// Wait until the auth expires.
time.Sleep(authLife)
// The user will have to refresh but the logout will still succeed.
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
})
}
func TestBridge_FailToLoad(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
// Deauth the user while bridge is stopped.
require.NoError(t, s.RevokeUser(userID))
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginRestart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginLogoutRestart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
// Logout the user.
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginDeleteRestart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
// Delete the user.
require.NoError(t, bridge.DeleteUser(ctx, userID))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_BridgePass(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID, pass string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
// Retrieve the bridge pass.
pass = must(bridge.GetUserInfo(userID)).BridgePass
// Log the user out.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// Log the user back in.
must(bridge.LoginUser(ctx, username, password, nil, nil))
// The bridge pass should be the same.
require.Equal(t, pass, pass)
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The bridge should load schizofrenic.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// The bridge pass should be the same.
require.Equal(t, pass, must(bridge.GetUserInfo(userID)).BridgePass)
})
})
}