GODT-1779: Remove go-imap

This commit is contained in:
James Houlahan
2022-08-26 17:00:21 +02:00
parent 3b0bc1ca15
commit 39433fe707
593 changed files with 12725 additions and 91626 deletions

View File

@ -1,85 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package api provides HTTP API of the Bridge.
//
// API endpoints:
// - /focus, see focusHandler
package api
import (
"fmt"
"net/http"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
"github.com/sirupsen/logrus"
)
var log = logrus.WithField("pkg", "api") //nolint:gochecknoglobals
type apiServer struct {
host string
settings *settings.Settings
eventListener listener.Listener
}
// NewAPIServer returns prepared API server struct.
func NewAPIServer(settings *settings.Settings, eventListener listener.Listener) *apiServer { //nolint:revive
return &apiServer{
host: bridge.Host,
settings: settings,
eventListener: eventListener,
}
}
// Starts the server.
func (api *apiServer) ListenAndServe() {
mux := http.NewServeMux()
mux.HandleFunc("/focus", wrapper(api, focusHandler))
addr := api.getAddress()
server := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second, // fix gosec G112 (vulnerability to [Slowloris](https://www.cloudflare.com/en-gb/learning/ddos/ddos-attack-tools/slowloris/) attack).
}
log.Info("API listening at ", addr)
if err := server.ListenAndServe(); err != nil {
api.eventListener.Emit(events.ErrorEvent, "API failed: "+err.Error())
log.Error("API failed: ", err)
}
defer server.Close() //nolint:errcheck
}
func (api *apiServer) getAddress() string {
port := api.settings.GetInt(settings.APIPortKey)
newPort := ports.FindFreePortFrom(port)
if newPort != port {
api.settings.SetInt(settings.APIPortKey, newPort)
}
return getAPIAddress(api.host, newPort)
}
func getAPIAddress(host string, port int) string {
return fmt.Sprintf("%s:%d", host, port)
}

View File

@ -1,51 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package api
import (
"net/http"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
)
// httpHandler with Go's Response and Request.
type httpHandler func(http.ResponseWriter, *http.Request)
// handler with our context.
type handler func(handlerContext) error
type handlerContext struct {
req *http.Request
resp http.ResponseWriter
eventListener listener.Listener
}
func wrapper(api *apiServer, callback handler) httpHandler {
return func(w http.ResponseWriter, req *http.Request) {
ctx := handlerContext{
req: req,
resp: w,
eventListener: api.eventListener,
}
err := callback(ctx)
if err != nil {
log.Error("API callback of ", req.URL, " failed: ", err)
http.Error(w, err.Error(), 500)
}
}
}

View File

@ -1,51 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package api
import (
"fmt"
"net/http"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
)
// focusHandler should be called from other instances (attempt to start bridge
// for the second time) to get focus in the currently running instance.
func focusHandler(ctx handlerContext) error {
log.Info("Focus from other instance")
ctx.eventListener.Emit(events.SecondInstanceEvent, "")
fmt.Fprintf(ctx.resp, "OK")
return nil
}
// CheckOtherInstanceAndFocus is helper for new instances to check if there is
// already a running instance and get it's focus.
func CheckOtherInstanceAndFocus(port int) error {
addr := getAPIAddress(bridge.Host, port)
resp, err := (&http.Client{}).Get("http://" + addr + "/focus")
if err != nil {
return err
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != 200 {
log.Error("Focus error: ", resp.StatusCode)
}
return nil
}

149
internal/app/app.go Normal file
View File

@ -0,0 +1,149 @@
package app
import (
"fmt"
"path/filepath"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
bridgeCLI "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/pkg/restarter"
"github.com/pkg/profile"
"github.com/urfave/cli/v2"
)
const (
flagCPUProfile = "cpu-prof"
flagCPUProfileShort = "p"
flagMemProfile = "mem-prof"
flagMemProfileShort = "m"
flagLogLevel = "log-level"
flagLogLevelShort = "l"
flagCLI = "cli"
flagCLIShort = "c"
flagNoWindow = "no-window"
flagNonInteractive = "non-interactive"
)
const (
appUsage = "Proton Mail IMAP and SMTP Bridge"
)
func New() *cli.App {
app := cli.NewApp()
app.Name = constants.FullAppName
app.Usage = appUsage
app.Flags = []cli.Flag{
&cli.BoolFlag{
Name: flagCPUProfile,
Aliases: []string{flagCPUProfileShort},
Usage: "Generate CPU profile",
},
&cli.BoolFlag{
Name: flagMemProfile,
Aliases: []string{flagMemProfileShort},
Usage: "Generate memory profile",
},
&cli.StringFlag{
Name: flagLogLevel,
Aliases: []string{flagLogLevelShort},
Usage: "Set the log level (one of panic, fatal, error, warn, info, debug)",
},
&cli.BoolFlag{
Name: flagCLI,
Aliases: []string{flagCLIShort},
Usage: "Use command line interface",
},
&cli.BoolFlag{
Name: flagNoWindow,
Usage: "Don't show window after start",
Hidden: true,
},
}
app.Action = run
return app
}
func run(c *cli.Context) error {
// If there's another instance already running, try to raise it and exit.
if raised := focus.TryRaise(); raised {
return nil
}
// Start CPU profile if requested.
if c.Bool(flagCPUProfile) {
p := profile.Start(profile.CPUProfile, profile.ProfilePath("cpu.pprof"))
defer p.Stop()
}
// Start memory profile if requested.
if c.Bool(flagMemProfile) {
p := profile.Start(profile.MemProfile, profile.MemProfileAllocs, profile.ProfilePath("mem.pprof"))
defer p.Stop()
}
// Create the restarter.
restarter := restarter.New()
defer restarter.Restart()
// Create a user agent that will be used for all requests.
identifier := useragent.New()
// Create a crash handler that will send crash reports to sentry.
crashHandler := crash.NewHandler(
sentry.NewReporter(constants.FullAppName, constants.Version, identifier).ReportException,
crash.ShowErrorNotification(constants.FullAppName),
func(r interface{}) error { restarter.Set(true, true); return nil },
)
defer crashHandler.HandlePanic()
// Create a locations provider to determine where to store our files.
provider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, constants.ConfigName))
if err != nil {
return fmt.Errorf("could not create locations provider: %w", err)
}
// Create a new locations object that will be used to provide paths to store files.
locations := locations.New(provider, constants.ConfigName)
// Initialize the logging.
if err := initLogging(c, locations, crashHandler); err != nil {
return fmt.Errorf("could not initialize logging: %w", err)
}
// Create the bridge.
bridge, err := newBridge(locations, identifier)
if err != nil {
return fmt.Errorf("could not create bridge: %w", err)
}
defer bridge.Close(c.Context)
// Start the frontend.
switch {
case c.Bool(flagCLI):
return bridgeCLI.New(bridge).Loop()
case c.Bool(flagNonInteractive):
select {}
default:
service, err := grpc.NewService(crashHandler, restarter, locations, bridge, !c.Bool(flagNoWindow))
if err != nil {
return fmt.Errorf("could not create service: %w", err)
}
return service.Loop()
}
}

View File

@ -1,35 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package base
import "strings"
// StripProcessSerialNumber removes additional flag from macOS.
// More info:
// http://mirror.informatimago.com/next/developer.apple.com/documentation/Carbon/Reference/Process_Manager/prmref_main/data_type_5.html#//apple_ref/doc/uid/TP30000208/C001951
func StripProcessSerialNumber(args []string) []string {
res := args[:0]
for _, arg := range args {
if !strings.Contains(arg, "-psn_") {
res = append(res, arg)
}
}
return res
}

View File

@ -1,424 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package base implements a common application base currently shared by bridge and IE.
// The base includes the following:
// - access to standard filesystem locations like config, cache, logging dirs
// - an extensible crash handler
// - versioned cache directory
// - persistent settings
// - event listener
// - credentials store
// - pmapi Manager
//
// In addition, the base initialises logging and reacts to command line arguments
// which control the log verbosity and enable cpu/memory profiling.
package base
import (
"math/rand"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-autostart"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/api"
"github.com/ProtonMail/proton-bridge/v2/internal/config/cache"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/v2/internal/versioner"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)
const (
flagCPUProfile = "cpu-prof"
flagCPUProfileShort = "p"
flagMemProfile = "mem-prof"
flagMemProfileShort = "m"
flagLogLevel = "log-level"
flagLogLevelShort = "l"
// FlagCLI indicate to start with command line interface.
FlagCLI = "cli"
flagCLIShort = "c"
flagRestart = "restart"
FlagLauncher = "launcher"
FlagNoWindow = "no-window"
)
type Base struct {
SentryReporter *sentry.Reporter
CrashHandler *crash.Handler
Locations *locations.Locations
Settings *settings.Settings
Lock *os.File
Cache *cache.Cache
Listener listener.Listener
Creds *credentials.Store
CM pmapi.Manager
CookieJar *cookies.Jar
UserAgent *useragent.UserAgent
Updater *updater.Updater
Versioner *versioner.Versioner
TLS *tls.TLS
Autostart *autostart.App
Name string // the app's name
usage string // the app's usage description
command string // the command used to launch the app (either the exe path or the launcher path)
restart bool // whether the app is currently set to restart
launcher string // launcher to be used if not set in args
mainExecutable string // mainExecutable the main executable process.
teardown []func() error // actions to perform when app is exiting
}
func New( //nolint:funlen
appName,
appUsage,
configName,
updateURLName,
keychainName,
cacheVersion string,
) (*Base, error) {
userAgent := useragent.New()
sentryReporter := sentry.NewReporter(appName, constants.Version, userAgent)
crashHandler := crash.NewHandler(
sentryReporter.ReportException,
crash.ShowErrorNotification(appName),
)
defer crashHandler.HandlePanic()
rand.Seed(time.Now().UnixNano())
os.Args = StripProcessSerialNumber(os.Args)
locationsProvider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, configName))
if err != nil {
return nil, err
}
locations := locations.New(locationsProvider, configName)
logsPath, err := locations.ProvideLogsPath()
if err != nil {
return nil, err
}
if err := logging.Init(logsPath); err != nil {
return nil, err
}
crashHandler.AddRecoveryAction(logging.DumpStackTrace(logsPath))
if err := migrateFiles(configName); err != nil {
logrus.WithError(err).Warn("Old config files could not be migrated")
}
if err := locations.Clean(); err != nil {
return nil, err
}
settingsPath, err := locations.ProvideSettingsPath()
if err != nil {
return nil, err
}
settingsObj := settings.New(settingsPath)
lock, err := checkSingleInstance(locations.GetLockFile(), settingsObj)
if err != nil {
logrus.WithError(err).Warnf("%v is already running", appName)
return nil, api.CheckOtherInstanceAndFocus(settingsObj.GetInt(settings.APIPortKey))
}
if err := migrateRebranding(settingsObj, keychainName); err != nil {
logrus.WithError(err).Warn("Rebranding migration failed")
}
cachePath, err := locations.ProvideCachePath()
if err != nil {
return nil, err
}
cache, err := cache.New(cachePath, cacheVersion)
if err != nil {
return nil, err
}
if err := cache.RemoveOldVersions(); err != nil {
return nil, err
}
listener := listener.New()
events.SetupEvents(listener)
// If we can't load the keychain for whatever reason,
// we signal to frontend and supply a dummy keychain that always returns errors.
kc, err := keychain.NewKeychain(settingsObj, keychainName)
if err != nil {
listener.Emit(events.CredentialsErrorEvent, err.Error())
kc = keychain.NewMissingKeychain()
}
cfg := pmapi.NewConfig(configName, constants.Version)
cfg.GetUserAgent = userAgent.String
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
cm := pmapi.New(cfg)
sentryReporter.SetClientFromManager(cm)
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
func() { listener.Emit(events.InternetConnChangedEvent, events.InternetOff) },
func() { listener.Emit(events.InternetConnChangedEvent, events.InternetOn) },
))
jar, err := cookies.NewCookieJar(settingsObj)
if err != nil {
return nil, err
}
cm.SetCookieJar(jar)
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
if err != nil {
return nil, err
}
kr, err := crypto.NewKeyRing(key)
if err != nil {
return nil, err
}
updatesDir, err := locations.ProvideUpdatesPath()
if err != nil {
return nil, err
}
versioner := versioner.New(updatesDir)
installer := updater.NewInstaller(versioner)
updater := updater.New(
cm,
installer,
settingsObj,
kr,
semver.MustParse(constants.Version),
updateURLName,
runtime.GOOS,
)
exe, err := os.Executable()
if err != nil {
return nil, err
}
autostart := &autostart.App{
Name: startupNameForRebranding(appName),
DisplayName: appName,
Exec: []string{exe, "--" + FlagNoWindow},
}
return &Base{
SentryReporter: sentryReporter,
CrashHandler: crashHandler,
Locations: locations,
Settings: settingsObj,
Lock: lock,
Cache: cache,
Listener: listener,
Creds: credentials.NewStore(kc),
CM: cm,
CookieJar: jar,
UserAgent: userAgent,
Updater: updater,
Versioner: versioner,
TLS: tls.New(settingsPath),
Autostart: autostart,
Name: appName,
usage: appUsage,
// By default, the command is the app's executable.
// This can be changed at runtime by using the "--launcher" flag.
command: exe,
// By default, the command is the app's executable.
// This can be changed at runtime by summoning the SetMainExecutable gRPC call.
mainExecutable: exe,
}, nil
}
func (b *Base) NewApp(mainLoop func(*Base, *cli.Context) error) *cli.App {
app := cli.NewApp()
app.Name = b.Name
app.Usage = b.usage
app.Version = constants.Version
app.Action = b.wrapMainLoop(mainLoop)
app.Flags = []cli.Flag{
&cli.BoolFlag{
Name: flagCPUProfile,
Aliases: []string{flagCPUProfileShort},
Usage: "Generate CPU profile",
},
&cli.BoolFlag{
Name: flagMemProfile,
Aliases: []string{flagMemProfileShort},
Usage: "Generate memory profile",
},
&cli.StringFlag{
Name: flagLogLevel,
Aliases: []string{flagLogLevelShort},
Usage: "Set the log level (one of panic, fatal, error, warn, info, debug)",
},
&cli.BoolFlag{
Name: FlagCLI,
Aliases: []string{flagCLIShort},
Usage: "Use command line interface",
},
&cli.BoolFlag{
Name: FlagNoWindow,
Usage: "Don't show window after start",
},
&cli.StringFlag{
Name: flagRestart,
Usage: "The number of times the application has already restarted",
Hidden: true,
},
&cli.StringFlag{
Name: FlagLauncher,
Usage: "The launcher to use to restart the application",
Hidden: true,
},
}
return app
}
// SetToRestart sets the app to restart the next time it is closed.
func (b *Base) SetToRestart() {
b.restart = true
}
func (b *Base) ForceLauncher(launcher string) {
b.launcher = launcher
b.setupLauncher(launcher)
}
func (b *Base) SetMainExecutable(exe string) {
logrus.Info("Main Executable set to ", exe)
b.mainExecutable = exe
}
// AddTeardownAction adds an action to perform during app teardown.
func (b *Base) AddTeardownAction(fn func() error) {
b.teardown = append(b.teardown, fn)
}
func (b *Base) wrapMainLoop(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc { //nolint:funlen
return func(c *cli.Context) error {
defer b.CrashHandler.HandlePanic()
defer func() { _ = b.Lock.Close() }()
// If launcher was used to start the app, use that for restart
// and autostart.
if launcher := c.String(FlagLauncher); launcher != "" {
b.setupLauncher(launcher)
}
if c.Bool(flagCPUProfile) {
startCPUProfile()
defer pprof.StopCPUProfile()
}
if c.Bool(flagMemProfile) {
defer makeMemoryProfile()
}
logging.SetLevel(c.String(flagLogLevel))
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
logrus.
WithField("appName", b.Name).
WithField("version", constants.Version).
WithField("revision", constants.Revision).
WithField("build", constants.BuildTime).
WithField("runtime", runtime.GOOS).
WithField("args", os.Args).
Info("Run app")
b.CrashHandler.AddRecoveryAction(func(interface{}) error {
sentry.Flush(2 * time.Second)
if c.Int(flagRestart) > maxAllowedRestarts {
logrus.
WithField("restart", c.Int("restart")).
Warn("Not restarting, already restarted too many times")
os.Exit(1)
return nil
}
return b.restartApp(true)
})
if err := appMainLoop(b, c); err != nil {
return err
}
if err := b.doTeardown(); err != nil {
return err
}
if b.restart {
return b.restartApp(false)
}
return nil
}
}
func (b *Base) doTeardown() error {
for _, action := range b.teardown {
if err := action(); err != nil {
return err
}
}
return nil
}
func (b *Base) setupLauncher(launcher string) {
b.command = launcher
// Bridge supports no-window option which we should use
// for autostart.
b.Autostart.Exec = []string{launcher, "--" + FlagNoWindow}
}

View File

@ -1,131 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package base
import (
"os"
"path/filepath"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/sirupsen/logrus"
)
// migrateFiles migrates files from their old (pre-refactor) locations to their new locations.
// We can remove this eventually.
//
// | entity | old location | new location |
// |-----------|-------------------------------------------|----------------------------------------|
// | prefs | ~/.cache/protonmail/<app>/c11/prefs.json | ~/.config/protonmail/<app>/prefs.json |
// | c11 1.5.x | ~/.cache/protonmail/<app>/c11 | ~/.cache/protonmail/<app>/cache/c11 |
// | c11 1.6.x | ~/.cache/protonmail/<app>/cache/c11 | ~/.config/protonmail/<app>/cache/c11 |
// | updates | ~/.cache/protonmail/<app>/updates | ~/.config/protonmail/<app>/updates |.
func migrateFiles(configName string) error {
locationsProvider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, configName))
if err != nil {
return err
}
locations := locations.New(locationsProvider, configName)
userCacheDir := locationsProvider.UserCache()
if err := migratePrefsFrom15x(locations, userCacheDir); err != nil {
return err
}
if err := migrateCacheFromBoth15xAnd16x(locations, userCacheDir); err != nil {
return err
}
if err := migrateUpdatesFrom16x(configName, locations); err != nil { //nolint:revive It is more clear to structure this way
return err
}
return nil
}
func migratePrefsFrom15x(locations *locations.Locations, userCacheDir string) error {
newSettingsDir, err := locations.ProvideSettingsPath()
if err != nil {
return err
}
return moveIfExists(
filepath.Join(userCacheDir, "c11", "prefs.json"),
filepath.Join(newSettingsDir, "prefs.json"),
)
}
func migrateCacheFromBoth15xAnd16x(locations *locations.Locations, userCacheDir string) error {
olderCacheDir := userCacheDir
newerCacheDir := locations.GetOldCachePath()
latestCacheDir, err := locations.ProvideCachePath()
if err != nil {
return err
}
// Migration for versions before 1.6.x.
if err := moveIfExists(
filepath.Join(olderCacheDir, "c11"),
filepath.Join(latestCacheDir, "c11"),
); err != nil {
return err
}
// Migration for versions 1.6.x.
return moveIfExists(
filepath.Join(newerCacheDir, "c11"),
filepath.Join(latestCacheDir, "c11"),
)
}
func migrateUpdatesFrom16x(configName string, locations *locations.Locations) error {
// In order to properly update Bridge 1.6.X and higher we need to
// change the launcher first. Since this is not part of automatic
// updates the migration must wait until manual update. Until that
// we need to keep old path.
if configName == "bridge" {
return nil
}
oldUpdatesPath := locations.GetOldUpdatesPath()
// Do not use ProvideUpdatesPath, that creates dir right away.
newUpdatesPath := locations.GetUpdatesPath()
return moveIfExists(oldUpdatesPath, newUpdatesPath)
}
func moveIfExists(source, destination string) error {
l := logrus.WithField("source", source).WithField("destination", destination)
if _, err := os.Stat(source); os.IsNotExist(err) {
l.Info("No need to migrate file, source doesn't exist")
return nil
}
if _, err := os.Stat(destination); !os.IsNotExist(err) {
// Once migrated, files should not stay in source anymore. Therefore
// if some files are still in source location but target already exist,
// it's suspicious. Could happen by installing new version, then the
// old one because of some reason, and then the new one again.
// Good to see as warning because it could be a reason why Bridge is
// behaving weirdly, like wrong configuration, or db re-sync and so on.
l.Warn("No need to migrate file, target already exists")
return nil
}
l.Info("Migrating files")
return os.Rename(source, destination)
}

View File

@ -1,197 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package base
import (
"errors"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
)
const darwin = "darwin"
func migrateRebranding(settingsObj *settings.Settings, keychainName string) (result error) {
if err := migrateStartupBeforeRebranding(); err != nil {
result = multierror.Append(result, err)
}
lastUsedVersion := settingsObj.Get(settings.LastVersionKey)
// Skipping migration: it is first bridge start or cache was cleared.
if lastUsedVersion == "" {
settingsObj.SetBool(settings.RebrandingMigrationKey, true)
return
}
// Skipping rest of migration: already done
if settingsObj.GetBool(settings.RebrandingMigrationKey) {
return
}
switch runtime.GOOS {
case "windows", "linux":
// GODT-1260 we would need admin rights to changes desktop files
// and start menu items.
settingsObj.SetBool(settings.RebrandingMigrationKey, true)
case darwin:
if shouldContinue, err := isMacBeforeRebranding(); !shouldContinue || err != nil {
if err != nil {
result = multierror.Append(result, err)
}
break
}
if err := migrateMacKeychainBeforeRebranding(settingsObj, keychainName); err != nil {
result = multierror.Append(result, err)
}
settingsObj.SetBool(settings.RebrandingMigrationKey, true)
}
return result
}
// migrateMacKeychainBeforeRebranding deals with write access restriction to
// mac keychain passwords which are caused by application renaming. The old
// passwords are copied under new name in order to have write access afer
// renaming.
func migrateMacKeychainBeforeRebranding(settingsObj *settings.Settings, keychainName string) error {
l := logrus.WithField("pkg", "app/base/migration")
l.Warn("Migrating mac keychain")
helperConstructor, ok := keychain.Helpers["macos-keychain"]
if !ok {
return errors.New("cannot find macos-keychain helper")
}
oldKC, err := helperConstructor("ProtonMailBridgeService")
if err != nil {
l.WithError(err).Error("Keychain constructor failed")
return err
}
idByURL, err := oldKC.List()
if err != nil {
l.WithError(err).Error("List old keychain failed")
return err
}
newKC, err := keychain.NewKeychain(settingsObj, keychainName)
if err != nil {
return err
}
for url, id := range idByURL {
li := l.WithField("id", id).WithField("url", url)
userID, secret, err := oldKC.Get(url)
if err != nil {
li.WithField("userID", userID).
WithField("err", err).
Error("Faild to get old item")
continue
}
if _, _, err := newKC.Get(userID); err == nil {
li.Warn("Skipping migration, item already exists.")
continue
}
if err := newKC.Put(userID, secret); err != nil {
li.WithError(err).Error("Failed to migrate user")
}
li.Info("Item migrated")
}
return nil
}
// migrateStartupBeforeRebranding removes old startup links. The creation of new links is
// handled by bridge initialisation.
func migrateStartupBeforeRebranding() error {
path, err := os.UserHomeDir()
if err != nil {
return err
}
switch runtime.GOOS {
case "windows":
path = filepath.Join(path, `AppData\Roaming\Microsoft\Windows\Start Menu\Programs\Startup\ProtonMail Bridge.lnk`)
case "linux":
path = filepath.Join(path, `.config/autostart/ProtonMail Bridge.desktop`)
case darwin:
path = filepath.Join(path, `Library/LaunchAgents/ProtonMail Bridge.plist`)
default:
return errors.New("unknown GOOS")
}
if _, err := os.Stat(path); os.IsNotExist(err) {
return nil
}
logrus.WithField("pkg", "app/base/migration").Warn("Migrating autostartup links")
return os.Remove(path)
}
// startupNameForRebranding returns the name for autostart launcher based on
// type of rebranded instance i.e. update or manual.
//
// This only affects darwin when udpate re-writes the old startup and then
// manual installed it would not run proper exe. Therefore we return "old" name
// for updates and "new" name for manual which would be properly migrated.
//
// For orther (linux and windows) the link is always pointing to launcher which
// path didn't changed.
func startupNameForRebranding(origin string) string {
if runtime.GOOS == darwin {
if path, err := os.Executable(); err == nil && strings.Contains(path, "ProtonMail Bridge") {
return "ProtonMail Bridge"
}
}
// No need to solve for other OS. See comment above.
return origin
}
// isBeforeRebranding decide if last used version was older than 2.2.0. If
// cannot decide it returns false with error.
func isMacBeforeRebranding() (bool, error) {
// previous version | update | do mac migration |
// | first | false |
// cleared-cache | manual | false |
// cleared-cache | in-app | false |
// old | in-app | false |
// old in-app | in-app | false |
// old | manual | true |
// old in-app | manual | true |
// manual | in-app | false |
// Skip if it was in-app update and not manual
if path, err := os.Executable(); err != nil || strings.Contains(path, "ProtonMail Bridge") {
return false, err
}
return true, nil
}

View File

@ -1,56 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package base
import (
"os"
"path/filepath"
"runtime"
"runtime/pprof"
"github.com/sirupsen/logrus"
)
// startCPUProfile starts CPU pprof.
func startCPUProfile() {
f, err := os.Create("./cpu.pprof")
if err != nil {
logrus.Fatal("Could not create CPU profile: ", err)
}
if err := pprof.StartCPUProfile(f); err != nil {
logrus.Fatal("Could not start CPU profile: ", err)
}
}
// makeMemoryProfile generates memory pprof.
func makeMemoryProfile() {
name := "./mem.pprof"
f, err := os.Create(name)
if err != nil {
logrus.Fatal("Could not create memory profile: ", err)
}
if abs, err := filepath.Abs(name); err == nil {
name = abs
}
logrus.Info("Writing memory profile to ", name)
runtime.GC() // get up-to-date statistics
if err := pprof.WriteHeapProfile(f); err != nil {
logrus.Fatal("Could not write memory profile: ", err)
}
_ = f.Close()
}

View File

@ -1,112 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package base
import (
"os"
"strconv"
"github.com/sirupsen/logrus"
"golang.org/x/sys/execabs"
)
// maxAllowedRestarts controls after how many crashes the app will give up restarting.
const maxAllowedRestarts = 10
func (b *Base) restartApp(crash bool) error {
var args []string
if crash {
args = incrementRestartFlag(os.Args)[1:]
defer func() { os.Exit(1) }()
} else {
args = os.Args[1:]
}
if b.launcher != "" {
args = forceLauncherFlag(args, b.launcher)
}
args = append(args, "--wait", b.mainExecutable)
logrus.
WithField("command", b.command).
WithField("args", args).
Warn("Restarting")
return execabs.Command(b.command, args...).Start() //nolint:gosec
}
// incrementRestartFlag increments the value of the restart flag.
// If no such flag is present, it is added with initial value 1.
func incrementRestartFlag(args []string) []string {
res := append([]string{}, args...)
hasFlag := false
for k, v := range res {
if v != "--restart" {
continue
}
hasFlag = true
if k+1 >= len(res) {
continue
}
n, err := strconv.Atoi(res[k+1])
if err != nil {
res[k+1] = "1"
} else {
res[k+1] = strconv.Itoa(n + 1)
}
}
if !hasFlag {
res = append(res, "--restart", "1")
}
return res
}
// forceLauncherFlag replace or add the launcher args with the one set in the app.
func forceLauncherFlag(args []string, launcher string) []string {
res := append([]string{}, args...)
hasFlag := false
for k, v := range res {
if v != "--launcher" {
continue
}
if k+1 >= len(res) {
continue
}
hasFlag = true
res[k+1] = launcher
}
if !hasFlag {
res = append(res, "--launcher", launcher)
}
return res
}

View File

