GODT-1482: Comment or mitigate panics, unlock cache when needed.

This commit is contained in:
Jakub
2021-12-21 11:34:21 +01:00
committed by Jakub Cuth
parent e9c05c5a6b
commit df601ecbbd
15 changed files with 312 additions and 51 deletions

View File

@ -40,9 +40,9 @@ func (im *imapMailbox) getMessage(storeMessage storeMessageProvider, items []ima
for _, item := range items { for _, item := range items {
switch item { switch item {
case imap.FetchEnvelope: case imap.FetchEnvelope:
// No need to check IsFullHeaderCached here. API header // No need to retrieve full header here. API header
// contain enough information to build the envelope. // contains enough information to build the envelope.
msg.Envelope = message.GetEnvelope(m, storeMessage.GetMIMEHeader()) msg.Envelope = message.GetEnvelope(m, storeMessage.GetMIMEHeaderFast())
case imap.FetchBody, imap.FetchBodyStructure: case imap.FetchBody, imap.FetchBodyStructure:
structure, err := im.getBodyStructure(storeMessage) structure, err := im.getBodyStructure(storeMessage)
if err != nil { if err != nil {
@ -158,7 +158,10 @@ func (im *imapMailbox) getMessageBodySection(storeMessage storeMessageProvider,
isMainHeaderRequested := len(section.Path) == 0 && section.Specifier == imap.HeaderSpecifier isMainHeaderRequested := len(section.Path) == 0 && section.Specifier == imap.HeaderSpecifier
if isMainHeaderRequested && storeMessage.IsFullHeaderCached() { if isMainHeaderRequested && storeMessage.IsFullHeaderCached() {
header = storeMessage.GetHeader() var err error
if header, err = storeMessage.GetHeader(); err != nil {
return nil, err
}
} else { } else {
structure, bodyReader, err := im.getBodyAndStructure(storeMessage) structure, bodyReader, err := im.getBodyAndStructure(storeMessage)
if err != nil { if err != nil {

View File

@ -383,8 +383,9 @@ func (im *imapMailbox) SearchMessages(isUID bool, criteria *imap.SearchCriteria)
} }
} }
// In order to speed up search it is not needed to check if IsFullHeaderCached. // In order to speed up search it is not needed to always
header := storeMessage.GetMIMEHeader() // retrieve the fully cached header.
header := storeMessage.GetMIMEHeaderFast()
if !criteria.SentBefore.IsZero() || !criteria.SentSince.IsZero() { if !criteria.SentBefore.IsZero() || !criteria.SentSince.IsZero() {
t, err := mail.Header(header).Date() t, err := mail.Header(header).Date()

View File

@ -99,10 +99,10 @@ type storeMessageProvider interface {
Message() *pmapi.Message Message() *pmapi.Message
IsMarkedDeleted() bool IsMarkedDeleted() bool
GetHeader() []byte GetHeader() ([]byte, error)
GetRFC822() ([]byte, error) GetRFC822() ([]byte, error)
GetRFC822Size() (uint32, error) GetRFC822Size() (uint32, error)
GetMIMEHeader() textproto.MIMEHeader GetMIMEHeaderFast() textproto.MIMEHeader
IsFullHeaderCached() bool IsFullHeaderCached() bool
GetBodyStructure() (*pkgMsg.BodyStructure, error) GetBodyStructure() (*pkgMsg.BodyStructure, error)
} }

View File

@ -21,8 +21,8 @@ import (
"context" "context"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/internal/store/cache"
"github.com/ProtonMail/proton-bridge/pkg/message" "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
@ -110,8 +110,15 @@ func SetBuildAndCacheJobLimit(maxJobs int) {
} }
func (store *Store) getCachedMessage(messageID string) ([]byte, error) { func (store *Store) getCachedMessage(messageID string) ([]byte, error) {
if store.cache.Has(store.user.ID(), messageID) { if store.IsCached(messageID) {
return store.cache.Get(store.user.ID(), messageID) literal, err := store.cache.Get(store.user.ID(), messageID)
if err == nil {
return literal, nil
}
store.log.
WithField("msg", messageID).
WithError(err).
Warn("Message is cached but cannot be retrieved")
} }
job, done := store.newBuildJob(context.Background(), messageID, message.ForegroundPriority) job, done := store.newBuildJob(context.Background(), messageID, message.ForegroundPriority)
@ -123,17 +130,43 @@ func (store *Store) getCachedMessage(messageID string) ([]byte, error) {
} }
if !store.isMessageADraft(messageID) { if !store.isMessageADraft(messageID) {
if err := store.cache.Set(store.user.ID(), messageID, literal); err != nil { if err := store.writeToCacheUnlockIfFails(messageID, literal); err != nil {
logrus.WithError(err).Error("Failed to cache message") store.log.WithError(err).Error("Failed to cache message")
} }
} else {
store.log.Debug("Skipping cache draft message")
} }
return literal, nil return literal, nil
} }
func (store *Store) writeToCacheUnlockIfFails(messageID string, literal []byte) error {
err := store.cache.Set(store.user.ID(), messageID, literal)
if err == nil && err != cache.ErrCacheNeedsUnlock {
return err
}
kr, err := store.client().GetUserKeyRing()
if err != nil {
return err
}
if err := store.UnlockCache(kr); err != nil {
return err
}
return store.cache.Set(store.user.ID(), messageID, literal)
}
// IsCached returns whether the given message already exists in the cache. // IsCached returns whether the given message already exists in the cache.
func (store *Store) IsCached(messageID string) bool { func (store *Store) IsCached(messageID string) (has bool) {
return store.cache.Has(store.user.ID(), messageID) defer func() {
if r := recover(); r != nil {
store.log.WithField("recovered", r).Error("Cannot retrieve whether message exits, assuming no")
}
}()
has = store.cache.Has(store.user.ID(), messageID)
return
} }
// BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache. // BuildAndCacheMessage builds the given message (with background priority) and puts it in the cache.

View File

@ -33,6 +33,7 @@ import (
"github.com/ricochet2200/go-disk-usage/du" "github.com/ricochet2200/go-disk-usage/du"
) )
var ErrMsgCorrupted = errors.New("ecrypted file was corrupted")
var ErrLowSpace = errors.New("not enough free space left on device") var ErrLowSpace = errors.New("not enough free space left on device")
// IsOnDiskCache will return true if Cache is type of onDiskCache. // IsOnDiskCache will return true if Cache is type of onDiskCache.
@ -86,6 +87,10 @@ func NewOnDiskCache(path string, cmp Compressor, opts Options) (Cache, error) {
}, nil }, nil
} }
func (c *onDiskCache) Lock(userID string) {
delete(c.gcm, userID)
}
func (c *onDiskCache) Unlock(userID string, passphrase []byte) error { func (c *onDiskCache) Unlock(userID string, passphrase []byte) error {
hash := sha256.New() hash := sha256.New()
@ -135,17 +140,29 @@ func (c *onDiskCache) Has(userID, messageID string) bool {
return false return false
default: default:
// Cannot decide whether the message is cached or not.
// Potential recover needs to be don in caller function.
panic(err) panic(err)
} }
} }
func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) { func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) {
gcm, ok := c.gcm[userID]
if !ok || gcm == nil {
return nil, ErrCacheNeedsUnlock
}
enc, err := c.readFile(c.getMessagePath(userID, messageID)) enc, err := c.readFile(c.getMessagePath(userID, messageID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
cmp, err := c.gcm[userID].Open(nil, enc[:c.gcm[userID].NonceSize()], enc[c.gcm[userID].NonceSize():], nil) // Data stored in file must larger than NonceSize.
if len(enc) <= gcm.NonceSize() {
return nil, ErrMsgCorrupted
}
cmp, err := gcm.Open(nil, enc[:gcm.NonceSize()], enc[gcm.NonceSize():], nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -154,7 +171,11 @@ func (c *onDiskCache) Get(userID, messageID string) ([]byte, error) {
} }
func (c *onDiskCache) Set(userID, messageID string, literal []byte) error { func (c *onDiskCache) Set(userID, messageID string, literal []byte) error {
nonce := make([]byte, c.gcm[userID].NonceSize()) gcm, ok := c.gcm[userID]
if !ok {
return ErrCacheNeedsUnlock
}
nonce := make([]byte, gcm.NonceSize())
if _, err := rand.Read(nonce); err != nil { if _, err := rand.Read(nonce); err != nil {
return err return err
@ -165,12 +186,13 @@ func (c *onDiskCache) Set(userID, messageID string, literal []byte) error {
return err return err
} }
// NOTE(GODT-1158): How to properly handle low space? Don't return error, that's bad. Instead send event? // NOTE(GODT-1158, GODT-1488): Need to properly handle low space. Don't
// return error, that's bad. Send event and clean least used message.
if !c.hasSpace(len(cmp)) { if !c.hasSpace(len(cmp)) {
return nil return nil
} }
return c.writeFile(c.getMessagePath(userID, messageID), c.gcm[userID].Seal(nonce, nonce, cmp, nil)) return c.writeFile(c.getMessagePath(userID, messageID), gcm.Seal(nonce, nonce, cmp, nil))
} }
func (c *onDiskCache) Rem(userID, messageID string) error { func (c *onDiskCache) Rem(userID, messageID string) error {

View File

@ -26,6 +26,7 @@ func getHash(name string) string {
hash := sha256.New() hash := sha256.New()
if _, err := hash.Write([]byte(name)); err != nil { if _, err := hash.Write([]byte(name)); err != nil {
// sha256.Write always returns nill err so this should never happen
panic(err) panic(err)
} }

View File

@ -28,7 +28,8 @@ type inMemoryCache struct {
size, limit int size, limit int
} }
// NewInMemoryCache creates a new in memory cache which stores up to the given number of bytes of cached data. // NewInMemoryCache creates a new in memory cache which stores up to the given
// number of bytes of cached data.
// NOTE(GODT-1158): Make this threadsafe. // NOTE(GODT-1158): Make this threadsafe.
func NewInMemoryCache(limit int) Cache { func NewInMemoryCache(limit int) Cache {
return &inMemoryCache{ return &inMemoryCache{
@ -42,7 +43,7 @@ func (c *inMemoryCache) Unlock(userID string, passphrase []byte) error {
return nil return nil
} }
func (c *inMemoryCache) Delete(userID string) error { func (c *inMemoryCache) Lock(userID string) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -51,23 +52,44 @@ func (c *inMemoryCache) Delete(userID string) error {
} }
delete(c.data, userID) delete(c.data, userID)
}
func (c *inMemoryCache) Delete(userID string) error {
c.Lock(userID)
return nil return nil
} }
// Has returns whether the given message exists in the cache. // Has returns whether the given message exists in the cache.
func (c *inMemoryCache) Has(userID, messageID string) bool { func (c *inMemoryCache) Has(userID, messageID string) bool {
if _, err := c.Get(userID, messageID); err != nil { c.lock.RLock()
return false defer c.lock.RUnlock()
if !c.isUserUnlocked(userID) {
// This might look counter intuitive but in order to be able to test
// "re-unlocking" mechanism we need to return true here.
//
// The situation is the same as it would happen for onDiskCache with
// locked user. Later during `Get` cache would return proper error
// `ErrCacheNeedsUnlock`. It is expected that store would then try to
// re-unlock.
//
// In order to do proper behaviour we should implement
// encryption for inMemoryCache.
return true
} }
return true _, ok := c.data[userID][messageID]
return ok
} }
func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) { func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) {
c.lock.RLock() c.lock.RLock()
defer c.lock.RUnlock() defer c.lock.RUnlock()
if !c.isUserUnlocked(userID) {
return nil, ErrCacheNeedsUnlock
}
literal, ok := c.data[userID][messageID] literal, ok := c.data[userID][messageID]
if !ok { if !ok {
return nil, errors.New("no such message in cache") return nil, errors.New("no such message in cache")
@ -76,12 +98,23 @@ func (c *inMemoryCache) Get(userID, messageID string) ([]byte, error) {
return literal, nil return literal, nil
} }
// NOTE(GODT-1158): What to actually do when memory limit is reached? Replace something existing? Return error? Drop silently? func (c *inMemoryCache) isUserUnlocked(userID string) bool {
// NOTE(GODT-1158): Pull in cache-rotating feature from old IMAP cache. _, ok := c.data[userID]
return ok
}
// Set saves the message literal to memory for further usage.
//
// NOTE(GODT-1158, GODT-1488): Once memory limit is reached we should do proper
// rotation based on usage frequency.
func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error { func (c *inMemoryCache) Set(userID, messageID string, literal []byte) error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if !c.isUserUnlocked(userID) {
return ErrCacheNeedsUnlock
}
if c.size+len(literal) > c.limit { if c.size+len(literal) > c.limit {
return nil return nil
} }
@ -96,6 +129,10 @@ func (c *inMemoryCache) Rem(userID, messageID string) error {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if !c.isUserUnlocked(userID) {
return nil
}
c.size -= len(c.data[userID][messageID]) c.size -= len(c.data[userID][messageID])
delete(c.data[userID], messageID) delete(c.data[userID], messageID)

View File

@ -17,8 +17,13 @@
package cache package cache
import "errors"
var ErrCacheNeedsUnlock = errors.New("cache needs to be unlocked")
type Cache interface { type Cache interface {
Unlock(userID string, passphrase []byte) error Unlock(userID string, passphrase []byte) error
Lock(userID string)
Delete(userID string) error Delete(userID string) error
Has(userID, messageID string) bool Has(userID, messageID string) bool

View File

@ -0,0 +1,122 @@
// Copyright (c) 2022 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestIsCachedCrashRecovers(t *testing.T) {
r := require.New(t)
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: "msg1",
Subject: "subject",
})
r.False(m.store.IsCached("msg1"))
m.store.cache = nil
r.False(m.store.IsCached("msg1"))
}
var wantLiteral = []byte("Mime-Version: 1.0\r\nContent-Transfer-Encoding: quoted-printable\r\nContent-Type: \r\nReferences: <msg1@protonmail.internalid>\r\nX-Pm-Date: Thu, 01 Jan 1970 00:00:00 +0000\r\nX-Pm-External-Id: <>\r\nX-Pm-Internal-Id: msg1\r\nX-Original-Date: Mon, 01 Jan 0001 00:00:00 +0000\r\nDate: Fri, 13 Aug 1982 00:00:00 +0000\r\nMessage-Id: <msg1@protonmail.internalid>\r\nSubject: subject\r\n\r\n")
func TestGetCachedMessageOK(t *testing.T) {
r := require.New(t)
m, clear := initMocks(t)
defer clear()
messageID := "msg1"
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: messageID,
Subject: "subject",
Flags: pmapi.FlagReceived,
Body: "body",
})
// Have build job
m.client.EXPECT().
KeyRingForAddressID(gomock.Any()).
Return(testPrivateKeyRing, nil).
Times(1)
haveLiteral, err := m.store.getCachedMessage(messageID)
r.NoError(err)
r.Equal(wantLiteral, haveLiteral)
r.True(m.store.IsCached(messageID))
// No build job
haveLiteral, err = m.store.getCachedMessage(messageID)
r.NoError(err)
r.Equal(wantLiteral, haveLiteral)
r.True(m.store.IsCached(messageID))
}
func TestGetCachedMessageCacheLocked(t *testing.T) {
r := require.New(t)
m, clear := initMocks(t)
defer clear()
messageID := "msg1"
m.newStoreNoEvents(t, true, &pmapi.Message{
ID: messageID,
Subject: "subject",
Flags: pmapi.FlagReceived,
Body: "body",
})
// Have build job
m.client.EXPECT().
KeyRingForAddressID(gomock.Any()).
Return(testPrivateKeyRing, nil).
Times(1)
haveLiteral, err := m.store.getCachedMessage(messageID)
r.NoError(err)
r.Equal(wantLiteral, haveLiteral)
r.True(m.store.IsCached(messageID))
// Lock cache
m.store.cache.Lock(m.store.user.ID())
// Have build job again due to failure
m.client.EXPECT().
KeyRingForAddressID(gomock.Any()).
Return(testPrivateKeyRing, nil).
Times(1)
haveLiteral, err = m.store.getCachedMessage(messageID)
r.NoError(err)
r.Equal(wantLiteral, haveLiteral)
r.True(m.store.IsCached(messageID))
// No build job
haveLiteral, err = m.store.getCachedMessage(messageID)
r.NoError(err)
r.Equal(wantLiteral, haveLiteral)
r.True(m.store.IsCached(messageID))
}

View File

@ -67,6 +67,10 @@ func (cacher *MsgCachePool) newJob(messageID string) {
} }
func (cacher *MsgCachePool) start() { func (cacher *MsgCachePool) start() {
if cacher.started {
return
}
cacher.started = true cacher.started = true
go func() { go func() {

View File

@ -330,7 +330,9 @@ func (storeMailbox *Mailbox) txGetFinalUID(b *bolt.Bucket) uint32 {
uid, _ := b.Cursor().Last() uid, _ := b.Cursor().Last()
if uid == nil { if uid == nil {
panic(errors.New("cannot get final UID of empty mailbox")) // This happened most probably due to empty mailbox and whole
// store needs to be re-initialize in order to fix it.
panic(errors.New("cannot get final UID"))
} }
return btoi(uid) return btoi(uid)

View File

@ -25,7 +25,10 @@ import (
func init() { //nolint[gochecknoinits] func init() { //nolint[gochecknoinits]
logrus.SetLevel(logrus.ErrorLevel) logrus.SetLevel(logrus.ErrorLevel)
if os.Getenv("VERBOSITY") == "trace" { switch os.Getenv("VERBOSITY") {
case "trace":
logrus.SetLevel(logrus.TraceLevel) logrus.SetLevel(logrus.TraceLevel)
case "debug":
logrus.SetLevel(logrus.DebugLevel)
} }
} }

View File

@ -24,7 +24,6 @@ import (
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message" pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
@ -101,20 +100,42 @@ func (message *Message) getRawHeader() ([]byte, error) {
} }
// GetHeader will return cached header from DB. // GetHeader will return cached header from DB.
func (message *Message) GetHeader() []byte { func (message *Message) GetHeader() ([]byte, error) {
raw, err := message.getRawHeader() raw, err := message.getRawHeader()
if err != nil { if err != nil {
panic(errors.Wrap(err, "failed to get raw message header")) message.store.log.
WithField("msgID", message.ID()).
WithError(err).
Warn("Cannot get raw header")
return nil, err
} }
return raw return raw, nil
}
// GetMIMEHeaderFast returns full header if message was cached. If full header
// is not available it will return header from metadata.
// NOTE: Returned header may not contain all fields.
func (message *Message) GetMIMEHeaderFast() (header textproto.MIMEHeader) {
var err error
if message.IsFullHeaderCached() {
header, err = message.GetMIMEHeader()
}
if header == nil || err != nil {
header = textproto.MIMEHeader(message.Message().Header)
}
return
} }
// GetMIMEHeader will return cached header from DB, parsed as a textproto.MIMEHeader. // GetMIMEHeader will return cached header from DB, parsed as a textproto.MIMEHeader.
func (message *Message) GetMIMEHeader() textproto.MIMEHeader { func (message *Message) GetMIMEHeader() (textproto.MIMEHeader, error) {
raw, err := message.getRawHeader() raw, err := message.getRawHeader()
if err != nil { if err != nil {
panic(errors.Wrap(err, "failed to get raw message header")) message.store.log.
WithField("msgID", message.ID()).
WithError(err).
Warn("Cannot get raw header for MIME header")
return nil, err
} }
header, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(raw))).ReadMIMEHeader() header, err := textproto.NewReader(bufio.NewReader(bytes.NewReader(raw))).ReadMIMEHeader()
@ -123,10 +144,10 @@ func (message *Message) GetMIMEHeader() textproto.MIMEHeader {
WithField("msgID", message.ID()). WithField("msgID", message.ID()).
WithError(err). WithError(err).
Warn("Cannot build header from bodystructure") Warn("Cannot build header from bodystructure")
return textproto.MIMEHeader(message.msg.Header) return nil, err
} }
return header return header, nil
} }
// GetBodyStructure returns the message's body structure. // GetBodyStructure returns the message's body structure.

View File

@ -187,7 +187,8 @@ func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, ms
mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client) mocks.user.EXPECT().GetClient().AnyTimes().Return(mocks.client)
mocks.client.EXPECT().GetUserKeyRing().Return(tests.MakeKeyRing(t), nil).AnyTimes() testUserKeyring := tests.MakeKeyRing(t)
mocks.client.EXPECT().GetUserKeyRing().Return(testUserKeyring, nil).AnyTimes()
mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{ mocks.client.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true}, {ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: true},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true}, {ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: true},
@ -225,6 +226,8 @@ func (mocks *mocksForStore) newStoreNoEvents(t *testing.T, combinedMode bool, ms
) )
require.NoError(mocks.tb, err) require.NoError(mocks.tb, err)
require.NoError(mocks.tb, mocks.store.UnlockCache(testUserKeyring))
// We want to wait until first sync has finished. // We want to wait until first sync has finished.
// Checking that event after sync was reuested is not the best way to // Checking that event after sync was reuested is not the best way to
// do the check, because sync could take more time, but sync is going // do the check, because sync could take more time, but sync is going

View File

@ -27,7 +27,6 @@ import (
pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message" pkgMsg "github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
a "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
@ -72,6 +71,7 @@ func TestGetMessageFromDB(t *testing.T) {
} }
func TestCreateOrUpdateMessageMetadata(t *testing.T) { func TestCreateOrUpdateMessageMetadata(t *testing.T) {
r := require.New(t)
m, clear := initMocks(t) m, clear := initMocks(t)
defer clear() defer clear()
@ -79,33 +79,37 @@ func TestCreateOrUpdateMessageMetadata(t *testing.T) {
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
metadata, err := m.store.getMessageFromDB("msg1") metadata, err := m.store.getMessageFromDB("msg1")
require.Nil(t, err) r.NoError(err)
msg := &Message{msg: metadata, store: m.store, storeMailbox: nil} msg := &Message{msg: metadata, store: m.store, storeMailbox: nil}
// Check non-meta and calculated data are cleared/empty. // Check non-meta and calculated data are cleared/empty.
a.Equal(t, "", metadata.Body) r.Equal("", metadata.Body)
a.Equal(t, []*pmapi.Attachment(nil), metadata.Attachments) r.Equal([]*pmapi.Attachment(nil), metadata.Attachments)
a.Equal(t, "", metadata.MIMEType) r.Equal("", metadata.MIMEType)
a.Equal(t, make(mail.Header), metadata.Header) r.Equal(make(mail.Header), metadata.Header)
wantHeader, wantSize := putBodystructureAndSizeToDB(m, "msg1") wantHeader, wantSize := putBodystructureAndSizeToDB(m, "msg1")
// Check cached data. // Check cached data.
require.Nil(t, err) haveHeader, err := msg.GetMIMEHeader()
a.Equal(t, wantHeader, msg.GetMIMEHeader()) r.NoError(err)
r.Equal(wantHeader, haveHeader)
haveSize, err := msg.GetRFC822Size() haveSize, err := msg.GetRFC822Size()
require.Nil(t, err) r.NoError(err)
a.Equal(t, wantSize, haveSize) r.Equal(wantSize, haveSize)
// Check cached data are not overridden by reinsert. // Check cached data are not overridden by reinsert.
insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel}) insertMessage(t, m, "msg1", "Test message 1", addrID1, false, []string{pmapi.AllMailLabel})
require.Nil(t, err) haveHeader, err = msg.GetMIMEHeader()
a.Equal(t, wantHeader, msg.GetMIMEHeader()) r.NoError(err)
r.Equal(wantHeader, haveHeader)
haveSize, err = msg.GetRFC822Size() haveSize, err = msg.GetRFC822Size()
require.Nil(t, err) r.NoError(err)
a.Equal(t, wantSize, haveSize) r.Equal(wantSize, haveSize)
} }
func TestDeleteMessage(t *testing.T) { func TestDeleteMessage(t *testing.T) {