GODT-1158: Store full messages bodies on disk

- GODT-1158: simple on-disk cache in store
- GODT-1158: better member naming in event loop
- GODT-1158: create on-disk cache during bridge setup
- GODT-1158: better job options
- GODT-1158: rename GetLiteral to GetRFC822
- GODT-1158: rename events -> currentEvents
- GODT-1158: unlock cache per-user
- GODT-1158: clean up cache after logout
- GODT-1158: randomized encrypted cache passphrase
- GODT-1158: Opt out of on-disk cache in settings
- GODT-1158: free space in cache
- GODT-1158: make tests compile
- GODT-1158: optional compression
- GODT-1158: cache custom location
- GODT-1158: basic capacity checker
- GODT-1158: cache free space config
- GODT-1158: only unlock cache if pmapi client is unlocked as well
- GODT-1158: simple background sync worker
- GODT-1158: set size/bodystructure when caching message
- GODT-1158: limit store db update blocking with semaphore
- GODT-1158: dumb 10-semaphore
- GODT-1158: properly handle delete; remove bad bodystructure handling
- GODT-1158: hacky fix for caching after logout... baaaaad
- GODT-1158: cache worker
- GODT-1158: compute body structure lazily
- GODT-1158: cache size in store
- GODT-1158: notify cacher when adding to store
- GODT-1158: 15 second store cache watcher
- GODT-1158: enable cacher
- GODT-1158: better cache worker starting/stopping
- GODT-1158: limit cacher to less concurrency than disk cache
- GODT-1158: message builder prio + pchan pkg
- GODT-1158: fix pchan, use in message builder
- GODT-1158: no sem in cacher (rely on message builder prio)
- GODT-1158: raise priority of existing jobs when requested
- GODT-1158: pending messages in on-disk cache
- GODT-1158: WIP just a note about deleting messages from disk cache
- GODT-1158: pending wait when trying to write
- GODT-1158: pending.add to return bool
- GODT-1225: Headers in bodystructure are stored as bytes.
- GODT-1158: fixing header caching
- GODT-1158: don't cache in background
- GODT-1158: all concurrency set in settings
- GODT-1158: worker pools inside message builder
- GODT-1158: fix linter issues
- GODT-1158: remove completed builds from builder
- GODT-1158: remove builder pool
- GODT-1158: cacher defer job done properly
- GODT-1158: fix linter
- GODT-1299: Continue with bodystructure build if deserialization failed
- GODT-1324: Delete messages from the cache when they are deleted on the server
- GODT-1158: refactor cache tests
- GODT-1158: move builder to app/bridge
- GODT-1306: Migrate cache on disk when location is changed (and delete when disabled)
This commit is contained in:
James Houlahan
2021-07-30 12:20:38 +02:00
committed by Jakub
parent 5cb893fc1b
commit 6bd0739013
79 changed files with 2911 additions and 1387 deletions

View File