@ -1,63 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package base
import (
"strings"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIncrementRestartFlag(t *testing.T) {
tests := []struct {
in []string
out []string
}{
{[]string{"./bridge", "--restart", "1"}, []string{"./bridge", "--restart", "2"}},
{[]string{"./bridge", "--restart", "2"}, []string{"./bridge", "--restart", "3"}},
{[]string{"./bridge", "--other", "--restart", "2"}, []string{"./bridge", "--other", "--restart", "3"}},
{[]string{"./bridge", "--restart", "2", "--other"}, []string{"./bridge", "--restart", "3", "--other"}},
{[]string{"./bridge", "--restart", "2", "--other", "2"}, []string{"./bridge", "--restart", "3", "--other", "2"}},
{[]string{"./bridge"}, []string{"./bridge", "--restart", "1"}},
{[]string{"./bridge", "--something"}, []string{"./bridge", "--something", "--restart", "1"}},
{[]string{"./bridge", "--something", "--else"}, []string{"./bridge", "--something", "--else", "--restart", "1"}},
{[]string{"./bridge", "--restart", "bad"}, []string{"./bridge", "--restart", "1"}},
{[]string{"./bridge", "--restart", "bad", "--other"}, []string{"./bridge", "--restart", "1", "--other"}},
}
for _, tt := range tests {
t.Run(strings.Join(tt.in, " "), func(t *testing.T) {
assert.Equal(t, tt.out, incrementRestartFlag(tt.in))
})
}
}
func TestVersionLessThan(t *testing.T) {
r := require.New(t)
old := semver.MustParse("1.1.0")
current := semver.MustParse("1.1.1")
newer := semver.MustParse("1.1.2")
r.True(old.LessThan(current))
r.False(current.LessThan(current))
r.False(newer.LessThan(current))
}

View File

@ -1,101 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build !windows
// +build !windows
package base
import (
"errors"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/allan-simon/go-singleinstance"
"golang.org/x/sys/unix"
)
// checkSingleInstance returns error if a bridge instance is already running
// This instance should be stop and window of running window should be brought
// to focus.
//
// For macOS and Linux when already running version is older than this instance
// it will kill old and continue with this new bridge (i.e. no error returned).
func checkSingleInstance(lockFilePath string, settingsObj *settings.Settings) (*os.File, error) {
if lock, err := singleinstance.CreateLockFile(lockFilePath); err == nil {
// Bridge is not runnig, continue normally
return lock, nil
}
if err := runningVersionIsOlder(settingsObj); err != nil {
return nil, err
}
pid, err := getPID(lockFilePath)
if err != nil {
return nil, err
}
if err := unix.Kill(pid, unix.SIGTERM); err != nil {
return nil, err
}
// Need to wait some time to release file lock
time.Sleep(time.Second)
return singleinstance.CreateLockFile(lockFilePath)
}
func getPID(lockFilePath string) (int, error) {
file, err := os.Open(filepath.Clean(lockFilePath))
if err != nil {
return 0, err
}
defer func() { _ = file.Close() }()
rawPID := make([]byte, 10) // PID is probably up to 7 digits long, 10 should be enough
n, err := file.Read(rawPID)
if err != nil {
return 0, err
}
return strconv.Atoi(strings.TrimSpace(string(rawPID[:n])))
}
func runningVersionIsOlder(settingsObj *settings.Settings) error {
currentVer, err := semver.StrictNewVersion(constants.Version)
if err != nil {
return err
}
runningVer, err := semver.StrictNewVersion(settingsObj.Get(settings.LastVersionKey))
if err != nil {
return err
}
if !runningVer.LessThan(currentVer) {
return errors.New("running version is not older")
}
return nil
}

View File

@ -1,32 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build windows
// +build windows
package base
import (
"os"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/allan-simon/go-singleinstance"
)
func checkSingleInstance(lockFilePath string, _ *settings.Settings) (*os.File, error) {
return singleinstance.CreateLockFile(lockFilePath)
}

205
internal/app/bridge.go Normal file
View File

@ -0,0 +1,205 @@
package app
import (
"encoding/base64"
"fmt"
"os"
"runtime"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-autostart"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/dialer"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/internal/versioner"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
const vaultSecretName = "bridge-vault-key"
func newBridge(locations *locations.Locations, identifier *useragent.UserAgent) (*bridge.Bridge, error) {
// Create the underlying dialer used by the bridge.
// It only connects to trusted servers and reports any untrusted servers it finds.
pinningDialer := dialer.NewPinningTLSDialer(
dialer.NewBasicTLSDialer(constants.APIHost),
dialer.NewTLSReporter(constants.APIHost, constants.AppVersion, identifier, dialer.TrustedAPIPins),
dialer.NewTLSPinChecker(dialer.TrustedAPIPins),
)
// Create a proxy dialer which switches to a proxy if the request fails.
proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost)
// Create the autostarter.
autostarter, err := newAutostarter()
if err != nil {
return nil, fmt.Errorf("could not create autostarter: %w", err)
}
// Create the update installer.
updater, err := newUpdater(locations)
if err != nil {
return nil, fmt.Errorf("could not create updater: %w", err)
}
// Get the current bridge version.
version, err := semver.NewVersion(constants.Version)
if err != nil {
return nil, fmt.Errorf("could not create version: %w", err)
}
// Create the encVault.
encVault, insecure, corrupt, err := newVault(locations)
if err != nil {
return nil, fmt.Errorf("could not create vault: %w", err)
} else if insecure {
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
} else if corrupt {
logrus.Warn("The vault is corrupt and has been wiped")
}
// Install the certificates if needed.
if installed := encVault.GetCertsInstalled(); !installed {
if err := certs.NewInstaller().InstallCert(encVault.GetBridgeTLSCert()); err != nil {
return nil, fmt.Errorf("failed to install certs: %w", err)
}
if err := encVault.SetCertsInstalled(true); err != nil {
return nil, fmt.Errorf("failed to set certs installed: %w", err)
}
if err := encVault.SetCertsInstalled(true); err != nil {
return nil, fmt.Errorf("could not set certs installed: %w", err)
}
}
// Create a new bridge.
bridge, err := bridge.New(constants.APIHost, locations, encVault, identifier, pinningDialer, proxyDialer, autostarter, updater, version)
if err != nil {
return nil, fmt.Errorf("could not create bridge: %w", err)
}
// If the vault could not be loaded properly, push errors to the bridge.
switch {
case insecure:
bridge.PushError(vault.ErrInsecure)
case corrupt:
bridge.PushError(vault.ErrCorrupt)
}
return bridge, nil
}
func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) {
var insecure bool
vaultDir, err := locations.ProvideSettingsPath()
if err != nil {
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
}
var vaultKey []byte
if key, err := getVaultKey(vaultDir); err != nil {
insecure = true
} else {
vaultKey = key
}
gluonDir, err := locations.ProvideGluonPath()
if err != nil {
return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err)
}
vault, corrupt, err := vault.New(vaultDir, gluonDir, vaultKey)
if err != nil {
return nil, false, false, fmt.Errorf("could not create vault: %w", err)
}
return vault, insecure, corrupt, nil
}
func getVaultKey(vaultDir string) ([]byte, error) {
helper, err := vault.GetHelper(vaultDir)
if err != nil {
return nil, fmt.Errorf("could not get keychain helper: %w", err)
}
keychain, err := keychain.NewKeychain(helper, constants.KeyChainName)
if err != nil {
return nil, fmt.Errorf("could not create keychain: %w", err)
}
secrets, err := keychain.List()
if err != nil {
return nil, fmt.Errorf("could not list keychain: %w", err)
}
if !slices.Contains(secrets, vaultSecretName) {
tok, err := crypto.RandomToken(32)
if err != nil {
return nil, fmt.Errorf("could not generate random token: %w", err)
}
if err := keychain.Put(vaultSecretName, base64.StdEncoding.EncodeToString(tok)); err != nil {
return nil, fmt.Errorf("could not put keychain item: %w", err)
}
}
_, keyEnc, err := keychain.Get(vaultSecretName)
if err != nil {
return nil, fmt.Errorf("could not get keychain item: %w", err)
}
keyDec, err := base64.StdEncoding.DecodeString(keyEnc)
if err != nil {
return nil, fmt.Errorf("could not decode keychain item: %w", err)
}
return keyDec, nil
}
func newAutostarter() (*autostart.App, error) {
exe, err := os.Executable()
if err != nil {
return nil, err
}
return &autostart.App{
Name: constants.FullAppName,
DisplayName: constants.FullAppName,
Exec: []string{exe, "--" + flagNoWindow},
}, nil
}
func newUpdater(locations *locations.Locations) (*updater.Updater, error) {
updatesDir, err := locations.ProvideUpdatesPath()
if err != nil {
return nil, fmt.Errorf("could not provide updates path: %w", err)
}
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
if err != nil {
return nil, fmt.Errorf("could not create key from armored: %w", err)
}
verifier, err := crypto.NewKeyRing(key)
if err != nil {
return nil, fmt.Errorf("could not create key ring: %w", err)
}
return updater.NewUpdater(
updater.NewInstaller(versioner.New(updatesDir)),
verifier,
constants.UpdateName,
runtime.GOOS,
), nil
}

View File

@ -1,269 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package bridge implements the bridge CLI application.
package bridge
import (
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/api"
"github.com/ProtonMail/proton-bridge/v2/internal/app/base"
pkgBridge "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/v2/internal/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/smtp"
"github.com/ProtonMail/proton-bridge/v2/internal/store"
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
)
const (
flagLogIMAP = "log-imap"
flagLogSMTP = "log-smtp"
flagNonInteractive = "noninteractive"
// Memory cache was estimated by empirical usage in past and it was set to 100MB.
// NOTE: This value must not be less than maximal size of one email (~30MB).
inMemoryCacheLimnit = 100 * (1 << 20)
)
func New(base *base.Base) *cli.App {
app := base.NewApp(main)
app.Flags = append(app.Flags, []cli.Flag{
&cli.StringFlag{
Name: flagLogIMAP,
Usage: "Enable logging of IMAP communications (all|client|server) (may contain decrypted data!)",
},
&cli.BoolFlag{
Name: flagLogSMTP,
Usage: "Enable logging of SMTP communications (may contain decrypted data!)",
},
&cli.BoolFlag{
Name: flagNonInteractive,
Usage: "Start Bridge entirely noninteractively",
},
}...)
return app
}
func main(b *base.Base, c *cli.Context) error { //nolint:funlen
cache, cacheErr := loadMessageCache(b)
if cacheErr != nil {
logrus.WithError(cacheErr).Error("Could not load local cache.")
}
builder := message.NewBuilder(
b.Settings.GetInt(settings.FetchWorkers),
b.Settings.GetInt(settings.AttachmentWorkers),
)
bridge := pkgBridge.New(
b.Locations,
b.Cache,
b.Settings,
b.SentryReporter,
b.CrashHandler,
b.Listener,
b.TLS,
b.UserAgent,
cache,
builder,
b.CM,
b.Creds,
b.Updater,
b.Versioner,
b.Autostart,
)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
tlsConfig, err := bridge.GetTLSConfig()
if err != nil {
return err
}
if cacheErr != nil {
bridge.AddError(pkgBridge.ErrLocalCacheUnavailable)
}
go func() {
defer b.CrashHandler.HandlePanic()
api.NewAPIServer(b.Settings, b.Listener).ListenAndServe()
}()
go func() {
defer b.CrashHandler.HandlePanic()
imapPort := b.Settings.GetInt(settings.IMAPPortKey)
imap.NewIMAPServer(
b.CrashHandler,
c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all",
c.String(flagLogIMAP) == "server" || c.String(flagLogIMAP) == "all",
imapPort, tlsConfig, imapBackend, b.UserAgent, b.Listener).ListenAndServe()
}()
go func() {
defer b.CrashHandler.HandlePanic()
smtpPort := b.Settings.GetInt(settings.SMTPPortKey)
useSSL := b.Settings.GetBool(settings.SMTPSSLKey)
smtp.NewSMTPServer(
b.CrashHandler,
c.Bool(flagLogSMTP),
smtpPort, useSSL, tlsConfig, smtpBackend, b.Listener).ListenAndServe()
}()
// We want to remove old versions if the app exits successfully.
b.AddTeardownAction(b.Versioner.RemoveOldVersions)
// We want cookies to be saved to disk so they are loaded the next time.
b.AddTeardownAction(b.CookieJar.PersistCookies)
var frontendMode string
switch {
case c.Bool(base.FlagCLI):
frontendMode = "cli"
case c.Bool(flagNonInteractive):
return <-(make(chan error)) // Block forever.
default:
frontendMode = "grpc"
}
f := frontend.New(
frontendMode,
!c.Bool(base.FlagNoWindow),
b.CrashHandler,
b.Listener,
b.Updater,
bridge,
b,
b.Locations,
)
// Watch for updates routine
go func() {
ticker := time.NewTicker(constants.UpdateCheckInterval)
for {
checkAndHandleUpdate(b.Updater, f, b.Settings.GetBool(settings.AutoUpdateKey))
<-ticker.C
}
}()
return f.Loop()
}
func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) {
log := logrus.WithField("pkg", "app/bridge")
version, err := u.Check()
if err != nil {
log.WithError(err).Error("An error occurred while checking for updates")
return
}
f.WaitUntilFrontendIsReady()
// Update links in UI
f.SetVersion(version)
if !u.IsUpdateApplicable(version) {
log.Info("No need to update")
return
}
log.WithField("version", version.Version).Info("An update is available")
if !autoUpdate {
f.NotifyManualUpdate(version, u.CanInstall(version))
return
}
if !u.CanInstall(version) {
log.Info("A manual update is required")
f.NotifySilentUpdateError(updater.ErrManualUpdateRequired)
return
}
if err := u.InstallUpdate(version); err != nil {
if errors.Cause(err) == updater.ErrDownloadVerify {
log.WithError(err).Warning("Skipping update installation due to temporary error")
} else {
log.WithError(err).Error("The update couldn't be installed")
f.NotifySilentUpdateError(err)
}
return
}
f.NotifySilentUpdateInstalled()
}
// loadMessageCache loads local cache in case it is enabled in settings and available.
// In any other case it is returning in-memory cache. Could also return an error in case
// local cache is enabled but unavailable (in-memory cache will be returned nevertheless).
func loadMessageCache(b *base.Base) (cache.Cache, error) {
if !b.Settings.GetBool(settings.CacheEnabledKey) {
return cache.NewInMemoryCache(inMemoryCacheLimnit), nil
}
var compressor cache.Compressor
// NOTE(GODT-1158): Changing compression is not an option currently
// available for user but, if user changes compression setting we have
// to nuke the cache.
if b.Settings.GetBool(settings.CacheCompressionKey) {
compressor = &cache.GZipCompressor{}
} else {
compressor = &cache.NoopCompressor{}
}
var path string
if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" {
path = customPath
} else {
path = b.Cache.GetDefaultMessageCacheDir()
// Store path so it will allways persist if default location
// will be changed in new version.
b.Settings.Set(settings.CacheLocationKey, path)
}
// To prevent memory peaks we set maximal write concurency for store
// build jobs.
store.SetBuildAndCacheJobLimit(b.Settings.GetInt(settings.CacheConcurrencyWrite))
messageCache, err := cache.NewOnDiskCache(path, compressor, cache.Options{
MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)),
MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey),
ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead),
ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
})
if err != nil {
return cache.NewInMemoryCache(inMemoryCacheLimnit), err
}
return messageCache, nil
}

28
internal/app/logging.go Normal file
View File

@ -0,0 +1,28 @@
package app
import (
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/urfave/cli/v2"
)
func initLogging(c *cli.Context, locations *locations.Locations, crashHandler *crash.Handler) error {
// Get a place to keep our logs.
logsPath, err := locations.ProvideLogsPath()
if err != nil {
return fmt.Errorf("could not provide logs path: %w", err)
}
// Initialize logging.
if err := logging.Init(logsPath, c.String(flagLogLevel)); err != nil {
return fmt.Errorf("could not initialize logging: %w", err)
}
// Ensure we dump a stack trace if we crash.
crashHandler.AddRecoveryAction(logging.DumpStackTrace(logsPath))
return nil
}

View File

@ -1,38 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package bridge provides core functionality of Bridge app.
package bridge
import "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
// IsAutostartEnabled checks if link file exits.
func (b *Bridge) IsAutostartEnabled() bool {
return b.autostart.IsEnabled()
}
// EnableAutostart creates link and sets the preferences.
func (b *Bridge) EnableAutostart() error {
b.settings.SetBool(settings.AutostartKey, true)
return b.autostart.Enable()
}
// DisableAutostart removes link and sets the preferences.
func (b *Bridge) DisableAutostart() error {
b.settings.SetBool(settings.AutostartKey, false)
return b.autostart.Disable()
}

View File

@ -1,325 +1,318 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package bridge provides core functionality of Bridge app.
package bridge
import (
"errors"
"context"
"crypto/tls"
"fmt"
"strconv"
"time"
"net"
"net/http"
"sync"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-autostart"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/watcher"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/metrics"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
logrus "github.com/sirupsen/logrus"
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
var log = logrus.WithField("pkg", "bridge") //nolint:gochecknoglobals
var ErrLocalCacheUnavailable = errors.New("local cache is unavailable")
type Bridge struct {
*users.Users
// vault holds bridge-specific data, such as preferences and known users (authorized or not).
vault *vault.Vault
locations Locator
settings SettingsProvider
clientManager pmapi.Manager
// users holds authorized users.
users map[string]*user.User
// api manages user API clients.
api *liteapi.Manager
cookieJar *cookies.Jar
proxyDialer ProxyDialer
identifier Identifier
// watchers holds all registered event watchers.
watchers []*watcher.Watcher[events.Event]
watchersLock sync.RWMutex
// tlsConfig holds the bridge TLS config used by the IMAP and SMTP servers.
tlsConfig *tls.Config
// imapServer is the bridge's IMAP server.
imapServer *gluon.Server
imapListener net.Listener
// smtpServer is the bridge's SMTP server.
smtpServer *smtp.Server
smtpBackend *smtpBackend
smtpListener net.Listener
// updater is the bridge's updater.
updater Updater
versioner Versioner
tls *tls.TLS
userAgent *useragent.UserAgent
cacheProvider CacheProvider
autostart *autostart.App
// Bridge's global errors list.
errors []error
curVersion *semver.Version
updateCheckCh chan struct{}
isAllMailVisible bool
isFirstStart bool
lastVersion string
// focusService is used to raise the bridge window when needed.
focusService *focus.FocusService
// autostarter is the bridge's autostarter.
autostarter Autostarter
// locator is the bridge's locator.
locator Locator
// errors contains errors encountered during startup.
errors []error
}
func New( //nolint:funlen
locations Locator,
cacheProvider CacheProvider,
setting SettingsProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
tls *tls.TLS,
userAgent *useragent.UserAgent,
cache cache.Cache,
builder *message.Builder,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
updater Updater,
versioner Versioner,
autostart *autostart.App,
) *Bridge {
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if setting.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy()
// New creates a new bridge.
func New(
apiURL string, // the URL of the API to use
locator Locator, // the locator to provide paths to store data
vault *vault.Vault, // the bridge's encrypted data store
identifier Identifier, // the identifier to keep track of the user agent
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
proxyDialer ProxyDialer, // the DoH dialer
autostarter Autostarter, // the autostarter to manage autostart settings
updater Updater, // the updater to fetch and install updates
curVersion *semver.Version, // the current version of the bridge
) (*Bridge, error) {
if vault.GetProxyAllowed() {
proxyDialer.AllowProxy()
} else {
proxyDialer.DisallowProxy()
}
u := users.New(
locations,
panicHandler,
eventListener,
clientManager,
credStorer,
newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
cookieJar, err := cookies.NewCookieJar(vault)
if err != nil {
return nil, fmt.Errorf("failed to create cookie jar: %w", err)
}
api := liteapi.New(
liteapi.WithHostURL(apiURL),
liteapi.WithAppVersion(constants.AppVersion),
liteapi.WithCookieJar(cookieJar),
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
)
b := &Bridge{
Users: u,
locations: locations,
settings: setting,
clientManager: clientManager,
updater: updater,
versioner: versioner,
tls: tls,
userAgent: userAgent,
cacheProvider: cacheProvider,
autostart: autostart,
isFirstStart: false,
isAllMailVisible: setting.GetBool(settings.IsAllMailVisible),
}
if setting.GetBool(settings.FirstStartKey) {
b.isFirstStart = true
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
logrus.WithError(err).Error("Failed to send metric")
}
setting.SetBool(settings.FirstStartKey, false)
}
// Keep in bridge and update in settings the last used version.
b.lastVersion = b.settings.Get(settings.LastVersionKey)
b.settings.Set(settings.LastVersionKey, constants.Version)
go b.heartbeat()
return b
}
// heartbeat sends a heartbeat signal once a day.
func (b *Bridge) heartbeat() {
for range time.Tick(time.Minute) {
lastHeartbeatDay, err := strconv.ParseInt(b.settings.Get(settings.LastHeartbeatKey), 10, 64)
if err != nil {
continue
}
// If we're still on the same day, don't send a heartbeat.
if time.Now().YearDay() == int(lastHeartbeatDay) {
continue
}
// We're on the next (or a different) day, so send a heartbeat.
if err := b.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel)); err != nil {
logrus.WithError(err).Error("Failed to send heartbeat")
continue
}
// Heartbeat was sent successfully so update the last heartbeat day.
b.settings.Set(settings.LastHeartbeatKey, fmt.Sprintf("%v", time.Now().YearDay()))
}
}
// GetUpdateChannel returns currently set update channel.
func (b *Bridge) GetUpdateChannel() updater.UpdateChannel {
return updater.UpdateChannel(b.settings.Get(settings.UpdateChannelKey))
}
// SetUpdateChannel switches update channel.
func (b *Bridge) SetUpdateChannel(channel updater.UpdateChannel) {
b.settings.Set(settings.UpdateChannelKey, string(channel))
}
func (b *Bridge) resetToLatestStable() error {
version, err := b.updater.Check()
tlsConfig, err := loadTLSConfig(vault)
if err != nil {
// If we can not check for updates - just remove all local updates and reset to base installer version.
// Not using `b.locations.ClearUpdates()` because `versioner.RemoveOtherVersions` can also handle
// case when it is needed to remove currently running verion.
if err := b.versioner.RemoveOtherVersions(semver.MustParse("0.0.0")); err != nil {
log.WithError(err).Error("Failed to clear updates while downgrading channel")
}
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
imapServer, err := newIMAPServer(vault.GetGluonDir(), curVersion, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
}
smtpBackend, err := newSMTPBackend()
if err != nil {
return nil, fmt.Errorf("failed to create SMTP backend: %w", err)
}
smtpServer, err := newSMTPServer(smtpBackend, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create SMTP server: %w", err)
}
focusService, err := focus.NewService()
if err != nil {
return nil, fmt.Errorf("failed to create focus service: %w", err)
}
bridge := &Bridge{
vault: vault,
users: make(map[string]*user.User),
api: api,
cookieJar: cookieJar,
proxyDialer: proxyDialer,
identifier: identifier,
tlsConfig: tlsConfig,
imapServer: imapServer,
smtpServer: smtpServer,
smtpBackend: smtpBackend,
updater: updater,
curVersion: curVersion,
updateCheckCh: make(chan struct{}, 1),
focusService: focusService,
autostarter: autostarter,
locator: locator,
}
api.AddStatusObserver(func(status liteapi.Status) {
bridge.publish(events.ConnStatus{
Status: status,
})
})
api.AddErrorHandler(liteapi.AppVersionBadCode, func() {
bridge.publish(events.UpdateForced{})
})
api.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error {
req.SetHeader("User-Agent", bridge.identifier.GetUserAgent())
return nil
})
go func() {
for range tlsReporter.GetTLSIssueCh() {
bridge.publish(events.TLSIssue{})
}
}()
go func() {
for range focusService.GetRaiseCh() {
bridge.publish(events.Raise{})
}
}()
go func() {
for event := range imapServer.AddWatcher() {
bridge.handleIMAPEvent(event)
}
}()
if err := bridge.loadUsers(context.Background()); err != nil {
return nil, fmt.Errorf("failed to load connected users: %w", err)
}
// If current version is same as upstream stable version - do nothing.
if version.Version.Equal(semver.MustParse(constants.Version)) {
return nil
if err := bridge.serveIMAP(); err != nil {
bridge.PushError(ErrServeIMAP)
}
if err := b.updater.InstallUpdate(version); err != nil {
return err
if err := bridge.serveSMTP(); err != nil {
bridge.PushError(ErrServeSMTP)
}
return b.versioner.RemoveOtherVersions(version.Version)
if err := bridge.watchForUpdates(); err != nil {
bridge.PushError(ErrWatchUpdates)
}
return bridge, nil
}
// FactoryReset will remove all local cache and settings.
// It will also downgrade to latest stable version if user is on early version.
func (b *Bridge) FactoryReset() {
wasEarly := b.GetUpdateChannel() == updater.EarlyChannel
// GetEvents returns a channel of events of the given type.
// If no types are supplied, all events are returned.
func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, func()) {
newWatcher := bridge.addWatcher(ofType...)
b.settings.Set(settings.UpdateChannelKey, string(updater.StableChannel))
return newWatcher.GetChannel(), func() { bridge.remWatcher(newWatcher) }
}
if wasEarly {
if err := b.resetToLatestStable(); err != nil {
log.WithError(err).Error("Failed to reset to latest stable version")
func (bridge *Bridge) FactoryReset(ctx context.Context) error {
panic("TODO")
}
func (bridge *Bridge) PushError(err error) {
bridge.errors = append(bridge.errors, err)
}
func (bridge *Bridge) GetErrors() []error {
return bridge.errors
}
func (bridge *Bridge) Close(ctx context.Context) error {
// Close the IMAP server.
if err := bridge.closeIMAP(ctx); err != nil {
logrus.WithError(err).Error("Failed to close IMAP server")
}
// Close the SMTP server.
if err := bridge.closeSMTP(); err != nil {
logrus.WithError(err).Error("Failed to close SMTP server")
}
// Close all users.
for _, user := range bridge.users {
if err := user.Close(ctx); err != nil {
logrus.WithError(err).Error("Failed to close user")
}
}
if err := b.Users.ClearData(); err != nil {
log.WithError(err).Error("Failed to remove bridge data")
// Persist the cookies.
if err := bridge.cookieJar.PersistCookies(); err != nil {
logrus.WithError(err).Error("Failed to persist cookies")
}
if err := b.Users.ClearUsers(); err != nil {
log.WithError(err).Error("Failed to remove bridge users")
// Close the focus service.
bridge.focusService.Close()
// Save the last version of bridge that was run.
if err := bridge.vault.SetLastVersion(bridge.curVersion); err != nil {
logrus.WithError(err).Error("Failed to save last version")
}
}
// GetKeychainApp returns current keychain helper.
func (b *Bridge) GetKeychainApp() string {
return b.settings.Get(settings.PreferredKeychainKey)
}
// SetKeychainApp sets current keychain helper.
func (b *Bridge) SetKeychainApp(helper string) {
b.settings.Set(settings.PreferredKeychainKey, helper)
}
func (b *Bridge) EnableCache() error {
if err := b.Users.EnableCache(); err != nil {
return err
}
b.settings.SetBool(settings.CacheEnabledKey, true)
return nil
}
func (b *Bridge) DisableCache() error {
if err := b.Users.DisableCache(); err != nil {
return err
}
func (bridge *Bridge) publish(event events.Event) {
bridge.watchersLock.RLock()
defer bridge.watchersLock.RUnlock()
b.settings.SetBool(settings.CacheEnabledKey, false)
// Reset back to the default location when disabling.
b.settings.Set(settings.CacheLocationKey, b.cacheProvider.GetDefaultMessageCacheDir())
return nil
}
func (b *Bridge) MigrateCache(from, to string) error {
if err := b.Users.MigrateCache(from, to); err != nil {
return err
}
b.settings.Set(settings.CacheLocationKey, to)
return nil
}
// SetProxyAllowed instructs the app whether to use DoH to access an API proxy if necessary.
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
func (b *Bridge) SetProxyAllowed(proxyAllowed bool) {
b.settings.SetBool(settings.AllowProxyKey, proxyAllowed)
if proxyAllowed {
b.clientManager.AllowProxy()
} else {
b.clientManager.DisallowProxy()
}
}
// GetProxyAllowed returns whether use of DoH is enabled to access an API proxy if necessary.
func (b *Bridge) GetProxyAllowed() bool {
return b.settings.GetBool(settings.AllowProxyKey)
}
// AddError add an error to a global error list if it does not contain it yet. Adding nil is noop.
func (b *Bridge) AddError(err error) {
if err == nil {
return
}
if b.HasError(err) {
return
}
b.errors = append(b.errors, err)
}
// DelError removes an error from global error list.
func (b *Bridge) DelError(err error) {
for idx, val := range b.errors {
if val == err {
b.errors = append(b.errors[:idx], b.errors[idx+1:]...)
return
for _, watcher := range bridge.watchers {
if watcher.IsWatching(event) {
if ok := watcher.Send(event); !ok {
logrus.WithField("event", event).Warn("Failed to send event to watcher")
}
}
}
}
// HasError returnes true if global error list contains an err.
func (b *Bridge) HasError(err error) bool {
for _, val := range b.errors {
if val == err {
return true
}
func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] {
bridge.watchersLock.Lock()
defer bridge.watchersLock.Unlock()
newWatcher := watcher.New(ofType...)
bridge.watchers = append(bridge.watchers, newWatcher)
return newWatcher
}
func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
bridge.watchersLock.Lock()
defer bridge.watchersLock.Unlock()
bridge.watchers = xslices.Filter(bridge.watchers, func(other *watcher.Watcher[events.Event]) bool {
return other != oldWatcher
})
}
func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) {
cert, err := tls.X509KeyPair(vault.GetBridgeTLSCert(), vault.GetBridgeTLSKey())
if err != nil {
return nil, err
}
return false
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}
// GetLastVersion returns the version which was used in previous execution of
// Bridge.
func (b *Bridge) GetLastVersion() string {
return b.lastVersion
}
func newListener(port int, useTLS bool, tlsConfig *tls.Config) (net.Listener, error) {
if useTLS {
tlsListener, err := tls.Listen("tcp", fmt.Sprintf(":%v", port), tlsConfig)
if err != nil {
return nil, err
}
// IsFirstStart returns true when Bridge is running for first time or after
// factory reset.
func (b *Bridge) IsFirstStart() bool {
return b.isFirstStart
}
return tlsListener, nil
}
// IsAllMailVisible can be called extensively by IMAP. Therefore, it is better
// to cache the value instead of reading from settings file.
func (b *Bridge) IsAllMailVisible() bool {
return b.isAllMailVisible
}
netListener, err := net.Listen("tcp", fmt.Sprintf(":%v", port))
if err != nil {
return nil, err
}
func (b *Bridge) SetIsAllMailVisible(isVisible bool) {
b.settings.SetBool(settings.IsAllMailVisible, isVisible)
b.isAllMailVisible = isVisible
return netListener, nil
}

View File

@ -0,0 +1,362 @@
package bridge_test
import (
"context"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
const (
username = "username"
password = "password"
)
var (
v2_3_0 = semver.MustParse("2.3.0")
v2_4_0 = semver.MustParse("2.4.0")
)
func TestBridge_ConnStatus(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of connection status events.
eventCh, done := bridge.GetEvents(events.ConnStatus{})
defer done()
// Simulate network disconnect.
mocks.TLSDialer.SetCanDial(false)
// Trigger some operation that will fail due to the network disconnect.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.Error(t, err)
// Wait for the event.
require.Equal(t, events.ConnStatus{Status: liteapi.StatusDown}, <-eventCh)
// Simulate network reconnect.
mocks.TLSDialer.SetCanDial(true)
// Trigger some operation that will succeed due to the network reconnect.
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, userID)
// Wait for the event.
require.Equal(t, events.ConnStatus{Status: liteapi.StatusUp}, <-eventCh)
})
})
}
func TestBridge_TLSIssue(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
defer done()
// Simulate a TLS issue.
go func() {
mocks.TLSIssueCh <- struct{}{}
}()
// Wait for the event.
require.IsType(t, events.TLSIssue{}, <-tlsEventCh)
})
})
}
func TestBridge_Focus(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of TLS issue events.
raiseCh, done := bridge.GetEvents(events.Raise{})
defer done()
// Simulate a focus event.
focus.TryRaise()
// Wait for the event.
require.IsType(t, events.Raise{}, <-raiseCh)
})
})
}
func TestBridge_UserAgent(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Set the platform to something other than the default.
bridge.SetCurrentPlatform("platform")
// Assert that the user agent then contains the platform.
require.Contains(t, bridge.GetCurrentUserAgent(), "platform")
// Login the user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
// Assert that the user agent was sent to the API.
require.Contains(t, calls[len(calls)-1].Request.Header.Get("User-Agent"), bridge.GetCurrentUserAgent())
})
})
}
func TestBridge_Cookies(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
var sessionID string
// Start bridge and add a user so that API assigns us a session ID via cookie.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
require.NoError(t, err)
sessionID = cookie.Value
})
// Start bridge again and check that it uses the same session ID.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
require.NoError(t, err)
require.Equal(t, sessionID, cookie.Value)
})
})
}
func TestBridge_CheckUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
defer done()
// We are currently on the latest version.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
},
CanInstall: true,
}, <-updateCh)
})
})
}
func TestBridge_AutoUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Enable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(true))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{})
defer done()
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateInstalled{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
},
}, <-updateCh)
})
})
}
func TestBridge_ManualUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
defer done()
// Simulate a new version being available, but it's too new for us.
mocks.Updater.SetLatestVersion(v2_4_0, v2_4_0)
// Check for updates.
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
Version: v2_4_0,
MinAuto: v2_4_0,
RolloutProportion: 1.0,
},
CanInstall: false,
}, <-updateCh)
})
})
}
func TestBridge_ForceUpdate(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateForced{})
defer done()
// Set the minimum accepted app version to something newer than the current version.
s.SetMinAppVersion(v2_4_0)
// Try to login the user. It will fail because the bridge is too old.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.Error(t, err)
// We should get an update required event.
require.Equal(t, events.UpdateForced{}, <-updateCh)
})
})
}
func TestBridge_BadVaultKey(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
var userID string
// Login a user.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
userID = newUserID
})
// Start bridge with the correct vault key -- it should load the users correctly.
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
})
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
withBridge(t, s.GetHostURL(), locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
withBridge(t, s.GetHostURL(), locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
})
})
}
// withEnv creates the full test environment and runs the tests.
func withEnv(t *testing.T, tests func(server *server.Server, locator bridge.Locator, vaultKey []byte)) {
// Create test API.
server := server.NewTLS()
defer server.Close()
// Add test user.
_, _, err := server.AddUser(username, password, username+"@pm.me")
require.NoError(t, err)
// Generate a random vault key.
vaultKey, err := crypto.RandomToken(32)
require.NoError(t, err)
// Run the tests.
tests(server, locations.New(bridge.NewTestLocationsProvider(t), "config-name"), vaultKey)
}
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
func withBridge(t *testing.T, apiURL string, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create the mock objects used in the tests.
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
// Bridge will enable the proxy by default at startup.
mocks.ProxyDialer.EXPECT().AllowProxy()
// Get the path to the vault.
vaultDir, err := locator.ProvideSettingsPath()
require.NoError(t, err)
// Create the vault.
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey)
require.NoError(t, err)
// Create a new bridge.
bridge, err := bridge.New(
apiURL,
locator,
vault,
useragent.New(),
mocks.TLSReporter,
mocks.ProxyDialer,
mocks.Autostarter,
mocks.Updater,
v2_3_0,
)
require.NoError(t, err)
// Use the bridge.
tests(bridge, mocks)
// Close the bridge.
require.NoError(t, bridge.Close(ctx))
}
// must is a helper function that panics on error.
func must[T any](val T, err error) T {
if err != nil {
panic(err)
}
return val
}
func getConnectedUserIDs(t *testing.T, bridge *bridge.Bridge) []string {
t.Helper()
return xslices.Filter(bridge.GetUserIDs(), func(userID string) bool {
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
return info.Connected
})
}

