GODT-1158: Store full messages bodies on disk

- GODT-1158: simple on-disk cache in store
- GODT-1158: better member naming in event loop
- GODT-1158: create on-disk cache during bridge setup
- GODT-1158: better job options
- GODT-1158: rename GetLiteral to GetRFC822
- GODT-1158: rename events -> currentEvents
- GODT-1158: unlock cache per-user
- GODT-1158: clean up cache after logout
- GODT-1158: randomized encrypted cache passphrase
- GODT-1158: Opt out of on-disk cache in settings
- GODT-1158: free space in cache
- GODT-1158: make tests compile
- GODT-1158: optional compression
- GODT-1158: cache custom location
- GODT-1158: basic capacity checker
- GODT-1158: cache free space config
- GODT-1158: only unlock cache if pmapi client is unlocked as well
- GODT-1158: simple background sync worker
- GODT-1158: set size/bodystructure when caching message
- GODT-1158: limit store db update blocking with semaphore
- GODT-1158: dumb 10-semaphore
- GODT-1158: properly handle delete; remove bad bodystructure handling
- GODT-1158: hacky fix for caching after logout... baaaaad
- GODT-1158: cache worker
- GODT-1158: compute body structure lazily
- GODT-1158: cache size in store
- GODT-1158: notify cacher when adding to store
- GODT-1158: 15 second store cache watcher
- GODT-1158: enable cacher
- GODT-1158: better cache worker starting/stopping
- GODT-1158: limit cacher to less concurrency than disk cache
- GODT-1158: message builder prio + pchan pkg
- GODT-1158: fix pchan, use in message builder
- GODT-1158: no sem in cacher (rely on message builder prio)
- GODT-1158: raise priority of existing jobs when requested
- GODT-1158: pending messages in on-disk cache
- GODT-1158: WIP just a note about deleting messages from disk cache
- GODT-1158: pending wait when trying to write
- GODT-1158: pending.add to return bool
- GODT-1225: Headers in bodystructure are stored as bytes.
- GODT-1158: fixing header caching
- GODT-1158: don't cache in background
- GODT-1158: all concurrency set in settings
- GODT-1158: worker pools inside message builder
- GODT-1158: fix linter issues
- GODT-1158: remove completed builds from builder
- GODT-1158: remove builder pool
- GODT-1158: cacher defer job done properly
- GODT-1158: fix linter
- GODT-1299: Continue with bodystructure build if deserialization failed
- GODT-1324: Delete messages from the cache when they are deleted on the server
- GODT-1158: refactor cache tests
- GODT-1158: move builder to app/bridge
- GODT-1306: Migrate cache on disk when location is changed (and delete when disabled)
This commit is contained in:
James Houlahan
2021-07-30 12:20:38 +02:00
committed by Jakub
parent 5cb893fc1b
commit 6bd0739013
79 changed files with 2911 additions and 1387 deletions

View File

@ -230,7 +230,7 @@ integration-test-bridge:
mocks: mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier,Storer > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go
@ -288,7 +288,7 @@ run-nogui-cli: clean-vendor gofiles
PROTONMAIL_ENV=dev go run ${BUILD_FLAGS} cmd/${TARGET_CMD}/main.go ${RUN_FLAGS} -c PROTONMAIL_ENV=dev go run ${BUILD_FLAGS} cmd/${TARGET_CMD}/main.go ${RUN_FLAGS} -c
run-debug: run-debug:
PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS} PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS} --noninteractive
run-qml-preview: run-qml-preview:
find internal/frontend/qml/ -iname '*qmlc' | xargs rm -f find internal/frontend/qml/ -iname '*qmlc' | xargs rm -f

1
TODO.md Normal file
View File

@ -0,0 +1 @@
- when cache is full, we need to stop the watcher? don't want to keep downloading messages and throwing them away when we try to cache them.

8
go.mod
View File

@ -32,7 +32,6 @@ require (
github.com/emersion/go-imap-move v0.0.0-20190710073258-6e5a51a5b342 github.com/emersion/go-imap-move v0.0.0-20190710073258-6e5a51a5b342
github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26
github.com/emersion/go-mbox v1.0.2
github.com/emersion/go-message v0.12.1-0.20201221184100-40c3f864532b github.com/emersion/go-message v0.12.1-0.20201221184100-40c3f864532b
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
github.com/emersion/go-smtp v0.14.0 github.com/emersion/go-smtp v0.14.0
@ -45,7 +44,6 @@ require (
github.com/golang/mock v1.4.4 github.com/golang/mock v1.4.4
github.com/google/go-cmp v0.5.1 github.com/google/go-cmp v0.5.1
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1
github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c // indirect
github.com/hashicorp/go-multierror v1.1.0 github.com/hashicorp/go-multierror v1.1.0
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
@ -55,13 +53,11 @@ require (
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.7.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/urfave/cli/v2 v2.2.0 github.com/urfave/cli/v2 v2.2.0
github.com/vmihailenco/msgpack/v5 v5.1.3 github.com/vmihailenco/msgpack/v5 v5.1.3
go.etcd.io/bbolt v1.3.6 go.etcd.io/bbolt v1.3.6

23
go.sum
View File

@ -124,8 +124,6 @@ github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c h1:khcEdu1y
github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c/go.mod h1:iApyhIQBiU4XFyr+3kdJyyGqle82TbQyuP2o+OZHrV0= github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c/go.mod h1:iApyhIQBiU4XFyr+3kdJyyGqle82TbQyuP2o+OZHrV0=
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 h1:FiSb8+XBQQSkcX3ubr+1tAtlRJBYaFmRZqOAweZ9Wy8= github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 h1:FiSb8+XBQQSkcX3ubr+1tAtlRJBYaFmRZqOAweZ9Wy8=
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26/go.mod h1:+gnnZx3Mg3MnCzZrv0eZdp5puxXQUgGT/6N6L7ShKfM= github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26/go.mod h1:+gnnZx3Mg3MnCzZrv0eZdp5puxXQUgGT/6N6L7ShKfM=
github.com/emersion/go-mbox v1.0.2 h1:tE/rT+lEugK9y0myEymCCHnwlZN04hlXPrbKkxRBA5I=
github.com/emersion/go-mbox v1.0.2/go.mod h1:Yp9IVuuOYLEuMv4yjgDHvhb5mHOcYH6x92Oas3QqEZI=
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k= github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
@ -197,9 +195,6 @@ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v0.0.0-20190411002643-bd77b112433e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4=
github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
@ -269,7 +264,6 @@ github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
@ -351,6 +345,8 @@ github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285 h1:d54EL9l+XteliUfUCGsEwwuk65dmmxX85VXF+9T6+50=
github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285/go.mod h1:fxIDly1xtudczrZeOOlfaUvd2OPb2qZAPuWdU2BsBTk=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo= github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo=
@ -365,13 +361,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@ -401,12 +393,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e h1:G0DQ/TRQyrEZjtLlLwevFjaRiG8eeCMlq9WXQ2OO2bk=
github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e/go.mod h1:SUUR2j3aE1z6/g76SdD6NwACEpvCxb3fvG82eKbD6us=
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d h1:hAZyEG2swPRWjF0kqqdGERXUazYnRJdAk4a58f14z7Y=
github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:7m8PDYDEtEVqfjoUQc2UrFqhG0CDmoVJjRlQxexndFc=
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d h1:AJRoBel/g9cDS+yE8BcN3E+TDD/xNAguG21aoR8DAIE=
github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:mH55Ek7AZcdns5KPp99O0bg+78el64YCYWHiQKrOdt4=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@ -444,7 +430,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190418165655-df01cb2cc480/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -488,7 +473,6 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@ -521,9 +505,7 @@ golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -557,7 +539,6 @@ golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=

View File

@ -32,7 +32,9 @@ import (
"github.com/ProtonMail/proton-bridge/internal/frontend/types" "github.com/ProtonMail/proton-bridge/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/internal/imap" "github.com/ProtonMail/proton-bridge/internal/imap"
"github.com/ProtonMail/proton-bridge/internal/smtp" "github.com/ProtonMail/proton-bridge/internal/smtp"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater" "github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -69,10 +71,21 @@ func New(base *base.Base) *cli.App {
func run(b *base.Base, c *cli.Context) error { // nolint[funlen] func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
tlsConfig, err := loadTLSConfig(b) tlsConfig, err := loadTLSConfig(b)
if err != nil { if err != nil {
logrus.WithError(err).Fatal("Failed to load TLS config") return err
} }
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, b.CM, b.Creds, b.Updater, b.Versioner)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, bridge) cache, err := loadCache(b)
if err != nil {
return err
}
builder := message.NewBuilder(
b.Settings.GetInt(settings.FetchWorkers),
b.Settings.GetInt(settings.AttachmentWorkers),
)
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, cache, builder, b.CM, b.Creds, b.Updater, b.Versioner)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge) smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
go func() { go func() {
@ -233,3 +246,35 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool)
f.NotifySilentUpdateInstalled() f.NotifySilentUpdateInstalled()
} }
// NOTE(GODT-1158): How big should in-memory cache be?
// NOTE(GODT-1158): How to handle cache location migration if user changes custom path?
func loadCache(b *base.Base) (cache.Cache, error) {
if !b.Settings.GetBool(settings.CacheEnabledKey) {
return cache.NewInMemoryCache(100 * (1 << 20)), nil
}
var compressor cache.Compressor
// NOTE(GODT-1158): If user changes compression setting we have to nuke the cache.
if b.Settings.GetBool(settings.CacheCompressionKey) {
compressor = &cache.GZipCompressor{}
} else {
compressor = &cache.NoopCompressor{}
}
var path string
if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" {
path = customPath
} else {
path = b.Cache.GetDefaultMessageCacheDir()
}
return cache.NewOnDiskCache(path, compressor, cache.Options{
MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)),
MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey),
ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead),
ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
})
}

View File

@ -28,17 +28,17 @@ import (
"github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater" "github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
logrus "github.com/sirupsen/logrus" logrus "github.com/sirupsen/logrus"
) )
var ( var log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
)
type Bridge struct { type Bridge struct {
*users.Users *users.Users
@ -52,11 +52,13 @@ type Bridge struct {
func New( func New(
locations Locator, locations Locator,
cache Cacher, cacheProvider CacheProvider,
s SettingsProvider, setting SettingsProvider,
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler, panicHandler users.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
clientManager pmapi.Manager, clientManager pmapi.Manager,
credStorer users.CredentialsStorer, credStorer users.CredentialsStorer,
updater Updater, updater Updater,
@ -64,7 +66,7 @@ func New(
) *Bridge { ) *Bridge {
// Allow DoH before starting the app if the user has previously set this setting. // Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked. // This allows us to start even if protonmail is blocked.
if s.GetBool(settings.AllowProxyKey) { if setting.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy() clientManager.AllowProxy()
} }
@ -74,25 +76,25 @@ func New(
eventListener, eventListener,
clientManager, clientManager,
credStorer, credStorer,
newStoreFactory(cache, sentryReporter, panicHandler, eventListener), newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
) )
b := &Bridge{ b := &Bridge{
Users: u, Users: u,
locations: locations, locations: locations,
settings: s, settings: setting,
clientManager: clientManager, clientManager: clientManager,
updater: updater, updater: updater,
versioner: versioner, versioner: versioner,
} }
if s.GetBool(settings.FirstStartKey) { if setting.GetBool(settings.FirstStartKey) {
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil { if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
logrus.WithError(err).Error("Failed to send metric") logrus.WithError(err).Error("Failed to send metric")
} }
s.SetBool(settings.FirstStartKey, false) setting.SetBool(settings.FirstStartKey, false)
} }
go b.heartbeat() go b.heartbeat()

View File

@ -23,47 +23,65 @@ import (
"github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store" "github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
) )
type storeFactory struct { type storeFactory struct {
cache Cacher cacheProvider CacheProvider
sentryReporter *sentry.Reporter sentryReporter *sentry.Reporter
panicHandler users.PanicHandler panicHandler users.PanicHandler
eventListener listener.Listener eventListener listener.Listener
storeCache *store.Cache events *store.Events
cache cache.Cache
builder *message.Builder
} }
func newStoreFactory( func newStoreFactory(
cache Cacher, cacheProvider CacheProvider,
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler, panicHandler users.PanicHandler,
eventListener listener.Listener, eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
) *storeFactory { ) *storeFactory {
return &storeFactory{ return &storeFactory{
cache: cache, cacheProvider: cacheProvider,
sentryReporter: sentryReporter, sentryReporter: sentryReporter,
panicHandler: panicHandler, panicHandler: panicHandler,
eventListener: eventListener, eventListener: eventListener,
storeCache: store.NewCache(cache.GetIMAPCachePath()), events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
cache: cache,
builder: builder,
} }
} }
// New creates new store for given user. // New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) { func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID()) return store.New(
return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache) f.sentryReporter,
f.panicHandler,
user,
f.eventListener,
f.cache,
f.builder,
getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
f.events,
)
} }
// Remove removes all store files for given user. // Remove removes all store files for given user.
func (f *storeFactory) Remove(userID string) error { func (f *storeFactory) Remove(userID string) error {
storePath := getUserStorePath(f.cache.GetDBDir(), userID) return store.RemoveStore(
return store.RemoveStore(f.storeCache, storePath, userID) f.events,
getUserStorePath(f.cacheProvider.GetDBDir(), userID),
userID,
)
} }
// getUserStorePath returns the file path of the store database for the given userID. // getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) { func getUserStorePath(storeDir string, userID string) (path string) {
fileName := fmt.Sprintf("mailbox-%v.db", userID) return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
return filepath.Join(storeDir, fileName)
} }

View File

