We build too many walls and not enough bridges

This commit is contained in:
Jakub
2020-04-08 12:59:16 +02:00
commit 17f4d6097a
494 changed files with 62753 additions and 0 deletions

109
internal/store/address.go Normal file
View File

@ -0,0 +1,109 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/sirupsen/logrus"
)
// Address holds mailboxes for IMAP user (login address). In combined mode
// there is only one address, in split mode there is one object per address.
type Address struct {
store *Store
address string
addressID string
mailboxes map[string]*Mailbox
log *logrus.Entry
}
func newAddress(
store *Store,
address, addressID string,
labels []*pmapi.Label,
) (addr *Address, err error) {
l := log.WithField("addressID", addressID)
storeAddress := &Address{
store: store,
address: address,
addressID: addressID,
log: l,
}
if err = storeAddress.init(labels); err != nil {
l.WithField("address", address).
WithError(err).
Error("Could not initialise store address")
return
}
return storeAddress, nil
}
func (storeAddress *Address) init(foldersAndLabels []*pmapi.Label) (err error) {
storeAddress.log.WithField("address", storeAddress.address).Debug("Initialising store address")
storeAddress.mailboxes = make(map[string]*Mailbox)
for _, label := range foldersAndLabels {
prefix := getLabelPrefix(label)
var mailbox *Mailbox
if mailbox, err = newMailbox(storeAddress, label.ID, prefix, label.Name, label.Color); err != nil {
storeAddress.log.
WithError(err).
WithField("labelID", label.ID).
Error("Could not init mailbox for folder or label")
return
}
storeAddress.mailboxes[label.ID] = mailbox
}
return
}
// getLabelPrefix returns the correct prefix for a pmapi label according to whether it is exclusive or not.
func getLabelPrefix(l *pmapi.Label) string {
switch {
case pmapi.IsSystemLabel(l.ID):
return ""
case l.Exclusive == 1:
return UserFoldersPrefix
default:
return UserLabelsPrefix
}
}
// AddressString returns the address.
func (storeAddress *Address) AddressString() string {
return storeAddress.address
}
// AddressID returns the address ID.
func (storeAddress *Address) AddressID() string {
return storeAddress.addressID
}
// APIAddress returns the `pmapi.Address` struct.
func (storeAddress *Address) APIAddress() *pmapi.Address {
return storeAddress.store.api.Addresses().ByEmail(storeAddress.address)
}

View File

@ -0,0 +1,106 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"fmt"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
// ListMailboxes returns all mailboxes.
func (storeAddress *Address) ListMailboxes() []*Mailbox {
storeAddress.store.lock.RLock()
defer storeAddress.store.lock.RUnlock()
mailboxes := make([]*Mailbox, 0, len(storeAddress.mailboxes))
for _, m := range storeAddress.mailboxes {
mailboxes = append(mailboxes, m)
}
return mailboxes
}
// GetMailbox returns mailbox with the given IMAP name.
func (storeAddress *Address) GetMailbox(name string) (*Mailbox, error) {
storeAddress.store.lock.RLock()
defer storeAddress.store.lock.RUnlock()
for _, m := range storeAddress.mailboxes {
if m.Name() == name {
return m, nil
}
}
return nil, fmt.Errorf("mailbox %v does not exist", name)
}
// CreateMailbox creates the mailbox by calling an API.
// Mailbox is created in the structure by processing event.
func (storeAddress *Address) CreateMailbox(name string) error {
return storeAddress.store.createMailbox(name)
}
// updateMailbox updates the mailbox by calling an API.
// Mailbox is updated in the structure by processing event.
func (storeAddress *Address) updateMailbox(labelID, newName, color string) error {
return storeAddress.store.updateMailbox(labelID, newName, color)
}
// deleteMailbox deletes the mailbox by calling an API.
// Mailbox is deleted in the structure by processing event.
func (storeAddress *Address) deleteMailbox(labelID string) error {
return storeAddress.store.deleteMailbox(labelID, storeAddress.addressID)
}
// createOrUpdateMailboxEvent creates or updates the mailbox in the structure.
// This is called from the event loop.
func (storeAddress *Address) createOrUpdateMailboxEvent(label *pmapi.Label) error {
prefix := getLabelPrefix(label)
mailbox, ok := storeAddress.mailboxes[label.ID]
if !ok {
mailbox, err := newMailbox(storeAddress, label.ID, prefix, label.Name, label.Color)
if err != nil {
return err
}
storeAddress.mailboxes[label.ID] = mailbox
} else {
mailbox.labelName = prefix + label.Name
mailbox.color = label.Color
}
return nil
}
// deleteMailboxEvent deletes the mailbox in the structure.
// This is called from the event loop.
func (storeAddress *Address) deleteMailboxEvent(labelID string) error {
storeMailbox, ok := storeAddress.mailboxes[labelID]
if !ok {
log.WithField("labelID", labelID).Warn("Could not find mailbox to delete")
return nil
}
delete(storeAddress.mailboxes, labelID)
return storeMailbox.deleteMailboxEvent()
}
func (storeAddress *Address) getMailboxByID(labelID string) (*Mailbox, error) {
storeMailbox, ok := storeAddress.mailboxes[labelID]
if !ok {
return nil, fmt.Errorf("mailbox with id %q does not exist", labelID)
}
return storeMailbox, nil
}

View File

@ -0,0 +1,42 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
bolt "go.etcd.io/bbolt"
)
func (storeAddress *Address) txCreateOrUpdateMessages(tx *bolt.Tx, msgs []*pmapi.Message) error {
for _, m := range storeAddress.mailboxes {
if err := m.txCreateOrUpdateMessages(tx, msgs); err != nil {
return err
}
}
return nil
}
// txDeleteMessage deletes the message from the mailbox buckets for this address.
func (storeAddress *Address) txDeleteMessage(tx *bolt.Tx, apiID string) error {
for _, m := range storeAddress.mailboxes {
if err := m.txDeleteMessage(tx, apiID); err != nil {
return err
}
}
return nil
}

114
internal/store/cache.go Normal file
View File

@ -0,0 +1,114 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"os"
"sync"
"github.com/pkg/errors"
)
// Cache caches the last event IDs for all accounts (there should be only one instance).
type Cache struct {
// cache is map from userID => key (such as last event) => value (such as event ID).
cache map[string]map[string]string
path string
lock *sync.RWMutex
}
// NewCache constructs a new cache at the given path.
func NewCache(path string) *Cache {
return &Cache{
path: path,
lock: &sync.RWMutex{},
}
}
func (c *Cache) getEventID(userID string) string {
c.lock.Lock()
defer c.lock.Unlock()
_ = c.loadCache()
if c.cache == nil {
c.cache = map[string]map[string]string{}
}
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
return c.cache[userID]["events"]
}
func (c *Cache) setEventID(userID, eventID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache[userID] == nil {
c.cache[userID] = map[string]string{}
}
c.cache[userID]["events"] = eventID
return c.saveCache()
}
func (c *Cache) loadCache() error {
if c.cache != nil {
return nil
}
f, err := os.Open(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewDecoder(f).Decode(&c.cache)
}
func (c *Cache) saveCache() error {
if c.cache == nil {
return errors.New("events: cannot save cache: cache is nil")
}
f, err := os.Create(c.path)
if err != nil {
return err
}
defer f.Close() //nolint[errcheck]
return json.NewEncoder(f).Encode(c.cache)
}
func (c *Cache) clearCacheUser(userID string) error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cache == nil {
log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil")
return nil
}
log.WithField("user", userID).Trace("Removing user from event loop cache")
delete(c.cache, userID)
return c.saveCache()
}

109
internal/store/change.go Normal file
View File

@ -0,0 +1,109 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"time"
"github.com/ProtonMail/proton-bridge/pkg/message"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
imap "github.com/emersion/go-imap"
imapBackend "github.com/emersion/go-imap/backend"
"github.com/sirupsen/logrus"
)
// SetIMAPUpdateChannel sets the channel on which imap update messages will be sent. This should be the channel
// on which the imap backend listens for imap updates.
func (store *Store) SetIMAPUpdateChannel(updates chan interface{}) {
store.log.Debug("Listening for IMAP updates")
if store.imapUpdates = updates; store.imapUpdates == nil {
store.log.Error("The IMAP Updates channel is nil")
}
}
func (store *Store) imapNotice(address, notice string) {
update := new(imapBackend.StatusUpdate)
update.Username = address
update.StatusResp = &imap.StatusResp{
Type: imap.StatusOk,
Code: imap.CodeAlert,
Info: notice,
}
store.imapSendUpdate(update)
}
func (store *Store) imapUpdateMessage(address, mailboxName string, uid, sequenceNumber uint32, msg *pmapi.Message) {
store.log.WithFields(logrus.Fields{
"address": address,
"mailbox": mailboxName,
"seqNum": sequenceNumber,
"uid": uid,
"flags": message.GetFlags(msg),
}).Trace("IDLE update")
update := new(imapBackend.MessageUpdate)
update.Username = address
update.Mailbox = mailboxName
update.Message = imap.NewMessage(sequenceNumber, []string{imap.FlagsMsgAttr, imap.UidMsgAttr})
update.Message.Flags = message.GetFlags(msg)
update.Message.Uid = uid
store.imapSendUpdate(update)
}
func (store *Store) imapDeleteMessage(address, mailboxName string, sequenceNumber uint32) {
store.log.WithFields(logrus.Fields{
"address": address,
"mailbox": mailboxName,
"seqNum": sequenceNumber,
}).Trace("IDLE delete")
update := new(imapBackend.ExpungeUpdate)
update.Username = address
update.Mailbox = mailboxName
update.SeqNum = sequenceNumber
store.imapSendUpdate(update)
}
func (store *Store) imapMailboxStatus(address, mailboxName string, total, unread uint) {
store.log.WithFields(logrus.Fields{
"address": address,
"mailbox": mailboxName,
"total": total,
"unread": unread,
}).Trace("IDLE status")
update := new(imapBackend.MailboxUpdate)
update.Username = address
update.Mailbox = mailboxName
update.MailboxStatus = imap.NewMailboxStatus(mailboxName, []string{imap.MailboxMessages, imap.MailboxUnseen})
update.MailboxStatus.Messages = uint32(total)
update.MailboxStatus.Unseen = uint32(unread)
store.imapSendUpdate(update)
}
func (store *Store) imapSendUpdate(update interface{}) {
if store.imapUpdates == nil {
store.log.Trace("IMAP IDLE unavailable")
return
}
select {
case <-time.After(1 * time.Second):
store.log.Error("Could not send IMAP update (timeout)")
return
case store.imapUpdates <- update:
}
}

View File

@ -0,0 +1,129 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
imapBackend "github.com/emersion/go-imap/backend"
"github.com/stretchr/testify/require"
)
func TestCreateOrUpdateMessageIMAPUpdates(t *testing.T) {
m, clear := initMocks(t)
defer clear()
updates := make(chan interface{})
m.newStoreNoEvents(true)
m.store.SetIMAPUpdateChannel(updates)
go checkIMAPUpdates(t, updates, []func(interface{}) bool{
checkMessageUpdate(addr1, "All Mail", 1, 1),
checkMessageUpdate(addr1, "All Mail", 2, 2),
})
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
close(updates)
}
func TestCreateOrUpdateMessageIMAPUpdatesBulkUpdate(t *testing.T) {
m, clear := initMocks(t)
defer clear()
updates := make(chan interface{})
m.newStoreNoEvents(true)
m.store.SetIMAPUpdateChannel(updates)
go checkIMAPUpdates(t, updates, []func(interface{}) bool{
checkMessageUpdate(addr1, "All Mail", 1, 1),
checkMessageUpdate(addr1, "All Mail", 2, 2),
})
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
close(updates)
}
func TestDeleteMessageIMAPUpdate(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
updates := make(chan interface{})
m.store.SetIMAPUpdateChannel(updates)
go checkIMAPUpdates(t, updates, []func(interface{}) bool{
checkMessageDelete(addr1, "All Mail", 2),
checkMessageDelete(addr1, "All Mail", 1),
})
require.Nil(t, m.store.deleteMessageEvent("msg2"))
require.Nil(t, m.store.deleteMessageEvent("msg1"))
close(updates)
}
func checkIMAPUpdates(t *testing.T, updates chan interface{}, checkFunctions []func(interface{}) bool) {
idx := 0
for update := range updates {
if idx >= len(checkFunctions) {
continue
}
if !checkFunctions[idx](update) {
continue
}
idx++
}
require.True(t, idx == len(checkFunctions), "Less updates than expected: %+v of %+v", idx, len(checkFunctions))
}
func checkMessageUpdate(username, mailbox string, seqNum, uid int) func(interface{}) bool { //nolint[unparam]
return func(update interface{}) bool {
switch u := update.(type) {
case *imapBackend.MessageUpdate:
return (u.Update.Username == username &&
u.Update.Mailbox == mailbox &&
u.Message.SeqNum == uint32(seqNum) &&
u.Message.Uid == uint32(uid))
default:
return false
}
}
}
func checkMessageDelete(username, mailbox string, seqNum int) func(interface{}) bool { //nolint[unparam]
return func(update interface{}) bool {
switch u := update.(type) {
case *imapBackend.ExpungeUpdate:
return (u.Update.Username == username &&
u.Update.Mailbox == mailbox &&
u.SeqNum == uint32(seqNum))
default:
return false
}
}
}

32
internal/store/convert.go Normal file
View File

@ -0,0 +1,32 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import "encoding/binary"
// itob returns a 4-byte big endian representation of v.
func itob(v uint32) []byte {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, v)
return b
}
// btoi returns the uint32 represented by b.
func btoi(b []byte) uint32 {
return binary.BigEndian.Uint32(b)
}

View File