View File

@ -21,67 +21,51 @@ import (
"archive/zip"
"bytes"
"context"
"errors"
"io"
"os"
"path/filepath"
"sort"
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"gitlab.protontech.ch/go/liteapi"
)
const (
MaxAttachmentSize = 7 * 1024 * 1024 // MaxAttachmentSize 7 MB total limit
MaxAttachmentSize = 7 * (1 << 20) // MaxAttachmentSize 7 MB total size of all attachments.
MaxCompressedFilesCount = 6
)
var ErrSizeTooLarge = errors.New("file is too big")
func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, description, username, email, client string, attachLogs bool) error {
var account string
// ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string, attachLogs bool) error { //nolint:funlen
if user, err := b.GetUser(address); err == nil {
accountName = user.Username()
} else if users := b.GetUsers(); len(users) > 0 {
accountName = users[0].Username()
if info, err := bridge.QueryUserInfo(username); err == nil {
account = info.Username
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
account = bridge.users[userIDs[0]].Name()
}
report := pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Browser: emailClient,
Title: "[Bridge] Bug",
Description: description,
Username: accountName,
Email: address,
}
var atts []liteapi.ReportBugAttachment
if attachLogs {
logs, err := b.getMatchingLogs(
func(filename string) bool {
return logging.MatchLogName(filename) && !logging.MatchStackTraceName(filename)
},
)
logs, err := getMatchingLogs(bridge.locator, func(filename string) bool {
return logging.MatchLogName(filename) && !logging.MatchStackTraceName(filename)
})
if err != nil {
log.WithError(err).Error("Can't get log files list")
return err
}
guiLogs, err := b.getMatchingLogs(
func(filename string) bool {
return logging.MatchGUILogName(filename) && !logging.MatchStackTraceName(filename)
},
)
guiLogs, err := getMatchingLogs(bridge.locator, func(filename string) bool {
return logging.MatchGUILogName(filename) && !logging.MatchStackTraceName(filename)
})
if err != nil {
log.WithError(err).Error("Can't get GUI log files list")
return err
}
crashes, err := b.getMatchingLogs(
func(filename string) bool {
return logging.MatchLogName(filename) && logging.MatchStackTraceName(filename)
},
)
crashes, err := getMatchingLogs(bridge.locator, func(filename string) bool {
return logging.MatchLogName(filename) && logging.MatchStackTraceName(filename)
})
if err != nil {
log.WithError(err).Error("Can't get crash files list")
return err
}
var matchFiles []string
@ -95,26 +79,42 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
archive, err := zipFiles(matchFiles)
if err != nil {
log.WithError(err).Error("Can't zip logs and crashes")
return err
}
if archive != nil {
report.AddAttachment("logs.zip", "application/zip", archive)
body, err := io.ReadAll(archive)
if err != nil {
return err
}
atts = append(atts, liteapi.ReportBugAttachment{
Name: "logs.zip",
Filename: "logs.zip",
MIMEType: "application/zip",
Body: body,
})
}
return b.clientManager.ReportBug(context.Background(), report)
return bridge.api.ReportBug(ctx, liteapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Description: description,
Client: client,
Username: account,
Email: email,
}, atts...)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func (b *Bridge) getMatchingLogs(filenameMatchFunc func(string) bool) (filenames []string, err error) {
logsPath, err := b.locations.ProvideLogsPath()
func getMatchingLogs(locator Locator, filenameMatchFunc func(string) bool) (filenames []string, err error) {
logsPath, err := locator.ProvideLogsPath()
if err != nil {
return nil, err
}
@ -131,24 +131,25 @@ func (b *Bridge) getMatchingLogs(filenameMatchFunc func(string) bool) (filenames
matchFiles = append(matchFiles, filepath.Join(logsPath, file.Name()))
}
}
sort.Strings(matchFiles) // Sorted by timestamp: oldest first.
return matchFiles, nil
}
type LimitedBuffer struct {
type limitedBuffer struct {
capacity int
buf *bytes.Buffer
}
func NewLimitedBuffer(capacity int) *LimitedBuffer {
return &LimitedBuffer{
func newLimitedBuffer(capacity int) *limitedBuffer {
return &limitedBuffer{
capacity: capacity,
buf: bytes.NewBuffer(make([]byte, 0, capacity)),
}
}
func (b *LimitedBuffer) Write(p []byte) (n int, err error) {
func (b *limitedBuffer) Write(p []byte) (n int, err error) {
if len(p)+b.buf.Len() > b.capacity {
return 0, ErrSizeTooLarge
}
@ -156,7 +157,7 @@ func (b *LimitedBuffer) Write(p []byte) (n int, err error) {
return b.buf.Write(p)
}
func (b *LimitedBuffer) Read(p []byte) (n int, err error) {
func (b *limitedBuffer) Read(p []byte) (n int, err error) {
return b.buf.Read(p)
}
@ -165,14 +166,13 @@ func zipFiles(filenames []string) (io.Reader, error) {
return nil, nil
}
buf := NewLimitedBuffer(MaxAttachmentSize)
buf := newLimitedBuffer(MaxAttachmentSize)
w := zip.NewWriter(buf)
defer w.Close() //nolint:errcheck
for _, file := range filenames {
err := addFileToZip(file, w)
if err != nil {
if err := addFileToZip(file, w); err != nil {
return nil, err
}
}
@ -209,12 +209,9 @@ func addFileToZip(filename string, writer *zip.Writer) error {
return err
}
_, err = io.Copy(fileWriter, fileReader)
if err != nil {
if _, err := io.Copy(fileWriter, fileReader); err != nil {
return err
}
err = fileReader.Close()
return err
return fileReader.Close()
}

View File

@ -1,70 +1,38 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
)
func (b *Bridge) ConfigureAppleMail(userID, address string) (bool, error) {
user, err := b.GetUser(userID)
if err != nil {
return false, err
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
if address == "" {
address = user.GetPrimaryAddress()
address = user.Addresses()[0]
}
username := address
addresses := address
if user.IsCombinedAddressMode() {
username = user.GetPrimaryAddress()
addresses = strings.Join(user.GetAddresses(), ",")
}
var (
restart = false
smtpSSL = b.settings.GetBool(settings.SMTPSSLKey)
)
// If configuring apple mail for Catalina or newer, users should use SSL.
if useragent.IsCatalinaOrNewer() && !smtpSSL {
smtpSSL = true
restart = true
b.settings.SetBool(settings.SMTPSSLKey, true)
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
if err := bridge.SetSMTPSSL(true); err != nil {
return err
}
}
if err := (&clientconfig.AppleMail{}).Configure(
Host,
b.settings.GetInt(settings.IMAPPortKey),
b.settings.GetInt(settings.SMTPPortKey),
false, smtpSSL,
username, addresses,
user.GetBridgePassword(),
); err != nil {
return false, err
}
return restart, nil
return (&clientconfig.AppleMail{}).Configure(
constants.Host,
bridge.vault.GetIMAPPort(),
bridge.vault.GetSMTPPort(),
bridge.vault.GetIMAPSSL(),
bridge.vault.GetSMTPSSL(),
address,
strings.Join(user.Addresses(), ","),
user.BridgePass(),
)
}

View File

@ -1,23 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
// Host settings.
const (
Host = "127.0.0.1"
)

16
internal/bridge/errors.go Normal file
View File

@ -0,0 +1,16 @@
package bridge
import "errors"
var (
ErrServeIMAP = errors.New("failed to serve IMAP")
ErrServeSMTP = errors.New("failed to serve SMTP")
ErrWatchUpdates = errors.New("failed to watch for updates")
ErrNoSuchUser = errors.New("no such user")
ErrUserAlreadyExists = errors.New("user already exists")
ErrUserAlreadyLoggedIn = errors.New("user already logged in")
ErrNotImplemented = errors.New("not implemented")
ErrSizeTooLarge = errors.New("file is too big")
)

67
internal/bridge/files.go Normal file
View File

@ -0,0 +1,67 @@
package bridge
import (
"os"
"path/filepath"
)
func moveDir(from, to string) error {
entries, err := os.ReadDir(from)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
if err := os.Mkdir(filepath.Join(to, entry.Name()), 0700); err != nil {
return err
}
if err := moveDir(filepath.Join(from, entry.Name()), filepath.Join(to, entry.Name())); err != nil {
return err
}
if err := os.RemoveAll(filepath.Join(from, entry.Name())); err != nil {
return err
}
} else {
if err := move(filepath.Join(from, entry.Name()), filepath.Join(to, entry.Name())); err != nil {
return err
}
}
}
return os.Remove(from)
}
func move(from, to string) error {
if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
return err
}
f, err := os.Open(from)
if err != nil {
return err
}
defer f.Close()
c, err := os.Create(to)
if err != nil {
return err
}
defer c.Close()
if err := os.Chmod(to, 0600); err != nil {
return err
}
if _, err := c.ReadFrom(f); err != nil {
return err
}
if err := os.Remove(from); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,56 @@
package bridge
import (
"os"
"path/filepath"
"testing"
)
func TestMoveDir(t *testing.T) {
from, to := t.TempDir(), t.TempDir()
// Create some files in from.
if err := os.WriteFile(filepath.Join(from, "a"), []byte("a"), 0600); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(from, "b"), []byte("b"), 0600); err != nil {
t.Fatal(err)
}
if err := os.Mkdir(filepath.Join(from, "c"), 0700); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(from, "c", "d"), []byte("d"), 0600); err != nil {
t.Fatal(err)
}
// Move the files.
if err := moveDir(from, to); err != nil {
t.Fatal(err)
}
// Check that the files were moved.
if _, err := os.Stat(filepath.Join(from, "a")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "a")); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(from, "b")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "b")); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(from, "c")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "c")); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(from, "c", "d")); !os.IsNotExist(err) {
t.Fatal(err)
}
if _, err := os.Stat(filepath.Join(to, "c", "d")); err != nil {
t.Fatal(err)
}
}

117
internal/bridge/imap.go Normal file
View File

@ -0,0 +1,117 @@
package bridge
import (
"context"
"crypto/tls"
"fmt"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon"
imapEvents "github.com/ProtonMail/gluon/events"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/sirupsen/logrus"
)
const (
defaultClientName = "UnknownClient"
defaultClientVersion = "0.0.1"
)
func (bridge *Bridge) GetIMAPPort() int {
return bridge.vault.GetIMAPPort()
}
func (bridge *Bridge) SetIMAPPort(newPort int) error {
if newPort == bridge.vault.GetIMAPPort() {
return nil
}
if err := bridge.vault.SetIMAPPort(newPort); err != nil {
return err
}
return bridge.restartIMAP(context.Background())
}
func (bridge *Bridge) GetIMAPSSL() bool {
return bridge.vault.GetIMAPSSL()
}
func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
if newSSL == bridge.vault.GetIMAPSSL() {
return nil
}
if err := bridge.vault.SetIMAPSSL(newSSL); err != nil {
return err
}
return bridge.restartIMAP(context.Background())
}
func (bridge *Bridge) serveIMAP() error {
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
if err != nil {
return fmt.Errorf("failed to create IMAP listener: %w", err)
}
bridge.imapListener = imapListener
return bridge.imapServer.Serve(context.Background(), bridge.imapListener)
}
func (bridge *Bridge) restartIMAP(ctx context.Context) error {
if err := bridge.imapListener.Close(); err != nil {
logrus.WithError(err).Warn("Failed to close IMAP listener")
}
return bridge.serveIMAP()
}
func (bridge *Bridge) closeIMAP(ctx context.Context) error {
if err := bridge.imapServer.Close(ctx); err != nil {
logrus.WithError(err).Warn("Failed to close IMAP server")
}
if err := bridge.imapListener.Close(); err != nil {
logrus.WithError(err).Warn("Failed to close IMAP listener")
}
return nil
}
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
switch event := event.(type) {
case imapEvents.SessionAdded:
if !bridge.identifier.HasClient() {
bridge.identifier.SetClient(defaultClientName, defaultClientVersion)
}
case imapEvents.IMAPID:
bridge.identifier.SetClient(event.IMAPID.Name, event.IMAPID.Version)
}
}
func newIMAPServer(gluonDir string, version *semver.Version, tlsConfig *tls.Config) (*gluon.Server, error) {
imapServer, err := gluon.New(
gluon.WithTLS(tlsConfig),
gluon.WithDataDir(gluonDir),
gluon.WithVersionInfo(
int(version.Major()),
int(version.Minor()),
int(version.Patch()),
constants.FullAppName,
"TODO",
"TODO",
),
gluon.WithLogger(
logrus.StandardLogger().WriterLevel(logrus.InfoLevel),
logrus.StandardLogger().WriterLevel(logrus.InfoLevel),
),
)
if err != nil {
return nil, err
}
return imapServer, nil
}

View File

@ -1,30 +1,13 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
func (b *Bridge) ProvideLogsPath() (string, error) {
return b.locations.ProvideLogsPath()
func (bridge *Bridge) GetLogsPath() (string, error) {
return bridge.locator.ProvideLogsPath()
}
func (b *Bridge) GetLicenseFilePath() string {
return b.locations.GetLicenseFilePath()
func (bridge *Bridge) GetLicenseFilePath() string {
return bridge.locator.GetLicenseFilePath()
}
func (b *Bridge) GetDependencyLicensesLink() string {
return b.locations.GetDependencyLicensesLink()
func (bridge *Bridge) GetDependencyLicensesLink() string {
return bridge.locator.GetDependencyLicensesLink()
}

127
internal/bridge/mocks.go Normal file
View File

@ -0,0 +1,127 @@
package bridge
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge/mocks"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/golang/mock/gomock"
)
type Mocks struct {
TLSDialer *TestDialer
ProxyDialer *mocks.MockProxyDialer
TLSReporter *mocks.MockTLSReporter
TLSIssueCh chan struct{}
Updater *TestUpdater
Autostarter *mocks.MockAutostarter
}
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
ctl := gomock.NewController(tb)
mocks := &Mocks{
TLSDialer: NewTestDialer(),
ProxyDialer: mocks.NewMockProxyDialer(ctl),
TLSReporter: mocks.NewMockTLSReporter(ctl),
TLSIssueCh: make(chan struct{}),
Updater: NewTestUpdater(version, minAuto),
Autostarter: mocks.NewMockAutostarter(ctl),
}
// When using the proxy dialer, we want to use the test dialer.
mocks.ProxyDialer.EXPECT().DialTLSContext(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
return mocks.TLSDialer.DialTLSContext(ctx, network, address)
}).AnyTimes()
// When getting the TLS issue channel, we want to return the test channel.
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
return mocks
}
type TestDialer struct {
canDial bool
}
func NewTestDialer() *TestDialer {
return &TestDialer{
canDial: true,
}
}
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
if !d.canDial {
return nil, errors.New("cannot dial")
}
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
}
func (d *TestDialer) SetCanDial(canDial bool) {
d.canDial = canDial
}
type TestLocationsProvider struct {
config, cache string
}
func NewTestLocationsProvider(tb testing.TB) *TestLocationsProvider {
return &TestLocationsProvider{
config: tb.TempDir(),
cache: tb.TempDir(),
}
}
func (provider *TestLocationsProvider) UserConfig() string {
return provider.config
}
func (provider *TestLocationsProvider) UserCache() string {
return provider.cache
}
type TestUpdater struct {
latest updater.VersionInfo
}
func NewTestUpdater(version, minAuto *semver.Version) *TestUpdater {
return &TestUpdater{
latest: updater.VersionInfo{
Version: version,
MinAuto: minAuto,
RolloutProportion: 1.0,
},
}
}
func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Version) {
testUpdater.latest = updater.VersionInfo{
Version: version,
MinAuto: minAuto,
RolloutProportion: 1.0,
}
}
func (updater *TestUpdater) GetVersionInfo(downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error) {
return updater.latest, nil
}
func (updater *TestUpdater) InstallUpdate(downloader updater.Downloader, update updater.VersionInfo) error {
return nil
}

View File

@ -0,0 +1,163 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockTLSReporter is a mock of TLSReporter interface.
type MockTLSReporter struct {
ctrl *gomock.Controller
recorder *MockTLSReporterMockRecorder
}
// MockTLSReporterMockRecorder is the mock recorder for MockTLSReporter.
type MockTLSReporterMockRecorder struct {
mock *MockTLSReporter
}
// NewMockTLSReporter creates a new mock instance.
func NewMockTLSReporter(ctrl *gomock.Controller) *MockTLSReporter {
mock := &MockTLSReporter{ctrl: ctrl}
mock.recorder = &MockTLSReporterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTLSReporter) EXPECT() *MockTLSReporterMockRecorder {
return m.recorder
}
// GetTLSIssueCh mocks base method.
func (m *MockTLSReporter) GetTLSIssueCh() <-chan struct{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTLSIssueCh")
ret0, _ := ret[0].(<-chan struct{})
return ret0
}
// GetTLSIssueCh indicates an expected call of GetTLSIssueCh.
func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
}
// MockProxyDialer is a mock of ProxyDialer interface.
type MockProxyDialer struct {
ctrl *gomock.Controller
recorder *MockProxyDialerMockRecorder
}
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
type MockProxyDialerMockRecorder struct {
mock *MockProxyDialer
}
// NewMockProxyDialer creates a new mock instance.
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
mock := &MockProxyDialer{ctrl: ctrl}
mock.recorder = &MockProxyDialerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method.
func (m *MockProxyDialer) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy.
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
}
// DialTLSContext mocks base method.
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DialTLSContext indicates an expected call of DialTLSContext.
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
}
// DisallowProxy mocks base method.
func (m *MockProxyDialer) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy.
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
}
// MockAutostarter is a mock of Autostarter interface.
type MockAutostarter struct {
ctrl *gomock.Controller
recorder *MockAutostarterMockRecorder
}
// MockAutostarterMockRecorder is the mock recorder for MockAutostarter.
type MockAutostarterMockRecorder struct {
mock *MockAutostarter
}
// NewMockAutostarter creates a new mock instance.
func NewMockAutostarter(ctrl *gomock.Controller) *MockAutostarter {
mock := &MockAutostarter{ctrl: ctrl}
mock.recorder = &MockAutostarterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAutostarter) EXPECT() *MockAutostarterMockRecorder {
return m.recorder
}
// Disable mocks base method.
func (m *MockAutostarter) Disable() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disable")
ret0, _ := ret[0].(error)
return ret0
}
// Disable indicates an expected call of Disable.
func (mr *MockAutostarterMockRecorder) Disable() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disable", reflect.TypeOf((*MockAutostarter)(nil).Disable))
}
// Enable mocks base method.
func (m *MockAutostarter) Enable() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Enable")
ret0, _ := ret[0].(error)
return ret0
}
// Enable indicates an expected call of Enable.
func (mr *MockAutostarterMockRecorder) Enable() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Enable", reflect.TypeOf((*MockAutostarter)(nil).Enable))
}

View File

@ -1,26 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Code generated by ./release-notes.sh at 'Fri Jan 22 11:01:06 AM CET 2021'. DO NOT EDIT.
package bridge
const ReleaseNotes = `
`
const ReleaseFixedBugs = `• Fixed sending error caused by inconsistent use of upper and lower case in senders email address
`

View File

@ -1,44 +1,175 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
import (
"context"
func (b *Bridge) Get(key settings.Key) string {
return b.settings.Get(key)
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
)
func (bridge *Bridge) GetKeychainApp() (string, error) {
vaultDir, err := bridge.locator.ProvideSettingsPath()
if err != nil {
return "", err
}
return vault.GetHelper(vaultDir)
}
func (b *Bridge) Set(key settings.Key, value string) {
b.settings.Set(key, value)
func (bridge *Bridge) SetKeychainApp(helper string) error {
vaultDir, err := bridge.locator.ProvideSettingsPath()
if err != nil {
return err
}
return vault.SetHelper(vaultDir, helper)
}
func (b *Bridge) GetBool(key settings.Key) bool {
return b.settings.GetBool(key)
func (bridge *Bridge) GetGluonDir() string {
return bridge.vault.GetGluonDir()
}
func (b *Bridge) SetBool(key settings.Key, value bool) {
b.settings.SetBool(key, value)
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
if newGluonDir == bridge.GetGluonDir() {
return nil
}
if err := bridge.closeIMAP(context.Background()); err != nil {
return err
}
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
return err
}
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
return err
}
imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig)
if err != nil {
return err
}
for _, user := range bridge.users {
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return err
}
if err := imapServer.LoadUser(context.Background(), imapConn, user.GluonID(), user.GluonKey()); err != nil {
return err
}
}
bridge.imapServer = imapServer
return bridge.serveIMAP()
}
func (b *Bridge) GetInt(key settings.Key) int {
return b.settings.GetInt(key)
func (bridge *Bridge) GetProxyAllowed() bool {
return bridge.vault.GetProxyAllowed()
}
func (b *Bridge) SetInt(key settings.Key, value int) {
b.settings.SetInt(key, value)
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
if allowed {
bridge.proxyDialer.AllowProxy()
} else {
bridge.proxyDialer.DisallowProxy()
}
return bridge.vault.SetProxyAllowed(allowed)
}
func (bridge *Bridge) GetShowAllMail() bool {
return bridge.vault.GetShowAllMail()
}
func (bridge *Bridge) SetShowAllMail(show bool) error {
panic("TODO")
}
func (bridge *Bridge) GetAutostart() bool {
return bridge.vault.GetAutostart()
}
func (bridge *Bridge) SetAutostart(autostart bool) error {
if err := bridge.vault.SetAutostart(autostart); err != nil {
return err
}
var err error
if autostart {
err = bridge.autostarter.Enable()
} else {
err = bridge.autostarter.Disable()
}
return err
}
func (bridge *Bridge) GetAutoUpdate() bool {
return bridge.vault.GetAutoUpdate()
}
func (bridge *Bridge) SetAutoUpdate(autoUpdate bool) error {
if bridge.vault.GetAutoUpdate() == autoUpdate {
return nil
}
if err := bridge.vault.SetAutoUpdate(autoUpdate); err != nil {
return err
}
bridge.updateCheckCh <- struct{}{}
return nil
}
func (bridge *Bridge) GetUpdateChannel() updater.Channel {
return updater.Channel(bridge.vault.GetUpdateChannel())
}
func (bridge *Bridge) SetUpdateChannel(channel updater.Channel) error {
if bridge.vault.GetUpdateChannel() == channel {
return nil
}
if err := bridge.vault.SetUpdateChannel(channel); err != nil {
return err
}
bridge.updateCheckCh <- struct{}{}
return nil
}
func (bridge *Bridge) GetLastVersion() *semver.Version {
return bridge.vault.GetLastVersion()
}
func (bridge *Bridge) GetFirstStart() bool {
return bridge.vault.GetFirstStart()
}
func (bridge *Bridge) SetFirstStart(firstStart bool) error {
return bridge.vault.SetFirstStart(firstStart)
}
func (bridge *Bridge) GetFirstStartGUI() bool {
return bridge.vault.GetFirstStartGUI()
}
func (bridge *Bridge) SetFirstStartGUI(firstStart bool) error {
return bridge.vault.SetFirstStartGUI(firstStart)
}
func (bridge *Bridge) GetColorScheme() string {
return bridge.vault.GetColorScheme()
}
func (bridge *Bridge) SetColorScheme(colorScheme string) error {
return bridge.vault.SetColorScheme(colorScheme)
}

View File

@ -0,0 +1,156 @@
package bridge_test
import (
"context"
"os"
"testing"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_Settings_GluonDir(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Create a user.
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
require.NoError(t, err)
// Create a new location for the Gluon data.
newGluonDir := t.TempDir()
// Move the gluon dir; it should also move the user's data.
require.NoError(t, bridge.SetGluonDir(context.Background(), newGluonDir))
// Check that the new directory is not empty.
entries, err := os.ReadDir(newGluonDir)
require.NoError(t, err)
// There should be at least one entry.
require.NotEmpty(t, entries)
})
})
}
func TestBridge_Settings_IMAPPort(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, the port is 1143.
require.Equal(t, 1143, bridge.GetIMAPPort())
// Set the port to 1144.
require.NoError(t, bridge.SetIMAPPort(1144))
// Get the new setting.
require.Equal(t, 1144, bridge.GetIMAPPort())
})
})
}
func TestBridge_Settings_IMAPSSL(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, IMAP SSL is disabled.
require.False(t, bridge.GetIMAPSSL())
// Enable IMAP SSL.
require.NoError(t, bridge.SetIMAPSSL(true))
// Get the new setting.
require.True(t, bridge.GetIMAPSSL())
})
})
}
func TestBridge_Settings_SMTPPort(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, the port is 1025.
require.Equal(t, 1025, bridge.GetSMTPPort())
// Set the port to 1024.
require.NoError(t, bridge.SetSMTPPort(1024))
// Get the new setting.
require.Equal(t, 1024, bridge.GetSMTPPort())
})
})
}
func TestBridge_Settings_SMTPSSL(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, SMTP SSL is disabled.
require.False(t, bridge.GetSMTPSSL())
// Enable SMTP SSL.
require.NoError(t, bridge.SetSMTPSSL(true))
// Get the new setting.
require.True(t, bridge.GetSMTPSSL())
})
})
}
func TestBridge_Settings_Proxy(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, proxy is allowed.
require.True(t, bridge.GetProxyAllowed())
// Disallow proxy.
mocks.ProxyDialer.EXPECT().DisallowProxy()
require.NoError(t, bridge.SetProxyAllowed(false))
// Get the new setting.
require.False(t, bridge.GetProxyAllowed())
})
})
}
func TestBridge_Settings_Autostart(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, autostart is disabled.
require.False(t, bridge.GetAutostart())
// Enable autostart.
mocks.Autostarter.EXPECT().Enable().Return(nil)
require.NoError(t, bridge.SetAutostart(true))
// Get the new setting.
require.True(t, bridge.GetAutostart())
})
})
}
func TestBridge_Settings_FirstStart(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true.
require.True(t, bridge.GetFirstStart())
// Set first start to false.
require.NoError(t, bridge.SetFirstStart(false))
// Get the new setting.
require.False(t, bridge.GetFirstStart())
})
})
}
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// By default, first start is true.
require.True(t, bridge.GetFirstStartGUI())
// Set first start to false.
require.NoError(t, bridge.SetFirstStartGUI(false))
// Get the new setting.
require.False(t, bridge.GetFirstStartGUI())
})
})
}

109
internal/bridge/smtp.go Normal file
View File

@ -0,0 +1,109 @@
package bridge
import (
"crypto/tls"
"fmt"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
)
func (bridge *Bridge) GetSMTPPort() int {
return bridge.vault.GetSMTPPort()
}
func (bridge *Bridge) SetSMTPPort(newPort int) error {
if newPort == bridge.vault.GetSMTPPort() {
return nil
}
if err := bridge.vault.SetSMTPPort(newPort); err != nil {
return err
}
return bridge.restartSMTP()
}
func (bridge *Bridge) GetSMTPSSL() bool {
return bridge.vault.GetSMTPSSL()
}
func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
if newSSL == bridge.vault.GetSMTPSSL() {
return nil
}
if err := bridge.vault.SetSMTPSSL(newSSL); err != nil {
return err
}
return bridge.restartSMTP()
}
func (bridge *Bridge) serveSMTP() error {
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
if err != nil {
return fmt.Errorf("failed to create SMTP listener: %w", err)
}
bridge.smtpListener = smtpListener
go func() {
if err := bridge.smtpServer.Serve(bridge.smtpListener); err != nil {
logrus.WithError(err).Error("SMTP server stopped")
}
}()
return nil
}
func (bridge *Bridge) restartSMTP() error {
if err := bridge.closeSMTP(); err != nil {
return err
}
smtpServer, err := newSMTPServer(bridge.smtpBackend, bridge.tlsConfig)
if err != nil {
return err
}
bridge.smtpServer = smtpServer
return bridge.serveSMTP()
}
func (bridge *Bridge) closeSMTP() error {
if err := bridge.smtpServer.Close(); err != nil {
logrus.WithError(err).Warn("Failed to close SMTP server")
}
// Don't close the SMTP listener -- it's closed by the server.
return nil
}
func newSMTPServer(smtpBackend *smtpBackend, tlsConfig *tls.Config) (*smtp.Server, error) {
smtpServer := smtp.NewServer(smtpBackend)
smtpServer.TLSConfig = tlsConfig
smtpServer.Domain = constants.Host
smtpServer.AllowInsecureAuth = true
smtpServer.MaxLineLength = 1 << 16
smtpServer.EnableAuth(sasl.Login, func(conn *smtp.Conn) sasl.Server {
return sasl.NewLoginServer(func(address, password string) error {
user, err := conn.Server().Backend.Login(nil, address, password)
if err != nil {
return err
}
conn.SetSession(user)
return nil
})
})
return smtpServer, nil
}

View File

@ -0,0 +1,70 @@
package bridge
import (
"sync"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"golang.org/x/exp/slices"
)
type smtpBackend struct {
users []*user.User
usersLock sync.RWMutex
}
func newSMTPBackend() (*smtpBackend, error) {
return &smtpBackend{}, nil
}
func (backend *smtpBackend) Login(state *smtp.ConnectionState, username string, password string) (smtp.Session, error) {
backend.usersLock.RLock()
defer backend.usersLock.RUnlock()
for _, user := range backend.users {
if slices.Contains(user.Addresses(), username) && user.BridgePass() == password {
return user.NewSMTPSession(username)
}
}
return nil, ErrNoSuchUser
}
func (backend *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
return nil, ErrNotImplemented
}
// addUser adds the given user to the backend.
// It returns an error if a user with the same ID already exists.
func (backend *smtpBackend) addUser(user *user.User) error {
backend.usersLock.Lock()
defer backend.usersLock.Unlock()
for _, u := range backend.users {
if u.ID() == user.ID() {
return ErrUserAlreadyExists
}
}
backend.users = append(backend.users, user)
return nil
}
// removeUser removes the given user from the backend.
// It returns an error if the user doesn't exist.
func (backend *smtpBackend) removeUser(user *user.User) error {
backend.usersLock.Lock()
defer backend.usersLock.Unlock()
idx := xslices.Index(backend.users, user)
if idx < 0 {
return ErrNoSuchUser
}
backend.users = append(backend.users[:idx], backend.users[idx+1:]...)
return nil
}