@ -28,7 +28,7 @@ type Locator interface {
ClearUpdates() error ClearUpdates() error
} }
type Cacher interface { type CacheProvider interface {
GetIMAPCachePath() string GetIMAPCachePath() string
GetDBDir() string GetDBDir() string
} }
@ -38,6 +38,7 @@ type SettingsProvider interface {
Set(key string, value string) Set(key string, value string)
GetBool(key string) bool GetBool(key string) bool
SetBool(key string, val bool) SetBool(key string, val bool)
GetInt(key string) int
} }
type Updater interface { type Updater interface {

View File

@ -45,6 +45,11 @@ func (c *Cache) GetDBDir() string {
return c.getCurrentCacheDir() return c.getCurrentCacheDir()
} }
// GetDefaultMessageCacheDir returns folder for cached messages files.
func (c *Cache) GetDefaultMessageCacheDir() string {
return filepath.Join(c.getCurrentCacheDir(), "messages")
}
// GetIMAPCachePath returns path to file with IMAP status. // GetIMAPCachePath returns path to file with IMAP status.
func (c *Cache) GetIMAPCachePath() string { func (c *Cache) GetIMAPCachePath() string {
return filepath.Join(c.getCurrentCacheDir(), "user_info.json") return filepath.Join(c.getCurrentCacheDir(), "user_info.json")

View File

@ -100,18 +100,28 @@ func (p *keyValueStore) GetBool(key string) bool {
} }
func (p *keyValueStore) GetInt(key string) int { func (p *keyValueStore) GetInt(key string) int {
if p.Get(key) == "" {
return 0
}
value, err := strconv.Atoi(p.Get(key)) value, err := strconv.Atoi(p.Get(key))
if err != nil { if err != nil {
logrus.WithError(err).Error("Cannot parse int") logrus.WithError(err).Error("Cannot parse int")
} }
return value return value
} }
func (p *keyValueStore) GetFloat64(key string) float64 { func (p *keyValueStore) GetFloat64(key string) float64 {
if p.Get(key) == "" {
return 0
}
value, err := strconv.ParseFloat(p.Get(key), 64) value, err := strconv.ParseFloat(p.Get(key), 64)
if err != nil { if err != nil {
logrus.WithError(err).Error("Cannot parse float64") logrus.WithError(err).Error("Cannot parse float64")
} }
return value return value
} }

View File

@ -43,6 +43,16 @@ const (
UpdateChannelKey = "update_channel" UpdateChannelKey = "update_channel"
RolloutKey = "rollout" RolloutKey = "rollout"
PreferredKeychainKey = "preferred_keychain" PreferredKeychainKey = "preferred_keychain"
CacheEnabledKey = "cache_enabled"
CacheCompressionKey = "cache_compression"
CacheLocationKey = "cache_location"
CacheMinFreeAbsKey = "cache_min_free_abs"
CacheMinFreeRatKey = "cache_min_free_rat"
CacheConcurrencyRead = "cache_concurrent_read"
CacheConcurrencyWrite = "cache_concurrent_write"
IMAPWorkers = "imap_workers"
FetchWorkers = "fetch_workers"
AttachmentWorkers = "attachment_workers"
) )
type Settings struct { type Settings struct {
@ -80,6 +90,16 @@ func (s *Settings) setDefaultValues() {
s.setDefault(UpdateChannelKey, "") s.setDefault(UpdateChannelKey, "")
s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint[gosec] G404 It is OK to use weak random number generator here s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint[gosec] G404 It is OK to use weak random number generator here
s.setDefault(PreferredKeychainKey, "") s.setDefault(PreferredKeychainKey, "")
s.setDefault(CacheEnabledKey, "true")
s.setDefault(CacheCompressionKey, "true")
s.setDefault(CacheLocationKey, "")
s.setDefault(CacheMinFreeAbsKey, "250000000")
s.setDefault(CacheMinFreeRatKey, "")
s.setDefault(CacheConcurrencyRead, "16")
s.setDefault(CacheConcurrencyWrite, "16")
s.setDefault(IMAPWorkers, "16")
s.setDefault(FetchWorkers, "16")
s.setDefault(AttachmentWorkers, "16")
s.setDefault(APIPortKey, DefaultAPIPort) s.setDefault(APIPortKey, DefaultAPIPort)
s.setDefault(IMAPPortKey, DefaultIMAPPort) s.setDefault(IMAPPortKey, DefaultIMAPPort)

View File

@ -128,6 +128,24 @@ func New( //nolint[funlen]
}) })
fe.AddCmd(dohCmd) fe.AddCmd(dohCmd)
// Cache-On-Disk commands.
codCmd := &ishell.Cmd{Name: "local-cache",
Help: "manage the local encrypted message cache",
}
codCmd.AddCmd(&ishell.Cmd{Name: "enable",
Help: "enable the local cache",
Func: fe.enableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{Name: "disable",
Help: "disable the local cache",
Func: fe.disableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{Name: "change-location",
Help: "change the location of the local cache",
Func: fe.setCacheOnDiskLocation,
})
fe.AddCmd(codCmd)
// Updates commands. // Updates commands.
updatesCmd := &ishell.Cmd{Name: "updates", updatesCmd := &ishell.Cmd{Name: "updates",
Help: "manage bridge updates", Help: "manage bridge updates",

View File

@ -19,6 +19,7 @@ package cli
import ( import (
"fmt" "fmt"
"os"
"strconv" "strconv"
"strings" "strings"
@ -155,6 +156,67 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
} }
} }
func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) {
if f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already enabled.")
return
}
if f.yesNoQuestion("Are you sure you want to enable the local cache") {
// Set this back to the default location before enabling.
f.settings.Set(settings.CacheLocationKey, "")
if err := f.bridge.EnableCache(); err != nil {
f.Println("The local cache could not be enabled.")
return
}
f.settings.SetBool(settings.CacheEnabledKey, true)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) {
if !f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already disabled.")
return
}
if f.yesNoQuestion("Are you sure you want to disable the local cache") {
if err := f.bridge.DisableCache(); err != nil {
f.Println("The local cache could not be disabled.")
return
}
f.settings.SetBool(settings.CacheEnabledKey, false)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) {
if !f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache must be enabled.")
return
}
if location := f.settings.Get(settings.CacheLocationKey); location != "" {
f.Println("The current local cache location is:", location)
}
if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
if err := f.bridge.MigrateCache(f.settings.Get(settings.CacheLocationKey), location); err != nil {
f.Println("The local cache location could not be changed.")
return
}
f.settings.Set(settings.CacheLocationKey, location)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) isPortFree(port string) bool { func (f *frontendCLI) isPortFree(port string) bool {
port = strings.ReplaceAll(port, ":", "") port = strings.ReplaceAll(port, ":", "")
if port == "" || port == currentPort { if port == "" || port == currentPort {
@ -171,3 +233,13 @@ func (f *frontendCLI) isPortFree(port string) bool {
} }
return true return true
} }
// NOTE(GODT-1158): Check free space in location.
func (f *frontendCLI) isCacheLocationUsable(location string) bool {
stat, err := os.Stat(location)
if err != nil {
return false
}
return stat.IsDir()
}

View File

@ -77,6 +77,9 @@ type Bridger interface {
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
AllowProxy() AllowProxy()
DisallowProxy() DisallowProxy()
EnableCache() error
DisableCache() error
MigrateCache(from, to string) error
GetUpdateChannel() updater.UpdateChannel GetUpdateChannel() updater.UpdateChannel
SetUpdateChannel(updater.UpdateChannel) (needRestart bool, err error) SetUpdateChannel(updater.UpdateChannel) (needRestart bool, err error)
GetKeychainApp() string GetKeychainApp() string

View File

@ -37,21 +37,13 @@ import (
"time" "time"
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/emersion/go-imap" "github.com/emersion/go-imap"
goIMAPBackend "github.com/emersion/go-imap/backend" goIMAPBackend "github.com/emersion/go-imap/backend"
) )
const (
// NOTE: Each fetch worker has its own set of attach workers so there can be up to 20*5=100 API requests at once.
// This is a reasonable limit to not overwhelm API while still maintaining as much parallelism as possible.
fetchWorkers = 20 // In how many workers to fetch message (group list on IMAP).
attachWorkers = 5 // In how many workers to fetch attachments (for one message).
buildWorkers = 20 // In how many workers to build messages.
)
type panicHandler interface { type panicHandler interface {
HandlePanic() HandlePanic()
} }
@ -61,26 +53,32 @@ type imapBackend struct {
bridge bridger bridge bridger
updates *imapUpdates updates *imapUpdates
eventListener listener.Listener eventListener listener.Listener
listWorkers int
users map[string]*imapUser users map[string]*imapUser
usersLocker sync.Locker usersLocker sync.Locker
builder *message.Builder
imapCache map[string]map[string]string imapCache map[string]map[string]string
imapCachePath string imapCachePath string
imapCacheLock *sync.RWMutex imapCacheLock *sync.RWMutex
} }
type settingsProvider interface {
GetInt(string) int
}
// NewIMAPBackend returns struct implementing go-imap/backend interface. // NewIMAPBackend returns struct implementing go-imap/backend interface.
func NewIMAPBackend( func NewIMAPBackend(
panicHandler panicHandler, panicHandler panicHandler,
eventListener listener.Listener, eventListener listener.Listener,
cache cacheProvider, cache cacheProvider,
setting settingsProvider,
bridge *bridge.Bridge, bridge *bridge.Bridge,
) *imapBackend { //nolint[golint] ) *imapBackend { //nolint[golint]
bridgeWrap := newBridgeWrap(bridge) bridgeWrap := newBridgeWrap(bridge)
backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener)
imapWorkers := setting.GetInt(settings.IMAPWorkers)
backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener, imapWorkers)
go backend.monitorDisconnectedUsers() go backend.monitorDisconnectedUsers()
@ -92,6 +90,7 @@ func newIMAPBackend(
cache cacheProvider, cache cacheProvider,
bridge bridger, bridge bridger,
eventListener listener.Listener, eventListener listener.Listener,
listWorkers int,
) *imapBackend { ) *imapBackend {
return &imapBackend{ return &imapBackend{
panicHandler: panicHandler, panicHandler: panicHandler,
@ -102,10 +101,9 @@ func newIMAPBackend(
users: map[string]*imapUser{}, users: map[string]*imapUser{},
usersLocker: &sync.Mutex{}, usersLocker: &sync.Mutex{},
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
imapCachePath: cache.GetIMAPCachePath(), imapCachePath: cache.GetIMAPCachePath(),
imapCacheLock: &sync.RWMutex{}, imapCacheLock: &sync.RWMutex{},
listWorkers: listWorkers,
} }
} }

View File

@ -1,151 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"bytes"
"sort"
"sync"
"time"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
)
type key struct {
ID string
Timestamp int64
Size int
}
type oldestFirst []key
func (s oldestFirst) Len() int { return len(s) }
func (s oldestFirst) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s oldestFirst) Less(i, j int) bool { return s[i].Timestamp < s[j].Timestamp }
type cachedMessage struct {
key
data []byte
structure pkgMsg.BodyStructure
}
//nolint[gochecknoglobals]
var (
cacheTimeLimit = int64(1 * 60 * 60 * 1000) // milliseconds
cacheSizeLimit = 100 * 1000 * 1000 // B - MUST be larger than email max size limit (~ 25 MB)
mailCache = make(map[string]cachedMessage)
// cacheMutex takes care of one single operation, whereas buildMutex takes
// care of the whole action doing multiple operations. buildMutex will protect
// you from asking server or decrypting or building the same message more
// than once. When first request to build the message comes, it will block
// all other build requests. When the first one is done, all others are
// handled by cache, not doing anything twice. With cacheMutex we are safe
// only to not mess up with the cache, but we could end up downloading and
// building message twice.
cacheMutex = &sync.Mutex{}
buildMutex = &sync.Mutex{}
buildLocks = map[string]interface{}{}
)
func (m *cachedMessage) isValidOrDel() bool {
if m.key.Timestamp+cacheTimeLimit < timestamp() {
delete(mailCache, m.key.ID)
return false
}
return true
}
func timestamp() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}
func Clear() {
mailCache = make(map[string]cachedMessage)
}
// BuildLock locks per message level, not on global level.
// Multiple different messages can be building at once.
func BuildLock(messageID string) {
for {
buildMutex.Lock()
if _, ok := buildLocks[messageID]; ok { // if locked, wait
buildMutex.Unlock()
time.Sleep(10 * time.Millisecond)
} else { // if unlocked, lock it
buildLocks[messageID] = struct{}{}
buildMutex.Unlock()
return
}
}
}
func BuildUnlock(messageID string) {
buildMutex.Lock()
defer buildMutex.Unlock()
delete(buildLocks, messageID)
}
func LoadMail(mID string) (reader *bytes.Reader, structure *pkgMsg.BodyStructure) {
reader = &bytes.Reader{}
cacheMutex.Lock()
defer cacheMutex.Unlock()
if message, ok := mailCache[mID]; ok && message.isValidOrDel() {
reader = bytes.NewReader(message.data)
structure = &message.structure
// Update timestamp to keep emails which are used often.
message.Timestamp = timestamp()
}
return
}
func SaveMail(mID string, msg []byte, structure *pkgMsg.BodyStructure) {
cacheMutex.Lock()
defer cacheMutex.Unlock()
newMessage := cachedMessage{
key: key{
ID: mID,
Timestamp: timestamp(),
Size: len(msg),
},
data: msg,
structure: *structure,
}
// Remove old and reduce size.
totalSize := 0
messageList := []key{}
for _, message := range mailCache {
if message.isValidOrDel() {
messageList = append(messageList, message.key)
totalSize += message.key.Size
}
}
sort.Sort(oldestFirst(messageList))
var oldest key
for totalSize+newMessage.key.Size >= cacheSizeLimit {
oldest, messageList = messageList[0], messageList[1:]
delete(mailCache, oldest.ID)
totalSize -= oldest.Size
}
// Write new.
mailCache[mID] = newMessage
}

View File

@ -1,98 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"fmt"
"testing"
"time"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/stretchr/testify/require"
)
var bs = &pkgMsg.BodyStructure{} //nolint[gochecknoglobals]
const testUID = "testmsg"
func TestSaveAndLoad(t *testing.T) {
msg := []byte("Test message")
SaveMail(testUID, msg, bs)
require.Equal(t, mailCache[testUID].data, msg)
reader, _ := LoadMail(testUID)
require.Equal(t, reader.Len(), len(msg))
stored := make([]byte, len(msg))
_, _ = reader.Read(stored)
require.Equal(t, stored, msg)
}
func TestMissing(t *testing.T) {
reader, _ := LoadMail("non-existing")
require.Equal(t, reader.Len(), 0)
}
func TestClearOld(t *testing.T) {
cacheTimeLimit = 10
msg := []byte("Test message")
SaveMail(testUID, msg, bs)
time.Sleep(100 * time.Millisecond)
reader, _ := LoadMail(testUID)
require.Equal(t, reader.Len(), 0)
}
func TestClearBig(t *testing.T) {
r := require.New(t)
wantMessage := []byte("Test message")
wantCacheSize := 3
nTestMessages := wantCacheSize * wantCacheSize
cacheSizeLimit = wantCacheSize*len(wantMessage) + 1
cacheTimeLimit = int64(1 << 20) // be sure the message will survive
// It should never have more than nSize items.
for i := 0; i < nTestMessages; i++ {
time.Sleep(1 * time.Millisecond)
SaveMail(fmt.Sprintf("%s%d", testUID, i), wantMessage, bs)
r.LessOrEqual(len(mailCache), wantCacheSize, "cache too big when %d", i)
}
// Check that the oldest are deleted first.
for i := 0; i < nTestMessages; i++ {
iUID := fmt.Sprintf("%s%d", testUID, i)
reader, _ := LoadMail(iUID)
mail := mailCache[iUID]
if i < (nTestMessages - wantCacheSize) {
r.Zero(reader.Len(), "LoadMail should return empty, but have %s for %s time %d ", string(mail.data), iUID, mail.key.Timestamp)
} else {
stored := make([]byte, len(wantMessage))
_, err := reader.Read(stored)
r.NoError(err)
r.Equal(wantMessage, stored, "LoadMail returned wrong message: %s for %s time %d", stored, iUID, mail.key.Timestamp)
}
}
}
func TestConcurency(t *testing.T) {
msg := []byte("Test message")
for i := 0; i < 10; i++ {
go SaveMail(fmt.Sprintf("%s%d", testUID, i), msg, bs)
}
}

View File

@ -37,12 +37,10 @@ type imapMailbox struct {
storeUser storeUserProvider storeUser storeUserProvider
storeAddress storeAddressProvider storeAddress storeAddressProvider
storeMailbox storeMailboxProvider storeMailbox storeMailboxProvider
builder *message.Builder
} }
// newIMAPMailbox returns struct implementing go-imap/mailbox interface. // newIMAPMailbox returns struct implementing go-imap/mailbox interface.
func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider, builder *message.Builder) *imapMailbox { func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider) *imapMailbox {
return &imapMailbox{ return &imapMailbox{
panicHandler: panicHandler, panicHandler: panicHandler,
user: user, user: user,
@ -56,8 +54,6 @@ func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox stor
storeUser: user.storeUser, storeUser: user.storeUser,
storeAddress: user.storeAddress, storeAddress: user.storeAddress,
storeMailbox: storeMailbox, storeMailbox: storeMailbox,
builder: builder,
} }
} }

View File