@ -0,0 +1,546 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"time"
bridgeEvents "github.com/ProtonMail/proton-bridge/internal/events"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const pollInterval = 30 * time.Second
type eventLoop struct {
cache *Cache
currentEventID string
pollCh chan chan struct{}
stopCh chan struct{}
notifyStopCh chan struct{}
isRunning bool
hasInternet bool
log *logrus.Entry
store *Store
apiClient PMAPIProvider
user BridgeUser
events listener.Listener
}
func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser, events listener.Listener) *eventLoop {
eventLog := log.WithField("userID", user.ID())
eventLog.Trace("Creating new event loop")
return &eventLoop{
cache: cache,
currentEventID: cache.getEventID(user.ID()),
pollCh: make(chan chan struct{}),
isRunning: false,
log: eventLog,
store: store,
apiClient: api,
user: user,
events: events,
}
}
func (loop *eventLoop) IsRunning() bool {
return loop.isRunning
}
func (loop *eventLoop) setFirstEventID() (err error) {
loop.log.Trace("Setting first event ID")
event, err := loop.apiClient.GetEvent("")
if err != nil {
loop.log.WithError(err).Error("Could not get latest event ID")
return
}
loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
loop.log.WithError(err).Error("Could not set latest event ID in user cache")
return
}
return
}
// pollNow starts polling events right away and waits till the events are
// processed so we are sure updates are propagated to the database.
func (loop *eventLoop) pollNow() {
eventProcessedCh := make(chan struct{})
loop.pollCh <- eventProcessedCh
<-eventProcessedCh
close(eventProcessedCh)
}
func (loop *eventLoop) stop() {
if loop.isRunning {
loop.isRunning = false
close(loop.stopCh)
select {
case <-loop.notifyStopCh:
loop.log.Info("Event loop was stopped")
case <-time.After(1 * time.Second):
loop.log.Warn("Timed out waiting for event loop to stop")
}
}
}
func (loop *eventLoop) start() { // nolint[funlen]
if loop.isRunning {
return
}
defer func() {
loop.isRunning = false
}()
loop.stopCh = make(chan struct{})
loop.notifyStopCh = make(chan struct{})
loop.isRunning = true
events := make(chan *pmapi.Event)
defer close(events)
loop.log.WithField("lastEventID", loop.currentEventID).Info("Subscribed to events")
defer func() {
loop.log.WithField("lastEventID", loop.currentEventID).Info("Subscription stopped")
}()
t := time.NewTicker(pollInterval)
defer t.Stop()
loop.hasInternet = true
go loop.pollNow()
for {
var eventProcessedCh chan struct{}
select {
case <-loop.stopCh:
close(loop.notifyStopCh)
return
case eventProcessedCh = <-loop.pollCh:
case <-t.C:
}
// Before we fetch the first event, check whether this is the first time we've
// started the event loop, and if so, trigger a full sync.
// In case internet connection was not available during start, it will be
// handled anyway when the connection is back here.
if loop.isBeforeFirstStart() {
if eventErr := loop.setFirstEventID(); eventErr != nil {
loop.log.WithError(eventErr).Warn("Could not set initial event ID")
}
}
// If the sync is not finished then a new sync is triggered.
if !loop.store.isSyncFinished() {
loop.store.triggerSync()
}
more, err := loop.processNextEvent()
if eventProcessedCh != nil {
eventProcessedCh <- struct{}{}
}
if err != nil {
loop.log.WithError(err).Error("Cannot process event, stopping event loop")
// When event loop stops, the only way to start it again is by login.
// It should stop only when user is logged out but even if there is other
// serious error, logout is intended action.
if errLogout := loop.user.Logout(); errLogout != nil {
loop.log.
WithError(errLogout).
Error("Failed to logout user after loop finished with error")
}
return
}
if more {
go loop.pollNow()
}
}
}
// isBeforeFirstStart returns whether the initial event ID was already set or not.
func (loop *eventLoop) isBeforeFirstStart() bool {
return loop.currentEventID == ""
}
// processNextEvent saves only successfully processed `eventID` into cache
// (disk). It will filter out in defer all errors except invalid token error.
// Invalid error will be returned and stop the event loop.
func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[funlen]
l := loop.log.WithField("currentEventID", loop.currentEventID)
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
// (e.g. no internet, ulimit reached etc.)
defer func() {
if errors.Cause(err) == pmapi.ErrAPINotReachable {
l.Warn("Internet unavailable")
loop.events.Emit(bridgeEvents.InternetOffEvent, "")
loop.hasInternet = false
err = nil
}
if err != nil && isFdCloseToULimit() {
l.Warn("Ulimit reached")
loop.events.Emit(bridgeEvents.RestartBridgeEvent, "")
err = nil
}
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
l.Warn("Need to upgrade application")
loop.events.Emit(bridgeEvents.UpgradeApplicationEvent, "")
err = nil
}
_, errUnauthorized := errors.Cause(err).(*pmapi.ErrUnauthorized)
// All errors except Invalid Token (which is not possible to recover from) are ignored.
if err != nil && !errUnauthorized && errors.Cause(err) != pmapi.ErrInvalidToken {
l.WithError(err).Trace("Error skipped")
err = nil
}
}()
l.Trace("Polling next event")
var event *pmapi.Event
if event, err = loop.apiClient.GetEvent(loop.currentEventID); err != nil {
return false, errors.Wrap(err, "failed to get event")
}
l = l.WithField("newEventID", event.EventID)
if !loop.hasInternet {
loop.events.Emit(bridgeEvents.InternetOnEvent, "")
loop.hasInternet = true
}
if err = loop.processEvent(event); err != nil {
return false, errors.Wrap(err, "failed to process event")
}
if loop.currentEventID != event.EventID {
// In case new event ID cannot be saved to cache, we update it in event loop
// anyway and continue processing new events to prevent the loop from repeatedly
// processing the same event.
// This allows the event loop to continue to function (unless the cache was broken
// and bridge stopped, in which case it will start from the old event ID anyway).
loop.currentEventID = event.EventID
if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil {
return false, errors.Wrap(err, "failed to save event ID to cache")
}
}
return event.More == 1, err
}
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
eventLog := loop.log.WithField("event", event.EventID)
eventLog.Debug("Processing event")
if (event.Refresh & pmapi.EventRefreshMail) != 0 {
eventLog.Info("Processing refresh event")
loop.store.triggerSync()
return
}
if len(event.Addresses) != 0 {
if err = loop.processAddresses(eventLog, event.Addresses); err != nil {
return errors.Wrap(err, "failed to process address events")
}
}
if len(event.Labels) != 0 {
if err = loop.processLabels(eventLog, event.Labels); err != nil {
return errors.Wrap(err, "failed to process label events")
}
}
if len(event.Messages) != 0 {
if err = loop.processMessages(eventLog, event.Messages); err != nil {
return errors.Wrap(err, "failed to process message events")
}
}
// One would expect that every event would contain MessageCount as part of
// the event.Messages, but this is apparently not the case.
// MessageCounts are served on an irregular basis, so we should update and
// compare the counts only when we receive them.
if len(event.MessageCounts) != 0 {
if err = loop.processMessageCounts(eventLog, event.MessageCounts); err != nil {
return errors.Wrap(err, "failed to process message count events")
}
}
if len(event.Notices) != 0 {
loop.processNotices(eventLog, event.Notices)
}
return err
}
func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmapi.EventAddress) (err error) {
log.Debug("Processing address change event")
// Get old addresses for comparisons before updating user.
oldList := loop.apiClient.Addresses()
if err = loop.user.UpdateUser(); err != nil {
if logoutErr := loop.user.Logout(); logoutErr != nil {
log.WithError(logoutErr).Error("Failed to logout user after failed update")
}
return errors.Wrap(err, "failed to update user")
}
for _, addressEvent := range addressEvents {
switch addressEvent.Action {
case pmapi.EventCreate:
log.WithField("email", addressEvent.Address.Email).Debug("Address was created")
loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
case pmapi.EventUpdate:
oldAddress := oldList.ByID(addressEvent.ID)
if oldAddress == nil {
log.Warning("Event refers to an address that isn't present")
continue
}
email := oldAddress.Email
log.WithField("email", email).Debug("Address was updated")
if addressEvent.Address.Receive != oldAddress.Receive {
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
}
case pmapi.EventDelete:
oldAddress := oldList.ByID(addressEvent.ID)
if oldAddress == nil {
log.Warning("Event refers to an address that isn't present")
continue
}
email := oldAddress.Email
log.WithField("email", email).Debug("Address was deleted")
loop.user.CloseConnection(email)
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
}
}
if err = loop.store.createOrUpdateAddressInfo(loop.apiClient.Addresses()); err != nil {
return errors.Wrap(err, "failed to update address IDs in store")
}
if err = loop.store.createOrDeleteAddressesEvent(); err != nil {
return errors.Wrap(err, "failed to create/delete store addresses")
}
return nil
}
func (loop *eventLoop) processLabels(eventLog *logrus.Entry, labels []*pmapi.EventLabel) error {
eventLog.Debug("Processing label change event")
for _, eventLabel := range labels {
label := eventLabel.Label
switch eventLabel.Action {
case pmapi.EventCreate, pmapi.EventUpdate:
if err := loop.store.createOrUpdateMailboxEvent(label); err != nil {
return errors.Wrap(err, "failed to create or update label")
}
case pmapi.EventDelete:
if err := loop.store.deleteMailboxEvent(eventLabel.ID); err != nil {
return errors.Wrap(err, "failed to delete label")
}
}
}
return nil
}
func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi.EventMessage) (err error) {
eventLog.Debug("Processing message change event")
for _, message := range messages {
msgLog := eventLog.WithField("msgID", message.ID)
switch message.Action {
case pmapi.EventCreate:
msgLog.Debug("Processing EventCreate for message")
if message.Created == nil {
msgLog.Error("Got EventCreate with nil message")
break
}
if err = loop.store.createOrUpdateMessageEvent(message.Created); err != nil {
return errors.Wrap(err, "failed to put message into DB")
}
case pmapi.EventUpdate, pmapi.EventUpdateFlags:
msgLog.Debug("Processing EventUpdate(Flags) for message")
if message.Updated == nil {
msgLog.Errorf("Got EventUpdate(Flags) with nil message")
break
}
var msg *pmapi.Message
msg, err = loop.store.getMessageFromDB(message.ID)
if err == ErrNoSuchAPIID {
msgLog.WithError(err).Warning("Cannot get message from DB for updating. Trying fetch...")
msg, err = loop.store.fetchMessage(message.ID)
// If message does not exist anywhere, update event is probably old and off topic - skip it.
if err == ErrNoSuchAPIID {
msgLog.Warn("Skipping message update, because message does not exist nor in local DB or on API")
continue
}
}
if err != nil {
return errors.Wrap(err, "failed to get message from DB for updating")
}
updateMessage(msgLog, msg, message.Updated)
if err = loop.store.createOrUpdateMessageEvent(msg); err != nil {
return errors.Wrap(err, "failed to update message in DB")
}
case pmapi.EventDelete:
msgLog.Debug("Processing EventDelete for message")
if err = loop.store.deleteMessageEvent(message.ID); err != nil {
return errors.Wrap(err, "failed to delete message from DB")
}
}
}
return err
}
func updateMessage(msgLog *logrus.Entry, message *pmapi.Message, updates *pmapi.EventMessageUpdated) { //nolint[funlen]
msgLog.Debug("Updating message")
message.Time = updates.Time
if updates.Subject != nil {
msgLog.WithField("subject", *updates.Subject).Trace("Updating message value")
message.Subject = *updates.Subject
}
if updates.Sender != nil {
msgLog.WithField("sender", *updates.Sender).Trace("Updating message value")
message.Sender = updates.Sender
}
if updates.ToList != nil {
msgLog.WithField("toList", *updates.ToList).Trace("Updating message value")
message.ToList = *updates.ToList
}
if updates.CCList != nil {
msgLog.WithField("ccList", *updates.CCList).Trace("Updating message value")
message.CCList = *updates.CCList
}
if updates.BCCList != nil {
msgLog.WithField("bccList", *updates.BCCList).Trace("Updating message value")
message.BCCList = *updates.BCCList
}
if updates.Unread != nil {
msgLog.WithField("unread", *updates.Unread).Trace("Updating message value")
message.Unread = *updates.Unread
}
if updates.Flags != nil {
msgLog.WithField("flags", *updates.Flags).Trace("Updating message value")
message.Flags = *updates.Flags
}
if updates.LabelIDs != nil {
msgLog.WithField("labelIDs", updates.LabelIDs).Trace("Updating message value")
message.LabelIDs = updates.LabelIDs
} else {
for _, added := range updates.LabelIDsAdded {
hasLabel := false
for _, l := range message.LabelIDs {
if added == l {
hasLabel = true
break
}
}
if !hasLabel {
msgLog.WithField("added", added).Trace("Adding label to message")
message.LabelIDs = append(message.LabelIDs, added)
}
}
labels := []string{}
for _, l := range message.LabelIDs {
removeLabel := false
for _, removed := range updates.LabelIDsRemoved {
if removed == l {
removeLabel = true
break
}
}
if removeLabel {
msgLog.WithField("label", l).Trace("Removing label from message")
} else {
labels = append(labels, l)
}
}
message.LabelIDs = labels
}
}
func (loop *eventLoop) processMessageCounts(l *logrus.Entry, messageCounts []*pmapi.MessagesCount) error {
l.WithField("apiCounts", messageCounts).Debug("Processing message count change event")
isSynced, err := loop.store.isSynced(messageCounts)
if err != nil {
return err
}
if !isSynced {
loop.store.triggerSync()
}
return nil
}
func (loop *eventLoop) processNotices(l *logrus.Entry, notices []string) {
l.Debug("Processing notice change event")
for _, notice := range notices {
l.Infof("Notice: %q", notice)
for _, address := range loop.user.GetStoreAddresses() {
loop.store.imapNotice(address, notice)
}
}
}

View File

@ -0,0 +1,153 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"net/mail"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestEventLoopProcessMoreEvents(t *testing.T) {
m, clear := initMocks(t)
defer clear()
// Event expectations need to be defined before calling `newStoreNoEvents`
// to force to use these for this particular test.
// Also, event loop calls ListMessages again and we need to place it after
// calling `newStoreNoEvents` to not break expectations for the first sync.
gomock.InOrder(
// Doesn't matter which IDs are used.
// This test is trying to see whether event loop will immediately process
// next event if there is `More` of them.
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
EventID: "event50",
More: 1,
}, nil),
m.api.EXPECT().GetEvent("event50").Return(&pmapi.Event{
EventID: "event70",
More: 0,
}, nil),
m.api.EXPECT().GetEvent("event70").Return(&pmapi.Event{
EventID: "event71",
More: 0,
}, nil),
)
m.newStoreNoEvents(true)
m.api.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
// Event loop runs in goroutine and will be stopped by deferred mock clearing.
go m.store.eventLoop.start()
// More events are processed right away.
require.Eventually(t, func() bool {
return m.store.eventLoop.currentEventID == "event70"
}, time.Second, 10*time.Millisecond)
// For normal event we need to wait to next polling.
time.Sleep(pollInterval)
require.Eventually(t, func() bool {
return m.store.eventLoop.currentEventID == "event71"
}, time.Second, 10*time.Millisecond)
}
func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
m, clear := initMocks(t)
defer clear()
subject := "old subject"
newSubject := "new subject"
// First sync will add message with old subject to database.
m.api.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
ID: "msg1",
Subject: subject,
}, nil)
// Event will update the subject.
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
EventID: "event1",
Messages: []*pmapi.EventMessage{{
EventItem: pmapi.EventItem{
ID: "msg1",
Action: pmapi.EventUpdate,
},
Updated: &pmapi.EventMessageUpdated{
ID: "msg1",
Subject: &newSubject,
},
}},
}, nil)
m.newStoreNoEvents(true)
// Event loop runs in goroutine and will be stopped by deferred mock clearing.
go m.store.eventLoop.start()
require.Eventually(t, func() bool {
msg, err := m.store.getMessageFromDB("msg1")
return err == nil && msg.Subject == newSubject
}, time.Second, 10*time.Millisecond)
}
func TestEventLoopUpdateMessage(t *testing.T) {
address1 := &mail.Address{Address: "user1@example.com"}
address2 := &mail.Address{Address: "user2@example.com"}
msg := &pmapi.Message{
ID: "msg1",
Subject: "old",
Unread: 0,
Flags: 10,
Sender: address1,
ToList: []*mail.Address{address2},
CCList: []*mail.Address{address1},
BCCList: []*mail.Address{},
Time: 20,
LabelIDs: []string{"old"},
}
newMsg := &pmapi.Message{
ID: "msg1",
Subject: "new",
Unread: 1,
Flags: 11,
Sender: address2,
ToList: []*mail.Address{address1},
CCList: []*mail.Address{address2},
BCCList: []*mail.Address{address1},
Time: 21,
LabelIDs: []string{"new"},
}
updateMessage(log, msg, &pmapi.EventMessageUpdated{
ID: "msg1",
Subject: &newMsg.Subject,
Unread: &newMsg.Unread,
Flags: &newMsg.Flags,
Sender: newMsg.Sender,
ToList: &newMsg.ToList,
CCList: &newMsg.CCList,
BCCList: &newMsg.BCCList,
Time: newMsg.Time,
LabelIDs: newMsg.LabelIDs,
})
require.Equal(t, newMsg, msg)
}

265
internal/store/mailbox.go Normal file
View File

@ -0,0 +1,265 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"fmt"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
// Mailbox is mailbox for specific address and mailbox.
type Mailbox struct {
store *Store
storeAddress *Address
labelID string
labelPrefix string
labelName string
color string
log *logrus.Entry
}
func newMailbox(storeAddress *Address, labelID, labelPrefix, labelName, color string) (mb *Mailbox, err error) {
l := log.
WithField("addrID", storeAddress.addressID).
WithField("lblID", labelID)
mb = &Mailbox{
store: storeAddress.store,
storeAddress: storeAddress,
labelID: labelID,
labelPrefix: labelPrefix,
labelName: labelPrefix + labelName,
color: color,
log: l,
}
if err = mb.store.db.Update(func(tx *bolt.Tx) error {
return initMailboxBucket(tx, mb.getBucketName())
}); err != nil {
l.WithError(err).Error("Could not initialise mailbox buckets")
}
syncDraftsIfNecssary(mb)
return
}
func syncDraftsIfNecssary(mb *Mailbox) { //nolint[funlen]
// We didn't support drafts before v1.2.6 and therefore if we now created
// Drafts mailbox we need to check whether counts match (drafts are synced).
// If not, sync them from local metadata without need to do full resync,
// Can be removed with 1.2.7 or later.
if mb.labelID != pmapi.DraftLabel {
return
}
// If the drafts mailbox total is non-zero, it means it has already been used
// and there is no need to continue. Otherwise, we may need to do an initial sync.
total, _, err := mb.GetCounts()
if err != nil || total != 0 {
return
}
counts, err := mb.store.getOnAPICounts()
if err != nil {
return
}
foundCounts := false
doSync := false
for _, count := range counts {
if count.LabelID != pmapi.DraftLabel {
continue
}
foundCounts = true
log.WithField("total", total).WithField("total-api", count.TotalOnAPI).Debug("Drafts mailbox created: checking need for sync")
if count.TotalOnAPI == total {
continue
}
doSync = true
break
}
if !foundCounts {
log.Debug("Drafts mailbox created: missing counts, refreshing")
_ = mb.store.updateCountsFromServer()
}
if !foundCounts || doSync {
err := mb.store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(metadataBucket).ForEach(func(k, v []byte) error {
msg := &pmapi.Message{}
if err := json.Unmarshal(v, msg); err != nil {
return err
}
for _, msgLabelID := range msg.LabelIDs {
if msgLabelID == pmapi.DraftLabel {
log.WithField("id", msg.ID).Debug("Drafts mailbox created: syncing draft locally")
_ = mb.txCreateOrUpdateMessages(tx, []*pmapi.Message{msg})
break
}
}
return nil
})
})
log.WithError(err).Info("Drafts mailbox created: synced localy")
}
}
func initMailboxBucket(tx *bolt.Tx, bucketName []byte) error {
bucket, err := tx.Bucket(mailboxesBucket).CreateBucketIfNotExists(bucketName)
if err != nil {
return err
}
if _, err := bucket.CreateBucketIfNotExists(imapIDsBucket); err != nil {
return err
}
if _, err := bucket.CreateBucketIfNotExists(apiIDsBucket); err != nil {
return err
}
return nil
}
// LabelID returns ID of mailbox.
func (storeMailbox *Mailbox) LabelID() string {
return storeMailbox.labelID
}
// Name returns the name of mailbox.
func (storeMailbox *Mailbox) Name() string {
return storeMailbox.labelName
}
// Color returns the color of mailbox.
func (storeMailbox *Mailbox) Color() string {
return storeMailbox.color
}
// UIDValidity returns the current value of structure version.
func (storeMailbox *Mailbox) UIDValidity() uint32 {
return storeMailbox.store.getMailboxesVersion()
}
// IsFolder returns whether the mailbox is a folder (has "Folders/" prefix).
func (storeMailbox *Mailbox) IsFolder() bool {
return storeMailbox.labelPrefix == UserFoldersPrefix
}
// IsLabel returns whether the mailbox is a label (has "Labels/" prefix).
func (storeMailbox *Mailbox) IsLabel() bool {
return storeMailbox.labelPrefix == UserLabelsPrefix
}
// IsSystem returns whether the mailbox is one of the specific system mailboxes (has no prefix).
func (storeMailbox *Mailbox) IsSystem() bool {
return storeMailbox.labelPrefix == ""
}
// Rename updates the mailbox by calling an API.
// Change has to be propagated to all the same mailboxes in all addresses.
// The propagation is processed by the event loop.
func (storeMailbox *Mailbox) Rename(newName string) error {
if storeMailbox.IsSystem() {
return fmt.Errorf("cannot rename system mailboxes")
}
if storeMailbox.IsFolder() {
if !strings.HasPrefix(newName, UserFoldersPrefix) {
return fmt.Errorf("cannot rename folder to non-folder")
}
newName = strings.TrimPrefix(newName, UserFoldersPrefix)
}
if storeMailbox.IsLabel() {
if !strings.HasPrefix(newName, UserLabelsPrefix) {
return fmt.Errorf("cannot rename label to non-label")
}
newName = strings.TrimPrefix(newName, UserLabelsPrefix)
}
return storeMailbox.storeAddress.updateMailbox(storeMailbox.labelID, newName, storeMailbox.color)
}
// Delete deletes the mailbox by calling an API.
// Deletion has to be propagated to all the same mailboxes in all addresses.
// The propagation is processed by the event loop.
func (storeMailbox *Mailbox) Delete() error {
return storeMailbox.storeAddress.deleteMailbox(storeMailbox.labelID)
}
// GetDelimiter returns the path separator.
func (storeMailbox *Mailbox) GetDelimiter() string {
return PathDelimiter
}
// deleteMailboxEvent deletes the mailbox bucket.
// This is called from the event loop.
func (storeMailbox *Mailbox) deleteMailboxEvent() error {
return storeMailbox.db().Update(func(tx *bolt.Tx) error {
return tx.Bucket(mailboxesBucket).DeleteBucket(storeMailbox.getBucketName())
})
}
// txGetIMAPIDsBucket returns the bucket mapping IMAP ID to API ID.
func (storeMailbox *Mailbox) txGetIMAPIDsBucket(tx *bolt.Tx) *bolt.Bucket {
return storeMailbox.txGetBucket(tx).Bucket(imapIDsBucket)
}
// txGetAPIIDsBucket returns the bucket mapping API ID to IMAP ID.
func (storeMailbox *Mailbox) txGetAPIIDsBucket(tx *bolt.Tx) *bolt.Bucket {
return storeMailbox.txGetBucket(tx).Bucket(apiIDsBucket)
}
// txGetBucket returns the bucket of mailbox containing mapping buckets.
func (storeMailbox *Mailbox) txGetBucket(tx *bolt.Tx) *bolt.Bucket {
return tx.Bucket(mailboxesBucket).Bucket(storeMailbox.getBucketName())
}
func getMailboxBucketName(addressID, labelID string) []byte {
return []byte(addressID + "-" + labelID)
}
// getBucketName returns the name of mailbox bucket.
func (storeMailbox *Mailbox) getBucketName() []byte {
return getMailboxBucketName(storeMailbox.storeAddress.addressID, storeMailbox.labelID)
}
// pollNow is a proxy for the store's eventloop's `pollNow()`.
func (storeMailbox *Mailbox) pollNow() {
storeMailbox.store.eventLoop.pollNow()
}
// api is a proxy for the store's `PMAPIProvider`.
func (storeMailbox *Mailbox) api() PMAPIProvider {
return storeMailbox.store.api
}
// update is a proxy for the store's db's `Update`.
func (storeMailbox *Mailbox) db() *bolt.DB {
return storeMailbox.store.db
}

View File