View File

@ -1,87 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"fmt"
"path/filepath"
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
"github.com/ProtonMail/proton-bridge/v2/internal/store"
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
)
type storeFactory struct {
cacheProvider CacheProvider
sentryReporter *sentry.Reporter
panicHandler users.PanicHandler
eventListener listener.Listener
events *store.Events
cache cache.Cache
builder *message.Builder
}
func newStoreFactory(
cacheProvider CacheProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
) *storeFactory {
return &storeFactory{
cacheProvider: cacheProvider,
sentryReporter: sentryReporter,
panicHandler: panicHandler,
eventListener: eventListener,
events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
cache: cache,
builder: builder,
}
}
// New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
return store.New(
f.sentryReporter,
f.panicHandler,
user,
f.eventListener,
f.cache,
f.builder,
getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
f.events,
)
}
// Remove removes all store files for given user.
func (f *storeFactory) Remove(userID string) error {
return store.RemoveStore(
f.events,
getUserStorePath(f.cacheProvider.GetDBDir(), userID),
userID,
)
}
// getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) {
return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
}

View File

@ -1,64 +1,5 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"crypto/tls"
pkgTLS "github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/pkg/errors"
logrus "github.com/sirupsen/logrus"
)
func (b *Bridge) GetTLSConfig() (*tls.Config, error) {
if !b.tls.HasCerts() {
if err := b.generateTLSCerts(); err != nil {
return nil, err
}
}
tlsConfig, err := b.tls.GetConfig()
if err == nil {
return tlsConfig, nil
}
logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates")
if err := b.generateTLSCerts(); err != nil {
return nil, err
}
return b.tls.GetConfig()
}
func (b *Bridge) generateTLSCerts() error {
template, err := pkgTLS.NewTLSTemplate()
if err != nil {
return errors.Wrap(err, "failed to generate TLS template")
}
if err := b.tls.GenerateCerts(template); err != nil {
return errors.Wrap(err, "failed to generate TLS certs")
}
if err := b.tls.InstallCerts(); err != nil {
return errors.Wrap(err, "failed to install TLS certs")
}
return nil
func (bridge *Bridge) GetBridgeTLSCert() ([]byte, []byte) {
return bridge.vault.GetBridgeTLSCert(), bridge.vault.GetBridgeTLSKey()
}

View File

@ -1,62 +1,43 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
import (
"github.com/Masterminds/semver/v3"
"context"
"net"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
)
type Locator interface {
ProvideSettingsPath() (string, error)
ProvideLogsPath() (string, error)
GetLicenseFilePath() string
GetDependencyLicensesLink() string
Clear() error
ClearUpdates() error
}
type CacheProvider interface {
GetIMAPCachePath() string
GetDBDir() string
GetDefaultMessageCacheDir() string
type Identifier interface {
GetUserAgent() string
HasClient() bool
SetClient(name, version string)
SetPlatform(platform string)
}
type SettingsProvider interface {
Get(key settings.Key) string
Set(key settings.Key, value string)
type TLSReporter interface {
GetTLSIssueCh() <-chan struct{}
}
GetBool(key settings.Key) bool
SetBool(key settings.Key, val bool)
type ProxyDialer interface {
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
GetInt(key settings.Key) int
SetInt(key settings.Key, val int)
AllowProxy()
DisallowProxy()
}
type Autostarter interface {
Enable() error
Disable() error
}
type Updater interface {
Check() (updater.VersionInfo, error)
IsDowngrade(updater.VersionInfo) bool
InstallUpdate(updater.VersionInfo) error
}
type Versioner interface {
RemoveOtherVersions(*semver.Version) error
GetVersionInfo(downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error)
InstallUpdate(downloader updater.Downloader, update updater.VersionInfo) error
}

View File

@ -0,0 +1,72 @@
package bridge
import (
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
)
func (bridge *Bridge) CheckForUpdates() {
bridge.updateCheckCh <- struct{}{}
}
func (bridge *Bridge) watchForUpdates() error {
ticker := time.NewTicker(constants.UpdateCheckInterval)
go func() {
for {
select {
case <-bridge.updateCheckCh:
case <-ticker.C:
}
version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel())
if err != nil {
continue
}
if err := bridge.handleUpdate(version); err != nil {
continue
}
}
}()
bridge.updateCheckCh <- struct{}{}
return nil
}
func (bridge *Bridge) handleUpdate(version updater.VersionInfo) error {
switch {
case !version.Version.GreaterThan(bridge.curVersion):
bridge.publish(events.UpdateNotAvailable{})
case version.RolloutProportion < bridge.vault.GetUpdateRollout():
bridge.publish(events.UpdateNotAvailable{})
case bridge.curVersion.LessThan(version.MinAuto):
bridge.publish(events.UpdateAvailable{
Version: version,
CanInstall: false,
})
case !bridge.vault.GetAutoUpdate():
bridge.publish(events.UpdateAvailable{
Version: version,
CanInstall: true,
})
default:
if err := bridge.updater.InstallUpdate(bridge.api, version); err != nil {
return err
}
bridge.publish(events.UpdateInstalled{
Version: version,
})
}
return nil
}

View File

@ -1,26 +1,9 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge
func (b *Bridge) GetCurrentUserAgent() string {
return b.userAgent.String()
func (bridge *Bridge) GetCurrentUserAgent() string {
return bridge.identifier.GetUserAgent()
}
func (b *Bridge) SetCurrentPlatform(platform string) {
b.userAgent.SetPlatform(platform)
func (bridge *Bridge) SetCurrentPlatform(platform string) {
bridge.identifier.SetPlatform(platform)
}

434
internal/bridge/users.go Normal file
View File

@ -0,0 +1,434 @@
package bridge
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
type UserInfo struct {
// UserID is the user's API ID.
UserID string
// Username is the user's API username.
Username string
// Connected is true if the user is logged in (has API auth).
Connected bool
// Addresses holds the user's email addresses. The first address is the primary address.
Addresses []string
// AddressMode is the user's address mode.
AddressMode AddressMode
// BridgePass is the user's bridge password.
BridgePass string
// UsedSpace is the amount of space used by the user.
UsedSpace int
// MaxSpace is the total amount of space available to the user.
MaxSpace int
}
type AddressMode int
const (
SplitMode AddressMode = iota
CombinedMode
)
// GetUserIDs returns the IDs of all known users (authorized or not).
func (bridge *Bridge) GetUserIDs() []string {
return bridge.vault.GetUserIDs()
}
// GetUserInfo returns info about the given user.
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
vaultUser, err := bridge.vault.GetUser(userID)
if err != nil {
return UserInfo{}, err
}
user, ok := bridge.users[userID]
if !ok {
return getUserInfo(vaultUser.UserID(), vaultUser.Username()), nil
}
return getConnUserInfo(user), nil
}
// QueryUserInfo queries the user info by username or address.
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
for userID, user := range bridge.users {
if user.Match(query) {
return bridge.GetUserInfo(userID)
}
}
return UserInfo{}, ErrNoSuchUser
}
// LoginUser authorizes a new bridge user with the given username and password.
// If necessary, a TOTP and mailbox password are requested via the callbacks.
func (bridge *Bridge) LoginUser(
ctx context.Context,
username, password string,
getTOTP func() (string, error),
getKeyPass func() ([]byte, error),
) (string, error) {
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
if err != nil {
return "", err
}
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
totp, err := getTOTP()
if err != nil {
return "", err
}
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
return "", err
}
}
var keyPass []byte
if auth.PasswordMode == liteapi.TwoPasswordMode {
pass, err := getKeyPass()
if err != nil {
return "", err
}
keyPass = pass
} else {
keyPass = []byte(password)
}
apiUser, apiAddrs, userKR, addrKRs, saltedKeyPass, err := client.Unlock(ctx, keyPass)
if err != nil {
return "", err
}
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
return "", err
}
return apiUser.ID, nil
}
// LogoutUser logs out the given user.
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
return bridge.logoutUser(ctx, userID, true, false)
}
// DeleteUser deletes the given user.
// If it is authorized, it is logged out first.
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
if bridge.users[userID] != nil {
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
return err
}
}
if err := bridge.vault.DeleteUser(userID); err != nil {
return err
}
bridge.publish(events.UserDeleted{
UserID: userID,
})
return nil
}
func (bridge *Bridge) GetAddressMode(userID string) (AddressMode, error) {
panic("TODO")
}
func (bridge *Bridge) SetAddressMode(userID string, mode AddressMode) error {
panic("TODO")
}
// loadUsers loads authorized users from the vault.
func (bridge *Bridge) loadUsers(ctx context.Context) error {
for _, userID := range bridge.vault.GetUserIDs() {
user, err := bridge.vault.GetUser(userID)
if err != nil {
return err
}
if user.AuthUID() == "" {
continue
}
if err := bridge.loadUser(ctx, user); err != nil {
logrus.WithError(err).Error("Failed to load connected user")
if err := user.Clear(); err != nil {
logrus.WithError(err).Error("Failed to clear user")
}
continue
}
}
return nil
}
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
if err != nil {
return fmt.Errorf("failed to create API client: %w", err)
}
apiUser, apiAddrs, userKR, addrKRs, err := client.UnlockSalted(ctx, user.KeyPass())
if err != nil {
return fmt.Errorf("failed to unlock user: %w", err)
}
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
return fmt.Errorf("failed to add user: %w", err)
}
bridge.publish(events.UserLoggedIn{
UserID: user.UserID(),
})
return nil
}
// addUser adds a new user with an already salted mailbox password.
func (bridge *Bridge) addUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
authUID, authRef string,
saltedKeyPass []byte,
) error {
if _, ok := bridge.users[apiUser.ID]; ok {
return ErrUserAlreadyLoggedIn
}
var user *user.User
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil {
return err
}
user = existingUser
} else {
newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
if err != nil {
return err
}
user = newUser
}
go func() {
for event := range user.GetNotifyCh() {
switch event := event.(type) {
case events.UserDeauth:
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
logrus.WithError(err).Error("Failed to logout user")
}
}
bridge.publish(event)
}
}()
// Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user.
client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error {
if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok {
bridge.identifier.SetClient(imapID.Name, imapID.Version)
}
return nil
})
bridge.publish(events.UserLoggedIn{
UserID: user.ID(),
})
return nil
}
func (bridge *Bridge) addNewUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
authUID, authRef string,
saltedKeyPass []byte,
) (*user.User, error) {
vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
if err != nil {
return nil, err
}
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
if err != nil {
return nil, err
}
gluonKey, err := crypto.RandomToken(32)
if err != nil {
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, gluonKey)
if err != nil {
return nil, err
}
if err := vaultUser.UpdateGluonData(gluonID, gluonKey); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err
}
bridge.users[apiUser.ID] = user
return user, nil
}
func (bridge *Bridge) addExistingUser(
ctx context.Context,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
authUID, authRef string,
saltedKeyPass []byte,
) (*user.User, error) {
vaultUser, err := bridge.vault.GetUser(apiUser.ID)
if err != nil {
return nil, err
}
if err := vaultUser.UpdateAuth(authUID, authRef); err != nil {
return nil, err
}
if err := vaultUser.UpdateKeyPass(saltedKeyPass); err != nil {
return nil, err
}
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
if err != nil {
return nil, err
}
imapConn, err := user.NewGluonConnector(ctx)
if err != nil {
return nil, err
}
if err := bridge.imapServer.LoadUser(ctx, imapConn, user.GluonID(), user.GluonKey()); err != nil {
return nil, err
}
if err := bridge.smtpBackend.addUser(user); err != nil {
return nil, err
}
bridge.users[apiUser.ID] = user
return user, nil
}
// logoutUser closes and removes the user with the given ID.
// If withAPI is true, the user will additionally be logged out from API.
// If withFiles is true, the user's files will be deleted.
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
user, ok := bridge.users[userID]
if !ok {
return ErrNoSuchUser
}
vaultUser, err := bridge.vault.GetUser(userID)
if err != nil {
return err
}
if err := bridge.imapServer.RemoveUser(ctx, vaultUser.GluonID(), withFiles); err != nil {
return err
}
if err := bridge.smtpBackend.removeUser(user); err != nil {
return err
}
if withAPI {
if err := user.Logout(ctx); err != nil {
return err
}
}
if err := user.Close(ctx); err != nil {
return err
}
if err := vaultUser.Clear(); err != nil {
return err
}
delete(bridge.users, userID)
bridge.publish(events.UserLoggedOut{
UserID: userID,
})
return nil
}
// getUserInfo returns information about a disconnected user.
func getUserInfo(userID, username string) UserInfo {
return UserInfo{
UserID: userID,
Username: username,
AddressMode: CombinedMode,
}
}
// getConnUserInfo returns information about a connected user.
func getConnUserInfo(user *user.User) UserInfo {
return UserInfo{
Connected: true,
UserID: user.ID(),
Username: user.Name(),
Addresses: user.Addresses(),
AddressMode: CombinedMode,
BridgePass: user.BridgePass(),
UsedSpace: user.UsedSpace(),
MaxSpace: user.MaxSpace(),
}
}

View File

@ -0,0 +1,286 @@
package bridge_test
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi/server"
)
func TestBridge_WithoutUsers(t *testing.T) {
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_Login(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
require.NoError(t, err)
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginLogoutLogin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Logout the user.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// The user is now disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Login the user again.
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginDeleteLogin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// The user is now connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// Delete the user.
require.NoError(t, bridge.DeleteUser(ctx, userID))
// The user is now gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
// Login the user again.
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginDeauthLogin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// Get a channel to receive the deauth event.
eventCh, done := bridge.GetEvents(events.UserDeauth{})
defer done()
// Deauth the user.
require.NoError(t, s.RevokeUser(userID))
// The user is eventually disconnected.
require.Eventually(t, func() bool {
return len(getConnectedUserIDs(t, bridge)) == 0
}, 10*time.Second, time.Second)
// We should get a deauth event.
require.IsType(t, events.UserDeauth{}, <-eventCh)
// Login the user after the disconnection.
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
require.Equal(t, userID, newUserID)
// The user is connected again.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginExpireLogin(t *testing.T) {
const authLife = 2 * time.Second
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
s.SetAuthLife(authLife)
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user. Its auth will only be valid for a short time.
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
// Wait until the auth expires.
time.Sleep(authLife)
// The user will have to refresh but the logout will still succeed.
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
})
}
func TestBridge_FailToLoad(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
// Deauth the user while bridge is stopped.
require.NoError(t, s.RevokeUser(userID))
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginRestart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still connected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginLogoutRestart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
// Logout the user.
require.NoError(t, bridge.LogoutUser(ctx, userID))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still disconnected.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_LoginDeleteRestart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
// Delete the user.
require.NoError(t, bridge.DeleteUser(ctx, userID))
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The user is still gone.
require.Empty(t, bridge.GetUserIDs())
require.Empty(t, getConnectedUserIDs(t, bridge))
})
})
}
func TestBridge_BridgePass(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
var userID, pass string
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Login the user.
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
// Retrieve the bridge pass.
pass = must(bridge.GetUserInfo(userID)).BridgePass
// Log the user out.
require.NoError(t, bridge.LogoutUser(ctx, userID))
// Log the user back in.
must(bridge.LoginUser(ctx, username, password, nil, nil))
// The bridge pass should be the same.
require.Equal(t, pass, pass)
})
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// The bridge should load schizofrenic.
require.Equal(t, []string{userID}, bridge.GetUserIDs())
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
// The bridge pass should be the same.
require.Equal(t, pass, must(bridge.GetUserInfo(userID)).BridgePass)
})
})
}

View File