@ -19,21 +19,13 @@ package imap
import ( import (
"bytes" "bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap" "github.com/emersion/go-imap"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
) )
func (im *imapMailbox) getMessage( func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []imap.FetchItem) (msg *imap.Message, err error) {
storeMessage storeMessageProvider,
items []imap.FetchItem,
msgBuildCountHistogram *msgBuildCountHistogram,
) (msg *imap.Message, err error) {
msglog := im.log.WithField("msgID", storeMessage.ID()) msglog := im.log.WithField("msgID", storeMessage.ID())
msglog.Trace("Getting message") msglog.Trace("Getting message")
@ -69,9 +61,12 @@ func (im *imapMailbox) getMessage(
// There is no point having message older than RFC itself, it's not possible. // There is no point having message older than RFC itself, it's not possible.
msg.InternalDate = message.SanitizeMessageDate(m.Time) msg.InternalDate = message.SanitizeMessageDate(m.Time)
case imap.FetchRFC822Size: case imap.FetchRFC822Size:
if msg.Size, err = im.getSize(storeMessage); err != nil { size, err := storeMessage.GetRFC822Size()
if err != nil {
return nil, err return nil, err
} }
msg.Size = size
case imap.FetchUid: case imap.FetchUid:
if msg.Uid, err = storeMessage.UID(); err != nil { if msg.Uid, err = storeMessage.UID(); err != nil {
return nil, err return nil, err
@ -79,7 +74,7 @@ func (im *imapMailbox) getMessage(
case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text: case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text:
fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests
default: default:
if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil { if err = im.getLiteralForSection(item, msg, storeMessage); err != nil {
return return
} }
} }
@ -88,35 +83,7 @@ func (im *imapMailbox) getMessage(
return msg, err return msg, err
} }
// getSize returns cached size or it will build the message, save the size in func (im *imapMailbox) getLiteralForSection(itemSection imap.FetchItem, msg *imap.Message, storeMessage storeMessageProvider) error {
// DB and then returns the size after build.
//
// We are storing size in DB as part of pmapi messages metada. The size
// attribute on the server represents size of encrypted body. The value is
// cleared in Bridge and the final decrypted size (including header, attachment
// and MIME structure) is computed after building the message.
func (im *imapMailbox) getSize(storeMessage storeMessageProvider) (uint32, error) {
m := storeMessage.Message()
if m.Size <= 0 {
im.log.WithField("msgID", m.ID).Debug("Size unknown - downloading body")
// We are sure the size is not a problem right now. Clients
// might not first check sizes of all messages so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude getting size from the
// counting and see build count as real message build.
if _, _, err := im.getBodyAndStructure(storeMessage, nil); err != nil {
return 0, err
}
}
return uint32(m.Size), nil
}
func (im *imapMailbox) getLiteralForSection(
itemSection imap.FetchItem,
msg *imap.Message,
storeMessage storeMessageProvider,
msgBuildCountHistogram *msgBuildCountHistogram,
) error {
section, err := imap.ParseBodySectionName(itemSection) section, err := imap.ParseBodySectionName(itemSection)
if err != nil { if err != nil {
log.WithError(err).Warn("Failed to parse body section name; part will be skipped") log.WithError(err).Warn("Failed to parse body section name; part will be skipped")
@ -124,7 +91,7 @@ func (im *imapMailbox) getLiteralForSection(
} }
var literal imap.Literal var literal imap.Literal
if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil { if literal, err = im.getMessageBodySection(storeMessage, section); err != nil {
return err return err
} }
@ -149,88 +116,25 @@ func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs *
// be sure if seeing 1st or 2nd sync is all right or not. // be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude first body structure fetch // Therefore, it's better to exclude first body structure fetch
// from the counting and see build count as real message build. // from the counting and see build count as real message build.
if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil { if bs, _, err = im.getBodyAndStructure(storeMessage); err != nil {
return return
} }
} }
return return
} }
func (im *imapMailbox) getBodyAndStructure( func (im *imapMailbox) getBodyAndStructure(storeMessage storeMessageProvider) (*message.BodyStructure, *bytes.Reader, error) {
storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram, rfc822, err := storeMessage.GetRFC822()
) ( if err != nil {
structure *message.BodyStructure, bodyReader *bytes.Reader, err error, return nil, nil, err
) {
m := storeMessage.Message()
id := im.storeUser.UserID() + m.ID
cache.BuildLock(id)
defer cache.BuildUnlock(id)
bodyReader, structure = cache.LoadMail(id)
// return the message which was found in cache
if bodyReader.Len() != 0 && structure != nil {
return structure, bodyReader, nil
} }
structure, body, err := im.buildMessage(m) structure, err := storeMessage.GetBodyStructure()
bodyReader = bytes.NewReader(body) if err != nil {
size := int64(len(body)) return nil, nil, err
l := im.log.WithField("newSize", size).WithField("msgID", m.ID)
if err != nil || structure == nil || size == 0 {
l.WithField("hasStructure", structure != nil).Warn("Failed to build message")
return structure, bodyReader, err
} }
// Save the size, body structure and header even for messages which return structure, bytes.NewReader(rfc822), nil
// were unable to decrypt. Hence they doesn't have to be computed every
// time.
m.Size = size
cacheMessageInStore(storeMessage, structure, body, l)
if msgBuildCountHistogram != nil {
times, errCount := storeMessage.IncreaseBuildCount()
if errCount != nil {
l.WithError(errCount).Warn("Cannot increase build count")
}
msgBuildCountHistogram.add(times)
}
// Drafts can change therefore we don't want to cache them.
if !isMessageInDraftFolder(m) {
cache.SaveMail(id, body, structure)
}
return structure, bodyReader, err
}
func cacheMessageInStore(storeMessage storeMessageProvider, structure *message.BodyStructure, body []byte, l *logrus.Entry) {
m := storeMessage.Message()
if errSize := storeMessage.SetSize(m.Size); errSize != nil {
l.WithError(errSize).Warn("Cannot update size while building")
}
if structure != nil && !isMessageInDraftFolder(m) {
if errStruct := storeMessage.SetBodyStructure(structure); errStruct != nil {
l.WithError(errStruct).Warn("Cannot update bodystructure while building")
}
}
header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body))
if errHead == nil && len(header) != 0 {
if errStore := storeMessage.SetHeader(header); errStore != nil {
l.WithError(errStore).Warn("Cannot update header in store")
}
} else {
l.WithError(errHead).Warn("Cannot get header bytes from structure")
}
}
func isMessageInDraftFolder(m *pmapi.Message) bool {
for _, labelID := range m.LabelIDs {
if labelID == pmapi.DraftLabel {
return true
}
}
return false
} }
// This will download message (or read from cache) and pick up the section, // This will download message (or read from cache) and pick up the section,
@ -246,11 +150,7 @@ func isMessageInDraftFolder(m *pmapi.Message) bool {
// For all other cases it is necessary to download and decrypt the message // For all other cases it is necessary to download and decrypt the message
// and drop the header which was obtained from cache. The header will // and drop the header which was obtained from cache. The header will
// will be stored in DB once successfully built. Check `getBodyAndStructure`. // will be stored in DB once successfully built. Check `getBodyAndStructure`.
func (im *imapMailbox) getMessageBodySection( func (im *imapMailbox) getMessageBodySection(storeMessage storeMessageProvider, section *imap.BodySectionName) (imap.Literal, error) {
storeMessage storeMessageProvider,
section *imap.BodySectionName,
msgBuildCountHistogram *msgBuildCountHistogram,
) (imap.Literal, error) {
var header []byte var header []byte
var response []byte var response []byte
@ -260,7 +160,7 @@ func (im *imapMailbox) getMessageBodySection(
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() { if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
header = storeMessage.GetHeader() header = storeMessage.GetHeader()
} else { } else {
structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram) structure, bodyReader, err := im.getBodyAndStructure(storeMessage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -276,7 +176,7 @@ func (im *imapMailbox) getMessageBodySection(
case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part. case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part.
fallthrough fallthrough
case section.Specifier == imap.HeaderSpecifier: case section.Specifier == imap.HeaderSpecifier:
header, err = structure.GetSectionHeaderBytes(bodyReader, section.Path) header, err = structure.GetSectionHeaderBytes(section.Path)
default: default:
err = errors.New("Unknown specifier " + string(section.Specifier)) err = errors.New("Unknown specifier " + string(section.Specifier))
} }
@ -293,30 +193,3 @@ func (im *imapMailbox) getMessageBodySection(
// Trim any output if requested. // Trim any output if requested.
return bytes.NewBuffer(section.ExtractPartial(response)), nil return bytes.NewBuffer(section.ExtractPartial(response)), nil
} }
// buildMessage from PM to IMAP.
func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) {
body, err := im.builder.NewJobWithOptions(
context.Background(),
im.user.client(),
m.ID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
).GetResult()
if err != nil {
return nil, nil, err
}
structure, err := message.NewBodyStructure(bytes.NewReader(body))
if err != nil {
return nil, nil, err
}
return structure, body, nil
}

View File

@ -479,11 +479,16 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
} }
// Filter by size (only if size was already calculated). // Filter by size (only if size was already calculated).
if m.Size > 0 { size, err := storeMessage.GetRFC822Size()
if criteria.Larger != 0 && m.Size <= int64(criteria.Larger) { if err != nil {
return nil, err
}
if size > 0 {
if criteria.Larger != 0 && int64(size) <= int64(criteria.Larger) {
continue continue
} }
if criteria.Smaller != 0 && m.Size >= int64(criteria.Smaller) { if criteria.Smaller != 0 && int64(size) >= int64(criteria.Smaller) {
continue continue
} }
} }
@ -513,13 +518,12 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
// //
// Messages must be sent to msgResponse. When the function returns, msgResponse must be closed. // Messages must be sent to msgResponse. When the function returns, msgResponse must be closed.
func (im *imapMailbox) ListMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) error { func (im *imapMailbox) ListMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) error {
msgBuildCountHistogram := newMsgBuildCountHistogram()
return im.logCommand(func() error { return im.logCommand(func() error {
return im.listMessages(isUID, seqSet, items, msgResponse, msgBuildCountHistogram) return im.listMessages(isUID, seqSet, items, msgResponse)
}, "FETCH", isUID, seqSet, items, msgBuildCountHistogram) }, "FETCH", isUID, seqSet, items)
} }
func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message, msgBuildCountHistogram *msgBuildCountHistogram) (err error) { //nolint[funlen] func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) (err error) { //nolint[funlen]
defer func() { defer func() {
close(msgResponse) close(msgResponse)
if err != nil { if err != nil {
@ -564,7 +568,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err return nil, err
} }
msg, err := im.getMessage(storeMessage, items, msgBuildCountHistogram) msg, err := im.getMessage(storeMessage, items)
if err != nil { if err != nil {
err = fmt.Errorf("list message build: %v", err) err = fmt.Errorf("list message build: %v", err)
l.WithField("metaID", storeMessage.ID()).Error(err) l.WithField("metaID", storeMessage.ID()).Error(err)
@ -594,7 +598,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil return nil
} }
err = parallel.RunParallel(fetchWorkers, input, processCallback, collectCallback) err = parallel.RunParallel(im.user.backend.listWorkers, input, processCallback, collectCallback)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,65 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"fmt"
"sync"
)
// msgBuildCountHistogram is used to analyse and log the number of repetitive
// downloads of requested messages per one fetch. The number of builds per each
// messageID is stored in persistent database. The msgBuildCountHistogram will
// take this number for each message in ongoing fetch and create histogram of
// repeats.
//
// Example: During `fetch 1:300` there were
// - 100 messages were downloaded first time
// - 100 messages were downloaded second time
// - 99 messages were downloaded 10th times
// - 1 messages were downloaded 100th times.
type msgBuildCountHistogram struct {
// Key represents how many times message was build.
// Value stores how many messages are build X times based on the key.
counts map[uint32]uint32
lock sync.Locker
}
func newMsgBuildCountHistogram() *msgBuildCountHistogram {
return &msgBuildCountHistogram{
counts: map[uint32]uint32{},
lock: &sync.Mutex{},
}
}
func (c *msgBuildCountHistogram) String() string {
res := ""
for nRebuild, counts := range c.counts {
if res != "" {
res += ", "
}
res += fmt.Sprintf("[%d]:%d", nRebuild, counts)
}
return res
}
func (c *msgBuildCountHistogram) add(nRebuild uint32) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[nRebuild]++
}

View File

@ -80,7 +80,6 @@ type storeMailboxProvider interface {
GetDelimiter() string GetDelimiter() string
GetMessage(apiID string) (storeMessageProvider, error) GetMessage(apiID string) (storeMessageProvider, error)
FetchMessage(apiID string) (storeMessageProvider, error)
LabelMessages(apiID []string) error LabelMessages(apiID []string) error
UnlabelMessages(apiID []string) error UnlabelMessages(apiID []string) error
MarkMessagesRead(apiID []string) error MarkMessagesRead(apiID []string) error
@ -100,14 +99,12 @@ type storeMessageProvider interface {
Message() *pmapi.Message Message() *pmapi.Message
IsMarkedDeleted() bool IsMarkedDeleted() bool
SetSize(int64) error
SetHeader([]byte) error
GetHeader() []byte GetHeader() []byte
GetRFC822() ([]byte, error)
GetRFC822Size() (uint32, error)
GetMIMEHeader() textproto.MIMEHeader GetMIMEHeader() textproto.MIMEHeader
IsFullHeaderCached() bool IsFullHeaderCached() bool
SetBodyStructure(*pkgMsg.BodyStructure) error
GetBodyStructure() (*pkgMsg.BodyStructure, error) GetBodyStructure() (*pkgMsg.BodyStructure, error)
IncreaseBuildCount() (uint32, error)
} }
type storeUserWrap struct { type storeUserWrap struct {
@ -165,7 +162,3 @@ func newStoreMailboxWrap(mailbox *store.Mailbox) *storeMailboxWrap {
func (s *storeMailboxWrap) GetMessage(apiID string) (storeMessageProvider, error) { func (s *storeMailboxWrap) GetMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.GetMessage(apiID) return s.Mailbox.GetMessage(apiID)
} }
func (s *storeMailboxWrap) FetchMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.FetchMessage(apiID)
}

View File

@ -135,7 +135,7 @@ func (iu *imapUser) ListMailboxes(showOnlySubcribed bool) ([]goIMAPBackend.Mailb
if showOnlySubcribed && !iu.isSubscribed(storeMailbox.LabelID()) { if showOnlySubcribed && !iu.isSubscribed(storeMailbox.LabelID()) {
continue continue
} }
mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder) mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox)
mailboxes = append(mailboxes, mailbox) mailboxes = append(mailboxes, mailbox)
} }
@ -167,7 +167,7 @@ func (iu *imapUser) GetMailbox(name string) (mb goIMAPBackend.Mailbox, err error
return return
} }
return newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder), nil return newIMAPMailbox(iu.panicHandler, iu, storeMailbox), nil
} }
// CreateMailbox creates a new mailbox. // CreateMailbox creates a new mailbox.

View File

@ -18,99 +18,113 @@
package store package store
import ( import (
"encoding/json" "github.com/ProtonMail/gopenpgp/v2/crypto"
"os" "github.com/ProtonMail/proton-bridge/pkg/message"
"sync" "github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
"github.com/pkg/errors"
) )
// Cache caches the last event IDs for all accounts (there should be only one instance). const passphraseKey = "passphrase"
type Cache struct {
// cache is map from userID => key (such as last event) => value (such as event ID).
cache map[string]map[string]string
path string
lock *sync.RWMutex
}
// NewCache constructs a new cache at the given path. // UnlockCache unlocks the cache for the user with the given keyring.
func NewCache(path string) *Cache { func (store *Store) UnlockCache(kr *crypto.KeyRing) error {
return &Cache{ passphrase, err := store.getCachePassphrase()
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Cache) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.loadCache(); err != nil {
log.WithError(err).Warn("Problem to load store cache")
}
if c.cache == nil {
c.cache = map[string]map[string]string{}
}
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
return c.cache[userID]["events"]
}
func (c *Cache) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
c.cache[userID]["events"] = eventID
return c.saveCache()
}
func (c *Cache) loadCache() error {
if c.cache != nil {
return nil
}
f, err := os.Open(c.path)
if err != nil { if err != nil {
return err return err
} }
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.cache) if passphrase == nil {
} if passphrase, err = crypto.RandomToken(32); err != nil {
return err
}
func (c *Cache) saveCache() error { enc, err := kr.Encrypt(crypto.NewPlainMessage(passphrase), nil)
if c.cache == nil { if err != nil {
return errors.New("events: cannot save cache: cache is nil") return err
}
if err := store.setCachePassphrase(enc.GetBinary()); err != nil {
return err
}
} else {
dec, err := kr.Decrypt(crypto.NewPGPMessage(passphrase), nil, crypto.GetUnixTime())
if err != nil {
return err
}
passphrase = dec.GetBinary()
} }
f, err := os.Create(c.path) if err := store.cache.Unlock(store.user.ID(), passphrase); err != nil {
return err
}
store.cacher.start()
return nil
}
func (store *Store) getCachePassphrase() ([]byte, error) {
var passphrase []byte
if err := store.db.View(func(tx *bolt.Tx) error {
passphrase = tx.Bucket(cachePassphraseBucket).Get([]byte(passphraseKey))
return nil
}); err != nil {
return nil, err
}
return passphrase, nil
}
func (store *Store) setCachePassphrase(passphrase []byte) error {
return store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(cachePassphraseBucket).Put([]byte(passphraseKey), passphrase)
})
}
func (store *Store) clearCachePassphrase() error {
return store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(cachePassphraseBucket).Delete([]byte(passphraseKey))
})
}
func (store *Store) getCachedMessage(messageID string) ([]byte, error) {
if store.cache.Has(store.user.ID(), messageID) {
return store.cache.Get(store.user.ID(), messageID)
}
job, done := store.newBuildJob(messageID, message.ForegroundPriority)
defer done()
literal, err := job.GetResult()
if err != nil {
return nil, err
}
// NOTE(GODT-1158): No need to block until cache has been set; do this async?
if err := store.cache.Set(store.user.ID(), messageID, literal); err != nil {
logrus.WithError(err).Error("Failed to cache message")
}
return literal, nil
}
// IsCached returns whether the given message already exists in the cache.
func (store *Store) IsCached(messageID string) bool {
return store.cache.Has(store.user.ID(), messageID)
}
// BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache.
// It builds with background priority.
func (store *Store) BuildAndCacheMessage(messageID string) error {
job, done := store.newBuildJob(messageID, message.BackgroundPriority)
defer done()
literal, err := job.GetResult()
if err != nil { if err != nil {
return err return err
} }
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.cache) return store.cache.Set(store.user.ID(), messageID, literal)
}
func (c *Cache) clearCacheUser(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache == nil {
log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user from event loop cache")
delete(c.cache, userID)
return c.saveCache()
} }

73
internal/store/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,73 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOnDiskCacheNoCompression(t *testing.T) {
cache, err := NewOnDiskCache(t.TempDir(), &NoopCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
require.NoError(t, err)
testCache(t, cache)
}
func TestOnDiskCacheGZipCompression(t *testing.T) {
cache, err := NewOnDiskCache(t.TempDir(), &GZipCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
require.NoError(t, err)
testCache(t, cache)
}
func TestInMemoryCache(t *testing.T) {
testCache(t, NewInMemoryCache(1<<20))
}
func testCache(t *testing.T, cache Cache) {
assert.NoError(t, cache.Unlock("userID1", []byte("my secret passphrase")))
assert.NoError(t, cache.Unlock("userID2", []byte("my other passphrase")))
getSetCachedMessage(t, cache, "userID1", "messageID1", "some secret")
assert.True(t, cache.Has("userID1", "messageID1"))
getSetCachedMessage(t, cache, "userID2", "messageID2", "some other secret")
assert.True(t, cache.Has("userID2", "messageID2"))
assert.NoError(t, cache.Rem("userID1", "messageID1"))
assert.False(t, cache.Has("userID1", "messageID1"))
assert.NoError(t, cache.Rem("userID2", "messageID2"))
assert.False(t, cache.Has("userID2", "messageID2"))
assert.NoError(t, cache.Delete("userID1"))
assert.NoError(t, cache.Delete("userID2"))
}
func getSetCachedMessage(t *testing.T, cache Cache, userID, messageID, secret string) {
assert.NoError(t, cache.Set(userID, messageID, []byte(secret)))
data, err := cache.Get(userID, messageID)
assert.NoError(t, err)
assert.Equal(t, []byte(secret), data)
}

33
internal/store/cache/compressor.go vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Compressor interface {
Compress([]byte) ([]byte, error)
Decompress([]byte) ([]byte, error)
}
type NoopCompressor struct{}
func (NoopCompressor) Compress(dec []byte) ([]byte, error) {
return dec, nil
}
func (NoopCompressor) Decompress(cmp []byte) ([]byte, error) {
return cmp, nil
}

60
internal/store/cache/compressor_gzip.go vendored Normal file
View File

@ -0,0 +1,60 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"bytes"
"compress/gzip"
)
type GZipCompressor struct{}
func (GZipCompressor) Compress(dec []byte) ([]byte, error) {
buf := new(bytes.Buffer)
zw := gzip.NewWriter(buf)
if _, err := zw.Write(dec); err != nil {
return nil, err
}
if err := zw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (GZipCompressor) Decompress(cmp []byte) ([]byte, error) {
zr, err := gzip.NewReader(bytes.NewReader(cmp))
if err != nil {
return nil, err
}
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(zr); err != nil {
return nil, err
}
if err := zr.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

244
internal/store/cache/disk.go vendored Normal file
View File

@ -0,0 +1,244 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"errors"
"io/ioutil"
"os"
"path/filepath"
"sync"
"github.com/ProtonMail/proton-bridge/pkg/semaphore"
"github.com/ricochet2200/go-disk-usage/du"
)
var ErrLowSpace = errors.New("not enough free space left on device")
type onDiskCache struct {
path string
opts Options
gcm map[string]cipher.AEAD
cmp Compressor
rsem, wsem semaphore.Semaphore
pending *pending
diskSize uint64
diskFree uint64
once *sync.Once
lock sync.Mutex
}
func NewOnDiskCache(path string, cmp Compressor, opts Options) (Cache, error) {
if err := os.MkdirAll(path, 0700); err != nil {
return nil, err
}
usage := du.NewDiskUsage(path)
// NOTE(GODT-1158): use Available() or Free()?
return &onDiskCache{
path: path,
opts: opts,
gcm: make(map[string]cipher.AEAD),
cmp: cmp,
rsem: semaphore.New(opts.ConcurrentRead),
wsem: semaphore.New(opts.ConcurrentWrite),
pending: newPending(),
diskSize: usage.Size(),
diskFree: usage.Available(),
once: &sync.Once{},
}, nil
}
func (c *onDiskCache) Unlock(userID string, passphrase []byte) error {
hash := sha256.New()
if _, err := hash.Write(passphrase); err != nil {
return err
}
aes, err := aes.NewCipher(hash.Sum(nil))
if err != nil {
return err
}
gcm, err := cipher.NewGCM(aes)
if err != nil {
return err
}
if err := os.MkdirAll(c.getUserPath(userID), 0700); err != nil {
return err
}
c.gcm[userID] = gcm
return nil
}
func (c *onDiskCache) Delete(userID string) error {
defer c.update()
return os.RemoveAll(c.getUserPath(userID))
}
// Has returns whether the given message exists in the cache.
func (c *onDiskCache) Has(userID, messageID string) bool {
c.pending.wait(c.getMessagePath(userID, messageID))
c.rsem.Lock()
defer c.rsem.Unlock()
_, err := os.Stat(c.getMessagePath(userID, messageID))
switch {
case err == nil:
return true
case os.IsNotExist(err):
return false
default:
panic(err)
}
}
func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) {
enc, err := c.readFile(c.getMessagePath(userID, messageID))
if err != nil {
return nil, err
}
cmp, err := c.gcm[userID].Open(nil, enc[:c.gcm[userID].NonceSize()], enc[c.gcm[userID].NonceSize():], nil)
if err != nil {
return nil, err
}
return c.cmp.Decompress(cmp)
}
func (c *onDiskCache) Set(userID, messageID string, literal []byte) error {
nonce := make([]byte, c.gcm[userID].NonceSize())
if _, err := rand.Read(nonce); err != nil {
return err
}
cmp, err := c.cmp.Compress(literal)
if err != nil {
return err
}
// NOTE(GODT-1158): How to properly handle low space? Don't return error, that's bad. Instead send event?
if !c.hasSpace(len(cmp)) {
return nil
}
return c.writeFile(c.getMessagePath(userID, messageID), c.gcm[userID].Seal(nonce, nonce, cmp, nil))
}
func (c *onDiskCache) Rem(userID, messageID string) error {
defer c.update()
return os.Remove(c.getMessagePath(userID, messageID))
}
func (c *onDiskCache) readFile(path string) ([]byte, error) {
c.rsem.Lock()
defer c.rsem.Unlock()
// Wait before reading in case the file is currently being written.
c.pending.wait(path)
return ioutil.ReadFile(filepath.Clean(path))
}
func (c *onDiskCache) writeFile(path string, b []byte) error {
c.wsem.Lock()
defer c.wsem.Unlock()
// Mark the file as currently being written.
// If it's already being written, wait for it to be done and return nil.
// NOTE(GODT-1158): Let's hope it succeeded...
if ok := c.pending.add(path); !ok {
c.pending.wait(path)
return nil
}
defer c.pending.done(path)
// Reduce the approximate free space (update it exactly later).
c.lock.Lock()
c.diskFree -= uint64(len(b))
c.lock.Unlock()
// Update the diskFree eventually.
defer c.update()
// NOTE(GODT-1158): What happens when this fails? Should be fixed eventually.
return ioutil.WriteFile(filepath.Clean(path), b, 0600)
}
func (c *onDiskCache) hasSpace(size int) bool {
c.lock.Lock()
defer c.lock.Unlock()
if c.opts.MinFreeAbs > 0 {
if c.diskFree-uint64(size) < c.opts.MinFreeAbs {
return false
}
}
if c.opts.MinFreeRat > 0 {
if float64(c.diskFree-uint64(size))/float64(c.diskSize) < c.opts.MinFreeRat {
return false
}
}
return true
}
func (c *onDiskCache) update() {
go func() {
c.once.Do(func() {
c.lock.Lock()
defer c.lock.Unlock()
// Update the free space.
c.diskFree = du.NewDiskUsage(c.path).Available()
// Reset the Once object (so we can update again).
c.once = &sync.Once{}
})
}()
}
func (c *onDiskCache) getUserPath(userID string) string {
return filepath.Join(c.path, getHash(userID))
}
func (c *onDiskCache) getMessagePath(userID, messageID string) string {
return filepath.Join(c.getUserPath(userID), getHash(messageID))
}

33
internal/store/cache/hash.go vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"crypto/sha256"
"encoding/hex"
)
func getHash(name string) string {
hash := sha256.New()
if _, err := hash.Write([]byte(name)); err != nil {
panic(err)
}
return hex.EncodeToString(hash.Sum(nil))
}