@ -0,0 +1,257 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"bytes"
"encoding/json"
"sort"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
// GetCounts returns numbers of total and unread messages in this mailbox bucket.
func (storeMailbox *Mailbox) GetCounts() (total, unread uint, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
total, unread, err = storeMailbox.txGetCounts(tx)
return err
})
return
}
func (storeMailbox *Mailbox) txGetCounts(tx *bolt.Tx) (total, unread uint, err error) {
// For total it would be enough to use `bolt.Bucket.Stats().KeyN` but
// we also need to retrieve the count of unread emails therefore we are
// looping all messages in this mailbox by `bolt.Cursor`
metaBucket := tx.Bucket(metadataBucket)
b := storeMailbox.txGetIMAPIDsBucket(tx)
c := b.Cursor()
imapID, apiID := c.First()
for ; imapID != nil; imapID, apiID = c.Next() {
total++
rawMsg := metaBucket.Get(apiID)
if rawMsg == nil {
return 0, 0, ErrNoSuchAPIID
}
// Do not unmarshal whole JSON to speed up the looping.
// Instead, we assume it will contain JSON int field `Unread`
// where `1` means true (i.e. message is unread)
if bytes.Contains(rawMsg, []byte(`"Unread":1`)) {
unread++
}
}
return total, unread, err
}
type mailboxCounts struct {
LabelID string
LabelName string
Color string
Order int
IsFolder bool
TotalOnAPI uint
UnreadOnAPI uint
}
func txGetCountsFromBucketOrNew(bkt *bolt.Bucket, labelID string) (*mailboxCounts, error) {
mc := &mailboxCounts{}
if mcJSON := bkt.Get([]byte(labelID)); mcJSON != nil {
if err := json.Unmarshal(mcJSON, mc); err != nil {
return nil, err
}
}
mc.LabelID = labelID // if it was empty before we need to set labelID
return mc, nil
}
func (mc *mailboxCounts) txWriteToBucket(bucket *bolt.Bucket) error {
mcJSON, err := json.Marshal(mc)
if err != nil {
return err
}
return bucket.Put([]byte(mc.LabelID), mcJSON)
}
func getSystemFolders() []*mailboxCounts {
return []*mailboxCounts{
{pmapi.InboxLabel, "INBOX", "#000", -1000, true, 0, 0},
{pmapi.SentLabel, "Sent", "#000", -9, true, 0, 0},
{pmapi.ArchiveLabel, "Archive", "#000", -8, true, 0, 0},
{pmapi.SpamLabel, "Spam", "#000", -7, true, 0, 0},
{pmapi.TrashLabel, "Trash", "#000", -6, true, 0, 0},
{pmapi.AllMailLabel, "All Mail", "#000", -5, true, 0, 0},
{pmapi.DraftLabel, "Drafts", "#000", -4, true, 0, 0},
}
}
// skipThisLabel decides to skip labelIDs that *are* pmapi system labels but *aren't* local system labels
// (i.e. if it's in `pmapi.SystemLabels` but not in `getSystemFolders` then we skip it, otherwise we don't).
func skipThisLabel(labelID string) bool {
switch labelID {
case pmapi.StarredLabel, pmapi.AllSentLabel, pmapi.AllDraftsLabel:
return true
}
return false
}
func sortByOrder(labels []*pmapi.Label) {
sort.Slice(labels, func(i, j int) bool {
return labels[i].Order < labels[j].Order
})
}
func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
return &pmapi.Label{
ID: mc.LabelID,
Name: mc.LabelName,
Color: mc.Color,
Order: mc.Order,
Type: pmapi.LabelTypeMailbox,
Exclusive: mc.isExclusive(),
}
}
func (mc *mailboxCounts) isExclusive() int {
if mc.IsFolder {
return 1
}
return 0
}
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
// Don't forget about system folders.
// It should set label id, name, color, isFolder, total, unread.
tx := func(tx *bolt.Tx) error {
countsBkt := tx.Bucket(countsBucket)
for _, label := range labels {
// Skipping is probably not necessary.
if skipThisLabel(label.ID) {
continue
}
// Get current data.
mailbox, err := txGetCountsFromBucketOrNew(countsBkt, label.ID)
if err != nil {
return err
}
// Update mailbox info, but dont change on-API-counts.
mailbox.LabelName = label.Name
mailbox.Color = label.Color
mailbox.Order = label.Order
mailbox.IsFolder = label.Exclusive == 1
// Write.
if err = mailbox.txWriteToBucket(countsBkt); err != nil {
return err
}
}
return nil
}
return store.db.Update(tx)
}
func (store *Store) getLabelsFromLocalStorage() ([]*pmapi.Label, error) {
countsOnAPI, err := store.getOnAPICounts()
if err != nil {
return nil, err
}
labels := []*pmapi.Label{}
for _, counts := range countsOnAPI {
labels = append(labels, counts.getPMLabel())
}
sortByOrder(labels)
return labels, nil
}
func (store *Store) getOnAPICounts() ([]*mailboxCounts, error) {
counts := []*mailboxCounts{}
tx := func(tx *bolt.Tx) error {
c := tx.Bucket(countsBucket).Cursor()
for k, countsB := c.First(); k != nil; k, countsB = c.Next() {
l := store.log.WithField("key", string(k))
if countsB == nil {
err := errors.New("empty counts in DB")
l.WithError(err).Error("While getting local labels")
return err
}
mbCounts := &mailboxCounts{}
if err := json.Unmarshal(countsB, mbCounts); err != nil {
l.WithError(err).Error("While unmarshaling local labels")
return err
}
counts = append(counts, mbCounts)
}
return nil
}
err := store.db.View(tx)
return counts, err
}
// createOrUpdateOnAPICounts will change only on-API-counts.
func (store *Store) createOrUpdateOnAPICounts(mailboxCountsOnAPI []*pmapi.MessagesCount) error {
store.log.WithField("apiCounts", mailboxCountsOnAPI).Debug("Updating API counts")
tx := func(tx *bolt.Tx) error {
countsBkt := tx.Bucket(countsBucket)
for _, countsOnAPI := range mailboxCountsOnAPI {
if skipThisLabel(countsOnAPI.LabelID) {
continue
}
// Get current data.
counts, err := txGetCountsFromBucketOrNew(countsBkt, countsOnAPI.LabelID)
if err != nil {
return err
}
// Update only counts.
counts.TotalOnAPI = uint(countsOnAPI.Total)
counts.UnreadOnAPI = uint(countsOnAPI.Unread)
if err = counts.txWriteToBucket(countsBkt); err != nil {
return err
}
}
return nil
}
return store.db.Update(tx)
}
func (store *Store) removeMailboxCount(labelID string) error {
err := store.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(countsBucket).Delete([]byte(labelID))
})
if err != nil {
store.log.WithError(err).
WithField("labelID", labelID).
Warning("Cannot remove counts")
}
return err
}

View File

@ -0,0 +1,126 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
a "github.com/stretchr/testify/assert"
)
func newLabel(order int, id, name string) *pmapi.Label {
return &pmapi.Label{
ID: id,
Name: name,
Order: order,
}
}
func TestSortByOrder(t *testing.T) {
want := []*pmapi.Label{
newLabel(-1000, pmapi.InboxLabel, "INBOX"),
newLabel(-5, pmapi.SentLabel, "Sent"),
newLabel(-4, pmapi.ArchiveLabel, "Archive"),
newLabel(-3, pmapi.SpamLabel, "Spam"),
newLabel(-2, pmapi.TrashLabel, "Trash"),
newLabel(-1, pmapi.AllMailLabel, "All Mail"),
newLabel(100, "labelID1", "custom_label"),
newLabel(1000, "folderID1", "custom_folder"),
}
labels := []*pmapi.Label{
want[6],
want[4],
want[3],
want[7],
want[5],
want[0],
want[2],
want[1],
}
sortByOrder(labels)
a.Equal(t, want, labels)
}
func TestMailboxNames(t *testing.T) {
want := map[string]string{
pmapi.InboxLabel: "INBOX",
pmapi.SentLabel: "Sent",
pmapi.ArchiveLabel: "Archive",
pmapi.SpamLabel: "Spam",
pmapi.TrashLabel: "Trash",
pmapi.AllMailLabel: "All Mail",
pmapi.DraftLabel: "Drafts",
"labelID1": "Labels/Label1",
"folderID1": "Folders/Folder1",
}
foldersAndLabels := []*pmapi.Label{
newLabel(100, "labelID1", "Label1"),
newLabel(1000, "folderID1", "Folder1"),
}
foldersAndLabels[1].Exclusive = 1
for _, counts := range getSystemFolders() {
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())
}
got := map[string]string{}
for _, m := range foldersAndLabels {
got[m.ID] = getLabelPrefix(m) + m.Name
}
a.Equal(t, want, got)
}
func TestAddSystemLabels(t *testing.T) {}
func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Store) {
nSystemFolders := 7
haveCounts, err := haveStore.getOnAPICounts()
a.NoError(t, err)
a.Len(t, haveCounts, len(wantCounts)+nSystemFolders)
for iWant, wantCount := range wantCounts {
iHave := iWant + nSystemFolders
haveCount := haveCounts[iHave]
a.Equal(t, wantCount.LabelID, haveCount.LabelID, "iHave:%d\niWant:%d\nHave:%v\nWant:%v", iHave, iWant, haveCount, wantCount)
a.Equal(t, wantCount.Total, int(haveCount.TotalOnAPI), "iHave:%d\niWant:%d\nHave:%v\nWant:%v", iHave, iWant, haveCount, wantCount)
a.Equal(t, wantCount.Unread, int(haveCount.UnreadOnAPI), "iHave:%d\niWant:%d\nHave:%v\nWant:%v", iHave, iWant, haveCount, wantCount)
}
}
func TestMailboxCountRemove(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
testCounts := []*pmapi.MessagesCount{
{LabelID: "label1", Total: 100, Unread: 0},
{LabelID: "label2", Total: 100, Unread: 30},
{LabelID: "label4", Total: 100, Unread: 100},
}
a.NoError(t, m.store.createOrUpdateOnAPICounts(testCounts))
a.NoError(t, m.store.removeMailboxCount("not existing"))
checkCounts(t, testCounts, m.store)
var pop *pmapi.MessagesCount
pop, testCounts = testCounts[2], testCounts[0:2]
a.NoError(t, m.store.removeMailboxCount(pop.LabelID))
checkCounts(t, testCounts, m.store)
}

View File

@ -0,0 +1,263 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"bytes"
"math"
"net/mail"
"regexp"
"strings"
"github.com/ProtonMail/proton-bridge/internal/imap/uidplus"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
// GetAPIIDsFromUIDRange returns API IDs by IMAP UID range.
//
// API IDs are the long base64 strings that the API uses to identify messages.
// UIDs are unique increasing integers that must be unique within a mailbox.
func (storeMailbox *Mailbox) GetAPIIDsFromUIDRange(start, stop uint32) (apiIDs []string, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetIMAPIDsBucket(tx)
if stop == 0 {
// A null stop means no stop.
stop = ^uint32(0)
}
startb := itob(start)
stopb := itob(stop)
c := b.Cursor()
for k, v := c.Seek(startb); k != nil && bytes.Compare(k, stopb) <= 0; k, v = c.Next() {
apiIDs = append(apiIDs, string(v))
}
return nil
})
return
}
// GetAPIIDsFromSequenceRange returns API IDs by IMAP sequence number range.
func (storeMailbox *Mailbox) GetAPIIDsFromSequenceRange(start, stop uint32) (apiIDs []string, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetIMAPIDsBucket(tx)
c := b.Cursor()
var i uint32
for k, v := c.First(); k != nil; k, v = c.Next() {
i++
if i < start {
continue
}
if stop > 0 && i > stop {
break
}
apiIDs = append(apiIDs, string(v))
}
return nil
})
return
}
// GetLatestAPIID returns the latest message API ID which still exists.
// Info: not the latest IMAP UID which can be already removed.
func (storeMailbox *Mailbox) GetLatestAPIID() (apiID string, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetAPIIDsBucket(tx)
c := b.Cursor()
lastAPIID, _ := c.Last()
apiID = string(lastAPIID)
if apiID == "" {
return errors.New("cannot get latest API ID: empty mailbox")
}
return nil
})
return
}
// GetNextUID returns the next IMAP UID.
func (storeMailbox *Mailbox) GetNextUID() (uid uint32, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetIMAPIDsBucket(tx)
uid, err = storeMailbox.txGetNextUID(b, false)
return err
})
return
}
func (storeMailbox *Mailbox) txGetNextUID(imapIDBucket *bolt.Bucket, write bool) (uint32, error) {
var uid uint64
var err error
if write {
uid, err = imapIDBucket.NextSequence()
if err != nil {
return 0, err
}
} else {
uid = imapIDBucket.Sequence() + 1
}
if math.MaxUint32 <= uid {
return 0, errors.New("too large sequence number")
}
return uint32(uid), nil
}
// getUID returns IMAP UID in this mailbox for message ID.
func (storeMailbox *Mailbox) getUID(apiID string) (uid uint32, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
uid, err = storeMailbox.txGetUID(tx, apiID)
return err
})
return
}
func (storeMailbox *Mailbox) txGetUID(tx *bolt.Tx, apiID string) (uint32, error) {
b := storeMailbox.txGetAPIIDsBucket(tx)
v := b.Get([]byte(apiID))
if v == nil {
return 0, ErrNoSuchAPIID
}
return btoi(v), nil
}
// getSequenceNumber returns IMAP sequence number in the mailbox for the message with the given API ID `apiID`.
func (storeMailbox *Mailbox) getSequenceNumber(apiID string) (seqNum uint32, err error) {
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetIMAPIDsBucket(tx)
uid, err := storeMailbox.txGetUID(tx, apiID)
if err != nil {
return err
}
seqNum, err = storeMailbox.txGetSequenceNumberOfUID(b, itob(uid))
return err
})
return
}
// txGetSequenceNumberOfUID returns the IMAP sequence number of the message
// with the given IMAP UID bytes `uidb`.
//
// NOTE: The `bolt.Cursor.Next()` loops in order of ascending key bytes. The
// IMAP UID bucket is ordered by increasing UID because it's using BigEndian to
// encode uint into byte. Hence the sequence number (IMAP ID) corresponds to
// position of uid key in this order.
func (storeMailbox *Mailbox) txGetSequenceNumberOfUID(bucket *bolt.Bucket, uidb []byte) (uint32, error) {
seqNum := uint32(0)
c := bucket.Cursor()
// Speed up for the case of last message. This is always true for
// adding new message. It will return number of keys in bucket because
// sequence number starts with 1.
// We cannot use bucket.Stats() for that--it doesn't work in the same
// transaction because stats are updated when transaction is committed.
// But we can at least optimise to not do equal for all keys.
lastKey, _ := c.Last()
isLast := bytes.Equal(lastKey, uidb)
for k, _ := c.First(); k != nil; k, _ = c.Next() {
seqNum++ // Sequence number starts at 1.
if isLast {
continue
}
if bytes.Equal(k, uidb) {
return seqNum, nil
}
}
if isLast {
return seqNum, nil
}
return 0, ErrNoSuchUID
}
// GetUIDList returns UID list corresponding to messageIDs in a requested order.
func (storeMailbox *Mailbox) GetUIDList(apiIDs []string) *uidplus.OrderedSeq {
seqSet := &uidplus.OrderedSeq{}
_ = storeMailbox.db().View(func(tx *bolt.Tx) error {
b := storeMailbox.txGetAPIIDsBucket(tx)
for _, apiID := range apiIDs {
v := b.Get([]byte(apiID))
if v == nil {
storeMailbox.log.
WithField("msgID", apiID).
Warn("Cannot find UID")
continue
}
seqSet.Add(btoi(v))
}
return nil
})
return seqSet
}
// GetUIDByHeader returns UID of message existing in mailbox or zero if no match found.
func (storeMailbox *Mailbox) GetUIDByHeader(header *mail.Header) (foundUID uint32) {
if header == nil {
return uint32(0)
}
// Message-Id in appended-after-send mail is processed as ExternalID
// in PM message. Message-Id in normal copy/move will be the PM internal ID.
messageID := header.Get("Message-Id")
// The most often situation is that message is APPENDed after it was sent so the
// Message-ID will be reflected by ExternalID in API message meta-data.
externalID := strings.Trim(messageID, "<> ") // remove '<>' to improve match
matchExternalID := regexp.MustCompile(`"ExternalID":"` +
` *(\\u003c)? *` + // \u003c is equivalent to `<`
regexp.QuoteMeta(externalID) +
` *(\\u003e)? *` + // \u0033 is equivalent to `>`
`"`,
)
// It is possible that client will try to COPY existing message to Sent
// using APPEND command. In that case the Message-Id from header will
// be internal message ID and we need to check whether it's already there.
matchInternalID := bytes.Split([]byte(externalID), []byte("@"))[0]
_ = storeMailbox.db().View(func(tx *bolt.Tx) error {
metaBucket := tx.Bucket(metadataBucket)
b := storeMailbox.txGetIMAPIDsBucket(tx)
c := b.Cursor()
imapID, apiID := c.Last()
for ; imapID != nil; imapID, apiID = c.Prev() {
rawMeta := metaBucket.Get(apiID)
if rawMeta == nil {
storeMailbox.log.
WithField("IMAP-UID", imapID).
WithField("API-ID", apiID).
Warn("Cannot find meta-data while searching for externalID")
continue
}
if !matchExternalID.Match(rawMeta) && !bytes.Equal(apiID, matchInternalID) {
continue
}
foundUID = btoi(imapID)
return nil
}
return nil
})
return foundUID
}

View File

@ -0,0 +1,147 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"net/mail"
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
a "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type wantID struct {
appID string
uid int
}
func TestGetSequenceNumberAndGetUID(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
checkMailboxMessageIDs(t, m, pmapi.InboxLabel, []wantID{{"msg1", 1}, {"msg3", 2}})
checkMailboxMessageIDs(t, m, pmapi.ArchiveLabel, []wantID{{"msg2", 1}})
checkMailboxMessageIDs(t, m, pmapi.SpamLabel, []wantID(nil))
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg1", 1}, {"msg2", 2}, {"msg3", 3}, {"msg4", 4}})
}
// checkMailboxMessageIDs checks that the mailbox contains all API IDs with correct sequence numbers and UIDs.
// wantIDs is map from IMAP UID to API ID. Sequence number is detected automatically by order of the ID in the map.
func checkMailboxMessageIDs(t *testing.T, m *mocksForStore, mailboxLabel string, wantIDs []wantID) {
storeAddress := m.store.addresses[addrID1]
storeMailbox := storeAddress.mailboxes[mailboxLabel]
ids, err := storeMailbox.GetAPIIDsFromSequenceRange(0, uint32(len(wantIDs)))
require.Nil(t, err)
idx := 0
for _, wantID := range wantIDs {
id := ids[idx]
require.Equal(t, wantID.appID, id, "Got IDs: %+v", ids)
uid, err := storeMailbox.getUID(wantID.appID)
require.Nil(t, err)
a.Equal(t, uint32(wantID.uid), uid)
seqNum, err := storeMailbox.getSequenceNumber(wantID.appID)
require.Nil(t, err)
a.Equal(t, uint32(idx+1), seqNum)
idx++
}
}
func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " externalID-non-pm-com "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = "<externalID@pm.me>"
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
// Not sure if this is a real-world scenario but we should be able to address this properly.
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
testDataUIDByHeader := []struct {
header *mail.Header
wantID uint32
}{
{
&mail.Header{"Message-Id": []string{"wrongID"}},
0,
},
{
&mail.Header{"Message-Id": []string{"ext"}},
0,
},
{
&mail.Header{"Message-Id": []string{"externalID"}},
0,
},
{
&mail.Header{"Message-Id": []string{"msg1"}},
1,
},
{
&mail.Header{"Message-Id": []string{"<msg3@pm.me>"}},
3,
},
{
&mail.Header{"Message-Id": []string{"<externalID-non-pm-com>"}},
2,
},
{
&mail.Header{"Message-Id": []string{"externalID@pm.me"}},
3,
},
{
&mail.Header{"Message-Id": []string{"external.()+*[]ID@another.pm.me"}},
4,
},
}
storeAddress := m.store.addresses[addrID1]
storeMailbox := storeAddress.mailboxes[pmapi.SentLabel]
for _, td := range testDataUIDByHeader {
haveID := storeMailbox.GetUIDByHeader(td.header)
a.Equal(t, td.wantID, haveID, "testing header: %v", td.header)
}
}