@ -32,7 +32,9 @@ import (
"github.com/ProtonMail/proton-bridge/internal/frontend/types"
"github.com/ProtonMail/proton-bridge/internal/imap"
"github.com/ProtonMail/proton-bridge/internal/smtp"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
@ -69,10 +71,21 @@ func New(base *base.Base) *cli.App {
func run(b *base.Base, c *cli.Context) error { // nolint[funlen]
tlsConfig, err := loadTLSConfig(b)
if err != nil {
logrus.WithError(err).Fatal("Failed to load TLS config")
return err
}
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, b.CM, b.Creds, b.Updater, b.Versioner)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, bridge)
cache, err := loadCache(b)
if err != nil {
return err
}
builder := message.NewBuilder(
b.Settings.GetInt(settings.FetchWorkers),
b.Settings.GetInt(settings.AttachmentWorkers),
)
bridge := bridge.New(b.Locations, b.Cache, b.Settings, b.SentryReporter, b.CrashHandler, b.Listener, cache, builder, b.CM, b.Creds, b.Updater, b.Versioner)
imapBackend := imap.NewIMAPBackend(b.CrashHandler, b.Listener, b.Cache, b.Settings, bridge)
smtpBackend := smtp.NewSMTPBackend(b.CrashHandler, b.Listener, b.Settings, bridge)
go func() {
@ -233,3 +246,35 @@ func checkAndHandleUpdate(u types.Updater, f frontend.Frontend, autoUpdate bool)
f.NotifySilentUpdateInstalled()
}
// NOTE(GODT-1158): How big should in-memory cache be?
// NOTE(GODT-1158): How to handle cache location migration if user changes custom path?
func loadCache(b *base.Base) (cache.Cache, error) {
if !b.Settings.GetBool(settings.CacheEnabledKey) {
return cache.NewInMemoryCache(100 * (1 << 20)), nil
}
var compressor cache.Compressor
// NOTE(GODT-1158): If user changes compression setting we have to nuke the cache.
if b.Settings.GetBool(settings.CacheCompressionKey) {
compressor = &cache.GZipCompressor{}
} else {
compressor = &cache.NoopCompressor{}
}
var path string
if customPath := b.Settings.Get(settings.CacheLocationKey); customPath != "" {
path = customPath
} else {
path = b.Cache.GetDefaultMessageCacheDir()
}
return cache.NewOnDiskCache(path, compressor, cache.Options{
MinFreeAbs: uint64(b.Settings.GetInt(settings.CacheMinFreeAbsKey)),
MinFreeRat: b.Settings.GetFloat64(settings.CacheMinFreeRatKey),
ConcurrentRead: b.Settings.GetInt(settings.CacheConcurrencyRead),
ConcurrentWrite: b.Settings.GetInt(settings.CacheConcurrencyWrite),
})
}

View File

@ -28,17 +28,17 @@ import (
"github.com/ProtonMail/proton-bridge/internal/constants"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/updater"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/listener"
logrus "github.com/sirupsen/logrus"
)
var (
log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
)
var log = logrus.WithField("pkg", "bridge") //nolint[gochecknoglobals]
type Bridge struct {
*users.Users
@ -52,11 +52,13 @@ type Bridge struct {
func New(
locations Locator,
cache Cacher,
s SettingsProvider,
cacheProvider CacheProvider,
setting SettingsProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
clientManager pmapi.Manager,
credStorer users.CredentialsStorer,
updater Updater,
@ -64,7 +66,7 @@ func New(
) *Bridge {
// Allow DoH before starting the app if the user has previously set this setting.
// This allows us to start even if protonmail is blocked.
if s.GetBool(settings.AllowProxyKey) {
if setting.GetBool(settings.AllowProxyKey) {
clientManager.AllowProxy()
}
@ -74,25 +76,25 @@ func New(
eventListener,
clientManager,
credStorer,
newStoreFactory(cache, sentryReporter, panicHandler, eventListener),
newStoreFactory(cacheProvider, sentryReporter, panicHandler, eventListener, cache, builder),
)
b := &Bridge{
Users: u,
locations: locations,
settings: s,
settings: setting,
clientManager: clientManager,
updater: updater,
versioner: versioner,
}
if s.GetBool(settings.FirstStartKey) {
if setting.GetBool(settings.FirstStartKey) {
if err := b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(constants.Version))); err != nil {
logrus.WithError(err).Error("Failed to send metric")
}
s.SetBool(settings.FirstStartKey, false)
setting.SetBool(settings.FirstStartKey, false)
}
go b.heartbeat()

View File

@ -23,47 +23,65 @@ import (
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
)
type storeFactory struct {
cache Cacher
cacheProvider CacheProvider
sentryReporter *sentry.Reporter
panicHandler users.PanicHandler
eventListener listener.Listener
storeCache *store.Cache
events *store.Events
cache cache.Cache
builder *message.Builder
}
func newStoreFactory(
cache Cacher,
cacheProvider CacheProvider,
sentryReporter *sentry.Reporter,
panicHandler users.PanicHandler,
eventListener listener.Listener,
cache cache.Cache,
builder *message.Builder,
) *storeFactory {
return &storeFactory{
cache: cache,
cacheProvider: cacheProvider,
sentryReporter: sentryReporter,
panicHandler: panicHandler,
eventListener: eventListener,
storeCache: store.NewCache(cache.GetIMAPCachePath()),
events: store.NewEvents(cacheProvider.GetIMAPCachePath()),
cache: cache,
builder: builder,
}
}
// New creates new store for given user.
func (f *storeFactory) New(user store.BridgeUser) (*store.Store, error) {
storePath := getUserStorePath(f.cache.GetDBDir(), user.ID())
return store.New(f.sentryReporter, f.panicHandler, user, f.eventListener, storePath, f.storeCache)
return store.New(
f.sentryReporter,
f.panicHandler,
user,
f.eventListener,
f.cache,
f.builder,
getUserStorePath(f.cacheProvider.GetDBDir(), user.ID()),
f.events,
)
}
// Remove removes all store files for given user.
func (f *storeFactory) Remove(userID string) error {
storePath := getUserStorePath(f.cache.GetDBDir(), userID)
return store.RemoveStore(f.storeCache, storePath, userID)
return store.RemoveStore(
f.events,
getUserStorePath(f.cacheProvider.GetDBDir(), userID),
userID,
)
}
// getUserStorePath returns the file path of the store database for the given userID.
func getUserStorePath(storeDir string, userID string) (path string) {
fileName := fmt.Sprintf("mailbox-%v.db", userID)
return filepath.Join(storeDir, fileName)
return filepath.Join(storeDir, fmt.Sprintf("mailbox-%v.db", userID))
}

View File

@ -28,7 +28,7 @@ type Locator interface {
ClearUpdates() error
}
type Cacher interface {
type CacheProvider interface {
GetIMAPCachePath() string
GetDBDir() string
}
@ -38,6 +38,7 @@ type SettingsProvider interface {
Set(key string, value string)
GetBool(key string) bool
SetBool(key string, val bool)
GetInt(key string) int
}
type Updater interface {

View File

@ -45,6 +45,11 @@ func (c *Cache) GetDBDir() string {
return c.getCurrentCacheDir()
}
// GetDefaultMessageCacheDir returns folder for cached messages files.
func (c *Cache) GetDefaultMessageCacheDir() string {
return filepath.Join(c.getCurrentCacheDir(), "messages")
}
// GetIMAPCachePath returns path to file with IMAP status.
func (c *Cache) GetIMAPCachePath() string {
return filepath.Join(c.getCurrentCacheDir(), "user_info.json")

View File

@ -100,18 +100,28 @@ func (p *keyValueStore) GetBool(key string) bool {
}
func (p *keyValueStore) GetInt(key string) int {
if p.Get(key) == "" {
return 0
}
value, err := strconv.Atoi(p.Get(key))
if err != nil {
logrus.WithError(err).Error("Cannot parse int")
}
return value
}
func (p *keyValueStore) GetFloat64(key string) float64 {
if p.Get(key) == "" {
return 0
}
value, err := strconv.ParseFloat(p.Get(key), 64)
if err != nil {
logrus.WithError(err).Error("Cannot parse float64")
}
return value
}

View File

@ -43,6 +43,16 @@ const (
UpdateChannelKey = "update_channel"
RolloutKey = "rollout"
PreferredKeychainKey = "preferred_keychain"
CacheEnabledKey = "cache_enabled"
CacheCompressionKey = "cache_compression"
CacheLocationKey = "cache_location"
CacheMinFreeAbsKey = "cache_min_free_abs"
CacheMinFreeRatKey = "cache_min_free_rat"
CacheConcurrencyRead = "cache_concurrent_read"
CacheConcurrencyWrite = "cache_concurrent_write"
IMAPWorkers = "imap_workers"
FetchWorkers = "fetch_workers"
AttachmentWorkers = "attachment_workers"
)
type Settings struct {
@ -80,6 +90,16 @@ func (s *Settings) setDefaultValues() {
s.setDefault(UpdateChannelKey, "")
s.setDefault(RolloutKey, fmt.Sprintf("%v", rand.Float64())) //nolint[gosec] G404 It is OK to use weak random number generator here
s.setDefault(PreferredKeychainKey, "")
s.setDefault(CacheEnabledKey, "true")
s.setDefault(CacheCompressionKey, "true")
s.setDefault(CacheLocationKey, "")
s.setDefault(CacheMinFreeAbsKey, "250000000")
s.setDefault(CacheMinFreeRatKey, "")
s.setDefault(CacheConcurrencyRead, "16")
s.setDefault(CacheConcurrencyWrite, "16")
s.setDefault(IMAPWorkers, "16")
s.setDefault(FetchWorkers, "16")
s.setDefault(AttachmentWorkers, "16")
s.setDefault(APIPortKey, DefaultAPIPort)
s.setDefault(IMAPPortKey, DefaultIMAPPort)

View File

@ -128,6 +128,24 @@ func New( //nolint[funlen]
})
fe.AddCmd(dohCmd)
// Cache-On-Disk commands.
codCmd := &ishell.Cmd{Name: "local-cache",
Help: "manage the local encrypted message cache",
}
codCmd.AddCmd(&ishell.Cmd{Name: "enable",
Help: "enable the local cache",
Func: fe.enableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{Name: "disable",
Help: "disable the local cache",
Func: fe.disableCacheOnDisk,
})
codCmd.AddCmd(&ishell.Cmd{Name: "change-location",
Help: "change the location of the local cache",
Func: fe.setCacheOnDiskLocation,
})
fe.AddCmd(codCmd)
// Updates commands.
updatesCmd := &ishell.Cmd{Name: "updates",
Help: "manage bridge updates",

View File

@ -19,6 +19,7 @@ package cli
import (
"fmt"
"os"
"strconv"
"strings"
@ -155,6 +156,67 @@ func (f *frontendCLI) disallowProxy(c *ishell.Context) {
}
}
func (f *frontendCLI) enableCacheOnDisk(c *ishell.Context) {
if f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already enabled.")
return
}
if f.yesNoQuestion("Are you sure you want to enable the local cache") {
// Set this back to the default location before enabling.
f.settings.Set(settings.CacheLocationKey, "")
if err := f.bridge.EnableCache(); err != nil {
f.Println("The local cache could not be enabled.")
return
}
f.settings.SetBool(settings.CacheEnabledKey, true)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) disableCacheOnDisk(c *ishell.Context) {
if !f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache is already disabled.")
return
}
if f.yesNoQuestion("Are you sure you want to disable the local cache") {
if err := f.bridge.DisableCache(); err != nil {
f.Println("The local cache could not be disabled.")
return
}
f.settings.SetBool(settings.CacheEnabledKey, false)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) setCacheOnDiskLocation(c *ishell.Context) {
if !f.settings.GetBool(settings.CacheEnabledKey) {
f.Println("The local cache must be enabled.")
return
}
if location := f.settings.Get(settings.CacheLocationKey); location != "" {
f.Println("The current local cache location is:", location)
}
if location := f.readStringInAttempts("Enter a new location for the cache", c.ReadLine, f.isCacheLocationUsable); location != "" {
if err := f.bridge.MigrateCache(f.settings.Get(settings.CacheLocationKey), location); err != nil {
f.Println("The local cache location could not be changed.")
return
}
f.settings.Set(settings.CacheLocationKey, location)
f.restarter.SetToRestart()
f.Stop()
}
}
func (f *frontendCLI) isPortFree(port string) bool {
port = strings.ReplaceAll(port, ":", "")
if port == "" || port == currentPort {
@ -171,3 +233,13 @@ func (f *frontendCLI) isPortFree(port string) bool {
}
return true
}
// NOTE(GODT-1158): Check free space in location.
func (f *frontendCLI) isCacheLocationUsable(location string) bool {
stat, err := os.Stat(location)
if err != nil {
return false
}
return stat.IsDir()
}

View File

@ -77,6 +77,9 @@ type Bridger interface {
ReportBug(osType, osVersion, description, accountName, address, emailClient string) error
AllowProxy()
DisallowProxy()
EnableCache() error
DisableCache() error
MigrateCache(from, to string) error
GetUpdateChannel() updater.UpdateChannel
SetUpdateChannel(updater.UpdateChannel) (needRestart bool, err error)
GetKeychainApp() string

View File

@ -37,21 +37,13 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/bridge"
"github.com/ProtonMail/proton-bridge/internal/config/settings"
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/emersion/go-imap"
goIMAPBackend "github.com/emersion/go-imap/backend"
)
const (
// NOTE: Each fetch worker has its own set of attach workers so there can be up to 20*5=100 API requests at once.
// This is a reasonable limit to not overwhelm API while still maintaining as much parallelism as possible.
fetchWorkers = 20 // In how many workers to fetch message (group list on IMAP).
attachWorkers = 5 // In how many workers to fetch attachments (for one message).
buildWorkers = 20 // In how many workers to build messages.
)
type panicHandler interface {
HandlePanic()
}
@ -61,26 +53,32 @@ type imapBackend struct {
bridge bridger
updates *imapUpdates
eventListener listener.Listener
listWorkers int
users map[string]*imapUser
usersLocker sync.Locker
builder *message.Builder
imapCache map[string]map[string]string
imapCachePath string
imapCacheLock *sync.RWMutex
}
type settingsProvider interface {
GetInt(string) int
}
// NewIMAPBackend returns struct implementing go-imap/backend interface.
func NewIMAPBackend(
panicHandler panicHandler,
eventListener listener.Listener,
cache cacheProvider,
setting settingsProvider,
bridge *bridge.Bridge,
) *imapBackend { //nolint[golint]
bridgeWrap := newBridgeWrap(bridge)
backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener)
imapWorkers := setting.GetInt(settings.IMAPWorkers)
backend := newIMAPBackend(panicHandler, cache, bridgeWrap, eventListener, imapWorkers)
go backend.monitorDisconnectedUsers()
@ -92,6 +90,7 @@ func newIMAPBackend(
cache cacheProvider,
bridge bridger,
eventListener listener.Listener,
listWorkers int,
) *imapBackend {
return &imapBackend{
panicHandler: panicHandler,
@ -102,10 +101,9 @@ func newIMAPBackend(
users: map[string]*imapUser{},
usersLocker: &sync.Mutex{},
builder: message.NewBuilder(fetchWorkers, attachWorkers, buildWorkers),
imapCachePath: cache.GetIMAPCachePath(),
imapCacheLock: &sync.RWMutex{},
listWorkers: listWorkers,
}
}

View File

@ -1,151 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"bytes"
"sort"
"sync"
"time"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
)
type key struct {
ID string
Timestamp int64
Size int
}
type oldestFirst []key
func (s oldestFirst) Len() int { return len(s) }
func (s oldestFirst) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s oldestFirst) Less(i, j int) bool { return s[i].Timestamp < s[j].Timestamp }
type cachedMessage struct {
key
data []byte
structure pkgMsg.BodyStructure
}
//nolint[gochecknoglobals]
var (
cacheTimeLimit = int64(1 * 60 * 60 * 1000) // milliseconds
cacheSizeLimit = 100 * 1000 * 1000 // B - MUST be larger than email max size limit (~ 25 MB)
mailCache = make(map[string]cachedMessage)
// cacheMutex takes care of one single operation, whereas buildMutex takes
// care of the whole action doing multiple operations. buildMutex will protect
// you from asking server or decrypting or building the same message more
// than once. When first request to build the message comes, it will block
// all other build requests. When the first one is done, all others are
// handled by cache, not doing anything twice. With cacheMutex we are safe
// only to not mess up with the cache, but we could end up downloading and
// building message twice.
cacheMutex = &sync.Mutex{}
buildMutex = &sync.Mutex{}
buildLocks = map[string]interface{}{}
)
func (m *cachedMessage) isValidOrDel() bool {
if m.key.Timestamp+cacheTimeLimit < timestamp() {
delete(mailCache, m.key.ID)
return false
}
return true
}
func timestamp() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}
func Clear() {
mailCache = make(map[string]cachedMessage)
}
// BuildLock locks per message level, not on global level.
// Multiple different messages can be building at once.
func BuildLock(messageID string) {
for {
buildMutex.Lock()
if _, ok := buildLocks[messageID]; ok { // if locked, wait
buildMutex.Unlock()
time.Sleep(10 * time.Millisecond)
} else { // if unlocked, lock it
buildLocks[messageID] = struct{}{}
buildMutex.Unlock()
return
}
}
}
func BuildUnlock(messageID string) {
buildMutex.Lock()
defer buildMutex.Unlock()
delete(buildLocks, messageID)
}
func LoadMail(mID string) (reader *bytes.Reader, structure *pkgMsg.BodyStructure) {
reader = &bytes.Reader{}
cacheMutex.Lock()
defer cacheMutex.Unlock()
if message, ok := mailCache[mID]; ok && message.isValidOrDel() {
reader = bytes.NewReader(message.data)
structure = &message.structure
// Update timestamp to keep emails which are used often.
message.Timestamp = timestamp()
}
return
}
func SaveMail(mID string, msg []byte, structure *pkgMsg.BodyStructure) {
cacheMutex.Lock()
defer cacheMutex.Unlock()
newMessage := cachedMessage{
key: key{
ID: mID,
Timestamp: timestamp(),
Size: len(msg),
},
data: msg,
structure: *structure,
}
// Remove old and reduce size.
totalSize := 0
messageList := []key{}
for _, message := range mailCache {
if message.isValidOrDel() {
messageList = append(messageList, message.key)
totalSize += message.key.Size
}
}
sort.Sort(oldestFirst(messageList))
var oldest key
for totalSize+newMessage.key.Size >= cacheSizeLimit {
oldest, messageList = messageList[0], messageList[1:]
delete(mailCache, oldest.ID)
totalSize -= oldest.Size
}
// Write new.
mailCache[mID] = newMessage
}

View File

@ -1,98 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"fmt"
"testing"
"time"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/stretchr/testify/require"
)
var bs = &pkgMsg.BodyStructure{} //nolint[gochecknoglobals]
const testUID = "testmsg"
func TestSaveAndLoad(t *testing.T) {
msg := []byte("Test message")
SaveMail(testUID, msg, bs)
require.Equal(t, mailCache[testUID].data, msg)
reader, _ := LoadMail(testUID)
require.Equal(t, reader.Len(), len(msg))
stored := make([]byte, len(msg))
_, _ = reader.Read(stored)
require.Equal(t, stored, msg)
}
func TestMissing(t *testing.T) {
reader, _ := LoadMail("non-existing")
require.Equal(t, reader.Len(), 0)
}
func TestClearOld(t *testing.T) {
cacheTimeLimit = 10
msg := []byte("Test message")
SaveMail(testUID, msg, bs)
time.Sleep(100 * time.Millisecond)
reader, _ := LoadMail(testUID)
require.Equal(t, reader.Len(), 0)
}
func TestClearBig(t *testing.T) {
r := require.New(t)
wantMessage := []byte("Test message")
wantCacheSize := 3
nTestMessages := wantCacheSize * wantCacheSize
cacheSizeLimit = wantCacheSize*len(wantMessage) + 1
cacheTimeLimit = int64(1 << 20) // be sure the message will survive
// It should never have more than nSize items.
for i := 0; i < nTestMessages; i++ {
time.Sleep(1 * time.Millisecond)
SaveMail(fmt.Sprintf("%s%d", testUID, i), wantMessage, bs)
r.LessOrEqual(len(mailCache), wantCacheSize, "cache too big when %d", i)
}
// Check that the oldest are deleted first.
for i := 0; i < nTestMessages; i++ {
iUID := fmt.Sprintf("%s%d", testUID, i)
reader, _ := LoadMail(iUID)
mail := mailCache[iUID]
if i < (nTestMessages - wantCacheSize) {
r.Zero(reader.Len(), "LoadMail should return empty, but have %s for %s time %d ", string(mail.data), iUID, mail.key.Timestamp)
} else {
stored := make([]byte, len(wantMessage))
_, err := reader.Read(stored)
r.NoError(err)
r.Equal(wantMessage, stored, "LoadMail returned wrong message: %s for %s time %d", stored, iUID, mail.key.Timestamp)
}
}
}
func TestConcurency(t *testing.T) {
msg := []byte("Test message")
for i := 0; i < 10; i++ {
go SaveMail(fmt.Sprintf("%s%d", testUID, i), msg, bs)
}
}

View File

@ -37,12 +37,10 @@ type imapMailbox struct {
storeUser storeUserProvider
storeAddress storeAddressProvider
storeMailbox storeMailboxProvider
builder *message.Builder
}
// newIMAPMailbox returns struct implementing go-imap/mailbox interface.
func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider, builder *message.Builder) *imapMailbox {
func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox storeMailboxProvider) *imapMailbox {
return &imapMailbox{
panicHandler: panicHandler,
user: user,
@ -56,8 +54,6 @@ func newIMAPMailbox(panicHandler panicHandler, user *imapUser, storeMailbox stor
storeUser: user.storeUser,
storeAddress: user.storeAddress,
storeMailbox: storeMailbox,
builder: builder,
}
}

View File

@ -19,21 +19,13 @@ package imap
import (
"bytes"
"context"
"github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
func (im *imapMailbox) getMessage(
storeMessage storeMessageProvider,
items []imap.FetchItem,
msgBuildCountHistogram *msgBuildCountHistogram,
) (msg *imap.Message, err error) {
func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []imap.FetchItem) (msg *imap.Message, err error) {
msglog := im.log.WithField("msgID", storeMessage.ID())
msglog.Trace("Getting message")
@ -69,9 +61,12 @@ func (im *imapMailbox) getMessage(
// There is no point having message older than RFC itself, it's not possible.
msg.InternalDate = message.SanitizeMessageDate(m.Time)
case imap.FetchRFC822Size:
if msg.Size, err = im.getSize(storeMessage); err != nil {
size, err := storeMessage.GetRFC822Size()
if err != nil {
return nil, err
}
msg.Size = size
case imap.FetchUid:
if msg.Uid, err = storeMessage.UID(); err != nil {
return nil, err
@ -79,7 +74,7 @@ func (im *imapMailbox) getMessage(
case imap.FetchAll, imap.FetchFast, imap.FetchFull, imap.FetchRFC822, imap.FetchRFC822Header, imap.FetchRFC822Text:
fallthrough // this is list of defined items by go-imap, but items can be also sections generated from requests
default:
if err = im.getLiteralForSection(item, msg, storeMessage, msgBuildCountHistogram); err != nil {
if err = im.getLiteralForSection(item, msg, storeMessage); err != nil {
return
}
}
@ -88,35 +83,7 @@ func (im *imapMailbox) getMessage(
return msg, err
}
// getSize returns cached size or it will build the message, save the size in
// DB and then returns the size after build.
//
// We are storing size in DB as part of pmapi messages metada. The size
// attribute on the server represents size of encrypted body. The value is
// cleared in Bridge and the final decrypted size (including header, attachment
// and MIME structure) is computed after building the message.
func (im *imapMailbox) getSize(storeMessage storeMessageProvider) (uint32, error) {
m := storeMessage.Message()
if m.Size <= 0 {
im.log.WithField("msgID", m.ID).Debug("Size unknown - downloading body")
// We are sure the size is not a problem right now. Clients
// might not first check sizes of all messages so we couldn't
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude getting size from the
// counting and see build count as real message build.
if _, _, err := im.getBodyAndStructure(storeMessage, nil); err != nil {
return 0, err
}
}
return uint32(m.Size), nil
}
func (im *imapMailbox) getLiteralForSection(
itemSection imap.FetchItem,
msg *imap.Message,
storeMessage storeMessageProvider,
msgBuildCountHistogram *msgBuildCountHistogram,
) error {
func (im *imapMailbox) getLiteralForSection(itemSection imap.FetchItem, msg *imap.Message, storeMessage storeMessageProvider) error {
section, err := imap.ParseBodySectionName(itemSection)
if err != nil {
log.WithError(err).Warn("Failed to parse body section name; part will be skipped")
@ -124,7 +91,7 @@ func (im *imapMailbox) getLiteralForSection(
}
var literal imap.Literal
if literal, err = im.getMessageBodySection(storeMessage, section, msgBuildCountHistogram); err != nil {
if literal, err = im.getMessageBodySection(storeMessage, section); err != nil {
return err
}
@ -149,88 +116,25 @@ func (im *imapMailbox) getBodyStructure(storeMessage storeMessageProvider) (bs *
// be sure if seeing 1st or 2nd sync is all right or not.
// Therefore, it's better to exclude first body structure fetch
// from the counting and see build count as real message build.
if bs, _, err = im.getBodyAndStructure(storeMessage, nil); err != nil {
if bs, _, err = im.getBodyAndStructure(storeMessage); err != nil {
return
}
}
return
}
func (im *imapMailbox) getBodyAndStructure(
storeMessage storeMessageProvider, msgBuildCountHistogram *msgBuildCountHistogram,
) (
structure *message.BodyStructure, bodyReader *bytes.Reader, err error,
) {
m := storeMessage.Message()
id := im.storeUser.UserID() + m.ID
cache.BuildLock(id)
defer cache.BuildUnlock(id)
bodyReader, structure = cache.LoadMail(id)
// return the message which was found in cache
if bodyReader.Len() != 0 && structure != nil {
return structure, bodyReader, nil
func (im *imapMailbox) getBodyAndStructure(storeMessage storeMessageProvider) (*message.BodyStructure, *bytes.Reader, error) {
rfc822, err := storeMessage.GetRFC822()
if err != nil {
return nil, nil, err
}
structure, body, err := im.buildMessage(m)
bodyReader = bytes.NewReader(body)
size := int64(len(body))
l := im.log.WithField("newSize", size).WithField("msgID", m.ID)
if err != nil || structure == nil || size == 0 {
l.WithField("hasStructure", structure != nil).Warn("Failed to build message")
return structure, bodyReader, err
structure, err := storeMessage.GetBodyStructure()
if err != nil {
return nil, nil, err
}
// Save the size, body structure and header even for messages which
// were unable to decrypt. Hence they doesn't have to be computed every
// time.
m.Size = size
cacheMessageInStore(storeMessage, structure, body, l)
if msgBuildCountHistogram != nil {
times, errCount := storeMessage.IncreaseBuildCount()
if errCount != nil {
l.WithError(errCount).Warn("Cannot increase build count")
}
msgBuildCountHistogram.add(times)
}
// Drafts can change therefore we don't want to cache them.
if !isMessageInDraftFolder(m) {
cache.SaveMail(id, body, structure)
}
return structure, bodyReader, err
}
func cacheMessageInStore(storeMessage storeMessageProvider, structure *message.BodyStructure, body []byte, l *logrus.Entry) {
m := storeMessage.Message()
if errSize := storeMessage.SetSize(m.Size); errSize != nil {
l.WithError(errSize).Warn("Cannot update size while building")
}
if structure != nil && !isMessageInDraftFolder(m) {
if errStruct := storeMessage.SetBodyStructure(structure); errStruct != nil {
l.WithError(errStruct).Warn("Cannot update bodystructure while building")
}
}
header, errHead := structure.GetMailHeaderBytes(bytes.NewReader(body))
if errHead == nil && len(header) != 0 {
if errStore := storeMessage.SetHeader(header); errStore != nil {
l.WithError(errStore).Warn("Cannot update header in store")
}
} else {
l.WithError(errHead).Warn("Cannot get header bytes from structure")
}
}
func isMessageInDraftFolder(m *pmapi.Message) bool {
for _, labelID := range m.LabelIDs {
if labelID == pmapi.DraftLabel {
return true
}
}
return false
return structure, bytes.NewReader(rfc822), nil
}
// This will download message (or read from cache) and pick up the section,
@ -246,11 +150,7 @@ func isMessageInDraftFolder(m *pmapi.Message) bool {
// For all other cases it is necessary to download and decrypt the message
// and drop the header which was obtained from cache. The header will
// will be stored in DB once successfully built. Check `getBodyAndStructure`.
func (im *imapMailbox) getMessageBodySection(
storeMessage storeMessageProvider,
section *imap.BodySectionName,
msgBuildCountHistogram *msgBuildCountHistogram,
) (imap.Literal, error) {
func (im *imapMailbox) getMessageBodySection(storeMessage storeMessageProvider, section *imap.BodySectionName) (imap.Literal, error) {
var header []byte
var response []byte
@ -260,7 +160,7 @@ func (im *imapMailbox) getMessageBodySection(
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
header = storeMessage.GetHeader()
} else {
structure, bodyReader, err := im.getBodyAndStructure(storeMessage, msgBuildCountHistogram)
structure, bodyReader, err := im.getBodyAndStructure(storeMessage)
if err != nil {
return nil, err
}
@ -276,7 +176,7 @@ func (im *imapMailbox) getMessageBodySection(
case section.Specifier == imap.MIMESpecifier: // The MIME part specifier refers to the [MIME-IMB] header for this part.
fallthrough
case section.Specifier == imap.HeaderSpecifier:
header, err = structure.GetSectionHeaderBytes(bodyReader, section.Path)
header, err = structure.GetSectionHeaderBytes(section.Path)
default:
err = errors.New("Unknown specifier " + string(section.Specifier))
}
@ -293,30 +193,3 @@ func (im *imapMailbox) getMessageBodySection(
// Trim any output if requested.
return bytes.NewBuffer(section.ExtractPartial(response)), nil
}
// buildMessage from PM to IMAP.
func (im *imapMailbox) buildMessage(m *pmapi.Message) (*message.BodyStructure, []byte, error) {
body, err := im.builder.NewJobWithOptions(
context.Background(),
im.user.client(),
m.ID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
).GetResult()
if err != nil {
return nil, nil, err
}
structure, err := message.NewBodyStructure(bytes.NewReader(body))
if err != nil {
return nil, nil, err
}
return structure, body, nil
}

View File

@ -479,11 +479,16 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
}
// Filter by size (only if size was already calculated).
if m.Size > 0 {
if criteria.Larger != 0 && m.Size <= int64(criteria.Larger) {
size, err := storeMessage.GetRFC822Size()
if err != nil {
return nil, err
}
if size > 0 {
if criteria.Larger != 0 && int64(size) <= int64(criteria.Larger) {
continue
}
if criteria.Smaller != 0 && m.Size >= int64(criteria.Smaller) {
if criteria.Smaller != 0 && int64(size) >= int64(criteria.Smaller) {
continue
}
}
@ -513,13 +518,12 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
//
// Messages must be sent to msgResponse. When the function returns, msgResponse must be closed.
func (im *imapMailbox) ListMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) error {
msgBuildCountHistogram := newMsgBuildCountHistogram()
return im.logCommand(func() error {
return im.listMessages(isUID, seqSet, items, msgResponse, msgBuildCountHistogram)
}, "FETCH", isUID, seqSet, items, msgBuildCountHistogram)
return im.listMessages(isUID, seqSet, items, msgResponse)
}, "FETCH", isUID, seqSet, items)
}
func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message, msgBuildCountHistogram *msgBuildCountHistogram) (err error) { //nolint[funlen]
func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []imap.FetchItem, msgResponse chan<- *imap.Message) (err error) { //nolint[funlen]
defer func() {
close(msgResponse)
if err != nil {
@ -564,7 +568,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil, err
}
msg, err := im.getMessage(storeMessage, items, msgBuildCountHistogram)
msg, err := im.getMessage(storeMessage, items)
if err != nil {
err = fmt.Errorf("list message build: %v", err)
l.WithField("metaID", storeMessage.ID()).Error(err)
@ -594,7 +598,7 @@ func (im *imapMailbox) listMessages(isUID bool, seqSet *imap.SeqSet, items []ima
return nil
}
err = parallel.RunParallel(fetchWorkers, input, processCallback, collectCallback)
err = parallel.RunParallel(im.user.backend.listWorkers, input, processCallback, collectCallback)
if err != nil {
return err
}

View File

@ -1,65 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imap
import (
"fmt"
"sync"
)
// msgBuildCountHistogram is used to analyse and log the number of repetitive
// downloads of requested messages per one fetch. The number of builds per each
// messageID is stored in persistent database. The msgBuildCountHistogram will
// take this number for each message in ongoing fetch and create histogram of
// repeats.
//
// Example: During `fetch 1:300` there were
// - 100 messages were downloaded first time
// - 100 messages were downloaded second time
// - 99 messages were downloaded 10th times
// - 1 messages were downloaded 100th times.
type msgBuildCountHistogram struct {
// Key represents how many times message was build.
// Value stores how many messages are build X times based on the key.
counts map[uint32]uint32
lock sync.Locker
}
func newMsgBuildCountHistogram() *msgBuildCountHistogram {
return &msgBuildCountHistogram{
counts: map[uint32]uint32{},
lock: &sync.Mutex{},
}
}
func (c *msgBuildCountHistogram) String() string {
res := ""
for nRebuild, counts := range c.counts {
if res != "" {
res += ", "
}
res += fmt.Sprintf("[%d]:%d", nRebuild, counts)
}
return res
}
func (c *msgBuildCountHistogram) add(nRebuild uint32) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[nRebuild]++
}

View File

@ -80,7 +80,6 @@ type storeMailboxProvider interface {
GetDelimiter() string
GetMessage(apiID string) (storeMessageProvider, error)
FetchMessage(apiID string) (storeMessageProvider, error)
LabelMessages(apiID []string) error
UnlabelMessages(apiID []string) error
MarkMessagesRead(apiID []string) error
@ -100,14 +99,12 @@ type storeMessageProvider interface {
Message() *pmapi.Message
IsMarkedDeleted() bool
SetSize(int64) error
SetHeader([]byte) error
GetHeader() []byte
GetRFC822() ([]byte, error)
GetRFC822Size() (uint32, error)
GetMIMEHeader() textproto.MIMEHeader
IsFullHeaderCached() bool
SetBodyStructure(*pkgMsg.BodyStructure) error
GetBodyStructure() (*pkgMsg.BodyStructure, error)
IncreaseBuildCount() (uint32, error)
}
type storeUserWrap struct {
@ -165,7 +162,3 @@ func newStoreMailboxWrap(mailbox *store.Mailbox) *storeMailboxWrap {
func (s *storeMailboxWrap) GetMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.GetMessage(apiID)
}
func (s *storeMailboxWrap) FetchMessage(apiID string) (storeMessageProvider, error) {
return s.Mailbox.FetchMessage(apiID)
}

View File

@ -135,7 +135,7 @@ func (iu *imapUser) ListMailboxes(showOnlySubcribed bool) ([]goIMAPBackend.Mailb
if showOnlySubcribed && !iu.isSubscribed(storeMailbox.LabelID()) {
continue
}
mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder)
mailbox := newIMAPMailbox(iu.panicHandler, iu, storeMailbox)
mailboxes = append(mailboxes, mailbox)
}
@ -167,7 +167,7 @@ func (iu *imapUser) GetMailbox(name string) (mb goIMAPBackend.Mailbox, err error
return
}
return newIMAPMailbox(iu.panicHandler, iu, storeMailbox, iu.backend.builder), nil
return newIMAPMailbox(iu.panicHandler, iu, storeMailbox), nil
}
// CreateMailbox creates a new mailbox.

View File

@ -18,99 +18,113 @@
package store
import (
"encoding/json"
"os"
"sync"
"github.com/pkg/errors"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
// Cache caches the last event IDs for all accounts (there should be only one instance).
type Cache struct {
// cache is map from userID => key (such as last event) => value (such as event ID).
cache map[string]map[string]string
path string
lock *sync.RWMutex
}
const passphraseKey = "passphrase"
// NewCache constructs a new cache at the given path.
func NewCache(path string) *Cache {
return &Cache{
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Cache) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.loadCache(); err != nil {
log.WithError(err).Warn("Problem to load store cache")
}
if c.cache == nil {
c.cache = map[string]map[string]string{}
}
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
return c.cache[userID]["events"]
}
func (c *Cache) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
c.cache[userID]["events"] = eventID
return c.saveCache()
}
func (c *Cache) loadCache() error {
if c.cache != nil {
return nil
}
f, err := os.Open(c.path)
// UnlockCache unlocks the cache for the user with the given keyring.
func (store *Store) UnlockCache(kr *crypto.KeyRing) error {
passphrase, err := store.getCachePassphrase()
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.cache)
}
if passphrase == nil {
if passphrase, err = crypto.RandomToken(32); err != nil {
return err
}
func (c *Cache) saveCache() error {
if c.cache == nil {
return errors.New("events: cannot save cache: cache is nil")
enc, err := kr.Encrypt(crypto.NewPlainMessage(passphrase), nil)
if err != nil {
return err
}
if err := store.setCachePassphrase(enc.GetBinary()); err != nil {
return err
}
} else {
dec, err := kr.Decrypt(crypto.NewPGPMessage(passphrase), nil, crypto.GetUnixTime())
if err != nil {
return err
}
passphrase = dec.GetBinary()
}
f, err := os.Create(c.path)
if err := store.cache.Unlock(store.user.ID(), passphrase); err != nil {
return err
}
store.cacher.start()
return nil
}
func (store *Store) getCachePassphrase() ([]byte, error) {
var passphrase []byte
if err := store.db.View(func(tx *bolt.Tx) error {
passphrase = tx.Bucket(cachePassphraseBucket).Get([]byte(passphraseKey))
return nil
}); err != nil {
return nil, err
}
return passphrase, nil
}
func (store *Store) setCachePassphrase(passphrase []byte) error {
return store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(cachePassphraseBucket).Put([]byte(passphraseKey), passphrase)
})
}
func (store *Store) clearCachePassphrase() error {
return store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(cachePassphraseBucket).Delete([]byte(passphraseKey))
})
}
func (store *Store) getCachedMessage(messageID string) ([]byte, error) {
if store.cache.Has(store.user.ID(), messageID) {
return store.cache.Get(store.user.ID(), messageID)
}
job, done := store.newBuildJob(messageID, message.ForegroundPriority)
defer done()
literal, err := job.GetResult()
if err != nil {
return nil, err
}
// NOTE(GODT-1158): No need to block until cache has been set; do this async?
if err := store.cache.Set(store.user.ID(), messageID, literal); err != nil {
logrus.WithError(err).Error("Failed to cache message")
}
return literal, nil
}
// IsCached returns whether the given message already exists in the cache.
func (store *Store) IsCached(messageID string) bool {
return store.cache.Has(store.user.ID(), messageID)
}
// BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache.
// It builds with background priority.
func (store *Store) BuildAndCacheMessage(messageID string) error {
job, done := store.newBuildJob(messageID, message.BackgroundPriority)
defer done()
literal, err := job.GetResult()
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.cache)
}
func (c *Cache) clearCacheUser(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache == nil {
log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user from event loop cache")
delete(c.cache, userID)
return c.saveCache()
return store.cache.Set(store.user.ID(), messageID, literal)
}

73
internal/store/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,73 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOnDiskCacheNoCompression(t *testing.T) {
cache, err := NewOnDiskCache(t.TempDir(), &NoopCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
require.NoError(t, err)
testCache(t, cache)
}
func TestOnDiskCacheGZipCompression(t *testing.T) {
cache, err := NewOnDiskCache(t.TempDir(), &GZipCompressor{}, Options{ConcurrentRead: runtime.NumCPU(), ConcurrentWrite: runtime.NumCPU()})
require.NoError(t, err)
testCache(t, cache)
}
func TestInMemoryCache(t *testing.T) {
testCache(t, NewInMemoryCache(1<<20))
}
func testCache(t *testing.T, cache Cache) {
assert.NoError(t, cache.Unlock("userID1", []byte("my secret passphrase")))
assert.NoError(t, cache.Unlock("userID2", []byte("my other passphrase")))
getSetCachedMessage(t, cache, "userID1", "messageID1", "some secret")
assert.True(t, cache.Has("userID1", "messageID1"))
getSetCachedMessage(t, cache, "userID2", "messageID2", "some other secret")
assert.True(t, cache.Has("userID2", "messageID2"))
assert.NoError(t, cache.Rem("userID1", "messageID1"))
assert.False(t, cache.Has("userID1", "messageID1"))
assert.NoError(t, cache.Rem("userID2", "messageID2"))
assert.False(t, cache.Has("userID2", "messageID2"))
assert.NoError(t, cache.Delete("userID1"))
assert.NoError(t, cache.Delete("userID2"))
}
func getSetCachedMessage(t *testing.T, cache Cache, userID, messageID, secret string) {
assert.NoError(t, cache.Set(userID, messageID, []byte(secret)))
data, err := cache.Get(userID, messageID)
assert.NoError(t, err)
assert.Equal(t, []byte(secret), data)
}

33
internal/store/cache/compressor.go vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Compressor interface {
Compress([]byte) ([]byte, error)
Decompress([]byte) ([]byte, error)
}
type NoopCompressor struct{}
func (NoopCompressor) Compress(dec []byte) ([]byte, error) {
return dec, nil
}
func (NoopCompressor) Decompress(cmp []byte) ([]byte, error) {
return cmp, nil
}

60
internal/store/cache/compressor_gzip.go vendored Normal file
View File

@ -0,0 +1,60 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"bytes"
"compress/gzip"
)
type GZipCompressor struct{}
func (GZipCompressor) Compress(dec []byte) ([]byte, error) {
buf := new(bytes.Buffer)
zw := gzip.NewWriter(buf)
if _, err := zw.Write(dec); err != nil {
return nil, err
}
if err := zw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (GZipCompressor) Decompress(cmp []byte) ([]byte, error) {
zr, err := gzip.NewReader(bytes.NewReader(cmp))
if err != nil {
return nil, err
}
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(zr); err != nil {
return nil, err
}
if err := zr.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

244
internal/store/cache/disk.go vendored Normal file
View File

@ -0,0 +1,244 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"errors"
"io/ioutil"
"os"
"path/filepath"
"sync"
"github.com/ProtonMail/proton-bridge/pkg/semaphore"
"github.com/ricochet2200/go-disk-usage/du"
)
var ErrLowSpace = errors.New("not enough free space left on device")
type onDiskCache struct {
path string
opts Options
gcm map[string]cipher.AEAD
cmp Compressor
rsem, wsem semaphore.Semaphore
pending *pending
diskSize uint64
diskFree uint64
once *sync.Once
lock sync.Mutex
}
func NewOnDiskCache(path string, cmp Compressor, opts Options) (Cache, error) {
if err := os.MkdirAll(path, 0700); err != nil {
return nil, err
}
usage := du.NewDiskUsage(path)
// NOTE(GODT-1158): use Available() or Free()?
return &onDiskCache{
path: path,
opts: opts,
gcm: make(map[string]cipher.AEAD),
cmp: cmp,
rsem: semaphore.New(opts.ConcurrentRead),
wsem: semaphore.New(opts.ConcurrentWrite),
pending: newPending(),
diskSize: usage.Size(),
diskFree: usage.Available(),
once: &sync.Once{},
}, nil
}
func (c *onDiskCache) Unlock(userID string, passphrase []byte) error {
hash := sha256.New()
if _, err := hash.Write(passphrase); err != nil {
return err
}
aes, err := aes.NewCipher(hash.Sum(nil))
if err != nil {
return err
}
gcm, err := cipher.NewGCM(aes)
if err != nil {
return err
}
if err := os.MkdirAll(c.getUserPath(userID), 0700); err != nil {
return err
}
c.gcm[userID] = gcm
return nil
}
func (c *onDiskCache) Delete(userID string) error {
defer c.update()
return os.RemoveAll(c.getUserPath(userID))
}
// Has returns whether the given message exists in the cache.
func (c *onDiskCache) Has(userID, messageID string) bool {
c.pending.wait(c.getMessagePath(userID, messageID))
c.rsem.Lock()
defer c.rsem.Unlock()
_, err := os.Stat(c.getMessagePath(userID, messageID))
switch {
case err == nil:
return true
case os.IsNotExist(err):
return false
default:
panic(err)
}
}
func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) {
enc, err := c.readFile(c.getMessagePath(userID, messageID))
if err != nil {
return nil, err
}
cmp, err := c.gcm[userID].Open(nil, enc[:c.gcm[userID].NonceSize()], enc[c.gcm[userID].NonceSize():], nil)
if err != nil {
return nil, err
}
return c.cmp.Decompress(cmp)
}
func (c *onDiskCache) Set(userID, messageID string, literal []byte) error {
nonce := make([]byte, c.gcm[userID].NonceSize())
if _, err := rand.Read(nonce); err != nil {
return err
}
cmp, err := c.cmp.Compress(literal)
if err != nil {
return err
}
// NOTE(GODT-1158): How to properly handle low space? Don't return error, that's bad. Instead send event?
if !c.hasSpace(len(cmp)) {
return nil
}
return c.writeFile(c.getMessagePath(userID, messageID), c.gcm[userID].Seal(nonce, nonce, cmp, nil))
}
func (c *onDiskCache) Rem(userID, messageID string) error {
defer c.update()
return os.Remove(c.getMessagePath(userID, messageID))
}
func (c *onDiskCache) readFile(path string) ([]byte, error) {
c.rsem.Lock()
defer c.rsem.Unlock()
// Wait before reading in case the file is currently being written.
c.pending.wait(path)
return ioutil.ReadFile(filepath.Clean(path))
}
func (c *onDiskCache) writeFile(path string, b []byte) error {
c.wsem.Lock()
defer c.wsem.Unlock()
// Mark the file as currently being written.
// If it's already being written, wait for it to be done and return nil.
// NOTE(GODT-1158): Let's hope it succeeded...
if ok := c.pending.add(path); !ok {
c.pending.wait(path)
return nil
}
defer c.pending.done(path)
// Reduce the approximate free space (update it exactly later).
c.lock.Lock()
c.diskFree -= uint64(len(b))
c.lock.Unlock()
// Update the diskFree eventually.
defer c.update()
// NOTE(GODT-1158): What happens when this fails? Should be fixed eventually.
return ioutil.WriteFile(filepath.Clean(path), b, 0600)
}
func (c *onDiskCache) hasSpace(size int) bool {
c.lock.Lock()
defer c.lock.Unlock()
if c.opts.MinFreeAbs > 0 {
if c.diskFree-uint64(size) < c.opts.MinFreeAbs {
return false
}
}
if c.opts.MinFreeRat > 0 {
if float64(c.diskFree-uint64(size))/float64(c.diskSize) < c.opts.MinFreeRat {
return false
}
}
return true
}
func (c *onDiskCache) update() {
go func() {
c.once.Do(func() {
c.lock.Lock()
defer c.lock.Unlock()
// Update the free space.
c.diskFree = du.NewDiskUsage(c.path).Available()
// Reset the Once object (so we can update again).
c.once = &sync.Once{}
})
}()
}
func (c *onDiskCache) getUserPath(userID string) string {
return filepath.Join(c.path, getHash(userID))
}
func (c *onDiskCache) getMessagePath(userID, messageID string) string {
return filepath.Join(c.getUserPath(userID), getHash(messageID))
}

33
internal/store/cache/hash.go vendored Normal file
View File

@ -0,0 +1,33 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"crypto/sha256"
"encoding/hex"
)
func getHash(name string) string {
hash := sha256.New()
if _, err := hash.Write([]byte(name)); err != nil {
panic(err)
}
return hex.EncodeToString(hash.Sum(nil))
}

104
internal/store/cache/memory.go vendored Normal file
View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"errors"
"sync"
)
type inMemoryCache struct {
lock sync.RWMutex
data map[string]map[string][]byte
size, limit int
}
// NewInMemoryCache creates a new in memory cache which stores up to the given number of bytes of cached data.
// NOTE(GODT-1158): Make this threadsafe.
func NewInMemoryCache(limit int) Cache {
return &inMemoryCache{
data: make(map[string]map[string][]byte),
limit: limit,
}
}
func (c *inMemoryCache) Unlock(userID string, passphrase []byte) error {
c.data[userID] = make(map[string][]byte)
return nil
}
func (c *inMemoryCache) Delete(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
for _, message := range c.data[userID] {
c.size -= len(message)
}
delete(c.data, userID)
return nil
}
// Has returns whether the given message exists in the cache.
func (c *inMemoryCache) Has(userID, messageID string) bool {
if _, err := c.Get(userID, messageID); err != nil {
return false
}
return true
}
func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) {
c.lock.RLock()
defer c.lock.RUnlock()
literal, ok := c.data[userID][messageID]
if !ok {
return nil, errors.New("no such message in cache")
}
return literal, nil
}
// NOTE(GODT-1158): What to actually do when memory limit is reached? Replace something existing? Return error? Drop silently?
// NOTE(GODT-1158): Pull in cache-rotating feature from old IMAP cache.
func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.size+len(literal) > c.limit {
return nil
}
c.size += len(literal)
c.data[userID][messageID] = literal
return nil
}
func (c *inMemoryCache) Rem(userID, messageID string) error {
c.lock.Lock()
defer c.lock.Unlock()
c.size -= len(c.data[userID][messageID])
delete(c.data[userID], messageID)
return nil
}

25
internal/store/cache/options.go vendored Normal file
View File

@ -0,0 +1,25 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Options struct {
MinFreeAbs uint64
MinFreeRat float64
ConcurrentRead int
ConcurrentWrite int
}

61
internal/store/cache/pending.go vendored Normal file
View File

@ -0,0 +1,61 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import "sync"
type pending struct {
lock sync.Mutex
path map[string]chan struct{}
}
func newPending() *pending {
return &pending{path: make(map[string]chan struct{})}
}
func (p *pending) add(path string) bool {
p.lock.Lock()
defer p.lock.Unlock()
if _, ok := p.path[path]; ok {
return false
}
p.path[path] = make(chan struct{})
return true
}
func (p *pending) wait(path string) {
p.lock.Lock()
ch, ok := p.path[path]
p.lock.Unlock()
if ok {
<-ch
}
}
func (p *pending) done(path string) {
p.lock.Lock()
defer p.lock.Unlock()
defer close(p.path[path])
delete(p.path, path)
}

51
internal/store/cache/pending_test.go vendored Normal file
View File

@ -0,0 +1,51 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPending(t *testing.T) {
pending := newPending()
pending.add("1")
pending.add("2")
pending.add("3")
resCh := make(chan string)
go func() { pending.wait("1"); resCh <- "1" }()
go func() { pending.wait("2"); resCh <- "2" }()
go func() { pending.wait("3"); resCh <- "3" }()
pending.done("1")
assert.Equal(t, "1", <-resCh)
pending.done("2")
assert.Equal(t, "2", <-resCh)
pending.done("3")
assert.Equal(t, "3", <-resCh)
}
func TestPendingUnknown(t *testing.T) {
newPending().wait("this is not currently being waited")
}

28
internal/store/cache/types.go vendored Normal file
View File

@ -0,0 +1,28 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package cache
type Cache interface {
Unlock(userID string, passphrase []byte) error
Delete(userID string) error
Has(userID, messageID string) bool
Get(userID, messageID string) ([]byte, error)
Set(userID, messageID string, literal []byte) error
Rem(userID, messageID string) error
}

View File

@ -0,0 +1,63 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import "time"
func (store *Store) StartWatcher() {
store.done = make(chan struct{})
go func() {
ticker := time.NewTicker(3 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// NOTE(GODT-1158): Race condition here? What if DB was already closed?
messageIDs, err := store.getAllMessageIDs()
if err != nil {
return
}
for _, messageID := range messageIDs {
if !store.IsCached(messageID) {
store.cacher.newJob(messageID)
}
}
case <-store.done:
return
}
}
}()
}
func (store *Store) stopWatcher() {
if store.done == nil {
return
}
select {
default:
close(store.done)
case <-store.done:
return
}
}

View File

@ -0,0 +1,104 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sync"
"github.com/sirupsen/logrus"
)
type Cacher struct {
storer Storer
jobs chan string
done chan struct{}
started bool
wg *sync.WaitGroup
}
type Storer interface {
IsCached(messageID string) bool
BuildAndCacheMessage(messageID string) error
}
func newCacher(storer Storer) *Cacher {
return &Cacher{
storer: storer,
jobs: make(chan string),
done: make(chan struct{}),
wg: &sync.WaitGroup{},
}
}
// newJob sends a new job to the cacher if it's running.
func (cacher *Cacher) newJob(messageID string) {
if !cacher.started {
return
}
select {
case <-cacher.done:
return
default:
if !cacher.storer.IsCached(messageID) {
cacher.wg.Add(1)
go func() { cacher.jobs <- messageID }()
}
}
}
func (cacher *Cacher) start() {
cacher.started = true
go func() {
for {
select {
case messageID := <-cacher.jobs:
go cacher.handleJob(messageID)
case <-cacher.done:
return
}
}
}()
}
func (cacher *Cacher) handleJob(messageID string) {
defer cacher.wg.Done()
if err := cacher.storer.BuildAndCacheMessage(messageID); err != nil {
logrus.WithError(err).Error("Failed to build and cache message")
} else {
logrus.WithField("messageID", messageID).Trace("Message cached")
}
}
func (cacher *Cacher) stop() {
cacher.started = false
cacher.wg.Wait()
select {
case <-cacher.done:
return
default:
close(cacher.done)
}
}

View File

@ -0,0 +1,103 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"testing"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
)
func withTestCacher(t *testing.T, doTest func(storer *storemocks.MockStorer, cacher *Cacher)) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Mock storer used to build/cache messages.
storer := storemocks.NewMockStorer(ctrl)
// Create a new cacher pointing to the fake store.
cacher := newCacher(storer)
// Start the cacher and wait for it to stop.
cacher.start()
defer cacher.stop()
doTest(storer, cacher)
}
func TestCacher(t *testing.T) {
// If the message is not yet cached, we should expect to try to build and cache it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
cacher.newJob("messageID")
})
}
func TestCacherAlreadyCached(t *testing.T) {
// If the message is already cached, we should not try to build it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(true)
cacher.newJob("messageID")
})
}
func TestCacherFail(t *testing.T) {
// If building the message fails, we should not try to cache it.
withTestCacher(t, func(storer *storemocks.MockStorer, cacher *Cacher) {
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(errors.New("failed to build message"))
cacher.newJob("messageID")
})
}
func TestCacherStop(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
// Mock storer used to build/cache messages.
storer := storemocks.NewMockStorer(ctrl)
// Create a new cacher pointing to the fake store.
cacher := newCacher(storer)
// Start the cacher.
cacher.start()
// Send a job -- this should succeed.
storer.EXPECT().IsCached("messageID").Return(false)
storer.EXPECT().BuildAndCacheMessage("messageID").Return(nil)
cacher.newJob("messageID")
// Stop the cacher.
cacher.stop()
// Send more jobs -- these should all be dropped.
cacher.newJob("messageID2")
cacher.newJob("messageID3")
cacher.newJob("messageID4")
cacher.newJob("messageID5")
// Stopping the cacher multiple times is safe.
cacher.stop()
cacher.stop()
cacher.stop()
cacher.stop()
}

View File

@ -34,7 +34,7 @@ func TestNotifyChangeCreateOrUpdateMessage(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@ -49,7 +49,7 @@ func TestNotifyChangeCreateOrUpdateMessages(t *testing.T) {
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(1), uint32(1), gomock.Any(), false)
m.changeNotifier.EXPECT().UpdateMessage(addr1, "All Mail", uint32(2), uint32(2), gomock.Any(), false)
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
m.store.SetChangeNotifier(m.changeNotifier)
msg1 := getTestMessage("msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
@ -61,7 +61,7 @@ func TestNotifyChangeDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})

View File

@ -38,7 +38,7 @@ const (
)
type eventLoop struct {
cache *Cache
currentEvents *Events
currentEventID string
currentEvent *pmapi.Event
pollCh chan chan struct{}
@ -51,26 +51,26 @@ type eventLoop struct {
log *logrus.Entry
store *Store
user BridgeUser
events listener.Listener
store *Store
user BridgeUser
listener listener.Listener
}
func newEventLoop(cache *Cache, store *Store, user BridgeUser, events listener.Listener) *eventLoop {
func newEventLoop(currentEvents *Events, store *Store, user BridgeUser, listener listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop")
return &eventLoop{
cache: cache,
currentEventID: cache.getEventID(user.ID()),
currentEvents: currentEvents,
currentEventID: currentEvents.getEventID(user.ID()),
pollCh: make(chan chan struct{}),
isRunning: false,
log: eventLog,
store: store,
user: user,
events: events,
store: store,
user: user,
listener: listener,
}
}
@ -89,7 +89,7 @@ func (loop *eventLoop) setFirstEventID() (err error) {
loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
if err = loop.currentEvents.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
loop.log.WithError(err).Error("Could not set latest event ID in user cache")
return
}
@ -229,7 +229,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
if err != nil && isFdCloseToULimit() {
l.Warn("Ulimit reached")
loop.events.Emit(bridgeEvents.RestartBridgeEvent, "")
loop.listener.Emit(bridgeEvents.RestartBridgeEvent, "")
err = nil
}
@ -291,7 +291,7 @@ func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[fun
// This allows the event loop to continue to function (unless the cache was broken
// and bridge stopped, in which case it will start from the old event ID anyway).
loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil {
if err = loop.currentEvents.setEventID(loop.user.ID(), event.EventID); err != nil {
return false, errors.Wrap(err, "failed to save event ID to cache")
}
}
@ -371,7 +371,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
switch addressEvent.Action {
case pmapi.EventCreate:
log.WithField("email", addressEvent.Address.Email).Debug("Address was created")
loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
loop.listener.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
case pmapi.EventUpdate:
oldAddress := oldList.ByID(addressEvent.ID)
@ -383,7 +383,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email
log.WithField("email", email).Debug("Address was updated")
if addressEvent.Address.Receive != oldAddress.Receive {
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
}
case pmapi.EventDelete:
@ -396,7 +396,7 @@ func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmap
email := oldAddress.Email
log.WithField("email", email).Debug("Address was deleted")
loop.user.CloseConnection(email)
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
loop.listener.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
case pmapi.EventUpdateFlags:
log.Error("EventUpdateFlags for address event is uknown operation")
}

View File

@ -53,7 +53,7 @@ func TestEventLoopProcessMoreEvents(t *testing.T) {
More: false,
}, nil),
)
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
// Event loop runs in goroutine started during store creation (newStoreNoEvents).
// Force to run the next event.
@ -78,7 +78,7 @@ func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
subject := "old subject"
newSubject := "new subject"
m.newStoreNoEvents(true, &pmapi.Message{
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: subject,
})
@ -106,7 +106,7 @@ func TestEventLoopDeletionNotPaused(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true, &pmapi.Message{
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
LabelIDs: []string{"label"},
@ -133,7 +133,7 @@ func TestEventLoopDeletionPaused(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true, &pmapi.Message{
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
LabelIDs: []string{"label"},

116
internal/store/events.go Normal file
View File

@ -0,0 +1,116 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"os"
"sync"
"github.com/pkg/errors"
)
// Events caches the last event IDs for all accounts (there should be only one instance).
type Events struct {
// eventMap is map from userID => key (such as last event) => value (such as event ID).
eventMap map[string]map[string]string
path string
lock *sync.RWMutex
}
// NewEvents constructs a new event cache at the given path.
func NewEvents(path string) *Events {
return &Events{
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Events) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
if err := c.loadEvents(); err != nil {
log.WithError(err).Warn("Problem to load store events")
}
if c.eventMap == nil {
c.eventMap = map[string]map[string]string{}
}
if c.eventMap[userID] == nil {
c.eventMap[userID] = map[string]string{}
}
return c.eventMap[userID]["events"]
}
func (c *Events) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.eventMap[userID] == nil {
c.eventMap[userID] = map[string]string{}
}
c.eventMap[userID]["events"] = eventID
return c.saveEvents()
}
func (c *Events) loadEvents() error {
if c.eventMap != nil {
return nil
}
f, err := os.Open(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.eventMap)
}
func (c *Events) saveEvents() error {
if c.eventMap == nil {
return errors.New("events: cannot save events: events map is nil")
}
f, err := os.Create(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.eventMap)
}
func (c *Events) clearUserEvents(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.eventMap == nil {
log.WithField("user", userID).Warning("Cannot clear user events: event map is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user events from event loop")
delete(c.eventMap, userID)
return c.saveEvents()
}

View File

@ -107,7 +107,7 @@ func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Sto
func TestMailboxCountRemove(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
testCounts := []*pmapi.MessagesCount{
{LabelID: "label1", Total: 100, Unread: 0},

View File

@ -35,7 +35,7 @@ func TestGetSequenceNumberAndGetUID(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@ -80,7 +80,7 @@ func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, false, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))

View File

@ -67,40 +67,19 @@ func (message *Message) Message() *pmapi.Message {
return message.msg
}
// IsMarkedDeleted returns true if message is marked as deleted for specific
// mailbox.
// IsMarkedDeleted returns true if message is marked as deleted for specific mailbox.
func (message *Message) IsMarkedDeleted() bool {
isMarkedAsDeleted := false
err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
var isMarkedAsDeleted bool
if err := message.storeMailbox.db().View(func(tx *bolt.Tx) error {
isMarkedAsDeleted = message.storeMailbox.txGetDeletedIDsBucket(tx).Get([]byte(message.msg.ID)) != nil
return nil
})
if err != nil {
}); err != nil {
message.storeMailbox.log.WithError(err).Error("Not able to retrieve deleted mark, assuming false.")
return false
}
return isMarkedAsDeleted
}
// SetSize updates the information about size of decrypted message which can be
// used for IMAP. This should not trigger any IMAP update.
// NOTE: The size from the server corresponds to pure body bytes. Hence it
// should not be used. The correct size has to be calculated from decrypted and
// built message.
func (message *Message) SetSize(size int64) error {
message.msg.Size = size
txUpdate := func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
}
stored.Size = size
return message.store.txPutMessage(
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
return isMarkedAsDeleted
}
// SetContentTypeAndHeader updates the information about content type and
@ -112,7 +91,7 @@ func (message *Message) SetSize(size int64) error {
func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error {
message.msg.MIMEType = mimeType
message.msg.Header = header
txUpdate := func(tx *bolt.Tx) error {
return message.store.db.Update(func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
@ -123,34 +102,26 @@ func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Hea
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
}
// SetHeader checks header can be parsed and if yes it stores header bytes in
// database.
func (message *Message) SetHeader(header []byte) error {
_, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(header))).ReadMIMEHeader()
if err != nil {
return err
}
return message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(headersBucket).Put([]byte(message.ID()), header)
})
}
// IsFullHeaderCached will check that valid full header is stored in DB.
func (message *Message) IsFullHeaderCached() bool {
header, err := message.getRawHeader()
return err == nil && header != nil
}
func (message *Message) getRawHeader() (raw []byte, err error) {
err = message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(headersBucket).Get([]byte(message.ID()))
var raw []byte
err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
return nil
})
return
return err == nil && raw != nil
}
func (message *Message) getRawHeader() ([]byte, error) {
bs, err := message.GetBodyStructure()
if err != nil {
return nil, err
}
return bs.GetMailHeaderBytes()
}
// GetHeader will return cached header from DB.
@ -178,44 +149,79 @@ func (message *Message) GetMIMEHeader() textproto.MIMEHeader {
return header
}
// SetBodyStructure stores serialized body structure in database.
func (message *Message) SetBodyStructure(bs *pkgMsg.BodyStructure) error {
txUpdate := func(tx *bolt.Tx) error {
return message.store.txPutBodyStructure(
tx.Bucket(bodystructureBucket),
message.ID(), bs,
)
}
return message.store.db.Update(txUpdate)
}
// GetBodyStructure returns the message's body structure.
// It checks first if it's in the store. If it is, it returns it from store,
// otherwise it computes it from the message cache (and saves the result to the store).
func (message *Message) GetBodyStructure() (*pkgMsg.BodyStructure, error) {
var raw []byte
// GetBodyStructure deserializes body structure from database. If body structure
// is not in database it returns nil error and nil body structure. If error
// occurs it returns nil body structure.
func (message *Message) GetBodyStructure() (bs *pkgMsg.BodyStructure, err error) {
txRead := func(tx *bolt.Tx) error {
bs, err = message.store.txGetBodyStructure(
tx.Bucket(bodystructureBucket),
message.ID(),
)
return err
}
if err = message.store.db.View(txRead); err != nil {
if err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(bodystructureBucket).Get([]byte(message.ID()))
return nil
}); err != nil {
return nil, err
}
if len(raw) > 0 {
// If not possible to deserialize just continue with build.
if bs, err := pkgMsg.DeserializeBodyStructure(raw); err == nil {
return bs, nil
}
}
literal, err := message.store.getCachedMessage(message.ID())
if err != nil {
return nil, err
}
bs, err := pkgMsg.NewBodyStructure(bytes.NewReader(literal))
if err != nil {
return nil, err
}
if raw, err = bs.Serialize(); err != nil {
return nil, err
}
if err := message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(bodystructureBucket).Put([]byte(message.ID()), raw)
}); err != nil {
return nil, err
}
return bs, nil
}
func (message *Message) IncreaseBuildCount() (times uint32, err error) {
txUpdate := func(tx *bolt.Tx) error {
times, err = message.store.txIncreaseMsgBuildCount(
tx.Bucket(msgBuildCountBucket),
message.ID(),
)
return err
}
if err = message.store.db.Update(txUpdate); err != nil {
// GetRFC822 returns the raw message literal.
func (message *Message) GetRFC822() ([]byte, error) {
return message.store.getCachedMessage(message.ID())
}
// GetRFC822Size returns the size of the raw message literal.
func (message *Message) GetRFC822Size() (uint32, error) {
var raw []byte
if err := message.store.db.View(func(tx *bolt.Tx) error {
raw = tx.Bucket(sizeBucket).Get([]byte(message.ID()))
return nil
}); err != nil {
return 0, err
}
return times, nil
if len(raw) > 0 {
return btoi(raw), nil
}
literal, err := message.store.getCachedMessage(message.ID())
if err != nil {
return 0, err
}
if err := message.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(sizeBucket).Put([]byte(message.ID()), itob(uint32(len(literal))))
}); err != nil {
return 0, err
}
return uint32(len(literal)), nil
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier)
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser,ChangeNotifier,Storer)
// Package mocks is a generated GoMock package.
package mocks
@ -318,3 +318,54 @@ func (mr *MockChangeNotifierMockRecorder) UpdateMessage(arg0, arg1, arg2, arg3,
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMessage", reflect.TypeOf((*MockChangeNotifier)(nil).UpdateMessage), arg0, arg1, arg2, arg3, arg4, arg5)
}
// MockStorer is a mock of Storer interface.
type MockStorer struct {
ctrl *gomock.Controller
recorder *MockStorerMockRecorder
}
// MockStorerMockRecorder is the mock recorder for MockStorer.
type MockStorerMockRecorder struct {
mock *MockStorer
}
// NewMockStorer creates a new mock instance.
func NewMockStorer(ctrl *gomock.Controller) *MockStorer {
mock := &MockStorer{ctrl: ctrl}
mock.recorder = &MockStorerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorer) EXPECT() *MockStorerMockRecorder {
return m.recorder
}
// BuildAndCacheMessage mocks base method.
func (m *MockStorer) BuildAndCacheMessage(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildAndCacheMessage", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// BuildAndCacheMessage indicates an expected call of BuildAndCacheMessage.
func (mr *MockStorerMockRecorder) BuildAndCacheMessage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildAndCacheMessage", reflect.TypeOf((*MockStorer)(nil).BuildAndCacheMessage), arg0)
}
// IsCached mocks base method.
func (m *MockStorer) IsCached(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsCached", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// IsCached indicates an expected call of IsCached.
func (mr *MockStorerMockRecorder) IsCached(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStorer)(nil).IsCached), arg0)
}

View File

@ -26,8 +26,11 @@ import (
"time"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/pkg/pool"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -52,19 +55,21 @@ var (
// Database structure:
// * metadata
// * {messageID} -> message data (subject, from, to, time, body size, ...)
// * {messageID} -> message data (subject, from, to, time, ...)
// * headers
// * {messageID} -> header bytes
// * bodystructure
// * {messageID} -> message body structure
// * msgbuildcount
// * {messageID} -> uint32 number of message builds to track re-sync issues
// * size
// * {messageID} -> uint32 value
// * counts
// * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive
// * address_info
// * {index} -> {address, addressID}
// * address_mode
// * mode -> string split or combined
// * cache_passphrase
// * passphrase -> cache passphrase (pgp encrypted message)
// * mailboxes_version
// * version -> uint32 value
// * sync_state
@ -79,19 +84,20 @@ var (
// * {messageID} -> uint32 imapUID
// * deleted_ids (can be missing or have no keys)
// * {messageID} -> true
metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
headersBucket = []byte("headers") //nolint[gochecknoglobals]
bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
msgBuildCountBucket = []byte("msgbuildcount") //nolint[gochecknoglobals]
countsBucket = []byte("counts") //nolint[gochecknoglobals]
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
headersBucket = []byte("headers") //nolint[gochecknoglobals]
bodystructureBucket = []byte("bodystructure") //nolint[gochecknoglobals]
sizeBucket = []byte("size") //nolint[gochecknoglobals]
countsBucket = []byte("counts") //nolint[gochecknoglobals]
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
cachePassphraseBucket = []byte("cache_passphrase") //nolint[gochecknoglobals]
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
deletedIDsBucket = []byte("deleted_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
// ErrNoSuchAPIID when mailbox does not have API ID.
ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals]
@ -117,18 +123,23 @@ func exposeContextForSMTP() context.Context {
type Store struct {
sentryReporter *sentry.Reporter
panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser
eventLoop *eventLoop
currentEvents *Events
log *logrus.Entry
cache *Cache
filePath string
db *bolt.DB
lock *sync.RWMutex
addresses map[string]*Address
notifier ChangeNotifier
builder *message.Builder
cache cache.Cache
cacher *Cacher
done chan struct{}
isSyncRunning bool
syncCooldown cooldown
addressMode addressMode
@ -139,12 +150,14 @@ func New( // nolint[funlen]
sentryReporter *sentry.Reporter,
panicHandler PanicHandler,
user BridgeUser,
events listener.Listener,
listener listener.Listener,
cache cache.Cache,
builder *message.Builder,
path string,
cache *Cache,
currentEvents *Events,
) (store *Store, err error) {
if user == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, events: %v, cache: %v", user, events, cache)
if user == nil || listener == nil || currentEvents == nil {
return nil, fmt.Errorf("missing parameters - user: %v, listener: %v, currentEvents: %v", user, listener, currentEvents)
}
l := log.WithField("user", user.ID())
@ -160,21 +173,29 @@ func New( // nolint[funlen]
bdb, err := openBoltDatabase(path)
if err != nil {
err = errors.Wrap(err, "failed to open store database")
return
return nil, errors.Wrap(err, "failed to open store database")
}
store = &Store{
sentryReporter: sentryReporter,
panicHandler: panicHandler,
user: user,
cache: cache,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
log: l,
currentEvents: currentEvents,
log: l,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
builder: builder,
cache: cache,
}
// Create a new cacher. It's not started yet.
// NOTE(GODT-1158): I hate this circular dependency store->cacher->store :(
store.cacher = newCacher(store)
// Minimal increase is event pollInterval, doubles every failed retry up to 5 minutes.
store.syncCooldown.setExponentialWait(pollInterval, 2, 5*time.Minute)
@ -188,7 +209,7 @@ func New( // nolint[funlen]
}
if user.IsConnected() {
store.eventLoop = newEventLoop(cache, store, user, events)
store.eventLoop = newEventLoop(currentEvents, store, user, listener)
go func() {
defer store.panicHandler.HandlePanic()
store.eventLoop.start()
@ -216,10 +237,11 @@ func openBoltDatabase(filePath string) (db *bolt.DB, err error) {
metadataBucket,
headersBucket,
bodystructureBucket,
msgBuildCountBucket,
sizeBucket,
countsBucket,
addressInfoBucket,
addressModeBucket,
cachePassphraseBucket,
syncStateBucket,
mailboxesBucket,
mboxVersionBucket,
@ -365,6 +387,24 @@ func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label)
return
}
// newBuildJob returns a new build job for the given message using the store's message builder.
func (store *Store) newBuildJob(messageID string, priority int) (*message.Job, pool.DoneFunc) {
return store.builder.NewJobWithOptions(
context.Background(),
store.client(),
messageID,
message.JobOptions{
IgnoreDecryptionErrors: true, // Whether to ignore decryption errors and create a "custom message" instead.
SanitizeDate: true, // Whether to replace all dates before 1970 with RFC822's birthdate.
AddInternalID: true, // Whether to include MessageID as X-Pm-Internal-Id.
AddExternalID: true, // Whether to include ExternalID as X-Pm-External-Id.
AddMessageDate: true, // Whether to include message time as X-Pm-Date.
AddMessageIDReference: true, // Whether to include the MessageID in References.
},
priority,
)
}
// Close stops the event loop and closes the database to free the file.
func (store *Store) Close() error {
store.lock.Lock()
@ -381,12 +421,21 @@ func (store *Store) CloseEventLoop() {
}
func (store *Store) close() error {
// Stop the watcher first before closing the database.
store.stopWatcher()
// Stop the cacher.
store.cacher.stop()
// Stop the event loop.
store.CloseEventLoop()
// Close the database.
return store.db.Close()
}
// Remove closes and removes the database file and clears the cache file.
func (store *Store) Remove() (err error) {
func (store *Store) Remove() error {
store.lock.Lock()
defer store.lock.Unlock()
@ -394,22 +443,34 @@ func (store *Store) Remove() (err error) {
var result *multierror.Error
if err = store.close(); err != nil {
if err := store.close(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to close store"))
}
if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil {
if err := RemoveStore(store.currentEvents, store.filePath, store.user.ID()); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove store"))
}
if err := store.RemoveCache(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove cache"))
}
return result.ErrorOrNil()
}
func (store *Store) RemoveCache() error {
if err := store.clearCachePassphrase(); err != nil {
logrus.WithError(err).Error("Failed to clear cache passphrase")
}
return store.cache.Delete(store.user.ID())
}
// RemoveStore removes the database file and clears the cache file.
func RemoveStore(cache *Cache, path, userID string) error {
func RemoveStore(currentEvents *Events, path, userID string) error {
var result *multierror.Error
if err := cache.clearCacheUser(userID); err != nil {
if err := currentEvents.clearUserEvents(userID); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache"))
}

View File

@ -23,13 +23,17 @@ import (
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
storemocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
tests "github.com/ProtonMail/proton-bridge/test"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
@ -139,7 +143,7 @@ type mocksForStore struct {
store *Store
tmpDir string
cache *Cache
cache *Events
}
func initMocks(tb testing.TB) (*mocksForStore, func()) {
@ -162,7 +166,7 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
require.NoError(tb, err)
cacheFile := filepath.Join(mocks.tmpDir, "cache.json")
mocks.cache = NewCache(cacheFile)
mocks.cache = NewEvents(cacheFile)
return mocks, func() {
if err := recover(); err != nil {
@ -176,13 +180,14 @@ func initMocks(tb testing.TB) (*mocksForStore, func()) {
}
}
func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, msgs ...*pmapi.Message) { //nolint[unparam]
mocks.user.EXPECT().ID().Return("userID").AnyTimes()
mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes()
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
@ -213,6 +218,8 @@ func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool, msgs ...*pmapi.M
mocks.panicHandler,
mocks.user,
mocks.events,
cache.NewInMemoryCache(1<<20),
message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache,
)

View File

@ -27,7 +27,6 @@ import (
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -154,11 +153,6 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
return false, err
}
msgSize := message.Size
if msgSize == 0 {
msgSize = int64(len(message.Body))
}
var attSize int64
for _, att := range attachments {
b, err := ioutil.ReadAll(att.encReader)
@ -169,7 +163,7 @@ func (store *Store) checkDraftTotalSize(message *pmapi.Message, attachments []*d
att.encReader = bytes.NewBuffer(b)
}
return msgSize+attSize <= maxUpload, nil
return int64(len(message.Body))+attSize <= maxUpload, nil
}
func (store *Store) getDraftAction(message *pmapi.Message) int {
@ -237,39 +231,6 @@ func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Messag
return nil
}
func (store *Store) txPutBodyStructure(bsBucket *bolt.Bucket, msgID string, bs *pkgMsg.BodyStructure) error {
raw, err := bs.Serialize()
if err != nil {
return err
}
err = bsBucket.Put([]byte(msgID), raw)
if err != nil {
return errors.Wrap(err, "cannot put bodystructure bucket")
}
return nil
}
func (store *Store) txGetBodyStructure(bsBucket *bolt.Bucket, msgID string) (*pkgMsg.BodyStructure, error) {
raw := bsBucket.Get([]byte(msgID))
if len(raw) == 0 {
return nil, nil
}
return pkgMsg.DeserializeBodyStructure(raw)
}
func (store *Store) txIncreaseMsgBuildCount(b *bolt.Bucket, msgID string) (uint32, error) {
key := []byte(msgID)
count := uint32(0)
raw := b.Get(key)
if raw != nil {
count = btoi(raw)
}
count++
return count, b.Put(key, itob(count))
}
// createOrUpdateMessageEvent is helper to create only one message with
// createOrUpdateMessagesEvent.
func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error {
@ -287,7 +248,7 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
b := tx.Bucket(metadataBucket)
for _, msg := range msgs {
clearNonMetadata(msg)
txUpdateMetadaFromDB(b, msg, store.log)
txUpdateMetadataFromDB(b, msg, store.log)
}
return nil
})
@ -341,6 +302,11 @@ func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { /
return err
}
// Notify the cacher that it should start caching messages.
for _, msg := range msgs {
store.cacher.newJob(msg.ID)
}
return nil
}
@ -351,16 +317,12 @@ func clearNonMetadata(onlyMeta *pmapi.Message) {
onlyMeta.Attachments = nil
}
// txUpdateMetadaFromDB changes the the onlyMeta data.
// txUpdateMetadataFromDB changes the the onlyMeta data.
// If there is stored message in metaBucket the size, header and MIMEType are
// not changed if already set. To change these:
// * size must be updated by Message.SetSize
// * contentType and header must be updated by Message.SetContentTypeAndHeader.
func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
// Size attribute on the server is counting encrypted data. We need to compute
// "real" size of decrypted data. Negative values will be processed during fetch.
onlyMeta.Size = -1
func txUpdateMetadataFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
msgb := metaBucket.Get([]byte(onlyMeta.ID))
if msgb == nil {
return
@ -378,8 +340,7 @@ func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log
return
}
// Keep already calculated size and content type.
onlyMeta.Size = stored.Size
// Keep content type.
onlyMeta.MIMEType = stored.MIMEType
if stored.Header != "" && stored.Header != "(No Header)" {
tmpMsg, err := mail.ReadMessage(
@ -401,6 +362,12 @@ func (store *Store) deleteMessageEvent(apiID string) error {
// deleteMessagesEvent deletes the message from metadata and all mailbox buckets.
func (store *Store) deleteMessagesEvent(apiIDs []string) error {
for _, messageID := range apiIDs {
if err := store.cache.Rem(store.UserID(), messageID); err != nil {
logrus.WithError(err).Error("Failed to remove message from cache")
}
}
return store.db.Update(func(tx *bolt.Tx) error {
for _, apiID := range apiIDs {
if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil {

View File

@ -33,7 +33,7 @@ func TestGetAllMessageIDs(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
@ -47,7 +47,7 @@ func TestGetMessageFromDB(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{
@ -72,7 +72,7 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1")
@ -81,12 +81,10 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
// Check non-meta and calculated data are cleared/empty.
a.Equal(t, "", msg.Body)
a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments)
a.Equal(t, int64(-1), msg.Size)
a.Equal(t, "", msg.MIMEType)
a.Equal(t, make(mail.Header), msg.Header)
// Change the calculated data.
wantSize := int64(42)
wantMIMEType := "plain-text"
wantHeader := mail.Header{
"Key": []string{"value"},
@ -94,13 +92,11 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1")
require.Nil(t, err)
require.Nil(t, storeMsg.SetSize(wantSize))
require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader))
// Check calculated data.
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
@ -109,7 +105,6 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
}
@ -118,7 +113,7 @@ func TestDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel})
@ -129,8 +124,7 @@ func TestDeleteMessage(t *testing.T) {
}
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread bool, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
require.Nil(t, m.store.createOrUpdateMessageEvent(getTestMessage(id, subject, sender, unread, labelIDs)))
}
func getTestMessage(id, subject, sender string, unread bool, labelIDs []string) *pmapi.Message {
@ -142,7 +136,6 @@ func getTestMessage(id, subject, sender string, unread bool, labelIDs []string)
Sender: address,
ToList: []*mail.Address{address},
LabelIDs: labelIDs,
Size: 12345,
Body: "body of message",
Attachments: []*pmapi.Attachment{{
ID: "attachment1",
@ -162,7 +155,7 @@ func TestCreateDraftCheckMessageSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(false)
m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 100, // Decrypted message 5 chars, encrypted 500+.
}, nil)
@ -181,7 +174,7 @@ func TestCreateDraftCheckMessageWithAttachmentSize(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(false)
m.newStoreNoEvents(t, false)
m.client.EXPECT().CurrentUser(gomock.Any()).Return(&pmapi.User{
MaxUpload: 800, // Decrypted message 5 chars + 5 chars of attachment, encrypted 500+ + 300+.
}, nil)

View File

@ -30,7 +30,7 @@ func TestLoadSaveSyncState(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
m.newStoreNoEvents(t, true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, false, []string{pmapi.AllMailLabel, pmapi.InboxLabel})

View File

@ -107,6 +107,21 @@ func (u *User) connect(client pmapi.Client, creds *credentials.Credentials) erro
return err
}
// If the client is already unlocked, we can unlock the store cache as well.
if client.IsUnlocked() {
kr, err := client.GetUserKeyRing()
if err != nil {
return err
}
if err := u.store.UnlockCache(kr); err != nil {
return err
}
// NOTE(GODT-1158): If using in-memory cache we probably shouldn't start the watcher?
u.store.StartWatcher()
}
return nil
}

View File

@ -32,7 +32,7 @@ func TestUpdateUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@ -50,7 +50,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
// Ignore any sync on background.
@ -76,7 +76,7 @@ func TestUserSwitchAddressMode(t *testing.T) {
r.False(t, user.creds.IsCombinedAddressMode)
r.False(t, user.IsCombinedAddressMode())
// MOck change to combined mode.
// Mock change to combined mode.
gomock.InOrder(
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "users@pm.me"),
m.eventListener.EXPECT().Emit(events.CloseConnectionEvent, "anotheruser@pm.me"),
@ -98,7 +98,7 @@ func TestLogoutUser(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@ -115,7 +115,7 @@ func TestLogoutUserFailsLogout(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
gomock.InOrder(
@ -133,7 +133,7 @@ func TestCheckBridgeLogin(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin(testCredentials.BridgePassword)
@ -144,7 +144,7 @@ func TestCheckBridgeLoginUpgradeApplication(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
m.eventListener.EXPECT().Emit(events.UpgradeApplicationEvent, "")
@ -187,7 +187,7 @@ func TestCheckBridgeLoginBadPassword(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
err := user.CheckBridgeLogin("wrong!")

View File

@ -64,7 +64,7 @@ func TestNewUser(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
checkNewUserHasCredentials(m, "", testCredentials)

View File

@ -31,7 +31,7 @@ func TestClearStoreWithStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
r.Nil(t, user.store.Close())
@ -43,7 +43,7 @@ func TestClearStoreWithoutStore(t *testing.T) {
m := initMocks(t)
defer m.ctrl.Finish()
user := testNewUser(m)
user := testNewUser(t, m)
defer cleanUpUserData(user)
r.NotNil(t, user.store)

View File

@ -18,13 +18,15 @@
package users
import (
"testing"
r "github.com/stretchr/testify/require"
)
// testNewUser sets up a new, authorised user.
func testNewUser(m mocks) *User {
func testNewUser(t *testing.T, m mocks) *User {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
user, creds, err := newUser(m.PanicHandler, "user", m.eventListener, m.credentialsStore, m.storeMaker)

View File

@ -20,12 +20,13 @@ package users
import (
"context"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/internal/events"
imapcache "github.com/ProtonMail/proton-bridge/internal/imap/cache"
"github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
"github.com/ProtonMail/proton-bridge/pkg/listener"
@ -225,6 +226,7 @@ func (u *Users) FinishLogin(client pmapi.Client, auth *pmapi.Auth, password []by
return nil, errors.Wrap(err, "failed to update password of user in credentials store")
}
// will go and unlock cache if not already done
if err := user.connect(client, creds); err != nil {
return nil, errors.Wrap(err, "failed to reconnect existing user")
}
@ -341,9 +343,6 @@ func (u *Users) ClearData() error {
result = multierror.Append(result, err)
}
// Need to clear imap cache otherwise fetch response will be remembered from previous test.
imapcache.Clear()
return result
}
@ -366,6 +365,7 @@ func (u *Users) DeleteUser(userID string, clearStore bool) error {
if err := user.closeStore(); err != nil {
log.WithError(err).Error("Failed to close user store")
}
if clearStore {
// Clear cache after closing connections (done in logout).
if err := user.clearStore(); err != nil {
@ -427,6 +427,41 @@ func (u *Users) DisallowProxy() {
u.clientManager.DisallowProxy()
}
func (u *Users) EnableCache() error {
// NOTE(GODT-1158): Check for available size before enabling.
return nil
}
func (u *Users) MigrateCache(from, to string) error {
// NOTE(GODT-1158): Is it enough to just close the store? Do we need to force-close the cacher too?
for _, user := range u.users {
if err := user.closeStore(); err != nil {
logrus.WithError(err).Error("Failed to close user's store")
}
}
// Ensure the parent directory exists.
if err := os.MkdirAll(filepath.Dir(to), 0700); err != nil {
return err
}
return os.Rename(from, to)
}
func (u *Users) DisableCache() error {
// NOTE(GODT-1158): Is it an error if we can't remove a user's cache?
for _, user := range u.users {
if err := user.store.RemoveCache(); err != nil {
logrus.WithError(err).Error("Failed to remove user's message cache")
}
}
return nil
}
// hasUser returns whether the struct currently has a user with ID `id`.
func (u *Users) hasUser(id string) (user *User, ok bool) {
for _, u := range u.users {

View File

@ -49,7 +49,7 @@ func TestUsersFinishLoginNewUser(t *testing.T) {
// Init users with no user from keychain.
m.credentialsStore.EXPECT().List().Return([]string{}, nil)
mockAddingConnectedUser(m)
mockAddingConnectedUser(t, m)
mockEventLoopNoAction(m)
m.clientManager.EXPECT().SendSimpleMetric(gomock.Any(), string(metrics.Setup), string(metrics.NewUser), string(metrics.NoLabel))
@ -74,7 +74,7 @@ func TestUsersFinishLoginExistingDisconnectedUser(t *testing.T) {
m.credentialsStore.EXPECT().UpdateToken(testCredentialsDisconnected.UserID, testAuthRefresh.UID, testAuthRefresh.RefreshToken).Return(testCredentials, nil),
m.credentialsStore.EXPECT().UpdatePassword(testCredentialsDisconnected.UserID, testCredentials.MailboxPassword).Return(testCredentials, nil),
)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
mockEventLoopNoAction(m)
m.eventListener.EXPECT().Emit(events.UserRefreshEvent, testCredentialsDisconnected.UserID)
@ -95,7 +95,7 @@ func TestUsersFinishLoginConnectedUser(t *testing.T) {
// Mock loading connected user.
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
// Mock process of FinishLogin of already connected user.

View File

@ -49,7 +49,7 @@ func TestNewUsersWithConnectedUser(t *testing.T) {
defer m.ctrl.Finish()
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentials})
}
@ -71,7 +71,7 @@ func TestNewUsersWithUsers(t *testing.T) {
m.credentialsStore.EXPECT().List().Return([]string{testCredentialsDisconnected.UserID, testCredentials.UserID}, nil)
mockLoadingDisconnectedUser(m, testCredentialsDisconnected)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentials)
mockEventLoopNoAction(m)
checkUsersNew(t, m, []*credentials.Credentials{testCredentialsDisconnected, testCredentials})
}

View File

@ -21,6 +21,7 @@ import (
"fmt"
"io/ioutil"
"os"
"runtime"
"runtime/debug"
"testing"
"time"
@ -28,10 +29,13 @@ import (
"github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/internal/sentry"
"github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/internal/users/credentials"
usersmocks "github.com/ProtonMail/proton-bridge/internal/users/mocks"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
pmapimocks "github.com/ProtonMail/proton-bridge/pkg/pmapi/mocks"
tests "github.com/ProtonMail/proton-bridge/test"
gomock "github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -42,9 +46,11 @@ func TestMain(m *testing.M) {
if os.Getenv("VERBOSITY") == "fatal" {
logrus.SetLevel(logrus.FatalLevel)
}
if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel)
}
os.Exit(m.Run())
}
@ -151,7 +157,7 @@ type mocks struct {
clientManager *pmapimocks.MockManager
pmapiClient *pmapimocks.MockClient
storeCache *store.Cache
storeCache *store.Events
}
func initMocks(t *testing.T) mocks {
@ -178,7 +184,7 @@ func initMocks(t *testing.T) mocks {
clientManager: pmapimocks.NewMockManager(mockCtrl),
pmapiClient: pmapimocks.NewMockClient(mockCtrl),
storeCache: store.NewCache(cacheFile.Name()),
storeCache: store.NewEvents(cacheFile.Name()),
}
// Called during clean-up.
@ -187,9 +193,20 @@ func initMocks(t *testing.T) mocks {
// Set up store factory.
m.storeMaker.EXPECT().New(gomock.Any()).DoAndReturn(func(user store.BridgeUser) (*store.Store, error) {
var sentryReporter *sentry.Reporter // Sentry reporter is not used under unit tests.
dbFile, err := ioutil.TempFile("", "bridge-store-db-*.db")
dbFile, err := ioutil.TempFile(t.TempDir(), "bridge-store-db-*.db")
r.NoError(t, err, "could not get temporary file for store db")
return store.New(sentryReporter, m.PanicHandler, user, m.eventListener, dbFile.Name(), m.storeCache)
return store.New(
sentryReporter,
m.PanicHandler,
user,
m.eventListener,
cache.NewInMemoryCache(1<<20),
message.NewBuilder(runtime.NumCPU(), runtime.NumCPU()),
dbFile.Name(),
m.storeCache,
)
}).AnyTimes()
m.storeMaker.EXPECT().Remove(gomock.Any()).AnyTimes()
@ -212,8 +229,8 @@ func (fr *fullStackReporter) Fatalf(format string, args ...interface{}) {
func testNewUsersWithUsers(t *testing.T, m mocks) *Users {
m.credentialsStore.EXPECT().List().Return([]string{testCredentials.UserID, testCredentialsSplit.UserID}, nil)
mockLoadingConnectedUser(m, testCredentials)
mockLoadingConnectedUser(m, testCredentialsSplit)
mockLoadingConnectedUser(t, m, testCredentials)
mockLoadingConnectedUser(t, m, testCredentialsSplit)
mockEventLoopNoAction(m)
return testNewUsers(t, m)
@ -245,7 +262,7 @@ func cleanUpUsersData(b *Users) {
}
}
func mockAddingConnectedUser(m mocks) {
func mockAddingConnectedUser(t *testing.T, m mocks) {
gomock.InOrder(
// Mock of users.FinishLogin.
m.pmapiClient.EXPECT().AuthSalt(gomock.Any()).Return("", nil),
@ -256,10 +273,10 @@ func mockAddingConnectedUser(m mocks) {
m.credentialsStore.EXPECT().Get("user").Return(testCredentials, nil),
)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
}
func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
func mockLoadingConnectedUser(t *testing.T, m mocks, creds *credentials.Credentials) {
authRefresh := &pmapi.AuthRefresh{
UID: "uid",
AccessToken: "acc",
@ -273,10 +290,10 @@ func mockLoadingConnectedUser(m mocks, creds *credentials.Credentials) {
m.credentialsStore.EXPECT().UpdateToken(creds.UserID, authRefresh.UID, authRefresh.RefreshToken).Return(creds, nil),
)
mockInitConnectedUser(m)
mockInitConnectedUser(t, m)
}
func mockInitConnectedUser(m mocks) {
func mockInitConnectedUser(t *testing.T, m mocks) {
// Mock of user initialisation.
m.pmapiClient.EXPECT().AddAuthRefreshHandler(gomock.Any())
m.pmapiClient.EXPECT().IsUnlocked().Return(true).AnyTimes()
@ -286,6 +303,7 @@ func mockInitConnectedUser(m mocks) {
m.pmapiClient.EXPECT().ListLabels(gomock.Any()).Return([]*pmapi.Label{}, nil),
m.pmapiClient.EXPECT().CountMessages(gomock.Any(), "").Return([]*pmapi.MessagesCount{}, nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.pmapiClient.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes(),
)
}