104
internal/store/cache/memory.go vendored Normal file
View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"errors"
"sync"
)
type inMemoryCache struct {
lock sync.RWMutex
data map[string]map[string][]byte
size, limit int
}
// NewInMemoryCache creates a new in memory cache which stores up to the given number of bytes of cached data.
// NOTE(GODT-1158): Make this threadsafe.
func NewInMemoryCache(limit int) Cache {
return &inMemoryCache{
data: make(map[string]map[string][]byte),
limit: limit,
}
}
func (c *inMemoryCache) Unlock(userID string, passphrase []byte) error {
c.data[userID] = make(map[string][]byte)
return nil
}
func (c *inMemoryCache) Delete(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
for _, message := range c.data[userID] {
c.size -= len(message)
}
delete(c.data, userID)
return nil
}
// Has returns whether the given message exists in the cache.
func (c *inMemoryCache) Has(userID, messageID string) bool {
if _, err := c.Get(userID, messageID); err != nil {
return false
}
return true
}
func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) {
c.lock.RLock()
defer c.lock.RUnlock()
literal, ok := c.data[userID][messageID]
if !ok {
return nil, errors.New("no such message in cache")
}
return literal, nil
}
// NOTE(GODT-1158): What to actually do when memory limit is reached? Replace something existing? Return error? Drop silently?
// NOTE(GODT-1158): Pull in cache-rotating feature from old IMAP cache.
func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.size+len(literal) > c.limit {
return nil
}
c.size += len(literal)
c.data[userID][messageID] = literal
return nil
}
func (c *inMemoryCache) Rem(userID, messageID string) error {
c.lock.Lock()
defer c.lock.Unlock()
c.size -= len(c.data[userID][messageID])
delete(c.data[userID], messageID)
return nil
}

25
internal/store/cache/options.go vendored Normal file
View File

@ -0,0 +1,25 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Options struct {
MinFreeAbs uint64
MinFreeRat float64
ConcurrentRead int
ConcurrentWrite int
}

61
internal/store/cache/pending.go vendored Normal file
View File

@ -0,0 +1,61 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import "sync"
type pending struct {
lock sync.Mutex
path map[string]chan struct{}
}
func newPending() *pending {
return &pending{path: make(map[string]chan struct{})}
}
func (p *pending) add(path string) bool {
p.lock.Lock()
defer p.lock.Unlock()
if _, ok := p.path[path]; ok {
return false
}
p.path[path] = make(chan struct{})
return true
}
func (p *pending) wait(path string) {
p.lock.Lock()
ch, ok := p.path[path]
p.lock.Unlock()
if ok {
<-ch
}
}
func (p *pending) done(path string) {
p.lock.Lock()
defer p.lock.Unlock()
defer close(p.path[path])
delete(p.path, path)
}

51
internal/store/cache/pending_test.go vendored Normal file
View File

@ -0,0 +1,51 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPending(t *testing.T) {
pending := newPending()
pending.add("1")
pending.add("2")
pending.add("3")
resCh := make(chan string)
go func() { pending.wait("1"); resCh <- "1" }()
go func() { pending.wait("2"); resCh <- "2" }()
go func() { pending.wait("3"); resCh <- "3" }()
pending.done("1")
assert.Equal(t, "1", <-resCh)
pending.done("2")
assert.Equal(t, "2", <-resCh)
pending.done("3")
assert.Equal(t, "3", <-resCh)
}
func TestPendingUnknown(t *testing.T) {
newPending().wait("this is not currently being waited")
}

28
internal/store/cache/types.go vendored Normal file
View File

@ -0,0 +1,28 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Cache interface {
Unlock(userID string, passphrase []byte) error
Delete(userID string) error
Has(userID, messageID string) bool
Get(userID, messageID string) ([]byte, error)
Set(userID, messageID string, literal []byte) error
Rem(userID, messageID string) error
}

View File

@ -0,0 +1,63 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import "time"
func (store *Store) StartWatcher() {
store.done = make(chan struct{})
go func() {
ticker := time.NewTicker(3 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// NOTE(GODT-1158): Race condition here? What if DB was already closed?
messageIDs, err := store.getAllMessageIDs()
if err != nil {
return
}
for _, messageID := range messageIDs {
if !store.IsCached(messageID) {
store.cacher.newJob(messageID)
}
}
case <-store.done:
return
}
}
}()
}
func (store *Store) stopWatcher() {
if store.done == nil {
return
}
select {
default:
close(store.done)
case <-store.done:
return
}
}

View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sync"
"github.com/sirupsen/logrus"
)
type Cacher struct {
storer Storer
jobs chan string
done chan struct{}
started bool
wg *sync.WaitGroup
}
type Storer interface {
IsCached(messageID string) bool
BuildAndCacheMessage(messageID string) error
}
func newCacher(storer Storer) *Cacher {
return &Cacher{
storer: storer,
jobs: make(chan string),
done: make(chan struct{}),
wg: &sync.WaitGroup{},
}
}
// newJob sends a new job to the cacher if it's running.
func (cacher *Cacher) newJob(messageID string) {
if !cacher.started {
return
}
select {
case <-cacher.done:
return
default:
if !cacher.storer.IsCached(messageID) {
cacher.wg.Add(1)
go func() { cacher.jobs <- messageID }()
}
}
}
func (cacher *Cacher) start() {
cacher.started = true
go func() {
for {
select {
case messageID := <-cacher.jobs:
go cacher.handleJob(messageID)
case <-cacher.done:
return
}
}
}()
}
func (cacher *Cacher) handleJob(messageID string) {
defer cacher.wg.Done()
if err := cacher.storer.BuildAndCacheMessage(messageID); err != nil {
logrus.WithError(err).Error("Failed to build and cache message")
} else {
logrus.WithField("messageID", messageID).Trace("Message cached")
}
}
func (cacher *Cacher) stop() {
cacher.started = false
cacher.wg.Wait()
select {
case <-cacher.done:
return
default:
close(cacher.done)
}
}

View File

@ -0,0 +1,103 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"testing"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
)
func withTestCacher(t *testing.T, doTest func(storer *storemocks.MockStorer, cacher *Cacher)) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Mock storer used to build/cache messages.
storer := storemocks.NewMockStorer(ctrl)
// Create a new cacher pointing to the fake store.
cacher := newCacher(storer)
// Start the cacher and wait for it to stop.
cacher.start()
defer cacher.stop()
doTest(storer, cacher)
}
func TestCacher(t *testing.T) {
// If the message is not yet cached, we should expect to try to build and cache it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
cacher.newJob("messageID")
})
}
func TestCacherAlreadyCached(t *testing.T) {
// If the message is already cached, we should not try to build it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(true)
cacher.newJob("messageID")
})
}
func TestCacherFail(t *testing.T) {
// If building the message fails, we should not try to cache it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(errors.New("failed to build message"))
cacher.newJob("messageID")
})
}
func TestCacherStop(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Mock storer used to build/cache messages.
storer := storemocks.NewMockStorer(ctrl)
// Create a new cacher pointing to the fake store.
cacher := newCacher(storer)
// Start the cacher.
cacher.start()
// Send a job -- this should succeed.
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
cacher.newJob("messageID")
// Stop the cacher.
cacher.stop()
// Send more jobs -- these should all be dropped.
cacher.newJob("messageID2")
cacher.newJob("messageID3")
cacher.newJob("messageID4")
cacher.newJob("messageID5")
// Stopping the cacher multiple times is safe.
cacher.stop()
cacher.stop()
cacher.stop()
cacher.stop()
}

View File

@ -34,7 +34,7 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false) m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false) m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier) m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@ -49,7 +49,7 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false) m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false) m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier) m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@ -61,7 +61,7 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})

View File

@ -38,7 +38,7 @@ const (
) )
type eventLoop struct { type eventLoop struct {
cache *Cache currentEvents *Events
currentEventID string currentEventID string
currentEvent *pmapi.Event currentEvent *pmapi.Event
pollCh chan chan struct{} pollCh chan chan struct{}
@ -51,26 +51,26 @@ type eventLoop struct {
log *logrus.Entry log *logrus.Entry
store *Store store *Store
user BridgeUser user BridgeUser
events listener.Listener listener listener.Listener
} }
func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop { func newEventLoop(currentEvents *Events, store *Store, user BridgeUser, listener listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID()) eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop") eventLog.Trace("Creating new event loop")
return &eventLoop{ return &eventLoop{
cache: cache, currentEvents: currentEvents,
currentEventID: cache.getEventID(user.ID()), currentEventID: currentEvents.getEventID(user.ID()),
pollCh: make(chan chan struct{}), pollCh: make(chan chan struct{}),
isRunning: false, isRunning: false,
log: eventLog, log: eventLog,
store: store, store: store,
user: user, user: user,
events: events, listener: listener,
} }
} }
@ -89,7 +89,7 @@ func (loop *eventLoop) setFirstEventID() (err error) {
loop.currentEventID = event.EventID loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil { if err = loop.currentEvents.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
loop.log.WithError(err).Error("Could not set latest event ID in user cache") loop.log.WithError(err).Error("Could not set latest event ID in user cache")
return return
} }
@ -229,7 +229,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
if err != nil && isFdCloseToULimit() { if err != nil && isFdCloseToULimit() {
l.Warn("Ulimit reached") l.Warn("Ulimit reached")
loop.events.Emit(bridgeEvents.RestartBridgeEvent, "") loop.listener.Emit(bridgeEvents.RestartBridgeEvent, "")
err = nil err = nil
} }
@ -291,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// This allows the event loop to continue to function (unless the cache was broken // This allows the event loop to continue to function (unless the cache was broken
// and bridge stopped, in which case it will start from the old event ID anyway). // and bridge stopped, in which case it will start from the old event ID anyway).
loop.currentEventID = event.EventID loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil { if err = loop.currentEvents.setEventID(loop.user.ID(), event.EventID); err != nil {
return false, errors.Wrap(err, "failed to save event ID to cache") return false, errors.Wrap(err, "failed to save event ID to cache")
} }
} }
@ -371,7 +371,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
switch addressEvent.Action { switch addressEvent.Action {
case pmapi.EventCreate: case pmapi.EventCreate:
log.WithField("email", addressEvent.Address.Email).Debug("Address was created") log.WithField("email", addressEvent.Address.Email).Debug("Address was created")
loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress()) loop.listener.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
case pmapi.EventUpdate: case pmapi.EventUpdate:
oldAddress := oldList.ByID(addressEvent.ID) oldAddress := oldList.ByID(addressEvent.ID)
@ -383,7 +383,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email email := oldAddress.Email
log.WithField("email", email).Debug("Address was updated") log.WithField("email", email).Debug("Address was updated")
if addressEvent.Address.Receive != oldAddress.Receive { if addressEvent.Address.Receive != oldAddress.Receive {
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email) loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
} }
case pmapi.EventDelete: case pmapi.EventDelete:
@ -396,7 +396,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email email := oldAddress.Email
log.WithField("email", email).Debug("Address was deleted") log.WithField("email", email).Debug("Address was deleted")
loop.user.CloseConnection(email) loop.user.CloseConnection(email)
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email) loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
case pmapi.EventUpdateFlags: case pmapi.EventUpdateFlags:
log.Error("EventUpdateFlags for address event is uknown operation") log.Error("EventUpdateFlags for address event is uknown operation")
} }

View File