View File

@ -0,0 +1,375 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
// GetMessage returns the `pmapi.Message` struct wrapped in `StoreMessage`
// tied to this mailbox.
func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.store.getMessageFromDB(apiID)
if err != nil {
return nil, err
}
return newStoreMessage(storeMailbox, msg), nil
}
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
// wrapping it.
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
msg, err := storeMailbox.store.fetchMessage(apiID)
if err != nil {
return nil, err
}
return newStoreMessage(storeMailbox, msg), nil
}
// ImportMessage imports the message by calling an API.
// It has to be propagated to all mailboxes which is done by the event loop.
func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labelIDs []string) error {
defer storeMailbox.pollNow()
if storeMailbox.labelID != pmapi.AllMailLabel {
labelIDs = append(labelIDs, storeMailbox.labelID)
}
importReqs := &pmapi.ImportMsgReq{
AddressID: msg.AddressID,
Body: body,
Unread: msg.Unread,
Flags: msg.Flags,
Time: msg.Time,
LabelIDs: labelIDs,
}
res, err := storeMailbox.api().Import([]*pmapi.ImportMsgReq{importReqs})
if err == nil && len(res) > 0 {
msg.ID = res[0].MessageID
}
return err
}
// LabelMessages adds the label by calling an API.
// It has to be propagated to all the same messages in all mailboxes.
// The propagation is processed by the event loop.
func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Labeling messages")
defer storeMailbox.pollNow()
return storeMailbox.api().LabelMessages(apiIDs, storeMailbox.labelID)
}
// UnlabelMessages removes the label by calling an API.
// It has to be propagated to all the same messages in all mailboxes.
// The propagation is processed by the event loop.
func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Unlabeling messages")
defer storeMailbox.pollNow()
return storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID)
}
// MarkMessagesRead marks the message read by calling an API.
// It has to be propagated to metadata mailbox which is done by the event loop.
func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as read")
defer storeMailbox.pollNow()
// Before deleting a message, TB sets \Seen flag which causes an event update
// and thus a refresh of the message by deleting and creating it again.
// TB does not notice this and happily continues with next command to move
// the message to the Trash but the message does not exist anymore.
// Therefore we do not issue API update if the message is already read.
ids := []string{}
for _, apiID := range apiIDs {
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 {
ids = append(ids, apiID)
}
}
return storeMailbox.api().MarkMessagesRead(ids)
}
// MarkMessagesUnread marks the message unread by calling an API.
// It has to be propagated to metadata mailbox which is done by the event loop.
func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unread")
defer storeMailbox.pollNow()
return storeMailbox.api().MarkMessagesUnread(apiIDs)
}
// MarkMessagesStarred adds the Starred label by calling an API.
// It has to be propagated to all the same messages in all mailboxes.
// The propagation is processed by the event loop.
func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as starred")
defer storeMailbox.pollNow()
return storeMailbox.api().LabelMessages(apiIDs, pmapi.StarredLabel)
}
// MarkMessagesUnstarred removes the Starred label by calling an API.
// It has to be propagated to all the same messages in all mailboxes.
// The propagation is processed by the event loop.
func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Marking messages as unstarred")
defer storeMailbox.pollNow()
return storeMailbox.api().UnlabelMessages(apiIDs, pmapi.StarredLabel)
}
// DeleteMessages deletes messages.
// If the mailbox is All Mail or All Sent, it does nothing.
// If the mailbox is Trash or Spam and message is not in any other mailbox, messages is deleted.
// In all other cases the message is only removed from the mailbox.
func (storeMailbox *Mailbox) DeleteMessages(apiIDs []string) error {
log.WithFields(logrus.Fields{
"messages": apiIDs,
"label": storeMailbox.labelID,
"mailbox": storeMailbox.Name,
}).Trace("Deleting messages")
defer storeMailbox.pollNow()
switch storeMailbox.labelID {
case pmapi.AllMailLabel, pmapi.AllSentLabel:
break
case pmapi.TrashLabel, pmapi.SpamLabel:
messageIDsToDelete := []string{}
messageIDsToUnlabel := []string{}
for _, apiID := range apiIDs {
msg, err := storeMailbox.store.getMessageFromDB(apiID)
if err != nil {
return err
}
otherLabels := false
// If the message has any custom label, we don't want to delete it, only remove trash/spam label.
for _, label := range msg.LabelIDs {
if label != pmapi.SpamLabel && label != pmapi.TrashLabel && label != pmapi.AllMailLabel && label != pmapi.AllSentLabel && label != pmapi.DraftLabel && label != pmapi.AllDraftsLabel {
otherLabels = true
break
}
}
if otherLabels {
messageIDsToUnlabel = append(messageIDsToUnlabel, apiID)
} else {
messageIDsToDelete = append(messageIDsToDelete, apiID)
}
}
if len(messageIDsToUnlabel) > 0 {
if err := storeMailbox.api().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
log.WithError(err).Warning("Cannot unlabel before deleting")
}
}
if len(messageIDsToDelete) > 0 {
if err := storeMailbox.api().DeleteMessages(messageIDsToDelete); err != nil {
return err
}
}
default:
if err := storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
return err
}
}
return nil
}
func (storeMailbox *Mailbox) txSkipAndRemoveFromMailbox(tx *bolt.Tx, msg *pmapi.Message) (skipAndRemove bool) {
defer func() {
if skipAndRemove {
if err := storeMailbox.txDeleteMessage(tx, msg.ID); err != nil {
storeMailbox.log.WithError(err).Error("Cannot remove message")
}
}
}()
mode, err := storeMailbox.store.getAddressMode()
if err != nil {
log.WithError(err).Error("Could not determine address mode")
return
}
skipAndRemove = true
// If it's split mode and it shouldn't be under this address, it should be skipped and removed.
if mode == splitMode && storeMailbox.storeAddress.addressID != msg.AddressID {
return
}
// If the message belongs in this mailbox, don't skip/remove it.
for _, labelID := range msg.LabelIDs {
if labelID == storeMailbox.labelID {
skipAndRemove = false
return
}
}
return skipAndRemove
}
// txCreateOrUpdateMessages will delete, create or update message from mailbox.
func (storeMailbox *Mailbox) txCreateOrUpdateMessages(tx *bolt.Tx, msgs []*pmapi.Message) error { //nolint[funlen]
// Buckets are not initialized right away because it's a heavy operation.
// The best option is to get the same bucket only once and only when needed.
var apiBucket, imapBucket *bolt.Bucket
for _, msg := range msgs {
if storeMailbox.txSkipAndRemoveFromMailbox(tx, msg) {
continue
}
// Update message.
if apiBucket == nil {
apiBucket = storeMailbox.txGetAPIIDsBucket(tx)
}
// Draft bodies can change and bodies are not re-fetched by IMAP clients.
// Every change has to be a new message; we need to delete the old one and always recreate it.
if storeMailbox.labelID == pmapi.DraftLabel {
if err := storeMailbox.txDeleteMessage(tx, msg.ID); err != nil {
return errors.Wrap(err, "cannot delete old draft")
}
} else {
uidb := apiBucket.Get([]byte(msg.ID))
if uidb != nil {
if imapBucket == nil {
imapBucket = storeMailbox.txGetIMAPIDsBucket(tx)
}
seqNum, seqErr := storeMailbox.txGetSequenceNumberOfUID(imapBucket, uidb)
if seqErr == nil {
storeMailbox.store.imapUpdateMessage(
storeMailbox.storeAddress.address,
storeMailbox.labelName,
btoi(uidb),
seqNum,
msg,
)
}
continue
}
}
// Create a new message.
if imapBucket == nil {
imapBucket = storeMailbox.txGetIMAPIDsBucket(tx)
}
uid, err := storeMailbox.txGetNextUID(imapBucket, true)
if err != nil {
return errors.Wrap(err, "cannot generate new UID")
}
uidb := itob(uid)
if err = imapBucket.Put(uidb, []byte(msg.ID)); err != nil {
return errors.Wrap(err, "cannot add to IMAP bucket")
}
if err = apiBucket.Put([]byte(msg.ID), uidb); err != nil {
return errors.Wrap(err, "cannot add to API bucket")
}
seqNum, err := storeMailbox.txGetSequenceNumberOfUID(imapBucket, uidb)
if err != nil {
return errors.Wrap(err, "cannot get sequence number from UID")
}
storeMailbox.store.imapUpdateMessage(
storeMailbox.storeAddress.address,
storeMailbox.labelName,
uid,
seqNum,
msg,
)
}
return storeMailbox.txMailboxStatusUpdate(tx)
}
// txDeleteMessage deletes the message from the mailbox bucket.
// and issues message delete and mailbox update changes to updates channel.
func (storeMailbox *Mailbox) txDeleteMessage(tx *bolt.Tx, apiID string) error {
apiBucket := storeMailbox.txGetAPIIDsBucket(tx)
apiIDb := []byte(apiID)
uidb := apiBucket.Get(apiIDb)
if uidb == nil {
return nil
}
imapBucket := storeMailbox.txGetIMAPIDsBucket(tx)
seqNum, seqNumErr := storeMailbox.txGetSequenceNumberOfUID(imapBucket, uidb)
if seqNumErr != nil {
storeMailbox.log.WithField("apiID", apiID).WithError(seqNumErr).Warn("Cannot get seqNum of deleting message")
}
if err := imapBucket.Delete(uidb); err != nil {
return errors.Wrap(err, "cannot delete from IMAP bucket")
}
if err := apiBucket.Delete(apiIDb); err != nil {
return errors.Wrap(err, "cannot delete from API bucket")
}
if seqNumErr == nil {
storeMailbox.store.imapDeleteMessage(
storeMailbox.storeAddress.address,
storeMailbox.labelName,
seqNum,
)
if err := storeMailbox.txMailboxStatusUpdate(tx); err != nil {
return err
}
}
return nil
}
func (storeMailbox *Mailbox) txMailboxStatusUpdate(tx *bolt.Tx) error {
total, unread, err := storeMailbox.txGetCounts(tx)
if err != nil {
return errors.Wrap(err, "cannot get counts for mailbox status update")
}
storeMailbox.store.imapMailboxStatus(
storeMailbox.storeAddress.address,
storeMailbox.labelName,
total,
unread,
)
return nil
}

View File

@ -0,0 +1,31 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"os"
"github.com/sirupsen/logrus"
)
func init() { //nolint[gochecknoinits]
logrus.SetLevel(logrus.ErrorLevel)
if os.Getenv("VERBOSITY") == "trace" {
logrus.SetLevel(logrus.TraceLevel)
}
}

108
internal/store/message.go Normal file
View File

@ -0,0 +1,108 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"net/mail"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
bolt "go.etcd.io/bbolt"
)
// Message is wrapper around `pmapi.Message` with connection to
// a specific mailbox with helper functions to get IMAP UID, sequence
// numbers and similar.
type Message struct {
api PMAPIProvider
msg *pmapi.Message
store *Store
storeMailbox *Mailbox
}
func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message {
return &Message{
api: storeMailbox.store.api,
msg: msg,
store: storeMailbox.store,
storeMailbox: storeMailbox,
}
}
// ID returns message ID on our API (always the same ID for all mailboxes).
func (message *Message) ID() string {
return message.msg.ID
}
// UID returns message UID for IMAP, specific for mailbox used to get the message.
func (message *Message) UID() (uint32, error) {
return message.storeMailbox.getUID(message.ID())
}
// SequenceNumber returns index of message in used mailbox.
func (message *Message) SequenceNumber() (uint32, error) {
return message.storeMailbox.getSequenceNumber(message.ID())
}
// Message returns message struct from pmapi.
func (message *Message) Message() *pmapi.Message {
return message.msg
}
// SetSize updates the information about size of decrypted message which can be
// used for IMAP. This should not trigger any IMAP update.
// NOTE: The size from the server corresponds to pure body bytes. Hence it
// should not be used. The correct size has to be calculated from decrypted and
// built message.
func (message *Message) SetSize(size int64) error {
message.msg.Size = size
txUpdate := func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
}
stored.Size = size
return message.store.txPutMessage(
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
}
// SetContentTypeAndHeader updates the information about content type and
// header of decrypted message. This should not trigger any IMAP update.
// NOTE: Content type depends on details of decrypted message which we want to
// cache.
func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error {
message.msg.MIMEType = mimeType
message.msg.Header = header
txUpdate := func(tx *bolt.Tx) error {
stored, err := message.store.txGetMessage(tx, message.msg.ID)
if err != nil {
return err
}
stored.MIMEType = mimeType
stored.Header = header
return message.store.txPutMessage(
tx.Bucket(metadataBucket),
stored,
)
}
return message.store.db.Update(txUpdate)
}

View File

@ -0,0 +1,193 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser)
// Package mocks is a generated GoMock package.
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockPanicHandler is a mock of PanicHandler interface
type MockPanicHandler struct {
ctrl *gomock.Controller
recorder *MockPanicHandlerMockRecorder
}
// MockPanicHandlerMockRecorder is the mock recorder for MockPanicHandler
type MockPanicHandlerMockRecorder struct {
mock *MockPanicHandler
}
// NewMockPanicHandler creates a new mock instance
func NewMockPanicHandler(ctrl *gomock.Controller) *MockPanicHandler {
mock := &MockPanicHandler{ctrl: ctrl}
mock.recorder = &MockPanicHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockPanicHandler) EXPECT() *MockPanicHandlerMockRecorder {
return m.recorder
}
// HandlePanic mocks base method
func (m *MockPanicHandler) HandlePanic() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "HandlePanic")
}
// HandlePanic indicates an expected call of HandlePanic
func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
}
// MockBridgeUser is a mock of BridgeUser interface
type MockBridgeUser struct {
ctrl *gomock.Controller
recorder *MockBridgeUserMockRecorder
}
// MockBridgeUserMockRecorder is the mock recorder for MockBridgeUser
type MockBridgeUserMockRecorder struct {
mock *MockBridgeUser
}
// NewMockBridgeUser creates a new mock instance
func NewMockBridgeUser(ctrl *gomock.Controller) *MockBridgeUser {
mock := &MockBridgeUser{ctrl: ctrl}
mock.recorder = &MockBridgeUserMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockBridgeUser) EXPECT() *MockBridgeUserMockRecorder {
return m.recorder
}
// CloseConnection mocks base method
func (m *MockBridgeUser) CloseConnection(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "CloseConnection", arg0)
}
// CloseConnection indicates an expected call of CloseConnection
func (mr *MockBridgeUserMockRecorder) CloseConnection(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseConnection", reflect.TypeOf((*MockBridgeUser)(nil).CloseConnection), arg0)
}
// GetAddressID mocks base method
func (m *MockBridgeUser) GetAddressID(arg0 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAddressID", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAddressID indicates an expected call of GetAddressID
func (mr *MockBridgeUserMockRecorder) GetAddressID(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0)
}
// GetPrimaryAddress mocks base method
func (m *MockBridgeUser) GetPrimaryAddress() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPrimaryAddress")
ret0, _ := ret[0].(string)
return ret0
}
// GetPrimaryAddress indicates an expected call of GetPrimaryAddress
func (mr *MockBridgeUserMockRecorder) GetPrimaryAddress() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryAddress", reflect.TypeOf((*MockBridgeUser)(nil).GetPrimaryAddress))
}
// GetStoreAddresses mocks base method
func (m *MockBridgeUser) GetStoreAddresses() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStoreAddresses")
ret0, _ := ret[0].([]string)
return ret0
}
// GetStoreAddresses indicates an expected call of GetStoreAddresses
func (mr *MockBridgeUserMockRecorder) GetStoreAddresses() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStoreAddresses", reflect.TypeOf((*MockBridgeUser)(nil).GetStoreAddresses))
}
// ID mocks base method
func (m *MockBridgeUser) ID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ID")
ret0, _ := ret[0].(string)
return ret0
}
// ID indicates an expected call of ID
func (mr *MockBridgeUserMockRecorder) ID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockBridgeUser)(nil).ID))
}
// IsCombinedAddressMode mocks base method
func (m *MockBridgeUser) IsCombinedAddressMode() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsCombinedAddressMode")
ret0, _ := ret[0].(bool)
return ret0
}
// IsCombinedAddressMode indicates an expected call of IsCombinedAddressMode
func (mr *MockBridgeUserMockRecorder) IsCombinedAddressMode() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCombinedAddressMode", reflect.TypeOf((*MockBridgeUser)(nil).IsCombinedAddressMode))
}
// IsConnected mocks base method
func (m *MockBridgeUser) IsConnected() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsConnected")
ret0, _ := ret[0].(bool)
return ret0
}
// IsConnected indicates an expected call of IsConnected
func (mr *MockBridgeUserMockRecorder) IsConnected() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockBridgeUser)(nil).IsConnected))
}
// Logout mocks base method
func (m *MockBridgeUser) Logout() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Logout")
ret0, _ := ret[0].(error)
return ret0
}
// Logout indicates an expected call of Logout
func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockBridgeUser)(nil).Logout))
}
// UpdateUser mocks base method
func (m *MockBridgeUser) UpdateUser() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUser")
ret0, _ := ret[0].(error)
return ret0
}
// UpdateUser indicates an expected call of UpdateUser
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser))
}

View File

@ -0,0 +1,106 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
// Package mocks is a generated GoMock package.
package mocks
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// Add mocks base method
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1)
}
// Add indicates an expected call of Add
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
}
// Emit mocks base method
func (m *MockListener) Emit(arg0, arg1 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Emit", arg0, arg1)
}
// Emit indicates an expected call of Emit
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
}
// Remove mocks base method
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Remove", arg0, arg1)
}
// Remove indicates an expected call of Remove
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
}
// RetryEmit mocks base method
func (m *MockListener) RetryEmit(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RetryEmit", arg0)
}
// RetryEmit indicates an expected call of RetryEmit
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
}
// SetBuffer mocks base method
func (m *MockListener) SetBuffer(arg0 string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetBuffer", arg0)
}
// SetBuffer indicates an expected call of SetBuffer
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
}
// SetLimit mocks base method
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", arg0, arg1)
}
// SetLimit indicates an expected call of SetLimit
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
}

396
internal/store/store.go Normal file
View File