@ -15,9 +15,31 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tls
package certs
import "golang.org/x/sys/execabs"
import (
"os"
"golang.org/x/sys/execabs"
)
func installCert(certPEM []byte) error {
name, err := writeToTempFile(certPEM)
if err != nil {
return err
}
return addTrustedCert(name)
}
func uninstallCert(certPEM []byte) error {
name, err := writeToTempFile(certPEM)
if err != nil {
return err
}
return removeTrustedCert(name)
}
func addTrustedCert(certPath string) error {
return execabs.Command( //nolint:gosec
@ -44,10 +66,20 @@ func removeTrustedCert(certPath string) error {
).Run()
}
func (t *TLS) InstallCerts() error {
return addTrustedCert(t.getTLSCertPath())
}
// writeToTempFile writes the given data to a temporary file and returns the path.
func writeToTempFile(data []byte) (string, error) {
f, err := os.CreateTemp("", "tls")
if err != nil {
return "", err
}
func (t *TLS) UninstallCerts() error {
return removeTrustedCert(t.getTLSCertPath())
if _, err := f.Write(data); err != nil {
return "", err
}
if err := f.Close(); err != nil {
return "", err
}
return f.Name(), nil
}

View File

@ -15,12 +15,12 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tls
package certs
func (t *TLS) InstallCerts() error {
func installCert([]byte) error {
return nil // Linux doesn't have a root cert store.
}
func (t *TLS) UninstallCerts() error {
func uninstallCert([]byte) error {
return nil // Linux doesn't have a root cert store.
}

View File

@ -15,12 +15,12 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tls
package certs
func (t *TLS) InstallCerts() error {
func installCert([]byte) error {
return nil // NOTE(GODT-986): Install certs to root cert store?
}
func (t *TLS) UninstallCerts() error {
func uninstallCert([]byte) error {
return nil // NOTE(GODT-986): Uninstall certs from root cert store?
}

View File

@ -0,0 +1,15 @@
package certs
type Installer struct{}
func NewInstaller() *Installer {
return &Installer{}
}
func (installer *Installer) InstallCert(certPEM []byte) error {
return installCert(certPEM)
}
func (installer *Installer) UninstallCert(certPEM []byte) error {
return uninstallCert(certPEM)
}

View File

@ -15,9 +15,10 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tls
package certs
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@ -27,22 +28,13 @@ import (
"fmt"
"math/big"
"net"
"os"
"path/filepath"
"time"
"github.com/pkg/errors"
)
type TLS struct {
settingsPath string
}
func New(settingsPath string) *TLS {
return &TLS{
settingsPath: settingsPath,
}
}
// ErrTLSCertExpiresSoon is returned when the TLS certificate is about to expire.
var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon")
// NewTLSTemplate creates a new TLS template certificate with a random serial number.
func NewTLSTemplate() (*x509.Certificate, error) {
@ -69,108 +61,40 @@ func NewTLSTemplate() (*x509.Certificate, error) {
}, nil
}
// NewPEMKeyPair return a new TLS private key and certificate in PEM encoded format.
func NewPEMKeyPair() (pemCert, pemKey []byte, err error) {
template, err := NewTLSTemplate()
if err != nil {
return nil, nil, errors.Wrap(err, "failed to generate TLS template")
}
// GenerateTLSCert generates a new TLS certificate and returns it as PEM.
var GenerateCert = func(template *x509.Certificate) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to generate private key")
}
pemKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create certificate")
}
pemCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certPEM := new(bytes.Buffer)
return pemCert, pemKey, nil
}
var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon")
// getTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP).
func (t *TLS) getTLSCertPath() string {
return filepath.Join(t.settingsPath, "cert.pem")
}
// getTLSKeyPath returns path to private key; used for TLS servers (IMAP, SMTP).
func (t *TLS) getTLSKeyPath() string {
return filepath.Join(t.settingsPath, "key.pem")
}
// HasCerts returns whether TLS certs have been generated.
func (t *TLS) HasCerts() bool {
if _, err := os.Stat(t.getTLSCertPath()); err != nil {
return false
if err := pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return nil, nil, err
}
if _, err := os.Stat(t.getTLSKeyPath()); err != nil {
return false
keyPEM := new(bytes.Buffer)
if err := pem.Encode(keyPEM, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
return nil, nil, err
}
return true
}
// GenerateCerts generates certs from the given template.
func (t *TLS) GenerateCerts(template *x509.Certificate) error {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return errors.Wrap(err, "failed to generate private key")
}
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return errors.Wrap(err, "failed to create certificate")
}
certOut, err := os.Create(t.getTLSCertPath())
if err != nil {
return err
}
defer certOut.Close() //nolint:errcheck,gosec
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return err
}
keyOut, err := os.OpenFile(t.getTLSKeyPath(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return err
}
defer keyOut.Close() //nolint:errcheck,gosec
return pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return certPEM.Bytes(), keyPEM.Bytes(), nil
}
// GetConfig tries to load TLS config or generate new one which is then returned.
func (t *TLS) GetConfig() (*tls.Config, error) {
c, err := tls.LoadX509KeyPair(t.getTLSCertPath(), t.getTLSKeyPath())
func GetConfig(certPEM, keyPEM []byte) (*tls.Config, error) {
c, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, errors.Wrap(err, "failed to load keypair")
}
return getConfigFromKeyPair(c)
}
// GetConfigFromPEMKeyPair load a TLS config from PEM encoded certificate and key.
func GetConfigFromPEMKeyPair(permCert, pemKey []byte) (*tls.Config, error) {
c, err := tls.X509KeyPair(permCert, pemKey)
if err != nil {
return nil, errors.Wrap(err, "failed to load keypair")
}
return getConfigFromKeyPair(c)
}
func getConfigFromKeyPair(c tls.Certificate) (*tls.Config, error) {
var err error
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "failed to parse certificate")

View File

@ -15,10 +15,10 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package tls
package certs
import (
"os"
"crypto/tls"
"testing"
"time"
@ -26,12 +26,6 @@ import (
)
func TestGetOldConfig(t *testing.T) {
dir, err := os.MkdirTemp("", "test-tls")
require.NoError(t, err)
// Create new tls object.
tls := New(dir)
// Create new TLS template.
tlsTemplate, err := NewTLSTemplate()
require.NoError(t, err)
@ -41,20 +35,15 @@ func TestGetOldConfig(t *testing.T) {
tlsTemplate.NotAfter = time.Now()
// Generate the certs from the template.
require.NoError(t, tls.GenerateCerts(tlsTemplate))
certPEM, keyPEM, err := GenerateCert(tlsTemplate)
require.NoError(t, err)
// Generate the config from the certs -- it's going to expire soon so we don't want to use it.
_, err = tls.GetConfig()
_, err = GetConfig(certPEM, keyPEM)
require.Equal(t, err, ErrTLSCertExpiresSoon)
}
func TestGetValidConfig(t *testing.T) {
dir, err := os.MkdirTemp("", "test-tls")
require.NoError(t, err)
// Create new tls object.
tls := New(dir)
// Create new TLS template.
tlsTemplate, err := NewTLSTemplate()
require.NoError(t, err)
@ -64,10 +53,11 @@ func TestGetValidConfig(t *testing.T) {
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
// Generate the certs from the template.
require.NoError(t, tls.GenerateCerts(tlsTemplate))
certPEM, keyPEM, err := GenerateCert(tlsTemplate)
require.NoError(t, err)
// Generate the config from the certs -- it's not going to expire soon so we want to use it.
config, err := tls.GetConfig()
config, err := GetConfig(certPEM, keyPEM)
require.NoError(t, err)
require.Equal(t, len(config.Certificates), 1)
@ -77,9 +67,13 @@ func TestGetValidConfig(t *testing.T) {
}
func TestNewConfig(t *testing.T) {
pemCert, pemKey, err := NewPEMKeyPair()
tlsTemplate, err := NewTLSTemplate()
require.NoError(t, err)
_, err = GetConfigFromPEMKeyPair(pemCert, pemKey)
pemCert, pemKey, err := GenerateCert(tlsTemplate)
require.NoError(t, err)
cert, err := tls.X509KeyPair(pemCert, pemKey)
require.NoError(t, err)
require.NotNil(t, cert)
}

View File

@ -23,7 +23,7 @@ import (
"strconv"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/ProtonMail/proton-bridge/v2/pkg/mobileconfig"
"golang.org/x/sys/execabs"
)

View File

@ -1,70 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package cache provides access to contents inside a cache directory.
package cache
import (
"os"
"path/filepath"
"github.com/ProtonMail/proton-bridge/v2/pkg/files"
)
type Cache struct {
dir, version string
}
func New(dir, version string) (*Cache, error) {
if err := os.MkdirAll(filepath.Join(dir, version), 0o700); err != nil {
return nil, err
}
return &Cache{
dir: dir,
version: version,
}, nil
}
// GetDBDir returns folder for db files.
func (c *Cache) GetDBDir() string {
return c.getCurrentCacheDir()
}
// GetDefaultMessageCacheDir returns folder for cached messages files.
func (c *Cache) GetDefaultMessageCacheDir() string {
return filepath.Join(c.getCurrentCacheDir(), "messages")
}
// GetIMAPCachePath returns path to file with IMAP status.
func (c *Cache) GetIMAPCachePath() string {
return filepath.Join(c.getCurrentCacheDir(), "user_info.json")
}
// GetTransferDir returns folder for import-export rules files.
func (c *Cache) GetTransferDir() string {
return c.getCurrentCacheDir()
}
// RemoveOldVersions removes any cache dirs that are not the current version.
func (c *Cache) RemoveOldVersions() error {
return files.Remove(c.dir).Except(c.getCurrentCacheDir()).Do()
}
func (c *Cache) getCurrentCacheDir() string {
return filepath.Join(c.dir, c.version)
}

View File

@ -1,69 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRemoveOldVersions(t *testing.T) {
dir, err := os.MkdirTemp("", "test-cache")
require.NoError(t, err)
cache, err := New(dir, "c4")
require.NoError(t, err)
createFilesInDir(t, dir,
"unexpected1.txt",
"c1/unexpected1.txt",
"c2/unexpected2.txt",
"c3/unexpected3.txt",
"something.txt",
)
require.DirExists(t, filepath.Join(dir, "c4"))
require.FileExists(t, filepath.Join(dir, "unexpected1.txt"))
require.FileExists(t, filepath.Join(dir, "c1", "unexpected1.txt"))
require.FileExists(t, filepath.Join(dir, "c2", "unexpected2.txt"))
require.FileExists(t, filepath.Join(dir, "c3", "unexpected3.txt"))
require.FileExists(t, filepath.Join(dir, "something.txt"))
assert.NoError(t, cache.RemoveOldVersions())
assert.DirExists(t, filepath.Join(dir, "c4"))
assert.NoFileExists(t, filepath.Join(dir, "unexpected1.txt"))
assert.NoFileExists(t, filepath.Join(dir, "c1", "unexpected1.txt"))
assert.NoFileExists(t, filepath.Join(dir, "c2", "unexpected2.txt"))
assert.NoFileExists(t, filepath.Join(dir, "c3", "unexpected3.txt"))
assert.NoFileExists(t, filepath.Join(dir, "something.txt"))
}
func createFilesInDir(t *testing.T, dir string, files ...string) {
for _, target := range files {
require.NoError(t, os.MkdirAll(filepath.Dir(filepath.Join(dir, target)), 0o700))
f, err := os.Create(filepath.Join(dir, target))
require.NoError(t, err)
require.NoError(t, f.Close())
}
}

View File

@ -1,151 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package settings
import (
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"sync"
"github.com/sirupsen/logrus"
)
type keyValueStore struct {
vals map[Key]string
path string
lock *sync.RWMutex
}
// newKeyValueStore returns loaded preferences.
func newKeyValueStore(path string) *keyValueStore {
p := &keyValueStore{
path: path,
lock: &sync.RWMutex{},
}
if err := p.load(); err != nil {
logrus.WithError(err).Warn("Cannot load preferences file, creating new one")
}
return p
}
func (p *keyValueStore) load() error {
if p.vals != nil {
return nil
}
p.lock.Lock()
defer p.lock.Unlock()
p.vals = make(map[Key]string)
f, err := os.Open(p.path)
if err != nil {
return err
}
defer f.Close() //nolint:errcheck,gosec
return json.NewDecoder(f).Decode(&p.vals)
}
func (p *keyValueStore) save() error {
if p.vals == nil {
return errors.New("cannot save preferences: cache is nil")
}
p.lock.Lock()
defer p.lock.Unlock()
b, err := json.MarshalIndent(p.vals, "", "\t")
if err != nil {
return err
}
return os.WriteFile(p.path, b, 0o600)
}
func (p *keyValueStore) setDefault(key Key, value string) {
if p.Get(key) == "" {
p.Set(key, value)
}
}
func (p *keyValueStore) Get(key Key) string {
p.lock.RLock()
defer p.lock.RUnlock()
return p.vals[key]
}
func (p *keyValueStore) GetBool(key Key) bool {
return p.Get(key) == "true"
}
func (p *keyValueStore) GetInt(key Key) int {
if p.Get(key) == "" {
return 0
}
value, err := strconv.Atoi(p.Get(key))
if err != nil {
logrus.WithError(err).Error("Cannot parse int")
}
return value
}
func (p *keyValueStore) GetFloat64(key Key) float64 {
if p.Get(key) == "" {
return 0
}
value, err := strconv.ParseFloat(p.Get(key), 64)
if err != nil {
logrus.WithError(err).Error("Cannot parse float64")
}
return value
}
func (p *keyValueStore) Set(key Key, value string) {
p.lock.Lock()
p.vals[key] = value
p.lock.Unlock()
if err := p.save(); err != nil {
logrus.WithError(err).Warn("Cannot save preferences")
}
}
func (p *keyValueStore) SetBool(key Key, value bool) {
if value {
p.Set(key, "true")
} else {
p.Set(key, "false")
}
}
func (p *keyValueStore) SetInt(key Key, value int) {
p.Set(key, strconv.Itoa(value))
}
func (p *keyValueStore) SetFloat64(key Key, value float64) {
p.Set(key, fmt.Sprintf("%v", value))
}

View File

@ -1,141 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package settings
import (
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestLoadNoKeyValueStore(t *testing.T) {
r := require.New(t)
pref, clean := newTestEmptyKeyValueStore(r)
defer clean()
r.Equal("", pref.Get("key"))
}
func TestLoadBadKeyValueStore(t *testing.T) {
r := require.New(t)
path, clean := newTmpFile(r)
defer clean()
r.NoError(os.WriteFile(path, []byte("{\"key\":\"MISSING_QUOTES"), 0o700))
pref := newKeyValueStore(path)
r.Equal("", pref.Get("key"))
}
func TestKeyValueStor(t *testing.T) {
r := require.New(t)
pref, clean := newTestKeyValueStore(r)
defer clean()
r.Equal("value", pref.Get("str"))
r.Equal("42", pref.Get("int"))
r.Equal("true", pref.Get("bool"))
r.Equal("t", pref.Get("falseBool"))
}
func TestKeyValueStoreGetInt(t *testing.T) {
r := require.New(t)
pref, clean := newTestKeyValueStore(r)
defer clean()
r.Equal(0, pref.GetInt("str"))
r.Equal(42, pref.GetInt("int"))
r.Equal(0, pref.GetInt("bool"))
r.Equal(0, pref.GetInt("falseBool"))
}
func TestKeyValueStoreGetBool(t *testing.T) {
r := require.New(t)
pref, clean := newTestKeyValueStore(r)
defer clean()
r.Equal(false, pref.GetBool("str"))
r.Equal(false, pref.GetBool("int"))
r.Equal(true, pref.GetBool("bool"))
r.Equal(false, pref.GetBool("falseBool"))
}
func TestKeyValueStoreSetDefault(t *testing.T) {
r := require.New(t)
pref, clean := newTestEmptyKeyValueStore(r)
defer clean()
pref.setDefault("key", "value")
pref.setDefault("key", "othervalue")
r.Equal("value", pref.Get("key"))
}
func TestKeyValueStoreSet(t *testing.T) {
r := require.New(t)
pref, clean := newTestEmptyKeyValueStore(r)
defer clean()
pref.Set("str", "value")
checkSavedKeyValueStore(r, pref.path, "{\n\t\"str\": \"value\"\n}")
}
func TestKeyValueStoreSetInt(t *testing.T) {
r := require.New(t)
pref, clean := newTestEmptyKeyValueStore(r)
defer clean()
pref.SetInt("int", 42)
checkSavedKeyValueStore(r, pref.path, "{\n\t\"int\": \"42\"\n}")
}
func TestKeyValueStoreSetBool(t *testing.T) {
r := require.New(t)
pref, clean := newTestEmptyKeyValueStore(r)
defer clean()
pref.SetBool("trueBool", true)
pref.SetBool("falseBool", false)
checkSavedKeyValueStore(r, pref.path, "{\n\t\"falseBool\": \"false\",\n\t\"trueBool\": \"true\"\n}")
}
func newTmpFile(r *require.Assertions) (path string, clean func()) {
tmpfile, err := os.CreateTemp("", "pref.*.json")
r.NoError(err)
defer r.NoError(tmpfile.Close())
return tmpfile.Name(), func() {
r.NoError(os.Remove(tmpfile.Name()))
}
}
func newTestEmptyKeyValueStore(r *require.Assertions) (*keyValueStore, func()) {
path, clean := newTmpFile(r)
return newKeyValueStore(path), clean
}
func newTestKeyValueStore(r *require.Assertions) (*keyValueStore, func()) {
path, clean := newTmpFile(r)
r.NoError(os.WriteFile(path, []byte("{\"str\":\"value\",\"int\":\"42\",\"bool\":\"true\",\"falseBool\":\"t\"}"), 0o700))
return newKeyValueStore(path), clean
}
func checkSavedKeyValueStore(r *require.Assertions, path, expected string) {
data, err := os.ReadFile(path)
r.NoError(err)
r.Equal(expected, string(data))
}

View File

@ -1,116 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package settings provides access to persistent user settings.
package settings
import (
"fmt"
"math/rand"
"path/filepath"
"time"
)
type Key string
// Keys of preferences in JSON file.
const (
FirstStartKey Key = "first_time_start"
FirstStartGUIKey Key = "first_time_start_gui"
LastHeartbeatKey Key = "last_heartbeat"
APIPortKey Key = "user_port_api"
IMAPPortKey Key = "user_port_imap"
SMTPPortKey Key = "user_port_smtp"
SMTPSSLKey Key = "user_ssl_smtp"
AllowProxyKey Key = "allow_proxy"
AutostartKey Key = "autostart"
AutoUpdateKey Key = "autoupdate"
CookiesKey Key = "cookies"
LastVersionKey Key = "last_used_version"
UpdateChannelKey Key = "update_channel"
RolloutKey Key = "rollout"
PreferredKeychainKey Key = "preferred_keychain"
CacheEnabledKey Key = "cache_enabled"
CacheCompressionKey Key = "cache_compression"
CacheLocationKey Key = "cache_location"
CacheMinFreeAbsKey Key = "cache_min_free_abs"
CacheMinFreeRatKey Key = "cache_min_free_rat"
CacheConcurrencyRead Key = "cache_concurrent_read"
CacheConcurrencyWrite Key = "cache_concurrent_write"
IMAPWorkers Key = "imap_workers"
FetchWorkers Key = "fetch_workers"
AttachmentWorkers Key = "attachment_workers"
ColorScheme Key = "color_scheme"
RebrandingMigrationKey Key = "rebranding_migrated"
IsAllMailVisible Key = "is_all_mail_visible"
)
type Settings struct {
*keyValueStore
settingsPath string
}
func New(settingsPath string) *Settings {
s := &Settings{
keyValueStore: newKeyValueStore(filepath.Join(settingsPath, "prefs.json")),
settingsPath: settingsPath,
}
s.setDefaultValues()
return s
}
const (
DefaultIMAPPort = "1143"
DefaultSMTPPort = "1025"
DefaultAPIPort = "1042"
)
func (s *Settings) setDefaultValues() {
s.setDefault(FirstStartKey, "true")
s.setDefault(FirstStartGUIKey, "true")
s.setDefault(LastHeartbeatKey, fmt.Sprintf("%v", time.Now().YearDay()))
s.setDefault(AllowProxyKey, "true")
s.setDefault(AutostartKey, "true")
s.setDefault(AutoUpdateKey, "true")
s.setDefault(LastVersionKey, "")
s.setDefault(UpdateChannelKey, "")
s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint:gosec // G404 It is OK to use weak random number generator here
s.setDefault(PreferredKeychainKey, "")
s.setDefault(CacheEnabledKey, "true")
s.setDefault(CacheCompressionKey, "true")
s.setDefault(CacheLocationKey, "")
s.setDefault(CacheMinFreeAbsKey, "250000000")
s.setDefault(CacheMinFreeRatKey, "")
s.setDefault(CacheConcurrencyRead, "16")
s.setDefault(CacheConcurrencyWrite, "16")
s.setDefault(IMAPWorkers, "16")
s.setDefault(FetchWorkers, "16")
s.setDefault(AttachmentWorkers, "16")
s.setDefault(ColorScheme, "")
s.setDefault(APIPortKey, DefaultAPIPort)
s.setDefault(IMAPPortKey, DefaultIMAPPort)
s.setDefault(SMTPPortKey, DefaultSMTPPort)
// By default, stick to STARTTLS. If the user uses catalina+applemail they'll have to change to SSL.
s.setDefault(SMTPSSLKey, "false")
s.setDefault(IsAllMailVisible, "true")
}

View File

@ -18,17 +18,35 @@
// Package constants contains variables that are set via ldflags during build.
package constants
import "fmt"
import (
"fmt"
"runtime"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
const VendorName = "protonmail"
//nolint:gochecknoglobals
var (
// Version of the build.
// Full app name (to show to the user).
FullAppName = ""
// ConfigName determines the name of the location where bridge stores config files.
ConfigName = "bridge"
// UpdateName is the name of the product appearing in the update URL.
UpdateName = "bridge"
// KeyChainName is the name of the entry in the OS keychain.
KeyChainName = "bridge"
// Version of the build.
Version = ""
Version = "2.3.0+git"
// AppVersion is the full rendered version of the app (to be used in request headers).
AppVersion = getAPIOS() + cases.Title(language.Und).String(ConfigName) + "_" + Version
// Revision is current hash of the build.
Revision = ""
@ -36,9 +54,31 @@ var (
// BuildTime stamp of the build.
BuildTime = ""
// BuildVersion is derived from LongVersion and BuildTime.
BuildVersion = fmt.Sprintf("%v (%v) %v", Version, Revision, BuildTime)
// DSNSentry client keys to be able to report crashes to Sentry.
DSNSentry = ""
// BuildVersion is derived from LongVersion and BuildTime.
BuildVersion = fmt.Sprintf("%v (%v) %v", Version, Revision, BuildTime)
// APIHost is our API address.
APIHost = "https://api.protonmail.ch"
// The host name of the bridge server.
Host = "127.0.0.1"
)
func getAPIOS() string {
switch runtime.GOOS {
case "darwin":
return "macOS"
case "linux":
return "Linux"
case "windows":
return "Windows"
default:
return "Linux"
}
}

View File

@ -22,8 +22,5 @@ package constants
import "time"
//nolint:gochecknoglobals
var (
// UpdateCheckInterval defines how often we check for new version.
UpdateCheckInterval = time.Hour //nolint:gochecknoglobals
)
// UpdateCheckInterval defines how often we check for new version.
const UpdateCheckInterval = time.Hour

View File

@ -22,8 +22,5 @@ package constants
import "time"
//nolint:gochecknoglobals
var (
// UpdateCheckInterval defines how often we check for new version
UpdateCheckInterval = time.Duration(5 * time.Minute)
)
// UpdateCheckInterval defines how often we check for new version
const UpdateCheckInterval = time.Duration(5 * time.Minute)

View File

@ -26,28 +26,31 @@ import (
"net/url"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
)
type cookiesByHost map[string][]*http.Cookie
type Persister interface {
GetCookies() ([]byte, error)
SetCookies([]byte) error
}
// Jar implements http.CookieJar by wrapping the standard library's cookiejar.Jar.
// The jar uses a pantry to load cookies at startup and save cookies when set.
type Jar struct {
jar *cookiejar.Jar
settings *settings.Settings
cookies cookiesByHost
locker sync.Locker
jar *cookiejar.Jar
persister Persister
cookies cookiesByHost
locker sync.Locker
}
func NewCookieJar(s *settings.Settings) (*Jar, error) {
func NewCookieJar(persister Persister) (*Jar, error) {
jar, err := cookiejar.New(nil)
if err != nil {
return nil, err
}
cookiesByHost, err := loadCookies(s)
cookiesByHost, err := loadCookies(persister)
if err != nil {
return nil, err
}
@ -62,10 +65,10 @@ func NewCookieJar(s *settings.Settings) (*Jar, error) {
}
return &Jar{
jar: jar,
settings: s,
cookies: cookiesByHost,
locker: &sync.Mutex{},
jar: jar,
persister: persister,
cookies: cookiesByHost,
locker: &sync.Mutex{},
}, nil
}
@ -101,16 +104,17 @@ func (j *Jar) PersistCookies() error {
return err
}
j.settings.Set(settings.CookiesKey, string(rawCookies))
return nil
return j.persister.SetCookies(rawCookies)
}
// loadCookies loads all non-expired cookies from disk.
func loadCookies(s *settings.Settings) (cookiesByHost, error) {
rawCookies := s.Get(settings.CookiesKey)
func loadCookies(persister Persister) (cookiesByHost, error) {
rawCookies, err := persister.GetCookies()
if err != nil {
return nil, err
}
if rawCookies == "" {
if len(rawCookies) == 0 {
return make(cookiesByHost), nil
}

View File

@ -18,13 +18,15 @@
package cookies
import (
"errors"
"io/fs"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -37,7 +39,7 @@ func TestJarGetSet(t *testing.T) {
})
defer ts.Close()
client, _ := getClientWithJar(t, newFakeSettings())
client, _ := getClientWithJar(t, newTestPersister(t))
// Hit a server that sets some cookies.
setRes, err := client.Get(ts.URL + "/set")
@ -63,7 +65,7 @@ func TestJarLoad(t *testing.T) {
defer ts.Close()
// This will be our "persistent storage" from which the cookie jar should load cookies.
s := newFakeSettings()
s := newTestPersister(t)
// This client saves cookies to persistent storage.
oldClient, jar := getClientWithJar(t, s)
@ -98,7 +100,7 @@ func TestJarExpiry(t *testing.T) {
defer ts.Close()
// This will be our "persistent storage" from which the cookie jar should load cookies.
s := newFakeSettings()
s := newTestPersister(t)
// This client saves cookies to persistent storage.
oldClient, jar1 := getClientWithJar(t, s)
@ -122,9 +124,12 @@ func TestJarExpiry(t *testing.T) {
// Save the cookies (expired ones were cleared out).
require.NoError(t, jar2.PersistCookies())
assert.Contains(t, s.Get(settings.CookiesKey), "TestName1")
assert.NotContains(t, s.Get(settings.CookiesKey), "TestName2")
assert.Contains(t, s.Get(settings.CookiesKey), "TestName3")
cookies, err := s.GetCookies()
require.NoError(t, err)
assert.Contains(t, string(cookies), "TestName1")
assert.NotContains(t, string(cookies), "TestName2")
assert.Contains(t, string(cookies), "TestName3")
}
type testCookie struct {
@ -132,8 +137,8 @@ type testCookie struct {
maxAge int
}
func getClientWithJar(t *testing.T, s *settings.Settings) (*http.Client, *Jar) {
jar, err := NewCookieJar(s)
func getClientWithJar(t *testing.T, persister Persister) (*http.Client, *Jar) {
jar, err := NewCookieJar(persister)
require.NoError(t, err)
return &http.Client{Jar: jar}, jar
@ -168,12 +173,26 @@ func getTestServer(t *testing.T, wantCookies []testCookie) *httptest.Server {
return httptest.NewServer(mux)
}
// newFakeSettings creates a temporary folder for files.
func newFakeSettings() *settings.Settings {
dir, err := os.MkdirTemp("", "test-settings")
if err != nil {
panic(err)
type testPersister struct {
path string
}
func newTestPersister(tb testing.TB) *testPersister {
path := filepath.Join(tb.TempDir(), "cookies.json")
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
if err := os.WriteFile(path, []byte{}, 0600); err != nil {
panic(err)
}
}
return settings.New(dir)
return &testPersister{path: path}
}
func (p *testPersister) GetCookies() ([]byte, error) {
return os.ReadFile(p.path)
}
func (p *testPersister) SetCookies(rawCookies []byte) error {
return os.WriteFile(p.path, rawCookies, 0600)
}

View File

@ -41,14 +41,11 @@ func (h *Handler) AddRecoveryAction(action RecoveryAction) *Handler {
func (h *Handler) HandlePanic() {
sentry.SkipDuringUnwind()
r := recover()
if r == nil {
return
}
for _, action := range h.actions {
if err := action(r); err != nil {
logrus.WithError(err).Error("Failed to execute recovery action")
if r := recover(); r != nil {
for _, action := range h.actions {
if err := action(r); err != nil {
logrus.WithError(err).Error("Failed to execute recovery action")
}
}
}
}

View File

@ -0,0 +1,77 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
type TLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
}
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
return &http.Transport{
DialTLSContext: dialer.DialTLSContext,
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 5 * time.Minute,
ExpectContinueTimeout: 500 * time.Millisecond,
// GODT-126: this was initially 10s but logs from users showed a significant number
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
// Bumping to 30s for now to avoid this problem.
ResponseHeaderTimeout: 30 * time.Second,
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
// to 30 seconds for the TLS handshake to take place.
TLSHandshakeTimeout: 30 * time.Second,
}
}
// BasicTLSDialer implements TLSDialer.
type BasicTLSDialer struct {
hostURL string
}
// NewBasicTLSDialer returns a new BasicTLSDialer.
func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
return &BasicTLSDialer{
hostURL: hostURL,
}
}
// DialTLS returns a connection to the given address using the given network.
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
return (&tls.Dialer{
NetDialer: &net.Dialer{
Timeout: 30 * time.Second,
},
Config: &tls.Config{
InsecureSkipVerify: address != d.hostURL,
},
}).DialContext(ctx, network, address)
}

View File

@ -0,0 +1,114 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"crypto/tls"
"net"
)
// TrustedAPIPins contains trusted public keys of the protonmail API and proxies.
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;).
var TrustedAPIPins = []string{ //nolint:gochecknoglobals
// api.protonmail.ch
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot backup
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold backup
// protonmail.com
// \todo remove when sure no one is using it.
`pin-sha256="8joiNBdqaYiQpKskgtkJsqRxF7zN0C0aqfi8DacknnI="`, // current
`pin-sha256="JMI8yrbc6jB1FYGyyWRLFTmDNgIszrNEMGlgy972e7w="`, // hot backup
`pin-sha256="Iu44zU84EOCZ9vx/vz67/MRVrxF1IO4i4NIa8ETwiIY="`, // cold backup
// proton.me
`pin-sha256="CT56BhOTmj5ZIPgb/xD5mH8rY3BLo/MlhP7oPyJUEDo="`, // current
`pin-sha256="35Dx28/uzN3LeltkCBQ8RHK0tlNSa2kCpCRGNp34Gxc="`, // hot backup
`pin-sha256="qYIukVc63DEITct8sFT7ebIq5qsWmuscaIKeJx+5J5A="`, // col backup
// proxies
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // backup 3
}
// TLSReportURI is the address where TLS reports should be sent.
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
// to report errors if the fingerprint check fails.
type PinningTLSDialer struct {
dialer TLSDialer
pinChecker PinChecker
reporter Reporter
tlsIssueCh chan struct{}
}
// Reporter is used to report TLS issues.
type Reporter interface {
ReportCertIssue(reportURI, host, port string, state tls.ConnectionState)
}
// PinChecker is used to check TLS keys of connections.
type PinChecker interface {
CheckCertificate(conn net.Conn) error
}
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
// which present known certificates.
// It checks pins using the given pinChecker and reports issues using the given reporter.
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
return &PinningTLSDialer{
dialer: dialer,
pinChecker: pinChecker,
reporter: reporter,
tlsIssueCh: make(chan struct{}, 1),
}
}
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := p.dialer.DialTLSContext(ctx, network, address)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if err := p.pinChecker.CheckCertificate(conn); err != nil {
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
}
p.tlsIssueCh <- struct{}{}
return nil, err
}
return conn, nil
}
// GetTLSIssueCh returns a channel which notifies when a TLS issue is reported.
func (p *PinningTLSDialer) GetTLSIssueCh() <-chan struct{} {
return p.tlsIssueCh
}

View File

@ -0,0 +1,67 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"github.com/ProtonMail/proton-bridge/v2/pkg/algo"
)
// ErrTLSMismatch indicates that no TLS fingerprint match could be found.
var ErrTLSMismatch = errors.New("no TLS fingerprint match found")
type TLSPinChecker struct {
trustedPins []string
}
func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
return &TLSPinChecker{
trustedPins: trustedPins,
}
}
// checkCertificate returns whether the connection presents a known TLS certificate.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return errors.New("connection is not a TLS connection")
}
connState := tlsConn.ConnectionState()
for _, peerCert := range connState.PeerCertificates {
fingerprint := certFingerprint(peerCert)
for _, pin := range p.trustedPins {
if pin == fingerprint {
return nil
}
}
}
return ErrTLSMismatch
}
func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
}

View File

@ -0,0 +1,118 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"fmt"
"time"
"github.com/go-resty/resty/v2"
)
// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
type tlsReport struct {
// DateTime of observed pin validation in time.RFC3339 format.
DateTime string `json:"date-time"`
// Hostname to which the UA made original request that failed pin validation.
Hostname string `json:"hostname"`
// Port to which the UA made original request that failed pin validation.
Port string `json:"port"`
// EffectiveExpirationDate for noted pins in time.RFC3339 format.
EffectiveExpirationDate string `json:"effective-expiration-date"`
// IncludeSubdomains indicates whether or not the UA has noted the
// includeSubDomains directive for the Known Pinned Host.
IncludeSubdomains bool `json:"include-subdomains"`
// NotedHostname indicates the hostname that the UA noted when it noted
// the Known Pinned Host. This field allows operators to understand why
// Pin Validation was performed for, e.g., foo.example.com when the
// noted Known Pinned Host was example.com with includeSubDomains set.
NotedHostname string `json:"noted-hostname"`
// ServedCertificateChain is the certificate chain, as served by
// the Known Pinned Host during TLS session setup. It is provided as an
// array of strings; each string pem1, ... pemN is the Privacy-Enhanced
// Mail (PEM) representation of each X.509 certificate as described in
// [RFC7468].
ServedCertificateChain []string `json:"served-certificate-chain"`
// ValidatedCertificateChain is the certificate chain, as
// constructed by the UA during certificate chain verification. (This
// may differ from the served-certificate-chain.) It is provided as an
// array of strings; each string pem1, ... pemN is the PEM
// representation of each X.509 certificate as described in [RFC7468].
// UAs that build certificate chains in more than one way during the
// validation process SHOULD send the last chain built. In this way,
// they can avoid keeping too much state during the validation process.
ValidatedCertificateChain []string `json:"validated-certificate-chain"`
// The known-pins are the Pins that the UA has noted for the Known
// Pinned Host. They are provided as an array of strings with the
// syntax: known-pin = token "=" quoted-string
// e.g.:
// ```
// "known-pins": [
// 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
// "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
// ]
// ```
KnownPins []string `json:"known-pins"`
// AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
AppVersion string `json:"app-version"`
}
// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
// Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
report = tlsReport{
Hostname: host,
Port: port,
NotedHostname: server,
ServedCertificateChain: certChain,
KnownPins: knownPins,
AppVersion: appVersion,
}
return
}
// sendReport posts the given TLS report to the standard TLS Report URI.
func sendReport(report tlsReport, userAgent, appVersion, hostURL, remoteURI string) error {
now := time.Now()
report.DateTime = now.Format(time.RFC3339)
report.EffectiveExpirationDate = now.Add(365 * 24 * time.Hour).Format(time.RFC3339)
if _, err := resty.New().
SetTransport(CreateTransportWithDialer(NewBasicTLSDialer(hostURL))).
SetHeader("User-Agent", userAgent).
SetHeader("x-pm-appversion", appVersion).
NewRequest().
SetBody(report).
Post(remoteURI); err != nil {
return fmt.Errorf("failed to send TLS report: %w", err)
}
return nil
}

View File

@ -0,0 +1,115 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/google/go-cmp/cmp"
"github.com/sirupsen/logrus"
)
type sentReport struct {
r tlsReport
t time.Time
}
type TLSReporter struct {
hostURL string
appVersion string
userAgent *useragent.UserAgent
trustedPins []string
sentReports []sentReport
}
func NewTLSReporter(hostURL, appVersion string, userAgent *useragent.UserAgent, trustedPins []string) *TLSReporter {
return &TLSReporter{
hostURL: hostURL,
appVersion: appVersion,
userAgent: userAgent,
trustedPins: trustedPins,
}
}
// reportCertIssue reports a TLS key mismatch.
func (r *TLSReporter) ReportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
var certChain []string
if len(connState.VerifiedChains) > 0 {
certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1])
} else {
certChain = marshalCert7468(connState.PeerCertificates)
}
report := newTLSReport(host, port, connState.ServerName, certChain, r.trustedPins, r.appVersion)
if !r.hasRecentlySentReport(report) {
r.recordReport(report)
if err := sendReport(report, r.userAgent.GetUserAgent(), r.appVersion, r.hostURL, remoteURI); err != nil {
logrus.WithError(err).Error("Failed to send TLS pinning report")
}
}
}
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
func (r *TLSReporter) hasRecentlySentReport(report tlsReport) bool {
var validReports []sentReport
for _, r := range r.sentReports {
if time.Since(r.t) < 24*time.Hour {
validReports = append(validReports, r)
}
}
r.sentReports = validReports
for _, r := range r.sentReports {
if cmp.Equal(report, r.r) {
return true
}
}
return false
}
// recordReport records the given report and the current time so we can check whether we recently sent this report.
func (r *TLSReporter) recordReport(report tlsReport) {
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
}
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
var buffer bytes.Buffer
for _, cert := range certs {
if err := pem.Encode(&buffer, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}); err != nil {
logrus.WithError(err).Error("Failed to encode TLS certificate")
}
pemCerts = append(pemCerts, buffer.String())
buffer.Reset()
}
return pemCerts
}

View File

@ -0,0 +1,59 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
"github.com/stretchr/testify/assert"
)
func TestTLSReporter_DoubleReport(t *testing.T) {
reportCounter := 0
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reportCounter++
}))
r := NewTLSReporter("hostURL", "appVersion", useragent.New(), TrustedAPIPins)
// Report the same issue many times.
for i := 0; i < 10; i++ {
r.ReportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
}
// We should only report once.
assert.Eventually(t, func() bool {
return reportCounter == 1
}, time.Second, time.Millisecond)
// If we then report something else many times.
for i := 0; i < 10; i++ {
r.ReportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
}
// We should get a second report.
assert.Eventually(t, func() bool {
return reportCounter == 2
}, time.Second, time.Millisecond)
}

View File

@ -0,0 +1,157 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
)
func getRootURL() string {
return "https://api.protonmail.ch"
}
func TestTLSPinValid(t *testing.T) {
called, _, _, _, cm := createClientWithPinningDialer(getRootURL())
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
checkTLSIssueHandler(t, 0, called)
}
func TestTLSPinBackup(t *testing.T) {
called, _, _, checker, cm := createClientWithPinningDialer(getRootURL())
copyTrustedPins(checker)
checker.trustedPins[1] = checker.trustedPins[0]
checker.trustedPins[0] = ""
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
checkTLSIssueHandler(t, 0, called)
}
func TestTLSPinInvalid(t *testing.T) {
s := server.NewTLS()
defer s.Close()
called, _, _, _, cm := createClientWithPinningDialer(s.GetHostURL())
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
checkTLSIssueHandler(t, 1, called)
}
func TestTLSPinNoMatch(t *testing.T) {
skipIfProxyIsSet(t)
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
copyTrustedPins(checker)
for i := 0; i < len(checker.trustedPins); i++ {
checker.trustedPins[i] = "testing"
}
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
// Check that it will be reported only once per session, but notified every time.
r.Equal(t, 1, len(reporter.sentReports))
checkTLSIssueHandler(t, 2, called)
}
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, _, _ := createClientWithPinningDialer("")
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
r.Error(t, err, "expected dial to fail because of wrong public key")
}
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, checker, _ := createClientWithPinningDialer("")
copyTrustedPins(checker)
checker.trustedPins = append(checker.trustedPins, `pin-sha256="LwnIKjNLV3z243ap8y0yXNPghsqE76J08Eq3COvUt2E="`)
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
}
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, checker, _ := createClientWithPinningDialer("")
copyTrustedPins(checker)
checker.trustedPins = append(checker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
_, err := dialer.DialTLSContext(context.Background(), "tcp", "self-signed.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
}
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
called := 0
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(hostURL), reporter, checker)
go func() {
for range dialer.GetTLSIssueCh() {
called++
}
}()
return &called, dialer, reporter, checker, liteapi.New(
liteapi.WithHostURL(hostURL),
liteapi.WithTransport(CreateTransportWithDialer(dialer)),
)
}
func copyTrustedPins(pinChecker *TLSPinChecker) {
copiedPins := make([]string, len(pinChecker.trustedPins))
copy(copiedPins, pinChecker.trustedPins)
pinChecker.trustedPins = copiedPins
}
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
a.Eventually(
t,
func() bool {
if wantCalledAtLeast == 0 {
return *called == 0
}
// Dialer can do more attempts resulting in more calls.
return *called >= wantCalledAtLeast
},
time.Second,
10*time.Millisecond,
)
// Repeated again so it generates nice message.
if wantCalledAtLeast == 0 {
r.Equal(t, 0, *called)
} else {
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
}
}

View File

@ -0,0 +1,152 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"net"
"net/url"
"sync"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
var ErrNoConnection = errors.New("no connection")
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
type ProxyTLSDialer struct {
dialer TLSDialer
locker sync.RWMutex
directAddress string
proxyAddress string
allowProxy bool
proxyProvider *proxyProvider
proxyUseDuration time.Duration
}
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
func NewProxyTLSDialer(dialer TLSDialer, hostURL string) *ProxyTLSDialer {
return &ProxyTLSDialer{
dialer: dialer,
locker: sync.RWMutex{},
directAddress: formatAsAddress(hostURL),
proxyAddress: formatAsAddress(hostURL),
proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders),
proxyUseDuration: proxyUseDuration,
}
}
// formatAsAddress returns URL as `host:port` for easy comparison in DialTLS.
func formatAsAddress(rawURL string) string {
url, err := url.Parse(rawURL)
if err != nil {
// This means wrong configuration.
// Developer should get feedback right away.
panic(err)
}
host := url.Host
if host == "" {
host = url.Path
}
port := "443"
if url.Scheme == "http" {
port = "80"
}
return net.JoinHostPort(host, port)
}
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
if address == d.directAddress {
address = d.proxyAddress
}
conn, err := d.dialer.DialTLSContext(ctx, network, address)
if err == nil || !d.allowProxy {
return conn, err
}
if err := d.switchToReachableServer(); err != nil {
return nil, err
}
return d.dialer.DialTLSContext(ctx, network, d.proxyAddress)
}
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
func (d *ProxyTLSDialer) switchToReachableServer() error {
d.locker.Lock()
defer d.locker.Unlock()
logrus.Info("Attempting to switch to a proxy")
proxy, err := d.proxyProvider.findReachableServer()
if err != nil {
return errors.Wrap(err, "failed to find a usable proxy")
}
proxyAddress := formatAsAddress(proxy)
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
if proxyAddress == d.directAddress {
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
d.proxyAddress = proxyAddress
return ErrNoConnection
}
logrus.WithField("proxy", proxyAddress).Info("Switching to a proxy")
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
// This means we want to disable it again in 24 hours.
if d.proxyAddress == d.directAddress {
go func() {
<-time.After(d.proxyUseDuration)
d.locker.Lock()
defer d.locker.Unlock()
d.proxyAddress = d.directAddress
}()
}
d.proxyAddress = proxyAddress
return nil
}
// AllowProxy allows the dialer to switch to a proxy if need be.
func (d *ProxyTLSDialer) AllowProxy() {
d.locker.Lock()
defer d.locker.Unlock()
d.allowProxy = true
}
// DisallowProxy prevents the dialer from switching to a proxy if need be.
func (d *ProxyTLSDialer) DisallowProxy() {
d.locker.Lock()
defer d.locker.Unlock()
d.allowProxy = false
d.proxyAddress = d.directAddress
}

View File