@ -53,7 +53,7 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
More: false, More: false,
}, nil), }, nil),
) )
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
// Event loop runs in goroutine started during store creation (newStoreNoEvents). // Event loop runs in goroutine started during store creation (newStoreNoEvents).
// Force to run the next event. // Force to run the next event.
@ -78,7 +78,7 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
subject := "old subject" subject := "old subject"
newSubject := "new subject" newSubject := "new subject"
m.newStoreNoEvents(true, &pmapi.Message{ m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1", ID: "msg1",
Subject: subject, Subject: subject,
}) })
@ -106,7 +106,7 @@ func TestEventLoopDeletionNotPaused(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true, &pmapi.Message{ m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1", ID: "msg1",
Subject: "subject", Subject: "subject",
LabelIDs: []string{"label"}, LabelIDs: []string{"label"},
@ -133,7 +133,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true, &pmapi.Message{ m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1", ID: "msg1",
Subject: "subject", Subject: "subject",
LabelIDs: []string{"label"}, LabelIDs: []string{"label"},

116
internal/store/events.go Normal file
View File

@ -0,0 +1,116 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"os"
"sync"
"github.com/pkg/errors"
)
// Events caches the last event IDs for all accounts (there should be only one instance).
type Events struct {
// eventMap is map from userID => key (such as last event) => value (such as event ID).
eventMap map[string]map[string]string
path string
lock *sync.RWMutex
}
// NewEvents constructs a new event cache at the given path.
func NewEvents(path string) *Events {
return &Events{
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Events) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.loadEvents(); err != nil {
log.WithError(err).Warn("Problem to load store events")
}
if c.eventMap == nil {
c.eventMap = map[string]map[string]string{}
}
if c.eventMap[userID] == nil {
c.eventMap[userID] = map[string]string{}
}
return c.eventMap[userID]["events"]
}
func (c *Events) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.eventMap[userID] == nil {
c.eventMap[userID] = map[string]string{}
}
c.eventMap[userID]["events"] = eventID
return c.saveEvents()
}
func (c *Events) loadEvents() error {
if c.eventMap != nil {
return nil
}
f, err := os.Open(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.eventMap)
}
func (c *Events) saveEvents() error {
if c.eventMap == nil {
return errors.New("events: cannot save events: events map is nil")
}
f, err := os.Create(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.eventMap)
}
func (c *Events) clearUserEvents(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.eventMap == nil {
log.WithField("user", userID).Warning("Cannot clear user events: event map is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user events from event loop")
delete(c.eventMap, userID)
return c.saveEvents()
}

View File

@ -107,7 +107,7 @@ func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Sto
func TestMailboxCountRemove(t *testing.T) { func TestMailboxCountRemove(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
testCounts := []*pmapi.MessagesCount{ testCounts := []*pmapi.MessagesCount{
{LabelID: "label1", Total: 100, Unread: 0}, {LabelID: "label1", Total: 100, Unread: 0},

View File

@ -35,7 +35,7 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@ -80,7 +80,7 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel}) tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))

View File

@ -67,40 +67,19 @@ func (message *Message) Message() *pmapi.Message {
return message.msg return message.msg
} }
// IsMarkedDeleted returns true if message is marked as deleted for specific // IsMarkedDeleted returns true if message is marked as deleted for specific mailbox.
// mailbox.
func (message *Message) IsMarkedDeleted() bool { func (message *Message) IsMarkedDeleted() bool {
isMarkedAsDeleted := false var isMarkedAsDeleted bool
err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
if err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
isMarkedAsDeleted = message.storeMailbox.txGetDeletedIDsBucket(tx).Get([]byte(message.msg.ID)) != nil isMarkedAsDeleted = message.storeMailbox.txGetDeletedIDsBucket(tx).Get([]byte(message.msg.ID)) != nil
return nil return nil
}) }); err != nil {
if err != nil {
message.storeMailbox.log.WithError(err).Error("Not able to retrieve deleted mark, assuming false.") message.storeMailbox.log.WithError(err).Error("Not able to retrieve deleted mark, assuming false.")
return false return false
} }
return isMarkedAsDeleted
}
// SetSize updates the information about size of decrypted message which can be return isMarkedAsDeleted
// used for IMAP. This should not trigger any IMAP update.
// NOTE: The size from the server corresponds to pure body bytes. Hence it
// should not be used. The correct size has to be calculated from decrypted and
// built message.
func (message *Message) SetSize(size int64) error {
message.msg.Size = size
txUpdate := func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
}
stored.Size = size
return message.store.txPutMessage(
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
} }
// SetContentTypeAndHeader updates the information about content type and // SetContentTypeAndHeader updates the information about content type and
@ -112,7 +91,7 @@ func (message *Message) SetSize(size int64) error {
func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error { func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error {
message.msg.MIMEType = mimeType message.msg.MIMEType = mimeType
message.msg.Header = header message.msg.Header = header
txUpdate := func(tx *bolt.Tx) error { return message.store.db.Update(func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID) stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil { if err != nil {
return err return err
@ -123,34 +102,26 @@ func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Hea
tx.Bucket(metadataBucket), tx.Bucket(metadataBucket),
stored, stored,
) )
}
return message.store.db.Update(txUpdate)
}
// SetHeader checks header can be parsed and if yes it stores header bytes in
// database.
func (message *Message) SetHeader(header []byte) error {
_, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(header))).ReadMIMEHeader()
if err != nil {
return err
}
return message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(headersBucket).Put([]byte(message.ID()), header)
}) })
} }
// IsFullHeaderCached will check that valid full header is stored in DB. // IsFullHeaderCached will check that valid full header is stored in DB.
func (message *Message) IsFullHeaderCached() bool { func (message *Message) IsFullHeaderCached() bool {
header, err := message.getRawHeader() var raw []byte
return err == nil && header != nil err := message.store.db.View(func(tx *bolt.Tx) error {
} raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
func (message *Message) getRawHeader() (raw []byte, err error) {
err = message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(headersBucket).Get([]byte(message.ID()))
return nil return nil
}) })
return return err == nil && raw != nil
}
func (message *Message) getRawHeader() ([]byte, error) {
bs, err := message.GetBodyStructure()
if err != nil {
return nil, err
}
return bs.GetMailHeaderBytes()
} }
// GetHeader will return cached header from DB. // GetHeader will return cached header from DB.
@ -178,44 +149,79 @@ func (message *Message) GetMIMEHeader() textproto.MIMEHeader {
return header return header
} }
// SetBodyStructure stores serialized body structure in database. // GetBodyStructure returns the message's body structure.
func (message *Message) SetBodyStructure(bs *pkgMsg.BodyStructure) error { // It checks first if it's in the store. If it is, it returns it from store,
txUpdate := func(tx *bolt.Tx) error { // otherwise it computes it from the message cache (and saves the result to the store).
return message.store.txPutBodyStructure( func (message *Message) GetBodyStructure() (*pkgMsg.BodyStructure, error) {
tx.Bucket(bodystructureBucket), var raw []byte
message.ID(), bs,
)
}
return message.store.db.Update(txUpdate)
}
// GetBodyStructure deserializes body structure from database. If body structure if err := message.store.db.View(func(tx *bolt.Tx) error {
// is not in database it returns nil error and nil body structure. If error raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
// occurs it returns nil body structure. return nil
func (message *Message) GetBodyStructure() (bs *pkgMsg.BodyStructure, err error) { }); err != nil {
txRead := func(tx *bolt.Tx) error {
bs, err = message.store.txGetBodyStructure(
tx.Bucket(bodystructureBucket),
message.ID(),
)
return err
}
if err = message.store.db.View(txRead); err != nil {
return nil, err return nil, err
} }
if len(raw) > 0 {
// If not possible to deserialize just continue with build.
if bs, err := pkgMsg.DeserializeBodyStructure(raw); err == nil {
return bs, nil
}
}
literal, err := message.store.getCachedMessage(message.ID())
if err != nil {
return nil, err
}
bs, err := pkgMsg.NewBodyStructure(bytes.NewReader(literal))
if err != nil {
return nil, err
}
if raw, err = bs.Serialize(); err != nil {
return nil, err
}
if err := message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(bodystructureBucket).Put([]byte(message.ID()), raw)
}); err != nil {
return nil, err
}
return bs, nil return bs, nil
} }
func (message *Message) IncreaseBuildCount() (times uint32, err error) { // GetRFC822 returns the raw message literal.
txUpdate := func(tx *bolt.Tx) error { func (message *Message) GetRFC822() ([]byte, error) {
times, err = message.store.txIncreaseMsgBuildCount( return message.store.getCachedMessage(message.ID())
tx.Bucket(msgBuildCountBucket), }
message.ID(),
) // GetRFC822Size returns the size of the raw message literal.
return err func (message *Message) GetRFC822Size() (uint32, error) {
} var raw []byte
if err = message.store.db.Update(txUpdate); err != nil {
if err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(sizeBucket).Get([]byte(message.ID()))
return nil
}); err != nil {
return 0, err return 0, err
} }
return times, nil
if len(raw) > 0 {
return btoi(raw), nil
}
literal, err := message.store.getCachedMessage(message.ID())
if err != nil {
return 0, err
}
if err := message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(sizeBucket).Put([]byte(message.ID()), itob(uint32(len(literal))))
}); err != nil {
return 0, err
}
return uint32(len(literal)), nil
} }

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier) // Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier,Storer)
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
package mocks package mocks
@ -318,3 +318,54 @@ func (mr *MockChangeNotifierMockRecorder) UpdateMessage(arg0, arg1, arg2, arg3,
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMessage", reflect.TypeOf((*MockChangeNotifier)(nil).UpdateMessage), arg0, arg1, arg2, arg3, arg4, arg5) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMessage", reflect.TypeOf((*MockChangeNotifier)(nil).UpdateMessage), arg0, arg1, arg2, arg3, arg4, arg5)
} }
// MockStorer is a mock of Storer interface.
type MockStorer struct {
ctrl *gomock.Controller
recorder *MockStorerMockRecorder
}
// MockStorerMockRecorder is the mock recorder for MockStorer.
type MockStorerMockRecorder struct {
mock *MockStorer
}
// NewMockStorer creates a new mock instance.
func NewMockStorer(ctrl *gomock.Controller) *MockStorer {
mock := &MockStorer{ctrl: ctrl}
mock.recorder = &MockStorerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorer) EXPECT() *MockStorerMockRecorder {
return m.recorder
}
// BuildAndCacheMessage mocks base method.
func (m *MockStorer) BuildAndCacheMessage(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildAndCacheMessage", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// BuildAndCacheMessage indicates an expected call of BuildAndCacheMessage.
func (mr *MockStorerMockRecorder) BuildAndCacheMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildAndCacheMessage", reflect.TypeOf((*MockStorer)(nil).BuildAndCacheMessage), arg0)
}
// IsCached mocks base method.
func (m *MockStorer) IsCached(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsCached", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// IsCached indicates an expected call of IsCached.
func (mr *MockStorerMockRecorder) IsCached(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStorer)(nil).IsCached), arg0)
}

View File

@ -26,8 +26,11 @@ import (
"time" "time"
"github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -52,19 +55,21 @@ var (
// Database structure: // Database structure:
// * metadata // * metadata
// * {messageID} -> message data (subject, from, to, time, body size, ...) // * {messageID} -> message data (subject, from, to, time, ...)
// * headers // * headers
// * {messageID} -> header bytes // * {messageID} -> header bytes
// * bodystructure // * bodystructure
// * {messageID} -> message body structure // * {messageID} -> message body structure
// * msgbuildcount // * size
// * {messageID} -> uint32 number of message builds to track re-sync issues // * {messageID} -> uint32 value
// * counts // * counts
// * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive // * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive
// * address_info // * address_info
// * {index} -> {address, addressID} // * {index} -> {address, addressID}
// * address_mode // * address_mode
// * mode -> string split or combined // * mode -> string split or combined
// * cache_passphrase
// * passphrase -> cache passphrase (pgp encrypted message)
// * mailboxes_version // * mailboxes_version
// * version -> uint32 value // * version -> uint32 value
// * sync_state // * sync_state
@ -79,19 +84,20 @@ var (
// * {messageID} -> uint32 imapUID // * {messageID} -> uint32 imapUID
// * deleted_ids (can be missing or have no keys) // * deleted_ids (can be missing or have no keys)
// * {messageID} -> true // * {messageID} -> true
metadataBucket = []byte("metadata") //nolint[gochecknoglobals] metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
headersBucket = []byte("headers") //nolint[gochecknoglobals] headersBucket = []byte("headers") //nolint[gochecknoglobals]
bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals] bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
msgBuildCountBucket = []byte("msgbuildcount") //nolint[gochecknoglobals] sizeBucket = []byte("size") //nolint[gochecknoglobals]
countsBucket = []byte("counts") //nolint[gochecknoglobals] countsBucket = []byte("counts") //nolint[gochecknoglobals]
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals] addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals] addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals] cachePassphraseBucket = []byte("cache_passphrase") //nolint[gochecknoglobals]
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals] syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals] mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals] imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals] apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals] deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
// ErrNoSuchAPIID when mailbox does not have API ID. // ErrNoSuchAPIID when mailbox does not have API ID.
ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals] ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals]
@ -117,18 +123,23 @@ func exposeContextForSMTP() context.Context {
type Store struct { type Store struct {
sentryReporter *sentry.Reporter sentryReporter *sentry.Reporter
panicHandler PanicHandler panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser user BridgeUser
eventLoop *eventLoop
currentEvents *Events
log *logrus.Entry log *logrus.Entry
cache *Cache
filePath string filePath string
db *bolt.DB db *bolt.DB
lock *sync.RWMutex lock *sync.RWMutex
addresses map[string]*Address addresses map[string]*Address
notifier ChangeNotifier notifier ChangeNotifier
builder *message.Builder
cache cache.Cache
cacher *Cacher
done chan struct{}
isSyncRunning bool isSyncRunning bool
syncCooldown cooldown syncCooldown cooldown
addressMode addressMode addressMode addressMode
@ -139,12 +150,14 @@ func New( // nolint[funlen]
sentryReporter *sentry.Reporter, sentryReporter *sentry.Reporter,
panicHandler PanicHandler, panicHandler PanicHandler,
user BridgeUser, user BridgeUser,
events listener.Listener, listener listener.Listener,
cache cache.Cache,
builder *message.Builder,
path string, path string,
cache *Cache, currentEvents *Events,
) (store *Store, err error) { ) (store *Store, err error) {
if user == nil || events == nil || cache == nil { if user == nil || listener == nil || currentEvents == nil {
return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache) return nil, fmt.Errorf("missing parameters - user: %v, listener: %v, currentEvents: %v", user, listener, currentEvents)
} }
l := log.WithField("user", user.ID()) l := log.WithField("user", user.ID())
@ -160,21 +173,29 @@ func New( // nolint[funlen]
bdb, err := openBoltDatabase(path) bdb, err := openBoltDatabase(path)
if err != nil { if err != nil {
err = errors.Wrap(err, "failed to open store database") return nil, errors.Wrap(err, "failed to open store database")
return
} }
store = &Store{ store = &Store{
sentryReporter: sentryReporter, sentryReporter: sentryReporter,
panicHandler: panicHandler, panicHandler: panicHandler,
user: user, user: user,
cache: cache, currentEvents: currentEvents,
filePath: path,
db: bdb, log: l,
lock: &sync.RWMutex{},
log: l, filePath: path,
db: bdb,
lock: &sync.RWMutex{},
builder: builder,
cache: cache,
} }
// Create a new cacher. It's not started yet.
// NOTE(GODT-1158): I hate this circular dependency store->cacher->store :(
store.cacher = newCacher(store)
// Minimal increase is event pollInterval, doubles every failed retry up to 5 minutes. // Minimal increase is event pollInterval, doubles every failed retry up to 5 minutes.
store.syncCooldown.setExponentialWait(pollInterval, 2, 5*time.Minute) store.syncCooldown.setExponentialWait(pollInterval, 2, 5*time.Minute)
@ -188,7 +209,7 @@ func New( // nolint[funlen]
} }
if user.IsConnected() { if user.IsConnected() {
store.eventLoop = newEventLoop(cache, store, user, events) store.eventLoop = newEventLoop(currentEvents, store, user, listener)
go func() { go func() {
defer store.panicHandler.HandlePanic() defer store.panicHandler.HandlePanic()
store.eventLoop.start() store.eventLoop.start()
@ -216,10 +237,11 @@ func openBoltDatabase(filePath string) (db *bolt.DB, err error) {
metadataBucket, metadataBucket,
headersBucket, headersBucket,
bodystructureBucket, bodystructureBucket,
msgBuildCountBucket, sizeBucket,
countsBucket, countsBucket,
addressInfoBucket, addressInfoBucket,
addressModeBucket, addressModeBucket,
cachePassphraseBucket,
syncStateBucket, syncStateBucket,
mailboxesBucket, mailboxesBucket,
mboxVersionBucket, mboxVersionBucket,
@ -365,6 +387,24 @@ func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label)
return return
} }
// newBuildJob returns a new build job for the given message using the store's message builder.
func (store *Store) newBuildJob(messageID string, priority int) (*message.Job, pool.DoneFunc) {
return store.builder.NewJobWithOptions(
context.Background(),
store.client(),
messageID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
priority,
)
}
// Close stops the event loop and closes the database to free the file. // Close stops the event loop and closes the database to free the file.
func (store *Store) Close() error { func (store *Store) Close() error {
store.lock.Lock() store.lock.Lock()
@ -381,12 +421,21 @@ func (store *Store) CloseEventLoop() {
} }
func (store *Store) close() error { func (store *Store) close() error {
// Stop the watcher first before closing the database.
store.stopWatcher()
// Stop the cacher.
store.cacher.stop()
// Stop the event loop.
store.CloseEventLoop() store.CloseEventLoop()
// Close the database.
return store.db.Close() return store.db.Close()
} }
// Remove closes and removes the database file and clears the cache file. // Remove closes and removes the database file and clears the cache file.
func (store *Store) Remove() (err error) { func (store *Store) Remove() error {
store.lock.Lock() store.lock.Lock()
defer store.lock.Unlock() defer store.lock.Unlock()
@ -394,22 +443,34 @@ func (store *Store) Remove() (err error) {
var result *multierror.Error var result *multierror.Error
if err = store.close(); err != nil { if err := store.close(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to close store")) result = multierror.Append(result, errors.Wrap(err, "failed to close store"))
} }
if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil { if err := RemoveStore(store.currentEvents, store.filePath, store.user.ID()); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove store")) result = multierror.Append(result, errors.Wrap(err, "failed to remove store"))
} }
if err := store.RemoveCache(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove cache"))
}
return result.ErrorOrNil() return result.ErrorOrNil()
} }
func (store *Store) RemoveCache() error {
if err := store.clearCachePassphrase(); err != nil {
logrus.WithError(err).Error("Failed to clear cache passphrase")
}
return store.cache.Delete(store.user.ID())
}
// RemoveStore removes the database file and clears the cache file. // RemoveStore removes the database file and clears the cache file.
func RemoveStore(cache *Cache, path, userID string) error { func RemoveStore(currentEvents *Events, path, userID string) error {
var result *multierror.Error var result *multierror.Error
if err := cache.clearCacheUser(userID); err != nil { if err := currentEvents.clearUserEvents(userID); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache")) result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache"))
} }

View File

@ -23,13 +23,17 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
"time" "time"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks" storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
tests "github.com/ProtonMail/proton-bridge/test"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -139,7 +143,7 @@ type mocksForStore struct {
store *Store store *Store
tmpDir string tmpDir string
cache *Cache cache *Events
} }
func initMocks(tb testing.TB) (*mocksForStore, func()) { func initMocks(tb testing.TB) (*mocksForStore, func()) {
@ -162,7 +166,7 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
require.NoError(tb, err) require.NoError(tb, err)
cacheFile := filepath.Join(mocks.tmpDir, "cache.json") cacheFile := filepath.Join(mocks.tmpDir, "cache.json")
mocks.cache = NewCache(cacheFile) mocks.cache = NewEvents(cacheFile)
return mocks, func() { return mocks, func() {
if err := recover(); err != nil { if err := recover(); err != nil {
@ -176,13 +180,14 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
} }
} }
func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam] func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
mocks.user.EXPECT().ID().Return("userID").AnyTimes() mocks.user.EXPECT().ID().Return("userID").AnyTimes()
mocks.user.EXPECT().IsConnected().Return(true) mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode) mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client) mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes()
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true}, {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
@ -213,6 +218,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.panicHandler, mocks.panicHandler,
mocks.user, mocks.user,
mocks.events, mocks.events,
cache.NewInMemoryCache(1<<20),
message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
filepath.Join(mocks.tmpDir, "mailbox-test.db"), filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache, mocks.cache,
) )

View File

@ -27,7 +27,6 @@ import (
"strings" "strings"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -154,11 +153,6 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
return false, err return false, err
} }
msgSize := message.Size
if msgSize == 0 {
msgSize = int64(len(message.Body))
}
var attSize int64 var attSize int64
for _, att := range attachments { for _, att := range attachments {
b, err := ioutil.ReadAll(att.encReader) b, err := ioutil.ReadAll(att.encReader)
@ -169,7 +163,7 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
att.encReader = bytes.NewBuffer(b) att.encReader = bytes.NewBuffer(b)
} }
return msgSize+attSize <= maxUpload, nil return int64(len(message.Body))+attSize <= maxUpload, nil
} }
func (store *Store) getDraftAction(message *pmapi.Message) int { func (store *Store) getDraftAction(message *pmapi.Message) int {
@ -237,39 +231,6 @@ func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Messag
return nil return nil
} }
func (store *Store) txPutBodyStructure(bsBucket *bolt.Bucket, msgID string, bs *pkgMsg.BodyStructure) error {
raw, err := bs.Serialize()
if err != nil {
return err
}
err = bsBucket.Put([]byte(msgID), raw)
if err != nil {
return errors.Wrap(err, "cannot put bodystructure bucket")
}
return nil
}
func (store *Store) txGetBodyStructure(bsBucket *bolt.Bucket, msgID string) (*pkgMsg.BodyStructure, error) {
raw := bsBucket.Get([]byte(msgID))
if len(raw) == 0 {
return nil, nil
}
return pkgMsg.DeserializeBodyStructure(raw)
}
func (store *Store) txIncreaseMsgBuildCount(b *bolt.Bucket, msgID string) (uint32, error) {
key := []byte(msgID)
count := uint32(0)
raw := b.Get(key)
if raw != nil {
count = btoi(raw)
}
count++
return count, b.Put(key, itob(count))
}
// createOrUpdateMessageEvent is helper to create only one message with // createOrUpdateMessageEvent is helper to create only one message with
// createOrUpdateMessagesEvent. // createOrUpdateMessagesEvent.
func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error { func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error {
@ -287,7 +248,7 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
b := tx.Bucket(metadataBucket) b := tx.Bucket(metadataBucket)
for _, msg := range msgs { for _, msg := range msgs {
clearNonMetadata(msg) clearNonMetadata(msg)
txUpdateMetadaFromDB(b, msg, store.log) txUpdateMetadataFromDB(b, msg, store.log)
} }
return nil return nil
}) })
@ -341,6 +302,11 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
return err return err
} }
// Notify the cacher that it should start caching messages.
for _, msg := range msgs {
store.cacher.newJob(msg.ID)
}
return nil return nil
} }
@ -351,16 +317,12 @@ func clearNonMetadata(onlyMeta *pmapi.Message) {
onlyMeta.Attachments = nil onlyMeta.Attachments = nil
} }
// txUpdateMetadaFromDB changes the the onlyMeta data. // txUpdateMetadataFromDB changes the the onlyMeta data.
// If there is stored message in metaBucket the size, header and MIMEType are // If there is stored message in metaBucket the size, header and MIMEType are
// not changed if already set. To change these: // not changed if already set. To change these:
// * size must be updated by Message.SetSize // * size must be updated by Message.SetSize
// * contentType and header must be updated by Message.SetContentTypeAndHeader. // * contentType and header must be updated by Message.SetContentTypeAndHeader.
func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) { func txUpdateMetadataFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
// Size attribute on the server is counting encrypted data. We need to compute
// "real" size of decrypted data. Negative values will be processed during fetch.
onlyMeta.Size = -1
msgb := metaBucket.Get([]byte(onlyMeta.ID)) msgb := metaBucket.Get([]byte(onlyMeta.ID))
if msgb == nil { if msgb == nil {
return return
@ -378,8 +340,7 @@ func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log
return return
} }
// Keep already calculated size and content type. // Keep content type.
onlyMeta.Size = stored.Size
onlyMeta.MIMEType = stored.MIMEType onlyMeta.MIMEType = stored.MIMEType
if stored.Header != "" && stored.Header != "(No Header)" { if stored.Header != "" && stored.Header != "(No Header)" {
tmpMsg, err := mail.ReadMessage( tmpMsg, err := mail.ReadMessage(
@ -401,6 +362,12 @@ func (store *Store) deleteMessageEvent(apiID string) error {
// deleteMessagesEvent deletes the message from metadata and all mailbox buckets. // deleteMessagesEvent deletes the message from metadata and all mailbox buckets.
func (store *Store) deleteMessagesEvent(apiIDs []string) error { func (store *Store) deleteMessagesEvent(apiIDs []string) error {
for _, messageID := range apiIDs {
if err := store.cache.Rem(store.UserID(), messageID); err != nil {
logrus.WithError(err).Error("Failed to remove message from cache")
}
}
return store.db.Update(func(tx *bolt.Tx) error { return store.db.Update(func(tx *bolt.Tx) error {
for _, apiID := range apiIDs { for _, apiID := range apiIDs {
if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil { if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil {

View File

@ -33,7 +33,7 @@ func TestGetAllMessageIDs(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@ -47,7 +47,7 @@ func TestGetMessageFromDB(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{ tests := []struct{ msgID, wantErr string }{
@ -72,7 +72,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1") msg, err := m.store.getMessageFromDB("msg1")
@ -81,12 +81,10 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
// Check non-meta and calculated data are cleared/empty. // Check non-meta and calculated data are cleared/empty.
a.Equal(t, "", msg.Body) a.Equal(t, "", msg.Body)
a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments) a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments)
a.Equal(t, int64(-1), msg.Size)
a.Equal(t, "", msg.MIMEType) a.Equal(t, "", msg.MIMEType)
a.Equal(t, make(mail.Header), msg.Header) a.Equal(t, make(mail.Header), msg.Header)
// Change the calculated data. // Change the calculated data.
wantSize := int64(42)
wantMIMEType := "plain-text" wantMIMEType := "plain-text"
wantHeader := mail.Header{ wantHeader := mail.Header{
"Key": []string{"value"}, "Key": []string{"value"},
@ -94,13 +92,11 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1") storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1")
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, storeMsg.SetSize(wantSize))
require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader)) require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader))
// Check calculated data. // Check calculated data.
msg, err = m.store.getMessageFromDB("msg1") msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err) require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType) a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header) a.Equal(t, wantHeader, msg.Header)
@ -109,7 +105,6 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
msg, err = m.store.getMessageFromDB("msg1") msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err) require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType) a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header) a.Equal(t, wantHeader, msg.Header)
} }
@ -118,7 +113,7 @@ func TestDeleteMessage(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
@ -129,8 +124,7 @@ func TestDeleteMessage(t *testing.T) {
} }
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam] func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs) require.Nil(t, m.store.createOrUpdateMessageEvent(getTestMessage(id, subject, sender, unread, labelIDs)))
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
} }
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message { func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
@ -142,7 +136,6 @@ func getTestMessage(id, subject, sender string, unread bool, labelIDs []string)
Sender: address, Sender: address,
ToList: []*mail.Address{address}, ToList: []*mail.Address{address},
LabelIDs: labelIDs, LabelIDs: labelIDs,
Size: 12345,
Body: "body of message", Body: "body of message",
Attachments: []*pmapi.Attachment{{ Attachments: []*pmapi.Attachment{{
ID: "attachment1", ID: "attachment1",
@ -162,7 +155,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(false) m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{ m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+. MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil) }, nil)
@ -181,7 +174,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(false) m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{ m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+. MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil) }, nil)

View File

@ -30,7 +30,7 @@ func TestLoadSaveSyncState(t *testing.T) {
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
m.newStoreNoEvents(true) m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})

View File

@ -107,6 +107,21 @@ func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) erro
return err return err
} }
// If the client is already unlocked, we can unlock the store cache as well.
if client.IsUnlocked() {
kr, err := client.GetUserKeyRing()
if err != nil {
return err
}
if err := u.store.UnlockCache(kr); err != nil {
return err
}
// NOTE(GODT-1158): If using in-memory cache we probably shouldn't start the watcher?
u.store.StartWatcher()
}
return nil return nil
} }