@ -0,0 +1,396 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Package store communicates with API and caches metadata in a local database.
package store
import (
"fmt"
"os"
"sync"
"time"
"github.com/ProtonMail/proton-bridge/pkg/listener"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
const (
// PathDelimiter for IMAP
PathDelimiter = "/"
// UserLabelsMailboxName for IMAP
UserLabelsMailboxName = "Labels"
// UserLabelsPrefix contains name with delimiter for IMAP
UserLabelsPrefix = UserLabelsMailboxName + PathDelimiter
// UserFoldersMailboxName for IMAP
UserFoldersMailboxName = "Folders"
// UserFoldersPrefix contains name with delimiter for IMAP
UserFoldersPrefix = UserFoldersMailboxName + PathDelimiter
)
var (
log = logrus.WithField("pkg", "store") //nolint[gochecknoglobals]
// Database structure:
// * metadata
// * {messageID} -> message data (subject, from, to, time, headers, body size, ...)
// * counts
// * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive
// * address_info
// * {index} -> {address, addressID}
// * address_mode
// * mode -> string split or combined
// * mailboxes_version
// * version -> uint32 value
// * sync_state
// * sync_state -> string timestamp when it was last synced (when missing, sync should be ongoing)
// * ids_ranges -> json array of groups with start and end message ID (when missing, there is no ongoing sync)
// * ids_to_be_deleted -> json array of message IDs to be deleted after sync (when missing, there is no ongoing sync)
// * mailboxes
// * {addressID+mailboxID}
// * imap_ids
// * {imapUID} -> string messageID
// * api_ids
// * {messageID} -> uint32 imapUID
metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
countsBucket = []byte("counts") //nolint[gochecknoglobals]
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
// ErrNoSuchAPIID when mailbox does not have API ID.
ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals]
// ErrNoSuchUID when mailbox does not have IMAP UID.
ErrNoSuchUID = errors.New("no such uid") //nolint[gochecknoglobals]
// ErrNoSuchSeqNum when mailbox does not have IMAP ID.
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
)
// Store is local user storage, which handles the synchronization between IMAP and PM API.
type Store struct {
panicHandler PanicHandler
eventLoop *eventLoop
user BridgeUser
api PMAPIProvider
log *logrus.Entry
cache *Cache
filePath string
db *bolt.DB
lock *sync.RWMutex
addresses map[string]*Address
imapUpdates chan interface{}
isSyncRunning bool
addressMode addressMode
}
// New creates or opens a store for the given `user`.
func New(
panicHandler PanicHandler,
user BridgeUser,
api PMAPIProvider,
events listener.Listener,
path string,
cache *Cache,
) (store *Store, err error) {
if user == nil || api == nil || events == nil || cache == nil {
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, api, events, cache)
}
l := log.WithField("user", user.ID())
var firstInit bool
if _, existErr := os.Stat(path); os.IsNotExist(existErr) {
l.Info("Creating new store database file with address mode from user's credentials store")
firstInit = true
} else {
l.Info("Store database file already exists, using mode already set")
firstInit = false
}
bdb, err := openBoltDatabase(path)
if err != nil {
err = errors.Wrap(err, "failed to open store database")
return
}
store = &Store{
panicHandler: panicHandler,
api: api,
user: user,
cache: cache,
filePath: path,
db: bdb,
lock: &sync.RWMutex{},
log: l,
}
if err = store.init(firstInit); err != nil {
l.WithError(err).Error("Could not initialise store, attempting to close")
if storeCloseErr := store.Close(); storeCloseErr != nil {
l.WithError(storeCloseErr).Warn("Could not close uninitialised store")
}
err = errors.Wrap(err, "failed to initialise store")
return
}
if user.IsConnected() {
store.eventLoop = newEventLoop(cache, store, api, user, events)
go func() {
defer store.panicHandler.HandlePanic()
store.eventLoop.start()
}()
}
return store, err
}
func openBoltDatabase(filePath string) (db *bolt.DB, err error) {
l := log.WithField("path", filePath)
l.Debug("Opening bolt database")
if db, err = bolt.Open(filePath, 0600, &bolt.Options{Timeout: 1 * time.Second}); err != nil {
l.WithError(err).Error("Could not open bolt database")
return
}
if val, set := os.LookupEnv("BRIDGESTRICTMODE"); set && val == "1" {
db.StrictMode = true
}
tx := func(tx *bolt.Tx) (err error) {
if _, err = tx.CreateBucketIfNotExists(metadataBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(countsBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(addressInfoBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(addressModeBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(syncStateBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(mailboxesBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(mboxVersionBucket); err != nil {
return
}
return
}
if err = db.Update(tx); err != nil {
return
}
return db, err
}
// init initialises the store for the given addresses.
func (store *Store) init(firstInit bool) (err error) {
if store.addresses != nil {
store.log.Warn("Store was already initialised")
return
}
// If it's the first time we are creating the store, use the mode set in the
// user's credentials, otherwise read it from the DB (if present).
if firstInit {
if store.user.IsCombinedAddressMode() {
err = store.setAddressMode(combinedMode)
} else {
err = store.setAddressMode(splitMode)
}
if err != nil {
return errors.Wrap(err, "first init setting store address mode")
}
} else if store.addressMode, err = store.getAddressMode(); err != nil {
store.log.WithError(err).Error("Store address mode is unknown, setting to combined mode")
if err = store.setAddressMode(combinedMode); err != nil {
return errors.Wrap(err, "setting store address mode")
}
}
store.log.WithField("mode", store.addressMode).Debug("Initialising store")
labels, err := store.initCounts()
if err != nil {
store.log.WithError(err).Error("Could not initialise label counts")
return
}
if err = store.initAddresses(labels); err != nil {
store.log.WithError(err).Error("Could not initialise store addresses")
return
}
return err
}
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
// the API is unavailable for whatever reason it tries to fetch the labels locally.
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
if labels, err = store.api.ListLabels(); err != nil {
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
store.log.WithError(err).Error("Cannot list local labels")
return
}
} else {
// the labels listed by PMAPI don't include system folders so we need to add them.
for _, counts := range getSystemFolders() {
labels = append(labels, counts.getPMLabel())
}
if err = store.createOrUpdateMailboxCountsBuckets(labels); err != nil {
store.log.WithError(err).Error("Cannot create counts")
return
}
if countsErr := store.updateCountsFromServer(); countsErr != nil {
store.log.WithError(countsErr).Warning("Continue without new counts from server")
}
}
sortByOrder(labels)
return
}
// initAddresses creates address objects in the store for each necessary address.
// In combined mode this means just one mailbox for all addresses but in split mode this means one mailbox per address.
func (store *Store) initAddresses(labels []*pmapi.Label) (err error) {
store.addresses = make(map[string]*Address)
addrInfo, err := store.GetAddressInfo()
if err != nil {
store.log.WithError(err).Error("Could not get addresses and address IDs")
return
}
// We need at least one address to continue.
if len(addrInfo) < 1 {
err = errors.New("no addresses to initialise")
store.log.WithError(err).Warn("There are no addresses to initialise")
return
}
// If in combined mode, we only need the user's primary address.
if store.addressMode == combinedMode {
addrInfo = addrInfo[:1]
}
for _, addr := range addrInfo {
if err = store.addAddress(addr.Address, addr.AddressID, labels); err != nil {
store.log.WithField("address", addr.Address).WithError(err).Error("Could not add address to store")
}
}
return err
}
// addAddress adds a new address to the store. If the address exists already it is overwritten.
func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label) (err error) {
if _, ok := store.addresses[addressID]; ok {
store.log.WithField("addressID", addressID).Debug("Overwriting store address")
}
addr, err := newAddress(store, address, addressID, labels)
if err != nil {
return errors.Wrap(err, "failed to create store address object")
}
store.addresses[addressID] = addr
return
}
// Close stops the event loop and closes the database to free the file.
func (store *Store) Close() error {
store.lock.Lock()
defer store.lock.Unlock()
return store.close()
}
// CloseEventLoop stops the eventloop (if it is present).
func (store *Store) CloseEventLoop() {
if store.eventLoop != nil {
store.eventLoop.stop()
}
}
func (store *Store) close() error {
store.CloseEventLoop()
return store.db.Close()
}
// Remove closes and removes the database file and clears the cache file.
func (store *Store) Remove() (err error) {
store.lock.Lock()
defer store.lock.Unlock()
store.log.Trace("Removing store")
var result *multierror.Error
if err = store.close(); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to close store"))
}
if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove store"))
}
return result.ErrorOrNil()
}
// RemoveStore removes the database file and clears the cache file.
func RemoveStore(cache *Cache, path, userID string) error {
var result *multierror.Error
if err := cache.clearCacheUser(userID); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache"))
}
// RemoveAll will not return an error if the path does not exist.
if err := os.RemoveAll(path); err != nil {
result = multierror.Append(result, errors.Wrap(err, "failed to remove database file"))
}
return result.ErrorOrNil()
}

View File

@ -0,0 +1,112 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
type addressMode string
const (
splitMode addressMode = "split"
combinedMode addressMode = "combined"
modeKey = "mode"
)
// getAddressMode returns the current address mode (split or combined) of the store.
// It first looks in the local cache but if that is not yet set, it loads it from the database.
func (store *Store) getAddressMode() (mode addressMode, err error) {
if store.addressMode != "" {
mode = store.addressMode
return
}
tx := func(tx *bolt.Tx) (err error) {
b := tx.Bucket(addressModeBucket)
dbMode := b.Get([]byte(modeKey))
if dbMode == nil {
return errors.New("address mode not set")
}
mode = addressMode(dbMode)
return
}
err = store.db.View(tx)
return
}
// IsCombinedMode returns whether the store is set to combined mode.
func (store *Store) IsCombinedMode() bool {
return store.addressMode == combinedMode
}
// UseCombinedMode sets whether the store should be set to combined mode.
func (store *Store) UseCombinedMode(useCombined bool) (err error) {
if useCombined {
err = store.switchAddressMode(combinedMode)
} else {
err = store.switchAddressMode(splitMode)
}
return
}
// switchAddressMode sets the address mode to the given value and rebuilds the mailboxes.
func (store *Store) switchAddressMode(mode addressMode) (err error) {
if store.addressMode == mode {
log.Debug("The store is using the correct address mode")
return
}
if err = store.setAddressMode(mode); err != nil {
log.WithError(err).Error("Could not set store address mode")
return
}
if err = store.RebuildMailboxes(); err != nil {
log.WithError(err).Error("Could not rebuild mailboxes after switching address mode")
return
}
return
}
// setAddressMode sets the current address mode (split or combined) of the store.
// It writes to database and updates the local value in the store object.
func (store *Store) setAddressMode(mode addressMode) (err error) {
store.log.WithField("mode", string(mode)).Info("Setting store address mode")
tx := func(tx *bolt.Tx) (err error) {
b := tx.Bucket(addressModeBucket)
return b.Put([]byte(modeKey), []byte(mode))
}
if err = store.db.Update(tx); err != nil {
return
}
store.addressMode = mode
return
}

View File

@ -0,0 +1,68 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import bolt "go.etcd.io/bbolt"
const (
versionKey = "version"
// versionOffset makes it possible to force email client to reload all
// mailboxes. If increased during application update it will trigger
// the reload on client side without needing to sync DB or re-setup account.
versionOffset = uint32(3)
)
func (store *Store) getMailboxesVersion() uint32 {
localVersion := store.readMailboxesVersion()
// If a read error occurs it returns 0 which is an invalid version value.
if localVersion == 0 {
localVersion = 1
_ = store.writeMailboxesVersion(localVersion)
}
// versionOffset will make email clients reload if increased during bridge update.
return localVersion + versionOffset
}
func (store *Store) increaseMailboxesVersion() error {
ver := store.readMailboxesVersion()
// The version is zero if a read error occurred. Operation ++ will make it 1
// which is default starting value.
ver++
return store.writeMailboxesVersion(ver)
}
func (store *Store) readMailboxesVersion() (version uint32) {
_ = store.db.View(func(tx *bolt.Tx) (err error) {
b := tx.Bucket(mboxVersionBucket)
verRaw := b.Get([]byte(versionKey))
if verRaw != nil {
version = btoi(verRaw)
}
return nil
})
return
}
func (store *Store) writeMailboxesVersion(ver uint32) error {
return store.db.Update(func(tx *bolt.Tx) (err error) {
b := tx.Bucket(mboxVersionBucket)
return b.Put([]byte(versionKey), itob(ver))
})
}

View File

@ -0,0 +1,129 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"io/ioutil"
"os"
"path/filepath"
"sync"
"testing"
bridgemocks "github.com/ProtonMail/proton-bridge/internal/bridge/mocks"
storeMocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
const (
addr1 = "niceaddress@pm.me"
addrID1 = "niceaddressID"
addr2 = "jamesandmichalarecool@pm.me"
addrID2 = "jamesandmichalarecool"
)
type mocksForStore struct {
tb testing.TB
ctrl *gomock.Controller
events *storeMocks.MockListener
api *bridgemocks.MockPMAPIProvider
user *storeMocks.MockBridgeUser
panicHandler *storeMocks.MockPanicHandler
store *Store
tmpDir string
cache *Cache
}
func initMocks(tb testing.TB) (*mocksForStore, func()) {
ctrl := gomock.NewController(tb)
mocks := &mocksForStore{
tb: tb,
ctrl: ctrl,
events: storeMocks.NewMockListener(ctrl),
api: bridgemocks.NewMockPMAPIProvider(ctrl),
user: storeMocks.NewMockBridgeUser(ctrl),
panicHandler: storeMocks.NewMockPanicHandler(ctrl),
}
// Called during clean-up.
mocks.panicHandler.EXPECT().HandlePanic().AnyTimes()
var err error
mocks.tmpDir, err = ioutil.TempDir("", "store-test")
require.NoError(tb, err)
cacheFile := filepath.Join(mocks.tmpDir, "cache.json")
mocks.cache = NewCache(cacheFile)
return mocks, func() {
if err := recover(); err != nil {
panic(err)
}
if mocks.store != nil {
require.Nil(tb, mocks.store.Close())
}
ctrl.Finish()
require.NoError(tb, os.RemoveAll(mocks.tmpDir))
}
}
func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unparam]
mocks.user.EXPECT().ID().Return("userID").AnyTimes()
mocks.user.EXPECT().IsConnected().Return(true)
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
mocks.api.EXPECT().Addresses().Return(pmapi.AddressList{
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
})
mocks.api.EXPECT().ListLabels()
mocks.api.EXPECT().CountMessages("")
mocks.api.EXPECT().GetEvent(gomock.Any()).
Return(&pmapi.Event{
EventID: "latestEventID",
}, nil).AnyTimes()
// We want to wait until first sync has finished.
firstSyncWaiter := sync.WaitGroup{}
firstSyncWaiter.Add(1)
mocks.api.EXPECT().
ListMessages(gomock.Any()).
DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
firstSyncWaiter.Done()
return []*pmapi.Message{}, 0, nil
})
var err error
mocks.store, err = New(
mocks.panicHandler,
mocks.user,
mocks.api,
mocks.events,
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
mocks.cache,
)
require.NoError(mocks.tb, err)
// Wait for sync to finish.
firstSyncWaiter.Wait()
}

View File

@ -0,0 +1,145 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"fmt"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/stretchr/testify/assert"
bolt "go.etcd.io/bbolt"
)
// TestSync triggers a sync of the store.
func (store *Store) TestSync() {
store.triggerSync()
}
// TestPollNow triggers a loop of the event loop.
func (store *Store) TestPollNow() {
store.eventLoop.pollNow()
}
// TestIsSyncRunning returns whether the sync is currently ongoing.
func (store *Store) TestIsSyncRunning() bool {
return store.isSyncRunning
}
// TestGetEventLoop returns the store's event loop.
func (store *Store) TestGetEventLoop() *eventLoop { //nolint[golint]
return store.eventLoop
}
// TestGetStoreFilePath returns the filepath of the store's database file.
func (store *Store) TestGetStoreFilePath() string {
return store.filePath
}
// TestDumpDB will dump store database content.
func (store *Store) TestDumpDB(tb assert.TestingT) {
dumpCounts := true
fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path())
txMails := txDumpMailsFactory(tb)
txDump := func(tx *bolt.Tx) error {
if dumpCounts {
if err := txDumpCounts(tx); err != nil {
return err
}
}
if err := txMails(tx); err != nil {
return err
}
return nil
}
assert.NoError(tb, store.db.View(txDump))
}
func txDumpMailsFactory(tb assert.TestingT) func(tx *bolt.Tx) error {
return func(tx *bolt.Tx) error {
mailboxes := tx.Bucket(mailboxesBucket)
metadata := tx.Bucket(metadataBucket)
err := mailboxes.ForEach(func(mboxName, mboxData []byte) error {
fmt.Println("mbox:", string(mboxName))
b := mailboxes.Bucket(mboxName).Bucket(imapIDsBucket)
c := b.Cursor()
i := 0
for imapID, apiID := c.First(); imapID != nil; imapID, apiID = c.Next() {
i++
fmt.Println(" ", i, "imap", btoi(imapID), "api", string(apiID))
data := metadata.Get(apiID)
if !assert.NotNil(tb, data) {
continue
}
if !assert.NoError(tb, txMailMeta(data, i)) {
continue
}
}
fmt.Println("total:", i)
return nil
})
return err
}
}
func txDumpCounts(tx *bolt.Tx) error {
counts := tx.Bucket(countsBucket)
err := counts.ForEach(func(labelID, countsB []byte) error {
defer fmt.Println()
fmt.Printf("counts id: %q ", string(labelID))
counts := &mailboxCounts{}
if err := json.Unmarshal(countsB, counts); err != nil {
fmt.Printf(" Error %v", err)
return nil
}
fmt.Printf(" total :%d unread %d", counts.TotalOnAPI, counts.UnreadOnAPI)
return nil
})
return err
}
func txMailMeta(data []byte, i int) error {
fullMetaDump := false
msg := &pmapi.Message{}
if err := json.Unmarshal(data, msg); err != nil {
return err
}
if msg.Body != "" {
fmt.Printf(" %d body %s\n\n", i, msg.Body)
panic("NONZERO BODY")
}
if i >= 10 {
return nil
}
if fullMetaDump {
fmt.Printf(" %d meta %s\n\n", i, string(data))
} else {
fmt.Println(
" Subj", msg.Subject,
"\n From", msg.Sender,
"\n Time", msg.Time,
"\n Labels", msg.LabelIDs,
"\n Unread", msg.Unread,
)
}
return nil
}

222
internal/store/sync.go Normal file
View File

