Compare commits

...

20 Commits

Author SHA1 Message Date
7c41c8e23a Other: Bridge James 1.8.3 2021-05-27 15:13:04 +02:00
36fdb88d96 GODT-1182: use correct contacts route 2021-05-27 14:15:00 +02:00
885fb95454 Other: Bridge James v1.8.2 2021-05-21 07:16:17 +02:00
629d6c5e4d GODT-1175: report bug 2021-05-20 15:24:43 +02:00
4072205709 Other: version bump and changelog Bridge James 1.8.2 2021-05-20 08:38:33 +02:00
5d82c218ca GODT-1165: Handle UID FETCH with sequence range of empty mailbox 2021-05-19 16:17:19 +02:00
6ff4c8a738 Other: Bridge James 1.8.0 2021-05-07 15:56:11 +02:00
dd66b7f8d0 GODT-1159 SMTP server not restarting after restored internet
- [x] write tests to check that IMAP and SMTP servers are closed when there
  is no internet
- [x] always create new go-smtp instance during listenAndServe(int)
2021-05-07 10:34:14 +00:00
0b95ed4dea GODT-1146: Refactor header filtering 2021-05-03 15:53:46 +02:00
ce64aeb05f Other: avoid API jail 2021-05-03 07:05:15 +02:00
27cfda680d GODT-1152: Correctly resolve wildcard sequence/UID set 2021-04-30 10:35:34 +00:00
323303a98b GODT-1089 Explicitly open system preferences window on BigSur. 2021-04-30 09:10:14 +00:00
8109831c07 GODT-35: Finish all details and make tests pass 2021-04-30 05:41:39 +02:00
2284e9ede1 GODT-35: New pmapi client and manager using resty 2021-04-30 05:34:36 +02:00
1d538e8540 GODT-876 Set default from if empty for importing draft 2021-04-29 14:07:20 +00:00
8ccaac8090 GODT-1056 Check encrypted size of the message before upload 2021-04-29 12:40:29 +00:00
22bf8f62ce GODT-1143 Turn off SMTP server while no connection 2021-04-28 11:23:41 +02:00
fed031ebaa Other: early release notes 1.7.1 2021-04-28 10:43:42 +02:00
7a15ebbd54 Other: update go.mod with qt docs 2021-04-27 17:11:23 +02:00
94b5799ba7 Other refactor: clean old builder 2021-04-27 08:12:50 +00:00
247 changed files with 7681 additions and 9796 deletions

View File

