diff --git a/Makefile b/Makefile index cd958edb..288eead8 100644 --- a/Makefile +++ b/Makefile @@ -230,7 +230,7 @@ integration-test-bridge: mocks: mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/users Locator,PanicHandler,CredentialsStorer,StoreMaker > internal/users/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/users/mocks/listener_mocks.go - mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier > internal/store/mocks/mocks.go + mockgen --package mocks github.com/ProtonMail/proton-bridge/internal/store PanicHandler,BridgeUser,ChangeNotifier,Storer > internal/store/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/listener Listener > internal/store/mocks/utils_mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/pmapi Client,Manager > pkg/pmapi/mocks/mocks.go mockgen --package mocks github.com/ProtonMail/proton-bridge/pkg/message Fetcher > pkg/message/mocks/mocks.go @@ -288,7 +288,7 @@ run-nogui-cli: clean-vendor gofiles PROTONMAIL_ENV=dev go run ${BUILD_FLAGS} cmd/${TARGET_CMD}/main.go ${RUN_FLAGS} -c run-debug: - PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS} + PROTONMAIL_ENV=dev dlv debug --build-flags "${BUILD_FLAGS}" cmd/${TARGET_CMD}/main.go -- ${RUN_FLAGS} --noninteractive run-qml-preview: find internal/frontend/qml/ -iname '*qmlc' | xargs rm -f diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..e475fa52 --- /dev/null +++ b/TODO.md @@ -0,0 +1 @@ +- when cache is full, we need to stop the watcher? don't want to keep downloading messages and throwing them away when we try to cache them. diff --git a/go.mod b/go.mod index d4fffee1..d1a5df19 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,6 @@ require ( github.com/emersion/go-imap-move v0.0.0-20190710073258-6e5a51a5b342 github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 - github.com/emersion/go-mbox v1.0.2 github.com/emersion/go-message v0.12.1-0.20201221184100-40c3f864532b github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 github.com/emersion/go-smtp v0.14.0 @@ -45,7 +44,6 @@ require ( github.com/golang/mock v1.4.4 github.com/google/go-cmp v0.5.1 github.com/google/uuid v1.1.1 - github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c // indirect github.com/hashicorp/go-multierror v1.1.0 github.com/jaytaylor/html2text v0.0.0-20200412013138-3577fbdbcff7 github.com/keybase/go-keychain v0.0.0-20200502122510-cda31fe0c86d @@ -55,13 +53,11 @@ require ( github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/pkg/errors v0.9.1 + github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285 github.com/sirupsen/logrus v1.7.0 - github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect + github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/testify v1.7.0 - github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e - github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d // indirect - github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d // indirect github.com/urfave/cli/v2 v2.2.0 github.com/vmihailenco/msgpack/v5 v5.1.3 go.etcd.io/bbolt v1.3.6 diff --git a/go.sum b/go.sum index e169da7d..a4c5597d 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,6 @@ github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c h1:khcEdu1y github.com/emersion/go-imap-quota v0.0.0-20210203125329-619074823f3c/go.mod h1:iApyhIQBiU4XFyr+3kdJyyGqle82TbQyuP2o+OZHrV0= github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26 h1:FiSb8+XBQQSkcX3ubr+1tAtlRJBYaFmRZqOAweZ9Wy8= github.com/emersion/go-imap-unselect v0.0.0-20171113212723-b985794e5f26/go.mod h1:+gnnZx3Mg3MnCzZrv0eZdp5puxXQUgGT/6N6L7ShKfM= -github.com/emersion/go-mbox v1.0.2 h1:tE/rT+lEugK9y0myEymCCHnwlZN04hlXPrbKkxRBA5I= -github.com/emersion/go-mbox v1.0.2/go.mod h1:Yp9IVuuOYLEuMv4yjgDHvhb5mHOcYH6x92Oas3QqEZI= github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= @@ -197,9 +195,6 @@ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20190411002643-bd77b112433e/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4= -github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= @@ -269,7 +264,6 @@ github.com/klauspost/compress v1.8.2/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= @@ -351,6 +345,8 @@ github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285 h1:d54EL9l+XteliUfUCGsEwwuk65dmmxX85VXF+9T6+50= +github.com/ricochet2200/go-disk-usage/du v0.0.0-20210707232629-ac9918953285/go.mod h1:fxIDly1xtudczrZeOOlfaUvd2OPb2qZAPuWdU2BsBTk= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo= @@ -365,13 +361,9 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= @@ -401,12 +393,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= -github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e h1:G0DQ/TRQyrEZjtLlLwevFjaRiG8eeCMlq9WXQ2OO2bk= -github.com/therecipe/qt v0.0.0-20200701200531-7f61353ee73e/go.mod h1:SUUR2j3aE1z6/g76SdD6NwACEpvCxb3fvG82eKbD6us= -github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d h1:hAZyEG2swPRWjF0kqqdGERXUazYnRJdAk4a58f14z7Y= -github.com/therecipe/qt/internal/binding/files/docs/5.12.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:7m8PDYDEtEVqfjoUQc2UrFqhG0CDmoVJjRlQxexndFc= -github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d h1:AJRoBel/g9cDS+yE8BcN3E+TDD/xNAguG21aoR8DAIE= -github.com/therecipe/qt/internal/binding/files/docs/5.13.0 v0.0.0-20200904063919-c0c124a5770d/go.mod h1:mH55Ek7AZcdns5KPp99O0bg+78el64YCYWHiQKrOdt4= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= @@ -444,7 +430,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190418165655-df01cb2cc480/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -488,7 +473,6 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190327091125-710a502c58a2/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190420063019-afa5a82059c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -521,9 +505,7 @@ golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -557,7 +539,6 @@ golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190327201419-c70d86f8b7cf/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= diff --git a/internal/app/bridge/bridge.go b/internal/app/bridge/bridge.go index fb39b4cc..44b0e061 100644 --- a/internal/app/bridge/bridge.go +++ b/internal/app/bridge/bridge.go @@ -32,7 +32,9 @@ import ( "github.com/ProtonMail/proton-bridge/internal/frontend/types" "github.com/ProtonMail/proton-bridge/internal/imap" "github.com/ProtonMail/proton-bridge/internal/smtp" + "github.com/ProtonMail/proton-bridge/internal/store/cache" "github.com/ProtonMail/proton-bridge/internal/updater" + "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -69,10 +71,21 @@ func New(base *base.Base) *cli.App { func run(b *base.Base, c *cli.Context) error { // nolint[funlen] tlsConfig, err := loadTLSConfig(b) if err != nil { - logrus.WithError(err).Fatal("Failed to load TLS config") + return err } - bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, b.CM, b.Creds, b.Updater, b.Versioner) - imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, bridge) + + cache, err := loadCache(b) + if err != nil { + return err + } + + builder := message.NewBuilder( + b.Settings.GetInt(settings.FetchWorkers), + b.Settings.GetInt(settings.AttachmentWorkers), + ) + + bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, cache, builder, b.CM, b.Creds, b.Updater, b.Versioner) + imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge) smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge) go func() { @@ -233,3 +246,35 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool) f.NotifySilentUpdateInstalled() } + +// NOTE(GODT-1158): How big should in-memory cache be? +// NOTE(GODT-1158): How to handle cache location migration if user changes custom path? +func loadCache(b *base.Base) (cache.Cache, error) { + if !b.Settings.GetBool(settings.CacheEnabledKey) { + return cache.NewInMemoryCache(100 * (1 << 20)), nil + } + + var compressor cache.Compressor + + // NOTE(GODT-1158): If user changes compression setting we have to nuke the cache. + if b.Settings.GetBool(settings.CacheCompressionKey) { + compressor = &cache.GZipCompressor{} + } else { + compressor = &cache.NoopCompressor{} + } + + var path string + + if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" { + path = customPath + } else { + path = b.Cache.GetDefaultMessageCacheDir() + } + + return cache.NewOnDiskCache(path, compressor, cache.Options{ + MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)), + MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey), + ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead), + ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite), + }) +} diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index ee282dcc..398ae830 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -28,17 +28,17 @@ import ( "github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/sentry" + "github.com/ProtonMail/proton-bridge/internal/store/cache" "github.com/ProtonMail/proton-bridge/internal/updater" "github.com/ProtonMail/proton-bridge/internal/users" + "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/listener" logrus "github.com/sirupsen/logrus" ) -var ( - log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals] -) +var log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals] type Bridge struct { *users.Users @@ -52,11 +52,13 @@ type Bridge struct { func New( locations Locator, - cache Cacher, - s SettingsProvider, + cacheProvider CacheProvider, + setting SettingsProvider, sentryReporter *sentry.Reporter, panicHandler users.PanicHandler, eventListener listener.Listener, + cache cache.Cache, + builder *message.Builder, clientManager pmapi.Manager, credStorer users.CredentialsStorer, updater Updater, @@ -64,7 +66,7 @@ func New( ) *Bridge { // Allow DoH before starting the app if the user has previously set this setting. // This allows us to start even if protonmail is blocked. - if s.GetBool(settings.AllowProxyKey) { + if setting.GetBool(settings.AllowProxyKey) { clientManager.AllowProxy() } @@ -74,25 +76,25 @@ func New( eventListener, clientManager, credStorer, - newStoreFactory(cache, sentryReporter, panicHandler, eventListener), + newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder), ) b := &Bridge{ Users: u, locations: locations, - settings: s, + settings: setting, clientManager: clientManager, updater: updater, versioner: versioner, } - if s.GetBool(settings.FirstStartKey) { + if setting.GetBool(settings.FirstStartKey) { if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil { logrus.WithError(err).Error("Failed to send metric") } - s.SetBool(settings.FirstStartKey, false) + setting.SetBool(settings.FirstStartKey, false) } go b.heartbeat() diff --git a/internal/bridge/store_factory.go b/internal/bridge/store_factory.go index e31263ad..eb1e7aaf 100644 --- a/internal/bridge/store_factory.go +++ b/internal/bridge/store_factory.go @@ -23,47 +23,65 @@ import ( "github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/store" + "github.com/ProtonMail/proton-bridge/internal/store/cache" "github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/pkg/listener" + "github.com/ProtonMail/proton-bridge/pkg/message" ) type storeFactory struct { - cache Cacher + cacheProvider CacheProvider sentryReporter *sentry.Reporter panicHandler users.PanicHandler eventListener listener.Listener - storeCache *store.Cache + events *store.Events + cache cache.Cache + builder *message.Builder } func newStoreFactory( - cache Cacher, + cacheProvider CacheProvider, sentryReporter *sentry.Reporter, panicHandler users.PanicHandler, eventListener listener.Listener, + cache cache.Cache, + builder *message.Builder, ) *storeFactory { return &storeFactory{ - cache: cache, + cacheProvider: cacheProvider, sentryReporter: sentryReporter, panicHandler: panicHandler, eventListener: eventListener, - storeCache: store.NewCache(cache.GetIMAPCachePath()), + events: store.NewEvents(cacheProvider.GetIMAPCachePath()), + cache: cache, + builder: builder, } } // New creates new store for given user. func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) { - storePath := getUserStorePath(f.cache.GetDBDir(), user.ID()) - return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache) + return store.New( + f.sentryReporter, + f.panicHandler, + user, + f.eventListener, + f.cache, + f.builder, + getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()), + f.events, + ) } // Remove removes all store files for given user. func (f *storeFactory) Remove(userID string) error { - storePath := getUserStorePath(f.cache.GetDBDir(), userID) - return store.RemoveStore(f.storeCache, storePath, userID) + return store.RemoveStore( + f.events, + getUserStorePath(f.cacheProvider.GetDBDir(), userID), + userID, + ) } // getUserStorePath returns the file path of the store database for the given userID. func getUserStorePath(storeDir string, userID string) (path string) { - fileName := fmt.Sprintf("mailbox-%v.db", userID) - return filepath.Join(storeDir, fileName) + return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID)) } diff --git a/internal/bridge/types.go b/internal/bridge/types.go index 4b3cb701..1077acf9 100644 --- a/internal/bridge/types.go +++ b/internal/bridge/types.go @@ -28,7 +28,7 @@ type Locator interface { ClearUpdates() error } -type Cacher interface { +type CacheProvider interface { GetIMAPCachePath() string GetDBDir() string } @@ -38,6 +38,7 @@ type SettingsProvider interface { Set(key string, value string) GetBool(key string) bool SetBool(key string, val bool) + GetInt(key string) int } type Updater interface { diff --git a/internal/config/cache/cache.go b/internal/config/cache/cache.go index 7c08bc5a..7c3ff9de 100644 --- a/internal/config/cache/cache.go +++ b/internal/config/cache/cache.go @@ -45,6 +45,11 @@ func (c *Cache) GetDBDir() string { return c.getCurrentCacheDir() } +// GetDefaultMessageCacheDir returns folder for cached messages files. +func (c *Cache) GetDefaultMessageCacheDir() string { + return filepath.Join(c.getCurrentCacheDir(), "messages") +} + // GetIMAPCachePath returns path to file with IMAP status. func (c *Cache) GetIMAPCachePath() string { return filepath.Join(c.getCurrentCacheDir(), "user_info.json") diff --git a/internal/config/settings/kvs.go b/internal/config/settings/kvs.go index f40cd1a8..6310cce7 100644 --- a/internal/config/settings/kvs.go +++ b/internal/config/settings/kvs.go @@ -100,18 +100,28 @@ func (p *keyValueStore) GetBool(key string) bool { } func (p *keyValueStore) GetInt(key string) int { + if p.Get(key) == "" { + return 0 + } + value, err := strconv.Atoi(p.Get(key)) if err != nil { logrus.WithError(err).Error("Cannot parse int") } + return value } func (p *keyValueStore) GetFloat64(key string) float64 { + if p.Get(key) == "" { + return 0 + } + value, err := strconv.ParseFloat(p.Get(key), 64) if err != nil { logrus.WithError(err).Error("Cannot parse float64") } + return value } diff --git a/internal/config/settings/settings.go b/internal/config/settings/settings.go index aadaa183..feb76b29 100644 --- a/internal/config/settings/settings.go +++ b/internal/config/settings/settings.go @@ -43,6 +43,16 @@ const ( UpdateChannelKey = "update_channel" RolloutKey = "rollout" PreferredKeychainKey = "preferred_keychain" + CacheEnabledKey = "cache_enabled" + CacheCompressionKey = "cache_compression" + CacheLocationKey = "cache_location" + CacheMinFreeAbsKey = "cache_min_free_abs" + CacheMinFreeRatKey = "cache_min_free_rat" + CacheConcurrencyRead = "cache_concurrent_read" + CacheConcurrencyWrite = "cache_concurrent_write" + IMAPWorkers = "imap_workers" + FetchWorkers = "fetch_workers" + AttachmentWorkers = "attachment_workers" ) type Settings struct { @@ -80,6 +90,16 @@ func (s *Settings) setDefaultValues() { s.setDefault(UpdateChannelKey, "") s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint[gosec] G404 It is OK to use weak random number generator here s.setDefault(PreferredKeychainKey, "") + s.setDefault(CacheEnabledKey, "true") + s.setDefault(CacheCompressionKey, "true") + s.setDefault(CacheLocationKey, "") + s.setDefault(CacheMinFreeAbsKey, "250000000") + s.setDefault(CacheMinFreeRatKey, "") + s.setDefault(CacheConcurrencyRead, "16") + s.setDefault(CacheConcurrencyWrite, "16") + s.setDefault(IMAPWorkers, "16") + s.setDefault(FetchWorkers, "16") + s.setDefault(AttachmentWorkers, "16") s.setDefault(APIPortKey, DefaultAPIPort) s.setDefault(IMAPPortKey, DefaultIMAPPort) diff --git a/internal/frontend/cli/frontend.go b/internal/frontend/cli/frontend.go index ea7899ef..dda9b7e6 100644 --- a/internal/frontend/cli/frontend.go +++ b/internal/frontend/cli/frontend.go @@ -128,6 +128,24 @@ func New( //nolint[funlen] }) fe.AddCmd(dohCmd) + // Cache-On-Disk commands. + codCmd := &ishell.Cmd{Name: "local-cache", + Help: "manage the local encrypted message cache", + } + codCmd.AddCmd(&ishell.Cmd{Name: "enable", + Help: "enable the local cache", + Func: fe.enableCacheOnDisk, + }) + codCmd.AddCmd(&ishell.Cmd{Name: "disable", + Help: "disable the local cache", + Func: fe.disableCacheOnDisk, + }) + codCmd.AddCmd(&ishell.Cmd{Name: "change-location", + Help: "change the location of the local cache", + Func: fe.setCacheOnDiskLocation, + }) + fe.AddCmd(codCmd) + // Updates commands. updatesCmd := &ishell.Cmd{Name: "updates", Help: "manage bridge updates", diff --git a/internal/frontend/cli/system.go b/internal/frontend/cli/system.go index 0ec95db3..c2643c72 100644 --- a/internal/frontend/cli/system.go +++ b/internal/frontend/cli/system.go @@ -19,6 +19,7 @@ package cli import ( "fmt" + "os" "strconv" "strings" @@ -155,6 +156,67 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) { } } +func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) { + if f.settings.GetBool(settings.CacheEnabledKey) { + f.Println("The local cache is already enabled.") + return + } + + if f.yesNoQuestion("Are you sure you want to enable the local cache") { + // Set this back to the default location before enabling. + f.settings.Set(settings.CacheLocationKey, "") + + if err := f.bridge.EnableCache(); err != nil { + f.Println("The local cache could not be enabled.") + return + } + + f.settings.SetBool(settings.CacheEnabledKey, true) + f.restarter.SetToRestart() + f.Stop() + } +} + +func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) { + if !f.settings.GetBool(settings.CacheEnabledKey) { + f.Println("The local cache is already disabled.") + return + } + + if f.yesNoQuestion("Are you sure you want to disable the local cache") { + if err := f.bridge.DisableCache(); err != nil { + f.Println("The local cache could not be disabled.") + return + } + + f.settings.SetBool(settings.CacheEnabledKey, false) + f.restarter.SetToRestart() + f.Stop() + } +} + +func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) { + if !f.settings.GetBool(settings.CacheEnabledKey) { + f.Println("The local cache must be enabled.") + return + } + + if location := f.settings.Get(settings.CacheLocationKey); location != "" { + f.Println("The current local cache location is:", location) + } + + if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" { + if err := f.bridge.MigrateCache(f.settings.Get(settings.CacheLocationKey), location); err != nil { + f.Println("The local cache location could not be changed.") + return + } + + f.settings.Set(settings.CacheLocationKey, location) + f.restarter.SetToRestart() + f.Stop() + } +} + func (f *frontendCLI) isPortFree(port string) bool { port = strings.ReplaceAll(port, ":", "") if port == "" || port == currentPort { @@ -171,3 +233,13 @@ func (f *frontendCLI) isPortFree(port string) bool { } return true } + +// NOTE(GODT-1158): Check free space in location. +func (f *frontendCLI) isCacheLocationUsable(location string) bool { + stat, err := os.Stat(location) + if err != nil { + return false + } + + return stat.IsDir() +} diff --git a/internal/frontend/types/types.go b/internal/frontend/types/types.go index 0552bcee..d1c63308 100644 --- a/internal/frontend/types/types.go +++ b/internal/frontend/types/types.go @@ -77,6 +77,9 @@ type Bridger interface { ReportBug(osType, osVersion, description, accountName, address, emailClient string) error AllowProxy() DisallowProxy() + EnableCache() error + DisableCache() error + MigrateCache(from, to string) error GetUpdateChannel() updater.UpdateChannel SetUpdateChannel(updater.UpdateChannel) (needRestart bool, err error) GetKeychainApp() string diff --git a/internal/imap/backend.go b/internal/imap/backend.go index aff3778b..85493052 100644 --- a/internal/imap/backend.go +++ b/internal/imap/backend.go @@ -37,21 +37,13 @@ import ( "time" "github.com/ProtonMail/proton-bridge/internal/bridge" + "github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/pkg/listener" - "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/emersion/go-imap" goIMAPBackend "github.com/emersion/go-imap/backend" ) -const ( - // NOTE: Each fetch worker has its own set of attach workers so there can be up to 20*5=100 API requests at once. - // This is a reasonable limit to not overwhelm API while still maintaining as much parallelism as possible. - fetchWorkers = 20 // In how many workers to fetch message (group list on IMAP). - attachWorkers = 5 // In how many workers to fetch attachments (for one message). - buildWorkers = 20 // In how many workers to build messages. -) - type panicHandler interface { HandlePanic() } @@ -61,26 +53,32 @@ type imapBackend struct { bridge bridger updates *imapUpdates eventListener listener.Listener + listWorkers int users map[string]*imapUser usersLocker sync.Locker - builder *message.Builder - imapCache map[string]map[string]string imapCachePath string imapCacheLock *sync.RWMutex } +type settingsProvider interface { + GetInt(string) int +} + // NewIMAPBackend returns struct implementing go-imap/backend interface. func NewIMAPBackend( panicHandler panicHandler, eventListener listener.Listener, cache cacheProvider, + setting settingsProvider, bridge *bridge.Bridge, ) *imapBackend { //nolint[golint] bridgeWrap := newBridgeWrap(bridge) - backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener) + + imapWorkers := setting.GetInt(settings.IMAPWorkers) + backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener, imapWorkers) go backend.monitorDisconnectedUsers() @@ -92,6 +90,7 @@ func newIMAPBackend( cache cacheProvider, bridge bridger, eventListener listener.Listener, + listWorkers int, ) *imapBackend { return &imapBackend{ panicHandler: panicHandler, @@ -102,10 +101,9 @@ func newIMAPBackend( users: map[string]*imapUser{}, usersLocker: &sync.Mutex{}, - builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers), - imapCachePath: cache.GetIMAPCachePath(), imapCacheLock: &sync.RWMutex{}, + listWorkers: listWorkers, } } diff --git a/internal/imap/cache/cache.go b/internal/imap/cache/cache.go deleted file mode 100644 index 6048f460..00000000 --- a/internal/imap/cache/cache.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package cache - -import ( - "bytes" - "sort" - "sync" - "time" - - pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message" -) - -type key struct { - ID string - Timestamp int64 - Size int -} - -type oldestFirst []key - -func (s oldestFirst) Len() int { return len(s) } -func (s oldestFirst) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s oldestFirst) Less(i, j int) bool { return s[i].Timestamp < s[j].Timestamp } - -type cachedMessage struct { - key - data []byte - structure pkgMsg.BodyStructure -} - -//nolint[gochecknoglobals] -var ( - cacheTimeLimit = int64(1 * 60 * 60 * 1000) // milliseconds - cacheSizeLimit = 100 * 1000 * 1000 // B - MUST be larger than email max size limit (~ 25 MB) - mailCache = make(map[string]cachedMessage) - - // cacheMutex takes care of one single operation, whereas buildMutex takes - // care of the whole action doing multiple operations. buildMutex will protect - // you from asking server or decrypting or building the same message more - // than once. When first request to build the message comes, it will block - // all other build requests. When the first one is done, all others are - // handled by cache, not doing anything twice. With cacheMutex we are safe - // only to not mess up with the cache, but we could end up downloading and - // building message twice. - cacheMutex = &sync.Mutex{} - buildMutex = &sync.Mutex{} - buildLocks = map[string]interface{}{} -) - -func (m *cachedMessage) isValidOrDel() bool { - if m.key.Timestamp+cacheTimeLimit < timestamp() { - delete(mailCache, m.key.ID) - return false - } - return true -} - -func timestamp() int64 { - return time.Now().UnixNano() / int64(time.Millisecond) -} - -func Clear() { - mailCache = make(map[string]cachedMessage) -} - -// BuildLock locks per message level, not on global level. -// Multiple different messages can be building at once. -func BuildLock(messageID string) { - for { - buildMutex.Lock() - if _, ok := buildLocks[messageID]; ok { // if locked, wait - buildMutex.Unlock() - time.Sleep(10 * time.Millisecond) - } else { // if unlocked, lock it - buildLocks[messageID] = struct{}{} - buildMutex.Unlock() - return - } - } -} - -func BuildUnlock(messageID string) { - buildMutex.Lock() - defer buildMutex.Unlock() - delete(buildLocks, messageID) -} - -func LoadMail(mID string) (reader *bytes.Reader, structure *pkgMsg.BodyStructure) { - reader = &bytes.Reader{} - cacheMutex.Lock() - defer cacheMutex.Unlock() - if message, ok := mailCache[mID]; ok && message.isValidOrDel() { - reader = bytes.NewReader(message.data) - structure = &message.structure - - // Update timestamp to keep emails which are used often. - message.Timestamp = timestamp() - } - return -} - -func SaveMail(mID string, msg []byte, structure *pkgMsg.BodyStructure) { - cacheMutex.Lock() - defer cacheMutex.Unlock() - - newMessage := cachedMessage{ - key: key{ - ID: mID, - Timestamp: timestamp(), - Size: len(msg), - }, - data: msg, - structure: *structure, - } - - // Remove old and reduce size. - totalSize := 0 - messageList := []key{} - for _, message := range mailCache { - if message.isValidOrDel() { - messageList = append(messageList, message.key) - totalSize += message.key.Size - } - } - sort.Sort(oldestFirst(messageList)) - var oldest key - for totalSize+newMessage.key.Size >= cacheSizeLimit { - oldest, messageList = messageList[0], messageList[1:] - delete(mailCache, oldest.ID) - totalSize -= oldest.Size - } - - // Write new. - mailCache[mID] = newMessage -} diff --git a/internal/imap/cache/cache_test.go b/internal/imap/cache/cache_test.go deleted file mode 100644 index 90696239..00000000 --- a/internal/imap/cache/cache_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package cache - -import ( - "fmt" - "testing" - "time" - - pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message" - "github.com/stretchr/testify/require" -) - -var bs = &pkgMsg.BodyStructure{} //nolint[gochecknoglobals] -const testUID = "testmsg" - -func TestSaveAndLoad(t *testing.T) { - msg := []byte("Test message") - - SaveMail(testUID, msg, bs) - require.Equal(t, mailCache[testUID].data, msg) - - reader, _ := LoadMail(testUID) - require.Equal(t, reader.Len(), len(msg)) - stored := make([]byte, len(msg)) - _, _ = reader.Read(stored) - require.Equal(t, stored, msg) -} - -func TestMissing(t *testing.T) { - reader, _ := LoadMail("non-existing") - require.Equal(t, reader.Len(), 0) -} - -func TestClearOld(t *testing.T) { - cacheTimeLimit = 10 - msg := []byte("Test message") - SaveMail(testUID, msg, bs) - time.Sleep(100 * time.Millisecond) - - reader, _ := LoadMail(testUID) - require.Equal(t, reader.Len(), 0) -} - -func TestClearBig(t *testing.T) { - r := require.New(t) - wantMessage := []byte("Test message") - - wantCacheSize := 3 - nTestMessages := wantCacheSize * wantCacheSize - cacheSizeLimit = wantCacheSize*len(wantMessage) + 1 - cacheTimeLimit = int64(1 << 20) // be sure the message will survive - - // It should never have more than nSize items. - for i := 0; i < nTestMessages; i++ { - time.Sleep(1 * time.Millisecond) - SaveMail(fmt.Sprintf("%s%d", testUID, i), wantMessage, bs) - r.LessOrEqual(len(mailCache), wantCacheSize, "cache too big when %d", i) - } - - // Check that the oldest are deleted first. - for i := 0; i < nTestMessages; i++ { - iUID := fmt.Sprintf("%s%d", testUID, i) - reader, _ := LoadMail(iUID) - mail := mailCache[iUID] - - if i < (nTestMessages - wantCacheSize) { - r.Zero(reader.Len(), "LoadMail should return empty, but have %s for %s time %d ", string(mail.data), iUID, mail.key.Timestamp) - } else { - stored := make([]byte, len(wantMessage)) - _, err := reader.Read(stored) - r.NoError(err) - r.Equal(wantMessage, stored, "LoadMail returned wrong message: %s for %s time %d", stored, iUID, mail.key.Timestamp) - } - } -} - -func TestConcurency(t *testing.T) { - msg := []byte("Test message") - for i := 0; i < 10; i++ { - go SaveMail(fmt.Sprintf("%s%d", testUID, i), msg, bs) - } -} diff --git a/internal/imap/mailbox.go b/internal/imap/mailbox.go index e9df47f6..8dc09ccb 100644 --- a/internal/imap/mailbox.go +++ b/internal/imap/mailbox.go @@ -37,12 +37,10 @@ type imapMailbox struct { storeUser storeUserProvider storeAddress storeAddressProvider storeMailbox storeMailboxProvider - - builder *message.Builder } // newIMAPMailbox returns struct implementing go-imap/mailbox interface. -func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider, builder *message.Builder) *imapMailbox { +func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider) *imapMailbox { return &imapMailbox{ panicHandler: panicHandler, user: user, @@ -56,8 +54,6 @@ func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox stor storeUser: user.storeUser, storeAddress: user.storeAddress, storeMailbox: storeMailbox, - - builder: builder, } } diff --git a/internal/imap/mailbox_fetch.go b/internal/imap/mailbox_fetch.go index b1f8d724..69aaa833 100644 --- a/internal/imap/mailbox_fetch.go +++ b/internal/imap/mailbox_fetch.go @@ -19,21 +19,13 @@ package imap import ( "bytes" - "context" - "github.com/ProtonMail/proton-bridge/internal/imap/cache" "github.com/ProtonMail/proton-bridge/pkg/message" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/emersion/go-imap" "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) -func (im *imapMailbox) getMessage( - storeMessage storeMessageProvider, - items []imap.FetchItem, - msgBuildCountHistogram *msgBuildCountHistogram, -) (msg *imap.Message, err error) { +func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []imap.FetchItem) (msg *imap.Message, err error) { msglog := im.log.WithField("msgID", storeMessage.ID()) msglog.Trace("Getting message") @@ -69,9 +61,12 @@ func (im *imapMailbox) getMessage( // There is no point having message older than RFC itself, it's not possible. msg.InternalDate = message.SanitizeMessageDate(m.Time) case imap.FetchRFC822Size: - if msg.Size, err = im.getSize(storeMessage); err != nil { + size, err := storeMessage.GetRFC822Size() + if err != nil { return nil, err } + + msg.Size = size case imap.FetchUid: if msg.Uid, err = storeMessage.UID(); err != nil { return nil, err @@ -79,7 +74,7 @@ func (im *imapMailbox) getMessage( case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text: fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests default: - if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil { + if err = im.getLiteralForSection(item, msg, storeMessage); err != nil { return } } @@ -88,35 +83,7 @@ func (im *imapMailbox) getMessage( return msg, err } -// getSize returns cached size or it will build the message, save the size in -// DB and then returns the size after build. -// -// We are storing size in DB as part of pmapi messages metada. The size -// attribute on the server represents size of encrypted body. The value is -// cleared in Bridge and the final decrypted size (including header, attachment -// and MIME structure) is computed after building the message. -func (im *imapMailbox) getSize(storeMessage storeMessageProvider) (uint32, error) { - m := storeMessage.Message() - if m.Size <= 0 { - im.log.WithField("msgID", m.ID).Debug("Size unknown - downloading body") - // We are sure the size is not a problem right now. Clients - // might not first check sizes of all messages so we couldn't - // be sure if seeing 1st or 2nd sync is all right or not. - // Therefore, it's better to exclude getting size from the - // counting and see build count as real message build. - if _, _, err := im.getBodyAndStructure(storeMessage, nil); err != nil { - return 0, err - } - } - return uint32(m.Size), nil -} - -func (im *imapMailbox) getLiteralForSection( - itemSection imap.FetchItem, - msg *imap.Message, - storeMessage storeMessageProvider, - msgBuildCountHistogram *msgBuildCountHistogram, -) error { +func (im *imapMailbox) getLiteralForSection(itemSection imap.FetchItem, msg *imap.Message, storeMessage storeMessageProvider) error { section, err := imap.ParseBodySectionName(itemSection) if err != nil { log.WithError(err).Warn("Failed to parse body section name; part will be skipped") @@ -124,7 +91,7 @@ func (im *imapMailbox) getLiteralForSection( } var literal imap.Literal - if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil { + if literal, err = im.getMessageBodySection(storeMessage, section); err != nil { return err } @@ -149,88 +116,25 @@ func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs * // be sure if seeing 1st or 2nd sync is all right or not. // Therefore, it's better to exclude first body structure fetch // from the counting and see build count as real message build. - if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil { + if bs, _, err = im.getBodyAndStructure(storeMessage); err != nil { return } } return } -func (im *imapMailbox) getBodyAndStructure( - storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram, -) ( - structure *message.BodyStructure, bodyReader *bytes.Reader, err error, -) { - m := storeMessage.Message() - id := im.storeUser.UserID() + m.ID - cache.BuildLock(id) - defer cache.BuildUnlock(id) - bodyReader, structure = cache.LoadMail(id) - - // return the message which was found in cache - if bodyReader.Len() != 0 && structure != nil { - return structure, bodyReader, nil +func (im *imapMailbox) getBodyAndStructure(storeMessage storeMessageProvider) (*message.BodyStructure, *bytes.Reader, error) { + rfc822, err := storeMessage.GetRFC822() + if err != nil { + return nil, nil, err } - structure, body, err := im.buildMessage(m) - bodyReader = bytes.NewReader(body) - size := int64(len(body)) - l := im.log.WithField("newSize", size).WithField("msgID", m.ID) - - if err != nil || structure == nil || size == 0 { - l.WithField("hasStructure", structure != nil).Warn("Failed to build message") - return structure, bodyReader, err + structure, err := storeMessage.GetBodyStructure() + if err != nil { + return nil, nil, err } - // Save the size, body structure and header even for messages which - // were unable to decrypt. Hence they doesn't have to be computed every - // time. - m.Size = size - cacheMessageInStore(storeMessage, structure, body, l) - - if msgBuildCountHistogram != nil { - times, errCount := storeMessage.IncreaseBuildCount() - if errCount != nil { - l.WithError(errCount).Warn("Cannot increase build count") - } - msgBuildCountHistogram.add(times) - } - - // Drafts can change therefore we don't want to cache them. - if !isMessageInDraftFolder(m) { - cache.SaveMail(id, body, structure) - } - - return structure, bodyReader, err -} - -func cacheMessageInStore(storeMessage storeMessageProvider, structure *message.BodyStructure, body []byte, l *logrus.Entry) { - m := storeMessage.Message() - if errSize := storeMessage.SetSize(m.Size); errSize != nil { - l.WithError(errSize).Warn("Cannot update size while building") - } - if structure != nil && !isMessageInDraftFolder(m) { - if errStruct := storeMessage.SetBodyStructure(structure); errStruct != nil { - l.WithError(errStruct).Warn("Cannot update bodystructure while building") - } - } - header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body)) - if errHead == nil && len(header) != 0 { - if errStore := storeMessage.SetHeader(header); errStore != nil { - l.WithError(errStore).Warn("Cannot update header in store") - } - } else { - l.WithError(errHead).Warn("Cannot get header bytes from structure") - } -} - -func isMessageInDraftFolder(m *pmapi.Message) bool { - for _, labelID := range m.LabelIDs { - if labelID == pmapi.DraftLabel { - return true - } - } - return false + return structure, bytes.NewReader(rfc822), nil } // This will download message (or read from cache) and pick up the section, @@ -246,11 +150,7 @@ func isMessageInDraftFolder(m *pmapi.Message) bool { // For all other cases it is necessary to download and decrypt the message // and drop the header which was obtained from cache. The header will // will be stored in DB once successfully built. Check `getBodyAndStructure`. -func (im *imapMailbox) getMessageBodySection( - storeMessage storeMessageProvider, - section *imap.BodySectionName, - msgBuildCountHistogram *msgBuildCountHistogram, -) (imap.Literal, error) { +func (im *imapMailbox) getMessageBodySection(storeMessage storeMessageProvider, section *imap.BodySectionName) (imap.Literal, error) { var header []byte var response []byte @@ -260,7 +160,7 @@ func (im *imapMailbox) getMessageBodySection( if isMainHeaderRequested && storeMessage.IsFullHeaderCached() { header = storeMessage.GetHeader() } else { - structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram) + structure, bodyReader, err := im.getBodyAndStructure(storeMessage) if err != nil { return nil, err } @@ -276,7 +176,7 @@ func (im *imapMailbox) getMessageBodySection( case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part. fallthrough case section.Specifier == imap.HeaderSpecifier: - header, err = structure.GetSectionHeaderBytes(bodyReader, section.Path) + header, err = structure.GetSectionHeaderBytes(section.Path) default: err = errors.New("Unknown specifier " + string(section.Specifier)) } @@ -293,30 +193,3 @@ func (im *imapMailbox) getMessageBodySection( // Trim any output if requested. return bytes.NewBuffer(section.ExtractPartial(response)), nil } - -// buildMessage from PM to IMAP. -func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) { - body, err := im.builder.NewJobWithOptions( - context.Background(), - im.user.client(), - m.ID, - message.JobOptions{ - IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead. - SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate. - AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id. - AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id. - AddMessageDate: true, // Whether to include message time as X-Pm-Date. - AddMessageIDReference: true, // Whether to include the MessageID in References. - }, - ).GetResult() - if err != nil { - return nil, nil, err - } - - structure, err := message.NewBodyStructure(bytes.NewReader(body)) - if err != nil { - return nil, nil, err - } - - return structure, body, nil -} diff --git a/internal/imap/mailbox_messages.go b/internal/imap/mailbox_messages.go index c91616ea..a2ca151b 100644 --- a/internal/imap/mailbox_messages.go +++ b/internal/imap/mailbox_messages.go @@ -479,11 +479,16 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria) } // Filter by size (only if size was already calculated). - if m.Size > 0 { - if criteria.Larger != 0 && m.Size <= int64(criteria.Larger) { + size, err := storeMessage.GetRFC822Size() + if err != nil { + return nil, err + } + + if size > 0 { + if criteria.Larger != 0 && int64(size) <= int64(criteria.Larger) { continue } - if criteria.Smaller != 0 && m.Size >= int64(criteria.Smaller) { + if criteria.Smaller != 0 && int64(size) >= int64(criteria.Smaller) { continue } } @@ -513,13 +518,12 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria) // // Messages must be sent to msgResponse. When the function returns, msgResponse must be closed. func (im *imapMailbox) ListMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) error { - msgBuildCountHistogram := newMsgBuildCountHistogram() return im.logCommand(func() error { - return im.listMessages(isUID, seqSet, items, msgResponse, msgBuildCountHistogram) - }, "FETCH", isUID, seqSet, items, msgBuildCountHistogram) + return im.listMessages(isUID, seqSet, items, msgResponse) + }, "FETCH", isUID, seqSet, items) } -func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message, msgBuildCountHistogram *msgBuildCountHistogram) (err error) { //nolint[funlen] +func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) (err error) { //nolint[funlen] defer func() { close(msgResponse) if err != nil { @@ -564,7 +568,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima return nil, err } - msg, err := im.getMessage(storeMessage, items, msgBuildCountHistogram) + msg, err := im.getMessage(storeMessage, items) if err != nil { err = fmt.Errorf("list message build: %v", err) l.WithField("metaID", storeMessage.ID()).Error(err) @@ -594,7 +598,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima return nil } - err = parallel.RunParallel(fetchWorkers, input, processCallback, collectCallback) + err = parallel.RunParallel(im.user.backend.listWorkers, input, processCallback, collectCallback) if err != nil { return err } diff --git a/internal/imap/msg_build_counts.go b/internal/imap/msg_build_counts.go deleted file mode 100644 index 03322cec..00000000 --- a/internal/imap/msg_build_counts.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package imap - -import ( - "fmt" - "sync" -) - -// msgBuildCountHistogram is used to analyse and log the number of repetitive -// downloads of requested messages per one fetch. The number of builds per each -// messageID is stored in persistent database. The msgBuildCountHistogram will -// take this number for each message in ongoing fetch and create histogram of -// repeats. -// -// Example: During `fetch 1:300` there were -// - 100 messages were downloaded first time -// - 100 messages were downloaded second time -// - 99 messages were downloaded 10th times -// - 1 messages were downloaded 100th times. -type msgBuildCountHistogram struct { - // Key represents how many times message was build. - // Value stores how many messages are build X times based on the key. - counts map[uint32]uint32 - lock sync.Locker -} - -func newMsgBuildCountHistogram() *msgBuildCountHistogram { - return &msgBuildCountHistogram{ - counts: map[uint32]uint32{}, - lock: &sync.Mutex{}, - } -} - -func (c *msgBuildCountHistogram) String() string { - res := "" - for nRebuild, counts := range c.counts { - if res != "" { - res += ", " - } - res += fmt.Sprintf("[%d]:%d", nRebuild, counts) - } - return res -} - -func (c *msgBuildCountHistogram) add(nRebuild uint32) { - c.lock.Lock() - defer c.lock.Unlock() - c.counts[nRebuild]++ -} diff --git a/internal/imap/store.go b/internal/imap/store.go index 918976be..2c17f5b2 100644 --- a/internal/imap/store.go +++ b/internal/imap/store.go @@ -80,7 +80,6 @@ type storeMailboxProvider interface { GetDelimiter() string GetMessage(apiID string) (storeMessageProvider, error) - FetchMessage(apiID string) (storeMessageProvider, error) LabelMessages(apiID []string) error UnlabelMessages(apiID []string) error MarkMessagesRead(apiID []string) error @@ -100,14 +99,12 @@ type storeMessageProvider interface { Message() *pmapi.Message IsMarkedDeleted() bool - SetSize(int64) error - SetHeader([]byte) error GetHeader() []byte + GetRFC822() ([]byte, error) + GetRFC822Size() (uint32, error) GetMIMEHeader() textproto.MIMEHeader IsFullHeaderCached() bool - SetBodyStructure(*pkgMsg.BodyStructure) error GetBodyStructure() (*pkgMsg.BodyStructure, error) - IncreaseBuildCount() (uint32, error) } type storeUserWrap struct { @@ -165,7 +162,3 @@ func newStoreMailboxWrap(mailbox *store.Mailbox) *storeMailboxWrap { func (s *storeMailboxWrap) GetMessage(apiID string) (storeMessageProvider, error) { return s.Mailbox.GetMessage(apiID) } - -func (s *storeMailboxWrap) FetchMessage(apiID string) (storeMessageProvider, error) { - return s.Mailbox.FetchMessage(apiID) -} diff --git a/internal/imap/user.go b/internal/imap/user.go index 8c4059c7..bf724656 100644 --- a/internal/imap/user.go +++ b/internal/imap/user.go @@ -135,7 +135,7 @@ func (iu *imapUser) ListMailboxes(showOnlySubcribed bool) ([]goIMAPBackend.Mailb if showOnlySubcribed && !iu.isSubscribed(storeMailbox.LabelID()) { continue } - mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder) + mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox) mailboxes = append(mailboxes, mailbox) } @@ -167,7 +167,7 @@ func (iu *imapUser) GetMailbox(name string) (mb goIMAPBackend.Mailbox, err error return } - return newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder), nil + return newIMAPMailbox(iu.panicHandler, iu, storeMailbox), nil } // CreateMailbox creates a new mailbox. diff --git a/internal/store/cache.go b/internal/store/cache.go index 76757ca1..9c74552b 100644 --- a/internal/store/cache.go +++ b/internal/store/cache.go @@ -18,99 +18,113 @@ package store import ( - "encoding/json" - "os" - "sync" - - "github.com/pkg/errors" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/pkg/message" + "github.com/sirupsen/logrus" + bolt "go.etcd.io/bbolt" ) -// Cache caches the last event IDs for all accounts (there should be only one instance). -type Cache struct { - // cache is map from userID => key (such as last event) => value (such as event ID). - cache map[string]map[string]string - path string - lock *sync.RWMutex -} +const passphraseKey = "passphrase" -// NewCache constructs a new cache at the given path. -func NewCache(path string) *Cache { - return &Cache{ - path: path, - lock: &sync.RWMutex{}, - } -} - -func (c *Cache) getEventID(userID string) string { - c.lock.Lock() - defer c.lock.Unlock() - - if err := c.loadCache(); err != nil { - log.WithError(err).Warn("Problem to load store cache") - } - - if c.cache == nil { - c.cache = map[string]map[string]string{} - } - if c.cache[userID] == nil { - c.cache[userID] = map[string]string{} - } - - return c.cache[userID]["events"] -} - -func (c *Cache) setEventID(userID, eventID string) error { - c.lock.Lock() - defer c.lock.Unlock() - - if c.cache[userID] == nil { - c.cache[userID] = map[string]string{} - } - c.cache[userID]["events"] = eventID - - return c.saveCache() -} - -func (c *Cache) loadCache() error { - if c.cache != nil { - return nil - } - - f, err := os.Open(c.path) +// UnlockCache unlocks the cache for the user with the given keyring. +func (store *Store) UnlockCache(kr *crypto.KeyRing) error { + passphrase, err := store.getCachePassphrase() if err != nil { return err } - defer f.Close() //nolint[errcheck] - return json.NewDecoder(f).Decode(&c.cache) -} + if passphrase == nil { + if passphrase, err = crypto.RandomToken(32); err != nil { + return err + } -func (c *Cache) saveCache() error { - if c.cache == nil { - return errors.New("events: cannot save cache: cache is nil") + enc, err := kr.Encrypt(crypto.NewPlainMessage(passphrase), nil) + if err != nil { + return err + } + + if err := store.setCachePassphrase(enc.GetBinary()); err != nil { + return err + } + } else { + dec, err := kr.Decrypt(crypto.NewPGPMessage(passphrase), nil, crypto.GetUnixTime()) + if err != nil { + return err + } + + passphrase = dec.GetBinary() } - f, err := os.Create(c.path) + if err := store.cache.Unlock(store.user.ID(), passphrase); err != nil { + return err + } + + store.cacher.start() + + return nil +} + +func (store *Store) getCachePassphrase() ([]byte, error) { + var passphrase []byte + + if err := store.db.View(func(tx *bolt.Tx) error { + passphrase = tx.Bucket(cachePassphraseBucket).Get([]byte(passphraseKey)) + return nil + }); err != nil { + return nil, err + } + + return passphrase, nil +} + +func (store *Store) setCachePassphrase(passphrase []byte) error { + return store.db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(cachePassphraseBucket).Put([]byte(passphraseKey), passphrase) + }) +} + +func (store *Store) clearCachePassphrase() error { + return store.db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(cachePassphraseBucket).Delete([]byte(passphraseKey)) + }) +} + +func (store *Store) getCachedMessage(messageID string) ([]byte, error) { + if store.cache.Has(store.user.ID(), messageID) { + return store.cache.Get(store.user.ID(), messageID) + } + + job, done := store.newBuildJob(messageID, message.ForegroundPriority) + defer done() + + literal, err := job.GetResult() + if err != nil { + return nil, err + } + + // NOTE(GODT-1158): No need to block until cache has been set; do this async? + if err := store.cache.Set(store.user.ID(), messageID, literal); err != nil { + logrus.WithError(err).Error("Failed to cache message") + } + + return literal, nil +} + +// IsCached returns whether the given message already exists in the cache. +func (store *Store) IsCached(messageID string) bool { + return store.cache.Has(store.user.ID(), messageID) +} + +// BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache. +// It builds with background priority. +func (store *Store) BuildAndCacheMessage(messageID string) error { + job, done := store.newBuildJob(messageID, message.BackgroundPriority) + defer done() + + literal, err := job.GetResult() if err != nil { return err } - defer f.Close() //nolint[errcheck] - return json.NewEncoder(f).Encode(c.cache) -} - -func (c *Cache) clearCacheUser(userID string) error { - c.lock.Lock() - defer c.lock.Unlock() - - if c.cache == nil { - log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil") - return nil - } - - log.WithField("user", userID).Trace("Removing user from event loop cache") - - delete(c.cache, userID) - - return c.saveCache() + return store.cache.Set(store.user.ID(), messageID, literal) } diff --git a/internal/store/cache/cache_test.go b/internal/store/cache/cache_test.go new file mode 100644 index 00000000..efb03ef3 --- /dev/null +++ b/internal/store/cache/cache_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOnDiskCacheNoCompression(t *testing.T) { + cache, err := NewOnDiskCache(t.TempDir(), &NoopCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()}) + require.NoError(t, err) + + testCache(t, cache) +} + +func TestOnDiskCacheGZipCompression(t *testing.T) { + cache, err := NewOnDiskCache(t.TempDir(), &GZipCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()}) + require.NoError(t, err) + + testCache(t, cache) +} + +func TestInMemoryCache(t *testing.T) { + testCache(t, NewInMemoryCache(1<<20)) +} + +func testCache(t *testing.T, cache Cache) { + assert.NoError(t, cache.Unlock("userID1", []byte("my secret passphrase"))) + assert.NoError(t, cache.Unlock("userID2", []byte("my other passphrase"))) + + getSetCachedMessage(t, cache, "userID1", "messageID1", "some secret") + assert.True(t, cache.Has("userID1", "messageID1")) + + getSetCachedMessage(t, cache, "userID2", "messageID2", "some other secret") + assert.True(t, cache.Has("userID2", "messageID2")) + + assert.NoError(t, cache.Rem("userID1", "messageID1")) + assert.False(t, cache.Has("userID1", "messageID1")) + + assert.NoError(t, cache.Rem("userID2", "messageID2")) + assert.False(t, cache.Has("userID2", "messageID2")) + + assert.NoError(t, cache.Delete("userID1")) + assert.NoError(t, cache.Delete("userID2")) +} + +func getSetCachedMessage(t *testing.T, cache Cache, userID, messageID, secret string) { + assert.NoError(t, cache.Set(userID, messageID, []byte(secret))) + + data, err := cache.Get(userID, messageID) + assert.NoError(t, err) + + assert.Equal(t, []byte(secret), data) +} diff --git a/internal/store/cache/compressor.go b/internal/store/cache/compressor.go new file mode 100644 index 00000000..c252a739 --- /dev/null +++ b/internal/store/cache/compressor.go @@ -0,0 +1,33 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +type Compressor interface { + Compress([]byte) ([]byte, error) + Decompress([]byte) ([]byte, error) +} + +type NoopCompressor struct{} + +func (NoopCompressor) Compress(dec []byte) ([]byte, error) { + return dec, nil +} + +func (NoopCompressor) Decompress(cmp []byte) ([]byte, error) { + return cmp, nil +} diff --git a/internal/store/cache/compressor_gzip.go b/internal/store/cache/compressor_gzip.go new file mode 100644 index 00000000..412de937 --- /dev/null +++ b/internal/store/cache/compressor_gzip.go @@ -0,0 +1,60 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import ( + "bytes" + "compress/gzip" +) + +type GZipCompressor struct{} + +func (GZipCompressor) Compress(dec []byte) ([]byte, error) { + buf := new(bytes.Buffer) + + zw := gzip.NewWriter(buf) + + if _, err := zw.Write(dec); err != nil { + return nil, err + } + + if err := zw.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func (GZipCompressor) Decompress(cmp []byte) ([]byte, error) { + zr, err := gzip.NewReader(bytes.NewReader(cmp)) + if err != nil { + return nil, err + } + + buf := new(bytes.Buffer) + + if _, err := buf.ReadFrom(zr); err != nil { + return nil, err + } + + if err := zr.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/internal/store/cache/disk.go b/internal/store/cache/disk.go new file mode 100644 index 00000000..b4b3ee33 --- /dev/null +++ b/internal/store/cache/disk.go @@ -0,0 +1,244 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "errors" + "io/ioutil" + "os" + "path/filepath" + "sync" + + "github.com/ProtonMail/proton-bridge/pkg/semaphore" + "github.com/ricochet2200/go-disk-usage/du" +) + +var ErrLowSpace = errors.New("not enough free space left on device") + +type onDiskCache struct { + path string + opts Options + + gcm map[string]cipher.AEAD + cmp Compressor + rsem, wsem semaphore.Semaphore + pending *pending + + diskSize uint64 + diskFree uint64 + once *sync.Once + lock sync.Mutex +} + +func NewOnDiskCache(path string, cmp Compressor, opts Options) (Cache, error) { + if err := os.MkdirAll(path, 0700); err != nil { + return nil, err + } + + usage := du.NewDiskUsage(path) + + // NOTE(GODT-1158): use Available() or Free()? + return &onDiskCache{ + path: path, + opts: opts, + + gcm: make(map[string]cipher.AEAD), + cmp: cmp, + rsem: semaphore.New(opts.ConcurrentRead), + wsem: semaphore.New(opts.ConcurrentWrite), + pending: newPending(), + + diskSize: usage.Size(), + diskFree: usage.Available(), + once: &sync.Once{}, + }, nil +} + +func (c *onDiskCache) Unlock(userID string, passphrase []byte) error { + hash := sha256.New() + + if _, err := hash.Write(passphrase); err != nil { + return err + } + + aes, err := aes.NewCipher(hash.Sum(nil)) + if err != nil { + return err + } + + gcm, err := cipher.NewGCM(aes) + if err != nil { + return err + } + + if err := os.MkdirAll(c.getUserPath(userID), 0700); err != nil { + return err + } + + c.gcm[userID] = gcm + + return nil +} + +func (c *onDiskCache) Delete(userID string) error { + defer c.update() + + return os.RemoveAll(c.getUserPath(userID)) +} + +// Has returns whether the given message exists in the cache. +func (c *onDiskCache) Has(userID, messageID string) bool { + c.pending.wait(c.getMessagePath(userID, messageID)) + + c.rsem.Lock() + defer c.rsem.Unlock() + + _, err := os.Stat(c.getMessagePath(userID, messageID)) + + switch { + case err == nil: + return true + + case os.IsNotExist(err): + return false + + default: + panic(err) + } +} + +func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) { + enc, err := c.readFile(c.getMessagePath(userID, messageID)) + if err != nil { + return nil, err + } + + cmp, err := c.gcm[userID].Open(nil, enc[:c.gcm[userID].NonceSize()], enc[c.gcm[userID].NonceSize():], nil) + if err != nil { + return nil, err + } + + return c.cmp.Decompress(cmp) +} + +func (c *onDiskCache) Set(userID, messageID string, literal []byte) error { + nonce := make([]byte, c.gcm[userID].NonceSize()) + + if _, err := rand.Read(nonce); err != nil { + return err + } + + cmp, err := c.cmp.Compress(literal) + if err != nil { + return err + } + + // NOTE(GODT-1158): How to properly handle low space? Don't return error, that's bad. Instead send event? + if !c.hasSpace(len(cmp)) { + return nil + } + + return c.writeFile(c.getMessagePath(userID, messageID), c.gcm[userID].Seal(nonce, nonce, cmp, nil)) +} + +func (c *onDiskCache) Rem(userID, messageID string) error { + defer c.update() + + return os.Remove(c.getMessagePath(userID, messageID)) +} + +func (c *onDiskCache) readFile(path string) ([]byte, error) { + c.rsem.Lock() + defer c.rsem.Unlock() + + // Wait before reading in case the file is currently being written. + c.pending.wait(path) + + return ioutil.ReadFile(filepath.Clean(path)) +} + +func (c *onDiskCache) writeFile(path string, b []byte) error { + c.wsem.Lock() + defer c.wsem.Unlock() + + // Mark the file as currently being written. + // If it's already being written, wait for it to be done and return nil. + // NOTE(GODT-1158): Let's hope it succeeded... + if ok := c.pending.add(path); !ok { + c.pending.wait(path) + return nil + } + defer c.pending.done(path) + + // Reduce the approximate free space (update it exactly later). + c.lock.Lock() + c.diskFree -= uint64(len(b)) + c.lock.Unlock() + + // Update the diskFree eventually. + defer c.update() + + // NOTE(GODT-1158): What happens when this fails? Should be fixed eventually. + return ioutil.WriteFile(filepath.Clean(path), b, 0600) +} + +func (c *onDiskCache) hasSpace(size int) bool { + c.lock.Lock() + defer c.lock.Unlock() + + if c.opts.MinFreeAbs > 0 { + if c.diskFree-uint64(size) < c.opts.MinFreeAbs { + return false + } + } + + if c.opts.MinFreeRat > 0 { + if float64(c.diskFree-uint64(size))/float64(c.diskSize) < c.opts.MinFreeRat { + return false + } + } + + return true +} + +func (c *onDiskCache) update() { + go func() { + c.once.Do(func() { + c.lock.Lock() + defer c.lock.Unlock() + + // Update the free space. + c.diskFree = du.NewDiskUsage(c.path).Available() + + // Reset the Once object (so we can update again). + c.once = &sync.Once{} + }) + }() +} + +func (c *onDiskCache) getUserPath(userID string) string { + return filepath.Join(c.path, getHash(userID)) +} + +func (c *onDiskCache) getMessagePath(userID, messageID string) string { + return filepath.Join(c.getUserPath(userID), getHash(messageID)) +} diff --git a/internal/store/cache/hash.go b/internal/store/cache/hash.go new file mode 100644 index 00000000..fc07771b --- /dev/null +++ b/internal/store/cache/hash.go @@ -0,0 +1,33 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import ( + "crypto/sha256" + "encoding/hex" +) + +func getHash(name string) string { + hash := sha256.New() + + if _, err := hash.Write([]byte(name)); err != nil { + panic(err) + } + + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/internal/store/cache/memory.go b/internal/store/cache/memory.go new file mode 100644 index 00000000..d364e5f9 --- /dev/null +++ b/internal/store/cache/memory.go @@ -0,0 +1,104 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import ( + "errors" + "sync" +) + +type inMemoryCache struct { + lock sync.RWMutex + data map[string]map[string][]byte + size, limit int +} + +// NewInMemoryCache creates a new in memory cache which stores up to the given number of bytes of cached data. +// NOTE(GODT-1158): Make this threadsafe. +func NewInMemoryCache(limit int) Cache { + return &inMemoryCache{ + data: make(map[string]map[string][]byte), + limit: limit, + } +} + +func (c *inMemoryCache) Unlock(userID string, passphrase []byte) error { + c.data[userID] = make(map[string][]byte) + return nil +} + +func (c *inMemoryCache) Delete(userID string) error { + c.lock.Lock() + defer c.lock.Unlock() + + for _, message := range c.data[userID] { + c.size -= len(message) + } + + delete(c.data, userID) + + return nil +} + +// Has returns whether the given message exists in the cache. +func (c *inMemoryCache) Has(userID, messageID string) bool { + if _, err := c.Get(userID, messageID); err != nil { + return false + } + + return true +} + +func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + literal, ok := c.data[userID][messageID] + if !ok { + return nil, errors.New("no such message in cache") + } + + return literal, nil +} + +// NOTE(GODT-1158): What to actually do when memory limit is reached? Replace something existing? Return error? Drop silently? +// NOTE(GODT-1158): Pull in cache-rotating feature from old IMAP cache. +func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.size+len(literal) > c.limit { + return nil + } + + c.size += len(literal) + c.data[userID][messageID] = literal + + return nil +} + +func (c *inMemoryCache) Rem(userID, messageID string) error { + c.lock.Lock() + defer c.lock.Unlock() + + c.size -= len(c.data[userID][messageID]) + + delete(c.data[userID], messageID) + + return nil +} diff --git a/internal/store/cache/options.go b/internal/store/cache/options.go new file mode 100644 index 00000000..a2dd9cb3 --- /dev/null +++ b/internal/store/cache/options.go @@ -0,0 +1,25 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +type Options struct { + MinFreeAbs uint64 + MinFreeRat float64 + ConcurrentRead int + ConcurrentWrite int +} diff --git a/internal/store/cache/pending.go b/internal/store/cache/pending.go new file mode 100644 index 00000000..29314472 --- /dev/null +++ b/internal/store/cache/pending.go @@ -0,0 +1,61 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import "sync" + +type pending struct { + lock sync.Mutex + path map[string]chan struct{} +} + +func newPending() *pending { + return &pending{path: make(map[string]chan struct{})} +} + +func (p *pending) add(path string) bool { + p.lock.Lock() + defer p.lock.Unlock() + + if _, ok := p.path[path]; ok { + return false + } + + p.path[path] = make(chan struct{}) + + return true +} + +func (p *pending) wait(path string) { + p.lock.Lock() + ch, ok := p.path[path] + p.lock.Unlock() + + if ok { + <-ch + } +} + +func (p *pending) done(path string) { + p.lock.Lock() + defer p.lock.Unlock() + + defer close(p.path[path]) + + delete(p.path, path) +} diff --git a/internal/store/cache/pending_test.go b/internal/store/cache/pending_test.go new file mode 100644 index 00000000..82bc2979 --- /dev/null +++ b/internal/store/cache/pending_test.go @@ -0,0 +1,51 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPending(t *testing.T) { + pending := newPending() + + pending.add("1") + pending.add("2") + pending.add("3") + + resCh := make(chan string) + + go func() { pending.wait("1"); resCh <- "1" }() + go func() { pending.wait("2"); resCh <- "2" }() + go func() { pending.wait("3"); resCh <- "3" }() + + pending.done("1") + assert.Equal(t, "1", <-resCh) + + pending.done("2") + assert.Equal(t, "2", <-resCh) + + pending.done("3") + assert.Equal(t, "3", <-resCh) +} + +func TestPendingUnknown(t *testing.T) { + newPending().wait("this is not currently being waited") +} diff --git a/internal/store/cache/types.go b/internal/store/cache/types.go new file mode 100644 index 00000000..a1c7c9db --- /dev/null +++ b/internal/store/cache/types.go @@ -0,0 +1,28 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package cache + +type Cache interface { + Unlock(userID string, passphrase []byte) error + Delete(userID string) error + + Has(userID, messageID string) bool + Get(userID, messageID string) ([]byte, error) + Set(userID, messageID string, literal []byte) error + Rem(userID, messageID string) error +} diff --git a/internal/store/cache_watcher.go b/internal/store/cache_watcher.go new file mode 100644 index 00000000..c8f3811a --- /dev/null +++ b/internal/store/cache_watcher.go @@ -0,0 +1,63 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package store + +import "time" + +func (store *Store) StartWatcher() { + store.done = make(chan struct{}) + + go func() { + ticker := time.NewTicker(3 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // NOTE(GODT-1158): Race condition here? What if DB was already closed? + messageIDs, err := store.getAllMessageIDs() + if err != nil { + return + } + + for _, messageID := range messageIDs { + if !store.IsCached(messageID) { + store.cacher.newJob(messageID) + } + } + + case <-store.done: + return + } + } + }() +} + +func (store *Store) stopWatcher() { + if store.done == nil { + return + } + + select { + default: + close(store.done) + + case <-store.done: + return + } +} diff --git a/internal/store/cache_worker.go b/internal/store/cache_worker.go new file mode 100644 index 00000000..68173428 --- /dev/null +++ b/internal/store/cache_worker.go @@ -0,0 +1,104 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package store + +import ( + "sync" + + "github.com/sirupsen/logrus" +) + +type Cacher struct { + storer Storer + jobs chan string + done chan struct{} + started bool + wg *sync.WaitGroup +} + +type Storer interface { + IsCached(messageID string) bool + BuildAndCacheMessage(messageID string) error +} + +func newCacher(storer Storer) *Cacher { + return &Cacher{ + storer: storer, + jobs: make(chan string), + done: make(chan struct{}), + wg: &sync.WaitGroup{}, + } +} + +// newJob sends a new job to the cacher if it's running. +func (cacher *Cacher) newJob(messageID string) { + if !cacher.started { + return + } + + select { + case <-cacher.done: + return + + default: + if !cacher.storer.IsCached(messageID) { + cacher.wg.Add(1) + go func() { cacher.jobs <- messageID }() + } + } +} + +func (cacher *Cacher) start() { + cacher.started = true + + go func() { + for { + select { + case messageID := <-cacher.jobs: + go cacher.handleJob(messageID) + + case <-cacher.done: + return + } + } + }() +} + +func (cacher *Cacher) handleJob(messageID string) { + defer cacher.wg.Done() + + if err := cacher.storer.BuildAndCacheMessage(messageID); err != nil { + logrus.WithError(err).Error("Failed to build and cache message") + } else { + logrus.WithField("messageID", messageID).Trace("Message cached") + } +} + +func (cacher *Cacher) stop() { + cacher.started = false + + cacher.wg.Wait() + + select { + case <-cacher.done: + return + + default: + close(cacher.done) + } +} diff --git a/internal/store/cache_worker_test.go b/internal/store/cache_worker_test.go new file mode 100644 index 00000000..6bd6b121 --- /dev/null +++ b/internal/store/cache_worker_test.go @@ -0,0 +1,103 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package store + +import ( + "testing" + + storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" +) + +func withTestCacher(t *testing.T, doTest func(storer *storemocks.MockStorer, cacher *Cacher)) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Mock storer used to build/cache messages. + storer := storemocks.NewMockStorer(ctrl) + + // Create a new cacher pointing to the fake store. + cacher := newCacher(storer) + + // Start the cacher and wait for it to stop. + cacher.start() + defer cacher.stop() + + doTest(storer, cacher) +} + +func TestCacher(t *testing.T) { + // If the message is not yet cached, we should expect to try to build and cache it. + withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) { + storer.EXPECT().IsCached("messageID").Return(false) + storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil) + cacher.newJob("messageID") + }) +} + +func TestCacherAlreadyCached(t *testing.T) { + // If the message is already cached, we should not try to build it. + withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) { + storer.EXPECT().IsCached("messageID").Return(true) + cacher.newJob("messageID") + }) +} + +func TestCacherFail(t *testing.T) { + // If building the message fails, we should not try to cache it. + withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) { + storer.EXPECT().IsCached("messageID").Return(false) + storer.EXPECT().BuildAndCacheMessage("messageID").Return(errors.New("failed to build message")) + cacher.newJob("messageID") + }) +} + +func TestCacherStop(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Mock storer used to build/cache messages. + storer := storemocks.NewMockStorer(ctrl) + + // Create a new cacher pointing to the fake store. + cacher := newCacher(storer) + + // Start the cacher. + cacher.start() + + // Send a job -- this should succeed. + storer.EXPECT().IsCached("messageID").Return(false) + storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil) + cacher.newJob("messageID") + + // Stop the cacher. + cacher.stop() + + // Send more jobs -- these should all be dropped. + cacher.newJob("messageID2") + cacher.newJob("messageID3") + cacher.newJob("messageID4") + cacher.newJob("messageID5") + + // Stopping the cacher multiple times is safe. + cacher.stop() + cacher.stop() + cacher.stop() + cacher.stop() +} diff --git a/internal/store/change_test.go b/internal/store/change_test.go index c4866d58..7fba37bc 100644 --- a/internal/store/change_test.go +++ b/internal/store/change_test.go @@ -34,7 +34,7 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) { m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false) m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false) - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) m.store.SetChangeNotifier(m.changeNotifier) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) @@ -49,7 +49,7 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) { m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false) m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false) - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) m.store.SetChangeNotifier(m.changeNotifier) msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) @@ -61,7 +61,7 @@ func TestNotifyChangeDeleteMessage(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) diff --git a/internal/store/event_loop.go b/internal/store/event_loop.go index 50453c9a..acdc469f 100644 --- a/internal/store/event_loop.go +++ b/internal/store/event_loop.go @@ -38,7 +38,7 @@ const ( ) type eventLoop struct { - cache *Cache + currentEvents *Events currentEventID string currentEvent *pmapi.Event pollCh chan chan struct{} @@ -51,26 +51,26 @@ type eventLoop struct { log *logrus.Entry - store *Store - user BridgeUser - events listener.Listener + store *Store + user BridgeUser + listener listener.Listener } -func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop { +func newEventLoop(currentEvents *Events, store *Store, user BridgeUser, listener listener.Listener) *eventLoop { eventLog := log.WithField("userID", user.ID()) eventLog.Trace("Creating new event loop") return &eventLoop{ - cache: cache, - currentEventID: cache.getEventID(user.ID()), + currentEvents: currentEvents, + currentEventID: currentEvents.getEventID(user.ID()), pollCh: make(chan chan struct{}), isRunning: false, log: eventLog, - store: store, - user: user, - events: events, + store: store, + user: user, + listener: listener, } } @@ -89,7 +89,7 @@ func (loop *eventLoop) setFirstEventID() (err error) { loop.currentEventID = event.EventID - if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil { + if err = loop.currentEvents.setEventID(loop.user.ID(), loop.currentEventID); err != nil { loop.log.WithError(err).Error("Could not set latest event ID in user cache") return } @@ -229,7 +229,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun if err != nil && isFdCloseToULimit() { l.Warn("Ulimit reached") - loop.events.Emit(bridgeEvents.RestartBridgeEvent, "") + loop.listener.Emit(bridgeEvents.RestartBridgeEvent, "") err = nil } @@ -291,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun // This allows the event loop to continue to function (unless the cache was broken // and bridge stopped, in which case it will start from the old event ID anyway). loop.currentEventID = event.EventID - if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil { + if err = loop.currentEvents.setEventID(loop.user.ID(), event.EventID); err != nil { return false, errors.Wrap(err, "failed to save event ID to cache") } } @@ -371,7 +371,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap switch addressEvent.Action { case pmapi.EventCreate: log.WithField("email", addressEvent.Address.Email).Debug("Address was created") - loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress()) + loop.listener.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress()) case pmapi.EventUpdate: oldAddress := oldList.ByID(addressEvent.ID) @@ -383,7 +383,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap email := oldAddress.Email log.WithField("email", email).Debug("Address was updated") if addressEvent.Address.Receive != oldAddress.Receive { - loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email) + loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email) } case pmapi.EventDelete: @@ -396,7 +396,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap email := oldAddress.Email log.WithField("email", email).Debug("Address was deleted") loop.user.CloseConnection(email) - loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email) + loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email) case pmapi.EventUpdateFlags: log.Error("EventUpdateFlags for address event is uknown operation") } diff --git a/internal/store/event_loop_test.go b/internal/store/event_loop_test.go index 845a03a6..3e0e23d1 100644 --- a/internal/store/event_loop_test.go +++ b/internal/store/event_loop_test.go @@ -53,7 +53,7 @@ func TestEventLoopProcessMoreEvents(t *testing.T) { More: false, }, nil), ) - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) // Event loop runs in goroutine started during store creation (newStoreNoEvents). // Force to run the next event. @@ -78,7 +78,7 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) { subject := "old subject" newSubject := "new subject" - m.newStoreNoEvents(true, &pmapi.Message{ + m.newStoreNoEvents(t, true, &pmapi.Message{ ID: "msg1", Subject: subject, }) @@ -106,7 +106,7 @@ func TestEventLoopDeletionNotPaused(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true, &pmapi.Message{ + m.newStoreNoEvents(t, true, &pmapi.Message{ ID: "msg1", Subject: "subject", LabelIDs: []string{"label"}, @@ -133,7 +133,7 @@ func TestEventLoopDeletionPaused(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true, &pmapi.Message{ + m.newStoreNoEvents(t, true, &pmapi.Message{ ID: "msg1", Subject: "subject", LabelIDs: []string{"label"}, diff --git a/internal/store/events.go b/internal/store/events.go new file mode 100644 index 00000000..0e3d6804 --- /dev/null +++ b/internal/store/events.go @@ -0,0 +1,116 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package store + +import ( + "encoding/json" + "os" + "sync" + + "github.com/pkg/errors" +) + +// Events caches the last event IDs for all accounts (there should be only one instance). +type Events struct { + // eventMap is map from userID => key (such as last event) => value (such as event ID). + eventMap map[string]map[string]string + path string + lock *sync.RWMutex +} + +// NewEvents constructs a new event cache at the given path. +func NewEvents(path string) *Events { + return &Events{ + path: path, + lock: &sync.RWMutex{}, + } +} + +func (c *Events) getEventID(userID string) string { + c.lock.Lock() + defer c.lock.Unlock() + + if err := c.loadEvents(); err != nil { + log.WithError(err).Warn("Problem to load store events") + } + + if c.eventMap == nil { + c.eventMap = map[string]map[string]string{} + } + if c.eventMap[userID] == nil { + c.eventMap[userID] = map[string]string{} + } + + return c.eventMap[userID]["events"] +} + +func (c *Events) setEventID(userID, eventID string) error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.eventMap[userID] == nil { + c.eventMap[userID] = map[string]string{} + } + c.eventMap[userID]["events"] = eventID + + return c.saveEvents() +} + +func (c *Events) loadEvents() error { + if c.eventMap != nil { + return nil + } + + f, err := os.Open(c.path) + if err != nil { + return err + } + defer f.Close() //nolint[errcheck] + + return json.NewDecoder(f).Decode(&c.eventMap) +} + +func (c *Events) saveEvents() error { + if c.eventMap == nil { + return errors.New("events: cannot save events: events map is nil") + } + + f, err := os.Create(c.path) + if err != nil { + return err + } + defer f.Close() //nolint[errcheck] + + return json.NewEncoder(f).Encode(c.eventMap) +} + +func (c *Events) clearUserEvents(userID string) error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.eventMap == nil { + log.WithField("user", userID).Warning("Cannot clear user events: event map is nil") + return nil + } + + log.WithField("user", userID).Trace("Removing user events from event loop") + + delete(c.eventMap, userID) + + return c.saveEvents() +} diff --git a/internal/store/mailbox_counts_test.go b/internal/store/mailbox_counts_test.go index 21f3c374..c3352823 100644 --- a/internal/store/mailbox_counts_test.go +++ b/internal/store/mailbox_counts_test.go @@ -107,7 +107,7 @@ func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Sto func TestMailboxCountRemove(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) testCounts := []*pmapi.MessagesCount{ {LabelID: "label1", Total: 100, Unread: 0}, diff --git a/internal/store/mailbox_ids_test.go b/internal/store/mailbox_ids_test.go index 77cd76cb..1541dff4 100644 --- a/internal/store/mailbox_ids_test.go +++ b/internal/store/mailbox_ids_test.go @@ -35,7 +35,7 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) @@ -80,7 +80,7 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen] m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel}) require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg)) diff --git a/internal/store/message.go b/internal/store/message.go index fe05609e..e12dbda6 100644 --- a/internal/store/message.go +++ b/internal/store/message.go @@ -67,40 +67,19 @@ func (message *Message) Message() *pmapi.Message { return message.msg } -// IsMarkedDeleted returns true if message is marked as deleted for specific -// mailbox. +// IsMarkedDeleted returns true if message is marked as deleted for specific mailbox. func (message *Message) IsMarkedDeleted() bool { - isMarkedAsDeleted := false - err := message.storeMailbox.db().View(func(tx *bolt.Tx) error { + var isMarkedAsDeleted bool + + if err := message.storeMailbox.db().View(func(tx *bolt.Tx) error { isMarkedAsDeleted = message.storeMailbox.txGetDeletedIDsBucket(tx).Get([]byte(message.msg.ID)) != nil return nil - }) - if err != nil { + }); err != nil { message.storeMailbox.log.WithError(err).Error("Not able to retrieve deleted mark, assuming false.") return false } - return isMarkedAsDeleted -} -// SetSize updates the information about size of decrypted message which can be -// used for IMAP. This should not trigger any IMAP update. -// NOTE: The size from the server corresponds to pure body bytes. Hence it -// should not be used. The correct size has to be calculated from decrypted and -// built message. -func (message *Message) SetSize(size int64) error { - message.msg.Size = size - txUpdate := func(tx *bolt.Tx) error { - stored, err := message.store.txGetMessage(tx, message.msg.ID) - if err != nil { - return err - } - stored.Size = size - return message.store.txPutMessage( - tx.Bucket(metadataBucket), - stored, - ) - } - return message.store.db.Update(txUpdate) + return isMarkedAsDeleted } // SetContentTypeAndHeader updates the information about content type and @@ -112,7 +91,7 @@ func (message *Message) SetSize(size int64) error { func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error { message.msg.MIMEType = mimeType message.msg.Header = header - txUpdate := func(tx *bolt.Tx) error { + return message.store.db.Update(func(tx *bolt.Tx) error { stored, err := message.store.txGetMessage(tx, message.msg.ID) if err != nil { return err @@ -123,34 +102,26 @@ func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Hea tx.Bucket(metadataBucket), stored, ) - } - return message.store.db.Update(txUpdate) -} - -// SetHeader checks header can be parsed and if yes it stores header bytes in -// database. -func (message *Message) SetHeader(header []byte) error { - _, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(header))).ReadMIMEHeader() - if err != nil { - return err - } - return message.store.db.Update(func(tx *bolt.Tx) error { - return tx.Bucket(headersBucket).Put([]byte(message.ID()), header) }) } // IsFullHeaderCached will check that valid full header is stored in DB. func (message *Message) IsFullHeaderCached() bool { - header, err := message.getRawHeader() - return err == nil && header != nil -} - -func (message *Message) getRawHeader() (raw []byte, err error) { - err = message.store.db.View(func(tx *bolt.Tx) error { - raw = tx.Bucket(headersBucket).Get([]byte(message.ID())) + var raw []byte + err := message.store.db.View(func(tx *bolt.Tx) error { + raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID())) return nil }) - return + return err == nil && raw != nil +} + +func (message *Message) getRawHeader() ([]byte, error) { + bs, err := message.GetBodyStructure() + if err != nil { + return nil, err + } + + return bs.GetMailHeaderBytes() } // GetHeader will return cached header from DB. @@ -178,44 +149,79 @@ func (message *Message) GetMIMEHeader() textproto.MIMEHeader { return header } -// SetBodyStructure stores serialized body structure in database. -func (message *Message) SetBodyStructure(bs *pkgMsg.BodyStructure) error { - txUpdate := func(tx *bolt.Tx) error { - return message.store.txPutBodyStructure( - tx.Bucket(bodystructureBucket), - message.ID(), bs, - ) - } - return message.store.db.Update(txUpdate) -} +// GetBodyStructure returns the message's body structure. +// It checks first if it's in the store. If it is, it returns it from store, +// otherwise it computes it from the message cache (and saves the result to the store). +func (message *Message) GetBodyStructure() (*pkgMsg.BodyStructure, error) { + var raw []byte -// GetBodyStructure deserializes body structure from database. If body structure -// is not in database it returns nil error and nil body structure. If error -// occurs it returns nil body structure. -func (message *Message) GetBodyStructure() (bs *pkgMsg.BodyStructure, err error) { - txRead := func(tx *bolt.Tx) error { - bs, err = message.store.txGetBodyStructure( - tx.Bucket(bodystructureBucket), - message.ID(), - ) - return err - } - if err = message.store.db.View(txRead); err != nil { + if err := message.store.db.View(func(tx *bolt.Tx) error { + raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID())) + return nil + }); err != nil { return nil, err } + + if len(raw) > 0 { + // If not possible to deserialize just continue with build. + if bs, err := pkgMsg.DeserializeBodyStructure(raw); err == nil { + return bs, nil + } + } + + literal, err := message.store.getCachedMessage(message.ID()) + if err != nil { + return nil, err + } + + bs, err := pkgMsg.NewBodyStructure(bytes.NewReader(literal)) + if err != nil { + return nil, err + } + + if raw, err = bs.Serialize(); err != nil { + return nil, err + } + + if err := message.store.db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(bodystructureBucket).Put([]byte(message.ID()), raw) + }); err != nil { + return nil, err + } + return bs, nil } -func (message *Message) IncreaseBuildCount() (times uint32, err error) { - txUpdate := func(tx *bolt.Tx) error { - times, err = message.store.txIncreaseMsgBuildCount( - tx.Bucket(msgBuildCountBucket), - message.ID(), - ) - return err - } - if err = message.store.db.Update(txUpdate); err != nil { +// GetRFC822 returns the raw message literal. +func (message *Message) GetRFC822() ([]byte, error) { + return message.store.getCachedMessage(message.ID()) +} + +// GetRFC822Size returns the size of the raw message literal. +func (message *Message) GetRFC822Size() (uint32, error) { + var raw []byte + + if err := message.store.db.View(func(tx *bolt.Tx) error { + raw = tx.Bucket(sizeBucket).Get([]byte(message.ID())) + return nil + }); err != nil { return 0, err } - return times, nil + + if len(raw) > 0 { + return btoi(raw), nil + } + + literal, err := message.store.getCachedMessage(message.ID()) + if err != nil { + return 0, err + } + + if err := message.store.db.Update(func(tx *bolt.Tx) error { + return tx.Bucket(sizeBucket).Put([]byte(message.ID()), itob(uint32(len(literal)))) + }); err != nil { + return 0, err + } + + return uint32(len(literal)), nil } diff --git a/internal/store/mocks/mocks.go b/internal/store/mocks/mocks.go index 7adf0ed3..14fb5831 100644 --- a/internal/store/mocks/mocks.go +++ b/internal/store/mocks/mocks.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier) +// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier,Storer) // Package mocks is a generated GoMock package. package mocks @@ -318,3 +318,54 @@ func (mr *MockChangeNotifierMockRecorder) UpdateMessage(arg0, arg1, arg2, arg3, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMessage", reflect.TypeOf((*MockChangeNotifier)(nil).UpdateMessage), arg0, arg1, arg2, arg3, arg4, arg5) } + +// MockStorer is a mock of Storer interface. +type MockStorer struct { + ctrl *gomock.Controller + recorder *MockStorerMockRecorder +} + +// MockStorerMockRecorder is the mock recorder for MockStorer. +type MockStorerMockRecorder struct { + mock *MockStorer +} + +// NewMockStorer creates a new mock instance. +func NewMockStorer(ctrl *gomock.Controller) *MockStorer { + mock := &MockStorer{ctrl: ctrl} + mock.recorder = &MockStorerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStorer) EXPECT() *MockStorerMockRecorder { + return m.recorder +} + +// BuildAndCacheMessage mocks base method. +func (m *MockStorer) BuildAndCacheMessage(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BuildAndCacheMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// BuildAndCacheMessage indicates an expected call of BuildAndCacheMessage. +func (mr *MockStorerMockRecorder) BuildAndCacheMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildAndCacheMessage", reflect.TypeOf((*MockStorer)(nil).BuildAndCacheMessage), arg0) +} + +// IsCached mocks base method. +func (m *MockStorer) IsCached(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsCached", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsCached indicates an expected call of IsCached. +func (mr *MockStorerMockRecorder) IsCached(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStorer)(nil).IsCached), arg0) +} diff --git a/internal/store/store.go b/internal/store/store.go index ac960e1f..03de5d0f 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -26,8 +26,11 @@ import ( "time" "github.com/ProtonMail/proton-bridge/internal/sentry" + "github.com/ProtonMail/proton-bridge/internal/store/cache" "github.com/ProtonMail/proton-bridge/pkg/listener" + "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/pmapi" + "github.com/ProtonMail/proton-bridge/pkg/pool" "github.com/hashicorp/go-multierror" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -52,19 +55,21 @@ var ( // Database structure: // * metadata - // * {messageID} -> message data (subject, from, to, time, body size, ...) + // * {messageID} -> message data (subject, from, to, time, ...) // * headers // * {messageID} -> header bytes // * bodystructure // * {messageID} -> message body structure - // * msgbuildcount - // * {messageID} -> uint32 number of message builds to track re-sync issues + // * size + // * {messageID} -> uint32 value // * counts // * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive // * address_info // * {index} -> {address, addressID} // * address_mode // * mode -> string split or combined + // * cache_passphrase + // * passphrase -> cache passphrase (pgp encrypted message) // * mailboxes_version // * version -> uint32 value // * sync_state @@ -79,19 +84,20 @@ var ( // * {messageID} -> uint32 imapUID // * deleted_ids (can be missing or have no keys) // * {messageID} -> true - metadataBucket = []byte("metadata") //nolint[gochecknoglobals] - headersBucket = []byte("headers") //nolint[gochecknoglobals] - bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals] - msgBuildCountBucket = []byte("msgbuildcount") //nolint[gochecknoglobals] - countsBucket = []byte("counts") //nolint[gochecknoglobals] - addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals] - addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals] - syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals] - mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals] - imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals] - apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals] - deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals] - mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals] + metadataBucket = []byte("metadata") //nolint[gochecknoglobals] + headersBucket = []byte("headers") //nolint[gochecknoglobals] + bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals] + sizeBucket = []byte("size") //nolint[gochecknoglobals] + countsBucket = []byte("counts") //nolint[gochecknoglobals] + addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals] + addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals] + cachePassphraseBucket = []byte("cache_passphrase") //nolint[gochecknoglobals] + syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals] + mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals] + imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals] + apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals] + deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals] + mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals] // ErrNoSuchAPIID when mailbox does not have API ID. ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals] @@ -117,18 +123,23 @@ func exposeContextForSMTP() context.Context { type Store struct { sentryReporter *sentry.Reporter panicHandler PanicHandler - eventLoop *eventLoop user BridgeUser + eventLoop *eventLoop + currentEvents *Events log *logrus.Entry - cache *Cache filePath string db *bolt.DB lock *sync.RWMutex addresses map[string]*Address notifier ChangeNotifier + builder *message.Builder + cache cache.Cache + cacher *Cacher + done chan struct{} + isSyncRunning bool syncCooldown cooldown addressMode addressMode @@ -139,12 +150,14 @@ func New( // nolint[funlen] sentryReporter *sentry.Reporter, panicHandler PanicHandler, user BridgeUser, - events listener.Listener, + listener listener.Listener, + cache cache.Cache, + builder *message.Builder, path string, - cache *Cache, + currentEvents *Events, ) (store *Store, err error) { - if user == nil || events == nil || cache == nil { - return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache) + if user == nil || listener == nil || currentEvents == nil { + return nil, fmt.Errorf("missing parameters - user: %v, listener: %v, currentEvents: %v", user, listener, currentEvents) } l := log.WithField("user", user.ID()) @@ -160,21 +173,29 @@ func New( // nolint[funlen] bdb, err := openBoltDatabase(path) if err != nil { - err = errors.Wrap(err, "failed to open store database") - return + return nil, errors.Wrap(err, "failed to open store database") } store = &Store{ sentryReporter: sentryReporter, panicHandler: panicHandler, user: user, - cache: cache, - filePath: path, - db: bdb, - lock: &sync.RWMutex{}, - log: l, + currentEvents: currentEvents, + + log: l, + + filePath: path, + db: bdb, + lock: &sync.RWMutex{}, + + builder: builder, + cache: cache, } + // Create a new cacher. It's not started yet. + // NOTE(GODT-1158): I hate this circular dependency store->cacher->store :( + store.cacher = newCacher(store) + // Minimal increase is event pollInterval, doubles every failed retry up to 5 minutes. store.syncCooldown.setExponentialWait(pollInterval, 2, 5*time.Minute) @@ -188,7 +209,7 @@ func New( // nolint[funlen] } if user.IsConnected() { - store.eventLoop = newEventLoop(cache, store, user, events) + store.eventLoop = newEventLoop(currentEvents, store, user, listener) go func() { defer store.panicHandler.HandlePanic() store.eventLoop.start() @@ -216,10 +237,11 @@ func openBoltDatabase(filePath string) (db *bolt.DB, err error) { metadataBucket, headersBucket, bodystructureBucket, - msgBuildCountBucket, + sizeBucket, countsBucket, addressInfoBucket, addressModeBucket, + cachePassphraseBucket, syncStateBucket, mailboxesBucket, mboxVersionBucket, @@ -365,6 +387,24 @@ func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label) return } +// newBuildJob returns a new build job for the given message using the store's message builder. +func (store *Store) newBuildJob(messageID string, priority int) (*message.Job, pool.DoneFunc) { + return store.builder.NewJobWithOptions( + context.Background(), + store.client(), + messageID, + message.JobOptions{ + IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead. + SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate. + AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id. + AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id. + AddMessageDate: true, // Whether to include message time as X-Pm-Date. + AddMessageIDReference: true, // Whether to include the MessageID in References. + }, + priority, + ) +} + // Close stops the event loop and closes the database to free the file. func (store *Store) Close() error { store.lock.Lock() @@ -381,12 +421,21 @@ func (store *Store) CloseEventLoop() { } func (store *Store) close() error { + // Stop the watcher first before closing the database. + store.stopWatcher() + + // Stop the cacher. + store.cacher.stop() + + // Stop the event loop. store.CloseEventLoop() + + // Close the database. return store.db.Close() } // Remove closes and removes the database file and clears the cache file. -func (store *Store) Remove() (err error) { +func (store *Store) Remove() error { store.lock.Lock() defer store.lock.Unlock() @@ -394,22 +443,34 @@ func (store *Store) Remove() (err error) { var result *multierror.Error - if err = store.close(); err != nil { + if err := store.close(); err != nil { result = multierror.Append(result, errors.Wrap(err, "failed to close store")) } - if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil { + if err := RemoveStore(store.currentEvents, store.filePath, store.user.ID()); err != nil { result = multierror.Append(result, errors.Wrap(err, "failed to remove store")) } + if err := store.RemoveCache(); err != nil { + result = multierror.Append(result, errors.Wrap(err, "failed to remove cache")) + } + return result.ErrorOrNil() } +func (store *Store) RemoveCache() error { + if err := store.clearCachePassphrase(); err != nil { + logrus.WithError(err).Error("Failed to clear cache passphrase") + } + + return store.cache.Delete(store.user.ID()) +} + // RemoveStore removes the database file and clears the cache file. -func RemoveStore(cache *Cache, path, userID string) error { +func RemoveStore(currentEvents *Events, path, userID string) error { var result *multierror.Error - if err := cache.clearCacheUser(userID); err != nil { + if err := currentEvents.clearUserEvents(userID); err != nil { result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache")) } diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 029e26a3..f594817f 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -23,13 +23,17 @@ import ( "io/ioutil" "os" "path/filepath" + "runtime" "testing" "time" "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/proton-bridge/internal/store/cache" storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks" + "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" + tests "github.com/ProtonMail/proton-bridge/test" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -139,7 +143,7 @@ type mocksForStore struct { store *Store tmpDir string - cache *Cache + cache *Events } func initMocks(tb testing.TB) (*mocksForStore, func()) { @@ -162,7 +166,7 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) { require.NoError(tb, err) cacheFile := filepath.Join(mocks.tmpDir, "cache.json") - mocks.cache = NewCache(cacheFile) + mocks.cache = NewEvents(cacheFile) return mocks, func() { if err := recover(); err != nil { @@ -176,13 +180,14 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) { } } -func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam] +func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam] mocks.user.EXPECT().ID().Return("userID").AnyTimes() mocks.user.EXPECT().IsConnected().Return(true) mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode) mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client) + mocks.client.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes() mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true}, @@ -213,6 +218,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M mocks.panicHandler, mocks.user, mocks.events, + cache.NewInMemoryCache(1<<20), + message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()), filepath.Join(mocks.tmpDir, "mailbox-test.db"), mocks.cache, ) diff --git a/internal/store/user_message.go b/internal/store/user_message.go index 47470109..9b06b890 100644 --- a/internal/store/user_message.go +++ b/internal/store/user_message.go @@ -27,7 +27,6 @@ import ( "strings" "github.com/ProtonMail/gopenpgp/v2/crypto" - pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -154,11 +153,6 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d return false, err } - msgSize := message.Size - if msgSize == 0 { - msgSize = int64(len(message.Body)) - } - var attSize int64 for _, att := range attachments { b, err := ioutil.ReadAll(att.encReader) @@ -169,7 +163,7 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d att.encReader = bytes.NewBuffer(b) } - return msgSize+attSize <= maxUpload, nil + return int64(len(message.Body))+attSize <= maxUpload, nil } func (store *Store) getDraftAction(message *pmapi.Message) int { @@ -237,39 +231,6 @@ func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Messag return nil } -func (store *Store) txPutBodyStructure(bsBucket *bolt.Bucket, msgID string, bs *pkgMsg.BodyStructure) error { - raw, err := bs.Serialize() - if err != nil { - return err - } - err = bsBucket.Put([]byte(msgID), raw) - if err != nil { - return errors.Wrap(err, "cannot put bodystructure bucket") - } - return nil -} - -func (store *Store) txGetBodyStructure(bsBucket *bolt.Bucket, msgID string) (*pkgMsg.BodyStructure, error) { - raw := bsBucket.Get([]byte(msgID)) - if len(raw) == 0 { - return nil, nil - } - return pkgMsg.DeserializeBodyStructure(raw) -} - -func (store *Store) txIncreaseMsgBuildCount(b *bolt.Bucket, msgID string) (uint32, error) { - key := []byte(msgID) - count := uint32(0) - - raw := b.Get(key) - if raw != nil { - count = btoi(raw) - } - - count++ - return count, b.Put(key, itob(count)) -} - // createOrUpdateMessageEvent is helper to create only one message with // createOrUpdateMessagesEvent. func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error { @@ -287,7 +248,7 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { / b := tx.Bucket(metadataBucket) for _, msg := range msgs { clearNonMetadata(msg) - txUpdateMetadaFromDB(b, msg, store.log) + txUpdateMetadataFromDB(b, msg, store.log) } return nil }) @@ -341,6 +302,11 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { / return err } + // Notify the cacher that it should start caching messages. + for _, msg := range msgs { + store.cacher.newJob(msg.ID) + } + return nil } @@ -351,16 +317,12 @@ func clearNonMetadata(onlyMeta *pmapi.Message) { onlyMeta.Attachments = nil } -// txUpdateMetadaFromDB changes the the onlyMeta data. +// txUpdateMetadataFromDB changes the the onlyMeta data. // If there is stored message in metaBucket the size, header and MIMEType are // not changed if already set. To change these: // * size must be updated by Message.SetSize // * contentType and header must be updated by Message.SetContentTypeAndHeader. -func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) { - // Size attribute on the server is counting encrypted data. We need to compute - // "real" size of decrypted data. Negative values will be processed during fetch. - onlyMeta.Size = -1 - +func txUpdateMetadataFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) { msgb := metaBucket.Get([]byte(onlyMeta.ID)) if msgb == nil { return @@ -378,8 +340,7 @@ func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log return } - // Keep already calculated size and content type. - onlyMeta.Size = stored.Size + // Keep content type. onlyMeta.MIMEType = stored.MIMEType if stored.Header != "" && stored.Header != "(No Header)" { tmpMsg, err := mail.ReadMessage( @@ -401,6 +362,12 @@ func (store *Store) deleteMessageEvent(apiID string) error { // deleteMessagesEvent deletes the message from metadata and all mailbox buckets. func (store *Store) deleteMessagesEvent(apiIDs []string) error { + for _, messageID := range apiIDs { + if err := store.cache.Rem(store.UserID(), messageID); err != nil { + logrus.WithError(err).Error("Failed to remove message from cache") + } + } + return store.db.Update(func(tx *bolt.Tx) error { for _, apiID := range apiIDs { if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil { diff --git a/internal/store/user_message_test.go b/internal/store/user_message_test.go index bc606723..ca196177 100644 --- a/internal/store/user_message_test.go +++ b/internal/store/user_message_test.go @@ -33,7 +33,7 @@ func TestGetAllMessageIDs(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel}) @@ -47,7 +47,7 @@ func TestGetMessageFromDB(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) tests := []struct{ msgID, wantErr string }{ @@ -72,7 +72,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) msg, err := m.store.getMessageFromDB("msg1") @@ -81,12 +81,10 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) { // Check non-meta and calculated data are cleared/empty. a.Equal(t, "", msg.Body) a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments) - a.Equal(t, int64(-1), msg.Size) a.Equal(t, "", msg.MIMEType) a.Equal(t, make(mail.Header), msg.Header) // Change the calculated data. - wantSize := int64(42) wantMIMEType := "plain-text" wantHeader := mail.Header{ "Key": []string{"value"}, @@ -94,13 +92,11 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) { storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1") require.Nil(t, err) - require.Nil(t, storeMsg.SetSize(wantSize)) require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader)) // Check calculated data. msg, err = m.store.getMessageFromDB("msg1") require.Nil(t, err) - a.Equal(t, wantSize, msg.Size) a.Equal(t, wantMIMEType, msg.MIMEType) a.Equal(t, wantHeader, msg.Header) @@ -109,7 +105,6 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) { msg, err = m.store.getMessageFromDB("msg1") require.Nil(t, err) - a.Equal(t, wantSize, msg.Size) a.Equal(t, wantMIMEType, msg.MIMEType) a.Equal(t, wantHeader, msg.Header) } @@ -118,7 +113,7 @@ func TestDeleteMessage(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel}) @@ -129,8 +124,7 @@ func TestDeleteMessage(t *testing.T) { } func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam] - msg := getTestMessage(id, subject, sender, unread, labelIDs) - require.Nil(t, m.store.createOrUpdateMessageEvent(msg)) + require.Nil(t, m.store.createOrUpdateMessageEvent(getTestMessage(id, subject, sender, unread, labelIDs))) } func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message { @@ -142,7 +136,6 @@ func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) Sender: address, ToList: []*mail.Address{address}, LabelIDs: labelIDs, - Size: 12345, Body: "body of message", Attachments: []*pmapi.Attachment{{ ID: "attachment1", @@ -162,7 +155,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(false) + m.newStoreNoEvents(t, false) m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{ MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+. }, nil) @@ -181,7 +174,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(false) + m.newStoreNoEvents(t, false) m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{ MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+. }, nil) diff --git a/internal/store/user_sync_test.go b/internal/store/user_sync_test.go index 688f7a9a..7b8da1bc 100644 --- a/internal/store/user_sync_test.go +++ b/internal/store/user_sync_test.go @@ -30,7 +30,7 @@ func TestLoadSaveSyncState(t *testing.T) { m, clear := initMocks(t) defer clear() - m.newStoreNoEvents(true) + m.newStoreNoEvents(t, true) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel}) diff --git a/internal/users/user.go b/internal/users/user.go index 8dd6921c..b3f60932 100644 --- a/internal/users/user.go +++ b/internal/users/user.go @@ -107,6 +107,21 @@ func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) erro return err } + // If the client is already unlocked, we can unlock the store cache as well. + if client.IsUnlocked() { + kr, err := client.GetUserKeyRing() + if err != nil { + return err + } + + if err := u.store.UnlockCache(kr); err != nil { + return err + } + + // NOTE(GODT-1158): If using in-memory cache we probably shouldn't start the watcher? + u.store.StartWatcher() + } + return nil } diff --git a/internal/users/user_credentials_test.go b/internal/users/user_credentials_test.go index accee921..df994fed 100644 --- a/internal/users/user_credentials_test.go +++ b/internal/users/user_credentials_test.go @@ -32,7 +32,7 @@ func TestUpdateUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) gomock.InOrder( @@ -50,7 +50,7 @@ func TestUserSwitchAddressMode(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) // Ignore any sync on background. @@ -76,7 +76,7 @@ func TestUserSwitchAddressMode(t *testing.T) { r.False(t, user.creds.IsCombinedAddressMode) r.False(t, user.IsCombinedAddressMode()) - // MOck change to combined mode. + // Mock change to combined mode. gomock.InOrder( m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"), m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"), @@ -98,7 +98,7 @@ func TestLogoutUser(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) gomock.InOrder( @@ -115,7 +115,7 @@ func TestLogoutUserFailsLogout(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) gomock.InOrder( @@ -133,7 +133,7 @@ func TestCheckBridgeLogin(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) err := user.CheckBridgeLogin(testCredentials.BridgePassword) @@ -144,7 +144,7 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "") @@ -187,7 +187,7 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) err := user.CheckBridgeLogin("wrong!") diff --git a/internal/users/user_new_test.go b/internal/users/user_new_test.go index f3810c06..b3b4045b 100644 --- a/internal/users/user_new_test.go +++ b/internal/users/user_new_test.go @@ -64,7 +64,7 @@ func TestNewUser(t *testing.T) { defer m.ctrl.Finish() m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - mockInitConnectedUser(m) + mockInitConnectedUser(t, m) mockEventLoopNoAction(m) checkNewUserHasCredentials(m, "", testCredentials) diff --git a/internal/users/user_store_test.go b/internal/users/user_store_test.go index 87ee0c8e..f333dcd8 100644 --- a/internal/users/user_store_test.go +++ b/internal/users/user_store_test.go @@ -31,7 +31,7 @@ func TestClearStoreWithStore(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) r.Nil(t, user.store.Close()) @@ -43,7 +43,7 @@ func TestClearStoreWithoutStore(t *testing.T) { m := initMocks(t) defer m.ctrl.Finish() - user := testNewUser(m) + user := testNewUser(t, m) defer cleanUpUserData(user) r.NotNil(t, user.store) diff --git a/internal/users/user_test.go b/internal/users/user_test.go index 9ced4b94..7780a1b1 100644 --- a/internal/users/user_test.go +++ b/internal/users/user_test.go @@ -18,13 +18,15 @@ package users import ( + "testing" + r "github.com/stretchr/testify/require" ) // testNewUser sets up a new, authorised user. -func testNewUser(m mocks) *User { +func testNewUser(t *testing.T, m mocks) *User { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil) - mockInitConnectedUser(m) + mockInitConnectedUser(t, m) mockEventLoopNoAction(m) user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker) diff --git a/internal/users/users.go b/internal/users/users.go index 87702a1a..01056515 100644 --- a/internal/users/users.go +++ b/internal/users/users.go @@ -20,12 +20,13 @@ package users import ( "context" + "os" + "path/filepath" "strings" "sync" "time" "github.com/ProtonMail/proton-bridge/internal/events" - imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache" "github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/users/credentials" "github.com/ProtonMail/proton-bridge/pkg/listener" @@ -225,6 +226,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password []by return nil, errors.Wrap(err, "failed to update password of user in credentials store") } + // will go and unlock cache if not already done if err := user.connect(client, creds); err != nil { return nil, errors.Wrap(err, "failed to reconnect existing user") } @@ -341,9 +343,6 @@ func (u *Users) ClearData() error { result = multierror.Append(result, err) } - // Need to clear imap cache otherwise fetch response will be remembered from previous test. - imapcache.Clear() - return result } @@ -366,6 +365,7 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error { if err := user.closeStore(); err != nil { log.WithError(err).Error("Failed to close user store") } + if clearStore { // Clear cache after closing connections (done in logout). if err := user.clearStore(); err != nil { @@ -427,6 +427,41 @@ func (u *Users) DisallowProxy() { u.clientManager.DisallowProxy() } +func (u *Users) EnableCache() error { + // NOTE(GODT-1158): Check for available size before enabling. + + return nil +} + +func (u *Users) MigrateCache(from, to string) error { + // NOTE(GODT-1158): Is it enough to just close the store? Do we need to force-close the cacher too? + + for _, user := range u.users { + if err := user.closeStore(); err != nil { + logrus.WithError(err).Error("Failed to close user's store") + } + } + + // Ensure the parent directory exists. + if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil { + return err + } + + return os.Rename(from, to) +} + +func (u *Users) DisableCache() error { + // NOTE(GODT-1158): Is it an error if we can't remove a user's cache? + + for _, user := range u.users { + if err := user.store.RemoveCache(); err != nil { + logrus.WithError(err).Error("Failed to remove user's message cache") + } + } + + return nil +} + // hasUser returns whether the struct currently has a user with ID `id`. func (u *Users) hasUser(id string) (user *User, ok bool) { for _, u := range u.users { diff --git a/internal/users/users_login_test.go b/internal/users/users_login_test.go index 8c24a9c6..1e321ff5 100644 --- a/internal/users/users_login_test.go +++ b/internal/users/users_login_test.go @@ -49,7 +49,7 @@ func TestUsersFinishLoginNewUser(t *testing.T) { // Init users with no user from keychain. m.credentialsStore.EXPECT().List().Return([]string{}, nil) - mockAddingConnectedUser(m) + mockAddingConnectedUser(t, m) mockEventLoopNoAction(m) m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel)) @@ -74,7 +74,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) { m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil), m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil), ) - mockInitConnectedUser(m) + mockInitConnectedUser(t, m) mockEventLoopNoAction(m) m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID) @@ -95,7 +95,7 @@ func TestUsersFinishLoginConnectedUser(t *testing.T) { // Mock loading connected user. m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil) - mockLoadingConnectedUser(m, testCredentials) + mockLoadingConnectedUser(t, m, testCredentials) mockEventLoopNoAction(m) // Mock process of FinishLogin of already connected user. diff --git a/internal/users/users_new_test.go b/internal/users/users_new_test.go index 00a7b886..c00c42b1 100644 --- a/internal/users/users_new_test.go +++ b/internal/users/users_new_test.go @@ -49,7 +49,7 @@ func TestNewUsersWithConnectedUser(t *testing.T) { defer m.ctrl.Finish() m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil) - mockLoadingConnectedUser(m, testCredentials) + mockLoadingConnectedUser(t, m, testCredentials) mockEventLoopNoAction(m) checkUsersNew(t, m, []*credentials.Credentials{testCredentials}) } @@ -71,7 +71,7 @@ func TestNewUsersWithUsers(t *testing.T) { m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil) mockLoadingDisconnectedUser(m, testCredentialsDisconnected) - mockLoadingConnectedUser(m, testCredentials) + mockLoadingConnectedUser(t, m, testCredentials) mockEventLoopNoAction(m) checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials}) } diff --git a/internal/users/users_test.go b/internal/users/users_test.go index 05077ec7..87b784d5 100644 --- a/internal/users/users_test.go +++ b/internal/users/users_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io/ioutil" "os" + "runtime" "runtime/debug" "testing" "time" @@ -28,10 +29,13 @@ import ( "github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/sentry" "github.com/ProtonMail/proton-bridge/internal/store" + "github.com/ProtonMail/proton-bridge/internal/store/cache" "github.com/ProtonMail/proton-bridge/internal/users/credentials" usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks" + "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/pmapi" pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks" + tests "github.com/ProtonMail/proton-bridge/test" gomock "github.com/golang/mock/gomock" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -42,9 +46,11 @@ func TestMain(m *testing.M) { if os.Getenv("VERBOSITY") == "fatal" { logrus.SetLevel(logrus.FatalLevel) } + if os.Getenv("VERBOSITY") == "trace" { logrus.SetLevel(logrus.TraceLevel) } + os.Exit(m.Run()) } @@ -151,7 +157,7 @@ type mocks struct { clientManager *pmapimocks.MockManager pmapiClient *pmapimocks.MockClient - storeCache *store.Cache + storeCache *store.Events } func initMocks(t *testing.T) mocks { @@ -178,7 +184,7 @@ func initMocks(t *testing.T) mocks { clientManager: pmapimocks.NewMockManager(mockCtrl), pmapiClient: pmapimocks.NewMockClient(mockCtrl), - storeCache: store.NewCache(cacheFile.Name()), + storeCache: store.NewEvents(cacheFile.Name()), } // Called during clean-up. @@ -187,9 +193,20 @@ func initMocks(t *testing.T) mocks { // Set up store factory. m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) { var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests. - dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db") + + dbFile, err := ioutil.TempFile(t.TempDir(), "bridge-store-db-*.db") r.NoError(t, err, "could not get temporary file for store db") - return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache) + + return store.New( + sentryReporter, + m.PanicHandler, + user, + m.eventListener, + cache.NewInMemoryCache(1<<20), + message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()), + dbFile.Name(), + m.storeCache, + ) }).AnyTimes() m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes() @@ -212,8 +229,8 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) { func testNewUsersWithUsers(t *testing.T, m mocks) *Users { m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil) - mockLoadingConnectedUser(m, testCredentials) - mockLoadingConnectedUser(m, testCredentialsSplit) + mockLoadingConnectedUser(t, m, testCredentials) + mockLoadingConnectedUser(t, m, testCredentialsSplit) mockEventLoopNoAction(m) return testNewUsers(t, m) @@ -245,7 +262,7 @@ func cleanUpUsersData(b *Users) { } } -func mockAddingConnectedUser(m mocks) { +func mockAddingConnectedUser(t *testing.T, m mocks) { gomock.InOrder( // Mock of users.FinishLogin. m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil), @@ -256,10 +273,10 @@ func mockAddingConnectedUser(m mocks) { m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil), ) - mockInitConnectedUser(m) + mockInitConnectedUser(t, m) } -func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) { +func mockLoadingConnectedUser(t *testing.T, m mocks, creds *credentials.Credentials) { authRefresh := &pmapi.AuthRefresh{ UID: "uid", AccessToken: "acc", @@ -273,10 +290,10 @@ func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) { m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil), ) - mockInitConnectedUser(m) + mockInitConnectedUser(t, m) } -func mockInitConnectedUser(m mocks) { +func mockInitConnectedUser(t *testing.T, m mocks) { // Mock of user initialisation. m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any()) m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes() @@ -286,6 +303,7 @@ func mockInitConnectedUser(m mocks) { m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil), m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil), m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}), + m.pmapiClient.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes(), ) } diff --git a/pkg/message/build.go b/pkg/message/build.go index 643bfb39..8d20fd7c 100644 --- a/pkg/message/build.go +++ b/pkg/message/build.go @@ -20,10 +20,12 @@ package message import ( "context" "io" + "io/ioutil" "sync" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/proton-bridge/pkg/pmapi" + "github.com/ProtonMail/proton-bridge/pkg/pool" "github.com/pkg/errors" ) @@ -32,11 +34,15 @@ var ( ErrNoSuchKeyRing = errors.New("the keyring to decrypt this message could not be found") ) +const ( + BackgroundPriority = 1 << iota + ForegroundPriority +) + type Builder struct { - reqs chan fetchReq - done chan struct{} - jobs map[string]*BuildJob - locker sync.Mutex + pool *pool.Pool + jobs map[string]*Job + lock sync.Mutex } type Fetcher interface { @@ -48,111 +54,159 @@ type Fetcher interface { // NewBuilder creates a new builder which manages the given number of fetch/attach/build workers. // - fetchWorkers: the number of workers which fetch messages from API // - attachWorkers: the number of workers which fetch attachments from API. -// - buildWorkers: the number of workers which decrypt/build RFC822 message literals. -// -// NOTE: Each fetch worker spawns a unique set of attachment workers! -// There can therefore be up to fetchWorkers*attachWorkers simultaneous API connections. // // The returned builder is ready to handle jobs -- see (*Builder).NewJob for more information. // // Call (*Builder).Done to shut down the builder and stop all workers. -func NewBuilder(fetchWorkers, attachWorkers, buildWorkers int) *Builder { - b := newBuilder() +func NewBuilder(fetchWorkers, attachWorkers int) *Builder { + attacherPool := pool.New(attachWorkers, newAttacherWorkFunc()) - fetchReqCh, fetchResCh := startFetchWorkers(fetchWorkers, attachWorkers) - buildReqCh, buildResCh := startBuildWorkers(buildWorkers) + fetcherPool := pool.New(fetchWorkers, newFetcherWorkFunc(attacherPool)) - go func() { - defer close(fetchReqCh) - - for { - select { - case req := <-b.reqs: - fetchReqCh <- req - - case <-b.done: - return - } - } - }() - - go func() { - defer close(buildReqCh) - - for res := range fetchResCh { - if res.err != nil { - b.jobFailure(res.messageID, res.err) - } else { - buildReqCh <- res - } - } - }() - - go func() { - for res := range buildResCh { - if res.err != nil { - b.jobFailure(res.messageID, res.err) - } else { - b.jobSuccess(res.messageID, res.literal) - } - } - }() - - return b -} - -func newBuilder() *Builder { return &Builder{ - reqs: make(chan fetchReq), - done: make(chan struct{}), - jobs: make(map[string]*BuildJob), + pool: fetcherPool, + jobs: make(map[string]*Job), } } -// NewJob tells the builder to begin building the message with the given ID. -// The result (or any error which occurred during building) can be retrieved from the returned job when available. -func (b *Builder) NewJob(ctx context.Context, api Fetcher, messageID string) *BuildJob { - return b.NewJobWithOptions(ctx, api, messageID, JobOptions{}) +func (builder *Builder) NewJob(ctx context.Context, fetcher Fetcher, messageID string, prio int) (*Job, pool.DoneFunc) { + return builder.NewJobWithOptions(ctx, fetcher, messageID, JobOptions{}, prio) } -// NewJobWithOptions creates a new job with custom options. See NewJob for more information. -func (b *Builder) NewJobWithOptions(ctx context.Context, api Fetcher, messageID string, opts JobOptions) *BuildJob { - b.locker.Lock() - defer b.locker.Unlock() +func (builder *Builder) NewJobWithOptions(ctx context.Context, fetcher Fetcher, messageID string, opts JobOptions, prio int) (*Job, pool.DoneFunc) { + builder.lock.Lock() + defer builder.lock.Unlock() - if job, ok := b.jobs[messageID]; ok { - return job + if job, ok := builder.jobs[messageID]; ok { + if job.GetPriority() < prio { + job.SetPriority(prio) + } + + return job, job.done } - b.jobs[messageID] = newBuildJob(messageID) + job, done := builder.pool.NewJob( + &fetchReq{ + fetcher: fetcher, + messageID: messageID, + options: opts, + }, + prio, + ) - go func() { b.reqs <- fetchReq{ctx: ctx, api: api, messageID: messageID, opts: opts} }() + buildJob := &Job{ + Job: job, + done: done, + } - return b.jobs[messageID] + builder.jobs[messageID] = buildJob + + return buildJob, func() { + builder.lock.Lock() + defer builder.lock.Unlock() + + // Remove the job from the builder. + delete(builder.jobs, messageID) + + // And mark it as done. + done() + } } -// Done shuts down the builder and stops all workers. -func (b *Builder) Done() { - b.locker.Lock() - defer b.locker.Unlock() - - close(b.done) +func (builder *Builder) Done() { + // NOTE(GODT-1158): Stop worker pool. } -func (b *Builder) jobSuccess(messageID string, literal []byte) { - b.locker.Lock() - defer b.locker.Unlock() - - b.jobs[messageID].postSuccess(literal) - - delete(b.jobs, messageID) +type fetchReq struct { + fetcher Fetcher + messageID string + options JobOptions } -func (b *Builder) jobFailure(messageID string, err error) { - b.locker.Lock() - defer b.locker.Unlock() - - b.jobs[messageID].postFailure(err) - - delete(b.jobs, messageID) +type attachReq struct { + fetcher Fetcher + message *pmapi.Message +} + +type Job struct { + *pool.Job + + done pool.DoneFunc +} + +func (job *Job) GetResult() ([]byte, error) { + res, err := job.Job.GetResult() + if err != nil { + return nil, err + } + + return res.([]byte), nil +} + +func newAttacherWorkFunc() pool.WorkFunc { + return func(payload interface{}, prio int) (interface{}, error) { + req, ok := payload.(*attachReq) + if !ok { + panic("bad payload type") + } + + res := make(map[string][]byte) + + for _, att := range req.message.Attachments { + rc, err := req.fetcher.GetAttachment(context.Background(), att.ID) + if err != nil { + return nil, err + } + + b, err := ioutil.ReadAll(rc) + if err != nil { + return nil, err + } + + if err := rc.Close(); err != nil { + return nil, err + } + + res[att.ID] = b + } + + return res, nil + } +} + +func newFetcherWorkFunc(attacherPool *pool.Pool) pool.WorkFunc { + return func(payload interface{}, prio int) (interface{}, error) { + req, ok := payload.(*fetchReq) + if !ok { + panic("bad payload type") + } + + msg, err := req.fetcher.GetMessage(context.Background(), req.messageID) + if err != nil { + return nil, err + } + + attJob, attDone := attacherPool.NewJob(&attachReq{ + fetcher: req.fetcher, + message: msg, + }, prio) + defer attDone() + + val, err := attJob.GetResult() + if err != nil { + return nil, err + } + + attData, ok := val.(map[string][]byte) + if !ok { + panic("bad response type") + } + + kr, err := req.fetcher.KeyRingForAddressID(msg.AddressID) + if err != nil { + return nil, ErrNoSuchKeyRing + } + + return buildRFC822(kr, msg, attData, req.options) + } } diff --git a/pkg/message/build_build.go b/pkg/message/build_build.go deleted file mode 100644 index 596f0e7f..00000000 --- a/pkg/message/build_build.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package message - -import ( - "sync" - - "github.com/pkg/errors" -) - -type buildRes struct { - messageID string - literal []byte - err error -} - -func newBuildResSuccess(messageID string, literal []byte) buildRes { - return buildRes{ - messageID: messageID, - literal: literal, - } -} - -func newBuildResFailure(messageID string, err error) buildRes { - return buildRes{ - messageID: messageID, - err: err, - } -} - -// startBuildWorkers starts the given number of build workers. -// These workers decrypt and build messages into RFC822 literals. -// Two channels are returned: -// - buildReqCh: used to send work items to the worker pool -// - buildResCh: used to receive work results from the worker pool -func startBuildWorkers(buildWorkers int) (chan fetchRes, chan buildRes) { - buildReqCh := make(chan fetchRes) - buildResCh := make(chan buildRes) - - go func() { - defer close(buildResCh) - - var wg sync.WaitGroup - - wg.Add(buildWorkers) - - for workerID := 0; workerID < buildWorkers; workerID++ { - go buildWorker(buildReqCh, buildResCh, &wg) - } - - wg.Wait() - }() - - return buildReqCh, buildResCh -} - -func buildWorker(buildReqCh <-chan fetchRes, buildResCh chan<- buildRes, wg *sync.WaitGroup) { - defer wg.Done() - - for req := range buildReqCh { - l := log. - WithField("addrID", req.msg.AddressID). - WithField("msgID", req.msg.ID) - if kr, err := req.api.KeyRingForAddressID(req.msg.AddressID); err != nil { - l.WithError(err).Warn("Cannot find keyring for address") - buildResCh <- newBuildResFailure(req.msg.ID, errors.Wrap(ErrNoSuchKeyRing, err.Error())) - } else if literal, err := buildRFC822(kr, req.msg, req.atts, req.opts); err != nil { - l.WithError(err).Warn("Build failed") - buildResCh <- newBuildResFailure(req.msg.ID, err) - } else { - buildResCh <- newBuildResSuccess(req.msg.ID, literal) - } - } -} diff --git a/pkg/message/build_fetch.go b/pkg/message/build_fetch.go deleted file mode 100644 index 81a99829..00000000 --- a/pkg/message/build_fetch.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) 2021 Proton Technologies AG -// -// This file is part of ProtonMail Bridge. -// -// ProtonMail Bridge is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// ProtonMail Bridge is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with ProtonMail Bridge. If not, see . - -package message - -import ( - "context" - "io/ioutil" - "sync" - - "github.com/ProtonMail/proton-bridge/pkg/parallel" - "github.com/ProtonMail/proton-bridge/pkg/pmapi" -) - -type fetchReq struct { - ctx context.Context - api Fetcher - messageID string - opts JobOptions -} - -type fetchRes struct { - fetchReq - - msg *pmapi.Message - atts [][]byte - err error -} - -func newFetchResSuccess(req fetchReq, msg *pmapi.Message, atts [][]byte) fetchRes { - return fetchRes{ - fetchReq: req, - msg: msg, - atts: atts, - } -} - -func newFetchResFailure(req fetchReq, err error) fetchRes { - return fetchRes{ - fetchReq: req, - err: err, - } -} - -// startFetchWorkers starts the given number of fetch workers. -// These workers download message and attachment data from API. -// Each fetch worker will use up to the given number of attachment workers to download attachments. -// Two channels are returned: -// - fetchReqCh: used to send work items to the worker pool -// - fetchResCh: used to receive work results from the worker pool -func startFetchWorkers(fetchWorkers, attachWorkers int) (chan fetchReq, chan fetchRes) { - fetchReqCh := make(chan fetchReq) - fetchResCh := make(chan fetchRes) - - go func() { - defer close(fetchResCh) - - var wg sync.WaitGroup - - wg.Add(fetchWorkers) - - for workerID := 0; workerID < fetchWorkers; workerID++ { - go fetchWorker(fetchReqCh, fetchResCh, attachWorkers, &wg) - } - - wg.Wait() - }() - - return fetchReqCh, fetchResCh -} - -func fetchWorker(fetchReqCh <-chan fetchReq, fetchResCh chan<- fetchRes, attachWorkers int, wg *sync.WaitGroup) { - defer wg.Done() - - for req := range fetchReqCh { - msg, atts, err := fetchMessage(req, attachWorkers) - if err != nil { - fetchResCh <- newFetchResFailure(req, err) - } else { - fetchResCh <- newFetchResSuccess(req, msg, atts) - } - } -} - -func fetchMessage(req fetchReq, attachWorkers int) (*pmapi.Message, [][]byte, error) { - msg, err := req.api.GetMessage(req.ctx, req.messageID) - if err != nil { - return nil, nil, err - } - - attList := make([]interface{}, len(msg.Attachments)) - - for i, att := range msg.Attachments { - attList[i] = att.ID - } - - process := func(value interface{}) (interface{}, error) { - rc, err := req.api.GetAttachment(req.ctx, value.(string)) - if err != nil { - return nil, err - } - - b, err := ioutil.ReadAll(rc) - if err != nil { - return nil, err - } - - if err := rc.Close(); err != nil { - return nil, err - } - - return b, nil - } - - attData := make([][]byte, len(msg.Attachments)) - - collect := func(idx int, value interface{}) error { - attData[idx] = value.([]byte) //nolint[forcetypeassert] we wan't to panic here - return nil - } - - if err := parallel.RunParallel(attachWorkers, attList, process, collect); err != nil { - return nil, nil, err - } - - return msg, attData, nil -} diff --git a/pkg/message/build_job.go b/pkg/message/build_job.go index b800afa7..a04c452a 100644 --- a/pkg/message/build_job.go +++ b/pkg/message/build_job.go @@ -25,35 +25,3 @@ type JobOptions struct { AddMessageDate bool // Whether to include message time as X-Pm-Date. AddMessageIDReference bool // Whether to include the MessageID in References. } - -type BuildJob struct { - messageID string - literal []byte - err error - - done chan struct{} -} - -func newBuildJob(messageID string) *BuildJob { - return &BuildJob{ - messageID: messageID, - done: make(chan struct{}), - } -} - -// GetResult returns the build result or any error which occurred during building. -// If the result is not ready yet, it blocks. -func (job *BuildJob) GetResult() ([]byte, error) { - <-job.done - return job.literal, job.err -} - -func (job *BuildJob) postSuccess(literal []byte) { - job.literal = literal - close(job.done) -} - -func (job *BuildJob) postFailure(err error) { - job.err = err - close(job.done) -} diff --git a/pkg/message/build_rfc822.go b/pkg/message/build_rfc822.go index 9bfcaa4d..bfcdac62 100644 --- a/pkg/message/build_rfc822.go +++ b/pkg/message/build_rfc822.go @@ -34,7 +34,7 @@ import ( "github.com/pkg/errors" ) -func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData [][]byte, opts JobOptions) ([]byte, error) { +func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData map[string][]byte, opts JobOptions) ([]byte, error) { switch { case len(msg.Attachments) > 0: return buildMultipartRFC822(kr, msg, attData, opts) @@ -80,7 +80,7 @@ func buildSimpleRFC822(kr *crypto.KeyRing, msg *pmapi.Message, opts JobOptions) func buildMultipartRFC822( kr *crypto.KeyRing, msg *pmapi.Message, - attData [][]byte, + attData map[string][]byte, opts JobOptions, ) ([]byte, error) { boundary := newBoundary(msg.ID) @@ -103,13 +103,13 @@ func buildMultipartRFC822( attachData [][]byte ) - for i, att := range msg.Attachments { + for _, att := range msg.Attachments { if att.Disposition == pmapi.DispositionInline { inlineAtts = append(inlineAtts, att) - inlineData = append(inlineData, attData[i]) + inlineData = append(inlineData, attData[att.ID]) } else { attachAtts = append(attachAtts, att) - attachData = append(attachData, attData[i]) + attachData = append(attachData, attData[att.ID]) } } diff --git a/pkg/message/build_test.go b/pkg/message/build_test.go index aca5925a..f589a27a 100644 --- a/pkg/message/build_test.go +++ b/pkg/message/build_test.go @@ -37,13 +37,16 @@ func TestBuildPlainMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Now()) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -56,14 +59,17 @@ func TestBuildPlainMessageWithLongKey(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(1, 1) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Now()) msg.Header["ReallyVeryVeryVeryVeryVeryLongLongLongLongLongLongLongKeyThatWillHaveNotSoLongValue"] = []string{"value"} - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -77,13 +83,16 @@ func TestBuildHTMLMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now()) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -96,7 +105,7 @@ func TestBuildPlainEncryptedMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-plaintext.eml")) @@ -104,7 +113,10 @@ func TestBuildPlainEncryptedMessage(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -124,7 +136,7 @@ func TestBuildPlainEncryptedMessageMissingHeader(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(1, 1) defer b.Done() body := readerToString(getFileReader("plaintext-missing-header.eml")) @@ -132,7 +144,10 @@ func TestBuildPlainEncryptedMessageMissingHeader(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Now()) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -144,7 +159,7 @@ func TestBuildPlainEncryptedMessageInvalidHeader(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(1, 1) defer b.Done() body := readerToString(getFileReader("plaintext-invalid-header.eml")) @@ -152,7 +167,10 @@ func TestBuildPlainEncryptedMessageInvalidHeader(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Now()) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -164,7 +182,7 @@ func TestBuildPlainSignedEncryptedMessageMissingHeader(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(1, 1) defer b.Done() body := readerToString(getFileReader("plaintext-missing-header.eml")) @@ -180,7 +198,10 @@ func TestBuildPlainSignedEncryptedMessageMissingHeader(t *testing.T) { msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -204,7 +225,7 @@ func TestBuildPlainSignedEncryptedMessageInvalidHeader(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(1, 1) defer b.Done() body := readerToString(getFileReader("plaintext-invalid-header.eml")) @@ -220,7 +241,10 @@ func TestBuildPlainSignedEncryptedMessageInvalidHeader(t *testing.T) { msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -244,7 +268,7 @@ func TestBuildPlainEncryptedLatin2Message(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-plaintext-latin2.eml")) @@ -252,7 +276,10 @@ func TestBuildPlainEncryptedLatin2Message(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -269,7 +296,7 @@ func TestBuildHTMLEncryptedMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-html.eml")) @@ -277,7 +304,10 @@ func TestBuildHTMLEncryptedMessage(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -298,7 +328,7 @@ func TestBuildPlainSignedMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("text_plain.eml")) @@ -314,7 +344,10 @@ func TestBuildPlainSignedMessage(t *testing.T) { msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -339,7 +372,7 @@ func TestBuildPlainSignedBase64Message(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("text_plain_base64.eml")) @@ -355,7 +388,10 @@ func TestBuildPlainSignedBase64Message(t *testing.T) { msg := newRawTestMessage("messageID", "addressID", "multipart/mixed", arm, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -381,7 +417,7 @@ func TestBuildSignedPlainEncryptedMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-signed-plaintext.eml")) @@ -389,7 +425,10 @@ func TestBuildSignedPlainEncryptedMessage(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -421,7 +460,7 @@ func TestBuildSignedHTMLEncryptedMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-signed-html.eml")) @@ -429,7 +468,10 @@ func TestBuildSignedHTMLEncryptedMessage(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -463,7 +505,7 @@ func TestBuildSignedPlainEncryptedMessageWithPubKey(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-signed-plaintext-with-pubkey.eml")) @@ -471,7 +513,10 @@ func TestBuildSignedPlainEncryptedMessageWithPubKey(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -512,7 +557,7 @@ func TestBuildSignedHTMLEncryptedMessageWithPubKey(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-signed-html-with-pubkey.eml")) @@ -520,7 +565,10 @@ func TestBuildSignedHTMLEncryptedMessageWithPubKey(t *testing.T) { kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -562,7 +610,7 @@ func TestBuildSignedMultipartAlternativeEncryptedMessageWithPubKey(t *testing.T) m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-signed-multipart-alternative-with-pubkey.eml")) @@ -570,7 +618,10 @@ func TestBuildSignedMultipartAlternativeEncryptedMessageWithPubKey(t *testing.T) kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -628,7 +679,7 @@ func TestBuildSignedEmbeddedMessageRFC822EncryptedMessageWithPubKey(t *testing.T m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() body := readerToString(getFileReader("pgp-mime-body-signed-embedded-message-rfc822-with-pubkey.eml")) @@ -636,7 +687,10 @@ func TestBuildSignedEmbeddedMessageRFC822EncryptedMessageWithPubKey(t *testing.T kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "multipart/mixed", body, time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -682,14 +736,17 @@ func TestBuildHTMLMessageWithAttachment(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now()) att := addTestAttachment(t, kr, msg, "attachID", "file.png", "image/png", "attachment", "attachment") - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res, 1). @@ -709,14 +766,17 @@ func TestBuildHTMLMessageWithRFC822Attachment(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now()) att := addTestAttachment(t, kr, msg, "attachID", "file.eml", "message/rfc822", "attachment", "... message/rfc822 ...") - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res, 1). @@ -736,14 +796,17 @@ func TestBuildHTMLMessageWithInlineAttachment(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now()) inl := addTestAttachment(t, kr, msg, "inlineID", "file.png", "image/png", "inline", "inline") - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res, 1). @@ -766,7 +829,7 @@ func TestBuildHTMLMessageWithComplexAttachments(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -776,7 +839,10 @@ func TestBuildHTMLMessageWithComplexAttachments(t *testing.T) { att0 := addTestAttachment(t, kr, msg, "attachID0", "attach0.png", "image/png", "attachment", "attach0") att1 := addTestAttachment(t, kr, msg, "attachID1", "attach1.png", "image/png", "attachment", "attach1") - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl0, inl1, att0, att1), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl0, inl1, att0, att1), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res, 1). @@ -820,14 +886,17 @@ func TestBuildAttachmentWithExoticFilename(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now()) att := addTestAttachment(t, kr, msg, "attachID", `I řeally šhould leařn czech.png`, "image/png", "attachment", "attachment") - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) // The "name" and "filename" params should actually be RFC2047-encoded because they aren't 7-bit clean. @@ -843,7 +912,7 @@ func TestBuildAttachmentWithLongFilename(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() veryLongName := strings.Repeat("a", 200) + ".png" @@ -852,7 +921,10 @@ func TestBuildAttachmentWithLongFilename(t *testing.T) { msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Now()) att := addTestAttachment(t, kr, msg, "attachID", veryLongName, "image/png", "attachment", "attachment") - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) // NOTE: hasMaxLineLength is too high! Long filenames should be linewrapped using multipart filenames. @@ -868,13 +940,16 @@ func TestBuildMessageDate(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res).expectDate(is(`Wed, 01 Jan 2020 00:00:00 +0000`)) @@ -884,7 +959,7 @@ func TestBuildMessageWithInvalidDate(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -893,21 +968,26 @@ func TestBuildMessageWithInvalidDate(t *testing.T) { msg := newTestMessage(t, kr, "messageID", "addressID", "text/html", "body", time.Unix(-1, 0)) // Build the message as usual; the date will be before 1970. - resRaw, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + jobRaw, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + resRaw, err := jobRaw.GetResult() require.NoError(t, err) + done() section(t, resRaw). expectDate(is(`Wed, 31 Dec 1969 23:59:59 +0000`)). expectHeader(`X-Original-Date`, isMissing()) // Build the message with date sanitization enabled; the date will be RFC822's birthdate. - resFix, err := b.NewJobWithOptions( + jobFix, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg), msg.ID, JobOptions{SanitizeDate: true}, - ).GetResult() + ForegroundPriority, + ) + resFix, err := jobFix.GetResult() require.NoError(t, err) + done() section(t, resFix). expectDate(is(`Fri, 13 Aug 1982 00:00:00 +0000`)). @@ -918,13 +998,16 @@ func TestBuildMessageInternalID(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "body", time.Now()) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res).expectHeader(`Message-Id`, is(``)) @@ -934,7 +1017,7 @@ func TestBuildMessageExternalID(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -943,7 +1026,10 @@ func TestBuildMessageExternalID(t *testing.T) { // Set the message's external ID; this should be used preferentially to set the Message-Id header field. msg.ExternalID = "externalID" - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res).expectHeader(`Message-Id`, is(``)) @@ -953,7 +1039,7 @@ func TestBuild8BitBody(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -961,7 +1047,10 @@ func TestBuild8BitBody(t *testing.T) { // Set an 8-bit body; the charset should be set to UTF-8. msg := newTestMessage(t, kr, "messageID", "addressID", "text/plain", "I řeally šhould leařn czech", time.Now()) - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res).expectContentTypeParam(`charset`, is(`utf-8`)) @@ -971,7 +1060,7 @@ func TestBuild8BitSubject(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -980,7 +1069,10 @@ func TestBuild8BitSubject(t *testing.T) { // Set an 8-bit subject; it should be RFC2047-encoded. msg.Subject = `I řeally šhould leařn czech` - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -992,7 +1084,7 @@ func TestBuild8BitSender(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1004,7 +1096,10 @@ func TestBuild8BitSender(t *testing.T) { Address: `mail@example.com`, } - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1016,7 +1111,7 @@ func TestBuild8BitRecipients(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1028,7 +1123,10 @@ func TestBuild8BitRecipients(t *testing.T) { {Name: `leařn czech`, Address: `mail2@example.com`}, } - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1040,7 +1138,7 @@ func TestBuildIncludeMessageIDReference(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1049,18 +1147,23 @@ func TestBuildIncludeMessageIDReference(t *testing.T) { // Add references. msg.Header["References"] = []string{""} - res, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + res, err := job.GetResult() require.NoError(t, err) + done() section(t, res).expectHeader(`References`, is(``)) - resRef, err := b.NewJobWithOptions( + jobRef, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg), msg.ID, JobOptions{AddMessageIDReference: true}, - ).GetResult() + ForegroundPriority, + ) + resRef, err := jobRef.GetResult() require.NoError(t, err) + done() section(t, resRef).expectHeader(`References`, is(` `)) } @@ -1069,7 +1172,7 @@ func TestBuildMessageIsDeterministic(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1077,11 +1180,15 @@ func TestBuildMessageIsDeterministic(t *testing.T) { inl := addTestAttachment(t, kr, msg, "inlineID", "file.png", "image/png", "inline", "inline") att := addTestAttachment(t, kr, msg, "attachID", "attach.png", "image/png", "attachment", "attachment") - res1, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID).GetResult() + job1, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID, ForegroundPriority) + res1, err := job1.GetResult() require.NoError(t, err) + done() - res2, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID).GetResult() + job2, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, inl, att), msg.ID, ForegroundPriority) + res2, err := job2.GetResult() require.NoError(t, err) + done() assert.Equal(t, res1, res2) } @@ -1090,15 +1197,18 @@ func TestBuildParallel(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(2, 2, 2) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) msg1 := newTestMessage(t, kr, "messageID1", "addressID", "text/plain", "body1", time.Now()) msg2 := newTestMessage(t, kr, "messageID2", "addressID", "text/plain", "body2", time.Now()) - job1 := b.NewJob(context.Background(), newTestFetcher(m, kr, msg1), msg1.ID) - job2 := b.NewJob(context.Background(), newTestFetcher(m, kr, msg2), msg2.ID) + job1, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg1), msg1.ID, ForegroundPriority) + defer done() + + job2, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg2), msg2.ID, ForegroundPriority) + defer done() res1, err := job1.GetResult() require.NoError(t, err) @@ -1115,7 +1225,7 @@ func TestBuildParallelSameMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(2, 2, 2) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1123,8 +1233,12 @@ func TestBuildParallelSameMessage(t *testing.T) { // Jobs for the same messageID are shared so fetcher is only called once. fetcher := newTestFetcher(m, kr, msg) - job1 := b.NewJob(context.Background(), fetcher, msg.ID) - job2 := b.NewJob(context.Background(), fetcher, msg.ID) + + job1, done := b.NewJob(context.Background(), fetcher, msg.ID, ForegroundPriority) + defer done() + + job2, done := b.NewJob(context.Background(), fetcher, msg.ID, ForegroundPriority) + defer done() res1, err := job1.GetResult() require.NoError(t, err) @@ -1141,7 +1255,7 @@ func TestBuildUndecryptableMessage(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1149,7 +1263,10 @@ func TestBuildUndecryptableMessage(t *testing.T) { // Use a different keyring for encrypting the message; it won't be decryptable. msg := newTestMessage(t, tests.MakeKeyRing(t), "messageID", "addressID", "text/plain", "body", time.Now()) - _, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg), msg.ID, ForegroundPriority) + defer done() + + _, err := job.GetResult() assert.True(t, errors.Is(err, ErrDecryptionFailed)) } @@ -1157,7 +1274,7 @@ func TestBuildUndecryptableAttachment(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1166,7 +1283,10 @@ func TestBuildUndecryptableAttachment(t *testing.T) { // Use a different keyring for encrypting the attachment; it won't be decryptable. att := addTestAttachment(t, tests.MakeKeyRing(t), msg, "attachID", "file.png", "image/png", "attachment", "attachment") - _, err := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID).GetResult() + job, done := b.NewJob(context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, ForegroundPriority) + defer done() + + _, err := job.GetResult() assert.True(t, errors.Is(err, ErrDecryptionFailed)) } @@ -1174,7 +1294,7 @@ func TestBuildCustomMessagePlain(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1184,12 +1304,16 @@ func TestBuildCustomMessagePlain(t *testing.T) { msg := newTestMessage(t, foreignKR, "messageID", "addressID", "text/plain", "body", time.Now()) // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1206,7 +1330,7 @@ func TestBuildCustomMessageHTML(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1216,12 +1340,16 @@ func TestBuildCustomMessageHTML(t *testing.T) { msg := newTestMessage(t, foreignKR, "messageID", "addressID", "text/html", "body", time.Now()) // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1238,7 +1366,7 @@ func TestBuildCustomMessageEncrypted(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1252,12 +1380,16 @@ func TestBuildCustomMessageEncrypted(t *testing.T) { msg.Subject = "this is a subject to make sure we preserve subject" // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1283,7 +1415,7 @@ func TestBuildCustomMessagePlainWithAttachment(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1294,12 +1426,16 @@ func TestBuildCustomMessagePlainWithAttachment(t *testing.T) { att := addTestAttachment(t, foreignKR, msg, "attachID", "file.png", "image/png", "attachment", "attachment") // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1324,7 +1460,7 @@ func TestBuildCustomMessageHTMLWithAttachment(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1335,12 +1471,16 @@ func TestBuildCustomMessageHTMLWithAttachment(t *testing.T) { att := addTestAttachment(t, foreignKR, msg, "attachID", "file.png", "image/png", "attachment", "attachment") // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1365,7 +1505,7 @@ func TestBuildCustomMessageOnlyBodyIsUndecryptable(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1378,12 +1518,16 @@ func TestBuildCustomMessageOnlyBodyIsUndecryptable(t *testing.T) { att := addTestAttachment(t, kr, msg, "attachID", "file.png", "image/png", "attachment", "attachment") // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1407,7 +1551,7 @@ func TestBuildCustomMessageOnlyAttachmentIsUndecryptable(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() // Use the original keyring for encrypting the message; it should decrypt fine. @@ -1419,12 +1563,16 @@ func TestBuildCustomMessageOnlyAttachmentIsUndecryptable(t *testing.T) { att := addTestAttachment(t, foreignKR, msg, "attachID", "file.png", "image/png", "attachment", "attachment") // Tell the job to ignore decryption errors; a custom message will be returned instead of an error. - res, err := b.NewJobWithOptions( + job, done := b.NewJobWithOptions( context.Background(), newTestFetcher(m, kr, msg, att), msg.ID, JobOptions{IgnoreDecryptionErrors: true}, - ).GetResult() + ForegroundPriority, + ) + defer done() + + res, err := job.GetResult() require.NoError(t, err) section(t, res). @@ -1448,7 +1596,7 @@ func TestBuildFetchMessageFail(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1459,7 +1607,10 @@ func TestBuildFetchMessageFail(t *testing.T) { f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(nil, errors.New("oops")) // The job should fail, returning an error and a nil result. - res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() + job, done := b.NewJob(context.Background(), f, msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() assert.Error(t, err) assert.Nil(t, res) } @@ -1468,7 +1619,7 @@ func TestBuildFetchAttachmentFail(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1481,7 +1632,10 @@ func TestBuildFetchAttachmentFail(t *testing.T) { f.EXPECT().GetAttachment(gomock.Any(), msg.Attachments[0].ID).Return(nil, errors.New("oops")) // The job should fail, returning an error and a nil result. - res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() + job, done := b.NewJob(context.Background(), f, msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() assert.Error(t, err) assert.Nil(t, res) } @@ -1490,7 +1644,7 @@ func TestBuildNoSuchKeyRing(t *testing.T) { m := gomock.NewController(t) defer m.Finish() - b := NewBuilder(1, 1, 1) + b := NewBuilder(2, 2) defer b.Done() kr := tests.MakeKeyRing(t) @@ -1501,7 +1655,10 @@ func TestBuildNoSuchKeyRing(t *testing.T) { f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil) f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(nil, errors.New("oops")) - res, err := b.NewJob(context.Background(), f, msg.ID).GetResult() + job, done := b.NewJob(context.Background(), f, msg.ID, ForegroundPriority) + defer done() + + res, err := job.GetResult() assert.Error(t, err) assert.Nil(t, res) diff --git a/pkg/message/section.go b/pkg/message/section.go index 9154fa4c..ee62162c 100644 --- a/pkg/message/section.go +++ b/pkg/message/section.go @@ -38,9 +38,10 @@ type BodyStructure map[string]*SectionInfo // SectionInfo is used to hold data about parts of each section. type SectionInfo struct { - Header textproto.MIMEHeader + Header []byte Start, BSize, Size, Lines int reader io.Reader + isHeaderReadFinished bool } // Read will also count the final size of section. @@ -48,9 +49,38 @@ func (si *SectionInfo) Read(p []byte) (n int, err error) { n, err = si.reader.Read(p) si.Size += n si.Lines += bytes.Count(p, []byte("\n")) + + si.readHeader(p) return } +// readHeader appends read data to Header until empty line is found. +func (si *SectionInfo) readHeader(p []byte) { + if si.isHeaderReadFinished { + return + } + + si.Header = append(si.Header, p...) + + if i := bytes.Index(si.Header, []byte("\n\r\n")); i > 0 { + si.Header = si.Header[:i+3] + si.isHeaderReadFinished = true + return + } + + // textproto works also with simple line ending so we should be liberal + // as well. + if i := bytes.Index(si.Header, []byte("\n\n")); i > 0 { + si.Header = si.Header[:i+2] + si.isHeaderReadFinished = true + } +} + +// GetMIMEHeader parses bytes and return MIME header. +func (si *SectionInfo) GetMIMEHeader() (textproto.MIMEHeader, error) { + return textproto.NewReader(bufio.NewReader(bytes.NewReader(si.Header))).ReadMIMEHeader() +} + func NewBodyStructure(reader io.Reader) (structure *BodyStructure, err error) { structure = &BodyStructure{} err = structure.Parse(reader) @@ -93,14 +123,15 @@ func (bs *BodyStructure) parseAllChildSections(r io.Reader, currentPath []int, s bufInfo := bufio.NewReader(info) tp := textproto.NewReader(bufInfo) - if info.Header, err = tp.ReadMIMEHeader(); err != nil { + tpHeader, err := tp.ReadMIMEHeader() + if err != nil { return } bodyInfo := &SectionInfo{reader: tp.R} bodyReader := bufio.NewReader(bodyInfo) - mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type")) + mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type")) // If multipart, call getAllParts, else read to count lines. if (strings.HasPrefix(mediaType, "multipart/") || mediaType == rfc822Message) && params["boundary"] != "" { @@ -260,9 +291,9 @@ func (bs *BodyStructure) GetMailHeader() (header textproto.MIMEHeader, err error } // GetMailHeaderBytes returns the bytes with main mail header. -// Warning: It can contain extra lines or multipart comment. -func (bs *BodyStructure) GetMailHeaderBytes(wholeMail io.ReadSeeker) (header []byte, err error) { - return bs.GetSectionHeaderBytes(wholeMail, []int{}) +// Warning: It can contain extra lines. +func (bs *BodyStructure) GetMailHeaderBytes() (header []byte, err error) { + return bs.GetSectionHeaderBytes([]int{}) } func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byte, error) { @@ -283,22 +314,21 @@ func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byt } // GetSectionHeader returns the mime header of specified section. -func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (header textproto.MIMEHeader, err error) { +func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (textproto.MIMEHeader, error) { info, err := bs.getInfoCheckSection(sectionPath) if err != nil { - return + return nil, err } - header = info.Header - return + return info.GetMIMEHeader() } -func (bs *BodyStructure) GetSectionHeaderBytes(wholeMail io.ReadSeeker, sectionPath []int) (header []byte, err error) { +// GetSectionHeaderBytes returns raw header bytes of specified section. +func (bs *BodyStructure) GetSectionHeaderBytes(sectionPath []int) ([]byte, error) { info, err := bs.getInfoCheckSection(sectionPath) if err != nil { - return + return nil, err } - headerLength := info.Size - info.BSize - return goToOffsetAndReadNBytes(wholeMail, info.Start, headerLength) + return info.Header, nil } // IMAPBodyStructure will prepare imap bodystructure recurently for given part. @@ -309,7 +339,12 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body return } - mediaType, params, _ := pmmime.ParseMediaType(info.Header.Get("Content-Type")) + tpHeader, err := info.GetMIMEHeader() + if err != nil { + return + } + + mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type")) mediaTypeSep := strings.Split(mediaType, "/") @@ -324,19 +359,19 @@ func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.Body Lines: uint32(info.Lines), } - if val := info.Header.Get("Content-ID"); val != "" { + if val := tpHeader.Get("Content-ID"); val != "" { imapBS.Id = val } - if val := info.Header.Get("Content-Transfer-Encoding"); val != "" { + if val := tpHeader.Get("Content-Transfer-Encoding"); val != "" { imapBS.Encoding = val } - if val := info.Header.Get("Content-Description"); val != "" { + if val := tpHeader.Get("Content-Description"); val != "" { imapBS.Description = val } - if val := info.Header.Get("Content-Disposition"); val != "" { + if val := tpHeader.Get("Content-Disposition"); val != "" { imapBS.Disposition = val } diff --git a/pkg/message/section_test.go b/pkg/message/section_test.go index b9ba8a39..16f27564 100644 --- a/pkg/message/section_test.go +++ b/pkg/message/section_test.go @@ -21,7 +21,6 @@ import ( "bytes" "fmt" "io/ioutil" - "net/textproto" "path/filepath" "runtime" "sort" @@ -71,7 +70,9 @@ func TestParseBodyStructure(t *testing.T) { debug("%10s: %-50s %5s %5s %5s %5s", "section", "type", "start", "size", "bsize", "lines") for _, path := range paths { sec := (*bs)[path] - contentType := (*bs)[path].Header.Get("Content-Type") + header, err := sec.GetMIMEHeader() + require.NoError(t, err) + contentType := header.Get("Content-Type") debug("%10s: %-50s %5d %5d %5d %5d", path, contentType, sec.Start, sec.Size, sec.BSize, sec.Lines) require.Equal(t, expectedStructure[path], contentType) } @@ -100,7 +101,9 @@ func TestParseBodyStructurePGP(t *testing.T) { haveStructure := map[string]string{} for path := range *bs { - haveStructure[path] = (*bs)[path].Header.Get("Content-Type") + header, err := (*bs)[path].GetMIMEHeader() + require.NoError(t, err) + haveStructure[path] = header.Get("Content-Type") } require.Equal(t, expectedStructure, haveStructure) @@ -192,7 +195,7 @@ Content-Type: plain/text r.NoError(err, debug(wantPath, info, haveBody)) r.Equal(wantBody, string(haveBody), debug(wantPath, info, haveBody)) - haveHeader, err := bs.GetSectionHeaderBytes(strings.NewReader(wantMail), wantPath) + haveHeader, err := bs.GetSectionHeaderBytes(wantPath) r.NoError(err, debug(wantPath, info, haveHeader)) r.Equal(wantHeader, string(haveHeader), debug(wantPath, info, haveHeader)) } @@ -211,7 +214,7 @@ Content-Type: multipart/mixed; boundary="0000MAIN" bs, err := NewBodyStructure(structReader) require.NoError(t, err) - haveHeader, err := bs.GetMailHeaderBytes(strings.NewReader(sampleMail)) + haveHeader, err := bs.GetMailHeaderBytes() require.NoError(t, err) require.Equal(t, wantHeader, haveHeader) } @@ -533,18 +536,14 @@ func TestBodyStructureSerialize(t *testing.T) { r := require.New(t) want := &BodyStructure{ "1": { - Header: textproto.MIMEHeader{ - "Content": []string{"type"}, - }, - Start: 1, - Size: 2, - BSize: 3, - Lines: 4, + Header: []byte("Content: type"), + Start: 1, + Size: 2, + BSize: 3, + Lines: 4, }, "1.1.1": { - Header: textproto.MIMEHeader{ - "X-Pm-Key": []string{"id"}, - }, + Header: []byte("X-Pm-Key: id"), Start: 11, Size: 12, BSize: 13, @@ -562,3 +561,32 @@ func TestBodyStructureSerialize(t *testing.T) { (*want)["1.1.1"].reader = nil r.Equal(want, have) } + +func TestSectionInfoReadHeader(t *testing.T) { + r := require.New(t) + + testData := []struct { + wantHeader, mail string + }{ + { + "key1: val1\nkey2: val2\n\n", + "key1: val1\nkey2: val2\n\nbody is here\n\nand it is not confused", + }, + { + "key1:\n val1\n\n", + "key1:\n val1\n\nbody is here", + }, + { + "key1: val1\r\nkey2: val2\r\n\r\n", + "key1: val1\r\nkey2: val2\r\n\r\nbody is here\r\n\r\nand it is not confused", + }, + } + + for _, td := range testData { + bs, err := NewBodyStructure(strings.NewReader(td.mail)) + r.NoError(err, "case %q", td.mail) + haveHeader, err := bs.GetMailHeaderBytes() + r.NoError(err, "case %q", td.mail) + r.Equal(td.wantHeader, string(haveHeader), "case %q", td.mail) + } +} diff --git a/pkg/pchan/pchan.go b/pkg/pchan/pchan.go new file mode 100644 index 00000000..b353a6a6 --- /dev/null +++ b/pkg/pchan/pchan.go @@ -0,0 +1,131 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pchan + +import ( + "sort" + "sync" +) + +type PChan struct { + lock sync.Mutex + items []*Item + ready, done chan struct{} +} + +type Item struct { + ch *PChan + val interface{} + prio int + done chan struct{} +} + +func (item *Item) Wait() { + <-item.done +} + +func (item *Item) GetPriority() int { + item.ch.lock.Lock() + defer item.ch.lock.Unlock() + + return item.prio +} + +func (item *Item) SetPriority(priority int) { + item.ch.lock.Lock() + defer item.ch.lock.Unlock() + + item.prio = priority + + sort.Slice(item.ch.items, func(i, j int) bool { + return item.ch.items[i].prio < item.ch.items[j].prio + }) +} + +func New() *PChan { + return &PChan{ + ready: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (ch *PChan) Push(val interface{}, prio int) *Item { + defer ch.notify() + + return ch.push(val, prio) +} + +func (ch *PChan) Pop() (interface{}, int, bool) { + select { + case <-ch.ready: + val, prio := ch.pop() + return val, prio, true + + case <-ch.done: + return nil, 0, false + } +} + +func (ch *PChan) Close() { + select { + case <-ch.done: + return + + default: + close(ch.done) + } +} + +func (ch *PChan) push(val interface{}, prio int) *Item { + ch.lock.Lock() + defer ch.lock.Unlock() + + done := make(chan struct{}) + + item := &Item{ + ch: ch, + val: val, + prio: prio, + done: done, + } + + ch.items = append(ch.items, item) + + return item +} + +func (ch *PChan) pop() (interface{}, int) { + ch.lock.Lock() + defer ch.lock.Unlock() + + sort.Slice(ch.items, func(i, j int) bool { + return ch.items[i].prio < ch.items[j].prio + }) + + var item *Item + + item, ch.items = ch.items[len(ch.items)-1], ch.items[:len(ch.items)-1] + + defer close(item.done) + + return item.val, item.prio +} + +func (ch *PChan) notify() { + go func() { ch.ready <- struct{}{} }() +} diff --git a/pkg/pchan/pchan_test.go b/pkg/pchan/pchan_test.go new file mode 100644 index 00000000..fdd323de --- /dev/null +++ b/pkg/pchan/pchan_test.go @@ -0,0 +1,123 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pchan + +import ( + "sort" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPChanConcurrentPush(t *testing.T) { + ch := New() + + var wg sync.WaitGroup + + // We are going to test with 5 additional goroutines. + wg.Add(5) + + // Start 5 concurrent pushes. + go func() { defer wg.Done(); ch.Push(1, 1) }() + go func() { defer wg.Done(); ch.Push(2, 2) }() + go func() { defer wg.Done(); ch.Push(3, 3) }() + go func() { defer wg.Done(); ch.Push(4, 4) }() + go func() { defer wg.Done(); ch.Push(5, 5) }() + + // Wait for the items to be pushed. + wg.Wait() + + // All 5 should now be ready for popping. + require.Len(t, ch.items, 5) + + // They should be popped in priority order. + assert.Equal(t, 5, getValue(t, ch)) + assert.Equal(t, 4, getValue(t, ch)) + assert.Equal(t, 3, getValue(t, ch)) + assert.Equal(t, 2, getValue(t, ch)) + assert.Equal(t, 1, getValue(t, ch)) +} + +func TestPChanConcurrentPop(t *testing.T) { + ch := New() + + var wg sync.WaitGroup + + // We are going to test with 5 additional goroutines. + wg.Add(5) + + // Make a list to store the results in. + var res list + + // Start 5 concurrent pops; these consume any items pushed. + go func() { defer wg.Done(); res.append(getValue(t, ch)) }() + go func() { defer wg.Done(); res.append(getValue(t, ch)) }() + go func() { defer wg.Done(); res.append(getValue(t, ch)) }() + go func() { defer wg.Done(); res.append(getValue(t, ch)) }() + go func() { defer wg.Done(); res.append(getValue(t, ch)) }() + + // Push and block; items should be popped immediately by the waiting goroutines. + ch.Push(1, 1).Wait() + ch.Push(2, 2).Wait() + ch.Push(3, 3).Wait() + ch.Push(4, 4).Wait() + ch.Push(5, 5).Wait() + + // Wait for all items to be popped then close the result channel. + wg.Wait() + + assert.True(t, sort.IntsAreSorted(res.items)) +} + +func TestPChanClose(t *testing.T) { + ch := New() + + go ch.Push(1, 1) + + valOpen, _, okOpen := ch.Pop() + assert.True(t, okOpen) + assert.Equal(t, 1, valOpen) + + ch.Close() + + valClose, _, okClose := ch.Pop() + assert.False(t, okClose) + assert.Nil(t, valClose) +} + +type list struct { + items []int + mut sync.Mutex +} + +func (l *list) append(val int) { + l.mut.Lock() + defer l.mut.Unlock() + + l.items = append(l.items, val) +} + +func getValue(t *testing.T, ch *PChan) int { + val, _, ok := ch.Pop() + + assert.True(t, ok) + + return val.(int) +} diff --git a/pkg/pmapi/client_types.go b/pkg/pmapi/client_types.go index bf87bd60..56e9afb6 100644 --- a/pkg/pmapi/client_types.go +++ b/pkg/pmapi/client_types.go @@ -71,6 +71,7 @@ type Client interface { GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error) CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error) + GetUserKeyRing() (*crypto.KeyRing, error) KeyRingForAddressID(string) (kr *crypto.KeyRing, err error) GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error) } diff --git a/pkg/pmapi/messages.go b/pkg/pmapi/messages.go index 498213a3..10c19c14 100644 --- a/pkg/pmapi/messages.go +++ b/pkg/pmapi/messages.go @@ -175,7 +175,6 @@ type Message struct { CCList []*mail.Address BCCList []*mail.Address Time int64 // Unix time - Size int64 NumAttachments int ExpirationTime int64 // Unix time SpamScore int diff --git a/pkg/pmapi/mocks/mocks.go b/pkg/pmapi/mocks/mocks.go index 03085507..07b4b767 100644 --- a/pkg/pmapi/mocks/mocks.go +++ b/pkg/pmapi/mocks/mocks.go @@ -362,6 +362,21 @@ func (mr *MockClientMockRecorder) GetPublicKeysForEmail(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeysForEmail", reflect.TypeOf((*MockClient)(nil).GetPublicKeysForEmail), arg0, arg1) } +// GetUserKeyRing mocks base method. +func (m *MockClient) GetUserKeyRing() (*crypto.KeyRing, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserKeyRing") + ret0, _ := ret[0].(*crypto.KeyRing) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserKeyRing indicates an expected call of GetUserKeyRing. +func (mr *MockClientMockRecorder) GetUserKeyRing() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserKeyRing", reflect.TypeOf((*MockClient)(nil).GetUserKeyRing)) +} + // Import mocks base method. func (m *MockClient) Import(arg0 context.Context, arg1 pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) { m.ctrl.T.Helper() diff --git a/pkg/pmapi/users.go b/pkg/pmapi/users.go index d272327c..f60fd357 100644 --- a/pkg/pmapi/users.go +++ b/pkg/pmapi/users.go @@ -20,6 +20,7 @@ package pmapi import ( "context" + "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/getsentry/sentry-go" "github.com/go-resty/resty/v2" "github.com/pkg/errors" @@ -138,3 +139,12 @@ func (c *client) CurrentUser(ctx context.Context) (*User, error) { return c.UpdateUser(ctx) } + +// CurrentUser returns currently active user or user will be updated. +func (c *client) GetUserKeyRing() (*crypto.KeyRing, error) { + if c.userKeyRing == nil { + return nil, errors.New("user keyring is not available") + } + + return c.userKeyRing, nil +} diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go new file mode 100644 index 00000000..2b67fe54 --- /dev/null +++ b/pkg/pool/pool.go @@ -0,0 +1,129 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pool + +import "github.com/ProtonMail/proton-bridge/pkg/pchan" + +type WorkFunc func(interface{}, int) (interface{}, error) + +type DoneFunc func() + +type Pool struct { + jobCh *pchan.PChan +} + +func New(size int, work WorkFunc) *Pool { + jobCh := pchan.New() + + for i := 0; i < size; i++ { + go func() { + for { + val, prio, ok := jobCh.Pop() + if !ok { + return + } + + job, ok := val.(*Job) + if !ok { + panic("bad result type") + } + + res, err := work(job.req, prio) + if err != nil { + job.postFailure(err) + } else { + job.postSuccess(res) + } + + job.waitDone() + } + }() + } + + return &Pool{jobCh: jobCh} +} + +func (pool *Pool) NewJob(req interface{}, prio int) (*Job, DoneFunc) { + job := newJob(req) + + job.setItem(pool.jobCh.Push(job, prio)) + + return job, job.markDone +} + +type Job struct { + req interface{} + res interface{} + err error + + item *pchan.Item + + ready, done chan struct{} +} + +func newJob(req interface{}) *Job { + return &Job{ + req: req, + ready: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (job *Job) GetResult() (interface{}, error) { + <-job.ready + + return job.res, job.err +} + +func (job *Job) GetPriority() int { + return job.item.GetPriority() +} + +func (job *Job) SetPriority(prio int) { + job.item.SetPriority(prio) +} + +func (job *Job) postSuccess(res interface{}) { + defer close(job.ready) + + job.res = res +} + +func (job *Job) postFailure(err error) { + defer close(job.ready) + + job.err = err +} + +func (job *Job) setItem(item *pchan.Item) { + job.item = item +} + +func (job *Job) markDone() { + select { + case <-job.done: + return + + default: + close(job.done) + } +} + +func (job *Job) waitDone() { + <-job.done +} diff --git a/pkg/pool/pool_test.go b/pkg/pool/pool_test.go new file mode 100644 index 00000000..6c38b329 --- /dev/null +++ b/pkg/pool/pool_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package pool_test + +import ( + "testing" + + "github.com/ProtonMail/proton-bridge/pkg/pool" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPool(t *testing.T) { + pool := pool.New(2, func(req interface{}, prio int) (interface{}, error) { return req, nil }) + + job1, done1 := pool.NewJob("echo", 1) + defer done1() + + job2, done2 := pool.NewJob("this", 1) + defer done2() + + res2, err := job2.GetResult() + require.NoError(t, err) + + res1, err := job1.GetResult() + require.NoError(t, err) + + assert.Equal(t, "echo", res1) + assert.Equal(t, "this", res2) +} diff --git a/pkg/semaphore/semaphore.go b/pkg/semaphore/semaphore.go new file mode 100644 index 00000000..592e45a4 --- /dev/null +++ b/pkg/semaphore/semaphore.go @@ -0,0 +1,53 @@ +// Copyright (c) 2021 Proton Technologies AG +// +// This file is part of ProtonMail Bridge. +// +// ProtonMail Bridge is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// ProtonMail Bridge is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with ProtonMail Bridge. If not, see . + +package semaphore + +import "sync" + +type Semaphore struct { + ch chan struct{} + wg sync.WaitGroup +} + +func New(max int) Semaphore { + return Semaphore{ch: make(chan struct{}, max)} +} + +func (sem *Semaphore) Lock() { + sem.ch <- struct{}{} +} + +func (sem *Semaphore) Unlock() { + <-sem.ch +} + +func (sem *Semaphore) Go(fn func()) { + sem.Lock() + sem.wg.Add(1) + + go func() { + defer sem.Unlock() + defer sem.wg.Done() + + fn() + }() +} + +func (sem *Semaphore) Wait() { + sem.wg.Wait() +} diff --git a/test/context/bridge.go b/test/context/bridge.go index 66b60f60..c48bbc5f 100644 --- a/test/context/bridge.go +++ b/test/context/bridge.go @@ -21,11 +21,14 @@ import ( "time" "github.com/ProtonMail/proton-bridge/internal/bridge" + "github.com/ProtonMail/proton-bridge/internal/config/settings" "github.com/ProtonMail/proton-bridge/internal/config/useragent" "github.com/ProtonMail/proton-bridge/internal/constants" "github.com/ProtonMail/proton-bridge/internal/sentry" + "github.com/ProtonMail/proton-bridge/internal/store/cache" "github.com/ProtonMail/proton-bridge/internal/users" "github.com/ProtonMail/proton-bridge/pkg/listener" + "github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/pmapi" ) @@ -61,18 +64,28 @@ func (ctx *TestContext) RestartBridge() error { } // newBridgeInstance creates a new bridge instance configured to use the given config/credstore. +// NOTE(GODT-1158): Need some tests with on-disk cache as well! Configurable in feature file or envvar? func newBridgeInstance( t *bddT, locations bridge.Locator, - cache bridge.Cacher, - settings *fakeSettings, + cacheProvider bridge.CacheProvider, + fakeSettings *fakeSettings, credStore users.CredentialsStorer, eventListener listener.Listener, clientManager pmapi.Manager, ) *bridge.Bridge { - sentryReporter := sentry.NewReporter("bridge", constants.Version, useragent.New()) - panicHandler := &panicHandler{t: t} - updater := newFakeUpdater() - versioner := newFakeVersioner() - return bridge.New(locations, cache, settings, sentryReporter, panicHandler, eventListener, clientManager, credStore, updater, versioner) + return bridge.New( + locations, + cacheProvider, + fakeSettings, + sentry.NewReporter("bridge", constants.Version, useragent.New()), + &panicHandler{t: t}, + eventListener, + cache.NewInMemoryCache(100*(1<<20)), + message.NewBuilder(fakeSettings.GetInt(settings.FetchWorkers), fakeSettings.GetInt(settings.AttachmentWorkers)), + clientManager, + credStore, + newFakeUpdater(), + newFakeVersioner(), + ) } diff --git a/test/context/imap.go b/test/context/imap.go index 1f85f802..f5dd8ca0 100644 --- a/test/context/imap.go +++ b/test/context/imap.go @@ -58,7 +58,7 @@ func (ctx *TestContext) withIMAPServer() { port := ctx.settings.GetInt(settings.IMAPPortKey) tls, _ := tls.New(settingsPath).GetConfig() - backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.bridge) + backend := imap.NewIMAPBackend(ph, ctx.listener, ctx.cache, ctx.settings, ctx.bridge) server := imap.NewIMAPServer(ph, true, true, port, tls, backend, ctx.userAgent, ctx.listener) go server.ListenAndServe() diff --git a/test/fakeapi/user.go b/test/fakeapi/user.go index 7f42dda4..55144f87 100644 --- a/test/fakeapi/user.go +++ b/test/fakeapi/user.go @@ -125,6 +125,10 @@ func (api *FakePMAPI) Addresses() pmapi.AddressList { return *api.addresses } +func (api *FakePMAPI) GetUserKeyRing() (*crypto.KeyRing, error) { + return api.userKeyRing, nil +} + func (api *FakePMAPI) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) { return api.addrKeyRing[addrID], nil }