@ -0,0 +1,222 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"math"
"sync"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
)
const (
syncMinPagesPerWorker = 10
syncMessagesMaxWorkers = 5
maxFilterPageSize = 150
)
type storeSynchronizer interface {
getAllMessageIDs() ([]string, error)
createOrUpdateMessagesEvent([]*pmapi.Message) error
deleteMessagesEvent([]string) error
saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string)
}
type messageLister interface {
ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
}
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error {
labelID := pmapi.AllMailLabel
// When the full sync starts (i.e. is not already in progress), we need to load
// - all message IDs in database, so we can see which messages we need to remove at the end of the sync
// - ID ranges which indicate how to split work into multiple workers
if !syncState.isIncomplete() {
if err := syncState.loadMessageIDsToBeDeleted(); err != nil {
return errors.Wrap(err, "failed to load message IDs")
}
if err := findIDRanges(labelID, api, syncState); err != nil {
return errors.Wrap(err, "failed to load IDs ranges")
}
syncState.save()
}
wg := &sync.WaitGroup{}
shouldStop := 0 // Using integer to have it atomic.
var resultError error
for _, idRange := range syncState.idRanges {
wg.Add(1)
idRange := idRange // Bind for goroutine.
go func() {
defer panicHandler.HandlePanic()
defer wg.Done()
err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop)
if err != nil {
shouldStop = 1
resultError = errors.Wrap(err, "failed to sync group")
}
}()
}
wg.Wait()
if resultError == nil {
if err := syncState.deleteMessagesToBeDeleted(); err != nil {
return errors.Wrap(err, "failed to delete messages")
}
}
return resultError
}
func findIDRanges(labelID string, api messageLister, syncState *syncState) error {
_, count, err := getSplitIDAndCount(labelID, api, 0)
if err != nil {
return errors.Wrap(err, "failed to get first ID and count")
}
log.WithField("total", count).Debug("Finding ID ranges")
if count == 0 {
return nil
}
syncState.initIDRanges()
pages := int(math.Ceil(float64(count) / float64(maxFilterPageSize)))
workers := (pages / syncMinPagesPerWorker) + 1
if workers > syncMessagesMaxWorkers {
workers = syncMessagesMaxWorkers
}
if workers == 1 {
return nil
}
step := int(math.Round(float64(pages) / float64(workers)))
// Increment steps in case there are more steps than max # of workers (due to rounding).
if (step*syncMessagesMaxWorkers)+1 < pages {
step++
}
for page := step; page < pages; page += step {
splitID, _, err := getSplitIDAndCount(labelID, api, page)
if err != nil {
return errors.Wrap(err, "failed to get IDs range")
}
// Some messages were probably deleted and so the page does not exist anymore.
// Would be good to start this function again, but let's rather start the sync instead of
// wasting time of many calls to API to find where to split workers.
if splitID == "" {
break
}
syncState.addIDRange(splitID)
}
return nil
}
func getSplitIDAndCount(labelID string, api messageLister, page int) (string, int, error) {
sort := "ID"
desc := false
filter := &pmapi.MessagesFilter{
LabelID: labelID,
Sort: sort,
Desc: &desc,
PageSize: maxFilterPageSize,
Page: page,
Limit: 1,
}
// If the page does not exist, an empty page instead of an error is returned.
messages, total, err := api.ListMessages(filter)
if err != nil {
return "", 0, errors.Wrap(err, "failed to list messages")
}
if len(messages) == 0 {
return "", 0, nil
}
return messages[0].ID, total, nil
}
func syncBatch( //nolint[funlen]
labelID string,
store storeSynchronizer,
api messageLister,
syncState *syncState,
idRange *syncIDRange,
shouldStop *int,
) error {
log.WithField("start", idRange.StartID).WithField("stop", idRange.StopID).Info("Starting sync batch")
for {
if *shouldStop == 1 || idRange.isFinished() {
break
}
sort := "ID"
desc := true
filter := &pmapi.MessagesFilter{
LabelID: labelID,
Sort: sort,
Desc: &desc,
PageSize: maxFilterPageSize,
Page: 0,
// Messages with BeginID and EndID are included. We will process
// those messages twice, but that's OK.
// When message is completely removed, it still works as expected.
BeginID: idRange.StartID,
EndID: idRange.StopID,
}
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
messages, _, err := api.ListMessages(filter)
if err != nil {
return errors.Wrap(err, "failed to list messages")
}
if len(messages) == 0 {
break
}
for _, m := range messages {
syncState.doNotDeleteMessageID(m.ID)
}
syncState.save()
if err := store.createOrUpdateMessagesEvent(messages); err != nil {
return errors.Wrap(err, "failed to create or update messages")
}
pageLastMessageID := messages[len(messages)-1].ID
if !desc {
idRange.setStartID(pageLastMessageID)
} else {
idRange.setStopID(pageLastMessageID)
}
if len(messages) < maxFilterPageSize {
break
}
}
return nil
}

View File

@ -0,0 +1,217 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sync"
"time"
"github.com/pkg/errors"
)
type syncState struct {
lock *sync.RWMutex
store storeSynchronizer
// finishTime is the time, when the sync was finished for the last time.
// When it's zero, it was never finished or the sync is ongoing.
finishTime int64
// idRanges are ID ranges which are used to split work in several workers.
// On the beginning of the sync it will find split IDs which are used to
// create this ranges. If we have 10000 messages and five workers, it will
// find IDs around 2000, 4000, 6000 and 8000 and then first worker will
// sync IDs 0-2000, second 2000-4000 and so on.
idRanges []*syncIDRange
// idsToBeDeletedMap is map with keys as message IDs. On the beginning
// of the sync, it will load all message IDs in database. During the sync,
// it will delete all messages from the map which were sycned. The rest
// at the end of the sync will be removed as those messages were not synced
// again. We do that because we don't want to remove everything on the
// beginning of the sync to keep client synced.
idsToBeDeletedMap map[string]bool
}
func newSyncState(store storeSynchronizer, finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) *syncState {
idsToBeDeletedMap := map[string]bool{}
for _, id := range idsToBeDeleted {
idsToBeDeletedMap[id] = true
}
syncState := &syncState{
lock: &sync.RWMutex{},
store: store,
finishTime: finishTime,
idRanges: idRanges,
idsToBeDeletedMap: idsToBeDeletedMap,
}
for _, idRange := range idRanges {
idRange.syncState = syncState
}
return syncState
}
func (s *syncState) save() {
s.lock.Lock()
defer s.lock.Unlock()
s.store.saveSyncState(s.finishTime, s.idRanges, s.getIDsToBeDeleted())
}
// isIncomplete returns whether the sync is in progress (no matter whether
// the sync is running or just not finished by info from database).
func (s *syncState) isIncomplete() bool {
s.lock.Lock()
defer s.lock.Unlock()
return s.finishTime == 0 && len(s.idRanges) != 0
}
// isFinished returns whether the sync was finished.
func (s *syncState) isFinished() bool {
s.lock.Lock()
defer s.lock.Unlock()
return s.finishTime != 0
}
// clearFinishTime sets finish time to zero.
func (s *syncState) clearFinishTime() {
s.lock.Lock()
defer s.save()
defer s.lock.Unlock()
s.finishTime = 0
}
// setFinishTime sets finish time to current time.
func (s *syncState) setFinishTime() {
s.lock.Lock()
defer s.save()
defer s.lock.Unlock()
s.finishTime = time.Now().UnixNano()
}
// initIDRanges inits the main full range. Then each range is added
// by `addIDRange`.
func (s *syncState) initIDRanges() {
s.lock.Lock()
defer s.lock.Unlock()
s.idRanges = []*syncIDRange{{
syncState: s,
StartID: "",
StopID: "",
}}
}
// addIDRange sets `splitID` as stopID for last range and adds new one
// starting with `splitID`.
func (s *syncState) addIDRange(splitID string) {
s.lock.Lock()
defer s.lock.Unlock()
lastGroup := s.idRanges[len(s.idRanges)-1]
lastGroup.StopID = splitID
s.idRanges = append(s.idRanges, &syncIDRange{
syncState: s,
StartID: splitID,
StopID: "",
})
}
// loadMessageIDsToBeDeleted loads all message IDs from database
// and by default all IDs are meant for deletion. During sync for
// each ID `doNotDeleteMessageID` has to be called to remove that
// message from being deleted by `deleteMessagesToBeDeleted`.
func (s *syncState) loadMessageIDsToBeDeleted() error {
idsToBeDeletedMap := make(map[string]bool)
ids, err := s.store.getAllMessageIDs()
if err != nil {
return err
}
for _, id := range ids {
idsToBeDeletedMap[id] = true
}
s.lock.Lock()
defer s.save()
defer s.lock.Unlock()
s.idsToBeDeletedMap = idsToBeDeletedMap
return nil
}
func (s *syncState) doNotDeleteMessageID(id string) {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.idsToBeDeletedMap, id)
}
func (s *syncState) deleteMessagesToBeDeleted() error {
s.lock.Lock()
defer s.lock.Unlock()
idsToBeDeleted := s.getIDsToBeDeleted()
log.Infof("Deleting %v messages after sync", len(idsToBeDeleted))
if err := s.store.deleteMessagesEvent(idsToBeDeleted); err != nil {
return errors.Wrap(err, "failed to delete messages")
}
return nil
}
// getIDsToBeDeleted is helper to convert internal map for easier
// manipulation to array.
func (s *syncState) getIDsToBeDeleted() []string {
keys := []string{}
for key := range s.idsToBeDeletedMap {
keys = append(keys, key)
}
return keys
}
// syncIDRange holds range which IDs need to be synced.
type syncIDRange struct {
syncState *syncState
StartID string
StopID string
}
func (r *syncIDRange) setStartID(startID string) {
r.StartID = startID
r.syncState.save()
}
func (r *syncIDRange) setStopID(stopID string) {
r.StopID = stopID
r.syncState.save()
}
// isFinished returns syncIDRange is finished when StartID and StopID
// are the same. But it cannot be full range, full range cannot be
// determined in other way than asking API.
func (r *syncIDRange) isFinished() bool {
return r.StartID == r.StopID && r.StartID != ""
}

View File

@ -0,0 +1,85 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSyncState_IDRanges(t *testing.T) {
store := &mockStoreSynchronizer{}
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
syncState.initIDRanges()
syncState.addIDRange("100")
syncState.addIDRange("200")
r := syncState.idRanges
assert.Equal(t, "", r[0].StartID)
assert.Equal(t, "100", r[0].StopID)
assert.Equal(t, "100", r[1].StartID)
assert.Equal(t, "200", r[1].StopID)
assert.Equal(t, "200", r[2].StartID)
assert.Equal(t, "", r[2].StopID)
}
func TestSyncState_IDRangesLoaded(t *testing.T) {
store := &mockStoreSynchronizer{}
syncState := newSyncState(store, 0, []*syncIDRange{
{StartID: "", StopID: "100"},
{StartID: "100", StopID: ""},
}, []string{})
r := syncState.idRanges
assert.Equal(t, "", r[0].StartID)
assert.Equal(t, "100", r[0].StopID)
assert.Equal(t, "100", r[1].StartID)
assert.Equal(t, "", r[1].StopID)
}
func TestSyncState_IDsToBeDeleted(t *testing.T) {
store := &mockStoreSynchronizer{
allMessageIDs: generateIDs(1, 9),
}
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
require.Nil(t, syncState.loadMessageIDsToBeDeleted())
syncState.doNotDeleteMessageID("1")
syncState.doNotDeleteMessageID("2")
syncState.doNotDeleteMessageID("3")
syncState.doNotDeleteMessageID("notthere")
idsToBeDeleted := syncState.getIDsToBeDeleted()
sort.Strings(idsToBeDeleted)
assert.Equal(t, generateIDs(4, 9), idsToBeDeleted)
}
func TestSyncState_IDsToBeDeletedLoaded(t *testing.T) {
store := &mockStoreSynchronizer{
allMessageIDs: generateIDs(1, 9),
}
syncState := newSyncState(store, 0, []*syncIDRange{}, generateIDs(4, 9))
idsToBeDeleted := syncState.getIDsToBeDeleted()
sort.Strings(idsToBeDeleted)
assert.Equal(t, generateIDs(4, 9), idsToBeDeleted)
}

509
internal/store/sync_test.go Normal file
View File

@ -0,0 +1,509 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sort"
"strconv"
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockLister struct {
err error
messageIDs []string
}
func (m *mockLister) ListMessages(filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
if m.err != nil {
return nil, 0, m.err
}
skipByID := true
skipByPaging := filter.PageSize * filter.Page
for idx := 0; idx < len(m.messageIDs); idx++ {
var messageID string
if !*filter.Desc {
messageID = m.messageIDs[idx]
if filter.BeginID == "" || messageID == filter.BeginID {
skipByID = false
}
} else {
messageID = m.messageIDs[len(m.messageIDs)-1-idx]
if filter.EndID == "" || messageID == filter.EndID {
skipByID = false
}
}
if skipByID {
continue
}
skipByPaging--
if skipByPaging > 0 {
continue
}
msgs = append(msgs, &pmapi.Message{
ID: messageID,
})
if len(msgs) == filter.PageSize || len(msgs) == filter.Limit {
break
}
if !*filter.Desc {
if messageID == filter.EndID {
break
}
} else {
if messageID == filter.BeginID {
break
}
}
}
return msgs, len(m.messageIDs), nil
}
type mockStoreSynchronizer struct {
allMessageIDs []string
errCreateOrUpdateMessagesEvent error
createdMessageIDsByBatch [][]string
}
func (m *mockStoreSynchronizer) getAllMessageIDs() ([]string, error) {
return m.allMessageIDs, nil
}
func (m *mockStoreSynchronizer) createOrUpdateMessagesEvent(messages []*pmapi.Message) error {
if m.errCreateOrUpdateMessagesEvent != nil {
return m.errCreateOrUpdateMessagesEvent
}
createdMessageIDs := []string{}
for _, message := range messages {
createdMessageIDs = append(createdMessageIDs, message.ID)
}
m.createdMessageIDsByBatch = append(m.createdMessageIDsByBatch, createdMessageIDs)
return nil
}
func (m *mockStoreSynchronizer) deleteMessagesEvent([]string) error {
return nil
}
func (m *mockStoreSynchronizer) saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) {
}
func newTestSyncState(store storeSynchronizer, splitIDs ...string) *syncState {
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
syncState.initIDRanges()
for _, splitID := range splitIDs {
syncState.addIDRange(splitID)
}
return syncState
}
func generateIDs(start, stop int) []string {
ids := []string{}
for x := start; x <= stop; x++ {
ids = append(ids, strconv.Itoa(x))
}
return ids
}
func generateIDsR(start, stop int) []string {
ids := []string{}
for x := start; x >= stop; x-- {
ids = append(ids, strconv.Itoa(x))
}
return ids
}
// Tests
func TestSyncAllMail(t *testing.T) { //nolint[funlen]
m, clear := initMocks(t)
defer clear()
numberOfMessages := 10000
api := &mockLister{
messageIDs: generateIDs(1, numberOfMessages),
}
tests := []struct {
name string
idRanges []*syncIDRange
idsToBeDeleted []string
wantUpdatedIDs []string
wantNotUpdatedIDs []string
}{
{
"full sync",
[]*syncIDRange{},
[]string{},
api.messageIDs,
[]string{},
},
{
"continue with interrupted sync",
[]*syncIDRange{
{StartID: "2000", StopID: "2100"},
{StartID: "4000", StopID: "4200"},
{StartID: "9500", StopID: ""},
},
mergeArrays(generateIDs(2000, 2100), generateIDs(4000, 4200), generateIDs(9500, 10010)),
mergeArrays(generateIDs(2000, 2100), generateIDs(4000, 4200), generateIDs(9500, 10000)),
mergeArrays(generateIDs(1, 1999), generateIDs(2101, 3999), generateIDs(4201, 9459)),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
store := &mockStoreSynchronizer{
allMessageIDs: generateIDs(1, numberOfMessages+10),
}
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.Nil(t, err)
// Check all messages were created or updated.
updateMessageIDsMap := map[string]bool{}
for _, messageIDs := range store.createdMessageIDsByBatch {
for _, messageID := range messageIDs {
updateMessageIDsMap[messageID] = true
}
}
for _, messageID := range tc.wantUpdatedIDs {
assert.True(t, updateMessageIDsMap[messageID], "Message %s was not created/updated, but should", messageID)
}
for _, messageID := range tc.wantNotUpdatedIDs {
assert.False(t, updateMessageIDsMap[messageID], "Message %s was created/updated, but shouldn't", messageID)
}
// Check all messages were deleted.
idsToBeDeleted := syncState.getIDsToBeDeleted()
sort.Strings(idsToBeDeleted)
assert.Equal(t, generateIDs(numberOfMessages+1, numberOfMessages+10), idsToBeDeleted)
})
}
}
func mergeArrays(arrays ...[]string) []string {
result := []string{}
for _, array := range arrays {
result = append(result, array...)
}
return result
}
func TestSyncAllMail_FailedListing(t *testing.T) {
m, clear := initMocks(t)
defer clear()
numberOfMessages := 10000
store := &mockStoreSynchronizer{
allMessageIDs: generateIDs(1, numberOfMessages+10),
}
api := &mockLister{
err: errors.New("error"),
messageIDs: generateIDs(1, numberOfMessages),
}
syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to list messages: error")
}
func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
numberOfMessages := 10000
store := &mockStoreSynchronizer{
errCreateOrUpdateMessagesEvent: errors.New("error"),
allMessageIDs: generateIDs(1, numberOfMessages+10),
}
api := &mockLister{
messageIDs: generateIDs(1, numberOfMessages),
}
syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
}
func TestFindIDRanges(t *testing.T) { //nolint[funlen]
store := &mockStoreSynchronizer{}
syncState := newTestSyncState(store)
tests := []struct {
name string
messageIDs []string
wantBatches [][]string
}{
{
"1k messages - 1 batch",
generateIDs(1, 1000),
[][]string{
{"", ""},
},
},
{
"1k messages not starting at 1",
generateIDs(1000, 2000),
[][]string{
{"", ""},
},
},
{
"2k messages - 2 batches",
generateIDs(1, 2000),
[][]string{
{"", "1050"},
{"1050", ""},
},
},
{
"4k messages - 3 batches",
generateIDs(1, 4000),
[][]string{
{"", "1350"},
{"1350", "2700"},
{"2700", ""},
},
},
{
"5k messages - 4 batches",
generateIDs(1, 5000),
[][]string{
{"", "1350"},
{"1350", "2700"},
{"2700", "4050"},
{"4050", ""},
},
},
{
"10k messages - 5 batches",
generateIDs(1, 10000),
[][]string{
{"", "2100"},
{"2100", "4200"},
{"4200", "6300"},
{"6300", "8400"},
{"8400", ""},
},
},
{
"150k messages - 5 batches",
generateIDs(1, 150000),
[][]string{
{"", "30000"},
{"30000", "60000"},
{"60000", "90000"},
{"90000", "120000"},
{"120000", ""},
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := &mockLister{
messageIDs: tc.messageIDs,
}
err := findIDRanges(pmapi.AllMailLabel, api, syncState)
require.Nil(t, err)
require.Equal(t, len(tc.wantBatches), len(syncState.idRanges))
for idx, idRange := range syncState.idRanges {
want := tc.wantBatches[idx]
assert.Equal(t, want[0], idRange.StartID, "Start ID for IDs range %d does not match", idx)
assert.Equal(t, want[1], idRange.StopID, "Stop ID for IDs range %d does not match", idx)
}
})
}
}
func TestFindIDRanges_FailedListing(t *testing.T) {
store := &mockStoreSynchronizer{}
api := &mockLister{
err: errors.New("error"),
}
syncState := newTestSyncState(store)
err := findIDRanges(pmapi.AllMailLabel, api, syncState)
require.EqualError(t, err, "failed to get first ID and count: failed to list messages: error")
}
func TestGetSplitIDAndCount(t *testing.T) { //nolint[funlen]
tests := []struct {
name string
err error
messageIDs []string
page int
wantID string
wantTotal int
wantErr string
}{
{
"1000 messages, first page",
nil,
generateIDs(1, 1000),
0,
"1",
1000,
"",
},
{
"1000 messages, 5th page",
nil,
generateIDs(1, 1000),
4,
"600",
1000,
"",
},
{
"one message, first page",
nil,
[]string{"1"},
0,
"1",
1,
"",
},
{
"no message, first page",
nil,
[]string{},
0,
"",
0,
"",
},
{
"listing error",
errors.New("error"),
generateIDs(1, 1000),
0,
"",
0,
"failed to list messages: error",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := &mockLister{
err: tc.err,
messageIDs: tc.messageIDs,
}
id, total, err := getSplitIDAndCount(pmapi.AllMailLabel, api, tc.page)
if tc.wantErr == "" {
require.Nil(t, err)
} else {
require.EqualError(t, err, tc.wantErr)
}
assert.Equal(t, tc.wantID, id)
assert.Equal(t, tc.wantTotal, total)
})
}
}
func TestSyncBatch(t *testing.T) {
tests := []struct {
name string
batchIdx int
wantCreatedMessageIDsByBatch [][]string
}{
{
"first-batch",
0,
[][]string{generateIDsR(200, 51), generateIDsR(51, 1)},
},
{
"second-batch",
1,
[][]string{generateIDsR(400, 251), generateIDsR(251, 200)},
},
{
"third-batch",
2,
[][]string{generateIDsR(600, 451), generateIDsR(451, 400)},
},
{
"fourth-batch",
3,
[][]string{generateIDsR(800, 651), generateIDsR(651, 600)},
},
{
"fifth-batch",
4,
[][]string{generateIDsR(1000, 851), generateIDsR(851, 800)},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
store := &mockStoreSynchronizer{}
api := &mockLister{
messageIDs: generateIDs(1, 1000),
}
err := testSyncBatch(t, store, api, tc.batchIdx, "200", "400", "600", "800")
require.Nil(t, err)
require.Equal(t, tc.wantCreatedMessageIDsByBatch, store.createdMessageIDsByBatch)
})
}
}
func TestSyncBatch_FailedListing(t *testing.T) {
store := &mockStoreSynchronizer{}
api := &mockLister{
err: errors.New("error"),
messageIDs: generateIDs(1, 1000),
}
err := testSyncBatch(t, store, api, 0)
require.EqualError(t, err, "failed to list messages: error")
}
func TestSyncBatch_FailedCreateOrUpdateMessage(t *testing.T) {
store := &mockStoreSynchronizer{
errCreateOrUpdateMessagesEvent: errors.New("error"),
}
api := &mockLister{
messageIDs: generateIDs(1, 1000),
}
err := testSyncBatch(t, store, api, 0)
require.EqualError(t, err, "failed to create or update messages: error")
}
func testSyncBatch(t *testing.T, store storeSynchronizer, api messageLister, rangeIdx int, splitIDs ...string) error { //nolint[unparam]
syncState := newTestSyncState(store, splitIDs...)
idRange := syncState.idRanges[rangeIdx]
shouldStop := 0
return syncBatch(pmapi.AllMailLabel, store, api, syncState, idRange, &shouldStop)
}

