Compare commits

..

42 Commits

Author SHA1 Message Date
4cc2ded001 chore: Kanmon Bridge 3.21.2 changelog. 2025-07-07 11:34:42 +02:00
15880dfe19 fix(BRIDGE-406): fixed faulty certificate chain validation logic; made certificate pin checks exclusive to leaf certs; 2025-07-04 15:19:44 +02:00
e9ea976773 chore: Kanmon Bridge 3.21.1 changelog. 2025-06-11 16:15:53 +02:00
a00af3a398 feat(BRIDGE-383): Internal mailbox conflict resolution extended; Minor alterations to mailbox conflict pre-checker 2025-06-11 16:11:20 +02:00
50ab740b92 chore: Kanmon Bridge 3.21.0 changelog. 2025-06-05 15:45:27 +02:00
39f2362996 feat(BRIDGE-379): mailbox pre-checker on startup & conflict resolver for bridge internal mailboxes; TODO potentially add this for system mailboxes as well 2025-06-05 14:34:29 +02:00
d2742c81e5 feat(BRIDGE-376): catch gluon errors related to label uniqueness constrainst... 2025-06-05 14:34:29 +02:00
9cb914cf13 fix(BRIDGE-377): Correct label field usage on label update handler 2025-06-05 14:34:29 +02:00
4088cf18c3 feat(BRIDGE-373): extend label conflict resolver logging & report sync errors to sentry 2025-06-05 14:34:29 +02:00
c02bae5eb2 fix(BRIDGE-378): Fix incorrect field usage for system mailbox names 2025-06-03 17:36:58 +02:00
2aa8acfb5b chore: changes to reconcile release/jubilee with dev 2025-05-28 16:56:28 +02:00
8109b384c5 fix(BRIDGE-362): added label conflict reconciliation logic 2025-05-28 16:56:07 +02:00
6d79ad3e41 chore: Jubilee Bridge 3.20.1 changelog. 2025-05-28 16:53:23 +02:00
5d93ee0cfc chore: Jubilee Bridge 3.20.0 changelog. 2025-05-28 16:53:23 +02:00
c3e2201945 feat(BRIDGE-366): Kill switch support for IMAP IDLE 2025-05-28 09:53:45 +02:00
89da7335b6 feat(BRIDGE-363): Observability metrics for IMAP connections; minor unleash service refactor; 2025-05-16 15:28:53 +02:00
a305ee1113 chore: Infinity Bridge 3.19.0 changelog. 2025-04-24 11:13:17 +02:00
e38f7748d0 chore: bump GPA 2025-04-22 12:41:13 +00:00
92b2024e3e chore(BRIDGE-352): bump go to 1.24.2 2025-04-17 14:06:54 +00:00
37a8fc95d2 chore(BRIDGE-353): update x/net package to 0.38.0 2025-04-17 10:48:31 +02:00
0c63533aa7 fix(BRIDGE-351): allow draft creation and import to BYOE addresses in combined mode 2025-04-15 17:26:14 +02:00
af98bc2273 fix(BRIDGE-301): don't use external, non-BYOE addresses for imports 2025-04-10 09:51:31 +00:00
b37f2d138a feat(BRIDGE-348): display BYOE addresses in Bridge 2025-04-10 10:06:40 +02:00
7831a98e6c chore(BRIDGE-346): silence http/net vulnerability 2025-04-09 10:49:34 +02:00
4d415675e0 fix(BRIDGE-341): replace go-autostart with fork; added ability to create shortcuts with unicode chars 2025-04-02 11:56:58 +00:00
291f44d1b5 fix(BRIDGE-332): filter new line characters from username and password fields in GUI 2025-04-01 14:37:18 +02:00
a4b315d67a fix(BRIDGE-336): check and create all labels in Gluon on Bridge start 2025-03-25 15:24:59 +01:00
a15d4eb3ef ci: update CODEOWNERS 2025-03-24 15:38:03 +01:00
4e764fe93d feat(BRIDGE-340): additional logging for label operations & bad events 2025-03-24 14:30:19 +01:00
df409925ec fix(BRIDGE-335): store last sucessfully used keychain helper as user preference 2025-03-19 15:10:09 +01:00
e68f3441d7 fix(BRIDGE-196): bump badssl public key 2025-03-19 10:00:23 +00:00
899d3293bc feat(BRIDGE-324): added a log entry for the vault key hash 2025-03-18 11:21:12 +00:00
c66f0b800a fix(BRIDGE-333): ignore unkown label IDs during synchronization 2025-03-17 10:43:26 +01:00
b9c75d02b2 chore: stabilize windows tests 2025-03-14 11:56:42 +01:00
4b91e66505 chore(BRIDGE-315): remove silenced vulns 2025-03-06 14:49:03 +00:00
0cbcd0bf13 fix(BRIDGE-329): fix menu bar icons not displayin on macOS 2025-03-06 15:10:52 +01:00
5c12b00e70 chore: Helix Bridge 3.18.0 changelog. 2025-03-06 10:37:52 +01:00
6e7cdfcd68 feat(BRIDGE-316): Changes required for Qt 6.8.2 bump; bumped go to 1.24.0; changes to OS bundler configs; golangci-lint bump; 2025-03-05 14:27:33 +01:00
a75f84742b chore: remove redundant log entry 2025-02-24 10:58:16 +01:00
f4ddf43ac7 chore: Grunwald Bridge 3.17.0 changelog. 2025-02-18 17:11:46 +01:00
da0f51ce5f feat(BRIDGE-309): Update to the bridge updater logic corresponding to the version file restructure 2025-02-17 15:43:15 +00:00
d711d9f562 feat(BRIDGE-154): include access token when refreshing 2025-02-17 15:10:05 +01:00
108 changed files with 5705 additions and 435 deletions

View File

@ -1 +1 @@
* @go/bridge-ppl/devs
* inbox-desktop-approvers

View File