@ -2,6 +2,38 @@
Changelog [format](http://keepachangelog.com/en/1.0.0/)
## [Bridge 1.8.3] James
### Fixed
* GODT-1182: Use correct contact route.
## [Bridge 1.8.2] James
### Fixed
* GODT-1175: Bug reporting.
## [Bridge 1.8.1] James
### Fixed
* GODT-1165: Handle UID FETCH with sequence range of empty mailbox.
## [Bridge 1.8.0] James
### Added
* GODT-1056 Check encrypted size of the message before upload.
* GODT-1143 Turn off SMTP server while no connection.
* GODT-1089 Explicitly open system preferences window on BigSur.
### Fixed
* GODT-1159 SMTP server not restarting after restored internet.
* GODT-1146 Refactor handling of fetching BODY[HEADER] (and similar) regarding trailing newline.
* GODT-1152 Correctly resolve wildcard sequence/UID set.
* Other: Avoid API jail.
## [Bridge 1.7.1] Iron
### Fixed

View File

@ -10,7 +10,7 @@ TARGET_OS?=${GOOS}
.PHONY: build build-ie build-nogui build-ie-nogui build-launcher build-launcher-ie versioner hasher
# Keep version hardcoded so app build works also without Git repository.
BRIDGE_APP_VERSION?=1.7.1+git
BRIDGE_APP_VERSION?=1.8.3+git
IE_APP_VERSION?=1.3.3+git
APP_VERSION:=${BRIDGE_APP_VERSION}
SRC_ICO:=logo.ico
@ -253,13 +253,17 @@ bench:
coverage: test
go tool cover -html=/tmp/coverage.out -o=coverage.html
integration-test-bridge:
${MAKE} -C test test-bridge
mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/transfer PanicHandler,ClientManager,IMAPClientProvider > internal/transfer/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,ClientManager,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/transfer PanicHandler,IMAPClientProvider > internal/transfer/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client > pkg/pmapi/mocks/mocks.go
lint: gofiles lint-golang lint-license lint-changelog

8
go.mod
View File

@ -40,7 +40,7 @@ require (
github.com/fatih/color v1.9.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/getsentry/sentry-go v0.8.0
github.com/go-resty/resty/v2 v2.3.0
github.com/go-resty/resty/v2 v2.6.0
github.com/golang/mock v1.4.4
github.com/google/go-cmp v0.5.1
github.com/google/uuid v1.1.1
@ -50,7 +50,7 @@ require (
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
github.com/logrusorgru/aurora v2.0.3+incompatible
github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/miekg/dns v1.1.30
github.com/miekg/dns v1.1.41
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1
@ -59,10 +59,12 @@ require (
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/stretchr/testify v1.6.1
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/urfave/cli/v2 v2.2.0
github.com/vmihailenco/msgpack/v5 v5.1.3
go.etcd.io/bbolt v1.3.5
golang.org/x/net v0.0.0-20200707034311-ab3426394381
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec
)

35
go.sum
View File

@ -113,8 +113,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8=
github.com/go-resty/resty/v2 v2.3.0 h1:JOOeAvjSlapTT92p8xiS19Zxev1neGikoHsXJeOq8So=
github.com/go-resty/resty/v2 v2.3.0/go.mod h1:UpN9CgLZNsv4e9XG50UU8xdI0F43UQ4HmxLBDwaroHU=
github.com/go-resty/resty/v2 v2.6.0 h1:joIR5PNLM2EFqqESUjCMGXrWmXNHEU9CEiK813oKYS4=
github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
@ -195,8 +195,8 @@ github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.1.30 h1:Qww6FseFn8PRfw07jueqIXqodm0JKiiKuK0DeXSqfyo=
github.com/miekg/dns v1.1.30/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY=
github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -262,6 +262,11 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e h1:G0DQ/TRQyrEZjtLlLwevFjaRiG8eeCMlq9WXQ2OO2bk=
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e/go.mod h1:SUUR2j3aE1z6/g76SdD6NwACEpvCxb3fvG82eKbD6us=
github.com/therecipe/qt v0.0.0-20200904063919-c0c124a5770d h1:T+d8FnaLSvM/1BdlDXhW4d5dr2F07bAbB+LpgzMxx+o=
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d h1:hAZyEG2swPRWjF0kqqdGERXUazYnRJdAk4a58f14z7Y=
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:7m8PDYDEtEVqfjoUQc2UrFqhG0CDmoVJjRlQxexndFc=
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d h1:AJRoBel/g9cDS+yE8BcN3E+TDD/xNAguG21aoR8DAIE=
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:mH55Ek7AZcdns5KPp99O0bg+78el64YCYWHiQKrOdt4=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
@ -305,16 +310,18 @@ golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -325,14 +332,19 @@ golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04 h1:cEhElsAv9LUt9ZUUocxzWe05oFLVd+AA2nstydTeI8g=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec h1:A1qYjneJuzBZZ2gIB8rd6zrfq6l7SoEMJ8EsSilNK/U=
golang.org/x/text v0.3.5-0.20201125200606-c27b9fd57aec/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -343,7 +355,6 @@ golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69 h1:yBHHx+XZqXJBm6Exke3N7V9gnlsyXxoCPEb1yVenjfk=
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -23,7 +23,7 @@
// - persistent settings
// - event listener
// - credentials store
// - pmapi ClientManager
// - pmapi Manager
// In addition, the base initialises logging and reacts to command line arguments
// which control the log verbosity and enable cpu/memory profiling.
package base
@ -85,7 +85,7 @@ type Base struct {
Cache *cache.Cache
Listener listener.Listener
Creds *credentials.Store
CM *pmapi.ClientManager
CM pmapi.Manager
CookieJar *cookies.Jar
UserAgent *useragent.UserAgent
Updater *updater.Updater
@ -181,13 +181,23 @@ func New( // nolint[funlen]
kc = keychain.NewMissingKeychain()
}
cfg := pmapi.NewConfig(configName, constants.Version)
cfg.GetUserAgent = userAgent.String
cfg.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
cfg.TLSIssueHandler = func() { listener.Emit(events.TLSCertIssue, "") }
cm := pmapi.New(cfg)
cm.AddConnectionObserver(pmapi.NewConnectionObserver(
func() { listener.Emit(events.InternetOffEvent, "") },
func() { listener.Emit(events.InternetOnEvent, "") },
))
jar, err := cookies.NewCookieJar(settingsObj)
if err != nil {
return nil, err
}
cm := pmapi.NewClientManager(getAPIConfig(configName, listener), userAgent)
cm.SetRoundTripper(pmapi.GetRoundTripper(cm, listener))
cm.SetCookieJar(jar)
key, err := crypto.NewKeyFromArmored(updater.DefaultPublicKey)
@ -328,6 +338,7 @@ func (b *Base) run(appMainLoop func(*Base, *cli.Context) error) cli.ActionFunc {
}
logging.SetLevel(c.String(flagLogLevel))
b.CM.SetLogging(logrus.WithField("pkg", "pmapi"), logrus.GetLevel() == logrus.TraceLevel)
logrus.
WithField("appName", b.Name).
@ -375,13 +386,3 @@ func (b *Base) doTeardown() error {
return nil
}
func getAPIConfig(configName string, listener listener.Listener) *pmapi.ClientConfig {
apiConfig := pmapi.GetAPIConfig(configName, constants.Version)
apiConfig.ConnectionOffHandler = func() { listener.Emit(events.InternetOffEvent, "") }
apiConfig.ConnectionOnHandler = func() { listener.Emit(events.InternetOnEvent, "") }
apiConfig.UpgradeApplicationHandler = func() { listener.Emit(events.UpgradeApplicationEvent, "") }
return apiConfig
}

View File

@ -95,6 +95,7 @@ func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
smtpPort := b.Settings.GetInt(settings.SMTPPortKey)
useSSL := b.Settings.GetBool(settings.SMTPSSLKey)
smtp.NewSMTPServer(
b.CrashHandler,
c.Bool(flagLogSMTP),
smtpPort, useSSL, tlsConfig, smtpBackend, b.Listener).ListenAndServe()
}()

View File

@ -19,6 +19,7 @@
package bridge
import (
"context"
"fmt"
"strconv"
"time"
@ -44,7 +45,7 @@ type Bridge struct {
locations Locator
settings SettingsProvider
clientManager users.ClientManager
clientManager pmapi.Manager
updater Updater
versioner Versioner
}
@ -56,7 +57,7 @@ func New(
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
clientManager users.ClientManager,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
updater Updater,
versioner Versioner,
@ -67,7 +68,7 @@ func New(
clientManager.AllowProxy()
}
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, clientManager, eventListener)
storeFactory := newStoreFactory(cache, sentryReporter, panicHandler, eventListener)
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, storeFactory, true)
b := &Bridge{
Users: u,
@ -118,28 +119,15 @@ func (b *Bridge) heartbeat() {
// ReportBug reports a new bug from the user.
func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
c := b.clientManager.GetAnonymousClient()
defer c.Logout()
title := "[Bridge] Bug"
report := pmapi.ReportReq{
return b.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Browser: emailClient,
Title: title,
Title: "[Bridge] Bug",
Description: description,
Username: accountName,
Email: address,
}
if err := c.Report(report); err != nil {
log.Error("Reporting bug failed: ", err)
return err
}
log.Info("Bug successfully reported")
return nil
})
}
// GetUpdateChannel returns currently set update channel.

View File

@ -31,7 +31,6 @@ type storeFactory struct {
cache Cacher
sentryReporter *sentry.Reporter
panicHandler users.PanicHandler
clientManager users.ClientManager
eventListener listener.Listener
storeCache *store.Cache
}
@ -40,14 +39,12 @@ func newStoreFactory(
cache Cacher,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
clientManager users.ClientManager,
eventListener listener.Listener,
) *storeFactory {
return &storeFactory{
cache: cache,
sentryReporter: sentryReporter,
panicHandler: panicHandler,
clientManager: clientManager,
eventListener: eventListener,
storeCache: store.NewCache(cache.GetIMAPCachePath()),
}
@ -56,7 +53,7 @@ func newStoreFactory(
// New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
return store.New(f.sentryReporter, f.panicHandler, user, f.clientManager, f.eventListener, storePath, f.storeCache)
return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
}
// Remove removes all store files for given user.

View File

@ -25,8 +25,20 @@ import (
"github.com/Masterminds/semver/v3"
)
// IsCatalinaOrNewer checks whether host is MacOS Catalina 10.15.x or higher.
// IsCatalinaOrNewer checks whether the host is MacOS Catalina 10.15.x or higher.
func IsCatalinaOrNewer() bool {
return isThisDarwinNewerOrEqual(getMinCatalina())
}
// IsBigSurOrNewer checks whether the host is MacOS BigSur 10.16.x or higher.
func IsBigSurOrNewer() bool {
return isThisDarwinNewerOrEqual(getMinBigSur())
}
func getMinCatalina() *semver.Version { return semver.MustParse("10.15.0") }
func getMinBigSur() *semver.Version { return semver.MustParse("10.16.0") }
func isThisDarwinNewerOrEqual(minVersion *semver.Version) bool {
if runtime.GOOS != "darwin" {
return false
}
@ -36,16 +48,14 @@ func IsCatalinaOrNewer() bool {
return false
}
return isVersionCatalinaOrNewer(strings.TrimSpace(string(rawVersion)))
return isVersionEqualOrNewer(minVersion, strings.TrimSpace(string(rawVersion)))
}
func isVersionCatalinaOrNewer(rawVersion string) bool {
// isVersionEqualOrNewer is separated to be able to run test on other than darwin.
func isVersionEqualOrNewer(minVersion *semver.Version, rawVersion string) bool {
semVersion, err := semver.NewVersion(rawVersion)
if err != nil {
return false
}
minVersion := semver.MustParse("10.15.0")
return semVersion.GreaterThan(minVersion) || semVersion.Equal(minVersion)
}

View File

@ -38,7 +38,27 @@ func TestIsVersionCatalinaOrNewer(t *testing.T) {
}
for args, exp := range testData {
got := isVersionCatalinaOrNewer(args.version)
got := isVersionEqualOrNewer(getMinCatalina(), args.version)
assert.Equal(t, exp, got, "version %v", args.version)
}
}
func TestIsVersionBigSurOrNewer(t *testing.T) {
testData := map[struct{ version string }]bool{
{""}: false,
{"9.0.0"}: false,
{"9.15.0"}: false,
{"10.13.0"}: false,
{"10.14.0"}: false,
{"10.14.99"}: false,
{"10.15.0"}: false,
{"10.16.0"}: true,
{"11.0.0"}: true,
{"11.1"}: true,
}
for args, exp := range testData {
got := isVersionEqualOrNewer(getMinBigSur(), args.version)
assert.Equal(t, exp, got, "version %v", args.version)
}
}

View File

@ -29,10 +29,15 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/pkg/mobileconfig"
)
const (
bigSurPreferncesPane = "/System/Library/PreferencePanes/Profiles.prefPane"
)
func init() { //nolint[gochecknoinit]
available = append(available, &appleMail{})
}
@ -43,7 +48,22 @@ func (c *appleMail) Name() string {
return "Apple Mail"
}
func (c *appleMail) Configure(imapPort, smtpPort int, imapSSL, smtpSSL bool, user types.User, addressIndex int) error { //nolint[funlen]
func (c *appleMail) Configure(imapPort, smtpPort int, imapSSL, smtpSSL bool, user types.User, addressIndex int) error {
mc := prepareMobileConfig(imapPort, smtpPort, imapSSL, smtpSSL, user, addressIndex)
confPath, err := saveConfigTemporarily(mc)
if err != nil {
return err
}
if useragent.IsBigSurOrNewer() {
return exec.Command("open", bigSurPreferncesPane, confPath).Run() //nolint[gosec] G204: open command is safe, mobileconfig is generated by us
}
return exec.Command("open", confPath).Run() //nolint[gosec] G204: open command is safe, mobileconfig is generated by us
}
func prepareMobileConfig(imapPort, smtpPort int, imapSSL, smtpSSL bool, user types.User, addressIndex int) *mobileconfig.Config {
var addresses string
var displayName string
@ -62,7 +82,7 @@ func (c *appleMail) Configure(imapPort, smtpPort int, imapSSL, smtpSSL bool, use
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
mc := &mobileconfig.Config{
return &mobileconfig.Config{
EmailAddress: addresses,
DisplayName: displayName,
Identifier: "protonmail " + displayName + timestamp,
@ -80,10 +100,12 @@ func (c *appleMail) Configure(imapPort, smtpPort int, imapSSL, smtpSSL bool, use
Username: displayName,
},
}
}
func saveConfigTemporarily(mc *mobileconfig.Config) (fname string, err error) {
dir, err := ioutil.TempDir("", "protonmail-autoconfig")
if err != nil {
return err
return
}
// Make sure the temporary file is deleted.
@ -93,16 +115,17 @@ func (c *appleMail) Configure(imapPort, smtpPort int, imapSSL, smtpSSL bool, use
})()
// Make sure the file is only readable for the current user.
f, err := os.OpenFile(filepath.Clean(filepath.Join(dir, "protonmail.mobileconfig")), os.O_RDWR|os.O_CREATE, 0600)
fname = filepath.Clean(filepath.Join(dir, "protonmail.mobileconfig"))
f, err := os.OpenFile(fname, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return err
return
}
if err := mc.WriteOut(f); err != nil {
if err = mc.WriteOut(f); err != nil {
_ = f.Close()
return err
return
}
_ = f.Close()
return exec.Command("open", f.Name()).Run() // nolint[gosec]
return
}

View File

@ -18,6 +18,7 @@
package cliie
import (
"context"
"strings"
"github.com/abiosoft/ishell"
@ -79,7 +80,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return
}
err = client.Auth2FA(twoFactor, auth)
err = client.Auth2FA(context.Background(), twoFactor)
if err != nil {
f.processAPIError(err)
return

View File

@ -84,11 +84,6 @@ func New( //nolint[funlen]
Aliases: []string{"u", "version", "v"},
Func: fe.checkUpdates,
})
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
Help: "check internet connection. (aliases: i, conn, connection)",
Aliases: []string{"i", "con", "connection"},
Func: fe.checkInternetConnection,
})
fe.AddCmd(checkCmd)
// Print info commands.
@ -177,13 +172,13 @@ func New( //nolint[funlen]
}
func (f *frontendCLI) watchEvents() {
errorCh := f.getEventChannel(events.ErrorEvent)
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
internetOffCh := f.getEventChannel(events.InternetOffEvent)
internetOnCh := f.getEventChannel(events.InternetOnEvent)
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
logoutCh := f.getEventChannel(events.LogoutEvent)
certIssue := f.getEventChannel(events.TLSCertIssue)
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
@ -208,13 +203,6 @@ func (f *frontendCLI) watchEvents() {
}
}
func (f *frontendCLI) getEventChannel(event string) <-chan string {
ch := make(chan string)
f.eventListener.Add(event, ch)
f.eventListener.RetryEmit(event)
return ch
}
// Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error {
f.Print(`

View File

@ -29,14 +29,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
}
}
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
if f.ie.CheckConnection() == nil {
f.Println("Internet connection is available.")
} else {
f.Println("Can not contact the server, please check your internet connection.")
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.locations.ProvideLogsPath(); err != nil {
f.Println("Failed to determine location of log files")

View File

@ -20,7 +20,7 @@ package cliie
import (
"strings"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/fatih/color"
)
@ -71,7 +71,7 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err)
switch err {
case pmapi.ErrAPINotReachable:
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()

View File

@ -18,6 +18,7 @@
package cli
import (
"context"
"strings"
"github.com/ProtonMail/proton-bridge/internal/bridge"
@ -126,7 +127,7 @@ func (f *frontendCLI) loginAccount(c *ishell.Context) { // nolint[funlen]
return
}
err = client.Auth2FA(twoFactor, auth)
err = client.Auth2FA(context.Background(), twoFactor)
if err != nil {
f.processAPIError(err)
return

View File

@ -157,15 +157,6 @@ func New( //nolint[funlen]
})
fe.AddCmd(updatesCmd)
// Check commands.
checkCmd := &ishell.Cmd{Name: "check", Help: "check internet connection or new version."}
checkCmd.AddCmd(&ishell.Cmd{Name: "internet",
Help: "check internet connection. (aliases: i, conn, connection)",
Aliases: []string{"i", "con", "connection"},
Func: fe.checkInternetConnection,
})
fe.AddCmd(checkCmd)
// Print info commands.
fe.AddCmd(&ishell.Cmd{Name: "log-dir",
Help: "print path to directory with logs. (aliases: log, logs)",
@ -228,14 +219,14 @@ func New( //nolint[funlen]
}
func (f *frontendCLI) watchEvents() {
errorCh := f.getEventChannel(events.ErrorEvent)
credentialsErrorCh := f.getEventChannel(events.CredentialsErrorEvent)
internetOffCh := f.getEventChannel(events.InternetOffEvent)
internetOnCh := f.getEventChannel(events.InternetOnEvent)
addressChangedCh := f.getEventChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.getEventChannel(events.AddressChangedLogoutEvent)
logoutCh := f.getEventChannel(events.LogoutEvent)
certIssue := f.getEventChannel(events.TLSCertIssue)
errorCh := f.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
certIssue := f.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
@ -262,13 +253,6 @@ func (f *frontendCLI) watchEvents() {
}
}
func (f *frontendCLI) getEventChannel(event string) <-chan string {
ch := make(chan string)
f.eventListener.Add(event, ch)
f.eventListener.RetryEmit(event)
return ch
}
// Loop starts the frontend loop with an interactive shell.
func (f *frontendCLI) Loop() error {
f.Print(`

View File

@ -39,14 +39,6 @@ func (f *frontendCLI) restart(c *ishell.Context) {
}
}
func (f *frontendCLI) checkInternetConnection(c *ishell.Context) {
if f.bridge.CheckConnection() == nil {
f.Println("Internet connection is available.")
} else {
f.Println("Can not contact the server, please check your internet connection.")
}
}
func (f *frontendCLI) printLogDir(c *ishell.Context) {
if path, err := f.locations.ProvideLogsPath(); err != nil {
f.Println("Failed to determine location of log files")

View File

@ -71,7 +71,7 @@ func (f *frontendCLI) printAndLogError(args ...interface{}) {
func (f *frontendCLI) processAPIError(err error) {
log.Warn("API error: ", err)
switch err {
case pmapi.ErrAPINotReachable:
case pmapi.ErrNoConnection:
f.notifyInternetOff()
case pmapi.ErrUpgradeApplication:
f.notifyNeedUpgrade()

View File

@ -409,7 +409,6 @@ Dialog {
onShow: {
if (winMain.updateState==gui.enums.statusNoInternet) {
go.checkInternet()
if (winMain.updateState==gui.enums.statusNoInternet) {
go.notifyError(gui.enums.errNoInternet)
root.hide()

View File

@ -857,14 +857,12 @@ Dialog {
inputPort . checkIsANumber()
//emailProvider . currentIndex!=0
)) isOK = false
go.checkInternet()
if (winMain.updateState == gui.enums.statusNoInternet) { // todo: use main error dialog for this
errorPopup.show(qsTr("Please check your internet connection."))
return false
}
break
case 2: // loading structure
go.checkInternet()
if (winMain.updateState == gui.enums.statusNoInternet) {
errorPopup.show(qsTr("Please check your internet connection."))
return false
@ -949,7 +947,6 @@ Dialog {
onShow : {
root.clear()
if (winMain.updateState==gui.enums.statusNoInternet) {
go.checkInternet()
if (winMain.updateState==gui.enums.statusNoInternet) {
winMain.popupMessage.show(go.canNotReachAPI)
root.hide()

View File

@ -25,33 +25,12 @@ import ProtonUI 1.0
Rectangle {
id: root
property var iTry: 0
property var secLeft: 0
property var second: 1000 // convert millisecond to second
property var checkInterval: [ 5, 10, 30, 60, 120, 300, 600 ] // seconds
property bool isVisible: true
property var fontSize : 1.2 * Style.main.fontSize
color : "black"
state: "upToDate"
Timer {
id: retryInternet
interval: second
triggeredOnStart: false
repeat: true
onTriggered : {
secLeft--
if (secLeft <= 0) {
retryInternet.stop()
go.checkInternet()
if (iTry < checkInterval.length-1) {
iTry++
}
secLeft=checkInterval[iTry]
retryInternet.start()
}
}
}
Row {
id: messageRow
anchors.centerIn: root
@ -110,16 +89,12 @@ Rectangle {
case "internetCheck":
break;
case "noInternet" :
retryInternet.start()
secLeft=checkInterval[iTry]
break;
case "oldVersion":
break;
case "forceUpdate":
break;
case "upToDate":
iTry = 0
secLeft=checkInterval[iTry]
break;
case "updateRestart":
break;
@ -128,24 +103,6 @@ Rectangle {
default :
break;
}
if (root.state!="noInternet") {
retryInternet.stop()
}
}
function timeToRetry() {
if (secLeft==1){
return qsTr("a second", "time to wait till internet connection is retried")
} else if (secLeft<60){
return secLeft + " " + qsTr("seconds", "time to wait till internet connection is retried")
} else {
var leading = ""+secLeft%60
if (leading.length < 2) {
leading = "0" + leading
}
return Math.floor(secLeft/60) + ":" + leading
}
}
states: [
@ -194,23 +151,15 @@ Rectangle {
PropertyChanges {
target: message
color: Style.main.line
text: qsTr("Cannot contact server. Retrying in ", "displayed when the app is disconnected from the internet or server has problems")+timeToRetry()+"."
text: qsTr("Cannot contact server. Please wait...", "displayed when the app is disconnected from the internet or server has problems")
}
PropertyChanges {
target: linkText
visible: false
}
PropertyChanges {
target: actionText
visible: true
text: qsTr("Retry now", "click to try to connect to the internet when the app is disconnected from the internet")
onClicked: {
go.checkInternet()
}
}
PropertyChanges {
target: separatorText
visible: true
visible: false
text: "|"
}
PropertyChanges {

View File

@ -1331,10 +1331,6 @@ Window {
return (fname!="fail")
}
function checkInternet() {
// nothing to do
}
function loadImportReports(fname) {
console.log("load import reports for ", fname)
}

View File

@ -20,6 +20,7 @@
package qtcommon
import (
"context"
"fmt"
"strings"
"sync"
@ -164,7 +165,7 @@ func (a *Accounts) showLoginError(err error, scope string) bool {
return false
}
log.Warnf("%s: %v", scope, err)
if err == pmapi.ErrAPINotReachable {
if err == pmapi.ErrNoConnection {
a.qml.SetConnectionStatus(false)
SendNotification(a.qml, TabAccount, a.qml.CanNotReachAPI())
a.qml.ProcessFinished()
@ -207,7 +208,7 @@ func (a *Accounts) Auth2FA(twoFacAuth string) int {
if a.auth == nil || a.authClient == nil {
err = fmt.Errorf("missing authentication in auth2FA %p %p", a.auth, a.authClient)
} else {
err = a.authClient.Auth2FA(twoFacAuth, a.auth)
err = a.authClient.Auth2FA(context.Background(), twoFacAuth)
}
if a.showLoginError(err, "auth2FA") {

View File

@ -113,10 +113,3 @@ type Listener interface {
Add(string, chan<- string)
RetryEmit(string)
}
func MakeAndRegisterEvent(eventListener Listener, event string) <-chan string {
ch := make(chan string)
eventListener.Add(event, ch)
eventListener.RetryEmit(event)
return ch
}

View File

@ -143,16 +143,16 @@ func (f *FrontendQt) NotifySilentUpdateError(err error) {
}
func (f *FrontendQt) watchEvents() {
credentialsErrorCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.CredentialsErrorEvent)
internetOffCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOffEvent)
internetOnCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.InternetOnEvent)
secondInstanceCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.SecondInstanceEvent)
restartBridgeCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.RestartBridgeEvent)
addressChangedCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedEvent)
addressChangedLogoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.AddressChangedLogoutEvent)
logoutCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.LogoutEvent)
updateApplicationCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UpgradeApplicationEvent)
newUserCh := qtcommon.MakeAndRegisterEvent(f.eventListener, events.UserRefreshEvent)
credentialsErrorCh := f.eventListener.ProvideChannel(events.CredentialsErrorEvent)
internetOffCh := f.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := f.eventListener.ProvideChannel(events.InternetOnEvent)
secondInstanceCh := f.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := f.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := f.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := f.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := f.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := f.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
newUserCh := f.eventListener.ProvideChannel(events.UserRefreshEvent)
for {
select {
case <-credentialsErrorCh:
@ -351,11 +351,6 @@ func (f *FrontendQt) sendBug(description, emailClient, address string) bool {
// }
//}
// checkInternet is almost idetical to bridge
func (f *FrontendQt) checkInternet() {
f.Qml.SetConnectionStatus(f.ie.CheckConnection() == nil)
}
func (f *FrontendQt) showError(code int, err error) {
f.Qml.SetErrorDescription(err.Error())
log.WithField("code", code).Errorln(err.Error())

View File

@ -78,7 +78,6 @@ type GoQMLInterface struct {
_ string `property:"versionCheckFailed"`
//
_ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"checkInternet"`
_ func() `slot:"setToRestart"`
@ -189,8 +188,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
return f.programVersion
})
s.ConnectCheckInternet(f.checkInternet)
s.ConnectSetToRestart(f.restarter.SetToRestart)
s.ConnectLoadStructureForExport(f.LoadStructureForExport)

View File

@ -20,6 +20,7 @@
package qt
import (
"context"
"fmt"
"strings"
@ -130,7 +131,7 @@ func (s *FrontendQt) showLoginError(err error, scope string) bool {
return false
}
log.Warnf("%s: %v", scope, err)
if err == pmapi.ErrAPINotReachable {
if err == pmapi.ErrNoConnection {
s.Qml.SetConnectionStatus(false)
s.SendNotification(TabAccount, s.Qml.CanNotReachAPI())
s.Qml.ProcessFinished()
@ -173,7 +174,7 @@ func (s *FrontendQt) auth2FA(twoFacAuth string) int {
if s.auth == nil || s.authClient == nil {
err = fmt.Errorf("missing authentication in auth2FA %p %p", s.auth, s.authClient)
} else {
err = s.authClient.Auth2FA(twoFacAuth, s.auth)
err = s.authClient.Auth2FA(context.Background(), twoFacAuth)
}
if s.showLoginError(err, "auth2FA") {

View File

@ -191,20 +191,20 @@ func (s *FrontendQt) NotifySilentUpdateError(err error) {
func (s *FrontendQt) watchEvents() {
s.WaitUntilFrontendIsReady()
errorCh := s.getEventChannel(events.ErrorEvent)
credentialsErrorCh := s.getEventChannel(events.CredentialsErrorEvent)
outgoingNoEncCh := s.getEventChannel(events.OutgoingNoEncEvent)
noActiveKeyForRecipientCh := s.getEventChannel(events.NoActiveKeyForRecipientEvent)
internetOffCh := s.getEventChannel(events.InternetOffEvent)
internetOnCh := s.getEventChannel(events.InternetOnEvent)
secondInstanceCh := s.getEventChannel(events.SecondInstanceEvent)
restartBridgeCh := s.getEventChannel(events.RestartBridgeEvent)
addressChangedCh := s.getEventChannel(events.AddressChangedEvent)
addressChangedLogoutCh := s.getEventChannel(events.AddressChangedLogoutEvent)
logoutCh := s.getEventChannel(events.LogoutEvent)
updateApplicationCh := s.getEventChannel(events.UpgradeApplicationEvent)
newUserCh := s.getEventChannel(events.UserRefreshEvent)
certIssue := s.getEventChannel(events.TLSCertIssue)
errorCh := s.eventListener.ProvideChannel(events.ErrorEvent)
credentialsErrorCh := s.eventListener.ProvideChannel(events.CredentialsErrorEvent)
outgoingNoEncCh := s.eventListener.ProvideChannel(events.OutgoingNoEncEvent)
noActiveKeyForRecipientCh := s.eventListener.ProvideChannel(events.NoActiveKeyForRecipientEvent)
internetOffCh := s.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := s.eventListener.ProvideChannel(events.InternetOnEvent)
secondInstanceCh := s.eventListener.ProvideChannel(events.SecondInstanceEvent)
restartBridgeCh := s.eventListener.ProvideChannel(events.RestartBridgeEvent)
addressChangedCh := s.eventListener.ProvideChannel(events.AddressChangedEvent)
addressChangedLogoutCh := s.eventListener.ProvideChannel(events.AddressChangedLogoutEvent)
logoutCh := s.eventListener.ProvideChannel(events.LogoutEvent)
updateApplicationCh := s.eventListener.ProvideChannel(events.UpgradeApplicationEvent)
newUserCh := s.eventListener.ProvideChannel(events.UserRefreshEvent)
certIssue := s.eventListener.ProvideChannel(events.TLSCertIssue)
for {
select {
case errorDetails := <-errorCh:
@ -254,13 +254,6 @@ func (s *FrontendQt) watchEvents() {
}
}
func (s *FrontendQt) getEventChannel(event string) <-chan string {
ch := make(chan string)
s.eventListener.Add(event, ch)
s.eventListener.RetryEmit(event)
return ch
}
// Loop function for tests.
//
// It runs QtExecute in new thread with function returning itself after setup.
@ -653,10 +646,6 @@ func (s *FrontendQt) isSMTPSTARTTLS() bool {
return !s.settings.GetBool(settings.SMTPSSLKey)
}
func (s *FrontendQt) checkInternet() {
s.Qml.SetConnectionStatus(s.bridge.CheckConnection() == nil)
}
func (s *FrontendQt) switchAddressModeUser(iAccount int) {
defer s.Qml.ProcessFinished()
userID := s.Accounts.get(iAccount).UserID()

View File

@ -84,7 +84,6 @@ type GoQMLInterface struct {
_ string `property:"progressDescription"`
_ func(isAvailable bool) `signal:"setConnectionStatus"`
_ func() `slot:"checkInternet"`
_ func() `slot:"setToRestart"`
@ -205,8 +204,6 @@ func (s *GoQMLInterface) SetFrontend(f *FrontendQt) {
return f.programVer
})
s.ConnectCheckInternet(f.checkInternet)
s.ConnectSetToRestart(f.restarter.SetToRestart)
s.ConnectToggleIsReportingOutgoingNoEnc(f.toggleIsReportingOutgoingNoEnc)

View File

@ -55,7 +55,6 @@ type UserManager interface {
GetUser(query string) (User, error)
DeleteUser(userID string, clearCache bool) error
ClearData() error
CheckConnection() error
}
// User is an interface of user needed by frontend.

View File

@ -38,11 +38,10 @@ type bridgeUser interface {
IsCombinedAddressMode() bool
GetAddressID(address string) (string, error)
GetPrimaryAddress() string
UpdateUser() error
Logout() error
CloseConnection(address string)
GetStore() storeUserProvider
GetTemporaryPMAPIClient() pmapi.Client
GetClient() pmapi.Client
}
type bridgeWrap struct {

View File

@ -0,0 +1,198 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"io"
"net/mail"
"strings"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/internal/imap/uidplus"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
)
// CreateMessage appends a new message to this mailbox. The \Recent flag will
// be added regardless of whether flags is empty or not. If date is nil, the
// current time will be used.
//
// If the Backend implements Updater, it must notify the client immediately
// via a mailbox update.
func (im *imapMailbox) CreateMessage(flags []string, date time.Time, body imap.Literal) error {
return im.logCommand(func() error {
return im.createMessage(flags, date, body)
}, "APPEND", flags, date)
}
func (im *imapMailbox) createMessage(flags []string, date time.Time, body imap.Literal) error { //nolint[funlen]
// Called from go-imap in goroutines - we need to handle panics for each function.
defer im.panicHandler.HandlePanic()
m, _, _, readers, err := message.Parse(body)
if err != nil {
return err
}
addr := im.storeAddress.APIAddress()
if addr == nil {
return errors.New("no available address for encryption")
}
m.AddressID = addr.ID
kr, err := im.user.client().KeyRingForAddressID(addr.ID)
if err != nil {
return err
}
// Handle imported messages which have no "Sender" address.
// This sometimes occurs with outlook which reports errors as imported emails or for drafts.
if m.Sender == nil {
im.log.Warning("Append: Missing email sender. Will use main address")
m.Sender = &mail.Address{
Name: "",
Address: addr.Email,
}
}
// "Drafts" needs to call special API routes.
// Clients always append the whole message again and remove the old one.
if im.storeMailbox.LabelID() == pmapi.DraftLabel {
// Sender address needs to be sanitised (drafts need to match cases exactly).
m.Sender.Address = pmapi.ConstructAddress(m.Sender.Address, addr.Email)
draft, _, err := im.user.storeUser.CreateDraft(kr, m, readers, "", "", "")
if err != nil {
return errors.Wrap(err, "failed to create draft")
}
targetSeq := im.storeMailbox.GetUIDList([]string{draft.ID})
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), targetSeq)
}
// We need to make sure this is an import, and not a sent message from this account
// (sent messages from the account will be added by the event loop).
if im.storeMailbox.LabelID() == pmapi.SentLabel {
sanitizedSender := pmapi.SanitizeEmail(m.Sender.Address)
// Check whether this message was sent by a bridge user.
user, err := im.user.backend.bridge.GetUser(sanitizedSender)
if err == nil && user.ID() == im.storeUser.UserID() {
logEntry := im.log.WithField("addr", sanitizedSender).WithField("extID", m.Header.Get("Message-Id"))
// If we find the message in the store already, we can skip importing it.
if foundUID := im.storeMailbox.GetUIDByHeader(&m.Header); foundUID != uint32(0) {
logEntry.Info("Ignoring APPEND of duplicate to Sent folder")
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), &uidplus.OrderedSeq{foundUID})
}
// We didn't find the message in the store, so we are currently sending it.
logEntry.WithField("time", date).Info("No matching UID, continuing APPEND to Sent")
}
}
message.ParseFlags(m, flags)
if !date.IsZero() {
m.Time = date.Unix()
}
internalID := m.Header.Get("X-Pm-Internal-Id")
references := m.Header.Get("References")
referenceList := strings.Fields(references)
// In case there is a mail client which corrupts headers, try
// "References" too.
if internalID == "" && len(referenceList) > 0 {
lastReference := referenceList[len(referenceList)-1]
match := pmapi.RxInternalReferenceFormat.FindStringSubmatch(lastReference)
if len(match) == 2 {
internalID = match[1]
}
}
im.user.appendExpungeLock.Lock()
defer im.user.appendExpungeLock.Unlock()
// Avoid appending a message which is already on the server. Apply the
// new label instead. This always happens with Outlook (it uses APPEND
// instead of COPY).
if internalID != "" {
// Check to see if this belongs to a different address in split mode or another ProtonMail account.
msg, err := im.storeMailbox.GetMessage(internalID)
if err == nil && (im.user.user.IsCombinedAddressMode() || (im.storeAddress.AddressID() == msg.Message().AddressID)) {
IDs := []string{internalID}
// See the comment bellow.
if msg.IsMarkedDeleted() {
if err := im.storeMailbox.MarkMessagesUndeleted(IDs); err != nil {
log.WithError(err).Error("Failed to undelete re-imported internal message")
}
}
err = im.storeMailbox.LabelMessages(IDs)
if err != nil {
return err
}
targetSeq := im.storeMailbox.GetUIDList(IDs)
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), targetSeq)
}
}
im.log.Info("Importing external message")
if err := im.importMessage(m, readers, kr); err != nil {
im.log.Error("Import failed: ", err)
return err
}
// IMAP clients can move message to local folder (setting \Deleted flag)
// and then move it back (IMAP client does not remember the message,
// so instead removing the flag it imports duplicate message).
// Regular IMAP server would keep the message twice and later EXPUNGE would
// not delete the message (EXPUNGE would delete the original message and
// the new duplicate one would stay). API detects duplicates; therefore
// we need to remove \Deleted flag if IMAP client re-imports.
msg, err := im.storeMailbox.GetMessage(m.ID)
if err == nil && msg.IsMarkedDeleted() {
if err := im.storeMailbox.MarkMessagesUndeleted([]string{m.ID}); err != nil {
log.WithError(err).Error("Failed to undelete re-imported message")
}
}
targetSeq := im.storeMailbox.GetUIDList([]string{m.ID})
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), targetSeq)
}
func (im *imapMailbox) importMessage(m *pmapi.Message, readers []io.Reader, kr *crypto.KeyRing) (err error) {
body, err := message.BuildEncrypted(m, readers, kr)
if err != nil {
return err
}
labels := []string{}
for _, l := range m.LabelIDs {
if l == pmapi.StarredLabel {
labels = append(labels, pmapi.StarredLabel)
}
}
return im.storeMailbox.ImportMessage(m, body, labels)
}

View File

@ -0,0 +1,322 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
func (im *imapMailbox) getMessage(
storeMessage storeMessageProvider,
items []imap.FetchItem,
msgBuildCountHistogram *msgBuildCountHistogram,
) (msg *imap.Message, err error) {
msglog := im.log.WithField("msgID", storeMessage.ID())
msglog.Trace("Getting message")
seqNum, err := storeMessage.SequenceNumber()
if err != nil {
return
}
m := storeMessage.Message()
msg = imap.NewMessage(seqNum, items)
for _, item := range items {
switch item {
case imap.FetchEnvelope:
// No need to check IsFullHeaderCached here. API header
// contain enough information to build the envelope.
msg.Envelope = message.GetEnvelope(m, storeMessage.GetMIMEHeader())
case imap.FetchBody, imap.FetchBodyStructure:
structure, err := im.getBodyStructure(storeMessage)
if err != nil {
return nil, err
}
if msg.BodyStructure, err = structure.IMAPBodyStructure([]int{}); err != nil {
return nil, err
}
case imap.FetchFlags:
msg.Flags = message.GetFlags(m)
if storeMessage.IsMarkedDeleted() {
msg.Flags = append(msg.Flags, imap.DeletedFlag)
}
case imap.FetchInternalDate:
// Apple Mail crashes fetching messages with date older than 1970.
// There is no point having message older than RFC itself, it's not possible.
msg.InternalDate = message.SanitizeMessageDate(m.Time)
case imap.FetchRFC822Size:
if msg.Size, err = im.getSize(storeMessage); err != nil {
return nil, err
}
case imap.FetchUid:
if msg.Uid, err = storeMessage.UID(); err != nil {
return nil, err
}
case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text:
fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests
default:
if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil {
return
}
}
}
return msg, err
}
// getSize returns cached size or it will build the message, save the size in
// DB and then returns the size after build.
//
// We are storing size in DB as part of pmapi messages metada. The size
// attribute on the server represents size of encrypted body. The value is
// cleared in Bridge and the final decrypted size (including header, attachment
// and MIME structure) is computed after building the message.
func (im *imapMailbox) getSize(storeMessage storeMessageProvider) (uint32, error) {
m := storeMessage.Message()
if m.Size <= 0 {
im.log.WithField("msgID", m.ID).Debug("Size unknown - downloading body")
// We are sure the size is not a problem right now. Clients
// might not first check sizes of all messages so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude getting size from the
// counting and see build count as real message build.
if _, _, err := im.getBodyAndStructure(storeMessage, nil); err != nil {
return 0, err
}
}
return uint32(m.Size), nil
}
func (im *imapMailbox) getLiteralForSection(
itemSection imap.FetchItem,
msg *imap.Message,
storeMessage storeMessageProvider,
msgBuildCountHistogram *msgBuildCountHistogram,
) error {
section, err := imap.ParseBodySectionName(itemSection)
if err != nil {
log.WithError(err).Warn("Failed to parse body section name; part will be skipped")
return nil //nolint[nilerr] ignore error
}
var literal imap.Literal
if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil {
return err
}
msg.Body[section] = literal
return nil
}
// getBodyStructure returns the cached body structure or it will build the message,
// save the structure in DB and then returns the structure after build.
//
// Apple Mail requests body structure for all messages irregularly. We cache
// bodystructure in local database in order to not re-download all messages
// from server.
func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs *message.BodyStructure, err error) {
bs, err = storeMessage.GetBodyStructure()
if err != nil {
im.log.WithError(err).Debug("Fail to retrieve bodystructure from database")
}
if bs == nil {
// We are sure the body structure is not a problem right now.
// Clients might do first fetch body structure so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude first body structure fetch
// from the counting and see build count as real message build.
if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil {
return
}
}
return
}
func (im *imapMailbox) getBodyAndStructure(
storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram,
) (
structure *message.BodyStructure, bodyReader *bytes.Reader, err error,
) {
m := storeMessage.Message()
id := im.storeUser.UserID() + m.ID
cache.BuildLock(id)
defer cache.BuildUnlock(id)
bodyReader, structure = cache.LoadMail(id)
// return the message which was found in cache
if bodyReader.Len() != 0 && structure != nil {
return structure, bodyReader, nil
}
structure, body, err := im.buildMessage(m)
bodyReader = bytes.NewReader(body)
size := int64(len(body))
l := im.log.WithField("newSize", size).WithField("msgID", m.ID)
if err != nil || structure == nil || size == 0 {
l.WithField("hasStructure", structure != nil).Warn("Failed to build message")
return structure, bodyReader, err
}
// Save the size, body structure and header even for messages which
// were unable to decrypt. Hence they doesn't have to be computed every
// time.
m.Size = size
cacheMessageInStore(storeMessage, structure, body, l)
if msgBuildCountHistogram != nil {
times, errCount := storeMessage.IncreaseBuildCount()
if errCount != nil {
l.WithError(errCount).Warn("Cannot increase build count")
}
msgBuildCountHistogram.add(times)
}
// Drafts can change therefore we don't want to cache them.
if !isMessageInDraftFolder(m) {
cache.SaveMail(id, body, structure)
}
return structure, bodyReader, err
}
func cacheMessageInStore(storeMessage storeMessageProvider, structure *message.BodyStructure, body []byte, l *logrus.Entry) {
m := storeMessage.Message()
if errSize := storeMessage.SetSize(m.Size); errSize != nil {
l.WithError(errSize).Warn("Cannot update size while building")
}
if structure != nil && !isMessageInDraftFolder(m) {
if errStruct := storeMessage.SetBodyStructure(structure); errStruct != nil {
l.WithError(errStruct).Warn("Cannot update bodystructure while building")
}
}
header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body))
if errHead == nil && len(header) != 0 {
if errStore := storeMessage.SetHeader(header); errStore != nil {
l.WithError(errStore).Warn("Cannot update header in store")
}
} else {
l.WithError(errHead).Warn("Cannot get header bytes from structure")
}
}
func isMessageInDraftFolder(m *pmapi.Message) bool {
for _, labelID := range m.LabelIDs {
if labelID == pmapi.DraftLabel {
return true
}
}
return false
}
// This will download message (or read from cache) and pick up the section,
// extract data (header,body, both) and trim the output if needed.
//
// In order to speed up (avoid download and decryptions) we
// cache the header. If a mail header was requested and DB
// contains full header (it means it was already built once)
// the DB header can be used without downloading and decrypting.
// Otherwise header is incomplete and clients would have issues
// e.g. AppleMail expects `text/plain` in HTML mails.
//
// For all other cases it is necessary to download and decrypt the message
// and drop the header which was obtained from cache. The header will
// will be stored in DB once successfully built. Check `getBodyAndStructure`.
func (im *imapMailbox) getMessageBodySection(
storeMessage storeMessageProvider,
section *imap.BodySectionName,
msgBuildCountHistogram *msgBuildCountHistogram,
) (imap.Literal, error) {
var header []byte
var response []byte
im.log.WithField("msgID", storeMessage.ID()).Trace("Getting message body")
isMainHeaderRequested := len(section.Path) == 0 && section.Specifier == imap.HeaderSpecifier
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
header = storeMessage.GetHeader()
} else {
structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram)
if err != nil {
return nil, err
}
switch {
case section.Specifier == imap.EntireSpecifier && len(section.Path) == 0:
// An empty section specification refers to the entire message, including the header.
response, err = structure.GetSection(bodyReader, section.Path)
case section.Specifier == imap.TextSpecifier || (section.Specifier == imap.EntireSpecifier && len(section.Path) != 0):
// The TEXT specifier refers to the content of the message (or section), omitting the [RFC-2822] header.
// Non-empty section with no specifier (imap.EntireSpecifier) refers to section content without header.
response, err = structure.GetSectionContent(bodyReader, section.Path)
case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part.
fallthrough
case section.Specifier == imap.HeaderSpecifier:
header, err = structure.GetSectionHeaderBytes(bodyReader, section.Path)
default:
err = errors.New("Unknown specifier " + string(section.Specifier))
}
if err != nil {
return nil, err
}
}
if header != nil {
response = filterHeader(header, section)
}
// Trim any output if requested.
return bytes.NewBuffer(section.ExtractPartial(response)), nil
}
// buildMessage from PM to IMAP.
func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) {
body, err := im.builder.NewJobWithOptions(
context.Background(),
im.user.client(),
m.ID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
).GetResult()
if err != nil {
return nil, nil, err
}
structure, err := message.NewBodyStructure(bytes.NewReader(body))
if err != nil {
return nil, nil, err
}
return structure, body, nil
}

View File

@ -0,0 +1,67 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilterHeader(t *testing.T) {
const header = "To: somebody\r\nFrom: somebody else\r\nSubject: this is\r\n\ta multiline field\r\n\r\n"
assert.Equal(t, "To: somebody\r\n\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "To")
})))
assert.Equal(t, "From: somebody else\r\n\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "From")
})))
assert.Equal(t, "To: somebody\r\nFrom: somebody else\r\n\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "To") || strings.EqualFold(field, "From")
})))
assert.Equal(t, "Subject: this is\r\n\ta multiline field\r\n\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "Subject")
})))
}
// TestFilterHeaderNoNewline tests that we don't include a trailing newline when filtering
// if the original header also lacks one (which it can legally do if there is no body).
func TestFilterHeaderNoNewline(t *testing.T) {
const header = "To: somebody\r\nFrom: somebody else\r\nSubject: this is\r\n\ta multiline field\r\n"
assert.Equal(t, "To: somebody\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "To")
})))
assert.Equal(t, "From: somebody else\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "From")
})))
assert.Equal(t, "To: somebody\r\nFrom: somebody else\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "To") || strings.EqualFold(field, "From")
})))
assert.Equal(t, "Subject: this is\r\n\ta multiline field\r\n", string(filterHeaderLines([]byte(header), func(field string) bool {
return strings.EqualFold(field, "Subject")
})))
}

View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"bufio"
"bytes"
"io"
"strings"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
)
func filterHeader(header []byte, section *imap.BodySectionName) []byte {
// Empty section.Fields means BODY[HEADER] was requested so we should return the full header.
if len(section.Fields) == 0 {
return header
}
fieldMap := make(map[string]struct{})
for _, field := range section.Fields {
fieldMap[strings.ToLower(field)] = struct{}{}
}
return filterHeaderLines(header, func(field string) bool {
_, ok := fieldMap[strings.ToLower(field)]
if section.NotFields {
ok = !ok
}
return ok
})
}
func filterHeaderLines(header []byte, wantField func(string) bool) []byte {
var res []byte
for _, line := range headerLines(header) {
if len(bytes.TrimSpace(line)) == 0 {
res = append(res, line...)
} else {
split := bytes.SplitN(line, []byte(": "), 2)
if len(split) != 2 {
continue
}
if wantField(string(bytes.ToLower(split[0]))) {
res = append(res, line...)
}
}
}
return res
}
// NOTE: This sucks because we trim and split stuff here already, only to do it again when we use this function!
func headerLines(header []byte) [][]byte {
var lines [][]byte
r := bufio.NewReader(bytes.NewReader(header))
for {
b, err := r.ReadBytes('\n')
if err != nil {
if err != io.EOF {
panic(errors.Wrap(err, "failed to read header line"))
}
break
}
switch {
case len(bytes.TrimSpace(b)) == 0:
lines = append(lines, b)
case len(bytes.SplitN(b, []byte(": "), 2)) != 2:
lines[len(lines)-1] = append(lines[len(lines)-1], b...)
default:
lines = append(lines, b)
}
}
return lines
}

View File

@ -1,540 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"bytes"
"context"
"fmt"
"io"
"net/mail"
"net/textproto"
"sort"
"strings"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/imap/uidplus"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
)
var (
rfc822Birthday = time.Date(1982, 8, 13, 0, 0, 0, 0, time.UTC) //nolint[gochecknoglobals]
)
type doNotCacheError struct{ e error }
func (dnc *doNotCacheError) Error() string { return dnc.e.Error() }
func (dnc *doNotCacheError) add(err error) { dnc.e = multierror.Append(dnc.e, err) }
func (dnc *doNotCacheError) errorOrNil() error {
if dnc == nil {
return nil
}
if dnc.e != nil {
return dnc
}
return nil
}
// CreateMessage appends a new message to this mailbox. The \Recent flag will
// be added regardless of whether flags is empty or not. If date is nil, the
// current time will be used.
//
// If the Backend implements Updater, it must notify the client immediately
// via a mailbox update.
func (im *imapMailbox) CreateMessage(flags []string, date time.Time, body imap.Literal) error {
return im.logCommand(func() error {
return im.createMessage(flags, date, body)
}, "APPEND", flags, date)
}
func (im *imapMailbox) createMessage(flags []string, date time.Time, body imap.Literal) error { // nolint[funlen]
// Called from go-imap in goroutines - we need to handle panics for each function.
defer im.panicHandler.HandlePanic()
m, _, _, readers, err := message.Parse(body)
if err != nil {
return err
}
addr := im.storeAddress.APIAddress()
if addr == nil {
return errors.New("no available address for encryption")
}
m.AddressID = addr.ID
kr, err := im.user.client().KeyRingForAddressID(addr.ID)
if err != nil {
return err
}
// Handle imported messages which have no "Sender" address.
// This sometimes occurs with outlook which reports errors as imported emails or for drafts.
if m.Sender == nil {
im.log.Warning("Append: Missing email sender. Will use main address")
m.Sender = &mail.Address{
Name: "",
Address: addr.Email,
}
}
// "Drafts" needs to call special API routes.
// Clients always append the whole message again and remove the old one.
if im.storeMailbox.LabelID() == pmapi.DraftLabel {
// Sender address needs to be sanitised (drafts need to match cases exactly).
m.Sender.Address = pmapi.ConstructAddress(m.Sender.Address, addr.Email)
draft, _, err := im.user.storeUser.CreateDraft(kr, m, readers, "", "", "")
if err != nil {
return errors.Wrap(err, "failed to create draft")
}
targetSeq := im.storeMailbox.GetUIDList([]string{draft.ID})
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), targetSeq)
}
// We need to make sure this is an import, and not a sent message from this account
// (sent messages from the account will be added by the event loop).
if im.storeMailbox.LabelID() == pmapi.SentLabel {
sanitizedSender := pmapi.SanitizeEmail(m.Sender.Address)
// Check whether this message was sent by a bridge user.
user, err := im.user.backend.bridge.GetUser(sanitizedSender)
if err == nil && user.ID() == im.storeUser.UserID() {
logEntry := im.log.WithField("addr", sanitizedSender).WithField("extID", m.Header.Get("Message-Id"))
// If we find the message in the store already, we can skip importing it.
if foundUID := im.storeMailbox.GetUIDByHeader(&m.Header); foundUID != uint32(0) {
logEntry.Info("Ignoring APPEND of duplicate to Sent folder")
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), &uidplus.OrderedSeq{foundUID})
}
// We didn't find the message in the store, so we are currently sending it.
logEntry.WithField("time", date).Info("No matching UID, continuing APPEND to Sent")
}
}
message.ParseFlags(m, flags)
if !date.IsZero() {
m.Time = date.Unix()
}
internalID := m.Header.Get("X-Pm-Internal-Id")
references := m.Header.Get("References")
referenceList := strings.Fields(references)
// In case there is a mail client which corrupts headers, try
// "References" too.
if internalID == "" && len(referenceList) > 0 {
lastReference := referenceList[len(referenceList)-1]
match := pmapi.RxInternalReferenceFormat.FindStringSubmatch(lastReference)
if len(match) == 2 {
internalID = match[1]
}
}
im.user.appendExpungeLock.Lock()
defer im.user.appendExpungeLock.Unlock()
// Avoid appending a message which is already on the server. Apply the
// new label instead. This always happens with Outlook (it uses APPEND
// instead of COPY).
if internalID != "" {
// Check to see if this belongs to a different address in split mode or another ProtonMail account.
msg, err := im.storeMailbox.GetMessage(internalID)
if err == nil && (im.user.user.IsCombinedAddressMode() || (im.storeAddress.AddressID() == msg.Message().AddressID)) {
IDs := []string{internalID}
// See the comment bellow.
if msg.IsMarkedDeleted() {
if err := im.storeMailbox.MarkMessagesUndeleted(IDs); err != nil {
log.WithError(err).Error("Failed to undelete re-imported internal message")
}
}
err = im.storeMailbox.LabelMessages(IDs)
if err != nil {
return err
}
targetSeq := im.storeMailbox.GetUIDList(IDs)
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), targetSeq)
}
}
im.log.Info("Importing external message")
if err := im.importMessage(m, readers, kr); err != nil {
im.log.Error("Import failed: ", err)
return err
}
// IMAP clients can move message to local folder (setting \Deleted flag)
// and then move it back (IMAP client does not remember the message,
// so instead removing the flag it imports duplicate message).
// Regular IMAP server would keep the message twice and later EXPUNGE would
// not delete the message (EXPUNGE would delete the original message and
// the new duplicate one would stay). API detects duplicates; therefore
// we need to remove \Deleted flag if IMAP client re-imports.
msg, err := im.storeMailbox.GetMessage(m.ID)
if err == nil && msg.IsMarkedDeleted() {
if err := im.storeMailbox.MarkMessagesUndeleted([]string{m.ID}); err != nil {
log.WithError(err).Error("Failed to undelete re-imported message")
}
}
targetSeq := im.storeMailbox.GetUIDList([]string{m.ID})
return uidplus.AppendResponse(im.storeMailbox.UIDValidity(), targetSeq)
}
func (im *imapMailbox) importMessage(m *pmapi.Message, readers []io.Reader, kr *crypto.KeyRing) (err error) { // nolint[funlen]
body, err := message.BuildEncrypted(m, readers, kr)
if err != nil {
return err
}
labels := []string{}
for _, l := range m.LabelIDs {
if l == pmapi.StarredLabel {
labels = append(labels, pmapi.StarredLabel)
}
}
return im.storeMailbox.ImportMessage(m, body, labels)
}
func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []imap.FetchItem, msgBuildCountHistogram *msgBuildCountHistogram) (msg *imap.Message, err error) { //nolint[funlen]
msglog := im.log.WithField("msgID", storeMessage.ID())
msglog.Trace("Getting message")
seqNum, err := storeMessage.SequenceNumber()
if err != nil {
return
}
m := storeMessage.Message()
msg = imap.NewMessage(seqNum, items)
for _, item := range items {
switch item {
case imap.FetchEnvelope:
// No need to check IsFullHeaderCached here. API header
// contain enough information to build the envelope.
msg.Envelope = message.GetEnvelope(m, storeMessage.GetHeader())
case imap.FetchBody, imap.FetchBodyStructure:
var structure *message.BodyStructure
structure, err = im.getBodyStructure(storeMessage)
if err != nil {
return
}
if msg.BodyStructure, err = structure.IMAPBodyStructure([]int{}); err != nil {
return
}
case imap.FetchFlags:
msg.Flags = message.GetFlags(m)
if storeMessage.IsMarkedDeleted() {
msg.Flags = append(msg.Flags, imap.DeletedFlag)
}
case imap.FetchInternalDate:
msg.InternalDate = time.Unix(m.Time, 0)
// Apple Mail crashes fetching messages with date older than 1970.
// There is no point having message older than RFC itself, it's not possible.
if msg.InternalDate.Before(rfc822Birthday) {
msg.InternalDate = rfc822Birthday
}
case imap.FetchRFC822Size:
// Size attribute on the server counts encrypted data. The value is cleared
// on our part and we need to compute "real" size of decrypted data.
if m.Size <= 0 {
msglog.Debug("Size unknown - downloading body")
// We are sure the size is not a problem right now. Clients
// might not first check sizes of all messages so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude getting size from the
// counting and see build count as real message build.
if _, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil {
return
}
}
msg.Size = uint32(m.Size)
case imap.FetchUid:
msg.Uid, err = storeMessage.UID()
if err != nil {
return nil, err
}
case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text:
fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests
default:
if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil {
return
}
}
}
return msg, err
}
func (im *imapMailbox) getLiteralForSection(itemSection imap.FetchItem, msg *imap.Message, storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram) error {
section, err := imap.ParseBodySectionName(itemSection)
if err != nil {
log.WithError(err).Warn("Failed to parse body section name; part will be skipped")
return nil //nolint[nilerr] ignore error
}
var literal imap.Literal
if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil {
return err
}
msg.Body[section] = literal
return nil
}
func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs *message.BodyStructure, err error) {
// Apple Mail requests body structure for all
// messages irregularly. We cache bodystructure in
// local database in order to not re-download all
// messages from server.
bs, err = storeMessage.GetBodyStructure()
if err != nil {
im.log.WithError(err).Debug("Fail to retrieve bodystructure from database")
}
if bs == nil {
// We are sure the body structure is not a problem right now.
// Clients might do first fetch body structure so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude first body structure fetch
// from the counting and see build count as real message build.
if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil {
return
}
}
return
}
//nolint[funlen] Jakub will fix in refactor
func (im *imapMailbox) getBodyAndStructure(storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram) (
structure *message.BodyStructure,
bodyReader *bytes.Reader, err error,
) {
m := storeMessage.Message()
id := im.storeUser.UserID() + m.ID
cache.BuildLock(id)
if bodyReader, structure = cache.LoadMail(id); bodyReader.Len() == 0 || structure == nil {
var body []byte
structure, body, err = im.buildMessage(m)
m.Size = int64(len(body))
// Save size and body structure even for messages unable to decrypt
// so the size or body structure doesn't have to be computed every time.
if err := storeMessage.SetSize(m.Size); err != nil {
im.log.WithError(err).
WithField("newSize", m.Size).
WithField("msgID", m.ID).
Warn("Cannot update size while building")
}
if structure != nil && !isMessageInDraftFolder(m) {
if err := storeMessage.SetBodyStructure(structure); err != nil {
im.log.WithError(err).
WithField("msgID", m.ID).
Warn("Cannot update bodystructure while building")
}
}
if err == nil && structure != nil && len(body) > 0 {
header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body))
if errHead == nil {
if errHead := storeMessage.SetHeader(header); errHead != nil {
im.log.WithError(errHead).
WithField("msgID", m.ID).
Warn("Cannot update header after building")
}
} else {
im.log.WithError(errHead).
WithField("msgID", m.ID).
Warn("Cannot get header bytes after building")
}
if msgBuildCountHistogram != nil {
times, err := storeMessage.IncreaseBuildCount()
if err != nil {
im.log.WithError(err).
WithField("msgID", m.ID).
Warn("Cannot increase build count")
}
msgBuildCountHistogram.add(times)
}
// Drafts can change and we don't want to cache them.
if !isMessageInDraftFolder(m) {
cache.SaveMail(id, body, structure)
}
bodyReader = bytes.NewReader(body)
}
if _, ok := err.(*doNotCacheError); ok {
im.log.WithField("msgID", m.ID).Errorf("do not cache message: %v", err)
err = nil
bodyReader = bytes.NewReader(body)
}
}
cache.BuildUnlock(id)
return structure, bodyReader, err
}
func isMessageInDraftFolder(m *pmapi.Message) bool {
for _, labelID := range m.LabelIDs {
if labelID == pmapi.DraftLabel {
return true
}
}
return false
}
// This will download message (or read from cache) and pick up the section,
// extract data (header,body, both) and trim the output if needed.
func (im *imapMailbox) getMessageBodySection( //nolint[funlen]
storeMessage storeMessageProvider,
section *imap.BodySectionName,
msgBuildCountHistogram *msgBuildCountHistogram,
) (imap.Literal, error) {
var header textproto.MIMEHeader
var extraNewlineAfterHeader bool
var response []byte
im.log.WithField("msgID", storeMessage.ID()).Trace("Getting message body")
isMainHeaderRequested := len(section.Path) == 0 && section.Specifier == imap.HeaderSpecifier
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
// In order to speed up (avoid download and decryptions) we
// cache the header. If a mail header was requested and DB
// contains full header (it means it was already built once)
// the DB header can be used without downloading and decrypting.
// Otherwise header is incomplete and clients would have issues
// e.g. AppleMail expects `text/plain` in HTML mails.
header = storeMessage.GetHeader()
} else {
// For all other cases it is necessary to download and decrypt the message
// and drop the header which was obtained from cache. The header will
// will be stored in DB once successfully built. Check `getBodyAndStructure`.
structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram)
if err != nil {
return nil, err
}
switch {
case section.Specifier == imap.EntireSpecifier && len(section.Path) == 0:
// An empty section specification refers to the entire message, including the header.
response, err = structure.GetSection(bodyReader, section.Path)
case section.Specifier == imap.TextSpecifier || (section.Specifier == imap.EntireSpecifier && len(section.Path) != 0):
// The TEXT specifier refers to the content of the message (or section), omitting the [RFC-2822] header.
// Non-empty section with no specifier (imap.EntireSpecifier) refers to section content without header.
response, err = structure.GetSectionContent(bodyReader, section.Path)
case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part.
fallthrough
case section.Specifier == imap.HeaderSpecifier:
if content, err := structure.GetSectionContent(bodyReader, section.Path); err == nil && content != nil {
extraNewlineAfterHeader = true
}
header, err = structure.GetSectionHeader(section.Path)
default:
err = errors.New("Unknown specifier " + string(section.Specifier))
}
if err != nil {
return nil, err
}
}
if header != nil {
response = filteredHeaderAsBytes(header, section)
// The blank line is included in all header fetches,
// except in the case of a message which has no body.
if extraNewlineAfterHeader {
response = append(response, []byte("\r\n")...)
}
}
// Trim any output if requested.
return bytes.NewBuffer(section.ExtractPartial(response)), nil
}
// filteredHeaderAsBytes filters the header fields by section fields and it
// returns the filtered fields as bytes.
// Options are: all fields, only selected fields, all fields except selected.
func filteredHeaderAsBytes(header textproto.MIMEHeader, section *imap.BodySectionName) []byte {
// remove fields
if len(section.Fields) != 0 && section.NotFields {
for _, field := range section.Fields {
header.Del(field)
}
}
fields := make([]string, 0, len(header))
if len(section.Fields) == 0 || section.NotFields { // add all and sort
for f := range header {
fields = append(fields, f)
}
sort.Strings(fields)
} else { // add only requested (in requested order)
for _, f := range section.Fields {
fields = append(fields, textproto.CanonicalMIMEHeaderKey(f))
}
}
headerBuf := &bytes.Buffer{}
for _, canonical := range fields {
if values, ok := header[canonical]; !ok {
continue
} else {
for _, val := range values {
fmt.Fprintf(headerBuf, "%s: %s\r\n", canonical, val)
}
}
}
return headerBuf.Bytes()
}
// buildMessage from PM to IMAP.
func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) {
body, err := im.builder.NewJobWithOptions(
context.Background(),
im.user.client(),
m.ID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
).GetResult()
if err != nil {
return nil, nil, err
}
structure, err := message.NewBodyStructure(bytes.NewReader(body))
if err != nil {
return nil, nil, err
}
return structure, body, nil
}

View File

@ -18,7 +18,6 @@
package imap
import (
"errors"
"fmt"
"net/mail"
"strings"
@ -30,6 +29,7 @@ import (
"github.com/ProtonMail/proton-bridge/pkg/parallel"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
@ -359,9 +359,8 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
}
}
// In order to speed up search it is not needed to check
// if IsFullHeaderCached.
header := storeMessage.GetHeader()
// In order to speed up search it is not needed to check if IsFullHeaderCached.
header := storeMessage.GetMIMEHeader()
if !criteria.SentBefore.IsZero() || !criteria.SentSince.IsZero() {
t, err := mail.Header(header).Date()
@ -422,7 +421,7 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
if isStringInList(m.LabelIDs, pmapi.StarredLabel) {
messageFlagsMap[imap.FlaggedFlag] = true
}
if m.Unread == 0 {
if !m.Unread {
messageFlagsMap[imap.SeenFlag] = true
}
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
@ -526,18 +525,6 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return err
}
// From RFC: UID range of 559:* always includes the UID of the last message
// in the mailbox, even if 559 is higher than any assigned UID value.
// See: https://tools.ietf.org/html/rfc3501#page-61
if isUID && seqSet.Dynamic() && len(apiIDs) == 0 {
l.Debug("Requesting empty UID dynamic fetch, adding latest message")
apiID, err := im.storeMailbox.GetLatestAPIID()
if err != nil {
return nil
}
apiIDs = []string{apiID}
}
input := make([]interface{}, len(apiIDs))
for i, apiID := range apiIDs {
input[i] = apiID
@ -560,7 +547,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err
}
if storeMessage.Message().Unread == 1 {
if storeMessage.Message().Unread {
for section := range msg.Body {
// Peek means get messages without marking them as read.
// If client does not only ask for peek, we have to mark them as read.

View File

@ -32,8 +32,8 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/imap/id"
"github.com/ProtonMail/proton-bridge/internal/imap/uidplus"
"github.com/ProtonMail/proton-bridge/internal/serverutil"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/ports"
"github.com/emersion/go-imap"
imapappendlimit "github.com/emersion/go-imap-appendlimit"
imapidle "github.com/emersion/go-imap-idle"
@ -116,60 +116,63 @@ func NewIMAPServer(panicHandler panicHandler, debugClient, debugServer bool, por
return server
}
// Starts the server.
func (s *imapServer) ListenAndServe() {
go s.monitorDisconnectedUsers()
go s.monitorInternetConnection()
func (s *imapServer) HandlePanic() { s.panicHandler.HandlePanic() }
func (s *imapServer) IsRunning() bool { return s.isRunning.Load().(bool) }
func (s *imapServer) Port() int { return s.port }
// When starting the Bridge, we don't want to retry to notify user
// quickly about the issue. Very probably retry will not help anyway.
s.listenAndServe(0)
// ListenAndServe starts the server and keeps it on based on internet
// availability.
func (s *imapServer) ListenAndServe() {
serverutil.ListenAndServe(s, s.eventListener)
}
func (s *imapServer) listenAndServe(retries int) {
if s.isRunning.Load().(bool) {
// ListenRetryAndServe will start listener. If port is occupied it will try
// again after coolDown time. Once listener is OK it will serve.
func (s *imapServer) ListenRetryAndServe(retries int, retryAfter time.Duration) {
if s.IsRunning() {
return
}
s.isRunning.Store(true)
log.Info("IMAP server listening at ", s.server.Addr)
l, err := net.Listen("tcp", s.server.Addr)
l := log.WithField("address", s.server.Addr)
l.Info("IMAP server is starting")
listener, err := net.Listen("tcp", s.server.Addr)
if err != nil {
s.isRunning.Store(false)
if retries > 0 {
log.WithError(err).WithField("retries", retries).Warn("IMAP listener failed")
time.Sleep(15 * time.Second)
s.listenAndServe(retries - 1)
l.WithError(err).WithField("retries", retries).Warn("IMAP listener failed")
time.Sleep(retryAfter)
s.ListenRetryAndServe(retries-1, retryAfter)
return
}
log.WithError(err).Error("IMAP listener failed")
l.WithError(err).Error("IMAP listener failed")
s.eventListener.Emit(events.ErrorEvent, "IMAP failed: "+err.Error())
return
}
err = s.server.Serve(&connListener{
Listener: l,
Listener: listener,
server: s,
userAgent: s.userAgent,
})
// Serve returns error every time, even after closing the server.
// User shouldn't be notified about error if server shouldn't be running,
// but it should in case it was not closed by `s.Close()`.
if err != nil && s.isRunning.Load().(bool) {
if err != nil && s.IsRunning() {
s.isRunning.Store(false)
log.WithError(err).Error("IMAP server failed")
l.WithError(err).Error("IMAP server failed")
s.eventListener.Emit(events.ErrorEvent, "IMAP failed: "+err.Error())
return
}
defer s.server.Close() //nolint[errcheck]
log.Info("IMAP server stopped")
l.Info("IMAP server stopped")
}
// Stops the server.
func (s *imapServer) Close() {
if !s.isRunning.Load().(bool) {
if !s.IsRunning() {
return
}
s.isRunning.Store(false)
@ -180,62 +183,16 @@ func (s *imapServer) Close() {
}
}
func (s *imapServer) monitorInternetConnection() {
on := make(chan string)
s.eventListener.Add(events.InternetOnEvent, on)
off := make(chan string)
s.eventListener.Add(events.InternetOffEvent, off)
for {
var expectedIsPortFree bool
select {
case <-on:
go func() {
defer s.panicHandler.HandlePanic()
// We had issues on Mac that from time to time something
// blocked our port for a bit after we closed IMAP server
// due to connection issues.
// Restart always helped, so we do retry to not bother user.
s.listenAndServe(10)
}()
expectedIsPortFree = false
case <-off:
s.Close()
expectedIsPortFree = true
}
start := time.Now()
for {
if ports.IsPortFree(s.port) == expectedIsPortFree {
break
}
// Safety stop if something went wrong.
if time.Since(start) > 15*time.Second {
log.WithField("expectedIsPortFree", expectedIsPortFree).Warn("Server start/stop check timeouted")
break
}
time.Sleep(100 * time.Millisecond)
}
}
}
func (s *imapServer) monitorDisconnectedUsers() {
ch := make(chan string)
s.eventListener.Add(events.CloseConnectionEvent, ch)
for address := range ch {
address := address
log.Info("Disconnecting all open IMAP connections for ", address)
disconnectUser := func(conn imapserver.Conn) {
connUser := conn.Context().User
if connUser != nil && strings.EqualFold(connUser.Username(), address) {
if err := conn.Close(); err != nil {
log.WithError(err).Error("Failed to close the connection")
}
func (s *imapServer) DisconnectUser(address string) {
log.Info("Disconnecting all open IMAP connections for ", address)
s.server.ForEachConn(func(conn imapserver.Conn) {
connUser := conn.Context().User
if connUser != nil && strings.EqualFold(connUser.Username(), address) {
if err := conn.Close(); err != nil {
log.WithError(err).Error("Failed to close the connection")
}
}
s.server.ForEachConn(disconnectUser)
}
})
}
// connListener sets debug loggers on server containing fields with local

View File

@ -20,48 +20,33 @@ package imap
import (
"fmt"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/ports"
"github.com/ProtonMail/proton-bridge/internal/serverutil/mocks"
imapserver "github.com/emersion/go-imap/server"
"github.com/stretchr/testify/require"
)
type testPanicHandler struct{}
func (ph *testPanicHandler) HandlePanic() {}
func TestIMAPServerTurnOffAndOnAgain(t *testing.T) {
panicHandler := &testPanicHandler{}
r := require.New(t)
ts := mocks.NewTestServer(12345)
eventListener := listener.New()
port := ports.FindFreePortFrom(12345)
server := imapserver.New(nil)
server.Addr = fmt.Sprintf("%v:%v", bridge.Host, port)
server.Addr = fmt.Sprintf("%v:%v", bridge.Host, ts.WantPort)
s := &imapServer{
panicHandler: panicHandler,
panicHandler: ts.PanicHandler,
server: server,
eventListener: eventListener,
port: ts.WantPort,
eventListener: ts.EventListener,
userAgent: useragent.New(),
}
s.isRunning.Store(false)
r.True(ts.IsPortFree())
go s.ListenAndServe()
time.Sleep(5 * time.Second)
require.False(t, ports.IsPortFree(port))
eventListener.Emit(events.InternetOffEvent, "")
time.Sleep(10 * time.Second)
require.True(t, ports.IsPortFree(port))
eventListener.Emit(events.InternetOnEvent, "")
time.Sleep(10 * time.Second)
require.False(t, ports.IsPortFree(port))
ts.RunServerTests(r)
}

View File

@ -102,7 +102,8 @@ type storeMessageProvider interface {
SetSize(int64) error
SetHeader([]byte) error
GetHeader() textproto.MIMEHeader
GetHeader() []byte
GetMIMEHeader() textproto.MIMEHeader
IsFullHeaderCached() bool
SetBodyStructure(*pkgMsg.BodyStructure) error
GetBodyStructure() (*pkgMsg.BodyStructure, error)

View File

@ -93,7 +93,7 @@ func newIMAPUser(
// This method should eventually no longer be necessary. Everything should go via store.
func (iu *imapUser) client() pmapi.Client {
return iu.user.GetTemporaryPMAPIClient()
return iu.user.GetClient()
}
func (iu *imapUser) isSubscribed(labelID string) bool {

View File

@ -20,7 +20,9 @@ package importexport
import (
"bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/transfer"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -39,7 +41,8 @@ type ImportExport struct {
locations Locator
cache Cacher
panicHandler users.PanicHandler
clientManager users.ClientManager
eventListener listener.Listener
clientManager pmapi.Manager
}
func New(
@ -47,7 +50,7 @@ func New(
cache Cacher,
panicHandler users.PanicHandler,
eventListener listener.Listener,
clientManager users.ClientManager,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
) *ImportExport {
u := users.New(locations, panicHandler, eventListener, clientManager, credStorer, &storeFactory{}, false)
@ -58,63 +61,38 @@ func New(
locations: locations,
cache: cache,
panicHandler: panicHandler,
eventListener: eventListener,
clientManager: clientManager,
}
}
// ReportBug reports a new bug from the user.
func (ie *ImportExport) ReportBug(osType, osVersion, description, accountName, address, emailClient string) error {
c := ie.clientManager.GetAnonymousClient()
defer c.Logout()
title := "[Import-Export] Bug"
report := pmapi.ReportReq{
return ie.clientManager.ReportBug(context.Background(), pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Browser: emailClient,
Title: title,
Title: "[Import-Export] Bug",
Description: description,
Username: accountName,
Email: address,
}
if err := c.Report(report); err != nil {
log.Error("Reporting bug failed: ", err)
return err
}
log.Info("Bug successfully reported")
return nil
})
}
// ReportFile submits import report file.
func (ie *ImportExport) ReportFile(osType, osVersion, accountName, address string, logdata []byte) error {
c := ie.clientManager.GetAnonymousClient()
defer c.Logout()
title := "[Import-Export] report file"
description := "An Import-Export report from the user swam down the river."
report := pmapi.ReportReq{
report := pmapi.ReportBugReq{
OS: osType,
OSVersion: osVersion,
Description: description,
Title: title,
Description: "An Import-Export report from the user swam down the river.",
Title: "[Import-Export] report file",
Username: accountName,
Email: address,
}
report.AddAttachment("log", "report.log", bytes.NewReader(logdata))
if err := c.Report(report); err != nil {
log.Error("Sending report failed: ", err)
return err
}
log.Info("Report successfully sent")
return nil
return ie.clientManager.ReportBug(context.Background(), report)
}
// GetLocalImporter returns transferrer from local EML or MBOX structure to ProtonMail account.
@ -187,5 +165,23 @@ func (ie *ImportExport) getPMAPIProvider(username, address string) (*transfer.PM
log.WithError(err).Info("Address does not exist, using all addresses")
}
return transfer.NewPMAPIProvider(ie.clientManager, user.ID(), addressID)
provider, err := transfer.NewPMAPIProvider(user.GetClient(), user.ID(), addressID)
if err != nil {
return nil, err
}
go func() {
internetOffCh := ie.eventListener.ProvideChannel(events.InternetOffEvent)
internetOnCh := ie.eventListener.ProvideChannel(events.InternetOnEvent)
for {
select {
case <-internetOffCh:
provider.SetConnectionDown()
case <-internetOnCh:
provider.SetConnectionUp()
}
}
}()
return provider, nil
}

View File

@ -0,0 +1,150 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package mocks
import (
"fmt"
"net/http"
"sync/atomic"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/ports"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
type DummyPanicHandler struct{}
func (ph *DummyPanicHandler) HandlePanic() {}
type TestServer struct {
PanicHandler *DummyPanicHandler
WantPort int
EventListener listener.Listener
isRunning atomic.Value
srv *http.Server
}
func NewTestServer(port int) *TestServer {
s := &TestServer{
PanicHandler: &DummyPanicHandler{},
EventListener: listener.New(),
WantPort: ports.FindFreePortFrom(port),
}
s.isRunning.Store(false)
return s
}
func (s *TestServer) IsPortFree() bool {
return true
}
func (s *TestServer) IsPortOccupied() bool {
return true
}
func (s *TestServer) Emit(event string, try, iEvt int) int {
// Emit has separate go routine so it is needed to wait here to
// prevent event race condition.
time.Sleep(100 * time.Millisecond)
iEvt++
s.EventListener.Emit(event, fmt.Sprintf("%d:%d", try, iEvt))
return iEvt
}
func (s *TestServer) HandlePanic() {}
func (s *TestServer) DisconnectUser(string) {}
func (s *TestServer) Port() int { return s.WantPort }
func (s *TestServer) IsRunning() bool { return s.isRunning.Load().(bool) }
func (s *TestServer) ListenRetryAndServe(retries int, retryAfter time.Duration) {
if s.isRunning.Load().(bool) {
return
}
s.isRunning.Store(true)
// There can be delay when starting server
time.Sleep(200 * time.Millisecond)
s.srv = &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", s.WantPort),
}
err := s.srv.ListenAndServe()
if err != nil {
s.isRunning.Store(false)
if retries > 0 {
time.Sleep(retryAfter)
s.ListenRetryAndServe(retries-1, retryAfter)
}
}
if s.IsRunning() {
logrus.Error("Not serving but isRunning is true")
s.isRunning.Store(false)
}
}
func (s *TestServer) Close() {
if !s.isRunning.Load().(bool) {
return
}
s.isRunning.Store(false)
// There can be delay when stopping server
time.Sleep(200 * time.Millisecond)
if err := s.srv.Close(); err != nil {
logrus.WithError(err).Error("Closing dummy server")
}
}
func (s *TestServer) RunServerTests(r *require.Assertions) {
// NOTE About choosing tick durations:
// In order to avoid ticks to synchronise and cause occasional race
// condition we choose the tick duration around 100ms but not exactly
// to have large common multiple.
r.Eventually(s.IsPortOccupied, 5*time.Second, 97*time.Millisecond)
// There was an issue where second time we were not able to restore server.
for try := 0; try < 3; try++ {
i := s.Emit(events.InternetOffEvent, try, 0)
r.Eventually(s.IsPortFree, 10*time.Second, 99*time.Millisecond, "signal off try %d : %d", try, i)
i = s.Emit(events.InternetOnEvent, try, i)
i = s.Emit(events.InternetOffEvent, try, i)
i = s.Emit(events.InternetOffEvent, try, i)
i = s.Emit(events.InternetOffEvent, try, i)
i = s.Emit(events.InternetOffEvent, try, i)
i = s.Emit(events.InternetOnEvent, try, i)
i = s.Emit(events.InternetOnEvent, try, i)
i = s.Emit(events.InternetOffEvent, try, i)
// Wait a bit longer if needed to process all events
r.Eventually(s.IsPortFree, 20*time.Second, 101*time.Millisecond, "again signal off number %d : %d", try, i)
i = s.Emit(events.InternetOnEvent, try, i)
r.Eventually(s.IsPortOccupied, 10*time.Second, 103*time.Millisecond, "signal on number %d : %d", try, i)
i = s.Emit(events.InternetOffEvent, try, i)
i = s.Emit(events.InternetOnEvent, try, i)
i = s.Emit(events.InternetOnEvent, try, i)
r.Eventually(s.IsPortOccupied, 10*time.Second, 107*time.Millisecond, "again signal on number %d : %d", try, i)
}
}

View File

@ -0,0 +1,132 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package serverutil
import (
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/ports"
"github.com/sirupsen/logrus"
)
// Server which can handle disconnected users and lost internet connection.
type Server interface {
HandlePanic()
DisconnectUser(string)
ListenRetryAndServe(int, time.Duration)
Close()
Port() int
IsRunning() bool
}
func monitorDisconnectedUsers(s Server, l listener.Listener) {
ch := make(chan string)
l.Add(events.CloseConnectionEvent, ch)
for address := range ch {
s.DisconnectUser(address)
}
}
func redirectInternetEventsToOneChannel(l listener.Listener) (isInternetOn chan bool) {
on := make(chan string)
l.Add(events.InternetOnEvent, on)
off := make(chan string)
l.Add(events.InternetOffEvent, off)
// Redirect two channels into one. When select was used the algorithm
// first read all on channels and then read all off channels.
isInternetOn = make(chan bool, 20)
go func() {
for {
logrus.WithField("try", <-on).Trace("Internet ON")
isInternetOn <- true
}
}()
go func() {
for {
logrus.WithField("try", <-off).Trace("Internet OFF")
isInternetOn <- false
}
}()
return
}
const (
recheckPortAfter = 50 * time.Millisecond
stopPortChecksAfter = 15 * time.Second
retryListnerAfter = 5 * time.Second
)
func monitorInternetConnection(s Server, l listener.Listener) {
isInternetOn := redirectInternetEventsToOneChannel(l)
for {
var expectedIsPortFree bool
if <-isInternetOn {
if s.IsRunning() {
continue
}
go func() {
defer s.HandlePanic()
// We had issues on Mac that from time to time something
// blocked our port for a bit after we closed IMAP server
// due to connection issues.
// Restart always helped, so we do retry to not bother user.
s.ListenRetryAndServe(10, retryListnerAfter)
}()
expectedIsPortFree = false
} else {
if !s.IsRunning() {
continue
}
s.Close()
expectedIsPortFree = true
}
start := time.Now()
for {
isPortFree := ports.IsPortFree(s.Port())
logrus.
WithField("port", s.Port()).
WithField("isFree", isPortFree).
WithField("wantToBeFree", expectedIsPortFree).
Trace("Check port")
if isPortFree == expectedIsPortFree {
break
}
// Safety stop if something went wrong.
if time.Since(start) > stopPortChecksAfter {
logrus.WithField("expectedIsPortFree", expectedIsPortFree).Warn("Server start/stop check timeouted")
break
}
time.Sleep(recheckPortAfter)
}
}
}
// ListenAndServe starts the server and keeps it on based on internet
// availability. It also monitors and disconnect users if requested.
func ListenAndServe(s Server, l listener.Listener) {
go monitorDisconnectedUsers(s, l)
go monitorInternetConnection(s, l)
// When starting the Bridge, we don't want to retry to notify user
// quickly about the issue. Very probably retry will not help anyway.
s.ListenRetryAndServe(0, 0)
}

View File

@ -15,9 +15,21 @@
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package users
package serverutil
// IsAuthorized returns whether the user has received an Auth from the API yet.
func (u *User) IsAuthorized() bool {
return u.isAuthorized
import (
"testing"
"github.com/ProtonMail/proton-bridge/internal/serverutil/mocks"
"github.com/stretchr/testify/require"
)
func TestServerTurnOffAndOnAgain(t *testing.T) {
r := require.New(t)
s := mocks.NewTestServer(12321)
r.True(s.IsPortFree())
go ListenAndServe(s, s.EventListener)
s.RunServerTests(r)
}

View File

@ -31,7 +31,7 @@ type bridgeUser interface {
CheckBridgeLogin(password string) error
IsCombinedAddressMode() bool
GetAddressID(address string) (string, error)
GetTemporaryPMAPIClient() pmapi.Client
GetClient() pmapi.Client
GetStore() storeUserProvider
}

View File

@ -18,6 +18,7 @@
package smtp
import (
"context"
"crypto/sha256"
"fmt"
"strings"
@ -28,7 +29,7 @@ import (
)
type messageGetter interface {
GetMessage(string) (*pmapi.Message, error)
GetMessage(context.Context, string) (*pmapi.Message, error)
}
type sendRecorderValue struct {
@ -126,7 +127,7 @@ func (q *sendRecorder) isSendingOrSent(client messageGetter, hash string) (isSen
return true, false
}
message, err := client.GetMessage(value.messageID)
message, err := client.GetMessage(context.TODO(), value.messageID)
// Message could be deleted or there could be an internet issue or whatever,
// so let's assume the message was not sent.
if err != nil {

View File

@ -18,6 +18,7 @@
package smtp
import (
"context"
"errors"
"fmt"
"net/mail"
@ -33,7 +34,7 @@ type testSendRecorderGetMessageMock struct {
err error
}
func (m *testSendRecorderGetMessageMock) GetMessage(messageID string) (*pmapi.Message, error) {
func (m *testSendRecorderGetMessageMock) GetMessage(_ context.Context, messageID string) (*pmapi.Message, error) {
return m.message, m.err
}

View File

@ -20,30 +20,34 @@ package smtp
import (
"crypto/tls"
"fmt"
"net"
"sync/atomic"
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/serverutil"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/emersion/go-sasl"
goSMTP "github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
)
type smtpServer struct {
// Server is Bridge SMTP server implementation.
type Server struct {
panicHandler panicHandler
backend goSMTP.Backend
server *goSMTP.Server
eventListener listener.Listener
debug bool
useSSL bool
port int
tls *tls.Config
isRunning atomic.Value
}
// NewSMTPServer returns an SMTP server configured with the given options.
func NewSMTPServer(debug bool, port int, useSSL bool, tls *tls.Config, smtpBackend goSMTP.Backend, eventListener listener.Listener) *smtpServer { //nolint[golint]
s := goSMTP.NewServer(smtpBackend)
s.Addr = fmt.Sprintf("%v:%v", bridge.Host, port)
s.TLSConfig = tls
s.Domain = bridge.Host
s.AllowInsecureAuth = true
s.MaxLineLength = 2 << 16
func NewSMTPServer(panicHandler panicHandler, debug bool, port int, useSSL bool, tls *tls.Config, smtpBackend goSMTP.Backend, eventListener listener.Listener) *Server {
if debug {
fmt.Println("THE LOG WILL CONTAIN **DECRYPTED** MESSAGE DATA")
log.Warning("================================================")
@ -51,13 +55,38 @@ func NewSMTPServer(debug bool, port int, useSSL bool, tls *tls.Config, smtpBacke
log.Warning("================================================")
}
server := &Server{
panicHandler: panicHandler,
backend: smtpBackend,
eventListener: eventListener,
debug: debug,
useSSL: useSSL,
port: port,
tls: tls,
}
server.isRunning.Store(false)
return server
}
func (s *Server) HandlePanic() { s.panicHandler.HandlePanic() }
func (s *Server) IsRunning() bool { return s.isRunning.Load().(bool) }
func (s *Server) Port() int { return s.port }
func newGoSMTPServer(debug bool, smtpBackend goSMTP.Backend, port int, tls *tls.Config) *goSMTP.Server {
newSMTP := goSMTP.NewServer(smtpBackend)
newSMTP.Addr = fmt.Sprintf("%v:%v", bridge.Host, port)
newSMTP.TLSConfig = tls
newSMTP.Domain = bridge.Host
newSMTP.AllowInsecureAuth = true
newSMTP.MaxLineLength = 1 << 16
if debug {
s.Debug = logrus.
newSMTP.Debug = logrus.
WithField("pkg", "smtp/server").
WriterLevel(logrus.DebugLevel)
}
s.EnableAuth(sasl.Login, func(conn *goSMTP.Conn) sasl.Server {
newSMTP.EnableAuth(sasl.Login, func(conn *goSMTP.Conn) sasl.Server {
return sasl.NewLoginServer(func(address, password string) error {
user, err := conn.Server().Backend.Login(nil, address, password)
if err != nil {
@ -68,57 +97,92 @@ func NewSMTPServer(debug bool, port int, useSSL bool, tls *tls.Config, smtpBacke
return nil
})
})
return &smtpServer{
server: s,
eventListener: eventListener,
useSSL: useSSL,
}
return newSMTP
}
// Starts the server.
func (s *smtpServer) ListenAndServe() {
go s.monitorDisconnectedUsers()
l := log.WithField("useSSL", s.useSSL).WithField("address", s.server.Addr)
// ListenAndServe starts the server and keeps it on based on internet
// availability.
func (s *Server) ListenAndServe() {
serverutil.ListenAndServe(s, s.eventListener)
}
l.Info("SMTP server is starting")
var err error
if s.useSSL {
err = s.server.ListenAndServeTLS()
} else {
err = s.server.ListenAndServe()
}
if err != nil {
s.eventListener.Emit(events.ErrorEvent, "SMTP failed: "+err.Error())
l.Error("SMTP failed: ", err)
func (s *Server) ListenRetryAndServe(retries int, retryAfter time.Duration) {
if s.IsRunning() {
return
}
defer s.server.Close() //nolint[errcheck]
s.isRunning.Store(true)
l.Info("SMTP server stopped")
s.server = newGoSMTPServer(s.debug, s.backend, s.port, s.tls)
l := log.WithField("useSSL", s.useSSL).WithField("address", s.server.Addr)
l.Info("SMTP server is starting")
var listener net.Listener
var err error
if s.useSSL {
listener, err = tls.Listen("tcp", s.server.Addr, s.server.TLSConfig)
} else {
listener, err = net.Listen("tcp", s.server.Addr)
}
l.WithError(err).Debug("Listener for SMTP created")
if err != nil {
s.isRunning.Store(false)
if retries > 0 {
l.WithError(err).WithField("retries", retries).Warn("SMTP listener failed")
time.Sleep(retryAfter)
s.ListenRetryAndServe(retries-1, retryAfter)
return
}
l.WithError(err).Error("SMTP listener failed")
s.eventListener.Emit(events.ErrorEvent, "SMTP failed: "+err.Error())
return
}
err = s.server.Serve(listener)
l.WithError(err).Debug("GoSMTP not serving")
// Serve returns error every time, even after closing the server.
// User shouldn't be notified about error if server shouldn't be running,
// but it should in case it was not closed by `s.Close()`.
if err != nil && s.IsRunning() {
s.isRunning.Store(false)
l.WithError(err).Error("SMTP server failed")
s.eventListener.Emit(events.ErrorEvent, "SMTP failed: "+err.Error())
return
}
defer func() {
// Go SMTP server instance can be closed only once. Otherwise
// it returns an error. The error is not export therefore we
// will check the string value.
err := s.server.Close()
if err == nil || err.Error() != "smtp: server already closed" {
l.WithError(err).Warn("Server was not closed")
}
}()
l.Info("SMTP server closed")
}
// Stops the server.
func (s *smtpServer) Close() {
// Close stops the server.
func (s *Server) Close() {
if !s.IsRunning() {
return
}
s.isRunning.Store(false)
if err := s.server.Close(); err != nil {
log.WithError(err).Error("Failed to close the connection")
log.WithError(err).Error("Cannot close the server")
}
}
func (s *smtpServer) monitorDisconnectedUsers() {
ch := make(chan string)
s.eventListener.Add(events.CloseConnectionEvent, ch)
for address := range ch {
log.Info("Disconnecting all open SMTP connections for ", address)
disconnectUser := func(conn *goSMTP.Conn) {
connUser := conn.Session()
if connUser != nil {
if err := conn.Close(); err != nil {
log.WithError(err).Error("Failed to close the connection")
}
func (s *Server) DisconnectUser(address string) {
log.Info("Disconnecting all open SMTP connections for ", address)
s.server.ForEachConn(func(conn *goSMTP.Conn) {
connUser := conn.Session()
if connUser != nil {
if err := conn.Close(); err != nil {
log.WithError(err).Error("Failed to close the connection")
}
}
s.server.ForEachConn(disconnectUser)
}
})
}

View File

@ -15,29 +15,29 @@
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
package smtp
import (
"net/url"
"testing"
"github.com/ProtonMail/proton-bridge/internal/serverutil/mocks"
"github.com/stretchr/testify/require"
)
// SendSimpleMetric makes a simple GET request to send a simple metrics report.
func (c *client) SendSimpleMetric(category, action, label string) (err error) {
v := url.Values{}
v.Set("Category", category)
v.Set("Action", action)
v.Set("Label", label)
func TestSMTPServerTurnOffAndOnAgain(t *testing.T) {
r := require.New(t)
ts := mocks.NewTestServer(12342)
req, err := c.NewRequest("GET", "/metrics?"+v.Encode(), nil)
if err != nil {
return
s := &Server{
panicHandler: ts.PanicHandler,
port: ts.WantPort,
eventListener: ts.EventListener,
}
s.isRunning.Store(false)
var res Res
if err = c.DoJSON(req, &res); err != nil {
return
}
r.True(ts.IsPortFree())
err = res.Err()
return
go s.ListenAndServe()
ts.RunServerTests(r)
}

View File

@ -21,10 +21,10 @@ package smtp
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"mime"
"net/mail"
"strings"
@ -81,7 +81,7 @@ func newSMTPUser(
// This method should eventually no longer be necessary. Everything should go via store.
func (su *smtpUser) client() pmapi.Client {
return su.user.GetTemporaryPMAPIClient()
return su.user.GetClient()
}
// Send sends an email from the given address to the given addresses with the given body.
@ -123,7 +123,7 @@ func (su *smtpUser) getSendPreferences(
}
func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata, err error) {
emails, err := su.client().GetContactEmailByEmail(recipient, 0, 1000)
emails, err := su.client().GetContactEmailByEmail(context.TODO(), recipient, 0, 1000)
if err != nil {
return
}
@ -135,7 +135,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
}
var contact pmapi.Contact
if contact, err = su.client().GetContactByID(email.ContactID); err != nil {
if contact, err = su.client().GetContactByID(context.TODO(), email.ContactID); err != nil {
return
}
@ -151,7 +151,7 @@ func (su *smtpUser) getContactVCardData(recipient string) (meta *ContactMetadata
}
func (su *smtpUser) getAPIKeyData(recipient string) (apiKeys []pmapi.PublicKey, isInternal bool, err error) {
return su.client().GetPublicKeysForEmail(recipient)
return su.client().GetPublicKeysForEmail(context.TODO(), recipient)
}
// Discard currently processed message.
@ -219,7 +219,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
messageReader = io.TeeReader(messageReader, b)
mailSettings, err := su.client().GetMailSettings()
mailSettings, err := su.client().GetMailSettings(context.TODO())
if err != nil {
return err
}
@ -325,12 +325,6 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
return nil
}
if ok, err := su.isTotalSizeOkay(message, attReaders); err != nil {
return err
} else if !ok {
return errors.New("message is too large")
}
su.backend.sendRecorder.addMessage(sendRecorderMessageHash)
message, atts, err := su.storeUser.CreateDraft(kr, message, attReaders, attachedPublicKey, attachedPublicKeyName, parentID)
if err != nil {
@ -346,7 +340,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
// can lead to sending the wrong message. Also clients do not necessarily
// delete the old draft.
if draftID != "" {
if err := su.client().DeleteMessages([]string{draftID}); err != nil {
if err := su.client().DeleteMessages(context.TODO(), []string{draftID}); err != nil {
log.WithError(err).WithField("draftID", draftID).Warn("Original draft cannot be deleted")
}
}
@ -400,7 +394,7 @@ func (su *smtpUser) Send(returnPath string, to []string, messageReader io.Reader
return errors.New("error decoding subject message " + message.Header.Get("Subject"))
}
if !su.continueSendingUnencryptedMail(subject) {
if err := su.client().DeleteMessages([]string{message.ID}); err != nil {
if err := su.client().DeleteMessages(context.TODO(), []string{message.ID}); err != nil {
log.WithError(err).Warn("Failed to delete canceled messages")
}
return errors.New("sending was canceled by user")
@ -429,7 +423,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
if su.addressID != "" {
filter.AddressID = su.addressID
}
metadata, _, _ := su.client().ListMessages(filter)
metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
for _, m := range metadata {
if m.IsDraft() {
draftID = m.ID
@ -449,7 +443,7 @@ func (su *smtpUser) handleReferencesHeader(m *pmapi.Message) (draftID, parentID
if su.addressID != "" {
filter.AddressID = su.addressID
}
metadata, _, _ := su.client().ListMessages(filter)
metadata, _, _ := su.client().ListMessages(context.TODO(), filter)
// There can be two or messages with the same external ID and then we cannot
// be sure which message should be parent. Better to not choose any.
if len(metadata) == 1 {
@ -541,24 +535,3 @@ func (su *smtpUser) Logout() error {
log.Debug("SMTP client logged out user ", su.addressID)
return nil
}
func (su *smtpUser) isTotalSizeOkay(message *pmapi.Message, attReaders []io.Reader) (bool, error) {
maxUpload, err := su.storeUser.GetMaxUpload()
if err != nil {
return false, err
}
var attSize int64
for i := range attReaders {
b, err := ioutil.ReadAll(attReaders[i])
if err != nil {
return false, err
}
attSize += int64(len(b))
attReaders[i] = bytes.NewBuffer(b)
}
return message.Size+attSize <= maxUpload, nil
}

View File

@ -90,7 +90,7 @@ func getLabelPrefix(l *pmapi.Label) string {
switch {
case pmapi.IsSystemLabel(l.ID):
return ""
case l.Exclusive == 1:
case bool(l.Exclusive):
return UserFoldersPrefix
default:
return UserLabelsPrefix

View File

@ -37,8 +37,8 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.newStoreNoEvents(true)
m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
}
func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
@ -52,8 +52,8 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.newStoreNoEvents(true)
m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg2 := getTestMessage("msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
}
@ -63,8 +63,8 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(2))
m.changeNotifier.EXPECT().DeleteMessage(addr1, "All Mail", uint32(1))

View File

@ -18,6 +18,7 @@
package store
import (
"context"
"math/rand"
"time"
@ -80,7 +81,7 @@ func (loop *eventLoop) client() pmapi.Client {
func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Info("Setting first event ID")
event, err := loop.client().GetEvent("")
event, err := loop.client().GetEvent(context.Background(), "")
if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID")
return
@ -99,6 +100,11 @@ func (loop *eventLoop) setFirstEventID() (err error) {
// pollNow starts polling events right away and waits till the events are
// processed so we are sure updates are propagated to the database.
func (loop *eventLoop) pollNow() {
// When event loop is not running, it would cause infinite wait.
if !loop.isRunning {
return
}
eventProcessedCh := make(chan struct{})
loop.pollCh <- eventProcessedCh
<-eventProcessedCh
@ -216,7 +222,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
// (e.g. no internet, ulimit reached etc.)
defer func() {
if errors.Cause(err) == pmapi.ErrAPINotReachable {
if errors.Cause(err) == pmapi.ErrNoConnection {
l.Warn("Internet unavailable")
err = nil
}
@ -232,13 +238,12 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
err = nil
}
_, errUnauthorized := errors.Cause(err).(*pmapi.ErrUnauthorized)
if err == nil {
loop.errCounter = 0
}
// All errors except Invalid Token (which is not possible to recover from) are ignored.
if err != nil && !errUnauthorized && errors.Cause(err) != pmapi.ErrInvalidToken {
// All errors except ErrUnauthorized (which is not possible to recover from) are ignored.
if err != nil && errors.Cause(err) != pmapi.ErrUnauthorized {
l.WithError(err).WithField("errors", loop.errCounter).Error("Error skipped")
loop.errCounter++
if loop.errCounter == errMaxSentry {
@ -259,7 +264,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
loop.pollCounter++
var event *pmapi.Event
if event, err = loop.client().GetEvent(loop.currentEventID); err != nil {
if event, err = loop.client().GetEvent(context.Background(), loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event")
}
@ -286,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
}
}
return event.More == 1, err
return bool(event.More), err
}
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
@ -345,7 +350,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
// Get old addresses for comparisons before updating user.
oldList := loop.client().Addresses()
if err = loop.user.UpdateUser(); err != nil {
if err = loop.user.UpdateUser(context.Background()); err != nil {
if logoutErr := loop.user.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Failed to logout user after failed update")
}
@ -456,8 +461,8 @@ func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi
msgLog.WithError(err).Warning("Message was not present in DB. Trying fetch...")
if msg, err = loop.client().GetMessage(message.ID); err != nil {
if _, ok := err.(*pmapi.ErrUnprocessableEntity); ok {
if msg, err = loop.client().GetMessage(context.Background(), message.ID); err != nil {
if _, ok := err.(pmapi.ErrUnprocessableEntity); ok {
msgLog.WithError(err).Warn("Skipping message update because message exists neither in local DB nor on API")
err = nil
continue

View File

@ -18,6 +18,7 @@
package store
import (
"context"
"net/mail"
"testing"
"time"
@ -39,17 +40,17 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
// Doesn't matter which IDs are used.
// This test is trying to see whether event loop will immediately process
// next event if there is `More` of them.
m.client.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").Return(&pmapi.Event{
EventID: "event50",
More: 1,
More: true,
}, nil),
m.client.EXPECT().GetEvent("event50").Return(&pmapi.Event{
m.client.EXPECT().GetEvent(gomock.Any(), "event50").Return(&pmapi.Event{
EventID: "event70",
More: 0,
More: false,
}, nil),
m.client.EXPECT().GetEvent("event70").Return(&pmapi.Event{
m.client.EXPECT().GetEvent(gomock.Any(), "event70").Return(&pmapi.Event{
EventID: "event71",
More: 0,
More: false,
}, nil),
)
m.newStoreNoEvents(true)
@ -165,7 +166,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
func testEvent(t *testing.T, m *mocksForStore, event *pmapi.Event) {
eventReceived := make(chan struct{})
m.client.EXPECT().GetEvent("latestEventID").DoAndReturn(func(eventID string) (*pmapi.Event, error) {
m.client.EXPECT().GetEvent(gomock.Any(), "latestEventID").DoAndReturn(func(_ context.Context, eventID string) (*pmapi.Event, error) {
defer close(eventReceived)
return event, nil
})
@ -187,7 +188,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
msg := &pmapi.Message{
ID: "msg1",
Subject: "old",
Unread: 0,
Unread: false,
Flags: 10,
Sender: address1,
ToList: []*mail.Address{address2},
@ -199,7 +200,7 @@ func TestEventLoopUpdateMessage(t *testing.T) {
newMsg := &pmapi.Message{
ID: "msg1",
Subject: "new",
Unread: 1,
Unread: true,
Flags: 11,
Sender: address2,
ToList: []*mail.Address{address1},

View File

@ -129,17 +129,10 @@ func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
Color: mc.Color,
Order: mc.Order,
Type: pmapi.LabelTypeMailbox,
Exclusive: mc.isExclusive(),
Exclusive: pmapi.Boolean(mc.IsFolder),
}
}
func (mc *mailboxCounts) isExclusive() int {
if mc.IsFolder {
return 1
}
return 0
}
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
// Don't forget about system folders.
@ -162,7 +155,7 @@ func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) er
mailbox.LabelName = label.Path
mailbox.Color = label.Color
mailbox.Order = label.Order
mailbox.IsFolder = label.Exclusive == 1
mailbox.IsFolder = bool(label.Exclusive)
// Write.
if err = mailbox.txWriteToBucket(countsBkt); err != nil {

View File

@ -75,7 +75,7 @@ func TestMailboxNames(t *testing.T) {
newLabel(100, "labelID1", "Label1"),
newLabel(1000, "folderID1", "Folder1"),
}
foldersAndLabels[1].Exclusive = 1
foldersAndLabels[1].Exclusive = true
for _, counts := range getSystemFolders() {
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())

View File

@ -36,23 +36,41 @@ import (
func (storeMailbox *Mailbox) GetAPIIDsFromUIDRange(start, stop uint32) (apiIDs []string, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetIMAPIDsBucket(tx)
c := b.Cursor()
// GODT-1153 If the mailbox is empty we should reply BAD to client.
if uid, _ := c.Last(); uid == nil {
return nil
}
// If the start range is a wildcard, the range can only refer to the last message in the mailbox.
if start == 0 {
_, apiID := c.Last()
apiIDs = append(apiIDs, string(apiID))
return nil
}
// Resolve the stop value to be the final UID in the mailbox.
if stop == 0 {
// A null stop means no stop.
stop = ^uint32(0)
stop = storeMailbox.txGetFinalUID(b)
}
// After resolving the stop value, it might be less than start so we sort it.
if start > stop {
start, stop = stop, start
}
startb := itob(start)
stopb := itob(stop)
c := b.Cursor()
for k, v := c.Seek(startb); k != nil && bytes.Compare(k, stopb) <= 0; k, v = c.Next() {
apiIDs = append(apiIDs, string(v))
}
return nil
})
return
return apiIDs, err
}
// GetAPIIDsFromSequenceRange returns API IDs by IMAP sequence number range.
@ -60,28 +78,52 @@ func (storeMailbox *Mailbox) GetAPIIDsFromSequenceRange(start, stop uint32) (api
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetIMAPIDsBucket(tx)
c := b.Cursor()
// GODT-1153 If the mailbox is empty we should reply BAD to client.
if uid, _ := c.Last(); uid == nil {
return nil
}
// If the start range is a wildcard, the range can only refer to the last message in the mailbox.
if start == 0 {
_, apiID := c.Last()
apiIDs = append(apiIDs, string(apiID))
return nil
}
var i uint32
for k, v := c.First(); k != nil; k, v = c.Next() {
i++
if i < start {
continue
}
if stop > 0 && i > stop {
break
}
apiIDs = append(apiIDs, string(v))
}
if stop == 0 && len(apiIDs) == 0 {
if _, apiID := c.Last(); len(apiID) > 0 {
apiIDs = append(apiIDs, string(apiID))
}
}
return nil
})
return
return apiIDs, err
}
// GetLatestAPIID returns the latest message API ID which still exists.
// Info: not the latest IMAP UID which can be already removed.
func (storeMailbox *Mailbox) GetLatestAPIID() (apiID string, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetAPIIDsBucket(tx)
c := b.Cursor()
c := storeMailbox.txGetAPIIDsBucket(tx).Cursor()
lastAPIID, _ := c.Last()
apiID = string(lastAPIID)
if apiID == "" {
@ -283,3 +325,13 @@ func (storeMailbox *Mailbox) GetUIDByHeader(header *mail.Header) (foundUID uint3
return foundUID
}
func (storeMailbox *Mailbox) txGetFinalUID(b *bolt.Bucket) uint32 {
uid, _ := b.Cursor().Last()
if uid == nil {
panic(errors.New("cannot get final UID of empty mailbox"))
}
return btoi(uid)
}

View File

@ -37,10 +37,10 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{pmapi.AllMailLabel})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
@ -56,7 +56,7 @@ func checkMailboxMessageIDs(t *testing.T, m *mocksForStore, mailboxLabel string,
storeAddress := m.store.addresses[addrID1]
storeMailbox := storeAddress.mailboxes[mailboxLabel]
ids, err := storeMailbox.GetAPIIDsFromSequenceRange(0, uint32(len(wantIDs)))
ids, err := storeMailbox.GetAPIIDsFromSequenceRange(1, uint32(len(wantIDs)))
require.Nil(t, err)
idx := 0
@ -82,20 +82,20 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m.newStoreNoEvents(true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " externalID-non-pm-com "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = "<externalID@pm.me>"
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
// Not sure if this is a real-world scenario but we should be able to address this properly.
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))

View File

@ -41,7 +41,7 @@ func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.client().GetMessage(apiID)
msg, err := storeMailbox.client().GetMessage(exposeContextForIMAP(), apiID)
if err != nil {
return nil, err
}
@ -58,15 +58,17 @@ func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labe
}
importReqs := &pmapi.ImportMsgReq{
AddressID: msg.AddressID,
Body: body,
Unread: msg.Unread,
Flags: msg.Flags,
Time: msg.Time,
LabelIDs: labelIDs,
Metadata: &pmapi.ImportMetadata{
AddressID: msg.AddressID,
Unread: msg.Unread,
Flags: msg.Flags,
Time: msg.Time,
LabelIDs: labelIDs,
},
Message: body,
}
res, err := storeMailbox.client().Import([]*pmapi.ImportMsgReq{importReqs})
res, err := storeMailbox.client().Import(exposeContextForIMAP(), pmapi.ImportMsgReqs{importReqs})
if err != nil {
return err
}
@ -95,7 +97,7 @@ func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed
}
defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(apiIDs, storeMailbox.labelID)
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
}
// UnlabelMessages removes the label by calling an API.
@ -108,7 +110,7 @@ func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
return ErrAllMailOpNotAllowed
}
defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID)
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID)
}
// MarkMessagesRead marks the message read by calling an API.
@ -128,14 +130,14 @@ func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
// Therefore we do not issue API update if the message is already read.
ids := []string{}
for _, apiID := range apiIDs {
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 {
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread {
ids = append(ids, apiID)
}
}
if len(ids) == 0 {
return nil
}
return storeMailbox.client().MarkMessagesRead(ids)
return storeMailbox.client().MarkMessagesRead(exposeContextForIMAP(), ids)
}
// MarkMessagesUnread marks the message unread by calling an API.
@ -147,7 +149,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread")
defer storeMailbox.pollNow()
return storeMailbox.client().MarkMessagesUnread(apiIDs)
return storeMailbox.client().MarkMessagesUnread(exposeContextForIMAP(), apiIDs)
}
// MarkMessagesStarred adds the Starred label by calling an API.
@ -160,7 +162,7 @@ func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred")
defer storeMailbox.pollNow()
return storeMailbox.client().LabelMessages(apiIDs, pmapi.StarredLabel)
return storeMailbox.client().LabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
}
// MarkMessagesUnstarred removes the Starred label by calling an API.
@ -173,7 +175,7 @@ func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow()
return storeMailbox.client().UnlabelMessages(apiIDs, pmapi.StarredLabel)
return storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, pmapi.StarredLabel)
}
// MarkMessagesDeleted adds local flag \Deleted. This is not propagated to API
@ -257,11 +259,11 @@ func (storeMailbox *Mailbox) RemoveDeleted(apiIDs []string) error {
}
case pmapi.DraftLabel:
storeMailbox.log.WithField("ids", apiIDs).Warn("Deleting drafts")
if err := storeMailbox.client().DeleteMessages(apiIDs); err != nil {
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), apiIDs); err != nil {
return err
}
default:
if err := storeMailbox.client().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), apiIDs, storeMailbox.labelID); err != nil {
return err
}
}
@ -299,13 +301,13 @@ func (storeMailbox *Mailbox) deleteFromTrashOrSpam(apiIDs []string) error {
}
}
if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.client().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
if err := storeMailbox.client().UnlabelMessages(exposeContextForIMAP(), messageIDsToUnlabel, storeMailbox.labelID); err != nil {
l.WithError(err).Warning("Cannot unlabel before deleting")
}
}
if len(messageIDsToDelete) > 0 {
storeMailbox.log.WithField("ids", messageIDsToDelete).Warn("Deleting messages")
if err := storeMailbox.client().DeleteMessages(messageIDsToDelete); err != nil {
if err := storeMailbox.client().DeleteMessages(exposeContextForIMAP(), messageIDsToDelete); err != nil {
return err
}
}

View File

@ -25,6 +25,7 @@ import (
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
@ -153,15 +154,27 @@ func (message *Message) getRawHeader() (raw []byte, err error) {
}
// GetHeader will return cached header from DB.
func (message *Message) GetHeader() textproto.MIMEHeader {
func (message *Message) GetHeader() []byte {
raw, err := message.getRawHeader()
if err != nil && raw == nil {
return textproto.MIMEHeader(message.msg.Header)
if err != nil {
panic(errors.Wrap(err, "failed to get raw message header"))
}
return raw
}
// GetMIMEHeader will return cached header from DB, parsed as a textproto.MIMEHeader.
func (message *Message) GetMIMEHeader() textproto.MIMEHeader {
raw, err := message.getRawHeader()
if err != nil {
panic(errors.Wrap(err, "failed to get raw message header"))
}
header, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(raw))).ReadMIMEHeader()
if err != nil {
return textproto.MIMEHeader(message.msg.Header)
}
return header
}

View File

@ -1,10 +1,11 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,ClientManager,BridgeUser,ChangeNotifier)
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
reflect "reflect"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -46,43 +47,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockBridgeUser is a mock of BridgeUser interface
type MockBridgeUser struct {
ctrl *gomock.Controller
@ -145,6 +109,20 @@ func (mr *MockBridgeUserMockRecorder) GetAddressID(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0)
}
// GetClient mocks base method
func (m *MockBridgeUser) GetClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockBridgeUserMockRecorder) GetClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockBridgeUser)(nil).GetClient))
}
// GetPrimaryAddress mocks base method
func (m *MockBridgeUser) GetPrimaryAddress() string {
m.ctrl.T.Helper()
@ -230,17 +208,17 @@ func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
}
// UpdateUser mocks base method
func (m *MockBridgeUser) UpdateUser() error {
func (m *MockBridgeUser) UpdateUser(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUser")
ret := m.ctrl.Call(m, "UpdateUser", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateUser indicates an expected call of UpdateUser
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call {
func (mr *MockBridgeUserMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser), arg0)
}
// MockChangeNotifier is a mock of ChangeNotifier interface

View File

@ -58,6 +58,20 @@ func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
}
// ProvideChannel mocks base method
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
ret0, _ := ret[0].(<-chan string)
return ret0
}
// ProvideChannel indicates an expected call of ProvideChannel
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
}
// Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()

View File

@ -19,6 +19,7 @@
package store
import (
"context"
"fmt"
"os"
"sync"
@ -100,13 +101,24 @@ var (
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
)
// exposeContextForIMAP should be replaced once with context passed
// as an argument from IMAP package and IMAP library should cancel
// context when IMAP client cancels the request.
func exposeContextForIMAP() context.Context {
return context.TODO()
}
// exposeContextForSMTP is the same as above but for SMTP.
func exposeContextForSMTP() context.Context {
return context.TODO()
}
// Store is local user storage, which handles the synchronization between IMAP and PM API.
type Store struct {
sentryReporter *sentry.Reporter
panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser
clientManager ClientManager
log *logrus.Entry
@ -127,13 +139,12 @@ func New( // nolint[funlen]
sentryReporter *sentry.Reporter,
panicHandler PanicHandler,
user BridgeUser,
clientManager ClientManager,
events listener.Listener,
path string,
cache *Cache,
) (store *Store, err error) {
if user == nil || clientManager == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, clientManager, events, cache)
if user == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
}
l := log.WithField("user", user.ID())
@ -156,7 +167,6 @@ func New( // nolint[funlen]
store = &Store{
sentryReporter: sentryReporter,
panicHandler: panicHandler,
clientManager: clientManager,
user: user,
cache: cache,
filePath: path,
@ -274,13 +284,13 @@ func (store *Store) init(firstInit bool) (err error) {
}
func (store *Store) client() pmapi.Client {
return store.clientManager.GetClient(store.UserID())
return store.user.GetClient()
}
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.client().ListLabels(); err != nil {
if labels, err = store.client().ListLabels(context.Background()); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels")

View File

@ -18,6 +18,7 @@
package store
import (
"context"
"fmt"
"io/ioutil"
"os"
@ -25,6 +26,7 @@ import (
"testing"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
@ -39,8 +41,92 @@ const (
addr2 = "jamesandmichalarecool@pm.me"
addrID2 = "jamesandmichalarecool"
testPrivateKeyPassword = "apple"
testPrivateKey = `-----BEGIN PGP PRIVATE KEY BLOCK-----
Version: OpenPGP.js v0.7.1
Comment: http://openpgpjs.org
xcMGBFRJbc0BCAC0mMLZPDBbtSCWvxwmOfXfJkE2+ssM3ux21LhD/bPiWefE
WSHlCjJ8PqPHy7snSiUuxuj3f9AvXPvg+mjGLBwu1/QsnSP24sl3qD2onl39
vPiLJXUqZs20ZRgnvX70gjkgEzMFBxINiy2MTIG+4RU8QA7y8KzWev0btqKi
MeVa+GLEHhgZ2KPOn4Jv1q4bI9hV0C9NUe2tTXS6/Vv3vbCY7lRR0kbJ65T5
c8CmpqJuASIJNrSXM/Q3NnnsY4kBYH0s5d2FgbASQvzrjuC2rngUg0EoPsrb
DEVRA2/BCJonw7aASiNCrSP92lkZdtYlax/pcoE/mQ4WSwySFmcFT7yFABEB
AAH+CQMIvzcDReuJkc9gnxAkfgmnkBFwRQrqT/4UAPOF8WGVo0uNvDo7Snlk
qWsJS+54+/Xx6Jur/PdBWeEu+6+6GnppYuvsaT0D0nFdFhF6pjng+02IOxfG
qlYXYcW4hRru3BfvJlSvU2LL/Z/ooBnw3T5vqd0eFHKrvabUuwf0x3+K/sru
Fp24rl2PU+bzQlUgKpWzKDmO+0RdKQ6KVCyCDMIXaAkALwNffAvYxI0wnb2y
WAV/bGn1ODnszOYPk3pEMR6kKSxLLaO69kYx4eTERFyJ+1puAxEPCk3Cfeif
yDWi4rU03YB16XH7hQLSFl61SKeIYlkKmkO5Hk1ybi/BhvOGBPVeGGbxWnwI
46G8DfBHW0+uvD5cAQtk2d/q3Ge1I+DIyvuRCcSu0XSBNv/Bkpp4IbAUPBaW
TIvf5p9oxw+AjrMtTtcdSiee1S6CvMMaHhVD7SI6qGA8GqwaXueeLuEXa0Ok
BWlehx8wibMi4a9fLcQZtzJkmGhR1WzXcJfiEg32srILwIzPQYxuFdZZ2elb
gYp/bMEIp4LKhi43IyM6peCDHDzEba8NuOSd0heEqFIm0vlXujMhkyMUvDBv
H0V5On4aMuw/aSEKcAdbazppOru/W1ndyFa5ZHQIC19g72ZaDVyYjPyvNgOV
AFqO4o3IbC5z31zMlTtMbAq2RG9svwUVejn0tmF6UPluTe0U1NuXFpLK6TCH
wqocLz4ecptfJQulpYjClVLgzaYGDuKwQpIwPWg5G/DtKSCGNtEkfqB3aemH
V5xmoYm1v5CQZAEvvsrLA6jxCk9lzqYV8QMivWNXUG+mneIEM35G0HOPzXca
LLyB+N8Zxioc9DPGfdbcxXuVgOKRepbkq4xv1pUpMQ4BUmlkejDRSP+5SIR3
iEthg+FU6GRSQbORE6nhrKjGBk8fpNpozQZVc2VySUTCwHIEEAEIACYFAlRJ
bc8GCwkIBwMCCRA+tiWe3yHfJAQVCAIKAxYCAQIbAwIeAQAA9J0H/RLR/Uwt
CakrPKtfeGaNuOI45SRTNxM8TklC6tM28sJSzkX8qKPzvI1PxyLhs/i0/fCQ
7Z5bU6n41oLuqUt2S9vy+ABlChKAeziOqCHUcMzHOtbKiPkKW88aO687nx+A
ol2XOnMTkVIC+edMUgnKp6tKtZnbO4ea6Cg88TFuli4hLHNXTfCECswuxHOc
AO1OKDRrCd08iPI5CLNCIV60QnduitE1vF6ehgrH25Vl6LEdd8vPVlTYAvsa
6ySk2RIrHNLUZZ3iII3MBFL8HyINp/XA1BQP+QbH801uSLq8agxM4iFT9C+O
D147SawUGhjD5RG7T+YtqItzgA1V9l277EXHwwYEVEltzwEIAJD57uX6bOc4
Tgf3utfL/4hdyoqIMVHkYQOvE27wPsZxX08QsdlaNeGji9Ap2ifIDuckUqn6
Ji9jtZDKtOzdTBm6rnG5nPmkn6BJXPhnecQRP8N0XBISnAGmE4t+bxtts5Wb
qeMdxJYqMiGqzrLBRJEIDTcg3+QF2Y3RywOqlcXqgG/xX++PsvR1Jiz0rEVP
TcBc7ytyb/Av7mx1S802HRYGJHOFtVLoPTrtPCvv+DRDK8JzxQW2XSQLlI0M
9s1tmYhCogYIIqKx9qOTd5mFJ1hJlL6i9xDkvE21qPFASFtww5tiYmUfFaxI
LwbXPZlQ1I/8fuaUdOxctQ+g40ZgHPcAEQEAAf4JAwgdUg8ubE2BT2DITBD+
XFgjrnUlQBilbN8/do/36KHuImSPO/GGLzKh4+oXxrvLc5fQLjeO+bzeen4u
COCBRO0hG7KpJPhQ6+T02uEF6LegE1sEz5hp6BpKUdPZ1+8799Rylb5kubC5
IKnLqqpGDbH3hIsmSV3CG/ESkaGMLc/K0ZPt1JRWtUQ9GesXT0v6fdM5GB/L
cZWFdDoYgZAw5BtymE44knIodfDAYJ4DHnPCh/oilWe1qVTQcNMdtkpBgkuo
THecqEmiODQz5EX8pVmS596XsnPO299Lo3TbaHUQo7EC6Au1Au9+b5hC1pDa
FVCLcproi/Cgch0B/NOCFkVLYmp6BEljRj2dSZRWbO0vgl9kFmJEeiiH41+k
EAI6PASSKZs3BYLFc2I8mBkcvt90kg4MTBjreuk0uWf1hdH2Rv8zprH4h5Uh
gjx5nUDX8WXyeLxTU5EBKry+A2DIe0Gm0/waxp6lBlUl+7ra28KYEoHm8Nq/
N9FCuEhFkFgw6EwUp7jsrFcqBKvmni6jyplm+mJXi3CK+IiNcqub4XPnBI97
lR19fupB/Y6M7yEaxIM8fTQXmP+x/fe8zRphdo+7o+pJQ3hk5LrrNPK8GEZ6
DLDOHjZzROhOgBvWtbxRktHk+f5YpuQL+xWd33IV1xYSSHuoAm0Zwt0QJxBs
oFBwJEq1NWM4FxXJBogvzV7KFhl/hXgtvx+GaMv3y8gucj+gE89xVv0XBXjl
5dy5/PgCI0Id+KAFHyKpJA0N0h8O4xdJoNyIBAwDZ8LHt0vlnLGwcJFR9X7/
PfWe0PFtC3d7cYY3RopDhnRP7MZs1Wo9nZ4IvlXoEsE2nPkWcns+Wv5Yaewr
s2ra9ZIK7IIJhqKKgmQtCeiXyFwTq+kfunDnxeCavuWL3HuLKIOZf7P9vXXt
XgEir9rCwF8EGAEIABMFAlRJbdIJED62JZ7fId8kAhsMAAD+LAf+KT1EpkwH
0ivTHmYako+6qG6DCtzd3TibWw51cmbY20Ph13NIS/MfBo828S9SXm/sVUzN
/r7qZgZYfI0/j57tG3BguVGm53qya4bINKyi1RjK6aKo/rrzRkh5ZVD5rVNO
E2zzvyYAnLUWG9AV1OYDxcgLrXqEMWlqZAo+Wmg7VrTBmdCGs/BPvscNgQRr
6Gpjgmv9ru6LjRL7vFhEcov/tkBLj+CtaWWFTd1s2vBLOs4rCsD9TT/23vfw
CnokvvVjKYN5oviy61yhpqF1rWlOsxZ4+2sKW3Pq7JLBtmzsZegTONfcQAf7
qqGRQm3MxoTdgQUShAwbNwNNQR9cInfMnA==
=2wIY
-----END PGP PRIVATE KEY BLOCK-----
`
)
var testPrivateKeyRing *crypto.KeyRing
func init() {
privKey, err := crypto.NewKeyFromArmored(testPrivateKey)
if err != nil {
panic(err)
}
privKeyUnlocked, err := privKey.Unlock([]byte(testPrivateKeyPassword))
if err != nil {
panic(err)
}
if testPrivateKeyRing, err = crypto.NewKeyRing(privKeyUnlocked); err != nil {
panic(err)
}
}
type mocksForStore struct {
tb testing.TB
@ -48,7 +134,6 @@ type mocksForStore struct {
events *storemocks.MockListener
user *storemocks.MockBridgeUser
client *pmapimocks.MockClient
clientManager *storemocks.MockClientManager
panicHandler *storemocks.MockPanicHandler
changeNotifier *storemocks.MockChangeNotifier
store *Store
@ -65,7 +150,6 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
events: storemocks.NewMockListener(ctrl),
user: storemocks.NewMockBridgeUser(ctrl),
client: pmapimocks.NewMockClient(ctrl),
clientManager: storemocks.NewMockClientManager(ctrl),
panicHandler: storemocks.NewMockPanicHandler(ctrl),
changeNotifier: storemocks.NewMockChangeNotifier(ctrl),
}
@ -97,30 +181,30 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.clientManager.EXPECT().GetClient("userID").AnyTimes().Return(mocks.client)
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
})
mocks.client.EXPECT().ListLabels().AnyTimes()
mocks.client.EXPECT().CountMessages("")
mocks.client.EXPECT().ListLabels(gomock.Any()).AnyTimes()
mocks.client.EXPECT().CountMessages(gomock.Any(), "")
// Call to get latest event ID and then to process first event.
eventAfterSyncRequested := make(chan struct{})
mocks.client.EXPECT().GetEvent("").Return(&pmapi.Event{
mocks.client.EXPECT().GetEvent(gomock.Any(), "").Return(&pmapi.Event{
EventID: "firstEventID",
}, nil)
mocks.client.EXPECT().GetEvent("firstEventID").DoAndReturn(func(_ string) (*pmapi.Event, error) {
mocks.client.EXPECT().GetEvent(gomock.Any(), "firstEventID").DoAndReturn(func(_ context.Context, _ string) (*pmapi.Event, error) {
close(eventAfterSyncRequested)
return &pmapi.Event{
EventID: "latestEventID",
}, nil
})
mocks.client.EXPECT().ListMessages(gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
mocks.client.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return(msgs, len(msgs), nil).AnyTimes()
for _, msg := range msgs {
mocks.client.EXPECT().GetMessage(msg.ID).Return(msg, nil).AnyTimes()
mocks.client.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
}
var err error
@ -128,7 +212,6 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
nil, // Sentry reporter is not used under unit tests.
mocks.panicHandler,
mocks.user,
mocks.clientManager,
mocks.events,
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache,

View File

@ -18,6 +18,7 @@
package store
import (
"context"
"math"
"sync"
@ -39,10 +40,10 @@ type storeSynchronizer interface {
}
type messageLister interface {
ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
ListMessages(context.Context, *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
}
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func() messageLister, syncState *syncState) error {
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error {
labelID := pmapi.AllMailLabel
// When the full sync starts (i.e. is not already in progress), we need to load
@ -53,7 +54,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
return errors.Wrap(err, "failed to load message IDs")
}
if err := findIDRanges(labelID, api(), syncState); err != nil {
if err := findIDRanges(labelID, api, syncState); err != nil {
return errors.Wrap(err, "failed to load IDs ranges")
}
syncState.save()
@ -71,7 +72,7 @@ func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api func()
defer panicHandler.HandlePanic()
defer wg.Done()
err := syncBatch(labelID, store, api(), syncState, idRange, &shouldStop)
err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop)
if err != nil {
shouldStop = 1
resultError = errors.Wrap(err, "failed to sync group")
@ -147,7 +148,7 @@ func getSplitIDAndCount(labelID string, api messageLister, page int) (string, in
Limit: 1,
}
// If the page does not exist, an empty page instead of an error is returned.
messages, total, err := api.ListMessages(filter)
messages, total, err := api.ListMessages(context.Background(), filter)
if err != nil {
return "", 0, errors.Wrap(err, "failed to list messages")
}
@ -189,7 +190,7 @@ func syncBatch( //nolint[funlen]
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
messages, _, err := api.ListMessages(filter)
messages, _, err := api.ListMessages(context.Background(), filter)
if err != nil {
return errors.Wrap(err, "failed to list messages")
}

View File

@ -18,6 +18,7 @@
package store
import (
"context"
"sort"
"strconv"
"sync"
@ -34,7 +35,7 @@ type mockLister struct {
messageIDs []string
}
func (m *mockLister) ListMessages(filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
func (m *mockLister) ListMessages(_ context.Context, filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
if m.err != nil {
return nil, 0, m.err
}
@ -197,7 +198,7 @@ func TestSyncAllMail(t *testing.T) { //nolint[funlen]
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.Nil(t, err)
// Check all messages were created or updated.
@ -245,7 +246,7 @@ func TestSyncAllMail_FailedListing(t *testing.T) {
}
syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to list messages: error")
}
@ -264,7 +265,7 @@ func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
}
syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, func() messageLister { return api }, syncState)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
}

View File

@ -17,16 +17,16 @@
package store
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface {
HandlePanic()
}
type ClientManager interface {
GetClient(userID string) pmapi.Client
}
// BridgeUser is subset of bridge.User for use by the Store.
type BridgeUser interface {
ID() string
@ -35,7 +35,8 @@ type BridgeUser interface {
IsCombinedAddressMode() bool
GetPrimaryAddress() string
GetStoreAddresses() []string
UpdateUser() error
GetClient() pmapi.Client
UpdateUser(context.Context) error
CloseAllConnections()
CloseConnection(string)
Logout() error

View File

@ -24,7 +24,7 @@ func (store *Store) UserID() string {
// GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.client().CurrentUser()
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
if err != nil {
return 0, 0, err
}
@ -33,7 +33,7 @@ func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
// GetMaxUpload returns max size of message + all attachments in bytes.
func (store *Store) GetMaxUpload() (int64, error) {
apiUser, err := store.client().CurrentUser()
apiUser, err := store.client().CurrentUser(exposeContextForIMAP())
if err != nil {
return 0, err
}

View File

@ -147,7 +147,7 @@ func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (er
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
for _, address := range addressList {
if address.Receive != pmapi.CanReceive {
if !address.Receive {
continue
}

View File

@ -38,14 +38,14 @@ func (store *Store) createMailbox(name string) error {
color := store.leastUsedColor()
var exclusive int
var exclusive bool
switch {
case strings.HasPrefix(name, UserLabelsPrefix):
name = strings.TrimPrefix(name, UserLabelsPrefix)
exclusive = 0
exclusive = false
case strings.HasPrefix(name, UserFoldersPrefix):
name = strings.TrimPrefix(name, UserFoldersPrefix)
exclusive = 1
exclusive = true
default:
// Ideally we would throw an error here, but then Outlook for
// macOS keeps trying to make an IMAP Drafts folder and popping
@ -55,10 +55,10 @@ func (store *Store) createMailbox(name string) error {
return nil
}
_, err := store.client().CreateLabel(&pmapi.Label{
_, err := store.client().CreateLabel(exposeContextForIMAP(), &pmapi.Label{
Name: name,
Color: color,
Exclusive: exclusive,
Exclusive: pmapi.Boolean(exclusive),
Type: pmapi.LabelTypeMailbox,
})
return err
@ -125,7 +125,7 @@ func (store *Store) leastUsedColor() string {
func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow()
_, err := store.client().UpdateLabel(&pmapi.Label{
_, err := store.client().UpdateLabel(exposeContextForIMAP(), &pmapi.Label{
ID: labelID,
Name: newName,
Color: color,
@ -142,15 +142,15 @@ func (store *Store) deleteMailbox(labelID, addressID string) error {
var err error
switch labelID {
case pmapi.SpamLabel:
err = store.client().EmptyFolder(pmapi.SpamLabel, addressID)
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.SpamLabel, addressID)
case pmapi.TrashLabel:
err = store.client().EmptyFolder(pmapi.TrashLabel, addressID)
err = store.client().EmptyFolder(exposeContextForIMAP(), pmapi.TrashLabel, addressID)
default:
err = fmt.Errorf("cannot empty mailbox %v", labelID)
}
return err
}
return store.client().DeleteLabel(labelID)
return store.client().DeleteLabel(exposeContextForIMAP(), labelID)
}
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
@ -165,7 +165,7 @@ func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) erro
return nil
}
labels, err := store.client().ListLabels()
labels, err := store.client().ListLabels(exposeContextForIMAP())
if err != nil {
return err
}

View File

@ -44,45 +44,132 @@ func (store *Store) CreateDraft(
attachedPublicKey,
attachedPublicKeyName string,
parentID string) (*pmapi.Message, []*pmapi.Attachment, error) {
defer store.eventLoop.pollNow()
attachments := store.prepareDraftAttachments(message, attachmentReaders, attachedPublicKey, attachedPublicKeyName)
// Since this is a draft, we don't need to sign it.
if err := message.Encrypt(kr, nil); err != nil {
if err := encryptDraft(kr, message, attachments); err != nil {
return nil, nil, errors.Wrap(err, "failed to encrypt draft")
}
attachments := message.Attachments
message.Attachments = nil
if ok, err := store.checkDraftTotalSize(message, attachments); err != nil {
return nil, nil, err
} else if !ok {
return nil, nil, errors.New("message is too large")
}
draftAction := store.getDraftAction(message)
draft, err := store.client().CreateDraft(message, parentID, draftAction)
draft, err := store.client().CreateDraft(exposeContextForSMTP(), message, parentID, draftAction)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft")
}
// Do poll only when call to API succeeded.
defer store.eventLoop.pollNow()
createdAttachments := []*pmapi.Attachment{}
for _, att := range attachments {
att.attachment.MessageID = draft.ID
createdAttachment, err := store.client().CreateAttachment(exposeContextForSMTP(), att.attachment, att.encReader, att.sigReader)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create attachment")
}
createdAttachments = append(createdAttachments, createdAttachment)
}
return draft, createdAttachments, nil
}
type draftAttachment struct {
attachment *pmapi.Attachment
reader io.Reader
sigReader io.Reader
encReader io.Reader
}
func (store *Store) prepareDraftAttachments(
message *pmapi.Message,
attachmentReaders []io.Reader,
attachedPublicKey,
attachedPublicKeyName string) []*draftAttachment {
attachments := []*draftAttachment{}
for idx, attachment := range message.Attachments {
attachments = append(attachments, &draftAttachment{
attachment: attachment,
reader: attachmentReaders[idx],
})
}
message.Attachments = nil
if attachedPublicKey != "" {
attachmentReaders = append(attachmentReaders, strings.NewReader(attachedPublicKey))
publicKeyAttachment := &pmapi.Attachment{
Name: attachedPublicKeyName + ".asc",
MIMEType: "application/pgp-keys",
Header: textproto.MIMEHeader{},
}
attachments = append(attachments, publicKeyAttachment)
attachments = append(attachments, &draftAttachment{
attachment: publicKeyAttachment,
reader: strings.NewReader(attachedPublicKey),
})
}
for idx, attachment := range attachments {
attachment.MessageID = draft.ID
attachmentBody, _ := ioutil.ReadAll(attachmentReaders[idx])
return attachments
}
createdAttachment, err := store.createAttachment(kr, attachment, attachmentBody)
func encryptDraft(kr *crypto.KeyRing, message *pmapi.Message, attachments []*draftAttachment) error {
// Since this is a draft, we don't need to sign it.
if err := message.Encrypt(kr, nil); err != nil {
return errors.Wrap(err, "failed to encrypt message")
}
for _, att := range attachments {
attachment := att.attachment
attachmentBody, err := ioutil.ReadAll(att.reader)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create attachment for draft")
return errors.Wrap(err, "failed to read attachment")
}
attachments[idx] = createdAttachment
r := bytes.NewReader(attachmentBody)
sigReader, err := attachment.DetachedSign(kr, r)
if err != nil {
return errors.Wrap(err, "failed to sign attachment")
}
att.sigReader = sigReader
r = bytes.NewReader(attachmentBody)
encReader, err := attachment.Encrypt(kr, r)
if err != nil {
return errors.Wrap(err, "failed to encrypt attachment")
}
att.encReader = encReader
att.reader = nil
}
return nil
}
func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*draftAttachment) (bool, error) {
maxUpload, err := store.GetMaxUpload()
if err != nil {
return false, err
}
return draft, attachments, nil
msgSize := message.Size
if msgSize == 0 {
msgSize = int64(len(message.Body))
}
var attSize int64
for _, att := range attachments {
b, err := ioutil.ReadAll(att.encReader)
if err != nil {
return false, err
}
attSize += int64(len(b))
att.encReader = bytes.NewBuffer(b)
}
return msgSize+attSize <= maxUpload, nil
}
func (store *Store) getDraftAction(message *pmapi.Message) int {
@ -93,31 +180,10 @@ func (store *Store) getDraftAction(message *pmapi.Message) int {
return pmapi.DraftActionReply
}
func (store *Store) createAttachment(kr *crypto.KeyRing, attachment *pmapi.Attachment, attachmentBody []byte) (*pmapi.Attachment, error) {
r := bytes.NewReader(attachmentBody)
sigReader, err := attachment.DetachedSign(kr, r)
if err != nil {
return nil, errors.Wrap(err, "failed to sign attachment")
}
r = bytes.NewReader(attachmentBody)
encReader, err := attachment.Encrypt(kr, r)
if err != nil {
return nil, errors.Wrap(err, "failed to encrypt attachment")
}
createdAttachment, err := store.client().CreateAttachment(attachment, encReader, sigReader)
if err != nil {
return nil, errors.Wrap(err, "failed to create attachment")
}
return createdAttachment, nil
}
// SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow()
_, _, err := store.client().SendMessage(messageID, req)
_, _, err := store.client().SendMessage(exposeContextForSMTP(), messageID, req)
return err
}

View File

@ -18,10 +18,13 @@
package store
import (
"io"
"net/mail"
"strings"
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -32,10 +35,10 @@ func TestGetAllMessageIDs(t *testing.T) {
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, false, []string{})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
}
@ -45,7 +48,7 @@ func TestGetMessageFromDB(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{
{"msg1", ""},
@ -70,7 +73,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1")
require.Nil(t, err)
@ -102,7 +105,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
a.Equal(t, wantHeader, msg.Header)
// Check calculated data are not overridden by reinsert.
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
@ -116,8 +119,8 @@ func TestDeleteMessage(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.deleteMessageEvent("msg1"))
@ -125,17 +128,17 @@ func TestDeleteMessage(t *testing.T) {
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
}
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread int, labelIDs []string) { //nolint[unparam]
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
}
func getTestMessage(id, subject, sender string, unread int, labelIDs []string) *pmapi.Message {
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
address := &mail.Address{Address: sender}
return &pmapi.Message{
ID: id,
Subject: subject,
Unread: unread,
Unread: pmapi.Boolean(unread),
Sender: address,
ToList: []*mail.Address{address},
LabelIDs: labelIDs,
@ -154,3 +157,47 @@ func checkAllMessageIDs(t *testing.T, m *mocksForStore, wantIDs []string) {
require.Nil(t, allErr)
require.Equal(t, wantIDs, allIds)
}
func TestCreateDraftCheckMessageSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil)
// Even small body is bloated to at least about 500 chars of basic pgp message.
message := &pmapi.Message{
Body: strings.Repeat("a", 5),
}
attachmentReaders := []io.Reader{}
_, _, err := m.store.CreateDraft(testPrivateKeyRing, message, attachmentReaders, "", "", "")
require.EqualError(t, err, "message is too large")
}
func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil)
// Even small body is bloated to at least about 500 chars of basic pgp message.
message := &pmapi.Message{
Body: strings.Repeat("a", 5),
Attachments: []*pmapi.Attachment{
{Name: "name"},
},
}
// Even small attachment is bloated to about 300 chars of encrypted text.
attachmentReaders := []io.Reader{
strings.NewReader(strings.Repeat("b", 5)),
}
_, _, err := m.store.CreateDraft(testPrivateKeyRing, message, attachmentReaders, "", "", "")
require.EqualError(t, err, "message is too large")
}

View File

@ -18,6 +18,7 @@
package store
import (
"context"
"encoding/json"
"fmt"
"strconv"
@ -34,7 +35,7 @@ const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error {
counts, err := store.client().CountMessages("")
counts, err := store.client().CountMessages(context.Background(), "")
if err != nil {
return errors.Wrap(err, "cannot update counts from server")
}
@ -152,7 +153,7 @@ func (store *Store) triggerSync() {
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
err := syncAllMail(store.panicHandler, store, func() messageLister { return store.client() }, syncState)
err := syncAllMail(store.panicHandler, store, store.client(), syncState)
if err != nil {
log.WithError(err).Error("Store sync failed")
store.syncCooldown.increaseWaitTime()

View File

@ -31,8 +31,8 @@ func TestLoadSaveSyncState(t *testing.T) {
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
// Clear everything.

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,ClientManager,IMAPClientProvider)
// Source: github.com/ProtonMail/proton-bridge/internal/transfer (interfaces: PanicHandler,IMAPClientProvider)
// Package mocks is a generated GoMock package.
package mocks
@ -7,7 +7,6 @@ package mocks
import (
reflect "reflect"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
imap "github.com/emersion/go-imap"
sasl "github.com/emersion/go-sasl"
gomock "github.com/golang/mock/gomock"
@ -48,57 +47,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// CheckConnection mocks base method
func (m *MockClientManager) CheckConnection() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckConnection")
ret0, _ := ret[0].(error)
return ret0
}
// CheckConnection indicates an expected call of CheckConnection
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockIMAPClientProvider is a mock of IMAPClientProvider interface
type MockIMAPClientProvider struct {
ctrl *gomock.Controller

View File

@ -19,13 +19,14 @@ package transfer
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"time"
imapID "github.com/ProtonMail/go-imap-id"
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
imapClient "github.com/emersion/go-imap/client"
"github.com/emersion/go-sasl"
@ -38,6 +39,8 @@ const (
imapRetries = 10
imapReconnectTimeout = 30 * time.Minute
imapReconnectSleep = time.Minute
protonStatusURL = "http://protonstatus.com/vpn_status"
)
type imapErrorLogger struct {
@ -118,7 +121,7 @@ func (p *IMAPProvider) tryReconnect(ensureSelectedIn string) error {
return previousErr
}
err := pmapi.CheckConnection()
err := checkConnection()
log.WithError(err).Debug("Connection check")
if err != nil {
time.Sleep(imapReconnectSleep)
@ -286,3 +289,23 @@ func (p *IMAPProvider) fetchHelper(uid bool, ensureSelectedIn string, seqSet *im
return err
}, ensureSelectedIn)
}
// checkConnection returns an error if there is no internet connection.
// Note we don't want to use client manager because it only reports connection
// issues with API; we are only interested here whether we can reach
// third-party IMAP servers.
func checkConnection() error {
client := &http.Client{Timeout: time.Second * 10}
resp, err := client.Get(protonStatusURL)
if err != nil {
return err
}
_ = resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("HTTP status code %d", resp.StatusCode)
}
return nil
}

View File

@ -18,6 +18,7 @@
package transfer
import (
"context"
"sort"
"github.com/ProtonMail/gopenpgp/v2/crypto"
@ -34,25 +35,27 @@ const (
// PMAPIProvider implements import and export to/from ProtonMail server.
type PMAPIProvider struct {
clientManager ClientManager
userID string
addressID string
keyRing *crypto.KeyRing
builder *message.Builder
client pmapi.Client
userID string
addressID string
keyRing *crypto.KeyRing
builder *message.Builder
nextImportRequests map[string]*pmapi.ImportMsgReq // Key is msg transfer ID.
nextImportRequestsSize int
timeIt *timeIt
connection bool
}
// NewPMAPIProvider returns new PMAPIProvider.
func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*PMAPIProvider, error) {
func NewPMAPIProvider(client pmapi.Client, userID, addressID string) (*PMAPIProvider, error) {
provider := &PMAPIProvider{
clientManager: clientManager,
userID: userID,
addressID: addressID,
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
client: client,
userID: userID,
addressID: addressID,
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
nextImportRequests: map[string]*pmapi.ImportMsgReq{},
nextImportRequestsSize: 0,
@ -61,7 +64,7 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
}
if addressID != "" {
keyRing, err := clientManager.GetClient(userID).KeyRingForAddressID(addressID)
keyRing, err := client.KeyRingForAddressID(addressID)
if err != nil {
return nil, errors.Wrap(err, "failed to get key ring")
}
@ -71,10 +74,6 @@ func NewPMAPIProvider(clientManager ClientManager, userID, addressID string) (*P
return provider, nil
}
func (p *PMAPIProvider) client() pmapi.Client {
return p.clientManager.GetClient(p.userID)
}
// ID returns identifier of current setup of PMAPI provider.
// Identification is unique per user.
func (p *PMAPIProvider) ID() string {
@ -83,7 +82,7 @@ func (p *PMAPIProvider) ID() string {
// Mailboxes returns all available labels in ProtonMail account.
func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox, error) {
labels, err := p.client().ListLabels()
labels, err := p.client.ListLabels(context.Background())
if err != nil {
return nil, err
}
@ -92,7 +91,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
emptyLabelsMap := map[string]bool{}
if !includeEmpty {
messagesCounts, err := p.client().CountMessages(p.addressID)
messagesCounts, err := p.client.CountMessages(context.Background(), p.addressID)
if err != nil {
return nil, err
}
@ -120,7 +119,7 @@ func (p *PMAPIProvider) Mailboxes(includeEmpty, includeAllMail bool) ([]Mailbox,
ID: label.ID,
Name: label.Name,
Color: label.Color,
IsExclusive: label.Exclusive == 1,
IsExclusive: bool(label.Exclusive),
})
}
return mailboxes, nil
@ -160,10 +159,10 @@ func (l byFoldersLabels) Swap(i, j int) {
// Less sorts first folders, then labels, by user order.
func (l byFoldersLabels) Less(i, j int) bool {
if l[i].Exclusive == 1 && l[j].Exclusive == 0 {
if l[i].Exclusive && !l[j].Exclusive {
return true
}
if l[i].Exclusive == 0 && l[j].Exclusive == 1 {
if !l[i].Exclusive && l[j].Exclusive {
return false
}
return l[i].Order < l[j].Order

View File

@ -157,7 +157,7 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
body, err := p.builder.NewJobWithOptions(
context.Background(),
p.client(),
p.client,
msg.ID,
message.JobOptions{IgnoreDecryptionErrors: !skipEncryptedMessages},
).GetResult()
@ -169,14 +169,9 @@ func (p *PMAPIProvider) exportMessage(rule *Rule, progress *Progress, pmapiMsgID
return Message{Body: []byte(msg.Body)}, err
}
unread := false
if msg.Unread == 1 {
unread = true
}
return Message{
ID: msgID,
Unread: unread,
Unread: bool(msg.Unread),
Body: body,
Sources: []Mailbox{rule.SourceMailbox},
Targets: rule.TargetMailboxes,

View File

@ -19,6 +19,7 @@ package transfer
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
@ -51,15 +52,10 @@ func (p *PMAPIProvider) CreateMailbox(mailbox Mailbox) (Mailbox, error) {
return Mailbox{}, errors.New("mailbox is already created")
}
exclusive := 0
if mailbox.IsExclusive {
exclusive = 1
}
label, err := p.client().CreateLabel(&pmapi.Label{
label, err := p.client.CreateLabel(context.Background(), &pmapi.Label{
Name: mailbox.Name,
Color: mailbox.Color,
Exclusive: exclusive,
Exclusive: pmapi.Boolean(mailbox.IsExclusive),
Type: pmapi.LabelTypeMailbox,
})
if err != nil {
@ -118,12 +114,20 @@ func (p *PMAPIProvider) transferDraft(rules transferRules, progress *Progress, m
progress.messageImported(msg.ID, importedID, err)
}
func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string, error) {
func (p *PMAPIProvider) importDraft(msg Message, globalMailbox *Mailbox) (string, error) { //nolint[funlen]
message, attachmentReaders, err := p.parseMessage(msg)
if err != nil {
return "", errors.Wrap(err, "failed to parse message")
}
if message.Sender == nil {
mainAddress := p.client.Addresses().Main()
message.Sender = &mail.Address{
Name: mainAddress.DisplayName,
Address: mainAddress.Email,
}
}
// Trying to encrypt an encrypted draft will return an error;
// users are forbidden to import messages encrypted with foreign keys to drafts.
if message.IsEncrypted() {
@ -186,7 +190,7 @@ func (p *PMAPIProvider) transferMessage(rules transferRules, progress *Progress,
return
}
importMsgReqSize := len(importMsgReq.Body)
importMsgReqSize := len(importMsgReq.Message)
if p.nextImportRequestsSize+importMsgReqSize > pmapiImportBatchMaxSize || len(p.nextImportRequests) == pmapiImportBatchMaxItems {
preparedImportRequestsCh <- p.nextImportRequests
p.nextImportRequests = map[string]*pmapi.ImportMsgReq{}
@ -218,11 +222,6 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
}
}
unread := 0
if msg.Unread {
unread = 1
}
labelIDs := []string{}
for _, target := range msg.Targets {
// Frontend should not set All Mail to Rules, but to be sure...
@ -235,12 +234,14 @@ func (p *PMAPIProvider) generateImportMsgReq(rules transferRules, progress *Prog
}
return &pmapi.ImportMsgReq{
AddressID: p.addressID,
Body: body,
Unread: unread,
Time: message.Time,
Flags: computeMessageFlags(message.Header),
LabelIDs: labelIDs,
Metadata: &pmapi.ImportMetadata{
AddressID: p.addressID,
Unread: pmapi.Boolean(msg.Unread),
Time: message.Time,
Flags: computeMessageFlags(message.Header),
LabelIDs: labelIDs,
},
Message: body,
}, nil
}
@ -285,7 +286,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
}
importMsgIDs := []string{}
importMsgRequests := []*pmapi.ImportMsgReq{}
importMsgRequests := pmapi.ImportMsgReqs{}
for msgID, req := range importRequests {
importMsgIDs = append(importMsgIDs, msgID)
importMsgRequests = append(importMsgRequests, req)
@ -319,7 +320,7 @@ func (p *PMAPIProvider) importMessages(progress *Progress, importRequests map[st
func (p *PMAPIProvider) importMessage(msgSourceID string, progress *Progress, req *pmapi.ImportMsgReq) (importedID string, importedErr error) {
progress.callWrap(func() error {
results, err := p.importRequest(msgSourceID, []*pmapi.ImportMsgReq{req})
results, err := p.importRequest(msgSourceID, pmapi.ImportMsgReqs{req})
if err != nil {
return errors.Wrap(err, "failed to import messages")
}

View File

@ -19,6 +19,7 @@ package transfer
import (
"bytes"
"context"
"fmt"
"testing"
"time"
@ -33,7 +34,7 @@ func TestPMAPIProviderMailboxes(t *testing.T) {
defer m.ctrl.Finish()
setupPMAPIClientExpectationForExport(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err)
tests := []struct {
@ -78,7 +79,7 @@ func TestPMAPIProviderTransferTo(t *testing.T) {
defer m.ctrl.Finish()
setupPMAPIClientExpectationForExport(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err)
rules, rulesClose := newTestRules(t)
@ -96,7 +97,7 @@ func TestPMAPIProviderTransferFrom(t *testing.T) {
defer m.ctrl.Finish()
setupPMAPIClientExpectationForImport(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err)
rules, rulesClose := newTestRules(t)
@ -114,7 +115,7 @@ func TestPMAPIProviderTransferFromDraft(t *testing.T) {
defer m.ctrl.Finish()
setupPMAPIClientExpectationForImportDraft(&m)
provider, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
provider, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err)
rules, rulesClose := newTestRules(t)
@ -133,9 +134,9 @@ func TestPMAPIProviderTransferFromTo(t *testing.T) {
setupPMAPIClientExpectationForExport(&m)
setupPMAPIClientExpectationForImport(&m)
source, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
source, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err)
target, err := NewPMAPIProvider(m.clientManager, "user", "addressID")
target, err := NewPMAPIProvider(m.pmapiClient, "user", "addressID")
r.NoError(t, err)
rules, rulesClose := newTestRules(t)
@ -151,22 +152,22 @@ func setupPMAPIRules(rules transferRules) {
func setupPMAPIClientExpectationForExport(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: 0, Order: 2},
{ID: "label2", Name: "Bar", Color: "green", Exclusive: 0, Order: 1},
{ID: "folder1", Name: "One", Color: "red", Exclusive: 1, Order: 1},
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: 1, Order: 2},
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{
{ID: "label1", Name: "Foo", Color: "blue", Exclusive: false, Order: 2},
{ID: "label2", Name: "Bar", Color: "green", Exclusive: false, Order: 1},
{ID: "folder1", Name: "One", Color: "red", Exclusive: true, Order: 1},
{ID: "folder2", Name: "Two", Color: "orange", Exclusive: true, Order: 2},
}, nil).AnyTimes()
m.pmapiClient.EXPECT().CountMessages(gomock.Any()).Return([]*pmapi.MessagesCount{
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.MessagesCount{
{LabelID: "label1", Total: 10},
{LabelID: "label2", Total: 0},
{LabelID: "folder1", Total: 20},
}, nil).AnyTimes()
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{
{ID: "msg1"},
{ID: "msg2"},
}, 2, nil).AnyTimes()
m.pmapiClient.EXPECT().GetMessage(gomock.Any()).DoAndReturn(func(msgID string) (*pmapi.Message, error) {
m.pmapiClient.EXPECT().GetMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msgID string) (*pmapi.Message, error) {
return &pmapi.Message{
ID: msgID,
Body: string(getTestMsgBody(msgID)),
@ -177,11 +178,11 @@ func setupPMAPIClientExpectationForExport(m *mocks) {
func setupPMAPIClientExpectationForImport(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().Import(gomock.Any()).DoAndReturn(func(requests []*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error) {
m.pmapiClient.EXPECT().Import(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, requests pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
results := []*pmapi.ImportMsgRes{}
for _, request := range requests {
for _, msgID := range []string{"msg1", "msg2"} {
if bytes.Contains(request.Body, []byte(msgID)) {
if bytes.Contains(request.Message, []byte(msgID)) {
results = append(results, &pmapi.ImportMsgRes{MessageID: msgID, Error: nil})
}
}
@ -192,7 +193,7 @@ func setupPMAPIClientExpectationForImport(m *mocks) {
func setupPMAPIClientExpectationForImportDraft(m *mocks) {
m.pmapiClient.EXPECT().KeyRingForAddressID(gomock.Any()).Return(m.keyring, nil).AnyTimes()
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
m.pmapiClient.EXPECT().CreateDraft(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msg *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
r.Equal(m.t, msg.Subject, "draft1")
msg.ID = "draft1"
return msg, nil

View File

@ -18,6 +18,7 @@
package transfer
import (
"context"
"fmt"
"io"
"time"
@ -29,9 +30,17 @@ import (
const (
pmapiRetries = 10
pmapiReconnectTimeout = 30 * time.Minute
pmapiReconnectSleep = time.Minute
pmapiReconnectSleep = 10 * time.Second
)
func (p *PMAPIProvider) SetConnectionUp() {
p.connection = true
}
func (p *PMAPIProvider) SetConnectionDown() {
p.connection = false
}
func (p *PMAPIProvider) ensureConnection(callback func() error) error {
var callErr error
for i := 1; i <= pmapiRetries; i++ {
@ -57,11 +66,8 @@ func (p *PMAPIProvider) tryReconnect() error {
return previousErr
}
err := p.clientManager.CheckConnection()
log.WithError(err).Debug("Connection check")
if err != nil {
if !p.connection {
time.Sleep(pmapiReconnectSleep)
previousErr = err
continue
}
@ -77,7 +83,7 @@ func (p *PMAPIProvider) listMessages(filter *pmapi.MessagesFilter) (messages []*
p.timeIt.start("listing", key)
defer p.timeIt.stop("listing", key)
messages, count, err = p.client().ListMessages(filter)
messages, count, err = p.client.ListMessages(context.Background(), filter)
return err
})
return
@ -88,18 +94,18 @@ func (p *PMAPIProvider) getMessage(msgID string) (message *pmapi.Message, err er
p.timeIt.start("download", msgID)
defer p.timeIt.stop("download", msgID)
message, err = p.client().GetMessage(msgID)
message, err = p.client.GetMessage(context.Background(), msgID)
return err
})
return
}
func (p *PMAPIProvider) importRequest(msgSourceID string, req []*pmapi.ImportMsgReq) (res []*pmapi.ImportMsgRes, err error) {
func (p *PMAPIProvider) importRequest(msgSourceID string, req pmapi.ImportMsgReqs) (res []*pmapi.ImportMsgRes, err error) {
err = p.ensureConnection(func() error {
p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID)
res, err = p.client().Import(req)
res, err = p.client.Import(context.Background(), req)
return err
})
return
@ -110,7 +116,7 @@ func (p *PMAPIProvider) createDraft(msgSourceID string, message *pmapi.Message,
p.timeIt.start("upload", msgSourceID)
defer p.timeIt.stop("upload", msgSourceID)
draft, err = p.client().CreateDraft(message, parent, action)
draft, err = p.client.CreateDraft(context.Background(), message, parent, action)
return err
})
return
@ -123,7 +129,7 @@ func (p *PMAPIProvider) createAttachment(msgSourceID string, att *pmapi.Attachme
p.timeIt.start("upload", key)
defer p.timeIt.stop("upload", key)
created, err = p.client().CreateAttachment(att, r, sig)
created, err = p.client.CreateAttachment(context.Background(), att, r, sig)
return err
})
return

View File

@ -23,7 +23,6 @@ import (
"github.com/ProtonMail/gopenpgp/v2/crypto"
transfermocks "github.com/ProtonMail/proton-bridge/internal/transfer/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
gomock "github.com/golang/mock/gomock"
)
@ -33,10 +32,8 @@ type mocks struct {
ctrl *gomock.Controller
panicHandler *transfermocks.MockPanicHandler
clientManager *transfermocks.MockClientManager
imapClientProvider *transfermocks.MockIMAPClientProvider
pmapiClient *pmapimocks.MockClient
pmapiConfig *pmapi.ClientConfig
keyring *crypto.KeyRing
}
@ -49,15 +46,11 @@ func initMocks(t *testing.T) mocks {
ctrl: mockCtrl,
panicHandler: transfermocks.NewMockPanicHandler(mockCtrl),
clientManager: transfermocks.NewMockClientManager(mockCtrl),
imapClientProvider: transfermocks.NewMockIMAPClientProvider(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
pmapiConfig: &pmapi.ClientConfig{},
keyring: newTestKeyring(),
}
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).AnyTimes()
return m
}

View File

@ -17,10 +17,6 @@
package transfer
import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface {
HandlePanic()
}
@ -32,8 +28,3 @@ type MetricsManager interface {
Cancel()
Fail()
}
type ClientManager interface {
GetClient(userID string) pmapi.Client
CheckConnection() error
}

View File

@ -18,6 +18,7 @@
package updater
import (
"bytes"
"encoding/json"
"io"
@ -31,10 +32,6 @@ import (
var ErrManualUpdateRequired = errors.New("manual update is required")
type ClientProvider interface {
GetAnonymousClient() pmapi.Client
}
type Installer interface {
InstallUpdate(*semver.Version, io.Reader) error
}
@ -46,7 +43,7 @@ type Settings interface {
}
type Updater struct {
cm ClientProvider
cm pmapi.Manager
installer Installer
settings Settings
kr *crypto.KeyRing
@ -59,7 +56,7 @@ type Updater struct {
}
func New(
cm ClientProvider,
cm pmapi.Manager,
installer Installer,
s Settings,
kr *crypto.KeyRing,
@ -87,13 +84,10 @@ func New(
func (u *Updater) Check() (VersionInfo, error) {
logrus.Info("Checking for updates")
client := u.cm.GetAnonymousClient()
defer client.Logout()
r, err := client.DownloadAndVerify(
b, err := u.cm.DownloadAndVerify(
u.kr,
u.getVersionFileURL(),
u.getVersionFileURL()+".sig",
u.kr,
)
if err != nil {
return VersionInfo{}, err
@ -101,7 +95,7 @@ func (u *Updater) Check() (VersionInfo, error) {
var versionMap VersionMap
if err := json.NewDecoder(r).Decode(&versionMap); err != nil {
if err := json.Unmarshal(b, &versionMap); err != nil {
return VersionInfo{}, err
}
@ -141,15 +135,12 @@ func (u *Updater) InstallUpdate(update VersionInfo) error {
return u.locker.doOnce(func() error {
logrus.WithField("package", update.Package).Info("Installing update package")
client := u.cm.GetAnonymousClient()
defer client.Logout()
r, err := client.DownloadAndVerify(update.Package, update.Package+".sig", u.kr)
b, err := u.cm.DownloadAndVerify(u.kr, update.Package, update.Package+".sig")
if err != nil {
return errors.Wrap(ErrDownloadVerify, err.Error())
}
if err := u.installer.InstallUpdate(update.Version, r); err != nil {
if err := u.installer.InstallUpdate(update.Version, bytes.NewReader(b)); err != nil {
return errors.Wrap(ErrInstall, err.Error())
}

View File

@ -18,7 +18,6 @@
package updater
import (
"bytes"
"encoding/json"
"errors"
"io"
@ -40,9 +39,9 @@ func TestCheck(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.1.0", false)
updater := newTestUpdater(cm, "1.1.0", false)
versionMap := VersionMap{
"stable": VersionInfo{
@ -53,13 +52,11 @@ func TestCheck(t *testing.T) {
},
}
client.EXPECT().DownloadAndVerify(
cm.EXPECT().DownloadAndVerify(
gomock.Any(),
updater.getVersionFileURL(),
updater.getVersionFileURL()+".sig",
gomock.Any(),
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
client.EXPECT().Logout()
).Return(mustMarshal(t, versionMap), nil)
version, err := updater.Check()
@ -71,9 +68,9 @@ func TestCheckEarlyAccess(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.1.0", true)
updater := newTestUpdater(cm, "1.1.0", true)
versionMap := VersionMap{
"stable": VersionInfo{
@ -90,13 +87,11 @@ func TestCheckEarlyAccess(t *testing.T) {
},
}
client.EXPECT().DownloadAndVerify(
cm.EXPECT().DownloadAndVerify(
gomock.Any(),
updater.getVersionFileURL(),
updater.getVersionFileURL()+".sig",
gomock.Any(),
).Return(bytes.NewReader(mustMarshal(t, versionMap)), nil)
client.EXPECT().Logout()
).Return(mustMarshal(t, versionMap), nil)
version, err := updater.Check()
@ -108,18 +103,16 @@ func TestCheckBadSignature(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.2.0", false)
updater := newTestUpdater(cm, "1.2.0", false)
client.EXPECT().DownloadAndVerify(
cm.EXPECT().DownloadAndVerify(
gomock.Any(),
updater.getVersionFileURL(),
updater.getVersionFileURL()+".sig",
gomock.Any(),
).Return(nil, errors.New("bad signature"))
client.EXPECT().Logout()
_, err := updater.Check()
assert.Error(t, err)
@ -129,9 +122,9 @@ func TestIsUpdateApplicable(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false)
updater := newTestUpdater(cm, "1.4.0", false)
versionOld := VersionInfo{
Version: semver.MustParse("1.3.0"),
@ -165,9 +158,9 @@ func TestCanInstall(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false)
updater := newTestUpdater(cm, "1.4.0", false)
versionManual := VersionInfo{
Version: semver.MustParse("1.5.0"),
@ -192,9 +185,9 @@ func TestInstallUpdate(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false)
updater := newTestUpdater(cm, "1.4.0", false)
latestVersion := VersionInfo{
Version: semver.MustParse("1.5.0"),
@ -203,13 +196,11 @@ func TestInstallUpdate(t *testing.T) {
RolloutProportion: 1.0,
}
client.EXPECT().DownloadAndVerify(
cm.EXPECT().DownloadAndVerify(
gomock.Any(),
latestVersion.Package,
latestVersion.Package+".sig",
gomock.Any(),
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
client.EXPECT().Logout()
).Return([]byte("tgz_data_here"), nil)
err := updater.InstallUpdate(latestVersion)
@ -220,9 +211,9 @@ func TestInstallUpdateBadSignature(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false)
updater := newTestUpdater(cm, "1.4.0", false)
latestVersion := VersionInfo{
Version: semver.MustParse("1.5.0"),
@ -231,14 +222,12 @@ func TestInstallUpdateBadSignature(t *testing.T) {
RolloutProportion: 1.0,
}
client.EXPECT().DownloadAndVerify(
cm.EXPECT().DownloadAndVerify(
gomock.Any(),
latestVersion.Package,
latestVersion.Package+".sig",
gomock.Any(),
).Return(nil, errors.New("bad signature"))
client.EXPECT().Logout()
err := updater.InstallUpdate(latestVersion)
assert.Error(t, err)
@ -248,9 +237,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
c := gomock.NewController(t)
defer c.Finish()
client := mocks.NewMockClient(c)
cm := mocks.NewMockManager(c)
updater := newTestUpdater(client, "1.4.0", false)
updater := newTestUpdater(cm, "1.4.0", false)
updater.installer = &fakeInstaller{delay: 2 * time.Second}
@ -261,13 +250,11 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
RolloutProportion: 1.0,
}
client.EXPECT().DownloadAndVerify(
cm.EXPECT().DownloadAndVerify(
gomock.Any(),
latestVersion.Package,
latestVersion.Package+".sig",
gomock.Any(),
).Return(bytes.NewReader([]byte("tgz_data_here")), nil)
client.EXPECT().Logout()
).Return([]byte("tgz_data_here"), nil)
wg := &sync.WaitGroup{}
@ -288,9 +275,9 @@ func TestInstallUpdateAlreadyOngoing(t *testing.T) {
wg.Wait()
}
func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *Updater {
func newTestUpdater(manager pmapi.Manager, curVer string, earlyAccess bool) *Updater {
return New(
&fakeClientProvider{client: client},
manager,
&fakeInstaller{},
newFakeSettings(0.5, earlyAccess),
nil,
@ -299,14 +286,6 @@ func newTestUpdater(client *mocks.MockClient, curVer string, earlyAccess bool) *
)
}
type fakeClientProvider struct {
client *mocks.MockClient
}
func (p *fakeClientProvider) GetAnonymousClient() pmapi.Client {
return p.client
}
type fakeInstaller struct {
bad bool
delay time.Duration

View File

@ -149,3 +149,13 @@ func (s *Credentials) Logout() {
func (s *Credentials) IsConnected() bool {
return s.APIToken != "" && s.MailboxPassword != ""
}
func (s *Credentials) SplitAPIToken() (string, string, error) {
split := strings.Split(s.APIToken, ":")
if len(split) != 2 {
return "", "", errors.New("malformed API token")
}
return split[0], split[1], nil
}

View File

@ -39,7 +39,7 @@ func NewStore(keychain *keychain.Keychain) *Store {
return &Store{secrets: keychain}
}
func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails []string) (creds *Credentials, err error) {
func (s *Store) Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*Credentials, error) {
storeLocker.Lock()
defer storeLocker.Unlock()
@ -49,10 +49,10 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
"emails": emails,
}).Trace("Adding new credentials")
creds = &Credentials{
creds := &Credentials{
UserID: userID,
Name: userName,
APIToken: apiToken,
APIToken: uid + ":" + ref,
MailboxPassword: mailboxPassword,
IsHidden: false,
}
@ -72,82 +72,82 @@ func (s *Store) Add(userID, userName, apiToken, mailboxPassword string, emails [
creds.Timestamp = time.Now().Unix()
}
if err = s.saveCredentials(creds); err != nil {
return
if err := s.saveCredentials(creds); err != nil {
return nil, err
}
return creds, err
return creds, nil
}
func (s *Store) SwitchAddressMode(userID string) error {
func (s *Store) SwitchAddressMode(userID string) (*Credentials, error) {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
return nil, err
}
credentials.IsCombinedAddressMode = !credentials.IsCombinedAddressMode
credentials.BridgePassword = generatePassword()
return s.saveCredentials(credentials)
return credentials, s.saveCredentials(credentials)
}
func (s *Store) UpdateEmails(userID string, emails []string) error {
func (s *Store) UpdateEmails(userID string, emails []string) (*Credentials, error) {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
return nil, err
}
credentials.SetEmailList(emails)
return s.saveCredentials(credentials)
return credentials, s.saveCredentials(credentials)
}
func (s *Store) UpdatePassword(userID, password string) error {
func (s *Store) UpdatePassword(userID, password string) (*Credentials, error) {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
return nil, err
}
credentials.MailboxPassword = password
return s.saveCredentials(credentials)
return credentials, s.saveCredentials(credentials)
}
func (s *Store) UpdateToken(userID, apiToken string) error {
func (s *Store) UpdateToken(userID, uid, ref string) (*Credentials, error) {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
return nil, err
}
credentials.APIToken = apiToken
credentials.APIToken = uid + ":" + ref
return s.saveCredentials(credentials)
return credentials, s.saveCredentials(credentials)
}
func (s *Store) Logout(userID string) error {
func (s *Store) Logout(userID string) (*Credentials, error) {
storeLocker.Lock()
defer storeLocker.Unlock()
credentials, err := s.get(userID)
if err != nil {
return err
return nil, err
}
credentials.Logout()
return s.saveCredentials(credentials)
return credentials, s.saveCredentials(credentials)
}
// List returns a list of usernames that have credentials stored.
@ -233,7 +233,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
_, secret, err := s.secrets.Get(userID)
if err != nil {
log.WithError(err).Error("Could not get credentials from native keychain")
log.WithError(err).Warn("Could not get credentials from native keychain")
return
}
@ -249,7 +249,7 @@ func (s *Store) get(userID string) (creds *Credentials, err error) {
}
// saveCredentials encrypts and saves password to the keychain store.
func (s *Store) saveCredentials(credentials *Credentials) (err error) {
func (s *Store) saveCredentials(credentials *Credentials) error {
credentials.Version = keychain.Version
return s.secrets.Put(credentials.UserID, credentials.Marshal())

View File

@ -26,8 +26,7 @@ import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
r "github.com/stretchr/testify/require"
)
const testSep = "\n"
@ -249,26 +248,26 @@ func TestMarshalFormats(t *testing.T) {
log.Infof("secretFmt %#v %d\n", secretFmt, len(secretFmt))
output := testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalStrings(secretStrings))
r.NoError(t, output.UnmarshalStrings(secretStrings))
log.Infof("strings out %#v \n", output)
require.True(t, input.IsSame(&output), "strings out not same")
r.True(t, input.IsSame(&output), "strings out not same")
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalGob(secretGob))
r.NoError(t, output.UnmarshalGob(secretGob))
log.Infof("gob out %#v\n \n", output)
assert.Equal(t, input, output)
r.Equal(t, input, output)
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.FromJSON(secretJSON))
r.NoError(t, output.FromJSON(secretJSON))
log.Infof("json out %#v \n", output)
require.True(t, input.IsSame(&output), "json out not same")
r.True(t, input.IsSame(&output), "json out not same")
/*
// Simple Fscanf not working!
output = testCredentials{APIToken: "refresh"}
require.NoError(t, output.UnmarshalFmt(secretFmt))
r.NoError(t, output.UnmarshalFmt(secretFmt))
log.Infof("fmt out %#v \n", output)
require.True(t, input.IsSame(&output), "fmt out not same")
r.True(t, input.IsSame(&output), "fmt out not same")
*/
}
@ -291,7 +290,7 @@ func TestMarshal(t *testing.T) {
log.Infof("secret %#v %d\n", secret, len(secret))
output := Credentials{APIToken: "refresh"}
require.NoError(t, output.Unmarshal(secret))
r.NoError(t, output.Unmarshal(secret))
log.Infof("output %#v\n", output)
assert.Equal(t, input, output)
r.Equal(t, input, output)
}

View File

@ -1,107 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./listener/listener.go
// Package users is a generated GoMock package.
package users
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(eventName string, limit time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", eventName, limit)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(eventName, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), eventName, limit)
}
// Add mocks base method
func (m *MockListener) Add(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", eventName, channel)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), eventName, channel)
}
// Remove mocks base method
func (m *MockListener) Remove(eventName string, channel chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", eventName, channel)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(eventName, channel interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), eventName, channel)
}
// Emit mocks base method
func (m *MockListener) Emit(eventName, data string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", eventName, data)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(eventName, data interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), eventName, data)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", eventName)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), eventName)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(eventName string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", eventName)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(eventName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), eventName)
}

View File

@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
}
// Emit mocks base method
func (m *MockListener) Emit(arg0, arg1 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", arg0, arg1)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
}
// ProvideChannel mocks base method
func (m *MockListener) ProvideChannel(arg0 string) <-chan string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ProvideChannel", arg0)
ret0, _ := ret[0].(<-chan string)
return ret0
}
// ProvideChannel indicates an expected call of ProvideChannel
func (mr *MockListenerMockRecorder) ProvideChannel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProvideChannel", reflect.TypeOf((*MockListener)(nil).ProvideChannel), arg0)
}
// Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0, arg1)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", arg0)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", arg0)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", arg0, arg1)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,ClientManager,CredentialsStorer,StoreMaker)
// Source: github.com/ProtonMail/proton-bridge/internal/users (interfaces: Locator,PanicHandler,CredentialsStorer,StoreMaker)
// Package mocks is a generated GoMock package.
package mocks
@ -9,7 +9,6 @@ import (
store "github.com/ProtonMail/proton-bridge/internal/store"
credentials "github.com/ProtonMail/proton-bridge/internal/users/credentials"
pmapi "github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
)
@ -85,109 +84,6 @@ func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockClientManager is a mock of ClientManager interface
type MockClientManager struct {
ctrl *gomock.Controller
recorder *MockClientManagerMockRecorder
}
// MockClientManagerMockRecorder is the mock recorder for MockClientManager
type MockClientManagerMockRecorder struct {
mock *MockClientManager
}
// NewMockClientManager creates a new mock instance
func NewMockClientManager(ctrl *gomock.Controller) *MockClientManager {
mock := &MockClientManager{ctrl: ctrl}
mock.recorder = &MockClientManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClientManager) EXPECT() *MockClientManagerMockRecorder {
return m.recorder
}
// AllowProxy mocks base method
func (m *MockClientManager) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy
func (mr *MockClientManagerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockClientManager)(nil).AllowProxy))
}
// CheckConnection mocks base method
func (m *MockClientManager) CheckConnection() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckConnection")
ret0, _ := ret[0].(error)
return ret0
}
// CheckConnection indicates an expected call of CheckConnection
func (mr *MockClientManagerMockRecorder) CheckConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckConnection", reflect.TypeOf((*MockClientManager)(nil).CheckConnection))
}
// DisallowProxy mocks base method
func (m *MockClientManager) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy
func (mr *MockClientManagerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockClientManager)(nil).DisallowProxy))
}
// GetAnonymousClient mocks base method
func (m *MockClientManager) GetAnonymousClient() pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAnonymousClient")
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetAnonymousClient indicates an expected call of GetAnonymousClient
func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
}
// GetAuthUpdateChannel mocks base method
func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
ret0, _ := ret[0].(chan pmapi.ClientAuth)
return ret0
}
// GetAuthUpdateChannel indicates an expected call of GetAuthUpdateChannel
func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetAuthUpdateChannel))
}
// GetClient mocks base method
func (m *MockClientManager) GetClient(arg0 string) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClient", arg0)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// GetClient indicates an expected call of GetClient
func (mr *MockClientManagerMockRecorder) GetClient(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientManager)(nil).GetClient), arg0)
}
// MockCredentialsStorer is a mock of CredentialsStorer interface
type MockCredentialsStorer struct {
ctrl *gomock.Controller
@ -212,18 +108,18 @@ func (m *MockCredentialsStorer) EXPECT() *MockCredentialsStorerMockRecorder {
}
// Add mocks base method
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3 string, arg4 []string) (*credentials.Credentials, error) {
func (m *MockCredentialsStorer) Add(arg0, arg1, arg2, arg3, arg4 string, arg5 []string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4)
ret := m.ctrl.Call(m, "Add", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Add indicates an expected call of Add
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
func (mr *MockCredentialsStorerMockRecorder) Add(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockCredentialsStorer)(nil).Add), arg0, arg1, arg2, arg3, arg4, arg5)
}
// Delete mocks base method
@ -271,11 +167,12 @@ func (mr *MockCredentialsStorerMockRecorder) List() *gomock.Call {
}
// Logout mocks base method
func (m *MockCredentialsStorer) Logout(arg0 string) error {
func (m *MockCredentialsStorer) Logout(arg0 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logout", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Logout indicates an expected call of Logout
@ -285,11 +182,12 @@ func (mr *MockCredentialsStorerMockRecorder) Logout(arg0 interface{}) *gomock.Ca
}
// SwitchAddressMode mocks base method
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) error {
func (m *MockCredentialsStorer) SwitchAddressMode(arg0 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SwitchAddressMode", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SwitchAddressMode indicates an expected call of SwitchAddressMode
@ -299,11 +197,12 @@ func (mr *MockCredentialsStorerMockRecorder) SwitchAddressMode(arg0 interface{})
}
// UpdateEmails mocks base method
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) error {
func (m *MockCredentialsStorer) UpdateEmails(arg0 string, arg1 []string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateEmails", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateEmails indicates an expected call of UpdateEmails
@ -313,11 +212,12 @@ func (mr *MockCredentialsStorerMockRecorder) UpdateEmails(arg0, arg1 interface{}
}
// UpdatePassword mocks base method
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) error {
func (m *MockCredentialsStorer) UpdatePassword(arg0, arg1 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdatePassword indicates an expected call of UpdatePassword
@ -327,17 +227,18 @@ func (mr *MockCredentialsStorerMockRecorder) UpdatePassword(arg0, arg1 interface
}
// UpdateToken mocks base method
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1 string) error {
func (m *MockCredentialsStorer) UpdateToken(arg0, arg1, arg2 string) (*credentials.Credentials, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
ret := m.ctrl.Call(m, "UpdateToken", arg0, arg1, arg2)
ret0, _ := ret[0].(*credentials.Credentials)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateToken indicates an expected call of UpdateToken
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockCredentialsStorerMockRecorder) UpdateToken(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateToken", reflect.TypeOf((*MockCredentialsStorer)(nil).UpdateToken), arg0, arg1, arg2)
}
// MockStoreMaker is a mock of StoreMaker interface

View File

@ -20,14 +20,8 @@ package users
import (
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type Configer interface {
GetAppVersion() string
GetAPIConfig() *pmapi.ClientConfig
}
type Locator interface {
Clear() error
}
@ -38,25 +32,16 @@ type PanicHandler interface {
type CredentialsStorer interface {
List() (userIDs []string, err error)
Add(userID, userName, apiToken, mailboxPassword string, emails []string) (*credentials.Credentials, error)
Add(userID, userName, uid, ref, mailboxPassword string, emails []string) (*credentials.Credentials, error)
Get(userID string) (*credentials.Credentials, error)
SwitchAddressMode(userID string) error
UpdateEmails(userID string, emails []string) error
UpdatePassword(userID, password string) error
UpdateToken(userID, apiToken string) error
Logout(userID string) error
SwitchAddressMode(userID string) (*credentials.Credentials, error)
UpdateEmails(userID string, emails []string) (*credentials.Credentials, error)
UpdatePassword(userID, password string) (*credentials.Credentials, error)
UpdateToken(userID, uid, ref string) (*credentials.Credentials, error)
Logout(userID string) (*credentials.Credentials, error)
Delete(userID string) error
}
type ClientManager interface {
GetClient(userID string) pmapi.Client
GetAnonymousClient() pmapi.Client
AllowProxy()
DisallowProxy()
GetAuthUpdateChannel() chan pmapi.ClientAuth
CheckConnection() error
}
type StoreMaker interface {
New(user store.BridgeUser) (*store.Store, error)
Remove(userID string) error

View File

@ -18,6 +18,7 @@
package users
import (
"context"
"runtime"
"strings"
"sync"
@ -36,11 +37,11 @@ var ErrLoggedOutUser = errors.New("account is logged out, use the app to login a
// User is a struct on top of API client and credentials store.
type User struct {
log *logrus.Entry
panicHandler PanicHandler
listener listener.Listener
clientManager ClientManager
credStorer CredentialsStorer
log *logrus.Entry
panicHandler PanicHandler
listener listener.Listener
client pmapi.Client
credStorer CredentialsStorer
storeFactory StoreMaker
store *store.Store
@ -48,75 +49,72 @@ type User struct {
userID string
creds *credentials.Credentials
lock sync.RWMutex
isAuthorized bool
lock sync.RWMutex
useOnlyActiveAddresses bool
}
// newUser creates a new user.
// The user is initially disconnected and must be connected by calling connect().
func newUser(
panicHandler PanicHandler,
userID string,
eventListener listener.Listener,
credStorer CredentialsStorer,
clientManager ClientManager,
storeFactory StoreMaker,
) (u *User, err error) {
useOnlyActiveAddresses bool,
) (*User, *credentials.Credentials, error) {
log := log.WithField("user", userID)
log.Debug("Creating or loading user")
creds, err := credStorer.Get(userID)
if err != nil {
return nil, errors.Wrap(err, "failed to load user credentials")
return nil, nil, errors.Wrap(err, "failed to load user credentials")
}
u = &User{
log: log,
panicHandler: panicHandler,
listener: eventListener,
credStorer: credStorer,
clientManager: clientManager,
storeFactory: storeFactory,
userID: userID,
creds: creds,
}
return
return &User{
log: log,
panicHandler: panicHandler,
listener: eventListener,
credStorer: credStorer,
storeFactory: storeFactory,
userID: userID,
creds: creds,
useOnlyActiveAddresses: useOnlyActiveAddresses,
}, creds, nil
}
func (u *User) client() pmapi.Client {
return u.clientManager.GetClient(u.userID)
}
// connect connects a user. This includes
// - providing it with an authorised API client
// - loading its credentials from the credentials store
// - loading and unlocking its PGP keys
// - loading its store.
func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) error {
u.log.Info("Connecting user")
// init initialises a user. This includes reloading its credentials from the credentials store
// (such as when logging out and back in, you need to reload the credentials because the new credentials will
// have the apitoken and password), authorising the user against the api, loading the user store (creating a new one
// if necessary), and setting the imap idle updates channel (used to send imap idle updates to the imap backend if
// something in the store changed).
func (u *User) init() (err error) {
u.log.Info("Initialising user")
// Connected users have an API client.
u.client = client
// Reload the user's credentials (if they log out and back in we need the new
// version with the apitoken and mailbox password).
creds, err := u.credStorer.Get(u.userID)
if err != nil {
return errors.Wrap(err, "failed to load user credentials")
}
u.client.AddAuthRefreshHandler(u.handleAuthRefresh)
// Save the latest credentials for the user.
u.creds = creds
// Try to authorise the user if they aren't already authorised.
// Note: we still allow users to set up accounts if the internet is off.
if authErr := u.authorizeIfNecessary(false); authErr != nil {
switch errors.Cause(authErr) {
case pmapi.ErrAPINotReachable, pmapi.ErrUpgradeApplication, ErrLoggedOutUser:
u.log.WithError(authErr).Warn("Could not authorize user")
default:
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(authErr, "failed to authorize user")
}
// Connected users have unlocked keys.
if err := u.unlockIfNecessary(); err != nil {
return err
}
// Connected users have a store.
if err := u.loadStore(); err != nil { //nolint[revive] easier to read
return err
}
return nil
}
func (u *User) loadStore() error {
// Logged-out user keeps store running to access offline data.
// Therefore it is necessary to close it before re-init.
if u.store != nil {
@ -125,93 +123,36 @@ func (u *User) init() (err error) {
}
u.store = nil
}
store, err := u.storeFactory.New(u)
if err != nil {
return errors.Wrap(err, "failed to create store")
}
u.store = store
return err
return nil
}
// authorizeIfNecessary checks whether user is logged in and is connected to api auth channel.
// If user is not already connected to the api auth channel (for example there was no internet during start),
// it tries to connect it.
func (u *User) authorizeIfNecessary(emitEvent bool) (err error) {
// If user is connected and has an auth channel, then perfect, nothing to do here.
if u.creds.IsConnected() && u.isAuthorized {
// The keyring unlock is triggered here to resolve state where apiClient
// is authenticated (we have auth token) but it was not possible to download
// and unlock the keys (internet not reachable).
return u.unlockIfNecessary()
}
func (u *User) handleAuthRefresh(auth *pmapi.AuthRefresh) {
u.log.Debug("User received auth refresh update")
if !u.creds.IsConnected() {
err = ErrLoggedOutUser
} else if err = u.authorizeAndUnlock(); err != nil {
u.log.WithError(err).Error("Could not authorize and unlock user")
switch errors.Cause(err) {
case pmapi.ErrUpgradeApplication, pmapi.ErrAPINotReachable: // Ignore these errors.
default:
if errLogout := u.credStorer.Logout(u.userID); errLogout != nil {
u.log.WithField("err", errLogout).Error("Could not log user out from credentials store")
}
if auth == nil {
if err := u.logout(); err != nil {
log.WithError(err).
WithField("userID", u.userID).
Error("User logout failed while watching API auths")
}
return
}
if emitEvent && err != nil &&
errors.Cause(err) != pmapi.ErrUpgradeApplication &&
errors.Cause(err) != pmapi.ErrAPINotReachable {
u.listener.Emit(events.LogoutEvent, u.userID)
}
return err
}
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
func (u *User) unlockIfNecessary() error {
if u.client().IsUnlocked() {
return nil
}
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
return nil
}
// authorizeAndUnlock tries to authorize the user with the API using the the user's APIToken.
// If that succeeds, it tries to unlock the user's keys and addresses.
func (u *User) authorizeAndUnlock() (err error) {
if u.creds.APIToken == "" {
u.log.Warn("Could not connect to API auth channel, have no API token")
return nil
}
if _, err := u.client().AuthRefresh(u.creds.APIToken); err != nil {
return errors.Wrap(err, "failed to refresh API auth")
}
if err := u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
return nil
}
func (u *User) updateAuthToken(auth *pmapi.Auth) {
u.log.Debug("User received auth")
if err := u.credStorer.UpdateToken(u.userID, auth.GenToken()); err != nil {
creds, err := u.credStorer.UpdateToken(u.userID, auth.UID, auth.RefreshToken)
if err != nil {
u.log.WithError(err).Error("Failed to update refresh token in credentials store")
return
}
u.refreshFromCredentials()
u.isAuthorized = true
u.creds = creds
}
// clearStore removes the database.
@ -244,13 +185,6 @@ func (u *User) closeStore() error {
return nil
}
// GetTemporaryPMAPIClient returns an authorised PMAPI client.
// Do not use! It's only for backward compatibility of old SMTP and IMAP implementations.
// After proper refactor of SMTP and IMAP remove this method.
func (u *User) GetTemporaryPMAPIClient() pmapi.Client {
return u.client()
}
// ID returns the user's userID.
func (u *User) ID() string {
return u.userID
@ -272,6 +206,44 @@ func (u *User) IsConnected() bool {
return u.creds.IsConnected()
}
func (u *User) GetClient() pmapi.Client {
if err := u.unlockIfNecessary(); err != nil {
u.log.WithError(err).Error("Failed to unlock user")
}
return u.client
}
// unlockIfNecessary will not trigger keyring unlocking if it was already successfully unlocked.
func (u *User) unlockIfNecessary() error {
if !u.creds.IsConnected() {
return nil
}
if u.client.IsUnlocked() {
return nil
}
// unlockIfNecessary is called with every access to underlying pmapi
// client. Unlock should only finish unlocking when connection is back up.
// That means it should try it fast enough and not retry if connection
// is still down.
err := u.client.Unlock(pmapi.ContextWithoutRetry(context.Background()), []byte(u.creds.MailboxPassword))
if err == nil {
return nil
}
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
u.log.WithError(err).Warn("Could not unlock user")
return nil
}
if logoutErr := u.logout(); logoutErr != nil {
u.log.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "failed to unlock user")
}
// IsCombinedAddressMode returns whether user is set in combined or split mode.
// Combined mode is the default mode and is what users typically need.
// Split mode is mostly for outlook as it cannot handle sending e-mails from an
@ -345,7 +317,7 @@ func (u *User) GetAddressID(address string) (id string, err error) {
return u.store.GetAddressID(address)
}
addresses := u.client().Addresses()
addresses := u.client.Addresses()
pmapiAddress := addresses.ByEmail(address)
if pmapiAddress != nil {
return pmapiAddress.ID, nil
@ -374,74 +346,70 @@ func (u *User) CheckBridgeLogin(password string) error {
u.lock.RLock()
defer u.lock.RUnlock()
// True here because users should be notified by popup of auth failure.
if err := u.authorizeIfNecessary(true); err != nil {
u.log.WithError(err).Error("Failed to authorize user")
return err
if !u.creds.IsConnected() {
u.listener.Emit(events.LogoutEvent, u.userID)
return ErrLoggedOutUser
}
return u.creds.CheckPassword(password)
}
// UpdateUser updates user details from API and saves to the credentials.
func (u *User) UpdateUser() error {
func (u *User) UpdateUser(ctx context.Context) error {
u.lock.Lock()
defer u.lock.Unlock()
if err := u.authorizeIfNecessary(true); err != nil {
return errors.Wrap(err, "cannot update user")
}
_, err := u.client().UpdateUser()
_, err := u.client.UpdateUser(ctx)
if err != nil {
return err
}
if err = u.client().ReloadKeys([]byte(u.creds.MailboxPassword)); err != nil {
if err := u.client.ReloadKeys(ctx, []byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to reload keys")
}
emails := u.client().Addresses().ActiveEmails()
if err := u.credStorer.UpdateEmails(u.userID, emails); err != nil {
creds, err := u.credStorer.UpdateEmails(u.userID, u.client.Addresses().ActiveEmails())
if err != nil {
return err
}
u.refreshFromCredentials()
u.creds = creds
return nil
}
// SwitchAddressMode changes mode from combined to split and vice versa. The mode to switch to is determined by the
// state of the user's credentials in the credentials store. See `IsCombinedAddressMode` for more details.
func (u *User) SwitchAddressMode() (err error) {
func (u *User) SwitchAddressMode() error {
u.log.Trace("Switching user address mode")
u.lock.Lock()
defer u.lock.Unlock()
u.CloseAllConnections()
if u.store == nil {
err = errors.New("store is not initialised")
return
return errors.New("store is not initialised")
}
newAddressModeState := !u.IsCombinedAddressMode()
if err = u.store.UseCombinedMode(newAddressModeState); err != nil {
u.log.WithError(err).Error("Could not switch store address mode")
return
if err := u.store.UseCombinedMode(newAddressModeState); err != nil {
return errors.Wrap(err, "could not switch store address mode")
}
if u.creds.IsCombinedAddressMode != newAddressModeState {
if err = u.credStorer.SwitchAddressMode(u.userID); err != nil {
u.log.WithError(err).Error("Could not switch credentials store address mode")
return
}
if u.creds.IsCombinedAddressMode == newAddressModeState {
return nil
}
u.refreshFromCredentials()
creds, err := u.credStorer.SwitchAddressMode(u.userID)
if err != nil {
return errors.Wrap(err, "could not switch credentials store address mode")
}
return err
u.creds = creds
return nil
}
// logout is the same as Logout, but for internal purposes (logged out from
@ -458,35 +426,36 @@ func (u *User) logout() error {
u.listener.Emit(events.UserRefreshEvent, u.userID)
}
u.isAuthorized = false
return err
}
// Logout logs out the user from pmapi, the credentials store, the mail store, and tries to remove as much
// sensitive data as possible.
func (u *User) Logout() (err error) {
func (u *User) Logout() error {
u.lock.Lock()
defer u.lock.Unlock()
u.log.Debug("Logging out user")
if !u.creds.IsConnected() {
return
return nil
}
u.client().Logout()
if err := u.client.AuthDelete(context.Background()); err != nil {
u.log.WithError(err).Warn("Failed to delete auth")
}
if err = u.credStorer.Logout(u.userID); err != nil {
creds, err := u.credStorer.Logout(u.userID)
if err != nil {
u.log.WithError(err).Warn("Could not log user out from credentials store")
if err = u.credStorer.Delete(u.userID); err != nil {
if err := u.credStorer.Delete(u.userID); err != nil {
u.log.WithError(err).Error("Could not delete user from credentials store")
}
} else {
u.creds = creds
}
u.refreshFromCredentials()
// Do not close whole store, just event loop. Some information might be needed offline (e.g. addressID)
u.closeEventLoop()
@ -494,15 +463,7 @@ func (u *User) Logout() (err error) {
runtime.GC()
return err
}
func (u *User) refreshFromCredentials() {
if credentials, err := u.credStorer.Get(u.userID); err != nil {
log.WithError(err).Error("Cannot refresh user credentials")
} else {
u.creds = credentials
}
return nil
}
func (u *User) closeEventLoop() {

View File

@ -18,13 +18,14 @@
package users
import (
"context"
"testing"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
)
func TestUpdateUser(t *testing.T) {
@ -35,25 +36,14 @@ func TestUpdateUser(t *testing.T) {
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().UpdateUser(gomock.Any()).Return(nil, nil),
m.pmapiClient.EXPECT().ReloadKeys(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}).Return(testCredentials, nil),
)
gomock.InOrder(
m.pmapiClient.EXPECT().GetEvent(testPMAPIEvent.EventID).Return(testPMAPIEvent, nil).MaxTimes(1),
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).MaxTimes(1),
)
assert.NoError(t, user.UpdateUser())
waitForEvents()
r.NoError(t, user.UpdateUser(context.Background()))
}
func TestUserSwitchAddressMode(t *testing.T) {
@ -63,128 +53,91 @@ func TestUserSwitchAddressMode(t *testing.T) {
user := testNewUser(m)
defer cleanUpUserData(user)
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
// Ignore any sync on background.
m.pmapiClient.EXPECT().ListMessages(gomock.Any(), gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
// Check initial state.
r.True(t, user.store.IsCombinedMode())
r.True(t, user.creds.IsCombinedAddressMode)
r.True(t, user.IsCombinedAddressMode())
// Mock change to split mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsSplit, nil),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentialsSplit, nil),
)
assert.NoError(t, user.SwitchAddressMode())
assert.False(t, user.store.IsCombinedMode())
assert.False(t, user.creds.IsCombinedAddressMode)
assert.False(t, user.IsCombinedAddressMode())
// Check switch to split mode.
r.NoError(t, user.SwitchAddressMode())
r.False(t, user.store.IsCombinedMode())
r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode())
// MOck change to combined mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "alsouser@pm.me"),
m.pmapiClient.EXPECT().ListLabels().Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages("").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().SwitchAddressMode("user").Return(testCredentials, nil),
)
m.pmapiClient.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
assert.NoError(t, user.SwitchAddressMode())
assert.True(t, user.store.IsCombinedMode())
assert.True(t, user.creds.IsCombinedAddressMode)
assert.True(t, user.IsCombinedAddressMode())
waitForEvents()
// Check switch to combined mode.
r.NoError(t, user.SwitchAddressMode())
r.True(t, user.store.IsCombinedMode())
r.True(t, user.creds.IsCombinedAddressMode)
r.True(t, user.IsCombinedAddressMode())
}
func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
r.NoError(t, err)
}
func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().Logout().Return(),
m.credentialsStore.EXPECT().Logout("user").Return(errors.New("logout failed")),
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(nil, errors.New("logout failed")),
m.credentialsStore.EXPECT().Delete("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
)
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
err := user.Logout()
waitForEvents()
assert.NoError(t, err)
r.NoError(t, err)
}
func TestCheckBridgeLoginOK(t *testing.T) {
func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
}
func TestCheckBridgeLoginTwiceOK(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().IsUnlocked().Return(true),
)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
err = user.CheckBridgeLogin(testCredentials.BridgePassword)
waitForEvents()
assert.NoError(t, err)
r.NoError(t, err)
}
func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
@ -199,8 +152,7 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
isApplicationOutdated = true
err := user.CheckBridgeLogin("any-pass")
waitForEvents()
assert.Equal(t, pmapi.ErrUpgradeApplication, err)
r.Equal(t, pmapi.ErrUpgradeApplication, err)
isApplicationOutdated = false
}
@ -209,28 +161,26 @@ func TestCheckBridgeLoginLoggedOut(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(t, err)
m.clientManager.EXPECT().GetClient(gomock.Any()).Return(m.pmapiClient).MinTimes(1)
gomock.InOrder(
// Mock init of user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
m.pmapiClient.EXPECT().ListLabels().Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return(nil, errors.New("ErrUnauthorized")),
m.pmapiClient.EXPECT().Addresses().Return(nil),
// Mock CheckBridgeLogin.
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
)
err = user.init()
assert.Error(t, err)
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(t, err)
err = user.connect(m.pmapiClient, testCredentialsDisconnected)
r.Error(t, err)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
err = user.CheckBridgeLogin(testCredentialsDisconnected.BridgePassword)
waitForEvents()
assert.Equal(t, ErrLoggedOutUser, err)
r.Equal(t, ErrLoggedOutUser, err)
}
func TestCheckBridgeLoginBadPassword(t *testing.T) {
@ -240,12 +190,6 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) {
user := testNewUser(m)
defer cleanUpUserData(user)
gomock.InOrder(
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
)
err := user.CheckBridgeLogin("wrong!")
waitForEvents()
assert.Equal(t, "backend/credentials: incorrect password", err.Error())
r.EqualError(t, err, "backend/credentials: incorrect password")
}

View File

@ -24,7 +24,7 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
gomock "github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
)
func TestNewUserNoCredentialsStore(t *testing.T) {
@ -33,80 +33,56 @@ func TestNewUserNoCredentialsStore(t *testing.T) {
m.credentialsStore.EXPECT().Get("user").Return(nil, errors.New("fail"))
_, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
a.Error(t, err)
}
func TestNewUserAuthRefreshFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(nil, errors.New("bad token")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
_, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.Error(t, err)
}
func TestNewUserUnlockFails(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user")
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user")
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me")
gomock.InOrder(
// Init of user.
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
m.pmapiClient.EXPECT().AuthRefresh("token").Return(testAuthRefresh, nil),
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()),
m.pmapiClient.EXPECT().IsUnlocked().Return(false),
m.pmapiClient.EXPECT().Unlock(gomock.Any(), []byte(testCredentials.MailboxPassword)).Return(errors.New("bad password")),
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(errors.New("bad password")),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.pmapiClient.EXPECT().Logout(),
m.credentialsStore.EXPECT().Logout("user").Return(nil),
m.credentialsStore.EXPECT().Get("user").Return(testCredentialsDisconnected, nil),
// Handle of unlock error.
m.pmapiClient.EXPECT().AuthDelete(gomock.Any()).Return(nil),
m.credentialsStore.EXPECT().Logout("user").Return(testCredentialsDisconnected, nil),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "user@pm.me"),
m.eventListener.EXPECT().Emit(events.LogoutEvent, "user"),
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, "user"),
)
checkNewUserHasCredentials(testCredentialsDisconnected, m)
checkNewUserHasCredentials(m, "failed to unlock user: bad password", testCredentialsDisconnected)
}
func TestNewUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(testCredentials, m)
checkNewUserHasCredentials(m, "", testCredentials)
}
func checkNewUserHasCredentials(creds *credentials.Credentials, m mocks) {
user, _ := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
func checkNewUserHasCredentials(m mocks, wantErr string, wantCreds *credentials.Credentials) {
user, _, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(m.t, err)
defer cleanUpUserData(user)
_ = user.init()
err = user.connect(m.pmapiClient, testCredentials)
if wantErr == "" {
r.NoError(m.t, err)
} else {
r.EqualError(m.t, err, wantErr)
}
r.Equal(m.t, wantCreds, user.creds)
waitForEvents()
a.Equal(m.t, creds, user.creds)
}
func _TestUserEventRefreshUpdatesAddresses(t *testing.T) { // nolint[funlen]
a.Fail(t, "not implemented")
}

View File

@ -15,27 +15,37 @@
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
package users
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
r "github.com/stretchr/testify/require"
)
func TestDoNotCache(t *testing.T) {
var dnc doNotCacheError
require.NoError(t, dnc.errorOrNil())
_, ok := dnc.errorOrNil().(*doNotCacheError)
require.True(t, !ok, "should not be type doNotCacheError")
dnc.add(errors.New("first"))
require.True(t, dnc.errorOrNil() != nil, "should be error")
_, ok = dnc.errorOrNil().(*doNotCacheError)
require.True(t, ok, "should be type doNotCacheError")
dnc.add(errors.New("second"))
dnc.add(errors.New("third"))
t.Log(dnc.errorOrNil())
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
r.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
r.Nil(t, user.store.Close())
user.store = nil
r.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
defer cleanUpUserData(user)
r.NotNil(t, user.store)
r.Nil(t, user.clearStore())
}

View File

@ -18,41 +18,20 @@
package users
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
r "github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockEventLoopNoAction(m)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(m.t, err)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker, false)
r.NoError(m.t, err)
err = user.init()
assert.NoError(m.t, err)
mockAuthUpdate(user, "reftok", m)
return user
}
func testNewUserForLogout(m mocks) *User {
m.clientManager.EXPECT().GetClient("user").Return(m.pmapiClient).MinTimes(1)
mockConnectedUser(m)
mockEventLoopNoAction(m)
user, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.clientManager, m.storeMaker)
assert.NoError(m.t, err)
err = user.init()
assert.NoError(m.t, err)
err = user.connect(m.pmapiClient, creds)
r.NoError(m.t, err)
return user
}
@ -60,30 +39,3 @@ func testNewUserForLogout(m mocks) *User {
func cleanUpUserData(u *User) {
_ = u.clearStore()
}
func _TestNeverLongStorePath(t *testing.T) { // nolint[unused]
assert.Fail(t, "not implemented")
}
func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
require.Nil(t, user.store.Close())
user.store = nil
assert.Nil(t, user.clearStore())
}
func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUserForLogout(m)
defer cleanUpUserData(user)
assert.NotNil(t, user.store)
assert.Nil(t, user.clearStore())
}

View File

@ -19,12 +19,15 @@
package users
import (
"context"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/hashicorp/go-multierror"
@ -45,7 +48,7 @@ type Users struct {
locations Locator
panicHandler PanicHandler
events listener.Listener
clientManager ClientManager
clientManager pmapi.Manager
credStorer CredentialsStorer
storeFactory StoreMaker
@ -62,16 +65,13 @@ type Users struct {
useOnlyActiveAddresses bool
lock sync.RWMutex
// stopAll can be closed to stop all goroutines from looping (watchAppOutdated, watchAPIAuths, heartbeat etc).
stopAll chan struct{}
}
func New(
locations Locator,
panicHandler PanicHandler,
eventListener listener.Listener,
clientManager ClientManager,
clientManager pmapi.Manager,
credStorer CredentialsStorer,
storeFactory StoreMaker,
useOnlyActiveAddresses bool,
@ -87,17 +87,11 @@ func New(
storeFactory: storeFactory,
useOnlyActiveAddresses: useOnlyActiveAddresses,
lock: sync.RWMutex{},
stopAll: make(chan struct{}),
}
go func() {
defer panicHandler.HandlePanic()
u.watchAppOutdated()
}()
go func() {
defer panicHandler.HandlePanic()
u.watchAPIAuths()
u.watchEvents()
}()
if u.credStorer == nil {
@ -109,76 +103,93 @@ func New(
return u
}
func (u *Users) loadUsersFromCredentialsStore() (err error) {
func (u *Users) watchEvents() {
upgradeCh := u.events.ProvideChannel(events.UpgradeApplicationEvent)
internetOnCh := u.events.ProvideChannel(events.InternetOnEvent)
for {
select {
case <-upgradeCh:
isApplicationOutdated = true
u.closeAllConnections()
case <-internetOnCh:
for _, user := range u.users {
if user.store == nil {
if err := user.loadStore(); err != nil {
log.WithError(err).Error("Failed to load store after reconnecting")
}
}
}
}
}
}
func (u *Users) loadUsersFromCredentialsStore() error {
u.lock.Lock()
defer u.lock.Unlock()
userIDs, err := u.credStorer.List()
if err != nil {
return
return err
}
for _, userID := range userIDs {
l := log.WithField("user", userID)
user, newUserErr := newUser(u.panicHandler, userID, u.events, u.credStorer, u.clientManager, u.storeFactory)
if newUserErr != nil {
l.WithField("user", userID).WithError(newUserErr).Warn("Could not load user, skipping")
user, creds, err := newUser(u.panicHandler, userID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
if err != nil {
l.WithError(err).Warn("Could not create user, skipping")
continue
}
u.users = append(u.users, user)
if initUserErr := user.init(); initUserErr != nil {
l.WithField("user", userID).WithError(initUserErr).Warn("Could not initialise user")
if creds.IsConnected() {
// If there is no connection, we don't want to retry. Load should
// happen fast enough to not block GUI. When connection is back up,
// watchEvents and unlockIfNecessary will finish user init later.
if err := u.loadConnectedUser(pmapi.ContextWithoutRetry(context.Background()), user, creds); err != nil {
l.WithError(err).Warn("Could not load connected user")
}
} else {
l.Warn("User is disconnected and must be connected manually")
if err := user.connect(u.clientManager.NewClient("", "", "", time.Time{}), creds); err != nil {
l.WithError(err).Warn("Could not load disconnected user")
}
}
}
return err
}
func (u *Users) watchAppOutdated() {
ch := make(chan string)
u.events.Add(events.UpgradeApplicationEvent, ch)
for {
select {
case <-ch:
isApplicationOutdated = true
u.closeAllConnections()
case <-u.stopAll:
return
}
func (u *Users) loadConnectedUser(ctx context.Context, user *User, creds *credentials.Credentials) error {
uid, ref, err := creds.SplitAPIToken()
if err != nil {
return errors.Wrap(err, "could not get user's refresh token")
}
}
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
func (u *Users) watchAPIAuths() {
for {
select {
case auth := <-u.clientManager.GetAuthUpdateChannel():
log.Debug("Users received auth from ClientManager")
client, auth, err := u.clientManager.NewClientWithRefresh(ctx, uid, ref)
if err != nil {
// When client cannot be refreshed right away due to no connection,
// we create client which will refresh automatically when possible.
connectErr := user.connect(u.clientManager.NewClient(uid, "", ref, time.Time{}), creds)
user, ok := u.hasUser(auth.UserID)
if !ok {
log.WithField("userID", auth.UserID).Info("User not available for auth update")
continue
}
if auth.Auth != nil {
user.updateAuthToken(auth.Auth)
} else if err := user.logout(); err != nil {
log.WithError(err).
WithField("userID", auth.UserID).
Error("User logout failed while watching API auths")
}
case <-u.stopAll:
return
switch errors.Cause(err) {
case pmapi.ErrNoConnection, pmapi.ErrUpgradeApplication:
return connectErr
}
if logoutErr := user.logout(); logoutErr != nil {
logrus.WithError(logoutErr).Warn("Could not logout user")
}
return errors.Wrap(err, "could not refresh token")
}
// Update the user's credentials with the latest auth used to connect this user.
if creds, err = u.credStorer.UpdateToken(creds.UserID, auth.UID, auth.RefreshToken); err != nil {
return errors.Wrap(err, "could not create get user's refresh token")
}
return user.connect(client, creds)
}
func (u *Users) closeAllConnections() {
@ -192,63 +203,47 @@ func (u *Users) closeAllConnections() {
func (u *Users) Login(username, password string) (authClient pmapi.Client, auth *pmapi.Auth, err error) {
u.crashBandicoot(username)
// We need to use anonymous client because we don't yet have userID and so can't save auth tokens yet.
authClient = u.clientManager.GetAnonymousClient()
authInfo, err := authClient.AuthInfo(username)
if err != nil {
log.WithField("username", username).WithError(err).Error("Could not get auth info for user")
return
}
if auth, err = authClient.Auth(username, password, authInfo); err != nil {
log.WithField("username", username).WithError(err).Error("Could not get auth for user")
return
}
return
return u.clientManager.NewClientWithLogin(context.Background(), username, password)
}
// FinishLogin finishes the login procedure and adds the user into the credentials store.
func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphrase string) (user *User, err error) { //nolint[funlen]
defer func() {
if err != nil {
log.WithError(err).Debug("Login not finished; removing auth session")
if delAuthErr := authClient.DeleteAuth(); delAuthErr != nil {
log.WithError(delAuthErr).Error("Failed to clear login session after unlock")
}
}
// The anonymous client will be removed from list and authentication will not be deleted.
authClient.Logout()
}()
apiUser, hashedPassphrase, err := getAPIUser(authClient, mbPassphrase)
func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password string) (user *User, err error) { //nolint[funlen]
apiUser, passphrase, err := getAPIUser(context.Background(), client, password)
if err != nil {
log.WithError(err).Error("Failed to get API user")
return
return nil, err
}
log.Info("Got API user")
if user, ok := u.hasUser(apiUser.ID); ok {
if user.IsConnected() {
if err := client.AuthDelete(context.Background()); err != nil {
logrus.WithError(err).Warn("Failed to delete new auth session")
}
var ok bool
if user, ok = u.hasUser(apiUser.ID); ok {
if err = u.connectExistingUser(user, auth, hashedPassphrase); err != nil {
log.WithError(err).Error("Failed to connect existing user")
return
return nil, errors.New("user is already connected")
}
} else {
if err = u.addNewUser(apiUser, auth, hashedPassphrase); err != nil {
log.WithError(err).Error("Failed to add new user")
return
// Update the user's credentials with the latest auth used to connect this user.
if _, err := u.credStorer.UpdateToken(auth.UserID, auth.UID, auth.RefreshToken); err != nil {
return nil, errors.Wrap(err, "failed to load user credentials")
}
// Update the password in case the user changed it.
creds, err := u.credStorer.UpdatePassword(apiUser.ID, string(passphrase))
if err != nil {
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
}
if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user")
}
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
return user, nil
}
// Old credentials use username as key (user ID) which needs to be removed
// once user logs in again with proper ID fetched from API.
if _, ok := u.hasUser(apiUser.Name); ok {
if err := u.DeleteUser(apiUser.Name, true); err != nil {
log.WithError(err).Error("Failed to delete old user")
}
if err := u.addNewUser(client, apiUser, auth, passphrase); err != nil {
return nil, errors.Wrap(err, "failed to add new user")
}
u.events.Emit(events.UserRefreshEvent, apiUser.ID)
@ -256,107 +251,63 @@ func (u *Users) FinishLogin(authClient pmapi.Client, auth *pmapi.Auth, mbPassphr
return u.GetUser(apiUser.ID)
}
// connectExistingUser connects an existing user.
func (u *Users) connectExistingUser(user *User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
if user.IsConnected() {
return errors.New("user is already connected")
}
log.Info("Connecting existing user")
// Update the user's password in the cred store in case they changed it.
if err = u.credStorer.UpdatePassword(user.ID(), hashedPassphrase); err != nil {
return errors.Wrap(err, "failed to update password of user in credentials store")
}
client := u.clientManager.GetClient(user.ID())
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to refresh auth token of new client")
}
if err = u.credStorer.UpdateToken(user.ID(), auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to update token of user in credentials store")
}
if err = user.init(); err != nil {
return errors.Wrap(err, "failed to initialise user")
}
return
}
// addNewUser adds a new user.
func (u *Users) addNewUser(apiUser *pmapi.User, auth *pmapi.Auth, hashedPassphrase string) (err error) {
func (u *Users) addNewUser(client pmapi.Client, apiUser *pmapi.User, auth *pmapi.Auth, passphrase []byte) error {
u.lock.Lock()
defer u.lock.Unlock()
client := u.clientManager.GetClient(apiUser.ID)
var emails []string
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
return errors.Wrap(err, "failed to refresh token in new client")
}
if apiUser, err = client.CurrentUser(); err != nil {
return errors.Wrap(err, "failed to update API user")
}
var emails []string //nolint[prealloc]
if u.useOnlyActiveAddresses {
emails = client.Addresses().ActiveEmails()
} else {
emails = client.Addresses().AllEmails()
}
if _, err = u.credStorer.Add(apiUser.ID, apiUser.Name, auth.GenToken(), hashedPassphrase, emails); err != nil {
return errors.Wrap(err, "failed to add user to credentials store")
if _, err := u.credStorer.Add(apiUser.ID, apiUser.Name, auth.UID, auth.RefreshToken, string(passphrase), emails); err != nil {
return errors.Wrap(err, "failed to add user credentials to credentials store")
}
user, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.clientManager, u.storeFactory)
user, creds, err := newUser(u.panicHandler, apiUser.ID, u.events, u.credStorer, u.storeFactory, u.useOnlyActiveAddresses)
if err != nil {
return errors.Wrap(err, "failed to create user")
return errors.Wrap(err, "failed to create new user")
}
// The user needs to be part of the users list in order for it to receive an auth during initialisation.
u.users = append(u.users, user)
if err = user.init(); err != nil {
u.users = u.users[:len(u.users)-1]
return errors.Wrap(err, "failed to initialise user")
if err := user.connect(client, creds); err != nil {
return errors.Wrap(err, "failed to connect new user")
}
if err := u.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)); err != nil {
log.WithError(err).Error("Failed to send metric")
}
return err
u.users = append(u.users, user)
return nil
}
func getAPIUser(client pmapi.Client, mbPassphrase string) (user *pmapi.User, hashedPassphrase string, err error) {
salt, err := client.AuthSalt()
func getAPIUser(ctx context.Context, client pmapi.Client, password string) (*pmapi.User, []byte, error) {
salt, err := client.AuthSalt(ctx)
if err != nil {
log.WithError(err).Error("Could not get salt")
return nil, "", err
return nil, nil, errors.Wrap(err, "failed to get salt")
}
hashedPassphrase, err = pmapi.HashMailboxPassword(mbPassphrase, salt)
passphrase, err := pmapi.HashMailboxPassword(password, salt)
if err != nil {
log.WithError(err).Error("Could not hash mailbox password")
return nil, "", err
return nil, nil, errors.Wrap(err, "failed to hash password")
}
// We unlock the user's PGP key here to detect if the user's mailbox password is wrong.
if err = client.Unlock([]byte(hashedPassphrase)); err != nil {
log.WithError(err).Error("Wrong mailbox password")
return nil, "", ErrWrongMailboxPassword
if err := client.Unlock(ctx, passphrase); err != nil {
return nil, nil, ErrWrongMailboxPassword
}
if user, err = client.CurrentUser(); err != nil {
log.WithError(err).Error("Could not load user data")
return nil, "", err
user, err := client.CurrentUser(ctx)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to load user data")
}
return user, hashedPassphrase, nil
return user, passphrase, nil
}
// GetUsers returns all added users into keychain (even logged out users).
@ -452,11 +403,9 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
// SendMetric sends a metric. We don't want to return any errors, only log them.
func (u *Users) SendMetric(m metrics.Metric) error {
c := u.clientManager.GetAnonymousClient()
defer c.Logout()
cat, act, lab := m.Get()
if err := c.SendSimpleMetric(string(cat), string(act), string(lab)); err != nil {
if err := u.clientManager.SendSimpleMetric(context.Background(), string(cat), string(act), string(lab)); err != nil {
return err
}
@ -481,17 +430,6 @@ func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy()
}
// CheckConnection returns whether there is an internet connection.
// This should use the connection manager when it is eventually implemented.
func (u *Users) CheckConnection() error {
return u.clientManager.CheckConnection()
}
// StopWatchers stops all goroutines.
func (u *Users) StopWatchers() {
close(u.stopAll)
}
// hasUser returns whether the struct currently has a user with ID `id`.
func (u *Users) hasUser(id string) (user *User, ok bool) {
for _, u := range u.users {

Some files were not shown because too many files have changed in this diff Show More