69
internal/store/types.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"io"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type PanicHandler interface {
HandlePanic()
}
// PMAPIProvider is subset of pmapi.Client for use by the Store.
type PMAPIProvider interface {
CurrentUser() (*pmapi.User, error)
Addresses() pmapi.AddressList
GetEvent(eventID string) (*pmapi.Event, error)
CountMessages(addressID string) ([]*pmapi.MessagesCount, error)
ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
GetMessage(apiID string) (*pmapi.Message, error)
Import([]*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error)
DeleteMessages(apiIDs []string) error
LabelMessages(apiIDs []string, labelID string) error
UnlabelMessages(apiIDs []string, labelID string) error
MarkMessagesRead(apiIDs []string) error
MarkMessagesUnread(apiIDs []string) error
CreateDraft(m *pmapi.Message, parent string, action int) (created *pmapi.Message, err error)
CreateAttachment(att *pmapi.Attachment, r io.Reader, sig io.Reader) (created *pmapi.Attachment, err error)
SendMessage(messageID string, req *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error)
ListLabels() ([]*pmapi.Label, error)
CreateLabel(label *pmapi.Label) (*pmapi.Label, error)
UpdateLabel(label *pmapi.Label) (*pmapi.Label, error)
DeleteLabel(labelID string) error
EmptyFolder(labelID string, addressID string) error
}
// BridgeUser is subset of bridge.User for use by the Store.
type BridgeUser interface {
ID() string
GetAddressID(address string) (string, error)
IsConnected() bool
IsCombinedAddressMode() bool
GetPrimaryAddress() string
GetStoreAddresses() []string
UpdateUser() error
CloseConnection(string)
Logout() error
}

64
internal/store/ulimit.go Normal file
View File

@ -0,0 +1,64 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"fmt"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
)
func uLimit() int {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
return 0
}
out, err := exec.Command("bash", "-c", "ulimit -n").Output()
if err != nil {
log.Print(err)
return 0
}
outStr := strings.Trim(string(out), " \n")
num, err := strconv.Atoi(outStr)
if err != nil {
log.Print(err)
return 0
}
return num
}
func isFdCloseToULimit() bool {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
return false
}
pid := fmt.Sprint(os.Getpid())
out, err := exec.Command("lsof", "-p", pid).Output() //nolint[gosec]
if err != nil {
log.Warn("isFdCloseToULimit: ", err)
return false
}
lines := strings.Split(string(out), "\n")
fd := len(lines) - 1
ulimit := uLimit()
log.Info("File descriptor check: num goroutines ", runtime.NumGoroutine(), " fd ", fd, " ulimit ", ulimit)
return fd >= int(0.95*float64(ulimit))
}

41
internal/store/user.go Normal file
View File

@ -0,0 +1,41 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
// UserID returns user ID.
func (store *Store) UserID() string {
return store.user.ID()
}
// GetSpace returns used and total space in bytes.
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
apiUser, err := store.api.CurrentUser()
if err != nil {
return 0, 0, err
}
return uint(apiUser.UsedSpace), uint(apiUser.MaxSpace), nil
}
// GetMaxUpload returns max size of attachment in bytes.
func (store *Store) GetMaxUpload() (uint, error) {
apiUser, err := store.api.CurrentUser()
if err != nil {
return 0, err
}
return uint(apiUser.MaxUpload), nil
}

View File

@ -0,0 +1,216 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"fmt"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
// GetAddress returns the store address by given ID.
func (store *Store) GetAddress(addressID string) (*Address, error) {
store.lock.RLock()
defer store.lock.RUnlock()
storeAddress, ok := store.addresses[addressID]
if !ok {
return nil, fmt.Errorf("addressID %v does not exist", addressID)
}
return storeAddress, nil
}
// RebuildMailboxes truncates all mailbox buckets and recreates them from the metadata bucket again.
func (store *Store) RebuildMailboxes() (err error) {
store.lock.Lock()
defer store.lock.Unlock()
log.WithField("user", store.UserID()).Trace("Truncating mailboxes")
if err = store.truncateMailboxesBucket(); err != nil {
log.WithError(err).Error("Could not truncate mailboxes bucket")
return
}
if err = store.truncateAddressInfoBucket(); err != nil {
log.WithError(err).Error("Could not truncate address info bucket")
return
}
if err = store.init(false); err != nil {
log.WithError(err).Error("Could not init store")
return
}
if err := store.increaseMailboxesVersion(); err != nil {
log.WithError(err).Error("Could not increase structure version")
// Do not return here. The truncation was already done and mode
// was changed in DB so we need to sync so that users start to see
// messages and not block other operations.
}
log.WithField("user", store.UserID()).Trace("Rebuilding mailboxes")
store.triggerSync()
return nil
}
// createOrDeleteAddressesEvent creates address objects in the store for each necessary address
// and deletes any address objects that shouldn't be there.
// It doesn't do anything to addresses that are rightfully there.
// It should only be called from the event loop.
func (store *Store) createOrDeleteAddressesEvent() (err error) {
labels, err := store.initCounts()
if err != nil {
return errors.Wrap(err, "failed to initialise label counts")
}
addrInfo, err := store.GetAddressInfo()
if err != nil {
return errors.Wrap(err, "failed to get addresses and address IDs")
}
// We need at least one address to continue.
if len(addrInfo) < 1 {
return errors.New("no addresses to initialise")
}
// If in combined mode, we only need the user's primary address.
if store.addressMode == combinedMode {
addrInfo = addrInfo[:1]
}
// Go through all addresses that *should* be there.
for _, addr := range addrInfo {
if _, ok := store.addresses[addr.AddressID]; ok {
continue
}
// This address is missing so we create it.
if err = store.addAddress(addr.Address, addr.AddressID, labels); err != nil {
return errors.Wrap(err, "failed to add address to store")
}
}
// Go through all addresses that *should not* be there.
for _, addr := range store.addresses {
belongs := false
for _, a := range addrInfo {
if addr.addressID == a.AddressID {
belongs = true
break
}
}
if belongs {
continue
}
delete(store.addresses, addr.addressID)
}
return err
}
// truncateAddressInfoBucket removes the address info bucket.
func (store *Store) truncateAddressInfoBucket() (err error) {
log.Trace("Truncating address info bucket")
tx := func(tx *bolt.Tx) (err error) {
if err = tx.DeleteBucket(addressInfoBucket); err != nil {
return
}
if _, err = tx.CreateBucketIfNotExists(addressInfoBucket); err != nil {
return
}
return
}
return store.db.Update(tx)
}
// truncateMailboxesBucket removes the mailboxes bucket.
func (store *Store) truncateMailboxesBucket() (err error) {
log.Trace("Truncating mailboxes bucket")
store.addresses = nil
tx := func(tx *bolt.Tx) (err error) {
mbs := tx.Bucket(mailboxesBucket)
return mbs.ForEach(func(addrIDMailbox, _ []byte) (err error) {
addr := mbs.Bucket(addrIDMailbox)
if err = addr.DeleteBucket(imapIDsBucket); err != nil {
return
}
if _, err = addr.CreateBucketIfNotExists(imapIDsBucket); err != nil {
return
}
if err = addr.DeleteBucket(apiIDsBucket); err != nil {
return
}
if _, err = addr.CreateBucketIfNotExists(apiIDsBucket); err != nil {
return
}
return
})
}
return store.db.Update(tx)
}
// initMailboxesBucket recreates the mailboxes bucket from the metadata bucket.
func (store *Store) initMailboxesBucket() error { //nolint[unused]
i := 0
tx := func(tx *bolt.Tx) error {
return tx.Bucket(metadataBucket).ForEach(func(k, v []byte) error {
msg := &pmapi.Message{}
if err := json.Unmarshal(v, msg); err != nil {
return err
}
for _, a := range store.addresses {
if err := a.txCreateOrUpdateMessages(tx, []*pmapi.Message{msg}); err != nil {
return err
}
}
i++
if i%100 == 0 {
store.log.WithField("i", i).
Trace("Init mboxes heartbeat")
}
return nil
})
}
return store.db.Update(tx)
}

View File

@ -0,0 +1,158 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
)
// AddressInfo is the format of the data held in the addresses bucket in the store.
// It allows us to easily keep an address and its ID together and serialisation/deserialisation to []byte.
type AddressInfo struct {
Address, AddressID string
}
// GetAddressID returns the ID of the given address.
func (store *Store) GetAddressID(addr string) (id string, err error) {
addrs, err := store.GetAddressInfo()
if err != nil {
return
}
for _, addrInfo := range addrs {
if strings.EqualFold(addrInfo.Address, addr) {
id = addrInfo.AddressID
return
}
}
err = errors.New("no such address")
return
}
// GetAddressInfo returns information about all addresses owned by the user.
// The first element is the user's primary address and the rest (if present) are aliases.
// It tries to source the information from the store but if the store doesn't yet have it, it
// fetches it from the API and caches it for later.
func (store *Store) GetAddressInfo() (addrs []AddressInfo, err error) {
if addrs, err = store.getAddressInfoFromStore(); err == nil && len(addrs) > 0 {
return
}
// Store does not have address info yet, need to build it first from API.
addressList := store.api.Addresses()
if addressList == nil {
err = errors.New("addresses unavailable")
store.log.WithError(err).Error("Could not get user addresses from API")
return
}
if err = store.createOrUpdateAddressInfo(addressList); err != nil {
store.log.WithError(err).Warn("Could not update address IDs in store")
return
}
return store.getAddressInfoFromStore()
}
// getAddressIDsByAddressFromStore returns a map from address to addressID for each address owned by the user.
func (store *Store) getAddressInfoFromStore() (addrs []AddressInfo, err error) {
store.log.Debug("Retrieving address info from store")
tx := func(tx *bolt.Tx) (err error) {
c := tx.Bucket(addressInfoBucket).Cursor()
for index, addrInfoBytes := c.First(); index != nil; index, addrInfoBytes = c.Next() {
var addrInfo AddressInfo
if err = json.Unmarshal(addrInfoBytes, &addrInfo); err != nil {
store.log.WithError(err).Error("Could not unmarshal address and addressID")
return
}
addrs = append(addrs, addrInfo)
}
return
}
err = store.db.View(tx)
return
}
// createOrUpdateAddressInfo updates the store address/addressID bucket to match the given address list.
// The address list supplied is assumed to contain active emails in any order.
// It firstly (and stupidly) deletes the bucket of addresses and then fills it with up to date info.
// This is because a user might delete an address and we don't want old addresses lying around (and finding the
// specific ones to delete is likely not much more efficient than just rebuilding from scratch).
func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (err error) {
tx := func(tx *bolt.Tx) error {
if err := tx.DeleteBucket(addressInfoBucket); err != nil {
store.log.WithError(err).Error("Could not delete addressIDs bucket")
return err
}
if _, err := tx.CreateBucketIfNotExists(addressInfoBucket); err != nil {
store.log.WithError(err).Error("Could not recreate addressIDs bucket")
return err
}
addrsBucket := tx.Bucket(addressInfoBucket)
for index, address := range filterAddresses(addressList) {
ib := itob(uint32(index))
info, err := json.Marshal(AddressInfo{
Address: address.Email,
AddressID: address.ID,
})
if err != nil {
store.log.WithError(err).Error("Could not marshal address and addressID")
return err
}
if err := addrsBucket.Put(ib, info); err != nil {
store.log.WithError(err).Error("Could not put address and addressID into store")
return err
}
}
return nil
}
return store.db.Update(tx)
}
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
for _, address := range addressList {
if address.Receive != pmapi.CanReceive {
continue
}
filteredList = append(filteredList, address)
}
return
}

View File