@ -0,0 +1,256 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"encoding/base64"
"strings"
"sync"
"time"
"github.com/go-resty/resty/v2"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
proxyUseDuration = 24 * time.Hour
proxyLookupWait = 5 * time.Second
proxyCacheRefreshTimeout = 20 * time.Second
proxyDoHTimeout = 20 * time.Second
proxyCanReachTimeout = 20 * time.Second
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
Quad9Provider = "https://dns11.quad9.net/dns-query"
Quad9PortProvider = "https://dns11.quad9.net:5053/dns-query"
GoogleProvider = "https://dns.google/dns-query"
)
var DoHProviders = []string{ //nolint:gochecknoglobals
Quad9Provider,
Quad9PortProvider,
GoogleProvider,
}
// proxyProvider manages known proxies.
type proxyProvider struct {
dialer TLSDialer
hostURL string
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
dohLookup func(ctx context.Context, query, provider string) (urls []string, err error)
providers []string // List of known doh providers.
query string // The query string used to find proxies.
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
cacheRefreshTimeout time.Duration
dohTimeout time.Duration
canReachTimeout time.Duration
lastLookup time.Time // The time at which we last attempted to find a proxy.
}
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
// to retrieve DNS records for the given query string.
func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *proxyProvider) {
p = &proxyProvider{
dialer: dialer,
hostURL: hostURL,
providers: providers,
query: proxyQuery,
cacheRefreshTimeout: proxyCacheRefreshTimeout,
dohTimeout: proxyDoHTimeout,
canReachTimeout: proxyCanReachTimeout,
}
// Use the default DNS lookup method; this can be overridden if necessary.
p.dohLookup = p.defaultDoHLookup
return
}
// findReachableServer returns a working API server (either proxy or standard API).
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
logrus.Debug("Trying to find a reachable server")
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon")
}
p.lastLookup = time.Now()
// We use a waitgroup to wait for both
// a) the check whether the API is reachable, and
// b) the DoH queries.
// This is because the Alternative Routes v2 spec says:
// Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished)
var wg sync.WaitGroup
var apiReachable bool
wg.Add(2)
go func() {
defer wg.Done()
apiReachable = p.canReach(p.hostURL)
}()
go func() {
defer wg.Done()
err = p.refreshProxyCache()
}()
wg.Wait()
if apiReachable {
proxy = p.hostURL
return
}
if err != nil {
return
}
for _, url := range p.proxyCache {
if p.canReach(url) {
proxy = url
return
}
}
return "", errors.New("no reachable server could be found")
}
// refreshProxyCache loads the latest proxies from the known providers.
// If the process takes longer than proxyCacheRefreshTimeout, an error is returned.
func (p *proxyProvider) refreshProxyCache() error {
logrus.Info("Refreshing proxy cache")
ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout)
defer cancel()
resultChan := make(chan []string)
go func() {
for _, provider := range p.providers {
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
resultChan <- proxies
return
}
}
// If no dohLoopkup worked, cancel right after it's done to not
// block refreshing for the whole cacheRefreshTimeout.
cancel()
}()
select {
case result := <-resultChan:
p.proxyCache = result
return nil
case <-ctx.Done():
return errors.New("timed out while refreshing proxy cache")
}
}
// canReach returns whether we can reach the given url.
func (p *proxyProvider) canReach(url string) bool {
logrus.WithField("url", url).Debug("Trying to ping proxy")
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
pinger := resty.New().
SetBaseURL(url).
SetTimeout(p.canReachTimeout).
SetTransport(CreateTransportWithDialer(p.dialer))
if _, err := pinger.R().Get("/tests/ping"); err != nil {
logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
return false
}
return true
}
// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup.
// It looks up DNS TXT records for the given query URL using the given DoH provider.
// It returns a list of all found TXT records.
// If the whole process takes more than proxyDoHTimeout then an error is returned.
func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) {
ctx, cancel := context.WithTimeout(ctx, p.dohTimeout)
defer cancel()
dataChan, errChan := make(chan []string), make(chan error)
go func() {
// Build new DNS request in RFC1035 format.
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
// Pack the DNS request message into wire format.
rawRequest, err := dnsRequest.Pack()
if err != nil {
errChan <- errors.Wrap(err, "failed to pack DNS request")
return
}
// Encode wire-format DNS request message as base64url (RFC4648) without padding chars.
encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest)
// Make DoH request to the given DoH provider.
rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider)
if err != nil {
errChan <- errors.Wrap(err, "failed to make DoH request")
return
}
// Unpack the DNS response.
dnsResponse := new(dns.Msg)
if err = dnsResponse.Unpack(rawResponse.Body()); err != nil {
errChan <- errors.Wrap(err, "failed to unpack DNS response")
return
}
// Pick out the TXT answers.
for _, answer := range dnsResponse.Answer {
if t, ok := answer.(*dns.TXT); ok {
data = append(data, t.Txt...)
}
}
dataChan <- data
}()
select {
case data = <-dataChan:
logrus.WithField("data", data).Info("Received TXT records")
return
case err = <-errChan:
logrus.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records")
return
case <-ctx.Done():
logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records")
return []string{}, errors.New("timed out querying DNS records")
}
}

View File

@ -0,0 +1,191 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"net/http"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
r "github.com/stretchr/testify/require"
)
func TestProxyProvider_FindProxy(t *testing.T) {
proxy := getTrustedServer()
defer closeServer(proxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, proxy.URL, url)
}
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
reachableProxy := getTrustedServer()
defer closeServer(reachableProxy)
// We actually close the unreachable proxy straight away rather than deferring the closure.
unreachableProxy := getTrustedServer()
closeServer(unreachableProxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
}
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, reachableProxy.URL, url)
}
func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
untrustedProxy := getUntrustedServer()
defer closeServer(untrustedProxy)
reporter := NewTLSReporter("", "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
p := newProxyProvider(dialer, "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
}
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, trustedProxy.URL, url)
}
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
unreachableProxy1 := getTrustedServer()
closeServer(unreachableProxy1)
unreachableProxy2 := getTrustedServer()
closeServer(unreachableProxy2)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
}
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
untrustedProxy1 := getUntrustedServer()
defer closeServer(untrustedProxy1)
untrustedProxy2 := getUntrustedServer()
defer closeServer(untrustedProxy2)
reporter := NewTLSReporter("", "appVersion", useragent.New(), TrustedAPIPins)
checker := NewTLSPinChecker(TrustedAPIPins)
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
p := newProxyProvider(dialer, "", []string{"not used"})
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
}
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.cacheRefreshTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
// We should fail to refresh the proxy cache because the doh provider
// takes 2 seconds to respond but we timeout after just 1 second.
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
time.Sleep(2 * time.Second)
}))
defer closeServer(slowProxy)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
p.canReachTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
// We should fail to reach the returned proxy because it takes 2 seconds
// to reach it and we only allow 1.
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
// DISABLEDTestProxyProvider_DoHLookup_Quad9Port cannot run on CI due to custom
// port filter. Basic functionality should be covered by other tests. Keeping
// code here to be able to run it locally if needed.
func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
skipIfProxyIsSet(t)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
url, err := p.findReachableServer()
r.NoError(t, err)
r.NotEmpty(t, url)
}
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
skipIfProxyIsSet(t)
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider})
url, err := p.findReachableServer()
r.NoError(t, err)
r.NotEmpty(t, url)
}

View File

@ -0,0 +1,273 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// getTrustedServer returns a server and sets its public key as one of the pinned ones.
func getTrustedServer() *httptest.Server {
return getTrustedServerWithHandler(
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
// Do nothing.
}),
)
}
func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server {
proxy := httptest.NewTLSServer(handler)
pin := certFingerprint(proxy.Certificate())
TrustedAPIPins = append(TrustedAPIPins, pin)
return proxy
}
const servercrt = `
-----BEGIN CERTIFICATE-----
MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD
VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp
dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t
T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j
b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx
MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR
BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf
MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR
aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI
hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt
r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2
pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM
GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru
rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c
kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw
ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI
DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu
ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0
MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3
LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R
BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5
L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA
CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P
3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg
yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l
z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc
u53wgIhCJGuX
-----END CERTIFICATE-----
`
const serverkey = `
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx
owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG
ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq
UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx
8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr
n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6
mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1
EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ
/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F
qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT
VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu
Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h
bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X
TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU
LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f
pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm
nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3
Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5
ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo
w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK
PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt
FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0
ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD
z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2
FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o
vwRMog6lPhlRhHh/FZ43Cg==
-----END PRIVATE KEY-----
`
// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones.
func getUntrustedServer() *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey))
if err != nil {
panic(err)
}
server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
server.StartTLS()
return server
}
// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys.
func closeServer(server *httptest.Server) {
pin := certFingerprint(server.Certificate())
for i := range TrustedAPIPins {
if TrustedAPIPins[i] == pin {
TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...)
break
}
}
server.Close()
}
func TestProxyDialer_UseProxy(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
}
func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
proxy1 := getTrustedServer()
defer closeServer(proxy1)
proxy2 := getTrustedServer()
defer closeServer(proxy2)
proxy3 := getTrustedServer()
defer closeServer(proxy3)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy3.URL), d.proxyAddress)
}
func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
d.proxyUseDuration = time.Second
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
time.Sleep(2 * time.Second)
require.Equal(t, ":443", d.proxyAddress)
}
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
trustedProxy := getTrustedServer()
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
// Simulate that the proxy stops working and that the standard api is reachable again.
closeServer(trustedProxy)
d.directAddress = formatAsAddress(getRootURL())
provider.hostURL = getRootURL()
time.Sleep(proxyLookupWait)
// We should now find the original API URL if it is working again.
// The error should be ErrAPINotReachable because the connection dropped intermittently but
// the original API is now reachable (see Alternative-Routing-v2 spec for details).
err = d.switchToReachableServer()
require.Error(t, err)
require.Equal(t, formatAsAddress(getRootURL()), d.proxyAddress)
}
func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
// proxy1 is closed later in this test so we don't defer it here.
proxy1 := getTrustedServer()
proxy2 := getTrustedServer()
defer closeServer(proxy2)
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
d.proxyProvider = provider
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
// The proxy stops working and the protonmail API is still blocked.
closeServer(proxy1)
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
}
func TestFormatAsAddress(t *testing.T) {
r := require.New(t)
testData := map[string]string{
"sub.domain.tld": "sub.domain.tld:443",
"http://sub.domain.tld": "sub.domain.tld:80",
"https://sub.domain.tld": "sub.domain.tld:443",
"ftp://sub.domain.tld": "sub.domain.tld:443",
"//sub.domain.tld": "sub.domain.tld:443",
}
for rawURL, wantURL := range testData {
r.Equal(wantURL, formatAsAddress(rawURL))
}
}

View File

@ -0,0 +1,16 @@
package dialer
import (
"testing"
"golang.org/x/net/http/httpproxy"
)
// skipIfProxyIsSet skips the tests if HTTPS proxy is set.
// Should be used for tests depending on proper certificate checks which
// is not possible under our CI setup.
func skipIfProxyIsSet(t *testing.T) {
if httpproxy.FromEnvironment().HTTPSProxy != "" {
t.SkipNow()
}
}

View File

@ -0,0 +1,13 @@
package events
import "gitlab.protontech.ch/go/liteapi"
type TLSIssue struct {
eventBase
}
type ConnStatus struct {
eventBase
Status liteapi.Status
}

7
internal/events/error.go Normal file
View File

@ -0,0 +1,7 @@
package events
type Error struct {
eventBase
Error error
}

View File

@ -1,60 +1,9 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package events provides names of events used by the event listener in bridge.
package events
import (
"time"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
)
// Constants of events used by the event listener in bridge.
const (
ErrorEvent = "error"
CredentialsErrorEvent = "credentialsError"
CloseConnectionEvent = "closeConnection"
LogoutEvent = "logout"
AddressChangedEvent = "addressChanged"
AddressChangedLogoutEvent = "addressChangedLogout"
UserRefreshEvent = "userRefresh"
RestartBridgeEvent = "restartBridge"
InternetConnChangedEvent = "internetChanged"
InternetOff = "internetOff"
InternetOn = "internetOn"
SecondInstanceEvent = "secondInstance"
NoActiveKeyForRecipientEvent = "noActiveKeyForRecipient"
UpgradeApplicationEvent = "upgradeApplication"
TLSCertIssue = "tlsCertPinningIssue"
UserChangeDone = "QMLUserChangedDone"
// LogoutEventTimeout is the minimum time to permit between logout events being sent.
LogoutEventTimeout = 3 * time.Minute
)
// SetupEvents specific to event type and data.
func SetupEvents(listener listener.Listener) {
listener.SetLimit(LogoutEvent, LogoutEventTimeout)
listener.SetBuffer(ErrorEvent)
listener.SetBuffer(CredentialsErrorEvent)
listener.SetBuffer(InternetConnChangedEvent)
listener.SetBuffer(UpgradeApplicationEvent)
listener.SetBuffer(TLSCertIssue)
listener.SetBuffer(UserRefreshEvent)
listener.Book(UserChangeDone)
type Event interface {
_isEvent()
}
type eventBase struct{}
func (eventBase) _isEvent() {}

5
internal/events/raise.go Normal file
View File

@ -0,0 +1,5 @@
package events
type Raise struct {
eventBase
}

24
internal/events/sync.go Normal file
View File

@ -0,0 +1,24 @@
package events
import "time"
type SyncStarted struct {
eventBase
UserID string
}
type SyncProgress struct {
eventBase
UserID string
Progress float64
Elapsed time.Duration
Remaining time.Duration
}
type SyncFinished struct {
eventBase
UserID string
}

25
internal/events/update.go Normal file
View File

@ -0,0 +1,25 @@
package events
import "github.com/ProtonMail/proton-bridge/v2/internal/updater"
type UpdateAvailable struct {
eventBase
Version updater.VersionInfo
CanInstall bool
}
type UpdateNotAvailable struct {
eventBase
}
type UpdateInstalled struct {
eventBase
Version updater.VersionInfo
}
type UpdateForced struct {
eventBase
}

52
internal/events/user.go Normal file
View File

@ -0,0 +1,52 @@
package events
type UserLoggedIn struct {
eventBase
UserID string
}
type UserLoggedOut struct {
eventBase
UserID string
}
type UserDeauth struct {
eventBase
UserID string
}
type UserDeleted struct {
eventBase
UserID string
}
type UserChanged struct {
eventBase
UserID string
}
type UserAddressCreated struct {
eventBase
UserID string
Address string
}
type UserAddressChanged struct {
eventBase
UserID string
Address string
}
type UserAddressDeleted struct {
eventBase
UserID string
Address string
}

32
internal/focus/client.go Normal file
View File

@ -0,0 +1,32 @@
package focus
import (
"context"
"fmt"
"net"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
)
func TryRaise() bool {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, net.JoinHostPort(Host, fmt.Sprint(Port)), grpc.WithInsecure())
if err != nil {
return false
}
if _, err := proto.NewFocusClient(cc).Raise(ctx, &emptypb.Empty{}); err != nil {
return false
}
if err := cc.Close(); err != nil {
return false
}
return true
}

View File

@ -0,0 +1,25 @@
package focus
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestFocusRaise(t *testing.T) {
// Start the focus service.
service, err := NewService()
require.NoError(t, err)
// Try to dial it, it should succeed.
require.True(t, TryRaise())
// The service should report a raise call.
<-service.GetRaiseCh()
// Stop the service.
service.Close()
// Try to dial it, it should fail.
require.False(t, TryRaise())
}

View File

@ -0,0 +1,3 @@
package proto
//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative focus.proto

View File

@ -0,0 +1,93 @@
// Copyright (c) 2022 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.6
// source: focus.proto
package proto
import (
reflect "reflect"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
var File_focus_proto protoreflect.FileDescriptor
var file_focus_proto_rawDesc = []byte{
0x0a, 0x0b, 0x66, 0x6f, 0x63, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x32, 0x40, 0x0a, 0x05, 0x46, 0x6f, 0x63, 0x75, 0x73, 0x12, 0x37, 0x0a, 0x05, 0x52, 0x61,
0x69, 0x73, 0x65, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x16, 0x2e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d,
0x70, 0x74, 0x79, 0x42, 0x3d, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
0x6d, 0x2f, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x6e, 0x4d, 0x61, 0x69, 0x6c, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x6e, 0x2d, 0x62, 0x72, 0x69, 0x64, 0x67, 0x65, 0x2f, 0x76, 0x32, 0x2f, 0x69, 0x6e,
0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x66, 0x6f, 0x63, 0x75, 0x73, 0x2f, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var file_focus_proto_goTypes = []interface{}{
(*emptypb.Empty)(nil), // 0: google.protobuf.Empty
}
var file_focus_proto_depIdxs = []int32{
0, // 0: proto.Focus.Raise:input_type -> google.protobuf.Empty
0, // 1: proto.Focus.Raise:output_type -> google.protobuf.Empty
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_focus_proto_init() }
func file_focus_proto_init() {
if File_focus_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_focus_proto_rawDesc,
NumEnums: 0,
NumMessages: 0,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_focus_proto_goTypes,
DependencyIndexes: file_focus_proto_depIdxs,
}.Build()
File_focus_proto = out.File
file_focus_proto_rawDesc = nil
file_focus_proto_goTypes = nil
file_focus_proto_depIdxs = nil
}

View File

@ -0,0 +1,31 @@
// Copyright (c) 2022 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
syntax = "proto3";
import "google/protobuf/empty.proto";
option go_package = "github.com/ProtonMail/proton-bridge/v2/internal/focus/proto";
package proto;
//**********************************************************************************************************************
// Service Declaration
//**********************************************************************************************************************≠––
service Focus {
rpc Raise(google.protobuf.Empty) returns (google.protobuf.Empty);
}

View File

@ -0,0 +1,107 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.6
// source: focus.proto
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// FocusClient is the client API for Focus service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type FocusClient interface {
Raise(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
}
type focusClient struct {
cc grpc.ClientConnInterface
}
func NewFocusClient(cc grpc.ClientConnInterface) FocusClient {
return &focusClient{cc}
}
func (c *focusClient) Raise(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) {
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, "/proto.Focus/Raise", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// FocusServer is the server API for Focus service.
// All implementations must embed UnimplementedFocusServer
// for forward compatibility
type FocusServer interface {
Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
mustEmbedUnimplementedFocusServer()
}
// UnimplementedFocusServer must be embedded to have forward compatible implementations.
type UnimplementedFocusServer struct {
}
func (UnimplementedFocusServer) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Raise not implemented")
}
func (UnimplementedFocusServer) mustEmbedUnimplementedFocusServer() {}
// UnsafeFocusServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to FocusServer will
// result in compilation errors.
type UnsafeFocusServer interface {
mustEmbedUnimplementedFocusServer()
}
func RegisterFocusServer(s grpc.ServiceRegistrar, srv FocusServer) {
s.RegisterService(&Focus_ServiceDesc, srv)
}
func _Focus_Raise_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(FocusServer).Raise(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/proto.Focus/Raise",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(FocusServer).Raise(ctx, req.(*emptypb.Empty))
}
return interceptor(ctx, in, info, handler)
}
// Focus_ServiceDesc is the grpc.ServiceDesc for Focus service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Focus_ServiceDesc = grpc.ServiceDesc{
ServiceName: "proto.Focus",
HandlerType: (*FocusServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Raise",
Handler: _Focus_Raise_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "focus.proto",
}

60
internal/focus/service.go Normal file
View File

@ -0,0 +1,60 @@
package focus
import (
"context"
"fmt"
"net"
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
)
const (
Host = "127.0.0.1"
Port = 1042
)
type FocusService struct {
proto.UnimplementedFocusServer
server *grpc.Server
listener net.Listener
raiseCh chan struct{}
}
func NewService() (*FocusService, error) {
listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port)))
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
service := &FocusService{
server: grpc.NewServer(),
listener: listener,
raiseCh: make(chan struct{}, 1),
}
proto.RegisterFocusServer(service.server, service)
go func() {
if err := service.server.Serve(listener); err != nil {
fmt.Printf("failed to serve: %v", err)
}
}()
return service, nil
}
func (service *FocusService) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
service.raiseCh <- struct{}{}
return &emptypb.Empty{}, nil
}
func (service *FocusService) GetRaiseCh() <-chan struct{} {
return service.raiseCh
}
func (service *FocusService) Close() {
service.server.Stop()
}

View File