View File

@ -32,7 +32,7 @@ func TestUpdateUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
gomock.InOrder( gomock.InOrder(
@ -50,7 +50,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
// Ignore any sync on background. // Ignore any sync on background.
@ -76,7 +76,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
r.False(t, user.creds.IsCombinedAddressMode) r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode()) r.False(t, user.IsCombinedAddressMode())
// MOck change to combined mode. // Mock change to combined mode.
gomock.InOrder( gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"), m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"), m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
@ -98,7 +98,7 @@ func TestLogoutUser(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
gomock.InOrder( gomock.InOrder(
@ -115,7 +115,7 @@ func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
gomock.InOrder( gomock.InOrder(
@ -133,7 +133,7 @@ func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
err := user.CheckBridgeLogin(testCredentials.BridgePassword) err := user.CheckBridgeLogin(testCredentials.BridgePassword)
@ -144,7 +144,7 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
@ -187,7 +187,7 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
err := user.CheckBridgeLogin("wrong!") err := user.CheckBridgeLogin("wrong!")

View File

@ -64,7 +64,7 @@ func TestNewUser(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m) mockInitConnectedUser(t, m)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
checkNewUserHasCredentials(m, "", testCredentials) checkNewUserHasCredentials(m, "", testCredentials)

View File

@ -31,7 +31,7 @@ func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
r.Nil(t, user.store.Close()) r.Nil(t, user.store.Close())
@ -43,7 +43,7 @@ func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t) m := initMocks(t)
defer m.ctrl.Finish() defer m.ctrl.Finish()
user := testNewUser(m) user := testNewUser(t, m)
defer cleanUpUserData(user) defer cleanUpUserData(user)
r.NotNil(t, user.store) r.NotNil(t, user.store)

View File

@ -18,13 +18,15 @@
package users package users
import ( import (
"testing"
r "github.com/stretchr/testify/require" r "github.com/stretchr/testify/require"
) )
// testNewUser sets up a new, authorised user. // testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User { func testNewUser(t *testing.T, m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m) mockInitConnectedUser(t, m)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker) user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker)

View File

@ -20,12 +20,13 @@ package users
import ( import (
"context" "context"
"os"
"path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials" "github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
@ -225,6 +226,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password []by
return nil, errors.Wrap(err, "failed to update password of user in credentials store") return nil, errors.Wrap(err, "failed to update password of user in credentials store")
} }
// will go and unlock cache if not already done
if err := user.connect(client, creds); err != nil { if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user") return nil, errors.Wrap(err, "failed to reconnect existing user")
} }
@ -341,9 +343,6 @@ func (u *Users) ClearData() error {
result = multierror.Append(result, err) result = multierror.Append(result, err)
} }
// Need to clear imap cache otherwise fetch response will be remembered from previous test.
imapcache.Clear()
return result return result
} }
@ -366,6 +365,7 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
if err := user.closeStore(); err != nil { if err := user.closeStore(); err != nil {
log.WithError(err).Error("Failed to close user store") log.WithError(err).Error("Failed to close user store")
} }
if clearStore { if clearStore {
// Clear cache after closing connections (done in logout). // Clear cache after closing connections (done in logout).
if err := user.clearStore(); err != nil { if err := user.clearStore(); err != nil {
@ -427,6 +427,41 @@ func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy() u.clientManager.DisallowProxy()
} }
func (u *Users) EnableCache() error {
// NOTE(GODT-1158): Check for available size before enabling.
return nil
}
func (u *Users) MigrateCache(from, to string) error {
// NOTE(GODT-1158): Is it enough to just close the store? Do we need to force-close the cacher too?
for _, user := range u.users {
if err := user.closeStore(); err != nil {
logrus.WithError(err).Error("Failed to close user's store")
}
}
// Ensure the parent directory exists.
if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
return err
}
return os.Rename(from, to)
}
func (u *Users) DisableCache() error {
// NOTE(GODT-1158): Is it an error if we can't remove a user's cache?
for _, user := range u.users {
if err := user.store.RemoveCache(); err != nil {
logrus.WithError(err).Error("Failed to remove user's message cache")
}
}
return nil
}
// hasUser returns whether the struct currently has a user with ID `id`. // hasUser returns whether the struct currently has a user with ID `id`.
func (u *Users) hasUser(id string) (user *User, ok bool) { func (u *Users) hasUser(id string) (user *User, ok bool) {
for _, u := range u.users { for _, u := range u.users {

View File

@ -49,7 +49,7 @@ func TestUsersFinishLoginNewUser(t *testing.T) {
// Init users with no user from keychain. // Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil) m.credentialsStore.EXPECT().List().Return([]string{}, nil)
mockAddingConnectedUser(m) mockAddingConnectedUser(t, m)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)) m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
@ -74,7 +74,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil), m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil), m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
) )
mockInitConnectedUser(m) mockInitConnectedUser(t, m)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID) m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
@ -95,7 +95,7 @@ func TestUsersFinishLoginConnectedUser(t *testing.T) {
// Mock loading connected user. // Mock loading connected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil) m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials) mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
// Mock process of FinishLogin of already connected user. // Mock process of FinishLogin of already connected user.

View File

@ -49,7 +49,7 @@ func TestNewUsersWithConnectedUser(t *testing.T) {
defer m.ctrl.Finish() defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil) m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials) mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials}) checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
} }
@ -71,7 +71,7 @@ func TestNewUsersWithUsers(t *testing.T) {
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil) m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected) mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
mockLoadingConnectedUser(m, testCredentials) mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials}) checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
} }

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"runtime"
"runtime/debug" "runtime/debug"
"testing" "testing"
"time" "time"
@ -28,10 +29,13 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store" "github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users/credentials" "github.com/ProtonMail/proton-bridge/internal/users/credentials"
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks" usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
tests "github.com/ProtonMail/proton-bridge/test"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -42,9 +46,11 @@ func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" { if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel) logrus.SetLevel(logrus.FatalLevel)
} }
if os.Getenv("VERBOSITY") == "trace" { if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel) logrus.SetLevel(logrus.TraceLevel)
} }
os.Exit(m.Run()) os.Exit(m.Run())
} }
@ -151,7 +157,7 @@ type mocks struct {
clientManager *pmapimocks.MockManager clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient pmapiClient *pmapimocks.MockClient
storeCache *store.Cache storeCache *store.Events
} }
func initMocks(t *testing.T) mocks { func initMocks(t *testing.T) mocks {
@ -178,7 +184,7 @@ func initMocks(t *testing.T) mocks {
clientManager: pmapimocks.NewMockManager(mockCtrl), clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl), pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()), storeCache: store.NewEvents(cacheFile.Name()),
} }
// Called during clean-up. // Called during clean-up.
@ -187,9 +193,20 @@ func initMocks(t *testing.T) mocks {
// Set up store factory. // Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) { m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests. var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
dbFile, err := ioutil.TempFile(t.TempDir(), "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db") r.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
return store.New(
sentryReporter,
m.PanicHandler,
user,
m.eventListener,
cache.NewInMemoryCache(1<<20),
message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
dbFile.Name(),
m.storeCache,
)
}).AnyTimes() }).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes() m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
@ -212,8 +229,8 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
func testNewUsersWithUsers(t *testing.T, m mocks) *Users { func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil) m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials) mockLoadingConnectedUser(t, m, testCredentials)
mockLoadingConnectedUser(m, testCredentialsSplit) mockLoadingConnectedUser(t, m, testCredentialsSplit)
mockEventLoopNoAction(m) mockEventLoopNoAction(m)
return testNewUsers(t, m) return testNewUsers(t, m)
@ -245,7 +262,7 @@ func cleanUpUsersData(b *Users) {
} }
} }
func mockAddingConnectedUser(m mocks) { func mockAddingConnectedUser(t *testing.T, m mocks) {
gomock.InOrder( gomock.InOrder(
// Mock of users.FinishLogin. // Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil), m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
@ -256,10 +273,10 @@ func mockAddingConnectedUser(m mocks) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
) )
mockInitConnectedUser(m) mockInitConnectedUser(t, m)
} }
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) { func mockLoadingConnectedUser(t *testing.T, m mocks, creds *credentials.Credentials) {
authRefresh := &pmapi.AuthRefresh{ authRefresh := &pmapi.AuthRefresh{
UID: "uid", UID: "uid",
AccessToken: "acc", AccessToken: "acc",
@ -273,10 +290,10 @@ func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil), m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
) )
mockInitConnectedUser(m) mockInitConnectedUser(t, m)
} }
func mockInitConnectedUser(m mocks) { func mockInitConnectedUser(t *testing.T, m mocks) {
// Mock of user initialisation. // Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()) m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes() m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
@ -286,6 +303,7 @@ func mockInitConnectedUser(m mocks) {
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.pmapiClient.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes(),
) )
} }

View File

