mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-24 19:06:44 +00:00
GODT-1779: Remove go-imap
This commit is contained in:
@ -1,85 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package api provides HTTP API of the Bridge.
|
||||
//
|
||||
// API endpoints:
|
||||
// - /focus, see focusHandler
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var log = logrus.WithField("pkg", "api") //nolint:gochecknoglobals
|
||||
|
||||
type apiServer struct {
|
||||
host string
|
||||
settings *settings.Settings
|
||||
eventListener listener.Listener
|
||||
}
|
||||
|
||||
// NewAPIServer returns prepared API server struct.
|
||||
func NewAPIServer(settings *settings.Settings, eventListener listener.Listener) *apiServer { //nolint:revive
|
||||
return &apiServer{
|
||||
host: bridge.Host,
|
||||
settings: settings,
|
||||
eventListener: eventListener,
|
||||
}
|
||||
}
|
||||
|
||||
// Starts the server.
|
||||
func (api *apiServer) ListenAndServe() {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/focus", wrapper(api, focusHandler))
|
||||
|
||||
addr := api.getAddress()
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second, // fix gosec G112 (vulnerability to [Slowloris](https://www.cloudflare.com/en-gb/learning/ddos/ddos-attack-tools/slowloris/) attack).
|
||||
}
|
||||
|
||||
log.Info("API listening at ", addr)
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
api.eventListener.Emit(events.ErrorEvent, "API failed: "+err.Error())
|
||||
log.Error("API failed: ", err)
|
||||
}
|
||||
defer server.Close() //nolint:errcheck
|
||||
}
|
||||
|
||||
func (api *apiServer) getAddress() string {
|
||||
port := api.settings.GetInt(settings.APIPortKey)
|
||||
newPort := ports.FindFreePortFrom(port)
|
||||
if newPort != port {
|
||||
api.settings.SetInt(settings.APIPortKey, newPort)
|
||||
}
|
||||
return getAPIAddress(api.host, newPort)
|
||||
}
|
||||
|
||||
func getAPIAddress(host string, port int) string {
|
||||
return fmt.Sprintf("%s:%d", host, port)
|
||||
}
|
||||
@ -1,51 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
)
|
||||
|
||||
// httpHandler with Go's Response and Request.
|
||||
type httpHandler func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// handler with our context.
|
||||
type handler func(handlerContext) error
|
||||
|
||||
type handlerContext struct {
|
||||
req *http.Request
|
||||
resp http.ResponseWriter
|
||||
eventListener listener.Listener
|
||||
}
|
||||
|
||||
func wrapper(api *apiServer, callback handler) httpHandler {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := handlerContext{
|
||||
req: req,
|
||||
resp: w,
|
||||
eventListener: api.eventListener,
|
||||
}
|
||||
err := callback(ctx)
|
||||
if err != nil {
|
||||
log.Error("API callback of ", req.URL, " failed: ", err)
|
||||
http.Error(w, err.Error(), 500)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,51 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
)
|
||||
|
||||
// focusHandler should be called from other instances (attempt to start bridge
|
||||
// for the second time) to get focus in the currently running instance.
|
||||
func focusHandler(ctx handlerContext) error {
|
||||
log.Info("Focus from other instance")
|
||||
ctx.eventListener.Emit(events.SecondInstanceEvent, "")
|
||||
fmt.Fprintf(ctx.resp, "OK")
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckOtherInstanceAndFocus is helper for new instances to check if there is
|
||||
// already a running instance and get it's focus.
|
||||
func CheckOtherInstanceAndFocus(port int) error {
|
||||
addr := getAPIAddress(bridge.Host, port)
|
||||
resp, err := (&http.Client{}).Get("http://" + addr + "/focus")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
log.Error("Focus error: ", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
149
internal/app/app.go
Normal file
149
internal/app/app.go
Normal file
@ -0,0 +1,149 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
||||
bridgeCLI "github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/restarter"
|
||||
"github.com/pkg/profile"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
flagCPUProfile = "cpu-prof"
|
||||
flagCPUProfileShort = "p"
|
||||
|
||||
flagMemProfile = "mem-prof"
|
||||
flagMemProfileShort = "m"
|
||||
|
||||
flagLogLevel = "log-level"
|
||||
flagLogLevelShort = "l"
|
||||
|
||||
flagCLI = "cli"
|
||||
flagCLIShort = "c"
|
||||
|
||||
flagNoWindow = "no-window"
|
||||
flagNonInteractive = "non-interactive"
|
||||
)
|
||||
|
||||
const (
|
||||
appUsage = "Proton Mail IMAP and SMTP Bridge"
|
||||
)
|
||||
|
||||
func New() *cli.App {
|
||||
app := cli.NewApp()
|
||||
|
||||
app.Name = constants.FullAppName
|
||||
app.Usage = appUsage
|
||||
app.Flags = []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: flagCPUProfile,
|
||||
Aliases: []string{flagCPUProfileShort},
|
||||
Usage: "Generate CPU profile",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagMemProfile,
|
||||
Aliases: []string{flagMemProfileShort},
|
||||
Usage: "Generate memory profile",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagLogLevel,
|
||||
Aliases: []string{flagLogLevelShort},
|
||||
Usage: "Set the log level (one of panic, fatal, error, warn, info, debug)",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagCLI,
|
||||
Aliases: []string{flagCLIShort},
|
||||
Usage: "Use command line interface",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagNoWindow,
|
||||
Usage: "Don't show window after start",
|
||||
Hidden: true,
|
||||
},
|
||||
}
|
||||
|
||||
app.Action = run
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func run(c *cli.Context) error {
|
||||
// If there's another instance already running, try to raise it and exit.
|
||||
if raised := focus.TryRaise(); raised {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start CPU profile if requested.
|
||||
if c.Bool(flagCPUProfile) {
|
||||
p := profile.Start(profile.CPUProfile, profile.ProfilePath("cpu.pprof"))
|
||||
defer p.Stop()
|
||||
}
|
||||
|
||||
// Start memory profile if requested.
|
||||
if c.Bool(flagMemProfile) {
|
||||
p := profile.Start(profile.MemProfile, profile.MemProfileAllocs, profile.ProfilePath("mem.pprof"))
|
||||
defer p.Stop()
|
||||
}
|
||||
|
||||
// Create the restarter.
|
||||
restarter := restarter.New()
|
||||
defer restarter.Restart()
|
||||
|
||||
// Create a user agent that will be used for all requests.
|
||||
identifier := useragent.New()
|
||||
|
||||
// Create a crash handler that will send crash reports to sentry.
|
||||
crashHandler := crash.NewHandler(
|
||||
sentry.NewReporter(constants.FullAppName, constants.Version, identifier).ReportException,
|
||||
crash.ShowErrorNotification(constants.FullAppName),
|
||||
func(r interface{}) error { restarter.Set(true, true); return nil },
|
||||
)
|
||||
defer crashHandler.HandlePanic()
|
||||
|
||||
// Create a locations provider to determine where to store our files.
|
||||
provider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, constants.ConfigName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create locations provider: %w", err)
|
||||
}
|
||||
|
||||
// Create a new locations object that will be used to provide paths to store files.
|
||||
locations := locations.New(provider, constants.ConfigName)
|
||||
|
||||
// Initialize the logging.
|
||||
if err := initLogging(c, locations, crashHandler); err != nil {
|
||||
return fmt.Errorf("could not initialize logging: %w", err)
|
||||
}
|
||||
|
||||
// Create the bridge.
|
||||
bridge, err := newBridge(locations, identifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create bridge: %w", err)
|
||||
}
|
||||
defer bridge.Close(c.Context)
|
||||
|
||||
// Start the frontend.
|
||||
switch {
|
||||
case c.Bool(flagCLI):
|
||||
return bridgeCLI.New(bridge).Loop()
|
||||
|
||||
case c.Bool(flagNonInteractive):
|
||||
select {}
|
||||
|
||||
default:
|
||||
service, err := grpc.NewService(crashHandler, restarter, locations, bridge, !c.Bool(flagNoWindow))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create service: %w", err)
|
||||
}
|
||||
|
||||
return service.Loop()
|
||||
}
|
||||
}
|
||||
@ -1,35 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package base
|
||||
|
||||
import "strings"
|
||||
|
||||
// StripProcessSerialNumber removes additional flag from macOS.
|
||||
// More info:
|
||||
// http://mirror.informatimago.com/next/developer.apple.com/documentation/Carbon/Reference/Process_Manager/prmref_main/data_type_5.html#//apple_ref/doc/uid/TP30000208/C001951
|
||||
func StripProcessSerialNumber(args []string) []string {
|
||||
res := args[:0]
|
||||
|
||||
for _, arg := range args {
|
||||
if !strings.Contains(arg, "-psn_") {
|
||||
res = append(res, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
@ -1,424 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package base implements a common application base currently shared by bridge and IE.
|
||||
// The base includes the following:
|
||||
// - access to standard filesystem locations like config, cache, logging dirs
|
||||
// - an extensible crash handler
|
||||
// - versioned cache directory
|
||||
// - persistent settings
|
||||
// - event listener
|
||||
// - credentials store
|
||||
// - pmapi Manager
|
||||
//
|
||||
// In addition, the base initialises logging and reacts to command line arguments
|
||||
// which control the log verbosity and enable cpu/memory profiling.
|
||||
package base
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/go-autostart"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/api"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/cache"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users/credentials"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/versioner"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
flagCPUProfile = "cpu-prof"
|
||||
flagCPUProfileShort = "p"
|
||||
flagMemProfile = "mem-prof"
|
||||
flagMemProfileShort = "m"
|
||||
flagLogLevel = "log-level"
|
||||
flagLogLevelShort = "l"
|
||||
// FlagCLI indicate to start with command line interface.
|
||||
FlagCLI = "cli"
|
||||
flagCLIShort = "c"
|
||||
flagRestart = "restart"
|
||||
FlagLauncher = "launcher"
|
||||
FlagNoWindow = "no-window"
|
||||
)
|
||||
|
||||
type Base struct {
|
||||
SentryReporter *sentry.Reporter
|
||||
CrashHandler *crash.Handler
|
||||
Locations *locations.Locations
|
||||
Settings *settings.Settings
|
||||
Lock *os.File
|
||||
Cache *cache.Cache
|
||||
Listener listener.Listener
|
||||
Creds *credentials.Store
|
||||
CM pmapi.Manager
|
||||
CookieJar *cookies.Jar
|
||||
UserAgent *useragent.UserAgent
|
||||
Updater *updater.Updater
|
||||
Versioner *versioner.Versioner
|
||||
TLS *tls.TLS
|
||||
Autostart *autostart.App
|
||||
|
||||
Name string // the app's name
|
||||
usage string // the app's usage description
|
||||
command string // the command used to launch the app (either the exe path or the launcher path)
|
||||
restart bool // whether the app is currently set to restart
|
||||
launcher string // launcher to be used if not set in args
|
||||
mainExecutable string // mainExecutable the main executable process.
|
||||
|
||||
teardown []func() error // actions to perform when app is exiting
|
||||
}
|
||||
|
||||
func New( //nolint:funlen
|
||||
appName,
|
||||
appUsage,
|
||||
configName,
|
||||
updateURLName,
|
||||
keychainName,
|
||||
cacheVersion string,
|
||||
) (*Base, error) {
|
||||
userAgent := useragent.New()
|
||||
|
||||
sentryReporter := sentry.NewReporter(appName, constants.Version, userAgent)
|
||||
|
||||
crashHandler := crash.NewHandler(
|
||||
sentryReporter.ReportException,
|
||||
crash.ShowErrorNotification(appName),
|
||||
)
|
||||
defer crashHandler.HandlePanic()
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
os.Args = StripProcessSerialNumber(os.Args)
|
||||
|
||||
locationsProvider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, configName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locations := locations.New(locationsProvider, configName)
|
||||
|
||||
logsPath, err := locations.ProvideLogsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := logging.Init(logsPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
crashHandler.AddRecoveryAction(logging.DumpStackTrace(logsPath))
|
||||
|
||||
if err := migrateFiles(configName); err != nil {
|
||||
logrus.WithError(err).Warn("Old config files could not be migrated")
|
||||
}
|
||||
|
||||
if err := locations.Clean(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settingsPath, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
settingsObj := settings.New(settingsPath)
|
||||
|
||||
lock, err := checkSingleInstance(locations.GetLockFile(), settingsObj)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warnf("%v is already running", appName)
|
||||
return nil, api.CheckOtherInstanceAndFocus(settingsObj.GetInt(settings.APIPortKey))
|
||||
}
|
||||
|
||||
if err := migrateRebranding(settingsObj, keychainName); err != nil {
|
||||
logrus.WithError(err).Warn("Rebranding migration failed")
|
||||
}
|
||||
|
||||
cachePath, err := locations.ProvideCachePath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache, err := cache.New(cachePath, cacheVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := cache.RemoveOldVersions(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener := listener.New()
|
||||
events.SetupEvents(listener)
|
||||
|
||||
// If we can't load the keychain for whatever reason,
|
||||
// we signal to frontend and supply a dummy keychain that always returns errors.
|
||||
kc, err := keychain.NewKeychain(settingsObj, keychainName)
|
||||
if err != nil {
|
||||
listener.Emit(events.CredentialsErrorEvent, err.Error())
|
||||
kc = keychain.NewMissingKeychain()
|
||||
}
|
||||
|
||||
cfg := pmapi.NewConfig(configName, constants.Version)
|
||||
cfg.GetUserAgent = userAgent.String
|
||||
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
|
||||
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
|
||||
|
||||
cm := pmapi.New(cfg)
|
||||
|
||||
sentryReporter.SetClientFromManager(cm)
|
||||
|
||||
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
|
||||
func() { listener.Emit(events.InternetConnChangedEvent, events.InternetOff) },
|
||||
func() { listener.Emit(events.InternetConnChangedEvent, events.InternetOn) },
|
||||
))
|
||||
|
||||
jar, err := cookies.NewCookieJar(settingsObj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm.SetCookieJar(jar)
|
||||
|
||||
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kr, err := crypto.NewKeyRing(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updatesDir, err := locations.ProvideUpdatesPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
versioner := versioner.New(updatesDir)
|
||||
installer := updater.NewInstaller(versioner)
|
||||
updater := updater.New(
|
||||
cm,
|
||||
installer,
|
||||
settingsObj,
|
||||
kr,
|
||||
semver.MustParse(constants.Version),
|
||||
updateURLName,
|
||||
runtime.GOOS,
|
||||
)
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
autostart := &autostart.App{
|
||||
Name: startupNameForRebranding(appName),
|
||||
DisplayName: appName,
|
||||
Exec: []string{exe, "--" + FlagNoWindow},
|
||||
}
|
||||
|
||||
return &Base{
|
||||
SentryReporter: sentryReporter,
|
||||
CrashHandler: crashHandler,
|
||||
Locations: locations,
|
||||
Settings: settingsObj,
|
||||
Lock: lock,
|
||||
Cache: cache,
|
||||
Listener: listener,
|
||||
Creds: credentials.NewStore(kc),
|
||||
CM: cm,
|
||||
CookieJar: jar,
|
||||
UserAgent: userAgent,
|
||||
Updater: updater,
|
||||
Versioner: versioner,
|
||||
TLS: tls.New(settingsPath),
|
||||
Autostart: autostart,
|
||||
|
||||
Name: appName,
|
||||
usage: appUsage,
|
||||
|
||||
// By default, the command is the app's executable.
|
||||
// This can be changed at runtime by using the "--launcher" flag.
|
||||
command: exe,
|
||||
// By default, the command is the app's executable.
|
||||
// This can be changed at runtime by summoning the SetMainExecutable gRPC call.
|
||||
mainExecutable: exe,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Base) NewApp(mainLoop func(*Base, *cli.Context) error) *cli.App {
|
||||
app := cli.NewApp()
|
||||
|
||||
app.Name = b.Name
|
||||
app.Usage = b.usage
|
||||
app.Version = constants.Version
|
||||
app.Action = b.wrapMainLoop(mainLoop)
|
||||
app.Flags = []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: flagCPUProfile,
|
||||
Aliases: []string{flagCPUProfileShort},
|
||||
Usage: "Generate CPU profile",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagMemProfile,
|
||||
Aliases: []string{flagMemProfileShort},
|
||||
Usage: "Generate memory profile",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagLogLevel,
|
||||
Aliases: []string{flagLogLevelShort},
|
||||
Usage: "Set the log level (one of panic, fatal, error, warn, info, debug)",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: FlagCLI,
|
||||
Aliases: []string{flagCLIShort},
|
||||
Usage: "Use command line interface",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: FlagNoWindow,
|
||||
Usage: "Don't show window after start",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: flagRestart,
|
||||
Usage: "The number of times the application has already restarted",
|
||||
Hidden: true,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: FlagLauncher,
|
||||
Usage: "The launcher to use to restart the application",
|
||||
Hidden: true,
|
||||
},
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
// SetToRestart sets the app to restart the next time it is closed.
|
||||
func (b *Base) SetToRestart() {
|
||||
b.restart = true
|
||||
}
|
||||
|
||||
func (b *Base) ForceLauncher(launcher string) {
|
||||
b.launcher = launcher
|
||||
b.setupLauncher(launcher)
|
||||
}
|
||||
|
||||
func (b *Base) SetMainExecutable(exe string) {
|
||||
logrus.Info("Main Executable set to ", exe)
|
||||
b.mainExecutable = exe
|
||||
}
|
||||
|
||||
// AddTeardownAction adds an action to perform during app teardown.
|
||||
func (b *Base) AddTeardownAction(fn func() error) {
|
||||
b.teardown = append(b.teardown, fn)
|
||||
}
|
||||
|
||||
func (b *Base) wrapMainLoop(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc { //nolint:funlen
|
||||
return func(c *cli.Context) error {
|
||||
defer b.CrashHandler.HandlePanic()
|
||||
defer func() { _ = b.Lock.Close() }()
|
||||
|
||||
// If launcher was used to start the app, use that for restart
|
||||
// and autostart.
|
||||
if launcher := c.String(FlagLauncher); launcher != "" {
|
||||
b.setupLauncher(launcher)
|
||||
}
|
||||
|
||||
if c.Bool(flagCPUProfile) {
|
||||
startCPUProfile()
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
if c.Bool(flagMemProfile) {
|
||||
defer makeMemoryProfile()
|
||||
}
|
||||
|
||||
logging.SetLevel(c.String(flagLogLevel))
|
||||
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
|
||||
|
||||
logrus.
|
||||
WithField("appName", b.Name).
|
||||
WithField("version", constants.Version).
|
||||
WithField("revision", constants.Revision).
|
||||
WithField("build", constants.BuildTime).
|
||||
WithField("runtime", runtime.GOOS).
|
||||
WithField("args", os.Args).
|
||||
Info("Run app")
|
||||
|
||||
b.CrashHandler.AddRecoveryAction(func(interface{}) error {
|
||||
sentry.Flush(2 * time.Second)
|
||||
|
||||
if c.Int(flagRestart) > maxAllowedRestarts {
|
||||
logrus.
|
||||
WithField("restart", c.Int("restart")).
|
||||
Warn("Not restarting, already restarted too many times")
|
||||
os.Exit(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return b.restartApp(true)
|
||||
})
|
||||
|
||||
if err := appMainLoop(b, c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.doTeardown(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b.restart {
|
||||
return b.restartApp(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Base) doTeardown() error {
|
||||
for _, action := range b.teardown {
|
||||
if err := action(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Base) setupLauncher(launcher string) {
|
||||
b.command = launcher
|
||||
// Bridge supports no-window option which we should use
|
||||
// for autostart.
|
||||
b.Autostart.Exec = []string{launcher, "--" + FlagNoWindow}
|
||||
}
|
||||
@ -1,131 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// migrateFiles migrates files from their old (pre-refactor) locations to their new locations.
|
||||
// We can remove this eventually.
|
||||
//
|
||||
// | entity | old location | new location |
|
||||
// |-----------|-------------------------------------------|----------------------------------------|
|
||||
// | prefs | ~/.cache/protonmail/<app>/c11/prefs.json | ~/.config/protonmail/<app>/prefs.json |
|
||||
// | c11 1.5.x | ~/.cache/protonmail/<app>/c11 | ~/.cache/protonmail/<app>/cache/c11 |
|
||||
// | c11 1.6.x | ~/.cache/protonmail/<app>/cache/c11 | ~/.config/protonmail/<app>/cache/c11 |
|
||||
// | updates | ~/.cache/protonmail/<app>/updates | ~/.config/protonmail/<app>/updates |.
|
||||
func migrateFiles(configName string) error {
|
||||
locationsProvider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, configName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
locations := locations.New(locationsProvider, configName)
|
||||
userCacheDir := locationsProvider.UserCache()
|
||||
|
||||
if err := migratePrefsFrom15x(locations, userCacheDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := migrateCacheFromBoth15xAnd16x(locations, userCacheDir); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := migrateUpdatesFrom16x(configName, locations); err != nil { //nolint:revive It is more clear to structure this way
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migratePrefsFrom15x(locations *locations.Locations, userCacheDir string) error {
|
||||
newSettingsDir, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return moveIfExists(
|
||||
filepath.Join(userCacheDir, "c11", "prefs.json"),
|
||||
filepath.Join(newSettingsDir, "prefs.json"),
|
||||
)
|
||||
}
|
||||
|
||||
func migrateCacheFromBoth15xAnd16x(locations *locations.Locations, userCacheDir string) error {
|
||||
olderCacheDir := userCacheDir
|
||||
newerCacheDir := locations.GetOldCachePath()
|
||||
latestCacheDir, err := locations.ProvideCachePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Migration for versions before 1.6.x.
|
||||
if err := moveIfExists(
|
||||
filepath.Join(olderCacheDir, "c11"),
|
||||
filepath.Join(latestCacheDir, "c11"),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Migration for versions 1.6.x.
|
||||
return moveIfExists(
|
||||
filepath.Join(newerCacheDir, "c11"),
|
||||
filepath.Join(latestCacheDir, "c11"),
|
||||
)
|
||||
}
|
||||
|
||||
func migrateUpdatesFrom16x(configName string, locations *locations.Locations) error {
|
||||
// In order to properly update Bridge 1.6.X and higher we need to
|
||||
// change the launcher first. Since this is not part of automatic
|
||||
// updates the migration must wait until manual update. Until that
|
||||
// we need to keep old path.
|
||||
if configName == "bridge" {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldUpdatesPath := locations.GetOldUpdatesPath()
|
||||
// Do not use ProvideUpdatesPath, that creates dir right away.
|
||||
newUpdatesPath := locations.GetUpdatesPath()
|
||||
|
||||
return moveIfExists(oldUpdatesPath, newUpdatesPath)
|
||||
}
|
||||
|
||||
func moveIfExists(source, destination string) error {
|
||||
l := logrus.WithField("source", source).WithField("destination", destination)
|
||||
|
||||
if _, err := os.Stat(source); os.IsNotExist(err) {
|
||||
l.Info("No need to migrate file, source doesn't exist")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := os.Stat(destination); !os.IsNotExist(err) {
|
||||
// Once migrated, files should not stay in source anymore. Therefore
|
||||
// if some files are still in source location but target already exist,
|
||||
// it's suspicious. Could happen by installing new version, then the
|
||||
// old one because of some reason, and then the new one again.
|
||||
// Good to see as warning because it could be a reason why Bridge is
|
||||
// behaving weirdly, like wrong configuration, or db re-sync and so on.
|
||||
l.Warn("No need to migrate file, target already exists")
|
||||
return nil
|
||||
}
|
||||
|
||||
l.Info("Migrating files")
|
||||
return os.Rename(source, destination)
|
||||
}
|
||||
@ -1,197 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const darwin = "darwin"
|
||||
|
||||
func migrateRebranding(settingsObj *settings.Settings, keychainName string) (result error) {
|
||||
if err := migrateStartupBeforeRebranding(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
lastUsedVersion := settingsObj.Get(settings.LastVersionKey)
|
||||
|
||||
// Skipping migration: it is first bridge start or cache was cleared.
|
||||
if lastUsedVersion == "" {
|
||||
settingsObj.SetBool(settings.RebrandingMigrationKey, true)
|
||||
return
|
||||
}
|
||||
|
||||
// Skipping rest of migration: already done
|
||||
if settingsObj.GetBool(settings.RebrandingMigrationKey) {
|
||||
return
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows", "linux":
|
||||
// GODT-1260 we would need admin rights to changes desktop files
|
||||
// and start menu items.
|
||||
settingsObj.SetBool(settings.RebrandingMigrationKey, true)
|
||||
case darwin:
|
||||
if shouldContinue, err := isMacBeforeRebranding(); !shouldContinue || err != nil {
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err := migrateMacKeychainBeforeRebranding(settingsObj, keychainName); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
settingsObj.SetBool(settings.RebrandingMigrationKey, true)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// migrateMacKeychainBeforeRebranding deals with write access restriction to
|
||||
// mac keychain passwords which are caused by application renaming. The old
|
||||
// passwords are copied under new name in order to have write access afer
|
||||
// renaming.
|
||||
func migrateMacKeychainBeforeRebranding(settingsObj *settings.Settings, keychainName string) error {
|
||||
l := logrus.WithField("pkg", "app/base/migration")
|
||||
l.Warn("Migrating mac keychain")
|
||||
|
||||
helperConstructor, ok := keychain.Helpers["macos-keychain"]
|
||||
if !ok {
|
||||
return errors.New("cannot find macos-keychain helper")
|
||||
}
|
||||
|
||||
oldKC, err := helperConstructor("ProtonMailBridgeService")
|
||||
if err != nil {
|
||||
l.WithError(err).Error("Keychain constructor failed")
|
||||
return err
|
||||
}
|
||||
|
||||
idByURL, err := oldKC.List()
|
||||
if err != nil {
|
||||
l.WithError(err).Error("List old keychain failed")
|
||||
return err
|
||||
}
|
||||
|
||||
newKC, err := keychain.NewKeychain(settingsObj, keychainName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for url, id := range idByURL {
|
||||
li := l.WithField("id", id).WithField("url", url)
|
||||
userID, secret, err := oldKC.Get(url)
|
||||
if err != nil {
|
||||
li.WithField("userID", userID).
|
||||
WithField("err", err).
|
||||
Error("Faild to get old item")
|
||||
continue
|
||||
}
|
||||
|
||||
if _, _, err := newKC.Get(userID); err == nil {
|
||||
li.Warn("Skipping migration, item already exists.")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := newKC.Put(userID, secret); err != nil {
|
||||
li.WithError(err).Error("Failed to migrate user")
|
||||
}
|
||||
|
||||
li.Info("Item migrated")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateStartupBeforeRebranding removes old startup links. The creation of new links is
|
||||
// handled by bridge initialisation.
|
||||
func migrateStartupBeforeRebranding() error {
|
||||
path, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
path = filepath.Join(path, `AppData\Roaming\Microsoft\Windows\Start Menu\Programs\Startup\ProtonMail Bridge.lnk`)
|
||||
case "linux":
|
||||
path = filepath.Join(path, `.config/autostart/ProtonMail Bridge.desktop`)
|
||||
case darwin:
|
||||
path = filepath.Join(path, `Library/LaunchAgents/ProtonMail Bridge.plist`)
|
||||
default:
|
||||
return errors.New("unknown GOOS")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
logrus.WithField("pkg", "app/base/migration").Warn("Migrating autostartup links")
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// startupNameForRebranding returns the name for autostart launcher based on
|
||||
// type of rebranded instance i.e. update or manual.
|
||||
//
|
||||
// This only affects darwin when udpate re-writes the old startup and then
|
||||
// manual installed it would not run proper exe. Therefore we return "old" name
|
||||
// for updates and "new" name for manual which would be properly migrated.
|
||||
//
|
||||
// For orther (linux and windows) the link is always pointing to launcher which
|
||||
// path didn't changed.
|
||||
func startupNameForRebranding(origin string) string {
|
||||
if runtime.GOOS == darwin {
|
||||
if path, err := os.Executable(); err == nil && strings.Contains(path, "ProtonMail Bridge") {
|
||||
return "ProtonMail Bridge"
|
||||
}
|
||||
}
|
||||
|
||||
// No need to solve for other OS. See comment above.
|
||||
return origin
|
||||
}
|
||||
|
||||
// isBeforeRebranding decide if last used version was older than 2.2.0. If
|
||||
// cannot decide it returns false with error.
|
||||
func isMacBeforeRebranding() (bool, error) {
|
||||
// previous version | update | do mac migration |
|
||||
// | first | false |
|
||||
// cleared-cache | manual | false |
|
||||
// cleared-cache | in-app | false |
|
||||
// old | in-app | false |
|
||||
// old in-app | in-app | false |
|
||||
// old | manual | true |
|
||||
// old in-app | manual | true |
|
||||
// manual | in-app | false |
|
||||
|
||||
// Skip if it was in-app update and not manual
|
||||
if path, err := os.Executable(); err != nil || strings.Contains(path, "ProtonMail Bridge") {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@ -1,56 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// startCPUProfile starts CPU pprof.
|
||||
func startCPUProfile() {
|
||||
f, err := os.Create("./cpu.pprof")
|
||||
if err != nil {
|
||||
logrus.Fatal("Could not create CPU profile: ", err)
|
||||
}
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
logrus.Fatal("Could not start CPU profile: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
// makeMemoryProfile generates memory pprof.
|
||||
func makeMemoryProfile() {
|
||||
name := "./mem.pprof"
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
logrus.Fatal("Could not create memory profile: ", err)
|
||||
}
|
||||
if abs, err := filepath.Abs(name); err == nil {
|
||||
name = abs
|
||||
}
|
||||
logrus.Info("Writing memory profile to ", name)
|
||||
runtime.GC() // get up-to-date statistics
|
||||
if err := pprof.WriteHeapProfile(f); err != nil {
|
||||
logrus.Fatal("Could not write memory profile: ", err)
|
||||
}
|
||||
_ = f.Close()
|
||||
}
|
||||
@ -1,112 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/execabs"
|
||||
)
|
||||
|
||||
// maxAllowedRestarts controls after how many crashes the app will give up restarting.
|
||||
const maxAllowedRestarts = 10
|
||||
|
||||
func (b *Base) restartApp(crash bool) error {
|
||||
var args []string
|
||||
|
||||
if crash {
|
||||
args = incrementRestartFlag(os.Args)[1:]
|
||||
defer func() { os.Exit(1) }()
|
||||
} else {
|
||||
args = os.Args[1:]
|
||||
}
|
||||
|
||||
if b.launcher != "" {
|
||||
args = forceLauncherFlag(args, b.launcher)
|
||||
}
|
||||
|
||||
args = append(args, "--wait", b.mainExecutable)
|
||||
|
||||
logrus.
|
||||
WithField("command", b.command).
|
||||
WithField("args", args).
|
||||
Warn("Restarting")
|
||||
|
||||
return execabs.Command(b.command, args...).Start() //nolint:gosec
|
||||
}
|
||||
|
||||
// incrementRestartFlag increments the value of the restart flag.
|
||||
// If no such flag is present, it is added with initial value 1.
|
||||
func incrementRestartFlag(args []string) []string {
|
||||
res := append([]string{}, args...)
|
||||
|
||||
hasFlag := false
|
||||
|
||||
for k, v := range res {
|
||||
if v != "--restart" {
|
||||
continue
|
||||
}
|
||||
|
||||
hasFlag = true
|
||||
|
||||
if k+1 >= len(res) {
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := strconv.Atoi(res[k+1])
|
||||
if err != nil {
|
||||
res[k+1] = "1"
|
||||
} else {
|
||||
res[k+1] = strconv.Itoa(n + 1)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasFlag {
|
||||
res = append(res, "--restart", "1")
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// forceLauncherFlag replace or add the launcher args with the one set in the app.
|
||||
func forceLauncherFlag(args []string, launcher string) []string {
|
||||
res := append([]string{}, args...)
|
||||
|
||||
hasFlag := false
|
||||
|
||||
for k, v := range res {
|
||||
if v != "--launcher" {
|
||||
continue
|
||||
}
|
||||
|
||||
if k+1 >= len(res) {
|
||||
continue
|
||||
}
|
||||
|
||||
hasFlag = true
|
||||
res[k+1] = launcher
|
||||
}
|
||||
|
||||
if !hasFlag {
|
||||
res = append(res, "--launcher", launcher)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
@ -1,63 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIncrementRestartFlag(t *testing.T) {
|
||||
tests := []struct {
|
||||
in []string
|
||||
out []string
|
||||
}{
|
||||
{[]string{"./bridge", "--restart", "1"}, []string{"./bridge", "--restart", "2"}},
|
||||
{[]string{"./bridge", "--restart", "2"}, []string{"./bridge", "--restart", "3"}},
|
||||
{[]string{"./bridge", "--other", "--restart", "2"}, []string{"./bridge", "--other", "--restart", "3"}},
|
||||
{[]string{"./bridge", "--restart", "2", "--other"}, []string{"./bridge", "--restart", "3", "--other"}},
|
||||
{[]string{"./bridge", "--restart", "2", "--other", "2"}, []string{"./bridge", "--restart", "3", "--other", "2"}},
|
||||
{[]string{"./bridge"}, []string{"./bridge", "--restart", "1"}},
|
||||
{[]string{"./bridge", "--something"}, []string{"./bridge", "--something", "--restart", "1"}},
|
||||
{[]string{"./bridge", "--something", "--else"}, []string{"./bridge", "--something", "--else", "--restart", "1"}},
|
||||
{[]string{"./bridge", "--restart", "bad"}, []string{"./bridge", "--restart", "1"}},
|
||||
{[]string{"./bridge", "--restart", "bad", "--other"}, []string{"./bridge", "--restart", "1", "--other"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(strings.Join(tt.in, " "), func(t *testing.T) {
|
||||
assert.Equal(t, tt.out, incrementRestartFlag(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionLessThan(t *testing.T) {
|
||||
r := require.New(t)
|
||||
|
||||
old := semver.MustParse("1.1.0")
|
||||
current := semver.MustParse("1.1.1")
|
||||
newer := semver.MustParse("1.1.2")
|
||||
|
||||
r.True(old.LessThan(current))
|
||||
r.False(current.LessThan(current))
|
||||
r.False(newer.LessThan(current))
|
||||
}
|
||||
@ -1,101 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/allan-simon/go-singleinstance"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// checkSingleInstance returns error if a bridge instance is already running
|
||||
// This instance should be stop and window of running window should be brought
|
||||
// to focus.
|
||||
//
|
||||
// For macOS and Linux when already running version is older than this instance
|
||||
// it will kill old and continue with this new bridge (i.e. no error returned).
|
||||
func checkSingleInstance(lockFilePath string, settingsObj *settings.Settings) (*os.File, error) {
|
||||
if lock, err := singleinstance.CreateLockFile(lockFilePath); err == nil {
|
||||
// Bridge is not runnig, continue normally
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
if err := runningVersionIsOlder(settingsObj); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pid, err := getPID(lockFilePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := unix.Kill(pid, unix.SIGTERM); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Need to wait some time to release file lock
|
||||
time.Sleep(time.Second)
|
||||
|
||||
return singleinstance.CreateLockFile(lockFilePath)
|
||||
}
|
||||
|
||||
func getPID(lockFilePath string) (int, error) {
|
||||
file, err := os.Open(filepath.Clean(lockFilePath))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
rawPID := make([]byte, 10) // PID is probably up to 7 digits long, 10 should be enough
|
||||
n, err := file.Read(rawPID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return strconv.Atoi(strings.TrimSpace(string(rawPID[:n])))
|
||||
}
|
||||
|
||||
func runningVersionIsOlder(settingsObj *settings.Settings) error {
|
||||
currentVer, err := semver.StrictNewVersion(constants.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
runningVer, err := semver.StrictNewVersion(settingsObj.Get(settings.LastVersionKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !runningVer.LessThan(currentVer) {
|
||||
return errors.New("running version is not older")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,32 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package base
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/allan-simon/go-singleinstance"
|
||||
)
|
||||
|
||||
func checkSingleInstance(lockFilePath string, _ *settings.Settings) (*os.File, error) {
|
||||
return singleinstance.CreateLockFile(lockFilePath)
|
||||
}
|
||||
205
internal/app/bridge.go
Normal file
205
internal/app/bridge.go
Normal file
@ -0,0 +1,205 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/go-autostart"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/dialer"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/versioner"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const vaultSecretName = "bridge-vault-key"
|
||||
|
||||
func newBridge(locations *locations.Locations, identifier *useragent.UserAgent) (*bridge.Bridge, error) {
|
||||
// Create the underlying dialer used by the bridge.
|
||||
// It only connects to trusted servers and reports any untrusted servers it finds.
|
||||
pinningDialer := dialer.NewPinningTLSDialer(
|
||||
dialer.NewBasicTLSDialer(constants.APIHost),
|
||||
dialer.NewTLSReporter(constants.APIHost, constants.AppVersion, identifier, dialer.TrustedAPIPins),
|
||||
dialer.NewTLSPinChecker(dialer.TrustedAPIPins),
|
||||
)
|
||||
|
||||
// Create a proxy dialer which switches to a proxy if the request fails.
|
||||
proxyDialer := dialer.NewProxyTLSDialer(pinningDialer, constants.APIHost)
|
||||
|
||||
// Create the autostarter.
|
||||
autostarter, err := newAutostarter()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create autostarter: %w", err)
|
||||
}
|
||||
|
||||
// Create the update installer.
|
||||
updater, err := newUpdater(locations)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create updater: %w", err)
|
||||
}
|
||||
|
||||
// Get the current bridge version.
|
||||
version, err := semver.NewVersion(constants.Version)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create version: %w", err)
|
||||
}
|
||||
|
||||
// Create the encVault.
|
||||
encVault, insecure, corrupt, err := newVault(locations)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create vault: %w", err)
|
||||
} else if insecure {
|
||||
logrus.Warn("The vault key could not be retrieved; the vault will not be encrypted")
|
||||
} else if corrupt {
|
||||
logrus.Warn("The vault is corrupt and has been wiped")
|
||||
}
|
||||
|
||||
// Install the certificates if needed.
|
||||
if installed := encVault.GetCertsInstalled(); !installed {
|
||||
if err := certs.NewInstaller().InstallCert(encVault.GetBridgeTLSCert()); err != nil {
|
||||
return nil, fmt.Errorf("failed to install certs: %w", err)
|
||||
}
|
||||
|
||||
if err := encVault.SetCertsInstalled(true); err != nil {
|
||||
return nil, fmt.Errorf("failed to set certs installed: %w", err)
|
||||
}
|
||||
|
||||
if err := encVault.SetCertsInstalled(true); err != nil {
|
||||
return nil, fmt.Errorf("could not set certs installed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new bridge.
|
||||
bridge, err := bridge.New(constants.APIHost, locations, encVault, identifier, pinningDialer, proxyDialer, autostarter, updater, version)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create bridge: %w", err)
|
||||
}
|
||||
|
||||
// If the vault could not be loaded properly, push errors to the bridge.
|
||||
switch {
|
||||
case insecure:
|
||||
bridge.PushError(vault.ErrInsecure)
|
||||
|
||||
case corrupt:
|
||||
bridge.PushError(vault.ErrCorrupt)
|
||||
}
|
||||
|
||||
return bridge, nil
|
||||
}
|
||||
|
||||
func newVault(locations *locations.Locations) (*vault.Vault, bool, bool, error) {
|
||||
var insecure bool
|
||||
|
||||
vaultDir, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("could not get vault dir: %w", err)
|
||||
}
|
||||
|
||||
var vaultKey []byte
|
||||
|
||||
if key, err := getVaultKey(vaultDir); err != nil {
|
||||
insecure = true
|
||||
} else {
|
||||
vaultKey = key
|
||||
}
|
||||
|
||||
gluonDir, err := locations.ProvideGluonPath()
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("could not provide gluon path: %w", err)
|
||||
}
|
||||
|
||||
vault, corrupt, err := vault.New(vaultDir, gluonDir, vaultKey)
|
||||
if err != nil {
|
||||
return nil, false, false, fmt.Errorf("could not create vault: %w", err)
|
||||
}
|
||||
|
||||
return vault, insecure, corrupt, nil
|
||||
}
|
||||
|
||||
func getVaultKey(vaultDir string) ([]byte, error) {
|
||||
helper, err := vault.GetHelper(vaultDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get keychain helper: %w", err)
|
||||
}
|
||||
|
||||
keychain, err := keychain.NewKeychain(helper, constants.KeyChainName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create keychain: %w", err)
|
||||
}
|
||||
|
||||
secrets, err := keychain.List()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not list keychain: %w", err)
|
||||
}
|
||||
|
||||
if !slices.Contains(secrets, vaultSecretName) {
|
||||
tok, err := crypto.RandomToken(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not generate random token: %w", err)
|
||||
}
|
||||
|
||||
if err := keychain.Put(vaultSecretName, base64.StdEncoding.EncodeToString(tok)); err != nil {
|
||||
return nil, fmt.Errorf("could not put keychain item: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, keyEnc, err := keychain.Get(vaultSecretName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get keychain item: %w", err)
|
||||
}
|
||||
|
||||
keyDec, err := base64.StdEncoding.DecodeString(keyEnc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not decode keychain item: %w", err)
|
||||
}
|
||||
|
||||
return keyDec, nil
|
||||
}
|
||||
|
||||
func newAutostarter() (*autostart.App, error) {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &autostart.App{
|
||||
Name: constants.FullAppName,
|
||||
DisplayName: constants.FullAppName,
|
||||
Exec: []string{exe, "--" + flagNoWindow},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newUpdater(locations *locations.Locations) (*updater.Updater, error) {
|
||||
updatesDir, err := locations.ProvideUpdatesPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not provide updates path: %w", err)
|
||||
}
|
||||
|
||||
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create key from armored: %w", err)
|
||||
}
|
||||
|
||||
verifier, err := crypto.NewKeyRing(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create key ring: %w", err)
|
||||
}
|
||||
|
||||
return updater.NewUpdater(
|
||||
updater.NewInstaller(versioner.New(updatesDir)),
|
||||
verifier,
|
||||
constants.UpdateName,
|
||||
runtime.GOOS,
|
||||
), nil
|
||||
}
|
||||
@ -1,269 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package bridge implements the bridge CLI application.
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/api"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/app/base"
|
||||
pkgBridge "github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/smtp"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
flagLogIMAP = "log-imap"
|
||||
flagLogSMTP = "log-smtp"
|
||||
flagNonInteractive = "noninteractive"
|
||||
|
||||
// Memory cache was estimated by empirical usage in past and it was set to 100MB.
|
||||
// NOTE: This value must not be less than maximal size of one email (~30MB).
|
||||
inMemoryCacheLimnit = 100 * (1 << 20)
|
||||
)
|
||||
|
||||
func New(base *base.Base) *cli.App {
|
||||
app := base.NewApp(main)
|
||||
|
||||
app.Flags = append(app.Flags, []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: flagLogIMAP,
|
||||
Usage: "Enable logging of IMAP communications (all|client|server) (may contain decrypted data!)",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagLogSMTP,
|
||||
Usage: "Enable logging of SMTP communications (may contain decrypted data!)",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: flagNonInteractive,
|
||||
Usage: "Start Bridge entirely noninteractively",
|
||||
},
|
||||
}...)
|
||||
|
||||
return app
|
||||
}
|
||||
|
||||
func main(b *base.Base, c *cli.Context) error { //nolint:funlen
|
||||
cache, cacheErr := loadMessageCache(b)
|
||||
if cacheErr != nil {
|
||||
logrus.WithError(cacheErr).Error("Could not load local cache.")
|
||||
}
|
||||
|
||||
builder := message.NewBuilder(
|
||||
b.Settings.GetInt(settings.FetchWorkers),
|
||||
b.Settings.GetInt(settings.AttachmentWorkers),
|
||||
)
|
||||
|
||||
bridge := pkgBridge.New(
|
||||
b.Locations,
|
||||
b.Cache,
|
||||
b.Settings,
|
||||
b.SentryReporter,
|
||||
b.CrashHandler,
|
||||
b.Listener,
|
||||
b.TLS,
|
||||
b.UserAgent,
|
||||
cache,
|
||||
builder,
|
||||
b.CM,
|
||||
b.Creds,
|
||||
b.Updater,
|
||||
b.Versioner,
|
||||
b.Autostart,
|
||||
)
|
||||
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
|
||||
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
|
||||
|
||||
tlsConfig, err := bridge.GetTLSConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cacheErr != nil {
|
||||
bridge.AddError(pkgBridge.ErrLocalCacheUnavailable)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.CrashHandler.HandlePanic()
|
||||
api.NewAPIServer(b.Settings, b.Listener).ListenAndServe()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer b.CrashHandler.HandlePanic()
|
||||
imapPort := b.Settings.GetInt(settings.IMAPPortKey)
|
||||
imap.NewIMAPServer(
|
||||
b.CrashHandler,
|
||||
c.String(flagLogIMAP) == "client" || c.String(flagLogIMAP) == "all",
|
||||
c.String(flagLogIMAP) == "server" || c.String(flagLogIMAP) == "all",
|
||||
imapPort, tlsConfig, imapBackend, b.UserAgent, b.Listener).ListenAndServe()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer b.CrashHandler.HandlePanic()
|
||||
smtpPort := b.Settings.GetInt(settings.SMTPPortKey)
|
||||
useSSL := b.Settings.GetBool(settings.SMTPSSLKey)
|
||||
smtp.NewSMTPServer(
|
||||
b.CrashHandler,
|
||||
c.Bool(flagLogSMTP),
|
||||
smtpPort, useSSL, tlsConfig, smtpBackend, b.Listener).ListenAndServe()
|
||||
}()
|
||||
|
||||
// We want to remove old versions if the app exits successfully.
|
||||
b.AddTeardownAction(b.Versioner.RemoveOldVersions)
|
||||
|
||||
// We want cookies to be saved to disk so they are loaded the next time.
|
||||
b.AddTeardownAction(b.CookieJar.PersistCookies)
|
||||
|
||||
var frontendMode string
|
||||
|
||||
switch {
|
||||
case c.Bool(base.FlagCLI):
|
||||
frontendMode = "cli"
|
||||
case c.Bool(flagNonInteractive):
|
||||
return <-(make(chan error)) // Block forever.
|
||||
default:
|
||||
frontendMode = "grpc"
|
||||
}
|
||||
|
||||
f := frontend.New(
|
||||
frontendMode,
|
||||
!c.Bool(base.FlagNoWindow),
|
||||
b.CrashHandler,
|
||||
b.Listener,
|
||||
b.Updater,
|
||||
bridge,
|
||||
b,
|
||||
b.Locations,
|
||||
)
|
||||
|
||||
// Watch for updates routine
|
||||
go func() {
|
||||
ticker := time.NewTicker(constants.UpdateCheckInterval)
|
||||
|
||||
for {
|
||||
checkAndHandleUpdate(b.Updater, f, b.Settings.GetBool(settings.AutoUpdateKey))
|
||||
<-ticker.C
|
||||
}
|
||||
}()
|
||||
|
||||
return f.Loop()
|
||||
}
|
||||
|
||||
func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) {
|
||||
log := logrus.WithField("pkg", "app/bridge")
|
||||
version, err := u.Check()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("An error occurred while checking for updates")
|
||||
return
|
||||
}
|
||||
|
||||
f.WaitUntilFrontendIsReady()
|
||||
|
||||
// Update links in UI
|
||||
f.SetVersion(version)
|
||||
|
||||
if !u.IsUpdateApplicable(version) {
|
||||
log.Info("No need to update")
|
||||
return
|
||||
}
|
||||
|
||||
log.WithField("version", version.Version).Info("An update is available")
|
||||
|
||||
if !autoUpdate {
|
||||
f.NotifyManualUpdate(version, u.CanInstall(version))
|
||||
return
|
||||
}
|
||||
|
||||
if !u.CanInstall(version) {
|
||||
log.Info("A manual update is required")
|
||||
f.NotifySilentUpdateError(updater.ErrManualUpdateRequired)
|
||||
return
|
||||
}
|
||||
|
||||
if err := u.InstallUpdate(version); err != nil {
|
||||
if errors.Cause(err) == updater.ErrDownloadVerify {
|
||||
log.WithError(err).Warning("Skipping update installation due to temporary error")
|
||||
} else {
|
||||
log.WithError(err).Error("The update couldn't be installed")
|
||||
f.NotifySilentUpdateError(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
f.NotifySilentUpdateInstalled()
|
||||
}
|
||||
|
||||
// loadMessageCache loads local cache in case it is enabled in settings and available.
|
||||
// In any other case it is returning in-memory cache. Could also return an error in case
|
||||
// local cache is enabled but unavailable (in-memory cache will be returned nevertheless).
|
||||
func loadMessageCache(b *base.Base) (cache.Cache, error) {
|
||||
if !b.Settings.GetBool(settings.CacheEnabledKey) {
|
||||
return cache.NewInMemoryCache(inMemoryCacheLimnit), nil
|
||||
}
|
||||
|
||||
var compressor cache.Compressor
|
||||
|
||||
// NOTE(GODT-1158): Changing compression is not an option currently
|
||||
// available for user but, if user changes compression setting we have
|
||||
// to nuke the cache.
|
||||
if b.Settings.GetBool(settings.CacheCompressionKey) {
|
||||
compressor = &cache.GZipCompressor{}
|
||||
} else {
|
||||
compressor = &cache.NoopCompressor{}
|
||||
}
|
||||
|
||||
var path string
|
||||
|
||||
if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" {
|
||||
path = customPath
|
||||
} else {
|
||||
path = b.Cache.GetDefaultMessageCacheDir()
|
||||
// Store path so it will allways persist if default location
|
||||
// will be changed in new version.
|
||||
b.Settings.Set(settings.CacheLocationKey, path)
|
||||
}
|
||||
|
||||
// To prevent memory peaks we set maximal write concurency for store
|
||||
// build jobs.
|
||||
store.SetBuildAndCacheJobLimit(b.Settings.GetInt(settings.CacheConcurrencyWrite))
|
||||
|
||||
messageCache, err := cache.NewOnDiskCache(path, compressor, cache.Options{
|
||||
MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)),
|
||||
MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey),
|
||||
ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead),
|
||||
ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
|
||||
})
|
||||
if err != nil {
|
||||
return cache.NewInMemoryCache(inMemoryCacheLimnit), err
|
||||
}
|
||||
|
||||
return messageCache, nil
|
||||
}
|
||||
28
internal/app/logging.go
Normal file
28
internal/app/logging.go
Normal file
@ -0,0 +1,28 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
func initLogging(c *cli.Context, locations *locations.Locations, crashHandler *crash.Handler) error {
|
||||
// Get a place to keep our logs.
|
||||
logsPath, err := locations.ProvideLogsPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not provide logs path: %w", err)
|
||||
}
|
||||
|
||||
// Initialize logging.
|
||||
if err := logging.Init(logsPath, c.String(flagLogLevel)); err != nil {
|
||||
return fmt.Errorf("could not initialize logging: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we dump a stack trace if we crash.
|
||||
crashHandler.AddRecoveryAction(logging.DumpStackTrace(logsPath))
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package bridge provides core functionality of Bridge app.
|
||||
package bridge
|
||||
|
||||
import "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
|
||||
// IsAutostartEnabled checks if link file exits.
|
||||
func (b *Bridge) IsAutostartEnabled() bool {
|
||||
return b.autostart.IsEnabled()
|
||||
}
|
||||
|
||||
// EnableAutostart creates link and sets the preferences.
|
||||
func (b *Bridge) EnableAutostart() error {
|
||||
b.settings.SetBool(settings.AutostartKey, true)
|
||||
return b.autostart.Enable()
|
||||
}
|
||||
|
||||
// DisableAutostart removes link and sets the preferences.
|
||||
func (b *Bridge) DisableAutostart() error {
|
||||
b.settings.SetBool(settings.AutostartKey, false)
|
||||
return b.autostart.Disable()
|
||||
}
|
||||
@ -1,325 +1,318 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package bridge provides core functionality of Bridge app.
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/go-autostart"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
|
||||
"github.com/ProtonMail/gluon"
|
||||
"github.com/ProtonMail/gluon/watcher"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/metrics"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/cookies"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
var log = logrus.WithField("pkg", "bridge") //nolint:gochecknoglobals
|
||||
|
||||
var ErrLocalCacheUnavailable = errors.New("local cache is unavailable")
|
||||
|
||||
type Bridge struct {
|
||||
*users.Users
|
||||
// vault holds bridge-specific data, such as preferences and known users (authorized or not).
|
||||
vault *vault.Vault
|
||||
|
||||
locations Locator
|
||||
settings SettingsProvider
|
||||
clientManager pmapi.Manager
|
||||
// users holds authorized users.
|
||||
users map[string]*user.User
|
||||
|
||||
// api manages user API clients.
|
||||
api *liteapi.Manager
|
||||
cookieJar *cookies.Jar
|
||||
proxyDialer ProxyDialer
|
||||
identifier Identifier
|
||||
|
||||
// watchers holds all registered event watchers.
|
||||
watchers []*watcher.Watcher[events.Event]
|
||||
watchersLock sync.RWMutex
|
||||
|
||||
// tlsConfig holds the bridge TLS config used by the IMAP and SMTP servers.
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// imapServer is the bridge's IMAP server.
|
||||
imapServer *gluon.Server
|
||||
imapListener net.Listener
|
||||
|
||||
// smtpServer is the bridge's SMTP server.
|
||||
smtpServer *smtp.Server
|
||||
smtpBackend *smtpBackend
|
||||
smtpListener net.Listener
|
||||
|
||||
// updater is the bridge's updater.
|
||||
updater Updater
|
||||
versioner Versioner
|
||||
tls *tls.TLS
|
||||
userAgent *useragent.UserAgent
|
||||
cacheProvider CacheProvider
|
||||
autostart *autostart.App
|
||||
// Bridge's global errors list.
|
||||
errors []error
|
||||
curVersion *semver.Version
|
||||
updateCheckCh chan struct{}
|
||||
|
||||
isAllMailVisible bool
|
||||
isFirstStart bool
|
||||
lastVersion string
|
||||
// focusService is used to raise the bridge window when needed.
|
||||
focusService *focus.FocusService
|
||||
|
||||
// autostarter is the bridge's autostarter.
|
||||
autostarter Autostarter
|
||||
|
||||
// locator is the bridge's locator.
|
||||
locator Locator
|
||||
|
||||
// errors contains errors encountered during startup.
|
||||
errors []error
|
||||
}
|
||||
|
||||
func New( //nolint:funlen
|
||||
locations Locator,
|
||||
cacheProvider CacheProvider,
|
||||
setting SettingsProvider,
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler users.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
tls *tls.TLS,
|
||||
userAgent *useragent.UserAgent,
|
||||
cache cache.Cache,
|
||||
builder *message.Builder,
|
||||
clientManager pmapi.Manager,
|
||||
credStorer users.CredentialsStorer,
|
||||
updater Updater,
|
||||
versioner Versioner,
|
||||
autostart *autostart.App,
|
||||
) *Bridge {
|
||||
// Allow DoH before starting the app if the user has previously set this setting.
|
||||
// This allows us to start even if protonmail is blocked.
|
||||
if setting.GetBool(settings.AllowProxyKey) {
|
||||
clientManager.AllowProxy()
|
||||
// New creates a new bridge.
|
||||
func New(
|
||||
apiURL string, // the URL of the API to use
|
||||
locator Locator, // the locator to provide paths to store data
|
||||
vault *vault.Vault, // the bridge's encrypted data store
|
||||
identifier Identifier, // the identifier to keep track of the user agent
|
||||
tlsReporter TLSReporter, // the TLS reporter to report TLS errors
|
||||
proxyDialer ProxyDialer, // the DoH dialer
|
||||
autostarter Autostarter, // the autostarter to manage autostart settings
|
||||
updater Updater, // the updater to fetch and install updates
|
||||
curVersion *semver.Version, // the current version of the bridge
|
||||
) (*Bridge, error) {
|
||||
if vault.GetProxyAllowed() {
|
||||
proxyDialer.AllowProxy()
|
||||
} else {
|
||||
proxyDialer.DisallowProxy()
|
||||
}
|
||||
|
||||
u := users.New(
|
||||
locations,
|
||||
panicHandler,
|
||||
eventListener,
|
||||
clientManager,
|
||||
credStorer,
|
||||
newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
|
||||
cookieJar, err := cookies.NewCookieJar(vault)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cookie jar: %w", err)
|
||||
}
|
||||
|
||||
api := liteapi.New(
|
||||
liteapi.WithHostURL(apiURL),
|
||||
liteapi.WithAppVersion(constants.AppVersion),
|
||||
liteapi.WithCookieJar(cookieJar),
|
||||
liteapi.WithTransport(&http.Transport{DialTLSContext: proxyDialer.DialTLSContext}),
|
||||
)
|
||||
|
||||
b := &Bridge{
|
||||
Users: u,
|
||||
locations: locations,
|
||||
settings: setting,
|
||||
clientManager: clientManager,
|
||||
updater: updater,
|
||||
versioner: versioner,
|
||||
tls: tls,
|
||||
userAgent: userAgent,
|
||||
cacheProvider: cacheProvider,
|
||||
autostart: autostart,
|
||||
isFirstStart: false,
|
||||
isAllMailVisible: setting.GetBool(settings.IsAllMailVisible),
|
||||
}
|
||||
|
||||
if setting.GetBool(settings.FirstStartKey) {
|
||||
b.isFirstStart = true
|
||||
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send metric")
|
||||
}
|
||||
setting.SetBool(settings.FirstStartKey, false)
|
||||
}
|
||||
|
||||
// Keep in bridge and update in settings the last used version.
|
||||
b.lastVersion = b.settings.Get(settings.LastVersionKey)
|
||||
b.settings.Set(settings.LastVersionKey, constants.Version)
|
||||
|
||||
go b.heartbeat()
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// heartbeat sends a heartbeat signal once a day.
|
||||
func (b *Bridge) heartbeat() {
|
||||
for range time.Tick(time.Minute) {
|
||||
lastHeartbeatDay, err := strconv.ParseInt(b.settings.Get(settings.LastHeartbeatKey), 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we're still on the same day, don't send a heartbeat.
|
||||
if time.Now().YearDay() == int(lastHeartbeatDay) {
|
||||
continue
|
||||
}
|
||||
|
||||
// We're on the next (or a different) day, so send a heartbeat.
|
||||
if err := b.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel)); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send heartbeat")
|
||||
continue
|
||||
}
|
||||
|
||||
// Heartbeat was sent successfully so update the last heartbeat day.
|
||||
b.settings.Set(settings.LastHeartbeatKey, fmt.Sprintf("%v", time.Now().YearDay()))
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpdateChannel returns currently set update channel.
|
||||
func (b *Bridge) GetUpdateChannel() updater.UpdateChannel {
|
||||
return updater.UpdateChannel(b.settings.Get(settings.UpdateChannelKey))
|
||||
}
|
||||
|
||||
// SetUpdateChannel switches update channel.
|
||||
func (b *Bridge) SetUpdateChannel(channel updater.UpdateChannel) {
|
||||
b.settings.Set(settings.UpdateChannelKey, string(channel))
|
||||
}
|
||||
|
||||
func (b *Bridge) resetToLatestStable() error {
|
||||
version, err := b.updater.Check()
|
||||
tlsConfig, err := loadTLSConfig(vault)
|
||||
if err != nil {
|
||||
// If we can not check for updates - just remove all local updates and reset to base installer version.
|
||||
// Not using `b.locations.ClearUpdates()` because `versioner.RemoveOtherVersions` can also handle
|
||||
// case when it is needed to remove currently running verion.
|
||||
if err := b.versioner.RemoveOtherVersions(semver.MustParse("0.0.0")); err != nil {
|
||||
log.WithError(err).Error("Failed to clear updates while downgrading channel")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to load TLS config: %w", err)
|
||||
}
|
||||
|
||||
imapServer, err := newIMAPServer(vault.GetGluonDir(), curVersion, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP server: %w", err)
|
||||
}
|
||||
|
||||
smtpBackend, err := newSMTPBackend()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SMTP backend: %w", err)
|
||||
}
|
||||
|
||||
smtpServer, err := newSMTPServer(smtpBackend, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SMTP server: %w", err)
|
||||
}
|
||||
|
||||
focusService, err := focus.NewService()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create focus service: %w", err)
|
||||
}
|
||||
|
||||
bridge := &Bridge{
|
||||
vault: vault,
|
||||
users: make(map[string]*user.User),
|
||||
|
||||
api: api,
|
||||
cookieJar: cookieJar,
|
||||
proxyDialer: proxyDialer,
|
||||
identifier: identifier,
|
||||
|
||||
tlsConfig: tlsConfig,
|
||||
imapServer: imapServer,
|
||||
smtpServer: smtpServer,
|
||||
smtpBackend: smtpBackend,
|
||||
|
||||
updater: updater,
|
||||
curVersion: curVersion,
|
||||
updateCheckCh: make(chan struct{}, 1),
|
||||
|
||||
focusService: focusService,
|
||||
autostarter: autostarter,
|
||||
locator: locator,
|
||||
}
|
||||
|
||||
api.AddStatusObserver(func(status liteapi.Status) {
|
||||
bridge.publish(events.ConnStatus{
|
||||
Status: status,
|
||||
})
|
||||
})
|
||||
|
||||
api.AddErrorHandler(liteapi.AppVersionBadCode, func() {
|
||||
bridge.publish(events.UpdateForced{})
|
||||
})
|
||||
|
||||
api.AddPreRequestHook(func(_ *resty.Client, req *resty.Request) error {
|
||||
req.SetHeader("User-Agent", bridge.identifier.GetUserAgent())
|
||||
return nil
|
||||
})
|
||||
|
||||
go func() {
|
||||
for range tlsReporter.GetTLSIssueCh() {
|
||||
bridge.publish(events.TLSIssue{})
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for range focusService.GetRaiseCh() {
|
||||
bridge.publish(events.Raise{})
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for event := range imapServer.AddWatcher() {
|
||||
bridge.handleIMAPEvent(event)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := bridge.loadUsers(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("failed to load connected users: %w", err)
|
||||
}
|
||||
|
||||
// If current version is same as upstream stable version - do nothing.
|
||||
if version.Version.Equal(semver.MustParse(constants.Version)) {
|
||||
return nil
|
||||
if err := bridge.serveIMAP(); err != nil {
|
||||
bridge.PushError(ErrServeIMAP)
|
||||
}
|
||||
|
||||
if err := b.updater.InstallUpdate(version); err != nil {
|
||||
return err
|
||||
if err := bridge.serveSMTP(); err != nil {
|
||||
bridge.PushError(ErrServeSMTP)
|
||||
}
|
||||
|
||||
return b.versioner.RemoveOtherVersions(version.Version)
|
||||
if err := bridge.watchForUpdates(); err != nil {
|
||||
bridge.PushError(ErrWatchUpdates)
|
||||
}
|
||||
|
||||
return bridge, nil
|
||||
}
|
||||
|
||||
// FactoryReset will remove all local cache and settings.
|
||||
// It will also downgrade to latest stable version if user is on early version.
|
||||
func (b *Bridge) FactoryReset() {
|
||||
wasEarly := b.GetUpdateChannel() == updater.EarlyChannel
|
||||
// GetEvents returns a channel of events of the given type.
|
||||
// If no types are supplied, all events are returned.
|
||||
func (bridge *Bridge) GetEvents(ofType ...events.Event) (<-chan events.Event, func()) {
|
||||
newWatcher := bridge.addWatcher(ofType...)
|
||||
|
||||
b.settings.Set(settings.UpdateChannelKey, string(updater.StableChannel))
|
||||
return newWatcher.GetChannel(), func() { bridge.remWatcher(newWatcher) }
|
||||
}
|
||||
|
||||
if wasEarly {
|
||||
if err := b.resetToLatestStable(); err != nil {
|
||||
log.WithError(err).Error("Failed to reset to latest stable version")
|
||||
func (bridge *Bridge) FactoryReset(ctx context.Context) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (bridge *Bridge) PushError(err error) {
|
||||
bridge.errors = append(bridge.errors, err)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetErrors() []error {
|
||||
return bridge.errors
|
||||
}
|
||||
|
||||
func (bridge *Bridge) Close(ctx context.Context) error {
|
||||
// Close the IMAP server.
|
||||
if err := bridge.closeIMAP(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close IMAP server")
|
||||
}
|
||||
|
||||
// Close the SMTP server.
|
||||
if err := bridge.closeSMTP(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close SMTP server")
|
||||
}
|
||||
|
||||
// Close all users.
|
||||
for _, user := range bridge.users {
|
||||
if err := user.Close(ctx); err != nil {
|
||||
logrus.WithError(err).Error("Failed to close user")
|
||||
}
|
||||
}
|
||||
|
||||
if err := b.Users.ClearData(); err != nil {
|
||||
log.WithError(err).Error("Failed to remove bridge data")
|
||||
// Persist the cookies.
|
||||
if err := bridge.cookieJar.PersistCookies(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to persist cookies")
|
||||
}
|
||||
|
||||
if err := b.Users.ClearUsers(); err != nil {
|
||||
log.WithError(err).Error("Failed to remove bridge users")
|
||||
// Close the focus service.
|
||||
bridge.focusService.Close()
|
||||
|
||||
// Save the last version of bridge that was run.
|
||||
if err := bridge.vault.SetLastVersion(bridge.curVersion); err != nil {
|
||||
logrus.WithError(err).Error("Failed to save last version")
|
||||
}
|
||||
}
|
||||
|
||||
// GetKeychainApp returns current keychain helper.
|
||||
func (b *Bridge) GetKeychainApp() string {
|
||||
return b.settings.Get(settings.PreferredKeychainKey)
|
||||
}
|
||||
|
||||
// SetKeychainApp sets current keychain helper.
|
||||
func (b *Bridge) SetKeychainApp(helper string) {
|
||||
b.settings.Set(settings.PreferredKeychainKey, helper)
|
||||
}
|
||||
|
||||
func (b *Bridge) EnableCache() error {
|
||||
if err := b.Users.EnableCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.settings.SetBool(settings.CacheEnabledKey, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bridge) DisableCache() error {
|
||||
if err := b.Users.DisableCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
func (bridge *Bridge) publish(event events.Event) {
|
||||
bridge.watchersLock.RLock()
|
||||
defer bridge.watchersLock.RUnlock()
|
||||
|
||||
b.settings.SetBool(settings.CacheEnabledKey, false)
|
||||
// Reset back to the default location when disabling.
|
||||
b.settings.Set(settings.CacheLocationKey, b.cacheProvider.GetDefaultMessageCacheDir())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bridge) MigrateCache(from, to string) error {
|
||||
if err := b.Users.MigrateCache(from, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.settings.Set(settings.CacheLocationKey, to)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetProxyAllowed instructs the app whether to use DoH to access an API proxy if necessary.
|
||||
// It also needs to work before the app is initialised (because we may need to use the proxy at startup).
|
||||
func (b *Bridge) SetProxyAllowed(proxyAllowed bool) {
|
||||
b.settings.SetBool(settings.AllowProxyKey, proxyAllowed)
|
||||
if proxyAllowed {
|
||||
b.clientManager.AllowProxy()
|
||||
} else {
|
||||
b.clientManager.DisallowProxy()
|
||||
}
|
||||
}
|
||||
|
||||
// GetProxyAllowed returns whether use of DoH is enabled to access an API proxy if necessary.
|
||||
func (b *Bridge) GetProxyAllowed() bool {
|
||||
return b.settings.GetBool(settings.AllowProxyKey)
|
||||
}
|
||||
|
||||
// AddError add an error to a global error list if it does not contain it yet. Adding nil is noop.
|
||||
func (b *Bridge) AddError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if b.HasError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
b.errors = append(b.errors, err)
|
||||
}
|
||||
|
||||
// DelError removes an error from global error list.
|
||||
func (b *Bridge) DelError(err error) {
|
||||
for idx, val := range b.errors {
|
||||
if val == err {
|
||||
b.errors = append(b.errors[:idx], b.errors[idx+1:]...)
|
||||
return
|
||||
for _, watcher := range bridge.watchers {
|
||||
if watcher.IsWatching(event) {
|
||||
if ok := watcher.Send(event); !ok {
|
||||
logrus.WithField("event", event).Warn("Failed to send event to watcher")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasError returnes true if global error list contains an err.
|
||||
func (b *Bridge) HasError(err error) bool {
|
||||
for _, val := range b.errors {
|
||||
if val == err {
|
||||
return true
|
||||
}
|
||||
func (bridge *Bridge) addWatcher(ofType ...events.Event) *watcher.Watcher[events.Event] {
|
||||
bridge.watchersLock.Lock()
|
||||
defer bridge.watchersLock.Unlock()
|
||||
|
||||
newWatcher := watcher.New(ofType...)
|
||||
|
||||
bridge.watchers = append(bridge.watchers, newWatcher)
|
||||
|
||||
return newWatcher
|
||||
}
|
||||
|
||||
func (bridge *Bridge) remWatcher(oldWatcher *watcher.Watcher[events.Event]) {
|
||||
bridge.watchersLock.Lock()
|
||||
defer bridge.watchersLock.Unlock()
|
||||
|
||||
bridge.watchers = xslices.Filter(bridge.watchers, func(other *watcher.Watcher[events.Event]) bool {
|
||||
return other != oldWatcher
|
||||
})
|
||||
}
|
||||
|
||||
func loadTLSConfig(vault *vault.Vault) (*tls.Config, error) {
|
||||
cert, err := tls.X509KeyPair(vault.GetBridgeTLSCert(), vault.GetBridgeTLSKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return false
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetLastVersion returns the version which was used in previous execution of
|
||||
// Bridge.
|
||||
func (b *Bridge) GetLastVersion() string {
|
||||
return b.lastVersion
|
||||
}
|
||||
func newListener(port int, useTLS bool, tlsConfig *tls.Config) (net.Listener, error) {
|
||||
if useTLS {
|
||||
tlsListener, err := tls.Listen("tcp", fmt.Sprintf(":%v", port), tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// IsFirstStart returns true when Bridge is running for first time or after
|
||||
// factory reset.
|
||||
func (b *Bridge) IsFirstStart() bool {
|
||||
return b.isFirstStart
|
||||
}
|
||||
return tlsListener, nil
|
||||
}
|
||||
|
||||
// IsAllMailVisible can be called extensively by IMAP. Therefore, it is better
|
||||
// to cache the value instead of reading from settings file.
|
||||
func (b *Bridge) IsAllMailVisible() bool {
|
||||
return b.isAllMailVisible
|
||||
}
|
||||
netListener, err := net.Listen("tcp", fmt.Sprintf(":%v", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (b *Bridge) SetIsAllMailVisible(isVisible bool) {
|
||||
b.settings.SetBool(settings.IsAllMailVisible, isVisible)
|
||||
b.isAllMailVisible = isVisible
|
||||
return netListener, nil
|
||||
}
|
||||
|
||||
362
internal/bridge/bridge_test.go
Normal file
362
internal/bridge/bridge_test.go
Normal file
@ -0,0 +1,362 @@
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
const (
|
||||
username = "username"
|
||||
password = "password"
|
||||
)
|
||||
|
||||
var (
|
||||
v2_3_0 = semver.MustParse("2.3.0")
|
||||
v2_4_0 = semver.MustParse("2.4.0")
|
||||
)
|
||||
|
||||
func TestBridge_ConnStatus(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of connection status events.
|
||||
eventCh, done := bridge.GetEvents(events.ConnStatus{})
|
||||
defer done()
|
||||
|
||||
// Simulate network disconnect.
|
||||
mocks.TLSDialer.SetCanDial(false)
|
||||
|
||||
// Trigger some operation that will fail due to the network disconnect.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.Error(t, err)
|
||||
|
||||
// Wait for the event.
|
||||
require.Equal(t, events.ConnStatus{Status: liteapi.StatusDown}, <-eventCh)
|
||||
|
||||
// Simulate network reconnect.
|
||||
mocks.TLSDialer.SetCanDial(true)
|
||||
|
||||
// Trigger some operation that will succeed due to the network reconnect.
|
||||
userID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, userID)
|
||||
|
||||
// Wait for the event.
|
||||
require.Equal(t, events.ConnStatus{Status: liteapi.StatusUp}, <-eventCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_TLSIssue(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of TLS issue events.
|
||||
tlsEventCh, done := bridge.GetEvents(events.TLSIssue{})
|
||||
defer done()
|
||||
|
||||
// Simulate a TLS issue.
|
||||
go func() {
|
||||
mocks.TLSIssueCh <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for the event.
|
||||
require.IsType(t, events.TLSIssue{}, <-tlsEventCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Focus(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of TLS issue events.
|
||||
raiseCh, done := bridge.GetEvents(events.Raise{})
|
||||
defer done()
|
||||
|
||||
// Simulate a focus event.
|
||||
focus.TryRaise()
|
||||
|
||||
// Wait for the event.
|
||||
require.IsType(t, events.Raise{}, <-raiseCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_UserAgent(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
var calls []server.Call
|
||||
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
calls = append(calls, call)
|
||||
})
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Set the platform to something other than the default.
|
||||
bridge.SetCurrentPlatform("platform")
|
||||
|
||||
// Assert that the user agent then contains the platform.
|
||||
require.Contains(t, bridge.GetCurrentUserAgent(), "platform")
|
||||
|
||||
// Login the user.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that the user agent was sent to the API.
|
||||
require.Contains(t, calls[len(calls)-1].Request.Header.Get("User-Agent"), bridge.GetCurrentUserAgent())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Cookies(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
var calls []server.Call
|
||||
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
calls = append(calls, call)
|
||||
})
|
||||
|
||||
var sessionID string
|
||||
|
||||
// Start bridge and add a user so that API assigns us a session ID via cookie.
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionID = cookie.Value
|
||||
})
|
||||
|
||||
// Start bridge again and check that it uses the same session ID.
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
cookie, err := calls[len(calls)-1].Request.Cookie("Session-Id")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sessionID, cookie.Value)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_CheckUpdate(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Disable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||
|
||||
// Get a stream of update events.
|
||||
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
|
||||
defer done()
|
||||
|
||||
// We are currently on the latest version.
|
||||
bridge.CheckForUpdates()
|
||||
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
|
||||
|
||||
// Simulate a new version being available.
|
||||
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
|
||||
|
||||
// Check for updates.
|
||||
bridge.CheckForUpdates()
|
||||
require.Equal(t, events.UpdateAvailable{
|
||||
Version: updater.VersionInfo{
|
||||
Version: v2_4_0,
|
||||
MinAuto: v2_3_0,
|
||||
RolloutProportion: 1.0,
|
||||
},
|
||||
CanInstall: true,
|
||||
}, <-updateCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_AutoUpdate(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Enable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(true))
|
||||
|
||||
// Get a stream of update events.
|
||||
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateInstalled{})
|
||||
defer done()
|
||||
|
||||
// Simulate a new version being available.
|
||||
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
|
||||
|
||||
// Check for updates.
|
||||
bridge.CheckForUpdates()
|
||||
require.Equal(t, events.UpdateInstalled{
|
||||
Version: updater.VersionInfo{
|
||||
Version: v2_4_0,
|
||||
MinAuto: v2_3_0,
|
||||
RolloutProportion: 1.0,
|
||||
},
|
||||
}, <-updateCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_ManualUpdate(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Disable autoupdate for this test.
|
||||
require.NoError(t, bridge.SetAutoUpdate(false))
|
||||
|
||||
// Get a stream of update events.
|
||||
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{}, events.UpdateAvailable{})
|
||||
defer done()
|
||||
|
||||
// Simulate a new version being available, but it's too new for us.
|
||||
mocks.Updater.SetLatestVersion(v2_4_0, v2_4_0)
|
||||
|
||||
// Check for updates.
|
||||
bridge.CheckForUpdates()
|
||||
require.Equal(t, events.UpdateAvailable{
|
||||
Version: updater.VersionInfo{
|
||||
Version: v2_4_0,
|
||||
MinAuto: v2_4_0,
|
||||
RolloutProportion: 1.0,
|
||||
},
|
||||
CanInstall: false,
|
||||
}, <-updateCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_ForceUpdate(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Get a stream of update events.
|
||||
updateCh, done := bridge.GetEvents(events.UpdateForced{})
|
||||
defer done()
|
||||
|
||||
// Set the minimum accepted app version to something newer than the current version.
|
||||
s.SetMinAppVersion(v2_4_0)
|
||||
|
||||
// Try to login the user. It will fail because the bridge is too old.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.Error(t, err)
|
||||
|
||||
// We should get an update required event.
|
||||
require.Equal(t, events.UpdateForced{}, <-updateCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_BadVaultKey(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, vaultKey []byte) {
|
||||
var userID string
|
||||
|
||||
// Login a user.
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
newUserID, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID = newUserID
|
||||
})
|
||||
|
||||
// Start bridge with the correct vault key -- it should load the users correctly.
|
||||
withBridge(t, s.GetHostURL(), locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.ElementsMatch(t, []string{userID}, bridge.GetUserIDs())
|
||||
})
|
||||
|
||||
// Start bridge with a bad vault key, the vault will be wiped and bridge will show no users.
|
||||
withBridge(t, s.GetHostURL(), locator, []byte("bad"), func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
|
||||
// Start bridge with a nil vault key, the vault will be wiped and bridge will show no users.
|
||||
withBridge(t, s.GetHostURL(), locator, nil, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// withEnv creates the full test environment and runs the tests.
|
||||
func withEnv(t *testing.T, tests func(server *server.Server, locator bridge.Locator, vaultKey []byte)) {
|
||||
// Create test API.
|
||||
server := server.NewTLS()
|
||||
defer server.Close()
|
||||
|
||||
// Add test user.
|
||||
_, _, err := server.AddUser(username, password, username+"@pm.me")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate a random vault key.
|
||||
vaultKey, err := crypto.RandomToken(32)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run the tests.
|
||||
tests(server, locations.New(bridge.NewTestLocationsProvider(t), "config-name"), vaultKey)
|
||||
}
|
||||
|
||||
// withBridge creates a new bridge which points to the given API URL and uses the given keychain, and closes it when done.
|
||||
func withBridge(t *testing.T, apiURL string, locator bridge.Locator, vaultKey []byte, tests func(bridge *bridge.Bridge, mocks *bridge.Mocks)) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create the mock objects used in the tests.
|
||||
mocks := bridge.NewMocks(t, v2_3_0, v2_3_0)
|
||||
|
||||
// Bridge will enable the proxy by default at startup.
|
||||
mocks.ProxyDialer.EXPECT().AllowProxy()
|
||||
|
||||
// Get the path to the vault.
|
||||
vaultDir, err := locator.ProvideSettingsPath()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create the vault.
|
||||
vault, _, err := vault.New(vaultDir, t.TempDir(), vaultKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new bridge.
|
||||
bridge, err := bridge.New(
|
||||
apiURL,
|
||||
locator,
|
||||
vault,
|
||||
useragent.New(),
|
||||
mocks.TLSReporter,
|
||||
mocks.ProxyDialer,
|
||||
mocks.Autostarter,
|
||||
mocks.Updater,
|
||||
v2_3_0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use the bridge.
|
||||
tests(bridge, mocks)
|
||||
|
||||
// Close the bridge.
|
||||
require.NoError(t, bridge.Close(ctx))
|
||||
}
|
||||
|
||||
// must is a helper function that panics on error.
|
||||
func must[T any](val T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func getConnectedUserIDs(t *testing.T, bridge *bridge.Bridge) []string {
|
||||
t.Helper()
|
||||
|
||||
return xslices.Filter(bridge.GetUserIDs(), func(userID string) bool {
|
||||
info, err := bridge.GetUserInfo(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
return info.Connected
|
||||
})
|
||||
}
|
||||
@ -21,67 +21,51 @@ import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/logging"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxAttachmentSize = 7 * 1024 * 1024 // MaxAttachmentSize 7 MB total limit
|
||||
MaxAttachmentSize = 7 * (1 << 20) // MaxAttachmentSize 7 MB total size of all attachments.
|
||||
MaxCompressedFilesCount = 6
|
||||
)
|
||||
|
||||
var ErrSizeTooLarge = errors.New("file is too big")
|
||||
func (bridge *Bridge) ReportBug(ctx context.Context, osType, osVersion, description, username, email, client string, attachLogs bool) error {
|
||||
var account string
|
||||
|
||||
// ReportBug reports a new bug from the user.
|
||||
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string, attachLogs bool) error { //nolint:funlen
|
||||
if user, err := b.GetUser(address); err == nil {
|
||||
accountName = user.Username()
|
||||
} else if users := b.GetUsers(); len(users) > 0 {
|
||||
accountName = users[0].Username()
|
||||
if info, err := bridge.QueryUserInfo(username); err == nil {
|
||||
account = info.Username
|
||||
} else if userIDs := bridge.GetUserIDs(); len(userIDs) > 0 {
|
||||
account = bridge.users[userIDs[0]].Name()
|
||||
}
|
||||
|
||||
report := pmapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Browser: emailClient,
|
||||
Title: "[Bridge] Bug",
|
||||
Description: description,
|
||||
Username: accountName,
|
||||
Email: address,
|
||||
}
|
||||
var atts []liteapi.ReportBugAttachment
|
||||
|
||||
if attachLogs {
|
||||
logs, err := b.getMatchingLogs(
|
||||
func(filename string) bool {
|
||||
return logging.MatchLogName(filename) && !logging.MatchStackTraceName(filename)
|
||||
},
|
||||
)
|
||||
logs, err := getMatchingLogs(bridge.locator, func(filename string) bool {
|
||||
return logging.MatchLogName(filename) && !logging.MatchStackTraceName(filename)
|
||||
})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Can't get log files list")
|
||||
return err
|
||||
}
|
||||
|
||||
guiLogs, err := b.getMatchingLogs(
|
||||
func(filename string) bool {
|
||||
return logging.MatchGUILogName(filename) && !logging.MatchStackTraceName(filename)
|
||||
},
|
||||
)
|
||||
guiLogs, err := getMatchingLogs(bridge.locator, func(filename string) bool {
|
||||
return logging.MatchGUILogName(filename) && !logging.MatchStackTraceName(filename)
|
||||
})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Can't get GUI log files list")
|
||||
return err
|
||||
}
|
||||
|
||||
crashes, err := b.getMatchingLogs(
|
||||
func(filename string) bool {
|
||||
return logging.MatchLogName(filename) && logging.MatchStackTraceName(filename)
|
||||
},
|
||||
)
|
||||
crashes, err := getMatchingLogs(bridge.locator, func(filename string) bool {
|
||||
return logging.MatchLogName(filename) && logging.MatchStackTraceName(filename)
|
||||
})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Can't get crash files list")
|
||||
return err
|
||||
}
|
||||
|
||||
var matchFiles []string
|
||||
@ -95,26 +79,42 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
|
||||
|
||||
archive, err := zipFiles(matchFiles)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Can't zip logs and crashes")
|
||||
return err
|
||||
}
|
||||
|
||||
if archive != nil {
|
||||
report.AddAttachment("logs.zip", "application/zip", archive)
|
||||
body, err := io.ReadAll(archive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atts = append(atts, liteapi.ReportBugAttachment{
|
||||
Name: "logs.zip",
|
||||
Filename: "logs.zip",
|
||||
MIMEType: "application/zip",
|
||||
Body: body,
|
||||
})
|
||||
}
|
||||
|
||||
return b.clientManager.ReportBug(context.Background(), report)
|
||||
return bridge.api.ReportBug(ctx, liteapi.ReportBugReq{
|
||||
OS: osType,
|
||||
OSVersion: osVersion,
|
||||
Description: description,
|
||||
Client: client,
|
||||
Username: account,
|
||||
Email: email,
|
||||
}, atts...)
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Bridge) getMatchingLogs(filenameMatchFunc func(string) bool) (filenames []string, err error) {
|
||||
logsPath, err := b.locations.ProvideLogsPath()
|
||||
func getMatchingLogs(locator Locator, filenameMatchFunc func(string) bool) (filenames []string, err error) {
|
||||
logsPath, err := locator.ProvideLogsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -131,24 +131,25 @@ func (b *Bridge) getMatchingLogs(filenameMatchFunc func(string) bool) (filenames
|
||||
matchFiles = append(matchFiles, filepath.Join(logsPath, file.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(matchFiles) // Sorted by timestamp: oldest first.
|
||||
|
||||
return matchFiles, nil
|
||||
}
|
||||
|
||||
type LimitedBuffer struct {
|
||||
type limitedBuffer struct {
|
||||
capacity int
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewLimitedBuffer(capacity int) *LimitedBuffer {
|
||||
return &LimitedBuffer{
|
||||
func newLimitedBuffer(capacity int) *limitedBuffer {
|
||||
return &limitedBuffer{
|
||||
capacity: capacity,
|
||||
buf: bytes.NewBuffer(make([]byte, 0, capacity)),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LimitedBuffer) Write(p []byte) (n int, err error) {
|
||||
func (b *limitedBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p)+b.buf.Len() > b.capacity {
|
||||
return 0, ErrSizeTooLarge
|
||||
}
|
||||
@ -156,7 +157,7 @@ func (b *LimitedBuffer) Write(p []byte) (n int, err error) {
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
|
||||
func (b *LimitedBuffer) Read(p []byte) (n int, err error) {
|
||||
func (b *limitedBuffer) Read(p []byte) (n int, err error) {
|
||||
return b.buf.Read(p)
|
||||
}
|
||||
|
||||
@ -165,14 +166,13 @@ func zipFiles(filenames []string) (io.Reader, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf := NewLimitedBuffer(MaxAttachmentSize)
|
||||
buf := newLimitedBuffer(MaxAttachmentSize)
|
||||
|
||||
w := zip.NewWriter(buf)
|
||||
defer w.Close() //nolint:errcheck
|
||||
|
||||
for _, file := range filenames {
|
||||
err := addFileToZip(file, w)
|
||||
if err != nil {
|
||||
if err := addFileToZip(file, w); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -209,12 +209,9 @@ func addFileToZip(filename string, writer *zip.Writer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(fileWriter, fileReader)
|
||||
if err != nil {
|
||||
if _, err := io.Copy(fileWriter, fileReader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fileReader.Close()
|
||||
|
||||
return err
|
||||
return fileReader.Close()
|
||||
}
|
||||
|
||||
@ -1,70 +1,38 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/clientconfig"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
)
|
||||
|
||||
func (b *Bridge) ConfigureAppleMail(userID, address string) (bool, error) {
|
||||
user, err := b.GetUser(userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
func (bridge *Bridge) ConfigureAppleMail(userID, address string) error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
if address == "" {
|
||||
address = user.GetPrimaryAddress()
|
||||
address = user.Addresses()[0]
|
||||
}
|
||||
|
||||
username := address
|
||||
addresses := address
|
||||
|
||||
if user.IsCombinedAddressMode() {
|
||||
username = user.GetPrimaryAddress()
|
||||
addresses = strings.Join(user.GetAddresses(), ",")
|
||||
}
|
||||
|
||||
var (
|
||||
restart = false
|
||||
smtpSSL = b.settings.GetBool(settings.SMTPSSLKey)
|
||||
)
|
||||
|
||||
// If configuring apple mail for Catalina or newer, users should use SSL.
|
||||
if useragent.IsCatalinaOrNewer() && !smtpSSL {
|
||||
smtpSSL = true
|
||||
restart = true
|
||||
b.settings.SetBool(settings.SMTPSSLKey, true)
|
||||
if useragent.IsCatalinaOrNewer() && !bridge.vault.GetSMTPSSL() {
|
||||
if err := bridge.SetSMTPSSL(true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := (&clientconfig.AppleMail{}).Configure(
|
||||
Host,
|
||||
b.settings.GetInt(settings.IMAPPortKey),
|
||||
b.settings.GetInt(settings.SMTPPortKey),
|
||||
false, smtpSSL,
|
||||
username, addresses,
|
||||
user.GetBridgePassword(),
|
||||
); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return restart, nil
|
||||
return (&clientconfig.AppleMail{}).Configure(
|
||||
constants.Host,
|
||||
bridge.vault.GetIMAPPort(),
|
||||
bridge.vault.GetSMTPPort(),
|
||||
bridge.vault.GetIMAPSSL(),
|
||||
bridge.vault.GetSMTPSSL(),
|
||||
address,
|
||||
strings.Join(user.Addresses(), ","),
|
||||
user.BridgePass(),
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
// Host settings.
|
||||
const (
|
||||
Host = "127.0.0.1"
|
||||
)
|
||||
16
internal/bridge/errors.go
Normal file
16
internal/bridge/errors.go
Normal file
@ -0,0 +1,16 @@
|
||||
package bridge
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrServeIMAP = errors.New("failed to serve IMAP")
|
||||
ErrServeSMTP = errors.New("failed to serve SMTP")
|
||||
ErrWatchUpdates = errors.New("failed to watch for updates")
|
||||
|
||||
ErrNoSuchUser = errors.New("no such user")
|
||||
ErrUserAlreadyExists = errors.New("user already exists")
|
||||
ErrUserAlreadyLoggedIn = errors.New("user already logged in")
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
|
||||
ErrSizeTooLarge = errors.New("file is too big")
|
||||
)
|
||||
67
internal/bridge/files.go
Normal file
67
internal/bridge/files.go
Normal file
@ -0,0 +1,67 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func moveDir(from, to string) error {
|
||||
entries, err := os.ReadDir(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
if err := os.Mkdir(filepath.Join(to, entry.Name()), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := moveDir(filepath.Join(from, entry.Name()), filepath.Join(to, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(filepath.Join(from, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := move(filepath.Join(from, entry.Name()), filepath.Join(to, entry.Name())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return os.Remove(from)
|
||||
}
|
||||
|
||||
func move(from, to string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
c, err := os.Create(to)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := os.Chmod(to, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := c.ReadFrom(f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Remove(from); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
56
internal/bridge/files_test.go
Normal file
56
internal/bridge/files_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMoveDir(t *testing.T) {
|
||||
from, to := t.TempDir(), t.TempDir()
|
||||
|
||||
// Create some files in from.
|
||||
if err := os.WriteFile(filepath.Join(from, "a"), []byte("a"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(from, "b"), []byte("b"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Mkdir(filepath.Join(from, "c"), 0700); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(from, "c", "d"), []byte("d"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Move the files.
|
||||
if err := moveDir(from, to); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Check that the files were moved.
|
||||
if _, err := os.Stat(filepath.Join(from, "a")); !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(to, "a")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(from, "b")); !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(to, "b")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(from, "c")); !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(to, "c")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(from, "c", "d")); !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(to, "c", "d")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
117
internal/bridge/imap.go
Normal file
117
internal/bridge/imap.go
Normal file
@ -0,0 +1,117 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/gluon"
|
||||
imapEvents "github.com/ProtonMail/gluon/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultClientName = "UnknownClient"
|
||||
defaultClientVersion = "0.0.1"
|
||||
)
|
||||
|
||||
func (bridge *Bridge) GetIMAPPort() int {
|
||||
return bridge.vault.GetIMAPPort()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetIMAPPort(newPort int) error {
|
||||
if newPort == bridge.vault.GetIMAPPort() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetIMAPPort(newPort); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bridge.restartIMAP(context.Background())
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetIMAPSSL() bool {
|
||||
return bridge.vault.GetIMAPSSL()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetIMAPSSL(newSSL bool) error {
|
||||
if newSSL == bridge.vault.GetIMAPSSL() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetIMAPSSL(newSSL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bridge.restartIMAP(context.Background())
|
||||
}
|
||||
|
||||
func (bridge *Bridge) serveIMAP() error {
|
||||
imapListener, err := newListener(bridge.vault.GetIMAPPort(), bridge.vault.GetIMAPSSL(), bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create IMAP listener: %w", err)
|
||||
}
|
||||
|
||||
bridge.imapListener = imapListener
|
||||
|
||||
return bridge.imapServer.Serve(context.Background(), bridge.imapListener)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) restartIMAP(ctx context.Context) error {
|
||||
if err := bridge.imapListener.Close(); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to close IMAP listener")
|
||||
}
|
||||
|
||||
return bridge.serveIMAP()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) closeIMAP(ctx context.Context) error {
|
||||
if err := bridge.imapServer.Close(ctx); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to close IMAP server")
|
||||
}
|
||||
|
||||
if err := bridge.imapListener.Close(); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to close IMAP listener")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) handleIMAPEvent(event imapEvents.Event) {
|
||||
switch event := event.(type) {
|
||||
case imapEvents.SessionAdded:
|
||||
if !bridge.identifier.HasClient() {
|
||||
bridge.identifier.SetClient(defaultClientName, defaultClientVersion)
|
||||
}
|
||||
|
||||
case imapEvents.IMAPID:
|
||||
bridge.identifier.SetClient(event.IMAPID.Name, event.IMAPID.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func newIMAPServer(gluonDir string, version *semver.Version, tlsConfig *tls.Config) (*gluon.Server, error) {
|
||||
imapServer, err := gluon.New(
|
||||
gluon.WithTLS(tlsConfig),
|
||||
gluon.WithDataDir(gluonDir),
|
||||
gluon.WithVersionInfo(
|
||||
int(version.Major()),
|
||||
int(version.Minor()),
|
||||
int(version.Patch()),
|
||||
constants.FullAppName,
|
||||
"TODO",
|
||||
"TODO",
|
||||
),
|
||||
gluon.WithLogger(
|
||||
logrus.StandardLogger().WriterLevel(logrus.InfoLevel),
|
||||
logrus.StandardLogger().WriterLevel(logrus.InfoLevel),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return imapServer, nil
|
||||
}
|
||||
@ -1,30 +1,13 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
func (b *Bridge) ProvideLogsPath() (string, error) {
|
||||
return b.locations.ProvideLogsPath()
|
||||
func (bridge *Bridge) GetLogsPath() (string, error) {
|
||||
return bridge.locator.ProvideLogsPath()
|
||||
}
|
||||
|
||||
func (b *Bridge) GetLicenseFilePath() string {
|
||||
return b.locations.GetLicenseFilePath()
|
||||
func (bridge *Bridge) GetLicenseFilePath() string {
|
||||
return bridge.locator.GetLicenseFilePath()
|
||||
}
|
||||
|
||||
func (b *Bridge) GetDependencyLicensesLink() string {
|
||||
return b.locations.GetDependencyLicensesLink()
|
||||
func (bridge *Bridge) GetDependencyLicensesLink() string {
|
||||
return bridge.locator.GetDependencyLicensesLink()
|
||||
}
|
||||
|
||||
127
internal/bridge/mocks.go
Normal file
127
internal/bridge/mocks.go
Normal file
@ -0,0 +1,127 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
type Mocks struct {
|
||||
TLSDialer *TestDialer
|
||||
ProxyDialer *mocks.MockProxyDialer
|
||||
|
||||
TLSReporter *mocks.MockTLSReporter
|
||||
TLSIssueCh chan struct{}
|
||||
|
||||
Updater *TestUpdater
|
||||
Autostarter *mocks.MockAutostarter
|
||||
}
|
||||
|
||||
func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
|
||||
ctl := gomock.NewController(tb)
|
||||
|
||||
mocks := &Mocks{
|
||||
TLSDialer: NewTestDialer(),
|
||||
ProxyDialer: mocks.NewMockProxyDialer(ctl),
|
||||
|
||||
TLSReporter: mocks.NewMockTLSReporter(ctl),
|
||||
TLSIssueCh: make(chan struct{}),
|
||||
|
||||
Updater: NewTestUpdater(version, minAuto),
|
||||
Autostarter: mocks.NewMockAutostarter(ctl),
|
||||
}
|
||||
|
||||
// When using the proxy dialer, we want to use the test dialer.
|
||||
mocks.ProxyDialer.EXPECT().DialTLSContext(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).DoAndReturn(func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return mocks.TLSDialer.DialTLSContext(ctx, network, address)
|
||||
}).AnyTimes()
|
||||
|
||||
// When getting the TLS issue channel, we want to return the test channel.
|
||||
mocks.TLSReporter.EXPECT().GetTLSIssueCh().Return(mocks.TLSIssueCh).AnyTimes()
|
||||
|
||||
return mocks
|
||||
}
|
||||
|
||||
type TestDialer struct {
|
||||
canDial bool
|
||||
}
|
||||
|
||||
func NewTestDialer() *TestDialer {
|
||||
return &TestDialer{
|
||||
canDial: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *TestDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
if !d.canDial {
|
||||
return nil, errors.New("cannot dial")
|
||||
}
|
||||
|
||||
return (&tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}).DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
func (d *TestDialer) SetCanDial(canDial bool) {
|
||||
d.canDial = canDial
|
||||
}
|
||||
|
||||
type TestLocationsProvider struct {
|
||||
config, cache string
|
||||
}
|
||||
|
||||
func NewTestLocationsProvider(tb testing.TB) *TestLocationsProvider {
|
||||
return &TestLocationsProvider{
|
||||
config: tb.TempDir(),
|
||||
cache: tb.TempDir(),
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *TestLocationsProvider) UserConfig() string {
|
||||
return provider.config
|
||||
}
|
||||
|
||||
func (provider *TestLocationsProvider) UserCache() string {
|
||||
return provider.cache
|
||||
}
|
||||
|
||||
type TestUpdater struct {
|
||||
latest updater.VersionInfo
|
||||
}
|
||||
|
||||
func NewTestUpdater(version, minAuto *semver.Version) *TestUpdater {
|
||||
return &TestUpdater{
|
||||
latest: updater.VersionInfo{
|
||||
Version: version,
|
||||
MinAuto: minAuto,
|
||||
|
||||
RolloutProportion: 1.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Version) {
|
||||
testUpdater.latest = updater.VersionInfo{
|
||||
Version: version,
|
||||
MinAuto: minAuto,
|
||||
|
||||
RolloutProportion: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
func (updater *TestUpdater) GetVersionInfo(downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error) {
|
||||
return updater.latest, nil
|
||||
}
|
||||
|
||||
func (updater *TestUpdater) InstallUpdate(downloader updater.Downloader, update updater.VersionInfo) error {
|
||||
return nil
|
||||
}
|
||||
163
internal/bridge/mocks/mocks.go
Normal file
163
internal/bridge/mocks/mocks.go
Normal file
@ -0,0 +1,163 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/v2/internal/bridge (interfaces: TLSReporter,ProxyDialer,Autostarter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockTLSReporter is a mock of TLSReporter interface.
|
||||
type MockTLSReporter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTLSReporterMockRecorder
|
||||
}
|
||||
|
||||
// MockTLSReporterMockRecorder is the mock recorder for MockTLSReporter.
|
||||
type MockTLSReporterMockRecorder struct {
|
||||
mock *MockTLSReporter
|
||||
}
|
||||
|
||||
// NewMockTLSReporter creates a new mock instance.
|
||||
func NewMockTLSReporter(ctrl *gomock.Controller) *MockTLSReporter {
|
||||
mock := &MockTLSReporter{ctrl: ctrl}
|
||||
mock.recorder = &MockTLSReporterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockTLSReporter) EXPECT() *MockTLSReporterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetTLSIssueCh mocks base method.
|
||||
func (m *MockTLSReporter) GetTLSIssueCh() <-chan struct{} {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTLSIssueCh")
|
||||
ret0, _ := ret[0].(<-chan struct{})
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetTLSIssueCh indicates an expected call of GetTLSIssueCh.
|
||||
func (mr *MockTLSReporterMockRecorder) GetTLSIssueCh() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTLSIssueCh", reflect.TypeOf((*MockTLSReporter)(nil).GetTLSIssueCh))
|
||||
}
|
||||
|
||||
// MockProxyDialer is a mock of ProxyDialer interface.
|
||||
type MockProxyDialer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProxyDialerMockRecorder
|
||||
}
|
||||
|
||||
// MockProxyDialerMockRecorder is the mock recorder for MockProxyDialer.
|
||||
type MockProxyDialerMockRecorder struct {
|
||||
mock *MockProxyDialer
|
||||
}
|
||||
|
||||
// NewMockProxyDialer creates a new mock instance.
|
||||
func NewMockProxyDialer(ctrl *gomock.Controller) *MockProxyDialer {
|
||||
mock := &MockProxyDialer{ctrl: ctrl}
|
||||
mock.recorder = &MockProxyDialerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProxyDialer) EXPECT() *MockProxyDialerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AllowProxy mocks base method.
|
||||
func (m *MockProxyDialer) AllowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AllowProxy")
|
||||
}
|
||||
|
||||
// AllowProxy indicates an expected call of AllowProxy.
|
||||
func (mr *MockProxyDialerMockRecorder) AllowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockProxyDialer)(nil).AllowProxy))
|
||||
}
|
||||
|
||||
// DialTLSContext mocks base method.
|
||||
func (m *MockProxyDialer) DialTLSContext(arg0 context.Context, arg1, arg2 string) (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DialTLSContext", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DialTLSContext indicates an expected call of DialTLSContext.
|
||||
func (mr *MockProxyDialerMockRecorder) DialTLSContext(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialTLSContext", reflect.TypeOf((*MockProxyDialer)(nil).DialTLSContext), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// DisallowProxy mocks base method.
|
||||
func (m *MockProxyDialer) DisallowProxy() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisallowProxy")
|
||||
}
|
||||
|
||||
// DisallowProxy indicates an expected call of DisallowProxy.
|
||||
func (mr *MockProxyDialerMockRecorder) DisallowProxy() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockProxyDialer)(nil).DisallowProxy))
|
||||
}
|
||||
|
||||
// MockAutostarter is a mock of Autostarter interface.
|
||||
type MockAutostarter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAutostarterMockRecorder
|
||||
}
|
||||
|
||||
// MockAutostarterMockRecorder is the mock recorder for MockAutostarter.
|
||||
type MockAutostarterMockRecorder struct {
|
||||
mock *MockAutostarter
|
||||
}
|
||||
|
||||
// NewMockAutostarter creates a new mock instance.
|
||||
func NewMockAutostarter(ctrl *gomock.Controller) *MockAutostarter {
|
||||
mock := &MockAutostarter{ctrl: ctrl}
|
||||
mock.recorder = &MockAutostarterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockAutostarter) EXPECT() *MockAutostarterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Disable mocks base method.
|
||||
func (m *MockAutostarter) Disable() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disable")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disable indicates an expected call of Disable.
|
||||
func (mr *MockAutostarterMockRecorder) Disable() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disable", reflect.TypeOf((*MockAutostarter)(nil).Disable))
|
||||
}
|
||||
|
||||
// Enable mocks base method.
|
||||
func (m *MockAutostarter) Enable() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Enable")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Enable indicates an expected call of Enable.
|
||||
func (mr *MockAutostarterMockRecorder) Enable() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Enable", reflect.TypeOf((*MockAutostarter)(nil).Enable))
|
||||
}
|
||||
@ -1,26 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Code generated by ./release-notes.sh at 'Fri Jan 22 11:01:06 AM CET 2021'. DO NOT EDIT.
|
||||
|
||||
package bridge
|
||||
|
||||
const ReleaseNotes = `
|
||||
`
|
||||
|
||||
const ReleaseFixedBugs = `• Fixed sending error caused by inconsistent use of upper and lower case in sender’s email address
|
||||
`
|
||||
@ -1,44 +1,175 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import "github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
import (
|
||||
"context"
|
||||
|
||||
func (b *Bridge) Get(key settings.Key) string {
|
||||
return b.settings.Get(key)
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
)
|
||||
|
||||
func (bridge *Bridge) GetKeychainApp() (string, error) {
|
||||
vaultDir, err := bridge.locator.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return vault.GetHelper(vaultDir)
|
||||
}
|
||||
|
||||
func (b *Bridge) Set(key settings.Key, value string) {
|
||||
b.settings.Set(key, value)
|
||||
func (bridge *Bridge) SetKeychainApp(helper string) error {
|
||||
vaultDir, err := bridge.locator.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return vault.SetHelper(vaultDir, helper)
|
||||
}
|
||||
|
||||
func (b *Bridge) GetBool(key settings.Key) bool {
|
||||
return b.settings.GetBool(key)
|
||||
func (bridge *Bridge) GetGluonDir() string {
|
||||
return bridge.vault.GetGluonDir()
|
||||
}
|
||||
|
||||
func (b *Bridge) SetBool(key settings.Key, value bool) {
|
||||
b.settings.SetBool(key, value)
|
||||
func (bridge *Bridge) SetGluonDir(ctx context.Context, newGluonDir string) error {
|
||||
if newGluonDir == bridge.GetGluonDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.closeIMAP(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := moveDir(bridge.GetGluonDir(), newGluonDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetGluonDir(newGluonDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imapServer, err := newIMAPServer(bridge.vault.GetGluonDir(), bridge.curVersion, bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, user := range bridge.users {
|
||||
imapConn, err := user.NewGluonConnector(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := imapServer.LoadUser(context.Background(), imapConn, user.GluonID(), user.GluonKey()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
bridge.imapServer = imapServer
|
||||
|
||||
return bridge.serveIMAP()
|
||||
}
|
||||
|
||||
func (b *Bridge) GetInt(key settings.Key) int {
|
||||
return b.settings.GetInt(key)
|
||||
func (bridge *Bridge) GetProxyAllowed() bool {
|
||||
return bridge.vault.GetProxyAllowed()
|
||||
}
|
||||
|
||||
func (b *Bridge) SetInt(key settings.Key, value int) {
|
||||
b.settings.SetInt(key, value)
|
||||
func (bridge *Bridge) SetProxyAllowed(allowed bool) error {
|
||||
if allowed {
|
||||
bridge.proxyDialer.AllowProxy()
|
||||
} else {
|
||||
bridge.proxyDialer.DisallowProxy()
|
||||
}
|
||||
|
||||
return bridge.vault.SetProxyAllowed(allowed)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetShowAllMail() bool {
|
||||
return bridge.vault.GetShowAllMail()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetShowAllMail(show bool) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAutostart() bool {
|
||||
return bridge.vault.GetAutostart()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetAutostart(autostart bool) error {
|
||||
if err := bridge.vault.SetAutostart(autostart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if autostart {
|
||||
err = bridge.autostarter.Enable()
|
||||
} else {
|
||||
err = bridge.autostarter.Disable()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAutoUpdate() bool {
|
||||
return bridge.vault.GetAutoUpdate()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetAutoUpdate(autoUpdate bool) error {
|
||||
if bridge.vault.GetAutoUpdate() == autoUpdate {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetAutoUpdate(autoUpdate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.updateCheckCh <- struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetUpdateChannel() updater.Channel {
|
||||
return updater.Channel(bridge.vault.GetUpdateChannel())
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetUpdateChannel(channel updater.Channel) error {
|
||||
if bridge.vault.GetUpdateChannel() == channel {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetUpdateChannel(channel); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.updateCheckCh <- struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetLastVersion() *semver.Version {
|
||||
return bridge.vault.GetLastVersion()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetFirstStart() bool {
|
||||
return bridge.vault.GetFirstStart()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetFirstStart(firstStart bool) error {
|
||||
return bridge.vault.SetFirstStart(firstStart)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetFirstStartGUI() bool {
|
||||
return bridge.vault.GetFirstStartGUI()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetFirstStartGUI(firstStart bool) error {
|
||||
return bridge.vault.SetFirstStartGUI(firstStart)
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetColorScheme() string {
|
||||
return bridge.vault.GetColorScheme()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetColorScheme(colorScheme string) error {
|
||||
return bridge.vault.SetColorScheme(colorScheme)
|
||||
}
|
||||
|
||||
156
internal/bridge/settings_test.go
Normal file
156
internal/bridge/settings_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_Settings_GluonDir(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Create a user.
|
||||
_, err := bridge.LoginUser(context.Background(), username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a new location for the Gluon data.
|
||||
newGluonDir := t.TempDir()
|
||||
|
||||
// Move the gluon dir; it should also move the user's data.
|
||||
require.NoError(t, bridge.SetGluonDir(context.Background(), newGluonDir))
|
||||
|
||||
// Check that the new directory is not empty.
|
||||
entries, err := os.ReadDir(newGluonDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// There should be at least one entry.
|
||||
require.NotEmpty(t, entries)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_IMAPPort(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, the port is 1143.
|
||||
require.Equal(t, 1143, bridge.GetIMAPPort())
|
||||
|
||||
// Set the port to 1144.
|
||||
require.NoError(t, bridge.SetIMAPPort(1144))
|
||||
|
||||
// Get the new setting.
|
||||
require.Equal(t, 1144, bridge.GetIMAPPort())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_IMAPSSL(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, IMAP SSL is disabled.
|
||||
require.False(t, bridge.GetIMAPSSL())
|
||||
|
||||
// Enable IMAP SSL.
|
||||
require.NoError(t, bridge.SetIMAPSSL(true))
|
||||
|
||||
// Get the new setting.
|
||||
require.True(t, bridge.GetIMAPSSL())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_SMTPPort(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, the port is 1025.
|
||||
require.Equal(t, 1025, bridge.GetSMTPPort())
|
||||
|
||||
// Set the port to 1024.
|
||||
require.NoError(t, bridge.SetSMTPPort(1024))
|
||||
|
||||
// Get the new setting.
|
||||
require.Equal(t, 1024, bridge.GetSMTPPort())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_SMTPSSL(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, SMTP SSL is disabled.
|
||||
require.False(t, bridge.GetSMTPSSL())
|
||||
|
||||
// Enable SMTP SSL.
|
||||
require.NoError(t, bridge.SetSMTPSSL(true))
|
||||
|
||||
// Get the new setting.
|
||||
require.True(t, bridge.GetSMTPSSL())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_Proxy(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, proxy is allowed.
|
||||
require.True(t, bridge.GetProxyAllowed())
|
||||
|
||||
// Disallow proxy.
|
||||
mocks.ProxyDialer.EXPECT().DisallowProxy()
|
||||
require.NoError(t, bridge.SetProxyAllowed(false))
|
||||
|
||||
// Get the new setting.
|
||||
require.False(t, bridge.GetProxyAllowed())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_Autostart(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, autostart is disabled.
|
||||
require.False(t, bridge.GetAutostart())
|
||||
|
||||
// Enable autostart.
|
||||
mocks.Autostarter.EXPECT().Enable().Return(nil)
|
||||
require.NoError(t, bridge.SetAutostart(true))
|
||||
|
||||
// Get the new setting.
|
||||
require.True(t, bridge.GetAutostart())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_FirstStart(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, first start is true.
|
||||
require.True(t, bridge.GetFirstStart())
|
||||
|
||||
// Set first start to false.
|
||||
require.NoError(t, bridge.SetFirstStart(false))
|
||||
|
||||
// Get the new setting.
|
||||
require.False(t, bridge.GetFirstStart())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Settings_FirstStartGUI(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// By default, first start is true.
|
||||
require.True(t, bridge.GetFirstStartGUI())
|
||||
|
||||
// Set first start to false.
|
||||
require.NoError(t, bridge.SetFirstStartGUI(false))
|
||||
|
||||
// Get the new setting.
|
||||
require.False(t, bridge.GetFirstStartGUI())
|
||||
})
|
||||
})
|
||||
}
|
||||
109
internal/bridge/smtp.go
Normal file
109
internal/bridge/smtp.go
Normal file
@ -0,0 +1,109 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/emersion/go-sasl"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (bridge *Bridge) GetSMTPPort() int {
|
||||
return bridge.vault.GetSMTPPort()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetSMTPPort(newPort int) error {
|
||||
if newPort == bridge.vault.GetSMTPPort() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetSMTPPort(newPort); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bridge.restartSMTP()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetSMTPSSL() bool {
|
||||
return bridge.vault.GetSMTPSSL()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetSMTPSSL(newSSL bool) error {
|
||||
if newSSL == bridge.vault.GetSMTPSSL() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := bridge.vault.SetSMTPSSL(newSSL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bridge.restartSMTP()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) serveSMTP() error {
|
||||
smtpListener, err := newListener(bridge.vault.GetSMTPPort(), bridge.vault.GetSMTPSSL(), bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create SMTP listener: %w", err)
|
||||
}
|
||||
|
||||
bridge.smtpListener = smtpListener
|
||||
|
||||
go func() {
|
||||
if err := bridge.smtpServer.Serve(bridge.smtpListener); err != nil {
|
||||
logrus.WithError(err).Error("SMTP server stopped")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) restartSMTP() error {
|
||||
if err := bridge.closeSMTP(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
smtpServer, err := newSMTPServer(bridge.smtpBackend, bridge.tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.smtpServer = smtpServer
|
||||
|
||||
return bridge.serveSMTP()
|
||||
}
|
||||
|
||||
func (bridge *Bridge) closeSMTP() error {
|
||||
if err := bridge.smtpServer.Close(); err != nil {
|
||||
logrus.WithError(err).Warn("Failed to close SMTP server")
|
||||
}
|
||||
|
||||
// Don't close the SMTP listener -- it's closed by the server.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSMTPServer(smtpBackend *smtpBackend, tlsConfig *tls.Config) (*smtp.Server, error) {
|
||||
smtpServer := smtp.NewServer(smtpBackend)
|
||||
|
||||
smtpServer.TLSConfig = tlsConfig
|
||||
smtpServer.Domain = constants.Host
|
||||
smtpServer.AllowInsecureAuth = true
|
||||
smtpServer.MaxLineLength = 1 << 16
|
||||
|
||||
smtpServer.EnableAuth(sasl.Login, func(conn *smtp.Conn) sasl.Server {
|
||||
return sasl.NewLoginServer(func(address, password string) error {
|
||||
user, err := conn.Server().Backend.Login(nil, address, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetSession(user)
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
return smtpServer, nil
|
||||
}
|
||||
70
internal/bridge/smtp_backend.go
Normal file
70
internal/bridge/smtp_backend.go
Normal file
@ -0,0 +1,70 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type smtpBackend struct {
|
||||
users []*user.User
|
||||
usersLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSMTPBackend() (*smtpBackend, error) {
|
||||
return &smtpBackend{}, nil
|
||||
}
|
||||
|
||||
func (backend *smtpBackend) Login(state *smtp.ConnectionState, username string, password string) (smtp.Session, error) {
|
||||
backend.usersLock.RLock()
|
||||
defer backend.usersLock.RUnlock()
|
||||
|
||||
for _, user := range backend.users {
|
||||
if slices.Contains(user.Addresses(), username) && user.BridgePass() == password {
|
||||
return user.NewSMTPSession(username)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrNoSuchUser
|
||||
}
|
||||
|
||||
func (backend *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// addUser adds the given user to the backend.
|
||||
// It returns an error if a user with the same ID already exists.
|
||||
func (backend *smtpBackend) addUser(user *user.User) error {
|
||||
backend.usersLock.Lock()
|
||||
defer backend.usersLock.Unlock()
|
||||
|
||||
for _, u := range backend.users {
|
||||
if u.ID() == user.ID() {
|
||||
return ErrUserAlreadyExists
|
||||
}
|
||||
}
|
||||
|
||||
backend.users = append(backend.users, user)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeUser removes the given user from the backend.
|
||||
// It returns an error if the user doesn't exist.
|
||||
func (backend *smtpBackend) removeUser(user *user.User) error {
|
||||
backend.usersLock.Lock()
|
||||
defer backend.usersLock.Unlock()
|
||||
|
||||
idx := xslices.Index(backend.users, user)
|
||||
|
||||
if idx < 0 {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
backend.users = append(backend.users[:idx], backend.users[idx+1:]...)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,87 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/sentry"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/store"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/store/cache"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
)
|
||||
|
||||
type storeFactory struct {
|
||||
cacheProvider CacheProvider
|
||||
sentryReporter *sentry.Reporter
|
||||
panicHandler users.PanicHandler
|
||||
eventListener listener.Listener
|
||||
events *store.Events
|
||||
cache cache.Cache
|
||||
builder *message.Builder
|
||||
}
|
||||
|
||||
func newStoreFactory(
|
||||
cacheProvider CacheProvider,
|
||||
sentryReporter *sentry.Reporter,
|
||||
panicHandler users.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
cache cache.Cache,
|
||||
builder *message.Builder,
|
||||
) *storeFactory {
|
||||
return &storeFactory{
|
||||
cacheProvider: cacheProvider,
|
||||
sentryReporter: sentryReporter,
|
||||
panicHandler: panicHandler,
|
||||
eventListener: eventListener,
|
||||
events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
|
||||
cache: cache,
|
||||
builder: builder,
|
||||
}
|
||||
}
|
||||
|
||||
// New creates new store for given user.
|
||||
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
|
||||
return store.New(
|
||||
f.sentryReporter,
|
||||
f.panicHandler,
|
||||
user,
|
||||
f.eventListener,
|
||||
f.cache,
|
||||
f.builder,
|
||||
getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
|
||||
f.events,
|
||||
)
|
||||
}
|
||||
|
||||
// Remove removes all store files for given user.
|
||||
func (f *storeFactory) Remove(userID string) error {
|
||||
return store.RemoveStore(
|
||||
f.events,
|
||||
getUserStorePath(f.cacheProvider.GetDBDir(), userID),
|
||||
userID,
|
||||
)
|
||||
}
|
||||
|
||||
// getUserStorePath returns the file path of the store database for the given userID.
|
||||
func getUserStorePath(storeDir string, userID string) (path string) {
|
||||
return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
|
||||
}
|
||||
@ -1,64 +1,5 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
pkgTLS "github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
|
||||
"github.com/pkg/errors"
|
||||
logrus "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (b *Bridge) GetTLSConfig() (*tls.Config, error) {
|
||||
if !b.tls.HasCerts() {
|
||||
if err := b.generateTLSCerts(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig, err := b.tls.GetConfig()
|
||||
if err == nil {
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
logrus.WithError(err).Error("Failed to load TLS config, regenerating certificates")
|
||||
|
||||
if err := b.generateTLSCerts(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.tls.GetConfig()
|
||||
}
|
||||
|
||||
func (b *Bridge) generateTLSCerts() error {
|
||||
template, err := pkgTLS.NewTLSTemplate()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to generate TLS template")
|
||||
}
|
||||
|
||||
if err := b.tls.GenerateCerts(template); err != nil {
|
||||
return errors.Wrap(err, "failed to generate TLS certs")
|
||||
}
|
||||
|
||||
if err := b.tls.InstallCerts(); err != nil {
|
||||
return errors.Wrap(err, "failed to install TLS certs")
|
||||
}
|
||||
|
||||
return nil
|
||||
func (bridge *Bridge) GetBridgeTLSCert() ([]byte, []byte) {
|
||||
return bridge.vault.GetBridgeTLSCert(), bridge.vault.GetBridgeTLSKey()
|
||||
}
|
||||
|
||||
@ -1,62 +1,43 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
)
|
||||
|
||||
type Locator interface {
|
||||
ProvideSettingsPath() (string, error)
|
||||
ProvideLogsPath() (string, error)
|
||||
|
||||
GetLicenseFilePath() string
|
||||
GetDependencyLicensesLink() string
|
||||
|
||||
Clear() error
|
||||
ClearUpdates() error
|
||||
}
|
||||
|
||||
type CacheProvider interface {
|
||||
GetIMAPCachePath() string
|
||||
GetDBDir() string
|
||||
GetDefaultMessageCacheDir() string
|
||||
type Identifier interface {
|
||||
GetUserAgent() string
|
||||
HasClient() bool
|
||||
SetClient(name, version string)
|
||||
SetPlatform(platform string)
|
||||
}
|
||||
|
||||
type SettingsProvider interface {
|
||||
Get(key settings.Key) string
|
||||
Set(key settings.Key, value string)
|
||||
type TLSReporter interface {
|
||||
GetTLSIssueCh() <-chan struct{}
|
||||
}
|
||||
|
||||
GetBool(key settings.Key) bool
|
||||
SetBool(key settings.Key, val bool)
|
||||
type ProxyDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
GetInt(key settings.Key) int
|
||||
SetInt(key settings.Key, val int)
|
||||
AllowProxy()
|
||||
DisallowProxy()
|
||||
}
|
||||
|
||||
type Autostarter interface {
|
||||
Enable() error
|
||||
Disable() error
|
||||
}
|
||||
|
||||
type Updater interface {
|
||||
Check() (updater.VersionInfo, error)
|
||||
IsDowngrade(updater.VersionInfo) bool
|
||||
InstallUpdate(updater.VersionInfo) error
|
||||
}
|
||||
|
||||
type Versioner interface {
|
||||
RemoveOtherVersions(*semver.Version) error
|
||||
GetVersionInfo(downloader updater.Downloader, channel updater.Channel) (updater.VersionInfo, error)
|
||||
InstallUpdate(downloader updater.Downloader, update updater.VersionInfo) error
|
||||
}
|
||||
|
||||
72
internal/bridge/updates.go
Normal file
72
internal/bridge/updates.go
Normal file
@ -0,0 +1,72 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
)
|
||||
|
||||
func (bridge *Bridge) CheckForUpdates() {
|
||||
bridge.updateCheckCh <- struct{}{}
|
||||
}
|
||||
|
||||
func (bridge *Bridge) watchForUpdates() error {
|
||||
ticker := time.NewTicker(constants.UpdateCheckInterval)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-bridge.updateCheckCh:
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
version, err := bridge.updater.GetVersionInfo(bridge.api, bridge.vault.GetUpdateChannel())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := bridge.handleUpdate(version); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
bridge.updateCheckCh <- struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) handleUpdate(version updater.VersionInfo) error {
|
||||
switch {
|
||||
case !version.Version.GreaterThan(bridge.curVersion):
|
||||
bridge.publish(events.UpdateNotAvailable{})
|
||||
|
||||
case version.RolloutProportion < bridge.vault.GetUpdateRollout():
|
||||
bridge.publish(events.UpdateNotAvailable{})
|
||||
|
||||
case bridge.curVersion.LessThan(version.MinAuto):
|
||||
bridge.publish(events.UpdateAvailable{
|
||||
Version: version,
|
||||
CanInstall: false,
|
||||
})
|
||||
|
||||
case !bridge.vault.GetAutoUpdate():
|
||||
bridge.publish(events.UpdateAvailable{
|
||||
Version: version,
|
||||
CanInstall: true,
|
||||
})
|
||||
|
||||
default:
|
||||
if err := bridge.updater.InstallUpdate(bridge.api, version); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.publish(events.UpdateInstalled{
|
||||
Version: version,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,26 +1,9 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package bridge
|
||||
|
||||
func (b *Bridge) GetCurrentUserAgent() string {
|
||||
return b.userAgent.String()
|
||||
func (bridge *Bridge) GetCurrentUserAgent() string {
|
||||
return bridge.identifier.GetUserAgent()
|
||||
}
|
||||
|
||||
func (b *Bridge) SetCurrentPlatform(platform string) {
|
||||
b.userAgent.SetPlatform(platform)
|
||||
func (bridge *Bridge) SetCurrentPlatform(platform string) {
|
||||
bridge.identifier.SetPlatform(platform)
|
||||
}
|
||||
|
||||
434
internal/bridge/users.go
Normal file
434
internal/bridge/users.go
Normal file
@ -0,0 +1,434 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
// UserID is the user's API ID.
|
||||
UserID string
|
||||
|
||||
// Username is the user's API username.
|
||||
Username string
|
||||
|
||||
// Connected is true if the user is logged in (has API auth).
|
||||
Connected bool
|
||||
|
||||
// Addresses holds the user's email addresses. The first address is the primary address.
|
||||
Addresses []string
|
||||
|
||||
// AddressMode is the user's address mode.
|
||||
AddressMode AddressMode
|
||||
|
||||
// BridgePass is the user's bridge password.
|
||||
BridgePass string
|
||||
|
||||
// UsedSpace is the amount of space used by the user.
|
||||
UsedSpace int
|
||||
|
||||
// MaxSpace is the total amount of space available to the user.
|
||||
MaxSpace int
|
||||
}
|
||||
|
||||
type AddressMode int
|
||||
|
||||
const (
|
||||
SplitMode AddressMode = iota
|
||||
CombinedMode
|
||||
)
|
||||
|
||||
// GetUserIDs returns the IDs of all known users (authorized or not).
|
||||
func (bridge *Bridge) GetUserIDs() []string {
|
||||
return bridge.vault.GetUserIDs()
|
||||
}
|
||||
|
||||
// GetUserInfo returns info about the given user.
|
||||
func (bridge *Bridge) GetUserInfo(userID string) (UserInfo, error) {
|
||||
vaultUser, err := bridge.vault.GetUser(userID)
|
||||
if err != nil {
|
||||
return UserInfo{}, err
|
||||
}
|
||||
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return getUserInfo(vaultUser.UserID(), vaultUser.Username()), nil
|
||||
}
|
||||
|
||||
return getConnUserInfo(user), nil
|
||||
}
|
||||
|
||||
// QueryUserInfo queries the user info by username or address.
|
||||
func (bridge *Bridge) QueryUserInfo(query string) (UserInfo, error) {
|
||||
for userID, user := range bridge.users {
|
||||
if user.Match(query) {
|
||||
return bridge.GetUserInfo(userID)
|
||||
}
|
||||
}
|
||||
|
||||
return UserInfo{}, ErrNoSuchUser
|
||||
}
|
||||
|
||||
// LoginUser authorizes a new bridge user with the given username and password.
|
||||
// If necessary, a TOTP and mailbox password are requested via the callbacks.
|
||||
func (bridge *Bridge) LoginUser(
|
||||
ctx context.Context,
|
||||
username, password string,
|
||||
getTOTP func() (string, error),
|
||||
getKeyPass func() ([]byte, error),
|
||||
) (string, error) {
|
||||
client, auth, err := bridge.api.NewClientWithLogin(ctx, username, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if auth.TwoFA.Enabled == liteapi.TOTPEnabled {
|
||||
totp, err := getTOTP()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := client.Auth2FA(ctx, liteapi.Auth2FAReq{TwoFactorCode: totp}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
var keyPass []byte
|
||||
|
||||
if auth.PasswordMode == liteapi.TwoPasswordMode {
|
||||
pass, err := getKeyPass()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPass = pass
|
||||
} else {
|
||||
keyPass = []byte(password)
|
||||
}
|
||||
|
||||
apiUser, apiAddrs, userKR, addrKRs, saltedKeyPass, err := client.Unlock(ctx, keyPass)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, saltedKeyPass); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return apiUser.ID, nil
|
||||
}
|
||||
|
||||
// LogoutUser logs out the given user.
|
||||
func (bridge *Bridge) LogoutUser(ctx context.Context, userID string) error {
|
||||
return bridge.logoutUser(ctx, userID, true, false)
|
||||
}
|
||||
|
||||
// DeleteUser deletes the given user.
|
||||
// If it is authorized, it is logged out first.
|
||||
func (bridge *Bridge) DeleteUser(ctx context.Context, userID string) error {
|
||||
if bridge.users[userID] != nil {
|
||||
if err := bridge.logoutUser(ctx, userID, true, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := bridge.vault.DeleteUser(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bridge.publish(events.UserDeleted{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) GetAddressMode(userID string) (AddressMode, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (bridge *Bridge) SetAddressMode(userID string, mode AddressMode) error {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
// loadUsers loads authorized users from the vault.
|
||||
func (bridge *Bridge) loadUsers(ctx context.Context) error {
|
||||
for _, userID := range bridge.vault.GetUserIDs() {
|
||||
user, err := bridge.vault.GetUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AuthUID() == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := bridge.loadUser(ctx, user); err != nil {
|
||||
logrus.WithError(err).Error("Failed to load connected user")
|
||||
|
||||
if err := user.Clear(); err != nil {
|
||||
logrus.WithError(err).Error("Failed to clear user")
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) loadUser(ctx context.Context, user *vault.User) error {
|
||||
client, auth, err := bridge.api.NewClientWithRefresh(ctx, user.AuthUID(), user.AuthRef())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create API client: %w", err)
|
||||
}
|
||||
|
||||
apiUser, apiAddrs, userKR, addrKRs, err := client.UnlockSalted(ctx, user.KeyPass())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock user: %w", err)
|
||||
}
|
||||
|
||||
if err := bridge.addUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, auth.UID, auth.RefreshToken, user.KeyPass()); err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.UserID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUser adds a new user with an already salted mailbox password.
|
||||
func (bridge *Bridge) addUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) error {
|
||||
if _, ok := bridge.users[apiUser.ID]; ok {
|
||||
return ErrUserAlreadyLoggedIn
|
||||
}
|
||||
|
||||
var user *user.User
|
||||
|
||||
if slices.Contains(bridge.vault.GetUserIDs(), apiUser.ID) {
|
||||
existingUser, err := bridge.addExistingUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user = existingUser
|
||||
} else {
|
||||
newUser, err := bridge.addNewUser(ctx, client, apiUser, apiAddrs, userKR, addrKRs, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user = newUser
|
||||
}
|
||||
|
||||
go func() {
|
||||
for event := range user.GetNotifyCh() {
|
||||
switch event := event.(type) {
|
||||
case events.UserDeauth:
|
||||
if err := bridge.logoutUser(context.Background(), event.UserID, false, false); err != nil {
|
||||
logrus.WithError(err).Error("Failed to logout user")
|
||||
}
|
||||
}
|
||||
|
||||
bridge.publish(event)
|
||||
}
|
||||
}()
|
||||
|
||||
// Gluon will set the IMAP ID in the context, if known, before making requests on behalf of this user.
|
||||
client.AddPreRequestHook(func(ctx context.Context, req *resty.Request) error {
|
||||
if imapID, ok := imap.GetIMAPIDFromContext(ctx); ok {
|
||||
bridge.identifier.SetClient(imapID.Name, imapID.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
bridge.publish(events.UserLoggedIn{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) addNewUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*user.User, error) {
|
||||
vaultUser, err := bridge.vault.AddUser(apiUser.ID, apiUser.Name, authUID, authRef, saltedKeyPass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gluonKey, err := crypto.RandomToken(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imapConn, err := user.NewGluonConnector(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gluonID, err := bridge.imapServer.AddUser(ctx, imapConn, gluonKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := vaultUser.UpdateGluonData(gluonID, gluonKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.addUser(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bridge.users[apiUser.ID] = user
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (bridge *Bridge) addExistingUser(
|
||||
ctx context.Context,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
authUID, authRef string,
|
||||
saltedKeyPass []byte,
|
||||
) (*user.User, error) {
|
||||
vaultUser, err := bridge.vault.GetUser(apiUser.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := vaultUser.UpdateAuth(authUID, authRef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := vaultUser.UpdateKeyPass(saltedKeyPass); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := user.New(ctx, vaultUser, client, apiUser, apiAddrs, userKR, addrKRs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imapConn, err := user.NewGluonConnector(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := bridge.imapServer.LoadUser(ctx, imapConn, user.GluonID(), user.GluonKey()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.addUser(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bridge.users[apiUser.ID] = user
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// logoutUser closes and removes the user with the given ID.
|
||||
// If withAPI is true, the user will additionally be logged out from API.
|
||||
// If withFiles is true, the user's files will be deleted.
|
||||
func (bridge *Bridge) logoutUser(ctx context.Context, userID string, withAPI, withFiles bool) error {
|
||||
user, ok := bridge.users[userID]
|
||||
if !ok {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
vaultUser, err := bridge.vault.GetUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := bridge.imapServer.RemoveUser(ctx, vaultUser.GluonID(), withFiles); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := bridge.smtpBackend.removeUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if withAPI {
|
||||
if err := user.Logout(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.Close(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vaultUser.Clear(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(bridge.users, userID)
|
||||
|
||||
bridge.publish(events.UserLoggedOut{
|
||||
UserID: userID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUserInfo returns information about a disconnected user.
|
||||
func getUserInfo(userID, username string) UserInfo {
|
||||
return UserInfo{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
AddressMode: CombinedMode,
|
||||
}
|
||||
}
|
||||
|
||||
// getConnUserInfo returns information about a connected user.
|
||||
func getConnUserInfo(user *user.User) UserInfo {
|
||||
return UserInfo{
|
||||
Connected: true,
|
||||
UserID: user.ID(),
|
||||
Username: user.Name(),
|
||||
Addresses: user.Addresses(),
|
||||
AddressMode: CombinedMode,
|
||||
BridgePass: user.BridgePass(),
|
||||
UsedSpace: user.UsedSpace(),
|
||||
MaxSpace: user.MaxSpace(),
|
||||
}
|
||||
}
|
||||
286
internal/bridge/users_test.go
Normal file
286
internal/bridge/users_test.go
Normal file
@ -0,0 +1,286 @@
|
||||
package bridge_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func TestBridge_WithoutUsers(t *testing.T) {
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_Login(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID, err := bridge.LoginUser(ctx, username, password, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The user is now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutLogin(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// The user is now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Logout the user.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
|
||||
// The user is now disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Login the user again.
|
||||
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
require.Equal(t, userID, newUserID)
|
||||
|
||||
// The user is connected again.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeleteLogin(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// The user is now connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Delete the user.
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
|
||||
// The user is now gone.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// Login the user again.
|
||||
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
require.Equal(t, userID, newUserID)
|
||||
|
||||
// The user is connected again.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeauthLogin(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// Get a channel to receive the deauth event.
|
||||
eventCh, done := bridge.GetEvents(events.UserDeauth{})
|
||||
defer done()
|
||||
|
||||
// Deauth the user.
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
|
||||
// The user is eventually disconnected.
|
||||
require.Eventually(t, func() bool {
|
||||
return len(getConnectedUserIDs(t, bridge)) == 0
|
||||
}, 10*time.Second, time.Second)
|
||||
|
||||
// We should get a deauth event.
|
||||
require.IsType(t, events.UserDeauth{}, <-eventCh)
|
||||
|
||||
// Login the user after the disconnection.
|
||||
newUserID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
require.Equal(t, userID, newUserID)
|
||||
|
||||
// The user is connected again.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginExpireLogin(t *testing.T) {
|
||||
const authLife = 2 * time.Second
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
s.SetAuthLife(authLife)
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user. Its auth will only be valid for a short time.
|
||||
userID := must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// Wait until the auth expires.
|
||||
time.Sleep(authLife)
|
||||
|
||||
// The user will have to refresh but the logout will still succeed.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_FailToLoad(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
// Deauth the user while bridge is stopped.
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginRestart(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
})
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still connected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginLogoutRestart(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// Logout the user.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
})
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still disconnected.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_LoginDeleteRestart(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
var userID string
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// Delete the user.
|
||||
require.NoError(t, bridge.DeleteUser(ctx, userID))
|
||||
})
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The user is still gone.
|
||||
require.Empty(t, bridge.GetUserIDs())
|
||||
require.Empty(t, getConnectedUserIDs(t, bridge))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBridge_BridgePass(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withEnv(t, func(s *server.Server, locator bridge.Locator, storeKey []byte) {
|
||||
var userID, pass string
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// Login the user.
|
||||
userID = must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// Retrieve the bridge pass.
|
||||
pass = must(bridge.GetUserInfo(userID)).BridgePass
|
||||
|
||||
// Log the user out.
|
||||
require.NoError(t, bridge.LogoutUser(ctx, userID))
|
||||
|
||||
// Log the user back in.
|
||||
must(bridge.LoginUser(ctx, username, password, nil, nil))
|
||||
|
||||
// The bridge pass should be the same.
|
||||
require.Equal(t, pass, pass)
|
||||
})
|
||||
|
||||
withBridge(t, s.GetHostURL(), locator, storeKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
|
||||
// The bridge should load schizofrenic.
|
||||
require.Equal(t, []string{userID}, bridge.GetUserIDs())
|
||||
require.Equal(t, []string{userID}, getConnectedUserIDs(t, bridge))
|
||||
|
||||
// The bridge pass should be the same.
|
||||
require.Equal(t, pass, must(bridge.GetUserInfo(userID)).BridgePass)
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -15,9 +15,31 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package tls
|
||||
package certs
|
||||
|
||||
import "golang.org/x/sys/execabs"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/execabs"
|
||||
)
|
||||
|
||||
func installCert(certPEM []byte) error {
|
||||
name, err := writeToTempFile(certPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return addTrustedCert(name)
|
||||
}
|
||||
|
||||
func uninstallCert(certPEM []byte) error {
|
||||
name, err := writeToTempFile(certPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return removeTrustedCert(name)
|
||||
}
|
||||
|
||||
func addTrustedCert(certPath string) error {
|
||||
return execabs.Command( //nolint:gosec
|
||||
@ -44,10 +66,20 @@ func removeTrustedCert(certPath string) error {
|
||||
).Run()
|
||||
}
|
||||
|
||||
func (t *TLS) InstallCerts() error {
|
||||
return addTrustedCert(t.getTLSCertPath())
|
||||
}
|
||||
// writeToTempFile writes the given data to a temporary file and returns the path.
|
||||
func writeToTempFile(data []byte) (string, error) {
|
||||
f, err := os.CreateTemp("", "tls")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
func (t *TLS) UninstallCerts() error {
|
||||
return removeTrustedCert(t.getTLSCertPath())
|
||||
if _, err := f.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
@ -15,12 +15,12 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package tls
|
||||
package certs
|
||||
|
||||
func (t *TLS) InstallCerts() error {
|
||||
func installCert([]byte) error {
|
||||
return nil // Linux doesn't have a root cert store.
|
||||
}
|
||||
|
||||
func (t *TLS) UninstallCerts() error {
|
||||
func uninstallCert([]byte) error {
|
||||
return nil // Linux doesn't have a root cert store.
|
||||
}
|
||||
@ -15,12 +15,12 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package tls
|
||||
package certs
|
||||
|
||||
func (t *TLS) InstallCerts() error {
|
||||
func installCert([]byte) error {
|
||||
return nil // NOTE(GODT-986): Install certs to root cert store?
|
||||
}
|
||||
|
||||
func (t *TLS) UninstallCerts() error {
|
||||
func uninstallCert([]byte) error {
|
||||
return nil // NOTE(GODT-986): Uninstall certs from root cert store?
|
||||
}
|
||||
15
internal/certs/installer.go
Normal file
15
internal/certs/installer.go
Normal file
@ -0,0 +1,15 @@
|
||||
package certs
|
||||
|
||||
type Installer struct{}
|
||||
|
||||
func NewInstaller() *Installer {
|
||||
return &Installer{}
|
||||
}
|
||||
|
||||
func (installer *Installer) InstallCert(certPEM []byte) error {
|
||||
return installCert(certPEM)
|
||||
}
|
||||
|
||||
func (installer *Installer) UninstallCert(certPEM []byte) error {
|
||||
return uninstallCert(certPEM)
|
||||
}
|
||||
@ -15,9 +15,10 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package tls
|
||||
package certs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
@ -27,22 +28,13 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type TLS struct {
|
||||
settingsPath string
|
||||
}
|
||||
|
||||
func New(settingsPath string) *TLS {
|
||||
return &TLS{
|
||||
settingsPath: settingsPath,
|
||||
}
|
||||
}
|
||||
// ErrTLSCertExpiresSoon is returned when the TLS certificate is about to expire.
|
||||
var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon")
|
||||
|
||||
// NewTLSTemplate creates a new TLS template certificate with a random serial number.
|
||||
func NewTLSTemplate() (*x509.Certificate, error) {
|
||||
@ -69,108 +61,40 @@ func NewTLSTemplate() (*x509.Certificate, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewPEMKeyPair return a new TLS private key and certificate in PEM encoded format.
|
||||
func NewPEMKeyPair() (pemCert, pemKey []byte, err error) {
|
||||
template, err := NewTLSTemplate()
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to generate TLS template")
|
||||
}
|
||||
|
||||
// GenerateTLSCert generates a new TLS certificate and returns it as PEM.
|
||||
var GenerateCert = func(template *x509.Certificate) ([]byte, []byte, error) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to generate private key")
|
||||
}
|
||||
|
||||
pemKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create certificate")
|
||||
}
|
||||
|
||||
pemCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
certPEM := new(bytes.Buffer)
|
||||
|
||||
return pemCert, pemKey, nil
|
||||
}
|
||||
|
||||
var ErrTLSCertExpiresSoon = fmt.Errorf("TLS certificate will expire soon")
|
||||
|
||||
// getTLSCertPath returns path to certificate; used for TLS servers (IMAP, SMTP).
|
||||
func (t *TLS) getTLSCertPath() string {
|
||||
return filepath.Join(t.settingsPath, "cert.pem")
|
||||
}
|
||||
|
||||
// getTLSKeyPath returns path to private key; used for TLS servers (IMAP, SMTP).
|
||||
func (t *TLS) getTLSKeyPath() string {
|
||||
return filepath.Join(t.settingsPath, "key.pem")
|
||||
}
|
||||
|
||||
// HasCerts returns whether TLS certs have been generated.
|
||||
func (t *TLS) HasCerts() bool {
|
||||
if _, err := os.Stat(t.getTLSCertPath()); err != nil {
|
||||
return false
|
||||
if err := pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(t.getTLSKeyPath()); err != nil {
|
||||
return false
|
||||
keyPEM := new(bytes.Buffer)
|
||||
|
||||
if err := pem.Encode(keyPEM, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GenerateCerts generates certs from the given template.
|
||||
func (t *TLS) GenerateCerts(template *x509.Certificate) error {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to generate private key")
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create certificate")
|
||||
}
|
||||
|
||||
certOut, err := os.Create(t.getTLSCertPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer certOut.Close() //nolint:errcheck,gosec
|
||||
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyOut, err := os.OpenFile(t.getTLSKeyPath(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer keyOut.Close() //nolint:errcheck,gosec
|
||||
|
||||
return pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
return certPEM.Bytes(), keyPEM.Bytes(), nil
|
||||
}
|
||||
|
||||
// GetConfig tries to load TLS config or generate new one which is then returned.
|
||||
func (t *TLS) GetConfig() (*tls.Config, error) {
|
||||
c, err := tls.LoadX509KeyPair(t.getTLSCertPath(), t.getTLSKeyPath())
|
||||
func GetConfig(certPEM, keyPEM []byte) (*tls.Config, error) {
|
||||
c, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load keypair")
|
||||
}
|
||||
|
||||
return getConfigFromKeyPair(c)
|
||||
}
|
||||
|
||||
// GetConfigFromPEMKeyPair load a TLS config from PEM encoded certificate and key.
|
||||
func GetConfigFromPEMKeyPair(permCert, pemKey []byte) (*tls.Config, error) {
|
||||
c, err := tls.X509KeyPair(permCert, pemKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to load keypair")
|
||||
}
|
||||
|
||||
return getConfigFromKeyPair(c)
|
||||
}
|
||||
|
||||
func getConfigFromKeyPair(c tls.Certificate) (*tls.Config, error) {
|
||||
var err error
|
||||
c.Leaf, err = x509.ParseCertificate(c.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse certificate")
|
||||
@ -15,10 +15,10 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package tls
|
||||
package certs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -26,12 +26,6 @@ import (
|
||||
)
|
||||
|
||||
func TestGetOldConfig(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "test-tls")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create new tls object.
|
||||
tls := New(dir)
|
||||
|
||||
// Create new TLS template.
|
||||
tlsTemplate, err := NewTLSTemplate()
|
||||
require.NoError(t, err)
|
||||
@ -41,20 +35,15 @@ func TestGetOldConfig(t *testing.T) {
|
||||
tlsTemplate.NotAfter = time.Now()
|
||||
|
||||
// Generate the certs from the template.
|
||||
require.NoError(t, tls.GenerateCerts(tlsTemplate))
|
||||
certPEM, keyPEM, err := GenerateCert(tlsTemplate)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate the config from the certs -- it's going to expire soon so we don't want to use it.
|
||||
_, err = tls.GetConfig()
|
||||
_, err = GetConfig(certPEM, keyPEM)
|
||||
require.Equal(t, err, ErrTLSCertExpiresSoon)
|
||||
}
|
||||
|
||||
func TestGetValidConfig(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "test-tls")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create new tls object.
|
||||
tls := New(dir)
|
||||
|
||||
// Create new TLS template.
|
||||
tlsTemplate, err := NewTLSTemplate()
|
||||
require.NoError(t, err)
|
||||
@ -64,10 +53,11 @@ func TestGetValidConfig(t *testing.T) {
|
||||
tlsTemplate.NotAfter = time.Now().Add(2 * 365 * 24 * time.Hour)
|
||||
|
||||
// Generate the certs from the template.
|
||||
require.NoError(t, tls.GenerateCerts(tlsTemplate))
|
||||
certPEM, keyPEM, err := GenerateCert(tlsTemplate)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate the config from the certs -- it's not going to expire soon so we want to use it.
|
||||
config, err := tls.GetConfig()
|
||||
config, err := GetConfig(certPEM, keyPEM)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(config.Certificates), 1)
|
||||
|
||||
@ -77,9 +67,13 @@ func TestGetValidConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewConfig(t *testing.T) {
|
||||
pemCert, pemKey, err := NewPEMKeyPair()
|
||||
tlsTemplate, err := NewTLSTemplate()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = GetConfigFromPEMKeyPair(pemCert, pemKey)
|
||||
pemCert, pemKey, err := GenerateCert(tlsTemplate)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := tls.X509KeyPair(pemCert, pemKey)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
}
|
||||
@ -23,7 +23,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/mobileconfig"
|
||||
"golang.org/x/sys/execabs"
|
||||
)
|
||||
|
||||
70
internal/config/cache/cache.go
vendored
70
internal/config/cache/cache.go
vendored
@ -1,70 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package cache provides access to contents inside a cache directory.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/files"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
dir, version string
|
||||
}
|
||||
|
||||
func New(dir, version string) (*Cache, error) {
|
||||
if err := os.MkdirAll(filepath.Join(dir, version), 0o700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Cache{
|
||||
dir: dir,
|
||||
version: version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDBDir returns folder for db files.
|
||||
func (c *Cache) GetDBDir() string {
|
||||
return c.getCurrentCacheDir()
|
||||
}
|
||||
|
||||
// GetDefaultMessageCacheDir returns folder for cached messages files.
|
||||
func (c *Cache) GetDefaultMessageCacheDir() string {
|
||||
return filepath.Join(c.getCurrentCacheDir(), "messages")
|
||||
}
|
||||
|
||||
// GetIMAPCachePath returns path to file with IMAP status.
|
||||
func (c *Cache) GetIMAPCachePath() string {
|
||||
return filepath.Join(c.getCurrentCacheDir(), "user_info.json")
|
||||
}
|
||||
|
||||
// GetTransferDir returns folder for import-export rules files.
|
||||
func (c *Cache) GetTransferDir() string {
|
||||
return c.getCurrentCacheDir()
|
||||
}
|
||||
|
||||
// RemoveOldVersions removes any cache dirs that are not the current version.
|
||||
func (c *Cache) RemoveOldVersions() error {
|
||||
return files.Remove(c.dir).Except(c.getCurrentCacheDir()).Do()
|
||||
}
|
||||
|
||||
func (c *Cache) getCurrentCacheDir() string {
|
||||
return filepath.Join(c.dir, c.version)
|
||||
}
|
||||
69
internal/config/cache/cache_test.go
vendored
69
internal/config/cache/cache_test.go
vendored
@ -1,69 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRemoveOldVersions(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "test-cache")
|
||||
require.NoError(t, err)
|
||||
|
||||
cache, err := New(dir, "c4")
|
||||
require.NoError(t, err)
|
||||
|
||||
createFilesInDir(t, dir,
|
||||
"unexpected1.txt",
|
||||
"c1/unexpected1.txt",
|
||||
"c2/unexpected2.txt",
|
||||
"c3/unexpected3.txt",
|
||||
"something.txt",
|
||||
)
|
||||
|
||||
require.DirExists(t, filepath.Join(dir, "c4"))
|
||||
require.FileExists(t, filepath.Join(dir, "unexpected1.txt"))
|
||||
require.FileExists(t, filepath.Join(dir, "c1", "unexpected1.txt"))
|
||||
require.FileExists(t, filepath.Join(dir, "c2", "unexpected2.txt"))
|
||||
require.FileExists(t, filepath.Join(dir, "c3", "unexpected3.txt"))
|
||||
require.FileExists(t, filepath.Join(dir, "something.txt"))
|
||||
|
||||
assert.NoError(t, cache.RemoveOldVersions())
|
||||
|
||||
assert.DirExists(t, filepath.Join(dir, "c4"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "unexpected1.txt"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "c1", "unexpected1.txt"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "c2", "unexpected2.txt"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "c3", "unexpected3.txt"))
|
||||
assert.NoFileExists(t, filepath.Join(dir, "something.txt"))
|
||||
}
|
||||
|
||||
func createFilesInDir(t *testing.T, dir string, files ...string) {
|
||||
for _, target := range files {
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(filepath.Join(dir, target)), 0o700))
|
||||
|
||||
f, err := os.Create(filepath.Join(dir, target))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
}
|
||||
}
|
||||
@ -1,151 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type keyValueStore struct {
|
||||
vals map[Key]string
|
||||
path string
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
// newKeyValueStore returns loaded preferences.
|
||||
func newKeyValueStore(path string) *keyValueStore {
|
||||
p := &keyValueStore{
|
||||
path: path,
|
||||
lock: &sync.RWMutex{},
|
||||
}
|
||||
if err := p.load(); err != nil {
|
||||
logrus.WithError(err).Warn("Cannot load preferences file, creating new one")
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *keyValueStore) load() error {
|
||||
if p.vals != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
p.vals = make(map[Key]string)
|
||||
|
||||
f, err := os.Open(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint:errcheck,gosec
|
||||
|
||||
return json.NewDecoder(f).Decode(&p.vals)
|
||||
}
|
||||
|
||||
func (p *keyValueStore) save() error {
|
||||
if p.vals == nil {
|
||||
return errors.New("cannot save preferences: cache is nil")
|
||||
}
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
b, err := json.MarshalIndent(p.vals, "", "\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(p.path, b, 0o600)
|
||||
}
|
||||
|
||||
func (p *keyValueStore) setDefault(key Key, value string) {
|
||||
if p.Get(key) == "" {
|
||||
p.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *keyValueStore) Get(key Key) string {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
return p.vals[key]
|
||||
}
|
||||
|
||||
func (p *keyValueStore) GetBool(key Key) bool {
|
||||
return p.Get(key) == "true"
|
||||
}
|
||||
|
||||
func (p *keyValueStore) GetInt(key Key) int {
|
||||
if p.Get(key) == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
value, err := strconv.Atoi(p.Get(key))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("Cannot parse int")
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func (p *keyValueStore) GetFloat64(key Key) float64 {
|
||||
if p.Get(key) == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
value, err := strconv.ParseFloat(p.Get(key), 64)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("Cannot parse float64")
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func (p *keyValueStore) Set(key Key, value string) {
|
||||
p.lock.Lock()
|
||||
p.vals[key] = value
|
||||
p.lock.Unlock()
|
||||
|
||||
if err := p.save(); err != nil {
|
||||
logrus.WithError(err).Warn("Cannot save preferences")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *keyValueStore) SetBool(key Key, value bool) {
|
||||
if value {
|
||||
p.Set(key, "true")
|
||||
} else {
|
||||
p.Set(key, "false")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *keyValueStore) SetInt(key Key, value int) {
|
||||
p.Set(key, strconv.Itoa(value))
|
||||
}
|
||||
|
||||
func (p *keyValueStore) SetFloat64(key Key, value float64) {
|
||||
p.Set(key, fmt.Sprintf("%v", value))
|
||||
}
|
||||
@ -1,141 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package settings
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoadNoKeyValueStore(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestEmptyKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
r.Equal("", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestLoadBadKeyValueStore(t *testing.T) {
|
||||
r := require.New(t)
|
||||
path, clean := newTmpFile(r)
|
||||
defer clean()
|
||||
|
||||
r.NoError(os.WriteFile(path, []byte("{\"key\":\"MISSING_QUOTES"), 0o700))
|
||||
pref := newKeyValueStore(path)
|
||||
r.Equal("", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestKeyValueStor(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
r.Equal("value", pref.Get("str"))
|
||||
r.Equal("42", pref.Get("int"))
|
||||
r.Equal("true", pref.Get("bool"))
|
||||
r.Equal("t", pref.Get("falseBool"))
|
||||
}
|
||||
|
||||
func TestKeyValueStoreGetInt(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
r.Equal(0, pref.GetInt("str"))
|
||||
r.Equal(42, pref.GetInt("int"))
|
||||
r.Equal(0, pref.GetInt("bool"))
|
||||
r.Equal(0, pref.GetInt("falseBool"))
|
||||
}
|
||||
|
||||
func TestKeyValueStoreGetBool(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
r.Equal(false, pref.GetBool("str"))
|
||||
r.Equal(false, pref.GetBool("int"))
|
||||
r.Equal(true, pref.GetBool("bool"))
|
||||
r.Equal(false, pref.GetBool("falseBool"))
|
||||
}
|
||||
|
||||
func TestKeyValueStoreSetDefault(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestEmptyKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
pref.setDefault("key", "value")
|
||||
pref.setDefault("key", "othervalue")
|
||||
r.Equal("value", pref.Get("key"))
|
||||
}
|
||||
|
||||
func TestKeyValueStoreSet(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestEmptyKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
pref.Set("str", "value")
|
||||
checkSavedKeyValueStore(r, pref.path, "{\n\t\"str\": \"value\"\n}")
|
||||
}
|
||||
|
||||
func TestKeyValueStoreSetInt(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestEmptyKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
pref.SetInt("int", 42)
|
||||
checkSavedKeyValueStore(r, pref.path, "{\n\t\"int\": \"42\"\n}")
|
||||
}
|
||||
|
||||
func TestKeyValueStoreSetBool(t *testing.T) {
|
||||
r := require.New(t)
|
||||
pref, clean := newTestEmptyKeyValueStore(r)
|
||||
defer clean()
|
||||
|
||||
pref.SetBool("trueBool", true)
|
||||
pref.SetBool("falseBool", false)
|
||||
checkSavedKeyValueStore(r, pref.path, "{\n\t\"falseBool\": \"false\",\n\t\"trueBool\": \"true\"\n}")
|
||||
}
|
||||
|
||||
func newTmpFile(r *require.Assertions) (path string, clean func()) {
|
||||
tmpfile, err := os.CreateTemp("", "pref.*.json")
|
||||
r.NoError(err)
|
||||
defer r.NoError(tmpfile.Close())
|
||||
|
||||
return tmpfile.Name(), func() {
|
||||
r.NoError(os.Remove(tmpfile.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
func newTestEmptyKeyValueStore(r *require.Assertions) (*keyValueStore, func()) {
|
||||
path, clean := newTmpFile(r)
|
||||
return newKeyValueStore(path), clean
|
||||
}
|
||||
|
||||
func newTestKeyValueStore(r *require.Assertions) (*keyValueStore, func()) {
|
||||
path, clean := newTmpFile(r)
|
||||
r.NoError(os.WriteFile(path, []byte("{\"str\":\"value\",\"int\":\"42\",\"bool\":\"true\",\"falseBool\":\"t\"}"), 0o700))
|
||||
return newKeyValueStore(path), clean
|
||||
}
|
||||
|
||||
func checkSavedKeyValueStore(r *require.Assertions, path, expected string) {
|
||||
data, err := os.ReadFile(path)
|
||||
r.NoError(err)
|
||||
r.Equal(expected, string(data))
|
||||
}
|
||||
@ -1,116 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package settings provides access to persistent user settings.
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Key string
|
||||
|
||||
// Keys of preferences in JSON file.
|
||||
const (
|
||||
FirstStartKey Key = "first_time_start"
|
||||
FirstStartGUIKey Key = "first_time_start_gui"
|
||||
LastHeartbeatKey Key = "last_heartbeat"
|
||||
APIPortKey Key = "user_port_api"
|
||||
IMAPPortKey Key = "user_port_imap"
|
||||
SMTPPortKey Key = "user_port_smtp"
|
||||
SMTPSSLKey Key = "user_ssl_smtp"
|
||||
AllowProxyKey Key = "allow_proxy"
|
||||
AutostartKey Key = "autostart"
|
||||
AutoUpdateKey Key = "autoupdate"
|
||||
CookiesKey Key = "cookies"
|
||||
LastVersionKey Key = "last_used_version"
|
||||
UpdateChannelKey Key = "update_channel"
|
||||
RolloutKey Key = "rollout"
|
||||
PreferredKeychainKey Key = "preferred_keychain"
|
||||
CacheEnabledKey Key = "cache_enabled"
|
||||
CacheCompressionKey Key = "cache_compression"
|
||||
CacheLocationKey Key = "cache_location"
|
||||
CacheMinFreeAbsKey Key = "cache_min_free_abs"
|
||||
CacheMinFreeRatKey Key = "cache_min_free_rat"
|
||||
CacheConcurrencyRead Key = "cache_concurrent_read"
|
||||
CacheConcurrencyWrite Key = "cache_concurrent_write"
|
||||
IMAPWorkers Key = "imap_workers"
|
||||
FetchWorkers Key = "fetch_workers"
|
||||
AttachmentWorkers Key = "attachment_workers"
|
||||
ColorScheme Key = "color_scheme"
|
||||
RebrandingMigrationKey Key = "rebranding_migrated"
|
||||
IsAllMailVisible Key = "is_all_mail_visible"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
*keyValueStore
|
||||
|
||||
settingsPath string
|
||||
}
|
||||
|
||||
func New(settingsPath string) *Settings {
|
||||
s := &Settings{
|
||||
keyValueStore: newKeyValueStore(filepath.Join(settingsPath, "prefs.json")),
|
||||
settingsPath: settingsPath,
|
||||
}
|
||||
|
||||
s.setDefaultValues()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultIMAPPort = "1143"
|
||||
DefaultSMTPPort = "1025"
|
||||
DefaultAPIPort = "1042"
|
||||
)
|
||||
|
||||
func (s *Settings) setDefaultValues() {
|
||||
s.setDefault(FirstStartKey, "true")
|
||||
s.setDefault(FirstStartGUIKey, "true")
|
||||
s.setDefault(LastHeartbeatKey, fmt.Sprintf("%v", time.Now().YearDay()))
|
||||
s.setDefault(AllowProxyKey, "true")
|
||||
s.setDefault(AutostartKey, "true")
|
||||
s.setDefault(AutoUpdateKey, "true")
|
||||
s.setDefault(LastVersionKey, "")
|
||||
s.setDefault(UpdateChannelKey, "")
|
||||
s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint:gosec // G404 It is OK to use weak random number generator here
|
||||
s.setDefault(PreferredKeychainKey, "")
|
||||
s.setDefault(CacheEnabledKey, "true")
|
||||
s.setDefault(CacheCompressionKey, "true")
|
||||
s.setDefault(CacheLocationKey, "")
|
||||
s.setDefault(CacheMinFreeAbsKey, "250000000")
|
||||
s.setDefault(CacheMinFreeRatKey, "")
|
||||
s.setDefault(CacheConcurrencyRead, "16")
|
||||
s.setDefault(CacheConcurrencyWrite, "16")
|
||||
s.setDefault(IMAPWorkers, "16")
|
||||
s.setDefault(FetchWorkers, "16")
|
||||
s.setDefault(AttachmentWorkers, "16")
|
||||
s.setDefault(ColorScheme, "")
|
||||
|
||||
s.setDefault(APIPortKey, DefaultAPIPort)
|
||||
s.setDefault(IMAPPortKey, DefaultIMAPPort)
|
||||
s.setDefault(SMTPPortKey, DefaultSMTPPort)
|
||||
|
||||
// By default, stick to STARTTLS. If the user uses catalina+applemail they'll have to change to SSL.
|
||||
s.setDefault(SMTPSSLKey, "false")
|
||||
|
||||
s.setDefault(IsAllMailVisible, "true")
|
||||
}
|
||||
@ -18,17 +18,35 @@
|
||||
// Package constants contains variables that are set via ldflags during build.
|
||||
package constants
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
const VendorName = "protonmail"
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
// Version of the build.
|
||||
// Full app name (to show to the user).
|
||||
FullAppName = ""
|
||||
|
||||
// ConfigName determines the name of the location where bridge stores config files.
|
||||
ConfigName = "bridge"
|
||||
|
||||
// UpdateName is the name of the product appearing in the update URL.
|
||||
UpdateName = "bridge"
|
||||
|
||||
// KeyChainName is the name of the entry in the OS keychain.
|
||||
KeyChainName = "bridge"
|
||||
|
||||
// Version of the build.
|
||||
Version = ""
|
||||
Version = "2.3.0+git"
|
||||
|
||||
// AppVersion is the full rendered version of the app (to be used in request headers).
|
||||
AppVersion = getAPIOS() + cases.Title(language.Und).String(ConfigName) + "_" + Version
|
||||
|
||||
// Revision is current hash of the build.
|
||||
Revision = ""
|
||||
@ -36,9 +54,31 @@ var (
|
||||
// BuildTime stamp of the build.
|
||||
BuildTime = ""
|
||||
|
||||
// BuildVersion is derived from LongVersion and BuildTime.
|
||||
BuildVersion = fmt.Sprintf("%v (%v) %v", Version, Revision, BuildTime)
|
||||
|
||||
// DSNSentry client keys to be able to report crashes to Sentry.
|
||||
DSNSentry = ""
|
||||
|
||||
// BuildVersion is derived from LongVersion and BuildTime.
|
||||
BuildVersion = fmt.Sprintf("%v (%v) %v", Version, Revision, BuildTime)
|
||||
// APIHost is our API address.
|
||||
APIHost = "https://api.protonmail.ch"
|
||||
|
||||
// The host name of the bridge server.
|
||||
Host = "127.0.0.1"
|
||||
)
|
||||
|
||||
func getAPIOS() string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return "macOS"
|
||||
|
||||
case "linux":
|
||||
return "Linux"
|
||||
|
||||
case "windows":
|
||||
return "Windows"
|
||||
|
||||
default:
|
||||
return "Linux"
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,8 +22,5 @@ package constants
|
||||
|
||||
import "time"
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
// UpdateCheckInterval defines how often we check for new version.
|
||||
UpdateCheckInterval = time.Hour //nolint:gochecknoglobals
|
||||
)
|
||||
// UpdateCheckInterval defines how often we check for new version.
|
||||
const UpdateCheckInterval = time.Hour
|
||||
|
||||
@ -22,8 +22,5 @@ package constants
|
||||
|
||||
import "time"
|
||||
|
||||
//nolint:gochecknoglobals
|
||||
var (
|
||||
// UpdateCheckInterval defines how often we check for new version
|
||||
UpdateCheckInterval = time.Duration(5 * time.Minute)
|
||||
)
|
||||
// UpdateCheckInterval defines how often we check for new version
|
||||
const UpdateCheckInterval = time.Duration(5 * time.Minute)
|
||||
|
||||
@ -26,28 +26,31 @@ import (
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
)
|
||||
|
||||
type cookiesByHost map[string][]*http.Cookie
|
||||
|
||||
type Persister interface {
|
||||
GetCookies() ([]byte, error)
|
||||
SetCookies([]byte) error
|
||||
}
|
||||
|
||||
// Jar implements http.CookieJar by wrapping the standard library's cookiejar.Jar.
|
||||
// The jar uses a pantry to load cookies at startup and save cookies when set.
|
||||
type Jar struct {
|
||||
jar *cookiejar.Jar
|
||||
settings *settings.Settings
|
||||
cookies cookiesByHost
|
||||
locker sync.Locker
|
||||
jar *cookiejar.Jar
|
||||
persister Persister
|
||||
cookies cookiesByHost
|
||||
locker sync.Locker
|
||||
}
|
||||
|
||||
func NewCookieJar(s *settings.Settings) (*Jar, error) {
|
||||
func NewCookieJar(persister Persister) (*Jar, error) {
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookiesByHost, err := loadCookies(s)
|
||||
cookiesByHost, err := loadCookies(persister)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -62,10 +65,10 @@ func NewCookieJar(s *settings.Settings) (*Jar, error) {
|
||||
}
|
||||
|
||||
return &Jar{
|
||||
jar: jar,
|
||||
settings: s,
|
||||
cookies: cookiesByHost,
|
||||
locker: &sync.Mutex{},
|
||||
jar: jar,
|
||||
persister: persister,
|
||||
cookies: cookiesByHost,
|
||||
locker: &sync.Mutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -101,16 +104,17 @@ func (j *Jar) PersistCookies() error {
|
||||
return err
|
||||
}
|
||||
|
||||
j.settings.Set(settings.CookiesKey, string(rawCookies))
|
||||
|
||||
return nil
|
||||
return j.persister.SetCookies(rawCookies)
|
||||
}
|
||||
|
||||
// loadCookies loads all non-expired cookies from disk.
|
||||
func loadCookies(s *settings.Settings) (cookiesByHost, error) {
|
||||
rawCookies := s.Get(settings.CookiesKey)
|
||||
func loadCookies(persister Persister) (cookiesByHost, error) {
|
||||
rawCookies, err := persister.GetCookies()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rawCookies == "" {
|
||||
if len(rawCookies) == 0 {
|
||||
return make(cookiesByHost), nil
|
||||
}
|
||||
|
||||
|
||||
@ -18,13 +18,15 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -37,7 +39,7 @@ func TestJarGetSet(t *testing.T) {
|
||||
})
|
||||
defer ts.Close()
|
||||
|
||||
client, _ := getClientWithJar(t, newFakeSettings())
|
||||
client, _ := getClientWithJar(t, newTestPersister(t))
|
||||
|
||||
// Hit a server that sets some cookies.
|
||||
setRes, err := client.Get(ts.URL + "/set")
|
||||
@ -63,7 +65,7 @@ func TestJarLoad(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
// This will be our "persistent storage" from which the cookie jar should load cookies.
|
||||
s := newFakeSettings()
|
||||
s := newTestPersister(t)
|
||||
|
||||
// This client saves cookies to persistent storage.
|
||||
oldClient, jar := getClientWithJar(t, s)
|
||||
@ -98,7 +100,7 @@ func TestJarExpiry(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
// This will be our "persistent storage" from which the cookie jar should load cookies.
|
||||
s := newFakeSettings()
|
||||
s := newTestPersister(t)
|
||||
|
||||
// This client saves cookies to persistent storage.
|
||||
oldClient, jar1 := getClientWithJar(t, s)
|
||||
@ -122,9 +124,12 @@ func TestJarExpiry(t *testing.T) {
|
||||
// Save the cookies (expired ones were cleared out).
|
||||
require.NoError(t, jar2.PersistCookies())
|
||||
|
||||
assert.Contains(t, s.Get(settings.CookiesKey), "TestName1")
|
||||
assert.NotContains(t, s.Get(settings.CookiesKey), "TestName2")
|
||||
assert.Contains(t, s.Get(settings.CookiesKey), "TestName3")
|
||||
cookies, err := s.GetCookies()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, string(cookies), "TestName1")
|
||||
assert.NotContains(t, string(cookies), "TestName2")
|
||||
assert.Contains(t, string(cookies), "TestName3")
|
||||
}
|
||||
|
||||
type testCookie struct {
|
||||
@ -132,8 +137,8 @@ type testCookie struct {
|
||||
maxAge int
|
||||
}
|
||||
|
||||
func getClientWithJar(t *testing.T, s *settings.Settings) (*http.Client, *Jar) {
|
||||
jar, err := NewCookieJar(s)
|
||||
func getClientWithJar(t *testing.T, persister Persister) (*http.Client, *Jar) {
|
||||
jar, err := NewCookieJar(persister)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &http.Client{Jar: jar}, jar
|
||||
@ -168,12 +173,26 @@ func getTestServer(t *testing.T, wantCookies []testCookie) *httptest.Server {
|
||||
return httptest.NewServer(mux)
|
||||
}
|
||||
|
||||
// newFakeSettings creates a temporary folder for files.
|
||||
func newFakeSettings() *settings.Settings {
|
||||
dir, err := os.MkdirTemp("", "test-settings")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
type testPersister struct {
|
||||
path string
|
||||
}
|
||||
|
||||
func newTestPersister(tb testing.TB) *testPersister {
|
||||
path := filepath.Join(tb.TempDir(), "cookies.json")
|
||||
|
||||
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
|
||||
if err := os.WriteFile(path, []byte{}, 0600); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return settings.New(dir)
|
||||
return &testPersister{path: path}
|
||||
}
|
||||
|
||||
func (p *testPersister) GetCookies() ([]byte, error) {
|
||||
return os.ReadFile(p.path)
|
||||
}
|
||||
|
||||
func (p *testPersister) SetCookies(rawCookies []byte) error {
|
||||
return os.WriteFile(p.path, rawCookies, 0600)
|
||||
}
|
||||
|
||||
@ -41,14 +41,11 @@ func (h *Handler) AddRecoveryAction(action RecoveryAction) *Handler {
|
||||
func (h *Handler) HandlePanic() {
|
||||
sentry.SkipDuringUnwind()
|
||||
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, action := range h.actions {
|
||||
if err := action(r); err != nil {
|
||||
logrus.WithError(err).Error("Failed to execute recovery action")
|
||||
if r := recover(); r != nil {
|
||||
for _, action := range h.actions {
|
||||
if err := action(r); err != nil {
|
||||
logrus.WithError(err).Error("Failed to execute recovery action")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
77
internal/dialer/dialer_basic.go
Normal file
77
internal/dialer/dialer_basic.go
Normal file
@ -0,0 +1,77 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
|
||||
}
|
||||
|
||||
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
|
||||
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
|
||||
return &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 5 * time.Minute,
|
||||
|
||||
ExpectContinueTimeout: 500 * time.Millisecond,
|
||||
|
||||
// GODT-126: this was initially 10s but logs from users showed a significant number
|
||||
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
|
||||
// Bumping to 30s for now to avoid this problem.
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
|
||||
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
|
||||
// to 30 seconds for the TLS handshake to take place.
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// BasicTLSDialer implements TLSDialer.
|
||||
type BasicTLSDialer struct {
|
||||
hostURL string
|
||||
}
|
||||
|
||||
// NewBasicTLSDialer returns a new BasicTLSDialer.
|
||||
func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
|
||||
return &BasicTLSDialer{
|
||||
hostURL: hostURL,
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLS returns a connection to the given address using the given network.
|
||||
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||
return (&tls.Dialer{
|
||||
NetDialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
Config: &tls.Config{
|
||||
InsecureSkipVerify: address != d.hostURL,
|
||||
},
|
||||
}).DialContext(ctx, network, address)
|
||||
}
|
||||
114
internal/dialer/dialer_pinning.go
Normal file
114
internal/dialer/dialer_pinning.go
Normal file
@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
// TrustedAPIPins contains trusted public keys of the protonmail API and proxies.
|
||||
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;).
|
||||
var TrustedAPIPins = []string{ //nolint:gochecknoglobals
|
||||
// api.protonmail.ch
|
||||
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
|
||||
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot backup
|
||||
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold backup
|
||||
|
||||
// protonmail.com
|
||||
// \todo remove when sure no one is using it.
|
||||
`pin-sha256="8joiNBdqaYiQpKskgtkJsqRxF7zN0C0aqfi8DacknnI="`, // current
|
||||
`pin-sha256="JMI8yrbc6jB1FYGyyWRLFTmDNgIszrNEMGlgy972e7w="`, // hot backup
|
||||
`pin-sha256="Iu44zU84EOCZ9vx/vz67/MRVrxF1IO4i4NIa8ETwiIY="`, // cold backup
|
||||
|
||||
// proton.me
|
||||
`pin-sha256="CT56BhOTmj5ZIPgb/xD5mH8rY3BLo/MlhP7oPyJUEDo="`, // current
|
||||
`pin-sha256="35Dx28/uzN3LeltkCBQ8RHK0tlNSa2kCpCRGNp34Gxc="`, // hot backup
|
||||
`pin-sha256="qYIukVc63DEITct8sFT7ebIq5qsWmuscaIKeJx+5J5A="`, // col backup
|
||||
|
||||
// proxies
|
||||
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // main
|
||||
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // backup 1
|
||||
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // backup 2
|
||||
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // backup 3
|
||||
}
|
||||
|
||||
// TLSReportURI is the address where TLS reports should be sent.
|
||||
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
|
||||
|
||||
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
|
||||
// to report errors if the fingerprint check fails.
|
||||
type PinningTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
pinChecker PinChecker
|
||||
reporter Reporter
|
||||
tlsIssueCh chan struct{}
|
||||
}
|
||||
|
||||
// Reporter is used to report TLS issues.
|
||||
type Reporter interface {
|
||||
ReportCertIssue(reportURI, host, port string, state tls.ConnectionState)
|
||||
}
|
||||
|
||||
// PinChecker is used to check TLS keys of connections.
|
||||
type PinChecker interface {
|
||||
CheckCertificate(conn net.Conn) error
|
||||
}
|
||||
|
||||
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
|
||||
// which present known certificates.
|
||||
// It checks pins using the given pinChecker and reports issues using the given reporter.
|
||||
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
|
||||
return &PinningTLSDialer{
|
||||
dialer: dialer,
|
||||
pinChecker: pinChecker,
|
||||
reporter: reporter,
|
||||
tlsIssueCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
|
||||
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := p.dialer.DialTLSContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.pinChecker.CheckCertificate(conn); err != nil {
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
|
||||
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
|
||||
}
|
||||
|
||||
p.tlsIssueCh <- struct{}{}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// GetTLSIssueCh returns a channel which notifies when a TLS issue is reported.
|
||||
func (p *PinningTLSDialer) GetTLSIssueCh() <-chan struct{} {
|
||||
return p.tlsIssueCh
|
||||
}
|
||||
67
internal/dialer/dialer_pinning_checker.go
Normal file
67
internal/dialer/dialer_pinning_checker.go
Normal file
@ -0,0 +1,67 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/algo"
|
||||
)
|
||||
|
||||
// ErrTLSMismatch indicates that no TLS fingerprint match could be found.
|
||||
var ErrTLSMismatch = errors.New("no TLS fingerprint match found")
|
||||
|
||||
type TLSPinChecker struct {
|
||||
trustedPins []string
|
||||
}
|
||||
|
||||
func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
|
||||
return &TLSPinChecker{
|
||||
trustedPins: trustedPins,
|
||||
}
|
||||
}
|
||||
|
||||
// checkCertificate returns whether the connection presents a known TLS certificate.
|
||||
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return errors.New("connection is not a TLS connection")
|
||||
}
|
||||
|
||||
connState := tlsConn.ConnectionState()
|
||||
|
||||
for _, peerCert := range connState.PeerCertificates {
|
||||
fingerprint := certFingerprint(peerCert)
|
||||
|
||||
for _, pin := range p.trustedPins {
|
||||
if pin == fingerprint {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ErrTLSMismatch
|
||||
}
|
||||
|
||||
func certFingerprint(cert *x509.Certificate) string {
|
||||
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
|
||||
}
|
||||
118
internal/dialer/dialer_pinning_report.go
Normal file
118
internal/dialer/dialer_pinning_report.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
|
||||
// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
|
||||
type tlsReport struct {
|
||||
// DateTime of observed pin validation in time.RFC3339 format.
|
||||
DateTime string `json:"date-time"`
|
||||
|
||||
// Hostname to which the UA made original request that failed pin validation.
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// Port to which the UA made original request that failed pin validation.
|
||||
Port string `json:"port"`
|
||||
|
||||
// EffectiveExpirationDate for noted pins in time.RFC3339 format.
|
||||
EffectiveExpirationDate string `json:"effective-expiration-date"`
|
||||
|
||||
// IncludeSubdomains indicates whether or not the UA has noted the
|
||||
// includeSubDomains directive for the Known Pinned Host.
|
||||
IncludeSubdomains bool `json:"include-subdomains"`
|
||||
|
||||
// NotedHostname indicates the hostname that the UA noted when it noted
|
||||
// the Known Pinned Host. This field allows operators to understand why
|
||||
// Pin Validation was performed for, e.g., foo.example.com when the
|
||||
// noted Known Pinned Host was example.com with includeSubDomains set.
|
||||
NotedHostname string `json:"noted-hostname"`
|
||||
|
||||
// ServedCertificateChain is the certificate chain, as served by
|
||||
// the Known Pinned Host during TLS session setup. It is provided as an
|
||||
// array of strings; each string pem1, ... pemN is the Privacy-Enhanced
|
||||
// Mail (PEM) representation of each X.509 certificate as described in
|
||||
// [RFC7468].
|
||||
ServedCertificateChain []string `json:"served-certificate-chain"`
|
||||
|
||||
// ValidatedCertificateChain is the certificate chain, as
|
||||
// constructed by the UA during certificate chain verification. (This
|
||||
// may differ from the served-certificate-chain.) It is provided as an
|
||||
// array of strings; each string pem1, ... pemN is the PEM
|
||||
// representation of each X.509 certificate as described in [RFC7468].
|
||||
// UAs that build certificate chains in more than one way during the
|
||||
// validation process SHOULD send the last chain built. In this way,
|
||||
// they can avoid keeping too much state during the validation process.
|
||||
ValidatedCertificateChain []string `json:"validated-certificate-chain"`
|
||||
|
||||
// The known-pins are the Pins that the UA has noted for the Known
|
||||
// Pinned Host. They are provided as an array of strings with the
|
||||
// syntax: known-pin = token "=" quoted-string
|
||||
// e.g.:
|
||||
// ```
|
||||
// "known-pins": [
|
||||
// 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
|
||||
// "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
|
||||
// ]
|
||||
// ```
|
||||
KnownPins []string `json:"known-pins"`
|
||||
|
||||
// AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
|
||||
AppVersion string `json:"app-version"`
|
||||
}
|
||||
|
||||
// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
|
||||
// Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
|
||||
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
|
||||
report = tlsReport{
|
||||
Hostname: host,
|
||||
Port: port,
|
||||
NotedHostname: server,
|
||||
ServedCertificateChain: certChain,
|
||||
KnownPins: knownPins,
|
||||
AppVersion: appVersion,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// sendReport posts the given TLS report to the standard TLS Report URI.
|
||||
func sendReport(report tlsReport, userAgent, appVersion, hostURL, remoteURI string) error {
|
||||
now := time.Now()
|
||||
|
||||
report.DateTime = now.Format(time.RFC3339)
|
||||
report.EffectiveExpirationDate = now.Add(365 * 24 * time.Hour).Format(time.RFC3339)
|
||||
|
||||
if _, err := resty.New().
|
||||
SetTransport(CreateTransportWithDialer(NewBasicTLSDialer(hostURL))).
|
||||
SetHeader("User-Agent", userAgent).
|
||||
SetHeader("x-pm-appversion", appVersion).
|
||||
NewRequest().
|
||||
SetBody(report).
|
||||
Post(remoteURI); err != nil {
|
||||
return fmt.Errorf("failed to send TLS report: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
115
internal/dialer/dialer_pinning_reporter.go
Normal file
115
internal/dialer/dialer_pinning_reporter.go
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type sentReport struct {
|
||||
r tlsReport
|
||||
t time.Time
|
||||
}
|
||||
|
||||
type TLSReporter struct {
|
||||
hostURL string
|
||||
appVersion string
|
||||
userAgent *useragent.UserAgent
|
||||
trustedPins []string
|
||||
sentReports []sentReport
|
||||
}
|
||||
|
||||
func NewTLSReporter(hostURL, appVersion string, userAgent *useragent.UserAgent, trustedPins []string) *TLSReporter {
|
||||
return &TLSReporter{
|
||||
hostURL: hostURL,
|
||||
appVersion: appVersion,
|
||||
userAgent: userAgent,
|
||||
trustedPins: trustedPins,
|
||||
}
|
||||
}
|
||||
|
||||
// reportCertIssue reports a TLS key mismatch.
|
||||
func (r *TLSReporter) ReportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
|
||||
var certChain []string
|
||||
|
||||
if len(connState.VerifiedChains) > 0 {
|
||||
certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1])
|
||||
} else {
|
||||
certChain = marshalCert7468(connState.PeerCertificates)
|
||||
}
|
||||
|
||||
report := newTLSReport(host, port, connState.ServerName, certChain, r.trustedPins, r.appVersion)
|
||||
|
||||
if !r.hasRecentlySentReport(report) {
|
||||
r.recordReport(report)
|
||||
|
||||
if err := sendReport(report, r.userAgent.GetUserAgent(), r.appVersion, r.hostURL, remoteURI); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send TLS pinning report")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
|
||||
func (r *TLSReporter) hasRecentlySentReport(report tlsReport) bool {
|
||||
var validReports []sentReport
|
||||
|
||||
for _, r := range r.sentReports {
|
||||
if time.Since(r.t) < 24*time.Hour {
|
||||
validReports = append(validReports, r)
|
||||
}
|
||||
}
|
||||
|
||||
r.sentReports = validReports
|
||||
|
||||
for _, r := range r.sentReports {
|
||||
if cmp.Equal(report, r.r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordReport records the given report and the current time so we can check whether we recently sent this report.
|
||||
func (r *TLSReporter) recordReport(report tlsReport) {
|
||||
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
|
||||
}
|
||||
|
||||
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
|
||||
var buffer bytes.Buffer
|
||||
for _, cert := range certs {
|
||||
if err := pem.Encode(&buffer, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}); err != nil {
|
||||
logrus.WithError(err).Error("Failed to encode TLS certificate")
|
||||
}
|
||||
pemCerts = append(pemCerts, buffer.String())
|
||||
buffer.Reset()
|
||||
}
|
||||
|
||||
return pemCerts
|
||||
}
|
||||
59
internal/dialer/dialer_pinning_reporter_test.go
Normal file
59
internal/dialer/dialer_pinning_reporter_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTLSReporter_DoubleReport(t *testing.T) {
|
||||
reportCounter := 0
|
||||
|
||||
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reportCounter++
|
||||
}))
|
||||
|
||||
r := NewTLSReporter("hostURL", "appVersion", useragent.New(), TrustedAPIPins)
|
||||
|
||||
// Report the same issue many times.
|
||||
for i := 0; i < 10; i++ {
|
||||
r.ReportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
|
||||
}
|
||||
|
||||
// We should only report once.
|
||||
assert.Eventually(t, func() bool {
|
||||
return reportCounter == 1
|
||||
}, time.Second, time.Millisecond)
|
||||
|
||||
// If we then report something else many times.
|
||||
for i := 0; i < 10; i++ {
|
||||
r.ReportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
|
||||
}
|
||||
|
||||
// We should get a second report.
|
||||
assert.Eventually(t, func() bool {
|
||||
return reportCounter == 2
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
157
internal/dialer/dialer_pinning_test.go
Normal file
157
internal/dialer/dialer_pinning_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
)
|
||||
|
||||
func getRootURL() string {
|
||||
return "https://api.protonmail.ch"
|
||||
}
|
||||
|
||||
func TestTLSPinValid(t *testing.T) {
|
||||
called, _, _, _, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
|
||||
|
||||
checkTLSIssueHandler(t, 0, called)
|
||||
}
|
||||
|
||||
func TestTLSPinBackup(t *testing.T) {
|
||||
called, _, _, checker, cm := createClientWithPinningDialer(getRootURL())
|
||||
copyTrustedPins(checker)
|
||||
checker.trustedPins[1] = checker.trustedPins[0]
|
||||
checker.trustedPins[0] = ""
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
|
||||
|
||||
checkTLSIssueHandler(t, 0, called)
|
||||
}
|
||||
|
||||
func TestTLSPinInvalid(t *testing.T) {
|
||||
s := server.NewTLS()
|
||||
defer s.Close()
|
||||
|
||||
called, _, _, _, cm := createClientWithPinningDialer(s.GetHostURL())
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
|
||||
|
||||
checkTLSIssueHandler(t, 1, called)
|
||||
}
|
||||
|
||||
func TestTLSPinNoMatch(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
copyTrustedPins(checker)
|
||||
for i := 0; i < len(checker.trustedPins); i++ {
|
||||
checker.trustedPins[i] = "testing"
|
||||
}
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "password")
|
||||
|
||||
// Check that it will be reported only once per session, but notified every time.
|
||||
r.Equal(t, 1, len(reporter.sentReports))
|
||||
checkTLSIssueHandler(t, 2, called)
|
||||
}
|
||||
|
||||
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _, _, _ := createClientWithPinningDialer("")
|
||||
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
|
||||
r.Error(t, err, "expected dial to fail because of wrong public key")
|
||||
}
|
||||
|
||||
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _, checker, _ := createClientWithPinningDialer("")
|
||||
copyTrustedPins(checker)
|
||||
checker.trustedPins = append(checker.trustedPins, `pin-sha256="LwnIKjNLV3z243ap8y0yXNPghsqE76J08Eq3COvUt2E="`)
|
||||
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
|
||||
}
|
||||
|
||||
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _, checker, _ := createClientWithPinningDialer("")
|
||||
copyTrustedPins(checker)
|
||||
checker.trustedPins = append(checker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
||||
_, err := dialer.DialTLSContext(context.Background(), "tcp", "self-signed.badssl.com:443")
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
|
||||
}
|
||||
|
||||
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *TLSReporter, *TLSPinChecker, *liteapi.Manager) {
|
||||
called := 0
|
||||
|
||||
reporter := NewTLSReporter(hostURL, "appVersion", useragent.New(), TrustedAPIPins)
|
||||
checker := NewTLSPinChecker(TrustedAPIPins)
|
||||
dialer := NewPinningTLSDialer(NewBasicTLSDialer(hostURL), reporter, checker)
|
||||
|
||||
go func() {
|
||||
for range dialer.GetTLSIssueCh() {
|
||||
called++
|
||||
}
|
||||
}()
|
||||
|
||||
return &called, dialer, reporter, checker, liteapi.New(
|
||||
liteapi.WithHostURL(hostURL),
|
||||
liteapi.WithTransport(CreateTransportWithDialer(dialer)),
|
||||
)
|
||||
}
|
||||
|
||||
func copyTrustedPins(pinChecker *TLSPinChecker) {
|
||||
copiedPins := make([]string, len(pinChecker.trustedPins))
|
||||
copy(copiedPins, pinChecker.trustedPins)
|
||||
pinChecker.trustedPins = copiedPins
|
||||
}
|
||||
|
||||
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
|
||||
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
|
||||
a.Eventually(
|
||||
t,
|
||||
func() bool {
|
||||
if wantCalledAtLeast == 0 {
|
||||
return *called == 0
|
||||
}
|
||||
// Dialer can do more attempts resulting in more calls.
|
||||
return *called >= wantCalledAtLeast
|
||||
},
|
||||
time.Second,
|
||||
10*time.Millisecond,
|
||||
)
|
||||
// Repeated again so it generates nice message.
|
||||
if wantCalledAtLeast == 0 {
|
||||
r.Equal(t, 0, *called)
|
||||
} else {
|
||||
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
|
||||
}
|
||||
}
|
||||
152
internal/dialer/dialer_proxy.go
Normal file
152
internal/dialer/dialer_proxy.go
Normal file
@ -0,0 +1,152 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var ErrNoConnection = errors.New("no connection")
|
||||
|
||||
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
|
||||
type ProxyTLSDialer struct {
|
||||
dialer TLSDialer
|
||||
|
||||
locker sync.RWMutex
|
||||
directAddress string
|
||||
proxyAddress string
|
||||
allowProxy bool
|
||||
proxyProvider *proxyProvider
|
||||
proxyUseDuration time.Duration
|
||||
}
|
||||
|
||||
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
|
||||
func NewProxyTLSDialer(dialer TLSDialer, hostURL string) *ProxyTLSDialer {
|
||||
return &ProxyTLSDialer{
|
||||
dialer: dialer,
|
||||
locker: sync.RWMutex{},
|
||||
directAddress: formatAsAddress(hostURL),
|
||||
proxyAddress: formatAsAddress(hostURL),
|
||||
proxyProvider: newProxyProvider(dialer, hostURL, DoHProviders),
|
||||
proxyUseDuration: proxyUseDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// formatAsAddress returns URL as `host:port` for easy comparison in DialTLS.
|
||||
func formatAsAddress(rawURL string) string {
|
||||
url, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
// This means wrong configuration.
|
||||
// Developer should get feedback right away.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
host := url.Host
|
||||
if host == "" {
|
||||
host = url.Path
|
||||
}
|
||||
|
||||
port := "443"
|
||||
if url.Scheme == "http" {
|
||||
port = "80"
|
||||
}
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
|
||||
func (d *ProxyTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if address == d.directAddress {
|
||||
address = d.proxyAddress
|
||||
}
|
||||
|
||||
conn, err := d.dialer.DialTLSContext(ctx, network, address)
|
||||
if err == nil || !d.allowProxy {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
if err := d.switchToReachableServer(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.dialer.DialTLSContext(ctx, network, d.proxyAddress)
|
||||
}
|
||||
|
||||
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
|
||||
func (d *ProxyTLSDialer) switchToReachableServer() error {
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
logrus.Info("Attempting to switch to a proxy")
|
||||
|
||||
proxy, err := d.proxyProvider.findReachableServer()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find a usable proxy")
|
||||
}
|
||||
|
||||
proxyAddress := formatAsAddress(proxy)
|
||||
|
||||
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
|
||||
if proxyAddress == d.directAddress {
|
||||
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
|
||||
d.proxyAddress = proxyAddress
|
||||
return ErrNoConnection
|
||||
}
|
||||
|
||||
logrus.WithField("proxy", proxyAddress).Info("Switching to a proxy")
|
||||
|
||||
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
|
||||
// This means we want to disable it again in 24 hours.
|
||||
if d.proxyAddress == d.directAddress {
|
||||
go func() {
|
||||
<-time.After(d.proxyUseDuration)
|
||||
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
d.proxyAddress = d.directAddress
|
||||
}()
|
||||
}
|
||||
|
||||
d.proxyAddress = proxyAddress
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowProxy allows the dialer to switch to a proxy if need be.
|
||||
func (d *ProxyTLSDialer) AllowProxy() {
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
d.allowProxy = true
|
||||
}
|
||||
|
||||
// DisallowProxy prevents the dialer from switching to a proxy if need be.
|
||||
func (d *ProxyTLSDialer) DisallowProxy() {
|
||||
d.locker.Lock()
|
||||
defer d.locker.Unlock()
|
||||
|
||||
d.allowProxy = false
|
||||
d.proxyAddress = d.directAddress
|
||||
}
|
||||
256
internal/dialer/dialer_proxy_provider.go
Normal file
256
internal/dialer/dialer_proxy_provider.go
Normal file
@ -0,0 +1,256 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
proxyUseDuration = 24 * time.Hour
|
||||
proxyLookupWait = 5 * time.Second
|
||||
proxyCacheRefreshTimeout = 20 * time.Second
|
||||
proxyDoHTimeout = 20 * time.Second
|
||||
proxyCanReachTimeout = 20 * time.Second
|
||||
|
||||
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
|
||||
Quad9Provider = "https://dns11.quad9.net/dns-query"
|
||||
Quad9PortProvider = "https://dns11.quad9.net:5053/dns-query"
|
||||
GoogleProvider = "https://dns.google/dns-query"
|
||||
)
|
||||
|
||||
var DoHProviders = []string{ //nolint:gochecknoglobals
|
||||
Quad9Provider,
|
||||
Quad9PortProvider,
|
||||
GoogleProvider,
|
||||
}
|
||||
|
||||
// proxyProvider manages known proxies.
|
||||
type proxyProvider struct {
|
||||
dialer TLSDialer
|
||||
|
||||
hostURL string
|
||||
|
||||
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
|
||||
dohLookup func(ctx context.Context, query, provider string) (urls []string, err error)
|
||||
|
||||
providers []string // List of known doh providers.
|
||||
query string // The query string used to find proxies.
|
||||
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
|
||||
|
||||
cacheRefreshTimeout time.Duration
|
||||
dohTimeout time.Duration
|
||||
canReachTimeout time.Duration
|
||||
|
||||
lastLookup time.Time // The time at which we last attempted to find a proxy.
|
||||
}
|
||||
|
||||
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
|
||||
// to retrieve DNS records for the given query string.
|
||||
func newProxyProvider(dialer TLSDialer, hostURL string, providers []string) (p *proxyProvider) {
|
||||
p = &proxyProvider{
|
||||
dialer: dialer,
|
||||
hostURL: hostURL,
|
||||
providers: providers,
|
||||
query: proxyQuery,
|
||||
cacheRefreshTimeout: proxyCacheRefreshTimeout,
|
||||
dohTimeout: proxyDoHTimeout,
|
||||
canReachTimeout: proxyCanReachTimeout,
|
||||
}
|
||||
|
||||
// Use the default DNS lookup method; this can be overridden if necessary.
|
||||
p.dohLookup = p.defaultDoHLookup
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// findReachableServer returns a working API server (either proxy or standard API).
|
||||
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
|
||||
logrus.Debug("Trying to find a reachable server")
|
||||
|
||||
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
|
||||
return "", errors.New("not looking for a proxy, too soon")
|
||||
}
|
||||
|
||||
p.lastLookup = time.Now()
|
||||
|
||||
// We use a waitgroup to wait for both
|
||||
// a) the check whether the API is reachable, and
|
||||
// b) the DoH queries.
|
||||
// This is because the Alternative Routes v2 spec says:
|
||||
// Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished)
|
||||
var wg sync.WaitGroup
|
||||
var apiReachable bool
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
apiReachable = p.canReach(p.hostURL)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err = p.refreshProxyCache()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if apiReachable {
|
||||
proxy = p.hostURL
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, url := range p.proxyCache {
|
||||
if p.canReach(url) {
|
||||
proxy = url
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no reachable server could be found")
|
||||
}
|
||||
|
||||
// refreshProxyCache loads the latest proxies from the known providers.
|
||||
// If the process takes longer than proxyCacheRefreshTimeout, an error is returned.
|
||||
func (p *proxyProvider) refreshProxyCache() error {
|
||||
logrus.Info("Refreshing proxy cache")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
resultChan := make(chan []string)
|
||||
|
||||
go func() {
|
||||
for _, provider := range p.providers {
|
||||
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
|
||||
resultChan <- proxies
|
||||
return
|
||||
}
|
||||
}
|
||||
// If no dohLoopkup worked, cancel right after it's done to not
|
||||
// block refreshing for the whole cacheRefreshTimeout.
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
p.proxyCache = result
|
||||
return nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return errors.New("timed out while refreshing proxy cache")
|
||||
}
|
||||
}
|
||||
|
||||
// canReach returns whether we can reach the given url.
|
||||
func (p *proxyProvider) canReach(url string) bool {
|
||||
logrus.WithField("url", url).Debug("Trying to ping proxy")
|
||||
|
||||
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
|
||||
url = "https://" + url
|
||||
}
|
||||
|
||||
pinger := resty.New().
|
||||
SetBaseURL(url).
|
||||
SetTimeout(p.canReachTimeout).
|
||||
SetTransport(CreateTransportWithDialer(p.dialer))
|
||||
|
||||
if _, err := pinger.R().Get("/tests/ping"); err != nil {
|
||||
logrus.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup.
|
||||
// It looks up DNS TXT records for the given query URL using the given DoH provider.
|
||||
// It returns a list of all found TXT records.
|
||||
// If the whole process takes more than proxyDoHTimeout then an error is returned.
|
||||
func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, p.dohTimeout)
|
||||
defer cancel()
|
||||
|
||||
dataChan, errChan := make(chan []string), make(chan error)
|
||||
|
||||
go func() {
|
||||
// Build new DNS request in RFC1035 format.
|
||||
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
|
||||
|
||||
// Pack the DNS request message into wire format.
|
||||
rawRequest, err := dnsRequest.Pack()
|
||||
if err != nil {
|
||||
errChan <- errors.Wrap(err, "failed to pack DNS request")
|
||||
return
|
||||
}
|
||||
|
||||
// Encode wire-format DNS request message as base64url (RFC4648) without padding chars.
|
||||
encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest)
|
||||
|
||||
// Make DoH request to the given DoH provider.
|
||||
rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider)
|
||||
if err != nil {
|
||||
errChan <- errors.Wrap(err, "failed to make DoH request")
|
||||
return
|
||||
}
|
||||
|
||||
// Unpack the DNS response.
|
||||
dnsResponse := new(dns.Msg)
|
||||
if err = dnsResponse.Unpack(rawResponse.Body()); err != nil {
|
||||
errChan <- errors.Wrap(err, "failed to unpack DNS response")
|
||||
return
|
||||
}
|
||||
|
||||
// Pick out the TXT answers.
|
||||
for _, answer := range dnsResponse.Answer {
|
||||
if t, ok := answer.(*dns.TXT); ok {
|
||||
data = append(data, t.Txt...)
|
||||
}
|
||||
}
|
||||
|
||||
dataChan <- data
|
||||
}()
|
||||
|
||||
select {
|
||||
case data = <-dataChan:
|
||||
logrus.WithField("data", data).Info("Received TXT records")
|
||||
return
|
||||
|
||||
case err = <-errChan:
|
||||
logrus.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records")
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
logrus.WithField("provider", dohProvider).Error("Timed out querying DNS records")
|
||||
return []string{}, errors.New("timed out querying DNS records")
|
||||
}
|
||||
}
|
||||
191
internal/dialer/dialer_proxy_provider_test.go
Normal file
191
internal/dialer/dialer_proxy_provider_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/useragent"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProxyProvider_FindProxy(t *testing.T) {
|
||||
proxy := getTrustedServer()
|
||||
defer closeServer(proxy)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, proxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
|
||||
reachableProxy := getTrustedServer()
|
||||
defer closeServer(reachableProxy)
|
||||
|
||||
// We actually close the unreachable proxy straight away rather than deferring the closure.
|
||||
unreachableProxy := getTrustedServer()
|
||||
closeServer(unreachableProxy)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
|
||||
}
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, reachableProxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
untrustedProxy := getUntrustedServer()
|
||||
defer closeServer(untrustedProxy)
|
||||
|
||||
reporter := NewTLSReporter("", "appVersion", useragent.New(), TrustedAPIPins)
|
||||
checker := NewTLSPinChecker(TrustedAPIPins)
|
||||
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
|
||||
|
||||
p := newProxyProvider(dialer, "", []string{"not used"})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
|
||||
}
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.Equal(t, trustedProxy.URL, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
|
||||
unreachableProxy1 := getTrustedServer()
|
||||
closeServer(unreachableProxy1)
|
||||
|
||||
unreachableProxy2 := getTrustedServer()
|
||||
closeServer(unreachableProxy2)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
|
||||
}
|
||||
|
||||
_, err := p.findReachableServer()
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
|
||||
untrustedProxy1 := getUntrustedServer()
|
||||
defer closeServer(untrustedProxy1)
|
||||
|
||||
untrustedProxy2 := getUntrustedServer()
|
||||
defer closeServer(untrustedProxy2)
|
||||
|
||||
reporter := NewTLSReporter("", "appVersion", useragent.New(), TrustedAPIPins)
|
||||
checker := NewTLSPinChecker(TrustedAPIPins)
|
||||
dialer := NewPinningTLSDialer(NewBasicTLSDialer(""), reporter, checker)
|
||||
|
||||
p := newProxyProvider(dialer, "", []string{"not used"})
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
|
||||
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
|
||||
}
|
||||
|
||||
_, err := p.findReachableServer()
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p.cacheRefreshTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
|
||||
|
||||
// We should fail to refresh the proxy cache because the doh provider
|
||||
// takes 2 seconds to respond but we timeout after just 1 second.
|
||||
_, err := p.findReachableServer()
|
||||
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
|
||||
slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
time.Sleep(2 * time.Second)
|
||||
}))
|
||||
defer closeServer(slowProxy)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"not used"})
|
||||
p.canReachTimeout = 1 * time.Second
|
||||
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
|
||||
|
||||
// We should fail to reach the returned proxy because it takes 2 seconds
|
||||
// to reach it and we only allow 1.
|
||||
_, err := p.findReachableServer()
|
||||
|
||||
r.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
|
||||
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider)
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
// DISABLEDTestProxyProvider_DoHLookup_Quad9Port cannot run on CI due to custom
|
||||
// port filter. Basic functionality should be covered by other tests. Keeping
|
||||
// code here to be able to run it locally if needed.
|
||||
func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
|
||||
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider)
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
|
||||
records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider)
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, records)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{Quad9Provider, GoogleProvider})
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, url)
|
||||
}
|
||||
|
||||
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
p := newProxyProvider(NewBasicTLSDialer(""), "", []string{"https://unreachable", Quad9Provider, GoogleProvider})
|
||||
|
||||
url, err := p.findReachableServer()
|
||||
r.NoError(t, err)
|
||||
r.NotEmpty(t, url)
|
||||
}
|
||||
273
internal/dialer/dialer_proxy_test.go
Normal file
273
internal/dialer/dialer_proxy_test.go
Normal file
@ -0,0 +1,273 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// getTrustedServer returns a server and sets its public key as one of the pinned ones.
|
||||
func getTrustedServer() *httptest.Server {
|
||||
return getTrustedServerWithHandler(
|
||||
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
// Do nothing.
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server {
|
||||
proxy := httptest.NewTLSServer(handler)
|
||||
|
||||
pin := certFingerprint(proxy.Certificate())
|
||||
TrustedAPIPins = append(TrustedAPIPins, pin)
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
const servercrt = `
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD
|
||||
VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp
|
||||
dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t
|
||||
T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j
|
||||
b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx
|
||||
MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR
|
||||
BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf
|
||||
MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR
|
||||
aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI
|
||||
hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt
|
||||
r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2
|
||||
pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM
|
||||
GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru
|
||||
rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c
|
||||
kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw
|
||||
ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI
|
||||
DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu
|
||||
ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0
|
||||
MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3
|
||||
LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R
|
||||
BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5
|
||||
L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA
|
||||
CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P
|
||||
3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg
|
||||
yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l
|
||||
z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc
|
||||
u53wgIhCJGuX
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
const serverkey = `
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx
|
||||
owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG
|
||||
ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq
|
||||
UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx
|
||||
8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr
|
||||
n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6
|
||||
mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1
|
||||
EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ
|
||||
/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F
|
||||
qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT
|
||||
VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu
|
||||
Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h
|
||||
bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X
|
||||
TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU
|
||||
LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f
|
||||
pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm
|
||||
nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3
|
||||
Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5
|
||||
ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo
|
||||
w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK
|
||||
PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt
|
||||
FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0
|
||||
ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD
|
||||
z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2
|
||||
FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o
|
||||
vwRMog6lPhlRhHh/FZ43Cg==
|
||||
-----END PRIVATE KEY-----
|
||||
`
|
||||
|
||||
// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones.
|
||||
func getUntrustedServer() *httptest.Server {
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
|
||||
server.StartTLS()
|
||||
return server
|
||||
}
|
||||
|
||||
// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys.
|
||||
func closeServer(server *httptest.Server) {
|
||||
pin := certFingerprint(server.Certificate())
|
||||
|
||||
for i := range TrustedAPIPins {
|
||||
if TrustedAPIPins[i] == pin {
|
||||
TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
server.Close()
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
|
||||
proxy1 := getTrustedServer()
|
||||
defer closeServer(proxy1)
|
||||
proxy2 := getTrustedServer()
|
||||
defer closeServer(proxy2)
|
||||
proxy3 := getTrustedServer()
|
||||
defer closeServer(proxy3)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
|
||||
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
|
||||
err = d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
|
||||
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
|
||||
err = d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy3.URL), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
defer closeServer(trustedProxy)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
d.proxyProvider = provider
|
||||
d.proxyUseDuration = time.Second
|
||||
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
require.Equal(t, ":443", d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
|
||||
trustedProxy := getTrustedServer()
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
|
||||
|
||||
// Simulate that the proxy stops working and that the standard api is reachable again.
|
||||
closeServer(trustedProxy)
|
||||
d.directAddress = formatAsAddress(getRootURL())
|
||||
provider.hostURL = getRootURL()
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
// We should now find the original API URL if it is working again.
|
||||
// The error should be ErrAPINotReachable because the connection dropped intermittently but
|
||||
// the original API is now reachable (see Alternative-Routing-v2 spec for details).
|
||||
err = d.switchToReachableServer()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, formatAsAddress(getRootURL()), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
|
||||
// proxy1 is closed later in this test so we don't defer it here.
|
||||
proxy1 := getTrustedServer()
|
||||
|
||||
proxy2 := getTrustedServer()
|
||||
defer closeServer(proxy2)
|
||||
|
||||
provider := newProxyProvider(NewBasicTLSDialer(""), "", DoHProviders)
|
||||
d := NewProxyTLSDialer(NewBasicTLSDialer(""), "")
|
||||
d.proxyProvider = provider
|
||||
provider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
|
||||
|
||||
err := d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
|
||||
|
||||
// Have to wait so as to not get rejected.
|
||||
time.Sleep(proxyLookupWait)
|
||||
|
||||
// The proxy stops working and the protonmail API is still blocked.
|
||||
closeServer(proxy1)
|
||||
|
||||
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
|
||||
err = d.switchToReachableServer()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
|
||||
}
|
||||
|
||||
func TestFormatAsAddress(t *testing.T) {
|
||||
r := require.New(t)
|
||||
testData := map[string]string{
|
||||
"sub.domain.tld": "sub.domain.tld:443",
|
||||
"http://sub.domain.tld": "sub.domain.tld:80",
|
||||
"https://sub.domain.tld": "sub.domain.tld:443",
|
||||
"ftp://sub.domain.tld": "sub.domain.tld:443",
|
||||
"//sub.domain.tld": "sub.domain.tld:443",
|
||||
}
|
||||
|
||||
for rawURL, wantURL := range testData {
|
||||
r.Equal(wantURL, formatAsAddress(rawURL))
|
||||
}
|
||||
}
|
||||
16
internal/dialer/dialer_test.go
Normal file
16
internal/dialer/dialer_test.go
Normal file
@ -0,0 +1,16 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/http/httpproxy"
|
||||
)
|
||||
|
||||
// skipIfProxyIsSet skips the tests if HTTPS proxy is set.
|
||||
// Should be used for tests depending on proper certificate checks which
|
||||
// is not possible under our CI setup.
|
||||
func skipIfProxyIsSet(t *testing.T) {
|
||||
if httpproxy.FromEnvironment().HTTPSProxy != "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
}
|
||||
13
internal/events/connection.go
Normal file
13
internal/events/connection.go
Normal file
@ -0,0 +1,13 @@
|
||||
package events
|
||||
|
||||
import "gitlab.protontech.ch/go/liteapi"
|
||||
|
||||
type TLSIssue struct {
|
||||
eventBase
|
||||
}
|
||||
|
||||
type ConnStatus struct {
|
||||
eventBase
|
||||
|
||||
Status liteapi.Status
|
||||
}
|
||||
7
internal/events/error.go
Normal file
7
internal/events/error.go
Normal file
@ -0,0 +1,7 @@
|
||||
package events
|
||||
|
||||
type Error struct {
|
||||
eventBase
|
||||
|
||||
Error error
|
||||
}
|
||||
@ -1,60 +1,9 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package events provides names of events used by the event listener in bridge.
|
||||
package events
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
)
|
||||
|
||||
// Constants of events used by the event listener in bridge.
|
||||
const (
|
||||
ErrorEvent = "error"
|
||||
CredentialsErrorEvent = "credentialsError"
|
||||
CloseConnectionEvent = "closeConnection"
|
||||
LogoutEvent = "logout"
|
||||
AddressChangedEvent = "addressChanged"
|
||||
AddressChangedLogoutEvent = "addressChangedLogout"
|
||||
UserRefreshEvent = "userRefresh"
|
||||
RestartBridgeEvent = "restartBridge"
|
||||
InternetConnChangedEvent = "internetChanged"
|
||||
InternetOff = "internetOff"
|
||||
InternetOn = "internetOn"
|
||||
SecondInstanceEvent = "secondInstance"
|
||||
NoActiveKeyForRecipientEvent = "noActiveKeyForRecipient"
|
||||
UpgradeApplicationEvent = "upgradeApplication"
|
||||
TLSCertIssue = "tlsCertPinningIssue"
|
||||
UserChangeDone = "QMLUserChangedDone"
|
||||
|
||||
// LogoutEventTimeout is the minimum time to permit between logout events being sent.
|
||||
LogoutEventTimeout = 3 * time.Minute
|
||||
)
|
||||
|
||||
// SetupEvents specific to event type and data.
|
||||
func SetupEvents(listener listener.Listener) {
|
||||
listener.SetLimit(LogoutEvent, LogoutEventTimeout)
|
||||
listener.SetBuffer(ErrorEvent)
|
||||
listener.SetBuffer(CredentialsErrorEvent)
|
||||
listener.SetBuffer(InternetConnChangedEvent)
|
||||
listener.SetBuffer(UpgradeApplicationEvent)
|
||||
listener.SetBuffer(TLSCertIssue)
|
||||
listener.SetBuffer(UserRefreshEvent)
|
||||
listener.Book(UserChangeDone)
|
||||
type Event interface {
|
||||
_isEvent()
|
||||
}
|
||||
|
||||
type eventBase struct{}
|
||||
|
||||
func (eventBase) _isEvent() {}
|
||||
|
||||
5
internal/events/raise.go
Normal file
5
internal/events/raise.go
Normal file
@ -0,0 +1,5 @@
|
||||
package events
|
||||
|
||||
type Raise struct {
|
||||
eventBase
|
||||
}
|
||||
24
internal/events/sync.go
Normal file
24
internal/events/sync.go
Normal file
@ -0,0 +1,24 @@
|
||||
package events
|
||||
|
||||
import "time"
|
||||
|
||||
type SyncStarted struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type SyncProgress struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
Progress float64
|
||||
Elapsed time.Duration
|
||||
Remaining time.Duration
|
||||
}
|
||||
|
||||
type SyncFinished struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
25
internal/events/update.go
Normal file
25
internal/events/update.go
Normal file
@ -0,0 +1,25 @@
|
||||
package events
|
||||
|
||||
import "github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
|
||||
type UpdateAvailable struct {
|
||||
eventBase
|
||||
|
||||
Version updater.VersionInfo
|
||||
|
||||
CanInstall bool
|
||||
}
|
||||
|
||||
type UpdateNotAvailable struct {
|
||||
eventBase
|
||||
}
|
||||
|
||||
type UpdateInstalled struct {
|
||||
eventBase
|
||||
|
||||
Version updater.VersionInfo
|
||||
}
|
||||
|
||||
type UpdateForced struct {
|
||||
eventBase
|
||||
}
|
||||
52
internal/events/user.go
Normal file
52
internal/events/user.go
Normal file
@ -0,0 +1,52 @@
|
||||
package events
|
||||
|
||||
type UserLoggedIn struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type UserLoggedOut struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type UserDeauth struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type UserDeleted struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type UserChanged struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
}
|
||||
|
||||
type UserAddressCreated struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
Address string
|
||||
}
|
||||
|
||||
type UserAddressChanged struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
Address string
|
||||
}
|
||||
|
||||
type UserAddressDeleted struct {
|
||||
eventBase
|
||||
|
||||
UserID string
|
||||
Address string
|
||||
}
|
||||
32
internal/focus/client.go
Normal file
32
internal/focus/client.go
Normal file
@ -0,0 +1,32 @@
|
||||
package focus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func TryRaise() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
cc, err := grpc.DialContext(ctx, net.JoinHostPort(Host, fmt.Sprint(Port)), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := proto.NewFocusClient(cc).Raise(ctx, &emptypb.Empty{}); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := cc.Close(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
25
internal/focus/focus_test.go
Normal file
25
internal/focus/focus_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package focus
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFocusRaise(t *testing.T) {
|
||||
// Start the focus service.
|
||||
service, err := NewService()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to dial it, it should succeed.
|
||||
require.True(t, TryRaise())
|
||||
|
||||
// The service should report a raise call.
|
||||
<-service.GetRaiseCh()
|
||||
|
||||
// Stop the service.
|
||||
service.Close()
|
||||
|
||||
// Try to dial it, it should fail.
|
||||
require.False(t, TryRaise())
|
||||
}
|
||||
3
internal/focus/proto/focus.go
Normal file
3
internal/focus/proto/focus.go
Normal file
@ -0,0 +1,3 @@
|
||||
package proto
|
||||
|
||||
//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative focus.proto
|
||||
93
internal/focus/proto/focus.pb.go
Normal file
93
internal/focus/proto/focus.pb.go
Normal file
@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2022 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.6
|
||||
// source: focus.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
var File_focus_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_focus_proto_rawDesc = []byte{
|
||||
0x0a, 0x0b, 0x66, 0x6f, 0x63, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x32, 0x40, 0x0a, 0x05, 0x46, 0x6f, 0x63, 0x75, 0x73, 0x12, 0x37, 0x0a, 0x05, 0x52, 0x61,
|
||||
0x69, 0x73, 0x65, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x16, 0x2e, 0x67, 0x6f,
|
||||
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d,
|
||||
0x70, 0x74, 0x79, 0x42, 0x3d, 0x5a, 0x3b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
|
||||
0x6d, 0x2f, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x6e, 0x4d, 0x61, 0x69, 0x6c, 0x2f, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x6e, 0x2d, 0x62, 0x72, 0x69, 0x64, 0x67, 0x65, 0x2f, 0x76, 0x32, 0x2f, 0x69, 0x6e,
|
||||
0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x66, 0x6f, 0x63, 0x75, 0x73, 0x2f, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var file_focus_proto_goTypes = []interface{}{
|
||||
(*emptypb.Empty)(nil), // 0: google.protobuf.Empty
|
||||
}
|
||||
var file_focus_proto_depIdxs = []int32{
|
||||
0, // 0: proto.Focus.Raise:input_type -> google.protobuf.Empty
|
||||
0, // 1: proto.Focus.Raise:output_type -> google.protobuf.Empty
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_focus_proto_init() }
|
||||
func file_focus_proto_init() {
|
||||
if File_focus_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_focus_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 0,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_focus_proto_goTypes,
|
||||
DependencyIndexes: file_focus_proto_depIdxs,
|
||||
}.Build()
|
||||
File_focus_proto = out.File
|
||||
file_focus_proto_rawDesc = nil
|
||||
file_focus_proto_goTypes = nil
|
||||
file_focus_proto_depIdxs = nil
|
||||
}
|
||||
31
internal/focus/proto/focus.proto
Normal file
31
internal/focus/proto/focus.proto
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (c) 2022 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
option go_package = "github.com/ProtonMail/proton-bridge/v2/internal/focus/proto";
|
||||
|
||||
package proto;
|
||||
|
||||
//**********************************************************************************************************************
|
||||
// Service Declaration
|
||||
//**********************************************************************************************************************≠––
|
||||
service Focus {
|
||||
rpc Raise(google.protobuf.Empty) returns (google.protobuf.Empty);
|
||||
}
|
||||
107
internal/focus/proto/focus_grpc.pb.go
Normal file
107
internal/focus/proto/focus_grpc.pb.go
Normal file
@ -0,0 +1,107 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v3.21.6
|
||||
// source: focus.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// FocusClient is the client API for Focus service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type FocusClient interface {
|
||||
Raise(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
type focusClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewFocusClient(cc grpc.ClientConnInterface) FocusClient {
|
||||
return &focusClient{cc}
|
||||
}
|
||||
|
||||
func (c *focusClient) Raise(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, "/proto.Focus/Raise", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// FocusServer is the server API for Focus service.
|
||||
// All implementations must embed UnimplementedFocusServer
|
||||
// for forward compatibility
|
||||
type FocusServer interface {
|
||||
Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
|
||||
mustEmbedUnimplementedFocusServer()
|
||||
}
|
||||
|
||||
// UnimplementedFocusServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedFocusServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedFocusServer) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Raise not implemented")
|
||||
}
|
||||
func (UnimplementedFocusServer) mustEmbedUnimplementedFocusServer() {}
|
||||
|
||||
// UnsafeFocusServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to FocusServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeFocusServer interface {
|
||||
mustEmbedUnimplementedFocusServer()
|
||||
}
|
||||
|
||||
func RegisterFocusServer(s grpc.ServiceRegistrar, srv FocusServer) {
|
||||
s.RegisterService(&Focus_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _Focus_Raise_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(emptypb.Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(FocusServer).Raise(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/proto.Focus/Raise",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(FocusServer).Raise(ctx, req.(*emptypb.Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// Focus_ServiceDesc is the grpc.ServiceDesc for Focus service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var Focus_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "proto.Focus",
|
||||
HandlerType: (*FocusServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Raise",
|
||||
Handler: _Focus_Raise_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "focus.proto",
|
||||
}
|
||||
60
internal/focus/service.go
Normal file
60
internal/focus/service.go
Normal file
@ -0,0 +1,60 @@
|
||||
package focus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/focus/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
const (
|
||||
Host = "127.0.0.1"
|
||||
Port = 1042
|
||||
)
|
||||
|
||||
type FocusService struct {
|
||||
proto.UnimplementedFocusServer
|
||||
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
raiseCh chan struct{}
|
||||
}
|
||||
|
||||
func NewService() (*FocusService, error) {
|
||||
listener, err := net.Listen("tcp", net.JoinHostPort(Host, fmt.Sprint(Port)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
service := &FocusService{
|
||||
server: grpc.NewServer(),
|
||||
listener: listener,
|
||||
raiseCh: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
proto.RegisterFocusServer(service.server, service)
|
||||
|
||||
go func() {
|
||||
if err := service.server.Serve(listener); err != nil {
|
||||
fmt.Printf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (service *FocusService) Raise(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
service.raiseCh <- struct{}{}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (service *FocusService) GetRaiseCh() <-chan struct{} {
|
||||
return service.raiseCh
|
||||
}
|
||||
|
||||
func (service *FocusService) Close() {
|
||||
service.server.Stop()
|
||||
}
|
||||
@ -22,7 +22,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -35,6 +35,7 @@ func (f *frontendCLI) completeUsernames(args []string) (usernames []string) {
|
||||
if len(args) == 1 {
|
||||
arg = args[0]
|
||||
}
|
||||
|
||||
for _, userID := range f.bridge.GetUserIDs() {
|
||||
user, err := f.bridge.GetUserInfo(userID)
|
||||
if err != nil {
|
||||
@ -50,8 +51,7 @@ func (f *frontendCLI) completeUsernames(args []string) (usernames []string) {
|
||||
// noAccountWrapper is a decorator for functions which need any account to be properly functional.
|
||||
func (f *frontendCLI) noAccountWrapper(callback func(*ishell.Context)) func(*ishell.Context) {
|
||||
return func(c *ishell.Context) {
|
||||
users := f.bridge.GetUserIDs()
|
||||
if len(users) == 0 {
|
||||
if len(f.bridge.GetUserIDs()) == 0 {
|
||||
f.Println("No active accounts. Please add account to continue.")
|
||||
} else {
|
||||
callback(c)
|
||||
@ -59,9 +59,9 @@ func (f *frontendCLI) noAccountWrapper(callback func(*ishell.Context)) func(*ish
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) askUserByIndexOrName(c *ishell.Context) users.UserInfo {
|
||||
func (f *frontendCLI) askUserByIndexOrName(c *ishell.Context) bridge.UserInfo {
|
||||
user := f.getUserByIndexOrName("")
|
||||
if user.ID != "" {
|
||||
if user.UserID != "" {
|
||||
return user
|
||||
}
|
||||
|
||||
@ -69,24 +69,24 @@ func (f *frontendCLI) askUserByIndexOrName(c *ishell.Context) users.UserInfo {
|
||||
indexRange := fmt.Sprintf("number between 0 and %d", numberOfAccounts-1)
|
||||
if len(c.Args) == 0 {
|
||||
f.Printf("Please choose %s or username.\n", indexRange)
|
||||
return users.UserInfo{}
|
||||
return bridge.UserInfo{}
|
||||
}
|
||||
arg := c.Args[0]
|
||||
user = f.getUserByIndexOrName(arg)
|
||||
if user.ID == "" {
|
||||
if user.UserID == "" {
|
||||
f.Printf("Wrong input '%s'. Choose %s or username.\n", bold(arg), indexRange)
|
||||
return users.UserInfo{}
|
||||
return bridge.UserInfo{}
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func (f *frontendCLI) getUserByIndexOrName(arg string) users.UserInfo {
|
||||
func (f *frontendCLI) getUserByIndexOrName(arg string) bridge.UserInfo {
|
||||
userIDs := f.bridge.GetUserIDs()
|
||||
numberOfAccounts := len(userIDs)
|
||||
if numberOfAccounts == 0 {
|
||||
return users.UserInfo{}
|
||||
return bridge.UserInfo{}
|
||||
}
|
||||
res := make([]users.UserInfo, len(userIDs))
|
||||
res := make([]bridge.UserInfo, len(userIDs))
|
||||
for idx, userID := range userIDs {
|
||||
user, err := f.bridge.GetUserInfo(userID)
|
||||
if err != nil {
|
||||
@ -99,7 +99,7 @@ func (f *frontendCLI) getUserByIndexOrName(arg string) users.UserInfo {
|
||||
}
|
||||
if index, err := strconv.Atoi(arg); err == nil {
|
||||
if index < 0 || index >= numberOfAccounts {
|
||||
return users.UserInfo{}
|
||||
return bridge.UserInfo{}
|
||||
}
|
||||
return res[index]
|
||||
}
|
||||
@ -108,5 +108,5 @@ func (f *frontendCLI) getUserByIndexOrName(arg string) users.UserInfo {
|
||||
return user
|
||||
}
|
||||
}
|
||||
return users.UserInfo{}
|
||||
return bridge.UserInfo{}
|
||||
}
|
||||
|
||||
@ -22,8 +22,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
@ -40,7 +39,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) {
|
||||
connected = "connected"
|
||||
}
|
||||
mode := "split"
|
||||
if user.Mode == users.CombinedMode {
|
||||
if user.AddressMode == bridge.CombinedMode {
|
||||
mode = "combined"
|
||||
}
|
||||
f.Printf(spacing, idx, user.Username, connected, mode)
|
||||
@ -50,7 +49,7 @@ func (f *frontendCLI) listAccounts(c *ishell.Context) {
|
||||
|
||||
func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
|
||||
user := f.askUserByIndexOrName(c)
|
||||
if user.ID == "" {
|
||||
if user.UserID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -59,8 +58,8 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if user.Mode == users.CombinedMode {
|
||||
f.showAccountAddressInfo(user, user.Addresses[user.Primary])
|
||||
if user.AddressMode == bridge.CombinedMode {
|
||||
f.showAccountAddressInfo(user, user.Addresses[0])
|
||||
} else {
|
||||
for _, address := range user.Addresses {
|
||||
f.showAccountAddressInfo(user, address)
|
||||
@ -68,25 +67,31 @@ func (f *frontendCLI) showAccountInfo(c *ishell.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) showAccountAddressInfo(user users.UserInfo, address string) {
|
||||
func (f *frontendCLI) showAccountAddressInfo(user bridge.UserInfo, address string) {
|
||||
imapSecurity := "STARTTLS"
|
||||
if f.bridge.GetIMAPSSL() {
|
||||
imapSecurity = "SSL"
|
||||
}
|
||||
|
||||
smtpSecurity := "STARTTLS"
|
||||
if f.bridge.GetBool(settings.SMTPSSLKey) {
|
||||
if f.bridge.GetSMTPSSL() {
|
||||
smtpSecurity = "SSL"
|
||||
}
|
||||
|
||||
f.Println(bold("Configuration for " + address))
|
||||
f.Printf("IMAP Settings\nAddress: %s\nIMAP port: %d\nUsername: %s\nPassword: %s\nSecurity: %s\n",
|
||||
bridge.Host,
|
||||
f.bridge.GetInt(settings.IMAPPortKey),
|
||||
constants.Host,
|
||||
f.bridge.GetIMAPPort(),
|
||||
address,
|
||||
user.Password,
|
||||
"STARTTLS",
|
||||
user.BridgePass,
|
||||
imapSecurity,
|
||||
)
|
||||
f.Println("")
|
||||
f.Printf("SMTP Settings\nAddress: %s\nSMTP port: %d\nUsername: %s\nPassword: %s\nSecurity: %s\n",
|
||||
bridge.Host,
|
||||
f.bridge.GetInt(settings.SMTPPortKey),
|
||||
constants.Host,
|
||||
f.bridge.GetSMTPPort(),
|
||||
address,
|
||||
user.Password,
|
||||
user.BridgePass,
|
||||
smtpSecurity,
|
||||
)
|
||||
f.Println("")
|
||||
@ -99,8 +104,8 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { //nolint:funlen
|
||||
loginName := ""
|
||||
if len(c.Args) > 0 {
|
||||
user := f.getUserByIndexOrName(c.Args[0])
|
||||
if user.ID != "" {
|
||||
loginName = user.Addresses[user.Primary]
|
||||
if user.UserID != "" {
|
||||
loginName = user.Addresses[0]
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,41 +124,23 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { //nolint:funlen
|
||||
}
|
||||
|
||||
f.Println("Authenticating ... ")
|
||||
client, auth, err := f.bridge.Login(loginName, []byte(password))
|
||||
|
||||
userID, err := f.bridge.LoginUser(
|
||||
context.Background(),
|
||||
loginName,
|
||||
password,
|
||||
func() (string, error) {
|
||||
return f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty), nil
|
||||
},
|
||||
func() ([]byte, error) {
|
||||
return []byte(f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)), nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if auth.HasTwoFactor() {
|
||||
twoFactor := f.readStringInAttempts("Two factor code", c.ReadLine, isNotEmpty)
|
||||
if twoFactor == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err = client.Auth2FA(context.Background(), twoFactor)
|
||||
if err != nil {
|
||||
f.processAPIError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mailboxPassword := password
|
||||
if auth.HasMailboxPassword() {
|
||||
mailboxPassword = f.readStringInAttempts("Mailbox password", c.ReadPassword, isNotEmpty)
|
||||
}
|
||||
if mailboxPassword == "" {
|
||||
return
|
||||
}
|
||||
|
||||
f.Println("Adding account ...")
|
||||
userID, err := f.bridge.FinishLogin(client, auth, []byte(mailboxPassword))
|
||||
if err != nil {
|
||||
log.WithField("username", loginName).WithError(err).Error("Login was unsuccessful")
|
||||
f.Println("Adding account was unsuccessful:", err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := f.bridge.GetUserInfo(userID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -167,11 +154,12 @@ func (f *frontendCLI) logoutAccount(c *ishell.Context) {
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
user := f.askUserByIndexOrName(c)
|
||||
if user.ID == "" {
|
||||
if user.UserID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to logout account " + bold(user.Username)) {
|
||||
if err := f.bridge.LogoutUser(user.ID); err != nil {
|
||||
if err := f.bridge.LogoutUser(context.Background(), user.UserID); err != nil {
|
||||
f.printAndLogError("Logging out failed: ", err)
|
||||
}
|
||||
}
|
||||
@ -182,12 +170,12 @@ func (f *frontendCLI) deleteAccount(c *ishell.Context) {
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
user := f.askUserByIndexOrName(c)
|
||||
if user.ID == "" {
|
||||
if user.UserID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to " + bold("remove account "+user.Username)) {
|
||||
clearCache := f.yesNoQuestion("Do you want to remove cache for this account")
|
||||
if err := f.bridge.DeleteUser(user.ID, clearCache); err != nil {
|
||||
if err := f.bridge.DeleteUser(context.Background(), user.UserID); err != nil {
|
||||
f.printAndLogError("Cannot delete account: ", err)
|
||||
return
|
||||
}
|
||||
@ -205,10 +193,13 @@ func (f *frontendCLI) deleteAccounts(c *ishell.Context) {
|
||||
for _, userID := range f.bridge.GetUserIDs() {
|
||||
user, err := f.bridge.GetUserInfo(userID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
f.printAndLogError("Cannot get user info: ", err)
|
||||
return
|
||||
}
|
||||
if err := f.bridge.DeleteUser(user.ID, false); err != nil {
|
||||
|
||||
if err := f.bridge.DeleteUser(context.Background(), user.UserID); err != nil {
|
||||
f.printAndLogError("Cannot delete account ", user.Username, ": ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -223,37 +214,50 @@ func (f *frontendCLI) deleteEverything(c *ishell.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
f.bridge.FactoryReset()
|
||||
f.bridge.FactoryReset(context.Background())
|
||||
|
||||
c.Println("Everything cleared")
|
||||
|
||||
// Clearing data removes everything (db, preferences, ...) so everything has to be stopped and started again.
|
||||
f.restarter.SetToRestart()
|
||||
|
||||
f.Stop()
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changeMode(c *ishell.Context) {
|
||||
user := f.askUserByIndexOrName(c)
|
||||
if user.ID == "" {
|
||||
if user.UserID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var targetMode users.AddressMode
|
||||
var targetMode bridge.AddressMode
|
||||
|
||||
if user.Mode == users.CombinedMode {
|
||||
targetMode = users.SplitMode
|
||||
if user.AddressMode == bridge.CombinedMode {
|
||||
targetMode = bridge.SplitMode
|
||||
} else {
|
||||
targetMode = users.CombinedMode
|
||||
targetMode = bridge.CombinedMode
|
||||
}
|
||||
|
||||
if !f.yesNoQuestion("Are you sure you want to change the mode for account " + bold(user.Username) + " to " + bold(targetMode)) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.SetAddressMode(user.ID, targetMode); err != nil {
|
||||
if err := f.bridge.SetAddressMode(user.UserID, targetMode); err != nil {
|
||||
f.printAndLogError("Cannot switch address mode:", err)
|
||||
}
|
||||
|
||||
f.Printf("Address mode for account %s changed to %s\n", user.Username, targetMode)
|
||||
}
|
||||
|
||||
func (f *frontendCLI) configureAppleMail(c *ishell.Context) {
|
||||
user := f.askUserByIndexOrName(c)
|
||||
if user.UserID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if !f.yesNoQuestion("Are you sure you want to configure Apple Mail for " + bold(user.Username) + " with address " + bold(user.Addresses[0])) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.ConfigureAppleMail(user.UserID, user.Addresses[0]); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.Printf("Apple Mail configured for %v with address %v\n", user.Username, user.Addresses[0])
|
||||
}
|
||||
|
||||
@ -19,11 +19,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
|
||||
"github.com/abiosoft/ishell"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -34,30 +36,14 @@ var log = logrus.WithField("pkg", "frontend/cli") //nolint:gochecknoglobals
|
||||
type frontendCLI struct {
|
||||
*ishell.Shell
|
||||
|
||||
eventListener listener.Listener
|
||||
updater types.Updater
|
||||
bridge types.Bridger
|
||||
|
||||
restarter types.Restarter
|
||||
bridge *bridge.Bridge
|
||||
}
|
||||
|
||||
// New returns a new CLI frontend configured with the given options.
|
||||
func New( //nolint:funlen
|
||||
panicHandler types.PanicHandler,
|
||||
|
||||
eventListener listener.Listener,
|
||||
updater types.Updater,
|
||||
bridge types.Bridger,
|
||||
restarter types.Restarter,
|
||||
) *frontendCLI { //nolint:revive
|
||||
func New(bridge *bridge.Bridge) *frontendCLI {
|
||||
fe := &frontendCLI{
|
||||
Shell: ishell.New(),
|
||||
|
||||
eventListener: eventListener,
|
||||
updater: updater,
|
||||
bridge: bridge,
|
||||
|
||||
restarter: restarter,
|
||||
Shell: ishell.New(),
|
||||
bridge: bridge,
|
||||
}
|
||||
|
||||
// Clear commands.
|
||||
@ -66,12 +52,6 @@ func New( //nolint:funlen
|
||||
Help: "remove stored accounts and preferences. (alias: cl)",
|
||||
Aliases: []string{"cl"},
|
||||
}
|
||||
clearCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "cache",
|
||||
Help: "remove stored preferences for accounts (aliases: c, prefs, preferences)",
|
||||
Aliases: []string{"c", "prefs", "preferences"},
|
||||
Func: fe.deleteCache,
|
||||
})
|
||||
clearCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "accounts",
|
||||
Help: "remove all accounts from keychain. (aliases: a, k, keychain)",
|
||||
@ -100,15 +80,30 @@ func New( //nolint:funlen
|
||||
Completer: fe.completeUsernames,
|
||||
})
|
||||
changeCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "port",
|
||||
Help: "change port numbers of IMAP and SMTP servers. (alias: p)",
|
||||
Aliases: []string{"p"},
|
||||
Func: fe.changePort,
|
||||
Name: "change-location",
|
||||
Help: "change the location of the encrypted message cache",
|
||||
Func: fe.setGluonLocation,
|
||||
})
|
||||
changeCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "imap-port",
|
||||
Help: "change port number of IMAP server.",
|
||||
Func: fe.changeIMAPPort,
|
||||
})
|
||||
changeCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "smtp-port",
|
||||
Help: "change port number of SMTP server.",
|
||||
Func: fe.changeSMTPPort,
|
||||
})
|
||||
changeCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "imap-security",
|
||||
Help: "change IMAP SSL settings servers.(alias: ssl-imap, starttls-imap)",
|
||||
Aliases: []string{"ssl-imap", "starttls-imap"},
|
||||
Func: fe.changeIMAPSecurity,
|
||||
})
|
||||
changeCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "smtp-security",
|
||||
Help: "change port numbers of IMAP and SMTP servers.(alias: ssl, starttls)",
|
||||
Aliases: []string{"ssl", "starttls"},
|
||||
Help: "change SMTP SSL settings servers.(alias: ssl-smtp, starttls-smtp)",
|
||||
Aliases: []string{"ssl-smtp", "starttls-smtp"},
|
||||
Func: fe.changeSMTPSecurity,
|
||||
})
|
||||
fe.AddCmd(changeCmd)
|
||||
@ -130,6 +125,22 @@ func New( //nolint:funlen
|
||||
})
|
||||
fe.AddCmd(dohCmd)
|
||||
|
||||
// Apple Mail commands.
|
||||
configureCmd := &ishell.Cmd{
|
||||
Name: "configure-apple-mail",
|
||||
Help: "Configures Apple Mail to use ProtonMail Bridge",
|
||||
Func: fe.configureAppleMail,
|
||||
}
|
||||
fe.AddCmd(configureCmd)
|
||||
|
||||
// TLS commands.
|
||||
exportTLSCmd := &ishell.Cmd{
|
||||
Name: "export-tls",
|
||||
Help: "Export the TLS certificate used by the Bridge",
|
||||
Func: fe.exportTLSCerts,
|
||||
}
|
||||
fe.AddCmd(exportTLSCmd)
|
||||
|
||||
// All mail visibility commands.
|
||||
allMailCmd := &ishell.Cmd{
|
||||
Name: "all-mail-visibility",
|
||||
@ -147,28 +158,6 @@ func New( //nolint:funlen
|
||||
})
|
||||
fe.AddCmd(allMailCmd)
|
||||
|
||||
// Cache-On-Disk commands.
|
||||
codCmd := &ishell.Cmd{
|
||||
Name: "local-cache",
|
||||
Help: "manage the local encrypted message cache",
|
||||
}
|
||||
codCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "enable",
|
||||
Help: "enable the local cache",
|
||||
Func: fe.enableCacheOnDisk,
|
||||
})
|
||||
codCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "disable",
|
||||
Help: "disable the local cache",
|
||||
Func: fe.disableCacheOnDisk,
|
||||
})
|
||||
codCmd.AddCmd(&ishell.Cmd{
|
||||
Name: "change-location",
|
||||
Help: "change the location of the local cache",
|
||||
Func: fe.setCacheOnDiskLocation,
|
||||
})
|
||||
fe.AddCmd(codCmd)
|
||||
|
||||
// Updates commands.
|
||||
updatesCmd := &ishell.Cmd{
|
||||
Name: "updates",
|
||||
@ -224,7 +213,6 @@ func New( //nolint:funlen
|
||||
Aliases: []string{"man"},
|
||||
Func: fe.printManual,
|
||||
})
|
||||
|
||||
fe.AddCmd(&ishell.Cmd{
|
||||
Name: "credits",
|
||||
Help: "print used resources.",
|
||||
@ -267,55 +255,122 @@ func New( //nolint:funlen
|
||||
Completer: fe.completeUsernames,
|
||||
})
|
||||
|
||||
// System commands.
|
||||
fe.AddCmd(&ishell.Cmd{
|
||||
Name: "restart",
|
||||
Help: "restart the bridge.",
|
||||
Func: fe.restart,
|
||||
})
|
||||
go fe.watchEvents()
|
||||
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
fe.watchEvents()
|
||||
}()
|
||||
return fe
|
||||
}
|
||||
|
||||
func (f *frontendCLI) watchEvents() {
|
||||
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
internetConnChangedCh := f.eventListener.ProvideChannel(events.InternetConnChangedEvent)
|
||||
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
f.Println("Bridge failed:", errorDetails)
|
||||
case <-credentialsErrorCh:
|
||||
eventCh, done := f.bridge.GetEvents()
|
||||
defer done()
|
||||
|
||||
// TODO: Better error events.
|
||||
for _, err := range f.bridge.GetErrors() {
|
||||
switch {
|
||||
case errors.Is(err, vault.ErrCorrupt):
|
||||
f.notifyCredentialsError()
|
||||
case stat := <-internetConnChangedCh:
|
||||
if stat == events.InternetOff {
|
||||
|
||||
case errors.Is(err, vault.ErrInsecure):
|
||||
f.notifyCredentialsError()
|
||||
|
||||
case errors.Is(err, bridge.ErrServeIMAP):
|
||||
f.Println("IMAP server error:", err)
|
||||
|
||||
case errors.Is(err, bridge.ErrServeSMTP):
|
||||
f.Println("SMTP server error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
for event := range eventCh {
|
||||
switch event := event.(type) {
|
||||
case events.ConnStatus:
|
||||
switch event.Status {
|
||||
case liteapi.StatusUp:
|
||||
f.notifyInternetOn()
|
||||
|
||||
case liteapi.StatusDown:
|
||||
f.notifyInternetOff()
|
||||
}
|
||||
if stat == events.InternetOn {
|
||||
f.notifyInternetOn()
|
||||
}
|
||||
case address := <-addressChangedCh:
|
||||
f.Printf("Address changed for %s. You may need to reconfigure your email client.", address)
|
||||
case address := <-addressChangedLogoutCh:
|
||||
f.notifyLogout(address)
|
||||
case userID := <-logoutCh:
|
||||
user, err := f.bridge.GetUserInfo(userID)
|
||||
|
||||
case events.UserDeauth:
|
||||
user, err := f.bridge.GetUserInfo(event.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.notifyLogout(user.Username)
|
||||
case <-certIssue:
|
||||
|
||||
case events.UserAddressChanged:
|
||||
user, err := f.bridge.GetUserInfo(event.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.Printf("Address changed for %s. You may need to reconfigure your email client.\n", user.Username)
|
||||
|
||||
case events.UserAddressDeleted:
|
||||
f.notifyLogout(event.Address)
|
||||
|
||||
case events.SyncStarted:
|
||||
user, err := f.bridge.GetUserInfo(event.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.Printf("A sync has begun for %s.\n", user.Username)
|
||||
|
||||
case events.SyncFinished:
|
||||
user, err := f.bridge.GetUserInfo(event.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.Printf("A sync has finished for %s.\n", user.Username)
|
||||
|
||||
case events.SyncProgress:
|
||||
user, err := f.bridge.GetUserInfo(event.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.Printf(
|
||||
"Sync (%v): %.1f%% (Elapsed: %0.1fs, ETA: %0.1fs)\n",
|
||||
user.Username,
|
||||
100*event.Progress,
|
||||
event.Elapsed.Seconds(),
|
||||
event.Remaining.Seconds(),
|
||||
)
|
||||
|
||||
case events.UpdateAvailable:
|
||||
f.Printf("An update is available (version %v)\n", event.Version.Version)
|
||||
|
||||
case events.UpdateForced:
|
||||
f.notifyNeedUpgrade()
|
||||
|
||||
case events.TLSIssue:
|
||||
f.notifyCertIssue()
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
f.Println("Bridge failed:", errorDetails)
|
||||
case <-credentialsErrorCh:
|
||||
f.notifyCredentialsError()
|
||||
case stat := <-internetConnChangedCh:
|
||||
if stat == events.InternetOff {
|
||||
f.notifyInternetOff()
|
||||
}
|
||||
if stat == events.InternetOn {
|
||||
f.notifyInternetOn()
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
// Loop starts the frontend loop with an interactive shell.
|
||||
@ -340,12 +395,3 @@ func (f *frontendCLI) Loop() error {
|
||||
f.Run()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *frontendCLI) NotifyManualUpdate(update updater.VersionInfo, canInstall bool) {
|
||||
// NOTE: Save the update somewhere so that it can be installed when user chooses "install now".
|
||||
}
|
||||
|
||||
func (f *frontendCLI) WaitUntilFrontendIsReady() {}
|
||||
func (f *frontendCLI) SetVersion(version updater.VersionInfo) {}
|
||||
func (f *frontendCLI) NotifySilentUpdateInstalled() {}
|
||||
func (f *frontendCLI) NotifySilentUpdateError(err error) {}
|
||||
|
||||
@ -18,28 +18,21 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
var currentPort = "" //nolint:gochecknoglobals
|
||||
|
||||
func (f *frontendCLI) restart(c *ishell.Context) {
|
||||
if f.yesNoQuestion("Are you sure you want to restart the Bridge") {
|
||||
f.Println("Restarting Bridge...")
|
||||
f.restarter.SetToRestart()
|
||||
f.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) printLogDir(c *ishell.Context) {
|
||||
if path, err := f.bridge.ProvideLogsPath(); err != nil {
|
||||
if path, err := f.bridge.GetLogsPath(); err != nil {
|
||||
f.Println("Failed to determine location of log files")
|
||||
} else {
|
||||
f.Println("Log files are stored in\n\n ", path)
|
||||
@ -50,79 +43,91 @@ func (f *frontendCLI) printManual(c *ishell.Context) {
|
||||
f.Println("More instructions about the Bridge can be found at\n\n https://protonmail.com/bridge")
|
||||
}
|
||||
|
||||
func (f *frontendCLI) deleteCache(c *ishell.Context) {
|
||||
func (f *frontendCLI) printCredits(c *ishell.Context) {
|
||||
for _, pkg := range strings.Split(bridge.Credits, ";") {
|
||||
f.Println(pkg)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changeIMAPSecurity(c *ishell.Context) {
|
||||
f.ShowPrompt(false)
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
if !f.yesNoQuestion("Do you really want to remove all stored preferences") {
|
||||
return
|
||||
newSecurity := "SSL"
|
||||
if f.bridge.GetIMAPSSL() {
|
||||
newSecurity = "STARTTLS"
|
||||
}
|
||||
|
||||
if err := f.bridge.ClearData(); err != nil {
|
||||
f.printAndLogError("Cache clear failed: ", err.Error())
|
||||
return
|
||||
msg := fmt.Sprintf("Are you sure you want to change IMAP setting to %q", newSecurity)
|
||||
|
||||
if f.yesNoQuestion(msg) {
|
||||
if err := f.bridge.SetIMAPSSL(!f.bridge.GetIMAPSSL()); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
f.Println("Cached cleared, restarting bridge")
|
||||
|
||||
// Clearing data removes everything (db, preferences, ...) so everything has to be stopped and started again.
|
||||
f.restarter.SetToRestart()
|
||||
|
||||
f.Stop()
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changeSMTPSecurity(c *ishell.Context) {
|
||||
f.ShowPrompt(false)
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
isSSL := f.bridge.GetBool(settings.SMTPSSLKey)
|
||||
newSecurity := "SSL"
|
||||
if isSSL {
|
||||
if f.bridge.GetSMTPSSL() {
|
||||
newSecurity = "STARTTLS"
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Are you sure you want to change SMTP setting to %q and restart the Bridge", newSecurity)
|
||||
msg := fmt.Sprintf("Are you sure you want to change SMTP setting to %q", newSecurity)
|
||||
|
||||
if f.yesNoQuestion(msg) {
|
||||
f.bridge.SetBool(settings.SMTPSSLKey, !isSSL)
|
||||
f.Println("Restarting Bridge...")
|
||||
f.restarter.SetToRestart()
|
||||
f.Stop()
|
||||
if err := f.bridge.SetSMTPSSL(!f.bridge.GetSMTPSSL()); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changePort(c *ishell.Context) {
|
||||
func (f *frontendCLI) changeIMAPPort(c *ishell.Context) {
|
||||
f.ShowPrompt(false)
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
currentPort = f.bridge.Get(settings.IMAPPortKey)
|
||||
newIMAPPort := f.readStringInAttempts("Set IMAP port (current "+currentPort+")", c.ReadLine, f.isPortFree)
|
||||
newIMAPPort := f.readStringInAttempts(fmt.Sprintf("Set IMAP port (current %v)", f.bridge.GetIMAPPort()), c.ReadLine, f.isPortFree)
|
||||
if newIMAPPort == "" {
|
||||
newIMAPPort = currentPort
|
||||
}
|
||||
imapPortChanged := newIMAPPort != currentPort
|
||||
|
||||
currentPort = f.bridge.Get(settings.SMTPPortKey)
|
||||
newSMTPPort := f.readStringInAttempts("Set SMTP port (current "+currentPort+")", c.ReadLine, f.isPortFree)
|
||||
if newSMTPPort == "" {
|
||||
newSMTPPort = currentPort
|
||||
}
|
||||
smtpPortChanged := newSMTPPort != currentPort
|
||||
|
||||
if newIMAPPort == newSMTPPort {
|
||||
f.Println("SMTP and IMAP ports must be different!")
|
||||
f.printAndLogError(errors.New("failed to get new port"))
|
||||
return
|
||||
}
|
||||
|
||||
if imapPortChanged || smtpPortChanged {
|
||||
f.Println("Saving values IMAP:", newIMAPPort, "SMTP:", newSMTPPort)
|
||||
f.bridge.Set(settings.IMAPPortKey, newIMAPPort)
|
||||
f.bridge.Set(settings.SMTPPortKey, newSMTPPort)
|
||||
f.Println("Restarting Bridge...")
|
||||
f.restarter.SetToRestart()
|
||||
f.Stop()
|
||||
} else {
|
||||
f.Println("Nothing changed")
|
||||
newIMAPPortInt, err := strconv.Atoi(newIMAPPort)
|
||||
if err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.SetIMAPPort(newIMAPPortInt); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) changeSMTPPort(c *ishell.Context) {
|
||||
f.ShowPrompt(false)
|
||||
defer f.ShowPrompt(true)
|
||||
|
||||
newSMTPPort := f.readStringInAttempts(fmt.Sprintf("Set SMTP port (current %v)", f.bridge.GetSMTPPort()), c.ReadLine, f.isPortFree)
|
||||
if newSMTPPort == "" {
|
||||
f.printAndLogError(errors.New("failed to get new port"))
|
||||
return
|
||||
}
|
||||
|
||||
newSMTPPortInt, err := strconv.Atoi(newSMTPPort)
|
||||
if err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := f.bridge.SetSMTPPort(newSMTPPortInt); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,7 +140,10 @@ func (f *frontendCLI) allowProxy(c *ishell.Context) {
|
||||
f.Println("Bridge is currently set to NOT use alternative routing to connect to Proton if it is being blocked.")
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
|
||||
f.bridge.SetProxyAllowed(true)
|
||||
if err := f.bridge.SetProxyAllowed(true); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,12 +156,15 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
|
||||
f.Println("Bridge is currently set to use alternative routing to connect to Proton if it is being blocked.")
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
|
||||
f.bridge.SetProxyAllowed(false)
|
||||
if err := f.bridge.SetProxyAllowed(false); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) hideAllMail(c *ishell.Context) {
|
||||
if !f.bridge.IsAllMailVisible() {
|
||||
if !f.bridge.GetShowAllMail() {
|
||||
f.Println("All Mail folder is not listed in your local client.")
|
||||
return
|
||||
}
|
||||
@ -161,12 +172,15 @@ func (f *frontendCLI) hideAllMail(c *ishell.Context) {
|
||||
f.Println("All Mail folder is listed in your client right now.")
|
||||
|
||||
if f.yesNoQuestion("Do you want to hide All Mail folder") {
|
||||
f.bridge.SetIsAllMailVisible(false)
|
||||
if err := f.bridge.SetShowAllMail(false); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) showAllMail(c *ishell.Context) {
|
||||
if f.bridge.IsAllMailVisible() {
|
||||
if f.bridge.GetShowAllMail() {
|
||||
f.Println("All Mail folder is listed in your local client.")
|
||||
return
|
||||
}
|
||||
@ -174,68 +188,47 @@ func (f *frontendCLI) showAllMail(c *ishell.Context) {
|
||||
f.Println("All Mail folder is not listed in your client right now.")
|
||||
|
||||
if f.yesNoQuestion("Do you want to show All Mail folder") {
|
||||
f.bridge.SetIsAllMailVisible(true)
|
||||
if err := f.bridge.SetShowAllMail(true); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) {
|
||||
if f.bridge.GetBool(settings.CacheEnabledKey) {
|
||||
f.Println("The local cache is already enabled.")
|
||||
return
|
||||
func (f *frontendCLI) setGluonLocation(c *ishell.Context) {
|
||||
if gluonDir := f.bridge.GetGluonDir(); gluonDir != "" {
|
||||
f.Println("The current message cache location is:", gluonDir)
|
||||
}
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to enable the local cache") {
|
||||
if err := f.bridge.EnableCache(); err != nil {
|
||||
f.Println("The local cache could not be enabled.")
|
||||
if location := f.readStringInAttempts("Enter a new location for the message cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
|
||||
if err := f.bridge.SetGluonDir(context.Background(), location); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.restarter.SetToRestart()
|
||||
f.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) {
|
||||
if !f.bridge.GetBool(settings.CacheEnabledKey) {
|
||||
f.Println("The local cache is already disabled.")
|
||||
return
|
||||
}
|
||||
func (f *frontendCLI) exportTLSCerts(c *ishell.Context) {
|
||||
if location := f.readStringInAttempts("Enter a path to which to export the TLS certificate used for IMAP and SMTP", c.ReadLine, f.isCacheLocationUsable); location != "" {
|
||||
cert, key := f.bridge.GetBridgeTLSCert()
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to disable the local cache") {
|
||||
if err := f.bridge.DisableCache(); err != nil {
|
||||
f.Println("The local cache could not be disabled.")
|
||||
if err := os.WriteFile(filepath.Join(location, "cert.pem"), cert, 0600); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.restarter.SetToRestart()
|
||||
f.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) {
|
||||
if !f.bridge.GetBool(settings.CacheEnabledKey) {
|
||||
f.Println("The local cache must be enabled.")
|
||||
return
|
||||
}
|
||||
|
||||
if location := f.bridge.Get(settings.CacheLocationKey); location != "" {
|
||||
f.Println("The current local cache location is:", location)
|
||||
}
|
||||
|
||||
if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
|
||||
if err := f.bridge.MigrateCache(f.bridge.Get(settings.CacheLocationKey), location); err != nil {
|
||||
f.Println("The local cache location could not be changed.")
|
||||
if err := os.WriteFile(filepath.Join(location, "key.pem"), key, 0600); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
|
||||
f.restarter.SetToRestart()
|
||||
f.Stop()
|
||||
f.Println("TLS certificate exported to", location)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) isPortFree(port string) bool {
|
||||
port = strings.ReplaceAll(port, ":", "")
|
||||
if port == "" || port == currentPort {
|
||||
if port == "" {
|
||||
return true
|
||||
}
|
||||
number, err := strconv.Atoi(port)
|
||||
|
||||
@ -18,36 +18,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/abiosoft/ishell"
|
||||
)
|
||||
|
||||
func (f *frontendCLI) checkUpdates(c *ishell.Context) {
|
||||
version, err := f.updater.Check()
|
||||
if err != nil {
|
||||
f.Println("An error occurred while checking for updates.")
|
||||
return
|
||||
}
|
||||
|
||||
if f.updater.IsUpdateApplicable(version) {
|
||||
f.Println("An update is available.")
|
||||
} else {
|
||||
f.Println("Your version is up to date.")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) printCredits(c *ishell.Context) {
|
||||
for _, pkg := range strings.Split(bridge.Credits, ";") {
|
||||
f.Println(pkg)
|
||||
}
|
||||
f.bridge.CheckForUpdates()
|
||||
}
|
||||
|
||||
func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
|
||||
if f.bridge.GetBool(settings.AutoUpdateKey) {
|
||||
if f.bridge.GetAutoUpdate() {
|
||||
f.Println("Bridge is already set to automatically install updates.")
|
||||
return
|
||||
}
|
||||
@ -55,12 +35,15 @@ func (f *frontendCLI) enableAutoUpdates(c *ishell.Context) {
|
||||
f.Println("Bridge is currently set to NOT automatically install updates.")
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to allow bridge to do this") {
|
||||
f.bridge.SetBool(settings.AutoUpdateKey, true)
|
||||
if err := f.bridge.SetAutoUpdate(true); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
|
||||
if !f.bridge.GetBool(settings.AutoUpdateKey) {
|
||||
if !f.bridge.GetAutoUpdate() {
|
||||
f.Println("Bridge is already set to NOT automatically install updates.")
|
||||
return
|
||||
}
|
||||
@ -68,7 +51,10 @@ func (f *frontendCLI) disableAutoUpdates(c *ishell.Context) {
|
||||
f.Println("Bridge is currently set to automatically install updates.")
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to stop bridge from doing this") {
|
||||
f.bridge.SetBool(settings.AutoUpdateKey, false)
|
||||
if err := f.bridge.SetAutoUpdate(false); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,7 +67,10 @@ func (f *frontendCLI) selectEarlyChannel(c *ishell.Context) {
|
||||
f.Println("Bridge is currently on the stable update channel.")
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to switch to the early-access update channel") {
|
||||
f.bridge.SetUpdateChannel(updater.EarlyChannel)
|
||||
if err := f.bridge.SetUpdateChannel(updater.EarlyChannel); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,6 +84,9 @@ func (f *frontendCLI) selectStableChannel(c *ishell.Context) {
|
||||
f.Println("Switching to the stable channel may reset all data!")
|
||||
|
||||
if f.yesNoQuestion("Are you sure you want to switch to the stable update channel") {
|
||||
f.bridge.SetUpdateChannel(updater.StableChannel)
|
||||
if err := f.bridge.SetUpdateChannel(updater.StableChannel); err != nil {
|
||||
f.printAndLogError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,7 +20,6 @@ package cli
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pmapi "github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
@ -67,15 +66,7 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
|
||||
}
|
||||
|
||||
func (f *frontendCLI) processAPIError(err error) {
|
||||
log.Warn("API error: ", err)
|
||||
switch err {
|
||||
case pmapi.ErrNoConnection:
|
||||
f.notifyInternetOff()
|
||||
case pmapi.ErrUpgradeApplication:
|
||||
f.notifyNeedUpgrade()
|
||||
default:
|
||||
f.Println("Server error:", err.Error())
|
||||
}
|
||||
f.printAndLogError(err)
|
||||
}
|
||||
|
||||
func (f *frontendCLI) notifyInternetOff() {
|
||||
@ -91,12 +82,7 @@ func (f *frontendCLI) notifyLogout(address string) {
|
||||
}
|
||||
|
||||
func (f *frontendCLI) notifyNeedUpgrade() {
|
||||
version, err := f.updater.Check()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to notify need upgrade")
|
||||
return
|
||||
}
|
||||
f.Println("Please download and install the newest version of application from", version.LandingPage)
|
||||
f.Println("Please download and install the newest version of the application.")
|
||||
}
|
||||
|
||||
func (f *frontendCLI) notifyCredentialsError() {
|
||||
|
||||
@ -1,75 +0,0 @@
|
||||
// Copyright (c) 2022 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package frontend provides all interfaces of the Bridge.
|
||||
package frontend
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/cli"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/grpc"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
)
|
||||
|
||||
type Frontend interface {
|
||||
Loop() error
|
||||
NotifyManualUpdate(update updater.VersionInfo, canInstall bool)
|
||||
SetVersion(update updater.VersionInfo)
|
||||
NotifySilentUpdateInstalled()
|
||||
NotifySilentUpdateError(error)
|
||||
WaitUntilFrontendIsReady()
|
||||
}
|
||||
|
||||
// New returns initialized frontend based on `frontendType`, which can be `cli` or `grpc`.
|
||||
func New(
|
||||
frontendType string,
|
||||
showWindowOnStart bool,
|
||||
panicHandler types.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
updater types.Updater,
|
||||
bridge *bridge.Bridge,
|
||||
restarter types.Restarter,
|
||||
locations *locations.Locations,
|
||||
) Frontend {
|
||||
switch frontendType {
|
||||
case "grpc":
|
||||
return grpc.NewService(
|
||||
showWindowOnStart,
|
||||
panicHandler,
|
||||
eventListener,
|
||||
updater,
|
||||
bridge,
|
||||
restarter,
|
||||
locations,
|
||||
)
|
||||
|
||||
case "cli":
|
||||
return cli.New(
|
||||
panicHandler,
|
||||
eventListener,
|
||||
updater,
|
||||
bridge,
|
||||
restarter,
|
||||
)
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -72,7 +72,6 @@ service Bridge {
|
||||
rpc IsAutomaticUpdateOn(google.protobuf.Empty) returns (google.protobuf.BoolValue);
|
||||
|
||||
// cache
|
||||
rpc IsCacheOnDiskEnabled (google.protobuf.Empty) returns (google.protobuf.BoolValue);
|
||||
rpc DiskCachePath(google.protobuf.Empty) returns (google.protobuf.StringValue);
|
||||
rpc ChangeLocalCache(ChangeLocalCacheRequest) returns (google.protobuf.Empty);
|
||||
|
||||
@ -160,7 +159,6 @@ message LoginAbortRequest {
|
||||
// Cache on disk related message
|
||||
//**********************************************************
|
||||
message ChangeLocalCacheRequest {
|
||||
bool enableDiskCache = 1;
|
||||
string diskCachePath = 2;
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v3.21.3
|
||||
// - protoc v3.21.7
|
||||
// source: bridge.proto
|
||||
|
||||
package grpc
|
||||
@ -64,7 +64,6 @@ type BridgeClient interface {
|
||||
SetIsAutomaticUpdateOn(ctx context.Context, in *wrapperspb.BoolValue, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
IsAutomaticUpdateOn(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.BoolValue, error)
|
||||
// cache
|
||||
IsCacheOnDiskEnabled(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.BoolValue, error)
|
||||
DiskCachePath(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.StringValue, error)
|
||||
ChangeLocalCache(ctx context.Context, in *ChangeLocalCacheRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
// mail
|
||||
@ -425,15 +424,6 @@ func (c *bridgeClient) IsAutomaticUpdateOn(ctx context.Context, in *emptypb.Empt
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *bridgeClient) IsCacheOnDiskEnabled(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.BoolValue, error) {
|
||||
out := new(wrapperspb.BoolValue)
|
||||
err := c.cc.Invoke(ctx, "/grpc.Bridge/IsCacheOnDiskEnabled", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *bridgeClient) DiskCachePath(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*wrapperspb.StringValue, error) {
|
||||
out := new(wrapperspb.StringValue)
|
||||
err := c.cc.Invoke(ctx, "/grpc.Bridge/DiskCachePath", in, out, opts...)
|
||||
@ -699,7 +689,6 @@ type BridgeServer interface {
|
||||
SetIsAutomaticUpdateOn(context.Context, *wrapperspb.BoolValue) (*emptypb.Empty, error)
|
||||
IsAutomaticUpdateOn(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error)
|
||||
// cache
|
||||
IsCacheOnDiskEnabled(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error)
|
||||
DiskCachePath(context.Context, *emptypb.Empty) (*wrapperspb.StringValue, error)
|
||||
ChangeLocalCache(context.Context, *ChangeLocalCacheRequest) (*emptypb.Empty, error)
|
||||
// mail
|
||||
@ -841,9 +830,6 @@ func (UnimplementedBridgeServer) SetIsAutomaticUpdateOn(context.Context, *wrappe
|
||||
func (UnimplementedBridgeServer) IsAutomaticUpdateOn(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method IsAutomaticUpdateOn not implemented")
|
||||
}
|
||||
func (UnimplementedBridgeServer) IsCacheOnDiskEnabled(context.Context, *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method IsCacheOnDiskEnabled not implemented")
|
||||
}
|
||||
func (UnimplementedBridgeServer) DiskCachePath(context.Context, *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method DiskCachePath not implemented")
|
||||
}
|
||||
@ -1571,24 +1557,6 @@ func _Bridge_IsAutomaticUpdateOn_Handler(srv interface{}, ctx context.Context, d
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Bridge_IsCacheOnDiskEnabled_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(emptypb.Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(BridgeServer).IsCacheOnDiskEnabled(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/grpc.Bridge/IsCacheOnDiskEnabled",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(BridgeServer).IsCacheOnDiskEnabled(ctx, req.(*emptypb.Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Bridge_DiskCachePath_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(emptypb.Empty)
|
||||
if err := dec(in); err != nil {
|
||||
@ -2139,10 +2107,6 @@ var Bridge_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "IsAutomaticUpdateOn",
|
||||
Handler: _Bridge_IsAutomaticUpdateOn_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "IsCacheOnDiskEnabled",
|
||||
Handler: _Bridge_IsCacheOnDiskEnabled_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "DiskCachePath",
|
||||
Handler: _Bridge_DiskCachePath_Handler,
|
||||
|
||||
@ -15,25 +15,18 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package serverutil
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
import "github.com/bradenaw/juniper/xslices"
|
||||
|
||||
// ServerErrorLogger implements go-imap/logger interface.
|
||||
type ServerErrorLogger struct {
|
||||
l *logrus.Entry
|
||||
// isInternetStatus returns true iff the event is InternetStatus.
|
||||
func (x *StreamEvent) isInternetStatus() bool {
|
||||
appEvent := x.GetApp()
|
||||
|
||||
return (appEvent != nil) && (appEvent.GetInternetStatus() != nil)
|
||||
}
|
||||
|
||||
func NewServerErrorLogger(protocol Protocol) *ServerErrorLogger {
|
||||
return &ServerErrorLogger{l: logrus.WithField("protocol", protocol)}
|
||||
}
|
||||
|
||||
func (s *ServerErrorLogger) Printf(format string, args ...interface{}) {
|
||||
s.l.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func (s *ServerErrorLogger) Println(args ...interface{}) {
|
||||
s.l.Errorln(args...)
|
||||
// filterOutInternetStatusEvents return a copy of the events list where all internet connection events have been removed.
|
||||
func filterOutInternetStatusEvents(events []*StreamEvent) []*StreamEvent {
|
||||
return xslices.Filter(events, func(event *StreamEvent) bool { return !event.isInternetStatus() })
|
||||
}
|
||||
@ -21,34 +21,29 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptotls "crypto/tls"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/tls"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/crash"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/types"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/locations"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/users"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/restarter"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -59,6 +54,7 @@ const (
|
||||
// Service is the RPC service struct.
|
||||
type Service struct { // nolint:structcheck
|
||||
UnimplementedBridgeServer
|
||||
|
||||
grpcServer *grpc.Server // the gGRPC server
|
||||
listener net.Listener
|
||||
eventStreamCh chan *StreamEvent
|
||||
@ -66,99 +62,87 @@ type Service struct { // nolint:structcheck
|
||||
eventQueue []*StreamEvent
|
||||
eventQueueMutex sync.Mutex
|
||||
|
||||
panicHandler types.PanicHandler
|
||||
eventListener listener.Listener
|
||||
updater types.Updater
|
||||
updateCheckMutex sync.Mutex
|
||||
bridge types.Bridger
|
||||
restarter types.Restarter
|
||||
showOnStartup bool
|
||||
authClient pmapi.Client
|
||||
auth *pmapi.Auth
|
||||
password []byte
|
||||
newVersionInfo updater.VersionInfo
|
||||
panicHandler *crash.Handler
|
||||
restarter *restarter.Restarter
|
||||
bridge *bridge.Bridge
|
||||
newVersionInfo updater.VersionInfo
|
||||
|
||||
log *logrus.Entry
|
||||
initializing sync.WaitGroup
|
||||
initializationDone sync.Once
|
||||
firstTimeAutostart sync.Once
|
||||
locations *locations.Locations
|
||||
token string
|
||||
pemCert string
|
||||
|
||||
showOnStartup bool
|
||||
}
|
||||
|
||||
// NewService returns a new instance of the service.
|
||||
func NewService(
|
||||
showOnStartup bool,
|
||||
panicHandler types.PanicHandler,
|
||||
eventListener listener.Listener,
|
||||
updater types.Updater,
|
||||
bridge types.Bridger,
|
||||
restarter types.Restarter,
|
||||
panicHandler *crash.Handler,
|
||||
restarter *restarter.Restarter,
|
||||
locations *locations.Locations,
|
||||
) *Service {
|
||||
s := Service{
|
||||
UnimplementedBridgeServer: UnimplementedBridgeServer{},
|
||||
panicHandler: panicHandler,
|
||||
eventListener: eventListener,
|
||||
updater: updater,
|
||||
bridge: bridge,
|
||||
restarter: restarter,
|
||||
showOnStartup: showOnStartup,
|
||||
bridge *bridge.Bridge,
|
||||
showOnStartup bool,
|
||||
) (*Service, error) {
|
||||
tlsConfig, certPEM, err := newTLSConfig()
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Could not generate gRPC TLS config")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0") // Port should be provided by the OS.
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("Could not create gRPC listener")
|
||||
}
|
||||
|
||||
token := uuid.NewString()
|
||||
|
||||
if path, err := saveGRPCServerConfigFile(locations, listener, token, certPEM); err != nil {
|
||||
logrus.WithError(err).WithField("path", path).Panic("Could not write gRPC service config file")
|
||||
} else {
|
||||
logrus.WithField("path", path).Info("Successfully saved gRPC service config file")
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
grpcServer: grpc.NewServer(
|
||||
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
||||
grpc.UnaryInterceptor(newUnaryTokenValidator(token)),
|
||||
grpc.StreamInterceptor(newStreamTokenValidator(token)),
|
||||
),
|
||||
listener: listener,
|
||||
|
||||
panicHandler: panicHandler,
|
||||
restarter: restarter,
|
||||
bridge: bridge,
|
||||
|
||||
log: logrus.WithField("pkg", "grpc"),
|
||||
initializing: sync.WaitGroup{},
|
||||
initializationDone: sync.Once{},
|
||||
firstTimeAutostart: sync.Once{},
|
||||
locations: locations,
|
||||
token: uuid.NewString(),
|
||||
|
||||
showOnStartup: showOnStartup,
|
||||
}
|
||||
|
||||
// Initializing.Done is only called sync.Once. Please keep the increment
|
||||
// set to 1
|
||||
// Initializing.Done is only called sync.Once. Please keep the increment set to 1
|
||||
s.initializing.Add(1)
|
||||
|
||||
tlsConfig, pemCert, err := s.generateTLSConfig()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Panic("Could not generate gRPC TLS config")
|
||||
}
|
||||
|
||||
s.pemCert = string(pemCert)
|
||||
|
||||
// Initialize the autostart.
|
||||
s.initAutostart()
|
||||
s.grpcServer = grpc.NewServer(
|
||||
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
||||
grpc.UnaryInterceptor(s.validateUnaryServerToken),
|
||||
grpc.StreamInterceptor(s.validateStreamServerToken),
|
||||
)
|
||||
|
||||
RegisterBridgeServer(s.grpcServer, &s)
|
||||
|
||||
s.listener, err = net.Listen("tcp", "127.0.0.1:0") // Port 0 means that the port is randomly picked by the system.
|
||||
if err != nil {
|
||||
s.log.WithError(err).Panic("Could not create gRPC listener")
|
||||
}
|
||||
|
||||
if path, err := s.saveGRPCServerConfigFile(); err != nil {
|
||||
s.log.WithError(err).WithField("path", path).Panic("Could not write gRPC service config file")
|
||||
} else {
|
||||
s.log.WithField("path", path).Info("Successfully saved gRPC service config file")
|
||||
}
|
||||
// Register the gRPC service implementation.
|
||||
RegisterBridgeServer(s.grpcServer, s)
|
||||
|
||||
s.log.Info("gRPC server listening on ", s.listener.Addr())
|
||||
|
||||
return &s
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GODT-1507 Windows: autostart needs to be created after Qt is initialized.
|
||||
// GODT-1206: if preferences file says it should be on enable it here.
|
||||
// TO-DO GODT-1681 Autostart needs to be properly implement for gRPC approach.
|
||||
func (s *Service) initAutostart() {
|
||||
// GODT-1507 Windows: autostart needs to be created after Qt is initialized.
|
||||
// GODT-1206: if preferences file says it should be on enable it here.
|
||||
|
||||
// TO-DO GODT-1681 Autostart needs to be properly implement for gRPC approach.
|
||||
|
||||
s.firstTimeAutostart.Do(func() {
|
||||
shouldAutostartBeOn := s.bridge.GetBool(settings.AutostartKey)
|
||||
if s.bridge.IsFirstStart() || shouldAutostartBeOn {
|
||||
if err := s.bridge.EnableAutostart(); err != nil {
|
||||
shouldAutostartBeOn := s.bridge.GetAutostart()
|
||||
if s.bridge.GetFirstStart() || shouldAutostartBeOn {
|
||||
if err := s.bridge.SetAutostart(true); err != nil {
|
||||
s.log.WithField("prefs", shouldAutostartBeOn).WithError(err).Error("Failed to enable first autostart")
|
||||
}
|
||||
return
|
||||
@ -168,7 +152,7 @@ func (s *Service) initAutostart() {
|
||||
|
||||
func (s *Service) Loop() error {
|
||||
defer func() {
|
||||
s.bridge.SetBool(settings.FirstStartGUIKey, false)
|
||||
_ = s.bridge.SetFirstStartGUI(false)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
@ -179,7 +163,7 @@ func (s *Service) Loop() error {
|
||||
s.log.Info("Starting gRPC server")
|
||||
|
||||
if err := s.grpcServer.Serve(s.listener); err != nil {
|
||||
s.log.WithError(err).Error("Error serving gRPC")
|
||||
s.log.WithError(err).Error("Failed to serve gRPC")
|
||||
return err
|
||||
}
|
||||
|
||||
@ -212,140 +196,59 @@ func (s *Service) WaitUntilFrontendIsReady() {
|
||||
s.initializing.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) watchEvents() { // nolint:funlen
|
||||
if s.bridge.HasError(bridge.ErrLocalCacheUnavailable) {
|
||||
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_UNAVAILABLE_ERROR))
|
||||
}
|
||||
func (s *Service) watchEvents() {
|
||||
eventCh, done := s.bridge.GetEvents()
|
||||
defer done()
|
||||
|
||||
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
|
||||
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
|
||||
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
|
||||
internetConnChangedCh := s.eventListener.ProvideChannel(events.InternetConnChangedEvent)
|
||||
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
|
||||
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
|
||||
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
|
||||
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
|
||||
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
|
||||
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
|
||||
userChangedCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
|
||||
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
|
||||
|
||||
// we forward events to the GUI/frontend via the gRPC event stream.
|
||||
for {
|
||||
select {
|
||||
case errorDetails := <-errorCh:
|
||||
if strings.Contains(errorDetails, "IMAP failed") {
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_IMAP_PORT_ISSUE))
|
||||
}
|
||||
if strings.Contains(errorDetails, "SMTP failed") {
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_SMTP_PORT_ISSUE))
|
||||
}
|
||||
case reason := <-credentialsErrorCh:
|
||||
if reason == keychain.ErrMacKeychainRebuild.Error() {
|
||||
_ = s.SendEvent(NewKeychainRebuildKeychainEvent())
|
||||
continue
|
||||
}
|
||||
// TODO: Better error events.
|
||||
for _, err := range s.bridge.GetErrors() {
|
||||
switch {
|
||||
case errors.Is(err, vault.ErrCorrupt):
|
||||
_ = s.SendEvent(NewKeychainHasNoKeychainEvent())
|
||||
case email := <-noActiveKeyForRecipientCh:
|
||||
_ = s.SendEvent(NewMailNoActiveKeyForRecipientEvent(email))
|
||||
case stat := <-internetConnChangedCh:
|
||||
if stat == events.InternetOff {
|
||||
_ = s.SendEvent(NewInternetStatusEvent(false))
|
||||
}
|
||||
if stat == events.InternetOn {
|
||||
_ = s.SendEvent(NewInternetStatusEvent(true))
|
||||
}
|
||||
|
||||
case <-secondInstanceCh:
|
||||
_ = s.SendEvent(NewShowMainWindowEvent())
|
||||
case <-restartBridgeCh:
|
||||
_, _ = s.Restart(
|
||||
metadata.AppendToOutgoingContext(context.Background(), serverTokenMetadataKey, s.token),
|
||||
&emptypb.Empty{},
|
||||
)
|
||||
case address := <-addressChangedCh:
|
||||
_ = s.SendEvent(NewMailAddressChangeEvent(address))
|
||||
case address := <-addressChangedLogoutCh:
|
||||
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(address))
|
||||
case userID := <-logoutCh:
|
||||
user, err := s.bridge.GetUserInfo(userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = s.SendEvent(NewUserDisconnectedEvent(user.Username))
|
||||
case <-updateApplicationCh:
|
||||
s.updateForce()
|
||||
case userID := <-userChangedCh:
|
||||
_ = s.SendEvent(NewUserChangedEvent(userID))
|
||||
case <-certIssue:
|
||||
_ = s.SendEvent(NewMailApiCertIssue())
|
||||
case errors.Is(err, vault.ErrInsecure):
|
||||
_ = s.SendEvent(NewKeychainHasNoKeychainEvent())
|
||||
|
||||
case errors.Is(err, bridge.ErrServeIMAP):
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_IMAP_PORT_ISSUE))
|
||||
|
||||
case errors.Is(err, bridge.ErrServeSMTP):
|
||||
_ = s.SendEvent(NewMailSettingsErrorEvent(MailSettingsErrorType_SMTP_PORT_ISSUE))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) loginAbort() {
|
||||
s.loginClean()
|
||||
}
|
||||
for event := range eventCh {
|
||||
switch event := event.(type) {
|
||||
case events.ConnStatus:
|
||||
_ = s.SendEvent(NewInternetStatusEvent(event.Status == liteapi.StatusUp))
|
||||
|
||||
func (s *Service) loginClean() {
|
||||
s.auth = nil
|
||||
s.authClient = nil
|
||||
for i := range s.password {
|
||||
s.password[i] = '\x00'
|
||||
}
|
||||
s.password = s.password[0:0]
|
||||
}
|
||||
case events.Raise:
|
||||
_ = s.SendEvent(NewShowMainWindowEvent())
|
||||
|
||||
func (s *Service) finishLogin() {
|
||||
defer s.loginClean()
|
||||
case events.UserAddressCreated:
|
||||
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
|
||||
|
||||
if len(s.password) == 0 || s.auth == nil || s.authClient == nil {
|
||||
s.log.
|
||||
WithField("hasPass", len(s.password) != 0).
|
||||
WithField("hasAuth", s.auth != nil).
|
||||
WithField("hasClient", s.authClient != nil).
|
||||
Error("Finish login: authentication incomplete")
|
||||
case events.UserAddressChanged:
|
||||
_ = s.SendEvent(NewMailAddressChangeEvent(event.Address))
|
||||
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TWO_PASSWORDS_ABORT, "Missing authentication, try again."))
|
||||
return
|
||||
}
|
||||
case events.UserAddressDeleted:
|
||||
_ = s.SendEvent(NewMailAddressChangeLogoutEvent(event.Address))
|
||||
|
||||
done := make(chan string)
|
||||
s.eventListener.Add(events.UserChangeDone, done)
|
||||
defer s.eventListener.Remove(events.UserChangeDone, done)
|
||||
case events.UserChanged:
|
||||
_ = s.SendEvent(NewUserChangedEvent(event.UserID))
|
||||
|
||||
userID, err := s.bridge.FinishLogin(s.authClient, s.auth, s.password)
|
||||
|
||||
if err != nil && err != users.ErrUserAlreadyConnected {
|
||||
s.log.WithError(err).Errorf("Finish login failed")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TWO_PASSWORDS_ABORT, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// The user changed should be triggered by FinishLogin, but it is not
|
||||
// guaranteed when this is going to happen. Therefor we should wait
|
||||
// until we receive the signal from userChanged function.
|
||||
s.waitForUserChangeDone(done, userID)
|
||||
|
||||
s.log.WithField("userID", userID).Debug("Login finished")
|
||||
_ = s.SendEvent(NewLoginFinishedEvent(userID))
|
||||
|
||||
if err == users.ErrUserAlreadyConnected {
|
||||
s.log.WithError(err).Error("User already logged in")
|
||||
_ = s.SendEvent(NewLoginAlreadyLoggedInEvent(userID))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) waitForUserChangeDone(done <-chan string, userID string) {
|
||||
for {
|
||||
select {
|
||||
case changedID := <-done:
|
||||
if changedID == userID {
|
||||
return
|
||||
case events.UserDeauth:
|
||||
if user, err := s.bridge.GetUserInfo(event.UserID); err != nil {
|
||||
s.log.WithError(err).Error("Failed to get user info")
|
||||
} else {
|
||||
_ = s.SendEvent(NewUserDisconnectedEvent(user.Username))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
s.log.WithField("ID", userID).Warning("Login finished but user not added within 2 seconds")
|
||||
return
|
||||
|
||||
case events.TLSIssue:
|
||||
_ = s.SendEvent(NewMailApiCertIssue())
|
||||
|
||||
case events.UpdateForced:
|
||||
panic("TODO")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -354,103 +257,46 @@ func (s *Service) triggerReset() {
|
||||
defer func() {
|
||||
_ = s.SendEvent(NewResetFinishedEvent())
|
||||
}()
|
||||
s.bridge.FactoryReset()
|
||||
if err := s.bridge.FactoryReset(context.Background()); err != nil {
|
||||
s.log.WithError(err).Error("Failed to reset")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) checkUpdate() {
|
||||
version, err := s.updater.Check()
|
||||
func newTLSConfig() (*tls.Config, []byte, error) {
|
||||
template, err := certs.NewTLSTemplate()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("An error occurred while checking for updates")
|
||||
s.SetVersion(updater.VersionInfo{})
|
||||
return
|
||||
}
|
||||
s.SetVersion(version)
|
||||
}
|
||||
|
||||
func (s *Service) updateForce() {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer s.updateCheckMutex.Unlock()
|
||||
s.checkUpdate()
|
||||
_ = s.SendEvent(NewUpdateForceEvent(s.newVersionInfo.Version.String()))
|
||||
}
|
||||
|
||||
func (s *Service) checkUpdateAndNotify(isReqFromUser bool) {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer func() {
|
||||
s.updateCheckMutex.Unlock()
|
||||
_ = s.SendEvent(NewUpdateCheckFinishedEvent())
|
||||
}()
|
||||
|
||||
s.checkUpdate()
|
||||
version := s.newVersionInfo
|
||||
if (version.Version == nil) || (version.Version.String() == "") {
|
||||
if isReqFromUser {
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
}
|
||||
return
|
||||
}
|
||||
if !s.updater.IsUpdateApplicable(s.newVersionInfo) {
|
||||
s.log.Info("No need to update")
|
||||
if isReqFromUser {
|
||||
_ = s.SendEvent(NewUpdateIsLatestVersionEvent())
|
||||
}
|
||||
} else if isReqFromUser {
|
||||
s.NotifyManualUpdate(s.newVersionInfo, s.updater.CanInstall(s.newVersionInfo))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) installUpdate() {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer s.updateCheckMutex.Unlock()
|
||||
|
||||
if !s.updater.CanInstall(s.newVersionInfo) {
|
||||
s.log.Warning("Skipping update installation, current version too old")
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
return
|
||||
return nil, nil, fmt.Errorf("failed to create TLS template: %w", err)
|
||||
}
|
||||
|
||||
if err := s.updater.InstallUpdate(s.newVersionInfo); err != nil {
|
||||
if errors.Cause(err) == updater.ErrDownloadVerify {
|
||||
s.log.WithError(err).Warning("Skipping update installation due to temporary error")
|
||||
} else {
|
||||
s.log.WithError(err).Error("The update couldn't be installed")
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.SendEvent(NewUpdateSilentRestartNeededEvent())
|
||||
}
|
||||
|
||||
func (s *Service) generateTLSConfig() (tlsConfig *cryptotls.Config, pemCert []byte, err error) {
|
||||
pemCert, pemKey, err := tls.NewPEMKeyPair()
|
||||
certPEM, keyPEM, err := certs.GenerateCert(template)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("Could not get TLS config")
|
||||
return nil, nil, fmt.Errorf("failed to generate cert: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = tls.GetConfigFromPEMKeyPair(pemCert, pemKey)
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("Could not get TLS config")
|
||||
return nil, nil, fmt.Errorf("failed to load cert: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig.ClientAuth = cryptotls.NoClientCert // skip client auth if the certificate allow it.
|
||||
|
||||
return tlsConfig, pemCert, nil
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientAuth: tls.NoClientCert,
|
||||
}, certPEM, nil
|
||||
}
|
||||
|
||||
func (s *Service) saveGRPCServerConfigFile() (string, error) {
|
||||
address, ok := s.listener.Addr().(*net.TCPAddr)
|
||||
func saveGRPCServerConfigFile(locations *locations.Locations, listener net.Listener, token string, certPEM []byte) (string, error) {
|
||||
address, ok := listener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("could not retrieve gRPC service listener address")
|
||||
}
|
||||
|
||||
sc := config{
|
||||
Port: address.Port,
|
||||
Cert: s.pemCert,
|
||||
Token: s.token,
|
||||
Cert: string(certPEM),
|
||||
Token: token,
|
||||
}
|
||||
|
||||
settingsPath, err := s.locations.ProvideSettingsPath()
|
||||
settingsPath, err := locations.ProvideSettingsPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -461,7 +307,7 @@ func (s *Service) saveGRPCServerConfigFile() (string, error) {
|
||||
}
|
||||
|
||||
// validateServerToken verify that the server token provided by the client is valid.
|
||||
func (s *Service) validateServerToken(ctx context.Context) error {
|
||||
func validateServerToken(ctx context.Context, wantToken string) error {
|
||||
values, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing server token")
|
||||
@ -476,40 +322,31 @@ func (s *Service) validateServerToken(ctx context.Context) error {
|
||||
return status.Error(codes.Unauthenticated, "more than one server token was provided")
|
||||
}
|
||||
|
||||
if token[0] != s.token {
|
||||
if token[0] != wantToken {
|
||||
return status.Error(codes.Unauthenticated, "invalid server token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUnaryServerToken check the server token for every unary gRPC call.
|
||||
func (s *Service) validateUnaryServerToken(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (resp interface{}, err error) {
|
||||
if err := s.validateServerToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// newUnaryTokenValidator checks the server token for every unary gRPC call.
|
||||
func newUnaryTokenValidator(wantToken string) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if err := validateServerToken(ctx, wantToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// validateStreamServerToken check the server token for every gRPC stream request.
|
||||
func (s *Service) validateStreamServerToken(
|
||||
srv interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
logEntry := s.log.WithField("FullMethod", info.FullMethod)
|
||||
// newStreamTokenValidator checks the server token for every gRPC stream request.
|
||||
func newStreamTokenValidator(wantToken string) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := validateServerToken(stream.Context(), wantToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.validateServerToken(ss.Context()); err != nil {
|
||||
logEntry.WithError(err).Error("Stream validator failed")
|
||||
return err
|
||||
return handler(srv, stream)
|
||||
}
|
||||
|
||||
return handler(srv, ss)
|
||||
}
|
||||
|
||||
@ -23,15 +23,13 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/bridge"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/constants"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/frontend/theme"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/updater"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/keychain"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/ports"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
@ -116,7 +114,7 @@ func (s *Service) Quit(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empt
|
||||
func (s *Service) Restart(ctx context.Context, empty *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("Restart")
|
||||
|
||||
s.restarter.SetToRestart()
|
||||
s.restarter.Set(true, false)
|
||||
return s.Quit(ctx, empty)
|
||||
}
|
||||
|
||||
@ -129,25 +127,19 @@ func (s *Service) ShowOnStartup(ctx context.Context, _ *emptypb.Empty) (*wrapper
|
||||
func (s *Service) ShowSplashScreen(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("ShowSplashScreen")
|
||||
|
||||
if s.bridge.IsFirstStart() {
|
||||
return wrapperspb.Bool(false), nil
|
||||
}
|
||||
|
||||
ver, err := semver.NewVersion(s.bridge.GetLastVersion())
|
||||
if err != nil {
|
||||
s.log.WithError(err).WithField("last", s.bridge.GetLastVersion()).Debug("Cannot parse last version")
|
||||
if s.bridge.GetFirstStart() {
|
||||
return wrapperspb.Bool(false), nil
|
||||
}
|
||||
|
||||
// Current splash screen contains update on rebranding. Therefore, it
|
||||
// should be shown only if the last used version was less than 2.2.0.
|
||||
return wrapperspb.Bool(ver.LessThan(semver.MustParse("2.2.0"))), nil
|
||||
return wrapperspb.Bool(s.bridge.GetLastVersion().LessThan(semver.MustParse("2.2.0"))), nil
|
||||
}
|
||||
|
||||
func (s *Service) IsFirstGuiStart(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsFirstGuiStart")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.FirstStartGUIKey)), nil
|
||||
return wrapperspb.Bool(s.bridge.GetFirstStartGUI()), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
@ -155,22 +147,16 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
|
||||
|
||||
defer func() { _ = s.SendEvent(NewToggleAutostartFinishedEvent()) }()
|
||||
|
||||
if isOn.Value == s.bridge.IsAutostartEnabled() {
|
||||
if isOn.Value == s.bridge.GetAutostart() {
|
||||
s.initAutostart()
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if isOn.Value {
|
||||
err = s.bridge.EnableAutostart()
|
||||
} else {
|
||||
err = s.bridge.DisableAutostart()
|
||||
}
|
||||
|
||||
s.initAutostart()
|
||||
|
||||
if err != nil {
|
||||
if err := s.bridge.SetAutostart(isOn.Value); err != nil {
|
||||
s.log.WithField("makeItEnabled", isOn.Value).WithError(err).Error("Autostart change failed")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set autostart: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
@ -179,7 +165,7 @@ func (s *Service) SetIsAutostartOn(ctx context.Context, isOn *wrapperspb.BoolVal
|
||||
func (s *Service) IsAutostartOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAutostartOn")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.IsAutostartEnabled()), nil
|
||||
return wrapperspb.Bool(s.bridge.GetAutostart()), nil
|
||||
}
|
||||
|
||||
func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
@ -190,8 +176,10 @@ func (s *Service) SetIsBetaEnabled(ctx context.Context, isEnabled *wrapperspb.Bo
|
||||
channel = updater.EarlyChannel
|
||||
}
|
||||
|
||||
s.bridge.SetUpdateChannel(channel)
|
||||
s.checkUpdate()
|
||||
if err := s.bridge.SetUpdateChannel(channel); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set update channel")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set update channel: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -205,7 +193,10 @@ func (s *Service) IsBetaEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapper
|
||||
func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isVisible", isVisible.Value).Debug("SetIsAllMailVisible")
|
||||
|
||||
s.bridge.SetIsAllMailVisible(isVisible.Value)
|
||||
if err := s.bridge.SetShowAllMail(isVisible.Value); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set show all mail")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set show all mail: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -213,7 +204,7 @@ func (s *Service) SetIsAllMailVisible(ctx context.Context, isVisible *wrapperspb
|
||||
func (s *Service) IsAllMailVisible(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAllMailVisible")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.IsAllMailVisible()), nil
|
||||
return wrapperspb.Bool(s.bridge.GetShowAllMail()), nil
|
||||
}
|
||||
|
||||
func (s *Service) GoOs(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
@ -241,7 +232,7 @@ func (s *Service) Version(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.St
|
||||
func (s *Service) LogsPath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("LogsPath")
|
||||
|
||||
path, err := s.bridge.ProvideLogsPath()
|
||||
path, err := s.bridge.GetLogsPath()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Cannot determine logs path")
|
||||
return nil, err
|
||||
@ -275,7 +266,10 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
|
||||
return nil, status.Error(codes.NotFound, "Color scheme not available")
|
||||
}
|
||||
|
||||
s.bridge.Set(settings.ColorScheme, name.Value)
|
||||
if err := s.bridge.SetColorScheme(name.Value); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set color scheme")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set color scheme: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -283,10 +277,13 @@ func (s *Service) SetColorSchemeName(ctx context.Context, name *wrapperspb.Strin
|
||||
func (s *Service) ColorSchemeName(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("ColorSchemeName")
|
||||
|
||||
current := s.bridge.Get(settings.ColorScheme)
|
||||
current := s.bridge.GetColorScheme()
|
||||
if !theme.IsAvailable(theme.Theme(current)) {
|
||||
current = string(theme.DefaultTheme())
|
||||
s.bridge.Set(settings.ColorScheme, current)
|
||||
if err := s.bridge.SetColorScheme(current); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set color scheme")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set color scheme: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return wrapperspb.String(current), nil
|
||||
@ -312,6 +309,7 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
|
||||
defer func() { _ = s.SendEvent(NewReportBugFinishedEvent()) }()
|
||||
|
||||
if err := s.bridge.ReportBug(
|
||||
context.Background(),
|
||||
report.OsType,
|
||||
report.OsVersion,
|
||||
report.Description,
|
||||
@ -331,6 +329,7 @@ func (s *Service) ReportBug(ctx context.Context, report *ReportBugRequest) (*emp
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
/*
|
||||
func (s *Service) ForceLauncher(ctx context.Context, launcher *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("launcher", launcher.Value).Debug("ForceLauncher")
|
||||
|
||||
@ -350,6 +349,7 @@ func (s *Service) SetMainExecutable(ctx context.Context, exe *wrapperspb.StringV
|
||||
}()
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
*/
|
||||
|
||||
func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login")
|
||||
@ -357,135 +357,44 @@ func (s *Service) Login(ctx context.Context, login *LoginRequest) (*emptypb.Empt
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
var err error
|
||||
s.password, err = base64.StdEncoding.DecodeString(login.Password)
|
||||
password, err := base64.StdEncoding.DecodeString(login.Password)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Cannot decode password")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot decode password"))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
|
||||
s.authClient, s.auth, err = s.bridge.Login(login.Username, s.password)
|
||||
// TODO: Handle different error types!
|
||||
// - bad credentials
|
||||
// - bad proton plan
|
||||
// - user already exists
|
||||
userID, err := s.bridge.LoginUser(context.Background(), login.Username, string(password), nil, nil)
|
||||
if err != nil {
|
||||
if err == pmapi.ErrPasswordWrong {
|
||||
// Remove error message since it is hardcoded in QML.
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, ""))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
if err == pmapi.ErrPaidPlanRequired {
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_FREE_USER, ""))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, err.Error()))
|
||||
s.loginClean()
|
||||
s.log.WithError(err).Error("Cannot login user")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot login user"))
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth.HasTwoFactor() {
|
||||
_ = s.SendEvent(NewLoginTfaRequestedEvent(login.Username))
|
||||
return
|
||||
}
|
||||
if s.auth.HasMailboxPassword() {
|
||||
_ = s.SendEvent(NewLoginTwoPasswordsRequestedEvent())
|
||||
return
|
||||
}
|
||||
|
||||
s.finishLogin()
|
||||
}()
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login2FA(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login2FA")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
if s.auth == nil || s.authClient == nil {
|
||||
s.log.Errorf("Login 2FA: authethication incomplete %p %p", s.auth, s.authClient)
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ABORT, "Missing authentication, try again."))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
|
||||
twoFA, err := base64.StdEncoding.DecodeString(login.Password)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Cannot decode 2fa code")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot decode 2fa code"))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
|
||||
err = s.authClient.Auth2FA(context.Background(), string(twoFA))
|
||||
if err == pmapi.ErrBad2FACodeTryAgain {
|
||||
s.log.Warn("Login 2FA: retry 2fa")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ERROR, ""))
|
||||
return
|
||||
}
|
||||
|
||||
if err == pmapi.ErrBad2FACode {
|
||||
s.log.Warn("Login 2FA: abort 2fa")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ABORT, ""))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.log.WithError(err).Warn("Login 2FA: failed.")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_TFA_ABORT, err.Error()))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth.HasMailboxPassword() {
|
||||
_ = s.SendEvent(NewLoginTwoPasswordsRequestedEvent())
|
||||
return
|
||||
}
|
||||
|
||||
s.finishLogin()
|
||||
_ = s.SendEvent(NewLoginFinishedEvent(userID))
|
||||
}()
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login2Passwords(ctx context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", login.Username).Debug("Login2Passwords")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
var err error
|
||||
s.password, err = base64.StdEncoding.DecodeString(login.Password)
|
||||
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Cannot decode mbox password")
|
||||
_ = s.SendEvent(NewLoginError(LoginErrorType_USERNAME_PASSWORD_ERROR, "Cannot decode mbox password"))
|
||||
s.loginClean()
|
||||
return
|
||||
}
|
||||
|
||||
s.finishLogin()
|
||||
}()
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
func (s *Service) Login2FA(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (s *Service) LoginAbort(ctx context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("username", loginAbort.Username).Debug("LoginAbort")
|
||||
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
s.loginAbort()
|
||||
}()
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
func (s *Service) Login2Passwords(_ context.Context, login *LoginRequest) (*emptypb.Empty, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (s *Service) CheckUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
func (s *Service) LoginAbort(_ context.Context, loginAbort *LoginAbortRequest) (*emptypb.Empty, error) {
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
/*
|
||||
func (s *Service) CheckUpdate(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
|
||||
s.log.Debug("CheckUpdate")
|
||||
|
||||
go func() {
|
||||
@ -507,21 +416,20 @@ func (s *Service) InstallUpdate(ctx context.Context, _ *emptypb.Empty) (*emptypb
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
*/
|
||||
|
||||
func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isOn", isOn.Value).Debug("SetIsAutomaticUpdateOn")
|
||||
|
||||
currentlyOn := s.bridge.GetBool(settings.AutoUpdateKey)
|
||||
currentlyOn := s.bridge.GetAutoUpdate()
|
||||
if currentlyOn == isOn.Value {
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
s.bridge.SetBool(settings.AutoUpdateKey, isOn.Value)
|
||||
go func() {
|
||||
defer s.panicHandler.HandlePanic()
|
||||
|
||||
s.checkUpdateAndNotify(false)
|
||||
}()
|
||||
if err := s.bridge.SetAutoUpdate(isOn.Value); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set auto update")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set auto update: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -529,51 +437,21 @@ func (s *Service) SetIsAutomaticUpdateOn(ctx context.Context, isOn *wrapperspb.B
|
||||
func (s *Service) IsAutomaticUpdateOn(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsAutomaticUpdateOn")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.AutoUpdateKey)), nil
|
||||
}
|
||||
|
||||
func (s *Service) IsCacheOnDiskEnabled(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) {
|
||||
s.log.Debug("IsCacheOnDiskEnabled")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.CacheEnabledKey)), nil
|
||||
return wrapperspb.Bool(s.bridge.GetAutoUpdate()), nil
|
||||
}
|
||||
|
||||
func (s *Service) DiskCachePath(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("DiskCachePath")
|
||||
|
||||
return wrapperspb.String(s.bridge.Get(settings.CacheLocationKey)), nil
|
||||
return wrapperspb.String(s.bridge.GetGluonDir()), nil
|
||||
}
|
||||
|
||||
func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCacheRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("enableDiskCache", change.EnableDiskCache).
|
||||
WithField("diskCachePath", change.DiskCachePath).
|
||||
Debug("DiskCachePath")
|
||||
s.log.WithField("diskCachePath", change.DiskCachePath).Debug("DiskCachePath")
|
||||
|
||||
restart := false
|
||||
defer func(willRestart *bool) {
|
||||
_ = s.SendEvent(NewCacheChangeLocalCacheFinishedEvent(*willRestart))
|
||||
if *willRestart {
|
||||
_, _ = s.Restart(ctx, &emptypb.Empty{})
|
||||
}
|
||||
}(&restart)
|
||||
|
||||
if change.EnableDiskCache != s.bridge.GetBool(settings.CacheEnabledKey) {
|
||||
if change.EnableDiskCache {
|
||||
if err := s.bridge.EnableCache(); err != nil {
|
||||
s.log.WithError(err).Error("Cannot enable disk cache")
|
||||
} else {
|
||||
restart = true
|
||||
_ = s.SendEvent(NewIsCacheOnDiskEnabledChanged(s.bridge.GetBool(settings.CacheEnabledKey)))
|
||||
}
|
||||
} else {
|
||||
if err := s.bridge.DisableCache(); err != nil {
|
||||
s.log.WithError(err).Error("Cannot disable disk cache")
|
||||
} else {
|
||||
restart = true
|
||||
_ = s.SendEvent(NewIsCacheOnDiskEnabledChanged(s.bridge.GetBool(settings.CacheEnabledKey)))
|
||||
}
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
_ = s.SendEvent(NewCacheChangeLocalCacheFinishedEvent(false))
|
||||
}()
|
||||
|
||||
path := change.DiskCachePath
|
||||
//goland:noinspection GoBoolExpressions
|
||||
@ -581,16 +459,14 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache
|
||||
path = path[1:]
|
||||
}
|
||||
|
||||
if change.EnableDiskCache && path != s.bridge.Get(settings.CacheLocationKey) {
|
||||
if err := s.bridge.MigrateCache(s.bridge.Get(settings.CacheLocationKey), path); err != nil {
|
||||
if path != s.bridge.GetGluonDir() {
|
||||
if err := s.bridge.SetGluonDir(ctx, path); err != nil {
|
||||
s.log.WithError(err).Error("The local cache location could not be changed.")
|
||||
_ = s.SendEvent(NewCacheErrorEvent(CacheErrorType_CACHE_CANT_MOVE_ERROR))
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
s.bridge.Set(settings.CacheLocationKey, path)
|
||||
restart = true
|
||||
_ = s.SendEvent(NewDiskCachePathChanged(s.bridge.Get(settings.CacheLocationKey)))
|
||||
_ = s.SendEvent(NewDiskCachePathChanged(s.bridge.GetGluonDir()))
|
||||
}
|
||||
|
||||
_ = s.SendEvent(NewCacheLocationChangeSuccessEvent())
|
||||
@ -601,7 +477,10 @@ func (s *Service) ChangeLocalCache(ctx context.Context, change *ChangeLocalCache
|
||||
func (s *Service) SetIsDoHEnabled(ctx context.Context, isEnabled *wrapperspb.BoolValue) (*emptypb.Empty, error) {
|
||||
s.log.WithField("isEnabled", isEnabled.Value).Debug("SetIsDohEnabled")
|
||||
|
||||
s.bridge.SetProxyAllowed(isEnabled.Value)
|
||||
if err := s.bridge.SetProxyAllowed(isEnabled.Value); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set DoH")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set DoH: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -615,13 +494,14 @@ func (s *Service) IsDoHEnabled(ctx context.Context, _ *emptypb.Empty) (*wrappers
|
||||
func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolValue) (*emptypb.Empty, error) { //nolint:revive,stylecheck
|
||||
s.log.WithField("useSsl", useSsl.Value).Debug("SetUseSslForSmtp")
|
||||
|
||||
if s.bridge.GetBool(settings.SMTPSSLKey) == useSsl.Value {
|
||||
if s.bridge.GetSMTPSSL() == useSsl.Value {
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
s.bridge.SetBool(settings.SMTPSSLKey, useSsl.Value)
|
||||
|
||||
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
|
||||
if err := s.bridge.SetSMTPSSL(useSsl.Value); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set SMTP SSL")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set SMTP SSL: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, s.SendEvent(NewMailSettingsUseSslForSmtpFinishedEvent())
|
||||
}
|
||||
@ -629,34 +509,39 @@ func (s *Service) SetUseSslForSmtp(ctx context.Context, useSsl *wrapperspb.BoolV
|
||||
func (s *Service) UseSslForSmtp(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.BoolValue, error) { //nolint:revive,stylecheck
|
||||
s.log.Debug("UseSslForSmtp")
|
||||
|
||||
return wrapperspb.Bool(s.bridge.GetBool(settings.SMTPSSLKey)), nil
|
||||
return wrapperspb.Bool(s.bridge.GetSMTPSSL()), nil
|
||||
}
|
||||
|
||||
func (s *Service) Hostname(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("Hostname")
|
||||
|
||||
return wrapperspb.String(bridge.Host), nil
|
||||
return wrapperspb.String(constants.Host), nil
|
||||
}
|
||||
|
||||
func (s *Service) ImapPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) {
|
||||
s.log.Debug("ImapPort")
|
||||
|
||||
return wrapperspb.Int32(int32(s.bridge.GetInt(settings.IMAPPortKey))), nil
|
||||
return wrapperspb.Int32(int32(s.bridge.GetIMAPPort())), nil
|
||||
}
|
||||
|
||||
func (s *Service) SmtpPort(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.Int32Value, error) { //nolint:revive,stylecheck
|
||||
s.log.Debug("SmtpPort")
|
||||
|
||||
return wrapperspb.Int32(int32(s.bridge.GetInt(settings.SMTPPortKey))), nil
|
||||
return wrapperspb.Int32(int32(s.bridge.GetSMTPPort())), nil
|
||||
}
|
||||
|
||||
func (s *Service) ChangePorts(ctx context.Context, ports *ChangePortsRequest) (*emptypb.Empty, error) {
|
||||
s.log.WithField("imapPort", ports.ImapPort).WithField("smtpPort", ports.SmtpPort).Debug("ChangePorts")
|
||||
|
||||
s.bridge.SetInt(settings.IMAPPortKey, int(ports.ImapPort))
|
||||
s.bridge.SetInt(settings.SMTPPortKey, int(ports.SmtpPort))
|
||||
if err := s.bridge.SetIMAPPort(int(ports.ImapPort)); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set IMAP port")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set IMAP port: %v", err)
|
||||
}
|
||||
|
||||
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
|
||||
if err := s.bridge.SetSMTPPort(int(ports.SmtpPort)); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set SMTP port")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set SMTP port: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, s.SendEvent(NewMailSettingsChangePortFinishedEvent())
|
||||
}
|
||||
@ -670,12 +555,7 @@ func (s *Service) IsPortFree(ctx context.Context, port *wrapperspb.Int32Value) (
|
||||
func (s *Service) AvailableKeychains(ctx context.Context, _ *emptypb.Empty) (*AvailableKeychainsResponse, error) {
|
||||
s.log.Debug("AvailableKeychains")
|
||||
|
||||
keychains := make([]string, 0, len(keychain.Helpers))
|
||||
for chain := range keychain.Helpers {
|
||||
keychains = append(keychains, chain)
|
||||
}
|
||||
|
||||
return &AvailableKeychainsResponse{Keychains: keychains}, nil
|
||||
return &AvailableKeychainsResponse{Keychains: maps.Keys(keychain.Helpers)}, nil
|
||||
}
|
||||
|
||||
func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.StringValue) (*emptypb.Empty, error) {
|
||||
@ -684,11 +564,20 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
|
||||
defer func() { _, _ = s.Restart(ctx, &emptypb.Empty{}) }()
|
||||
defer func() { _ = s.SendEvent(NewKeychainChangeKeychainFinishedEvent()) }()
|
||||
|
||||
if s.bridge.GetKeychainApp() == keychain.Value {
|
||||
helper, err := s.bridge.GetKeychainApp()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Failed to get current keychain")
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current keychain: %v", err)
|
||||
}
|
||||
|
||||
if helper == keychain.Value {
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
s.bridge.SetKeychainApp(keychain.Value)
|
||||
if err := s.bridge.SetKeychainApp(keychain.Value); err != nil {
|
||||
s.log.WithError(err).Error("Failed to set keychain")
|
||||
return nil, status.Errorf(codes.Internal, "failed to set keychain: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@ -696,5 +585,11 @@ func (s *Service) SetCurrentKeychain(ctx context.Context, keychain *wrapperspb.S
|
||||
func (s *Service) CurrentKeychain(ctx context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
|
||||
s.log.Debug("CurrentKeychain")
|
||||
|
||||
return wrapperspb.String(s.bridge.GetKeychainApp()), nil
|
||||
helper, err := s.bridge.GetKeychainApp()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("Failed to get current keychain")
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current keychain: %v", err)
|
||||
}
|
||||
|
||||
return wrapperspb.String(helper), nil
|
||||
}
|
||||
|
||||
@ -87,12 +87,8 @@ func (s *Service) StopEventStream(ctx context.Context, _ *emptypb.Empty) (*empty
|
||||
|
||||
// SendEvent sends an event to the via the gRPC event stream.
|
||||
func (s *Service) SendEvent(event *StreamEvent) error {
|
||||
s.eventQueueMutex.Lock()
|
||||
defer s.eventQueueMutex.Unlock()
|
||||
|
||||
if s.eventStreamCh == nil {
|
||||
// nobody is connected to the event stream, we queue events
|
||||
s.eventQueue = append(s.eventQueue, event)
|
||||
if s.eventStreamCh == nil { // nobody is connected to the event stream, we queue events
|
||||
s.queueEvent(event)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -167,3 +163,14 @@ func (s *Service) StartEventTest() error { //nolint:funlen
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) queueEvent(event *StreamEvent) {
|
||||
s.eventQueueMutex.Lock()
|
||||
defer s.eventQueueMutex.Unlock()
|
||||
|
||||
if event.isInternetStatus() {
|
||||
s.eventQueue = append(filterOutInternetStatusEvents(s.eventQueue), event)
|
||||
} else {
|
||||
s.eventQueue = append(s.eventQueue, event)
|
||||
}
|
||||
}
|
||||
|
||||
68
internal/frontend/grpc/service_updates.go
Normal file
68
internal/frontend/grpc/service_updates.go
Normal file
@ -0,0 +1,68 @@
|
||||
package grpc
|
||||
|
||||
/*
|
||||
func (s *Service) checkUpdate() {
|
||||
version, err := s.updater.Check()
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("An error occurred while checking for updates")
|
||||
s.SetVersion(updater.VersionInfo{})
|
||||
return
|
||||
}
|
||||
s.SetVersion(version)
|
||||
}
|
||||
|
||||
func (s *Service) updateForce() {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer s.updateCheckMutex.Unlock()
|
||||
s.checkUpdate()
|
||||
_ = s.SendEvent(NewUpdateForceEvent(s.newVersionInfo.Version.String()))
|
||||
}
|
||||
|
||||
func (s *Service) checkUpdateAndNotify(isReqFromUser bool) {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer func() {
|
||||
s.updateCheckMutex.Unlock()
|
||||
_ = s.SendEvent(NewUpdateCheckFinishedEvent())
|
||||
}()
|
||||
|
||||
s.checkUpdate()
|
||||
version := s.newVersionInfo
|
||||
if version.Version.String() == "" {
|
||||
if isReqFromUser {
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
}
|
||||
return
|
||||
}
|
||||
if !s.updater.IsUpdateApplicable(s.newVersionInfo) {
|
||||
s.log.Info("No need to update")
|
||||
if isReqFromUser {
|
||||
_ = s.SendEvent(NewUpdateIsLatestVersionEvent())
|
||||
}
|
||||
} else if isReqFromUser {
|
||||
s.NotifyManualUpdate(s.newVersionInfo, s.updater.CanInstall(s.newVersionInfo))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) installUpdate() {
|
||||
s.updateCheckMutex.Lock()
|
||||
defer s.updateCheckMutex.Unlock()
|
||||
|
||||
if !s.updater.CanInstall(s.newVersionInfo) {
|
||||
s.log.Warning("Skipping update installation, current version too old")
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.updater.InstallUpdate(s.newVersionInfo); err != nil {
|
||||
if errors.Cause(err) == updater.ErrDownloadVerify {
|
||||
s.log.WithError(err).Warning("Skipping update installation due to temporary error")
|
||||
} else {
|
||||
s.log.WithError(err).Error("The update couldn't be installed")
|
||||
_ = s.SendEvent(NewUpdateErrorEvent(UpdateErrorType_UPDATE_MANUAL_ERROR))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.SendEvent(NewUpdateSilentRestartNeededEvent())
|
||||
}
|
||||
*/
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user