@ -3,14 +3,14 @@
## Prerequisites
* 64-bit OS:
- the go-rfc5322 module cannot currently be compiled for 32-bit OSes
* Go 1.23.4
* Go 1.24.0
* Bash with basic build utils: make, gcc, sed, find, grep, ...
- For Windows, it is recommended to use MinGW 64bit shell from [MSYS2](https://www.msys2.org/)
* GCC (Linux), msvc (Windows) or Xcode (macOS)
* Windres (Windows)
* libglvnd and libsecret development files (Linux)
* pkg-config (Linux)
* cmake, ninja-build and Qt 6.4.3 are required to build the graphical user interface. On Linux,
* cmake, ninja-build and Qt 6.8.2 are required to build the graphical user interface. On Linux,
the Mesa OpenGL development files are also needed.
To enable the sending of crash reports using Sentry please set the
@ -19,7 +19,7 @@ Otherwise, the sending of crash reports will be disabled.
## Build
In order to build Bridge app with Qt interface we are using
[Qt 6.4.3](https://doc.qt.io/qt-6/gettingstarted.html).
[Qt 6.8.2](https://doc.qt.io/qt-6/gettingstarted.html).
Please note that qmake path must be in your `PATH` to ensure Qt to be found.
Also, before you start build **on Windows**, please unset the `MSYSTEM` variable

View File

@ -127,6 +127,7 @@ Proton Mail Bridge includes the following 3rd party software:
* [blackfriday](https://github.com/russross/blackfriday/v2) available under [license](https://github.com/russross/blackfriday/v2/blob/master/LICENSE)
* [pflag](https://github.com/spf13/pflag) available under [license](https://github.com/spf13/pflag/blob/master/LICENSE)
* [bom](https://github.com/ssor/bom) available under [license](https://github.com/ssor/bom/blob/master/LICENSE)
* [objx](https://github.com/stretchr/objx) available under [license](https://github.com/stretchr/objx/blob/master/LICENSE)
* [golang-asm](https://github.com/twitchyliquid64/golang-asm) available under [license](https://github.com/twitchyliquid64/golang-asm/blob/master/LICENSE)
* [codec](https://github.com/ugorji/go/codec) available under [license](https://github.com/ugorji/go/codec/blob/master/LICENSE)
* [tagparser](https://github.com/vmihailenco/tagparser/v2) available under [license](https://github.com/vmihailenco/tagparser/v2/blob/master/LICENSE)
@ -141,6 +142,7 @@ Proton Mail Bridge includes the following 3rd party software:
* [appengine](https://google.golang.org/appengine) available under [license](https://pkg.go.dev/google.golang.org/appengine?tab=licenses)
* [genproto](https://google.golang.org/genproto) available under [license](https://pkg.go.dev/google.golang.org/genproto?tab=licenses)
* [yaml](https://gopkg.in/yaml.v3) available under [license](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE) available under [license](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)
* [go-autostart](https://github.com/ElectroNafta/go-autostart) available under [license](https://github.com/ElectroNafta/go-autostart/blob/master/LICENSE)
* [go-message](https://github.com/ProtonMail/go-message) available under [license](https://github.com/ProtonMail/go-message/blob/master/LICENSE)
* [go-smtp](https://github.com/ProtonMail/go-smtp) available under [license](https://github.com/ProtonMail/go-smtp/blob/master/LICENSE)
* [resty](https://github.com/LBeernaertProton/resty/v2) available under [license](https://github.com/LBeernaertProton/resty/v2/blob/master/LICENSE)

View File

@ -3,6 +3,93 @@
Changelog [format](http://keepachangelog.com/en/1.0.0/)
## Kanmon Bridge 3.21.2
### Fixed
* BRIDGE-406: Fixed faulty certificate chain validation logic. Made certificate pin checks exclusive to leaf certificates.
## Kanmon Bridge 3.21.1
### Changed
* BRIDGE-383: Extended internal mailbox conflict resolution logic and minor changes to the mailbox conflict pre-checker.
## Kanmon Bridge 3.21.0
### Added
* BRIDGE-379: Mailbox pre-check on Bridge startup & conflict resolver for Bridge internal mailboxes.
### Changed
* BRIDGE-376: Explicitly catch Gluon DB mailbox name conflicts and report them to Sentry.
* BRIDGE-373: Extend user mailbox conflict resolver logging & report sync errors to Sentry.
* BRIDGE-366: Kill switch support for IMAP IDLE.
* BRIDGE-363: Observability metric support for IMAP connections.
### Fixed
* BRIDGE-377: Correct API label field usage on user label conflict resolver - update handler (event loop).
* BRIDGE-378: Fix incorrect field usage for system mailbox names.
## Jubilee Bridge 3.20.1
### Fixed
* BRIDGE-362: Implemented logic for reconciling label conflicts.
## Jubilee Bridge 3.20.0
### Added
* BRIDGE-348: Enable display of BYOE addresses in Bridge.
* BRIDGE-340: Added additional logging for label operations and related bad events.
* BRIDGE-324: Log a hash of the vault key on Bridge start.
### Changed
* BRIDGE-352: Chore: bump go to 1.24.2.
* BRIDGE-353: Chore: update x/net package to 0.38.0.
### Fixed
* BRIDGE-351: Allow draft creation and import to BYOE addresses in combined mode.
* BRIDGE-301: Prevent imports into non-BYOE external addresses.
* BRIDGE-341: Replaced go-autostart with a fork to support creating autostart shortcuts in directories with Unicode characters on Windows.
* BRIDGE-332: Strip newline characters from username and password fields in the Bridge GUI.
* BRIDGE-336: Ensure all remote labels are verified and created in Gluon at Bridge startup.
* BRIDGE-335: Persist the last successfully used keychain helper as a user preference on Linux.
* BRIDGE-333: Ignore unknown label IDs during Bridge synchronization.
## Infinity Bridge 3.19.0
### Changed
* BRIDGE-316: Update Qt to latest LTS version 6.8.2.
## Helix Bridge 3.18.0
### Changed
* BRIDGE-309: Revised update logic and structure.
* BRIDGE-154: Added access token to expiry refresh request.
## Grunwald Bridge 3.17.0
### Added
* BRIDGE-271: Report version file check failure to Sentry.
* BRIDGE-247: Test: Automate Bridge 0% update rollout.
* BRIDGE-248: Test: Additional Bridge UI e2e automation tests.
### Changed
* BRIDGE-73: Update goopenpgp.
* BRIDGE-287: Update x/net and x/crypto dependencies.
* BRIDGE-303: Update govulncheck to latest release.
* BRIDGE-226: Bump Go version to 1.23.4.
* BRIDGE-288: Extension to synchronization update handler, observability tweaks and gluon update.
### Fixed
* BRIDGE-291: Use correct field for user plan type.
* BRIDGE-143: Add missing QML component attribute, cut/paste disabled on read-only text areas.
## Flavien Bridge 3.16.0
### Added

View File

@ -12,7 +12,7 @@ ROOT_DIR:=$(realpath .)
.PHONY: build build-gui build-nogui build-launcher versioner hasher
# Keep version hardcoded so app build works also without Git repository.
BRIDGE_APP_VERSION?=3.16.0+git
BRIDGE_APP_VERSION?=3.21.2+git
APP_VERSION:=${BRIDGE_APP_VERSION}
APP_FULL_NAME:=Proton Mail Bridge
APP_VENDOR:=Proton AG
@ -189,7 +189,7 @@ ${RESOURCE_FILE}: ./dist/info.rc ./dist/${SRC_ICO} .FORCE
## Dev dependencies
.PHONY: install-devel-tools install-linter install-go-mod-outdated install-git-hooks
LINTVER:="v1.61.0"
LINTVER:="v1.64.6"
LINTSRC:="https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh"
install-dev-dependencies: install-devel-tools install-linter install-go-mod-outdated

20
go.mod
View File

@ -1,15 +1,15 @@
module github.com/ProtonMail/proton-bridge/v3
go 1.23
go 1.24
toolchain go1.23.4
toolchain go1.24.2
require (
github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557
github.com/Masterminds/semver/v3 v3.2.0
github.com/ProtonMail/gluon v0.17.1-0.20250116113909-2ebd96ec0bc2
github.com/ProtonMail/gluon v0.17.1-0.20250611120816-05167d499f8d
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a
github.com/ProtonMail/go-proton-api v0.4.1-0.20250121114701-67bd01ad0bc3
github.com/ProtonMail/go-proton-api v0.4.1-0.20250417134000-e624a080f7ba
github.com/ProtonMail/gopenpgp/v2 v2.8.2-proton
github.com/PuerkitoBio/goquery v1.8.1
github.com/abiosoft/ishell v2.0.0+incompatible
@ -46,10 +46,10 @@ require (
github.com/vmihailenco/msgpack/v5 v5.3.5
go.uber.org/goleak v1.2.1
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
golang.org/x/net v0.34.0
golang.org/x/net v0.38.0
golang.org/x/oauth2 v0.7.0
golang.org/x/sys v0.29.0
golang.org/x/text v0.21.0
golang.org/x/sys v0.31.0
golang.org/x/text v0.23.0
google.golang.org/api v0.114.0
google.golang.org/grpc v1.56.3
google.golang.org/protobuf v1.33.0
@ -114,6 +114,7 @@ require (
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
@ -121,9 +122,9 @@ require (
gitlab.com/c0b/go-ordered-json v0.0.0-20201030195603-febf46534d5a // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect
@ -131,6 +132,7 @@ require (
)
replace (
github.com/ProtonMail/go-autostart => github.com/ElectroNafta/go-autostart v0.0.0-20250402094843-326608c16033
github.com/emersion/go-message => github.com/ProtonMail/go-message v0.13.1-0.20240919135104-3bc88e6a9423
github.com/emersion/go-smtp => github.com/ProtonMail/go-smtp v0.0.0-20231109081432-2b3d50599865
github.com/go-resty/resty/v2 => github.com/LBeernaertProton/resty/v2 v2.0.0-20231129100320-dddf8030d93a

32
go.sum
View File

@ -23,6 +23,8 @@ github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557 h1:l6surSnJ3RP4qA
github.com/0xAX/notificator v0.0.0-20220220101646-ee9b8921e557/go.mod h1:sTrmvD/TxuypdOERsDOS7SndZg0rzzcCi1b6wQMXUYM=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ElectroNafta/go-autostart v0.0.0-20250402094843-326608c16033 h1:d2RB9rQmSusb0K+qSgB+DAY+8i+AXZ/o+oDHj2vAUaA=
github.com/ElectroNafta/go-autostart v0.0.0-20250402094843-326608c16033/go.mod h1:o0nKiWcK0e2G/90uL6akWRkzOV4mFcZmvpBPpigJvdw=
github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9/go.mod h1:7uhhqiBaR4CpN0k9rMjOtjpcfGd6DG2m04zQxKnWQ0I=
github.com/LBeernaertProton/resty/v2 v2.0.0-20231129100320-dddf8030d93a h1:eQO/GF/+H8/9udc9QAgieFr+jr1tjXlJo35RAhsUbWY=
github.com/LBeernaertProton/resty/v2 v2.0.0-20231129100320-dddf8030d93a/go.mod h1:iiP/OpA0CkcL3IGt1O0+/SIItFUbkkyw5BGXiVdTu+A=
@ -34,10 +36,8 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
github.com/ProtonMail/gluon v0.17.1-0.20250116113909-2ebd96ec0bc2 h1:lDgMidI/9j2eedavcy7YICv8+F73ooVTUoUGBE4dO0s=
github.com/ProtonMail/gluon v0.17.1-0.20250116113909-2ebd96ec0bc2/go.mod h1:0/c03TzZPNiSgY5UDJK1iRDkjlDPwWugxTT6et2qDu8=
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a h1:D+aZah+k14Gn6kmL7eKxoo/4Dr/lK3ChBcwce2+SQP4=
github.com/ProtonMail/go-autostart v0.0.0-20210130080809-00ed301c8e9a/go.mod h1:oTGdE7/DlWIr23G0IKW3OXK9wZ5Hw1GGiaJFccTvZi4=
github.com/ProtonMail/gluon v0.17.1-0.20250611120816-05167d499f8d h1:45W7G+X0w7nzLzeB0eiFkGho5DTK1jNmmNbt3IhN524=
github.com/ProtonMail/gluon v0.17.1-0.20250611120816-05167d499f8d/go.mod h1:0/c03TzZPNiSgY5UDJK1iRDkjlDPwWugxTT6et2qDu8=
github.com/ProtonMail/go-crypto v0.0.0-20230321155629-9a39f2531310/go.mod h1:8TI4H3IbrackdNgv+92dI+rhpCaLqM0IfpgCgenFvRE=
github.com/ProtonMail/go-crypto v1.1.4-proton h1:KIo9uNlk3vzlwI7o5VjhiEjI4Ld1TDixOMnoNZyfpFE=
github.com/ProtonMail/go-crypto v1.1.4-proton/go.mod h1:zNoyBJW3p/yVWiHNZgfTF9VsjwqYof5YY0M9kt2QaX0=
@ -45,8 +45,8 @@ github.com/ProtonMail/go-message v0.13.1-0.20240919135104-3bc88e6a9423 h1:p8nBDx
github.com/ProtonMail/go-message v0.13.1-0.20240919135104-3bc88e6a9423/go.mod h1:NBAn21zgCJ/52WLDyed18YvYFm5tEoeDauubFqLokM4=
github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f h1:tCbYj7/299ekTTXpdwKYF8eBlsYsDVoggDAuAjoK66k=
github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f/go.mod h1:gcr0kNtGBqin9zDW9GOHcVntrwnjrK+qdJ06mWYBybw=
github.com/ProtonMail/go-proton-api v0.4.1-0.20250121114701-67bd01ad0bc3 h1:YYnLBVcg7WrEbYVmF1PBr4AEQlob9rCphsMHAmF4CAo=
github.com/ProtonMail/go-proton-api v0.4.1-0.20250121114701-67bd01ad0bc3/go.mod h1:RYgagBFkA3zFrSt7/vviFFwjZxBo6pGzcTwFsLwsnyc=
github.com/ProtonMail/go-proton-api v0.4.1-0.20250417134000-e624a080f7ba h1:DFBngZ7u/f69flRFzPp6Ipo6PKEyflJlA5OCh52yDB4=
github.com/ProtonMail/go-proton-api v0.4.1-0.20250417134000-e624a080f7ba/go.mod h1:eXIoLyIHxvPo8Kd9e1ygYIrAwbeWJhLi3vgSz2crlK4=
github.com/ProtonMail/go-smtp v0.0.0-20231109081432-2b3d50599865 h1:EP1gnxLL5Z7xBSymE9nSTM27nRYINuvssAtDmG0suD8=
github.com/ProtonMail/go-smtp v0.0.0-20231109081432-2b3d50599865/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ=
github.com/ProtonMail/go-srp v0.0.7 h1:Sos3Qk+th4tQR64vsxGIxYpN3rdnG9Wf9K4ZloC1JrI=
@ -498,8 +498,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -555,8 +555,8 @@ golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -571,8 +571,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -611,8 +611,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@ -632,8 +632,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=

View File

@ -138,7 +138,7 @@ func migrateOldAccounts(locations *locations.Locations, keychains *keychain.List
if err != nil {
return fmt.Errorf("failed to get helper: %w", err)
}
keychain, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper())
keychain, _, err := keychain.NewKeychain(helper, "bridge", keychains.GetHelpers(), keychains.GetDefaultHelper())
if err != nil {
return fmt.Errorf("failed to create keychain: %w", err)
}

View File

@ -134,7 +134,7 @@ func TestKeychainMigration(t *testing.T) {
func TestUserMigration(t *testing.T) {
kcl := keychain.NewTestKeychainsList()
kc, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper())
kc, _, err := keychain.NewKeychain("mock", "bridge", kcl.GetHelpers(), kcl.GetDefaultHelper())
require.NoError(t, err)
require.NoError(t, kc.Put("brokenID", "broken"))

View File

@ -18,6 +18,8 @@
package app
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"path"
@ -67,11 +69,12 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
logrus.WithField("vaultDir", vaultDir).Debug("Loading vault from directory")
var (
vaultKey []byte
insecure bool
vaultKey []byte
insecure bool
lastUsedHelper string
)
if key, err := loadVaultKey(vaultDir, keychains); err != nil {
if key, helper, err := loadVaultKey(vaultDir, keychains); err != nil {
if reporter != nil {
if rerr := reporter.ReportMessageWithContext("Could not load/create vault key", map[string]any{
"keychainDefaultHelper": keychains.GetDefaultHelper(),
@ -89,6 +92,8 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
vaultDir = path.Join(vaultDir, "insecure")
} else {
vaultKey = key
lastUsedHelper = helper
logHashedVaultKey(vaultKey) // Log a hash of the vault key.
}
gluonCacheDir, err := locations.ProvideGluonCachePath()
@ -96,34 +101,47 @@ func newVault(reporter *sentry.Reporter, locations *locations.Locations, keychai
return nil, false, nil, fmt.Errorf("could not provide gluon path: %w", err)
}
vault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler)
userVault, corrupt, err := vault.New(vaultDir, gluonCacheDir, vaultKey, panicHandler)
if err != nil {
return nil, false, corrupt, fmt.Errorf("could not create vault: %w", err)
}
return vault, insecure, corrupt, nil
// Remember the last successfully used keychain and store that as the user preference.
if err := vault.SetHelper(vaultDir, lastUsedHelper); err != nil {
logrus.WithError(err).Error("Could not store last used keychain helper")
}
return userVault, insecure, corrupt, nil
}
func loadVaultKey(vaultDir string, keychains *keychain.List) ([]byte, error) {
helper, err := vault.GetHelper(vaultDir)
// loadVaultKey - loads the key used to encrypt the vault alongside the keychain helper used to access it.
func loadVaultKey(vaultDir string, keychains *keychain.List) (key []byte, keychainHelper string, err error) {
keychainHelper, err = vault.GetHelper(vaultDir)
if err != nil {
return nil, fmt.Errorf("could not get keychain helper: %w", err)
return nil, keychainHelper, fmt.Errorf("could not get keychain helper: %w", err)
}
kc, err := keychain.NewKeychain(helper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper())
kc, keychainHelper, err := keychain.NewKeychain(keychainHelper, constants.KeyChainName, keychains.GetHelpers(), keychains.GetDefaultHelper())
if err != nil {
return nil, fmt.Errorf("could not create keychain: %w", err)
return nil, keychainHelper, fmt.Errorf("could not create keychain: %w", err)
}
key, err := vault.GetVaultKey(kc)
key, err = vault.GetVaultKey(kc)
if err != nil {
if keychain.IsErrKeychainNoItem(err) {
logrus.WithError(err).Warn("no vault key found, generating new")
return vault.NewVaultKey(kc)
key, err := vault.NewVaultKey(kc)
return key, keychainHelper, err
}
return nil, fmt.Errorf("could not check for vault key: %w", err)
return nil, keychainHelper, fmt.Errorf("could not check for vault key: %w", err)
}
return key, nil
return key, keychainHelper, nil
}
// logHashedVaultKey - computes a sha256 hash and encodes it to base 64. The resulting string is logged.
func logHashedVaultKey(vaultKey []byte) {
hashedKey := sha256.Sum256(vaultKey)
logrus.WithField("hashedKey", hex.EncodeToString(hashedKey[:])).Info("Found vault key")
}

View File

@ -55,6 +55,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/pkg/keychain"
"github.com/bradenaw/juniper/xslices"
"github.com/elastic/go-sysinfo/types"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
)
@ -81,8 +82,9 @@ type Bridge struct {
imapEventCh chan imapEvents.Event
// updater is the bridge's updater.
updater Updater
installCh chan installJob
updater Updater
installChLegacy chan installJobLegacy
installCh chan installJob
// heartbeat is the telemetry heartbeat for metrics.
heartbeat *heartBeatState
@ -149,6 +151,9 @@ type Bridge struct {
// notificationStore is used for notification deduplication
notificationStore *notifications.Store
// getHostVersion primarily used for testing the update logic - it should return an OS version
getHostVersion func(host types.Host) string
}
var logPkg = logrus.WithField("pkg", "bridge") //nolint:gochecknoglobals
@ -283,8 +288,9 @@ func newBridge(
tlsConfig: tlsConfig,
imapEventCh: imapEventCh,
updater: updater,
installCh: make(chan installJob),
updater: updater,
installChLegacy: make(chan installJobLegacy),
installCh: make(chan installJob),
curVersion: curVersion,
newVersion: curVersion,
@ -316,6 +322,8 @@ func newBridge(
observabilityService: observabilityService,
notificationStore: notifications.NewStore(locator.ProvideNotificationsCachePath),
getHostVersion: func(host types.Host) string { return host.Info().OS.Version },
}
bridge.serverManager = imapsmtpserver.NewService(context.Background(),
@ -327,6 +335,7 @@ func newBridge(
uidValidityGenerator,
&bridgeIMAPSMTPTelemetry{b: bridge},
observabilityService,
unleashService,
)
// Check whether username has changed and correct (macOS only)
@ -436,8 +445,17 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
// Check for updates when triggered.
bridge.goUpdate = bridge.tasks.PeriodicOrTrigger(constants.UpdateCheckInterval, 0, func(ctx context.Context) {
logPkg.Info("Checking for updates")
var versionLegacy updater.VersionInfoLegacy
var version updater.VersionInfo
var err error
useOldUpdateLogic := bridge.GetFeatureFlagValue(unleash.UpdateUseNewVersionFileStructureDisabled)
if useOldUpdateLogic {
versionLegacy, err = bridge.updater.GetVersionInfoLegacy(ctx, bridge.api, bridge.vault.GetUpdateChannel())
} else {
version, err = bridge.updater.GetVersionInfo(ctx, bridge.api)
}
version, err := bridge.updater.GetVersionInfo(ctx, bridge.api, bridge.vault.GetUpdateChannel())
if err != nil {
bridge.publish(events.UpdateCheckFailed{Error: err})
if errors.Is(err, updater.ErrVersionFileDownloadOrVerify) {
@ -450,12 +468,23 @@ func (bridge *Bridge) init(tlsReporter TLSReporter) error {
}
}
} else {
bridge.handleUpdate(version)
if useOldUpdateLogic {
bridge.handleUpdateLegacy(versionLegacy)
} else {
bridge.handleUpdate(version)
}
}
})
defer bridge.goUpdate()
// Install updates when available.
// Install updates when available - based on old update logic
bridge.tasks.Once(func(ctx context.Context) {
async.RangeContext(ctx, bridge.installChLegacy, func(job installJobLegacy) {
bridge.installUpdateLegacy(ctx, job)
})
})
// Install updates when available - based on new update logic
bridge.tasks.Once(func(ctx context.Context) {
async.RangeContext(ctx, bridge.installCh, func(job installJob) {
bridge.installUpdate(ctx, job)
@ -692,13 +721,13 @@ func (bridge *Bridge) verifyUsernameChange() {
func GetUpdatedCachePath(gluonDBPath, gluonCachePath string) string {
// If gluon cache is moved to an external drive; regex find will fail; as is expected
cachePathMatches := usernameChangeRegex.FindStringSubmatch(gluonCachePath)
if cachePathMatches == nil || len(cachePathMatches) < 2 {
if len(cachePathMatches) < 2 {
return ""
}
cacheUsername := cachePathMatches[1]
dbPathMatches := usernameChangeRegex.FindStringSubmatch(gluonDBPath)
if dbPathMatches == nil || len(dbPathMatches) < 2 {
if len(dbPathMatches) < 2 {
return ""
}
@ -718,7 +747,7 @@ func (bridge *Bridge) PushObservabilityMetric(metric proton.ObservabilityMetric)
bridge.observabilityService.AddMetrics(metric)
}
func (bridge *Bridge) PushDistinctObservabilityMetrics(errType observability.DistinctionErrorTypeEnum, metrics ...proton.ObservabilityMetric) {
func (bridge *Bridge) PushDistinctObservabilityMetrics(errType observability.DistinctionMetricTypeEnum, metrics ...proton.ObservabilityMetric) {
bridge.observabilityService.AddDistinctMetrics(errType, metrics...)
}
@ -740,3 +769,19 @@ func (bridge *Bridge) ReportMessageWithContext(message string, messageCtx report
func (bridge *Bridge) GetUsers() map[string]*user.User {
return bridge.users
}
// SetCurrentVersionTest - sets the current version of bridge; should only be used for tests.
func (bridge *Bridge) SetCurrentVersionTest(version *semver.Version) {
bridge.curVersion = version
bridge.newVersion = version
}
// SetHostVersionGetterTest - sets the OS version helper func; only used for testing.
func (bridge *Bridge) SetHostVersionGetterTest(fn func(host types.Host) string) {
bridge.getHostVersion = fn
}
// SetRolloutPercentageTest - sets the rollout percentage; should only be used for testing.
func (bridge *Bridge) SetRolloutPercentageTest(rollout float64) error {
return bridge.vault.SetUpdateRollout(rollout)
}

View File

@ -45,6 +45,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/focus"
"github.com/ProtonMail/proton-bridge/v3/internal/locations"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapsmtpserver"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/user"
"github.com/ProtonMail/proton-bridge/v3/internal/useragent"
@ -383,9 +384,14 @@ func TestBridge_Cookies(t *testing.T) {
})
}
func TestBridge_CheckUpdate(t *testing.T) {
func TestBridge_CheckUpdate_Legacy(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
unleash.ModifyPollPeriodAndJitter(500*time.Millisecond, 0)
s.PushFeatureFlag(unleash.UpdateUseNewVersionFileStructureDisabled)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Wait for FF poll.
time.Sleep(600 * time.Millisecond)
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
@ -400,7 +406,7 @@ func TestBridge_CheckUpdate(t *testing.T) {
require.Equal(t, events.UpdateNotAvailable{}, <-noUpdateCh)
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
mocks.Updater.SetLatestVersionLegacy(v2_4_0, v2_3_0)
// Get a stream of update available events.
updateCh, done := bridge.GetEvents(events.UpdateAvailable{})
@ -411,7 +417,7 @@ func TestBridge_CheckUpdate(t *testing.T) {
// We should receive an event indicating that an update is available.
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
VersionLegacy: updater.VersionInfoLegacy{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
@ -423,25 +429,30 @@ func TestBridge_CheckUpdate(t *testing.T) {
})
}
func TestBridge_AutoUpdate(t *testing.T) {
func TestBridge_AutoUpdate_Legacy(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
unleash.ModifyPollPeriodAndJitter(500*time.Millisecond, 0)
s.PushFeatureFlag(unleash.UpdateUseNewVersionFileStructureDisabled)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(b *bridge.Bridge, mocks *bridge.Mocks) {
// Wait for FF poll.
time.Sleep(600 * time.Millisecond)
// Enable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(true))
require.NoError(t, b.SetAutoUpdate(true))
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
updateCh, done := b.GetEvents(events.UpdateInstalled{})
defer done()
// Simulate a new version being available.
mocks.Updater.SetLatestVersion(v2_4_0, v2_3_0)
mocks.Updater.SetLatestVersionLegacy(v2_4_0, v2_3_0)
// Check for updates.
bridge.CheckForUpdates()
b.CheckForUpdates()
// We should receive an event indicating that the update was silently installed.
require.Equal(t, events.UpdateInstalled{
Version: updater.VersionInfo{
VersionLegacy: updater.VersionInfoLegacy{
Version: v2_4_0,
MinAuto: v2_3_0,
RolloutProportion: 1.0,
@ -452,9 +463,14 @@ func TestBridge_AutoUpdate(t *testing.T) {
})
}
func TestBridge_ManualUpdate(t *testing.T) {
func TestBridge_ManualUpdate_Legacy(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
unleash.ModifyPollPeriodAndJitter(500*time.Millisecond, 0)
s.PushFeatureFlag(unleash.UpdateUseNewVersionFileStructureDisabled)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, mocks *bridge.Mocks) {
// Wait for FF poll.
time.Sleep(600 * time.Millisecond)
// Disable autoupdate for this test.
require.NoError(t, bridge.SetAutoUpdate(false))
@ -463,14 +479,14 @@ func TestBridge_ManualUpdate(t *testing.T) {
defer done()
// Simulate a new version being available, but it's too new for us.
mocks.Updater.SetLatestVersion(v2_4_0, v2_4_0)
mocks.Updater.SetLatestVersionLegacy(v2_4_0, v2_4_0)
// Check for updates.
bridge.CheckForUpdates()
// We should receive an event indicating an update is available, but we can't install it.
require.Equal(t, events.UpdateAvailable{
Version: updater.VersionInfo{
VersionLegacy: updater.VersionInfoLegacy{
Version: v2_4_0,
MinAuto: v2_4_0,
RolloutProportion: 1.0,
@ -484,7 +500,12 @@ func TestBridge_ManualUpdate(t *testing.T) {
func TestBridge_ForceUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
unleash.ModifyPollPeriodAndJitter(500*time.Millisecond, 0)
s.PushFeatureFlag(unleash.UpdateUseNewVersionFileStructureDisabled)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
// Wait for FF poll.
time.Sleep(600 * time.Millisecond)
// Get a stream of update events.
updateCh, done := bridge.GetEvents(events.UpdateForced{})
defer done()
@ -597,7 +618,7 @@ func TestBridge_AddressWithoutKeys(t *testing.T) {
require.NoError(t, err)
// Create an additional address for the user; it will not have keys.
aliasAddrID, err := s.CreateAddress(userID, "alias@pm.me", []byte("password"))
aliasAddrID, err := s.CreateAddress(userID, "alias@pm.me", []byte("password"), true)
require.NoError(t, err)
// Create an API client so we can remove the address keys.
@ -764,7 +785,7 @@ func TestBridge_ChangeAddressOrder(t *testing.T) {
require.NoError(t, err)
// Create a second address for the user.
aliasID, err := s.CreateAddress(userID, "alias@"+s.GetDomain(), password)
aliasID, err := s.CreateAddress(userID, "alias@"+s.GetDomain(), password, true)
require.NoError(t, err)
// Create 10 messages for the user.

View File

@ -54,6 +54,9 @@ func NewMocks(tb testing.TB, version, minAuto *semver.Version) *Mocks {
mocks.Heartbeat.EXPECT().IsTelemetryAvailable(gomock.Any()).AnyTimes()
mocks.Heartbeat.EXPECT().GetHeartbeatPeriodicInterval().AnyTimes().Return(500 * time.Millisecond)
// It's called whenever a context is cancelled during sync. We should ought to remove this and make it more granular in the future.
mocks.Reporter.EXPECT().ReportMessageWithContext("Failed to sync, will retry later", gomock.Any()).AnyTimes()
return mocks
}
@ -119,13 +122,14 @@ func (provider *TestLocationsProvider) UserCache() string {
}
type TestUpdater struct {
latest updater.VersionInfo
lock sync.RWMutex
latest updater.VersionInfoLegacy
releases updater.VersionInfo
lock sync.RWMutex
}
func NewTestUpdater(version, minAuto *semver.Version) *TestUpdater {
return &TestUpdater{
latest: updater.VersionInfo{
latest: updater.VersionInfoLegacy{
Version: version,
MinAuto: minAuto,
@ -134,11 +138,11 @@ func NewTestUpdater(version, minAuto *semver.Version) *TestUpdater {
}
}
func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Version) {
func (testUpdater *TestUpdater) SetLatestVersionLegacy(version, minAuto *semver.Version) {
testUpdater.lock.Lock()
defer testUpdater.lock.Unlock()
testUpdater.latest = updater.VersionInfo{
testUpdater.latest = updater.VersionInfoLegacy{
Version: version,
MinAuto: minAuto,
@ -146,17 +150,35 @@ func (testUpdater *TestUpdater) SetLatestVersion(version, minAuto *semver.Versio
}
}
func (testUpdater *TestUpdater) GetVersionInfo(_ context.Context, _ updater.Downloader, _ updater.Channel) (updater.VersionInfo, error) {
func (testUpdater *TestUpdater) GetVersionInfoLegacy(_ context.Context, _ updater.Downloader, _ updater.Channel) (updater.VersionInfoLegacy, error) {
testUpdater.lock.RLock()
defer testUpdater.lock.RUnlock()
return testUpdater.latest, nil
}
func (testUpdater *TestUpdater) InstallUpdate(_ context.Context, _ updater.Downloader, _ updater.VersionInfo) error {
func (testUpdater *TestUpdater) InstallUpdateLegacy(_ context.Context, _ updater.Downloader, _ updater.VersionInfoLegacy) error {
return nil
}
func (testUpdater *TestUpdater) RemoveOldUpdates() error {
return nil
}
func (testUpdater *TestUpdater) SetLatestVersion(releases updater.VersionInfo) {
testUpdater.lock.Lock()
defer testUpdater.lock.Unlock()
testUpdater.releases = releases
}
func (testUpdater *TestUpdater) GetVersionInfo(_ context.Context, _ updater.Downloader) (updater.VersionInfo, error) {
testUpdater.lock.RLock()
defer testUpdater.lock.RUnlock()
return testUpdater.releases, nil
}
func (testUpdater *TestUpdater) InstallUpdate(_ context.Context, _ updater.Downloader, _ updater.Release) error {
return nil
}

View File

@ -88,3 +88,18 @@ func (mr *MockReporterMockRecorder) ReportMessageWithContext(arg0, arg1 interfac
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportMessageWithContext", reflect.TypeOf((*MockReporter)(nil).ReportMessageWithContext), arg0, arg1)
}
// ReportWarningWithContext mocks base method.
func (m *MockReporter) ReportWarningWithContext(arg0 string, arg1 map[string]interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReportWarningWithContext", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ReportWarningWithContext indicates an expected call of ReportWarningWithContext.
func (mr *MockReporterMockRecorder) ReportWarningWithContext(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportWarningWithContext", reflect.TypeOf((*MockReporter)(nil).ReportMessageWithContext), arg0, arg1)
}

View File

@ -25,7 +25,7 @@ func NewMockObservabilitySender(ctrl *gomock.Controller) *MockObservabilitySende
func (m *MockObservabilitySender) EXPECT() *MockObservabilitySenderRecorder { return m.recorder }
func (m *MockObservabilitySender) AddDistinctMetrics(errType observability.DistinctionErrorTypeEnum, _ ...proton.ObservabilityMetric) {
func (m *MockObservabilitySender) AddDistinctMetrics(errType observability.DistinctionMetricTypeEnum, _ ...proton.ObservabilityMetric) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddDistinctMetrics", errType)
}
@ -35,7 +35,18 @@ func (m *MockObservabilitySender) AddMetrics(metrics ...proton.ObservabilityMetr
m.ctrl.Call(m, "AddMetrics", metrics)
}
func (mr *MockObservabilitySenderRecorder) AddDistinctMetrics(errType observability.DistinctionErrorTypeEnum, _ ...proton.ObservabilityMetric) *gomock.Call {
func (m *MockObservabilitySender) AddTimeLimitedMetric(metricType observability.DistinctionMetricTypeEnum, metric proton.ObservabilityMetric) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddTimeLimitedMetric", metricType, metric)
}
func (m *MockObservabilitySender) GetEmailClient() string {
m.ctrl.T.Helper()
m.ctrl.Call(m, "GetEmailClient")
return ""
}
func (mr *MockObservabilitySenderRecorder) AddDistinctMetrics(errType observability.DistinctionMetricTypeEnum, _ ...proton.ObservabilityMetric) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock,
"AddDistinctMetrics",
@ -47,3 +58,13 @@ func (mr *MockObservabilitySenderRecorder) AddMetrics(metrics ...proton.Observab
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddMetrics", reflect.TypeOf((*MockObservabilitySender)(nil).AddMetrics), metrics)
}
func (mr *MockObservabilitySenderRecorder) AddTimeLimitedMetric(metricType observability.DistinctionMetricTypeEnum, metric proton.ObservabilityMetric) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddTimeLimitedMetric", reflect.TypeOf((*MockObservabilitySender)(nil).AddTimeLimitedMetric), metricType, metric)
}
func (mr *MockObservabilitySenderRecorder) GetEmailClient() {
mr.mock.ctrl.T.Helper()
mr.mock.ctrl.Call(mr.mock, "GetEmailClient", reflect.TypeOf((*MockObservabilitySender)(nil).GetEmailClient))
}

View File

@ -127,9 +127,9 @@ func TestBridge_Observability_UserMetric(t *testing.T) {
}
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, vaultKey []byte) {
userMetricPeriod := time.Millisecond * 200
userMetricPeriod := time.Millisecond * 600
heartbeatPeriod := time.Second * 10
throttlePeriod := time.Millisecond * 100
throttlePeriod := time.Millisecond * 300
observability.ModifyUserMetricInterval(userMetricPeriod)
observability.ModifyThrottlePeriod(throttlePeriod)

View File

@ -355,7 +355,7 @@ func TestBridge_CanProcessEventsDuringSync(t *testing.T) {
// Create a new address
newAddress := "foo@proton.ch"
addrID, err := s.CreateAddress(userID, newAddress, password)
addrID, err := s.CreateAddress(userID, newAddress, password, true)
require.NoError(t, err)
event := <-addressCreatedCh
@ -430,7 +430,7 @@ func TestBridge_EventReplayAfterSyncHasFinished(t *testing.T) {
createNumMessages(ctx, t, c, addrID, labelID, numMsg)
})
addrID1, err := s.CreateAddress(userID, "foo@proton.ch", password)
addrID1, err := s.CreateAddress(userID, "foo@proton.ch", password, true)
require.NoError(t, err)
var allowSyncToProgress atomic.Bool
@ -469,7 +469,7 @@ func TestBridge_EventReplayAfterSyncHasFinished(t *testing.T) {
})
// User AddrID2 event as a check point to see when the new address was created.
addrID2, err := s.CreateAddress(userID, "bar@proton.ch", password)
addrID2, err := s.CreateAddress(userID, "bar@proton.ch", password, true)
require.NoError(t, err)
allowSyncToProgress.Store(true)
@ -552,7 +552,7 @@ func TestBridge_MessageCreateDuringSync(t *testing.T) {
})
// User AddrID2 event as a check point to see when the new address was created.
addrID, err := s.CreateAddress(userID, "bar@proton.ch", password)
addrID, err := s.CreateAddress(userID, "bar@proton.ch", password, true)
require.NoError(t, err)
// At most two events can be published, one for the first address, then for the second.
@ -663,7 +663,7 @@ func TestBridge_AddressOrderChangeDuringSyncInCombinedModeDoesNotTriggerBadEvent
require.Equal(t, 1, len(info.Addresses))
require.Equal(t, info.Addresses[0], "user@proton.local")
addrID2, err := s.CreateAddress(userID, "foo@"+s.GetDomain(), password)
addrID2, err := s.CreateAddress(userID, "foo@"+s.GetDomain(), password, true)
require.NoError(t, err)
require.NoError(t, s.SetAddressOrder(userID, []string{addrID2, addrID}))

View File

@ -52,7 +52,9 @@ type Autostarter interface {
}
type Updater interface {
GetVersionInfo(context.Context, updater.Downloader, updater.Channel) (updater.VersionInfo, error)
InstallUpdate(context.Context, updater.Downloader, updater.VersionInfo) error
GetVersionInfoLegacy(context.Context, updater.Downloader, updater.Channel) (updater.VersionInfoLegacy, error)
InstallUpdateLegacy(context.Context, updater.Downloader, updater.VersionInfoLegacy) error
RemoveOldUpdates() error
GetVersionInfo(context.Context, updater.Downloader) (updater.VersionInfo, error)
InstallUpdate(context.Context, updater.Downloader, updater.Release) error
}

View File

@ -21,22 +21,168 @@ import (
"context"
"errors"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/safe"
"github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/elastic/go-sysinfo"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
func (bridge *Bridge) CheckForUpdates() {
bridge.goUpdate()
}
func (bridge *Bridge) InstallUpdate(version updater.VersionInfo) {
bridge.installCh <- installJob{version: version, silent: false}
func (bridge *Bridge) InstallUpdateLegacy(version updater.VersionInfoLegacy) {
bridge.installChLegacy <- installJobLegacy{version: version, silent: false}
}
func (bridge *Bridge) InstallUpdate(release updater.Release) {
bridge.installCh <- installJob{Release: release, Silent: false}
}
func (bridge *Bridge) handleUpdate(version updater.VersionInfo) {
updateChannel := bridge.vault.GetUpdateChannel()
updateRollout := bridge.vault.GetUpdateRollout()
autoUpdateEnabled := bridge.vault.GetAutoUpdate()
checkSystemVersion := true
hostInfo, err := sysinfo.Host()
// If we're unable to get host system information we skip the update's minimum/maximum OS version checks
if err != nil {
checkSystemVersion = false
logrus.WithError(err).Error("Failed to obtain host system info while handling updates")
if reporterErr := bridge.reporter.ReportMessageWithContext(
"Failed to obtain host system info while handling updates",
reporter.Context{"error": err},
); reporterErr != nil {
logrus.WithError(reporterErr).Error("Failed to report update error")
}
}
if len(version.Releases) > 0 {
// Update latest is only used to update the release notes and landing page URL
bridge.publish(events.UpdateLatest{Release: version.Releases[0]})
}
// minAutoUpdateEvent - used to determine the highest compatible update that satisfies the Minimum Bridge version
minAutoUpdateEvent := events.UpdateAvailable{
Release: updater.Release{Version: &semver.Version{}},
Compatible: false,
Silent: false,
}
// We assume that the version file is always created in descending order
// where newer versions are prepended to the top of the releases
// The logic for checking update eligibility is as follows:
// 1. Check release channel.
// 2. Check whether release version is greater.
// 3. Check if rollout is larger.
// 4. Check OS Version restrictions (provided that restrictions are provided, and we can extract the OS version).
// 5. Check Minimum Compatible Bridge Version.
// 6. Check if an update package is provided.
// 7. Check auto-update.
for _, release := range version.Releases {
log := logrus.WithFields(logrus.Fields{
"current": bridge.curVersion,
"channel": updateChannel,
"update_version": release.Version,
"update_channel": release.ReleaseCategory,
"update_min_auto": release.MinAuto,
"update_rollout": release.RolloutProportion,
"update_min_os_version": release.SystemVersion.Minimum,
"update_max_os_version": release.SystemVersion.Maximum,
})
log.Debug("Checking update release")
if !release.ReleaseCategory.UpdateEligible(updateChannel) {
log.Debug("Update does not satisfy update channel requirement")
continue
}
if !release.Version.GreaterThan(bridge.curVersion) {
log.Debug("Update version is not greater than current version")
continue
}
if release.RolloutProportion < updateRollout {
log.Debug("Update has not been rolled out yet")
continue
}
if checkSystemVersion {
shouldContinue, err := release.SystemVersion.IsHostVersionEligible(log, hostInfo, bridge.getHostVersion)
if err != nil && shouldContinue {
log.WithError(err).Error(
"Failed to verify host system version compatibility during release check." +
"Error is non-fatal continuing with checks",
)
} else if err != nil {
log.WithError(err).Error("Failed to verify host system version compatibility during update check")
continue
}
if !shouldContinue {
log.Debug("Host version does not satisfy system requirements for update")
continue
}
}
if release.MinAuto != nil && bridge.curVersion.LessThan(release.MinAuto) {
log.Debug("Update is available but is incompatible with this Bridge version")
if release.Version.GreaterThan(minAutoUpdateEvent.Release.Version) {
minAutoUpdateEvent.Release = release
}
continue
}
// Check if we have a provided installer package
if found := slices.IndexFunc(release.File, func(file updater.File) bool {
return file.Identifier == updater.PackageIdentifier
}); found == -1 {
log.Error("Update is available but does not contain update package")
if reporterErr := bridge.reporter.ReportMessageWithContext(
"Available update does not contain update package",
reporter.Context{"update_version": release.Version},
); reporterErr != nil {
log.WithError(reporterErr).Error("Failed to report update error")
}
continue
}
if !autoUpdateEnabled {
log.Info("An update is available but auto-update is disabled")
bridge.publish(events.UpdateAvailable{
Release: release,
Compatible: true,
Silent: false,
})
return
}
// If we've gotten to this point that means an automatic update is available and we should install it
safe.RLock(func() {
bridge.installCh <- installJob{Release: release, Silent: true}
}, bridge.newVersionLock)
return
}
// If there's a release with a minAuto requirement that we satisfy (alongside all other checks)
// then notify the user that a manual update is needed
if !minAutoUpdateEvent.Release.Version.Equal(&semver.Version{}) {
bridge.publish(minAutoUpdateEvent)
}
bridge.publish(events.UpdateNotAvailable{})
}
func (bridge *Bridge) handleUpdateLegacy(version updater.VersionInfoLegacy) {
log := logrus.WithFields(logrus.Fields{
"version": version.Version,
"current": bridge.curVersion,
@ -44,7 +190,7 @@ func (bridge *Bridge) handleUpdate(version updater.VersionInfo) {
})
bridge.publish(events.UpdateLatest{
Version: version,
VersionLegacy: version,
})
switch {
@ -62,33 +208,33 @@ func (bridge *Bridge) handleUpdate(version updater.VersionInfo) {
log.Info("An update is available but is incompatible with this version")
bridge.publish(events.UpdateAvailable{
Version: version,
Compatible: false,
Silent: false,
VersionLegacy: version,
Compatible: false,
Silent: false,
})
case !bridge.vault.GetAutoUpdate():
log.Info("An update is available but auto-update is disabled")
bridge.publish(events.UpdateAvailable{
Version: version,
Compatible: true,
Silent: false,
VersionLegacy: version,
Compatible: true,
Silent: false,
})
default:
safe.RLock(func() {
bridge.installCh <- installJob{version: version, silent: true}
bridge.installChLegacy <- installJobLegacy{version: version, silent: true}
}, bridge.newVersionLock)
}
}
type installJob struct {
version updater.VersionInfo
type installJobLegacy struct {
version updater.VersionInfoLegacy
silent bool
}
func (bridge *Bridge) installUpdate(ctx context.Context, job installJob) {
func (bridge *Bridge) installUpdateLegacy(ctx context.Context, job installJobLegacy) {
safe.Lock(func() {
log := logrus.WithFields(logrus.Fields{
"version": job.version.Version,
@ -103,17 +249,12 @@ func (bridge *Bridge) installUpdate(ctx context.Context, job installJob) {
log.WithField("silent", job.silent).Info("An update is available")
bridge.publish(events.UpdateAvailable{
Version: job.version,
Compatible: true,
Silent: job.silent,
VersionLegacy: job.version,
Compatible: true,
Silent: job.silent,
})
bridge.publish(events.UpdateInstalling{
Version: job.version,
Silent: job.silent,
})
err := bridge.updater.InstallUpdate(ctx, bridge.api, job.version)
err := bridge.updater.InstallUpdateLegacy(ctx, bridge.api, job.version)
switch {
case errors.Is(err, updater.ErrDownloadVerify):
@ -134,8 +275,79 @@ func (bridge *Bridge) installUpdate(ctx context.Context, job installJob) {
log.WithError(err).Error("The update could not be installed")
bridge.publish(events.UpdateFailed{
Version: job.version,
Silent: job.silent,
VersionLegacy: job.version,
Silent: job.silent,
Error: err,
})
default:
log.Info("The update was installed successfully")
bridge.publish(events.UpdateInstalled{
VersionLegacy: job.version,
Silent: job.silent,
})
bridge.newVersion = job.version.Version
}
}, bridge.newVersionLock)
}
type installJob struct {
Release updater.Release
Silent bool
}
func (bridge *Bridge) installUpdate(ctx context.Context, job installJob) {
safe.Lock(func() {
log := logrus.WithFields(logrus.Fields{
"version": job.Release.Version,
"current": bridge.curVersion,
"channel": bridge.vault.GetUpdateChannel(),
})
if !job.Release.Version.GreaterThan(bridge.newVersion) {
return
}
log.WithField("silent", job.Silent).Info("An update is available")
bridge.publish(events.UpdateAvailable{
Release: job.Release,
Compatible: true,
Silent: job.Silent,
})
err := bridge.updater.InstallUpdate(ctx, bridge.api, job.Release)
switch {
case errors.Is(err, updater.ErrReleaseUpdatePackageMissing):
log.WithError(err).Error("The update could not be installed but we will fail silently")
if reporterErr := bridge.reporter.ReportExceptionWithContext(
"Cannot download update, update package is missing",
reporter.Context{"error": err},
); reporterErr != nil {
log.WithError(reporterErr).Error("Failed to report update error")
}
case errors.Is(err, updater.ErrDownloadVerify):
// BRIDGE-207: if download or verification fails, we do not want to trigger a manual update. We report in the log and to Sentry
// and we fail silently.
log.WithError(err).Error("The update could not be installed, but we will fail silently")
if reporterErr := bridge.reporter.ReportMessageWithContext(
"Cannot download or verify update",
reporter.Context{"error": err},
); reporterErr != nil {
log.WithError(reporterErr).Error("Failed to report update error")
}
case errors.Is(err, updater.ErrUpdateAlreadyInstalled):
log.Info("The update was already installed")
case err != nil:
log.WithError(err).Error("The update could not be installed")
bridge.publish(events.UpdateFailed{
Release: job.Release,
Silent: job.Silent,
Error: err,
})
@ -143,11 +355,11 @@ func (bridge *Bridge) installUpdate(ctx context.Context, job installJob) {
log.Info("The update was installed successfully")
bridge.publish(events.UpdateInstalled{
Version: job.version,
Silent: job.silent,
Release: job.Release,
Silent: job.Silent,
})
bridge.newVersion = job.version.Version
bridge.newVersion = job.Release.Version
}
}, bridge.newVersionLock)
}

View File

@ -0,0 +1,700 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package bridge_test
import (
"context"
"runtime"
"testing"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
bridgePkg "github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/updater"
"github.com/ProtonMail/proton-bridge/v3/internal/updater/versioncompare"
"github.com/elastic/go-sysinfo/types"
"github.com/stretchr/testify/require"
)
// NOTE: we always assume the highest version is always the first in the release json array
func Test_Update_BetaEligible(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
defer done()
err := bridge.SetUpdateChannel(updater.EarlyChannel)
require.NoError(t, err)
bridge.SetCurrentVersionTest(semver.MustParse("2.1.1"))
expectedRelease := updater.Release{
ReleaseCategory: updater.EarlyAccessReleaseCategory,
Version: semver.MustParse("2.1.2"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
expectedRelease,
}}
go func() {
time.Sleep(1 * time.Second)
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
}()
select {
case update := <-updateCh:
require.Equal(t, events.UpdateInstalled{
Release: expectedRelease,
Silent: true,
}, update)
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for update")
}
})
})
}
func Test_Update_Stable(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
defer done()
err := bridge.SetUpdateChannel(updater.StableChannel)
require.NoError(t, err)
bridge.SetCurrentVersionTest(semver.MustParse("2.1.1"))
expectedRelease := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.1.3"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
{
ReleaseCategory: updater.EarlyAccessReleaseCategory,
Version: semver.MustParse("2.1.4"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
expectedRelease,
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
require.Equal(t, events.UpdateInstalled{
Release: expectedRelease,
Silent: true,
}, <-updateCh)
})
})
}
func Test_Update_CurrentReleaseNewest(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{})
defer done()
err := bridge.SetUpdateChannel(updater.StableChannel)
require.NoError(t, err)
bridge.SetCurrentVersionTest(semver.MustParse("2.1.5"))
expectedRelease := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.1.3"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
{
ReleaseCategory: updater.EarlyAccessReleaseCategory,
Version: semver.MustParse("2.1.4"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
expectedRelease,
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
})
})
}
func Test_Update_NotRolledOutYet(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
require.NoError(t, bridge.SetUpdateChannel(updater.EarlyChannel))
bridge.SetCurrentVersionTest(semver.MustParse("2.0.0"))
require.NoError(t, bridge.SetRolloutPercentageTest(1.0))
updaterData := updater.VersionInfo{Releases: []updater.Release{
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.1.5"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 0.5,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.1.4"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 0.5,
MinAuto: &semver.Version{},
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
}}
mocks.Updater.SetLatestVersion(updaterData)
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{})
defer done()
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
})
})
}
func Test_Update_CheckOSVersion_NoUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
require.NoError(t, bridge.SetAutoUpdate(true))
require.NoError(t, bridge.SetUpdateChannel(updater.StableChannel))
currentBridgeVersion := semver.MustParse("2.1.5")
bridge.SetCurrentVersionTest(currentBridgeVersion)
// Override the OS version check
bridge.SetHostVersionGetterTest(func(_ types.Host) string {
return "10.0.0"
})
updateNotAvailableCh, done := bridge.GetEvents(events.UpdateNotAvailable{})
defer done()
updateCh, updateChDone := bridge.GetEvents(events.UpdateInstalled{})
defer updateChDone()
expectedRelease := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.4.0"),
SystemVersion: versioncompare.SystemVersion{
Minimum: "12.0.0",
Maximum: "13.0.0",
},
RolloutProportion: 1.0,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
expectedRelease,
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.3.0"),
SystemVersion: versioncompare.SystemVersion{
Minimum: "10.1.0",
Maximum: "11.5",
},
RolloutProportion: 1.0,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
if runtime.GOOS == "darwin" {
require.Equal(t, events.UpdateNotAvailable{}, <-updateNotAvailableCh)
} else {
require.Equal(t, events.UpdateInstalled{
Release: expectedRelease,
Silent: true,
}, <-updateCh)
}
})
})
}
func Test_Update_CheckOSVersion_HasUpdate(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
require.NoError(t, bridge.SetAutoUpdate(true))
require.NoError(t, bridge.SetUpdateChannel(updater.StableChannel))
updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
defer done()
currentBridgeVersion := semver.MustParse("2.1.5")
bridge.SetCurrentVersionTest(currentBridgeVersion)
// Override the OS version check
bridge.SetHostVersionGetterTest(func(_ types.Host) string {
return "10.0.0"
})
expectedUpdateRelease := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.2.0"),
SystemVersion: versioncompare.SystemVersion{
Minimum: "10.0.0",
Maximum: "10.1.12",
},
RolloutProportion: 1.0,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
expectedUpdateReleaseWindowsLinux := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.4.0"),
SystemVersion: versioncompare.SystemVersion{
Minimum: "12.0.0",
},
RolloutProportion: 1.0,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
expectedUpdateReleaseWindowsLinux,
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.3.0"),
SystemVersion: versioncompare.SystemVersion{
Minimum: "11.0.0",
},
RolloutProportion: 1.0,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
expectedUpdateRelease,
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.1.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
if runtime.GOOS == "darwin" {
require.Equal(t, events.UpdateInstalled{
Release: expectedUpdateRelease,
Silent: true,
}, <-updateCh)
} else {
require.Equal(t, events.UpdateInstalled{
Release: expectedUpdateReleaseWindowsLinux,
Silent: true,
}, <-updateCh)
}
})
})
}
func Test_Update_UpdateFromMinVer_UpdateAvailable(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
require.NoError(t, bridge.SetAutoUpdate(true))
require.NoError(t, bridge.SetUpdateChannel(updater.StableChannel))
currentBridgeVersion := semver.MustParse("2.1.5")
bridge.SetCurrentVersionTest(currentBridgeVersion)
updateCh, done := bridge.GetEvents(events.UpdateInstalled{})
defer done()
expectedUpdateRelease := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.2.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: currentBridgeVersion,
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.3.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.2.1"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.2.1"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.2.0"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
expectedUpdateRelease,
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
require.Equal(t, events.UpdateInstalled{
Release: expectedUpdateRelease,
Silent: true,
}, <-updateCh)
})
})
}
// Test_Update_UpdateFromMinVer_NoCompatibleVersionForceManual -
// if we have an update, but we don't satisfy minVersion, a manual update to the highest possible version should be performed.
func Test_Update_UpdateFromMinVer_NoCompatibleVersionForceManual(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
require.NoError(t, bridge.SetAutoUpdate(true))
require.NoError(t, bridge.SetUpdateChannel(updater.StableChannel))
currentBridgeVersion := semver.MustParse("2.1.5")
bridge.SetCurrentVersionTest(currentBridgeVersion)
updateCh, done := bridge.GetEvents(events.UpdateAvailable{})
defer done()
expectedUpdateRelease := updater.Release{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.3.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.2.1"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.2.1"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.2.0"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
{
ReleaseCategory: updater.StableReleaseCategory,
Version: semver.MustParse("2.2.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.1.6"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
expectedUpdateRelease,
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
require.Equal(t, events.UpdateAvailable{
Release: expectedUpdateRelease,
Silent: false,
Compatible: false,
}, <-updateCh)
})
})
}
// Test_Update_UpdateFromMinVer_NoCompatibleVersionForceManual_BetaMismatch - only Beta updates are available
// nor do we satisfy the minVersion, we can't do anything in this case.
func Test_Update_UpdateFromMinVer_NoCompatibleVersionForceManual_BetaMismatch(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridgePkg.Locator, vaultKey []byte) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, vaultKey, func(bridge *bridgePkg.Bridge, mocks *bridgePkg.Mocks) {
require.NoError(t, bridge.SetAutoUpdate(true))
require.NoError(t, bridge.SetUpdateChannel(updater.StableChannel))
currentBridgeVersion := semver.MustParse("2.1.5")
bridge.SetCurrentVersionTest(currentBridgeVersion)
updateCh, done := bridge.GetEvents(events.UpdateNotAvailable{})
defer done()
expectedUpdateRelease := updater.Release{
ReleaseCategory: updater.EarlyAccessReleaseCategory,
Version: semver.MustParse("2.3.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.2.1"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
}
updaterData := updater.VersionInfo{Releases: []updater.Release{
{
ReleaseCategory: updater.EarlyAccessReleaseCategory,
Version: semver.MustParse("2.2.1"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.2.0"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
{
ReleaseCategory: updater.EarlyAccessReleaseCategory,
Version: semver.MustParse("2.2.0"),
SystemVersion: versioncompare.SystemVersion{},
RolloutProportion: 1.0,
MinAuto: semver.MustParse("2.1.6"),
File: []updater.File{
{
URL: "RANDOM_INSTALLER_URL",
Identifier: updater.InstallerIdentifier,
},
{
URL: "RANDOM_PACKAGE_URL",
Identifier: updater.PackageIdentifier,
},
},
},
expectedUpdateRelease,
}}
mocks.Updater.SetLatestVersion(updaterData)
bridge.CheckForUpdates()
require.Equal(t, events.UpdateNotAvailable{}, <-updateCh)
})
})
}

View File

@ -551,7 +551,7 @@ func (bridge *Bridge) addUserWithVault(
syncSettingsPath,
isNew,
bridge.notificationStore,
bridge.unleashService.GetFlagValue,
bridge.unleashService,
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)

View File

@ -23,6 +23,7 @@ import (
"net"
"net/http"
"net/mail"
"runtime"
"strings"
"sync/atomic"
"testing"
@ -76,6 +77,9 @@ func TestBridge_User_RefreshEvent(t *testing.T) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
syncCh, closeCh := chToType[events.Event, events.SyncFinished](bridge.GetEvents(events.SyncFinished{}))
if runtime.GOOS != "windows" {
require.Equal(t, userID, (<-syncCh).UserID)
}
require.Equal(t, userID, (<-syncCh).UserID)
closeCh()
@ -304,7 +308,7 @@ func TestBridge_User_AddressEvents_NoBadEvent(t *testing.T) {
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
userLoginAndSync(ctx, t, bridge, "user", password)
addrID, err = s.CreateAddress(userID, "other@pm.me", password)
addrID, err = s.CreateAddress(userID, "other@pm.me", password, true)
require.NoError(t, err)
userContinueEventProcess(ctx, t, s, bridge)
@ -312,7 +316,7 @@ func TestBridge_User_AddressEvents_NoBadEvent(t *testing.T) {
userContinueEventProcess(ctx, t, s, bridge)
})
otherID, err := s.CreateAddress(userID, "another@pm.me", password)
otherID, err := s.CreateAddress(userID, "another@pm.me", password, true)
require.NoError(t, err)
require.NoError(t, s.RemoveAddress(userID, otherID))
@ -328,6 +332,87 @@ func TestBridge_User_AddressEvents_NoBadEvent(t *testing.T) {
})
}
func TestBridge_User_AddressEvents_BYOEAddressAdded(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
// Create a user.
userID, addrID, err := s.CreateUser("user", password)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
userLoginAndSync(ctx, t, bridge, "user", password)
// Create an additional proton address
addrID, err = s.CreateAddress(userID, "other@pm.me", password, true)
require.NoError(t, err)
userContinueEventProcess(ctx, t, s, bridge)
require.NoError(t, s.AddAddressCreatedEvent(userID, addrID))
userContinueEventProcess(ctx, t, s, bridge)
userInfo, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 2, len(userInfo.Addresses))
// Create an external address with sending disabled.
externalID, err := s.CreateExternalAddress(userID, "another@yahoo.com", password, false)
require.NoError(t, err)
userContinueEventProcess(ctx, t, s, bridge)
require.NoError(t, s.AddAddressCreatedEvent(userID, externalID))
userContinueEventProcess(ctx, t, s, bridge)
// User addresses should still return 2, as we ignore the external address.
userInfo, err = bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 2, len(userInfo.Addresses))
// Create an external address w. sending enabled. This is considered a BYOE address.
BYOEAddrID, err := s.CreateExternalAddress(userID, "other@yahoo.com", password, true)
require.NoError(t, err)
userContinueEventProcess(ctx, t, s, bridge)
require.NoError(t, s.AddAddressCreatedEvent(userID, BYOEAddrID))
userContinueEventProcess(ctx, t, s, bridge)
userInfo, err = bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 3, len(userInfo.Addresses))
})
})
}
func TestBridge_User_AddressEvents_ExternalAddressSendChanged(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
userID, _, err := s.CreateUser("user", password)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
userLoginAndSync(ctx, t, bridge, "user", password)
// Create an additional external address.
externalID, err := s.CreateExternalAddress(userID, "other@yahoo.me", password, false)
require.NoError(t, err)
userContinueEventProcess(ctx, t, s, bridge)
require.NoError(t, s.AddAddressCreatedEvent(userID, externalID))
userContinueEventProcess(ctx, t, s, bridge)
// We expect only one address, the external one without sending should not be considered a valid address.
userInfo, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 1, len(userInfo.Addresses))
// Change it to allow sending such that it becomes a BYOE address.
err = s.ChangeAddressAllowSend(userID, externalID, true)
require.NoError(t, err)
userContinueEventProcess(ctx, t, s, bridge)
require.NoError(t, s.AddAddressUpdatedEvent(userID, externalID))
userContinueEventProcess(ctx, t, s, bridge)
// We should now have 2 usable addresses listed.
userInfo, err = bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 2, len(userInfo.Addresses))
})
})
}
func TestBridge_User_AddressEventUpdatedForAddressThatDoesNotExist_NoBadEvent(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
// Create a user.
@ -694,7 +779,7 @@ func TestBridge_User_DisableEnableAddress(t *testing.T) {
require.NoError(t, err)
// Create an additional address for the user.
aliasID, err := s.CreateAddress(userID, "alias@"+s.GetDomain(), password)
aliasID, err := s.CreateAddress(userID, "alias@"+s.GetDomain(), password, true)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
@ -745,7 +830,7 @@ func TestBridge_User_CreateDisabledAddress(t *testing.T) {
require.NoError(t, err)
// Create an additional address for the user.
aliasID, err := s.CreateAddress(userID, "alias@"+s.GetDomain(), password)
aliasID, err := s.CreateAddress(userID, "alias@"+s.GetDomain(), password, true)
require.NoError(t, err)
// Immediately disable the address.

View File

@ -658,7 +658,7 @@ func TestBridge_UserInfo_Alias(t *testing.T) {
require.NoError(t, err)
// Give the new user an alias.
require.NoError(t, getErr(s.CreateAddress(userID, "alias@pm.me", []byte("password"))))
require.NoError(t, getErr(s.CreateAddress(userID, "alias@pm.me", []byte("password"), true)))
// Login the user.
require.NoError(t, getErr(bridge.LoginFull(ctx, "primary", []byte("password"), nil, nil)))
@ -706,7 +706,7 @@ func TestBridge_User_GetAddresses(t *testing.T) {
// Create a user.
userID, _, err := s.CreateUser("user", password)
require.NoError(t, err)
addrID2, err := s.CreateAddress(userID, "user@external.com", []byte("password"))
addrID2, err := s.CreateAddress(userID, "user@external.com", password, false)
require.NoError(t, err)
require.NoError(t, s.ChangeAddressType(userID, addrID2, proton.AddressTypeExternal))
@ -720,6 +720,29 @@ func TestBridge_User_GetAddresses(t *testing.T) {
})
}
func TestBridge_User_GetAddresses_BYOE(t *testing.T) {
withEnv(t, func(ctx context.Context, s *server.Server, netCtl *proton.NetCtl, locator bridge.Locator, storeKey []byte) {
// Create a user.
userID, _, err := s.CreateUser("user", password)
require.NoError(t, err)
// Add a non-sending external address.
_, err = s.CreateExternalAddress(userID, "user@external.com", password, false)
require.NoError(t, err)
// Add a BYOE address.
_, err = s.CreateExternalAddress(userID, "user2@external.com", password, true)
require.NoError(t, err)
withBridge(ctx, t, s.GetHostURL(), netCtl, locator, storeKey, func(bridge *bridge.Bridge, _ *bridge.Mocks) {
userLoginAndSync(ctx, t, bridge, "user", password)
info, err := bridge.GetUserInfo(userID)
require.NoError(t, err)
require.Equal(t, 2, len(info.Addresses))
require.Equal(t, info.Addresses[0], "user@proton.local")
require.Equal(t, info.Addresses[1], "user2@external.com")
})
})
}
// getErr returns the error that was passed to it.
func getErr[T any](_ T, err error) error {
return err

View File

@ -22,6 +22,8 @@ import (
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
"time"
)
@ -29,6 +31,11 @@ type TLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
}
type SecureTLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error)
ShouldSkipCertificateChainVerification(address string) bool
}
func SetBasicTransportTimeouts(t *http.Transport) {
t.MaxIdleConns = 100
t.MaxIdleConnsPerHost = 100
@ -71,6 +78,35 @@ func NewBasicTLSDialer(hostURL string) *BasicTLSDialer {
}
}
func extractDomain(hostname string) string {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
return strings.Join(parts[len(parts)-2:], ".")
}
return hostname
}
// ShouldSkipCertificateChainVerification determines whether certificate chain validation should be skipped.
// It compares the domain of the requested address with the configured host URL domain.
// Returns true if the domains don't match (skip verification), false if they do (perform verification).
//
// NOTE: This assumes single-part TLDs (.com, .me) and won't handle multi-part TLDs correctly.
func (d *BasicTLSDialer) ShouldSkipCertificateChainVerification(address string) bool {
parsedURL, err := url.Parse(d.hostURL)
if err != nil {
return true
}
addressHost, _, err := net.SplitHostPort(address)
if err != nil {
addressHost = address
}
hostDomain := extractDomain(parsedURL.Host)
addressDomain := extractDomain(addressHost)
return addressDomain != hostDomain
}
// DialTLSContext returns a connection to the given address using the given network.
func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
return (&tls.Dialer{
@ -78,7 +114,7 @@ func (d *BasicTLSDialer) DialTLSContext(ctx context.Context, network, address st
Timeout: 30 * time.Second,
},
Config: &tls.Config{
InsecureSkipVerify: address != d.hostURL, //nolint:gosec
InsecureSkipVerify: d.ShouldSkipCertificateChainVerification(address), //nolint:gosec
},
}).DialContext(ctx, network, address)
}

View File

@ -0,0 +1,134 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package dialer
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestBasicTLSDialer_ShouldSkipCertificateChainVerification(t *testing.T) {
tests := []struct {
hostURL string
address string
expected bool
}{
{
hostURL: "https://mail-api.proton.me",
address: "mail-api.proton.me:443",
expected: false,
},
{
hostURL: "https://proton.me",
address: "proton.me",
expected: false,
},
{
hostURL: "https://api.proton.me",
address: "mail.proton.me:443",
expected: false,
},
{
hostURL: "https://proton.me",
address: "mail-api.proton.me:443",
expected: false,
},
{
hostURL: "https://mail-api.proton.me",
address: "proton.me:443",
expected: false,
},
{
hostURL: "https://mail.google.com",
address: "mail-api.proton.me:443",
expected: true,
},
{
hostURL: "https://mail-api.protonmail.com",
address: "mail-api.proton.me:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "google.com:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "proton.com:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "example.me:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "mail.example.com:443",
expected: true,
},
{
hostURL: "https://proton.me",
address: "proton.me",
expected: false,
},
{
hostURL: "https://proton.me:8080",
address: "proton.me:443",
expected: true,
},
{
hostURL: "https://proton.me/api/v1",
address: "proton.me:443",
expected: false,
},
{
hostURL: "https://proton.black",
address: "mail-api.pascal.proton.black",
expected: false,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "mail-api.pascal.proton.black",
expected: false,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "proton.black:332",
expected: false,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "proton.me",
expected: true,
},
{
hostURL: "https://mail-api.pascal.proton.black",
address: "proton.me:332",
expected: true,
},
}
for _, tt := range tests {
dialer := NewBasicTLSDialer(tt.hostURL)
result := dialer.ShouldSkipCertificateChainVerification(tt.address)
require.Equal(t, tt.expected, result)
}
}

View File

@ -50,12 +50,12 @@ var TrustedAPIPins = []string{ //nolint:gochecknoglobals
}
// TLSReportURI is the address where TLS reports should be sent.
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
const TLSReportURI = "https://reports.proton.me/reports/tls"
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
// to report errors if the fingerprint check fails.
type PinningTLSDialer struct {
dialer TLSDialer
dialer SecureTLSDialer
pinChecker PinChecker
reporter Reporter
tlsIssueCh chan struct{}
@ -68,13 +68,13 @@ type Reporter interface {
// PinChecker is used to check TLS keys of connections.
type PinChecker interface {
CheckCertificate(conn net.Conn) error
CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error
}
// NewPinningTLSDialer constructs a new dialer which only returns TCP connections to servers
// which present known certificates.
// It checks pins using the given pinChecker and reports issues using the given reporter.
func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
func NewPinningTLSDialer(dialer SecureTLSDialer, reporter Reporter, pinChecker PinChecker) *PinningTLSDialer {
return &PinningTLSDialer{
dialer: dialer,
pinChecker: pinChecker,
@ -85,6 +85,7 @@ func NewPinningTLSDialer(dialer TLSDialer, reporter Reporter, pinChecker PinChec
// DialTLSContext dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
shouldSkipCertificateChainVerification := p.dialer.ShouldSkipCertificateChainVerification(address)
conn, err := p.dialer.DialTLSContext(ctx, network, address)
if err != nil {
return nil, err
@ -95,7 +96,7 @@ func (p *PinningTLSDialer) DialTLSContext(ctx context.Context, network, address
return nil, err
}
if err := p.pinChecker.CheckCertificate(conn); err != nil {
if err := p.pinChecker.CheckCertificate(conn, shouldSkipCertificateChainVerification); err != nil {
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.ReportCertIssue(TLSReportURI, host, port, tlsConn.ConnectionState())
}

View File

@ -41,3 +41,15 @@ func NewTLSPinChecker(trustedPins []string) *TLSPinChecker {
func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
}
func (p *TLSPinChecker) isCertFoundInKnownPins(cert *x509.Certificate) bool {
fingerprint := certFingerprint(cert)
for _, pin := range p.trustedPins {
if pin == fingerprint {
return true
}
}
return false
}

View File

@ -25,8 +25,8 @@ import (
"net"
)
// CheckCertificate returns whether the connection presents a known TLS certificate.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
// CheckCertificate verifies that the connection presents a known pinned leaf TLS certificate.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, certificateChainVerificationSkipped bool) error {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return errors.New("connection is not a TLS connection")
@ -34,14 +34,31 @@ func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
connState := tlsConn.ConnectionState()
for _, peerCert := range connState.PeerCertificates {
fingerprint := certFingerprint(peerCert)
// When certificate chain verification is enabled (e.g., for known API hosts), we expect the TLS handshake to produce verified chains.
// We then validate that the leaf certificate of at least one verified chain matches a known pinned public key.
if !certificateChainVerificationSkipped {
if len(connState.VerifiedChains) == 0 {
return errors.New("no verified certificate chains")
}
for _, pin := range p.trustedPins {
if pin == fingerprint {
for _, chain := range connState.VerifiedChains {
// Check if the leaf certificate is one of the trusted pins.
if p.isCertFoundInKnownPins(chain[0]) {
return nil
}
}
return ErrTLSMismatch
}
// When certificate chain verification is skipped (e.g., for DoH proxies using self-signed certs),
// we only validate the leaf certificate against known pinned public keys.
if len(connState.PeerCertificates) == 0 {
return errors.New("no peer certificates available")
}
if p.isCertFoundInKnownPins(connState.PeerCertificates[0]) {
return nil
}
return ErrTLSMismatch

View File

@ -23,6 +23,6 @@ import "net"
// CheckCertificate returns whether the connection presents a known TLS certificate.
// The QA implementation always returns nil.
func (p *TLSPinChecker) CheckCertificate(conn net.Conn) error {
func (p *TLSPinChecker) CheckCertificate(conn net.Conn, _ bool) error {
return nil
}

View File

@ -64,8 +64,7 @@ func TestTLSPinInvalid(t *testing.T) {
checkTLSIssueHandler(t, 1, called)
}
// Disabled for now we'll need to patch this up.
func _TestTLSPinNoMatch(t *testing.T) { //nolint:unused
func TestTLSPinNoMatch(t *testing.T) {
skipIfProxyIsSet(t)
called, _, reporter, checker, cm := createClientWithPinningDialer(getRootURL())
@ -91,13 +90,12 @@ func TestTLSSignedCertWrongPublicKey(t *testing.T) {
r.Error(t, err, "expected dial to fail because of wrong public key")
}
// GODT-2293 bump badssl cert and re enable this.
func _TestTLSSignedCertTrustedPublicKey(t *testing.T) { //nolint:unused,deadcode
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _, checker, _ := createClientWithPinningDialer("")
copyTrustedPins(checker)
checker.trustedPins = append(checker.trustedPins, `pin-sha256="LwnIKjNLV3z243ap8y0yXNPghsqE76J08Eq3COvUt2E="`)
checker.trustedPins = append(checker.trustedPins, `pin-sha256="FlvTPG/nIMKtOj9nelnEjujwSZ5EDyfiKYxZgbXREls="`)
_, err := dialer.DialTLSContext(context.Background(), "tcp", "rsa4096.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
}

View File

@ -24,14 +24,35 @@ import (
)
// UpdateLatest is published when the latest version of bridge is known.
// It is only used for updating the release notes and landing page URLs.
type UpdateLatest struct {
eventBase
Version updater.VersionInfo
// VersionLegacy - holds Update version information; corresponding to the old update structure and logic;
VersionLegacy updater.VersionInfoLegacy
// Release - holds Release version data; part of the new update logic as of BRIDGE-309.
Release updater.Release
}
func (event UpdateLatest) GetLatestVersion() string {
var latestVersion string
if !event.VersionLegacy.IsEmpty() {
latestVersion = event.VersionLegacy.Version.String()
} else if !event.Release.IsEmpty() {
latestVersion = event.Release.Version.String()
}
return latestVersion
}
func (event UpdateLatest) String() string {
return fmt.Sprintf("UpdateLatest: Version: %s", event.Version.Version)
if !event.VersionLegacy.IsEmpty() {
return fmt.Sprintf("UpdateLatest: Version: %s", event.VersionLegacy.Version)
}
if !event.Release.IsEmpty() {
return fmt.Sprintf("UpdateLatest: Version: %s", event.Release.Version)
}
return ""
}
// UpdateAvailable is published when an update is available.
@ -40,7 +61,11 @@ func (event UpdateLatest) String() string {
type UpdateAvailable struct {
eventBase
Version updater.VersionInfo
// VersionLegacy - holds Update version information; corresponding to the old update structure and logic;
VersionLegacy updater.VersionInfoLegacy
// Release - holds Release version data; part of the new update logic as of BRIDGE-309.
Release updater.Release
// Compatible is true if the update can be installed automatically.
Compatible bool
@ -49,8 +74,23 @@ type UpdateAvailable struct {
Silent bool
}
func (event UpdateAvailable) GetLatestVersion() string {
var latestVersion string
if !event.VersionLegacy.IsEmpty() {
latestVersion = event.VersionLegacy.Version.String()
} else if !event.Release.IsEmpty() {
latestVersion = event.Release.Version.String()
}
return latestVersion
}
func (event UpdateAvailable) String() string {
return fmt.Sprintf("UpdateAvailable: Version %s, Compatible: %t, Silent: %t", event.Version.Version, event.Compatible, event.Silent)
if !event.Release.IsEmpty() {
return fmt.Sprintf("UpdateAvailable: Version %s, Compatible: %t, Silent: %t", event.Release.Version, event.Compatible, event.Silent)
} else if !event.VersionLegacy.IsEmpty() {
return fmt.Sprintf("UpdateAvailable: Version %s, Compatible: %t, Silent: %t", event.VersionLegacy.Version, event.Compatible, event.Silent)
}
return ""
}
// UpdateNotAvailable is published when no update is available.
@ -62,45 +102,70 @@ func (event UpdateNotAvailable) String() string {
return "UpdateNotAvailable"
}
// UpdateInstalling is published when bridge begins installing an update.
type UpdateInstalling struct {
eventBase
Version updater.VersionInfo
Silent bool
}
func (event UpdateInstalling) String() string {
return fmt.Sprintf("UpdateInstalling: Version %s, Silent: %t", event.Version.Version, event.Silent)
}
// UpdateInstalled is published when an update has been installed.
type UpdateInstalled struct {
eventBase
Version updater.VersionInfo
// VersionLegacy - holds Update version information; corresponding to the old update structure and logic;
VersionLegacy updater.VersionInfoLegacy
// Release - holds Release version data; part of the new update logic as of BRIDGE-309.
Release updater.Release
Silent bool
}
func (event UpdateInstalled) GetLatestVersion() string {
var latestVersion string
if !event.VersionLegacy.IsEmpty() {
latestVersion = event.VersionLegacy.Version.String()
} else if !event.Release.IsEmpty() {
latestVersion = event.Release.Version.String()
}
return latestVersion
}
func (event UpdateInstalled) String() string {
return fmt.Sprintf("UpdateInstalled: Version %s, Silent: %t", event.Version.Version, event.Silent)
if !event.Release.IsEmpty() {
return fmt.Sprintf("UpdateInstalled: Version %s, Silent: %t", event.Release.Version, event.Silent)
} else if !event.VersionLegacy.IsEmpty() {
return fmt.Sprintf("UpdateInstalled: Version %s, Silent: %t", event.VersionLegacy.Version, event.Silent)
}
return ""
}
// UpdateFailed is published when an update fails to be installed.
type UpdateFailed struct {
eventBase
Version updater.VersionInfo
// VersionLegacy - holds Update version information; corresponding to the old update structure and logic;
VersionLegacy updater.VersionInfoLegacy
// Release - holds Release version data; part of the new update logic as of BRIDGE-309.
Release updater.Release
Silent bool
Error error
}
func (event UpdateFailed) GetLatestVersion() string {
var latestVersion string
if !event.VersionLegacy.IsEmpty() {
latestVersion = event.VersionLegacy.Version.String()
} else if !event.Release.IsEmpty() {
latestVersion = event.Release.Version.String()
}
return latestVersion
}
func (event UpdateFailed) String() string {
return fmt.Sprintf("UpdateFailed: Version %s, Silent: %t, Error: %s", event.Version.Version, event.Silent, event.Error)
if !event.Release.IsEmpty() {
return fmt.Sprintf("UpdateFailed: Version %s, Silent: %t, Error: %s", event.Release.Version, event.Silent, event.Error)
} else if !event.VersionLegacy.IsEmpty() {
return fmt.Sprintf("UpdateFailed: Version %s, Silent: %t, Error: %s", event.VersionLegacy.Version, event.Silent, event.Error)
}
return ""
}
// UpdateForced is published when the bridge version is too old and must be updated.

View File

@ -29,7 +29,7 @@ using namespace bridgepp;
//****************************************************************************************************************************************************
BridgeApp::BridgeApp(int &argc, char **argv)
: QApplication(argc, argv) {
setAttribute(Qt::AA_DontShowIconsInMenus, false);
}

View File

@ -24,15 +24,33 @@ cmake_minimum_required(VERSION 3.22)
install(SCRIPT ${deploy_script})
# QML
install(DIRECTORY "${QT_DIR}/qml/Qt"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/MacOS")
install(DIRECTORY "${QT_DIR}/qml/Qt/labs/platform"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/MacOS/Qt/labs")
install(DIRECTORY "${QT_DIR}/qml/QtQml"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/MacOS")
install(DIRECTORY "${QT_DIR}/qml/QtQuick"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/MacOS")
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/MacOS"
PATTERN "VirtualKeyboard" EXCLUDE
PATTERN "Effects" EXCLUDE
PATTERN "LocalStorage" EXCLUDE
PATTERN "NativeStyle" EXCLUDE
PATTERN "Particles" EXCLUDE
PATTERN "Scene2D" EXCLUDE
PATTERN "Scene3D" EXCLUDE
PATTERN "Shapes" EXCLUDE
PATTERN "Timeline" EXCLUDE
PATTERN "VectorImage" EXCLUDE
PATTERN "Controls/FluentWinUI3" EXCLUDE
PATTERN "Controls/designer" EXCLUDE
PATTERN "Controls/Fusion" EXCLUDE
PATTERN "Controls/Imagine" EXCLUDE
PATTERN "Controls/Material" EXCLUDE
PATTERN "Controls/Universal" EXCLUDE
PATTERN "Controls/iOS" EXCLUDE
PATTERN "Controls/macOS" EXCLUDE)
# FRAMEWORKS
install(DIRECTORY "${QT_DIR}/lib/QtQmlWorkerScript.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
install(DIRECTORY "${QT_DIR}/lib/QtQuickControls2Impl.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
install(DIRECTORY "${QT_DIR}/lib/QtQuickLayouts.framework"
@ -43,6 +61,14 @@ install(DIRECTORY "${QT_DIR}/lib/QtQuickDialogs2QuickImpl.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
install(DIRECTORY "${QT_DIR}/lib/QtQuickDialogs2Utils.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
# ADDITIONAL FRAMEWORKS FOR Qt 6.8
install(DIRECTORY "${QT_DIR}/lib/QtQuickControls2Basic.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
install(DIRECTORY "${QT_DIR}/lib/QtLabsPlatform.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
install(DIRECTORY "${QT_DIR}/lib/QtQuickControls2BasicStyleImpl.framework"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/Frameworks")
# PLUGINS
install(FILES "${QT_DIR}/plugins/imageformats/libqsvg.dylib"
DESTINATION "${CMAKE_INSTALL_PREFIX}/bridge-gui.app/Contents/PlugIns/imageformats")

View File

@ -54,9 +54,9 @@ AppendQt6Lib("libQt6Gui.so.6")
AppendQt6Lib("libQt6Core.so.6")
AppendQt6Lib("libQt6QuickTemplates2.so.6")
AppendQt6Lib("libQt6DBus.so.6")
AppendQt6Lib("libicui18n.so.56")
AppendQt6Lib("libicuuc.so.56")
AppendQt6Lib("libicudata.so.56")
AppendQt6Lib("libicui18n.so.73")
AppendQt6Lib("libicuuc.so.73")
AppendQt6Lib("libicudata.so.73")
AppendQt6Lib("libQt6XcbQpa.so.6")
AppendQt6Lib("libQt6WaylandClient.so.6")
AppendQt6Lib("libQt6WlShellIntegration.so.6")
@ -68,6 +68,10 @@ AppendQt6Lib("libQt6PrintSupport.so.6")
AppendQt6Lib("libQt6Xml.so.6")
AppendQt6Lib("libQt6OpenGLWidgets.so.6")
AppendQt6Lib("libQt6QuickWidgets.so.6")
AppendQt6Lib("libQt6QmlMeta.so.6")
AppendQt6Lib("libQt6LabsPlatform.so.6")
AppendQt6Lib("libQt6QuickControls2Basic.so.6")
AppendQt6Lib("libQt6QuickControls2BasicStyleImpl.so.6")
# QML dependencies
AppendQt6Lib("libQt6QmlWorkerScript.so.6")

View File

@ -57,20 +57,36 @@ AppendVCPKGLib("re2.dll")
AppendVCPKGLib("sentry.dll")
AppendVCPKGLib("zlib1.dll")
# QML DLLs
AppendQt6Lib("Qt6QmlWorkerScript.dll")
AppendQt6Lib("Qt6Widgets.dll")
AppendQt6Lib("Qt6QuickControls2Impl.dll")
AppendQt6Lib("Qt6QuickLayouts.dll")
AppendQt6Lib("Qt6QuickDialogs2.dll")
AppendQt6Lib("Qt6QuickDialogs2QuickImpl.dll")
AppendQt6Lib("Qt6QuickDialogs2Utils.dll")
AppendQt6Lib("Qt6LabsPlatform.dll")
AppendQt6Lib("Qt6QuickControls2.dll")
AppendQt6Lib("Qt6QuickControls2Basic.dll")
install(FILES ${DEPLOY_LIBS} DESTINATION "${CMAKE_INSTALL_PREFIX}")
# QML PlugIns
install(DIRECTORY ${QT_DIR}/qml/Qt/labs/platform DESTINATION "${CMAKE_INSTALL_PREFIX}/Qt/labs/")
install(DIRECTORY ${QT_DIR}/qml/QtQml DESTINATION "${CMAKE_INSTALL_PREFIX}")
install(DIRECTORY ${QT_DIR}/qml/QtQuick DESTINATION "${CMAKE_INSTALL_PREFIX}")
install(DIRECTORY ${QT_DIR}/qml/QtQuick DESTINATION "${CMAKE_INSTALL_PREFIX}"
PATTERN "Effects" EXCLUDE
PATTERN "LocalStorage" EXCLUDE
PATTERN "NativeStyle" EXCLUDE
PATTERN "Particles" EXCLUDE
PATTERN "Shapes" EXCLUDE
PATTERN "VectorImage" EXCLUDE
PATTERN "Controls/designer" EXCLUDE
PATTERN "Controls/FluentWinUI3" EXCLUDE
PATTERN "Controls/Fusion" EXCLUDE
PATTERN "Controls/Imagine" EXCLUDE
PATTERN "Controls/Material" EXCLUDE
PATTERN "Controls/Universal" EXCLUDE
PATTERN "Controls/Windows" EXCLUDE)
# crash handler utils
install(PROGRAMS "${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}/tools/sentry-native/crashpad_handler.exe" DESTINATION "${CMAKE_INSTALL_PREFIX}")

View File

@ -58,9 +58,9 @@ Item {
}
ColorImage {
color: root.colorScheme.text_norm
height: root.colorScheme.body_font_size
height: ProtonStyle.body_font_size
source: "/qml/icons/ic-copy.svg"
sourceSize.height: root.colorScheme.body_font_size
sourceSize.height: ProtonStyle.body_font_size
MouseArea {
anchors.fill: parent

View File

@ -86,9 +86,9 @@ SettingsView {
ColorImage {
Layout.alignment: Qt.AlignCenter
color: root.colorScheme.interaction_norm
height: root.colorScheme.body_font_size
height: ProtonStyle.body_font_size
source: root._isAdvancedShown ? "/qml/icons/ic-chevron-down.svg" : "/qml/icons/ic-chevron-right.svg"
sourceSize.height: root.colorScheme.body_font_size
sourceSize.height: ProtonStyle.body_font_size
MouseArea {
anchors.fill: parent

View File

@ -72,9 +72,9 @@ Item {
ColorImage {
anchors.centerIn: parent
color: root.colorScheme.background_norm
height: root.colorScheme.body_font_size
height: ProtonStyle.body_font_size
source: "/qml/icons/ic-check.svg"
sourceSize.height: root.colorScheme.body_font_size
sourceSize.height: ProtonStyle.body_font_size
visible: root.checked
}
}
@ -82,9 +82,9 @@ Item {
id: loader
anchors.centerIn: parent
color: root.colorScheme.text_norm
height: root.colorScheme.body_font_size
height: ProtonStyle.body_font_size
source: "/qml/icons/Loader_16.svg"
sourceSize.height: root.colorScheme.body_font_size
sourceSize.height: ProtonStyle.body_font_size
visible: root.loading
RotationAnimation {

View File

@ -271,7 +271,10 @@ FocusScope {
usernameTextField.enabled = false;
passwordTextField.enabled = false;
loading = true;
Backend.login(usernameTextField.text, Qt.btoa(passwordTextField.text));
let usernameTextFiltered = usernameTextField.text.replace(/[\n\r]+$/, "");
let passwordTextFiltered = passwordTextField.text.replace(/[\n\r]+$/, "");
Backend.login(usernameTextFiltered, Qt.btoa(passwordTextFiltered));
}
Layout.fillWidth: true

View File

@ -482,16 +482,16 @@ func (f *frontendCLI) watchEvents(eventCh <-chan events.Event) { // nolint:gocyc
case events.UpdateAvailable:
if !event.Compatible {
f.Printf("A new version (%v) is available but it cannot be installed automatically.\n", event.Version.Version)
f.Printf("A new version (%v) is available but it cannot be installed automatically.\n", event.GetLatestVersion())
} else if !event.Silent {
f.Printf("A new version (%v) is available.\n", event.Version.Version)
f.Printf("A new version (%v) is available.\n", event.GetLatestVersion())
}
case events.UpdateInstalled:
f.Printf("A new version (%v) was installed.\n", event.Version.Version)
f.Printf("A new version (%v) was installed.\n", event.GetLatestVersion())
case events.UpdateFailed:
f.Printf("A new version (%v) failed to be installed (%v).\n", event.Version.Version, event.Error)
f.Printf("A new version (%v) failed to be installed (%v).\n", event.GetLatestVersion(), event.Error)
case events.UpdateForced:
f.notifyNeedUpgrade()

View File

@ -78,11 +78,13 @@ type Service struct { // nolint:structcheck
eventCh <-chan events.Event
quitCh <-chan struct{}
latest updater.VersionInfo
latestLock safe.RWMutex
latestLegacy updater.VersionInfoLegacy
latest updater.Release
latestLock safe.RWMutex
target updater.VersionInfo
targetLock safe.RWMutex
targetLegacy updater.VersionInfoLegacy
target updater.Release
targetLock safe.RWMutex
authClient *proton.Client
auth proton.Auth
@ -168,11 +170,13 @@ func NewService(
eventCh: eventCh,
quitCh: quitCh,
latest: updater.VersionInfo{},
latestLock: safe.NewRWMutex(),
latestLegacy: updater.VersionInfoLegacy{},
latest: updater.Release{},
latestLock: safe.NewRWMutex(),
target: updater.VersionInfo{},
targetLock: safe.NewRWMutex(),
targetLegacy: updater.VersionInfoLegacy{},
target: updater.Release{},
targetLock: safe.NewRWMutex(),
log: logrus.WithField("pkg", "grpc"),
initializing: sync.WaitGroup{},
@ -354,10 +358,11 @@ func (s *Service) watchEvents() {
case events.UpdateLatest:
safe.RLock(func() {
s.latest = event.Version
s.latestLegacy = event.VersionLegacy
s.latest = event.Release
}, s.latestLock)
_ = s.SendEvent(NewUpdateVersionChangedEvent())
_ = s.SendEvent(NewUpdateVersionChangedEvent()) // This updates the release notes page and landing page.
case events.UpdateAvailable:
switch {
@ -366,10 +371,11 @@ func (s *Service) watchEvents() {
case !event.Silent:
safe.RLock(func() {
s.target = event.Version
s.targetLegacy = event.VersionLegacy
s.target = event.Release
}, s.targetLock)
_ = s.SendEvent(NewUpdateManualReadyEvent(event.Version.Version.String()))
_ = s.SendEvent(NewUpdateManualReadyEvent(event.GetLatestVersion()))
}
case events.UpdateInstalled:
@ -391,8 +397,10 @@ func (s *Service) watchEvents() {
if s.latest.Version != nil {
latest = s.latest.Version.String()
} else if version, ok := s.checkLatestVersion(); ok {
latest = version.Version.String()
} else if s.latestLegacy.Version != nil {
latest = s.latestLegacy.Version.String()
} else if latestVersion, ok := s.checkLatestVersion(); ok {
latest = latestVersion
} else {
latest = "unknown"
}
@ -517,7 +525,7 @@ func (s *Service) triggerReset() {
s.bridge.FactoryReset(context.Background())
}
func (s *Service) checkLatestVersion() (updater.VersionInfo, bool) {
func (s *Service) checkLatestVersion() (string, bool) {
updateCh, done := s.bridge.GetEvents(events.UpdateLatest{})
defer done()
@ -526,14 +534,13 @@ func (s *Service) checkLatestVersion() (updater.VersionInfo, bool) {
select {
case event := <-updateCh:
if latest, ok := event.(events.UpdateLatest); ok {
return latest.Version, true
return latest.GetLatestVersion(), true
}
case <-time.After(5 * time.Second):
// ...
}
return updater.VersionInfo{}, false
return "", false
}
func newTLSConfig() (*tls.Config, []byte, error) {

View File

@ -298,7 +298,14 @@ func (s *Service) ReleaseNotesPageLink(_ context.Context, _ *emptypb.Empty) (*wr
s.latestLock.RUnlock()
}()
return wrapperspb.String(s.latest.ReleaseNotesPage), nil
var releaseNotesPage string
if !s.latestLegacy.IsEmpty() {
releaseNotesPage = s.latestLegacy.ReleaseNotesPage
} else if !s.latest.IsEmpty() {
releaseNotesPage = s.latest.ReleaseNotesPage
}
return wrapperspb.String(releaseNotesPage), nil
}
func (s *Service) LandingPageLink(_ context.Context, _ *emptypb.Empty) (*wrapperspb.StringValue, error) {
@ -308,7 +315,14 @@ func (s *Service) LandingPageLink(_ context.Context, _ *emptypb.Empty) (*wrapper
s.latestLock.RUnlock()
}()
return wrapperspb.String(s.latest.LandingPage), nil
var landingPage string
if !s.latestLegacy.IsEmpty() {
landingPage = s.latestLegacy.LandingPage
} else if !s.latest.IsEmpty() {
landingPage = s.latest.LandingPage
}
return wrapperspb.String(landingPage), nil
}
func (s *Service) SetColorSchemeName(_ context.Context, name *wrapperspb.StringValue) (*emptypb.Empty, error) {
@ -617,7 +631,11 @@ func (s *Service) InstallUpdate(_ context.Context, _ *emptypb.Empty) (*emptypb.E
defer async.HandlePanic(s.panicHandler)
safe.RLock(func() {
s.bridge.InstallUpdate(s.target)
if !s.targetLegacy.IsEmpty() {
s.bridge.InstallUpdateLegacy(s.targetLegacy)
} else if !s.target.IsEmpty() {
s.bridge.InstallUpdate(s.target)
}
}, s.targetLock)
}()

View File

@ -212,7 +212,7 @@ func buildSessionInfoList(dir string) (map[SessionID]*sessionInfo, error) {
}
rx := regexp.MustCompile(`^(\d{8}_\d{9})_.*\.log$`)
match := rx.FindStringSubmatch(entry.Name())
if match == nil || len(match) < 2 {
if len(match) < 2 {
continue
}

View File

@ -157,7 +157,7 @@ func (r *Reporter) ReportExceptionWithContext(i interface{}, context map[string]
SkipDuringUnwind()
err := fmt.Errorf("recover: %v", i)
return r.scopedReport(context, func() {
return r.scopedReport(context, func(_ *sentry.Scope) {
SkipDuringUnwind()
if eventID := sentry.CaptureException(err); eventID != nil {
logrus.WithError(err).
@ -169,7 +169,20 @@ func (r *Reporter) ReportExceptionWithContext(i interface{}, context map[string]
func (r *Reporter) ReportMessageWithContext(msg string, context map[string]interface{}) error {
SkipDuringUnwind()
return r.scopedReport(context, func() {
return r.scopedReport(context, func(_ *sentry.Scope) {
SkipDuringUnwind()
if eventID := sentry.CaptureMessage(msg); eventID != nil {
logrus.WithField("message", msg).
WithField("reportID", *eventID).
Warn("Captured message")
}
})
}
func (r *Reporter) ReportWarningWithContext(msg string, context map[string]interface{}) error {
SkipDuringUnwind()
return r.scopedReport(context, func(scope *sentry.Scope) {
scope.SetLevel(sentry.LevelWarning)
SkipDuringUnwind()
if eventID := sentry.CaptureMessage(msg); eventID != nil {
logrus.WithField("message", msg).
@ -180,7 +193,7 @@ func (r *Reporter) ReportMessageWithContext(msg string, context map[string]inter
}
// Report reports a sentry crash with stacktrace from all goroutines.
func (r *Reporter) scopedReport(context map[string]interface{}, doReport func()) error {
func (r *Reporter) scopedReport(context map[string]interface{}, doReport func(scope *sentry.Scope)) error {
SkipDuringUnwind()
if os.Getenv("PROTONMAIL_ENV") == "dev" {
@ -206,7 +219,7 @@ func (r *Reporter) scopedReport(context map[string]interface{}, doReport func())
map[string]sentry.Context{"bridge": contextToString(context)},
)
}
doReport()
doReport(scope)
})
if !sentry.Flush(time.Second * 10) {
@ -287,3 +300,25 @@ func contextToString(context sentry.Context) sentry.Context {
return res
}
type NullSentryReporter struct{}
func (n NullSentryReporter) ReportException(any) error {
return nil
}
func (n NullSentryReporter) ReportMessage(string) error {
return nil
}
func (n NullSentryReporter) ReportMessageWithContext(string, reporter.Context) error {
return nil
}
func (n NullSentryReporter) ReportWarningWithContext(string, reporter.Context) error {
return nil
}
func (n NullSentryReporter) ReportExceptionWithContext(any, reporter.Context) error {
return nil
}

View File

@ -0,0 +1,388 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice
import (
"context"
"errors"
"fmt"
"strings"
"github.com/ProtonMail/gluon/db"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/sirupsen/logrus"
)
type GluonLabelNameProvider interface {
GetUserMailboxByName(ctx context.Context, addrID string, labelName []string) (imap.MailboxData, error)
}
type gluonIDProvider interface {
GetGluonID(addrID string) (string, bool)
}
type sentryReporter interface {
ReportMessageWithContext(string, reporter.Context) error
ReportWarningWithContext(string, reporter.Context) error
}
type apiClient interface {
GetLabel(ctx context.Context, labelID string, labelTypes ...proton.LabelType) (proton.Label, error)
}
type mailboxFetcherFn func(ctx context.Context, label proton.Label) (imap.MailboxData, error)
type mailboxMessageCountFetcherFn func(ctx context.Context, internalMailboxID imap.InternalMailboxID) (int, error)
type LabelConflictManager struct {
gluonLabelNameProvider GluonLabelNameProvider
gluonIDProvider gluonIDProvider
client apiClient
reporter sentryReporter
featureFlagProvider unleash.FeatureFlagValueProvider
}
func NewLabelConflictManager(
gluonLabelNameProvider GluonLabelNameProvider,
gluonIDProvider gluonIDProvider,
client apiClient,
reporter sentryReporter,
featureFlagProvider unleash.FeatureFlagValueProvider) *LabelConflictManager {
return &LabelConflictManager{
gluonLabelNameProvider: gluonLabelNameProvider,
gluonIDProvider: gluonIDProvider,
client: client,
reporter: reporter,
featureFlagProvider: featureFlagProvider,
}
}
func (m *LabelConflictManager) generateMailboxFetcher(connectors []*Connector) mailboxFetcherFn {
return func(ctx context.Context, label proton.Label) (imap.MailboxData, error) {
for _, updateCh := range connectors {
addrID, ok := m.gluonIDProvider.GetGluonID(updateCh.addrID)
if !ok {
continue
}
return m.gluonLabelNameProvider.GetUserMailboxByName(ctx, addrID, GetMailboxName(label))
}
return imap.MailboxData{}, errors.New("no gluon connectors found")
}
}
func (m *LabelConflictManager) generateMailboxMessageCountFetcher(connectors []*Connector) mailboxMessageCountFetcherFn {
return func(ctx context.Context, id imap.InternalMailboxID) (int, error) {
var countSum int
var errs []error
for _, conn := range connectors {
count, err := conn.GetMailboxMessageCount(ctx, id)
countSum += count
errs = append(errs, err)
}
return countSum, errors.Join(errs...)
}
}
type LabelConflictResolver interface {
ResolveConflict(ctx context.Context, label proton.Label, visited map[string]bool) (func() []imap.Update, error)
}
type labelConflictResolverImpl struct {
mailboxFetch mailboxFetcherFn
client apiClient
reporter sentryReporter
log *logrus.Entry
}
type nullLabelConflictResolverImpl struct {
}
func (r *nullLabelConflictResolverImpl) ResolveConflict(_ context.Context, _ proton.Label, _ map[string]bool) (func() []imap.Update, error) {
return func() []imap.Update {
return []imap.Update{}
}, nil
}
func (m *LabelConflictManager) NewConflictResolver(connectors []*Connector) LabelConflictResolver {
if m.featureFlagProvider.GetFlagValue(unleash.LabelConflictResolverDisabled) {
return &nullLabelConflictResolverImpl{}
}
return &labelConflictResolverImpl{
mailboxFetch: m.generateMailboxFetcher(connectors),
client: m.client,
reporter: m.reporter,
log: logrus.WithFields(logrus.Fields{
"pkg": "imapservice/labelConflictResolver",
"numberOfConnectors": len(connectors),
}),
}
}
func (r *labelConflictResolverImpl) ResolveConflict(ctx context.Context, label proton.Label, visited map[string]bool) (func() []imap.Update, error) {
logger := r.log.WithFields(logrus.Fields{
"labelID": label.ID,
"labelPath": hashLabelPaths(GetMailboxName(label)),
})
// For system type labels we shouldn't care.
var updateFns []func() []imap.Update
// There's a cycle, such as in a label swap operation, we'll need to temporarily rename the label.
// The change will be overwritten by one of the previous recursive calls.
if visited[label.ID] {
logrus.Info("Cycle detected, applying temporary rename")
fn := func() []imap.Update {
return []imap.Update{newMailboxUpdatedOrCreated(imap.MailboxID(label.ID), getMailboxNameWithTempPrefix(label))}
}
updateFns = append(updateFns, fn)
return combineIMAPUpdateFns(updateFns), nil
}
visited[label.ID] = true
// Fetch the gluon mailbox data and verify whether there are conflicts with the name.
mailboxData, err := r.mailboxFetch(ctx, label)
if err != nil {
// Name is free, create the mailbox.
if db.IsErrNotFound(err) {
logger.Info("Label not found in DB, creating mailbox.")
fn := func() []imap.Update {
return []imap.Update{newMailboxUpdatedOrCreated(imap.MailboxID(label.ID), GetMailboxName(label))}
}
updateFns = append(updateFns, fn)
return combineIMAPUpdateFns(updateFns), nil
}
return combineIMAPUpdateFns(updateFns), err
}
// Verify whether the label name corresponds to the same label ID. If true terminate, we don't need to update.
if mailboxData.RemoteID == label.ID {
logger.Info("Mailbox name matches label ID, no conflict.")
return combineIMAPUpdateFns(updateFns), nil
}
// This means we've found a conflict. So let's log it.
logger = logger.WithFields(logrus.Fields{
"conflictingLabelID": mailboxData.RemoteID,
"conflictingLabelPath": hashLabelPaths(mailboxData.BridgeName),
})
logger.Info("Label conflict found")
// If the label name belongs to some other label ID. Fetch it's state from the remote.
conflictingLabel, err := r.client.GetLabel(ctx, mailboxData.RemoteID, proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem)
if err != nil {
// If it's not present on the remote we should delete it. And create the new label.
if errors.Is(err, proton.ErrNoSuchLabel) {
logger.Info("Conflicting label does not exist on remote. Deleting.")
fn := func() []imap.Update {
return []imap.Update{
imap.NewMailboxDeleted(imap.MailboxID(mailboxData.RemoteID)), // Should this be with remote ID
newMailboxUpdatedOrCreated(imap.MailboxID(label.ID), GetMailboxName(label)),
}
}
updateFns = append(updateFns, fn)
return combineIMAPUpdateFns(updateFns), nil
}
logger.WithError(err).Error("Failed to fetch conflicting label from remote.")
return combineIMAPUpdateFns(updateFns), err
}
// Check if the conflicting label name has changed. If not, then this is a BE inconsistency.
if compareLabelNames(GetMailboxName(conflictingLabel), mailboxData.BridgeName) {
if err := r.reporter.ReportMessageWithContext("Unexpected label conflict", reporter.Context{
"labelID": label.ID,
"conflictingLabelID": conflictingLabel.ID,
}); err != nil {
logger.WithError(err).Error("Failed to report update error")
}
err := fmt.Errorf("unexpected label conflict: the name of label ID %s is already used by label ID %s", label.ID, conflictingLabel.ID)
return combineIMAPUpdateFns(updateFns), err
}
// The name of the conflicting label has changed on the remote. We need to verify that the new name does not conflict with anything else.
// Thus, a recursive check can be performed.
logger.WithField("conflictingLabelNewPath", hashLabelPaths(conflictingLabel.Path)).
Info("Conflicting label name has changed. Recursively resolving conflict.")
childUpdateFns, err := r.ResolveConflict(ctx, conflictingLabel, visited)
if err != nil {
return combineIMAPUpdateFns(updateFns), err
}
updateFns = append(updateFns, childUpdateFns)
fn := func() []imap.Update {
return []imap.Update{newMailboxUpdatedOrCreated(imap.MailboxID(label.ID), GetMailboxName(label))}
}
updateFns = append(updateFns, fn)
return combineIMAPUpdateFns(updateFns), nil
}
func combineIMAPUpdateFns(updateFunctions []func() []imap.Update) func() []imap.Update {
return func() []imap.Update {
var updates []imap.Update
for _, fn := range updateFunctions {
updates = append(updates, fn()...)
}
return updates
}
}
func compareLabelNames(labelName1, labelName2 []string) bool {
name1 := strings.Join(labelName1, "")
name2 := strings.Join(labelName2, "")
return name1 == name2
}
func hashLabelPaths(path []string) string {
return algo.HashBase64SHA256(strings.Join(path, ""))
}
type InternalLabelConflictResolver interface {
ResolveConflict(ctx context.Context, apiLabels map[string]proton.Label) (func() []imap.Update, error)
}
type internalLabelConflictResolverImpl struct {
mailboxFetch mailboxFetcherFn
mailboxMessageCountFetch mailboxMessageCountFetcherFn
userLabelConflictResolver LabelConflictResolver
allowNonEmptyMailboxDeletion bool
client apiClient
reporter sentryReporter
log *logrus.Entry
}
type nullInternalLabelConflictResolver struct{}
func (r *nullInternalLabelConflictResolver) ResolveConflict(_ context.Context, _ map[string]proton.Label) (func() []imap.Update, error) {
return func() []imap.Update { return []imap.Update{} }, nil
}
func (m *LabelConflictManager) NewInternalLabelConflictResolver(connectors []*Connector) InternalLabelConflictResolver {
if m.featureFlagProvider.GetFlagValue(unleash.InternalLabelConflictResolverDisabled) {
return &nullInternalLabelConflictResolver{}
}
return &internalLabelConflictResolverImpl{
mailboxFetch: m.generateMailboxFetcher(connectors),
mailboxMessageCountFetch: m.generateMailboxMessageCountFetcher(connectors),
userLabelConflictResolver: m.NewConflictResolver(connectors),
allowNonEmptyMailboxDeletion: m.featureFlagProvider.GetFlagValue(unleash.ItnternalLabelConflictNonEmptyMailboxDeletion),
client: m.client,
reporter: m.reporter,
log: logrus.WithFields(logrus.Fields{
"pkg": "imapservice/internalLabelConflictResolver",
"numberOfConnectors": len(connectors),
}),
}
}
func (r *internalLabelConflictResolverImpl) ResolveConflict(ctx context.Context, apiLabels map[string]proton.Label) (func() []imap.Update, error) {
updateFns := []func() []imap.Update{}
for _, prefix := range []string{folderPrefix, labelPrefix} {
internalLabel := proton.Label{
Path: []string{prefix},
ID: prefix,
Name: prefix,
}
mbox, err := r.mailboxFetch(ctx, internalLabel)
if err != nil {
if db.IsErrNotFound(err) {
continue
}
return nil, err
}
// If the ID's match then we don't have a discrepancy.
if mbox.RemoteID == internalLabel.ID {
continue
}
logFields := logrus.Fields{
"internalLabelID": internalLabel.ID,
"internalLabelName": internalLabel.Name,
"conflictingLabelID": mbox.RemoteID,
"conflictingLabelName": strings.Join(mbox.BridgeName, "/"),
}
reporterContext := reporter.Context(logFields)
logger := r.log.WithFields(logFields)
logger.Info("Encountered conflict, resolving.")
// There is a discrepancy, let's see if it comes from API.
apiLabel, ok := apiLabels[mbox.RemoteID]
if !ok {
// Label does not come from API, we should delete it.
// Due diligence, check if there are any messages associated with the mailbox.
msgCount, _ := r.mailboxMessageCountFetch(ctx, mbox.InternalID)
if msgCount != 0 {
logger.WithField("conflictingLabelMessageCount", msgCount).Info("Non-API conflicting label has associated messages")
reporterContext["conflictingLabelMessageCount"] = msgCount
if rerr := r.reporter.ReportWarningWithContext("Internal mailbox name conflict. Conflicting non-API label has messages.",
reporterContext); rerr != nil {
logger.WithError(rerr).Error("Failed to send report to sentry")
}
if !r.allowNonEmptyMailboxDeletion {
return combineIMAPUpdateFns(updateFns), fmt.Errorf("internal mailbox conflicting non-api label has associated messages")
}
}
fn := func() []imap.Update {
return []imap.Update{imap.NewMailboxDeletedSilent(imap.MailboxID(mbox.RemoteID))}
}
updateFns = append(updateFns, fn)
continue
}
reporterContext["conflictingLabelType"] = apiLabel.Type
// Label is indeed from API let's see if it's name has changed.
if compareLabelNames(GetMailboxName(apiLabel), internalLabel.Path) {
logger.Error("Conflict, same-name mailbox is returned by API")
if err := r.reporter.ReportMessageWithContext("Internal mailbox name conflict. Same-name mailbox is returned by API", reporterContext); err != nil {
logger.WithError(err).Error("Could not send report to sentry")
}
return combineIMAPUpdateFns(updateFns), fmt.Errorf("API label %s conflicts with internal label %s",
GetMailboxName(apiLabel),
strings.Join(mbox.BridgeName, "/"),
)
}
// If it's name has changed then we ought to rename it while still taking care of potential conflicts.
labelRenameUpdates, err := r.userLabelConflictResolver.ResolveConflict(ctx, apiLabel, make(map[string]bool))
if err != nil {
reporterContext["err"] = err.Error()
if rerr := r.reporter.ReportMessageWithContext("Failed to resolve internal mailbox conflict", reporterContext); rerr != nil {
logger.WithError(rerr).Error("Could not send report to sentry")
}
return combineIMAPUpdateFns(updateFns),
fmt.Errorf("failed to resolve user label conflict for '%s': %w", apiLabel.Name, err)
}
updateFns = append(updateFns, labelRenameUpdates)
}
return combineIMAPUpdateFns(updateFns), nil
}

View File

@ -0,0 +1,961 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice_test
import (
"context"
"errors"
"fmt"
"testing"
"github.com/ProtonMail/gluon/db"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type ffProviderFalse struct{}
type ffProviderTrue struct{}
func (f ffProviderFalse) GetFlagValue(_ string) bool {
return false
}
func (f ffProviderTrue) GetFlagValue(_ string) bool {
return true
}
type mockLabelNameProvider struct {
mock.Mock
}
func (m *mockLabelNameProvider) GetUserMailboxByName(ctx context.Context, addrID string, labelName []string) (imap.MailboxData, error) {
args := m.Called(ctx, addrID, labelName)
v, ok := args.Get(0).(imap.MailboxData)
if !ok {
return imap.MailboxData{}, fmt.Errorf("failed to assert type")
}
return v, args.Error(1)
}
type mockIDProvider struct {
mock.Mock
}
func (m *mockIDProvider) GetGluonID(addrID string) (string, bool) {
args := m.Called(addrID)
return args.String(0), args.Bool(1)
}
type mockAPIClient struct {
mock.Mock
}
func (m *mockAPIClient) GetLabel(ctx context.Context, id string, types ...proton.LabelType) (proton.Label, error) {
args := m.Called(ctx, id, types)
v, ok := args.Get(0).(proton.Label)
if !ok {
return proton.Label{}, fmt.Errorf("failed to assert type")
}
return v, args.Error(1)
}
type mockReporter struct {
mock.Mock
}
func (m *mockReporter) ReportMessageWithContext(msg string, ctx reporter.Context) error {
args := m.Called(msg, ctx)
return args.Error(0)
}
func (m *mockReporter) ReportWarningWithContext(msg string, ctx reporter.Context) error {
args := m.Called(msg, ctx)
return args.Error(0)
}
func TestResolveConflict_UnexpectedLabelConflict(t *testing.T) {
ctx := context.Background()
label := proton.Label{
ID: "label-1",
Path: []string{"Work"},
Type: proton.LabelTypeLabel,
}
conflictingLabel := proton.Label{
ID: "label-2",
Path: []string{"Work"},
Type: proton.LabelTypeLabel,
}
conflictMbox := imap.MailboxData{
RemoteID: "label-2",
BridgeName: []string{"Labels", "Work"},
}
mockLabelProvider := new(mockLabelNameProvider)
mockIDProvider := new(mockIDProvider)
mockClient := new(mockAPIClient)
mockReporter := new(mockReporter)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id", imapservice.GetMailboxName(label)).
Return(conflictMbox, nil)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id", true)
mockClient.On("GetLabel", mock.Anything, "label-2", mock.Anything).
Return(conflictingLabel, nil)
mockReporter.On("ReportMessageWithContext", "Unexpected label conflict", mock.Anything).
Return(nil)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
resolver := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{}).
NewConflictResolver([]*imapservice.Connector{connector})
visited := make(map[string]bool)
_, err := resolver.ResolveConflict(ctx, label, visited)
assert.Error(t, err)
assert.Contains(t, err.Error(), "unexpected label conflict")
}
func TestResolveDiscrepancy_LabelDoesNotExist(t *testing.T) {
ctx := context.Background()
label := proton.Label{
ID: "label-id-1",
Name: "Inbox",
Type: proton.LabelTypeLabel,
}
mockLabelProvider := new(mockLabelNameProvider)
mockIDProvider := new(mockIDProvider)
mockClient := new(mockAPIClient)
mockReporter := new(mockReporter)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", imapservice.GetMailboxName(label)).
Return(imap.MailboxData{}, db.ErrNotFound)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
visited := make(map[string]bool)
fn, err := resolver.ResolveConflict(ctx, label, visited)
assert.NoError(t, err)
updates := fn()
assert.Len(t, updates, 1)
muc, ok := updates[0].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(label.ID), muc.Mailbox.ID)
}
func TestResolveConflict_MailboxFetchError(t *testing.T) {
ctx := context.Background()
label := proton.Label{
ID: "111",
Path: []string{"Work"},
Type: proton.LabelTypeLabel,
}
mockLabelProvider := new(mockLabelNameProvider)
mockIDProvider := new(mockIDProvider)
mockClient := new(mockAPIClient)
mockReporter := new(mockReporter)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id", imapservice.GetMailboxName(label)).
Return(imap.MailboxData{}, errors.New("database connection error"))
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id", true)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
resolver := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{}).
NewConflictResolver([]*imapservice.Connector{connector})
visited := make(map[string]bool)
_, err := resolver.ResolveConflict(ctx, label, visited)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database connection error")
}
func TestResolveDiscrepancy_ConflictingLabelDeletedRemotely(t *testing.T) {
ctx := context.Background()
label := proton.Label{
ID: "label-new",
Path: []string{"Work"},
Type: proton.LabelTypeLabel,
}
conflictMbox := imap.MailboxData{
RemoteID: "label-old",
BridgeName: []string{"Labels", "Work"},
}
mockLabelProvider := new(mockLabelNameProvider)
mockIDProvider := new(mockIDProvider)
mockClient := new(mockAPIClient)
mockReporter := new(mockReporter)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", imapservice.GetMailboxName(label)).
Return(conflictMbox, nil)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockClient.On("GetLabel", mock.Anything, "label-old", mock.Anything).
Return(proton.Label{}, proton.ErrNoSuchLabel)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
visited := make(map[string]bool)
fn, err := resolver.ResolveConflict(ctx, label, visited)
assert.NoError(t, err)
updates := fn()
assert.Len(t, updates, 2)
deleted, ok := updates[0].(*imap.MailboxDeleted)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID("label-old"), deleted.MailboxID)
updated, ok := updates[1].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, "Work", updated.Mailbox.Name[len(updated.Mailbox.Name)-1])
}
func TestResolveDiscrepancy_LabelAlreadyCorrect(t *testing.T) {
ctx := context.Background()
label := proton.Label{
ID: "label-id-1",
Name: "Personal",
Type: proton.LabelTypeLabel,
}
mbox := imap.MailboxData{
RemoteID: "label-id-1",
BridgeName: []string{"Labels", "Personal"},
}
mockLabelProvider := new(mockLabelNameProvider)
mockIDProvider := new(mockIDProvider)
mockClient := new(mockAPIClient)
mockReporter := new(mockReporter)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", imapservice.GetMailboxName(label)).
Return(mbox, nil)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
visited := make(map[string]bool)
fn, err := resolver.ResolveConflict(ctx, label, visited)
assert.NoError(t, err)
assert.Len(t, fn(), 0)
}
func TestResolveConflict_DeepNestedPath(t *testing.T) {
ctx := context.Background()
label := proton.Label{
ID: "111",
Path: []string{"Level1", "Level2", "Level3", "DeepFolder"},
Type: proton.LabelTypeFolder,
}
mockLabelProvider := new(mockLabelNameProvider)
mockIDProvider := new(mockIDProvider)
mockClient := new(mockAPIClient)
mockReporter := new(mockReporter)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id", imapservice.GetMailboxName(label)).
Return(imap.MailboxData{}, db.ErrNotFound)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id", true)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
resolver := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{}).
NewConflictResolver([]*imapservice.Connector{connector})
visited := make(map[string]bool)
fn, err := resolver.ResolveConflict(ctx, label, visited)
assert.NoError(t, err)
updates := fn()
assert.Len(t, updates, 1)
updated, ok := updates[0].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID("111"), updated.Mailbox.ID)
expectedName := imapservice.GetMailboxName(label)
assert.Equal(t, expectedName, updated.Mailbox.Name)
}
func TestResolveLabelDiscrepancy_LabelSwap(t *testing.T) {
apiLabels := []proton.Label{
{
ID: "111",
Path: []string{"X"},
Type: proton.LabelTypeLabel,
},
{
ID: "222",
Path: []string{"Y"},
Type: proton.LabelTypeLabel,
},
}
gluonLabels := []imap.MailboxData{
{
RemoteID: "111",
BridgeName: []string{"Labels", "Y"},
},
{
RemoteID: "222",
BridgeName: []string{"Labels", "X"},
},
}
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
for _, mbox := range gluonLabels {
mockLabelProvider.
On("GetUserMailboxByName", mock.Anything, "gluon-id-1", mbox.BridgeName).
Return(mbox, nil)
}
for _, label := range apiLabels {
mockClient.
On("GetLabel", mock.Anything, label.ID, []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).
Return(label, nil)
}
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
visited := make(map[string]bool)
fn, err := resolver.ResolveConflict(context.Background(), apiLabels[0], visited)
require.NoError(t, err)
updates := fn()
assert.NotEmpty(t, updates)
assert.Equal(t, 3, len(updates)) // We expect three calls to be made for a swap operation.
updateOne, ok := updates[0].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateOne.Mailbox.ID)
assert.Equal(t, "tmp_X", updateOne.Mailbox.Name[len(updateOne.Mailbox.Name)-1])
updateTwo, ok := updates[1].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[1].ID), updateTwo.Mailbox.ID)
assert.Equal(t, "Y", updateTwo.Mailbox.Name[len(updateTwo.Mailbox.Name)-1])
updateThree, ok := updates[2].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateThree.Mailbox.ID)
assert.Equal(t, "X", updateThree.Mailbox.Name[len(updateThree.Mailbox.Name)-1])
}
func TestResolveLabelDiscrepancy_LabelSwapExtended(t *testing.T) {
apiLabels := []proton.Label{
{
ID: "111",
Path: []string{"X"},
Type: proton.LabelTypeLabel,
},
{
ID: "222",
Path: []string{"Y"},
Type: proton.LabelTypeLabel,
},
{
ID: "333",
Path: []string{"Z"},
Type: proton.LabelTypeLabel,
},
{
ID: "444",
Path: []string{"D"},
Type: proton.LabelTypeLabel,
},
}
gluonLabels := []imap.MailboxData{
{
RemoteID: "111",
BridgeName: []string{"Labels", "D"},
},
{
RemoteID: "222",
BridgeName: []string{"Labels", "Z"},
},
{
RemoteID: "333",
BridgeName: []string{"Labels", "Y"},
},
{
RemoteID: "444",
BridgeName: []string{"Labels", "X"},
},
}
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
for _, mbox := range gluonLabels {
mockLabelProvider.
On("GetUserMailboxByName", mock.Anything, "gluon-id-1", mbox.BridgeName).
Return(mbox, nil)
}
for _, label := range apiLabels {
mockClient.
On("GetLabel", mock.Anything, label.ID, []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).
Return(label, nil)
}
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
fn, err := resolver.ResolveConflict(context.Background(), apiLabels[0], make(map[string]bool))
require.NoError(t, err)
updates := fn()
assert.NotEmpty(t, updates)
// Three calls yet again for a swap operation.
assert.Equal(t, 3, len(updates))
updateOne, ok := updates[0].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateOne.Mailbox.ID)
assert.Equal(t, "tmp_X", updateOne.Mailbox.Name[len(updateOne.Mailbox.Name)-1])
updateTwo, ok := updates[1].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[3].ID), updateTwo.Mailbox.ID)
assert.Equal(t, "D", updateTwo.Mailbox.Name[len(updateTwo.Mailbox.Name)-1])
updateThree, ok := updates[2].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateThree.Mailbox.ID)
assert.Equal(t, "X", updateThree.Mailbox.Name[len(updateThree.Mailbox.Name)-1])
// Fix the secondary swap.
fn, err = resolver.ResolveConflict(context.Background(), apiLabels[1], make(map[string]bool))
require.NoError(t, err)
updates = fn()
assert.Equal(t, 3, len(updates))
updateOne, ok = updates[0].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[1].ID), updateOne.Mailbox.ID)
assert.Equal(t, "tmp_Y", updateOne.Mailbox.Name[len(updateOne.Mailbox.Name)-1])
updateTwo, ok = updates[1].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[2].ID), updateTwo.Mailbox.ID)
assert.Equal(t, "Z", updateTwo.Mailbox.Name[len(updateTwo.Mailbox.Name)-1])
updateThree, ok = updates[2].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[1].ID), updateThree.Mailbox.ID)
assert.Equal(t, "Y", updateThree.Mailbox.Name[len(updateThree.Mailbox.Name)-1])
}
func TestResolveLabelDiscrepancy_LabelSwapCyclic(t *testing.T) {
apiLabels := []proton.Label{
{ID: "111", Path: []string{"A"}, Type: proton.LabelTypeLabel},
{ID: "222", Path: []string{"B"}, Type: proton.LabelTypeLabel},
{ID: "333", Path: []string{"C"}, Type: proton.LabelTypeLabel},
{ID: "444", Path: []string{"D"}, Type: proton.LabelTypeLabel},
}
gluonLabels := []imap.MailboxData{
{RemoteID: "111", BridgeName: []string{"Labels", "D"}}, // A <- D
{RemoteID: "222", BridgeName: []string{"Labels", "A"}}, // B <- A
{RemoteID: "333", BridgeName: []string{"Labels", "B"}}, // C <- B
{RemoteID: "444", BridgeName: []string{"Labels", "C"}}, // D <- C
}
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
for _, mbox := range gluonLabels {
mockLabelProvider.
On("GetUserMailboxByName", mock.Anything, "gluon-id-1", mbox.BridgeName).
Return(mbox, nil)
}
for _, label := range apiLabels {
mockClient.
On("GetLabel", mock.Anything, label.ID, []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).
Return(label, nil)
}
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
fn, err := resolver.ResolveConflict(context.Background(), apiLabels[0], make(map[string]bool))
require.NoError(t, err)
updates := fn()
assert.NotEmpty(t, updates)
assert.Equal(t, 5, len(updates))
updateOne, ok := updates[0].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateOne.Mailbox.ID)
assert.Equal(t, "tmp_A", updateOne.Mailbox.Name[len(updateOne.Mailbox.Name)-1])
updateTwo, ok := updates[1].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[3].ID), updateTwo.Mailbox.ID)
assert.Equal(t, "D", updateTwo.Mailbox.Name[len(updateTwo.Mailbox.Name)-1])
updateThree, ok := updates[2].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[2].ID), updateThree.Mailbox.ID)
assert.Equal(t, "C", updateThree.Mailbox.Name[len(updateThree.Mailbox.Name)-1])
updateFour, ok := updates[3].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[1].ID), updateFour.Mailbox.ID)
assert.Equal(t, "B", updateFour.Mailbox.Name[len(updateFour.Mailbox.Name)-1])
updateFive, ok := updates[4].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateFive.Mailbox.ID)
assert.Equal(t, "A", updateFive.Mailbox.Name[len(updateFive.Mailbox.Name)-1])
}
func TestResolveLabelDiscrepancy_LabelSwapCyclicWithDeletedLabel(t *testing.T) {
apiLabels := []proton.Label{
{ID: "111", Path: []string{"A"}, Type: proton.LabelTypeLabel},
{ID: "333", Path: []string{"C"}, Type: proton.LabelTypeLabel},
{ID: "444", Path: []string{"D"}, Type: proton.LabelTypeLabel},
}
gluonLabels := []imap.MailboxData{
{RemoteID: "111", BridgeName: []string{"Labels", "D"}},
{RemoteID: "222", BridgeName: []string{"Labels", "A"}},
{RemoteID: "333", BridgeName: []string{"Labels", "B"}},
{RemoteID: "444", BridgeName: []string{"Labels", "C"}},
}
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
for _, mbox := range gluonLabels {
mockLabelProvider.
On("GetUserMailboxByName", mock.Anything, "gluon-id-1", mbox.BridgeName).
Return(mbox, nil)
}
for _, label := range apiLabels {
mockClient.
On("GetLabel", mock.Anything, label.ID, []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).
Return(label, nil)
}
mockClient.On("GetLabel", mock.Anything, "222", []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).Return(proton.Label{}, proton.ErrNoSuchLabel)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewConflictResolver(connectors)
fn, err := resolver.ResolveConflict(context.Background(), apiLabels[2], make(map[string]bool))
require.NoError(t, err)
updates := fn()
assert.NotEmpty(t, updates)
assert.Equal(t, 3, len(updates))
updateOne, ok := updates[0].(*imap.MailboxDeleted)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID("222"), updateOne.MailboxID)
updateTwo, ok := updates[1].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[0].ID), updateTwo.Mailbox.ID)
assert.Equal(t, "A", updateTwo.Mailbox.Name[len(updateTwo.Mailbox.Name)-1])
updateThree, ok := updates[2].(*imap.MailboxUpdatedOrCreated)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID(apiLabels[2].ID), updateThree.Mailbox.ID)
assert.Equal(t, "D", updateThree.Mailbox.Name[len(updateThree.Mailbox.Name)-1])
}
func TestResolveLabelDiscrepancy_LabelSwapCyclicWithDeletedLabel_KillSwitchEnabled(t *testing.T) {
apiLabels := []proton.Label{
{ID: "111", Path: []string{"A"}, Type: proton.LabelTypeLabel},
{ID: "333", Path: []string{"C"}, Type: proton.LabelTypeLabel},
{ID: "444", Path: []string{"D"}, Type: proton.LabelTypeLabel},
}
gluonLabels := []imap.MailboxData{
{RemoteID: "111", BridgeName: []string{"Labels", "D"}},
{RemoteID: "222", BridgeName: []string{"Labels", "A"}},
{RemoteID: "333", BridgeName: []string{"Labels", "B"}},
{RemoteID: "444", BridgeName: []string{"Labels", "C"}},
}
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
for _, mbox := range gluonLabels {
mockLabelProvider.
On("GetUserMailboxByName", mock.Anything, "gluon-id-1", mbox.BridgeName).
Return(mbox, nil)
}
for _, label := range apiLabels {
mockClient.
On("GetLabel", mock.Anything, label.ID, []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).
Return(label, nil)
}
mockClient.On("GetLabel", mock.Anything, "222", []proton.LabelType{proton.LabelTypeFolder, proton.LabelTypeLabel, proton.LabelTypeSystem}).Return(proton.Label{}, proton.ErrNoSuchLabel)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderTrue{})
resolver := manager.NewConflictResolver(connectors)
fn, err := resolver.ResolveConflict(context.Background(), apiLabels[2], make(map[string]bool))
require.NoError(t, err)
updates := fn()
assert.Empty(t, updates)
}
func TestInternalLabelConflictResolver_NoConflicts(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{}, db.ErrNotFound)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Labels"}).
Return(imap.MailboxData{}, db.ErrNotFound)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
apiLabels := make(map[string]proton.Label)
fn, err := resolver.ResolveConflict(ctx, apiLabels)
assert.NoError(t, err)
updates := fn()
assert.Empty(t, updates)
}
func TestInternalLabelConflictResolver_CorrectIDs(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{RemoteID: "Folders", BridgeName: []string{"Folders"}}, nil)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Labels"}).
Return(imap.MailboxData{RemoteID: "Labels", BridgeName: []string{"Labels"}}, nil)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
apiLabels := make(map[string]proton.Label)
fn, err := resolver.ResolveConflict(ctx, apiLabels)
assert.NoError(t, err)
updates := fn()
assert.Empty(t, updates)
}
type mockMailboxCountProvider struct {
mock.Mock
}
func (m *mockMailboxCountProvider) GetUserMailboxCountByInternalID(ctx context.Context, addrID string, internalID imap.InternalMailboxID) (int, error) {
args := m.Called(ctx, addrID, internalID)
return args.Int(0), args.Error(1)
}
func TestInternalLabelConflictResolver_ConflictingNonAPILabel_ZeroCount(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockCountProvider := new(mockMailboxCountProvider)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
// Mock mailbox fetch to return conflicting mailbox
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{RemoteID: "wrong-id", BridgeName: []string{"Folders"}, InternalID: imap.InternalMailboxID(123)}, nil)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Labels"}).
Return(imap.MailboxData{}, db.ErrNotFound)
// Mock message count fetch to return 0 messages.
mockLabelProvider.On("GetMailboxMessageCount", mock.Anything, "gluon-id-1", imap.InternalMailboxID(123)).
Return(0, nil)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
mockCountProvider.On("GetUserMailboxCountByInternalID",
mock.Anything,
"addr-1",
imap.InternalMailboxID(123)).
Return(0, nil)
connector.SetMailboxCountProviderTest(mockCountProvider)
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
// API labels don't contain the conflicting label ID
apiLabels := make(map[string]proton.Label)
fn, err := resolver.ResolveConflict(ctx, apiLabels)
assert.NoError(t, err)
updates := fn()
assert.Len(t, updates, 1)
deleted, ok := updates[0].(*imap.MailboxDeletedSilent)
assert.True(t, ok)
assert.Equal(t, imap.MailboxID("wrong-id"), deleted.MailboxID)
}
func TestInternalLabelConflictResolver_ConflictingNonAPILabel_PositiveCount(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockCountProvider := new(mockMailboxCountProvider)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockReporter.On("ReportWarningWithContext", mock.Anything, mock.Anything).
Return(nil)
// Mock mailbox fetch to return conflicting mailbox
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{RemoteID: "wrong-id", BridgeName: []string{"Folders"}, InternalID: imap.InternalMailboxID(123)}, nil)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Labels"}).
Return(imap.MailboxData{}, db.ErrNotFound)
// Mock message count fetch to return 0 messages.
mockLabelProvider.On("GetMailboxMessageCount", mock.Anything, "gluon-id-1", imap.InternalMailboxID(123)).
Return(0, nil)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
mockCountProvider.On("GetUserMailboxCountByInternalID",
mock.Anything,
"addr-1",
imap.InternalMailboxID(123)).
Return(10, nil)
connector.SetMailboxCountProviderTest(mockCountProvider)
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
// API labels don't contain the conflicting label ID
apiLabels := make(map[string]proton.Label)
fn, err := resolver.ResolveConflict(ctx, apiLabels)
assert.EqualError(t, err, "internal mailbox conflicting non-api label has associated messages")
updates := fn()
assert.Empty(t, updates, 0)
}
func TestInternalLabelConflictResolver_ConflictingAPILabelSameName(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockCountProvider := new(mockMailboxCountProvider)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{RemoteID: "api-label-id", BridgeName: []string{"Folders"}}, nil)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Labels"}).
Return(imap.MailboxData{}, db.ErrNotFound)
mockReporter.On("ReportMessageWithContext", "Internal mailbox name conflict. Same-name mailbox is returned by API", mock.Anything).
Return(nil)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connector.SetMailboxCountProviderTest(mockCountProvider)
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
// API user label with empty path.
apiLabels := map[string]proton.Label{
"api-label-id": {
ID: "api-label-id",
Name: "Folders",
Path: []string{""},
Type: proton.LabelTypeFolder,
},
}
_, err := resolver.ResolveConflict(ctx, apiLabels)
assert.Error(t, err)
assert.Contains(t, err.Error(), "API label")
assert.Contains(t, err.Error(), "conflicts with internal label")
}
func TestInternalLabelConflictResolver_MailboxFetchError(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{}, errors.New("database connection error"))
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderFalse{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
apiLabels := make(map[string]proton.Label)
_, err := resolver.ResolveConflict(ctx, apiLabels)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database connection error")
}
func TestNewInternalLabelConflictResolver_KillSwitchEnabled(t *testing.T) {
ctx := context.Background()
mockLabelProvider := new(mockLabelNameProvider)
mockClient := new(mockAPIClient)
mockIDProvider := new(mockIDProvider)
mockReporter := new(mockReporter)
mockIDProvider.On("GetGluonID", "addr-1").Return("gluon-id-1", true)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Folders"}).
Return(imap.MailboxData{RemoteID: "wrong-folders-id", BridgeName: []string{"Folders"}}, nil)
mockLabelProvider.On("GetUserMailboxByName", mock.Anything, "gluon-id-1", []string{"Labels"}).
Return(imap.MailboxData{RemoteID: "wrong-labels-id", BridgeName: []string{"Labels"}}, nil)
connector := &imapservice.Connector{}
connector.SetAddrIDTest("addr-1")
connectors := []*imapservice.Connector{connector}
manager := imapservice.NewLabelConflictManager(mockLabelProvider, mockIDProvider, mockClient, mockReporter, ffProviderTrue{})
resolver := manager.NewInternalLabelConflictResolver(connectors)
apiLabels := map[string]proton.Label{
"some-api-label": {
ID: "some-api-label",
Name: "SomeLabel",
Path: []string{"SomeLabel"},
Type: proton.LabelTypeLabel,
},
}
fn, err := resolver.ResolveConflict(ctx, apiLabels)
assert.NoError(t, err)
updates := fn()
assert.Empty(t, updates)
}

View File

@ -45,6 +45,10 @@ import (
"golang.org/x/exp/slices"
)
type mailboxCountProvider interface {
GetUserMailboxCountByInternalID(ctx context.Context, addrID string, internalID imap.InternalMailboxID) (int, error)
}
// Connector contains all IMAP state required to satisfy sync and or imap queries.
type Connector struct {
addrID string
@ -67,6 +71,8 @@ type Connector struct {
sharedCache *SharedCache
syncState *SyncState
mailboxCountProvider mailboxCountProvider
}
var errNoSenderAddressMatch = errors.New("no matching sender found in address list")
@ -82,6 +88,7 @@ func NewConnector(
reporter reporter.Reporter,
showAllMail bool,
syncState *SyncState,
mailboxCountProvider mailboxCountProvider,
) *Connector {
userID := identityState.UserID()
@ -115,6 +122,8 @@ func NewConnector(
sharedCache: NewSharedCached(),
syncState: syncState,
mailboxCountProvider: mailboxCountProvider,
}
}
@ -257,7 +266,7 @@ func (s *Connector) DeleteMailbox(ctx context.Context, _ connector.IMAPStateWrit
wLabels := s.labels.Write()
defer wLabels.Close()
wLabels.Delete(string(mboxID))
wLabels.Delete(string(mboxID), "connectorDeleteMailbox")
return nil
}
@ -555,7 +564,7 @@ func (s *Connector) createLabel(ctx context.Context, name []string) (imap.Mailbo
wLabels := s.labels.Write()
defer wLabels.Close()
wLabels.SetLabel(label.ID, label)
wLabels.SetLabel(label.ID, label, "connectorCreateLabel")
return toIMAPMailbox(label, s.flags, s.permFlags, s.attrs), nil
}
@ -593,7 +602,7 @@ func (s *Connector) createFolder(ctx context.Context, name []string) (imap.Mailb
}
// Add label to list so subsequent sub folder create requests work correct.
wLabels.SetLabel(label.ID, label)
wLabels.SetLabel(label.ID, label, "connectorCreateFolder")
return toIMAPMailbox(label, s.flags, s.permFlags, s.attrs), nil
}
@ -619,7 +628,7 @@ func (s *Connector) updateLabel(ctx context.Context, labelID imap.MailboxID, nam
wLabels := s.labels.Write()
defer wLabels.Close()
wLabels.SetLabel(label.ID, update)
wLabels.SetLabel(label.ID, update, "connectorUpdateLabel")
return nil
}
@ -660,7 +669,7 @@ func (s *Connector) updateFolder(ctx context.Context, labelID imap.MailboxID, na
return err
}
wLabels.SetLabel(label.ID, update)
wLabels.SetLabel(label.ID, update, "connectorUpdateFolder")
return nil
}
@ -680,7 +689,7 @@ func (s *Connector) importMessage(
}
isDraft := slices.Contains(labelIDs, proton.DraftsLabel)
addr, err := s.getImportAddress(p, isDraft)
addr, err := getImportAddress(p, isDraft, s.addrID, s)
if err != nil {
return imap.Message{}, nil, err
}
@ -800,8 +809,10 @@ func (s *Connector) createDraftWithParser(ctx context.Context, parser *parser.Pa
return draft, nil
}
func (s *Connector) publishUpdate(_ context.Context, update imap.Update) {
s.updateCh.Enqueue(update)
func (s *Connector) publishUpdate(_ context.Context, updates ...imap.Update) {
for _, update := range updates {
s.updateCh.Enqueue(update)
}
}
func fixGODT3003Labels(
@ -871,45 +882,6 @@ func equalAddresses(a, b string) bool {
return strings.EqualFold(stripPlusAlias(a), stripPlusAlias(b))
}
func (s *Connector) getImportAddress(p *parser.Parser, isDraft bool) (proton.Address, error) {
// addr is primary for combined mode or active for split mode
address, ok := s.identityState.GetAddress(s.addrID)
if !ok {
return proton.Address{}, errors.New("could not find account address")
}
inCombinedMode := s.addressMode == usertypes.AddressModeCombined
if !inCombinedMode {
return address, nil
}
senderAddr, err := s.getSenderProtonAddress(p)
if err != nil {
if !errors.Is(err, errNoSenderAddressMatch) {
s.log.WithError(err).Warn("Could not get import address")
}
// We did not find a match, so we use the default address.
return address, nil
}
if senderAddr.ID == address.ID {
return address, nil
}
// GODT-3185 / BRIDGE-120 In combined mode, in certain cases we adapt the address used for encryption.
// - draft with non-default address in combined mode: using sender address
// - import with non-default address in combined mode: using sender address
// - import with non-default disabled address in combined mode: using sender address
isSenderAddressDisabled := (!bool(senderAddr.Send)) || (senderAddr.Status != proton.AddressStatusEnabled)
if isDraft && isSenderAddressDisabled {
return address, nil
}
return senderAddr, nil
}
func (s *Connector) getSenderProtonAddress(p *parser.Parser) (proton.Address, error) {
// Step 1: extract sender email address from message
if (p == nil) || (p.Root() == nil) || (p.Root().Header.Len() == 0) {
@ -942,3 +914,16 @@ func (s *Connector) getSenderProtonAddress(p *parser.Parser) (proton.Address, er
return addressList[index], nil
}
func (s *Connector) SetAddrIDTest(addrID string) {
s.addrID = addrID
}
func (s *Connector) GetMailboxMessageCount(ctx context.Context, mailboxInternalID imap.InternalMailboxID) (int, error) {
return s.mailboxCountProvider.GetUserMailboxCountByInternalID(ctx, s.addrID, mailboxInternalID)
}
// SetMailboxCountProviderTest - sets the relevant provider. Should only be used for testing.
func (s *Connector) SetMailboxCountProviderTest(provider mailboxCountProvider) {
s.mailboxCountProvider = provider
}

View File

@ -43,7 +43,7 @@ func TestFixGODT3003Labels(t *testing.T) {
Path: []string{"bar", "Foo"},
Color: "",
Type: proton.LabelTypeFolder,
})
}, "")
wr.SetLabel("0", proton.Label{
ID: "0",
@ -52,7 +52,7 @@ func TestFixGODT3003Labels(t *testing.T) {
Path: []string{"Inbox"},
Color: "",
Type: proton.LabelTypeSystem,
})
}, "")
wr.SetLabel("bar", proton.Label{
ID: "bar",
@ -61,7 +61,7 @@ func TestFixGODT3003Labels(t *testing.T) {
Path: []string{"bar"},
Color: "",
Type: proton.LabelTypeFolder,
})
}, "")
wr.SetLabel("my_label", proton.Label{
ID: "my_label",
@ -70,7 +70,7 @@ func TestFixGODT3003Labels(t *testing.T) {
Path: []string{"MyLabel"},
Color: "",
Type: proton.LabelTypeLabel,
})
}, "")
wr.SetLabel("my_label2", proton.Label{
ID: "my_label2",
@ -79,7 +79,7 @@ func TestFixGODT3003Labels(t *testing.T) {
Path: []string{labelPrefix, "MyLabel2"},
Color: "",
Type: proton.LabelTypeLabel,
})
}, "")
wr.Close()
mboxs := []imap.MailboxNoAttrib{
@ -133,7 +133,7 @@ func TestFixGODT3003Labels_Noop(t *testing.T) {
Path: []string{folderPrefix, "bar", "Foo"},
Color: "",
Type: proton.LabelTypeFolder,
})
}, "")
wr.SetLabel("0", proton.Label{
ID: "0",
@ -142,7 +142,7 @@ func TestFixGODT3003Labels_Noop(t *testing.T) {
Path: []string{"Inbox"},
Color: "",
Type: proton.LabelTypeSystem,
})
}, "")
wr.SetLabel("bar", proton.Label{
ID: "bar",
@ -151,7 +151,7 @@ func TestFixGODT3003Labels_Noop(t *testing.T) {
Path: []string{folderPrefix, "bar"},
Color: "",
Type: proton.LabelTypeFolder,
})
}, "")
wr.SetLabel("my_label", proton.Label{
ID: "my_label",
@ -160,7 +160,7 @@ func TestFixGODT3003Labels_Noop(t *testing.T) {
Path: []string{labelPrefix, "MyLabel"},
Color: "",
Type: proton.LabelTypeLabel,
})
}, "")
wr.SetLabel("my_label2", proton.Label{
ID: "my_label2",
@ -169,7 +169,7 @@ func TestFixGODT3003Labels_Noop(t *testing.T) {
Path: []string{labelPrefix, "MyLabel2"},
Color: "",
Type: proton.LabelTypeLabel,
})
}, "")
wr.Close()
mboxs := []imap.MailboxNoAttrib{

View File

@ -102,6 +102,16 @@ func newMailboxCreatedUpdate(labelID imap.MailboxID, labelName []string) *imap.M
})
}
func newMailboxUpdatedOrCreated(labelID imap.MailboxID, labelName []string) *imap.MailboxUpdatedOrCreated {
return imap.NewMailboxUpdatedOrCreated(imap.Mailbox{
ID: labelID,
Name: labelName,
Flags: defaultMailboxFlags(),
PermanentFlags: defaultMailboxPermanentFlags(),
Attributes: imap.NewFlagSet(),
})
}
func GetMailboxName(label proton.Label) []string {
var name []string
@ -112,9 +122,10 @@ func GetMailboxName(label proton.Label) []string {
case proton.LabelTypeLabel:
name = append([]string{labelPrefix}, label.Path...)
case proton.LabelTypeContactGroup:
fallthrough
case proton.LabelTypeSystem:
name = []string{label.Name}
case proton.LabelTypeContactGroup:
fallthrough
default:
name = label.Path
@ -122,3 +133,12 @@ func GetMailboxName(label proton.Label) []string {
return name
}
func nameWithTempPrefix(path []string) []string {
path[len(path)-1] = "tmp_" + path[len(path)-1]
return path
}
func getMailboxNameWithTempPrefix(label proton.Label) []string {
return nameWithTempPrefix(GetMailboxName(label))
}

View File

@ -0,0 +1,223 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice
import (
"context"
"errors"
"fmt"
"strings"
"github.com/ProtonMail/gluon/db"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/pkg/algo"
"github.com/sirupsen/logrus"
)
type labelDiscrepancyType int
const (
discrepancyInternal labelDiscrepancyType = iota
discrepancySystem
discrepancyUser
)
func (t labelDiscrepancyType) String() string {
switch t {
case discrepancyInternal:
return "internal"
case discrepancySystem:
return "system"
case discrepancyUser:
return "user"
default:
return "unknown"
}
}
type labelDiscrepancy struct {
labelName string
labelPath string
labelPathParsed string
labelID string
conflictingLabelName string
conflictingLabelID string
Type labelDiscrepancyType
}
func joinStrings(input []string) string {
return strings.Join(input, "/")
}
func newLabelDiscrepancy(label proton.Label, mbox imap.MailboxData, dType labelDiscrepancyType) labelDiscrepancy {
discrepancy := labelDiscrepancy{
labelName: label.Name,
labelID: label.ID,
conflictingLabelID: mbox.RemoteID,
Type: dType,
}
if dType == discrepancyUser {
discrepancy.labelName = algo.HashBase64SHA256(label.Name)
discrepancy.labelPath = algo.HashBase64SHA256(joinStrings(label.Path))
discrepancy.labelPathParsed = algo.HashBase64SHA256(joinStrings(GetMailboxName(label)))
discrepancy.conflictingLabelName = algo.HashBase64SHA256(joinStrings(mbox.BridgeName))
} else {
discrepancy.labelName = label.Name
discrepancy.labelPath = joinStrings(label.Path)
discrepancy.labelPathParsed = joinStrings(GetMailboxName(label))
discrepancy.conflictingLabelName = joinStrings(mbox.BridgeName)
}
return discrepancy
}
func discrepanciesToContext(discrepancies []labelDiscrepancy) reporter.Context {
ctx := make(reporter.Context)
for i, d := range discrepancies {
prefix := fmt.Sprintf("discrepancy_%d_", i)
ctx[prefix+"type"] = d.Type.String()
ctx[prefix+"label_id"] = d.labelID
ctx[prefix+"label_name"] = d.labelName
ctx[prefix+"label_path"] = d.labelPath
ctx[prefix+"label_path_parsed"] = d.labelPathParsed
ctx[prefix+"conflicting_label_name"] = d.conflictingLabelName
ctx[prefix+"conflicting_label_id"] = d.conflictingLabelID
}
ctx["discrepancy_count"] = len(discrepancies)
return ctx
}
type ConnectorGetter interface {
getConnectors() []*Connector
}
type LabelConflictChecker struct {
gluonLabelNameProvider GluonLabelNameProvider
gluonIDProvider gluonIDProvider
connectorGetter ConnectorGetter
reporter reporter.Reporter
logger *logrus.Entry
}
func NewConflictChecker(connectorGetter ConnectorGetter, reporter reporter.Reporter, provider gluonIDProvider, nameProvider GluonLabelNameProvider) *LabelConflictChecker {
return &LabelConflictChecker{
gluonLabelNameProvider: nameProvider,
gluonIDProvider: provider,
connectorGetter: connectorGetter,
reporter: reporter,
logger: logrus.WithFields(logrus.Fields{
"pkg": "imapservice/labelConflictChecker",
}),
}
}
func (c *LabelConflictChecker) getFn() mailboxFetcherFn {
connectors := c.connectorGetter.getConnectors()
return func(ctx context.Context, label proton.Label) (imap.MailboxData, error) {
for _, updateCh := range connectors {
addrID, ok := c.gluonIDProvider.GetGluonID(updateCh.addrID)
if !ok {
continue
}
return c.gluonLabelNameProvider.GetUserMailboxByName(ctx, addrID, GetMailboxName(label))
}
return imap.MailboxData{}, errors.New("no gluon connectors found")
}
}
func (c *LabelConflictChecker) CheckAndReportConflicts(ctx context.Context, labels map[string]proton.Label) error {
labelDiscrepancies, err := c.checkConflicts(ctx, labels, c.getFn())
if err != nil {
return err
}
if len(labelDiscrepancies) == 0 {
return nil
}
reporterCtx := discrepanciesToContext(labelDiscrepancies)
if err := c.reporter.ReportMessageWithContext("Found label conflicts on Bridge start", reporterCtx); err != nil {
c.logger.WithError(err).Error("Failed to report label conflicts to Sentry")
}
return nil
}
func (c *LabelConflictChecker) checkConflicts(ctx context.Context, labels map[string]proton.Label, mboxFetch mailboxFetcherFn) ([]labelDiscrepancy, error) {
discrepancies := []labelDiscrepancy{}
// Verify bridge internal mailboxes.
for _, prefix := range []string{folderPrefix, labelPrefix} {
label := proton.Label{
Path: []string{prefix},
ID: prefix,
Name: prefix,
}
mbox, err := mboxFetch(ctx, label)
if err != nil {
if db.IsErrNotFound(err) {
continue
}
return nil, err
}
if mbox.RemoteID != label.ID {
discrepancies = append(discrepancies, newLabelDiscrepancy(label, mbox, discrepancyInternal))
}
}
// Verify system and user mailboxes.
for _, label := range labels {
if !WantLabel(label) {
continue
}
mbox, err := mboxFetch(ctx, label)
if err != nil {
if db.IsErrNotFound(err) {
continue
}
return nil, err
}
if mbox.RemoteID != label.ID {
var dType labelDiscrepancyType
switch label.Type {
case proton.LabelTypeSystem:
dType = discrepancySystem
case proton.LabelTypeFolder, proton.LabelTypeLabel:
dType = discrepancyUser
case proton.LabelTypeContactGroup:
fallthrough
default:
dType = discrepancySystem
}
discrepancies = append(discrepancies, newLabelDiscrepancy(label, mbox, dType))
}
}
return discrepancies, nil
}

View File

@ -21,6 +21,7 @@ import (
"context"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice"
)
@ -34,6 +35,16 @@ type IMAPServerManager interface {
) error
RemoveIMAPUser(ctx context.Context, deleteData bool, provider GluonIDProvider, addrID ...string) error
LogRemoteLabelIDs(ctx context.Context, provider GluonIDProvider, addrID ...string) error
GetUserMailboxByName(ctx context.Context, addrID string, mailboxName []string) (imap.MailboxData, error)
GetUserMailboxCountByInternalID(ctx context.Context, addrID string, internalID imap.InternalMailboxID) (int, error)
GetOpenIMAPSessionCount() int
GetRollingIMAPConnectionCount() int
}
type NullIMAPServerManager struct{}
@ -57,6 +68,30 @@ func (n NullIMAPServerManager) RemoveIMAPUser(
return nil
}
func (n NullIMAPServerManager) LogRemoteLabelIDs(
_ context.Context,
_ GluonIDProvider,
_ ...string,
) error {
return nil
}
func (n NullIMAPServerManager) GetUserMailboxByName(_ context.Context, _ string, _ []string) (imap.MailboxData, error) {
return imap.MailboxData{}, nil
}
func (n NullIMAPServerManager) GetUserMailboxCountByInternalID(_ context.Context, _ string, _ imap.InternalMailboxID) (int, error) {
return 0, nil
}
func (n NullIMAPServerManager) GetOpenIMAPSessionCount() int {
return 0
}
func (n NullIMAPServerManager) GetRollingIMAPConnectionCount() int {
return 0
}
func NewNullIMAPServerManager() *NullIMAPServerManager {
return &NullIMAPServerManager{}
}

View File

@ -36,6 +36,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
"github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/sirupsen/logrus"
@ -91,7 +92,9 @@ type Service struct {
lastHandledEventID string
isSyncing atomic.Bool
observabilitySender observability.Sender
observabilitySender observability.Sender
labelConflictManager *LabelConflictManager
LabelConflictChecker *LabelConflictChecker
}
func NewService(
@ -112,6 +115,7 @@ func NewService(
maxSyncMemory uint64,
showAllMail bool,
observabilitySender observability.Sender,
featureFlagProvider unleash.FeatureFlagValueProvider,
) *Service {
subscriberName := fmt.Sprintf("imap-%v", identityState.User.ID)
@ -121,11 +125,12 @@ func NewService(
})
rwIdentity := newRWIdentity(identityState, bridgePassProvider, keyPassProvider)
syncUpdateApplier := NewSyncUpdateApplier()
labelConflictManager := NewLabelConflictManager(serverManager, gluonIDProvider, client, reporter, featureFlagProvider)
syncUpdateApplier := NewSyncUpdateApplier(labelConflictManager)
syncMessageBuilder := NewSyncMessageBuilder(rwIdentity)
syncReporter := newSyncReporter(identityState.User.ID, eventPublisher, time.Second)
return &Service{
service := &Service{
cpc: cpc.NewCPC(),
client: client,
log: log,
@ -156,8 +161,12 @@ func NewService(
syncReporter: syncReporter,
syncConfigPath: GetSyncConfigPath(syncConfigDir, identityState.User.ID),
observabilitySender: observabilitySender,
observabilitySender: observabilitySender,
labelConflictManager: labelConflictManager,
}
service.LabelConflictChecker = NewConflictChecker(service, reporter, gluonIDProvider, serverManager)
return service
}
func (s *Service) Start(
@ -176,7 +185,14 @@ func (s *Service) Start(
s.syncStateProvider = syncStateProvider
}
s.syncHandler = syncservice.NewHandler(syncRegulator, s.client, s.identityState.UserID(), s.syncStateProvider, s.log, s.panicHandler)
s.syncHandler = syncservice.NewHandler(
syncRegulator,
s.client,
s.identityState.UserID(),
s.syncStateProvider,
s.log,
s.panicHandler,
s.reporter)
// Get user labels
apiLabels, err := s.client.GetLabels(ctx, proton.LabelTypeSystem, proton.LabelTypeFolder, proton.LabelTypeLabel)
@ -355,6 +371,12 @@ func (s *Service) run(ctx context.Context) { //nolint gocyclo
case *onBadEventReq:
s.log.Debug("Bad Event Request")
// // Log remote label IDs stored in the local labelMap.
s.labels.LogLabels()
// Log the remote label IDs store in Gluon.
if err := s.logRemoteMailboxIDsFromServer(ctx, s.connectors); err != nil {
s.log.Warnf("Could not obtain remote mailbox IDs from server: %v", err)
}
err := s.removeConnectorsFromServer(ctx, s.connectors, false)
req.Reply(ctx, nil, err)
@ -518,6 +540,7 @@ func (s *Service) buildConnectors() (map[string]*Connector, error) {
s.reporter,
s.showAllMail,
s.syncStateProvider,
s.serverManager,
)
return connectors, nil
@ -535,6 +558,7 @@ func (s *Service) buildConnectors() (map[string]*Connector, error) {
s.reporter,
s.showAllMail,
s.syncStateProvider,
s.serverManager,
)
}
@ -572,6 +596,16 @@ func (s *Service) addConnectorsToServer(ctx context.Context, connectors map[stri
return nil
}
func (s *Service) logRemoteMailboxIDsFromServer(ctx context.Context, connectors map[string]*Connector) error {
addrIDs := make([]string, 0, len(connectors))
for _, c := range connectors {
addrIDs = append(addrIDs, c.addrID)
}
return s.serverManager.LogRemoteLabelIDs(ctx, s.gluonIDProvider, addrIDs...)
}
func (s *Service) removeConnectorsFromServer(ctx context.Context, connectors map[string]*Connector, deleteData bool) error {
addrIDs := make([]string, 0, len(connectors))
@ -635,7 +669,7 @@ func (s *Service) setShowAllMail(v bool) {
func (s *Service) startSyncing() {
s.isSyncing.Store(true)
s.syncHandler.Execute(s.syncReporter, s.labels.GetLabelMap(), s.syncUpdateApplier, s.syncMessageBuilder, syncservice.DefaultRetryCoolDown)
s.syncHandler.Execute(s.syncReporter, s.labels.GetLabelMap(), s.syncUpdateApplier, s.syncMessageBuilder, syncservice.DefaultRetryCoolDown, s.LabelConflictChecker)
}
func (s *Service) cancelSync() {
@ -643,6 +677,10 @@ func (s *Service) cancelSync() {
s.isSyncing.Store(false)
}
func (s *Service) getConnectors() []*Connector {
return maps.Values(s.connectors)
}
type resyncReq struct{}
type getLabelsReq struct{}

View File

@ -157,6 +157,7 @@ func addNewAddressSplitMode(ctx context.Context, s *Service, addrID string) erro
s.reporter,
s.showAllMail,
s.syncStateProvider,
s.serverManager,
)
if err := s.serverManager.AddIMAPUser(ctx, connector, connector.addrID, s.gluonIDProvider, s.syncStateProvider); err != nil {
@ -165,7 +166,7 @@ func addNewAddressSplitMode(ctx context.Context, s *Service, addrID string) erro
s.connectors[connector.addrID] = connector
updates, err := syncLabels(ctx, s.labels.GetLabelMap(), []*Connector{connector})
updates, err := syncLabels(ctx, s.labels.GetLabelMap(), []*Connector{connector}, s.labelConflictManager)
if err != nil {
return fmt.Errorf("failed to create labels updates for new address: %w", err)
}

View File

@ -42,7 +42,10 @@ func (s *Service) HandleLabelEvents(ctx context.Context, events []proton.LabelEv
continue
}
updates := onLabelCreated(ctx, s, event)
updates, err := onLabelCreated(ctx, s, event)
if err != nil {
return fmt.Errorf("failed to handle create label event: %w", err)
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return err
@ -74,8 +77,8 @@ func (s *Service) HandleLabelEvents(ctx context.Context, events []proton.LabelEv
return nil
}
func onLabelCreated(ctx context.Context, s *Service, event proton.LabelEvent) []imap.Update {
updates := make([]imap.Update, 0, len(s.connectors))
func onLabelCreated(ctx context.Context, s *Service, event proton.LabelEvent) ([]imap.Update, error) {
updates := []imap.Update{}
s.log.WithFields(logrus.Fields{
"labelID": event.ID,
@ -85,9 +88,19 @@ func onLabelCreated(ctx context.Context, s *Service, event proton.LabelEvent) []
wr := s.labels.Write()
defer wr.Close()
wr.SetLabel(event.Label.ID, event.Label)
wr.SetLabel(event.Label.ID, event.Label, "onLabelCreated")
labelConflictResolver := s.labelConflictManager.NewConflictResolver(maps.Values(s.connectors))
conflictUpdatesGenerator, err := labelConflictResolver.ResolveConflict(ctx, event.Label, make(map[string]bool))
if err != nil {
return updates, err
}
for _, updateCh := range maps.Values(s.connectors) {
conflictUpdates := conflictUpdatesGenerator()
updateCh.publishUpdate(ctx, conflictUpdates...)
updates = append(updates, conflictUpdates...)
update := newMailboxCreatedUpdate(imap.MailboxID(event.ID), GetMailboxName(event.Label))
updateCh.publishUpdate(ctx, update)
updates = append(updates, update)
@ -99,7 +112,7 @@ func onLabelCreated(ctx context.Context, s *Service, event proton.LabelEvent) []
Name: event.Label.Name,
})
return updates
return updates, nil
}
func onLabelUpdated(ctx context.Context, s *Service, event proton.LabelEvent) ([]imap.Update, error) {
@ -121,7 +134,7 @@ func onLabelUpdated(ctx context.Context, s *Service, event proton.LabelEvent) ([
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
if _, ok := wr.GetLabel(label.ID); ok {
wr.SetLabel(label.ID, event.Label)
wr.SetLabel(label.ID, event.Label, "onLabelUpdatedLabelEventID")
}
// API doesn't notify us that the path has changed. We need to fetch it again.
@ -134,10 +147,21 @@ func onLabelUpdated(ctx context.Context, s *Service, event proton.LabelEvent) ([
}
// Update the label in the map.
wr.SetLabel(apiLabel.ID, apiLabel)
wr.SetLabel(apiLabel.ID, apiLabel, "onLabelUpdatedApiID")
// Resolve potential conflicts
labelConflictResolver := s.labelConflictManager.NewConflictResolver(maps.Values(s.connectors))
conflictUpdatesGenerator, err := labelConflictResolver.ResolveConflict(ctx, apiLabel, make(map[string]bool))
if err != nil {
return updates, err
}
// Notify the IMAP clients.
for _, updateCh := range maps.Values(s.connectors) {
conflictUpdates := conflictUpdatesGenerator()
updateCh.publishUpdate(ctx, conflictUpdates...)
updates = append(updates, conflictUpdates...)
update := imap.NewMailboxUpdated(
imap.MailboxID(apiLabel.ID),
GetMailboxName(apiLabel),
@ -176,7 +200,7 @@ func onLabelDeleted(ctx context.Context, s *Service, event proton.LabelEvent) []
wr := s.labels.Write()
wr.Close()
wr.Delete(event.ID)
wr.Delete(event.ID, "onLabelDeleted")
s.eventPublisher.PublishEvent(ctx, events.UserLabelDeleted{
UserID: s.identityState.UserID(),

View File

@ -256,8 +256,8 @@ func onMessageUpdateDraftOrSent(ctx context.Context, s *Service, event proton.Me
res.update.Literal,
res.update.MailboxIDs,
res.update.ParsedMessage,
true, // Is the message doesn't exist, silently create it.
false,
true, // Is the message doesn't exist, silently create it.
duringSync, // Ignore unknown labelIDs during sync.
)
didPublish, err := safePublishMessageUpdate(ctx, s, full.AddressID, update, duringSync)

View File

@ -113,7 +113,7 @@ func (s syncMessageEventHandler) HandleMessageEvents(ctx context.Context, events
if err := waitOnIMAPUpdates(ctx, updates); gluon.IsNoSuchMessage(err) {
logrus.WithError(err).Error("Failed to handle update message event in gluon, will try creating it (sync)")
updates, err := onMessageCreated(ctx, s.service, event.Message, false, true)
updates, err := onMessageCreated(ctx, s.service, event.Message, true, true)
if err != nil {
s.service.observabilitySender.AddDistinctMetrics(
observability.SyncError,

View File

@ -22,6 +22,8 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
@ -42,8 +44,8 @@ type labelsRead interface {
type labelsWrite interface {
labelsRead
SetLabel(id string, label proton.Label)
Delete(id string)
SetLabel(id string, label proton.Label, actionSource string)
Delete(id string, actionSource string)
}
type rwLabels struct {
@ -51,6 +53,22 @@ type rwLabels struct {
labels labelMap
}
func (r *rwLabels) LogLabels() {
r.lock.RLock()
defer r.lock.RUnlock()
remoteLabelIDs := make([]string, len(r.labels))
i := 0
for labelID := range r.labels {
remoteLabelIDs[i] = labelID
i++
}
logrus.WithFields(logrus.Fields{
"remoteLabelIDs": remoteLabelIDs,
}).Debug("Logging remote label IDs stored in labelMap")
}
func (r *rwLabels) Read() labelsRead {
r.lock.RLock()
return &rwLabelsRead{rw: r}
@ -75,6 +93,15 @@ func (r *rwLabels) SetLabels(labels []proton.Label) {
r.lock.Lock()
defer r.lock.Unlock()
labelIDs := xslices.Map(labels, func(label proton.Label) string {
return label.ID
})
logrus.WithFields(logrus.Fields{
"pkg": "rwLabels",
"labelIDs": labelIDs,
}).Info("Setting labels")
r.labels = usertypes.GroupBy(labels, func(label proton.Label) string { return label.ID })
}
@ -123,10 +150,20 @@ func (r rwLabelsWrite) GetLabels() []proton.Label {
return r.rw.getLabelsUnsafe()
}
func (r rwLabelsWrite) SetLabel(id string, label proton.Label) {
func (r rwLabelsWrite) SetLabel(id string, label proton.Label, actionSource string) {
logAction("SetLabel", actionSource, label.ID)
r.rw.labels[id] = label
}
func (r rwLabelsWrite) Delete(id string) {
func (r rwLabelsWrite) Delete(id string, actionSource string) {
logAction("Delete", actionSource, id)
delete(r.rw.labels, id)
}
func logAction(actionType, actionSource, labelID string) {
logrus.WithFields(logrus.Fields{
"pkg": "rwLabelsWrite",
"actionSource": actionSource,
"labelID": labelID,
}).Debug(actionType)
}

View File

@ -31,8 +31,9 @@ import (
)
type SyncUpdateApplier struct {
requestCh chan updateRequest
replyCh chan updateReply
requestCh chan updateRequest
replyCh chan updateReply
labelConflictManager *LabelConflictManager
}
type updateReply struct {
@ -42,10 +43,11 @@ type updateReply struct {
type updateRequest = func(ctx context.Context, mode usertypes.AddressMode, connectors map[string]*Connector) ([]imap.Update, error)
func NewSyncUpdateApplier() *SyncUpdateApplier {
func NewSyncUpdateApplier(labelConflictManager *LabelConflictManager) *SyncUpdateApplier {
return &SyncUpdateApplier{
requestCh: make(chan updateRequest),
replyCh: make(chan updateReply),
requestCh: make(chan updateRequest),
replyCh: make(chan updateReply),
labelConflictManager: labelConflictManager,
}
}
@ -111,42 +113,9 @@ func (s *SyncUpdateApplier) ApplySyncUpdates(ctx context.Context, updates []sync
return nil
}
func (s *SyncUpdateApplier) SyncSystemLabelsOnly(ctx context.Context, labels map[string]proton.Label) error {
request := func(ctx context.Context, _ usertypes.AddressMode, connectors map[string]*Connector) ([]imap.Update, error) {
updates := make([]imap.Update, 0, len(labels)*len(connectors))
for _, label := range labels {
if !WantLabel(label) {
continue
}
if label.Type != proton.LabelTypeSystem {
continue
}
for _, c := range connectors {
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
updates = append(updates, update)
c.publishUpdate(ctx, update)
}
}
return updates, nil
}
updates, err := s.sendRequest(ctx, request)
if err != nil {
return err
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return fmt.Errorf("could not sync system labels: %w", err)
}
return nil
}
func (s *SyncUpdateApplier) SyncLabels(ctx context.Context, labels map[string]proton.Label) error {
request := func(ctx context.Context, _ usertypes.AddressMode, connectors map[string]*Connector) ([]imap.Update, error) {
return syncLabels(ctx, labels, maps.Values(connectors))
return syncLabels(ctx, labels, maps.Values(connectors), s.labelConflictManager)
}
updates, err := s.sendRequest(ctx, request)
@ -161,15 +130,34 @@ func (s *SyncUpdateApplier) SyncLabels(ctx context.Context, labels map[string]pr
}
// nolint:exhaustive
func syncLabels(ctx context.Context, labels map[string]proton.Label, connectors []*Connector) ([]imap.Update, error) {
func syncLabels(ctx context.Context, labels map[string]proton.Label, connectors []*Connector, labelConflictManager *LabelConflictManager) ([]imap.Update, error) {
var updates []imap.Update
userLabelConflictResolver := labelConflictManager.NewConflictResolver(connectors)
internalLabelConflictResolver := labelConflictManager.NewInternalLabelConflictResolver(connectors)
conflictUpdateGenerator, err := internalLabelConflictResolver.ResolveConflict(ctx, labels)
if err != nil {
return updates, err
}
for _, updateCh := range connectors {
conflictUpdates := conflictUpdateGenerator()
updateCh.publishUpdate(ctx, conflictUpdates...)
updates = append(updates, conflictUpdates...)
}
// Create placeholder Folders/Labels mailboxes with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} {
for _, updateCh := range connectors {
update := newPlaceHolderMailboxCreatedUpdate(prefix)
updateCh.publishUpdate(ctx, update)
updates = append(updates, update)
// Ensure we perform a rename operation as well. The created event won't update the name if the ID exists.
renameUpdate := imap.NewMailboxUpdated(imap.MailboxID(prefix), []string{prefix})
updateCh.publishUpdate(ctx, renameUpdate)
updates = append(updates, renameUpdate)
}
}
@ -188,7 +176,16 @@ func syncLabels(ctx context.Context, labels map[string]proton.Label, connectors
}
case proton.LabelTypeFolder, proton.LabelTypeLabel:
conflictUpdatesGenerator, err := userLabelConflictResolver.ResolveConflict(ctx, label, make(map[string]bool))
if err != nil {
return updates, err
}
for _, updateCh := range connectors {
conflictUpdates := conflictUpdatesGenerator()
updateCh.publishUpdate(ctx, conflictUpdates...)
updates = append(updates, conflictUpdates...)
update := newMailboxCreatedUpdate(imap.MailboxID(labelID), GetMailboxName(label))
updateCh.publishUpdate(ctx, update)
updates = append(updates, update)

View File

@ -0,0 +1,100 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice
import (
"errors"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
)
type connectorInterface interface {
getSenderProtonAddress(p *parser.Parser) (proton.Address, error)
getAddress(id string) (proton.Address, bool)
getPrimaryAddress() (proton.Address, error)
getAddressMode() usertypes.AddressMode
logError(err error, errMsg string)
}
func (s *Connector) logError(err error, errMsg string) {
s.log.WithError(err).Warn(errMsg)
}
func (s *Connector) getAddressMode() usertypes.AddressMode {
return s.addressMode
}
func (s *Connector) getPrimaryAddress() (proton.Address, error) {
return s.identityState.GetPrimaryAddress()
}
func (s *Connector) getAddress(id string) (proton.Address, bool) {
return s.identityState.GetAddress(id)
}
func getImportAddress(p *parser.Parser, isDraft bool, id string, conn connectorInterface) (proton.Address, error) {
// addr is primary for combined mode or active for split mode
address, ok := conn.getAddress(id)
if !ok {
return proton.Address{}, errors.New("could not find account address")
}
// If the address is external and not BYOE - with sending enabled, then use the primary address as an import target.
if address.Type == proton.AddressTypeExternal && !address.Send {
var err error
address, err = conn.getPrimaryAddress()
if err != nil {
return proton.Address{}, errors.New("could not get primary account address")
}
}
inCombinedMode := conn.getAddressMode() == usertypes.AddressModeCombined
if !inCombinedMode {
return address, nil
}
senderAddr, err := conn.getSenderProtonAddress(p)
if err != nil {
if !errors.Is(err, errNoSenderAddressMatch) {
conn.logError(err, "Could not get import address")
}
// We did not find a match, so we use the default address.
return address, nil
}
if senderAddr.ID == address.ID {
return address, nil
}
// GODT-3185 / BRIDGE-120 In combined mode, in certain cases we adapt the address used for encryption.
// - draft with non-default address in combined mode: using sender address
// - import with non-default address in combined mode: using sender address
// - import with non-default disabled address in combined mode: using sender address
isSenderAddressDisabled := (!bool(senderAddr.Send)) || (senderAddr.Status != proton.AddressStatusEnabled)
isSenderExternalNonBYOE := senderAddr.Type == proton.AddressTypeExternal && !bool(senderAddr.Send)
// Forbid drafts/imports for external non-BYOE addresses
if isSenderExternalNonBYOE || (isDraft && isSenderAddressDisabled) {
return address, nil
}
return senderAddr, nil
}

View File

@ -0,0 +1,380 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice
import (
"errors"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/pkg/message/parser"
"github.com/stretchr/testify/require"
)
type testConnector struct {
addressMode usertypes.AddressMode
primaryAddress proton.Address
senderAddress proton.Address
imapAddress proton.Address
senderAddressError error
}
func (t *testConnector) getSenderProtonAddress(_ *parser.Parser) (proton.Address, error) {
return t.senderAddress, t.senderAddressError
}
func (t *testConnector) getAddress(_ string) (proton.Address, bool) {
return t.imapAddress, true
}
func (t *testConnector) getPrimaryAddress() (proton.Address, error) {
return t.primaryAddress, nil
}
func (t *testConnector) getAddressMode() usertypes.AddressMode {
return t.addressMode
}
func (t *testConnector) logError(_ error, _ string) {
}
func Test_GetImportAddress_SplitMode(t *testing.T) {
primaryAddress := proton.Address{
ID: "1",
Email: "primary@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
Status: proton.AddressStatusEnabled,
}
imapAddressProton := proton.Address{
ID: "2",
Email: "imap@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
}
testConn := &testConnector{
addressMode: usertypes.AddressModeSplit,
primaryAddress: primaryAddress,
imapAddress: imapAddressProton,
}
// Import address is internal, we're creating a draft.
// Expected: returned address is internal.
addr, err := getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Import address is internal, we're attempting to import a message.
// Expected: returned address is internal.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
imapAddressBYOE := proton.Address{
ID: "3",
Email: "byoe@external.com",
Send: true,
Receive: true,
Type: proton.AddressTypeExternal,
}
// IMAP address is BYOE, we're creating a draft
// Expected: returned address is BYOE.
testConn.imapAddress = imapAddressBYOE
addr, err = getImportAddress(nil, true, imapAddressBYOE.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressBYOE.ID, addr.ID)
require.Equal(t, imapAddressBYOE.Email, addr.Email)
// IMAP address is BYOE, we're importing a message
// Expected: returned address is BYOE.
addr, err = getImportAddress(nil, false, imapAddressBYOE.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressBYOE.ID, addr.ID)
require.Equal(t, imapAddressBYOE.Email, addr.Email)
imapAddressExternal := proton.Address{
ID: "4",
Email: "external@external.com",
Send: false,
Receive: false,
Type: proton.AddressTypeExternal,
}
// IMAP address is external, we're creating a draft.
// Expected: returned address is primary.
testConn.imapAddress = imapAddressExternal
addr, err = getImportAddress(nil, true, imapAddressExternal.ID, testConn)
require.NoError(t, err)
require.Equal(t, primaryAddress.ID, addr.ID)
require.Equal(t, primaryAddress.Email, addr.Email)
// IMAP address is external, we're trying to import.
// Expected: returned address is primary.
addr, err = getImportAddress(nil, false, imapAddressExternal.ID, testConn)
require.NoError(t, err)
require.Equal(t, primaryAddress.ID, addr.ID)
require.Equal(t, primaryAddress.Email, addr.Email)
}
func Test_GetImportAddress_CombinedMode_ProtonAddresses(t *testing.T) {
primaryAddress := proton.Address{
ID: "1",
Email: "primary@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
Status: proton.AddressStatusEnabled,
}
imapAddressProton := proton.Address{
ID: "2",
Email: "imap@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
}
senderAddress := proton.Address{
ID: "3",
Email: "sender@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
Status: proton.AddressStatusEnabled,
}
testConn := &testConnector{
addressMode: usertypes.AddressModeCombined,
primaryAddress: primaryAddress,
imapAddress: imapAddressProton,
senderAddress: senderAddress,
}
// Both the sender address and the imap address are the same. We're creating a draft.
// Expected: IMAP address is returned.
testConn.senderAddress = imapAddressProton
testConn.imapAddress = imapAddressProton
addr, err := getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Both the sender address and the imap address are the same. We're trying to import
// Expected: IMAP address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Sender address and imap address are different. Sender address is enabled and has sending enabled.
// We're creating a draft.
// Expected: Sender address is returned.
testConn.senderAddress = senderAddress
testConn.imapAddress = imapAddressProton
addr, err = getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, senderAddress.ID, addr.ID)
require.Equal(t, senderAddress.Email, addr.Email)
// Sender address and imap address are different. Sender address is enabled and has sending enabled.
// We're importing a message.
// Expected: Sender address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, senderAddress.ID, addr.ID)
require.Equal(t, senderAddress.Email, addr.Email)
// Sender address and imap address are different. Sender address is disabled, but has sending enabled.
// We're creating a draft message.
// Expected: IMAP address is returned.
senderAddress.Status = proton.AddressStatusDisabled
testConn.senderAddress = senderAddress
addr, err = getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Sender address and imap address are different. Sender address is disabled, but has sending enabled.
// We're importing a message.
// Expected: IMAP address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, senderAddress.ID, addr.ID)
require.Equal(t, senderAddress.Email, addr.Email)
// Sender address and imap address are different. Sender address is enabled, but has sending disabled.
// We're creating a draft.
// Expected: IMAP address is returned.
senderAddress.Status = proton.AddressStatusEnabled
senderAddress.Send = false
testConn.senderAddress = senderAddress
addr, err = getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Sender address and imap address are different. Sender address is enabled, but has sending disabled.
// We're importing a message.
// Expected: IMAP address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, senderAddress.ID, addr.ID)
require.Equal(t, senderAddress.Email, addr.Email)
// Sender address and imap address are different. But sender address is not an associated proton address.
// We're creating a draft.
// Expected: Sender address is returned.
testConn.senderAddressError = errors.New("sender address is not associated with the account")
addr, err = getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Sender address and imap address are different. But sender address is not an associated proton address.
// We're importing a message.
// Expected: Sender address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
}
func Test_GetImportAddress_CombinedMode_ExternalAddresses(t *testing.T) {
primaryAddress := proton.Address{
ID: "1",
Email: "primary@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
Status: proton.AddressStatusEnabled,
}
imapAddressProton := proton.Address{
ID: "2",
Email: "imap@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
}
senderAddressExternal := proton.Address{
ID: "3",
Email: "sender@external.me",
Send: false,
Receive: false,
Type: proton.AddressTypeExternal,
Status: proton.AddressStatusEnabled,
}
senderAddressExternalSecondary := proton.Address{
ID: "4",
Email: "sender2@external.me",
Send: false,
Receive: false,
Type: proton.AddressTypeExternal,
Status: proton.AddressStatusEnabled,
}
testConn := &testConnector{
addressMode: usertypes.AddressModeCombined,
primaryAddress: primaryAddress,
imapAddress: imapAddressProton,
senderAddress: senderAddressExternal,
}
// Sender address is external, and we're creating a draft.
// Expected: IMAP address is returned.
testConn.senderAddress = senderAddressExternal
testConn.imapAddress = imapAddressProton
addr, err := getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Sender address is external, and we're importing a message.
// Expected: IMAP address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, imapAddressProton.ID, addr.ID)
require.Equal(t, imapAddressProton.Email, addr.Email)
// Sender and IMAP address are external, and we're trying to import.
// Expected: Primary address is returned.
testConn.imapAddress = senderAddressExternalSecondary
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, primaryAddress.ID, addr.ID)
require.Equal(t, primaryAddress.Email, addr.Email)
// Sender and IMAP address are external, and we're trying to create a draft.
// Expected: Primary address is returned.
addr, err = getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, primaryAddress.ID, addr.ID)
require.Equal(t, primaryAddress.Email, addr.Email)
}
func Test_GetImportAddress_CombinedMode_BYOEAddresses(t *testing.T) {
primaryAddress := proton.Address{
ID: "1",
Email: "primary@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
Status: proton.AddressStatusEnabled,
}
imapAddressProton := proton.Address{
ID: "2",
Email: "imap@proton.me",
Send: true,
Receive: true,
Type: proton.AddressTypeOriginal,
}
senderAddressBYOE := proton.Address{
ID: "3",
Email: "sender@external.me",
Send: true,
Receive: true,
Type: proton.AddressTypeExternal,
Status: proton.AddressStatusEnabled,
}
testConn := &testConnector{
addressMode: usertypes.AddressModeCombined,
primaryAddress: primaryAddress,
imapAddress: imapAddressProton,
senderAddress: senderAddressBYOE,
}
// Sender address is BYOE, and we're creating a draft.
// Expected: BYOE address is returned.
testConn.senderAddress = senderAddressBYOE
testConn.imapAddress = imapAddressProton
addr, err := getImportAddress(nil, true, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, senderAddressBYOE.ID, addr.ID)
require.Equal(t, senderAddressBYOE.Email, addr.Email)
// Sender address is BYOE, and we're importing a message.
// Expected: BYOE address is returned.
addr, err = getImportAddress(nil, false, imapAddressProton.ID, testConn)
require.NoError(t, err)
require.Equal(t, senderAddressBYOE.ID, addr.ID)
require.Equal(t, senderAddressBYOE.Email, addr.Email)
}

View File

@ -24,6 +24,7 @@ import (
"io"
"os"
"path/filepath"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon"
@ -37,9 +38,16 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/files"
"github.com/ProtonMail/proton-bridge/v3/internal/logging"
"github.com/ProtonMail/proton-bridge/v3/internal/services/observability"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/sirupsen/logrus"
)
const (
rollingCounterNewConnectionThreshold = 300
rollingCounterNumberOfBuckets = 6
rollingCounterBucketRotationInterval = time.Second * 10
)
var logIMAP = logrus.WithField("pkg", "server/imap") //nolint:gochecknoglobals
type IMAPSettingsProvider interface {
@ -81,6 +89,7 @@ func newIMAPServer(
uidValidityGenerator imap.UIDValidityGenerator,
panicHandler async.PanicHandler,
observabilitySender observability.Sender,
featureFlagProvider unleash.FeatureFlagValueProvider,
) (*gluon.Server, error) {
gluonCacheDir = ApplyGluonCachePathSuffix(gluonCacheDir)
gluonConfigDir = ApplyGluonConfigPathSuffix(gluonConfigDir)
@ -126,6 +135,8 @@ func newIMAPServer(
gluon.WithUIDValidityGenerator(uidValidityGenerator),
gluon.WithPanicHandler(panicHandler),
gluon.WithObservabilitySender(observability.NewAdapter(observabilitySender), int(observability.GluonImapError), int(observability.GluonMessageError), int(observability.GluonOtherError)),
gluon.WithConnectionRollingCounter(rollingCounterNewConnectionThreshold, rollingCounterNumberOfBuckets, rollingCounterBucketRotationInterval),
gluon.WithFeatureFlagProvider(featureFlagProvider),
}
if disableIMAPAuthenticate {

View File

@ -34,6 +34,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/services/observability"
bridgesmtp "github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
"github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
@ -63,6 +64,7 @@ type Service struct {
telemetry Telemetry
observabilitySender observability.Sender
featureFlagProvider unleash.FeatureFlagValueProvider
}
func NewService(
@ -75,6 +77,7 @@ func NewService(
uidValidityGenerator imap.UIDValidityGenerator,
telemetry Telemetry,
observabilitySender observability.Sender,
featureFlagProvider unleash.FeatureFlagValueProvider,
) *Service {
return &Service{
requests: cpc.NewCPC(),
@ -91,6 +94,7 @@ func NewService(
telemetry: telemetry,
observabilitySender: observabilitySender,
featureFlagProvider: featureFlagProvider,
}
}
@ -170,6 +174,14 @@ func (sm *Service) SetGluonDir(ctx context.Context, gluonDir string) error {
return err
}
func (sm *Service) LogRemoteLabelIDs(ctx context.Context, provider imapservice.GluonIDProvider, addrID ...string) error {
_, err := sm.requests.Send(ctx, &smRequestLogRemoteMailboxIDs{
addrID: addrID,
idProvider: provider,
})
return err
}
func (sm *Service) RemoveIMAPUser(ctx context.Context, deleteData bool, provider imapservice.GluonIDProvider, addrID ...string) error {
_, err := sm.requests.Send(ctx, &smRequestRemoveIMAPUser{
withData: deleteData,
@ -192,6 +204,22 @@ func (sm *Service) RemoveSMTPAccount(ctx context.Context, service *bridgesmtp.Se
return err
}
func (sm *Service) GetUserMailboxByName(ctx context.Context, addrID string, mailboxName []string) (imap.MailboxData, error) {
return sm.imapServer.GetUserMailboxByName(ctx, addrID, mailboxName)
}
func (sm *Service) GetUserMailboxCountByInternalID(ctx context.Context, addrID string, internalID imap.InternalMailboxID) (int, error) {
return sm.imapServer.GetUserMailboxCountByInternalID(ctx, addrID, internalID)
}
func (sm *Service) GetOpenIMAPSessionCount() int {
return sm.imapServer.GetOpenSessionCount()
}
func (sm *Service) GetRollingIMAPConnectionCount() int {
return sm.imapServer.GetRollingIMAPConnectionCount()
}
func (sm *Service) run(ctx context.Context, subscription events.Subscription) {
eventSub := subscription.Add()
defer subscription.Remove(eventSub)
@ -244,6 +272,10 @@ func (sm *Service) run(ctx context.Context, subscription events.Subscription) {
sm.handleLoadedUserCountChange(ctx)
}
case *smRequestLogRemoteMailboxIDs:
err := sm.logRemoteLabelIDsFromServer(ctx, r.addrID, r.idProvider)
request.Reply(ctx, nil, err)
case *smRequestRemoveIMAPUser:
err := sm.handleRemoveIMAPUser(ctx, r.withData, r.idProvider, r.addrID...)
request.Reply(ctx, nil, err)
@ -311,6 +343,35 @@ func (sm *Service) handleAddIMAPUser(ctx context.Context,
return sm.handleAddIMAPUserImpl(ctx, connector, addrID, idProvider, syncStateProvider)
}
func (sm *Service) logRemoteLabelIDsFromServer(ctx context.Context, addrIDs []string, idProvider imapservice.GluonIDProvider) error {
if sm.imapServer == nil {
return fmt.Errorf("no imap server instance running")
}
for _, addrID := range addrIDs {
gluonID, ok := idProvider.GetGluonID(addrID)
if !ok {
sm.log.Warnf("Could not find Gluon ID for addrID %v", addrID)
continue
}
log := sm.log.WithFields(logrus.Fields{
"addrID": addrID,
"gluonID": gluonID,
})
remoteLabelIDs, err := sm.imapServer.GetAllMailboxRemoteIDsForUser(ctx, gluonID)
if err != nil {
log.WithError(err).Error("Could not obtain remote label IDs for user")
continue
}
log.WithField("remoteLabelIDs", remoteLabelIDs).Debug("Logging Gluon remote Label IDs")
}
return nil
}
func (sm *Service) handleAddIMAPUserImpl(ctx context.Context,
connector connector.Connector,
addrID string,
@ -457,6 +518,7 @@ func (sm *Service) createIMAPServer(ctx context.Context) (*gluon.Server, error)
sm.uidValidityGenerator,
sm.panicHandler,
sm.observabilitySender,
sm.featureFlagProvider,
)
if err == nil {
sm.eventPublisher.PublishEvent(ctx, events.IMAPServerCreated{})
@ -723,3 +785,8 @@ type smRequestAddSMTPAccount struct {
type smRequestRemoveSMTPAccount struct {
account *bridgesmtp.Service
}
type smRequestLogRemoteMailboxIDs struct {
addrID []string
idProvider imapservice.GluonIDProvider
}

View File

@ -44,7 +44,7 @@ type Service struct {
store *Store
getFlagValueFn unleash.GetFlagValueFn
featureFlagValueProvider unleash.FeatureFlagValueProvider
observabilitySender observability.Sender
}
@ -52,7 +52,7 @@ type Service struct {
const bitfieldRegexPattern = `^\\\d+`
func NewService(userID string, service userevents.Subscribable, eventPublisher events.EventPublisher, store *Store,
getFlagFn unleash.GetFlagValueFn, observabilitySender observability.Sender) *Service {
featureFlagValueProvider unleash.FeatureFlagValueProvider, observabilitySender observability.Sender) *Service {
return &Service{
userID: userID,
@ -68,8 +68,8 @@ func NewService(userID string, service userevents.Subscribable, eventPublisher e
store: store,
getFlagValueFn: getFlagFn,
observabilitySender: observabilitySender,
featureFlagValueProvider: featureFlagValueProvider,
observabilitySender: observabilitySender,
}
}
@ -102,7 +102,7 @@ func (s *Service) run(ctx context.Context) {
}
func (s *Service) HandleNotificationEvents(ctx context.Context, notificationEvents []proton.NotificationEvent) error {
if s.getFlagValueFn(unleash.EventLoopNotificationDisabled) {
if s.featureFlagValueProvider.GetFlagValue(unleash.EventLoopNotificationDisabled) {
s.log.Info("Received notification events. Skipping as kill switch is enabled.")
return nil
}

View File

@ -19,6 +19,7 @@ package observability
import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/services/observability/gluonmetrics"
)
type Adapter struct {
@ -88,6 +89,15 @@ func (adapter *Adapter) AddDistinctMetrics(errType interface{}, metrics ...map[s
}
if len(typedMetrics) > 0 {
adapter.sender.AddDistinctMetrics(DistinctionErrorTypeEnum(errTypeInt), typedMetrics...)
adapter.sender.AddDistinctMetrics(DistinctionMetricTypeEnum(errTypeInt), typedMetrics...)
}
}
func (adapter *Adapter) AddIMAPConnectionsExceededThresholdMetric(totalOpenIMAPConnections, newIMAPConnections int) {
metric := gluonmetrics.GenerateNewOpenedIMAPConnectionsExceedThreshold(
adapter.sender.GetEmailClient(),
BucketIMAPConnections(totalOpenIMAPConnections),
BucketIMAPConnections(newIMAPConnections))
adapter.sender.AddTimeLimitedMetric(NewIMAPConnectionsExceedThreshold, metric)
}

View File

@ -19,21 +19,22 @@ package observability
import "time"
// DistinctionErrorTypeEnum - maps to the specific error schema for which we
// want to send a user update.
type DistinctionErrorTypeEnum int
// DistinctionMetricTypeEnum - used to distinct specific metrics which we want to limit over some interval.
// Most enums are tied to a specific error schema for which we also send a specific distinction user update.
type DistinctionMetricTypeEnum int
const (
SyncError DistinctionErrorTypeEnum = iota
SyncError DistinctionMetricTypeEnum = iota
GluonImapError
GluonMessageError
GluonOtherError
SMTPError
EventLoopError // EventLoopError - should always be kept last when inserting new keys.
NewIMAPConnectionsExceedThreshold
)
// errorSchemaMap - maps between the DistinctionErrorTypeEnum and the relevant schema name.
var errorSchemaMap = map[DistinctionErrorTypeEnum]string{ //nolint:gochecknoglobals
// errorSchemaMap - maps between some DistinctionMetricTypeEnum and the relevant schema name.
var errorSchemaMap = map[DistinctionMetricTypeEnum]string{ //nolint:gochecknoglobals
SyncError: "bridge_sync_errors_users_total",
EventLoopError: "bridge_event_loop_events_errors_users_total",
GluonImapError: "bridge_gluon_imap_errors_users_total",
@ -43,9 +44,9 @@ var errorSchemaMap = map[DistinctionErrorTypeEnum]string{ //nolint:gochecknoglob
}
// createLastSentMap - needs to be updated whenever we make changes to the enum.
func createLastSentMap() map[DistinctionErrorTypeEnum]time.Time {
func createLastSentMap() map[DistinctionMetricTypeEnum]time.Time {
registerTime := time.Now().Add(-updateInterval)
lastSentMap := make(map[DistinctionErrorTypeEnum]time.Time)
lastSentMap := make(map[DistinctionMetricTypeEnum]time.Time)
for errType := SyncError; errType <= EventLoopError; errType++ {
lastSentMap[errType] = registerTime

View File

@ -40,7 +40,7 @@ type distinctionUtility struct {
panicHandler async.PanicHandler
lastSentMap map[DistinctionErrorTypeEnum]time.Time // Ensures we don't step over the limit of one user update every 5 mins.
lastSentMap map[DistinctionMetricTypeEnum]time.Time // Ensures we don't step over the limit of one user update every 5 mins.
observabilitySender observabilitySender
settingsGetter settingsGetter
@ -87,7 +87,7 @@ func (d *distinctionUtility) setSettingsGetter(getter settingsGetter) {
// checkAndUpdateLastSentMap - checks whether we have sent a relevant user update metric
// within the last 5 minutes.
func (d *distinctionUtility) checkAndUpdateLastSentMap(key DistinctionErrorTypeEnum) bool {
func (d *distinctionUtility) checkAndUpdateLastSentMap(key DistinctionMetricTypeEnum) bool {
curTime := time.Now()
val, ok := d.lastSentMap[key]
if !ok {
@ -107,7 +107,7 @@ func (d *distinctionUtility) checkAndUpdateLastSentMap(key DistinctionErrorTypeE
// and the relevant settings. In the future this will need to be expanded to support multiple
// versions of the metric if we ever decide to change them.
func (d *distinctionUtility) generateUserMetric(
metricType DistinctionErrorTypeEnum,
metricType DistinctionMetricTypeEnum,
) proton.ObservabilityMetric {
schemaName, ok := errorSchemaMap[metricType]
if !ok {
@ -138,7 +138,7 @@ func generateUserMetric(schemaName, plan, mailClient, dohEnabled, betaAccess str
}
}
func (d *distinctionUtility) generateDistinctMetrics(errType DistinctionErrorTypeEnum, metrics ...proton.ObservabilityMetric) []proton.ObservabilityMetric {
func (d *distinctionUtility) generateDistinctMetrics(errType DistinctionMetricTypeEnum, metrics ...proton.ObservabilityMetric) []proton.ObservabilityMetric {
d.updateHeartbeatData(errType)
if d.checkAndUpdateLastSentMap(errType) {

View File

@ -0,0 +1,45 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package gluonmetrics
import (
"time"
"github.com/ProtonMail/go-proton-api"
)
const (
newIMAPConnectionThresholdExceededSchemaName = "bridge_imap_recently_opened_connections_total"
newIMAPConnectionThresholdExceededVersion = 1
)
func GenerateNewOpenedIMAPConnectionsExceedThreshold(emailClient, totalOpenIMAPConnectionCount, newlyOpenedIMAPConnectionCount string) proton.ObservabilityMetric {
return proton.ObservabilityMetric{
Name: newIMAPConnectionThresholdExceededSchemaName,
Version: newIMAPConnectionThresholdExceededVersion,
Timestamp: time.Now().Unix(),
Data: map[string]interface{}{
"Value": 1,
"Labels": map[string]string{
"mailClient": emailClient,
"numberOfOpenIMAPConnectionsBuckets": totalOpenIMAPConnectionCount,
"numberOfRecentlyOpenedIMAPConnectionsBuckets": newlyOpenedIMAPConnectionCount,
},
},
}
}

View File

@ -42,7 +42,7 @@ func (d *distinctionUtility) resetHeartbeatData() {
d.heartbeatData.receivedGluonError = false
}
func (d *distinctionUtility) updateHeartbeatData(errType DistinctionErrorTypeEnum) {
func (d *distinctionUtility) updateHeartbeatData(errType DistinctionMetricTypeEnum) {
d.withUpdateHeartbeatDataLock(func() {
//nolint:exhaustive
switch errType {

View File

@ -45,7 +45,9 @@ type client struct {
// so we can easily pass them down to relevant components.
type Sender interface {
AddMetrics(metrics ...proton.ObservabilityMetric)
AddDistinctMetrics(errType DistinctionErrorTypeEnum, metrics ...proton.ObservabilityMetric)
AddDistinctMetrics(errType DistinctionMetricTypeEnum, metrics ...proton.ObservabilityMetric)
AddTimeLimitedMetric(metricType DistinctionMetricTypeEnum, metric proton.ObservabilityMetric)
GetEmailClient() string
}
type Service struct {
@ -325,11 +327,25 @@ func (s *Service) AddMetrics(metrics ...proton.ObservabilityMetric) {
// what number of events come from what number of users.
// As the binning interval is what allows us to do this we
// should not send these if there are no logged-in users at that moment.
func (s *Service) AddDistinctMetrics(errType DistinctionErrorTypeEnum, metrics ...proton.ObservabilityMetric) {
func (s *Service) AddDistinctMetrics(errType DistinctionMetricTypeEnum, metrics ...proton.ObservabilityMetric) {
metrics = s.distinctionUtility.generateDistinctMetrics(errType, metrics...)
s.addMetricsIfClients(metrics...)
}
// AddTimeLimitedMetric - schedules a metric to be sent if a metric of the same type has not been sent within some interval.
// The interval is defined in the distinction utility.
func (s *Service) AddTimeLimitedMetric(metricType DistinctionMetricTypeEnum, metric proton.ObservabilityMetric) {
if !s.distinctionUtility.checkAndUpdateLastSentMap(metricType) {
return
}
s.addMetricsIfClients(metric)
}
func (s *Service) GetEmailClient() string {
return s.distinctionUtility.getEmailClientUserAgent()
}
// ModifyHeartbeatInterval - should only be used for testing. Resets the heartbeat ticker.
func (s *Service) ModifyHeartbeatInterval(duration time.Duration) {
s.distinctionUtility.heartbeatTicker.Reset(duration)

View File

@ -66,3 +66,30 @@ func getEnabled(value bool) string {
}
return "enabled"
}
func BucketIMAPConnections(val int) string {
switch {
case val < 10:
return "<10"
case val < 25:
return "10-24"
case val < 50:
return "25-49"
case val < 100:
return "50-99"
case val < 200:
return "100-199"
case val < 300:
return "200-299"
case val < 500:
return "300-499"
case val < 1000:
return "500-999"
case val < 2000:
return "1000-1999"
case val < 3000:
return "2000-2999"
default:
return "3000+"
}
}

View File

@ -21,6 +21,7 @@ import (
"time"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/services/observability"
)
const (
@ -29,6 +30,9 @@ const (
smtpSendSuccessSchemaName = "bridge_smtp_send_success_total"
smtpSendSuccessSchemaVersion = 1
smtpSubmissionRequestSchemaName = "bridge_smtp_send_request_total"
smtpSubmissionRequestSchemaVersion = 1
)
func generateSMTPErrorObservabilityMetric(errorType string) proton.ObservabilityMetric {
@ -88,3 +92,19 @@ func GenerateSMTPSendSuccess() proton.ObservabilityMetric {
},
}
}
func GenerateSMTPSubmissionRequest(emailClient string, numberOfOpenIMAPConnections, numberOfRecentlyOpenedIMAPConnections int) proton.ObservabilityMetric {
return proton.ObservabilityMetric{
Name: smtpSubmissionRequestSchemaName,
Version: smtpSubmissionRequestSchemaVersion,
Timestamp: time.Now().Unix(),
Data: map[string]interface{}{
"Value": 1,
"Labels": map[string]string{
"numberOfOpenIMAPConnections": observability.BucketIMAPConnections(numberOfOpenIMAPConnections),
"numberOfRecentlyOpenedIMAPConnections": observability.BucketIMAPConnections(numberOfRecentlyOpenedIMAPConnections),
"mailClient": emailClient,
},
},
}
}

View File

@ -32,13 +32,24 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/services/observability"
"github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks"
"github.com/ProtonMail/proton-bridge/v3/internal/services/sendrecorder"
"github.com/ProtonMail/proton-bridge/v3/internal/services/smtp/observabilitymetrics"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents"
"github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/usertypes"
"github.com/ProtonMail/proton-bridge/v3/pkg/cpc"
"github.com/sirupsen/logrus"
)
const (
newlyOpenedIMAPConnectionsThreshold = 300
)
type imapSessionCountProvider interface {
GetOpenIMAPSessionCount() int
GetRollingIMAPConnectionCount() int
}
type Service struct {
userID string
panicHandler async.PanicHandler
@ -59,6 +70,9 @@ type Service struct {
serverManager ServerManager
observabilitySender observability.Sender
imapSessionCountProvider imapSessionCountProvider
featureFlagValueProvider unleash.FeatureFlagValueProvider
}
func NewService(
@ -74,6 +88,8 @@ func NewService(
identityState *useridentity.State,
serverManager ServerManager,
observabilitySender observability.Sender,
imapSessionCountProvider imapSessionCountProvider,
featureFlagValueProvider unleash.FeatureFlagValueProvider,
) *Service {
subscriberName := fmt.Sprintf("smpt-%v", userID)
@ -99,7 +115,9 @@ func NewService(
addressMode: mode,
serverManager: serverManager,
observabilitySender: observabilitySender,
imapSessionCountProvider: imapSessionCountProvider,
observabilitySender: observabilitySender,
featureFlagValueProvider: featureFlagValueProvider,
}
}
@ -207,7 +225,6 @@ func (s *Service) run(ctx context.Context) {
switch r := request.Value().(type) {
case *sendMailReq:
s.log.Debug("Received send mail request")
err := s.sendMail(ctx, r)
request.Reply(ctx, nil, err)
@ -252,16 +269,38 @@ type sendMailReq struct {
func (s *Service) sendMail(ctx context.Context, req *sendMailReq) error {
defer async.HandlePanic(s.panicHandler)
openSessionCount := s.imapSessionCountProvider.GetOpenIMAPSessionCount()
newlyOpenedSessions := s.imapSessionCountProvider.GetRollingIMAPConnectionCount()
log := s.log.WithFields(logrus.Fields{
"newlyOpenedIMAPConnectionsCount": newlyOpenedSessions,
"openIMAPConnectionsCount": openSessionCount,
})
log.Debug("Received send mail request")
// Send SMTP send request metric to observability.
s.observabilitySender.AddMetrics(observabilitymetrics.GenerateSMTPSubmissionRequest(s.observabilitySender.GetEmailClient(), openSessionCount, newlyOpenedSessions))
// Send report to sentry if kill switch is disabled & number of newly opened IMAP connections exceed threshold.
if !s.featureFlagValueProvider.GetFlagValue(unleash.SMTPSubmissionRequestSentryReportDisabled) && newlyOpenedSessions >= newlyOpenedIMAPConnectionsThreshold {
if err := s.reporter.ReportMessageWithContext("SMTP Send Mail Request - newly opened IMAP connections exceed threshold", reporter.Context{
"newlyOpenedIMAPConnectionsCount": newlyOpenedSessions,
"openIMAPConnectionsCount": openSessionCount,
"emailClient": s.observabilitySender.GetEmailClient(),
}); err != nil {
s.log.WithError(err).Error("Failed to submit report to sentry (SMTP Send Mail Request)")
}
}
start := time.Now()
s.log.Debug("Received send mail request")
defer func() {
end := time.Now()
s.log.Debugf("Send mail request finished in %v", end.Sub(start))
log.Debugf("Send mail request finished in %v", end.Sub(start))
}()
if err := s.smtpSendMail(ctx, req.authID, req.from, req.to, req.r); err != nil {
if apiErr := new(proton.APIError); errors.As(err, &apiErr) {
s.log.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message")
log.WithError(apiErr).WithField("Details", apiErr.DetailsToString()).Error("failed to send message")
}
return err

View File

@ -19,10 +19,13 @@ package syncservice
import (
"context"
"errors"
"fmt"
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/db"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/sirupsen/logrus"
@ -33,6 +36,10 @@ const NumSyncStages = 4
type LabelMap = map[string]proton.Label
type labelConflictChecker interface {
CheckAndReportConflicts(ctx context.Context, labels map[string]proton.Label) error
}
// Handler is the interface from which we control the syncing of the IMAP data. One instance should be created for each
// user and used for every subsequent sync request.
type Handler struct {
@ -45,6 +52,7 @@ type Handler struct {
syncFinishedCh chan error
panicHandler async.PanicHandler
downloadCache *DownloadCache
sentryReporter reporter.Reporter
}
func NewHandler(
@ -54,6 +62,7 @@ func NewHandler(
state StateProvider,
log *logrus.Entry,
panicHandler async.PanicHandler,
sentryReporter reporter.Reporter,
) *Handler {
return &Handler{
client: client,
@ -65,6 +74,7 @@ func NewHandler(
regulator: regulator,
panicHandler: panicHandler,
downloadCache: newDownloadCache(),
sentryReporter: sentryReporter,
}
}
@ -91,12 +101,17 @@ func (t *Handler) Execute(
updateApplier UpdateApplier,
messageBuilder MessageBuilder,
coolDown time.Duration,
labelConflictChecker labelConflictChecker,
) {
t.log.Info("Sync triggered")
t.group.Once(func(ctx context.Context) {
start := time.Now()
t.log.WithField("start", start).Info("Beginning user sync")
if err := labelConflictChecker.CheckAndReportConflicts(ctx, labels); err != nil {
t.log.WithError(err).Error("Failed to check and report label conflicts")
}
syncReporter.OnStart(ctx)
var err error
for {
@ -104,6 +119,20 @@ func (t *Handler) Execute(
t.log.WithError(err).Error("Sync aborted")
break
} else if err = t.run(ctx, syncReporter, labels, updateApplier, messageBuilder); err != nil {
if db.IsUniqueLabelConstraintError(err) {
if sentryErr := t.sentryReporter.ReportMessageWithContext("Failed to sync due to label unique constraint conflict",
reporter.Context{"err": err}); sentryErr != nil {
t.log.WithError(sentryErr).Error("Failed to report label unique constraint conflict error to Sentry")
}
} else if !(errors.Is(err, context.Canceled)) {
if sentryErr := t.sentryReporter.ReportMessageWithContext("Failed to sync, will retry later", reporter.Context{
"err": err.Error(),
"user_id": t.userID,
}); sentryErr != nil {
t.log.WithError(sentryErr).Error("Failed to report sentry message")
}
}
t.log.WithError(err).Error("Failed to sync, will retry later")
sleepCtx(ctx, coolDown)
} else {
@ -138,11 +167,10 @@ func (t *Handler) run(ctx context.Context,
}
if syncStatus.IsComplete() {
t.log.Info("Sync already complete, only system labels will be updated")
if err := updateApplier.SyncSystemLabelsOnly(ctx, labels); err != nil {
t.log.WithError(err).Error("Failed to sync system labels")
t.log.Info("Sync already complete, updating labels")
if err := updateApplier.SyncLabels(ctx, labels); err != nil {
t.log.WithError(err).Error("Failed to sync labels")
return err
}

View File

@ -25,6 +25,7 @@ import (
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/bradenaw/juniper/xmaps"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
@ -74,8 +75,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) {
}
{
call1 := tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(1).Return(nil)
tt.updateApplier.EXPECT().SyncSystemLabelsOnly(gomock.Any(), gomock.Eq(labels)).After(call1).Times(1).Return(nil)
tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Times(2).Return(nil)
}
{
@ -203,12 +203,19 @@ func TestTask_StateHasSyncedState(t *testing.T) {
}, nil
})
tt.updateApplier.EXPECT().SyncSystemLabelsOnly(gomock.Any(), gomock.Eq(labels)).Return(nil)
tt.updateApplier.EXPECT().SyncLabels(gomock.Any(), gomock.Eq(labels)).Return(nil)
err := tt.task.run(context.Background(), tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder)
require.NoError(t, err)
}
type mockLabelConflictChecker struct {
}
func (m *mockLabelConflictChecker) CheckAndReportConflicts(_ context.Context, _ map[string]proton.Label) error {
return nil
}
func TestTask_RepeatsOnSyncFailure(t *testing.T) {
const MessageTotal int64 = 50
const MessageID string = "foo"
@ -272,7 +279,7 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) {
tt.syncReporter.EXPECT().OnFinished(gomock.Any())
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
tt.task.Execute(tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder, time.Microsecond)
tt.task.Execute(tt.syncReporter, labels, tt.updateApplier, tt.messageBuilder, time.Microsecond, &mockLabelConflictChecker{})
require.NoError(t, <-tt.task.OnSyncFinishedCH())
}
@ -343,7 +350,7 @@ func newTestHandler(mockCtrl *gomock.Controller, userID string) thandler { // no
client := NewMockAPIClient(mockCtrl)
messageBuilder := NewMockMessageBuilder(mockCtrl)
syncReporter := NewMockReporter(mockCtrl)
task := NewHandler(regulator, client, userID, syncState, logrus.WithField("test", "test"), &async.NoopPanicHandler{})
task := NewHandler(regulator, client, userID, syncState, logrus.WithField("test", "test"), &async.NoopPanicHandler{}, sentry.NullSentryReporter{})
return thandler{
task: task,

View File

@ -80,7 +80,6 @@ type MessageBuilder interface {
type UpdateApplier interface {
ApplySyncUpdates(ctx context.Context, updates []BuildResult) error
SyncSystemLabelsOnly(ctx context.Context, labels map[string]proton.Label) error
SyncLabels(ctx context.Context, labels map[string]proton.Label) error
}

View File

@ -548,20 +548,6 @@ func (mr *MockUpdateApplierMockRecorder) SyncLabels(arg0, arg1 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncLabels", reflect.TypeOf((*MockUpdateApplier)(nil).SyncLabels), arg0, arg1)
}
// SyncSystemLabelsOnly mocks base method.
func (m *MockUpdateApplier) SyncSystemLabelsOnly(arg0 context.Context, arg1 map[string]proton.Label) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncSystemLabelsOnly", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SyncSystemLabelsOnly indicates an expected call of SyncSystemLabelsOnly.
func (mr *MockUpdateApplierMockRecorder) SyncSystemLabelsOnly(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncSystemLabelsOnly", reflect.TypeOf((*MockUpdateApplier)(nil).SyncSystemLabelsOnly), arg0, arg1)
}
// MockMessageBuilder is a mock of MessageBuilder interface.
type MockMessageBuilder struct {
ctrl *gomock.Controller

View File

@ -29,6 +29,8 @@ import (
"time"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/db"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/gluon/watcher"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal"
@ -70,6 +72,8 @@ type Service struct {
eventPollWaitersLock sync.Mutex
eventSubscription events.Subscription
eventWatcher *watcher.Watcher[events.Event]
sentryReporter reporter.Reporter
}
func NewService(
@ -82,6 +86,7 @@ func NewService(
eventTimeout time.Duration,
panicHandler async.PanicHandler,
eventSubscription events.Subscription,
sentryReporter reporter.Reporter,
) *Service {
return &Service{
cpc: cpc.NewCPC(),
@ -99,6 +104,7 @@ func NewService(
panicHandler: panicHandler,
eventSubscription: eventSubscription,
eventWatcher: eventSubscription.Add(events.ConnStatusDown{}, events.ConnStatusUp{}),
sentryReporter: sentryReporter,
}
}
@ -414,6 +420,14 @@ func (s *Service) handleEventError(ctx context.Context, lastEventID string, even
return subscriberName, fmt.Errorf("failed to handle event due to server error: %w", err)
}
if db.IsUniqueLabelConstraintError(err) {
if err := s.sentryReporter.ReportMessageWithContext("Unique label constraint error occurred on event", reporter.Context{
"err": err,
}); err != nil {
s.log.WithError(err).Error("Failed to report label constraint error to sentry")
}
}
// Otherwise, the error is a client-side issue; notify bridge to handle it.
s.log.WithField("event", event).Warn("Failed to handle API event")

View File

@ -30,6 +30,7 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
@ -49,6 +50,7 @@ func TestServiceHandleEventError_SubscriberEventUnwrapping(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
lastEventID := "PrevEvent"
@ -87,6 +89,7 @@ func TestServiceHandleEventError_BadEventPutsServiceOnPause(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
service.Resume()
lastEventID := "PrevEvent"
@ -121,6 +124,7 @@ func TestServiceHandleEventError_BadEventFromPublishTimeout(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
@ -152,6 +156,7 @@ func TestServiceHandleEventError_NoBadEventCheck(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}
@ -178,6 +183,7 @@ func TestServiceHandleEventError_JsonUnmarshalEventProducesUncategorizedErrorEve
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
lastEventID := "PrevEvent"
event := proton.Event{EventID: "MyEvent"}

View File

@ -28,6 +28,7 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
@ -69,6 +70,7 @@ func TestServiceHandleEvent_CheckEventCategoriesHandledInOrder(t *testing.T) {
10*time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
subscription := NewCallbackSubscriber("test", EventHandler{
@ -130,6 +132,7 @@ func TestServiceHandleEvent_CheckEventFailureCausesError(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
subscription := NewCallbackSubscriber("test", EventHandler{
@ -168,6 +171,7 @@ func TestServiceHandleEvent_CheckEventFailureCausesErrorParallel(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
subscription := NewCallbackSubscriber("test", EventHandler{

View File

@ -27,6 +27,7 @@ import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
mocks2 "github.com/ProtonMail/proton-bridge/v3/internal/events/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/services/orderedtasks"
"github.com/ProtonMail/proton-bridge/v3/internal/services/userevents/mocks"
"github.com/golang/mock/gomock"
@ -76,6 +77,7 @@ func TestService_EventIDLoadStore(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
_, err := service.Start(context.Background(), group)
@ -132,6 +134,7 @@ func TestService_RetryEventOnNonCatastrophicFailure(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
service.Subscribe(NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber}))
@ -182,6 +185,7 @@ func TestService_OnBadEventServiceIsPaused(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
// Event publisher expectations.
@ -249,6 +253,7 @@ func TestService_UnsubscribeDuringEventHandlingDoesNotCauseDeadlock(t *testing.T
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
subscription := NewCallbackSubscriber("foo", EventHandler{MessageHandler: subscriber})
@ -309,6 +314,7 @@ func TestService_UnsubscribeBeforeHandlingEventIsNotConsideredError(t *testing.T
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
subscription := NewEventSubscriber("Foo")
@ -369,6 +375,7 @@ func TestService_WaitOnEventPublishAfterPause(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
subscriber.EXPECT().HandleMessageEvents(gomock.Any(), gomock.Eq(messageEvents)).Times(1).DoAndReturn(func(_ context.Context, _ []proton.MessageEvent) error {
@ -442,6 +449,7 @@ func TestService_EventRewind(t *testing.T) {
time.Second,
async.NoopPanicHandler{},
events.NewNullSubscription(),
sentry.NullSentryReporter{},
)
_, err := service.Start(context.Background(), group)

View File

@ -37,14 +37,32 @@ var pollJitter = 2 * time.Minute //nolint:gochecknoglobals
const filename = "unleash_flags"
const (
EventLoopNotificationDisabled = "InboxBridgeEventLoopNotificationDisabled"
IMAPAuthenticateCommandDisabled = "InboxBridgeImapAuthenticateCommandDisabled"
UserRemovalGluonDataCleanupDisabled = "InboxBridgeUserRemovalGluonDataCleanupDisabled"
EventLoopNotificationDisabled = "InboxBridgeEventLoopNotificationDisabled"
IMAPAuthenticateCommandDisabled = "InboxBridgeImapAuthenticateCommandDisabled"
UserRemovalGluonDataCleanupDisabled = "InboxBridgeUserRemovalGluonDataCleanupDisabled"
UpdateUseNewVersionFileStructureDisabled = "InboxBridgeUpdateWithOsFilterDisabled"
LabelConflictResolverDisabled = "InboxBridgeLabelConflictResolverDisabled"
SMTPSubmissionRequestSentryReportDisabled = "InboxBridgeSmtpSubmissionRequestSentryReportDisabled"
InternalLabelConflictResolverDisabled = "InboxBridgeUnexpectedFoldersLabelsStartupFixupDisabled"
ItnternalLabelConflictNonEmptyMailboxDeletion = "InboxBridgeUnknownNonEmptyMailboxDeletion"
)
type requestFeaturesFn func(ctx context.Context) (proton.FeatureFlagResult, error)
type GetFlagValueFn func(key string) bool
type FeatureFlagValueProvider interface {
GetFlagValue(key string) bool
}
// NullUnleashService - mock of the unleash service. Should be used for testing.
type NullUnleashService struct{}
func (n NullUnleashService) GetFlagValue(_ string) bool {
return false
}
func NewNullUnleashService() *NullUnleashService {
return &NullUnleashService{}
}
type requestFeaturesFn func(ctx context.Context) (proton.FeatureFlagResult, error)
type Service struct {
panicHandler async.PanicHandler
timer *proton.Ticker

View File

@ -0,0 +1,255 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package updater
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func Test_ReleaseCategory_UpdateEligible(t *testing.T) {
// If release is beta only beta users can update
require.True(t, EarlyAccessReleaseCategory.UpdateEligible(EarlyChannel))
require.False(t, EarlyAccessReleaseCategory.UpdateEligible(StableChannel))
// If the release is stable and is the newest then both beta and stable users can update
require.True(t, StableReleaseCategory.UpdateEligible(EarlyChannel))
require.True(t, StableReleaseCategory.UpdateEligible(StableChannel))
}
func Test_ReleaseCategory_JsonUnmarshal(t *testing.T) {
tests := []struct {
input string
expected ReleaseCategory
wantErr bool
}{
{
input: `{"ReleaseCategory": "EarlyAccess"}`,
expected: EarlyAccessReleaseCategory,
},
{
input: `{"ReleaseCategory": "Earlyaccess"}`,
expected: EarlyAccessReleaseCategory,
},
{
input: `{"ReleaseCategory": "earlyaccess"}`,
expected: EarlyAccessReleaseCategory,
},
{
input: `{"ReleaseCategory": " earlyaccess "}`,
expected: EarlyAccessReleaseCategory,
},
{
input: `{"ReleaseCategory": "Stable"}`,
expected: StableReleaseCategory,
},
{
input: `{"ReleaseCategory": "Stable "}`,
expected: StableReleaseCategory,
},
{
input: `{"ReleaseCategory": "stable"}`,
expected: StableReleaseCategory,
},
{
input: `{"ReleaseCategory": "invalid"}`,
wantErr: true,
},
}
var data struct {
ReleaseCategory ReleaseCategory
}
for _, test := range tests {
err := json.Unmarshal([]byte(test.input), &data)
if err != nil && !test.wantErr {
t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, test.wantErr)
return
}
if test.wantErr && err == nil {
t.Errorf("expected err got nil")
}
if !test.wantErr && data.ReleaseCategory != test.expected {
t.Errorf("got %v, want %v", data.ReleaseCategory, test.expected)
}
}
}
func Test_ReleaseCategory_JsonMarshal(t *testing.T) {
tests := []struct {
input struct {
ReleaseCategory ReleaseCategory `json:"ReleaseCategory"`
}
expectedOutput string
wantErr bool
}{
{
input: struct {
ReleaseCategory ReleaseCategory `json:"ReleaseCategory"`
}{ReleaseCategory: StableReleaseCategory},
expectedOutput: `{"ReleaseCategory":"Stable"}`,
},
{
input: struct {
ReleaseCategory ReleaseCategory `json:"ReleaseCategory"`
}{ReleaseCategory: EarlyAccessReleaseCategory},
expectedOutput: `{"ReleaseCategory":"EarlyAccess"}`,
},
{
input: struct {
ReleaseCategory ReleaseCategory `json:"ReleaseCategory"`
}{ReleaseCategory: 4},
wantErr: true,
},
}
for _, test := range tests {
output, err := json.Marshal(test.input)
if test.wantErr {
if err == nil && len(output) == 0 {
t.Errorf("expected error or non-empty output for invalid category")
return
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if string(output) != test.expectedOutput {
t.Errorf("json.Marshal() = %v, want %v", string(output), test.expectedOutput)
}
}
}
}
func Test_FileIdentifier_JsonUnmarshal(t *testing.T) {
tests := []struct {
input string
expected FileIdentifier
wantErr bool
}{
{
input: `{"Identifier": "package"}`,
expected: PackageIdentifier,
},
{
input: `{"Identifier": "Package"}`,
expected: PackageIdentifier,
},
{
input: `{"Identifier": "pACKage"}`,
expected: PackageIdentifier,
},
{
input: `{"Identifier": "pACKage "}`,
expected: PackageIdentifier,
},
{
input: `{"Identifier": "installer"}`,
expected: InstallerIdentifier,
},
{
input: `{"Identifier": "Installer"}`,
expected: InstallerIdentifier,
},
{
input: `{"Identifier": "iNSTaller "}`,
expected: InstallerIdentifier,
},
{
input: `{"Identifier": "error"}`,
wantErr: true,
},
}
var data struct {
Identifier FileIdentifier
}
for _, test := range tests {
err := json.Unmarshal([]byte(test.input), &data)
if err != nil && !test.wantErr {
t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, test.wantErr)
return
}
if test.wantErr && err == nil {
t.Errorf("expected err got nil")
}
if !test.wantErr && data.Identifier != test.expected {
t.Errorf("got %v, want %v", data.Identifier, test.expected)
}
}
}
func Test_FileIdentifier_JsonMarshal(t *testing.T) {
tests := []struct {
input struct {
Identifier FileIdentifier
}
expectedOutput string
wantErr bool
}{
{
input: struct {
Identifier FileIdentifier
}{Identifier: PackageIdentifier},
expectedOutput: `{"Identifier":"package"}`,
},
{
input: struct {
Identifier FileIdentifier
}{Identifier: InstallerIdentifier},
expectedOutput: `{"Identifier":"installer"}`,
},
{
input: struct {
Identifier FileIdentifier
}{Identifier: 4},
wantErr: true,
},
}
for _, test := range tests {
output, err := json.Marshal(test.input)
if test.wantErr {
if err == nil && len(output) == 0 {
t.Errorf("expected error or non-empty output for invalid identifier")
return
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if string(output) != test.expectedOutput {
t.Errorf("json.Marshal() = %v, want %v", string(output), test.expectedOutput)
}
}
}
}

View File

@ -0,0 +1,135 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package updater
import (
"encoding/json"
"fmt"
"strings"
)
type ReleaseCategory uint8
type FileIdentifier uint8
const (
EarlyAccessReleaseCategory ReleaseCategory = iota
StableReleaseCategory
)
const (
PackageIdentifier FileIdentifier = iota
InstallerIdentifier
)
var (
releaseCategoryName = map[uint8]string{ //nolint:gochecknoglobals
0: "EarlyAccess",
1: "Stable",
}
releaseCategoryValue = map[string]uint8{ //nolint:gochecknoglobals
"earlyaccess": 0,
"stable": 1,
}
fileIdentifierName = map[uint8]string{ //nolint:gochecknoglobals
0: "package",
1: "installer",
}
fileIdentifierValue = map[string]uint8{ //nolint:gochecknoglobals
"package": 0,
"installer": 1,
}
)
func ParseFileIdentifier(s string) (FileIdentifier, error) {
s = strings.TrimSpace(strings.ToLower(s))
val, ok := fileIdentifierValue[s]
if !ok {
return FileIdentifier(0), fmt.Errorf("%s is not a valid file identifier", s)
}
return FileIdentifier(val), nil
}
func (fi FileIdentifier) String() string {
return fileIdentifierName[uint8(fi)]
}
func (fi FileIdentifier) MarshalJSON() ([]byte, error) {
return json.Marshal(fi.String())
}
func (fi *FileIdentifier) UnmarshalJSON(data []byte) (err error) {
var fileIdentifier string
if err := json.Unmarshal(data, &fileIdentifier); err != nil {
return err
}
parsedFileIdentifier, err := ParseFileIdentifier(fileIdentifier)
if err != nil {
return err
}
*fi = parsedFileIdentifier
return nil
}
func ParseReleaseCategory(s string) (ReleaseCategory, error) {
s = strings.TrimSpace(strings.ToLower(s))
val, ok := releaseCategoryValue[s]
if !ok {
return ReleaseCategory(0), fmt.Errorf("%s is not a valid release category", s)
}
return ReleaseCategory(val), nil
}
func (rc ReleaseCategory) String() string {
return releaseCategoryName[uint8(rc)]
}
func (rc ReleaseCategory) MarshalJSON() ([]byte, error) {
return json.Marshal(rc.String())
}
func (rc *ReleaseCategory) UnmarshalJSON(data []byte) (err error) {
var releaseCat string
if err := json.Unmarshal(data, &releaseCat); err != nil {
return err
}
parsedCat, err := ParseReleaseCategory(releaseCat)
if err != nil {
return err
}
*rc = parsedCat
return nil
}
func (rc ReleaseCategory) UpdateEligible(channel Channel) bool {
if channel == StableChannel && rc == StableReleaseCategory {
return true
}
if channel == EarlyChannel && rc == EarlyAccessReleaseCategory || rc == StableReleaseCategory {
return true
}
return false
}

View File

@ -29,13 +29,17 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/versioner"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
)
const updateFileVersion = 1
var (
ErrDownloadVerify = errors.New("failed to download or verify the update")
ErrInstall = errors.New("failed to install the update")
ErrUpdateAlreadyInstalled = errors.New("update is already installed")
ErrVersionFileDownloadOrVerify = errors.New("failed to download or verify the version file")
ErrReleaseUpdatePackageMissing = errors.New("release update package is missing")
)
type Downloader interface {
@ -53,6 +57,7 @@ type Updater struct {
verifier *crypto.KeyRing
product string
platform string
version uint
}
func NewUpdater(ver *versioner.Versioner, verifier *crypto.KeyRing, product, platform string) *Updater {
@ -62,10 +67,36 @@ func NewUpdater(ver *versioner.Versioner, verifier *crypto.KeyRing, product, pla
verifier: verifier,
product: product,
platform: platform,
version: updateFileVersion,
}
}
func (u *Updater) GetVersionInfo(ctx context.Context, downloader Downloader, channel Channel) (VersionInfo, error) {
func (u *Updater) GetVersionInfoLegacy(ctx context.Context, downloader Downloader, channel Channel) (VersionInfoLegacy, error) {
b, err := downloader.DownloadAndVerify(
ctx,
u.verifier,
u.getVersionFileURLLegacy(),
u.getVersionFileURLLegacy()+".sig",
)
if err != nil {
return VersionInfoLegacy{}, fmt.Errorf("%w: %w", ErrVersionFileDownloadOrVerify, err)
}
var versionMap VersionMap
if err := json.Unmarshal(b, &versionMap); err != nil {
return VersionInfoLegacy{}, err
}
version, ok := versionMap[channel]
if !ok {
return VersionInfoLegacy{}, errors.New("no updates available for this channel")
}
return version, nil
}
func (u *Updater) GetVersionInfo(ctx context.Context, downloader Downloader) (VersionInfo, error) {
b, err := downloader.DownloadAndVerify(
ctx,
u.verifier,
@ -76,21 +107,16 @@ func (u *Updater) GetVersionInfo(ctx context.Context, downloader Downloader, cha
return VersionInfo{}, fmt.Errorf("%w: %w", ErrVersionFileDownloadOrVerify, err)
}
var versionMap VersionMap
var releases VersionInfo
if err := json.Unmarshal(b, &versionMap); err != nil {
if err := json.Unmarshal(b, &releases); err != nil {
return VersionInfo{}, err
}
version, ok := versionMap[channel]
if !ok {
return VersionInfo{}, errors.New("no updates available for this channel")
}
return version, nil
return releases, nil
}
func (u *Updater) InstallUpdate(ctx context.Context, downloader Downloader, update VersionInfo) error {
func (u *Updater) InstallUpdateLegacy(ctx context.Context, downloader Downloader, update VersionInfoLegacy) error {
if u.installer.IsAlreadyInstalled(update.Version) {
return ErrUpdateAlreadyInstalled
}
@ -113,13 +139,64 @@ func (u *Updater) InstallUpdate(ctx context.Context, downloader Downloader, upda
return nil
}
func (u *Updater) InstallUpdate(ctx context.Context, downloader Downloader, release Release) error {
if u.installer.IsAlreadyInstalled(release.Version) {
return ErrUpdateAlreadyInstalled
}
// Find update package
idx := slices.IndexFunc(release.File, func(file File) bool {
return file.Identifier == PackageIdentifier
})
if idx == -1 {
logrus.WithFields(logrus.Fields{
"release_version": release.Version,
}).Error("Update release does not contain update package")
return ErrReleaseUpdatePackageMissing
}
releaseUpdatePackage := release.File[idx]
b, err := downloader.DownloadAndVerify(
ctx,
u.verifier,
releaseUpdatePackage.URL,
releaseUpdatePackage.URL+".sig",
)
if err != nil {
return fmt.Errorf("%w: %w", ErrDownloadVerify, err)
}
if err := u.installer.InstallUpdate(release.Version, bytes.NewReader(b)); err != nil {
logrus.WithError(err).Error("Failed to install update")
return ErrInstall
}
return nil
}
func (u *Updater) RemoveOldUpdates() error {
return u.versioner.RemoveOldVersions()
}
// getVersionFileURL returns the URL of the version file.
// getVersionFileURLLegacy returns the URL of the version file.
// For example:
// - https://protonmail.com/download/bridge/version_linux.json
func (u *Updater) getVersionFileURL() string {
func (u *Updater) getVersionFileURLLegacy() string {
return fmt.Sprintf("%v/%v/version_%v.json", Host, u.product, u.platform)
}
// getVersionFileURL returns the URL of the version file.
// For example:
// - https://protonmail.com/download/windows/x86/v1/version.json
// - https://protonmail.com/download/linux/x86/v1/version.json
// - https://protonmail.com/download/darwin/universal/v1/version.json
func (u *Updater) getVersionFileURL() string {
switch u.platform {
case "darwin":
return fmt.Sprintf("%v/%v/%v/universal/v%v/version.json", Host, u.product, u.platform, u.version)
default:
return fmt.Sprintf("%v/%v/%v/x86/v%v/version.json", Host, u.product, u.platform, u.version)
}
}

View File

@ -19,10 +19,36 @@ package updater
import (
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/updater/versioncompare"
)
// VersionInfo is information about one version of the app.
type File struct {
URL string `json:"Url"`
Sha512CheckSum string `json:"Sha512CheckSum,omitempty"`
Identifier FileIdentifier `json:"Identifier"`
}
type Release struct {
ReleaseCategory ReleaseCategory `json:"CategoryName"`
Version *semver.Version
SystemVersion versioncompare.SystemVersion `json:"SystemVersion,omitempty"`
RolloutProportion float64
MinAuto *semver.Version `json:"MinAuto,omitempty"`
ReleaseNotesPage string
LandingPage string
File []File `json:"File"`
}
func (rel Release) IsEmpty() bool {
return rel.Version == nil && len(rel.File) == 0
}
type VersionInfo struct {
Releases []Release `json:"Releases"`
}
// VersionInfoLegacy is information about one version of the app.
type VersionInfoLegacy struct {
// Version is the semantic version of the release.
Version *semver.Version
@ -46,6 +72,10 @@ type VersionInfo struct {
RolloutProportion float64
}
func (verInfo VersionInfoLegacy) IsEmpty() bool {
return verInfo.Version == nil && verInfo.ReleaseNotesPage == ""
}
// VersionMap represents the structure of the version.json file.
// It looks like this:
//
@ -79,4 +109,4 @@ type VersionInfo struct {
// }
// }.
type VersionMap map[Channel]VersionInfo
type VersionMap map[Channel]VersionInfoLegacy

View File

@ -0,0 +1,205 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package updater
import (
"encoding/json"
"testing"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/proton-bridge/v3/internal/updater/versioncompare"
)
var mockJSONData = `
{
"Releases": [
{
"CategoryName": "Stable",
"Version": "2.1.0",
"ReleaseDate": "2025-01-15T08:00:00Z",
"File": [
{
"Url": "https://downloads.example.com/v2.1.0/MyApp-2.1.0.pkg",
"Sha512CheckSum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
"Identifier": "package"
},
{
"Url": "https://downloads.example.com/v2.1.0/MyApp-2.1.0.dmg",
"Sha512CheckSum": "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce",
"Identifier": "installer"
}
],
"RolloutProportion": 0.5,
"MinAuto": "2.0.0",
"Commit": "8f52d45c9f8c31aa391315ea24e40c4a7e0b2c1d",
"ReleaseNotesPage": "https://example.com/releases/2.1.0/notes",
"LandingPage": "https://example.com/releases/2.1.0"
},
{
"CategoryName": "EarlyAccess",
"Version": "2.2.0-beta.1",
"ReleaseDate": "2025-01-20T10:00:00Z",
"File": [
{
"Url": "https://downloads.example.com/beta/v2.2.0-beta.1/MyApp-2.2.0-beta.1.pkg",
"Sha512CheckSum": "a9f0e44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
"Identifier": "package"
}
],
"SystemVersion": {
"Minimum": "13"
},
"RolloutProportion": 0.25,
"MinAuto": "2.1.0",
"Commit": "3e72d45c9f8c31aa391315ea24e40c4a7e0b2c1d",
"ReleaseNotesPage": "https://example.com/releases/2.2.0-beta.1/notes",
"LandingPage": "https://example.com/releases/2.2.0-beta.1"
},
{
"CategoryName": "Stable",
"Version": "2.0.0",
"ReleaseDate": "2024-12-01T09:00:00Z",
"File": [
{
"Url": "https://downloads.example.com/v2.0.0/MyApp-2.0.0.pkg",
"Sha512CheckSum": "b5f0e44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
"Identifier": "package"
},
{
"Url": "https://downloads.example.com/v2.0.0/MyApp-2.0.0.dmg",
"Sha512CheckSum": "d583e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce",
"Identifier": "installer"
}
],
"SystemVersion": {
"Maximum": "12.0.0",
"Minimum": "1.0.0"
},
"RolloutProportion": 1.0,
"MinAuto": "1.9.0",
"Commit": "2a42d45c9f8c31aa391315ea24e40c4a7e0b2c1d",
"ReleaseNotesPage": "https://example.com/releases/2.0.0/notes",
"LandingPage": "https://example.com/releases/2.0.0"
}
]
}
`
var expectedVersionInfo = VersionInfo{
Releases: []Release{
{
ReleaseCategory: StableReleaseCategory,
Version: semver.MustParse("2.1.0"),
RolloutProportion: 0.5,
MinAuto: semver.MustParse("2.0.0"),
File: []File{
{
URL: "https://downloads.example.com/v2.1.0/MyApp-2.1.0.pkg",
Identifier: PackageIdentifier,
},
{
URL: "https://downloads.example.com/v2.1.0/MyApp-2.1.0.dmg",
Identifier: InstallerIdentifier,
},
},
},
{
ReleaseCategory: EarlyAccessReleaseCategory,
Version: semver.MustParse("2.2.0-beta.1"),
RolloutProportion: 0.25,
MinAuto: semver.MustParse("2.1.0"),
File: []File{
{
URL: "https://downloads.example.com/beta/v2.2.0-beta.1/MyApp-2.2.0-beta.1.pkg",
Identifier: PackageIdentifier,
},
},
SystemVersion: versioncompare.SystemVersion{Minimum: "13"},
},
{
ReleaseCategory: StableReleaseCategory,
Version: semver.MustParse("2.0.0"),
RolloutProportion: 1.0,
MinAuto: semver.MustParse("1.9.0"),
SystemVersion: versioncompare.SystemVersion{Maximum: "12.0.0", Minimum: "1.0.0"},
File: []File{
{
URL: "https://downloads.example.com/v2.0.0/MyApp-2.0.0.pkg",
Identifier: PackageIdentifier,
},
{
URL: "https://downloads.example.com/v2.0.0/MyApp-2.0.0.dmg",
Identifier: InstallerIdentifier,
},
},
},
},
}
func Test_Releases_JsonParse(t *testing.T) {
var versionInfo VersionInfo
if err := json.Unmarshal([]byte(mockJSONData), &versionInfo); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
if len(expectedVersionInfo.Releases) != len(versionInfo.Releases) {
t.Fatalf("expected %d releases, parsed %d releases", len(expectedVersionInfo.Releases), len(versionInfo.Releases))
}
for i, expectedRelease := range expectedVersionInfo.Releases {
release := versionInfo.Releases[i]
if release.ReleaseCategory != expectedRelease.ReleaseCategory {
t.Errorf("Release %d: expected category %v, got %v", i, expectedRelease.ReleaseCategory, release.ReleaseCategory)
}
if release.Version.String() != expectedRelease.Version.String() {
t.Errorf("Release %d: expected version %s, got %s", i, expectedRelease.Version, release.Version)
}
if release.RolloutProportion != expectedRelease.RolloutProportion {
t.Errorf("Release %d: expected rollout proportion %f, got %f", i, expectedRelease.RolloutProportion, release.RolloutProportion)
}
if expectedRelease.MinAuto != nil && release.MinAuto.String() != expectedRelease.MinAuto.String() {
t.Errorf("Release %d: expected min auto %s, got %s", i, expectedRelease.MinAuto, release.MinAuto)
}
if expectedRelease.SystemVersion.Minimum != release.SystemVersion.Minimum {
t.Errorf("Release %d: expected system version minimum %s, got %s", i, expectedRelease.SystemVersion.Minimum, release.SystemVersion.Minimum)
}
if expectedRelease.SystemVersion.Maximum != release.SystemVersion.Maximum {
t.Errorf("Release %d: expected system version minimum %s, got %s", i, expectedRelease.SystemVersion.Maximum, release.SystemVersion.Maximum)
}
if len(release.File) != len(expectedRelease.File) {
t.Errorf("Release %d: expected %d files, got %d", i, len(expectedRelease.File), len(release.File))
}
for j, expectedFile := range expectedRelease.File {
file := release.File[j]
if file.URL != expectedFile.URL {
t.Errorf("Release %d, File %d: expected URL %s, got %s", i, j, expectedFile.URL, file.URL)
}
if file.Identifier != expectedFile.Identifier {
t.Errorf("Release %d, File %d: expected Identifier %v, got %v", i, j, expectedFile.Identifier, file.Identifier)
}
}
}
}

View File

@ -0,0 +1,134 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build darwin
package versioncompare
import (
"fmt"
"strconv"
"strings"
"github.com/elastic/go-sysinfo/types"
"github.com/sirupsen/logrus"
)
func (sysVer SystemVersion) IsHostVersionEligible(log *logrus.Entry, host types.Host, getHostOSVersion func(host types.Host) string) (bool, error) {
if sysVer.Minimum == "" && sysVer.Maximum == "" {
return true, nil
}
// We use getHostOSVersion simply for testing; It's passed via Bridge.
var hostVersion string
if getHostOSVersion == nil {
hostVersion = host.Info().OS.Version
} else {
hostVersion = getHostOSVersion(host)
}
log.Debugf("Checking host OS and update system version requirements. Host: %s; Maximum: %s; Minimum: %s",
hostVersion, sysVer.Maximum, sysVer.Minimum)
hostVersionArr := strings.Split(hostVersion, ".")
if len(hostVersionArr) == 0 || hostVersion == "" {
return true, fmt.Errorf("could not get host version: %v", hostVersion)
}
hostVersionArrInt := make([]int, len(hostVersionArr))
for i := 0; i < len(hostVersionArr); i++ {
hostNum, err := strconv.Atoi(hostVersionArr[i])
if err != nil {
// If we receive an alphanumeric version - we should continue with the update and stop checking for
// OS version requirements.
return true, fmt.Errorf("invalid host version number: %s - %s", hostVersionArr[i], hostVersion)
}
hostVersionArrInt[i] = hostNum
}
if sysVer.Minimum != "" {
pass, err := compareMinimumVersion(hostVersionArrInt, sysVer.Minimum)
if err != nil {
return false, err
}
if !pass {
return false, fmt.Errorf("host version is below minimum: hostVersion %v - minimumVersion %v", hostVersion, sysVer.Minimum)
}
}
if sysVer.Maximum != "" {
pass, err := compareMaximumVersion(hostVersionArrInt, sysVer.Maximum)
if err != nil {
return false, err
}
if !pass {
return false, fmt.Errorf("host version is above maximum version: hostVersion %v - minimumVersion %v", hostVersion, sysVer.Maximum)
}
}
return true, nil
}
func compareMinimumVersion(hostVersionArr []int, minVersion string) (bool, error) {
minVersionArr := strings.Split(minVersion, ".")
iterationDepth := min(len(hostVersionArr), len(minVersionArr))
for i := 0; i < iterationDepth; i++ {
hostNum := hostVersionArr[i]
minNum, err := strconv.Atoi(minVersionArr[i])
if err != nil {
return false, fmt.Errorf("invalid minimum version number: %s - %s", minVersionArr[i], minVersion)
}
if hostNum < minNum {
return false, nil
}
if hostNum > minNum {
return true, nil
}
}
return true, nil // minVersion is inclusive
}
func compareMaximumVersion(hostVersionArr []int, maxVersion string) (bool, error) {
maxVersionArr := strings.Split(maxVersion, ".")
iterationDepth := min(len(maxVersionArr), len(hostVersionArr))
for i := 0; i < iterationDepth; i++ {
hostNum := hostVersionArr[i]
maxNum, err := strconv.Atoi(maxVersionArr[i])
if err != nil {
return false, fmt.Errorf("invalid maximum version number: %s - %s", maxVersionArr[i], maxVersion)
}
if hostNum > maxNum {
return false, nil
}
if hostNum < maxNum {
return true, nil
}
}
return true, nil // maxVersion is inclusive
}

View File

@ -0,0 +1,105 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build darwin
package versioncompare
import (
"testing"
"github.com/elastic/go-sysinfo"
"github.com/elastic/go-sysinfo/types"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
func Test_IsHost_EligibleDarwin(t *testing.T) {
host, err := sysinfo.Host()
require.NoError(t, err)
testData := []struct {
sysVer SystemVersion
getHostOsVersionFn func(host types.Host) string
shouldContinue bool
wantErr bool
}{
{
sysVer: SystemVersion{Minimum: "9.5", Maximum: "12.0"},
getHostOsVersionFn: func(_ types.Host) string { return "10.0" },
shouldContinue: true,
},
{
sysVer: SystemVersion{Minimum: "9.5.5.5", Maximum: "10.1.1.0"},
getHostOsVersionFn: func(_ types.Host) string { return "10.0" },
shouldContinue: true,
},
{
sysVer: SystemVersion{Minimum: "10.0.1", Maximum: "12.0"},
getHostOsVersionFn: func(_ types.Host) string { return "10.0" },
shouldContinue: true,
},
{
sysVer: SystemVersion{Minimum: "11.0", Maximum: "12.0"},
getHostOsVersionFn: func(_ types.Host) string { return "10.0" },
shouldContinue: false,
wantErr: true,
},
{
sysVer: SystemVersion{Minimum: "11.1.0", Maximum: "12.0.0"},
getHostOsVersionFn: func(_ types.Host) string { return "11.0.0" },
shouldContinue: false,
wantErr: true,
},
{
sysVer: SystemVersion{Minimum: "10.0", Maximum: "12.0"},
getHostOsVersionFn: func(_ types.Host) string { return "12.0" },
shouldContinue: true,
},
{
sysVer: SystemVersion{Minimum: "11.1.0", Maximum: "12.0.0"},
getHostOsVersionFn: func(_ types.Host) string { return "" },
shouldContinue: true,
wantErr: true,
},
{
sysVer: SystemVersion{Minimum: "11.1.0", Maximum: "12.0.0"},
getHostOsVersionFn: func(_ types.Host) string { return "a.b.c" },
shouldContinue: true,
wantErr: true,
},
{
sysVer: SystemVersion{},
getHostOsVersionFn: func(_ types.Host) string { return "1.2.3" },
shouldContinue: true,
wantErr: false,
},
}
for _, test := range testData {
l := logrus.WithField("test", "test")
shouldContinue, err := test.sysVer.IsHostVersionEligible(l, host, test.getHostOsVersionFn)
if test.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, test.shouldContinue, shouldContinue)
}
}

View File

@ -0,0 +1,31 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build linux
package versioncompare
import (
"github.com/elastic/go-sysinfo/types"
"github.com/sirupsen/logrus"
)
// IsHostVersionEligible - Checks whether host OS version is eligible for update. Defaults to true on Linux.
func (sysVer SystemVersion) IsHostVersionEligible(log *logrus.Entry, _ types.Host, _ func(host types.Host) string) (bool, error) {
log.Info("Checking host OS version on Linux. Defaulting to true.")
return true, nil
}

View File

@ -0,0 +1,31 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build windows
package versioncompare
import (
"github.com/elastic/go-sysinfo/types"
"github.com/sirupsen/logrus"
)
// IsHostVersionEligible - Checks whether host OS version is eligible for update. Defaults to true on Linux.
func (sysVer SystemVersion) IsHostVersionEligible(log *logrus.Entry, _ types.Host, _ func(host types.Host) string) (bool, error) {
log.Info("Checking host OS version on Windows. Defaulting to true.")
return true, nil
}

View File

@ -0,0 +1,29 @@
// Copyright (c) 2025 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package versioncompare
import "fmt"
type SystemVersion struct {
Minimum string `json:"Minimum,omitempty"`
Maximum string `json:"Maximum,omitempty"`
}
func (sysVer SystemVersion) String() string {
return fmt.Sprintf("SystemVersion: Maximum %s, Minimum %s", sysVer.Maximum, sysVer.Minimum)
}

View File

@ -110,7 +110,7 @@ func New(
syncConfigDir string,
isNew bool,
notificationStore *notifications.Store,
getFlagValFn unleash.GetFlagValueFn,
featureFlagValueProvider unleash.FeatureFlagValueProvider,
) (*User, error) {
user, err := newImpl(
ctx,
@ -130,7 +130,7 @@ func New(
syncConfigDir,
isNew,
notificationStore,
getFlagValFn,
featureFlagValueProvider,
)
if err != nil {
// Cleanup any pending resources on error
@ -163,7 +163,7 @@ func newImpl(
syncConfigDir string,
isNew bool,
notificationStore *notifications.Store,
getFlagValueFn unleash.GetFlagValueFn,
featureFlagValueProvider unleash.FeatureFlagValueProvider,
) (*User, error) {
logrus.WithField("userID", apiUser.ID).Info("Creating new user")
@ -241,6 +241,7 @@ func newImpl(
5*time.Minute,
crashHandler,
eventSubscription,
reporter,
)
addressMode := usertypes.VaultToAddressMode(encVault.AddressMode())
@ -262,6 +263,8 @@ func newImpl(
identityState.Clone(),
smtpServerManager,
observabilityService,
imapServerManager,
featureFlagValueProvider,
)
user.imapService = imapservice.NewService(
@ -282,9 +285,10 @@ func newImpl(
user.maxSyncMemory,
showAllMail,
observabilityService,
featureFlagValueProvider,
)
user.notificationService = notifications.NewService(user.id, user.eventService, user, notificationStore, getFlagValueFn, observabilityService)
user.notificationService = notifications.NewService(user.id, user.eventService, user, notificationStore, featureFlagValueProvider, observabilityService)
// When we receive an auth object, we update it in the vault.
// This will be used to authorize the user on the next run.
@ -739,7 +743,7 @@ func (user *User) protonAddresses() []proton.Address {
}
addresses := xslices.Filter(maps.Values(apiAddrs), func(addr proton.Address) bool {
return addr.Status == proton.AddressStatusEnabled && addr.Type != proton.AddressTypeExternal
return addr.Status == proton.AddressStatusEnabled && (addr.IsBYOEAddress() || addr.Type != proton.AddressTypeExternal)
})
slices.SortFunc(addresses, func(a, b proton.Address) bool {

View File

@ -28,11 +28,13 @@ import (
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/ProtonMail/proton-bridge/v3/internal/certs"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/sentry"
"github.com/ProtonMail/proton-bridge/v3/internal/services/imapservice"
"github.com/ProtonMail/proton-bridge/v3/internal/services/notifications"
"github.com/ProtonMail/proton-bridge/v3/internal/services/observability"
"github.com/ProtonMail/proton-bridge/v3/internal/services/smtp"
"github.com/ProtonMail/proton-bridge/v3/internal/telemetry/mocks"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/ProtonMail/proton-bridge/v3/tests"
"github.com/golang/mock/gomock"
@ -110,7 +112,7 @@ func withAccount(tb testing.TB, s *server.Server, username, password string, ali
addrIDs := []string{addrID}
for _, email := range aliases {
addrID, err := s.CreateAddress(userID, email, []byte(password))
addrID, err := s.CreateAddress(userID, email, []byte(password), true)
require.NoError(tb, err)
require.NoError(tb, s.ChangeAddressDisplayName(userID, addrID, email+" (Display Name)"))
@ -150,12 +152,13 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
nullEventSubscription := events.NewNullSubscription()
nullIMAPServerManager := imapservice.NewNullIMAPServerManager()
nullSMTPServerManager := smtp.NewNullServerManager()
nullUnleashService := unleash.NewNullUnleashService()
user, err := New(
ctx,
vaultUser,
client,
nil,
sentry.NullSentryReporter{},
apiUser,
nil,
true,
@ -171,9 +174,7 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
notifications.NewStore(func() (string, error) {
return "", nil
}),
func(_ string) bool {
return false
},
nullUnleashService,
)
require.NoError(tb, err)
defer user.Close()

View File

@ -63,6 +63,10 @@ func GetHelper(vaultDir string) (string, error) {
}
func SetHelper(vaultDir, helper string) error {
if helper == "" {
return nil
}
settings, err := LoadKeychainSettings(vaultDir)
if err != nil {
return err

View File

@ -82,11 +82,11 @@ func (kcl *List) GetDefaultHelper() string {
return kcl.defaultHelper
}
// NewKeychain creates a new native keychain.
func NewKeychain(preferred, keychainName string, helpers Helpers, defaultHelper string) (*Keychain, error) {
// NewKeychain creates a new native keychain. It also returns the keychain helper used to access the keychain.
func NewKeychain(preferred, keychainName string, helpers Helpers, defaultHelper string) (kc *Keychain, usedKeychainHelper string, err error) {
// There must be at least one keychain helper available.
if len(helpers) < 1 {
return nil, ErrNoKeychain
return nil, "", ErrNoKeychain
}
// If the preferred keychain is unsupported, fallback to the default one.
@ -97,16 +97,16 @@ func NewKeychain(preferred, keychainName string, helpers Helpers, defaultHelper
// Load the user's preferred keychain helper.
helperConstructor, ok := helpers[preferred]
if !ok {
return nil, ErrNoKeychain
return nil, "", ErrNoKeychain
}
// Construct the keychain helper.
helper, err := helperConstructor(hostURL(keychainName))
if err != nil {
return nil, err
return nil, preferred, err
}
return newKeychain(helper, hostURL(keychainName)), nil
return newKeychain(helper, hostURL(keychainName)), preferred, nil
}
func newKeychain(helper credentials.Helper, url string) *Keychain {

View File

@ -120,7 +120,7 @@ func TestIsErrKeychainNoItem(t *testing.T) {
helpers := NewList().GetHelpers()
for helperName := range helpers {
kc, err := NewKeychain(helperName, "bridge-test", helpers, helperName)
kc, _, err := NewKeychain(helperName, "bridge-test", helpers, helperName)
r.NoError(err)
_, _, err = kc.Get("non-existing")

View File

@ -36,6 +36,8 @@ type API interface {
GetDomain() string
GetAppVersion() string
PushFeatureFlag(string)
Close()
}
@ -61,6 +63,10 @@ func (api *fakeAPI) GetAppVersion() string {
return proton.DefaultAppVersion
}
func (api *fakeAPI) PushFeatureFlag(flagName string) {
api.Server.PushFeatureFlag(flagName)
}
type liveAPI struct {
*server.Server

View File

@ -32,6 +32,7 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/bridge"
"github.com/ProtonMail/proton-bridge/v3/internal/events"
"github.com/ProtonMail/proton-bridge/v3/internal/kb"
"github.com/ProtonMail/proton-bridge/v3/internal/unleash"
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/cucumber/godog"
"github.com/golang/mock/gomock"
@ -55,7 +56,7 @@ func (s *scenario) bridgeStops() error {
func (s *scenario) bridgeVersionIsAndTheLatestAvailableVersionIsReachableFrom(current, latest, minAuto string) error {
s.t.version = semver.MustParse(current)
s.t.mocks.Updater.SetLatestVersion(semver.MustParse(latest), semver.MustParse(minAuto))
s.t.mocks.Updater.SetLatestVersionLegacy(semver.MustParse(latest), semver.MustParse(minAuto))
return nil
}
@ -361,8 +362,8 @@ func (s *scenario) bridgeSendsAnUpdateAvailableEventForVersion(version string) e
return errors.New("expected update event to be installable")
}
if !event.Version.Version.Equal(semver.MustParse(version)) {
return fmt.Errorf("expected update event for version %s, got %s", version, event.Version.Version)
if !event.VersionLegacy.Version.Equal(semver.MustParse(version)) {
return fmt.Errorf("expected update event for version %s, got %s", version, event.VersionLegacy.Version)
}
return nil
@ -378,8 +379,8 @@ func (s *scenario) bridgeSendsAManualUpdateEventForVersion(version string) error
return errors.New("expected update event to not be installable")
}
if !event.Version.Version.Equal(semver.MustParse(version)) {
return fmt.Errorf("expected update event for version %s, got %s", version, event.Version.Version)
if !event.VersionLegacy.Version.Equal(semver.MustParse(version)) {
return fmt.Errorf("expected update event for version %s, got %s", version, event.VersionLegacy.Version)
}
return nil
@ -391,8 +392,8 @@ func (s *scenario) bridgeSendsAnUpdateInstalledEventForVersion(version string) e
return errors.New("expected update installed event, got none")
}
if !event.Version.Version.Equal(semver.MustParse(version)) {
return fmt.Errorf("expected update installed event for version %s, got %s", version, event.Version.Version)
if !event.VersionLegacy.Version.Equal(semver.MustParse(version)) {
return fmt.Errorf("expected update installed event for version %s, got %s", version, event.VersionLegacy.Version)
}
return nil
@ -483,3 +484,25 @@ func (s *scenario) bridgeSMTPPortIs(expectedPort int) error {
return nil
}
func (s *scenario) bridgeLegacyUpdateKillSwitchEnabled() error {
unleash.ModifyPollPeriodAndJitter(5*time.Second, 0)
s.t.api.PushFeatureFlag(unleash.UpdateUseNewVersionFileStructureDisabled)
return nil
}
func (s *scenario) bridgeLegacyUpdateEnabled() error {
return eventually(func() error {
res := s.t.bridge.GetFeatureFlagValue(unleash.UpdateUseNewVersionFileStructureDisabled)
fmt.Println("RES", res)
if res != true {
return fmt.Errorf("expected the %v kill-switch to be enabled", unleash.UpdateUseNewVersionFileStructureDisabled)
}
return nil
})
}
func (s *scenario) bridgeChecksForUpdates() error {
s.t.bridge.CheckForUpdates()
return nil
}

View File

@ -89,8 +89,11 @@ func (r *reportRecorder) close() {
}
func (r *reportRecorder) assertEmpty() {
if !r.skipAssert {
r.assert.Empty(r.reports)
if !r.skipAssert && len(r.reports) > 0 {
for _, report := range r.reports {
// Sentry reports with failed syncs are expected, mostly due to sync context cancellations.
r.assert.Equal(report.message, "Failed to sync, will retry later")
}
}
}
@ -143,6 +146,11 @@ func (r *reportRecorder) ReportMessageWithContext(message string, context report
return nil
}
func (r *reportRecorder) ReportWarningWithContext(message string, context reporter.Context) error {
r.add(false, message, context)
return nil
}
func (r *reportRecorder) ReportExceptionWithContext(data any, context reporter.Context) error {
if context == nil {
context = reporter.Context{}

View File

@ -1,23 +1,34 @@
Feature: Bridge checks for updates
Background:
Given the legacy update kill switch is enabled
Scenario: Update not available
Given bridge is version "2.3.0" and the latest available version is "2.3.0" reachable from "2.3.0"
When bridge starts
And bridge verifies that the legacy update is enabled
And bridge checks for updates
Then bridge sends an update not available event
Scenario: Update available without automatic updates enabled
Given bridge is version "2.3.0" and the latest available version is "2.4.0" reachable from "2.3.0"
And the user has disabled automatic updates
When bridge starts
And bridge verifies that the legacy update is enabled
And bridge checks for updates
Then bridge sends an update available event for version "2.4.0"
Scenario: Update available with automatic updates enabled
Given bridge is version "2.3.0" and the latest available version is "2.4.0" reachable from "2.3.0"
When bridge starts
And bridge verifies that the legacy update is enabled
And bridge checks for updates
Then bridge sends an update installed event for version "2.4.0"
Scenario: Manual update available with automatic updates enabled
Given bridge is version "2.3.0" and the latest available version is "2.4.0" reachable from "2.4.0"
When bridge starts
And bridge verifies that the legacy update is enabled
And bridge checks for updates
Then bridge sends a manual update event for version "2.4.0"
Scenario: Update is required to continue using bridge

Some files were not shown because too many files have changed in this diff Show More