@ -22,7 +22,7 @@ import (
"strconv"
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/abiosoft/ishell"
)
@ -35,6 +35,7 @@ func (f *frontendCLI) completeUsernames(args []string) (usernames []string) {
if len(args) == 1 {
arg = args[0]
}
for _, userID := range f.bridge.GetUserIDs() {
user, err := f.bridge.GetUserInfo(userID)
if err != nil {
@ -50,8 +51,7 @@ func (f *frontendCLI) completeUsernames(args []string) (usernames []string) {
// noAccountWrapper is a decorator for functions which need any account to be properly functional.
func (f *frontendCLI) noAccountWrapper(callback func(*ishell.Context)) func(*ishell.Context) {
return func(c *ishell.Context) {
users := f.bridge.GetUserIDs()
if len(users) == 0 {
if len(f.bridge.GetUserIDs()) == 0 {
f.Println("No active accounts. Please add account to continue.")
} else {
callback(c)
@ -59,9 +59,9 @@ func (f *frontendCLI) noAccountWrapper(callback func(*ishell.Context)) func(*ish
}
}
func (f *frontendCLI) askUserByIndexOrName(c *ishell.Context) users.UserInfo {
func (f *frontendCLI) askUserByIndexOrName(c *ishell.Context) bridge.UserInfo {
user := f.getUserByIndexOrName("")
if user.ID != "" {
if user.UserID != "" {
return user
}
@ -69,24 +69,24 @@ func (f *frontendCLI) askUserByIndexOrName(c *ishell.Context) users.UserInfo {
indexRange := fmt.Sprintf("number between 0 and %d", numberOfAccounts-1)
if len(c.Args) == 0 {
f.Printf("Please choose %s or username.\n", indexRange)
return users.UserInfo{}
return bridge.UserInfo{}
}
arg := c.Args[0]
user = f.getUserByIndexOrName(arg)
if user.ID == "" {
if user.UserID == "" {
f.Printf("Wrong input '%s'. Choose %s or username.\n", bold(arg), indexRange)
return users.UserInfo{}
return bridge.UserInfo{}
}
return user
}
func (f *frontendCLI) getUserByIndexOrName(arg string) users.UserInfo {
func (f *frontendCLI) getUserByIndexOrName(arg string) bridge.UserInfo {
userIDs := f.bridge.GetUserIDs()
numberOfAccounts := len(userIDs)
if numberOfAccounts == 0 {
return users.UserInfo{}
return bridge.UserInfo{}
}
res := make([]users.UserInfo, len(userIDs))
res := make([]bridge.UserInfo, len(userIDs))
for idx, userID := range userIDs {
user, err := f.bridge.GetUserInfo(userID)
if err != nil {
@ -99,7 +99,7 @@ func (f *frontendCLI) getUserByIndexOrName(arg string) users.UserInfo {
}
if index, err := strconv.Atoi(arg); err == nil {
if index < 0 || index >= numberOfAccounts {
return users.UserInfo{}
return bridge.UserInfo{}
}
return res[index]
}
@ -108,5 +108,5 @@ func (f *frontendCLI) getUserByIndexOrName(arg string) users.UserInfo {
return user
}
}
return users.UserInfo{}
return bridge.UserInfo{}
}

View File

@ -22,8 +22,7 @@ import (
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/abiosoft/ishell"
)
@ -40,7 +39,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) {
connected = "connected"
}
mode := "split"
if user.Mode == users.CombinedMode {
if user.AddressMode == bridge.CombinedMode {
mode = "combined"
}
f.Printf(spacing, idx, user.Username, connected, mode)
@ -50,7 +49,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) {
func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
user := f.askUserByIndexOrName(c)
if user.ID == "" {
if user.UserID == "" {
return
}
@ -59,8 +58,8 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
return
}
if user.Mode == users.CombinedMode {
f.showAccountAddressInfo(user, user.Addresses[user.Primary])
if user.AddressMode == bridge.CombinedMode {
f.showAccountAddressInfo(user, user.Addresses[0])
} else {
for _, address := range user.Addresses {
f.showAccountAddressInfo(user, address)
@ -68,25 +67,31 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
}
}
func (f *frontendCLI) showAccountAddressInfo(user users.UserInfo, address string) {
func (f *frontendCLI) showAccountAddressInfo(user bridge.UserInfo, address string) {
imapSecurity := "STARTTLS"
if f.bridge.GetIMAPSSL() {
imapSecurity = "SSL"
}
smtpSecurity := "STARTTLS"
if f.bridge.GetBool(settings.SMTPSSLKey) {
if f.bridge.GetSMTPSSL() {
smtpSecurity = "SSL"
}
f.Println(bold("Configuration for " + address))
f.Printf("IMAP Settings\nAddress: %s\nIMAP port: %d\nUsername: %s\nPassword: %s\nSecurity: %s\n",
bridge.Host,
f.bridge.GetInt(settings.IMAPPortKey),
constants.Host,
f.bridge.GetIMAPPort(),
address,
user.Password,
"STARTTLS",
user.BridgePass,
imapSecurity,
)
f.Println("")
f.Printf("SMTP Settings\nAddress: %s\nSMTP port: %d\nUsername: %s\nPassword: %s\nSecurity: %s\n",
bridge.Host,
f.bridge.GetInt(settings.SMTPPortKey),
constants.Host,
f.bridge.GetSMTPPort(),
address,
user.Password,
user.BridgePass,
smtpSecurity,
)
f.Println("")
@ -99,8 +104,8 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { //nolint:funlen
loginName := ""
if len(c.Args) > 0 {
user := f.getUserByIndexOrName(c.Args[0])
if user.ID != "" {
loginName = user.Addresses[user.Primary]
if user.UserID != "" {
loginName = user.Addresses[0]
}
}
@ -119,41 +124,23 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { //nolint:funlen
}
f.Println("Authenticating ... ")
client, auth, err := f.bridge.Login(loginName, []byte(password))
userID, err := f.bridge.LoginUser(
context.Background(),
loginName,
password,
func() (string, error) {
return f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty), nil
},
func() ([]byte, error) {
return []byte(f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)), nil
},
)
if err != nil {
f.processAPIError(err)
return
}
if auth.HasTwoFactor() {
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
if twoFactor == "" {
return
}
err = client.Auth2FA(context.Background(), twoFactor)
if err != nil {
f.processAPIError(err)
return
}
}
mailboxPassword := password
if auth.HasMailboxPassword() {
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
}
if mailboxPassword == "" {
return
}
f.Println("Adding account ...")
userID, err := f.bridge.FinishLogin(client, auth, []byte(mailboxPassword))
if err != nil {
log.WithField("username", loginName).WithError(err).Error("Login was unsuccessful")
f.Println("Adding account was unsuccessful:", err)
return
}
user, err := f.bridge.GetUserInfo(userID)
if err != nil {
panic(err)
@ -167,11 +154,12 @@ func (f *frontendCLI) logoutAccount(c *ishell.Context) {
defer f.ShowPrompt(true)
user := f.askUserByIndexOrName(c)
if user.ID == "" {
if user.UserID == "" {
return
}
if f.yesNoQuestion("Are you sure you want to logout account " + bold(user.Username)) {
if err := f.bridge.LogoutUser(user.ID); err != nil {
if err := f.bridge.LogoutUser(context.Background(), user.UserID); err != nil {
f.printAndLogError("Logging out failed: ", err)
}
}
@ -182,12 +170,12 @@ func (f *frontendCLI) deleteAccount(c *ishell.Context) {
defer f.ShowPrompt(true)
user := f.askUserByIndexOrName(c)
if user.ID == "" {
if user.UserID == "" {
return
}
if f.yesNoQuestion("Are you sure you want to " + bold("remove account "+user.Username)) {
clearCache := f.yesNoQuestion("Do you want to remove cache for this account")
if err := f.bridge.DeleteUser(user.ID, clearCache); err != nil {
if err := f.bridge.DeleteUser(context.Background(), user.UserID); err != nil {
f.printAndLogError("Cannot delete account: ", err)
return
}
@ -205,10 +193,13 @@ func (f *frontendCLI) deleteAccounts(c *ishell.Context) {
for _, userID := range f.bridge.GetUserIDs() {
user, err := f.bridge.GetUserInfo(userID)
if err != nil {
panic(err)
f.printAndLogError("Cannot get user info: ", err)
return
}
if err := f.bridge.DeleteUser(user.ID, false); err != nil {
if err := f.bridge.DeleteUser(context.Background(), user.UserID); err != nil {
f.printAndLogError("Cannot delete account ", user.Username, ": ", err)
return
}
}
@ -223,37 +214,50 @@ func (f *frontendCLI) deleteEverything(c *ishell.Context) {
return
}
f.bridge.FactoryReset()
f.bridge.FactoryReset(context.Background())
c.Println("Everything cleared")
// Clearing data removes everything (db, preferences, ...) so everything has to be stopped and started again.
f.restarter.SetToRestart()
f.Stop()
}
func (f *frontendCLI) changeMode(c *ishell.Context) {
user := f.askUserByIndexOrName(c)
if user.ID == "" {
if user.UserID == "" {
return
}
var targetMode users.AddressMode
var targetMode bridge.AddressMode
if user.Mode == users.CombinedMode {
targetMode = users.SplitMode
if user.AddressMode == bridge.CombinedMode {
targetMode = bridge.SplitMode
} else {
targetMode = users.CombinedMode
targetMode = bridge.CombinedMode
}
if !f.yesNoQuestion("Are you sure you want to change the mode for account " + bold(user.Username) + " to " + bold(targetMode)) {
return
}
if err := f.bridge.SetAddressMode(user.ID, targetMode); err != nil {
if err := f.bridge.SetAddressMode(user.UserID, targetMode); err != nil {
f.printAndLogError("Cannot switch address mode:", err)
}
f.Printf("Address mode for account %s changed to %s\n", user.Username, targetMode)
}
func (f *frontendCLI) configureAppleMail(c *ishell.Context) {
user := f.askUserByIndexOrName(c)
if user.UserID == "" {
return
}
if !f.yesNoQuestion("Are you sure you want to configure Apple Mail for " + bold(user.Username) + " with address " + bold(user.Addresses[0])) {
return
}
if err := f.bridge.ConfigureAppleMail(user.UserID, user.Addresses[0]); err != nil {
f.printAndLogError(err)
return
}
f.Printf("Apple Mail configured for %v with address %v\n", user.Username, user.Addresses[0])
}

View File

@ -19,11 +19,13 @@
package cli
import (
"errors"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"gitlab.protontech.ch/go/liteapi"
"github.com/abiosoft/ishell"
"github.com/sirupsen/logrus"
@ -34,30 +36,14 @@ var log = logrus.WithField("pkg", "frontend/cli") //nolint:gochecknoglobals
type frontendCLI struct {
*ishell.Shell
eventListener listener.Listener
updater types.Updater
bridge types.Bridger
restarter types.Restarter
bridge *bridge.Bridge
}
// New returns a new CLI frontend configured with the given options.
func New( //nolint:funlen
panicHandler types.PanicHandler,
eventListener listener.Listener,
updater types.Updater,
bridge types.Bridger,
restarter types.Restarter,
) *frontendCLI { //nolint:revive
func New(bridge *bridge.Bridge) *frontendCLI {
fe := &frontendCLI{
Shell: ishell.New(),
eventListener: eventListener,
updater: updater,
bridge: bridge,
restarter: restarter,
Shell: ishell.New(),
bridge: bridge,
}
// Clear commands.
@ -66,12 +52,6 @@ func New( //nolint:funlen
Help: "remove stored accounts and preferences. (alias: cl)",
Aliases: []string{"cl"},
}
clearCmd.AddCmd(&ishell.Cmd{
Name: "cache",
Help: "remove stored preferences for accounts (aliases: c, prefs, preferences)",
Aliases: []string{"c", "prefs", "preferences"},
Func: fe.deleteCache,
})
clearCmd.AddCmd(&ishell.Cmd{
Name: "accounts",
Help: "remove all accounts from keychain. (aliases: a, k, keychain)",
@ -100,15 +80,30 @@ func New( //nolint:funlen
Completer: fe.completeUsernames,
})
changeCmd.AddCmd(&ishell.Cmd{
Name: "port",
Help: "change port numbers of IMAP and SMTP servers. (alias: p)",
Aliases: []string{"p"},
Func: fe.changePort,
Name: "change-location",
Help: "change the location of the encrypted message cache",
Func: fe.setGluonLocation,
})
changeCmd.AddCmd(&ishell.Cmd{
Name: "imap-port",
Help: "change port number of IMAP server.",
Func: fe.changeIMAPPort,
})
changeCmd.AddCmd(&ishell.Cmd{
Name: "smtp-port",
Help: "change port number of SMTP server.",
Func: fe.changeSMTPPort,
})
changeCmd.AddCmd(&ishell.Cmd{
Name: "imap-security",
Help: "change IMAP SSL settings servers.(alias: ssl-imap, starttls-imap)",
Aliases: []string{"ssl-imap", "starttls-imap"},
Func: fe.changeIMAPSecurity,
})
changeCmd.AddCmd(&ishell.Cmd{
Name: "smtp-security",
Help: "change port numbers of IMAP and SMTP servers.(alias: ssl, starttls)",
Aliases: []string{"ssl", "starttls"},
Help: "change SMTP SSL settings servers.(alias: ssl-smtp, starttls-smtp)",
Aliases: []string{"ssl-smtp", "starttls-smtp"},
Func: fe.changeSMTPSecurity,
})
fe.AddCmd(changeCmd)
@ -130,6 +125,22 @@ func New( //nolint:funlen
})
fe.AddCmd(dohCmd)
// Apple Mail commands.
configureCmd := &ishell.Cmd{
Name: "configure-apple-mail",
Help: "Configures Apple Mail to use ProtonMail Bridge",
Func: fe.configureAppleMail,
}
fe.AddCmd(configureCmd)
// TLS commands.
exportTLSCmd := &ishell.Cmd{
Name: "export-tls",
Help: "Export the TLS certificate used by the Bridge",
Func: fe.exportTLSCerts,
}
fe.AddCmd(exportTLSCmd)
// All mail visibility commands.
allMailCmd := &ishell.Cmd{
Name: "all-mail-visibility",
@ -147,28 +158,6 @@ func New( //nolint:funlen
})
fe.AddCmd(allMailCmd)
// Cache-On-Disk commands.
codCmd := &ishell.Cmd{
Name: "local-cache",
Help: "manage the local encrypted message cache",
}
codCmd.AddCmd(&ishell.Cmd{
Name: "enable",
Help: "enable the local cache",
Func: fe.enableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{
Name: "disable",
Help: "disable the local cache",
Func: fe.disableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{
Name: "change-location",
Help: "change the location of the local cache",
Func: fe.setCacheOnDiskLocation,
})
fe.AddCmd(codCmd)
// Updates commands.
updatesCmd := &ishell.Cmd{
Name: "updates",
@ -224,7 +213,6 @@ func New( //nolint:funlen
Aliases: []string{"man"},
Func: fe.printManual,
})
fe.AddCmd(&ishell.Cmd{
Name: "credits",
Help: "print used resources.",
@ -267,55 +255,122 @@ func New( //nolint:funlen
Completer: fe.completeUsernames,
})
// System commands.
fe.AddCmd(&ishell.Cmd{
Name: "restart",
Help: "restart the bridge.",
Func: fe.restart,
})
go fe.watchEvents()
go func() {
defer panicHandler.HandlePanic()
fe.watchEvents()
}()
return fe
}
func (f *frontendCLI) watchEvents() {
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetConnChangedCh := f.eventListener.ProvideChannel(events.InternetConnChangedEvent)
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
f.Println("Bridge failed:", errorDetails)
case <-credentialsErrorCh:
eventCh, done := f.bridge.GetEvents()
defer done()
// TODO: Better error events.
for _, err := range f.bridge.GetErrors() {
switch {
case errors.Is(err, vault.ErrCorrupt):
f.notifyCredentialsError()
case stat := <-internetConnChangedCh:
if stat == events.InternetOff {
case errors.Is(err, vault.ErrInsecure):
f.notifyCredentialsError()
case errors.Is(err, bridge.ErrServeIMAP):
f.Println("IMAP server error:", err)
case errors.Is(err, bridge.ErrServeSMTP):
f.Println("SMTP server error:", err)
}
}
for event := range eventCh {
switch event := event.(type) {
case events.ConnStatus:
switch event.Status {
case liteapi.StatusUp:
f.notifyInternetOn()
case liteapi.StatusDown:
f.notifyInternetOff()
}
if stat == events.InternetOn {
f.notifyInternetOn()
}
case address := <-addressChangedCh:
f.Printf("Address changed for %s. You may need to reconfigure your email client.", address)
case address := <-addressChangedLogoutCh:
f.notifyLogout(address)
case userID := <-logoutCh:
user, err := f.bridge.GetUserInfo(userID)
case events.UserDeauth:
user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil {
return
}
f.notifyLogout(user.Username)
case <-certIssue:
case events.UserAddressChanged:
user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil {
return
}
f.Printf("Address changed for %s. You may need to reconfigure your email client.\n", user.Username)
case events.UserAddressDeleted:
f.notifyLogout(event.Address)
case events.SyncStarted:
user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil {
return
}
f.Printf("A sync has begun for %s.\n", user.Username)
case events.SyncFinished:
user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil {
return
}
f.Printf("A sync has finished for %s.\n", user.Username)
case events.SyncProgress:
user, err := f.bridge.GetUserInfo(event.UserID)
if err != nil {
return
}
f.Printf(
"Sync (%v): %.1f%% (Elapsed: %0.1fs, ETA: %0.1fs)\n",
user.Username,
100*event.Progress,
event.Elapsed.Seconds(),
event.Remaining.Seconds(),
)
case events.UpdateAvailable:
f.Printf("An update is available (version %v)\n", event.Version.Version)
case events.UpdateForced:
f.notifyNeedUpgrade()
case events.TLSIssue:
f.notifyCertIssue()
}
}
/*
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
for {
select {
case errorDetails := <-errorCh:
f.Println("Bridge failed:", errorDetails)
case <-credentialsErrorCh:
f.notifyCredentialsError()
case stat := <-internetConnChangedCh:
if stat == events.InternetOff {
f.notifyInternetOff()
}
if stat == events.InternetOn {
f.notifyInternetOn()
}
}
}
*/
}
// Loop starts the frontend loop with an interactive shell.
@ -340,12 +395,3 @@ func (f *frontendCLI) Loop() error {
f.Run()
return nil
}
func (f *frontendCLI) NotifyManualUpdate(update updater.VersionInfo, canInstall bool) {
// NOTE: Save the update somewhere so that it can be installed when user chooses "install now".
}
func (f *frontendCLI) WaitUntilFrontendIsReady() {}
func (f *frontendCLI) SetVersion(version updater.VersionInfo) {}
func (f *frontendCLI) NotifySilentUpdateInstalled() {}
func (f *frontendCLI) NotifySilentUpdateError(err error) {}

View File

@ -18,28 +18,21 @@
package cli
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
"github.com/abiosoft/ishell"
)
var currentPort = "" //nolint:gochecknoglobals
func (f *frontendCLI) restart(c *ishell.Context) {
if f.yesNoQuestion("Are you sure you want to restart the Bridge") {
f.Println("Restarting Bridge...")
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.bridge.ProvideLogsPath(); err != nil {
if path, err := f.bridge.GetLogsPath(); err != nil {
f.Println("Failed to determine location of log files")
} else {
f.Println("Log files are stored in\n\n ", path)
@ -50,79 +43,91 @@ func (f *frontendCLI) printManual(c *ishell.Context) {
f.Println("More instructions about the Bridge can be found at\n\n https://protonmail.com/bridge")
}
func (f *frontendCLI) deleteCache(c *ishell.Context) {
func (f *frontendCLI) printCredits(c *ishell.Context) {
for _, pkg := range strings.Split(bridge.Credits, ";") {
f.Println(pkg)
}
}
func (f *frontendCLI) changeIMAPSecurity(c *ishell.Context) {
f.ShowPrompt(false)
defer f.ShowPrompt(true)
if !f.yesNoQuestion("Do you really want to remove all stored preferences") {
return
newSecurity := "SSL"
if f.bridge.GetIMAPSSL() {
newSecurity = "STARTTLS"
}
if err := f.bridge.ClearData(); err != nil {
f.printAndLogError("Cache clear failed: ", err.Error())
return
msg := fmt.Sprintf("Are you sure you want to change IMAP setting to %q", newSecurity)
if f.yesNoQuestion(msg) {
if err := f.bridge.SetIMAPSSL(!f.bridge.GetIMAPSSL()); err != nil {
f.printAndLogError(err)
return
}
}
f.Println("Cached cleared, restarting bridge")
// Clearing data removes everything (db, preferences, ...) so everything has to be stopped and started again.
f.restarter.SetToRestart()
f.Stop()
}
func (f *frontendCLI) changeSMTPSecurity(c *ishell.Context) {
f.ShowPrompt(false)
defer f.ShowPrompt(true)
isSSL := f.bridge.GetBool(settings.SMTPSSLKey)
newSecurity := "SSL"
if isSSL {
if f.bridge.GetSMTPSSL() {
newSecurity = "STARTTLS"
}
msg := fmt.Sprintf("Are you sure you want to change SMTP setting to %q and restart the Bridge", newSecurity)
msg := fmt.Sprintf("Are you sure you want to change SMTP setting to %q", newSecurity)
if f.yesNoQuestion(msg) {
f.bridge.SetBool(settings.SMTPSSLKey, !isSSL)
f.Println("Restarting Bridge...")
f.restarter.SetToRestart()
f.Stop()
if err := f.bridge.SetSMTPSSL(!f.bridge.GetSMTPSSL()); err != nil {
f.printAndLogError(err)
return
}
}
}
func (f *frontendCLI) changePort(c *ishell.Context) {
func (f *frontendCLI) changeIMAPPort(c *ishell.Context) {
f.ShowPrompt(false)
defer f.ShowPrompt(true)
currentPort = f.bridge.Get(settings.IMAPPortKey)
newIMAPPort := f.readStringInAttempts("Set IMAP port (current "+currentPort+")", c.ReadLine, f.isPortFree)
newIMAPPort := f.readStringInAttempts(fmt.Sprintf("Set IMAP port (current %v)", f.bridge.GetIMAPPort()), c.ReadLine, f.isPortFree)
if newIMAPPort == "" {
newIMAPPort = currentPort
}
imapPortChanged := newIMAPPort != currentPort
currentPort = f.bridge.Get(settings.SMTPPortKey)
newSMTPPort := f.readStringInAttempts("Set SMTP port (current "+currentPort+")", c.ReadLine, f.isPortFree)
if newSMTPPort == "" {
newSMTPPort = currentPort
}
smtpPortChanged := newSMTPPort != currentPort
if newIMAPPort == newSMTPPort {
f.Println("SMTP and IMAP ports must be different!")
f.printAndLogError(errors.New("failed to get new port"))
return
}
if imapPortChanged || smtpPortChanged {
f.Println("Saving values IMAP:", newIMAPPort, "SMTP:", newSMTPPort)
f.bridge.Set(settings.IMAPPortKey, newIMAPPort)
f.bridge.Set(settings.SMTPPortKey, newSMTPPort)
f.Println("Restarting Bridge...")
f.restarter.SetToRestart()
f.Stop()
} else {
f.Println("Nothing changed")
newIMAPPortInt, err := strconv.Atoi(newIMAPPort)
if err != nil {
f.printAndLogError(err)
return
}
if err := f.bridge.SetIMAPPort(newIMAPPortInt); err != nil {
f.printAndLogError(err)
return
}
}
func (f *frontendCLI) changeSMTPPort(c *ishell.Context) {
f.ShowPrompt(false)
defer f.ShowPrompt(true)
newSMTPPort := f.readStringInAttempts(fmt.Sprintf("Set SMTP port (current %v)", f.bridge.GetSMTPPort()), c.ReadLine, f.isPortFree)
if newSMTPPort == "" {
f.printAndLogError(errors.New("failed to get new port"))
return
}
newSMTPPortInt, err := strconv.Atoi(newSMTPPort)
if err != nil {
f.printAndLogError(err)
return
}
if err := f.bridge.SetSMTPPort(newSMTPPortInt); err != nil {
f.printAndLogError(err)
return
}
}
@ -135,7 +140,10 @@ func (f *frontendCLI) allowProxy(c *ishell.Context) {
f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.")
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
f.bridge.SetProxyAllowed(true)
if err := f.bridge.SetProxyAllowed(true); err != nil {
f.printAndLogError(err)
return
}
}
}
@ -148,12 +156,15 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.")
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
f.bridge.SetProxyAllowed(false)
if err := f.bridge.SetProxyAllowed(false); err != nil {
f.printAndLogError(err)
return
}
}
}
func (f *frontendCLI) hideAllMail(c *ishell.Context) {
if !f.bridge.IsAllMailVisible() {
if !f.bridge.GetShowAllMail() {
f.Println("All Mail folder is not listed in your local client.")
return
}
@ -161,12 +172,15 @@ func (f *frontendCLI) hideAllMail(c *ishell.Context) {
f.Println("All Mail folder is listed in your client right now.")
if f.yesNoQuestion("Do you want to hide All Mail folder") {
f.bridge.SetIsAllMailVisible(false)
if err := f.bridge.SetShowAllMail(false); err != nil {
f.printAndLogError(err)
return
}
}
}
func (f *frontendCLI) showAllMail(c *ishell.Context) {
if f.bridge.IsAllMailVisible() {
if f.bridge.GetShowAllMail() {
f.Println("All Mail folder is listed in your local client.")
return
}
@ -174,68 +188,47 @@ func (f *frontendCLI) showAllMail(c *ishell.Context) {
f.Println("All Mail folder is not listed in your client right now.")
if f.yesNoQuestion("Do you want to show All Mail folder") {
f.bridge.SetIsAllMailVisible(true)
if err := f.bridge.SetShowAllMail(true); err != nil {
f.printAndLogError(err)
return
}
}
}
func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) {
if f.bridge.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already enabled.")
return
func (f *frontendCLI) setGluonLocation(c *ishell.Context) {
if gluonDir := f.bridge.GetGluonDir(); gluonDir != "" {
f.Println("The current message cache location is:", gluonDir)
}
if f.yesNoQuestion("Are you sure you want to enable the local cache") {
if err := f.bridge.EnableCache(); err != nil {
f.Println("The local cache could not be enabled.")
if location := f.readStringInAttempts("Enter a new location for the message cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
if err := f.bridge.SetGluonDir(context.Background(), location); err != nil {
f.printAndLogError(err)
return
}
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) {
if !f.bridge.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already disabled.")
return
}
func (f *frontendCLI) exportTLSCerts(c *ishell.Context) {
if location := f.readStringInAttempts("Enter a path to which to export the TLS certificate used for IMAP and SMTP", c.ReadLine, f.isCacheLocationUsable); location != "" {
cert, key := f.bridge.GetBridgeTLSCert()
if f.yesNoQuestion("Are you sure you want to disable the local cache") {
if err := f.bridge.DisableCache(); err != nil {
f.Println("The local cache could not be disabled.")
if err := os.WriteFile(filepath.Join(location, "cert.pem"), cert, 0600); err != nil {
f.printAndLogError(err)
return
}
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) {
if !f.bridge.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache must be enabled.")
return
}
if location := f.bridge.Get(settings.CacheLocationKey); location != "" {
f.Println("The current local cache location is:", location)
}
if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
if err := f.bridge.MigrateCache(f.bridge.Get(settings.CacheLocationKey), location); err != nil {
f.Println("The local cache location could not be changed.")
if err := os.WriteFile(filepath.Join(location, "key.pem"), key, 0600); err != nil {
f.printAndLogError(err)
return
}
f.restarter.SetToRestart()
f.Stop()
f.Println("TLS certificate exported to", location)
}
}
func (f *frontendCLI) isPortFree(port string) bool {
port = strings.ReplaceAll(port, ":", "")
if port == "" || port == currentPort {
if port == "" {
return true
}
number, err := strconv.Atoi(port)

View File

@ -18,36 +18,16 @@
package cli
import (
"strings"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/abiosoft/ishell"
)
func (f *frontendCLI) checkUpdates(c *ishell.Context) {
version, err := f.updater.Check()
if err != nil {
f.Println("An error occurred while checking for updates.")
return
}
if f.updater.IsUpdateApplicable(version) {
f.Println("An update is available.")
} else {
f.Println("Your version is up to date.")
}
}
func (f *frontendCLI) printCredits(c *ishell.Context) {
for _, pkg := range strings.Split(bridge.Credits, ";") {
f.Println(pkg)
}
f.bridge.CheckForUpdates()
}
func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
if f.bridge.GetBool(settings.AutoUpdateKey) {
if f.bridge.GetAutoUpdate() {
f.Println("Bridge is already set to automatically install updates.")
return
}
@ -55,12 +35,15 @@ func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
f.Println("Bridge is currently set to NOT automatically install updates.")
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
f.bridge.SetBool(settings.AutoUpdateKey, true)
if err := f.bridge.SetAutoUpdate(true); err != nil {
f.printAndLogError(err)
return
}
}
}
func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
if !f.bridge.GetBool(settings.AutoUpdateKey) {
if !f.bridge.GetAutoUpdate() {
f.Println("Bridge is already set to NOT automatically install updates.")
return
}
@ -68,7 +51,10 @@ func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
f.Println("Bridge is currently set to automatically install updates.")
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
f.bridge.SetBool(settings.AutoUpdateKey, false)
if err := f.bridge.SetAutoUpdate(false); err != nil {
f.printAndLogError(err)
return
}
}
}
@ -81,7 +67,10 @@ func (f *frontendCLI) selectEarlyChannel(c *ishell.Context) {
f.Println("Bridge is currently on the stable update channel.")
if f.yesNoQuestion("Are you sure you want to switch to the early-access update channel") {
f.bridge.SetUpdateChannel(updater.EarlyChannel)
if err := f.bridge.SetUpdateChannel(updater.EarlyChannel); err != nil {
f.printAndLogError(err)
return
}
}
}
@ -95,6 +84,9 @@ func (f *frontendCLI) selectStableChannel(c *ishell.Context) {
f.Println("Switching to the stable channel may reset all data!")
if f.yesNoQuestion("Are you sure you want to switch to the stable update channel") {
f.bridge.SetUpdateChannel(updater.StableChannel)
if err := f.bridge.SetUpdateChannel(updater.StableChannel); err != nil {
f.printAndLogError(err)
return
}
}
}

View File

@ -20,7 +20,6 @@ package cli
import (
"strings"
pmapi "github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/fatih/color"
)
@ -67,15 +66,7 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
}
func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err)
switch err {
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()
default:
f.Println("Server error:", err.Error())
}
f.printAndLogError(err)
}
func (f *frontendCLI) notifyInternetOff() {
@ -91,12 +82,7 @@ func (f *frontendCLI) notifyLogout(address string) {
}
func (f *frontendCLI) notifyNeedUpgrade() {
version, err := f.updater.Check()
if err != nil {
log.WithError(err).Error("Failed to notify need upgrade")
return
}
f.Println("Please download and install the newest version of application from", version.LandingPage)
f.Println("Please download and install the newest version of the application.")
}
func (f *frontendCLI) notifyCredentialsError() {

View File

@ -1,75 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package frontend provides all interfaces of the Bridge.
package frontend
import (
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
)
type Frontend interface {
Loop() error
NotifyManualUpdate(update updater.VersionInfo, canInstall bool)
SetVersion(update updater.VersionInfo)
NotifySilentUpdateInstalled()
NotifySilentUpdateError(error)
WaitUntilFrontendIsReady()
}
// New returns initialized frontend based on `frontendType`, which can be `cli` or `grpc`.
func New(
frontendType string,
showWindowOnStart bool,
panicHandler types.PanicHandler,
eventListener listener.Listener,
updater types.Updater,
bridge *bridge.Bridge,
restarter types.Restarter,
locations *locations.Locations,
) Frontend {
switch frontendType {
case "grpc":
return grpc.NewService(
showWindowOnStart,
panicHandler,
eventListener,
updater,
bridge,
restarter,
locations,
)
case "cli":
return cli.New(
panicHandler,
eventListener,
updater,
bridge,
restarter,
)
default:
return nil
}
}

File diff suppressed because it is too large Load Diff

View File

@ -72,7 +72,6 @@ service Bridge {
rpc IsAutomaticUpdateOn(google.protobuf.Empty) returns (google.protobuf.BoolValue);
// cache
rpc IsCacheOnDiskEnabled (google.protobuf.Empty) returns (google.protobuf.BoolValue);
rpc DiskCachePath(google.protobuf.Empty) returns (google.protobuf.StringValue);
rpc ChangeLocalCache(ChangeLocalCacheRequest) returns (google.protobuf.Empty);
@ -160,7 +159,6 @@ message LoginAbortRequest {
// Cache on disk related message
//**********************************************************
message ChangeLocalCacheRequest {
bool enableDiskCache = 1;
string diskCachePath = 2;
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.3
// - protoc v3.21.7
// source: bridge.proto
package grpc
@ -64,7 +64,6 @@ type BridgeClient interface {
SetIsAutomaticUpdateOn(ctx context.Context, in *wrapperspb.BoolValue, opts ...grpc.CallOption) (*emptypb.Empty, error)
IsAutomaticUpdateOn(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.BoolValue, error)
// cache
IsCacheOnDiskEnabled(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.BoolValue, error)
DiskCachePath(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.StringValue, error)
ChangeLocalCache(ctx context.Context, in *ChangeLocalCacheRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
// mail
@ -425,15 +424,6 @@ func (c *bridgeClient) IsAutomaticUpdateOn(ctx context.Context, in *emptypb.Empt
return out, nil
}
func (c *bridgeClient) IsCacheOnDiskEnabled(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.BoolValue, error) {
out := new(wrapperspb.BoolValue)
err := c.cc.Invoke(ctx, "/grpc.Bridge/IsCacheOnDiskEnabled", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *bridgeClient) DiskCachePath(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.StringValue, error) {
out := new(wrapperspb.StringValue)
err := c.cc.Invoke(ctx, "/grpc.Bridge/DiskCachePath", in, out, opts...)
@ -699,7 +689,6 @@ type BridgeServer interface {
SetIsAutomaticUpdateOn(context.Context, *wrapperspb.BoolValue) (*emptypb.Empty, error)
IsAutomaticUpdateOn(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error)
// cache
IsCacheOnDiskEnabled(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error)
DiskCachePath(context.Context, *emptypb.Empty) (*wrapperspb.StringValue, error)
ChangeLocalCache(context.Context, *ChangeLocalCacheRequest) (*emptypb.Empty, error)
// mail
@ -841,9 +830,6 @@ func (UnimplementedBridgeServer) SetIsAutomaticUpdateOn(context.Context, *wrappe
func (UnimplementedBridgeServer) IsAutomaticUpdateOn(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error) {
return nil, status.Errorf(codes.Unimplemented, "method IsAutomaticUpdateOn not implemented")
}
func (UnimplementedBridgeServer) IsCacheOnDiskEnabled(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error) {
return nil, status.Errorf(codes.Unimplemented, "method IsCacheOnDiskEnabled not implemented")
}
func (UnimplementedBridgeServer) DiskCachePath(context.Context, *emptypb.Empty) (*wrapperspb.StringValue, error) {
return nil, status.Errorf(codes.Unimplemented, "method DiskCachePath not implemented")
}
@ -1571,24 +1557,6 @@ func _Bridge_IsAutomaticUpdateOn_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler)
}
func _Bridge_IsCacheOnDiskEnabled_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BridgeServer).IsCacheOnDiskEnabled(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/grpc.Bridge/IsCacheOnDiskEnabled",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BridgeServer).IsCacheOnDiskEnabled(ctx, req.(*emptypb.Empty))
}
return interceptor(ctx, in, info, handler)
}
func _Bridge_DiskCachePath_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
@ -2139,10 +2107,6 @@ var Bridge_ServiceDesc = grpc.ServiceDesc{
MethodName: "IsAutomaticUpdateOn",
Handler: _Bridge_IsAutomaticUpdateOn_Handler,
},
{
MethodName: "IsCacheOnDiskEnabled",
Handler: _Bridge_IsCacheOnDiskEnabled_Handler,
},
{
MethodName: "DiskCachePath",
Handler: _Bridge_DiskCachePath_Handler,

View File

@ -15,25 +15,18 @@
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package serverutil
package grpc
import (
"github.com/sirupsen/logrus"
)
import "github.com/bradenaw/juniper/xslices"
// ServerErrorLogger implements go-imap/logger interface.
type ServerErrorLogger struct {
l *logrus.Entry
// isInternetStatus returns true iff the event is InternetStatus.
func (x *StreamEvent) isInternetStatus() bool {
appEvent := x.GetApp()
return (appEvent != nil) && (appEvent.GetInternetStatus() != nil)
}
func NewServerErrorLogger(protocol Protocol) *ServerErrorLogger {
return &ServerErrorLogger{l: logrus.WithField("protocol", protocol)}
}
func (s *ServerErrorLogger) Printf(format string, args ...interface{}) {
s.l.Errorf(format, args...)
}
func (s *ServerErrorLogger) Println(args ...interface{}) {
s.l.Errorln(args...)
// filterOutInternetStatusEvents return a copy of the events list where all internet connection events have been removed.
func filterOutInternetStatusEvents(events []*StreamEvent) []*StreamEvent {
return xslices.Filter(events, func(event *StreamEvent) bool { return !event.isInternetStatus() })
}

View File

@ -21,34 +21,29 @@ package grpc
import (
"context"
cryptotls "crypto/tls"
"crypto/tls"
"errors"
"fmt"
"net"
"path/filepath"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/internal/users"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/pkg/restarter"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
status "google.golang.org/grpc/status"
)
const (
@ -59,6 +54,7 @@ const (
// Service is the RPC service struct.
type Service struct { // nolint:structcheck
UnimplementedBridgeServer
grpcServer *grpc.Server // the gGRPC server
listener net.Listener
eventStreamCh chan *StreamEvent
@ -66,99 +62,87 @@ type Service struct { // nolint:structcheck
eventQueue []*StreamEvent
eventQueueMutex sync.Mutex
panicHandler types.PanicHandler
eventListener listener.Listener
updater types.Updater
updateCheckMutex sync.Mutex
bridge types.Bridger
restarter types.Restarter
showOnStartup bool
authClient pmapi.Client
auth *pmapi.Auth
password []byte
newVersionInfo updater.VersionInfo
panicHandler *crash.Handler
restarter *restarter.Restarter
bridge *bridge.Bridge
newVersionInfo updater.VersionInfo
log *logrus.Entry
initializing sync.WaitGroup
initializationDone sync.Once
firstTimeAutostart sync.Once
locations *locations.Locations
token string
pemCert string
showOnStartup bool
}
// NewService returns a new instance of the service.
func NewService(
showOnStartup bool,
panicHandler types.PanicHandler,
eventListener listener.Listener,
updater types.Updater,
bridge types.Bridger,
restarter types.Restarter,
panicHandler *crash.Handler,
restarter *restarter.Restarter,
locations *locations.Locations,
) *Service {
s := Service{
UnimplementedBridgeServer: UnimplementedBridgeServer{},
panicHandler: panicHandler,
eventListener: eventListener,
updater: updater,
bridge: bridge,
restarter: restarter,
showOnStartup: showOnStartup,
bridge *bridge.Bridge,
showOnStartup bool,
) (*Service, error) {
tlsConfig, certPEM, err := newTLSConfig()
if err != nil {
logrus.WithError(err).Panic("Could not generate gRPC TLS config")
}
listener, err := net.Listen("tcp", "127.0.0.1:0") // Port should be provided by the OS.
if err != nil {
logrus.WithError(err).Panic("Could not create gRPC listener")
}
token := uuid.NewString()
if path, err := saveGRPCServerConfigFile(locations, listener, token, certPEM); err != nil {
logrus.WithError(err).WithField("path", path).Panic("Could not write gRPC service config file")
} else {
logrus.WithField("path", path).Info("Successfully saved gRPC service config file")
}
s := &Service{
grpcServer: grpc.NewServer(
grpc.Creds(credentials.NewTLS(tlsConfig)),
grpc.UnaryInterceptor(newUnaryTokenValidator(token)),
grpc.StreamInterceptor(newStreamTokenValidator(token)),
),
listener: listener,
panicHandler: panicHandler,
restarter: restarter,
bridge: bridge,
log: logrus.WithField("pkg", "grpc"),
initializing: sync.WaitGroup{},
initializationDone: sync.Once{},
firstTimeAutostart: sync.Once{},
locations: locations,
token: uuid.NewString(),
showOnStartup: showOnStartup,
}
// Initializing.Done is only called sync.Once. Please keep the increment
// set to 1
// Initializing.Done is only called sync.Once. Please keep the increment set to 1
s.initializing.Add(1)
tlsConfig, pemCert, err := s.generateTLSConfig()
if err != nil {
s.log.WithError(err).Panic("Could not generate gRPC TLS config")
}
s.pemCert = string(pemCert)
// Initialize the autostart.
s.initAutostart()
s.grpcServer = grpc.NewServer(
grpc.Creds(credentials.NewTLS(tlsConfig)),
grpc.UnaryInterceptor(s.validateUnaryServerToken),
grpc.StreamInterceptor(s.validateStreamServerToken),
)
RegisterBridgeServer(s.grpcServer, &s)
s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system.
if err != nil {
s.log.WithError(err).Panic("Could not create gRPC listener")
}
if path, err := s.saveGRPCServerConfigFile(); err != nil {
s.log.WithError(err).WithField("path", path).Panic("Could not write gRPC service config file")
} else {
s.log.WithField("path", path).Info("Successfully saved gRPC service config file")
}
// Register the gRPC service implementation.
RegisterBridgeServer(s.grpcServer, s)
s.log.Info("gRPC server listening on ", s.listener.Addr())
return &s
return s, nil
}
// GODT-1507 Windows: autostart needs to be created after Qt is initialized.
// GODT-1206: if preferences file says it should be on enable it here.
// TO-DO GODT-1681 Autostart needs to be properly implement for gRPC approach.
func (s *Service) initAutostart() {
// GODT-1507 Windows: autostart needs to be created after Qt is initialized.
// GODT-1206: if preferences file says it should be on enable it here.
// TO-DO GODT-1681 Autostart needs to be properly implement for gRPC approach.
s.firstTimeAutostart.Do(func() {
shouldAutostartBeOn := s.bridge.GetBool(settings.AutostartKey)
if s.bridge.IsFirstStart() || shouldAutostartBeOn {
if err := s.bridge.EnableAutostart(); err != nil {
shouldAutostartBeOn := s.bridge.GetAutostart()
if s.bridge.GetFirstStart() || shouldAutostartBeOn {
if err := s.bridge.SetAutostart(true); err != nil {
s.log.WithField("prefs", shouldAutostartBeOn).WithError(err).Error("Failed to enable first autostart")
}
return
@ -168,7 +152,7 @@ func (s *Service) initAutostart() {
func (s *Service) Loop() error {
defer func() {
s.bridge.SetBool(settings.FirstStartGUIKey, false)
_ = s.bridge.SetFirstStartGUI(false)
}()
go func() {
@ -179,7 +163,7 @@ func (s *Service) Loop() error {
s.log.Info("Starting gRPC server")
if err := s.grpcServer.Serve(s.listener); err != nil {
s.log.WithError(err).Error("Error serving gRPC")
s.log.WithError(err).Error("Failed to serve gRPC")
return err
}
@ -212,140 +196,59 @@ func (s *Service) WaitUntilFrontendIsReady() {
s.initializing.Wait()
}
func (s *Service) watchEvents() { // nolint:funlen
if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) {
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR))
}
func (s *Service) watchEvents() {
eventCh, done := s.bridge.GetEvents()
defer done()
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
internetConnChangedCh := s.eventListener.ProvideChannel(events.InternetConnChangedEvent)
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
userChangedCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
// we forward events to the GUI/frontend via the gRPC event stream.
for {
select {
case errorDetails := <-errorCh:
if strings.Contains(errorDetails, "IMAP failed") {
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_IMAP_PORT_ISSUE))
}
if strings.Contains(errorDetails, "SMTP failed") {
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_SMTP_PORT_ISSUE))
}
case reason := <-credentialsErrorCh:
if reason == keychain.ErrMacKeychainRebuild.Error() {
_ = s.SendEvent(NewKeychainRebuildKeychainEvent())
continue
}
// TODO: Better error events.
for _, err := range s.bridge.GetErrors() {
switch {
case errors.Is(err, vault.ErrCorrupt):
_ = s.SendEvent(NewKeychainHasNoKeychainEvent())
case email := <-noActiveKeyForRecipientCh:
_ = s.SendEvent(NewMailNoActiveKeyForRecipientEvent(email))
case stat := <-internetConnChangedCh:
if stat == events.InternetOff {
_ = s.SendEvent(NewInternetStatusEvent(false))
}
if stat == events.InternetOn {
_ = s.SendEvent(NewInternetStatusEvent(true))
}
case <-secondInstanceCh:
_ = s.SendEvent(NewShowMainWindowEvent())
case <-restartBridgeCh:
_, _ = s.Restart(
metadata.AppendToOutgoingContext(context.Background(), serverTokenMetadataKey, s.token),
&emptypb.Empty{},
)
case address := <-addressChangedCh:
_ = s.SendEvent(NewMailAddressChangeEvent(address))
case address := <-addressChangedLogoutCh:
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(address))
case userID := <-logoutCh:
user, err := s.bridge.GetUserInfo(userID)
if err != nil {
return
}
_ = s.SendEvent(NewUserDisconnectedEvent(user.Username))
case <-updateApplicationCh:
s.updateForce()
case userID := <-userChangedCh:
_ = s.SendEvent(NewUserChangedEvent(userID))
case <-certIssue:
_ = s.SendEvent(NewMailApiCertIssue())
case errors.Is(err, vault.ErrInsecure):
_ = s.SendEvent(NewKeychainHasNoKeychainEvent())
case errors.Is(err, bridge.ErrServeIMAP):
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_IMAP_PORT_ISSUE))
case errors.Is(err, bridge.ErrServeSMTP):
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_SMTP_PORT_ISSUE))
}
}
}
func (s *Service) loginAbort() {
s.loginClean()
}
for event := range eventCh {
switch event := event.(type) {
case events.ConnStatus:
_ = s.SendEvent(NewInternetStatusEvent(event.Status == liteapi.StatusUp))
func (s *Service) loginClean() {
s.auth = nil
s.authClient = nil
for i := range s.password {
s.password[i] = '\x00'
}
s.password = s.password[0:0]
}
case events.Raise:
_ = s.SendEvent(NewShowMainWindowEvent())
func (s *Service) finishLogin() {
defer s.loginClean()
case events.UserAddressCreated:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
if len(s.password) == 0 || s.auth == nil || s.authClient == nil {
s.log.
WithField("hasPass", len(s.password) != 0).
WithField("hasAuth", s.auth != nil).
WithField("hasClient", s.authClient != nil).
Error("Finish login: authentication incomplete")
case events.UserAddressChanged:
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
_ = s.SendEvent(NewLoginError(LoginErrorType_TWO_PASSWORDS_ABORT, "Missing authentication, try again."))
return
}
case events.UserAddressDeleted:
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Address))
done := make(chan string)
s.eventListener.Add(events.UserChangeDone, done)
defer s.eventListener.Remove(events.UserChangeDone, done)
case events.UserChanged:
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
userID, err := s.bridge.FinishLogin(s.authClient, s.auth, s.password)
if err != nil && err != users.ErrUserAlreadyConnected {
s.log.WithError(err).Errorf("Finish login failed")
_ = s.SendEvent(NewLoginError(LoginErrorType_TWO_PASSWORDS_ABORT, err.Error()))
return
}
// The user changed should be triggered by FinishLogin, but it is not
// guaranteed when this is going to happen. Therefor we should wait
// until we receive the signal from userChanged function.
s.waitForUserChangeDone(done, userID)
s.log.WithField("userID", userID).Debug("Login finished")
_ = s.SendEvent(NewLoginFinishedEvent(userID))
if err == users.ErrUserAlreadyConnected {
s.log.WithError(err).Error("User already logged in")
_ = s.SendEvent(NewLoginAlreadyLoggedInEvent(userID))
}
}
func (s *Service) waitForUserChangeDone(done <-chan string, userID string) {
for {
select {
case changedID := <-done:
if changedID == userID {
return
case events.UserDeauth:
if user, err := s.bridge.GetUserInfo(event.UserID); err != nil {
s.log.WithError(err).Error("Failed to get user info")
} else {
_ = s.SendEvent(NewUserDisconnectedEvent(user.Username))
}
case <-time.After(2 * time.Second):
s.log.WithField("ID", userID).Warning("Login finished but user not added within 2 seconds")
return
case events.TLSIssue:
_ = s.SendEvent(NewMailApiCertIssue())
case events.UpdateForced:
panic("TODO")
}
}
}
@ -354,103 +257,46 @@ func (s *Service) triggerReset() {
defer func() {
_ = s.SendEvent(NewResetFinishedEvent())
}()
s.bridge.FactoryReset()
if err := s.bridge.FactoryReset(context.Background()); err != nil {
s.log.WithError(err).Error("Failed to reset")
}
}
func (s *Service) checkUpdate() {
version, err := s.updater.Check()
func newTLSConfig() (*tls.Config, []byte, error) {
template, err := certs.NewTLSTemplate()
if err != nil {
s.log.WithError(err).Error("An error occurred while checking for updates")
s.SetVersion(updater.VersionInfo{})
return
}
s.SetVersion(version)
}
func (s *Service) updateForce() {
s.updateCheckMutex.Lock()
defer s.updateCheckMutex.Unlock()
s.checkUpdate()
_ = s.SendEvent(NewUpdateForceEvent(s.newVersionInfo.Version.String()))
}
func (s *Service) checkUpdateAndNotify(isReqFromUser bool) {
s.updateCheckMutex.Lock()
defer func() {
s.updateCheckMutex.Unlock()
_ = s.SendEvent(NewUpdateCheckFinishedEvent())
}()
s.checkUpdate()
version := s.newVersionInfo
if (version.Version == nil) || (version.Version.String() == "") {
if isReqFromUser {
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
}
return
}
if !s.updater.IsUpdateApplicable(s.newVersionInfo) {
s.log.Info("No need to update")
if isReqFromUser {
_ = s.SendEvent(NewUpdateIsLatestVersionEvent())
}
} else if isReqFromUser {
s.NotifyManualUpdate(s.newVersionInfo, s.updater.CanInstall(s.newVersionInfo))
}
}
func (s *Service) installUpdate() {
s.updateCheckMutex.Lock()
defer s.updateCheckMutex.Unlock()
if !s.updater.CanInstall(s.newVersionInfo) {
s.log.Warning("Skipping update installation, current version too old")
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
return
return nil, nil, fmt.Errorf("failed to create TLS template: %w", err)
}
if err := s.updater.InstallUpdate(s.newVersionInfo); err != nil {
if errors.Cause(err) == updater.ErrDownloadVerify {
s.log.WithError(err).Warning("Skipping update installation due to temporary error")
} else {
s.log.WithError(err).Error("The update couldn't be installed")
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
}
return
}
_ = s.SendEvent(NewUpdateSilentRestartNeededEvent())
}
func (s *Service) generateTLSConfig() (tlsConfig *cryptotls.Config, pemCert []byte, err error) {
pemCert, pemKey, err := tls.NewPEMKeyPair()
certPEM, keyPEM, err := certs.GenerateCert(template)
if err != nil {
return nil, nil, errors.New("Could not get TLS config")
return nil, nil, fmt.Errorf("failed to generate cert: %w", err)
}
tlsConfig, err = tls.GetConfigFromPEMKeyPair(pemCert, pemKey)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, errors.New("Could not get TLS config")
return nil, nil, fmt.Errorf("failed to load cert: %w", err)
}
tlsConfig.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it.
return tlsConfig, pemCert, nil
return &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.NoClientCert,
}, certPEM, nil
}
func (s *Service) saveGRPCServerConfigFile() (string, error) {
address, ok := s.listener.Addr().(*net.TCPAddr)
func saveGRPCServerConfigFile(locations *locations.Locations, listener net.Listener, token string, certPEM []byte) (string, error) {
address, ok := listener.Addr().(*net.TCPAddr)
if !ok {
return "", fmt.Errorf("could not retrieve gRPC service listener address")
}
sc := config{
Port: address.Port,
Cert: s.pemCert,
Token: s.token,
Cert: string(certPEM),
Token: token,
}
settingsPath, err := s.locations.ProvideSettingsPath()
settingsPath, err := locations.ProvideSettingsPath()
if err != nil {
return "", err
}
@ -461,7 +307,7 @@ func (s *Service) saveGRPCServerConfigFile() (string, error) {
}
// validateServerToken verify that the server token provided by the client is valid.
func (s *Service) validateServerToken(ctx context.Context) error {
func validateServerToken(ctx context.Context, wantToken string) error {
values, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "missing server token")
@ -476,40 +322,31 @@ func (s *Service) validateServerToken(ctx context.Context) error {
return status.Error(codes.Unauthenticated, "more than one server token was provided")
}
if token[0] != s.token {
if token[0] != wantToken {
return status.Error(codes.Unauthenticated, "invalid server token")
}
return nil
}
// validateUnaryServerToken check the server token for every unary gRPC call.
func (s *Service) validateUnaryServerToken(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (resp interface{}, err error) {
if err := s.validateServerToken(ctx); err != nil {
return nil, err
}
// newUnaryTokenValidator checks the server token for every unary gRPC call.
func newUnaryTokenValidator(wantToken string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validateServerToken(ctx, wantToken); err != nil {
return nil, err
}
return handler(ctx, req)
return handler(ctx, req)
}
}
// validateStreamServerToken check the server token for every gRPC stream request.
func (s *Service) validateStreamServerToken(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
logEntry := s.log.WithField("FullMethod", info.FullMethod)
// newStreamTokenValidator checks the server token for every gRPC stream request.
func newStreamTokenValidator(wantToken string) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := validateServerToken(stream.Context(), wantToken); err != nil {
return err
}
if err := s.validateServerToken(ss.Context()); err != nil {
logEntry.WithError(err).Error("Stream validator failed")
return err
return handler(srv, stream)
}
return handler(srv, ss)
}

View File

@ -23,15 +23,13 @@ import (
"runtime"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/theme"
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
@ -116,7 +114,7 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt
func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) {
s.log.Debug("Restart")
s.restarter.SetToRestart()
s.restarter.Set(true, false)
return s.Quit(ctx, empty)
}
@ -129,25 +127,19 @@ func (s *Service) ShowOnStartup(ctx context.Context, _ *emptypb.Empty) (*wrapper
func (s *Service) ShowSplashScreen(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("ShowSplashScreen")
if s.bridge.IsFirstStart() {
return wrapperspb.Bool(false), nil
}
ver, err := semver.NewVersion(s.bridge.GetLastVersion())
if err != nil {
s.log.WithError(err).WithField("last", s.bridge.GetLastVersion()).Debug("Cannot parse last version")
if s.bridge.GetFirstStart() {
return wrapperspb.Bool(false), nil
}
// Current splash screen contains update on rebranding. Therefore, it
// should be shown only if the last used version was less than 2.2.0.
return wrapperspb.Bool(ver.LessThan(semver.MustParse("2.2.0"))), nil
return wrapperspb.Bool(s.bridge.GetLastVersion().LessThan(semver.MustParse("2.2.0"))), nil
}
func (s *Service) IsFirstGuiStart(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsFirstGuiStart")
return wrapperspb.Bool(s.bridge.GetBool(settings.FirstStartGUIKey)), nil
return wrapperspb.Bool(s.bridge.GetFirstStartGUI()), nil
}
func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
@ -155,22 +147,16 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
defer func() { _ = s.SendEvent(NewToggleAutostartFinishedEvent()) }()
if isOn.Value == s.bridge.IsAutostartEnabled() {
if isOn.Value == s.bridge.GetAutostart() {
s.initAutostart()
return &emptypb.Empty{}, nil
}
var err error
if isOn.Value {
err = s.bridge.EnableAutostart()
} else {
err = s.bridge.DisableAutostart()
}
s.initAutostart()
if err != nil {
if err := s.bridge.SetAutostart(isOn.Value); err != nil {
s.log.WithField("makeItEnabled", isOn.Value).WithError(err).Error("Autostart change failed")
return nil, status.Errorf(codes.Internal, "failed to set autostart: %v", err)
}
return &emptypb.Empty{}, nil
@ -179,7 +165,7 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
func (s *Service) IsAutostartOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsAutostartOn")
return wrapperspb.Bool(s.bridge.IsAutostartEnabled()), nil
return wrapperspb.Bool(s.bridge.GetAutostart()), nil
}
func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
@ -190,8 +176,10 @@ func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.Bo
channel = updater.EarlyChannel
}
s.bridge.SetUpdateChannel(channel)
s.checkUpdate()
if err := s.bridge.SetUpdateChannel(channel); err != nil {
s.log.WithError(err).Error("Failed to set update channel")
return nil, status.Errorf(codes.Internal, "failed to set update channel: %v", err)
}
return &emptypb.Empty{}, nil
}
@ -205,7 +193,10 @@ func (s *Service) IsBetaEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapper
func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isVisible", isVisible.Value).Debug("SetIsAllMailVisible")
s.bridge.SetIsAllMailVisible(isVisible.Value)
if err := s.bridge.SetShowAllMail(isVisible.Value); err != nil {
s.log.WithError(err).Error("Failed to set show all mail")
return nil, status.Errorf(codes.Internal, "failed to set show all mail: %v", err)
}
return &emptypb.Empty{}, nil
}
@ -213,7 +204,7 @@ func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb
func (s *Service) IsAllMailVisible(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsAllMailVisible")
return wrapperspb.Bool(s.bridge.IsAllMailVisible()), nil
return wrapperspb.Bool(s.bridge.GetShowAllMail()), nil
}
func (s *Service) GoOs(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
@ -241,7 +232,7 @@ func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.St
func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("LogsPath")
path, err := s.bridge.ProvideLogsPath()
path, err := s.bridge.GetLogsPath()
if err != nil {
s.log.WithError(err).Error("Cannot determine logs path")
return nil, err
@ -275,7 +266,10 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
return nil, status.Error(codes.NotFound, "Color scheme not available")
}
s.bridge.Set(settings.ColorScheme, name.Value)
if err := s.bridge.SetColorScheme(name.Value); err != nil {
s.log.WithError(err).Error("Failed to set color scheme")
return nil, status.Errorf(codes.Internal, "failed to set color scheme: %v", err)
}
return &emptypb.Empty{}, nil
}
@ -283,10 +277,13 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("ColorSchemeName")
current := s.bridge.Get(settings.ColorScheme)
current := s.bridge.GetColorScheme()
if !theme.IsAvailable(theme.Theme(current)) {
current = string(theme.DefaultTheme())
s.bridge.Set(settings.ColorScheme, current)
if err := s.bridge.SetColorScheme(current); err != nil {
s.log.WithError(err).Error("Failed to set color scheme")
return nil, status.Errorf(codes.Internal, "failed to set color scheme: %v", err)
}
}
return wrapperspb.String(current), nil
@ -312,6 +309,7 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }()
if err := s.bridge.ReportBug(
context.Background(),
report.OsType,
report.OsVersion,
report.Description,
@ -331,6 +329,7 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
return &emptypb.Empty{}, nil
}
/*
func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher")
@ -350,6 +349,7 @@ func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringV
}()
return &emptypb.Empty{}, nil
}
*/
func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
s.log.WithField("username", login.Username).Debug("Login")
@ -357,135 +357,44 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
go func() {
defer s.panicHandler.HandlePanic()
var err error
s.password, err = base64.StdEncoding.DecodeString(login.Password)
password, err := base64.StdEncoding.DecodeString(login.Password)
if err != nil {
s.log.WithError(err).Error("Cannot decode password")
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot decode password"))
s.loginClean()
return
}
s.authClient, s.auth, err = s.bridge.Login(login.Username, s.password)
// TODO: Handle different error types!
// - bad credentials
// - bad proton plan
// - user already exists
userID, err := s.bridge.LoginUser(context.Background(), login.Username, string(password), nil, nil)
if err != nil {
if err == pmapi.ErrPasswordWrong {
// Remove error message since it is hardcoded in QML.
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, ""))
s.loginClean()
return
}
if err == pmapi.ErrPaidPlanRequired {
_ = s.SendEvent(NewLoginError(LoginErrorType_FREE_USER, ""))
s.loginClean()
return
}
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, err.Error()))
s.loginClean()
s.log.WithError(err).Error("Cannot login user")
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot login user"))
return
}
if s.auth.HasTwoFactor() {
_ = s.SendEvent(NewLoginTfaRequestedEvent(login.Username))
return
}
if s.auth.HasMailboxPassword() {
_ = s.SendEvent(NewLoginTwoPasswordsRequestedEvent())
return
}
s.finishLogin()
}()
return &emptypb.Empty{}, nil
}
func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
s.log.WithField("username", login.Username).Debug("Login2FA")
go func() {
defer s.panicHandler.HandlePanic()
if s.auth == nil || s.authClient == nil {
s.log.Errorf("Login 2FA: authethication incomplete %p %p", s.auth, s.authClient)
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ABORT, "Missing authentication, try again."))
s.loginClean()
return
}
twoFA, err := base64.StdEncoding.DecodeString(login.Password)
if err != nil {
s.log.WithError(err).Error("Cannot decode 2fa code")
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot decode 2fa code"))
s.loginClean()
return
}
err = s.authClient.Auth2FA(context.Background(), string(twoFA))
if err == pmapi.ErrBad2FACodeTryAgain {
s.log.Warn("Login 2FA: retry 2fa")
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ERROR, ""))
return
}
if err == pmapi.ErrBad2FACode {
s.log.Warn("Login 2FA: abort 2fa")
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ABORT, ""))
s.loginClean()
return
}
if err != nil {
s.log.WithError(err).Warn("Login 2FA: failed.")
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ABORT, err.Error()))
s.loginClean()
return
}
if s.auth.HasMailboxPassword() {
_ = s.SendEvent(NewLoginTwoPasswordsRequestedEvent())
return
}
s.finishLogin()
_ = s.SendEvent(NewLoginFinishedEvent(userID))
}()
return &emptypb.Empty{}, nil
}
func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
s.log.WithField("username", login.Username).Debug("Login2Passwords")
go func() {
defer s.panicHandler.HandlePanic()
var err error
s.password, err = base64.StdEncoding.DecodeString(login.Password)
if err != nil {
s.log.WithError(err).Error("Cannot decode mbox password")
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot decode mbox password"))
s.loginClean()
return
}
s.finishLogin()
}()
return &emptypb.Empty{}, nil
func (s *Service) Login2FA(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
panic("TODO")
}
func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
go func() {
defer s.panicHandler.HandlePanic()
s.loginAbort()
}()
return &emptypb.Empty{}, nil
func (s *Service) Login2Passwords(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
panic("TODO")
}
func (s *Service) CheckUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
func (s *Service) LoginAbort(_ context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
panic("TODO")
}
/*
func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
s.log.Debug("CheckUpdate")
go func() {
@ -507,21 +416,20 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
return &emptypb.Empty{}, nil
}
*/
func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isOn", isOn.Value).Debug("SetIsAutomaticUpdateOn")
currentlyOn := s.bridge.GetBool(settings.AutoUpdateKey)
currentlyOn := s.bridge.GetAutoUpdate()
if currentlyOn == isOn.Value {
return &emptypb.Empty{}, nil
}
s.bridge.SetBool(settings.AutoUpdateKey, isOn.Value)
go func() {
defer s.panicHandler.HandlePanic()
s.checkUpdateAndNotify(false)
}()
if err := s.bridge.SetAutoUpdate(isOn.Value); err != nil {
s.log.WithError(err).Error("Failed to set auto update")
return nil, status.Errorf(codes.Internal, "failed to set auto update: %v", err)
}
return &emptypb.Empty{}, nil
}
@ -529,51 +437,21 @@ func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.B
func (s *Service) IsAutomaticUpdateOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsAutomaticUpdateOn")
return wrapperspb.Bool(s.bridge.GetBool(settings.AutoUpdateKey)), nil
}
func (s *Service) IsCacheOnDiskEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
s.log.Debug("IsCacheOnDiskEnabled")
return wrapperspb.Bool(s.bridge.GetBool(settings.CacheEnabledKey)), nil
return wrapperspb.Bool(s.bridge.GetAutoUpdate()), nil
}
func (s *Service) DiskCachePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("DiskCachePath")
return wrapperspb.String(s.bridge.Get(settings.CacheLocationKey)), nil
return wrapperspb.String(s.bridge.GetGluonDir()), nil
}
func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCacheRequest) (*emptypb.Empty, error) {
s.log.WithField("enableDiskCache", change.EnableDiskCache).
WithField("diskCachePath", change.DiskCachePath).
Debug("DiskCachePath")
s.log.WithField("diskCachePath", change.DiskCachePath).Debug("DiskCachePath")
restart := false
defer func(willRestart *bool) {
_ = s.SendEvent(NewCacheChangeLocalCacheFinishedEvent(*willRestart))
if *willRestart {
_, _ = s.Restart(ctx, &emptypb.Empty{})
}
}(&restart)
if change.EnableDiskCache != s.bridge.GetBool(settings.CacheEnabledKey) {
if change.EnableDiskCache {
if err := s.bridge.EnableCache(); err != nil {
s.log.WithError(err).Error("Cannot enable disk cache")
} else {
restart = true
_ = s.SendEvent(NewIsCacheOnDiskEnabledChanged(s.bridge.GetBool(settings.CacheEnabledKey)))
}
} else {
if err := s.bridge.DisableCache(); err != nil {
s.log.WithError(err).Error("Cannot disable disk cache")
} else {
restart = true
_ = s.SendEvent(NewIsCacheOnDiskEnabledChanged(s.bridge.GetBool(settings.CacheEnabledKey)))
}
}
}
defer func() {
_ = s.SendEvent(NewCacheChangeLocalCacheFinishedEvent(false))
}()
path := change.DiskCachePath
//goland:noinspection GoBoolExpressions
@ -581,16 +459,14 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache
path = path[1:]
}
if change.EnableDiskCache && path != s.bridge.Get(settings.CacheLocationKey) {
if err := s.bridge.MigrateCache(s.bridge.Get(settings.CacheLocationKey), path); err != nil {
if path != s.bridge.GetGluonDir() {
if err := s.bridge.SetGluonDir(ctx, path); err != nil {
s.log.WithError(err).Error("The local cache location could not be changed.")
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_CANT_MOVE_ERROR))
return &emptypb.Empty{}, nil
}
s.bridge.Set(settings.CacheLocationKey, path)
restart = true
_ = s.SendEvent(NewDiskCachePathChanged(s.bridge.Get(settings.CacheLocationKey)))
_ = s.SendEvent(NewDiskCachePathChanged(s.bridge.GetGluonDir()))
}
_ = s.SendEvent(NewCacheLocationChangeSuccessEvent())
@ -601,7 +477,10 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache
func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsDohEnabled")
s.bridge.SetProxyAllowed(isEnabled.Value)
if err := s.bridge.SetProxyAllowed(isEnabled.Value); err != nil {
s.log.WithError(err).Error("Failed to set DoH")
return nil, status.Errorf(codes.Internal, "failed to set DoH: %v", err)
}
return &emptypb.Empty{}, nil
}
@ -615,13 +494,14 @@ func (s *Service) IsDoHEnabled(ctx context.Context, _ *emptypb.Empty) (*wrappers
func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolValue) (*emptypb.Empty, error) { //nolint:revive,stylecheck
s.log.WithField("useSsl", useSsl.Value).Debug("SetUseSslForSmtp")
if s.bridge.GetBool(settings.SMTPSSLKey) == useSsl.Value {
if s.bridge.GetSMTPSSL() == useSsl.Value {
return &emptypb.Empty{}, nil
}
s.bridge.SetBool(settings.SMTPSSLKey, useSsl.Value)
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
if err := s.bridge.SetSMTPSSL(useSsl.Value); err != nil {
s.log.WithError(err).Error("Failed to set SMTP SSL")
return nil, status.Errorf(codes.Internal, "failed to set SMTP SSL: %v", err)
}
return &emptypb.Empty{}, s.SendEvent(NewMailSettingsUseSslForSmtpFinishedEvent())
}
@ -629,34 +509,39 @@ func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolV
func (s *Service) UseSslForSmtp(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { //nolint:revive,stylecheck
s.log.Debug("UseSslForSmtp")
return wrapperspb.Bool(s.bridge.GetBool(settings.SMTPSSLKey)), nil
return wrapperspb.Bool(s.bridge.GetSMTPSSL()), nil
}
func (s *Service) Hostname(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("Hostname")
return wrapperspb.String(bridge.Host), nil
return wrapperspb.String(constants.Host), nil
}
func (s *Service) ImapPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) {
s.log.Debug("ImapPort")
return wrapperspb.Int32(int32(s.bridge.GetInt(settings.IMAPPortKey))), nil
return wrapperspb.Int32(int32(s.bridge.GetIMAPPort())), nil
}
func (s *Service) SmtpPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) { //nolint:revive,stylecheck
s.log.Debug("SmtpPort")
return wrapperspb.Int32(int32(s.bridge.GetInt(settings.SMTPPortKey))), nil
return wrapperspb.Int32(int32(s.bridge.GetSMTPPort())), nil
}
func (s *Service) ChangePorts(ctx context.Context, ports *ChangePortsRequest) (*emptypb.Empty, error) {
s.log.WithField("imapPort", ports.ImapPort).WithField("smtpPort", ports.SmtpPort).Debug("ChangePorts")
s.bridge.SetInt(settings.IMAPPortKey, int(ports.ImapPort))
s.bridge.SetInt(settings.SMTPPortKey, int(ports.SmtpPort))
if err := s.bridge.SetIMAPPort(int(ports.ImapPort)); err != nil {
s.log.WithError(err).Error("Failed to set IMAP port")
return nil, status.Errorf(codes.Internal, "failed to set IMAP port: %v", err)
}
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
if err := s.bridge.SetSMTPPort(int(ports.SmtpPort)); err != nil {
s.log.WithError(err).Error("Failed to set SMTP port")
return nil, status.Errorf(codes.Internal, "failed to set SMTP port: %v", err)
}
return &emptypb.Empty{}, s.SendEvent(NewMailSettingsChangePortFinishedEvent())
}
@ -670,12 +555,7 @@ func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (
func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
s.log.Debug("AvailableKeychains")
keychains := make([]string, 0, len(keychain.Helpers))
for chain := range keychain.Helpers {
keychains = append(keychains, chain)
}
return &AvailableKeychainsResponse{Keychains: keychains}, nil
return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil
}
func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) {
@ -684,11 +564,20 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
defer func() { _ = s.SendEvent(NewKeychainChangeKeychainFinishedEvent()) }()
if s.bridge.GetKeychainApp() == keychain.Value {
helper, err := s.bridge.GetKeychainApp()
if err != nil {
s.log.WithError(err).Error("Failed to get current keychain")
return nil, status.Errorf(codes.Internal, "failed to get current keychain: %v", err)
}
if helper == keychain.Value {
return &emptypb.Empty{}, nil
}
s.bridge.SetKeychainApp(keychain.Value)
if err := s.bridge.SetKeychainApp(keychain.Value); err != nil {
s.log.WithError(err).Error("Failed to set keychain")
return nil, status.Errorf(codes.Internal, "failed to set keychain: %v", err)
}
return &emptypb.Empty{}, nil
}
@ -696,5 +585,11 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
func (s *Service) CurrentKeychain(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
s.log.Debug("CurrentKeychain")
return wrapperspb.String(s.bridge.GetKeychainApp()), nil
helper, err := s.bridge.GetKeychainApp()
if err != nil {
s.log.WithError(err).Error("Failed to get current keychain")
return nil, status.Errorf(codes.Internal, "failed to get current keychain: %v", err)
}
return wrapperspb.String(helper), nil
}

View File

@ -87,12 +87,8 @@ func (s *Service) StopEventStream(ctx context.Context, _ *emptypb.Empty) (*empty
// SendEvent sends an event to the via the gRPC event stream.
func (s *Service) SendEvent(event *StreamEvent) error {
s.eventQueueMutex.Lock()
defer s.eventQueueMutex.Unlock()
if s.eventStreamCh == nil {
// nobody is connected to the event stream, we queue events
s.eventQueue = append(s.eventQueue, event)
if s.eventStreamCh == nil { // nobody is connected to the event stream, we queue events
s.queueEvent(event)
return nil
}
@ -167,3 +163,14 @@ func (s *Service) StartEventTest() error { //nolint:funlen
return nil
}
func (s *Service) queueEvent(event *StreamEvent) {
s.eventQueueMutex.Lock()
defer s.eventQueueMutex.Unlock()
if event.isInternetStatus() {
s.eventQueue = append(filterOutInternetStatusEvents(s.eventQueue), event)
} else {
s.eventQueue = append(s.eventQueue, event)
}
}

View File

@ -0,0 +1,68 @@
package grpc
/*
func (s *Service) checkUpdate() {
version, err := s.updater.Check()
if err != nil {
s.log.WithError(err).Error("An error occurred while checking for updates")
s.SetVersion(updater.VersionInfo{})
return
}
s.SetVersion(version)
}
func (s *Service) updateForce() {
s.updateCheckMutex.Lock()
defer s.updateCheckMutex.Unlock()
s.checkUpdate()
_ = s.SendEvent(NewUpdateForceEvent(s.newVersionInfo.Version.String()))
}
func (s *Service) checkUpdateAndNotify(isReqFromUser bool) {
s.updateCheckMutex.Lock()
defer func() {
s.updateCheckMutex.Unlock()
_ = s.SendEvent(NewUpdateCheckFinishedEvent())
}()
s.checkUpdate()
version := s.newVersionInfo
if version.Version.String() == "" {
if isReqFromUser {
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
}
return
}
if !s.updater.IsUpdateApplicable(s.newVersionInfo) {
s.log.Info("No need to update")
if isReqFromUser {
_ = s.SendEvent(NewUpdateIsLatestVersionEvent())
}
} else if isReqFromUser {
s.NotifyManualUpdate(s.newVersionInfo, s.updater.CanInstall(s.newVersionInfo))
}
}
func (s *Service) installUpdate() {
s.updateCheckMutex.Lock()
defer s.updateCheckMutex.Unlock()
if !s.updater.CanInstall(s.newVersionInfo) {
s.log.Warning("Skipping update installation, current version too old")
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
return
}
if err := s.updater.InstallUpdate(s.newVersionInfo); err != nil {
if errors.Cause(err) == updater.ErrDownloadVerify {
s.log.WithError(err).Warning("Skipping update installation due to temporary error")
} else {
s.log.WithError(err).Error("The update couldn't be installed")
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
}
return
}
_ = s.SendEvent(NewUpdateSilentRestartNeededEvent())
}
*/

Some files were not shown because too many files have changed in this diff Show More