@ -0,0 +1,229 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"fmt"
"strings"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
)
// createMailbox creates the mailbox via the API.
// The store mailbox is created later by processing an event.
func (store *Store) createMailbox(name string) error {
defer store.eventLoop.pollNow()
log.WithField("name", name).Debug("Creating mailbox")
if store.hasMailbox(name) {
return fmt.Errorf("mailbox %v already exists", name)
}
color := store.leastUsedColor()
var exclusive int
switch {
case strings.HasPrefix(name, UserLabelsPrefix):
name = strings.TrimPrefix(name, UserLabelsPrefix)
exclusive = 0
case strings.HasPrefix(name, UserFoldersPrefix):
name = strings.TrimPrefix(name, UserFoldersPrefix)
exclusive = 1
default:
// Ideally we would throw an error here, but then Outlook for
// macOS keeps trying to make an IMAP Drafts folder and popping
// up the error to the user.
store.log.WithField("name", name).
Warn("Ignoring creation of new mailbox in IMAP root")
return nil
}
_, err := store.api.CreateLabel(&pmapi.Label{
Name: name,
Color: color,
Exclusive: exclusive,
Type: pmapi.LabelTypeMailbox,
})
return err
}
// allAddressesHaveMailbox returns whether each address has a mailbox with the given labelID.
func (store *Store) allAddressesHaveMailbox(labelID string) bool {
store.lock.RLock()
defer store.lock.RUnlock()
for _, a := range store.addresses {
addressHasMailbox := false
for _, m := range a.mailboxes {
if m.labelID == labelID {
addressHasMailbox = true
break
}
}
if !addressHasMailbox {
return false
}
}
return true
}
// hasMailbox returns whether there is at least one address which has a mailbox with the given name.
func (store *Store) hasMailbox(name string) bool {
mailbox, _ := store.getMailbox(name)
return mailbox != nil
}
// getMailbox returns the first mailbox with the given name.
func (store *Store) getMailbox(name string) (*Mailbox, error) {
store.lock.RLock()
defer store.lock.RUnlock()
for _, a := range store.addresses {
for _, m := range a.mailboxes {
if m.labelName == name {
return m, nil
}
}
}
return nil, fmt.Errorf("mailbox %s does not exist", name)
}
// leastUsedColor returns the least used color to be used for a newly created folder or label.
func (store *Store) leastUsedColor() string {
store.lock.RLock()
defer store.lock.RUnlock()
usage := map[string]int{}
for _, a := range store.addresses {
for _, m := range a.mailboxes {
if m.color != "" {
usage[m.color]++
}
}
}
leastUsed := pmapi.LabelColors[0]
for _, color := range pmapi.LabelColors {
if usage[leastUsed] > usage[color] {
leastUsed = color
}
}
return leastUsed
}
// updateMailbox updates the mailbox via the API.
// The store mailbox is updated later by processing an event.
func (store *Store) updateMailbox(labelID, newName, color string) error {
defer store.eventLoop.pollNow()
_, err := store.api.UpdateLabel(&pmapi.Label{
ID: labelID,
Name: newName,
Color: color,
})
return err
}
// deleteMailbox deletes the mailbox via the API.
// The store mailbox is deleted later by processing an event.
func (store *Store) deleteMailbox(labelID, addressID string) error {
defer store.eventLoop.pollNow()
if pmapi.IsSystemLabel(labelID) {
var err error
switch labelID {
case pmapi.SpamLabel:
err = store.api.EmptyFolder(pmapi.SpamLabel, addressID)
case pmapi.TrashLabel:
err = store.api.EmptyFolder(pmapi.TrashLabel, addressID)
default:
err = fmt.Errorf("cannot empty mailbox %v", labelID)
}
return err
}
return store.api.DeleteLabel(labelID)
}
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
newLabelIDs := []string{}
for labelID := range affectedLabelIDs {
if pmapi.IsSystemLabel(labelID) || store.allAddressesHaveMailbox(labelID) {
continue
}
newLabelIDs = append(newLabelIDs, labelID)
}
if len(newLabelIDs) == 0 {
return nil
}
labels, err := store.api.ListLabels()
if err != nil {
return err
}
for _, newLabelID := range newLabelIDs {
for _, label := range labels {
if label.ID != newLabelID {
continue
}
if err := store.createOrUpdateMailboxEvent(label); err != nil {
return err
}
}
}
return nil
}
// createOrUpdateMailboxEvent creates or updates the mailbox in the store.
// This is called from the event loop.
func (store *Store) createOrUpdateMailboxEvent(label *pmapi.Label) error {
store.lock.Lock()
defer store.lock.Unlock()
if label.Type != pmapi.LabelTypeMailbox {
return nil
}
if err := store.createOrUpdateMailboxCountsBuckets([]*pmapi.Label{label}); err != nil {
return errors.Wrap(err, "cannot update counts")
}
for _, a := range store.addresses {
if err := a.createOrUpdateMailboxEvent(label); err != nil {
return err
}
}
return nil
}
// deleteMailboxEvent deletes the mailbox in the store.
// This is called from the event loop.
func (store *Store) deleteMailboxEvent(labelID string) error {
store.lock.Lock()
defer store.lock.Unlock()
_ = store.removeMailboxCount(labelID)
for _, a := range store.addresses {
if err := a.deleteMailboxEvent(labelID); err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,329 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"net/mail"
"net/textproto"
"strings"
pmcrypto "github.com/ProtonMail/gopenpgp/crypto"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
// CreateDraft creates draft with attachments.
// If `attachedPublicKey` is passed, it's added to attachments.
// Both draft and attachments are encrypted with passed `kr` key.
func (store *Store) CreateDraft(
kr *pmcrypto.KeyRing,
message *pmapi.Message,
attachmentReaders []io.Reader,
attachedPublicKey,
attachedPublicKeyName string,
parentID string) (*pmapi.Message, []*pmapi.Attachment, error) {
defer store.eventLoop.pollNow()
// Since this is a draft, we don't need to sign it.
if err := message.Encrypt(kr, nil); err != nil {
return nil, nil, errors.Wrap(err, "failed to encrypt draft")
}
attachments := message.Attachments
message.Attachments = nil
draftAction := store.getDraftAction(message)
draft, err := store.api.CreateDraft(message, parentID, draftAction)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create draft")
}
if attachedPublicKey != "" {
attachmentReaders = append(attachmentReaders, strings.NewReader(attachedPublicKey))
publicKeyAttachment := &pmapi.Attachment{
Name: attachedPublicKeyName + ".asc",
MIMEType: "application/pgp-key",
Header: textproto.MIMEHeader{},
}
attachments = append(attachments, publicKeyAttachment)
}
for idx, attachment := range attachments {
attachment.MessageID = draft.ID
attachmentBody, _ := ioutil.ReadAll(attachmentReaders[idx])
createdAttachment, err := store.createAttachment(kr, attachment, attachmentBody)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to create attachment for draft")
}
attachments[idx] = createdAttachment
}
return draft, attachments, nil
}
func (store *Store) getDraftAction(message *pmapi.Message) int {
// If not a reply, must be a forward.
if len(message.Header["In-Reply-To"]) == 0 {
return pmapi.DraftActionForward
}
return pmapi.DraftActionReply
}
func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Attachment, attachmentBody []byte) (*pmapi.Attachment, error) {
r := bytes.NewReader(attachmentBody)
sigReader, err := attachment.DetachedSign(kr, r)
if err != nil {
return nil, errors.Wrap(err, "failed to sign attachment")
}
r = bytes.NewReader(attachmentBody)
encReader, err := attachment.Encrypt(kr, r)
if err != nil {
return nil, errors.Wrap(err, "failed to encrypt attachment")
}
createdAttachment, err := store.api.CreateAttachment(attachment, encReader, sigReader)
if err != nil {
return nil, errors.Wrap(err, "failed to create attachment")
}
return createdAttachment, nil
}
// SendMessage sends the message.
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
defer store.eventLoop.pollNow()
_, _, err := store.api.SendMessage(messageID, req)
return err
}
// getAllMessageIDs returns all API IDs of messages in the local database.
func (store *Store) getAllMessageIDs() (apiIDs []string, err error) {
err = store.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(metadataBucket)
return b.ForEach(func(k, v []byte) error {
apiIDs = append(apiIDs, string(k))
return nil
})
})
return
}
// getMessageFromDB returns pmapi struct of message by API ID.
func (store *Store) getMessageFromDB(apiID string) (msg *pmapi.Message, err error) {
err = store.db.View(func(tx *bolt.Tx) error {
msg, err = store.txGetMessage(tx, apiID)
return err
})
return
}
// fetchMessage returns pmapi struct of message by API ID. If the requested
// message is not in the database, it will try to fetch it from the server.
// NOTE: Do not update the database here to prevent issues (extreme edge case).
// The database will be updated by the event loop anyway.
func (store *Store) fetchMessage(apiID string) (msg *pmapi.Message, err error) {
if msg, err = store.api.GetMessage(apiID); err != nil {
if err.Error() == "Message does not exist" {
return nil, ErrNoSuchAPIID
}
}
return
}
func (store *Store) txGetMessage(tx *bolt.Tx, apiID string) (*pmapi.Message, error) {
b := tx.Bucket(metadataBucket)
msgb := b.Get([]byte(apiID))
if msgb == nil {
return nil, ErrNoSuchAPIID
}
msg := &pmapi.Message{}
if err := json.Unmarshal(msgb, msg); err != nil {
return nil, err
}
return msg, nil
}
func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message) error {
b, err := json.Marshal(onlyMeta)
if err != nil {
return errors.Wrap(err, "cannot marshall metadata")
}
err = metaBucket.Put([]byte(onlyMeta.ID), b)
if err != nil {
return errors.Wrap(err, "cannot add to metadata bucket")
}
return nil
}
// createOrUpdateMessageEvent is helper to create only one message with
// createOrUpdateMessagesEvent.
func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error {
return store.createOrUpdateMessagesEvent([]*pmapi.Message{msg})
}
// createOrUpdateMessagesEvent tries to create or update messages in database.
// This function is optimised for insertion of many messages at once.
// It calls createLabelsIfMissing if needed.
func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { //nolint[funlen]
store.log.WithField("msgs", msgs).Trace("Creating or updating messages in the store")
// Strip non meta first to reduce memory (no need to keep all old msg ID data during update).
err := store.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(metadataBucket)
for _, msg := range msgs {
clearNonMetadata(msg)
txUpdateMetadaFromDB(b, msg, store.log)
}
return nil
})
if err != nil {
return err
}
affectedLabels := map[string]bool{}
for _, m := range msgs {
for _, l := range m.LabelIDs {
affectedLabels[l] = true
}
}
if err = store.createLabelsIfMissing(affectedLabels); err != nil {
return err
}
// Updating metadata and mailboxes is not atomic, but this is OK.
// The worst case scenario is we have metadata but not updated mailboxes
// which is OK as without information in mailboxes IMAP we will never ask
// for metadata. Also, when doing the operation again, it will simply
// update the metadata.
// The reason to split is efficiency--it's more memory efficient.
// Update metadata.
err = store.db.Update(func(tx *bolt.Tx) error {
metaBucket := tx.Bucket(metadataBucket)
for _, msg := range msgs {
err := store.txPutMessage(metaBucket, msg)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
// Update mailboxes.
err = store.db.Update(func(tx *bolt.Tx) error {
for _, a := range store.addresses {
if err := a.txCreateOrUpdateMessages(tx, msgs); err != nil {
store.log.WithError(err).Error("cannot update maiboxes")
return errors.Wrap(err, "cannot add to mailboxes bucket")
}
}
return nil
})
if err != nil {
return err
}
return nil
}
// clearNonMetadata to not allow to store decrypted or encrypted data i.e. body
// and attachments.
func clearNonMetadata(onlyMeta *pmapi.Message) {
onlyMeta.Body = ""
onlyMeta.Attachments = nil
}
// txUpdateMetadaFromDB changes the the onlyMeta data.
// If there is stored message in metaBucket the size, header and MIMEType are
// not changed if already set. To change these:
// * size must be updated by Message.SetSize
// * contentType and header must be updated by Message.SetContentTypeAndHeader
func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
// Size attribute on the server is counting encrypted data. We need to compute
// "real" size of decrypted data. Negative values will be processed during fetch.
onlyMeta.Size = -1
msgb := metaBucket.Get([]byte(onlyMeta.ID))
if msgb == nil {
return
}
// It is faster to unmarshal only the needed items.
stored := &struct {
Size int64
Header string
MIMEType string
}{}
if err := json.Unmarshal(msgb, stored); err != nil {
log.WithError(err).
Error("Fail to unmarshal from DB, metadata will be overwritten")
return
}
// Keep already calculated size and content type.
onlyMeta.Size = stored.Size
onlyMeta.MIMEType = stored.MIMEType
if stored.Header != "" && stored.Header != "(No Header)" {
tmpMsg, err := mail.ReadMessage(
strings.NewReader(stored.Header + "\r\n\r\n"),
)
if err == nil {
onlyMeta.Header = tmpMsg.Header
} else {
log.WithError(err).
Error("Fail to parse, the header will be overwritten")
}
}
}
// deleteMessageEvent is helper to delete only one message with deleteMessagesEvent.
func (store *Store) deleteMessageEvent(apiID string) error {
return store.deleteMessagesEvent([]string{apiID})
}
// deleteMessagesEvent deletes the message from metadata and all mailbox buckets.
func (store *Store) deleteMessagesEvent(apiIDs []string) error {
return store.db.Update(func(tx *bolt.Tx) error {
for _, apiID := range apiIDs {
if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil {
return err
}
for _, a := range store.addresses {
if err := a.txDeleteMessage(tx, apiID); err != nil {
return err
}
}
}
return nil
})
}

View File

@ -0,0 +1,156 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"net/mail"
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
a "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetAllMessageIDs(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{})
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
}
func TestGetMessageFromDB(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
tests := []struct{ msgID, wantErr string }{
{"msg1", ""},
{"msg2", "no such api id"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.msgID, func(t *testing.T) {
msg, err := m.store.getMessageFromDB(tc.msgID)
if tc.wantErr != "" {
require.EqualError(t, err, tc.wantErr)
} else {
require.Nil(t, err)
require.Equal(t, tc.msgID, msg.ID)
}
})
}
}
func TestCreateOrUpdateMessageMetadata(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
msg, err := m.store.getMessageFromDB("msg1")
require.Nil(t, err)
// Check non-meta and calculated data are cleared/empty.
a.Equal(t, "", msg.Body)
a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments)
a.Equal(t, int64(-1), msg.Size)
a.Equal(t, "", msg.MIMEType)
a.Equal(t, mail.Header(nil), msg.Header)
// Change the calculated data.
wantSize := int64(42)
wantMIMEType := "plain-text"
wantHeader := mail.Header{
"Key": []string{"value"},
}
storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1")
require.Nil(t, err)
require.Nil(t, storeMsg.SetSize(wantSize))
require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader))
// Check calculated data.
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
// Check calculated data are not overridden by reinsert.
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
msg, err = m.store.getMessageFromDB("msg1")
require.Nil(t, err)
a.Equal(t, wantSize, msg.Size)
a.Equal(t, wantMIMEType, msg.MIMEType)
a.Equal(t, wantHeader, msg.Header)
}
func TestDeleteMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
require.Nil(t, m.store.deleteMessageEvent("msg1"))
checkAllMessageIDs(t, m, []string{"msg2"})
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
}
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread int, labelIDs []string) { //nolint[unparam]
msg := getTestMessage(id, subject, sender, unread, labelIDs)
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
}
func getTestMessage(id, subject, sender string, unread int, labelIDs []string) *pmapi.Message {
address := &mail.Address{Address: sender}
return &pmapi.Message{
ID: id,
Subject: subject,
Unread: unread,
Sender: address,
ToList: []*mail.Address{address},
LabelIDs: labelIDs,
Size: 12345,
Body: "body of message",
Attachments: []*pmapi.Attachment{{
ID: "attachment1",
MessageID: id,
Name: "attachment",
}},
}
}
func checkAllMessageIDs(t *testing.T, m *mocksForStore, wantIDs []string) {
allIds, allErr := m.store.getAllMessageIDs()
require.Nil(t, allErr)
require.Equal(t, wantIDs, allIds)
}

247
internal/store/user_sync.go Normal file
View File

@ -0,0 +1,247 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"encoding/json"
"fmt"
"strconv"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
bolt "go.etcd.io/bbolt"
)
const syncFinishTimeKey = "sync_state" // The original key was sync_state and we want to keep compatibility.
const syncIDRangesKey = "id_ranges"
const syncIDsToBeDeletedKey = "ids_to_be_deleted"
// updateCountsFromServer will download and set the counts.
func (store *Store) updateCountsFromServer() error {
counts, err := store.api.CountMessages("")
if err != nil {
return errors.Wrap(err, "cannot update counts from server")
}
return store.createOrUpdateOnAPICounts(counts)
}
// isSynced checks whether DB counts are synced with provided counts from API.
func (store *Store) isSynced(countsOnAPI []*pmapi.MessagesCount) (bool, error) {
store.log.WithField("apiCounts", countsOnAPI).Debug("Checking whether store is synced")
// IMPORTANT: The countsOnAPI can contain duplicates due to event merge
// (ie one label can be present multiple times). It is important to
// process all counts before checking whether they are synced.
if err := store.createOrUpdateOnAPICounts(countsOnAPI); err != nil {
store.log.WithError(err).Error("Cannot update counts before check sync")
return false, err
}
allCounts, err := store.getOnAPICounts()
if err != nil {
return false, err
}
store.lock.Lock()
defer store.lock.Unlock()
countsAreOK := true
for _, counts := range allCounts {
total, unread := uint(0), uint(0)
for _, address := range store.addresses {
mbox, err := address.getMailboxByID(counts.LabelID)
if err != nil {
return false, errors.Wrapf(
err,
"cannot find mailbox for address %q",
address.addressID,
)
}
mboxTot, mboxUnread, err := mbox.GetCounts()
if err != nil {
errW := errors.Wrap(err, "cannot count messages")
store.log.
WithError(errW).
WithField("label", counts.LabelID).
WithField("address", address.addressID).
Error("IsSynced failed")
return false, err
}
total += mboxTot
unread += mboxUnread
}
if total != counts.TotalOnAPI || unread != counts.UnreadOnAPI {
store.log.WithFields(logrus.Fields{
"label": counts.LabelID,
"db-total": total,
"db-unread": unread,
"api-total": counts.TotalOnAPI,
"api-unread": counts.UnreadOnAPI,
}).Warning("counts differ")
countsAreOK = false
}
}
return countsAreOK, nil
}
// triggerSync starts a sync of complete user by syncing All Mail mailbox.
// All Mail mailbox contains all messages, so we download all meta data needed
// to generate any address/mailbox IMAP UIDs.
// Sync state can be in three states:
// * Nothing in database. For example when user logs in for the first time.
// `triggerSync` will start full sync.
// * Database has syncIDRangesKey and syncIDsToBeDeletedKey keys with data.
// Sync is in progress or was interrupted. In later case when, `triggerSync`
// will continue where it left off.
// * Database has only syncStateKey with time when database was last synced.
// `triggerSync` will reset it and start full sync again.
func (store *Store) triggerSync() {
syncState := store.loadSyncState()
// We first clear the last sync state in case this sync fails.
syncState.clearFinishTime()
// We don't want sync to block.
go func() {
defer store.panicHandler.HandlePanic()
store.log.Debug("Store sync triggered")
store.lock.Lock()
if store.isSyncRunning {
store.lock.Unlock()
store.log.Info("Store sync is already ongoing")
return
}
store.isSyncRunning = true
store.lock.Unlock()
defer func() {
store.lock.Lock()
store.isSyncRunning = false
store.lock.Unlock()
}()
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
err := syncAllMail(store.panicHandler, store, store.api, syncState)
if err != nil {
log.WithError(err).Error("Store sync failed")
return
}
syncState.setFinishTime()
}()
}
// isSyncFinished returns whether the database has finished a sync.
func (store *Store) isSyncFinished() (isSynced bool) {
return store.loadSyncState().isFinished()
}
// loadSyncState loads information about sync from database.
// See `triggerSync` to learn more about possible states.
func (store *Store) loadSyncState() *syncState {
finishTime := int64(0)
idRanges := []*syncIDRange{}
idsToBeDeleted := []string{}
err := store.db.View(func(tx *bolt.Tx) (err error) {
b := tx.Bucket(syncStateBucket)
finishTimeByte := b.Get([]byte(syncFinishTimeKey))
if finishTimeByte != nil {
finishTime, err = strconv.ParseInt(string(finishTimeByte), 10, 64)
if err != nil {
store.log.WithError(err).Error("Failed to unmarshal sync IDs ranges")
}
}
idRangesData := b.Get([]byte(syncIDRangesKey))
if idRangesData != nil {
if err := json.Unmarshal(idRangesData, &idRanges); err != nil {
store.log.WithError(err).Error("Failed to unmarshal sync IDs ranges")
}
}
idsToBeDeletedData := b.Get([]byte(syncIDsToBeDeletedKey))
if idsToBeDeletedData != nil {
if err := json.Unmarshal(idsToBeDeletedData, &idsToBeDeleted); err != nil {
store.log.WithError(err).Error("Failed to unmarshal sync IDs to be deleted")
}
}
return
})
if err != nil {
store.log.WithError(err).Error("Failed to load sync state")
}
return newSyncState(store, finishTime, idRanges, idsToBeDeleted)
}
// saveSyncState saves information about sync to database.
// See `triggerSync` to learn more about possible states.
func (store *Store) saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) {
idRangesData, err := json.Marshal(idRanges)
if err != nil {
store.log.WithError(err).Error("Failed to marshall sync IDs ranges")
}
idsToBeDeletedData, err := json.Marshal(idsToBeDeleted)
if err != nil {
store.log.WithError(err).Error("Failed to marshall sync IDs to be deleted")
}
err = store.db.Update(func(tx *bolt.Tx) (err error) {
b := tx.Bucket(syncStateBucket)
if finishTime != 0 {
curTime := []byte(fmt.Sprintf("%v", finishTime))
if err := b.Put([]byte(syncFinishTimeKey), curTime); err != nil {
return err
}
if err := b.Delete([]byte(syncIDRangesKey)); err != nil {
return err
}
if err := b.Delete([]byte(syncIDsToBeDeletedKey)); err != nil {
return err
}
} else {
if err := b.Delete([]byte(syncFinishTimeKey)); err != nil {
return err
}
if err := b.Put([]byte(syncIDRangesKey), idRangesData); err != nil {
return err
}
if err := b.Put([]byte(syncIDsToBeDeletedKey), idsToBeDeletedData); err != nil {
return err
}
}
return nil
})
if err != nil {
store.log.WithError(err).Error("Failed to set sync state")
}
}

View File

@ -0,0 +1,90 @@
// Copyright (c) 2020 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"sort"
"testing"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoadSaveSyncState(t *testing.T) {
m, clear := initMocks(t)
defer clear()
m.newStoreNoEvents(true)
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
// Clear everything.
syncState := m.store.loadSyncState()
syncState.clearFinishTime()
// Check everything is empty at the beginning.
syncState = m.store.loadSyncState()
checkSyncStateAfterLoad(t, syncState, false, false, []string{})
// Save IDs ranges and check everything is also properly loaded.
syncState.initIDRanges()
syncState.addIDRange("100")
syncState.addIDRange("200")
syncState.save()
syncState = m.store.loadSyncState()
checkSyncStateAfterLoad(t, syncState, false, true, []string{})
// Save IDs to be deleted and check everything is properly loaded.
require.Nil(t, syncState.loadMessageIDsToBeDeleted())
syncState = m.store.loadSyncState()
checkSyncStateAfterLoad(t, syncState, false, true, []string{"msg1", "msg2"})
// Set finish time and check everything is resetted to empty values.
syncState.setFinishTime()
syncState = m.store.loadSyncState()
checkSyncStateAfterLoad(t, syncState, true, false, []string{})
}
func checkSyncStateAfterLoad(t *testing.T, syncState *syncState, wantIsFinished bool, wantIDRanges bool, wantIDsToBeDeleted []string) {
assert.Equal(t, wantIsFinished, syncState.isFinished())
if wantIDRanges {
require.Equal(t, 3, len(syncState.idRanges))
assert.Equal(t, "", syncState.idRanges[0].StartID)
assert.Equal(t, "100", syncState.idRanges[0].StopID)
assert.Equal(t, "100", syncState.idRanges[1].StartID)
assert.Equal(t, "200", syncState.idRanges[1].StopID)
assert.Equal(t, "200", syncState.idRanges[2].StartID)
assert.Equal(t, "", syncState.idRanges[2].StopID)
} else {
assert.Empty(t, syncState.idRanges)
}
idsToBeDeleted := syncState.getIDsToBeDeleted()
sort.Strings(idsToBeDeleted)
assert.Equal(t, wantIDsToBeDeleted, idsToBeDeleted)
}