@ -20,10 +20,12 @@ package message
import ( import (
"context" "context"
"io" "io"
"io/ioutil"
"sync" "sync"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -32,11 +34,15 @@ var (
ErrNoSuchKeyRing = errors.New("the keyring to decrypt this message could not be found") ErrNoSuchKeyRing = errors.New("the keyring to decrypt this message could not be found")
) )
const (
BackgroundPriority = 1 << iota
ForegroundPriority
)
type Builder struct { type Builder struct {
reqs chan fetchReq pool *pool.Pool
done chan struct{} jobs map[string]*Job
jobs map[string]*BuildJob lock sync.Mutex
locker sync.Mutex
} }
type Fetcher interface { type Fetcher interface {
@ -48,111 +54,159 @@ type Fetcher interface {
// NewBuilder creates a new builder which manages the given number of fetch/attach/build workers. // NewBuilder creates a new builder which manages the given number of fetch/attach/build workers.
// - fetchWorkers: the number of workers which fetch messages from API // - fetchWorkers: the number of workers which fetch messages from API
// - attachWorkers: the number of workers which fetch attachments from API. // - attachWorkers: the number of workers which fetch attachments from API.
// - buildWorkers: the number of workers which decrypt/build RFC822 message literals.
//
// NOTE: Each fetch worker spawns a unique set of attachment workers!
// There can therefore be up to fetchWorkers*attachWorkers simultaneous API connections.
// //
// The returned builder is ready to handle jobs -- see (*Builder).NewJob for more information. // The returned builder is ready to handle jobs -- see (*Builder).NewJob for more information.
// //
// Call (*Builder).Done to shut down the builder and stop all workers. // Call (*Builder).Done to shut down the builder and stop all workers.
func NewBuilder(fetchWorkers, attachWorkers, buildWorkers int) *Builder { func NewBuilder(fetchWorkers, attachWorkers int) *Builder {
b := newBuilder() attacherPool := pool.New(attachWorkers, newAttacherWorkFunc())
fetchReqCh, fetchResCh := startFetchWorkers(fetchWorkers, attachWorkers) fetcherPool := pool.New(fetchWorkers, newFetcherWorkFunc(attacherPool))
buildReqCh, buildResCh := startBuildWorkers(buildWorkers)
go func() {
defer close(fetchReqCh)
for {
select {
case req := <-b.reqs:
fetchReqCh <- req
case <-b.done:
return
}
}
}()
go func() {
defer close(buildReqCh)
for res := range fetchResCh {
if res.err != nil {
b.jobFailure(res.messageID, res.err)
} else {
buildReqCh <- res
}
}
}()
go func() {
for res := range buildResCh {
if res.err != nil {
b.jobFailure(res.messageID, res.err)
} else {
b.jobSuccess(res.messageID, res.literal)
}
}
}()
return b
}
func newBuilder() *Builder {
return &Builder{ return &Builder{
reqs: make(chan fetchReq), pool: fetcherPool,
done: make(chan struct{}), jobs: make(map[string]*Job),
jobs: make(map[string]*BuildJob),
} }
} }
// NewJob tells the builder to begin building the message with the given ID. func (builder *Builder) NewJob(ctx context.Context, fetcher Fetcher, messageID string, prio int) (*Job, pool.DoneFunc) {
// The result (or any error which occurred during building) can be retrieved from the returned job when available. return builder.NewJobWithOptions(ctx, fetcher, messageID, JobOptions{}, prio)
func (b *Builder) NewJob(ctx context.Context, api Fetcher, messageID string) *BuildJob {
return b.NewJobWithOptions(ctx, api, messageID, JobOptions{})
} }
// NewJobWithOptions creates a new job with custom options. See NewJob for more information. func (builder *Builder) NewJobWithOptions(ctx context.Context, fetcher Fetcher, messageID string, opts JobOptions, prio int) (*Job, pool.DoneFunc) {
func (b *Builder) NewJobWithOptions(ctx context.Context, api Fetcher, messageID string, opts JobOptions) *BuildJob { builder.lock.Lock()
b.locker.Lock() defer builder.lock.Unlock()
defer b.locker.Unlock()
if job, ok := b.jobs[messageID]; ok { if job, ok := builder.jobs[messageID]; ok {
return job if job.GetPriority() < prio {
job.SetPriority(prio)
}
return job, job.done
} }
b.jobs[messageID] = newBuildJob(messageID) job, done := builder.pool.NewJob(
&fetchReq{
fetcher: fetcher,
messageID: messageID,
options: opts,
},
prio,
)
go func() { b.reqs <- fetchReq{ctx: ctx, api: api, messageID: messageID, opts: opts} }() buildJob := &Job{
Job: job,
done: done,
}
return b.jobs[messageID] builder.jobs[messageID] = buildJob
return buildJob, func() {
builder.lock.Lock()
defer builder.lock.Unlock()
// Remove the job from the builder.
delete(builder.jobs, messageID)
// And mark it as done.
done()
}
} }
// Done shuts down the builder and stops all workers. func (builder *Builder) Done() {
func (b *Builder) Done() { // NOTE(GODT-1158): Stop worker pool.
b.locker.Lock()
defer b.locker.Unlock()
close(b.done)
} }
func (b *Builder) jobSuccess(messageID string, literal []byte) { type fetchReq struct {
b.locker.Lock() fetcher Fetcher
defer b.locker.Unlock() messageID string
options JobOptions
b.jobs[messageID].postSuccess(literal)
delete(b.jobs, messageID)
} }
func (b *Builder) jobFailure(messageID string, err error) { type attachReq struct {
b.locker.Lock() fetcher Fetcher
defer b.locker.Unlock() message *pmapi.Message
}
b.jobs[messageID].postFailure(err)
type Job struct {
delete(b.jobs, messageID) *pool.Job
done pool.DoneFunc
}
func (job *Job) GetResult() ([]byte, error) {
res, err := job.Job.GetResult()
if err != nil {
return nil, err
}
return res.([]byte), nil
}
func newAttacherWorkFunc() pool.WorkFunc {
return func(payload interface{}, prio int) (interface{}, error) {
req, ok := payload.(*attachReq)
if !ok {
panic("bad payload type")
}
res := make(map[string][]byte)
for _, att := range req.message.Attachments {
rc, err := req.fetcher.GetAttachment(context.Background(), att.ID)
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
if err := rc.Close(); err != nil {
return nil, err
}
res[att.ID] = b
}
return res, nil
}
}
func newFetcherWorkFunc(attacherPool *pool.Pool) pool.WorkFunc {
return func(payload interface{}, prio int) (interface{}, error) {
req, ok := payload.(*fetchReq)
if !ok {
panic("bad payload type")
}
msg, err := req.fetcher.GetMessage(context.Background(), req.messageID)
if err != nil {
return nil, err
}
attJob, attDone := attacherPool.NewJob(&attachReq{
fetcher: req.fetcher,
message: msg,
}, prio)
defer attDone()
val, err := attJob.GetResult()
if err != nil {
return nil, err
}
attData, ok := val.(map[string][]byte)
if !ok {
panic("bad response type")
}
kr, err := req.fetcher.KeyRingForAddressID(msg.AddressID)
if err != nil {
return nil, ErrNoSuchKeyRing
}
return buildRFC822(kr, msg, attData, req.options)
}
} }

View File

@ -1,89 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"sync"
"github.com/pkg/errors"
)
type buildRes struct {
messageID string
literal []byte
err error
}
func newBuildResSuccess(messageID string, literal []byte) buildRes {
return buildRes{
messageID: messageID,
literal: literal,
}
}
func newBuildResFailure(messageID string, err error) buildRes {
return buildRes{
messageID: messageID,
err: err,
}
}
// startBuildWorkers starts the given number of build workers.
// These workers decrypt and build messages into RFC822 literals.
// Two channels are returned:
// - buildReqCh: used to send work items to the worker pool
// - buildResCh: used to receive work results from the worker pool
func startBuildWorkers(buildWorkers int) (chan fetchRes, chan buildRes) {
buildReqCh := make(chan fetchRes)
buildResCh := make(chan buildRes)
go func() {
defer close(buildResCh)
var wg sync.WaitGroup
wg.Add(buildWorkers)
for workerID := 0; workerID < buildWorkers; workerID++ {
go buildWorker(buildReqCh, buildResCh, &wg)
}
wg.Wait()
}()
return buildReqCh, buildResCh
}
func buildWorker(buildReqCh <-chan fetchRes, buildResCh chan<- buildRes, wg *sync.WaitGroup) {
defer wg.Done()
for req := range buildReqCh {
l := log.
WithField("addrID", req.msg.AddressID).
WithField("msgID", req.msg.ID)
if kr, err := req.api.KeyRingForAddressID(req.msg.AddressID); err != nil {
l.WithError(err).Warn("Cannot find keyring for address")
buildResCh <- newBuildResFailure(req.msg.ID, errors.Wrap(ErrNoSuchKeyRing, err.Error()))
} else if literal, err := buildRFC822(kr, req.msg, req.atts, req.opts); err != nil {
l.WithError(err).Warn("Build failed")
buildResCh <- newBuildResFailure(req.msg.ID, err)
} else {
buildResCh <- newBuildResSuccess(req.msg.ID, literal)
}
}
}

View File

@ -1,141 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"context"
"io/ioutil"
"sync"
"github.com/ProtonMail/proton-bridge/pkg/parallel"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type fetchReq struct {
ctx context.Context
api Fetcher
messageID string
opts JobOptions
}
type fetchRes struct {
fetchReq
msg *pmapi.Message
atts [][]byte
err error
}
func newFetchResSuccess(req fetchReq, msg *pmapi.Message, atts [][]byte) fetchRes {
return fetchRes{
fetchReq: req,
msg: msg,
atts: atts,
}
}
func newFetchResFailure(req fetchReq, err error) fetchRes {
return fetchRes{
fetchReq: req,
err: err,
}
}
// startFetchWorkers starts the given number of fetch workers.
// These workers download message and attachment data from API.
// Each fetch worker will use up to the given number of attachment workers to download attachments.
// Two channels are returned:
// - fetchReqCh: used to send work items to the worker pool
// - fetchResCh: used to receive work results from the worker pool
func startFetchWorkers(fetchWorkers, attachWorkers int) (chan fetchReq, chan fetchRes) {
fetchReqCh := make(chan fetchReq)
fetchResCh := make(chan fetchRes)
go func() {
defer close(fetchResCh)
var wg sync.WaitGroup
wg.Add(fetchWorkers)
for workerID := 0; workerID < fetchWorkers; workerID++ {
go fetchWorker(fetchReqCh, fetchResCh, attachWorkers, &wg)
}
wg.Wait()
}()
return fetchReqCh, fetchResCh
}
func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachWorkers int, wg *sync.WaitGroup) {
defer wg.Done()
for req := range fetchReqCh {
msg, atts, err := fetchMessage(req, attachWorkers)
if err != nil {
fetchResCh <- newFetchResFailure(req, err)
} else {
fetchResCh <- newFetchResSuccess(req, msg, atts)
}
}
}
func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
msg, err := req.api.GetMessage(req.ctx, req.messageID)
if err != nil {
return nil, nil, err
}
attList := make([]interface{}, len(msg.Attachments))
for i, att := range msg.Attachments {
attList[i] = att.ID
}
process := func(value interface{}) (interface{}, error) {
rc, err := req.api.GetAttachment(req.ctx, value.(string))
if err != nil {
return nil, err
}
b, err := ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
if err := rc.Close(); err != nil {
return nil, err
}
return b, nil
}
attData := make([][]byte, len(msg.Attachments))
collect := func(idx int, value interface{}) error {
attData[idx] = value.([]byte) //nolint[forcetypeassert] we wan't to panic here
return nil
}
if err := parallel.RunParallel(attachWorkers, attList, process, collect); err != nil {
return nil, nil, err
}
return msg, attData, nil
}

View File

@ -25,35 +25,3 @@ type JobOptions struct {
AddMessageDate bool // Whether to include message time as X-Pm-Date. AddMessageDate bool // Whether to include message time as X-Pm-Date.
AddMessageIDReference bool // Whether to include the MessageID in References. AddMessageIDReference bool // Whether to include the MessageID in References.
} }
type BuildJob struct {
messageID string
literal []byte
err error
done chan struct{}
}
func newBuildJob(messageID string) *BuildJob {
return &BuildJob{
messageID: messageID,
done: make(chan struct{}),
}
}
// GetResult returns the build result or any error which occurred during building.
// If the result is not ready yet, it blocks.
func (job *BuildJob) GetResult() ([]byte, error) {
<-job.done
return job.literal, job.err
}
func (job *BuildJob) postSuccess(literal []byte) {
job.literal = literal
close(job.done)
}
func (job *BuildJob) postFailure(err error) {
job.err = err
close(job.done)
}

View File

