diff --git a/Makefile b/Makefile
index cd958edb..288eead8 100644
--- a/Makefile
+++ b/Makefile
@@ -230,7 +230,7 @@ integration-test-bridge:
mocks:
mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go
- mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go
+ mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier,Storer > internal/store/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go
@@ -288,7 +288,7 @@ run-nogui-cli: clean-vendor gofiles
PROTONMAIL_ENV=dev go run ${BUILD_FLAGS} cmd/${TARGET_CMD}/main.go ${RUN_FLAGS} -c
run-debug:
- PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS}
+ PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS} --noninteractive
run-qml-preview:
find internal/frontend/qml/ -iname '*qmlc' | xargs rm -f
diff --git a/TODO.md b/TODO.md
new file mode 100644
index 00000000..e475fa52
--- /dev/null
+++ b/TODO.md
@@ -0,0 +1 @@
+- when cache is full, we need to stop the watcher? don't want to keep downloading messages and throwing them away when we try to cache them.
diff --git a/go.mod b/go.mod
index d4fffee1..d1a5df19 100644
--- a/go.mod
+++ b/go.mod
@@ -32,7 +32,6 @@ require (
github.com/emersion/go-imap-move v0.0.0-20190710073258-6e5a51a5b342
github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26
- github.com/emersion/go-mbox v1.0.2
github.com/emersion/go-message v0.12.1-0.20201221184100-40c3f864532b
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
github.com/emersion/go-smtp v0.14.0
@@ -45,7 +44,6 @@ require (
github.com/golang/mock v1.4.4
github.com/google/go-cmp v0.5.1
github.com/google/uuid v1.1.1
- github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c // indirect
github.com/hashicorp/go-multierror v1.1.0
github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7
github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d
@@ -55,13 +53,11 @@ require (
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
github.com/olekukonko/tablewriter v0.0.4 // indirect
github.com/pkg/errors v0.9.1
+ github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285
github.com/sirupsen/logrus v1.7.0
- github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
+ github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.7.0
- github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e
- github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d // indirect
- github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d // indirect
github.com/urfave/cli/v2 v2.2.0
github.com/vmihailenco/msgpack/v5 v5.1.3
go.etcd.io/bbolt v1.3.6
diff --git a/go.sum b/go.sum
index e169da7d..a4c5597d 100644
--- a/go.sum
+++ b/go.sum
@@ -124,8 +124,6 @@ github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c h1:khcEdu1y
github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c/go.mod h1:iApyhIQBiU4XFyr+3kdJyyGqle82TbQyuP2o+OZHrV0=
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 h1:FiSb8+XBQQSkcX3ubr+1tAtlRJBYaFmRZqOAweZ9Wy8=
github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26/go.mod h1:+gnnZx3Mg3MnCzZrv0eZdp5puxXQUgGT/6N6L7ShKfM=
-github.com/emersion/go-mbox v1.0.2 h1:tE/rT+lEugK9y0myEymCCHnwlZN04hlXPrbKkxRBA5I=
-github.com/emersion/go-mbox v1.0.2/go.mod h1:Yp9IVuuOYLEuMv4yjgDHvhb5mHOcYH6x92Oas3QqEZI=
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
@@ -197,9 +195,6 @@ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
-github.com/gopherjs/gopherjs v0.0.0-20190411002643-bd77b112433e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
-github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4=
-github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
@@ -269,7 +264,6 @@ github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0
github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
@@ -351,6 +345,8 @@ github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
+github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285 h1:d54EL9l+XteliUfUCGsEwwuk65dmmxX85VXF+9T6+50=
+github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285/go.mod h1:fxIDly1xtudczrZeOOlfaUvd2OPb2qZAPuWdU2BsBTk=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo=
@@ -365,13 +361,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
-github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
-github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
-github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
-github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@@ -401,12 +393,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
-github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e h1:G0DQ/TRQyrEZjtLlLwevFjaRiG8eeCMlq9WXQ2OO2bk=
-github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e/go.mod h1:SUUR2j3aE1z6/g76SdD6NwACEpvCxb3fvG82eKbD6us=
-github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d h1:hAZyEG2swPRWjF0kqqdGERXUazYnRJdAk4a58f14z7Y=
-github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:7m8PDYDEtEVqfjoUQc2UrFqhG0CDmoVJjRlQxexndFc=
-github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d h1:AJRoBel/g9cDS+yE8BcN3E+TDD/xNAguG21aoR8DAIE=
-github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:mH55Ek7AZcdns5KPp99O0bg+78el64YCYWHiQKrOdt4=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -444,7 +430,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190418165655-df01cb2cc480/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@@ -488,7 +473,6 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -521,9 +505,7 @@ golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -557,7 +539,6 @@ golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
-golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
diff --git a/internal/app/bridge/bridge.go b/internal/app/bridge/bridge.go
index fb39b4cc..44b0e061 100644
--- a/internal/app/bridge/bridge.go
+++ b/internal/app/bridge/bridge.go
@@ -32,7 +32,9 @@ import (
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/internal/imap"
"github.com/ProtonMail/proton-bridge/internal/smtp"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
@@ -69,10 +71,21 @@ func New(base *base.Base) *cli.App {
func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
tlsConfig, err := loadTLSConfig(b)
if err != nil {
- logrus.WithError(err).Fatal("Failed to load TLS config")
+ return err
}
- bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, b.CM, b.Creds, b.Updater, b.Versioner)
- imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, bridge)
+
+ cache, err := loadCache(b)
+ if err != nil {
+ return err
+ }
+
+ builder := message.NewBuilder(
+ b.Settings.GetInt(settings.FetchWorkers),
+ b.Settings.GetInt(settings.AttachmentWorkers),
+ )
+
+ bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, cache, builder, b.CM, b.Creds, b.Updater, b.Versioner)
+ imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
go func() {
@@ -233,3 +246,35 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool)
f.NotifySilentUpdateInstalled()
}
+
+// NOTE(GODT-1158): How big should in-memory cache be?
+// NOTE(GODT-1158): How to handle cache location migration if user changes custom path?
+func loadCache(b *base.Base) (cache.Cache, error) {
+ if !b.Settings.GetBool(settings.CacheEnabledKey) {
+ return cache.NewInMemoryCache(100 * (1 << 20)), nil
+ }
+
+ var compressor cache.Compressor
+
+ // NOTE(GODT-1158): If user changes compression setting we have to nuke the cache.
+ if b.Settings.GetBool(settings.CacheCompressionKey) {
+ compressor = &cache.GZipCompressor{}
+ } else {
+ compressor = &cache.NoopCompressor{}
+ }
+
+ var path string
+
+ if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" {
+ path = customPath
+ } else {
+ path = b.Cache.GetDefaultMessageCacheDir()
+ }
+
+ return cache.NewOnDiskCache(path, compressor, cache.Options{
+ MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)),
+ MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey),
+ ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead),
+ ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
+ })
+}
diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go
index ee282dcc..398ae830 100644
--- a/internal/bridge/bridge.go
+++ b/internal/bridge/bridge.go
@@ -28,17 +28,17 @@ import (
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/sentry"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/internal/users"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/listener"
logrus "github.com/sirupsen/logrus"
)
-var (
- log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
-)
+var log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
type Bridge struct {
*users.Users
@@ -52,11 +52,13 @@ type Bridge struct {
func New(
locations Locator,
- cache Cacher,
- s SettingsProvider,
+ cacheProvider CacheProvider,
+ setting SettingsProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
+ cache cache.Cache,
+ builder *message.Builder,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
updater Updater,
@@ -64,7 +66,7 @@ func New(
) *Bridge {
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
- if s.GetBool(settings.AllowProxyKey) {
+ if setting.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy()
}
@@ -74,25 +76,25 @@ func New(
eventListener,
clientManager,
credStorer,
- newStoreFactory(cache, sentryReporter, panicHandler, eventListener),
+ newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
)
b := &Bridge{
Users: u,
locations: locations,
- settings: s,
+ settings: setting,
clientManager: clientManager,
updater: updater,
versioner: versioner,
}
- if s.GetBool(settings.FirstStartKey) {
+ if setting.GetBool(settings.FirstStartKey) {
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
logrus.WithError(err).Error("Failed to send metric")
}
- s.SetBool(settings.FirstStartKey, false)
+ setting.SetBool(settings.FirstStartKey, false)
}
go b.heartbeat()
diff --git a/internal/bridge/store_factory.go b/internal/bridge/store_factory.go
index e31263ad..eb1e7aaf 100644
--- a/internal/bridge/store_factory.go
+++ b/internal/bridge/store_factory.go
@@ -23,47 +23,65 @@ import (
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
)
type storeFactory struct {
- cache Cacher
+ cacheProvider CacheProvider
sentryReporter *sentry.Reporter
panicHandler users.PanicHandler
eventListener listener.Listener
- storeCache *store.Cache
+ events *store.Events
+ cache cache.Cache
+ builder *message.Builder
}
func newStoreFactory(
- cache Cacher,
+ cacheProvider CacheProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
+ cache cache.Cache,
+ builder *message.Builder,
) *storeFactory {
return &storeFactory{
- cache: cache,
+ cacheProvider: cacheProvider,
sentryReporter: sentryReporter,
panicHandler: panicHandler,
eventListener: eventListener,
- storeCache: store.NewCache(cache.GetIMAPCachePath()),
+ events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
+ cache: cache,
+ builder: builder,
}
}
// New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
- storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
- return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
+ return store.New(
+ f.sentryReporter,
+ f.panicHandler,
+ user,
+ f.eventListener,
+ f.cache,
+ f.builder,
+ getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
+ f.events,
+ )
}
// Remove removes all store files for given user.
func (f *storeFactory) Remove(userID string) error {
- storePath := getUserStorePath(f.cache.GetDBDir(), userID)
- return store.RemoveStore(f.storeCache, storePath, userID)
+ return store.RemoveStore(
+ f.events,
+ getUserStorePath(f.cacheProvider.GetDBDir(), userID),
+ userID,
+ )
}
// getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) {
- fileName := fmt.Sprintf("mailbox-%v.db", userID)
- return filepath.Join(storeDir, fileName)
+ return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
}
diff --git a/internal/bridge/types.go b/internal/bridge/types.go
index 4b3cb701..1077acf9 100644
--- a/internal/bridge/types.go
+++ b/internal/bridge/types.go
@@ -28,7 +28,7 @@ type Locator interface {
ClearUpdates() error
}
-type Cacher interface {
+type CacheProvider interface {
GetIMAPCachePath() string
GetDBDir() string
}
@@ -38,6 +38,7 @@ type SettingsProvider interface {
Set(key string, value string)
GetBool(key string) bool
SetBool(key string, val bool)
+ GetInt(key string) int
}
type Updater interface {
diff --git a/internal/config/cache/cache.go b/internal/config/cache/cache.go
index 7c08bc5a..7c3ff9de 100644
--- a/internal/config/cache/cache.go
+++ b/internal/config/cache/cache.go
@@ -45,6 +45,11 @@ func (c *Cache) GetDBDir() string {
return c.getCurrentCacheDir()
}
+// GetDefaultMessageCacheDir returns folder for cached messages files.
+func (c *Cache) GetDefaultMessageCacheDir() string {
+ return filepath.Join(c.getCurrentCacheDir(), "messages")
+}
+
// GetIMAPCachePath returns path to file with IMAP status.
func (c *Cache) GetIMAPCachePath() string {
return filepath.Join(c.getCurrentCacheDir(), "user_info.json")
diff --git a/internal/config/settings/kvs.go b/internal/config/settings/kvs.go
index f40cd1a8..6310cce7 100644
--- a/internal/config/settings/kvs.go
+++ b/internal/config/settings/kvs.go
@@ -100,18 +100,28 @@ func (p *keyValueStore) GetBool(key string) bool {
}
func (p *keyValueStore) GetInt(key string) int {
+ if p.Get(key) == "" {
+ return 0
+ }
+
value, err := strconv.Atoi(p.Get(key))
if err != nil {
logrus.WithError(err).Error("Cannot parse int")
}
+
return value
}
func (p *keyValueStore) GetFloat64(key string) float64 {
+ if p.Get(key) == "" {
+ return 0
+ }
+
value, err := strconv.ParseFloat(p.Get(key), 64)
if err != nil {
logrus.WithError(err).Error("Cannot parse float64")
}
+
return value
}
diff --git a/internal/config/settings/settings.go b/internal/config/settings/settings.go
index aadaa183..feb76b29 100644
--- a/internal/config/settings/settings.go
+++ b/internal/config/settings/settings.go
@@ -43,6 +43,16 @@ const (
UpdateChannelKey = "update_channel"
RolloutKey = "rollout"
PreferredKeychainKey = "preferred_keychain"
+ CacheEnabledKey = "cache_enabled"
+ CacheCompressionKey = "cache_compression"
+ CacheLocationKey = "cache_location"
+ CacheMinFreeAbsKey = "cache_min_free_abs"
+ CacheMinFreeRatKey = "cache_min_free_rat"
+ CacheConcurrencyRead = "cache_concurrent_read"
+ CacheConcurrencyWrite = "cache_concurrent_write"
+ IMAPWorkers = "imap_workers"
+ FetchWorkers = "fetch_workers"
+ AttachmentWorkers = "attachment_workers"
)
type Settings struct {
@@ -80,6 +90,16 @@ func (s *Settings) setDefaultValues() {
s.setDefault(UpdateChannelKey, "")
s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint[gosec] G404 It is OK to use weak random number generator here
s.setDefault(PreferredKeychainKey, "")
+ s.setDefault(CacheEnabledKey, "true")
+ s.setDefault(CacheCompressionKey, "true")
+ s.setDefault(CacheLocationKey, "")
+ s.setDefault(CacheMinFreeAbsKey, "250000000")
+ s.setDefault(CacheMinFreeRatKey, "")
+ s.setDefault(CacheConcurrencyRead, "16")
+ s.setDefault(CacheConcurrencyWrite, "16")
+ s.setDefault(IMAPWorkers, "16")
+ s.setDefault(FetchWorkers, "16")
+ s.setDefault(AttachmentWorkers, "16")
s.setDefault(APIPortKey, DefaultAPIPort)
s.setDefault(IMAPPortKey, DefaultIMAPPort)
diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go
index ea7899ef..dda9b7e6 100644
--- a/internal/frontend/cli/frontend.go
+++ b/internal/frontend/cli/frontend.go
@@ -128,6 +128,24 @@ func New( //nolint[funlen]
})
fe.AddCmd(dohCmd)
+ // Cache-On-Disk commands.
+ codCmd := &ishell.Cmd{Name: "local-cache",
+ Help: "manage the local encrypted message cache",
+ }
+ codCmd.AddCmd(&ishell.Cmd{Name: "enable",
+ Help: "enable the local cache",
+ Func: fe.enableCacheOnDisk,
+ })
+ codCmd.AddCmd(&ishell.Cmd{Name: "disable",
+ Help: "disable the local cache",
+ Func: fe.disableCacheOnDisk,
+ })
+ codCmd.AddCmd(&ishell.Cmd{Name: "change-location",
+ Help: "change the location of the local cache",
+ Func: fe.setCacheOnDiskLocation,
+ })
+ fe.AddCmd(codCmd)
+
// Updates commands.
updatesCmd := &ishell.Cmd{Name: "updates",
Help: "manage bridge updates",
diff --git a/internal/frontend/cli/system.go b/internal/frontend/cli/system.go
index 0ec95db3..c2643c72 100644
--- a/internal/frontend/cli/system.go
+++ b/internal/frontend/cli/system.go
@@ -19,6 +19,7 @@ package cli
import (
"fmt"
+ "os"
"strconv"
"strings"
@@ -155,6 +156,67 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
}
}
+func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) {
+ if f.settings.GetBool(settings.CacheEnabledKey) {
+ f.Println("The local cache is already enabled.")
+ return
+ }
+
+ if f.yesNoQuestion("Are you sure you want to enable the local cache") {
+ // Set this back to the default location before enabling.
+ f.settings.Set(settings.CacheLocationKey, "")
+
+ if err := f.bridge.EnableCache(); err != nil {
+ f.Println("The local cache could not be enabled.")
+ return
+ }
+
+ f.settings.SetBool(settings.CacheEnabledKey, true)
+ f.restarter.SetToRestart()
+ f.Stop()
+ }
+}
+
+func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) {
+ if !f.settings.GetBool(settings.CacheEnabledKey) {
+ f.Println("The local cache is already disabled.")
+ return
+ }
+
+ if f.yesNoQuestion("Are you sure you want to disable the local cache") {
+ if err := f.bridge.DisableCache(); err != nil {
+ f.Println("The local cache could not be disabled.")
+ return
+ }
+
+ f.settings.SetBool(settings.CacheEnabledKey, false)
+ f.restarter.SetToRestart()
+ f.Stop()
+ }
+}
+
+func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) {
+ if !f.settings.GetBool(settings.CacheEnabledKey) {
+ f.Println("The local cache must be enabled.")
+ return
+ }
+
+ if location := f.settings.Get(settings.CacheLocationKey); location != "" {
+ f.Println("The current local cache location is:", location)
+ }
+
+ if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
+ if err := f.bridge.MigrateCache(f.settings.Get(settings.CacheLocationKey), location); err != nil {
+ f.Println("The local cache location could not be changed.")
+ return
+ }
+
+ f.settings.Set(settings.CacheLocationKey, location)
+ f.restarter.SetToRestart()
+ f.Stop()
+ }
+}
+
func (f *frontendCLI) isPortFree(port string) bool {
port = strings.ReplaceAll(port, ":", "")
if port == "" || port == currentPort {
@@ -171,3 +233,13 @@ func (f *frontendCLI) isPortFree(port string) bool {
}
return true
}
+
+// NOTE(GODT-1158): Check free space in location.
+func (f *frontendCLI) isCacheLocationUsable(location string) bool {
+ stat, err := os.Stat(location)
+ if err != nil {
+ return false
+ }
+
+ return stat.IsDir()
+}
diff --git a/internal/frontend/types/types.go b/internal/frontend/types/types.go
index 0552bcee..d1c63308 100644
--- a/internal/frontend/types/types.go
+++ b/internal/frontend/types/types.go
@@ -77,6 +77,9 @@ type Bridger interface {
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
AllowProxy()
DisallowProxy()
+ EnableCache() error
+ DisableCache() error
+ MigrateCache(from, to string) error
GetUpdateChannel() updater.UpdateChannel
SetUpdateChannel(updater.UpdateChannel) (needRestart bool, err error)
GetKeychainApp() string
diff --git a/internal/imap/backend.go b/internal/imap/backend.go
index aff3778b..85493052 100644
--- a/internal/imap/backend.go
+++ b/internal/imap/backend.go
@@ -37,21 +37,13 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
+ "github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
- "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/emersion/go-imap"
goIMAPBackend "github.com/emersion/go-imap/backend"
)
-const (
- // NOTE: Each fetch worker has its own set of attach workers so there can be up to 20*5=100 API requests at once.
- // This is a reasonable limit to not overwhelm API while still maintaining as much parallelism as possible.
- fetchWorkers = 20 // In how many workers to fetch message (group list on IMAP).
- attachWorkers = 5 // In how many workers to fetch attachments (for one message).
- buildWorkers = 20 // In how many workers to build messages.
-)
-
type panicHandler interface {
HandlePanic()
}
@@ -61,26 +53,32 @@ type imapBackend struct {
bridge bridger
updates *imapUpdates
eventListener listener.Listener
+ listWorkers int
users map[string]*imapUser
usersLocker sync.Locker
- builder *message.Builder
-
imapCache map[string]map[string]string
imapCachePath string
imapCacheLock *sync.RWMutex
}
+type settingsProvider interface {
+ GetInt(string) int
+}
+
// NewIMAPBackend returns struct implementing go-imap/backend interface.
func NewIMAPBackend(
panicHandler panicHandler,
eventListener listener.Listener,
cache cacheProvider,
+ setting settingsProvider,
bridge *bridge.Bridge,
) *imapBackend { //nolint[golint]
bridgeWrap := newBridgeWrap(bridge)
- backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener)
+
+ imapWorkers := setting.GetInt(settings.IMAPWorkers)
+ backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener, imapWorkers)
go backend.monitorDisconnectedUsers()
@@ -92,6 +90,7 @@ func newIMAPBackend(
cache cacheProvider,
bridge bridger,
eventListener listener.Listener,
+ listWorkers int,
) *imapBackend {
return &imapBackend{
panicHandler: panicHandler,
@@ -102,10 +101,9 @@ func newIMAPBackend(
users: map[string]*imapUser{},
usersLocker: &sync.Mutex{},
- builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
-
imapCachePath: cache.GetIMAPCachePath(),
imapCacheLock: &sync.RWMutex{},
+ listWorkers: listWorkers,
}
}
diff --git a/internal/imap/cache/cache.go b/internal/imap/cache/cache.go
deleted file mode 100644
index 6048f460..00000000
--- a/internal/imap/cache/cache.go
+++ /dev/null
@@ -1,151 +0,0 @@
-// Copyright (c) 2021 Proton Technologies AG
-//
-// This file is part of ProtonMail Bridge.
-//
-// ProtonMail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// ProtonMail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with ProtonMail Bridge. If not, see .
-
-package cache
-
-import (
- "bytes"
- "sort"
- "sync"
- "time"
-
- pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
-)
-
-type key struct {
- ID string
- Timestamp int64
- Size int
-}
-
-type oldestFirst []key
-
-func (s oldestFirst) Len() int { return len(s) }
-func (s oldestFirst) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-func (s oldestFirst) Less(i, j int) bool { return s[i].Timestamp < s[j].Timestamp }
-
-type cachedMessage struct {
- key
- data []byte
- structure pkgMsg.BodyStructure
-}
-
-//nolint[gochecknoglobals]
-var (
- cacheTimeLimit = int64(1 * 60 * 60 * 1000) // milliseconds
- cacheSizeLimit = 100 * 1000 * 1000 // B - MUST be larger than email max size limit (~ 25 MB)
- mailCache = make(map[string]cachedMessage)
-
- // cacheMutex takes care of one single operation, whereas buildMutex takes
- // care of the whole action doing multiple operations. buildMutex will protect
- // you from asking server or decrypting or building the same message more
- // than once. When first request to build the message comes, it will block
- // all other build requests. When the first one is done, all others are
- // handled by cache, not doing anything twice. With cacheMutex we are safe
- // only to not mess up with the cache, but we could end up downloading and
- // building message twice.
- cacheMutex = &sync.Mutex{}
- buildMutex = &sync.Mutex{}
- buildLocks = map[string]interface{}{}
-)
-
-func (m *cachedMessage) isValidOrDel() bool {
- if m.key.Timestamp+cacheTimeLimit < timestamp() {
- delete(mailCache, m.key.ID)
- return false
- }
- return true
-}
-
-func timestamp() int64 {
- return time.Now().UnixNano() / int64(time.Millisecond)
-}
-
-func Clear() {
- mailCache = make(map[string]cachedMessage)
-}
-
-// BuildLock locks per message level, not on global level.
-// Multiple different messages can be building at once.
-func BuildLock(messageID string) {
- for {
- buildMutex.Lock()
- if _, ok := buildLocks[messageID]; ok { // if locked, wait
- buildMutex.Unlock()
- time.Sleep(10 * time.Millisecond)
- } else { // if unlocked, lock it
- buildLocks[messageID] = struct{}{}
- buildMutex.Unlock()
- return
- }
- }
-}
-
-func BuildUnlock(messageID string) {
- buildMutex.Lock()
- defer buildMutex.Unlock()
- delete(buildLocks, messageID)
-}
-
-func LoadMail(mID string) (reader *bytes.Reader, structure *pkgMsg.BodyStructure) {
- reader = &bytes.Reader{}
- cacheMutex.Lock()
- defer cacheMutex.Unlock()
- if message, ok := mailCache[mID]; ok && message.isValidOrDel() {
- reader = bytes.NewReader(message.data)
- structure = &message.structure
-
- // Update timestamp to keep emails which are used often.
- message.Timestamp = timestamp()
- }
- return
-}
-
-func SaveMail(mID string, msg []byte, structure *pkgMsg.BodyStructure) {
- cacheMutex.Lock()
- defer cacheMutex.Unlock()
-
- newMessage := cachedMessage{
- key: key{
- ID: mID,
- Timestamp: timestamp(),
- Size: len(msg),
- },
- data: msg,
- structure: *structure,
- }
-
- // Remove old and reduce size.
- totalSize := 0
- messageList := []key{}
- for _, message := range mailCache {
- if message.isValidOrDel() {
- messageList = append(messageList, message.key)
- totalSize += message.key.Size
- }
- }
- sort.Sort(oldestFirst(messageList))
- var oldest key
- for totalSize+newMessage.key.Size >= cacheSizeLimit {
- oldest, messageList = messageList[0], messageList[1:]
- delete(mailCache, oldest.ID)
- totalSize -= oldest.Size
- }
-
- // Write new.
- mailCache[mID] = newMessage
-}
diff --git a/internal/imap/cache/cache_test.go b/internal/imap/cache/cache_test.go
deleted file mode 100644
index 90696239..00000000
--- a/internal/imap/cache/cache_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-// Copyright (c) 2021 Proton Technologies AG
-//
-// This file is part of ProtonMail Bridge.
-//
-// ProtonMail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// ProtonMail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with ProtonMail Bridge. If not, see .
-
-package cache
-
-import (
- "fmt"
- "testing"
- "time"
-
- pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
- "github.com/stretchr/testify/require"
-)
-
-var bs = &pkgMsg.BodyStructure{} //nolint[gochecknoglobals]
-const testUID = "testmsg"
-
-func TestSaveAndLoad(t *testing.T) {
- msg := []byte("Test message")
-
- SaveMail(testUID, msg, bs)
- require.Equal(t, mailCache[testUID].data, msg)
-
- reader, _ := LoadMail(testUID)
- require.Equal(t, reader.Len(), len(msg))
- stored := make([]byte, len(msg))
- _, _ = reader.Read(stored)
- require.Equal(t, stored, msg)
-}
-
-func TestMissing(t *testing.T) {
- reader, _ := LoadMail("non-existing")
- require.Equal(t, reader.Len(), 0)
-}
-
-func TestClearOld(t *testing.T) {
- cacheTimeLimit = 10
- msg := []byte("Test message")
- SaveMail(testUID, msg, bs)
- time.Sleep(100 * time.Millisecond)
-
- reader, _ := LoadMail(testUID)
- require.Equal(t, reader.Len(), 0)
-}
-
-func TestClearBig(t *testing.T) {
- r := require.New(t)
- wantMessage := []byte("Test message")
-
- wantCacheSize := 3
- nTestMessages := wantCacheSize * wantCacheSize
- cacheSizeLimit = wantCacheSize*len(wantMessage) + 1
- cacheTimeLimit = int64(1 << 20) // be sure the message will survive
-
- // It should never have more than nSize items.
- for i := 0; i < nTestMessages; i++ {
- time.Sleep(1 * time.Millisecond)
- SaveMail(fmt.Sprintf("%s%d", testUID, i), wantMessage, bs)
- r.LessOrEqual(len(mailCache), wantCacheSize, "cache too big when %d", i)
- }
-
- // Check that the oldest are deleted first.
- for i := 0; i < nTestMessages; i++ {
- iUID := fmt.Sprintf("%s%d", testUID, i)
- reader, _ := LoadMail(iUID)
- mail := mailCache[iUID]
-
- if i < (nTestMessages - wantCacheSize) {
- r.Zero(reader.Len(), "LoadMail should return empty, but have %s for %s time %d ", string(mail.data), iUID, mail.key.Timestamp)
- } else {
- stored := make([]byte, len(wantMessage))
- _, err := reader.Read(stored)
- r.NoError(err)
- r.Equal(wantMessage, stored, "LoadMail returned wrong message: %s for %s time %d", stored, iUID, mail.key.Timestamp)
- }
- }
-}
-
-func TestConcurency(t *testing.T) {
- msg := []byte("Test message")
- for i := 0; i < 10; i++ {
- go SaveMail(fmt.Sprintf("%s%d", testUID, i), msg, bs)
- }
-}
diff --git a/internal/imap/mailbox.go b/internal/imap/mailbox.go
index e9df47f6..8dc09ccb 100644
--- a/internal/imap/mailbox.go
+++ b/internal/imap/mailbox.go
@@ -37,12 +37,10 @@ type imapMailbox struct {
storeUser storeUserProvider
storeAddress storeAddressProvider
storeMailbox storeMailboxProvider
-
- builder *message.Builder
}
// newIMAPMailbox returns struct implementing go-imap/mailbox interface.
-func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider, builder *message.Builder) *imapMailbox {
+func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider) *imapMailbox {
return &imapMailbox{
panicHandler: panicHandler,
user: user,
@@ -56,8 +54,6 @@ func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox stor
storeUser: user.storeUser,
storeAddress: user.storeAddress,
storeMailbox: storeMailbox,
-
- builder: builder,
}
}
diff --git a/internal/imap/mailbox_fetch.go b/internal/imap/mailbox_fetch.go
index b1f8d724..69aaa833 100644
--- a/internal/imap/mailbox_fetch.go
+++ b/internal/imap/mailbox_fetch.go
@@ -19,21 +19,13 @@ package imap
import (
"bytes"
- "context"
- "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/pkg/message"
- "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
- "github.com/sirupsen/logrus"
)
-func (im *imapMailbox) getMessage(
- storeMessage storeMessageProvider,
- items []imap.FetchItem,
- msgBuildCountHistogram *msgBuildCountHistogram,
-) (msg *imap.Message, err error) {
+func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []imap.FetchItem) (msg *imap.Message, err error) {
msglog := im.log.WithField("msgID", storeMessage.ID())
msglog.Trace("Getting message")
@@ -69,9 +61,12 @@ func (im *imapMailbox) getMessage(
// There is no point having message older than RFC itself, it's not possible.
msg.InternalDate = message.SanitizeMessageDate(m.Time)
case imap.FetchRFC822Size:
- if msg.Size, err = im.getSize(storeMessage); err != nil {
+ size, err := storeMessage.GetRFC822Size()
+ if err != nil {
return nil, err
}
+
+ msg.Size = size
case imap.FetchUid:
if msg.Uid, err = storeMessage.UID(); err != nil {
return nil, err
@@ -79,7 +74,7 @@ func (im *imapMailbox) getMessage(
case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text:
fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests
default:
- if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil {
+ if err = im.getLiteralForSection(item, msg, storeMessage); err != nil {
return
}
}
@@ -88,35 +83,7 @@ func (im *imapMailbox) getMessage(
return msg, err
}
-// getSize returns cached size or it will build the message, save the size in
-// DB and then returns the size after build.
-//
-// We are storing size in DB as part of pmapi messages metada. The size
-// attribute on the server represents size of encrypted body. The value is
-// cleared in Bridge and the final decrypted size (including header, attachment
-// and MIME structure) is computed after building the message.
-func (im *imapMailbox) getSize(storeMessage storeMessageProvider) (uint32, error) {
- m := storeMessage.Message()
- if m.Size <= 0 {
- im.log.WithField("msgID", m.ID).Debug("Size unknown - downloading body")
- // We are sure the size is not a problem right now. Clients
- // might not first check sizes of all messages so we couldn't
- // be sure if seeing 1st or 2nd sync is all right or not.
- // Therefore, it's better to exclude getting size from the
- // counting and see build count as real message build.
- if _, _, err := im.getBodyAndStructure(storeMessage, nil); err != nil {
- return 0, err
- }
- }
- return uint32(m.Size), nil
-}
-
-func (im *imapMailbox) getLiteralForSection(
- itemSection imap.FetchItem,
- msg *imap.Message,
- storeMessage storeMessageProvider,
- msgBuildCountHistogram *msgBuildCountHistogram,
-) error {
+func (im *imapMailbox) getLiteralForSection(itemSection imap.FetchItem, msg *imap.Message, storeMessage storeMessageProvider) error {
section, err := imap.ParseBodySectionName(itemSection)
if err != nil {
log.WithError(err).Warn("Failed to parse body section name; part will be skipped")
@@ -124,7 +91,7 @@ func (im *imapMailbox) getLiteralForSection(
}
var literal imap.Literal
- if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil {
+ if literal, err = im.getMessageBodySection(storeMessage, section); err != nil {
return err
}
@@ -149,88 +116,25 @@ func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs *
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude first body structure fetch
// from the counting and see build count as real message build.
- if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil {
+ if bs, _, err = im.getBodyAndStructure(storeMessage); err != nil {
return
}
}
return
}
-func (im *imapMailbox) getBodyAndStructure(
- storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram,
-) (
- structure *message.BodyStructure, bodyReader *bytes.Reader, err error,
-) {
- m := storeMessage.Message()
- id := im.storeUser.UserID() + m.ID
- cache.BuildLock(id)
- defer cache.BuildUnlock(id)
- bodyReader, structure = cache.LoadMail(id)
-
- // return the message which was found in cache
- if bodyReader.Len() != 0 && structure != nil {
- return structure, bodyReader, nil
+func (im *imapMailbox) getBodyAndStructure(storeMessage storeMessageProvider) (*message.BodyStructure, *bytes.Reader, error) {
+ rfc822, err := storeMessage.GetRFC822()
+ if err != nil {
+ return nil, nil, err
}
- structure, body, err := im.buildMessage(m)
- bodyReader = bytes.NewReader(body)
- size := int64(len(body))
- l := im.log.WithField("newSize", size).WithField("msgID", m.ID)
-
- if err != nil || structure == nil || size == 0 {
- l.WithField("hasStructure", structure != nil).Warn("Failed to build message")
- return structure, bodyReader, err
+ structure, err := storeMessage.GetBodyStructure()
+ if err != nil {
+ return nil, nil, err
}
- // Save the size, body structure and header even for messages which
- // were unable to decrypt. Hence they doesn't have to be computed every
- // time.
- m.Size = size
- cacheMessageInStore(storeMessage, structure, body, l)
-
- if msgBuildCountHistogram != nil {
- times, errCount := storeMessage.IncreaseBuildCount()
- if errCount != nil {
- l.WithError(errCount).Warn("Cannot increase build count")
- }
- msgBuildCountHistogram.add(times)
- }
-
- // Drafts can change therefore we don't want to cache them.
- if !isMessageInDraftFolder(m) {
- cache.SaveMail(id, body, structure)
- }
-
- return structure, bodyReader, err
-}
-
-func cacheMessageInStore(storeMessage storeMessageProvider, structure *message.BodyStructure, body []byte, l *logrus.Entry) {
- m := storeMessage.Message()
- if errSize := storeMessage.SetSize(m.Size); errSize != nil {
- l.WithError(errSize).Warn("Cannot update size while building")
- }
- if structure != nil && !isMessageInDraftFolder(m) {
- if errStruct := storeMessage.SetBodyStructure(structure); errStruct != nil {
- l.WithError(errStruct).Warn("Cannot update bodystructure while building")
- }
- }
- header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body))
- if errHead == nil && len(header) != 0 {
- if errStore := storeMessage.SetHeader(header); errStore != nil {
- l.WithError(errStore).Warn("Cannot update header in store")
- }
- } else {
- l.WithError(errHead).Warn("Cannot get header bytes from structure")
- }
-}
-
-func isMessageInDraftFolder(m *pmapi.Message) bool {
- for _, labelID := range m.LabelIDs {
- if labelID == pmapi.DraftLabel {
- return true
- }
- }
- return false
+ return structure, bytes.NewReader(rfc822), nil
}
// This will download message (or read from cache) and pick up the section,
@@ -246,11 +150,7 @@ func isMessageInDraftFolder(m *pmapi.Message) bool {
// For all other cases it is necessary to download and decrypt the message
// and drop the header which was obtained from cache. The header will
// will be stored in DB once successfully built. Check `getBodyAndStructure`.
-func (im *imapMailbox) getMessageBodySection(
- storeMessage storeMessageProvider,
- section *imap.BodySectionName,
- msgBuildCountHistogram *msgBuildCountHistogram,
-) (imap.Literal, error) {
+func (im *imapMailbox) getMessageBodySection(storeMessage storeMessageProvider, section *imap.BodySectionName) (imap.Literal, error) {
var header []byte
var response []byte
@@ -260,7 +160,7 @@ func (im *imapMailbox) getMessageBodySection(
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
header = storeMessage.GetHeader()
} else {
- structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram)
+ structure, bodyReader, err := im.getBodyAndStructure(storeMessage)
if err != nil {
return nil, err
}
@@ -276,7 +176,7 @@ func (im *imapMailbox) getMessageBodySection(
case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part.
fallthrough
case section.Specifier == imap.HeaderSpecifier:
- header, err = structure.GetSectionHeaderBytes(bodyReader, section.Path)
+ header, err = structure.GetSectionHeaderBytes(section.Path)
default:
err = errors.New("Unknown specifier " + string(section.Specifier))
}
@@ -293,30 +193,3 @@ func (im *imapMailbox) getMessageBodySection(
// Trim any output if requested.
return bytes.NewBuffer(section.ExtractPartial(response)), nil
}
-
-// buildMessage from PM to IMAP.
-func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) {
- body, err := im.builder.NewJobWithOptions(
- context.Background(),
- im.user.client(),
- m.ID,
- message.JobOptions{
- IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
- SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
- AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
- AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
- AddMessageDate: true, // Whether to include message time as X-Pm-Date.
- AddMessageIDReference: true, // Whether to include the MessageID in References.
- },
- ).GetResult()
- if err != nil {
- return nil, nil, err
- }
-
- structure, err := message.NewBodyStructure(bytes.NewReader(body))
- if err != nil {
- return nil, nil, err
- }
-
- return structure, body, nil
-}
diff --git a/internal/imap/mailbox_messages.go b/internal/imap/mailbox_messages.go
index c91616ea..a2ca151b 100644
--- a/internal/imap/mailbox_messages.go
+++ b/internal/imap/mailbox_messages.go
@@ -479,11 +479,16 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
}
// Filter by size (only if size was already calculated).
- if m.Size > 0 {
- if criteria.Larger != 0 && m.Size <= int64(criteria.Larger) {
+ size, err := storeMessage.GetRFC822Size()
+ if err != nil {
+ return nil, err
+ }
+
+ if size > 0 {
+ if criteria.Larger != 0 && int64(size) <= int64(criteria.Larger) {
continue
}
- if criteria.Smaller != 0 && m.Size >= int64(criteria.Smaller) {
+ if criteria.Smaller != 0 && int64(size) >= int64(criteria.Smaller) {
continue
}
}
@@ -513,13 +518,12 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
//
// Messages must be sent to msgResponse. When the function returns, msgResponse must be closed.
func (im *imapMailbox) ListMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) error {
- msgBuildCountHistogram := newMsgBuildCountHistogram()
return im.logCommand(func() error {
- return im.listMessages(isUID, seqSet, items, msgResponse, msgBuildCountHistogram)
- }, "FETCH", isUID, seqSet, items, msgBuildCountHistogram)
+ return im.listMessages(isUID, seqSet, items, msgResponse)
+ }, "FETCH", isUID, seqSet, items)
}
-func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message, msgBuildCountHistogram *msgBuildCountHistogram) (err error) { //nolint[funlen]
+func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) (err error) { //nolint[funlen]
defer func() {
close(msgResponse)
if err != nil {
@@ -564,7 +568,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err
}
- msg, err := im.getMessage(storeMessage, items, msgBuildCountHistogram)
+ msg, err := im.getMessage(storeMessage, items)
if err != nil {
err = fmt.Errorf("list message build: %v", err)
l.WithField("metaID", storeMessage.ID()).Error(err)
@@ -594,7 +598,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil
}
- err = parallel.RunParallel(fetchWorkers, input, processCallback, collectCallback)
+ err = parallel.RunParallel(im.user.backend.listWorkers, input, processCallback, collectCallback)
if err != nil {
return err
}
diff --git a/internal/imap/msg_build_counts.go b/internal/imap/msg_build_counts.go
deleted file mode 100644
index 03322cec..00000000
--- a/internal/imap/msg_build_counts.go
+++ /dev/null
@@ -1,65 +0,0 @@
-// Copyright (c) 2021 Proton Technologies AG
-//
-// This file is part of ProtonMail Bridge.
-//
-// ProtonMail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// ProtonMail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with ProtonMail Bridge. If not, see .
-
-package imap
-
-import (
- "fmt"
- "sync"
-)
-
-// msgBuildCountHistogram is used to analyse and log the number of repetitive
-// downloads of requested messages per one fetch. The number of builds per each
-// messageID is stored in persistent database. The msgBuildCountHistogram will
-// take this number for each message in ongoing fetch and create histogram of
-// repeats.
-//
-// Example: During `fetch 1:300` there were
-// - 100 messages were downloaded first time
-// - 100 messages were downloaded second time
-// - 99 messages were downloaded 10th times
-// - 1 messages were downloaded 100th times.
-type msgBuildCountHistogram struct {
- // Key represents how many times message was build.
- // Value stores how many messages are build X times based on the key.
- counts map[uint32]uint32
- lock sync.Locker
-}
-
-func newMsgBuildCountHistogram() *msgBuildCountHistogram {
- return &msgBuildCountHistogram{
- counts: map[uint32]uint32{},
- lock: &sync.Mutex{},
- }
-}
-
-func (c *msgBuildCountHistogram) String() string {
- res := ""
- for nRebuild, counts := range c.counts {
- if res != "" {
- res += ", "
- }
- res += fmt.Sprintf("[%d]:%d", nRebuild, counts)
- }
- return res
-}
-
-func (c *msgBuildCountHistogram) add(nRebuild uint32) {
- c.lock.Lock()
- defer c.lock.Unlock()
- c.counts[nRebuild]++
-}
diff --git a/internal/imap/store.go b/internal/imap/store.go
index 918976be..2c17f5b2 100644
--- a/internal/imap/store.go
+++ b/internal/imap/store.go
@@ -80,7 +80,6 @@ type storeMailboxProvider interface {
GetDelimiter() string
GetMessage(apiID string) (storeMessageProvider, error)
- FetchMessage(apiID string) (storeMessageProvider, error)
LabelMessages(apiID []string) error
UnlabelMessages(apiID []string) error
MarkMessagesRead(apiID []string) error
@@ -100,14 +99,12 @@ type storeMessageProvider interface {
Message() *pmapi.Message
IsMarkedDeleted() bool
- SetSize(int64) error
- SetHeader([]byte) error
GetHeader() []byte
+ GetRFC822() ([]byte, error)
+ GetRFC822Size() (uint32, error)
GetMIMEHeader() textproto.MIMEHeader
IsFullHeaderCached() bool
- SetBodyStructure(*pkgMsg.BodyStructure) error
GetBodyStructure() (*pkgMsg.BodyStructure, error)
- IncreaseBuildCount() (uint32, error)
}
type storeUserWrap struct {
@@ -165,7 +162,3 @@ func newStoreMailboxWrap(mailbox *store.Mailbox) *storeMailboxWrap {
func (s *storeMailboxWrap) GetMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.GetMessage(apiID)
}
-
-func (s *storeMailboxWrap) FetchMessage(apiID string) (storeMessageProvider, error) {
- return s.Mailbox.FetchMessage(apiID)
-}
diff --git a/internal/imap/user.go b/internal/imap/user.go
index 8c4059c7..bf724656 100644
--- a/internal/imap/user.go
+++ b/internal/imap/user.go
@@ -135,7 +135,7 @@ func (iu *imapUser) ListMailboxes(showOnlySubcribed bool) ([]goIMAPBackend.Mailb
if showOnlySubcribed && !iu.isSubscribed(storeMailbox.LabelID()) {
continue
}
- mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder)
+ mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox)
mailboxes = append(mailboxes, mailbox)
}
@@ -167,7 +167,7 @@ func (iu *imapUser) GetMailbox(name string) (mb goIMAPBackend.Mailbox, err error
return
}
- return newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder), nil
+ return newIMAPMailbox(iu.panicHandler, iu, storeMailbox), nil
}
// CreateMailbox creates a new mailbox.
diff --git a/internal/store/cache.go b/internal/store/cache.go
index 76757ca1..9c74552b 100644
--- a/internal/store/cache.go
+++ b/internal/store/cache.go
@@ -18,99 +18,113 @@
package store
import (
- "encoding/json"
- "os"
- "sync"
-
- "github.com/pkg/errors"
+ "github.com/ProtonMail/gopenpgp/v2/crypto"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
+ "github.com/sirupsen/logrus"
+ bolt "go.etcd.io/bbolt"
)
-// Cache caches the last event IDs for all accounts (there should be only one instance).
-type Cache struct {
- // cache is map from userID => key (such as last event) => value (such as event ID).
- cache map[string]map[string]string
- path string
- lock *sync.RWMutex
-}
+const passphraseKey = "passphrase"
-// NewCache constructs a new cache at the given path.
-func NewCache(path string) *Cache {
- return &Cache{
- path: path,
- lock: &sync.RWMutex{},
- }
-}
-
-func (c *Cache) getEventID(userID string) string {
- c.lock.Lock()
- defer c.lock.Unlock()
-
- if err := c.loadCache(); err != nil {
- log.WithError(err).Warn("Problem to load store cache")
- }
-
- if c.cache == nil {
- c.cache = map[string]map[string]string{}
- }
- if c.cache[userID] == nil {
- c.cache[userID] = map[string]string{}
- }
-
- return c.cache[userID]["events"]
-}
-
-func (c *Cache) setEventID(userID, eventID string) error {
- c.lock.Lock()
- defer c.lock.Unlock()
-
- if c.cache[userID] == nil {
- c.cache[userID] = map[string]string{}
- }
- c.cache[userID]["events"] = eventID
-
- return c.saveCache()
-}
-
-func (c *Cache) loadCache() error {
- if c.cache != nil {
- return nil
- }
-
- f, err := os.Open(c.path)
+// UnlockCache unlocks the cache for the user with the given keyring.
+func (store *Store) UnlockCache(kr *crypto.KeyRing) error {
+ passphrase, err := store.getCachePassphrase()
if err != nil {
return err
}
- defer f.Close() //nolint[errcheck]
- return json.NewDecoder(f).Decode(&c.cache)
-}
+ if passphrase == nil {
+ if passphrase, err = crypto.RandomToken(32); err != nil {
+ return err
+ }
-func (c *Cache) saveCache() error {
- if c.cache == nil {
- return errors.New("events: cannot save cache: cache is nil")
+ enc, err := kr.Encrypt(crypto.NewPlainMessage(passphrase), nil)
+ if err != nil {
+ return err
+ }
+
+ if err := store.setCachePassphrase(enc.GetBinary()); err != nil {
+ return err
+ }
+ } else {
+ dec, err := kr.Decrypt(crypto.NewPGPMessage(passphrase), nil, crypto.GetUnixTime())
+ if err != nil {
+ return err
+ }
+
+ passphrase = dec.GetBinary()
}
- f, err := os.Create(c.path)
+ if err := store.cache.Unlock(store.user.ID(), passphrase); err != nil {
+ return err
+ }
+
+ store.cacher.start()
+
+ return nil
+}
+
+func (store *Store) getCachePassphrase() ([]byte, error) {
+ var passphrase []byte
+
+ if err := store.db.View(func(tx *bolt.Tx) error {
+ passphrase = tx.Bucket(cachePassphraseBucket).Get([]byte(passphraseKey))
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+
+ return passphrase, nil
+}
+
+func (store *Store) setCachePassphrase(passphrase []byte) error {
+ return store.db.Update(func(tx *bolt.Tx) error {
+ return tx.Bucket(cachePassphraseBucket).Put([]byte(passphraseKey), passphrase)
+ })
+}
+
+func (store *Store) clearCachePassphrase() error {
+ return store.db.Update(func(tx *bolt.Tx) error {
+ return tx.Bucket(cachePassphraseBucket).Delete([]byte(passphraseKey))
+ })
+}
+
+func (store *Store) getCachedMessage(messageID string) ([]byte, error) {
+ if store.cache.Has(store.user.ID(), messageID) {
+ return store.cache.Get(store.user.ID(), messageID)
+ }
+
+ job, done := store.newBuildJob(messageID, message.ForegroundPriority)
+ defer done()
+
+ literal, err := job.GetResult()
+ if err != nil {
+ return nil, err
+ }
+
+ // NOTE(GODT-1158): No need to block until cache has been set; do this async?
+ if err := store.cache.Set(store.user.ID(), messageID, literal); err != nil {
+ logrus.WithError(err).Error("Failed to cache message")
+ }
+
+ return literal, nil
+}
+
+// IsCached returns whether the given message already exists in the cache.
+func (store *Store) IsCached(messageID string) bool {
+ return store.cache.Has(store.user.ID(), messageID)
+}
+
+// BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache.
+// It builds with background priority.
+func (store *Store) BuildAndCacheMessage(messageID string) error {
+ job, done := store.newBuildJob(messageID, message.BackgroundPriority)
+ defer done()
+
+ literal, err := job.GetResult()
if err != nil {
return err
}
- defer f.Close() //nolint[errcheck]
- return json.NewEncoder(f).Encode(c.cache)
-}
-
-func (c *Cache) clearCacheUser(userID string) error {
- c.lock.Lock()
- defer c.lock.Unlock()
-
- if c.cache == nil {
- log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil")
- return nil
- }
-
- log.WithField("user", userID).Trace("Removing user from event loop cache")
-
- delete(c.cache, userID)
-
- return c.saveCache()
+ return store.cache.Set(store.user.ID(), messageID, literal)
}
diff --git a/internal/store/cache/cache_test.go b/internal/store/cache/cache_test.go
new file mode 100644
index 00000000..efb03ef3
--- /dev/null
+++ b/internal/store/cache/cache_test.go
@@ -0,0 +1,73 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import (
+ "runtime"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOnDiskCacheNoCompression(t *testing.T) {
+ cache, err := NewOnDiskCache(t.TempDir(), &NoopCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
+ require.NoError(t, err)
+
+ testCache(t, cache)
+}
+
+func TestOnDiskCacheGZipCompression(t *testing.T) {
+ cache, err := NewOnDiskCache(t.TempDir(), &GZipCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
+ require.NoError(t, err)
+
+ testCache(t, cache)
+}
+
+func TestInMemoryCache(t *testing.T) {
+ testCache(t, NewInMemoryCache(1<<20))
+}
+
+func testCache(t *testing.T, cache Cache) {
+ assert.NoError(t, cache.Unlock("userID1", []byte("my secret passphrase")))
+ assert.NoError(t, cache.Unlock("userID2", []byte("my other passphrase")))
+
+ getSetCachedMessage(t, cache, "userID1", "messageID1", "some secret")
+ assert.True(t, cache.Has("userID1", "messageID1"))
+
+ getSetCachedMessage(t, cache, "userID2", "messageID2", "some other secret")
+ assert.True(t, cache.Has("userID2", "messageID2"))
+
+ assert.NoError(t, cache.Rem("userID1", "messageID1"))
+ assert.False(t, cache.Has("userID1", "messageID1"))
+
+ assert.NoError(t, cache.Rem("userID2", "messageID2"))
+ assert.False(t, cache.Has("userID2", "messageID2"))
+
+ assert.NoError(t, cache.Delete("userID1"))
+ assert.NoError(t, cache.Delete("userID2"))
+}
+
+func getSetCachedMessage(t *testing.T, cache Cache, userID, messageID, secret string) {
+ assert.NoError(t, cache.Set(userID, messageID, []byte(secret)))
+
+ data, err := cache.Get(userID, messageID)
+ assert.NoError(t, err)
+
+ assert.Equal(t, []byte(secret), data)
+}
diff --git a/internal/store/cache/compressor.go b/internal/store/cache/compressor.go
new file mode 100644
index 00000000..c252a739
--- /dev/null
+++ b/internal/store/cache/compressor.go
@@ -0,0 +1,33 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+type Compressor interface {
+ Compress([]byte) ([]byte, error)
+ Decompress([]byte) ([]byte, error)
+}
+
+type NoopCompressor struct{}
+
+func (NoopCompressor) Compress(dec []byte) ([]byte, error) {
+ return dec, nil
+}
+
+func (NoopCompressor) Decompress(cmp []byte) ([]byte, error) {
+ return cmp, nil
+}
diff --git a/internal/store/cache/compressor_gzip.go b/internal/store/cache/compressor_gzip.go
new file mode 100644
index 00000000..412de937
--- /dev/null
+++ b/internal/store/cache/compressor_gzip.go
@@ -0,0 +1,60 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import (
+ "bytes"
+ "compress/gzip"
+)
+
+type GZipCompressor struct{}
+
+func (GZipCompressor) Compress(dec []byte) ([]byte, error) {
+ buf := new(bytes.Buffer)
+
+ zw := gzip.NewWriter(buf)
+
+ if _, err := zw.Write(dec); err != nil {
+ return nil, err
+ }
+
+ if err := zw.Close(); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+func (GZipCompressor) Decompress(cmp []byte) ([]byte, error) {
+ zr, err := gzip.NewReader(bytes.NewReader(cmp))
+ if err != nil {
+ return nil, err
+ }
+
+ buf := new(bytes.Buffer)
+
+ if _, err := buf.ReadFrom(zr); err != nil {
+ return nil, err
+ }
+
+ if err := zr.Close(); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
diff --git a/internal/store/cache/disk.go b/internal/store/cache/disk.go
new file mode 100644
index 00000000..b4b3ee33
--- /dev/null
+++ b/internal/store/cache/disk.go
@@ -0,0 +1,244 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "crypto/sha256"
+ "errors"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/ProtonMail/proton-bridge/pkg/semaphore"
+ "github.com/ricochet2200/go-disk-usage/du"
+)
+
+var ErrLowSpace = errors.New("not enough free space left on device")
+
+type onDiskCache struct {
+ path string
+ opts Options
+
+ gcm map[string]cipher.AEAD
+ cmp Compressor
+ rsem, wsem semaphore.Semaphore
+ pending *pending
+
+ diskSize uint64
+ diskFree uint64
+ once *sync.Once
+ lock sync.Mutex
+}
+
+func NewOnDiskCache(path string, cmp Compressor, opts Options) (Cache, error) {
+ if err := os.MkdirAll(path, 0700); err != nil {
+ return nil, err
+ }
+
+ usage := du.NewDiskUsage(path)
+
+ // NOTE(GODT-1158): use Available() or Free()?
+ return &onDiskCache{
+ path: path,
+ opts: opts,
+
+ gcm: make(map[string]cipher.AEAD),
+ cmp: cmp,
+ rsem: semaphore.New(opts.ConcurrentRead),
+ wsem: semaphore.New(opts.ConcurrentWrite),
+ pending: newPending(),
+
+ diskSize: usage.Size(),
+ diskFree: usage.Available(),
+ once: &sync.Once{},
+ }, nil
+}
+
+func (c *onDiskCache) Unlock(userID string, passphrase []byte) error {
+ hash := sha256.New()
+
+ if _, err := hash.Write(passphrase); err != nil {
+ return err
+ }
+
+ aes, err := aes.NewCipher(hash.Sum(nil))
+ if err != nil {
+ return err
+ }
+
+ gcm, err := cipher.NewGCM(aes)
+ if err != nil {
+ return err
+ }
+
+ if err := os.MkdirAll(c.getUserPath(userID), 0700); err != nil {
+ return err
+ }
+
+ c.gcm[userID] = gcm
+
+ return nil
+}
+
+func (c *onDiskCache) Delete(userID string) error {
+ defer c.update()
+
+ return os.RemoveAll(c.getUserPath(userID))
+}
+
+// Has returns whether the given message exists in the cache.
+func (c *onDiskCache) Has(userID, messageID string) bool {
+ c.pending.wait(c.getMessagePath(userID, messageID))
+
+ c.rsem.Lock()
+ defer c.rsem.Unlock()
+
+ _, err := os.Stat(c.getMessagePath(userID, messageID))
+
+ switch {
+ case err == nil:
+ return true
+
+ case os.IsNotExist(err):
+ return false
+
+ default:
+ panic(err)
+ }
+}
+
+func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) {
+ enc, err := c.readFile(c.getMessagePath(userID, messageID))
+ if err != nil {
+ return nil, err
+ }
+
+ cmp, err := c.gcm[userID].Open(nil, enc[:c.gcm[userID].NonceSize()], enc[c.gcm[userID].NonceSize():], nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return c.cmp.Decompress(cmp)
+}
+
+func (c *onDiskCache) Set(userID, messageID string, literal []byte) error {
+ nonce := make([]byte, c.gcm[userID].NonceSize())
+
+ if _, err := rand.Read(nonce); err != nil {
+ return err
+ }
+
+ cmp, err := c.cmp.Compress(literal)
+ if err != nil {
+ return err
+ }
+
+ // NOTE(GODT-1158): How to properly handle low space? Don't return error, that's bad. Instead send event?
+ if !c.hasSpace(len(cmp)) {
+ return nil
+ }
+
+ return c.writeFile(c.getMessagePath(userID, messageID), c.gcm[userID].Seal(nonce, nonce, cmp, nil))
+}
+
+func (c *onDiskCache) Rem(userID, messageID string) error {
+ defer c.update()
+
+ return os.Remove(c.getMessagePath(userID, messageID))
+}
+
+func (c *onDiskCache) readFile(path string) ([]byte, error) {
+ c.rsem.Lock()
+ defer c.rsem.Unlock()
+
+ // Wait before reading in case the file is currently being written.
+ c.pending.wait(path)
+
+ return ioutil.ReadFile(filepath.Clean(path))
+}
+
+func (c *onDiskCache) writeFile(path string, b []byte) error {
+ c.wsem.Lock()
+ defer c.wsem.Unlock()
+
+ // Mark the file as currently being written.
+ // If it's already being written, wait for it to be done and return nil.
+ // NOTE(GODT-1158): Let's hope it succeeded...
+ if ok := c.pending.add(path); !ok {
+ c.pending.wait(path)
+ return nil
+ }
+ defer c.pending.done(path)
+
+ // Reduce the approximate free space (update it exactly later).
+ c.lock.Lock()
+ c.diskFree -= uint64(len(b))
+ c.lock.Unlock()
+
+ // Update the diskFree eventually.
+ defer c.update()
+
+ // NOTE(GODT-1158): What happens when this fails? Should be fixed eventually.
+ return ioutil.WriteFile(filepath.Clean(path), b, 0600)
+}
+
+func (c *onDiskCache) hasSpace(size int) bool {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.opts.MinFreeAbs > 0 {
+ if c.diskFree-uint64(size) < c.opts.MinFreeAbs {
+ return false
+ }
+ }
+
+ if c.opts.MinFreeRat > 0 {
+ if float64(c.diskFree-uint64(size))/float64(c.diskSize) < c.opts.MinFreeRat {
+ return false
+ }
+ }
+
+ return true
+}
+
+func (c *onDiskCache) update() {
+ go func() {
+ c.once.Do(func() {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ // Update the free space.
+ c.diskFree = du.NewDiskUsage(c.path).Available()
+
+ // Reset the Once object (so we can update again).
+ c.once = &sync.Once{}
+ })
+ }()
+}
+
+func (c *onDiskCache) getUserPath(userID string) string {
+ return filepath.Join(c.path, getHash(userID))
+}
+
+func (c *onDiskCache) getMessagePath(userID, messageID string) string {
+ return filepath.Join(c.getUserPath(userID), getHash(messageID))
+}
diff --git a/internal/store/cache/hash.go b/internal/store/cache/hash.go
new file mode 100644
index 00000000..fc07771b
--- /dev/null
+++ b/internal/store/cache/hash.go
@@ -0,0 +1,33 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+)
+
+func getHash(name string) string {
+ hash := sha256.New()
+
+ if _, err := hash.Write([]byte(name)); err != nil {
+ panic(err)
+ }
+
+ return hex.EncodeToString(hash.Sum(nil))
+}
diff --git a/internal/store/cache/memory.go b/internal/store/cache/memory.go
new file mode 100644
index 00000000..d364e5f9
--- /dev/null
+++ b/internal/store/cache/memory.go
@@ -0,0 +1,104 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import (
+ "errors"
+ "sync"
+)
+
+type inMemoryCache struct {
+ lock sync.RWMutex
+ data map[string]map[string][]byte
+ size, limit int
+}
+
+// NewInMemoryCache creates a new in memory cache which stores up to the given number of bytes of cached data.
+// NOTE(GODT-1158): Make this threadsafe.
+func NewInMemoryCache(limit int) Cache {
+ return &inMemoryCache{
+ data: make(map[string]map[string][]byte),
+ limit: limit,
+ }
+}
+
+func (c *inMemoryCache) Unlock(userID string, passphrase []byte) error {
+ c.data[userID] = make(map[string][]byte)
+ return nil
+}
+
+func (c *inMemoryCache) Delete(userID string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ for _, message := range c.data[userID] {
+ c.size -= len(message)
+ }
+
+ delete(c.data, userID)
+
+ return nil
+}
+
+// Has returns whether the given message exists in the cache.
+func (c *inMemoryCache) Has(userID, messageID string) bool {
+ if _, err := c.Get(userID, messageID); err != nil {
+ return false
+ }
+
+ return true
+}
+
+func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+
+ literal, ok := c.data[userID][messageID]
+ if !ok {
+ return nil, errors.New("no such message in cache")
+ }
+
+ return literal, nil
+}
+
+// NOTE(GODT-1158): What to actually do when memory limit is reached? Replace something existing? Return error? Drop silently?
+// NOTE(GODT-1158): Pull in cache-rotating feature from old IMAP cache.
+func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.size+len(literal) > c.limit {
+ return nil
+ }
+
+ c.size += len(literal)
+ c.data[userID][messageID] = literal
+
+ return nil
+}
+
+func (c *inMemoryCache) Rem(userID, messageID string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ c.size -= len(c.data[userID][messageID])
+
+ delete(c.data[userID], messageID)
+
+ return nil
+}
diff --git a/internal/store/cache/options.go b/internal/store/cache/options.go
new file mode 100644
index 00000000..a2dd9cb3
--- /dev/null
+++ b/internal/store/cache/options.go
@@ -0,0 +1,25 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+type Options struct {
+ MinFreeAbs uint64
+ MinFreeRat float64
+ ConcurrentRead int
+ ConcurrentWrite int
+}
diff --git a/internal/store/cache/pending.go b/internal/store/cache/pending.go
new file mode 100644
index 00000000..29314472
--- /dev/null
+++ b/internal/store/cache/pending.go
@@ -0,0 +1,61 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import "sync"
+
+type pending struct {
+ lock sync.Mutex
+ path map[string]chan struct{}
+}
+
+func newPending() *pending {
+ return &pending{path: make(map[string]chan struct{})}
+}
+
+func (p *pending) add(path string) bool {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ if _, ok := p.path[path]; ok {
+ return false
+ }
+
+ p.path[path] = make(chan struct{})
+
+ return true
+}
+
+func (p *pending) wait(path string) {
+ p.lock.Lock()
+ ch, ok := p.path[path]
+ p.lock.Unlock()
+
+ if ok {
+ <-ch
+ }
+}
+
+func (p *pending) done(path string) {
+ p.lock.Lock()
+ defer p.lock.Unlock()
+
+ defer close(p.path[path])
+
+ delete(p.path, path)
+}
diff --git a/internal/store/cache/pending_test.go b/internal/store/cache/pending_test.go
new file mode 100644
index 00000000..82bc2979
--- /dev/null
+++ b/internal/store/cache/pending_test.go
@@ -0,0 +1,51 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestPending(t *testing.T) {
+ pending := newPending()
+
+ pending.add("1")
+ pending.add("2")
+ pending.add("3")
+
+ resCh := make(chan string)
+
+ go func() { pending.wait("1"); resCh <- "1" }()
+ go func() { pending.wait("2"); resCh <- "2" }()
+ go func() { pending.wait("3"); resCh <- "3" }()
+
+ pending.done("1")
+ assert.Equal(t, "1", <-resCh)
+
+ pending.done("2")
+ assert.Equal(t, "2", <-resCh)
+
+ pending.done("3")
+ assert.Equal(t, "3", <-resCh)
+}
+
+func TestPendingUnknown(t *testing.T) {
+ newPending().wait("this is not currently being waited")
+}
diff --git a/internal/store/cache/types.go b/internal/store/cache/types.go
new file mode 100644
index 00000000..a1c7c9db
--- /dev/null
+++ b/internal/store/cache/types.go
@@ -0,0 +1,28 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package cache
+
+type Cache interface {
+ Unlock(userID string, passphrase []byte) error
+ Delete(userID string) error
+
+ Has(userID, messageID string) bool
+ Get(userID, messageID string) ([]byte, error)
+ Set(userID, messageID string, literal []byte) error
+ Rem(userID, messageID string) error
+}
diff --git a/internal/store/cache_watcher.go b/internal/store/cache_watcher.go
new file mode 100644
index 00000000..c8f3811a
--- /dev/null
+++ b/internal/store/cache_watcher.go
@@ -0,0 +1,63 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package store
+
+import "time"
+
+func (store *Store) StartWatcher() {
+ store.done = make(chan struct{})
+
+ go func() {
+ ticker := time.NewTicker(3 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ // NOTE(GODT-1158): Race condition here? What if DB was already closed?
+ messageIDs, err := store.getAllMessageIDs()
+ if err != nil {
+ return
+ }
+
+ for _, messageID := range messageIDs {
+ if !store.IsCached(messageID) {
+ store.cacher.newJob(messageID)
+ }
+ }
+
+ case <-store.done:
+ return
+ }
+ }
+ }()
+}
+
+func (store *Store) stopWatcher() {
+ if store.done == nil {
+ return
+ }
+
+ select {
+ default:
+ close(store.done)
+
+ case <-store.done:
+ return
+ }
+}
diff --git a/internal/store/cache_worker.go b/internal/store/cache_worker.go
new file mode 100644
index 00000000..68173428
--- /dev/null
+++ b/internal/store/cache_worker.go
@@ -0,0 +1,104 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package store
+
+import (
+ "sync"
+
+ "github.com/sirupsen/logrus"
+)
+
+type Cacher struct {
+ storer Storer
+ jobs chan string
+ done chan struct{}
+ started bool
+ wg *sync.WaitGroup
+}
+
+type Storer interface {
+ IsCached(messageID string) bool
+ BuildAndCacheMessage(messageID string) error
+}
+
+func newCacher(storer Storer) *Cacher {
+ return &Cacher{
+ storer: storer,
+ jobs: make(chan string),
+ done: make(chan struct{}),
+ wg: &sync.WaitGroup{},
+ }
+}
+
+// newJob sends a new job to the cacher if it's running.
+func (cacher *Cacher) newJob(messageID string) {
+ if !cacher.started {
+ return
+ }
+
+ select {
+ case <-cacher.done:
+ return
+
+ default:
+ if !cacher.storer.IsCached(messageID) {
+ cacher.wg.Add(1)
+ go func() { cacher.jobs <- messageID }()
+ }
+ }
+}
+
+func (cacher *Cacher) start() {
+ cacher.started = true
+
+ go func() {
+ for {
+ select {
+ case messageID := <-cacher.jobs:
+ go cacher.handleJob(messageID)
+
+ case <-cacher.done:
+ return
+ }
+ }
+ }()
+}
+
+func (cacher *Cacher) handleJob(messageID string) {
+ defer cacher.wg.Done()
+
+ if err := cacher.storer.BuildAndCacheMessage(messageID); err != nil {
+ logrus.WithError(err).Error("Failed to build and cache message")
+ } else {
+ logrus.WithField("messageID", messageID).Trace("Message cached")
+ }
+}
+
+func (cacher *Cacher) stop() {
+ cacher.started = false
+
+ cacher.wg.Wait()
+
+ select {
+ case <-cacher.done:
+ return
+
+ default:
+ close(cacher.done)
+ }
+}
diff --git a/internal/store/cache_worker_test.go b/internal/store/cache_worker_test.go
new file mode 100644
index 00000000..6bd6b121
--- /dev/null
+++ b/internal/store/cache_worker_test.go
@@ -0,0 +1,103 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package store
+
+import (
+ "testing"
+
+ storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
+ "github.com/golang/mock/gomock"
+ "github.com/pkg/errors"
+)
+
+func withTestCacher(t *testing.T, doTest func(storer *storemocks.MockStorer, cacher *Cacher)) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ // Mock storer used to build/cache messages.
+ storer := storemocks.NewMockStorer(ctrl)
+
+ // Create a new cacher pointing to the fake store.
+ cacher := newCacher(storer)
+
+ // Start the cacher and wait for it to stop.
+ cacher.start()
+ defer cacher.stop()
+
+ doTest(storer, cacher)
+}
+
+func TestCacher(t *testing.T) {
+ // If the message is not yet cached, we should expect to try to build and cache it.
+ withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
+ storer.EXPECT().IsCached("messageID").Return(false)
+ storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
+ cacher.newJob("messageID")
+ })
+}
+
+func TestCacherAlreadyCached(t *testing.T) {
+ // If the message is already cached, we should not try to build it.
+ withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
+ storer.EXPECT().IsCached("messageID").Return(true)
+ cacher.newJob("messageID")
+ })
+}
+
+func TestCacherFail(t *testing.T) {
+ // If building the message fails, we should not try to cache it.
+ withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
+ storer.EXPECT().IsCached("messageID").Return(false)
+ storer.EXPECT().BuildAndCacheMessage("messageID").Return(errors.New("failed to build message"))
+ cacher.newJob("messageID")
+ })
+}
+
+func TestCacherStop(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ // Mock storer used to build/cache messages.
+ storer := storemocks.NewMockStorer(ctrl)
+
+ // Create a new cacher pointing to the fake store.
+ cacher := newCacher(storer)
+
+ // Start the cacher.
+ cacher.start()
+
+ // Send a job -- this should succeed.
+ storer.EXPECT().IsCached("messageID").Return(false)
+ storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
+ cacher.newJob("messageID")
+
+ // Stop the cacher.
+ cacher.stop()
+
+ // Send more jobs -- these should all be dropped.
+ cacher.newJob("messageID2")
+ cacher.newJob("messageID3")
+ cacher.newJob("messageID4")
+ cacher.newJob("messageID5")
+
+ // Stopping the cacher multiple times is safe.
+ cacher.stop()
+ cacher.stop()
+ cacher.stop()
+ cacher.stop()
+}
diff --git a/internal/store/change_test.go b/internal/store/change_test.go
index c4866d58..7fba37bc 100644
--- a/internal/store/change_test.go
+++ b/internal/store/change_test.go
@@ -34,7 +34,7 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@@ -49,7 +49,7 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@@ -61,7 +61,7 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
diff --git a/internal/store/event_loop.go b/internal/store/event_loop.go
index 50453c9a..acdc469f 100644
--- a/internal/store/event_loop.go
+++ b/internal/store/event_loop.go
@@ -38,7 +38,7 @@ const (
)
type eventLoop struct {
- cache *Cache
+ currentEvents *Events
currentEventID string
currentEvent *pmapi.Event
pollCh chan chan struct{}
@@ -51,26 +51,26 @@ type eventLoop struct {
log *logrus.Entry
- store *Store
- user BridgeUser
- events listener.Listener
+ store *Store
+ user BridgeUser
+ listener listener.Listener
}
-func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop {
+func newEventLoop(currentEvents *Events, store *Store, user BridgeUser, listener listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop")
return &eventLoop{
- cache: cache,
- currentEventID: cache.getEventID(user.ID()),
+ currentEvents: currentEvents,
+ currentEventID: currentEvents.getEventID(user.ID()),
pollCh: make(chan chan struct{}),
isRunning: false,
log: eventLog,
- store: store,
- user: user,
- events: events,
+ store: store,
+ user: user,
+ listener: listener,
}
}
@@ -89,7 +89,7 @@ func (loop *eventLoop) setFirstEventID() (err error) {
loop.currentEventID = event.EventID
- if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
+ if err = loop.currentEvents.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
loop.log.WithError(err).Error("Could not set latest event ID in user cache")
return
}
@@ -229,7 +229,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
if err != nil && isFdCloseToULimit() {
l.Warn("Ulimit reached")
- loop.events.Emit(bridgeEvents.RestartBridgeEvent, "")
+ loop.listener.Emit(bridgeEvents.RestartBridgeEvent, "")
err = nil
}
@@ -291,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// This allows the event loop to continue to function (unless the cache was broken
// and bridge stopped, in which case it will start from the old event ID anyway).
loop.currentEventID = event.EventID
- if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil {
+ if err = loop.currentEvents.setEventID(loop.user.ID(), event.EventID); err != nil {
return false, errors.Wrap(err, "failed to save event ID to cache")
}
}
@@ -371,7 +371,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
switch addressEvent.Action {
case pmapi.EventCreate:
log.WithField("email", addressEvent.Address.Email).Debug("Address was created")
- loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
+ loop.listener.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
case pmapi.EventUpdate:
oldAddress := oldList.ByID(addressEvent.ID)
@@ -383,7 +383,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email
log.WithField("email", email).Debug("Address was updated")
if addressEvent.Address.Receive != oldAddress.Receive {
- loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
+ loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
}
case pmapi.EventDelete:
@@ -396,7 +396,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email
log.WithField("email", email).Debug("Address was deleted")
loop.user.CloseConnection(email)
- loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
+ loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
case pmapi.EventUpdateFlags:
log.Error("EventUpdateFlags for address event is uknown operation")
}
diff --git a/internal/store/event_loop_test.go b/internal/store/event_loop_test.go
index 845a03a6..3e0e23d1 100644
--- a/internal/store/event_loop_test.go
+++ b/internal/store/event_loop_test.go
@@ -53,7 +53,7 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
More: false,
}, nil),
)
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
// Event loop runs in goroutine started during store creation (newStoreNoEvents).
// Force to run the next event.
@@ -78,7 +78,7 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
subject := "old subject"
newSubject := "new subject"
- m.newStoreNoEvents(true, &pmapi.Message{
+ m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: subject,
})
@@ -106,7 +106,7 @@ func TestEventLoopDeletionNotPaused(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true, &pmapi.Message{
+ m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
LabelIDs: []string{"label"},
@@ -133,7 +133,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true, &pmapi.Message{
+ m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
LabelIDs: []string{"label"},
diff --git a/internal/store/events.go b/internal/store/events.go
new file mode 100644
index 00000000..0e3d6804
--- /dev/null
+++ b/internal/store/events.go
@@ -0,0 +1,116 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package store
+
+import (
+ "encoding/json"
+ "os"
+ "sync"
+
+ "github.com/pkg/errors"
+)
+
+// Events caches the last event IDs for all accounts (there should be only one instance).
+type Events struct {
+ // eventMap is map from userID => key (such as last event) => value (such as event ID).
+ eventMap map[string]map[string]string
+ path string
+ lock *sync.RWMutex
+}
+
+// NewEvents constructs a new event cache at the given path.
+func NewEvents(path string) *Events {
+ return &Events{
+ path: path,
+ lock: &sync.RWMutex{},
+ }
+}
+
+func (c *Events) getEventID(userID string) string {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if err := c.loadEvents(); err != nil {
+ log.WithError(err).Warn("Problem to load store events")
+ }
+
+ if c.eventMap == nil {
+ c.eventMap = map[string]map[string]string{}
+ }
+ if c.eventMap[userID] == nil {
+ c.eventMap[userID] = map[string]string{}
+ }
+
+ return c.eventMap[userID]["events"]
+}
+
+func (c *Events) setEventID(userID, eventID string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.eventMap[userID] == nil {
+ c.eventMap[userID] = map[string]string{}
+ }
+ c.eventMap[userID]["events"] = eventID
+
+ return c.saveEvents()
+}
+
+func (c *Events) loadEvents() error {
+ if c.eventMap != nil {
+ return nil
+ }
+
+ f, err := os.Open(c.path)
+ if err != nil {
+ return err
+ }
+ defer f.Close() //nolint[errcheck]
+
+ return json.NewDecoder(f).Decode(&c.eventMap)
+}
+
+func (c *Events) saveEvents() error {
+ if c.eventMap == nil {
+ return errors.New("events: cannot save events: events map is nil")
+ }
+
+ f, err := os.Create(c.path)
+ if err != nil {
+ return err
+ }
+ defer f.Close() //nolint[errcheck]
+
+ return json.NewEncoder(f).Encode(c.eventMap)
+}
+
+func (c *Events) clearUserEvents(userID string) error {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
+ if c.eventMap == nil {
+ log.WithField("user", userID).Warning("Cannot clear user events: event map is nil")
+ return nil
+ }
+
+ log.WithField("user", userID).Trace("Removing user events from event loop")
+
+ delete(c.eventMap, userID)
+
+ return c.saveEvents()
+}
diff --git a/internal/store/mailbox_counts_test.go b/internal/store/mailbox_counts_test.go
index 21f3c374..c3352823 100644
--- a/internal/store/mailbox_counts_test.go
+++ b/internal/store/mailbox_counts_test.go
@@ -107,7 +107,7 @@ func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Sto
func TestMailboxCountRemove(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
testCounts := []*pmapi.MessagesCount{
{LabelID: "label1", Total: 100, Unread: 0},
diff --git a/internal/store/mailbox_ids_test.go b/internal/store/mailbox_ids_test.go
index 77cd76cb..1541dff4 100644
--- a/internal/store/mailbox_ids_test.go
+++ b/internal/store/mailbox_ids_test.go
@@ -35,7 +35,7 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@@ -80,7 +80,7 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
diff --git a/internal/store/message.go b/internal/store/message.go
index fe05609e..e12dbda6 100644
--- a/internal/store/message.go
+++ b/internal/store/message.go
@@ -67,40 +67,19 @@ func (message *Message) Message() *pmapi.Message {
return message.msg
}
-// IsMarkedDeleted returns true if message is marked as deleted for specific
-// mailbox.
+// IsMarkedDeleted returns true if message is marked as deleted for specific mailbox.
func (message *Message) IsMarkedDeleted() bool {
- isMarkedAsDeleted := false
- err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
+ var isMarkedAsDeleted bool
+
+ if err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
isMarkedAsDeleted = message.storeMailbox.txGetDeletedIDsBucket(tx).Get([]byte(message.msg.ID)) != nil
return nil
- })
- if err != nil {
+ }); err != nil {
message.storeMailbox.log.WithError(err).Error("Not able to retrieve deleted mark, assuming false.")
return false
}
- return isMarkedAsDeleted
-}
-// SetSize updates the information about size of decrypted message which can be
-// used for IMAP. This should not trigger any IMAP update.
-// NOTE: The size from the server corresponds to pure body bytes. Hence it
-// should not be used. The correct size has to be calculated from decrypted and
-// built message.
-func (message *Message) SetSize(size int64) error {
- message.msg.Size = size
- txUpdate := func(tx *bolt.Tx) error {
- stored, err := message.store.txGetMessage(tx, message.msg.ID)
- if err != nil {
- return err
- }
- stored.Size = size
- return message.store.txPutMessage(
- tx.Bucket(metadataBucket),
- stored,
- )
- }
- return message.store.db.Update(txUpdate)
+ return isMarkedAsDeleted
}
// SetContentTypeAndHeader updates the information about content type and
@@ -112,7 +91,7 @@ func (message *Message) SetSize(size int64) error {
func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error {
message.msg.MIMEType = mimeType
message.msg.Header = header
- txUpdate := func(tx *bolt.Tx) error {
+ return message.store.db.Update(func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
@@ -123,34 +102,26 @@ func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Hea
tx.Bucket(metadataBucket),
stored,
)
- }
- return message.store.db.Update(txUpdate)
-}
-
-// SetHeader checks header can be parsed and if yes it stores header bytes in
-// database.
-func (message *Message) SetHeader(header []byte) error {
- _, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(header))).ReadMIMEHeader()
- if err != nil {
- return err
- }
- return message.store.db.Update(func(tx *bolt.Tx) error {
- return tx.Bucket(headersBucket).Put([]byte(message.ID()), header)
})
}
// IsFullHeaderCached will check that valid full header is stored in DB.
func (message *Message) IsFullHeaderCached() bool {
- header, err := message.getRawHeader()
- return err == nil && header != nil
-}
-
-func (message *Message) getRawHeader() (raw []byte, err error) {
- err = message.store.db.View(func(tx *bolt.Tx) error {
- raw = tx.Bucket(headersBucket).Get([]byte(message.ID()))
+ var raw []byte
+ err := message.store.db.View(func(tx *bolt.Tx) error {
+ raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
return nil
})
- return
+ return err == nil && raw != nil
+}
+
+func (message *Message) getRawHeader() ([]byte, error) {
+ bs, err := message.GetBodyStructure()
+ if err != nil {
+ return nil, err
+ }
+
+ return bs.GetMailHeaderBytes()
}
// GetHeader will return cached header from DB.
@@ -178,44 +149,79 @@ func (message *Message) GetMIMEHeader() textproto.MIMEHeader {
return header
}
-// SetBodyStructure stores serialized body structure in database.
-func (message *Message) SetBodyStructure(bs *pkgMsg.BodyStructure) error {
- txUpdate := func(tx *bolt.Tx) error {
- return message.store.txPutBodyStructure(
- tx.Bucket(bodystructureBucket),
- message.ID(), bs,
- )
- }
- return message.store.db.Update(txUpdate)
-}
+// GetBodyStructure returns the message's body structure.
+// It checks first if it's in the store. If it is, it returns it from store,
+// otherwise it computes it from the message cache (and saves the result to the store).
+func (message *Message) GetBodyStructure() (*pkgMsg.BodyStructure, error) {
+ var raw []byte
-// GetBodyStructure deserializes body structure from database. If body structure
-// is not in database it returns nil error and nil body structure. If error
-// occurs it returns nil body structure.
-func (message *Message) GetBodyStructure() (bs *pkgMsg.BodyStructure, err error) {
- txRead := func(tx *bolt.Tx) error {
- bs, err = message.store.txGetBodyStructure(
- tx.Bucket(bodystructureBucket),
- message.ID(),
- )
- return err
- }
- if err = message.store.db.View(txRead); err != nil {
+ if err := message.store.db.View(func(tx *bolt.Tx) error {
+ raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
+ return nil
+ }); err != nil {
return nil, err
}
+
+ if len(raw) > 0 {
+ // If not possible to deserialize just continue with build.
+ if bs, err := pkgMsg.DeserializeBodyStructure(raw); err == nil {
+ return bs, nil
+ }
+ }
+
+ literal, err := message.store.getCachedMessage(message.ID())
+ if err != nil {
+ return nil, err
+ }
+
+ bs, err := pkgMsg.NewBodyStructure(bytes.NewReader(literal))
+ if err != nil {
+ return nil, err
+ }
+
+ if raw, err = bs.Serialize(); err != nil {
+ return nil, err
+ }
+
+ if err := message.store.db.Update(func(tx *bolt.Tx) error {
+ return tx.Bucket(bodystructureBucket).Put([]byte(message.ID()), raw)
+ }); err != nil {
+ return nil, err
+ }
+
return bs, nil
}
-func (message *Message) IncreaseBuildCount() (times uint32, err error) {
- txUpdate := func(tx *bolt.Tx) error {
- times, err = message.store.txIncreaseMsgBuildCount(
- tx.Bucket(msgBuildCountBucket),
- message.ID(),
- )
- return err
- }
- if err = message.store.db.Update(txUpdate); err != nil {
+// GetRFC822 returns the raw message literal.
+func (message *Message) GetRFC822() ([]byte, error) {
+ return message.store.getCachedMessage(message.ID())
+}
+
+// GetRFC822Size returns the size of the raw message literal.
+func (message *Message) GetRFC822Size() (uint32, error) {
+ var raw []byte
+
+ if err := message.store.db.View(func(tx *bolt.Tx) error {
+ raw = tx.Bucket(sizeBucket).Get([]byte(message.ID()))
+ return nil
+ }); err != nil {
return 0, err
}
- return times, nil
+
+ if len(raw) > 0 {
+ return btoi(raw), nil
+ }
+
+ literal, err := message.store.getCachedMessage(message.ID())
+ if err != nil {
+ return 0, err
+ }
+
+ if err := message.store.db.Update(func(tx *bolt.Tx) error {
+ return tx.Bucket(sizeBucket).Put([]byte(message.ID()), itob(uint32(len(literal))))
+ }); err != nil {
+ return 0, err
+ }
+
+ return uint32(len(literal)), nil
}
diff --git a/internal/store/mocks/mocks.go b/internal/store/mocks/mocks.go
index 7adf0ed3..14fb5831 100644
--- a/internal/store/mocks/mocks.go
+++ b/internal/store/mocks/mocks.go
@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
+// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier,Storer)
// Package mocks is a generated GoMock package.
package mocks
@@ -318,3 +318,54 @@ func (mr *MockChangeNotifierMockRecorder) UpdateMessage(arg0, arg1, arg2, arg3,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMessage", reflect.TypeOf((*MockChangeNotifier)(nil).UpdateMessage), arg0, arg1, arg2, arg3, arg4, arg5)
}
+
+// MockStorer is a mock of Storer interface.
+type MockStorer struct {
+ ctrl *gomock.Controller
+ recorder *MockStorerMockRecorder
+}
+
+// MockStorerMockRecorder is the mock recorder for MockStorer.
+type MockStorerMockRecorder struct {
+ mock *MockStorer
+}
+
+// NewMockStorer creates a new mock instance.
+func NewMockStorer(ctrl *gomock.Controller) *MockStorer {
+ mock := &MockStorer{ctrl: ctrl}
+ mock.recorder = &MockStorerMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockStorer) EXPECT() *MockStorerMockRecorder {
+ return m.recorder
+}
+
+// BuildAndCacheMessage mocks base method.
+func (m *MockStorer) BuildAndCacheMessage(arg0 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "BuildAndCacheMessage", arg0)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// BuildAndCacheMessage indicates an expected call of BuildAndCacheMessage.
+func (mr *MockStorerMockRecorder) BuildAndCacheMessage(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildAndCacheMessage", reflect.TypeOf((*MockStorer)(nil).BuildAndCacheMessage), arg0)
+}
+
+// IsCached mocks base method.
+func (m *MockStorer) IsCached(arg0 string) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IsCached", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IsCached indicates an expected call of IsCached.
+func (mr *MockStorerMockRecorder) IsCached(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStorer)(nil).IsCached), arg0)
+}
diff --git a/internal/store/store.go b/internal/store/store.go
index ac960e1f..03de5d0f 100644
--- a/internal/store/store.go
+++ b/internal/store/store.go
@@ -26,8 +26,11 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/sentry"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/pkg/listener"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
+ "github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@@ -52,19 +55,21 @@ var (
// Database structure:
// * metadata
- // * {messageID} -> message data (subject, from, to, time, body size, ...)
+ // * {messageID} -> message data (subject, from, to, time, ...)
// * headers
// * {messageID} -> header bytes
// * bodystructure
// * {messageID} -> message body structure
- // * msgbuildcount
- // * {messageID} -> uint32 number of message builds to track re-sync issues
+ // * size
+ // * {messageID} -> uint32 value
// * counts
// * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive
// * address_info
// * {index} -> {address, addressID}
// * address_mode
// * mode -> string split or combined
+ // * cache_passphrase
+ // * passphrase -> cache passphrase (pgp encrypted message)
// * mailboxes_version
// * version -> uint32 value
// * sync_state
@@ -79,19 +84,20 @@ var (
// * {messageID} -> uint32 imapUID
// * deleted_ids (can be missing or have no keys)
// * {messageID} -> true
- metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
- headersBucket = []byte("headers") //nolint[gochecknoglobals]
- bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
- msgBuildCountBucket = []byte("msgbuildcount") //nolint[gochecknoglobals]
- countsBucket = []byte("counts") //nolint[gochecknoglobals]
- addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
- addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
- syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
- mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
- imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
- apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
- deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
- mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
+ metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
+ headersBucket = []byte("headers") //nolint[gochecknoglobals]
+ bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
+ sizeBucket = []byte("size") //nolint[gochecknoglobals]
+ countsBucket = []byte("counts") //nolint[gochecknoglobals]
+ addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
+ addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
+ cachePassphraseBucket = []byte("cache_passphrase") //nolint[gochecknoglobals]
+ syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
+ mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
+ imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
+ apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
+ deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
+ mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
// ErrNoSuchAPIID when mailbox does not have API ID.
ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals]
@@ -117,18 +123,23 @@ func exposeContextForSMTP() context.Context {
type Store struct {
sentryReporter *sentry.Reporter
panicHandler PanicHandler
- eventLoop *eventLoop
user BridgeUser
+ eventLoop *eventLoop
+ currentEvents *Events
log *logrus.Entry
- cache *Cache
filePath string
db *bolt.DB
lock *sync.RWMutex
addresses map[string]*Address
notifier ChangeNotifier
+ builder *message.Builder
+ cache cache.Cache
+ cacher *Cacher
+ done chan struct{}
+
isSyncRunning bool
syncCooldown cooldown
addressMode addressMode
@@ -139,12 +150,14 @@ func New( // nolint[funlen]
sentryReporter *sentry.Reporter,
panicHandler PanicHandler,
user BridgeUser,
- events listener.Listener,
+ listener listener.Listener,
+ cache cache.Cache,
+ builder *message.Builder,
path string,
- cache *Cache,
+ currentEvents *Events,
) (store *Store, err error) {
- if user == nil || events == nil || cache == nil {
- return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
+ if user == nil || listener == nil || currentEvents == nil {
+ return nil, fmt.Errorf("missing parameters - user: %v, listener: %v, currentEvents: %v", user, listener, currentEvents)
}
l := log.WithField("user", user.ID())
@@ -160,21 +173,29 @@ func New( // nolint[funlen]
bdb, err := openBoltDatabase(path)
if err != nil {
- err = errors.Wrap(err, "failed to open store database")
- return
+ return nil, errors.Wrap(err, "failed to open store database")
}
store = &Store{
sentryReporter: sentryReporter,
panicHandler: panicHandler,
user: user,
- cache: cache,
- filePath: path,
- db: bdb,
- lock: &sync.RWMutex{},
- log: l,
+ currentEvents: currentEvents,
+
+ log: l,
+
+ filePath: path,
+ db: bdb,
+ lock: &sync.RWMutex{},
+
+ builder: builder,
+ cache: cache,
}
+ // Create a new cacher. It's not started yet.
+ // NOTE(GODT-1158): I hate this circular dependency store->cacher->store :(
+ store.cacher = newCacher(store)
+
// Minimal increase is event pollInterval, doubles every failed retry up to 5 minutes.
store.syncCooldown.setExponentialWait(pollInterval, 2, 5*time.Minute)
@@ -188,7 +209,7 @@ func New( // nolint[funlen]
}
if user.IsConnected() {
- store.eventLoop = newEventLoop(cache, store, user, events)
+ store.eventLoop = newEventLoop(currentEvents, store, user, listener)
go func() {
defer store.panicHandler.HandlePanic()
store.eventLoop.start()
@@ -216,10 +237,11 @@ func openBoltDatabase(filePath string) (db *bolt.DB, err error) {
metadataBucket,
headersBucket,
bodystructureBucket,
- msgBuildCountBucket,
+ sizeBucket,
countsBucket,
addressInfoBucket,
addressModeBucket,
+ cachePassphraseBucket,
syncStateBucket,
mailboxesBucket,
mboxVersionBucket,
@@ -365,6 +387,24 @@ func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label)
return
}
+// newBuildJob returns a new build job for the given message using the store's message builder.
+func (store *Store) newBuildJob(messageID string, priority int) (*message.Job, pool.DoneFunc) {
+ return store.builder.NewJobWithOptions(
+ context.Background(),
+ store.client(),
+ messageID,
+ message.JobOptions{
+ IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
+ SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
+ AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
+ AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
+ AddMessageDate: true, // Whether to include message time as X-Pm-Date.
+ AddMessageIDReference: true, // Whether to include the MessageID in References.
+ },
+ priority,
+ )
+}
+
// Close stops the event loop and closes the database to free the file.
func (store *Store) Close() error {
store.lock.Lock()
@@ -381,12 +421,21 @@ func (store *Store) CloseEventLoop() {
}
func (store *Store) close() error {
+ // Stop the watcher first before closing the database.
+ store.stopWatcher()
+
+ // Stop the cacher.
+ store.cacher.stop()
+
+ // Stop the event loop.
store.CloseEventLoop()
+
+ // Close the database.
return store.db.Close()
}
// Remove closes and removes the database file and clears the cache file.
-func (store *Store) Remove() (err error) {
+func (store *Store) Remove() error {
store.lock.Lock()
defer store.lock.Unlock()
@@ -394,22 +443,34 @@ func (store *Store) Remove() (err error) {
var result *multierror.Error
- if err = store.close(); err != nil {
+ if err := store.close(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to close store"))
}
- if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil {
+ if err := RemoveStore(store.currentEvents, store.filePath, store.user.ID()); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove store"))
}
+ if err := store.RemoveCache(); err != nil {
+ result = multierror.Append(result, errors.Wrap(err, "failed to remove cache"))
+ }
+
return result.ErrorOrNil()
}
+func (store *Store) RemoveCache() error {
+ if err := store.clearCachePassphrase(); err != nil {
+ logrus.WithError(err).Error("Failed to clear cache passphrase")
+ }
+
+ return store.cache.Delete(store.user.ID())
+}
+
// RemoveStore removes the database file and clears the cache file.
-func RemoveStore(cache *Cache, path, userID string) error {
+func RemoveStore(currentEvents *Events, path, userID string) error {
var result *multierror.Error
- if err := cache.clearCacheUser(userID); err != nil {
+ if err := currentEvents.clearUserEvents(userID); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache"))
}
diff --git a/internal/store/store_test.go b/internal/store/store_test.go
index 029e26a3..f594817f 100644
--- a/internal/store/store_test.go
+++ b/internal/store/store_test.go
@@ -23,13 +23,17 @@ import (
"io/ioutil"
"os"
"path/filepath"
+ "runtime"
"testing"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
+ tests "github.com/ProtonMail/proton-bridge/test"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
@@ -139,7 +143,7 @@ type mocksForStore struct {
store *Store
tmpDir string
- cache *Cache
+ cache *Events
}
func initMocks(tb testing.TB) (*mocksForStore, func()) {
@@ -162,7 +166,7 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
require.NoError(tb, err)
cacheFile := filepath.Join(mocks.tmpDir, "cache.json")
- mocks.cache = NewCache(cacheFile)
+ mocks.cache = NewEvents(cacheFile)
return mocks, func() {
if err := recover(); err != nil {
@@ -176,13 +180,14 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
}
}
-func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
+func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
mocks.user.EXPECT().ID().Return("userID").AnyTimes()
mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
+ mocks.client.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes()
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
@@ -213,6 +218,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.panicHandler,
mocks.user,
mocks.events,
+ cache.NewInMemoryCache(1<<20),
+ message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache,
)
diff --git a/internal/store/user_message.go b/internal/store/user_message.go
index 47470109..9b06b890 100644
--- a/internal/store/user_message.go
+++ b/internal/store/user_message.go
@@ -27,7 +27,6 @@ import (
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
- pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@@ -154,11 +153,6 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
return false, err
}
- msgSize := message.Size
- if msgSize == 0 {
- msgSize = int64(len(message.Body))
- }
-
var attSize int64
for _, att := range attachments {
b, err := ioutil.ReadAll(att.encReader)
@@ -169,7 +163,7 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
att.encReader = bytes.NewBuffer(b)
}
- return msgSize+attSize <= maxUpload, nil
+ return int64(len(message.Body))+attSize <= maxUpload, nil
}
func (store *Store) getDraftAction(message *pmapi.Message) int {
@@ -237,39 +231,6 @@ func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Messag
return nil
}
-func (store *Store) txPutBodyStructure(bsBucket *bolt.Bucket, msgID string, bs *pkgMsg.BodyStructure) error {
- raw, err := bs.Serialize()
- if err != nil {
- return err
- }
- err = bsBucket.Put([]byte(msgID), raw)
- if err != nil {
- return errors.Wrap(err, "cannot put bodystructure bucket")
- }
- return nil
-}
-
-func (store *Store) txGetBodyStructure(bsBucket *bolt.Bucket, msgID string) (*pkgMsg.BodyStructure, error) {
- raw := bsBucket.Get([]byte(msgID))
- if len(raw) == 0 {
- return nil, nil
- }
- return pkgMsg.DeserializeBodyStructure(raw)
-}
-
-func (store *Store) txIncreaseMsgBuildCount(b *bolt.Bucket, msgID string) (uint32, error) {
- key := []byte(msgID)
- count := uint32(0)
-
- raw := b.Get(key)
- if raw != nil {
- count = btoi(raw)
- }
-
- count++
- return count, b.Put(key, itob(count))
-}
-
// createOrUpdateMessageEvent is helper to create only one message with
// createOrUpdateMessagesEvent.
func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error {
@@ -287,7 +248,7 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
b := tx.Bucket(metadataBucket)
for _, msg := range msgs {
clearNonMetadata(msg)
- txUpdateMetadaFromDB(b, msg, store.log)
+ txUpdateMetadataFromDB(b, msg, store.log)
}
return nil
})
@@ -341,6 +302,11 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
return err
}
+ // Notify the cacher that it should start caching messages.
+ for _, msg := range msgs {
+ store.cacher.newJob(msg.ID)
+ }
+
return nil
}
@@ -351,16 +317,12 @@ func clearNonMetadata(onlyMeta *pmapi.Message) {
onlyMeta.Attachments = nil
}
-// txUpdateMetadaFromDB changes the the onlyMeta data.
+// txUpdateMetadataFromDB changes the the onlyMeta data.
// If there is stored message in metaBucket the size, header and MIMEType are
// not changed if already set. To change these:
// * size must be updated by Message.SetSize
// * contentType and header must be updated by Message.SetContentTypeAndHeader.
-func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
- // Size attribute on the server is counting encrypted data. We need to compute
- // "real" size of decrypted data. Negative values will be processed during fetch.
- onlyMeta.Size = -1
-
+func txUpdateMetadataFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
msgb := metaBucket.Get([]byte(onlyMeta.ID))
if msgb == nil {
return
@@ -378,8 +340,7 @@ func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log
return
}
- // Keep already calculated size and content type.
- onlyMeta.Size = stored.Size
+ // Keep content type.
onlyMeta.MIMEType = stored.MIMEType
if stored.Header != "" && stored.Header != "(No Header)" {
tmpMsg, err := mail.ReadMessage(
@@ -401,6 +362,12 @@ func (store *Store) deleteMessageEvent(apiID string) error {
// deleteMessagesEvent deletes the message from metadata and all mailbox buckets.
func (store *Store) deleteMessagesEvent(apiIDs []string) error {
+ for _, messageID := range apiIDs {
+ if err := store.cache.Rem(store.UserID(), messageID); err != nil {
+ logrus.WithError(err).Error("Failed to remove message from cache")
+ }
+ }
+
return store.db.Update(func(tx *bolt.Tx) error {
for _, apiID := range apiIDs {
if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil {
diff --git a/internal/store/user_message_test.go b/internal/store/user_message_test.go
index bc606723..ca196177 100644
--- a/internal/store/user_message_test.go
+++ b/internal/store/user_message_test.go
@@ -33,7 +33,7 @@ func TestGetAllMessageIDs(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@@ -47,7 +47,7 @@ func TestGetMessageFromDB(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{
@@ -72,7 +72,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1")
@@ -81,12 +81,10 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
// Check non-meta and calculated data are cleared/empty.
a.Equal(t, "", msg.Body)
a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments)
- a.Equal(t, int64(-1), msg.Size)
a.Equal(t, "", msg.MIMEType)
a.Equal(t, make(mail.Header), msg.Header)
// Change the calculated data.
- wantSize := int64(42)
wantMIMEType := "plain-text"
wantHeader := mail.Header{
"Key": []string{"value"},
@@ -94,13 +92,11 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1")
require.Nil(t, err)
- require.Nil(t, storeMsg.SetSize(wantSize))
require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader))
// Check calculated data.
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
- a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
@@ -109,7 +105,6 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
- a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
}
@@ -118,7 +113,7 @@ func TestDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
@@ -129,8 +124,7 @@ func TestDeleteMessage(t *testing.T) {
}
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
- msg := getTestMessage(id, subject, sender, unread, labelIDs)
- require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
+ require.Nil(t, m.store.createOrUpdateMessageEvent(getTestMessage(id, subject, sender, unread, labelIDs)))
}
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
@@ -142,7 +136,6 @@ func getTestMessage(id, subject, sender string, unread bool, labelIDs []string)
Sender: address,
ToList: []*mail.Address{address},
LabelIDs: labelIDs,
- Size: 12345,
Body: "body of message",
Attachments: []*pmapi.Attachment{{
ID: "attachment1",
@@ -162,7 +155,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(false)
+ m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil)
@@ -181,7 +174,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(false)
+ m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil)
diff --git a/internal/store/user_sync_test.go b/internal/store/user_sync_test.go
index 688f7a9a..7b8da1bc 100644
--- a/internal/store/user_sync_test.go
+++ b/internal/store/user_sync_test.go
@@ -30,7 +30,7 @@ func TestLoadSaveSyncState(t *testing.T) {
m, clear := initMocks(t)
defer clear()
- m.newStoreNoEvents(true)
+ m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
diff --git a/internal/users/user.go b/internal/users/user.go
index 8dd6921c..b3f60932 100644
--- a/internal/users/user.go
+++ b/internal/users/user.go
@@ -107,6 +107,21 @@ func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) erro
return err
}
+ // If the client is already unlocked, we can unlock the store cache as well.
+ if client.IsUnlocked() {
+ kr, err := client.GetUserKeyRing()
+ if err != nil {
+ return err
+ }
+
+ if err := u.store.UnlockCache(kr); err != nil {
+ return err
+ }
+
+ // NOTE(GODT-1158): If using in-memory cache we probably shouldn't start the watcher?
+ u.store.StartWatcher()
+ }
+
return nil
}
diff --git a/internal/users/user_credentials_test.go b/internal/users/user_credentials_test.go
index accee921..df994fed 100644
--- a/internal/users/user_credentials_test.go
+++ b/internal/users/user_credentials_test.go
@@ -32,7 +32,7 @@ func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@@ -50,7 +50,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
// Ignore any sync on background.
@@ -76,7 +76,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode())
- // MOck change to combined mode.
+ // Mock change to combined mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
@@ -98,7 +98,7 @@ func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@@ -115,7 +115,7 @@ func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@@ -133,7 +133,7 @@ func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
@@ -144,7 +144,7 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
@@ -187,7 +187,7 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin("wrong!")
diff --git a/internal/users/user_new_test.go b/internal/users/user_new_test.go
index f3810c06..b3b4045b 100644
--- a/internal/users/user_new_test.go
+++ b/internal/users/user_new_test.go
@@ -64,7 +64,7 @@ func TestNewUser(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
- mockInitConnectedUser(m)
+ mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(m, "", testCredentials)
diff --git a/internal/users/user_store_test.go b/internal/users/user_store_test.go
index 87ee0c8e..f333dcd8 100644
--- a/internal/users/user_store_test.go
+++ b/internal/users/user_store_test.go
@@ -31,7 +31,7 @@ func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
r.Nil(t, user.store.Close())
@@ -43,7 +43,7 @@ func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
- user := testNewUser(m)
+ user := testNewUser(t, m)
defer cleanUpUserData(user)
r.NotNil(t, user.store)
diff --git a/internal/users/user_test.go b/internal/users/user_test.go
index 9ced4b94..7780a1b1 100644
--- a/internal/users/user_test.go
+++ b/internal/users/user_test.go
@@ -18,13 +18,15 @@
package users
import (
+ "testing"
+
r "github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
-func testNewUser(m mocks) *User {
+func testNewUser(t *testing.T, m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
- mockInitConnectedUser(m)
+ mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker)
diff --git a/internal/users/users.go b/internal/users/users.go
index 87702a1a..01056515 100644
--- a/internal/users/users.go
+++ b/internal/users/users.go
@@ -20,12 +20,13 @@ package users
import (
"context"
+ "os"
+ "path/filepath"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
- imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener"
@@ -225,6 +226,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password []by
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
}
+ // will go and unlock cache if not already done
if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user")
}
@@ -341,9 +343,6 @@ func (u *Users) ClearData() error {
result = multierror.Append(result, err)
}
- // Need to clear imap cache otherwise fetch response will be remembered from previous test.
- imapcache.Clear()
-
return result
}
@@ -366,6 +365,7 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
if err := user.closeStore(); err != nil {
log.WithError(err).Error("Failed to close user store")
}
+
if clearStore {
// Clear cache after closing connections (done in logout).
if err := user.clearStore(); err != nil {
@@ -427,6 +427,41 @@ func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy()
}
+func (u *Users) EnableCache() error {
+ // NOTE(GODT-1158): Check for available size before enabling.
+
+ return nil
+}
+
+func (u *Users) MigrateCache(from, to string) error {
+ // NOTE(GODT-1158): Is it enough to just close the store? Do we need to force-close the cacher too?
+
+ for _, user := range u.users {
+ if err := user.closeStore(); err != nil {
+ logrus.WithError(err).Error("Failed to close user's store")
+ }
+ }
+
+ // Ensure the parent directory exists.
+ if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
+ return err
+ }
+
+ return os.Rename(from, to)
+}
+
+func (u *Users) DisableCache() error {
+ // NOTE(GODT-1158): Is it an error if we can't remove a user's cache?
+
+ for _, user := range u.users {
+ if err := user.store.RemoveCache(); err != nil {
+ logrus.WithError(err).Error("Failed to remove user's message cache")
+ }
+ }
+
+ return nil
+}
+
// hasUser returns whether the struct currently has a user with ID `id`.
func (u *Users) hasUser(id string) (user *User, ok bool) {
for _, u := range u.users {
diff --git a/internal/users/users_login_test.go b/internal/users/users_login_test.go
index 8c24a9c6..1e321ff5 100644
--- a/internal/users/users_login_test.go
+++ b/internal/users/users_login_test.go
@@ -49,7 +49,7 @@ func TestUsersFinishLoginNewUser(t *testing.T) {
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
- mockAddingConnectedUser(m)
+ mockAddingConnectedUser(t, m)
mockEventLoopNoAction(m)
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
@@ -74,7 +74,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
)
- mockInitConnectedUser(m)
+ mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
@@ -95,7 +95,7 @@ func TestUsersFinishLoginConnectedUser(t *testing.T) {
// Mock loading connected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
- mockLoadingConnectedUser(m, testCredentials)
+ mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
// Mock process of FinishLogin of already connected user.
diff --git a/internal/users/users_new_test.go b/internal/users/users_new_test.go
index 00a7b886..c00c42b1 100644
--- a/internal/users/users_new_test.go
+++ b/internal/users/users_new_test.go
@@ -49,7 +49,7 @@ func TestNewUsersWithConnectedUser(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
- mockLoadingConnectedUser(m, testCredentials)
+ mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
}
@@ -71,7 +71,7 @@ func TestNewUsersWithUsers(t *testing.T) {
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
- mockLoadingConnectedUser(m, testCredentials)
+ mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
}
diff --git a/internal/users/users_test.go b/internal/users/users_test.go
index 05077ec7..87b784d5 100644
--- a/internal/users/users_test.go
+++ b/internal/users/users_test.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io/ioutil"
"os"
+ "runtime"
"runtime/debug"
"testing"
"time"
@@ -28,10 +29,13 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
+ tests "github.com/ProtonMail/proton-bridge/test"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@@ -42,9 +46,11 @@ func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel)
}
+
if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel)
}
+
os.Exit(m.Run())
}
@@ -151,7 +157,7 @@ type mocks struct {
clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient
- storeCache *store.Cache
+ storeCache *store.Events
}
func initMocks(t *testing.T) mocks {
@@ -178,7 +184,7 @@ func initMocks(t *testing.T) mocks {
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
- storeCache: store.NewCache(cacheFile.Name()),
+ storeCache: store.NewEvents(cacheFile.Name()),
}
// Called during clean-up.
@@ -187,9 +193,20 @@ func initMocks(t *testing.T) mocks {
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
- dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
+
+ dbFile, err := ioutil.TempFile(t.TempDir(), "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db")
- return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
+
+ return store.New(
+ sentryReporter,
+ m.PanicHandler,
+ user,
+ m.eventListener,
+ cache.NewInMemoryCache(1<<20),
+ message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
+ dbFile.Name(),
+ m.storeCache,
+ )
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
@@ -212,8 +229,8 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
- mockLoadingConnectedUser(m, testCredentials)
- mockLoadingConnectedUser(m, testCredentialsSplit)
+ mockLoadingConnectedUser(t, m, testCredentials)
+ mockLoadingConnectedUser(t, m, testCredentialsSplit)
mockEventLoopNoAction(m)
return testNewUsers(t, m)
@@ -245,7 +262,7 @@ func cleanUpUsersData(b *Users) {
}
}
-func mockAddingConnectedUser(m mocks) {
+func mockAddingConnectedUser(t *testing.T, m mocks) {
gomock.InOrder(
// Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
@@ -256,10 +273,10 @@ func mockAddingConnectedUser(m mocks) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
- mockInitConnectedUser(m)
+ mockInitConnectedUser(t, m)
}
-func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
+func mockLoadingConnectedUser(t *testing.T, m mocks, creds *credentials.Credentials) {
authRefresh := &pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
@@ -273,10 +290,10 @@ func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
)
- mockInitConnectedUser(m)
+ mockInitConnectedUser(t, m)
}
-func mockInitConnectedUser(m mocks) {
+func mockInitConnectedUser(t *testing.T, m mocks) {
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
@@ -286,6 +303,7 @@ func mockInitConnectedUser(m mocks) {
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
+ m.pmapiClient.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes(),
)
}
diff --git a/pkg/message/build.go b/pkg/message/build.go
index 643bfb39..8d20fd7c 100644
--- a/pkg/message/build.go
+++ b/pkg/message/build.go
@@ -20,10 +20,12 @@ package message
import (
"context"
"io"
+ "io/ioutil"
"sync"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
+ "github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/pkg/errors"
)
@@ -32,11 +34,15 @@ var (
ErrNoSuchKeyRing = errors.New("the keyring to decrypt this message could not be found")
)
+const (
+ BackgroundPriority = 1 << iota
+ ForegroundPriority
+)
+
type Builder struct {
- reqs chan fetchReq
- done chan struct{}
- jobs map[string]*BuildJob
- locker sync.Mutex
+ pool *pool.Pool
+ jobs map[string]*Job
+ lock sync.Mutex
}
type Fetcher interface {
@@ -48,111 +54,159 @@ type Fetcher interface {
// NewBuilder creates a new builder which manages the given number of fetch/attach/build workers.
// - fetchWorkers: the number of workers which fetch messages from API
// - attachWorkers: the number of workers which fetch attachments from API.
-// - buildWorkers: the number of workers which decrypt/build RFC822 message literals.
-//
-// NOTE: Each fetch worker spawns a unique set of attachment workers!
-// There can therefore be up to fetchWorkers*attachWorkers simultaneous API connections.
//
// The returned builder is ready to handle jobs -- see (*Builder).NewJob for more information.
//
// Call (*Builder).Done to shut down the builder and stop all workers.
-func NewBuilder(fetchWorkers, attachWorkers, buildWorkers int) *Builder {
- b := newBuilder()
+func NewBuilder(fetchWorkers, attachWorkers int) *Builder {
+ attacherPool := pool.New(attachWorkers, newAttacherWorkFunc())
- fetchReqCh, fetchResCh := startFetchWorkers(fetchWorkers, attachWorkers)
- buildReqCh, buildResCh := startBuildWorkers(buildWorkers)
+ fetcherPool := pool.New(fetchWorkers, newFetcherWorkFunc(attacherPool))
- go func() {
- defer close(fetchReqCh)
-
- for {
- select {
- case req := <-b.reqs:
- fetchReqCh <- req
-
- case <-b.done:
- return
- }
- }
- }()
-
- go func() {
- defer close(buildReqCh)
-
- for res := range fetchResCh {
- if res.err != nil {
- b.jobFailure(res.messageID, res.err)
- } else {
- buildReqCh <- res
- }
- }
- }()
-
- go func() {
- for res := range buildResCh {
- if res.err != nil {
- b.jobFailure(res.messageID, res.err)
- } else {
- b.jobSuccess(res.messageID, res.literal)
- }
- }
- }()
-
- return b
-}
-
-func newBuilder() *Builder {
return &Builder{
- reqs: make(chan fetchReq),
- done: make(chan struct{}),
- jobs: make(map[string]*BuildJob),
+ pool: fetcherPool,
+ jobs: make(map[string]*Job),
}
}
-// NewJob tells the builder to begin building the message with the given ID.
-// The result (or any error which occurred during building) can be retrieved from the returned job when available.
-func (b *Builder) NewJob(ctx context.Context, api Fetcher, messageID string) *BuildJob {
- return b.NewJobWithOptions(ctx, api, messageID, JobOptions{})
+func (builder *Builder) NewJob(ctx context.Context, fetcher Fetcher, messageID string, prio int) (*Job, pool.DoneFunc) {
+ return builder.NewJobWithOptions(ctx, fetcher, messageID, JobOptions{}, prio)
}
-// NewJobWithOptions creates a new job with custom options. See NewJob for more information.
-func (b *Builder) NewJobWithOptions(ctx context.Context, api Fetcher, messageID string, opts JobOptions) *BuildJob {
- b.locker.Lock()
- defer b.locker.Unlock()
+func (builder *Builder) NewJobWithOptions(ctx context.Context, fetcher Fetcher, messageID string, opts JobOptions, prio int) (*Job, pool.DoneFunc) {
+ builder.lock.Lock()
+ defer builder.lock.Unlock()
- if job, ok := b.jobs[messageID]; ok {
- return job
+ if job, ok := builder.jobs[messageID]; ok {
+ if job.GetPriority() < prio {
+ job.SetPriority(prio)
+ }
+
+ return job, job.done
}
- b.jobs[messageID] = newBuildJob(messageID)
+ job, done := builder.pool.NewJob(
+ &fetchReq{
+ fetcher: fetcher,
+ messageID: messageID,
+ options: opts,
+ },
+ prio,
+ )
- go func() { b.reqs <- fetchReq{ctx: ctx, api: api, messageID: messageID, opts: opts} }()
+ buildJob := &Job{
+ Job: job,
+ done: done,
+ }
- return b.jobs[messageID]
+ builder.jobs[messageID] = buildJob
+
+ return buildJob, func() {
+ builder.lock.Lock()
+ defer builder.lock.Unlock()
+
+ // Remove the job from the builder.
+ delete(builder.jobs, messageID)
+
+ // And mark it as done.
+ done()
+ }
}
-// Done shuts down the builder and stops all workers.
-func (b *Builder) Done() {
- b.locker.Lock()
- defer b.locker.Unlock()
-
- close(b.done)
+func (builder *Builder) Done() {
+ // NOTE(GODT-1158): Stop worker pool.
}
-func (b *Builder) jobSuccess(messageID string, literal []byte) {
- b.locker.Lock()
- defer b.locker.Unlock()
-
- b.jobs[messageID].postSuccess(literal)
-
- delete(b.jobs, messageID)
+type fetchReq struct {
+ fetcher Fetcher
+ messageID string
+ options JobOptions
}
-func (b *Builder) jobFailure(messageID string, err error) {
- b.locker.Lock()
- defer b.locker.Unlock()
-
- b.jobs[messageID].postFailure(err)
-
- delete(b.jobs, messageID)
+type attachReq struct {
+ fetcher Fetcher
+ message *pmapi.Message
+}
+
+type Job struct {
+ *pool.Job
+
+ done pool.DoneFunc
+}
+
+func (job *Job) GetResult() ([]byte, error) {
+ res, err := job.Job.GetResult()
+ if err != nil {
+ return nil, err
+ }
+
+ return res.([]byte), nil
+}
+
+func newAttacherWorkFunc() pool.WorkFunc {
+ return func(payload interface{}, prio int) (interface{}, error) {
+ req, ok := payload.(*attachReq)
+ if !ok {
+ panic("bad payload type")
+ }
+
+ res := make(map[string][]byte)
+
+ for _, att := range req.message.Attachments {
+ rc, err := req.fetcher.GetAttachment(context.Background(), att.ID)
+ if err != nil {
+ return nil, err
+ }
+
+ b, err := ioutil.ReadAll(rc)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := rc.Close(); err != nil {
+ return nil, err
+ }
+
+ res[att.ID] = b
+ }
+
+ return res, nil
+ }
+}
+
+func newFetcherWorkFunc(attacherPool *pool.Pool) pool.WorkFunc {
+ return func(payload interface{}, prio int) (interface{}, error) {
+ req, ok := payload.(*fetchReq)
+ if !ok {
+ panic("bad payload type")
+ }
+
+ msg, err := req.fetcher.GetMessage(context.Background(), req.messageID)
+ if err != nil {
+ return nil, err
+ }
+
+ attJob, attDone := attacherPool.NewJob(&attachReq{
+ fetcher: req.fetcher,
+ message: msg,
+ }, prio)
+ defer attDone()
+
+ val, err := attJob.GetResult()
+ if err != nil {
+ return nil, err
+ }
+
+ attData, ok := val.(map[string][]byte)
+ if !ok {
+ panic("bad response type")
+ }
+
+ kr, err := req.fetcher.KeyRingForAddressID(msg.AddressID)
+ if err != nil {
+ return nil, ErrNoSuchKeyRing
+ }
+
+ return buildRFC822(kr, msg, attData, req.options)
+ }
}
diff --git a/pkg/message/build_build.go b/pkg/message/build_build.go
deleted file mode 100644
index 596f0e7f..00000000
--- a/pkg/message/build_build.go
+++ /dev/null
@@ -1,89 +0,0 @@
-// Copyright (c) 2021 Proton Technologies AG
-//
-// This file is part of ProtonMail Bridge.
-//
-// ProtonMail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// ProtonMail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with ProtonMail Bridge. If not, see .
-
-package message
-
-import (
- "sync"
-
- "github.com/pkg/errors"
-)
-
-type buildRes struct {
- messageID string
- literal []byte
- err error
-}
-
-func newBuildResSuccess(messageID string, literal []byte) buildRes {
- return buildRes{
- messageID: messageID,
- literal: literal,
- }
-}
-
-func newBuildResFailure(messageID string, err error) buildRes {
- return buildRes{
- messageID: messageID,
- err: err,
- }
-}
-
-// startBuildWorkers starts the given number of build workers.
-// These workers decrypt and build messages into RFC822 literals.
-// Two channels are returned:
-// - buildReqCh: used to send work items to the worker pool
-// - buildResCh: used to receive work results from the worker pool
-func startBuildWorkers(buildWorkers int) (chan fetchRes, chan buildRes) {
- buildReqCh := make(chan fetchRes)
- buildResCh := make(chan buildRes)
-
- go func() {
- defer close(buildResCh)
-
- var wg sync.WaitGroup
-
- wg.Add(buildWorkers)
-
- for workerID := 0; workerID < buildWorkers; workerID++ {
- go buildWorker(buildReqCh, buildResCh, &wg)
- }
-
- wg.Wait()
- }()
-
- return buildReqCh, buildResCh
-}
-
-func buildWorker(buildReqCh <-chan fetchRes, buildResCh chan<- buildRes, wg *sync.WaitGroup) {
- defer wg.Done()
-
- for req := range buildReqCh {
- l := log.
- WithField("addrID", req.msg.AddressID).
- WithField("msgID", req.msg.ID)
- if kr, err := req.api.KeyRingForAddressID(req.msg.AddressID); err != nil {
- l.WithError(err).Warn("Cannot find keyring for address")
- buildResCh <- newBuildResFailure(req.msg.ID, errors.Wrap(ErrNoSuchKeyRing, err.Error()))
- } else if literal, err := buildRFC822(kr, req.msg, req.atts, req.opts); err != nil {
- l.WithError(err).Warn("Build failed")
- buildResCh <- newBuildResFailure(req.msg.ID, err)
- } else {
- buildResCh <- newBuildResSuccess(req.msg.ID, literal)
- }
- }
-}
diff --git a/pkg/message/build_fetch.go b/pkg/message/build_fetch.go
deleted file mode 100644
index 81a99829..00000000
--- a/pkg/message/build_fetch.go
+++ /dev/null
@@ -1,141 +0,0 @@
-// Copyright (c) 2021 Proton Technologies AG
-//
-// This file is part of ProtonMail Bridge.
-//
-// ProtonMail Bridge is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// ProtonMail Bridge is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with ProtonMail Bridge. If not, see .
-
-package message
-
-import (
- "context"
- "io/ioutil"
- "sync"
-
- "github.com/ProtonMail/proton-bridge/pkg/parallel"
- "github.com/ProtonMail/proton-bridge/pkg/pmapi"
-)
-
-type fetchReq struct {
- ctx context.Context
- api Fetcher
- messageID string
- opts JobOptions
-}
-
-type fetchRes struct {
- fetchReq
-
- msg *pmapi.Message
- atts [][]byte
- err error
-}
-
-func newFetchResSuccess(req fetchReq, msg *pmapi.Message, atts [][]byte) fetchRes {
- return fetchRes{
- fetchReq: req,
- msg: msg,
- atts: atts,
- }
-}
-
-func newFetchResFailure(req fetchReq, err error) fetchRes {
- return fetchRes{
- fetchReq: req,
- err: err,
- }
-}
-
-// startFetchWorkers starts the given number of fetch workers.
-// These workers download message and attachment data from API.
-// Each fetch worker will use up to the given number of attachment workers to download attachments.
-// Two channels are returned:
-// - fetchReqCh: used to send work items to the worker pool
-// - fetchResCh: used to receive work results from the worker pool
-func startFetchWorkers(fetchWorkers, attachWorkers int) (chan fetchReq, chan fetchRes) {
- fetchReqCh := make(chan fetchReq)
- fetchResCh := make(chan fetchRes)
-
- go func() {
- defer close(fetchResCh)
-
- var wg sync.WaitGroup
-
- wg.Add(fetchWorkers)
-
- for workerID := 0; workerID < fetchWorkers; workerID++ {
- go fetchWorker(fetchReqCh, fetchResCh, attachWorkers, &wg)
- }
-
- wg.Wait()
- }()
-
- return fetchReqCh, fetchResCh
-}
-
-func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachWorkers int, wg *sync.WaitGroup) {
- defer wg.Done()
-
- for req := range fetchReqCh {
- msg, atts, err := fetchMessage(req, attachWorkers)
- if err != nil {
- fetchResCh <- newFetchResFailure(req, err)
- } else {
- fetchResCh <- newFetchResSuccess(req, msg, atts)
- }
- }
-}
-
-func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) {
- msg, err := req.api.GetMessage(req.ctx, req.messageID)
- if err != nil {
- return nil, nil, err
- }
-
- attList := make([]interface{}, len(msg.Attachments))
-
- for i, att := range msg.Attachments {
- attList[i] = att.ID
- }
-
- process := func(value interface{}) (interface{}, error) {
- rc, err := req.api.GetAttachment(req.ctx, value.(string))
- if err != nil {
- return nil, err
- }
-
- b, err := ioutil.ReadAll(rc)
- if err != nil {
- return nil, err
- }
-
- if err := rc.Close(); err != nil {
- return nil, err
- }
-
- return b, nil
- }
-
- attData := make([][]byte, len(msg.Attachments))
-
- collect := func(idx int, value interface{}) error {
- attData[idx] = value.([]byte) //nolint[forcetypeassert] we wan't to panic here
- return nil
- }
-
- if err := parallel.RunParallel(attachWorkers, attList, process, collect); err != nil {
- return nil, nil, err
- }
-
- return msg, attData, nil
-}
diff --git a/pkg/message/build_job.go b/pkg/message/build_job.go
index b800afa7..a04c452a 100644
--- a/pkg/message/build_job.go
+++ b/pkg/message/build_job.go
@@ -25,35 +25,3 @@ type JobOptions struct {
AddMessageDate bool // Whether to include message time as X-Pm-Date.
AddMessageIDReference bool // Whether to include the MessageID in References.
}
-
-type BuildJob struct {
- messageID string
- literal []byte
- err error
-
- done chan struct{}
-}
-
-func newBuildJob(messageID string) *BuildJob {
- return &BuildJob{
- messageID: messageID,
- done: make(chan struct{}),
- }
-}
-
-// GetResult returns the build result or any error which occurred during building.
-// If the result is not ready yet, it blocks.
-func (job *BuildJob) GetResult() ([]byte, error) {
- <-job.done
- return job.literal, job.err
-}
-
-func (job *BuildJob) postSuccess(literal []byte) {
- job.literal = literal
- close(job.done)
-}
-
-func (job *BuildJob) postFailure(err error) {
- job.err = err
- close(job.done)
-}
diff --git a/pkg/message/build_rfc822.go b/pkg/message/build_rfc822.go
index 9bfcaa4d..bfcdac62 100644
--- a/pkg/message/build_rfc822.go
+++ b/pkg/message/build_rfc822.go
@@ -34,7 +34,7 @@ import (
"github.com/pkg/errors"
)
-func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData [][]byte, opts JobOptions) ([]byte, error) {
+func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData map[string][]byte, opts JobOptions) ([]byte, error) {
switch {
case len(msg.Attachments) > 0:
return buildMultipartRFC822(kr, msg, attData, opts)
@@ -80,7 +80,7 @@ func buildSimpleRFC822(kr *crypto.KeyRing, msg *pmapi.Message, opts JobOptions)
func buildMultipartRFC822(
kr *crypto.KeyRing,
msg *pmapi.Message,
- attData [][]byte,
+ attData map[string][]byte,
opts JobOptions,
) ([]byte, error) {
boundary := newBoundary(msg.ID)
@@ -103,13 +103,13 @@ func buildMultipartRFC822(
attachData [][]byte
)
- for i, att := range msg.Attachments {
+ for _, att := range msg.Attachments {
if att.Disposition == pmapi.DispositionInline {
inlineAtts = append(inlineAtts, att)
- inlineData = append(inlineData, attData[i])
+ inlineData = append(inlineData, attData[att.ID])
} else {
attachAtts = append(attachAtts, att)
- attachData = append(attachData, attData[i])
+ attachData = append(attachData, attData[att.ID])
}
}
diff --git a/pkg/message/build_test.go b/pkg/message/build_test.go
index aca5925a..f589a27a 100644
--- a/pkg/message/build_test.go
+++ b/pkg/message/build_test.go
@@ -37,13 +37,16 @@ func TestBuildPlainMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Now())
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -56,14 +59,17 @@ func TestBuildPlainMessageWithLongKey(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(1, 1)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Now())
msg.Header["ReallyVeryVeryVeryVeryVeryLongLongLongLongLongLongLongKeyThatWillHaveNotSoLongValue"] = []string{"value"}
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -77,13 +83,16 @@ func TestBuildHTMLMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "
body", time.Now())
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -96,7 +105,7 @@ func TestBuildPlainEncryptedMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-plaintext.eml"))
@@ -104,7 +113,10 @@ func TestBuildPlainEncryptedMessage(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -124,7 +136,7 @@ func TestBuildPlainEncryptedMessageMissingHeader(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(1, 1)
defer b.Done()
body := readerToString(getFileReader("plaintext-missing-header.eml"))
@@ -132,7 +144,10 @@ func TestBuildPlainEncryptedMessageMissingHeader(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Now())
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -144,7 +159,7 @@ func TestBuildPlainEncryptedMessageInvalidHeader(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(1, 1)
defer b.Done()
body := readerToString(getFileReader("plaintext-invalid-header.eml"))
@@ -152,7 +167,10 @@ func TestBuildPlainEncryptedMessageInvalidHeader(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Now())
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -164,7 +182,7 @@ func TestBuildPlainSignedEncryptedMessageMissingHeader(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(1, 1)
defer b.Done()
body := readerToString(getFileReader("plaintext-missing-header.eml"))
@@ -180,7 +198,10 @@ func TestBuildPlainSignedEncryptedMessageMissingHeader(t *testing.T) {
msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -204,7 +225,7 @@ func TestBuildPlainSignedEncryptedMessageInvalidHeader(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(1, 1)
defer b.Done()
body := readerToString(getFileReader("plaintext-invalid-header.eml"))
@@ -220,7 +241,10 @@ func TestBuildPlainSignedEncryptedMessageInvalidHeader(t *testing.T) {
msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -244,7 +268,7 @@ func TestBuildPlainEncryptedLatin2Message(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-plaintext-latin2.eml"))
@@ -252,7 +276,10 @@ func TestBuildPlainEncryptedLatin2Message(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -269,7 +296,7 @@ func TestBuildHTMLEncryptedMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-html.eml"))
@@ -277,7 +304,10 @@ func TestBuildHTMLEncryptedMessage(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -298,7 +328,7 @@ func TestBuildPlainSignedMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("text_plain.eml"))
@@ -314,7 +344,10 @@ func TestBuildPlainSignedMessage(t *testing.T) {
msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -339,7 +372,7 @@ func TestBuildPlainSignedBase64Message(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("text_plain_base64.eml"))
@@ -355,7 +388,10 @@ func TestBuildPlainSignedBase64Message(t *testing.T) {
msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -381,7 +417,7 @@ func TestBuildSignedPlainEncryptedMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-signed-plaintext.eml"))
@@ -389,7 +425,10 @@ func TestBuildSignedPlainEncryptedMessage(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -421,7 +460,7 @@ func TestBuildSignedHTMLEncryptedMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-signed-html.eml"))
@@ -429,7 +468,10 @@ func TestBuildSignedHTMLEncryptedMessage(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -463,7 +505,7 @@ func TestBuildSignedPlainEncryptedMessageWithPubKey(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-signed-plaintext-with-pubkey.eml"))
@@ -471,7 +513,10 @@ func TestBuildSignedPlainEncryptedMessageWithPubKey(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -512,7 +557,7 @@ func TestBuildSignedHTMLEncryptedMessageWithPubKey(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-signed-html-with-pubkey.eml"))
@@ -520,7 +565,10 @@ func TestBuildSignedHTMLEncryptedMessageWithPubKey(t *testing.T) {
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -562,7 +610,7 @@ func TestBuildSignedMultipartAlternativeEncryptedMessageWithPubKey(t *testing.T)
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-signed-multipart-alternative-with-pubkey.eml"))
@@ -570,7 +618,10 @@ func TestBuildSignedMultipartAlternativeEncryptedMessageWithPubKey(t *testing.T)
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -628,7 +679,7 @@ func TestBuildSignedEmbeddedMessageRFC822EncryptedMessageWithPubKey(t *testing.T
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
body := readerToString(getFileReader("pgp-mime-body-signed-embedded-message-rfc822-with-pubkey.eml"))
@@ -636,7 +687,10 @@ func TestBuildSignedEmbeddedMessageRFC822EncryptedMessageWithPubKey(t *testing.T
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -682,14 +736,17 @@ func TestBuildHTMLMessageWithAttachment(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now())
att := addTestAttachment(t, kr, msg, "attachID", "file.png", "image/png", "attachment", "attachment")
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res, 1).
@@ -709,14 +766,17 @@ func TestBuildHTMLMessageWithRFC822Attachment(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now())
att := addTestAttachment(t, kr, msg, "attachID", "file.eml", "message/rfc822", "attachment", "... message/rfc822 ...")
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res, 1).
@@ -736,14 +796,17 @@ func TestBuildHTMLMessageWithInlineAttachment(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now())
inl := addTestAttachment(t, kr, msg, "inlineID", "file.png", "image/png", "inline", "inline")
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res, 1).
@@ -766,7 +829,7 @@ func TestBuildHTMLMessageWithComplexAttachments(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -776,7 +839,10 @@ func TestBuildHTMLMessageWithComplexAttachments(t *testing.T) {
att0 := addTestAttachment(t, kr, msg, "attachID0", "attach0.png", "image/png", "attachment", "attach0")
att1 := addTestAttachment(t, kr, msg, "attachID1", "attach1.png", "image/png", "attachment", "attach1")
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl0, inl1, att0, att1), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl0, inl1, att0, att1), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res, 1).
@@ -820,14 +886,17 @@ func TestBuildAttachmentWithExoticFilename(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now())
att := addTestAttachment(t, kr, msg, "attachID", `I řeally šhould leařn czech.png`, "image/png", "attachment", "attachment")
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
// The "name" and "filename" params should actually be RFC2047-encoded because they aren't 7-bit clean.
@@ -843,7 +912,7 @@ func TestBuildAttachmentWithLongFilename(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
veryLongName := strings.Repeat("a", 200) + ".png"
@@ -852,7 +921,10 @@ func TestBuildAttachmentWithLongFilename(t *testing.T) {
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now())
att := addTestAttachment(t, kr, msg, "attachID", veryLongName, "image/png", "attachment", "attachment")
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
// NOTE: hasMaxLineLength is too high! Long filenames should be linewrapped using multipart filenames.
@@ -868,13 +940,16 @@ func TestBuildMessageDate(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC))
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).expectDate(is(`Wed, 01 Jan 2020 00:00:00 +0000`))
@@ -884,7 +959,7 @@ func TestBuildMessageWithInvalidDate(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -893,21 +968,26 @@ func TestBuildMessageWithInvalidDate(t *testing.T) {
msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Unix(-1, 0))
// Build the message as usual; the date will be before 1970.
- resRaw, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ jobRaw, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ resRaw, err := jobRaw.GetResult()
require.NoError(t, err)
+ done()
section(t, resRaw).
expectDate(is(`Wed, 31 Dec 1969 23:59:59 +0000`)).
expectHeader(`X-Original-Date`, isMissing())
// Build the message with date sanitization enabled; the date will be RFC822's birthdate.
- resFix, err := b.NewJobWithOptions(
+ jobFix, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg),
msg.ID,
JobOptions{SanitizeDate: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ resFix, err := jobFix.GetResult()
require.NoError(t, err)
+ done()
section(t, resFix).
expectDate(is(`Fri, 13 Aug 1982 00:00:00 +0000`)).
@@ -918,13 +998,16 @@ func TestBuildMessageInternalID(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Now())
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).expectHeader(`Message-Id`, is(``))
@@ -934,7 +1017,7 @@ func TestBuildMessageExternalID(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -943,7 +1026,10 @@ func TestBuildMessageExternalID(t *testing.T) {
// Set the message's external ID; this should be used preferentially to set the Message-Id header field.
msg.ExternalID = "externalID"
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).expectHeader(`Message-Id`, is(``))
@@ -953,7 +1039,7 @@ func TestBuild8BitBody(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -961,7 +1047,10 @@ func TestBuild8BitBody(t *testing.T) {
// Set an 8-bit body; the charset should be set to UTF-8.
msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "I řeally šhould leařn czech", time.Now())
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).expectContentTypeParam(`charset`, is(`utf-8`))
@@ -971,7 +1060,7 @@ func TestBuild8BitSubject(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -980,7 +1069,10 @@ func TestBuild8BitSubject(t *testing.T) {
// Set an 8-bit subject; it should be RFC2047-encoded.
msg.Subject = `I řeally šhould leařn czech`
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -992,7 +1084,7 @@ func TestBuild8BitSender(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1004,7 +1096,10 @@ func TestBuild8BitSender(t *testing.T) {
Address: `mail@example.com`,
}
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1016,7 +1111,7 @@ func TestBuild8BitRecipients(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1028,7 +1123,10 @@ func TestBuild8BitRecipients(t *testing.T) {
{Name: `leařn czech`, Address: `mail2@example.com`},
}
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1040,7 +1138,7 @@ func TestBuildIncludeMessageIDReference(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1049,18 +1147,23 @@ func TestBuildIncludeMessageIDReference(t *testing.T) {
// Add references.
msg.Header["References"] = []string{""}
- res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ res, err := job.GetResult()
require.NoError(t, err)
+ done()
section(t, res).expectHeader(`References`, is(``))
- resRef, err := b.NewJobWithOptions(
+ jobRef, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg),
msg.ID,
JobOptions{AddMessageIDReference: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ resRef, err := jobRef.GetResult()
require.NoError(t, err)
+ done()
section(t, resRef).expectHeader(`References`, is(` `))
}
@@ -1069,7 +1172,7 @@ func TestBuildMessageIsDeterministic(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1077,11 +1180,15 @@ func TestBuildMessageIsDeterministic(t *testing.T) {
inl := addTestAttachment(t, kr, msg, "inlineID", "file.png", "image/png", "inline", "inline")
att := addTestAttachment(t, kr, msg, "attachID", "attach.png", "image/png", "attachment", "attachment")
- res1, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID).GetResult()
+ job1, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID, ForegroundPriority)
+ res1, err := job1.GetResult()
require.NoError(t, err)
+ done()
- res2, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID).GetResult()
+ job2, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID, ForegroundPriority)
+ res2, err := job2.GetResult()
require.NoError(t, err)
+ done()
assert.Equal(t, res1, res2)
}
@@ -1090,15 +1197,18 @@ func TestBuildParallel(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(2, 2, 2)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
msg1 := newTestMessage(t, kr, "messageID1", "addressID", "text/plain", "body1", time.Now())
msg2 := newTestMessage(t, kr, "messageID2", "addressID", "text/plain", "body2", time.Now())
- job1 := b.NewJob(context.Background(), newTestFetcher(m, kr, msg1), msg1.ID)
- job2 := b.NewJob(context.Background(), newTestFetcher(m, kr, msg2), msg2.ID)
+ job1, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg1), msg1.ID, ForegroundPriority)
+ defer done()
+
+ job2, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg2), msg2.ID, ForegroundPriority)
+ defer done()
res1, err := job1.GetResult()
require.NoError(t, err)
@@ -1115,7 +1225,7 @@ func TestBuildParallelSameMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(2, 2, 2)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1123,8 +1233,12 @@ func TestBuildParallelSameMessage(t *testing.T) {
// Jobs for the same messageID are shared so fetcher is only called once.
fetcher := newTestFetcher(m, kr, msg)
- job1 := b.NewJob(context.Background(), fetcher, msg.ID)
- job2 := b.NewJob(context.Background(), fetcher, msg.ID)
+
+ job1, done := b.NewJob(context.Background(), fetcher, msg.ID, ForegroundPriority)
+ defer done()
+
+ job2, done := b.NewJob(context.Background(), fetcher, msg.ID, ForegroundPriority)
+ defer done()
res1, err := job1.GetResult()
require.NoError(t, err)
@@ -1141,7 +1255,7 @@ func TestBuildUndecryptableMessage(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1149,7 +1263,10 @@ func TestBuildUndecryptableMessage(t *testing.T) {
// Use a different keyring for encrypting the message; it won't be decryptable.
msg := newTestMessage(t, tests.MakeKeyRing(t), "messageID", "addressID", "text/plain", "body", time.Now())
- _, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority)
+ defer done()
+
+ _, err := job.GetResult()
assert.True(t, errors.Is(err, ErrDecryptionFailed))
}
@@ -1157,7 +1274,7 @@ func TestBuildUndecryptableAttachment(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1166,7 +1283,10 @@ func TestBuildUndecryptableAttachment(t *testing.T) {
// Use a different keyring for encrypting the attachment; it won't be decryptable.
att := addTestAttachment(t, tests.MakeKeyRing(t), msg, "attachID", "file.png", "image/png", "attachment", "attachment")
- _, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority)
+ defer done()
+
+ _, err := job.GetResult()
assert.True(t, errors.Is(err, ErrDecryptionFailed))
}
@@ -1174,7 +1294,7 @@ func TestBuildCustomMessagePlain(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1184,12 +1304,16 @@ func TestBuildCustomMessagePlain(t *testing.T) {
msg := newTestMessage(t, foreignKR, "messageID", "addressID", "text/plain", "body", time.Now())
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1206,7 +1330,7 @@ func TestBuildCustomMessageHTML(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1216,12 +1340,16 @@ func TestBuildCustomMessageHTML(t *testing.T) {
msg := newTestMessage(t, foreignKR, "messageID", "addressID", "text/html", "body", time.Now())
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1238,7 +1366,7 @@ func TestBuildCustomMessageEncrypted(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1252,12 +1380,16 @@ func TestBuildCustomMessageEncrypted(t *testing.T) {
msg.Subject = "this is a subject to make sure we preserve subject"
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1283,7 +1415,7 @@ func TestBuildCustomMessagePlainWithAttachment(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1294,12 +1426,16 @@ func TestBuildCustomMessagePlainWithAttachment(t *testing.T) {
att := addTestAttachment(t, foreignKR, msg, "attachID", "file.png", "image/png", "attachment", "attachment")
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg, att),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1324,7 +1460,7 @@ func TestBuildCustomMessageHTMLWithAttachment(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1335,12 +1471,16 @@ func TestBuildCustomMessageHTMLWithAttachment(t *testing.T) {
att := addTestAttachment(t, foreignKR, msg, "attachID", "file.png", "image/png", "attachment", "attachment")
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg, att),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1365,7 +1505,7 @@ func TestBuildCustomMessageOnlyBodyIsUndecryptable(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1378,12 +1518,16 @@ func TestBuildCustomMessageOnlyBodyIsUndecryptable(t *testing.T) {
att := addTestAttachment(t, kr, msg, "attachID", "file.png", "image/png", "attachment", "attachment")
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg, att),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1407,7 +1551,7 @@ func TestBuildCustomMessageOnlyAttachmentIsUndecryptable(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
// Use the original keyring for encrypting the message; it should decrypt fine.
@@ -1419,12 +1563,16 @@ func TestBuildCustomMessageOnlyAttachmentIsUndecryptable(t *testing.T) {
att := addTestAttachment(t, foreignKR, msg, "attachID", "file.png", "image/png", "attachment", "attachment")
// Tell the job to ignore decryption errors; a custom message will be returned instead of an error.
- res, err := b.NewJobWithOptions(
+ job, done := b.NewJobWithOptions(
context.Background(),
newTestFetcher(m, kr, msg, att),
msg.ID,
JobOptions{IgnoreDecryptionErrors: true},
- ).GetResult()
+ ForegroundPriority,
+ )
+ defer done()
+
+ res, err := job.GetResult()
require.NoError(t, err)
section(t, res).
@@ -1448,7 +1596,7 @@ func TestBuildFetchMessageFail(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1459,7 +1607,10 @@ func TestBuildFetchMessageFail(t *testing.T) {
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(nil, errors.New("oops"))
// The job should fail, returning an error and a nil result.
- res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), f, msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
assert.Error(t, err)
assert.Nil(t, res)
}
@@ -1468,7 +1619,7 @@ func TestBuildFetchAttachmentFail(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1481,7 +1632,10 @@ func TestBuildFetchAttachmentFail(t *testing.T) {
f.EXPECT().GetAttachment(gomock.Any(), msg.Attachments[0].ID).Return(nil, errors.New("oops"))
// The job should fail, returning an error and a nil result.
- res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), f, msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
assert.Error(t, err)
assert.Nil(t, res)
}
@@ -1490,7 +1644,7 @@ func TestBuildNoSuchKeyRing(t *testing.T) {
m := gomock.NewController(t)
defer m.Finish()
- b := NewBuilder(1, 1, 1)
+ b := NewBuilder(2, 2)
defer b.Done()
kr := tests.MakeKeyRing(t)
@@ -1501,7 +1655,10 @@ func TestBuildNoSuchKeyRing(t *testing.T) {
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops"))
- res, err := b.NewJob(context.Background(), f, msg.ID).GetResult()
+ job, done := b.NewJob(context.Background(), f, msg.ID, ForegroundPriority)
+ defer done()
+
+ res, err := job.GetResult()
assert.Error(t, err)
assert.Nil(t, res)
diff --git a/pkg/message/section.go b/pkg/message/section.go
index 9154fa4c..ee62162c 100644
--- a/pkg/message/section.go
+++ b/pkg/message/section.go
@@ -38,9 +38,10 @@ type BodyStructure map[string]*SectionInfo
// SectionInfo is used to hold data about parts of each section.
type SectionInfo struct {
- Header textproto.MIMEHeader
+ Header []byte
Start, BSize, Size, Lines int
reader io.Reader
+ isHeaderReadFinished bool
}
// Read will also count the final size of section.
@@ -48,9 +49,38 @@ func (si *SectionInfo) Read(p []byte) (n int, err error) {
n, err = si.reader.Read(p)
si.Size += n
si.Lines += bytes.Count(p, []byte("\n"))
+
+ si.readHeader(p)
return
}
+// readHeader appends read data to Header until empty line is found.
+func (si *SectionInfo) readHeader(p []byte) {
+ if si.isHeaderReadFinished {
+ return
+ }
+
+ si.Header = append(si.Header, p...)
+
+ if i := bytes.Index(si.Header, []byte("\n\r\n")); i > 0 {
+ si.Header = si.Header[:i+3]
+ si.isHeaderReadFinished = true
+ return
+ }
+
+ // textproto works also with simple line ending so we should be liberal
+ // as well.
+ if i := bytes.Index(si.Header, []byte("\n\n")); i > 0 {
+ si.Header = si.Header[:i+2]
+ si.isHeaderReadFinished = true
+ }
+}
+
+// GetMIMEHeader parses bytes and return MIME header.
+func (si *SectionInfo) GetMIMEHeader() (textproto.MIMEHeader, error) {
+ return textproto.NewReader(bufio.NewReader(bytes.NewReader(si.Header))).ReadMIMEHeader()
+}
+
func NewBodyStructure(reader io.Reader) (structure *BodyStructure, err error) {
structure = &BodyStructure{}
err = structure.Parse(reader)
@@ -93,14 +123,15 @@ func (bs *BodyStructure) parseAllChildSections(r io.Reader, currentPath []int, s
bufInfo := bufio.NewReader(info)
tp := textproto.NewReader(bufInfo)
- if info.Header, err = tp.ReadMIMEHeader(); err != nil {
+ tpHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
return
}
bodyInfo := &SectionInfo{reader: tp.R}
bodyReader := bufio.NewReader(bodyInfo)
- mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type"))
+ mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
// If multipart, call getAllParts, else read to count lines.
if (strings.HasPrefix(mediaType, "multipart/") || mediaType == rfc822Message) && params["boundary"] != "" {
@@ -260,9 +291,9 @@ func (bs *BodyStructure) GetMailHeader() (header textproto.MIMEHeader, err error
}
// GetMailHeaderBytes returns the bytes with main mail header.
-// Warning: It can contain extra lines or multipart comment.
-func (bs *BodyStructure) GetMailHeaderBytes(wholeMail io.ReadSeeker) (header []byte, err error) {
- return bs.GetSectionHeaderBytes(wholeMail, []int{})
+// Warning: It can contain extra lines.
+func (bs *BodyStructure) GetMailHeaderBytes() (header []byte, err error) {
+ return bs.GetSectionHeaderBytes([]int{})
}
func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byte, error) {
@@ -283,22 +314,21 @@ func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byt
}
// GetSectionHeader returns the mime header of specified section.
-func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (header textproto.MIMEHeader, err error) {
+func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (textproto.MIMEHeader, error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
- return
+ return nil, err
}
- header = info.Header
- return
+ return info.GetMIMEHeader()
}
-func (bs *BodyStructure) GetSectionHeaderBytes(wholeMail io.ReadSeeker, sectionPath []int) (header []byte, err error) {
+// GetSectionHeaderBytes returns raw header bytes of specified section.
+func (bs *BodyStructure) GetSectionHeaderBytes(sectionPath []int) ([]byte, error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
- return
+ return nil, err
}
- headerLength := info.Size - info.BSize
- return goToOffsetAndReadNBytes(wholeMail, info.Start, headerLength)
+ return info.Header, nil
}
// IMAPBodyStructure will prepare imap bodystructure recurently for given part.
@@ -309,7 +339,12 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body
return
}
- mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type"))
+ tpHeader, err := info.GetMIMEHeader()
+ if err != nil {
+ return
+ }
+
+ mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
mediaTypeSep := strings.Split(mediaType, "/")
@@ -324,19 +359,19 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body
Lines: uint32(info.Lines),
}
- if val := info.Header.Get("Content-ID"); val != "" {
+ if val := tpHeader.Get("Content-ID"); val != "" {
imapBS.Id = val
}
- if val := info.Header.Get("Content-Transfer-Encoding"); val != "" {
+ if val := tpHeader.Get("Content-Transfer-Encoding"); val != "" {
imapBS.Encoding = val
}
- if val := info.Header.Get("Content-Description"); val != "" {
+ if val := tpHeader.Get("Content-Description"); val != "" {
imapBS.Description = val
}
- if val := info.Header.Get("Content-Disposition"); val != "" {
+ if val := tpHeader.Get("Content-Disposition"); val != "" {
imapBS.Disposition = val
}
diff --git a/pkg/message/section_test.go b/pkg/message/section_test.go
index b9ba8a39..16f27564 100644
--- a/pkg/message/section_test.go
+++ b/pkg/message/section_test.go
@@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"io/ioutil"
- "net/textproto"
"path/filepath"
"runtime"
"sort"
@@ -71,7 +70,9 @@ func TestParseBodyStructure(t *testing.T) {
debug("%10s: %-50s %5s %5s %5s %5s", "section", "type", "start", "size", "bsize", "lines")
for _, path := range paths {
sec := (*bs)[path]
- contentType := (*bs)[path].Header.Get("Content-Type")
+ header, err := sec.GetMIMEHeader()
+ require.NoError(t, err)
+ contentType := header.Get("Content-Type")
debug("%10s: %-50s %5d %5d %5d %5d", path, contentType, sec.Start, sec.Size, sec.BSize, sec.Lines)
require.Equal(t, expectedStructure[path], contentType)
}
@@ -100,7 +101,9 @@ func TestParseBodyStructurePGP(t *testing.T) {
haveStructure := map[string]string{}
for path := range *bs {
- haveStructure[path] = (*bs)[path].Header.Get("Content-Type")
+ header, err := (*bs)[path].GetMIMEHeader()
+ require.NoError(t, err)
+ haveStructure[path] = header.Get("Content-Type")
}
require.Equal(t, expectedStructure, haveStructure)
@@ -192,7 +195,7 @@ Content-Type: plain/text
r.NoError(err, debug(wantPath, info, haveBody))
r.Equal(wantBody, string(haveBody), debug(wantPath, info, haveBody))
- haveHeader, err := bs.GetSectionHeaderBytes(strings.NewReader(wantMail), wantPath)
+ haveHeader, err := bs.GetSectionHeaderBytes(wantPath)
r.NoError(err, debug(wantPath, info, haveHeader))
r.Equal(wantHeader, string(haveHeader), debug(wantPath, info, haveHeader))
}
@@ -211,7 +214,7 @@ Content-Type: multipart/mixed; boundary="0000MAIN"
bs, err := NewBodyStructure(structReader)
require.NoError(t, err)
- haveHeader, err := bs.GetMailHeaderBytes(strings.NewReader(sampleMail))
+ haveHeader, err := bs.GetMailHeaderBytes()
require.NoError(t, err)
require.Equal(t, wantHeader, haveHeader)
}
@@ -533,18 +536,14 @@ func TestBodyStructureSerialize(t *testing.T) {
r := require.New(t)
want := &BodyStructure{
"1": {
- Header: textproto.MIMEHeader{
- "Content": []string{"type"},
- },
- Start: 1,
- Size: 2,
- BSize: 3,
- Lines: 4,
+ Header: []byte("Content: type"),
+ Start: 1,
+ Size: 2,
+ BSize: 3,
+ Lines: 4,
},
"1.1.1": {
- Header: textproto.MIMEHeader{
- "X-Pm-Key": []string{"id"},
- },
+ Header: []byte("X-Pm-Key: id"),
Start: 11,
Size: 12,
BSize: 13,
@@ -562,3 +561,32 @@ func TestBodyStructureSerialize(t *testing.T) {
(*want)["1.1.1"].reader = nil
r.Equal(want, have)
}
+
+func TestSectionInfoReadHeader(t *testing.T) {
+ r := require.New(t)
+
+ testData := []struct {
+ wantHeader, mail string
+ }{
+ {
+ "key1: val1\nkey2: val2\n\n",
+ "key1: val1\nkey2: val2\n\nbody is here\n\nand it is not confused",
+ },
+ {
+ "key1:\n val1\n\n",
+ "key1:\n val1\n\nbody is here",
+ },
+ {
+ "key1: val1\r\nkey2: val2\r\n\r\n",
+ "key1: val1\r\nkey2: val2\r\n\r\nbody is here\r\n\r\nand it is not confused",
+ },
+ }
+
+ for _, td := range testData {
+ bs, err := NewBodyStructure(strings.NewReader(td.mail))
+ r.NoError(err, "case %q", td.mail)
+ haveHeader, err := bs.GetMailHeaderBytes()
+ r.NoError(err, "case %q", td.mail)
+ r.Equal(td.wantHeader, string(haveHeader), "case %q", td.mail)
+ }
+}
diff --git a/pkg/pchan/pchan.go b/pkg/pchan/pchan.go
new file mode 100644
index 00000000..b353a6a6
--- /dev/null
+++ b/pkg/pchan/pchan.go
@@ -0,0 +1,131 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package pchan
+
+import (
+ "sort"
+ "sync"
+)
+
+type PChan struct {
+ lock sync.Mutex
+ items []*Item
+ ready, done chan struct{}
+}
+
+type Item struct {
+ ch *PChan
+ val interface{}
+ prio int
+ done chan struct{}
+}
+
+func (item *Item) Wait() {
+ <-item.done
+}
+
+func (item *Item) GetPriority() int {
+ item.ch.lock.Lock()
+ defer item.ch.lock.Unlock()
+
+ return item.prio
+}
+
+func (item *Item) SetPriority(priority int) {
+ item.ch.lock.Lock()
+ defer item.ch.lock.Unlock()
+
+ item.prio = priority
+
+ sort.Slice(item.ch.items, func(i, j int) bool {
+ return item.ch.items[i].prio < item.ch.items[j].prio
+ })
+}
+
+func New() *PChan {
+ return &PChan{
+ ready: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+}
+
+func (ch *PChan) Push(val interface{}, prio int) *Item {
+ defer ch.notify()
+
+ return ch.push(val, prio)
+}
+
+func (ch *PChan) Pop() (interface{}, int, bool) {
+ select {
+ case <-ch.ready:
+ val, prio := ch.pop()
+ return val, prio, true
+
+ case <-ch.done:
+ return nil, 0, false
+ }
+}
+
+func (ch *PChan) Close() {
+ select {
+ case <-ch.done:
+ return
+
+ default:
+ close(ch.done)
+ }
+}
+
+func (ch *PChan) push(val interface{}, prio int) *Item {
+ ch.lock.Lock()
+ defer ch.lock.Unlock()
+
+ done := make(chan struct{})
+
+ item := &Item{
+ ch: ch,
+ val: val,
+ prio: prio,
+ done: done,
+ }
+
+ ch.items = append(ch.items, item)
+
+ return item
+}
+
+func (ch *PChan) pop() (interface{}, int) {
+ ch.lock.Lock()
+ defer ch.lock.Unlock()
+
+ sort.Slice(ch.items, func(i, j int) bool {
+ return ch.items[i].prio < ch.items[j].prio
+ })
+
+ var item *Item
+
+ item, ch.items = ch.items[len(ch.items)-1], ch.items[:len(ch.items)-1]
+
+ defer close(item.done)
+
+ return item.val, item.prio
+}
+
+func (ch *PChan) notify() {
+ go func() { ch.ready <- struct{}{} }()
+}
diff --git a/pkg/pchan/pchan_test.go b/pkg/pchan/pchan_test.go
new file mode 100644
index 00000000..fdd323de
--- /dev/null
+++ b/pkg/pchan/pchan_test.go
@@ -0,0 +1,123 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package pchan
+
+import (
+ "sort"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPChanConcurrentPush(t *testing.T) {
+ ch := New()
+
+ var wg sync.WaitGroup
+
+ // We are going to test with 5 additional goroutines.
+ wg.Add(5)
+
+ // Start 5 concurrent pushes.
+ go func() { defer wg.Done(); ch.Push(1, 1) }()
+ go func() { defer wg.Done(); ch.Push(2, 2) }()
+ go func() { defer wg.Done(); ch.Push(3, 3) }()
+ go func() { defer wg.Done(); ch.Push(4, 4) }()
+ go func() { defer wg.Done(); ch.Push(5, 5) }()
+
+ // Wait for the items to be pushed.
+ wg.Wait()
+
+ // All 5 should now be ready for popping.
+ require.Len(t, ch.items, 5)
+
+ // They should be popped in priority order.
+ assert.Equal(t, 5, getValue(t, ch))
+ assert.Equal(t, 4, getValue(t, ch))
+ assert.Equal(t, 3, getValue(t, ch))
+ assert.Equal(t, 2, getValue(t, ch))
+ assert.Equal(t, 1, getValue(t, ch))
+}
+
+func TestPChanConcurrentPop(t *testing.T) {
+ ch := New()
+
+ var wg sync.WaitGroup
+
+ // We are going to test with 5 additional goroutines.
+ wg.Add(5)
+
+ // Make a list to store the results in.
+ var res list
+
+ // Start 5 concurrent pops; these consume any items pushed.
+ go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
+ go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
+ go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
+ go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
+ go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
+
+ // Push and block; items should be popped immediately by the waiting goroutines.
+ ch.Push(1, 1).Wait()
+ ch.Push(2, 2).Wait()
+ ch.Push(3, 3).Wait()
+ ch.Push(4, 4).Wait()
+ ch.Push(5, 5).Wait()
+
+ // Wait for all items to be popped then close the result channel.
+ wg.Wait()
+
+ assert.True(t, sort.IntsAreSorted(res.items))
+}
+
+func TestPChanClose(t *testing.T) {
+ ch := New()
+
+ go ch.Push(1, 1)
+
+ valOpen, _, okOpen := ch.Pop()
+ assert.True(t, okOpen)
+ assert.Equal(t, 1, valOpen)
+
+ ch.Close()
+
+ valClose, _, okClose := ch.Pop()
+ assert.False(t, okClose)
+ assert.Nil(t, valClose)
+}
+
+type list struct {
+ items []int
+ mut sync.Mutex
+}
+
+func (l *list) append(val int) {
+ l.mut.Lock()
+ defer l.mut.Unlock()
+
+ l.items = append(l.items, val)
+}
+
+func getValue(t *testing.T, ch *PChan) int {
+ val, _, ok := ch.Pop()
+
+ assert.True(t, ok)
+
+ return val.(int)
+}
diff --git a/pkg/pmapi/client_types.go b/pkg/pmapi/client_types.go
index bf87bd60..56e9afb6 100644
--- a/pkg/pmapi/client_types.go
+++ b/pkg/pmapi/client_types.go
@@ -71,6 +71,7 @@ type Client interface {
GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error)
CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
+ GetUserKeyRing() (*crypto.KeyRing, error)
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
}
diff --git a/pkg/pmapi/messages.go b/pkg/pmapi/messages.go
index 498213a3..10c19c14 100644
--- a/pkg/pmapi/messages.go
+++ b/pkg/pmapi/messages.go
@@ -175,7 +175,6 @@ type Message struct {
CCList []*mail.Address
BCCList []*mail.Address
Time int64 // Unix time
- Size int64
NumAttachments int
ExpirationTime int64 // Unix time
SpamScore int
diff --git a/pkg/pmapi/mocks/mocks.go b/pkg/pmapi/mocks/mocks.go
index 03085507..07b4b767 100644
--- a/pkg/pmapi/mocks/mocks.go
+++ b/pkg/pmapi/mocks/mocks.go
@@ -362,6 +362,21 @@ func (mr *MockClientMockRecorder) GetPublicKeysForEmail(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeysForEmail", reflect.TypeOf((*MockClient)(nil).GetPublicKeysForEmail), arg0, arg1)
}
+// GetUserKeyRing mocks base method.
+func (m *MockClient) GetUserKeyRing() (*crypto.KeyRing, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetUserKeyRing")
+ ret0, _ := ret[0].(*crypto.KeyRing)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetUserKeyRing indicates an expected call of GetUserKeyRing.
+func (mr *MockClientMockRecorder) GetUserKeyRing() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserKeyRing", reflect.TypeOf((*MockClient)(nil).GetUserKeyRing))
+}
+
// Import mocks base method.
func (m *MockClient) Import(arg0 context.Context, arg1 pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
m.ctrl.T.Helper()
diff --git a/pkg/pmapi/users.go b/pkg/pmapi/users.go
index d272327c..f60fd357 100644
--- a/pkg/pmapi/users.go
+++ b/pkg/pmapi/users.go
@@ -20,6 +20,7 @@ package pmapi
import (
"context"
+ "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/getsentry/sentry-go"
"github.com/go-resty/resty/v2"
"github.com/pkg/errors"
@@ -138,3 +139,12 @@ func (c *client) CurrentUser(ctx context.Context) (*User, error) {
return c.UpdateUser(ctx)
}
+
+// CurrentUser returns currently active user or user will be updated.
+func (c *client) GetUserKeyRing() (*crypto.KeyRing, error) {
+ if c.userKeyRing == nil {
+ return nil, errors.New("user keyring is not available")
+ }
+
+ return c.userKeyRing, nil
+}
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
new file mode 100644
index 00000000..2b67fe54
--- /dev/null
+++ b/pkg/pool/pool.go
@@ -0,0 +1,129 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package pool
+
+import "github.com/ProtonMail/proton-bridge/pkg/pchan"
+
+type WorkFunc func(interface{}, int) (interface{}, error)
+
+type DoneFunc func()
+
+type Pool struct {
+ jobCh *pchan.PChan
+}
+
+func New(size int, work WorkFunc) *Pool {
+ jobCh := pchan.New()
+
+ for i := 0; i < size; i++ {
+ go func() {
+ for {
+ val, prio, ok := jobCh.Pop()
+ if !ok {
+ return
+ }
+
+ job, ok := val.(*Job)
+ if !ok {
+ panic("bad result type")
+ }
+
+ res, err := work(job.req, prio)
+ if err != nil {
+ job.postFailure(err)
+ } else {
+ job.postSuccess(res)
+ }
+
+ job.waitDone()
+ }
+ }()
+ }
+
+ return &Pool{jobCh: jobCh}
+}
+
+func (pool *Pool) NewJob(req interface{}, prio int) (*Job, DoneFunc) {
+ job := newJob(req)
+
+ job.setItem(pool.jobCh.Push(job, prio))
+
+ return job, job.markDone
+}
+
+type Job struct {
+ req interface{}
+ res interface{}
+ err error
+
+ item *pchan.Item
+
+ ready, done chan struct{}
+}
+
+func newJob(req interface{}) *Job {
+ return &Job{
+ req: req,
+ ready: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+}
+
+func (job *Job) GetResult() (interface{}, error) {
+ <-job.ready
+
+ return job.res, job.err
+}
+
+func (job *Job) GetPriority() int {
+ return job.item.GetPriority()
+}
+
+func (job *Job) SetPriority(prio int) {
+ job.item.SetPriority(prio)
+}
+
+func (job *Job) postSuccess(res interface{}) {
+ defer close(job.ready)
+
+ job.res = res
+}
+
+func (job *Job) postFailure(err error) {
+ defer close(job.ready)
+
+ job.err = err
+}
+
+func (job *Job) setItem(item *pchan.Item) {
+ job.item = item
+}
+
+func (job *Job) markDone() {
+ select {
+ case <-job.done:
+ return
+
+ default:
+ close(job.done)
+ }
+}
+
+func (job *Job) waitDone() {
+ <-job.done
+}
diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go
new file mode 100644
index 00000000..6c38b329
--- /dev/null
+++ b/pkg/pool/pool_test.go
@@ -0,0 +1,45 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package pool_test
+
+import (
+ "testing"
+
+ "github.com/ProtonMail/proton-bridge/pkg/pool"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPool(t *testing.T) {
+ pool := pool.New(2, func(req interface{}, prio int) (interface{}, error) { return req, nil })
+
+ job1, done1 := pool.NewJob("echo", 1)
+ defer done1()
+
+ job2, done2 := pool.NewJob("this", 1)
+ defer done2()
+
+ res2, err := job2.GetResult()
+ require.NoError(t, err)
+
+ res1, err := job1.GetResult()
+ require.NoError(t, err)
+
+ assert.Equal(t, "echo", res1)
+ assert.Equal(t, "this", res2)
+}
diff --git a/pkg/semaphore/semaphore.go b/pkg/semaphore/semaphore.go
new file mode 100644
index 00000000..592e45a4
--- /dev/null
+++ b/pkg/semaphore/semaphore.go
@@ -0,0 +1,53 @@
+// Copyright (c) 2021 Proton Technologies AG
+//
+// This file is part of ProtonMail Bridge.
+//
+// ProtonMail Bridge is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// ProtonMail Bridge is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with ProtonMail Bridge. If not, see .
+
+package semaphore
+
+import "sync"
+
+type Semaphore struct {
+ ch chan struct{}
+ wg sync.WaitGroup
+}
+
+func New(max int) Semaphore {
+ return Semaphore{ch: make(chan struct{}, max)}
+}
+
+func (sem *Semaphore) Lock() {
+ sem.ch <- struct{}{}
+}
+
+func (sem *Semaphore) Unlock() {
+ <-sem.ch
+}
+
+func (sem *Semaphore) Go(fn func()) {
+ sem.Lock()
+ sem.wg.Add(1)
+
+ go func() {
+ defer sem.Unlock()
+ defer sem.wg.Done()
+
+ fn()
+ }()
+}
+
+func (sem *Semaphore) Wait() {
+ sem.wg.Wait()
+}
diff --git a/test/context/bridge.go b/test/context/bridge.go
index 66b60f60..c48bbc5f 100644
--- a/test/context/bridge.go
+++ b/test/context/bridge.go
@@ -21,11 +21,14 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
+ "github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/config/useragent"
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/sentry"
+ "github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener"
+ "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
@@ -61,18 +64,28 @@ func (ctx *TestContext) RestartBridge() error {
}
// newBridgeInstance creates a new bridge instance configured to use the given config/credstore.
+// NOTE(GODT-1158): Need some tests with on-disk cache as well! Configurable in feature file or envvar?
func newBridgeInstance(
t *bddT,
locations bridge.Locator,
- cache bridge.Cacher,
- settings *fakeSettings,
+ cacheProvider bridge.CacheProvider,
+ fakeSettings *fakeSettings,
credStore users.CredentialsStorer,
eventListener listener.Listener,
clientManager pmapi.Manager,
) *bridge.Bridge {
- sentryReporter := sentry.NewReporter("bridge", constants.Version, useragent.New())
- panicHandler := &panicHandler{t: t}
- updater := newFakeUpdater()
- versioner := newFakeVersioner()
- return bridge.New(locations, cache, settings, sentryReporter, panicHandler, eventListener, clientManager, credStore, updater, versioner)
+ return bridge.New(
+ locations,
+ cacheProvider,
+ fakeSettings,
+ sentry.NewReporter("bridge", constants.Version, useragent.New()),
+ &panicHandler{t: t},
+ eventListener,
+ cache.NewInMemoryCache(100*(1<<20)),
+ message.NewBuilder(fakeSettings.GetInt(settings.FetchWorkers), fakeSettings.GetInt(settings.AttachmentWorkers)),
+ clientManager,
+ credStore,
+ newFakeUpdater(),
+ newFakeVersioner(),
+ )
}
diff --git a/test/context/imap.go b/test/context/imap.go
index 1f85f802..f5dd8ca0 100644
--- a/test/context/imap.go
+++ b/test/context/imap.go
@@ -58,7 +58,7 @@ func (ctx *TestContext) withIMAPServer() {
port := ctx.settings.GetInt(settings.IMAPPortKey)
tls, _ := tls.New(settingsPath).GetConfig()
- backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.bridge)
+ backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.settings, ctx.bridge)
server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener)
go server.ListenAndServe()
diff --git a/test/fakeapi/user.go b/test/fakeapi/user.go
index 7f42dda4..55144f87 100644
--- a/test/fakeapi/user.go
+++ b/test/fakeapi/user.go
@@ -125,6 +125,10 @@ func (api *FakePMAPI) Addresses() pmapi.AddressList {
return *api.addresses
}
+func (api *FakePMAPI) GetUserKeyRing() (*crypto.KeyRing, error) {
+ return api.userKeyRing, nil
+}
+
func (api *FakePMAPI) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
return api.addrKeyRing[addrID], nil
}