@ -34,7 +34,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData [][]byte, opts JobOptions) ([]byte, error) { func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData map[string][]byte, opts JobOptions) ([]byte, error) {
switch { switch {
case len(msg.Attachments) > 0: case len(msg.Attachments) > 0:
return buildMultipartRFC822(kr, msg, attData, opts) return buildMultipartRFC822(kr, msg, attData, opts)
@ -80,7 +80,7 @@ func buildSimpleRFC822(kr *crypto.KeyRing, msg *pmapi.Message, opts JobOptions)
func buildMultipartRFC822( func buildMultipartRFC822(
kr *crypto.KeyRing, kr *crypto.KeyRing,
msg *pmapi.Message, msg *pmapi.Message,
attData [][]byte, attData map[string][]byte,
opts JobOptions, opts JobOptions,
) ([]byte, error) { ) ([]byte, error) {
boundary := newBoundary(msg.ID) boundary := newBoundary(msg.ID)
@ -103,13 +103,13 @@ func buildMultipartRFC822(
attachData [][]byte attachData [][]byte
) )
for i, att := range msg.Attachments { for _, att := range msg.Attachments {
if att.Disposition == pmapi.DispositionInline { if att.Disposition == pmapi.DispositionInline {
inlineAtts = append(inlineAtts, att) inlineAtts = append(inlineAtts, att)
inlineData = append(inlineData, attData[i]) inlineData = append(inlineData, attData[att.ID])
} else { } else {
attachAtts = append(attachAtts, att) attachAtts = append(attachAtts, att)
attachData = append(attachData, attData[i]) attachData = append(attachData, attData[att.ID])
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -38,9 +38,10 @@ type BodyStructure map[string]*SectionInfo
// SectionInfo is used to hold data about parts of each section. // SectionInfo is used to hold data about parts of each section.
type SectionInfo struct { type SectionInfo struct {
Header textproto.MIMEHeader Header []byte
Start, BSize, Size, Lines int Start, BSize, Size, Lines int
reader io.Reader reader io.Reader
isHeaderReadFinished bool
} }
// Read will also count the final size of section. // Read will also count the final size of section.
@ -48,9 +49,38 @@ func (si *SectionInfo) Read(p []byte) (n int, err error) {
n, err = si.reader.Read(p) n, err = si.reader.Read(p)
si.Size += n si.Size += n
si.Lines += bytes.Count(p, []byte("\n")) si.Lines += bytes.Count(p, []byte("\n"))
si.readHeader(p)
return return
} }
// readHeader appends read data to Header until empty line is found.
func (si *SectionInfo) readHeader(p []byte) {
if si.isHeaderReadFinished {
return
}
si.Header = append(si.Header, p...)
if i := bytes.Index(si.Header, []byte("\n\r\n")); i > 0 {
si.Header = si.Header[:i+3]
si.isHeaderReadFinished = true
return
}
// textproto works also with simple line ending so we should be liberal
// as well.
if i := bytes.Index(si.Header, []byte("\n\n")); i > 0 {
si.Header = si.Header[:i+2]
si.isHeaderReadFinished = true
}
}
// GetMIMEHeader parses bytes and return MIME header.
func (si *SectionInfo) GetMIMEHeader() (textproto.MIMEHeader, error) {
return textproto.NewReader(bufio.NewReader(bytes.NewReader(si.Header))).ReadMIMEHeader()
}
func NewBodyStructure(reader io.Reader) (structure *BodyStructure, err error) { func NewBodyStructure(reader io.Reader) (structure *BodyStructure, err error) {
structure = &BodyStructure{} structure = &BodyStructure{}
err = structure.Parse(reader) err = structure.Parse(reader)
@ -93,14 +123,15 @@ func (bs *BodyStructure) parseAllChildSections(r io.Reader, currentPath []int, s
bufInfo := bufio.NewReader(info) bufInfo := bufio.NewReader(info)
tp := textproto.NewReader(bufInfo) tp := textproto.NewReader(bufInfo)
if info.Header, err = tp.ReadMIMEHeader(); err != nil { tpHeader, err := tp.ReadMIMEHeader()
if err != nil {
return return
} }
bodyInfo := &SectionInfo{reader: tp.R} bodyInfo := &SectionInfo{reader: tp.R}
bodyReader := bufio.NewReader(bodyInfo) bodyReader := bufio.NewReader(bodyInfo)
mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type")) mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
// If multipart, call getAllParts, else read to count lines. // If multipart, call getAllParts, else read to count lines.
if (strings.HasPrefix(mediaType, "multipart/") || mediaType == rfc822Message) && params["boundary"] != "" { if (strings.HasPrefix(mediaType, "multipart/") || mediaType == rfc822Message) && params["boundary"] != "" {
@ -260,9 +291,9 @@ func (bs *BodyStructure) GetMailHeader() (header textproto.MIMEHeader, err error
} }
// GetMailHeaderBytes returns the bytes with main mail header. // GetMailHeaderBytes returns the bytes with main mail header.
// Warning: It can contain extra lines or multipart comment. // Warning: It can contain extra lines.
func (bs *BodyStructure) GetMailHeaderBytes(wholeMail io.ReadSeeker) (header []byte, err error) { func (bs *BodyStructure) GetMailHeaderBytes() (header []byte, err error) {
return bs.GetSectionHeaderBytes(wholeMail, []int{}) return bs.GetSectionHeaderBytes([]int{})
} }
func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byte, error) { func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byte, error) {
@ -283,22 +314,21 @@ func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byt
} }
// GetSectionHeader returns the mime header of specified section. // GetSectionHeader returns the mime header of specified section.
func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (header textproto.MIMEHeader, err error) { func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (textproto.MIMEHeader, error) {
info, err := bs.getInfoCheckSection(sectionPath) info, err := bs.getInfoCheckSection(sectionPath)
if err != nil { if err != nil {
return return nil, err
} }
header = info.Header return info.GetMIMEHeader()
return
} }
func (bs *BodyStructure) GetSectionHeaderBytes(wholeMail io.ReadSeeker, sectionPath []int) (header []byte, err error) { // GetSectionHeaderBytes returns raw header bytes of specified section.
func (bs *BodyStructure) GetSectionHeaderBytes(sectionPath []int) ([]byte, error) {
info, err := bs.getInfoCheckSection(sectionPath) info, err := bs.getInfoCheckSection(sectionPath)
if err != nil { if err != nil {
return return nil, err
} }
headerLength := info.Size - info.BSize return info.Header, nil
return goToOffsetAndReadNBytes(wholeMail, info.Start, headerLength)
} }
// IMAPBodyStructure will prepare imap bodystructure recurently for given part. // IMAPBodyStructure will prepare imap bodystructure recurently for given part.
@ -309,7 +339,12 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body
return return
} }
mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type")) tpHeader, err := info.GetMIMEHeader()
if err != nil {
return
}
mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
mediaTypeSep := strings.Split(mediaType, "/") mediaTypeSep := strings.Split(mediaType, "/")
@ -324,19 +359,19 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body
Lines: uint32(info.Lines), Lines: uint32(info.Lines),
} }
if val := info.Header.Get("Content-ID"); val != "" { if val := tpHeader.Get("Content-ID"); val != "" {
imapBS.Id = val imapBS.Id = val
} }
if val := info.Header.Get("Content-Transfer-Encoding"); val != "" { if val := tpHeader.Get("Content-Transfer-Encoding"); val != "" {
imapBS.Encoding = val imapBS.Encoding = val
} }
if val := info.Header.Get("Content-Description"); val != "" { if val := tpHeader.Get("Content-Description"); val != "" {
imapBS.Description = val imapBS.Description = val
} }
if val := info.Header.Get("Content-Disposition"); val != "" { if val := tpHeader.Get("Content-Disposition"); val != "" {
imapBS.Disposition = val imapBS.Disposition = val
} }

View File

@ -21,7 +21,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/textproto"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort" "sort"
@ -71,7 +70,9 @@ func TestParseBodyStructure(t *testing.T) {
debug("%10s: %-50s %5s %5s %5s %5s", "section", "type", "start", "size", "bsize", "lines") debug("%10s: %-50s %5s %5s %5s %5s", "section", "type", "start", "size", "bsize", "lines")
for _, path := range paths { for _, path := range paths {
sec := (*bs)[path] sec := (*bs)[path]
contentType := (*bs)[path].Header.Get("Content-Type") header, err := sec.GetMIMEHeader()
require.NoError(t, err)
contentType := header.Get("Content-Type")
debug("%10s: %-50s %5d %5d %5d %5d", path, contentType, sec.Start, sec.Size, sec.BSize, sec.Lines) debug("%10s: %-50s %5d %5d %5d %5d", path, contentType, sec.Start, sec.Size, sec.BSize, sec.Lines)
require.Equal(t, expectedStructure[path], contentType) require.Equal(t, expectedStructure[path], contentType)
} }
@ -100,7 +101,9 @@ func TestParseBodyStructurePGP(t *testing.T) {
haveStructure := map[string]string{} haveStructure := map[string]string{}
for path := range *bs { for path := range *bs {
haveStructure[path] = (*bs)[path].Header.Get("Content-Type") header, err := (*bs)[path].GetMIMEHeader()
require.NoError(t, err)
haveStructure[path] = header.Get("Content-Type")
} }
require.Equal(t, expectedStructure, haveStructure) require.Equal(t, expectedStructure, haveStructure)
@ -192,7 +195,7 @@ Content-Type: plain/text
r.NoError(err, debug(wantPath, info, haveBody)) r.NoError(err, debug(wantPath, info, haveBody))
r.Equal(wantBody, string(haveBody), debug(wantPath, info, haveBody)) r.Equal(wantBody, string(haveBody), debug(wantPath, info, haveBody))
haveHeader, err := bs.GetSectionHeaderBytes(strings.NewReader(wantMail), wantPath) haveHeader, err := bs.GetSectionHeaderBytes(wantPath)
r.NoError(err, debug(wantPath, info, haveHeader)) r.NoError(err, debug(wantPath, info, haveHeader))
r.Equal(wantHeader, string(haveHeader), debug(wantPath, info, haveHeader)) r.Equal(wantHeader, string(haveHeader), debug(wantPath, info, haveHeader))
} }
@ -211,7 +214,7 @@ Content-Type: multipart/mixed; boundary="0000MAIN"
bs, err := NewBodyStructure(structReader) bs, err := NewBodyStructure(structReader)
require.NoError(t, err) require.NoError(t, err)
haveHeader, err := bs.GetMailHeaderBytes(strings.NewReader(sampleMail)) haveHeader, err := bs.GetMailHeaderBytes()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, wantHeader, haveHeader) require.Equal(t, wantHeader, haveHeader)
} }
@ -533,18 +536,14 @@ func TestBodyStructureSerialize(t *testing.T) {
r := require.New(t) r := require.New(t)
want := &BodyStructure{ want := &BodyStructure{
"1": { "1": {
Header: textproto.MIMEHeader{ Header: []byte("Content: type"),
"Content": []string{"type"}, Start: 1,
}, Size: 2,
Start: 1, BSize: 3,
Size: 2, Lines: 4,
BSize: 3,
Lines: 4,
}, },
"1.1.1": { "1.1.1": {
Header: textproto.MIMEHeader{ Header: []byte("X-Pm-Key: id"),
"X-Pm-Key": []string{"id"},
},
Start: 11, Start: 11,
Size: 12, Size: 12,
BSize: 13, BSize: 13,
@ -562,3 +561,32 @@ func TestBodyStructureSerialize(t *testing.T) {
(*want)["1.1.1"].reader = nil (*want)["1.1.1"].reader = nil
r.Equal(want, have) r.Equal(want, have)
} }
func TestSectionInfoReadHeader(t *testing.T) {
r := require.New(t)
testData := []struct {
wantHeader, mail string
}{
{
"key1: val1\nkey2: val2\n\n",
"key1: val1\nkey2: val2\n\nbody is here\n\nand it is not confused",
},
{
"key1:\n val1\n\n",
"key1:\n val1\n\nbody is here",
},
{
"key1: val1\r\nkey2: val2\r\n\r\n",
"key1: val1\r\nkey2: val2\r\n\r\nbody is here\r\n\r\nand it is not confused",
},
}
for _, td := range testData {
bs, err := NewBodyStructure(strings.NewReader(td.mail))
r.NoError(err, "case %q", td.mail)
haveHeader, err := bs.GetMailHeaderBytes()
r.NoError(err, "case %q", td.mail)
r.Equal(td.wantHeader, string(haveHeader), "case %q", td.mail)
}
}

131
pkg/pchan/pchan.go Normal file
View File

@ -0,0 +1,131 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pchan
import (
"sort"
"sync"
)
type PChan struct {
lock sync.Mutex
items []*Item
ready, done chan struct{}
}
type Item struct {
ch *PChan
val interface{}
prio int
done chan struct{}
}
func (item *Item) Wait() {
<-item.done
}
func (item *Item) GetPriority() int {
item.ch.lock.Lock()
defer item.ch.lock.Unlock()
return item.prio
}
func (item *Item) SetPriority(priority int) {
item.ch.lock.Lock()
defer item.ch.lock.Unlock()
item.prio = priority
sort.Slice(item.ch.items, func(i, j int) bool {
return item.ch.items[i].prio < item.ch.items[j].prio
})
}
func New() *PChan {
return &PChan{
ready: make(chan struct{}),
done: make(chan struct{}),
}
}
func (ch *PChan) Push(val interface{}, prio int) *Item {
defer ch.notify()
return ch.push(val, prio)
}
func (ch *PChan) Pop() (interface{}, int, bool) {
select {
case <-ch.ready:
val, prio := ch.pop()
return val, prio, true
case <-ch.done:
return nil, 0, false
}
}
func (ch *PChan) Close() {
select {
case <-ch.done:
return
default:
close(ch.done)
}
}
func (ch *PChan) push(val interface{}, prio int) *Item {
ch.lock.Lock()
defer ch.lock.Unlock()
done := make(chan struct{})
item := &Item{
ch: ch,
val: val,
prio: prio,
done: done,
}
ch.items = append(ch.items, item)
return item
}
func (ch *PChan) pop() (interface{}, int) {
ch.lock.Lock()
defer ch.lock.Unlock()
sort.Slice(ch.items, func(i, j int) bool {
return ch.items[i].prio < ch.items[j].prio
})
var item *Item
item, ch.items = ch.items[len(ch.items)-1], ch.items[:len(ch.items)-1]
defer close(item.done)
return item.val, item.prio
}
func (ch *PChan) notify() {
go func() { ch.ready <- struct{}{} }()
}

123
pkg/pchan/pchan_test.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pchan
import (
"sort"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPChanConcurrentPush(t *testing.T) {
ch := New()
var wg sync.WaitGroup
// We are going to test with 5 additional goroutines.
wg.Add(5)
// Start 5 concurrent pushes.
go func() { defer wg.Done(); ch.Push(1, 1) }()
go func() { defer wg.Done(); ch.Push(2, 2) }()
go func() { defer wg.Done(); ch.Push(3, 3) }()
go func() { defer wg.Done(); ch.Push(4, 4) }()
go func() { defer wg.Done(); ch.Push(5, 5) }()
// Wait for the items to be pushed.
wg.Wait()
// All 5 should now be ready for popping.
require.Len(t, ch.items, 5)
// They should be popped in priority order.
assert.Equal(t, 5, getValue(t, ch))
assert.Equal(t, 4, getValue(t, ch))
assert.Equal(t, 3, getValue(t, ch))
assert.Equal(t, 2, getValue(t, ch))
assert.Equal(t, 1, getValue(t, ch))
}
func TestPChanConcurrentPop(t *testing.T) {
ch := New()
var wg sync.WaitGroup
// We are going to test with 5 additional goroutines.
wg.Add(5)
// Make a list to store the results in.
var res list
// Start 5 concurrent pops; these consume any items pushed.
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
// Push and block; items should be popped immediately by the waiting goroutines.
ch.Push(1, 1).Wait()
ch.Push(2, 2).Wait()
ch.Push(3, 3).Wait()
ch.Push(4, 4).Wait()
ch.Push(5, 5).Wait()
// Wait for all items to be popped then close the result channel.
wg.Wait()
assert.True(t, sort.IntsAreSorted(res.items))
}
func TestPChanClose(t *testing.T) {
ch := New()
go ch.Push(1, 1)
valOpen, _, okOpen := ch.Pop()
assert.True(t, okOpen)
assert.Equal(t, 1, valOpen)
ch.Close()
valClose, _, okClose := ch.Pop()
assert.False(t, okClose)
assert.Nil(t, valClose)
}
type list struct {
items []int
mut sync.Mutex
}
func (l *list) append(val int) {
l.mut.Lock()
defer l.mut.Unlock()
l.items = append(l.items, val)
}
func getValue(t *testing.T, ch *PChan) int {
val, _, ok := ch.Pop()
assert.True(t, ok)
return val.(int)
}

View File

@ -71,6 +71,7 @@ type Client interface {
GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error) GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error)
CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
GetUserKeyRing() (*crypto.KeyRing, error)
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error) KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error) GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
} }

View File

@ -175,7 +175,6 @@ type Message struct {
CCList []*mail.Address CCList []*mail.Address
BCCList []*mail.Address BCCList []*mail.Address
Time int64 // Unix time Time int64 // Unix time
Size int64
NumAttachments int NumAttachments int
ExpirationTime int64 // Unix time ExpirationTime int64 // Unix time
SpamScore int SpamScore int

View File

@ -362,6 +362,21 @@ func (mr *MockClientMockRecorder) GetPublicKeysForEmail(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeysForEmail", reflect.TypeOf((*MockClient)(nil).GetPublicKeysForEmail), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeysForEmail", reflect.TypeOf((*MockClient)(nil).GetPublicKeysForEmail), arg0, arg1)
} }
// GetUserKeyRing mocks base method.
func (m *MockClient) GetUserKeyRing() (*crypto.KeyRing, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserKeyRing")
ret0, _ := ret[0].(*crypto.KeyRing)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserKeyRing indicates an expected call of GetUserKeyRing.
func (mr *MockClientMockRecorder) GetUserKeyRing() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserKeyRing", reflect.TypeOf((*MockClient)(nil).GetUserKeyRing))
}
// Import mocks base method. // Import mocks base method.
func (m *MockClient) Import(arg0 context.Context, arg1 pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) { func (m *MockClient) Import(arg0 context.Context, arg1 pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -20,6 +20,7 @@ package pmapi
import ( import (
"context" "context"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/go-resty/resty/v2" "github.com/go-resty/resty/v2"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -138,3 +139,12 @@ func (c *client) CurrentUser(ctx context.Context) (*User, error) {
return c.UpdateUser(ctx) return c.UpdateUser(ctx)
} }
// CurrentUser returns currently active user or user will be updated.
func (c *client) GetUserKeyRing() (*crypto.KeyRing, error) {
if c.userKeyRing == nil {
return nil, errors.New("user keyring is not available")
}
return c.userKeyRing, nil
}

129
pkg/pool/pool.go Normal file
View File

@ -0,0 +1,129 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pool
import "github.com/ProtonMail/proton-bridge/pkg/pchan"
type WorkFunc func(interface{}, int) (interface{}, error)
type DoneFunc func()
type Pool struct {
jobCh *pchan.PChan
}
func New(size int, work WorkFunc) *Pool {
jobCh := pchan.New()
for i := 0; i < size; i++ {
go func() {
for {
val, prio, ok := jobCh.Pop()
if !ok {
return
}
job, ok := val.(*Job)
if !ok {
panic("bad result type")
}
res, err := work(job.req, prio)
if err != nil {
job.postFailure(err)
} else {
job.postSuccess(res)
}
job.waitDone()
}
}()
}
return &Pool{jobCh: jobCh}
}
func (pool *Pool) NewJob(req interface{}, prio int) (*Job, DoneFunc) {
job := newJob(req)
job.setItem(pool.jobCh.Push(job, prio))
return job, job.markDone
}
type Job struct {
req interface{}
res interface{}
err error
item *pchan.Item
ready, done chan struct{}
}
func newJob(req interface{}) *Job {
return &Job{
req: req,
ready: make(chan struct{}),
done: make(chan struct{}),
}
}
func (job *Job) GetResult() (interface{}, error) {
<-job.ready
return job.res, job.err
}
func (job *Job) GetPriority() int {
return job.item.GetPriority()
}
func (job *Job) SetPriority(prio int) {
job.item.SetPriority(prio)
}
func (job *Job) postSuccess(res interface{}) {
defer close(job.ready)
job.res = res
}
func (job *Job) postFailure(err error) {
defer close(job.ready)
job.err = err
}
func (job *Job) setItem(item *pchan.Item) {
job.item = item
}
func (job *Job) markDone() {
select {
case <-job.done:
return
default:
close(job.done)
}
}
func (job *Job) waitDone() {
<-job.done
}

45
pkg/pool/pool_test.go Normal file
View File

@ -0,0 +1,45 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pool_test
import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPool(t *testing.T) {
pool := pool.New(2, func(req interface{}, prio int) (interface{}, error) { return req, nil })
job1, done1 := pool.NewJob("echo", 1)
defer done1()
job2, done2 := pool.NewJob("this", 1)
defer done2()
res2, err := job2.GetResult()
require.NoError(t, err)
res1, err := job1.GetResult()
require.NoError(t, err)
assert.Equal(t, "echo", res1)
assert.Equal(t, "this", res2)
}

View File

@ -0,0 +1,53 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package semaphore
import "sync"
type Semaphore struct {
ch chan struct{}
wg sync.WaitGroup
}
func New(max int) Semaphore {
return Semaphore{ch: make(chan struct{}, max)}
}
func (sem *Semaphore) Lock() {
sem.ch <- struct{}{}
}
func (sem *Semaphore) Unlock() {
<-sem.ch
}
func (sem *Semaphore) Go(fn func()) {
sem.Lock()
sem.wg.Add(1)
go func() {
defer sem.Unlock()
defer sem.wg.Done()
fn()
}()
}
func (sem *Semaphore) Wait() {
sem.wg.Wait()
}

View File

@ -21,11 +21,14 @@ import (
"time" "time"
"github.com/ProtonMail/proton-bridge/internal/bridge" "github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/config/useragent" "github.com/ProtonMail/proton-bridge/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener" "github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
) )
@ -61,18 +64,28 @@ func (ctx *TestContext) RestartBridge() error {
} }
// newBridgeInstance creates a new bridge instance configured to use the given config/credstore. // newBridgeInstance creates a new bridge instance configured to use the given config/credstore.
// NOTE(GODT-1158): Need some tests with on-disk cache as well! Configurable in feature file or envvar?
func newBridgeInstance( func newBridgeInstance(
t *bddT, t *bddT,
locations bridge.Locator, locations bridge.Locator,
cache bridge.Cacher, cacheProvider bridge.CacheProvider,
settings *fakeSettings, fakeSettings *fakeSettings,
credStore users.CredentialsStorer, credStore users.CredentialsStorer,
eventListener listener.Listener, eventListener listener.Listener,
clientManager pmapi.Manager, clientManager pmapi.Manager,
) *bridge.Bridge { ) *bridge.Bridge {
sentryReporter := sentry.NewReporter("bridge", constants.Version, useragent.New()) return bridge.New(
panicHandler := &panicHandler{t: t} locations,
updater := newFakeUpdater() cacheProvider,
versioner := newFakeVersioner() fakeSettings,
return bridge.New(locations, cache, settings, sentryReporter, panicHandler, eventListener, clientManager, credStore, updater, versioner) sentry.NewReporter("bridge", constants.Version, useragent.New()),
&panicHandler{t: t},
eventListener,
cache.NewInMemoryCache(100*(1<<20)),
message.NewBuilder(fakeSettings.GetInt(settings.FetchWorkers), fakeSettings.GetInt(settings.AttachmentWorkers)),
clientManager,
credStore,
newFakeUpdater(),
newFakeVersioner(),
)
} }

View File

@ -58,7 +58,7 @@ func (ctx *TestContext) withIMAPServer() {
port := ctx.settings.GetInt(settings.IMAPPortKey) port := ctx.settings.GetInt(settings.IMAPPortKey)
tls, _ := tls.New(settingsPath).GetConfig() tls, _ := tls.New(settingsPath).GetConfig()
backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.bridge) backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.settings, ctx.bridge)
server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener) server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener)
go server.ListenAndServe() go server.ListenAndServe()

View File

@ -125,6 +125,10 @@ func (api *FakePMAPI) Addresses() pmapi.AddressList {
return *api.addresses return *api.addresses
} }
func (api *FakePMAPI) GetUserKeyRing() (*crypto.KeyRing, error) {
return api.userKeyRing, nil
}
func (api *FakePMAPI) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) { func (api *FakePMAPI) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
return api.addrKeyRing[addrID], nil return api.addrKeyRing[addrID], nil
} }