mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-15 14:56:42 +00:00
We build too many walls and not enough bridges
This commit is contained in:
109
internal/store/address.go
Normal file
109
internal/store/address.go
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Address holds mailboxes for IMAP user (login address). In combined mode
|
||||
// there is only one address, in split mode there is one object per address.
|
||||
type Address struct {
|
||||
store *Store
|
||||
address string
|
||||
addressID string
|
||||
mailboxes map[string]*Mailbox
|
||||
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func newAddress(
|
||||
store *Store,
|
||||
address, addressID string,
|
||||
labels []*pmapi.Label,
|
||||
) (addr *Address, err error) {
|
||||
l := log.WithField("addressID", addressID)
|
||||
|
||||
storeAddress := &Address{
|
||||
store: store,
|
||||
address: address,
|
||||
addressID: addressID,
|
||||
log: l,
|
||||
}
|
||||
|
||||
if err = storeAddress.init(labels); err != nil {
|
||||
l.WithField("address", address).
|
||||
WithError(err).
|
||||
Error("Could not initialise store address")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return storeAddress, nil
|
||||
}
|
||||
|
||||
func (storeAddress *Address) init(foldersAndLabels []*pmapi.Label) (err error) {
|
||||
storeAddress.log.WithField("address", storeAddress.address).Debug("Initialising store address")
|
||||
|
||||
storeAddress.mailboxes = make(map[string]*Mailbox)
|
||||
|
||||
for _, label := range foldersAndLabels {
|
||||
prefix := getLabelPrefix(label)
|
||||
|
||||
var mailbox *Mailbox
|
||||
if mailbox, err = newMailbox(storeAddress, label.ID, prefix, label.Name, label.Color); err != nil {
|
||||
storeAddress.log.
|
||||
WithError(err).
|
||||
WithField("labelID", label.ID).
|
||||
Error("Could not init mailbox for folder or label")
|
||||
return
|
||||
}
|
||||
|
||||
storeAddress.mailboxes[label.ID] = mailbox
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// getLabelPrefix returns the correct prefix for a pmapi label according to whether it is exclusive or not.
|
||||
func getLabelPrefix(l *pmapi.Label) string {
|
||||
switch {
|
||||
case pmapi.IsSystemLabel(l.ID):
|
||||
return ""
|
||||
case l.Exclusive == 1:
|
||||
return UserFoldersPrefix
|
||||
default:
|
||||
return UserLabelsPrefix
|
||||
}
|
||||
}
|
||||
|
||||
// AddressString returns the address.
|
||||
func (storeAddress *Address) AddressString() string {
|
||||
return storeAddress.address
|
||||
}
|
||||
|
||||
// AddressID returns the address ID.
|
||||
func (storeAddress *Address) AddressID() string {
|
||||
return storeAddress.addressID
|
||||
}
|
||||
|
||||
// APIAddress returns the `pmapi.Address` struct.
|
||||
func (storeAddress *Address) APIAddress() *pmapi.Address {
|
||||
return storeAddress.store.api.Addresses().ByEmail(storeAddress.address)
|
||||
}
|
||||
106
internal/store/address_mailbox.go
Normal file
106
internal/store/address_mailbox.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
// ListMailboxes returns all mailboxes.
|
||||
func (storeAddress *Address) ListMailboxes() []*Mailbox {
|
||||
storeAddress.store.lock.RLock()
|
||||
defer storeAddress.store.lock.RUnlock()
|
||||
|
||||
mailboxes := make([]*Mailbox, 0, len(storeAddress.mailboxes))
|
||||
for _, m := range storeAddress.mailboxes {
|
||||
mailboxes = append(mailboxes, m)
|
||||
}
|
||||
return mailboxes
|
||||
}
|
||||
|
||||
// GetMailbox returns mailbox with the given IMAP name.
|
||||
func (storeAddress *Address) GetMailbox(name string) (*Mailbox, error) {
|
||||
storeAddress.store.lock.RLock()
|
||||
defer storeAddress.store.lock.RUnlock()
|
||||
|
||||
for _, m := range storeAddress.mailboxes {
|
||||
if m.Name() == name {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("mailbox %v does not exist", name)
|
||||
}
|
||||
|
||||
// CreateMailbox creates the mailbox by calling an API.
|
||||
// Mailbox is created in the structure by processing event.
|
||||
func (storeAddress *Address) CreateMailbox(name string) error {
|
||||
return storeAddress.store.createMailbox(name)
|
||||
}
|
||||
|
||||
// updateMailbox updates the mailbox by calling an API.
|
||||
// Mailbox is updated in the structure by processing event.
|
||||
func (storeAddress *Address) updateMailbox(labelID, newName, color string) error {
|
||||
return storeAddress.store.updateMailbox(labelID, newName, color)
|
||||
}
|
||||
|
||||
// deleteMailbox deletes the mailbox by calling an API.
|
||||
// Mailbox is deleted in the structure by processing event.
|
||||
func (storeAddress *Address) deleteMailbox(labelID string) error {
|
||||
return storeAddress.store.deleteMailbox(labelID, storeAddress.addressID)
|
||||
}
|
||||
|
||||
// createOrUpdateMailboxEvent creates or updates the mailbox in the structure.
|
||||
// This is called from the event loop.
|
||||
func (storeAddress *Address) createOrUpdateMailboxEvent(label *pmapi.Label) error {
|
||||
prefix := getLabelPrefix(label)
|
||||
mailbox, ok := storeAddress.mailboxes[label.ID]
|
||||
if !ok {
|
||||
mailbox, err := newMailbox(storeAddress, label.ID, prefix, label.Name, label.Color)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storeAddress.mailboxes[label.ID] = mailbox
|
||||
} else {
|
||||
mailbox.labelName = prefix + label.Name
|
||||
mailbox.color = label.Color
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteMailboxEvent deletes the mailbox in the structure.
|
||||
// This is called from the event loop.
|
||||
func (storeAddress *Address) deleteMailboxEvent(labelID string) error {
|
||||
storeMailbox, ok := storeAddress.mailboxes[labelID]
|
||||
if !ok {
|
||||
log.WithField("labelID", labelID).Warn("Could not find mailbox to delete")
|
||||
return nil
|
||||
}
|
||||
delete(storeAddress.mailboxes, labelID)
|
||||
return storeMailbox.deleteMailboxEvent()
|
||||
}
|
||||
|
||||
func (storeAddress *Address) getMailboxByID(labelID string) (*Mailbox, error) {
|
||||
storeMailbox, ok := storeAddress.mailboxes[labelID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("mailbox with id %q does not exist", labelID)
|
||||
}
|
||||
return storeMailbox, nil
|
||||
}
|
||||
42
internal/store/address_message.go
Normal file
42
internal/store/address_message.go
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func (storeAddress *Address) txCreateOrUpdateMessages(tx *bolt.Tx, msgs []*pmapi.Message) error {
|
||||
for _, m := range storeAddress.mailboxes {
|
||||
if err := m.txCreateOrUpdateMessages(tx, msgs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// txDeleteMessage deletes the message from the mailbox buckets for this address.
|
||||
func (storeAddress *Address) txDeleteMessage(tx *bolt.Tx, apiID string) error {
|
||||
for _, m := range storeAddress.mailboxes {
|
||||
if err := m.txDeleteMessage(tx, apiID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
114
internal/store/cache.go
Normal file
114
internal/store/cache.go
Normal file
@ -0,0 +1,114 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Cache caches the last event IDs for all accounts (there should be only one instance).
|
||||
type Cache struct {
|
||||
// cache is map from userID => key (such as last event) => value (such as event ID).
|
||||
cache map[string]map[string]string
|
||||
path string
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCache constructs a new cache at the given path.
|
||||
func NewCache(path string) *Cache {
|
||||
return &Cache{
|
||||
path: path,
|
||||
lock: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) getEventID(userID string) string {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
_ = c.loadCache()
|
||||
|
||||
if c.cache == nil {
|
||||
c.cache = map[string]map[string]string{}
|
||||
}
|
||||
if c.cache[userID] == nil {
|
||||
c.cache[userID] = map[string]string{}
|
||||
}
|
||||
|
||||
return c.cache[userID]["events"]
|
||||
}
|
||||
|
||||
func (c *Cache) setEventID(userID, eventID string) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.cache[userID] == nil {
|
||||
c.cache[userID] = map[string]string{}
|
||||
}
|
||||
c.cache[userID]["events"] = eventID
|
||||
|
||||
return c.saveCache()
|
||||
}
|
||||
|
||||
func (c *Cache) loadCache() error {
|
||||
if c.cache != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(c.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint[errcheck]
|
||||
|
||||
return json.NewDecoder(f).Decode(&c.cache)
|
||||
}
|
||||
|
||||
func (c *Cache) saveCache() error {
|
||||
if c.cache == nil {
|
||||
return errors.New("events: cannot save cache: cache is nil")
|
||||
}
|
||||
|
||||
f, err := os.Create(c.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close() //nolint[errcheck]
|
||||
|
||||
return json.NewEncoder(f).Encode(c.cache)
|
||||
}
|
||||
|
||||
func (c *Cache) clearCacheUser(userID string) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.cache == nil {
|
||||
log.WithField("user", userID).Warning("Cannot clear user from cache: cache is nil")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithField("user", userID).Trace("Removing user from event loop cache")
|
||||
|
||||
delete(c.cache, userID)
|
||||
|
||||
return c.saveCache()
|
||||
}
|
||||
109
internal/store/change.go
Normal file
109
internal/store/change.go
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/message"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
imap "github.com/emersion/go-imap"
|
||||
imapBackend "github.com/emersion/go-imap/backend"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetIMAPUpdateChannel sets the channel on which imap update messages will be sent. This should be the channel
|
||||
// on which the imap backend listens for imap updates.
|
||||
func (store *Store) SetIMAPUpdateChannel(updates chan interface{}) {
|
||||
store.log.Debug("Listening for IMAP updates")
|
||||
|
||||
if store.imapUpdates = updates; store.imapUpdates == nil {
|
||||
store.log.Error("The IMAP Updates channel is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) imapNotice(address, notice string) {
|
||||
update := new(imapBackend.StatusUpdate)
|
||||
update.Username = address
|
||||
update.StatusResp = &imap.StatusResp{
|
||||
Type: imap.StatusOk,
|
||||
Code: imap.CodeAlert,
|
||||
Info: notice,
|
||||
}
|
||||
store.imapSendUpdate(update)
|
||||
}
|
||||
|
||||
func (store *Store) imapUpdateMessage(address, mailboxName string, uid, sequenceNumber uint32, msg *pmapi.Message) {
|
||||
store.log.WithFields(logrus.Fields{
|
||||
"address": address,
|
||||
"mailbox": mailboxName,
|
||||
"seqNum": sequenceNumber,
|
||||
"uid": uid,
|
||||
"flags": message.GetFlags(msg),
|
||||
}).Trace("IDLE update")
|
||||
update := new(imapBackend.MessageUpdate)
|
||||
update.Username = address
|
||||
update.Mailbox = mailboxName
|
||||
update.Message = imap.NewMessage(sequenceNumber, []string{imap.FlagsMsgAttr, imap.UidMsgAttr})
|
||||
update.Message.Flags = message.GetFlags(msg)
|
||||
update.Message.Uid = uid
|
||||
store.imapSendUpdate(update)
|
||||
}
|
||||
|
||||
func (store *Store) imapDeleteMessage(address, mailboxName string, sequenceNumber uint32) {
|
||||
store.log.WithFields(logrus.Fields{
|
||||
"address": address,
|
||||
"mailbox": mailboxName,
|
||||
"seqNum": sequenceNumber,
|
||||
}).Trace("IDLE delete")
|
||||
update := new(imapBackend.ExpungeUpdate)
|
||||
update.Username = address
|
||||
update.Mailbox = mailboxName
|
||||
update.SeqNum = sequenceNumber
|
||||
store.imapSendUpdate(update)
|
||||
}
|
||||
|
||||
func (store *Store) imapMailboxStatus(address, mailboxName string, total, unread uint) {
|
||||
store.log.WithFields(logrus.Fields{
|
||||
"address": address,
|
||||
"mailbox": mailboxName,
|
||||
"total": total,
|
||||
"unread": unread,
|
||||
}).Trace("IDLE status")
|
||||
update := new(imapBackend.MailboxUpdate)
|
||||
update.Username = address
|
||||
update.Mailbox = mailboxName
|
||||
update.MailboxStatus = imap.NewMailboxStatus(mailboxName, []string{imap.MailboxMessages, imap.MailboxUnseen})
|
||||
update.MailboxStatus.Messages = uint32(total)
|
||||
update.MailboxStatus.Unseen = uint32(unread)
|
||||
store.imapSendUpdate(update)
|
||||
}
|
||||
|
||||
func (store *Store) imapSendUpdate(update interface{}) {
|
||||
if store.imapUpdates == nil {
|
||||
store.log.Trace("IMAP IDLE unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
store.log.Error("Could not send IMAP update (timeout)")
|
||||
return
|
||||
case store.imapUpdates <- update:
|
||||
}
|
||||
}
|
||||
129
internal/store/change_test.go
Normal file
129
internal/store/change_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
imapBackend "github.com/emersion/go-imap/backend"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateOrUpdateMessageIMAPUpdates(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
updates := make(chan interface{})
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
m.store.SetIMAPUpdateChannel(updates)
|
||||
|
||||
go checkIMAPUpdates(t, updates, []func(interface{}) bool{
|
||||
checkMessageUpdate(addr1, "All Mail", 1, 1),
|
||||
checkMessageUpdate(addr1, "All Mail", 2, 2),
|
||||
})
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
close(updates)
|
||||
}
|
||||
|
||||
func TestCreateOrUpdateMessageIMAPUpdatesBulkUpdate(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
updates := make(chan interface{})
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
m.store.SetIMAPUpdateChannel(updates)
|
||||
|
||||
go checkIMAPUpdates(t, updates, []func(interface{}) bool{
|
||||
checkMessageUpdate(addr1, "All Mail", 1, 1),
|
||||
checkMessageUpdate(addr1, "All Mail", 2, 2),
|
||||
})
|
||||
|
||||
msg1 := getTestMessage("msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
msg2 := getTestMessage("msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
require.Nil(t, m.store.createOrUpdateMessagesEvent([]*pmapi.Message{msg1, msg2}))
|
||||
|
||||
close(updates)
|
||||
}
|
||||
|
||||
func TestDeleteMessageIMAPUpdate(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
updates := make(chan interface{})
|
||||
m.store.SetIMAPUpdateChannel(updates)
|
||||
go checkIMAPUpdates(t, updates, []func(interface{}) bool{
|
||||
checkMessageDelete(addr1, "All Mail", 2),
|
||||
checkMessageDelete(addr1, "All Mail", 1),
|
||||
})
|
||||
|
||||
require.Nil(t, m.store.deleteMessageEvent("msg2"))
|
||||
require.Nil(t, m.store.deleteMessageEvent("msg1"))
|
||||
close(updates)
|
||||
}
|
||||
|
||||
func checkIMAPUpdates(t *testing.T, updates chan interface{}, checkFunctions []func(interface{}) bool) {
|
||||
idx := 0
|
||||
for update := range updates {
|
||||
if idx >= len(checkFunctions) {
|
||||
continue
|
||||
}
|
||||
if !checkFunctions[idx](update) {
|
||||
continue
|
||||
}
|
||||
idx++
|
||||
}
|
||||
require.True(t, idx == len(checkFunctions), "Less updates than expected: %+v of %+v", idx, len(checkFunctions))
|
||||
}
|
||||
|
||||
func checkMessageUpdate(username, mailbox string, seqNum, uid int) func(interface{}) bool { //nolint[unparam]
|
||||
return func(update interface{}) bool {
|
||||
switch u := update.(type) {
|
||||
case *imapBackend.MessageUpdate:
|
||||
return (u.Update.Username == username &&
|
||||
u.Update.Mailbox == mailbox &&
|
||||
u.Message.SeqNum == uint32(seqNum) &&
|
||||
u.Message.Uid == uint32(uid))
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkMessageDelete(username, mailbox string, seqNum int) func(interface{}) bool { //nolint[unparam]
|
||||
return func(update interface{}) bool {
|
||||
switch u := update.(type) {
|
||||
case *imapBackend.ExpungeUpdate:
|
||||
return (u.Update.Username == username &&
|
||||
u.Update.Mailbox == mailbox &&
|
||||
u.SeqNum == uint32(seqNum))
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
32
internal/store/convert.go
Normal file
32
internal/store/convert.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// itob returns a 4-byte big endian representation of v.
|
||||
func itob(v uint32) []byte {
|
||||
b := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(b, v)
|
||||
return b
|
||||
}
|
||||
|
||||
// btoi returns the uint32 represented by b.
|
||||
func btoi(b []byte) uint32 {
|
||||
return binary.BigEndian.Uint32(b)
|
||||
}
|
||||
546
internal/store/event_loop.go
Normal file
546
internal/store/event_loop.go
Normal file
@ -0,0 +1,546 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
bridgeEvents "github.com/ProtonMail/proton-bridge/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const pollInterval = 30 * time.Second
|
||||
|
||||
type eventLoop struct {
|
||||
cache *Cache
|
||||
currentEventID string
|
||||
pollCh chan chan struct{}
|
||||
stopCh chan struct{}
|
||||
notifyStopCh chan struct{}
|
||||
isRunning bool
|
||||
hasInternet bool
|
||||
|
||||
log *logrus.Entry
|
||||
|
||||
store *Store
|
||||
apiClient PMAPIProvider
|
||||
user BridgeUser
|
||||
events listener.Listener
|
||||
}
|
||||
|
||||
func newEventLoop(cache *Cache, store *Store, api PMAPIProvider, user BridgeUser, events listener.Listener) *eventLoop {
|
||||
eventLog := log.WithField("userID", user.ID())
|
||||
eventLog.Trace("Creating new event loop")
|
||||
|
||||
return &eventLoop{
|
||||
cache: cache,
|
||||
currentEventID: cache.getEventID(user.ID()),
|
||||
pollCh: make(chan chan struct{}),
|
||||
isRunning: false,
|
||||
|
||||
log: eventLog,
|
||||
|
||||
store: store,
|
||||
apiClient: api,
|
||||
user: user,
|
||||
events: events,
|
||||
}
|
||||
}
|
||||
|
||||
func (loop *eventLoop) IsRunning() bool {
|
||||
return loop.isRunning
|
||||
}
|
||||
|
||||
func (loop *eventLoop) setFirstEventID() (err error) {
|
||||
loop.log.Trace("Setting first event ID")
|
||||
|
||||
event, err := loop.apiClient.GetEvent("")
|
||||
if err != nil {
|
||||
loop.log.WithError(err).Error("Could not get latest event ID")
|
||||
return
|
||||
}
|
||||
|
||||
loop.currentEventID = event.EventID
|
||||
|
||||
if err = loop.cache.setEventID(loop.user.ID(), loop.currentEventID); err != nil {
|
||||
loop.log.WithError(err).Error("Could not set latest event ID in user cache")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// pollNow starts polling events right away and waits till the events are
|
||||
// processed so we are sure updates are propagated to the database.
|
||||
func (loop *eventLoop) pollNow() {
|
||||
eventProcessedCh := make(chan struct{})
|
||||
loop.pollCh <- eventProcessedCh
|
||||
<-eventProcessedCh
|
||||
close(eventProcessedCh)
|
||||
}
|
||||
|
||||
func (loop *eventLoop) stop() {
|
||||
if loop.isRunning {
|
||||
loop.isRunning = false
|
||||
close(loop.stopCh)
|
||||
|
||||
select {
|
||||
case <-loop.notifyStopCh:
|
||||
loop.log.Info("Event loop was stopped")
|
||||
case <-time.After(1 * time.Second):
|
||||
loop.log.Warn("Timed out waiting for event loop to stop")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (loop *eventLoop) start() { // nolint[funlen]
|
||||
if loop.isRunning {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
loop.isRunning = false
|
||||
}()
|
||||
loop.stopCh = make(chan struct{})
|
||||
loop.notifyStopCh = make(chan struct{})
|
||||
loop.isRunning = true
|
||||
|
||||
events := make(chan *pmapi.Event)
|
||||
defer close(events)
|
||||
|
||||
loop.log.WithField("lastEventID", loop.currentEventID).Info("Subscribed to events")
|
||||
defer func() {
|
||||
loop.log.WithField("lastEventID", loop.currentEventID).Info("Subscription stopped")
|
||||
}()
|
||||
|
||||
t := time.NewTicker(pollInterval)
|
||||
defer t.Stop()
|
||||
|
||||
loop.hasInternet = true
|
||||
|
||||
go loop.pollNow()
|
||||
|
||||
for {
|
||||
var eventProcessedCh chan struct{}
|
||||
select {
|
||||
case <-loop.stopCh:
|
||||
close(loop.notifyStopCh)
|
||||
return
|
||||
case eventProcessedCh = <-loop.pollCh:
|
||||
case <-t.C:
|
||||
}
|
||||
|
||||
// Before we fetch the first event, check whether this is the first time we've
|
||||
// started the event loop, and if so, trigger a full sync.
|
||||
// In case internet connection was not available during start, it will be
|
||||
// handled anyway when the connection is back here.
|
||||
if loop.isBeforeFirstStart() {
|
||||
if eventErr := loop.setFirstEventID(); eventErr != nil {
|
||||
loop.log.WithError(eventErr).Warn("Could not set initial event ID")
|
||||
}
|
||||
}
|
||||
|
||||
// If the sync is not finished then a new sync is triggered.
|
||||
if !loop.store.isSyncFinished() {
|
||||
loop.store.triggerSync()
|
||||
}
|
||||
|
||||
more, err := loop.processNextEvent()
|
||||
if eventProcessedCh != nil {
|
||||
eventProcessedCh <- struct{}{}
|
||||
}
|
||||
if err != nil {
|
||||
loop.log.WithError(err).Error("Cannot process event, stopping event loop")
|
||||
// When event loop stops, the only way to start it again is by login.
|
||||
// It should stop only when user is logged out but even if there is other
|
||||
// serious error, logout is intended action.
|
||||
if errLogout := loop.user.Logout(); errLogout != nil {
|
||||
loop.log.
|
||||
WithError(errLogout).
|
||||
Error("Failed to logout user after loop finished with error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if more {
|
||||
go loop.pollNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isBeforeFirstStart returns whether the initial event ID was already set or not.
|
||||
func (loop *eventLoop) isBeforeFirstStart() bool {
|
||||
return loop.currentEventID == ""
|
||||
}
|
||||
|
||||
// processNextEvent saves only successfully processed `eventID` into cache
|
||||
// (disk). It will filter out in defer all errors except invalid token error.
|
||||
// Invalid error will be returned and stop the event loop.
|
||||
func (loop *eventLoop) processNextEvent() (more bool, err error) { // nolint[funlen]
|
||||
l := loop.log.WithField("currentEventID", loop.currentEventID)
|
||||
|
||||
// We only want to consider invalid tokens as real errors because all other errors might fix themselves eventually
|
||||
// (e.g. no internet, ulimit reached etc.)
|
||||
defer func() {
|
||||
if errors.Cause(err) == pmapi.ErrAPINotReachable {
|
||||
l.Warn("Internet unavailable")
|
||||
loop.events.Emit(bridgeEvents.InternetOffEvent, "")
|
||||
loop.hasInternet = false
|
||||
err = nil
|
||||
}
|
||||
|
||||
if err != nil && isFdCloseToULimit() {
|
||||
l.Warn("Ulimit reached")
|
||||
loop.events.Emit(bridgeEvents.RestartBridgeEvent, "")
|
||||
err = nil
|
||||
}
|
||||
|
||||
if errors.Cause(err) == pmapi.ErrUpgradeApplication {
|
||||
l.Warn("Need to upgrade application")
|
||||
loop.events.Emit(bridgeEvents.UpgradeApplicationEvent, "")
|
||||
err = nil
|
||||
}
|
||||
|
||||
_, errUnauthorized := errors.Cause(err).(*pmapi.ErrUnauthorized)
|
||||
|
||||
// All errors except Invalid Token (which is not possible to recover from) are ignored.
|
||||
if err != nil && !errUnauthorized && errors.Cause(err) != pmapi.ErrInvalidToken {
|
||||
l.WithError(err).Trace("Error skipped")
|
||||
err = nil
|
||||
}
|
||||
}()
|
||||
|
||||
l.Trace("Polling next event")
|
||||
var event *pmapi.Event
|
||||
if event, err = loop.apiClient.GetEvent(loop.currentEventID); err != nil {
|
||||
return false, errors.Wrap(err, "failed to get event")
|
||||
}
|
||||
|
||||
l = l.WithField("newEventID", event.EventID)
|
||||
|
||||
if !loop.hasInternet {
|
||||
loop.events.Emit(bridgeEvents.InternetOnEvent, "")
|
||||
loop.hasInternet = true
|
||||
}
|
||||
|
||||
if err = loop.processEvent(event); err != nil {
|
||||
return false, errors.Wrap(err, "failed to process event")
|
||||
}
|
||||
|
||||
if loop.currentEventID != event.EventID {
|
||||
// In case new event ID cannot be saved to cache, we update it in event loop
|
||||
// anyway and continue processing new events to prevent the loop from repeatedly
|
||||
// processing the same event.
|
||||
// This allows the event loop to continue to function (unless the cache was broken
|
||||
// and bridge stopped, in which case it will start from the old event ID anyway).
|
||||
loop.currentEventID = event.EventID
|
||||
if err = loop.cache.setEventID(loop.user.ID(), event.EventID); err != nil {
|
||||
return false, errors.Wrap(err, "failed to save event ID to cache")
|
||||
}
|
||||
}
|
||||
|
||||
return event.More == 1, err
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processEvent(event *pmapi.Event) (err error) {
|
||||
eventLog := loop.log.WithField("event", event.EventID)
|
||||
eventLog.Debug("Processing event")
|
||||
|
||||
if (event.Refresh & pmapi.EventRefreshMail) != 0 {
|
||||
eventLog.Info("Processing refresh event")
|
||||
loop.store.triggerSync()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if len(event.Addresses) != 0 {
|
||||
if err = loop.processAddresses(eventLog, event.Addresses); err != nil {
|
||||
return errors.Wrap(err, "failed to process address events")
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Labels) != 0 {
|
||||
if err = loop.processLabels(eventLog, event.Labels); err != nil {
|
||||
return errors.Wrap(err, "failed to process label events")
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Messages) != 0 {
|
||||
if err = loop.processMessages(eventLog, event.Messages); err != nil {
|
||||
return errors.Wrap(err, "failed to process message events")
|
||||
}
|
||||
}
|
||||
|
||||
// One would expect that every event would contain MessageCount as part of
|
||||
// the event.Messages, but this is apparently not the case.
|
||||
// MessageCounts are served on an irregular basis, so we should update and
|
||||
// compare the counts only when we receive them.
|
||||
if len(event.MessageCounts) != 0 {
|
||||
if err = loop.processMessageCounts(eventLog, event.MessageCounts); err != nil {
|
||||
return errors.Wrap(err, "failed to process message count events")
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Notices) != 0 {
|
||||
loop.processNotices(eventLog, event.Notices)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processAddresses(log *logrus.Entry, addressEvents []*pmapi.EventAddress) (err error) {
|
||||
log.Debug("Processing address change event")
|
||||
|
||||
// Get old addresses for comparisons before updating user.
|
||||
oldList := loop.apiClient.Addresses()
|
||||
|
||||
if err = loop.user.UpdateUser(); err != nil {
|
||||
if logoutErr := loop.user.Logout(); logoutErr != nil {
|
||||
log.WithError(logoutErr).Error("Failed to logout user after failed update")
|
||||
}
|
||||
return errors.Wrap(err, "failed to update user")
|
||||
}
|
||||
|
||||
for _, addressEvent := range addressEvents {
|
||||
switch addressEvent.Action {
|
||||
case pmapi.EventCreate:
|
||||
log.WithField("email", addressEvent.Address.Email).Debug("Address was created")
|
||||
loop.events.Emit(bridgeEvents.AddressChangedEvent, loop.user.GetPrimaryAddress())
|
||||
|
||||
case pmapi.EventUpdate:
|
||||
oldAddress := oldList.ByID(addressEvent.ID)
|
||||
if oldAddress == nil {
|
||||
log.Warning("Event refers to an address that isn't present")
|
||||
continue
|
||||
}
|
||||
|
||||
email := oldAddress.Email
|
||||
log.WithField("email", email).Debug("Address was updated")
|
||||
if addressEvent.Address.Receive != oldAddress.Receive {
|
||||
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
|
||||
}
|
||||
|
||||
case pmapi.EventDelete:
|
||||
oldAddress := oldList.ByID(addressEvent.ID)
|
||||
if oldAddress == nil {
|
||||
log.Warning("Event refers to an address that isn't present")
|
||||
continue
|
||||
}
|
||||
|
||||
email := oldAddress.Email
|
||||
log.WithField("email", email).Debug("Address was deleted")
|
||||
loop.user.CloseConnection(email)
|
||||
loop.events.Emit(bridgeEvents.AddressChangedLogoutEvent, email)
|
||||
}
|
||||
}
|
||||
|
||||
if err = loop.store.createOrUpdateAddressInfo(loop.apiClient.Addresses()); err != nil {
|
||||
return errors.Wrap(err, "failed to update address IDs in store")
|
||||
}
|
||||
|
||||
if err = loop.store.createOrDeleteAddressesEvent(); err != nil {
|
||||
return errors.Wrap(err, "failed to create/delete store addresses")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processLabels(eventLog *logrus.Entry, labels []*pmapi.EventLabel) error {
|
||||
eventLog.Debug("Processing label change event")
|
||||
|
||||
for _, eventLabel := range labels {
|
||||
label := eventLabel.Label
|
||||
switch eventLabel.Action {
|
||||
case pmapi.EventCreate, pmapi.EventUpdate:
|
||||
if err := loop.store.createOrUpdateMailboxEvent(label); err != nil {
|
||||
return errors.Wrap(err, "failed to create or update label")
|
||||
}
|
||||
case pmapi.EventDelete:
|
||||
if err := loop.store.deleteMailboxEvent(eventLabel.ID); err != nil {
|
||||
return errors.Wrap(err, "failed to delete label")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processMessages(eventLog *logrus.Entry, messages []*pmapi.EventMessage) (err error) {
|
||||
eventLog.Debug("Processing message change event")
|
||||
|
||||
for _, message := range messages {
|
||||
msgLog := eventLog.WithField("msgID", message.ID)
|
||||
|
||||
switch message.Action {
|
||||
case pmapi.EventCreate:
|
||||
msgLog.Debug("Processing EventCreate for message")
|
||||
|
||||
if message.Created == nil {
|
||||
msgLog.Error("Got EventCreate with nil message")
|
||||
break
|
||||
}
|
||||
|
||||
if err = loop.store.createOrUpdateMessageEvent(message.Created); err != nil {
|
||||
return errors.Wrap(err, "failed to put message into DB")
|
||||
}
|
||||
|
||||
case pmapi.EventUpdate, pmapi.EventUpdateFlags:
|
||||
msgLog.Debug("Processing EventUpdate(Flags) for message")
|
||||
|
||||
if message.Updated == nil {
|
||||
msgLog.Errorf("Got EventUpdate(Flags) with nil message")
|
||||
break
|
||||
}
|
||||
|
||||
var msg *pmapi.Message
|
||||
msg, err = loop.store.getMessageFromDB(message.ID)
|
||||
if err == ErrNoSuchAPIID {
|
||||
msgLog.WithError(err).Warning("Cannot get message from DB for updating. Trying fetch...")
|
||||
msg, err = loop.store.fetchMessage(message.ID)
|
||||
// If message does not exist anywhere, update event is probably old and off topic - skip it.
|
||||
if err == ErrNoSuchAPIID {
|
||||
msgLog.Warn("Skipping message update, because message does not exist nor in local DB or on API")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get message from DB for updating")
|
||||
}
|
||||
|
||||
updateMessage(msgLog, msg, message.Updated)
|
||||
|
||||
if err = loop.store.createOrUpdateMessageEvent(msg); err != nil {
|
||||
return errors.Wrap(err, "failed to update message in DB")
|
||||
}
|
||||
|
||||
case pmapi.EventDelete:
|
||||
msgLog.Debug("Processing EventDelete for message")
|
||||
|
||||
if err = loop.store.deleteMessageEvent(message.ID); err != nil {
|
||||
return errors.Wrap(err, "failed to delete message from DB")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func updateMessage(msgLog *logrus.Entry, message *pmapi.Message, updates *pmapi.EventMessageUpdated) { //nolint[funlen]
|
||||
msgLog.Debug("Updating message")
|
||||
|
||||
message.Time = updates.Time
|
||||
|
||||
if updates.Subject != nil {
|
||||
msgLog.WithField("subject", *updates.Subject).Trace("Updating message value")
|
||||
message.Subject = *updates.Subject
|
||||
}
|
||||
|
||||
if updates.Sender != nil {
|
||||
msgLog.WithField("sender", *updates.Sender).Trace("Updating message value")
|
||||
message.Sender = updates.Sender
|
||||
}
|
||||
|
||||
if updates.ToList != nil {
|
||||
msgLog.WithField("toList", *updates.ToList).Trace("Updating message value")
|
||||
message.ToList = *updates.ToList
|
||||
}
|
||||
|
||||
if updates.CCList != nil {
|
||||
msgLog.WithField("ccList", *updates.CCList).Trace("Updating message value")
|
||||
message.CCList = *updates.CCList
|
||||
}
|
||||
|
||||
if updates.BCCList != nil {
|
||||
msgLog.WithField("bccList", *updates.BCCList).Trace("Updating message value")
|
||||
message.BCCList = *updates.BCCList
|
||||
}
|
||||
|
||||
if updates.Unread != nil {
|
||||
msgLog.WithField("unread", *updates.Unread).Trace("Updating message value")
|
||||
message.Unread = *updates.Unread
|
||||
}
|
||||
|
||||
if updates.Flags != nil {
|
||||
msgLog.WithField("flags", *updates.Flags).Trace("Updating message value")
|
||||
message.Flags = *updates.Flags
|
||||
}
|
||||
|
||||
if updates.LabelIDs != nil {
|
||||
msgLog.WithField("labelIDs", updates.LabelIDs).Trace("Updating message value")
|
||||
message.LabelIDs = updates.LabelIDs
|
||||
} else {
|
||||
for _, added := range updates.LabelIDsAdded {
|
||||
hasLabel := false
|
||||
for _, l := range message.LabelIDs {
|
||||
if added == l {
|
||||
hasLabel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasLabel {
|
||||
msgLog.WithField("added", added).Trace("Adding label to message")
|
||||
message.LabelIDs = append(message.LabelIDs, added)
|
||||
}
|
||||
}
|
||||
|
||||
labels := []string{}
|
||||
for _, l := range message.LabelIDs {
|
||||
removeLabel := false
|
||||
for _, removed := range updates.LabelIDsRemoved {
|
||||
if removed == l {
|
||||
removeLabel = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if removeLabel {
|
||||
msgLog.WithField("label", l).Trace("Removing label from message")
|
||||
} else {
|
||||
labels = append(labels, l)
|
||||
}
|
||||
}
|
||||
|
||||
message.LabelIDs = labels
|
||||
}
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processMessageCounts(l *logrus.Entry, messageCounts []*pmapi.MessagesCount) error {
|
||||
l.WithField("apiCounts", messageCounts).Debug("Processing message count change event")
|
||||
|
||||
isSynced, err := loop.store.isSynced(messageCounts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !isSynced {
|
||||
loop.store.triggerSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (loop *eventLoop) processNotices(l *logrus.Entry, notices []string) {
|
||||
l.Debug("Processing notice change event")
|
||||
|
||||
for _, notice := range notices {
|
||||
l.Infof("Notice: %q", notice)
|
||||
for _, address := range loop.user.GetStoreAddresses() {
|
||||
loop.store.imapNotice(address, notice)
|
||||
}
|
||||
}
|
||||
}
|
||||
153
internal/store/event_loop_test.go
Normal file
153
internal/store/event_loop_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEventLoopProcessMoreEvents(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
// Event expectations need to be defined before calling `newStoreNoEvents`
|
||||
// to force to use these for this particular test.
|
||||
// Also, event loop calls ListMessages again and we need to place it after
|
||||
// calling `newStoreNoEvents` to not break expectations for the first sync.
|
||||
gomock.InOrder(
|
||||
// Doesn't matter which IDs are used.
|
||||
// This test is trying to see whether event loop will immediately process
|
||||
// next event if there is `More` of them.
|
||||
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event50",
|
||||
More: 1,
|
||||
}, nil),
|
||||
m.api.EXPECT().GetEvent("event50").Return(&pmapi.Event{
|
||||
EventID: "event70",
|
||||
More: 0,
|
||||
}, nil),
|
||||
m.api.EXPECT().GetEvent("event70").Return(&pmapi.Event{
|
||||
EventID: "event71",
|
||||
More: 0,
|
||||
}, nil),
|
||||
)
|
||||
m.newStoreNoEvents(true)
|
||||
m.api.EXPECT().ListMessages(gomock.Any()).Return([]*pmapi.Message{}, 0, nil).AnyTimes()
|
||||
|
||||
// Event loop runs in goroutine and will be stopped by deferred mock clearing.
|
||||
go m.store.eventLoop.start()
|
||||
|
||||
// More events are processed right away.
|
||||
require.Eventually(t, func() bool {
|
||||
return m.store.eventLoop.currentEventID == "event70"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
// For normal event we need to wait to next polling.
|
||||
time.Sleep(pollInterval)
|
||||
require.Eventually(t, func() bool {
|
||||
return m.store.eventLoop.currentEventID == "event71"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestEventLoopUpdateMessageFromLoop(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
subject := "old subject"
|
||||
newSubject := "new subject"
|
||||
|
||||
// First sync will add message with old subject to database.
|
||||
m.api.EXPECT().GetMessage("msg1").Return(&pmapi.Message{
|
||||
ID: "msg1",
|
||||
Subject: subject,
|
||||
}, nil)
|
||||
// Event will update the subject.
|
||||
m.api.EXPECT().GetEvent("latestEventID").Return(&pmapi.Event{
|
||||
EventID: "event1",
|
||||
Messages: []*pmapi.EventMessage{{
|
||||
EventItem: pmapi.EventItem{
|
||||
ID: "msg1",
|
||||
Action: pmapi.EventUpdate,
|
||||
},
|
||||
Updated: &pmapi.EventMessageUpdated{
|
||||
ID: "msg1",
|
||||
Subject: &newSubject,
|
||||
},
|
||||
}},
|
||||
}, nil)
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
// Event loop runs in goroutine and will be stopped by deferred mock clearing.
|
||||
go m.store.eventLoop.start()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
msg, err := m.store.getMessageFromDB("msg1")
|
||||
return err == nil && msg.Subject == newSubject
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestEventLoopUpdateMessage(t *testing.T) {
|
||||
address1 := &mail.Address{Address: "user1@example.com"}
|
||||
address2 := &mail.Address{Address: "user2@example.com"}
|
||||
msg := &pmapi.Message{
|
||||
ID: "msg1",
|
||||
Subject: "old",
|
||||
Unread: 0,
|
||||
Flags: 10,
|
||||
Sender: address1,
|
||||
ToList: []*mail.Address{address2},
|
||||
CCList: []*mail.Address{address1},
|
||||
BCCList: []*mail.Address{},
|
||||
Time: 20,
|
||||
LabelIDs: []string{"old"},
|
||||
}
|
||||
newMsg := &pmapi.Message{
|
||||
ID: "msg1",
|
||||
Subject: "new",
|
||||
Unread: 1,
|
||||
Flags: 11,
|
||||
Sender: address2,
|
||||
ToList: []*mail.Address{address1},
|
||||
CCList: []*mail.Address{address2},
|
||||
BCCList: []*mail.Address{address1},
|
||||
Time: 21,
|
||||
LabelIDs: []string{"new"},
|
||||
}
|
||||
|
||||
updateMessage(log, msg, &pmapi.EventMessageUpdated{
|
||||
ID: "msg1",
|
||||
Subject: &newMsg.Subject,
|
||||
Unread: &newMsg.Unread,
|
||||
Flags: &newMsg.Flags,
|
||||
Sender: newMsg.Sender,
|
||||
ToList: &newMsg.ToList,
|
||||
CCList: &newMsg.CCList,
|
||||
BCCList: &newMsg.BCCList,
|
||||
Time: newMsg.Time,
|
||||
LabelIDs: newMsg.LabelIDs,
|
||||
})
|
||||
|
||||
require.Equal(t, newMsg, msg)
|
||||
}
|
||||
265
internal/store/mailbox.go
Normal file
265
internal/store/mailbox.go
Normal file
@ -0,0 +1,265 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// Mailbox is mailbox for specific address and mailbox.
|
||||
type Mailbox struct {
|
||||
store *Store
|
||||
storeAddress *Address
|
||||
|
||||
labelID string
|
||||
labelPrefix string
|
||||
labelName string
|
||||
color string
|
||||
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
func newMailbox(storeAddress *Address, labelID, labelPrefix, labelName, color string) (mb *Mailbox, err error) {
|
||||
l := log.
|
||||
WithField("addrID", storeAddress.addressID).
|
||||
WithField("lblID", labelID)
|
||||
mb = &Mailbox{
|
||||
store: storeAddress.store,
|
||||
storeAddress: storeAddress,
|
||||
labelID: labelID,
|
||||
labelPrefix: labelPrefix,
|
||||
labelName: labelPrefix + labelName,
|
||||
color: color,
|
||||
log: l,
|
||||
}
|
||||
|
||||
if err = mb.store.db.Update(func(tx *bolt.Tx) error {
|
||||
return initMailboxBucket(tx, mb.getBucketName())
|
||||
}); err != nil {
|
||||
l.WithError(err).Error("Could not initialise mailbox buckets")
|
||||
}
|
||||
|
||||
syncDraftsIfNecssary(mb)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func syncDraftsIfNecssary(mb *Mailbox) { //nolint[funlen]
|
||||
// We didn't support drafts before v1.2.6 and therefore if we now created
|
||||
// Drafts mailbox we need to check whether counts match (drafts are synced).
|
||||
// If not, sync them from local metadata without need to do full resync,
|
||||
// Can be removed with 1.2.7 or later.
|
||||
if mb.labelID != pmapi.DraftLabel {
|
||||
return
|
||||
}
|
||||
|
||||
// If the drafts mailbox total is non-zero, it means it has already been used
|
||||
// and there is no need to continue. Otherwise, we may need to do an initial sync.
|
||||
total, _, err := mb.GetCounts()
|
||||
if err != nil || total != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
counts, err := mb.store.getOnAPICounts()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
foundCounts := false
|
||||
doSync := false
|
||||
for _, count := range counts {
|
||||
if count.LabelID != pmapi.DraftLabel {
|
||||
continue
|
||||
}
|
||||
foundCounts = true
|
||||
log.WithField("total", total).WithField("total-api", count.TotalOnAPI).Debug("Drafts mailbox created: checking need for sync")
|
||||
if count.TotalOnAPI == total {
|
||||
continue
|
||||
}
|
||||
doSync = true
|
||||
break
|
||||
}
|
||||
|
||||
if !foundCounts {
|
||||
log.Debug("Drafts mailbox created: missing counts, refreshing")
|
||||
_ = mb.store.updateCountsFromServer()
|
||||
}
|
||||
|
||||
if !foundCounts || doSync {
|
||||
err := mb.store.db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.Bucket(metadataBucket).ForEach(func(k, v []byte) error {
|
||||
msg := &pmapi.Message{}
|
||||
if err := json.Unmarshal(v, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, msgLabelID := range msg.LabelIDs {
|
||||
if msgLabelID == pmapi.DraftLabel {
|
||||
log.WithField("id", msg.ID).Debug("Drafts mailbox created: syncing draft locally")
|
||||
_ = mb.txCreateOrUpdateMessages(tx, []*pmapi.Message{msg})
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
log.WithError(err).Info("Drafts mailbox created: synced localy")
|
||||
}
|
||||
}
|
||||
|
||||
func initMailboxBucket(tx *bolt.Tx, bucketName []byte) error {
|
||||
bucket, err := tx.Bucket(mailboxesBucket).CreateBucketIfNotExists(bucketName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := bucket.CreateBucketIfNotExists(imapIDsBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := bucket.CreateBucketIfNotExists(apiIDsBucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LabelID returns ID of mailbox.
|
||||
func (storeMailbox *Mailbox) LabelID() string {
|
||||
return storeMailbox.labelID
|
||||
}
|
||||
|
||||
// Name returns the name of mailbox.
|
||||
func (storeMailbox *Mailbox) Name() string {
|
||||
return storeMailbox.labelName
|
||||
}
|
||||
|
||||
// Color returns the color of mailbox.
|
||||
func (storeMailbox *Mailbox) Color() string {
|
||||
return storeMailbox.color
|
||||
}
|
||||
|
||||
// UIDValidity returns the current value of structure version.
|
||||
func (storeMailbox *Mailbox) UIDValidity() uint32 {
|
||||
return storeMailbox.store.getMailboxesVersion()
|
||||
}
|
||||
|
||||
// IsFolder returns whether the mailbox is a folder (has "Folders/" prefix).
|
||||
func (storeMailbox *Mailbox) IsFolder() bool {
|
||||
return storeMailbox.labelPrefix == UserFoldersPrefix
|
||||
}
|
||||
|
||||
// IsLabel returns whether the mailbox is a label (has "Labels/" prefix).
|
||||
func (storeMailbox *Mailbox) IsLabel() bool {
|
||||
return storeMailbox.labelPrefix == UserLabelsPrefix
|
||||
}
|
||||
|
||||
// IsSystem returns whether the mailbox is one of the specific system mailboxes (has no prefix).
|
||||
func (storeMailbox *Mailbox) IsSystem() bool {
|
||||
return storeMailbox.labelPrefix == ""
|
||||
}
|
||||
|
||||
// Rename updates the mailbox by calling an API.
|
||||
// Change has to be propagated to all the same mailboxes in all addresses.
|
||||
// The propagation is processed by the event loop.
|
||||
func (storeMailbox *Mailbox) Rename(newName string) error {
|
||||
if storeMailbox.IsSystem() {
|
||||
return fmt.Errorf("cannot rename system mailboxes")
|
||||
}
|
||||
|
||||
if storeMailbox.IsFolder() {
|
||||
if !strings.HasPrefix(newName, UserFoldersPrefix) {
|
||||
return fmt.Errorf("cannot rename folder to non-folder")
|
||||
}
|
||||
|
||||
newName = strings.TrimPrefix(newName, UserFoldersPrefix)
|
||||
}
|
||||
|
||||
if storeMailbox.IsLabel() {
|
||||
if !strings.HasPrefix(newName, UserLabelsPrefix) {
|
||||
return fmt.Errorf("cannot rename label to non-label")
|
||||
}
|
||||
|
||||
newName = strings.TrimPrefix(newName, UserLabelsPrefix)
|
||||
}
|
||||
|
||||
return storeMailbox.storeAddress.updateMailbox(storeMailbox.labelID, newName, storeMailbox.color)
|
||||
}
|
||||
|
||||
// Delete deletes the mailbox by calling an API.
|
||||
// Deletion has to be propagated to all the same mailboxes in all addresses.
|
||||
// The propagation is processed by the event loop.
|
||||
func (storeMailbox *Mailbox) Delete() error {
|
||||
return storeMailbox.storeAddress.deleteMailbox(storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// GetDelimiter returns the path separator.
|
||||
func (storeMailbox *Mailbox) GetDelimiter() string {
|
||||
return PathDelimiter
|
||||
}
|
||||
|
||||
// deleteMailboxEvent deletes the mailbox bucket.
|
||||
// This is called from the event loop.
|
||||
func (storeMailbox *Mailbox) deleteMailboxEvent() error {
|
||||
return storeMailbox.db().Update(func(tx *bolt.Tx) error {
|
||||
return tx.Bucket(mailboxesBucket).DeleteBucket(storeMailbox.getBucketName())
|
||||
})
|
||||
}
|
||||
|
||||
// txGetIMAPIDsBucket returns the bucket mapping IMAP ID to API ID.
|
||||
func (storeMailbox *Mailbox) txGetIMAPIDsBucket(tx *bolt.Tx) *bolt.Bucket {
|
||||
return storeMailbox.txGetBucket(tx).Bucket(imapIDsBucket)
|
||||
}
|
||||
|
||||
// txGetAPIIDsBucket returns the bucket mapping API ID to IMAP ID.
|
||||
func (storeMailbox *Mailbox) txGetAPIIDsBucket(tx *bolt.Tx) *bolt.Bucket {
|
||||
return storeMailbox.txGetBucket(tx).Bucket(apiIDsBucket)
|
||||
}
|
||||
|
||||
// txGetBucket returns the bucket of mailbox containing mapping buckets.
|
||||
func (storeMailbox *Mailbox) txGetBucket(tx *bolt.Tx) *bolt.Bucket {
|
||||
return tx.Bucket(mailboxesBucket).Bucket(storeMailbox.getBucketName())
|
||||
}
|
||||
|
||||
func getMailboxBucketName(addressID, labelID string) []byte {
|
||||
return []byte(addressID + "-" + labelID)
|
||||
}
|
||||
|
||||
// getBucketName returns the name of mailbox bucket.
|
||||
func (storeMailbox *Mailbox) getBucketName() []byte {
|
||||
return getMailboxBucketName(storeMailbox.storeAddress.addressID, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// pollNow is a proxy for the store's eventloop's `pollNow()`.
|
||||
func (storeMailbox *Mailbox) pollNow() {
|
||||
storeMailbox.store.eventLoop.pollNow()
|
||||
}
|
||||
|
||||
// api is a proxy for the store's `PMAPIProvider`.
|
||||
func (storeMailbox *Mailbox) api() PMAPIProvider {
|
||||
return storeMailbox.store.api
|
||||
}
|
||||
|
||||
// update is a proxy for the store's db's `Update`.
|
||||
func (storeMailbox *Mailbox) db() *bolt.DB {
|
||||
return storeMailbox.store.db
|
||||
}
|
||||
257
internal/store/mailbox_counts.go
Normal file
257
internal/store/mailbox_counts.go
Normal file
@ -0,0 +1,257 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// GetCounts returns numbers of total and unread messages in this mailbox bucket.
|
||||
func (storeMailbox *Mailbox) GetCounts() (total, unread uint, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
total, unread, err = storeMailbox.txGetCounts(tx)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (storeMailbox *Mailbox) txGetCounts(tx *bolt.Tx) (total, unread uint, err error) {
|
||||
// For total it would be enough to use `bolt.Bucket.Stats().KeyN` but
|
||||
// we also need to retrieve the count of unread emails therefore we are
|
||||
// looping all messages in this mailbox by `bolt.Cursor`
|
||||
metaBucket := tx.Bucket(metadataBucket)
|
||||
b := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
c := b.Cursor()
|
||||
imapID, apiID := c.First()
|
||||
for ; imapID != nil; imapID, apiID = c.Next() {
|
||||
total++
|
||||
rawMsg := metaBucket.Get(apiID)
|
||||
if rawMsg == nil {
|
||||
return 0, 0, ErrNoSuchAPIID
|
||||
}
|
||||
// Do not unmarshal whole JSON to speed up the looping.
|
||||
// Instead, we assume it will contain JSON int field `Unread`
|
||||
// where `1` means true (i.e. message is unread)
|
||||
if bytes.Contains(rawMsg, []byte(`"Unread":1`)) {
|
||||
unread++
|
||||
}
|
||||
}
|
||||
return total, unread, err
|
||||
}
|
||||
|
||||
type mailboxCounts struct {
|
||||
LabelID string
|
||||
LabelName string
|
||||
Color string
|
||||
Order int
|
||||
IsFolder bool
|
||||
TotalOnAPI uint
|
||||
UnreadOnAPI uint
|
||||
}
|
||||
|
||||
func txGetCountsFromBucketOrNew(bkt *bolt.Bucket, labelID string) (*mailboxCounts, error) {
|
||||
mc := &mailboxCounts{}
|
||||
if mcJSON := bkt.Get([]byte(labelID)); mcJSON != nil {
|
||||
if err := json.Unmarshal(mcJSON, mc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
mc.LabelID = labelID // if it was empty before we need to set labelID
|
||||
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
func (mc *mailboxCounts) txWriteToBucket(bucket *bolt.Bucket) error {
|
||||
mcJSON, err := json.Marshal(mc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put([]byte(mc.LabelID), mcJSON)
|
||||
}
|
||||
|
||||
func getSystemFolders() []*mailboxCounts {
|
||||
return []*mailboxCounts{
|
||||
{pmapi.InboxLabel, "INBOX", "#000", -1000, true, 0, 0},
|
||||
{pmapi.SentLabel, "Sent", "#000", -9, true, 0, 0},
|
||||
{pmapi.ArchiveLabel, "Archive", "#000", -8, true, 0, 0},
|
||||
{pmapi.SpamLabel, "Spam", "#000", -7, true, 0, 0},
|
||||
{pmapi.TrashLabel, "Trash", "#000", -6, true, 0, 0},
|
||||
{pmapi.AllMailLabel, "All Mail", "#000", -5, true, 0, 0},
|
||||
{pmapi.DraftLabel, "Drafts", "#000", -4, true, 0, 0},
|
||||
}
|
||||
}
|
||||
|
||||
// skipThisLabel decides to skip labelIDs that *are* pmapi system labels but *aren't* local system labels
|
||||
// (i.e. if it's in `pmapi.SystemLabels` but not in `getSystemFolders` then we skip it, otherwise we don't).
|
||||
func skipThisLabel(labelID string) bool {
|
||||
switch labelID {
|
||||
case pmapi.StarredLabel, pmapi.AllSentLabel, pmapi.AllDraftsLabel:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sortByOrder(labels []*pmapi.Label) {
|
||||
sort.Slice(labels, func(i, j int) bool {
|
||||
return labels[i].Order < labels[j].Order
|
||||
})
|
||||
}
|
||||
|
||||
func (mc *mailboxCounts) getPMLabel() *pmapi.Label {
|
||||
return &pmapi.Label{
|
||||
ID: mc.LabelID,
|
||||
Name: mc.LabelName,
|
||||
Color: mc.Color,
|
||||
Order: mc.Order,
|
||||
Type: pmapi.LabelTypeMailbox,
|
||||
Exclusive: mc.isExclusive(),
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mailboxCounts) isExclusive() int {
|
||||
if mc.IsFolder {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// createOrUpdateMailboxCountsBuckets will not change the on-API-counts.
|
||||
func (store *Store) createOrUpdateMailboxCountsBuckets(labels []*pmapi.Label) error {
|
||||
// Don't forget about system folders.
|
||||
// It should set label id, name, color, isFolder, total, unread.
|
||||
tx := func(tx *bolt.Tx) error {
|
||||
countsBkt := tx.Bucket(countsBucket)
|
||||
for _, label := range labels {
|
||||
// Skipping is probably not necessary.
|
||||
if skipThisLabel(label.ID) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get current data.
|
||||
mailbox, err := txGetCountsFromBucketOrNew(countsBkt, label.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update mailbox info, but dont change on-API-counts.
|
||||
mailbox.LabelName = label.Name
|
||||
mailbox.Color = label.Color
|
||||
mailbox.Order = label.Order
|
||||
mailbox.IsFolder = label.Exclusive == 1
|
||||
|
||||
// Write.
|
||||
if err = mailbox.txWriteToBucket(countsBkt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return store.db.Update(tx)
|
||||
}
|
||||
|
||||
func (store *Store) getLabelsFromLocalStorage() ([]*pmapi.Label, error) {
|
||||
countsOnAPI, err := store.getOnAPICounts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels := []*pmapi.Label{}
|
||||
for _, counts := range countsOnAPI {
|
||||
labels = append(labels, counts.getPMLabel())
|
||||
}
|
||||
sortByOrder(labels)
|
||||
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func (store *Store) getOnAPICounts() ([]*mailboxCounts, error) {
|
||||
counts := []*mailboxCounts{}
|
||||
tx := func(tx *bolt.Tx) error {
|
||||
c := tx.Bucket(countsBucket).Cursor()
|
||||
for k, countsB := c.First(); k != nil; k, countsB = c.Next() {
|
||||
l := store.log.WithField("key", string(k))
|
||||
if countsB == nil {
|
||||
err := errors.New("empty counts in DB")
|
||||
l.WithError(err).Error("While getting local labels")
|
||||
return err
|
||||
}
|
||||
|
||||
mbCounts := &mailboxCounts{}
|
||||
if err := json.Unmarshal(countsB, mbCounts); err != nil {
|
||||
l.WithError(err).Error("While unmarshaling local labels")
|
||||
return err
|
||||
}
|
||||
|
||||
counts = append(counts, mbCounts)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
err := store.db.View(tx)
|
||||
|
||||
return counts, err
|
||||
}
|
||||
|
||||
// createOrUpdateOnAPICounts will change only on-API-counts.
|
||||
func (store *Store) createOrUpdateOnAPICounts(mailboxCountsOnAPI []*pmapi.MessagesCount) error {
|
||||
store.log.WithField("apiCounts", mailboxCountsOnAPI).Debug("Updating API counts")
|
||||
|
||||
tx := func(tx *bolt.Tx) error {
|
||||
countsBkt := tx.Bucket(countsBucket)
|
||||
for _, countsOnAPI := range mailboxCountsOnAPI {
|
||||
if skipThisLabel(countsOnAPI.LabelID) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get current data.
|
||||
counts, err := txGetCountsFromBucketOrNew(countsBkt, countsOnAPI.LabelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update only counts.
|
||||
counts.TotalOnAPI = uint(countsOnAPI.Total)
|
||||
counts.UnreadOnAPI = uint(countsOnAPI.Unread)
|
||||
|
||||
if err = counts.txWriteToBucket(countsBkt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return store.db.Update(tx)
|
||||
}
|
||||
|
||||
func (store *Store) removeMailboxCount(labelID string) error {
|
||||
err := store.db.Update(func(tx *bolt.Tx) error {
|
||||
return tx.Bucket(countsBucket).Delete([]byte(labelID))
|
||||
})
|
||||
if err != nil {
|
||||
store.log.WithError(err).
|
||||
WithField("labelID", labelID).
|
||||
Warning("Cannot remove counts")
|
||||
}
|
||||
return err
|
||||
}
|
||||
126
internal/store/mailbox_counts_test.go
Normal file
126
internal/store/mailbox_counts_test.go
Normal file
@ -0,0 +1,126 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func newLabel(order int, id, name string) *pmapi.Label {
|
||||
return &pmapi.Label{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Order: order,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortByOrder(t *testing.T) {
|
||||
want := []*pmapi.Label{
|
||||
newLabel(-1000, pmapi.InboxLabel, "INBOX"),
|
||||
newLabel(-5, pmapi.SentLabel, "Sent"),
|
||||
newLabel(-4, pmapi.ArchiveLabel, "Archive"),
|
||||
newLabel(-3, pmapi.SpamLabel, "Spam"),
|
||||
newLabel(-2, pmapi.TrashLabel, "Trash"),
|
||||
newLabel(-1, pmapi.AllMailLabel, "All Mail"),
|
||||
newLabel(100, "labelID1", "custom_label"),
|
||||
newLabel(1000, "folderID1", "custom_folder"),
|
||||
}
|
||||
labels := []*pmapi.Label{
|
||||
want[6],
|
||||
want[4],
|
||||
want[3],
|
||||
want[7],
|
||||
want[5],
|
||||
want[0],
|
||||
want[2],
|
||||
want[1],
|
||||
}
|
||||
|
||||
sortByOrder(labels)
|
||||
a.Equal(t, want, labels)
|
||||
}
|
||||
|
||||
func TestMailboxNames(t *testing.T) {
|
||||
want := map[string]string{
|
||||
pmapi.InboxLabel: "INBOX",
|
||||
pmapi.SentLabel: "Sent",
|
||||
pmapi.ArchiveLabel: "Archive",
|
||||
pmapi.SpamLabel: "Spam",
|
||||
pmapi.TrashLabel: "Trash",
|
||||
pmapi.AllMailLabel: "All Mail",
|
||||
pmapi.DraftLabel: "Drafts",
|
||||
"labelID1": "Labels/Label1",
|
||||
"folderID1": "Folders/Folder1",
|
||||
}
|
||||
|
||||
foldersAndLabels := []*pmapi.Label{
|
||||
newLabel(100, "labelID1", "Label1"),
|
||||
newLabel(1000, "folderID1", "Folder1"),
|
||||
}
|
||||
foldersAndLabels[1].Exclusive = 1
|
||||
|
||||
for _, counts := range getSystemFolders() {
|
||||
foldersAndLabels = append(foldersAndLabels, counts.getPMLabel())
|
||||
}
|
||||
|
||||
got := map[string]string{}
|
||||
for _, m := range foldersAndLabels {
|
||||
got[m.ID] = getLabelPrefix(m) + m.Name
|
||||
}
|
||||
a.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestAddSystemLabels(t *testing.T) {}
|
||||
|
||||
func checkCounts(t testing.TB, wantCounts []*pmapi.MessagesCount, haveStore *Store) {
|
||||
nSystemFolders := 7
|
||||
haveCounts, err := haveStore.getOnAPICounts()
|
||||
a.NoError(t, err)
|
||||
a.Len(t, haveCounts, len(wantCounts)+nSystemFolders)
|
||||
for iWant, wantCount := range wantCounts {
|
||||
iHave := iWant + nSystemFolders
|
||||
haveCount := haveCounts[iHave]
|
||||
a.Equal(t, wantCount.LabelID, haveCount.LabelID, "iHave:%d\niWant:%d\nHave:%v\nWant:%v", iHave, iWant, haveCount, wantCount)
|
||||
a.Equal(t, wantCount.Total, int(haveCount.TotalOnAPI), "iHave:%d\niWant:%d\nHave:%v\nWant:%v", iHave, iWant, haveCount, wantCount)
|
||||
a.Equal(t, wantCount.Unread, int(haveCount.UnreadOnAPI), "iHave:%d\niWant:%d\nHave:%v\nWant:%v", iHave, iWant, haveCount, wantCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMailboxCountRemove(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
testCounts := []*pmapi.MessagesCount{
|
||||
{LabelID: "label1", Total: 100, Unread: 0},
|
||||
{LabelID: "label2", Total: 100, Unread: 30},
|
||||
{LabelID: "label4", Total: 100, Unread: 100},
|
||||
}
|
||||
a.NoError(t, m.store.createOrUpdateOnAPICounts(testCounts))
|
||||
|
||||
a.NoError(t, m.store.removeMailboxCount("not existing"))
|
||||
checkCounts(t, testCounts, m.store)
|
||||
|
||||
var pop *pmapi.MessagesCount
|
||||
pop, testCounts = testCounts[2], testCounts[0:2]
|
||||
a.NoError(t, m.store.removeMailboxCount(pop.LabelID))
|
||||
checkCounts(t, testCounts, m.store)
|
||||
}
|
||||
263
internal/store/mailbox_ids.go
Normal file
263
internal/store/mailbox_ids.go
Normal file
@ -0,0 +1,263 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/internal/imap/uidplus"
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// GetAPIIDsFromUIDRange returns API IDs by IMAP UID range.
|
||||
//
|
||||
// API IDs are the long base64 strings that the API uses to identify messages.
|
||||
// UIDs are unique increasing integers that must be unique within a mailbox.
|
||||
func (storeMailbox *Mailbox) GetAPIIDsFromUIDRange(start, stop uint32) (apiIDs []string, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
b := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
|
||||
if stop == 0 {
|
||||
// A null stop means no stop.
|
||||
stop = ^uint32(0)
|
||||
}
|
||||
|
||||
startb := itob(start)
|
||||
stopb := itob(stop)
|
||||
|
||||
c := b.Cursor()
|
||||
for k, v := c.Seek(startb); k != nil && bytes.Compare(k, stopb) <= 0; k, v = c.Next() {
|
||||
apiIDs = append(apiIDs, string(v))
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetAPIIDsFromSequenceRange returns API IDs by IMAP sequence number range.
|
||||
func (storeMailbox *Mailbox) GetAPIIDsFromSequenceRange(start, stop uint32) (apiIDs []string, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
b := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
c := b.Cursor()
|
||||
var i uint32
|
||||
for k, v := c.First(); k != nil; k, v = c.Next() {
|
||||
i++
|
||||
if i < start {
|
||||
continue
|
||||
}
|
||||
if stop > 0 && i > stop {
|
||||
break
|
||||
}
|
||||
apiIDs = append(apiIDs, string(v))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetLatestAPIID returns the latest message API ID which still exists.
|
||||
// Info: not the latest IMAP UID which can be already removed.
|
||||
func (storeMailbox *Mailbox) GetLatestAPIID() (apiID string, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
b := storeMailbox.txGetAPIIDsBucket(tx)
|
||||
c := b.Cursor()
|
||||
lastAPIID, _ := c.Last()
|
||||
apiID = string(lastAPIID)
|
||||
if apiID == "" {
|
||||
return errors.New("cannot get latest API ID: empty mailbox")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetNextUID returns the next IMAP UID.
|
||||
func (storeMailbox *Mailbox) GetNextUID() (uid uint32, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
b := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
uid, err = storeMailbox.txGetNextUID(b, false)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (storeMailbox *Mailbox) txGetNextUID(imapIDBucket *bolt.Bucket, write bool) (uint32, error) {
|
||||
var uid uint64
|
||||
var err error
|
||||
if write {
|
||||
uid, err = imapIDBucket.NextSequence()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
uid = imapIDBucket.Sequence() + 1
|
||||
}
|
||||
if math.MaxUint32 <= uid {
|
||||
return 0, errors.New("too large sequence number")
|
||||
}
|
||||
return uint32(uid), nil
|
||||
}
|
||||
|
||||
// getUID returns IMAP UID in this mailbox for message ID.
|
||||
func (storeMailbox *Mailbox) getUID(apiID string) (uid uint32, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
uid, err = storeMailbox.txGetUID(tx, apiID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (storeMailbox *Mailbox) txGetUID(tx *bolt.Tx, apiID string) (uint32, error) {
|
||||
b := storeMailbox.txGetAPIIDsBucket(tx)
|
||||
v := b.Get([]byte(apiID))
|
||||
if v == nil {
|
||||
return 0, ErrNoSuchAPIID
|
||||
}
|
||||
return btoi(v), nil
|
||||
}
|
||||
|
||||
// getSequenceNumber returns IMAP sequence number in the mailbox for the message with the given API ID `apiID`.
|
||||
func (storeMailbox *Mailbox) getSequenceNumber(apiID string) (seqNum uint32, err error) {
|
||||
err = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
b := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
uid, err := storeMailbox.txGetUID(tx, apiID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
seqNum, err = storeMailbox.txGetSequenceNumberOfUID(b, itob(uid))
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// txGetSequenceNumberOfUID returns the IMAP sequence number of the message
|
||||
// with the given IMAP UID bytes `uidb`.
|
||||
//
|
||||
// NOTE: The `bolt.Cursor.Next()` loops in order of ascending key bytes. The
|
||||
// IMAP UID bucket is ordered by increasing UID because it's using BigEndian to
|
||||
// encode uint into byte. Hence the sequence number (IMAP ID) corresponds to
|
||||
// position of uid key in this order.
|
||||
func (storeMailbox *Mailbox) txGetSequenceNumberOfUID(bucket *bolt.Bucket, uidb []byte) (uint32, error) {
|
||||
seqNum := uint32(0)
|
||||
c := bucket.Cursor()
|
||||
|
||||
// Speed up for the case of last message. This is always true for
|
||||
// adding new message. It will return number of keys in bucket because
|
||||
// sequence number starts with 1.
|
||||
// We cannot use bucket.Stats() for that--it doesn't work in the same
|
||||
// transaction because stats are updated when transaction is committed.
|
||||
// But we can at least optimise to not do equal for all keys.
|
||||
lastKey, _ := c.Last()
|
||||
isLast := bytes.Equal(lastKey, uidb)
|
||||
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
seqNum++ // Sequence number starts at 1.
|
||||
if isLast {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(k, uidb) {
|
||||
return seqNum, nil
|
||||
}
|
||||
}
|
||||
|
||||
if isLast {
|
||||
return seqNum, nil
|
||||
}
|
||||
|
||||
return 0, ErrNoSuchUID
|
||||
}
|
||||
|
||||
// GetUIDList returns UID list corresponding to messageIDs in a requested order.
|
||||
func (storeMailbox *Mailbox) GetUIDList(apiIDs []string) *uidplus.OrderedSeq {
|
||||
seqSet := &uidplus.OrderedSeq{}
|
||||
_ = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
b := storeMailbox.txGetAPIIDsBucket(tx)
|
||||
for _, apiID := range apiIDs {
|
||||
v := b.Get([]byte(apiID))
|
||||
if v == nil {
|
||||
storeMailbox.log.
|
||||
WithField("msgID", apiID).
|
||||
Warn("Cannot find UID")
|
||||
continue
|
||||
}
|
||||
|
||||
seqSet.Add(btoi(v))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return seqSet
|
||||
}
|
||||
|
||||
// GetUIDByHeader returns UID of message existing in mailbox or zero if no match found.
|
||||
func (storeMailbox *Mailbox) GetUIDByHeader(header *mail.Header) (foundUID uint32) {
|
||||
if header == nil {
|
||||
return uint32(0)
|
||||
}
|
||||
|
||||
// Message-Id in appended-after-send mail is processed as ExternalID
|
||||
// in PM message. Message-Id in normal copy/move will be the PM internal ID.
|
||||
messageID := header.Get("Message-Id")
|
||||
|
||||
// The most often situation is that message is APPENDed after it was sent so the
|
||||
// Message-ID will be reflected by ExternalID in API message meta-data.
|
||||
externalID := strings.Trim(messageID, "<> ") // remove '<>' to improve match
|
||||
matchExternalID := regexp.MustCompile(`"ExternalID":"` +
|
||||
` *(\\u003c)? *` + // \u003c is equivalent to `<`
|
||||
regexp.QuoteMeta(externalID) +
|
||||
` *(\\u003e)? *` + // \u0033 is equivalent to `>`
|
||||
`"`,
|
||||
)
|
||||
|
||||
// It is possible that client will try to COPY existing message to Sent
|
||||
// using APPEND command. In that case the Message-Id from header will
|
||||
// be internal message ID and we need to check whether it's already there.
|
||||
matchInternalID := bytes.Split([]byte(externalID), []byte("@"))[0]
|
||||
|
||||
_ = storeMailbox.db().View(func(tx *bolt.Tx) error {
|
||||
metaBucket := tx.Bucket(metadataBucket)
|
||||
b := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
c := b.Cursor()
|
||||
imapID, apiID := c.Last()
|
||||
for ; imapID != nil; imapID, apiID = c.Prev() {
|
||||
rawMeta := metaBucket.Get(apiID)
|
||||
if rawMeta == nil {
|
||||
storeMailbox.log.
|
||||
WithField("IMAP-UID", imapID).
|
||||
WithField("API-ID", apiID).
|
||||
Warn("Cannot find meta-data while searching for externalID")
|
||||
continue
|
||||
}
|
||||
|
||||
if !matchExternalID.Match(rawMeta) && !bytes.Equal(apiID, matchInternalID) {
|
||||
continue
|
||||
}
|
||||
|
||||
foundUID = btoi(imapID)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return foundUID
|
||||
}
|
||||
147
internal/store/mailbox_ids_test.go
Normal file
147
internal/store/mailbox_ids_test.go
Normal file
@ -0,0 +1,147 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type wantID struct {
|
||||
appID string
|
||||
uid int
|
||||
}
|
||||
|
||||
func TestGetSequenceNumberAndGetUID(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
||||
|
||||
checkMailboxMessageIDs(t, m, pmapi.InboxLabel, []wantID{{"msg1", 1}, {"msg3", 2}})
|
||||
checkMailboxMessageIDs(t, m, pmapi.ArchiveLabel, []wantID{{"msg2", 1}})
|
||||
checkMailboxMessageIDs(t, m, pmapi.SpamLabel, []wantID(nil))
|
||||
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg1", 1}, {"msg2", 2}, {"msg3", 3}, {"msg4", 4}})
|
||||
}
|
||||
|
||||
// checkMailboxMessageIDs checks that the mailbox contains all API IDs with correct sequence numbers and UIDs.
|
||||
// wantIDs is map from IMAP UID to API ID. Sequence number is detected automatically by order of the ID in the map.
|
||||
func checkMailboxMessageIDs(t *testing.T, m *mocksForStore, mailboxLabel string, wantIDs []wantID) {
|
||||
storeAddress := m.store.addresses[addrID1]
|
||||
storeMailbox := storeAddress.mailboxes[mailboxLabel]
|
||||
|
||||
ids, err := storeMailbox.GetAPIIDsFromSequenceRange(0, uint32(len(wantIDs)))
|
||||
require.Nil(t, err)
|
||||
|
||||
idx := 0
|
||||
for _, wantID := range wantIDs {
|
||||
id := ids[idx]
|
||||
require.Equal(t, wantID.appID, id, "Got IDs: %+v", ids)
|
||||
|
||||
uid, err := storeMailbox.getUID(wantID.appID)
|
||||
require.Nil(t, err)
|
||||
a.Equal(t, uint32(wantID.uid), uid)
|
||||
|
||||
seqNum, err := storeMailbox.getSequenceNumber(wantID.appID)
|
||||
require.Nil(t, err)
|
||||
a.Equal(t, uint32(idx+1), seqNum)
|
||||
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUIDByHeader(t *testing.T) { //nolint[funlen]
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
tstMsg := getTestMessage("msg1", "Without external ID", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
tstMsg = getTestMessage("msg2", "External ID with spaces", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg.ExternalID = " externalID-non-pm-com "
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
tstMsg = getTestMessage("msg3", "External ID with <>", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg.ExternalID = "<externalID@pm.me>"
|
||||
tstMsg.Header = mail.Header{"References": []string{"wrongID", "externalID-non-pm-com", "msg2"}}
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
// Not sure if this is a real-world scenario but we should be able to address this properly.
|
||||
tstMsg = getTestMessage("msg4", "External ID with <> and spaces and special characters", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.SentLabel})
|
||||
tstMsg.ExternalID = " < external.()+*[]ID@another.pm.me > "
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(tstMsg))
|
||||
|
||||
testDataUIDByHeader := []struct {
|
||||
header *mail.Header
|
||||
wantID uint32
|
||||
}{
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"wrongID"}},
|
||||
0,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"ext"}},
|
||||
0,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"externalID"}},
|
||||
0,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"msg1"}},
|
||||
1,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"<msg3@pm.me>"}},
|
||||
3,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"<externalID-non-pm-com>"}},
|
||||
2,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"externalID@pm.me"}},
|
||||
3,
|
||||
},
|
||||
{
|
||||
&mail.Header{"Message-Id": []string{"external.()+*[]ID@another.pm.me"}},
|
||||
4,
|
||||
},
|
||||
}
|
||||
|
||||
storeAddress := m.store.addresses[addrID1]
|
||||
storeMailbox := storeAddress.mailboxes[pmapi.SentLabel]
|
||||
|
||||
for _, td := range testDataUIDByHeader {
|
||||
haveID := storeMailbox.GetUIDByHeader(td.header)
|
||||
a.Equal(t, td.wantID, haveID, "testing header: %v", td.header)
|
||||
}
|
||||
}
|
||||
375
internal/store/mailbox_message.go
Normal file
375
internal/store/mailbox_message.go
Normal file
@ -0,0 +1,375 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// GetMessage returns the `pmapi.Message` struct wrapped in `StoreMessage`
|
||||
// tied to this mailbox.
|
||||
func (storeMailbox *Mailbox) GetMessage(apiID string) (*Message, error) {
|
||||
msg, err := storeMailbox.store.getMessageFromDB(apiID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStoreMessage(storeMailbox, msg), nil
|
||||
}
|
||||
|
||||
// FetchMessage fetches the message with the given `apiID`, stores it in the database, and returns a new store message
|
||||
// wrapping it.
|
||||
func (storeMailbox *Mailbox) FetchMessage(apiID string) (*Message, error) {
|
||||
msg, err := storeMailbox.store.fetchMessage(apiID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStoreMessage(storeMailbox, msg), nil
|
||||
}
|
||||
|
||||
// ImportMessage imports the message by calling an API.
|
||||
// It has to be propagated to all mailboxes which is done by the event loop.
|
||||
func (storeMailbox *Mailbox) ImportMessage(msg *pmapi.Message, body []byte, labelIDs []string) error {
|
||||
defer storeMailbox.pollNow()
|
||||
|
||||
if storeMailbox.labelID != pmapi.AllMailLabel {
|
||||
labelIDs = append(labelIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
importReqs := &pmapi.ImportMsgReq{
|
||||
AddressID: msg.AddressID,
|
||||
Body: body,
|
||||
Unread: msg.Unread,
|
||||
Flags: msg.Flags,
|
||||
Time: msg.Time,
|
||||
LabelIDs: labelIDs,
|
||||
}
|
||||
|
||||
res, err := storeMailbox.api().Import([]*pmapi.ImportMsgReq{importReqs})
|
||||
if err == nil && len(res) > 0 {
|
||||
msg.ID = res[0].MessageID
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// LabelMessages adds the label by calling an API.
|
||||
// It has to be propagated to all the same messages in all mailboxes.
|
||||
// The propagation is processed by the event loop.
|
||||
func (storeMailbox *Mailbox) LabelMessages(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Labeling messages")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().LabelMessages(apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// UnlabelMessages removes the label by calling an API.
|
||||
// It has to be propagated to all the same messages in all mailboxes.
|
||||
// The propagation is processed by the event loop.
|
||||
func (storeMailbox *Mailbox) UnlabelMessages(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Unlabeling messages")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID)
|
||||
}
|
||||
|
||||
// MarkMessagesRead marks the message read by calling an API.
|
||||
// It has to be propagated to metadata mailbox which is done by the event loop.
|
||||
func (storeMailbox *Mailbox) MarkMessagesRead(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as read")
|
||||
defer storeMailbox.pollNow()
|
||||
|
||||
// Before deleting a message, TB sets \Seen flag which causes an event update
|
||||
// and thus a refresh of the message by deleting and creating it again.
|
||||
// TB does not notice this and happily continues with next command to move
|
||||
// the message to the Trash but the message does not exist anymore.
|
||||
// Therefore we do not issue API update if the message is already read.
|
||||
ids := []string{}
|
||||
for _, apiID := range apiIDs {
|
||||
if message, _ := storeMailbox.store.getMessageFromDB(apiID); message == nil || message.Unread == 1 {
|
||||
ids = append(ids, apiID)
|
||||
}
|
||||
}
|
||||
return storeMailbox.api().MarkMessagesRead(ids)
|
||||
}
|
||||
|
||||
// MarkMessagesUnread marks the message unread by calling an API.
|
||||
// It has to be propagated to metadata mailbox which is done by the event loop.
|
||||
func (storeMailbox *Mailbox) MarkMessagesUnread(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unread")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().MarkMessagesUnread(apiIDs)
|
||||
}
|
||||
|
||||
// MarkMessagesStarred adds the Starred label by calling an API.
|
||||
// It has to be propagated to all the same messages in all mailboxes.
|
||||
// The propagation is processed by the event loop.
|
||||
func (storeMailbox *Mailbox) MarkMessagesStarred(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as starred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().LabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// MarkMessagesUnstarred removes the Starred label by calling an API.
|
||||
// It has to be propagated to all the same messages in all mailboxes.
|
||||
// The propagation is processed by the event loop.
|
||||
func (storeMailbox *Mailbox) MarkMessagesUnstarred(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Marking messages as unstarred")
|
||||
defer storeMailbox.pollNow()
|
||||
return storeMailbox.api().UnlabelMessages(apiIDs, pmapi.StarredLabel)
|
||||
}
|
||||
|
||||
// DeleteMessages deletes messages.
|
||||
// If the mailbox is All Mail or All Sent, it does nothing.
|
||||
// If the mailbox is Trash or Spam and message is not in any other mailbox, messages is deleted.
|
||||
// In all other cases the message is only removed from the mailbox.
|
||||
func (storeMailbox *Mailbox) DeleteMessages(apiIDs []string) error {
|
||||
log.WithFields(logrus.Fields{
|
||||
"messages": apiIDs,
|
||||
"label": storeMailbox.labelID,
|
||||
"mailbox": storeMailbox.Name,
|
||||
}).Trace("Deleting messages")
|
||||
defer storeMailbox.pollNow()
|
||||
|
||||
switch storeMailbox.labelID {
|
||||
case pmapi.AllMailLabel, pmapi.AllSentLabel:
|
||||
break
|
||||
case pmapi.TrashLabel, pmapi.SpamLabel:
|
||||
messageIDsToDelete := []string{}
|
||||
messageIDsToUnlabel := []string{}
|
||||
for _, apiID := range apiIDs {
|
||||
msg, err := storeMailbox.store.getMessageFromDB(apiID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
otherLabels := false
|
||||
// If the message has any custom label, we don't want to delete it, only remove trash/spam label.
|
||||
for _, label := range msg.LabelIDs {
|
||||
if label != pmapi.SpamLabel && label != pmapi.TrashLabel && label != pmapi.AllMailLabel && label != pmapi.AllSentLabel && label != pmapi.DraftLabel && label != pmapi.AllDraftsLabel {
|
||||
otherLabels = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if otherLabels {
|
||||
messageIDsToUnlabel = append(messageIDsToUnlabel, apiID)
|
||||
} else {
|
||||
messageIDsToDelete = append(messageIDsToDelete, apiID)
|
||||
}
|
||||
}
|
||||
if len(messageIDsToUnlabel) > 0 {
|
||||
if err := storeMailbox.api().UnlabelMessages(messageIDsToUnlabel, storeMailbox.labelID); err != nil {
|
||||
log.WithError(err).Warning("Cannot unlabel before deleting")
|
||||
}
|
||||
}
|
||||
if len(messageIDsToDelete) > 0 {
|
||||
if err := storeMailbox.api().DeleteMessages(messageIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
if err := storeMailbox.api().UnlabelMessages(apiIDs, storeMailbox.labelID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storeMailbox *Mailbox) txSkipAndRemoveFromMailbox(tx *bolt.Tx, msg *pmapi.Message) (skipAndRemove bool) {
|
||||
defer func() {
|
||||
if skipAndRemove {
|
||||
if err := storeMailbox.txDeleteMessage(tx, msg.ID); err != nil {
|
||||
storeMailbox.log.WithError(err).Error("Cannot remove message")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
mode, err := storeMailbox.store.getAddressMode()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Could not determine address mode")
|
||||
return
|
||||
}
|
||||
|
||||
skipAndRemove = true
|
||||
|
||||
// If it's split mode and it shouldn't be under this address, it should be skipped and removed.
|
||||
if mode == splitMode && storeMailbox.storeAddress.addressID != msg.AddressID {
|
||||
return
|
||||
}
|
||||
|
||||
// If the message belongs in this mailbox, don't skip/remove it.
|
||||
for _, labelID := range msg.LabelIDs {
|
||||
if labelID == storeMailbox.labelID {
|
||||
skipAndRemove = false
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return skipAndRemove
|
||||
}
|
||||
|
||||
// txCreateOrUpdateMessages will delete, create or update message from mailbox.
|
||||
func (storeMailbox *Mailbox) txCreateOrUpdateMessages(tx *bolt.Tx, msgs []*pmapi.Message) error { //nolint[funlen]
|
||||
// Buckets are not initialized right away because it's a heavy operation.
|
||||
// The best option is to get the same bucket only once and only when needed.
|
||||
var apiBucket, imapBucket *bolt.Bucket
|
||||
for _, msg := range msgs {
|
||||
if storeMailbox.txSkipAndRemoveFromMailbox(tx, msg) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update message.
|
||||
if apiBucket == nil {
|
||||
apiBucket = storeMailbox.txGetAPIIDsBucket(tx)
|
||||
}
|
||||
|
||||
// Draft bodies can change and bodies are not re-fetched by IMAP clients.
|
||||
// Every change has to be a new message; we need to delete the old one and always recreate it.
|
||||
if storeMailbox.labelID == pmapi.DraftLabel {
|
||||
if err := storeMailbox.txDeleteMessage(tx, msg.ID); err != nil {
|
||||
return errors.Wrap(err, "cannot delete old draft")
|
||||
}
|
||||
} else {
|
||||
uidb := apiBucket.Get([]byte(msg.ID))
|
||||
|
||||
if uidb != nil {
|
||||
if imapBucket == nil {
|
||||
imapBucket = storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
}
|
||||
seqNum, seqErr := storeMailbox.txGetSequenceNumberOfUID(imapBucket, uidb)
|
||||
if seqErr == nil {
|
||||
storeMailbox.store.imapUpdateMessage(
|
||||
storeMailbox.storeAddress.address,
|
||||
storeMailbox.labelName,
|
||||
btoi(uidb),
|
||||
seqNum,
|
||||
msg,
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new message.
|
||||
if imapBucket == nil {
|
||||
imapBucket = storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
}
|
||||
uid, err := storeMailbox.txGetNextUID(imapBucket, true)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot generate new UID")
|
||||
}
|
||||
uidb := itob(uid)
|
||||
|
||||
if err = imapBucket.Put(uidb, []byte(msg.ID)); err != nil {
|
||||
return errors.Wrap(err, "cannot add to IMAP bucket")
|
||||
}
|
||||
if err = apiBucket.Put([]byte(msg.ID), uidb); err != nil {
|
||||
return errors.Wrap(err, "cannot add to API bucket")
|
||||
}
|
||||
|
||||
seqNum, err := storeMailbox.txGetSequenceNumberOfUID(imapBucket, uidb)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot get sequence number from UID")
|
||||
}
|
||||
storeMailbox.store.imapUpdateMessage(
|
||||
storeMailbox.storeAddress.address,
|
||||
storeMailbox.labelName,
|
||||
uid,
|
||||
seqNum,
|
||||
msg,
|
||||
)
|
||||
}
|
||||
|
||||
return storeMailbox.txMailboxStatusUpdate(tx)
|
||||
}
|
||||
|
||||
// txDeleteMessage deletes the message from the mailbox bucket.
|
||||
// and issues message delete and mailbox update changes to updates channel.
|
||||
func (storeMailbox *Mailbox) txDeleteMessage(tx *bolt.Tx, apiID string) error {
|
||||
apiBucket := storeMailbox.txGetAPIIDsBucket(tx)
|
||||
apiIDb := []byte(apiID)
|
||||
uidb := apiBucket.Get(apiIDb)
|
||||
if uidb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
imapBucket := storeMailbox.txGetIMAPIDsBucket(tx)
|
||||
|
||||
seqNum, seqNumErr := storeMailbox.txGetSequenceNumberOfUID(imapBucket, uidb)
|
||||
if seqNumErr != nil {
|
||||
storeMailbox.log.WithField("apiID", apiID).WithError(seqNumErr).Warn("Cannot get seqNum of deleting message")
|
||||
}
|
||||
|
||||
if err := imapBucket.Delete(uidb); err != nil {
|
||||
return errors.Wrap(err, "cannot delete from IMAP bucket")
|
||||
}
|
||||
|
||||
if err := apiBucket.Delete(apiIDb); err != nil {
|
||||
return errors.Wrap(err, "cannot delete from API bucket")
|
||||
}
|
||||
|
||||
if seqNumErr == nil {
|
||||
storeMailbox.store.imapDeleteMessage(
|
||||
storeMailbox.storeAddress.address,
|
||||
storeMailbox.labelName,
|
||||
seqNum,
|
||||
)
|
||||
if err := storeMailbox.txMailboxStatusUpdate(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (storeMailbox *Mailbox) txMailboxStatusUpdate(tx *bolt.Tx) error {
|
||||
total, unread, err := storeMailbox.txGetCounts(tx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot get counts for mailbox status update")
|
||||
}
|
||||
storeMailbox.store.imapMailboxStatus(
|
||||
storeMailbox.storeAddress.address,
|
||||
storeMailbox.labelName,
|
||||
total,
|
||||
unread,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
31
internal/store/main_test.go
Normal file
31
internal/store/main_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() { //nolint[gochecknoinits]
|
||||
logrus.SetLevel(logrus.ErrorLevel)
|
||||
if os.Getenv("VERBOSITY") == "trace" {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
}
|
||||
}
|
||||
108
internal/store/message.go
Normal file
108
internal/store/message.go
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// Message is wrapper around `pmapi.Message` with connection to
|
||||
// a specific mailbox with helper functions to get IMAP UID, sequence
|
||||
// numbers and similar.
|
||||
type Message struct {
|
||||
api PMAPIProvider
|
||||
msg *pmapi.Message
|
||||
|
||||
store *Store
|
||||
storeMailbox *Mailbox
|
||||
}
|
||||
|
||||
func newStoreMessage(storeMailbox *Mailbox, msg *pmapi.Message) *Message {
|
||||
return &Message{
|
||||
api: storeMailbox.store.api,
|
||||
msg: msg,
|
||||
store: storeMailbox.store,
|
||||
storeMailbox: storeMailbox,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns message ID on our API (always the same ID for all mailboxes).
|
||||
func (message *Message) ID() string {
|
||||
return message.msg.ID
|
||||
}
|
||||
|
||||
// UID returns message UID for IMAP, specific for mailbox used to get the message.
|
||||
func (message *Message) UID() (uint32, error) {
|
||||
return message.storeMailbox.getUID(message.ID())
|
||||
}
|
||||
|
||||
// SequenceNumber returns index of message in used mailbox.
|
||||
func (message *Message) SequenceNumber() (uint32, error) {
|
||||
return message.storeMailbox.getSequenceNumber(message.ID())
|
||||
}
|
||||
|
||||
// Message returns message struct from pmapi.
|
||||
func (message *Message) Message() *pmapi.Message {
|
||||
return message.msg
|
||||
}
|
||||
|
||||
// SetSize updates the information about size of decrypted message which can be
|
||||
// used for IMAP. This should not trigger any IMAP update.
|
||||
// NOTE: The size from the server corresponds to pure body bytes. Hence it
|
||||
// should not be used. The correct size has to be calculated from decrypted and
|
||||
// built message.
|
||||
func (message *Message) SetSize(size int64) error {
|
||||
message.msg.Size = size
|
||||
txUpdate := func(tx *bolt.Tx) error {
|
||||
stored, err := message.store.txGetMessage(tx, message.msg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stored.Size = size
|
||||
return message.store.txPutMessage(
|
||||
tx.Bucket(metadataBucket),
|
||||
stored,
|
||||
)
|
||||
}
|
||||
return message.store.db.Update(txUpdate)
|
||||
}
|
||||
|
||||
// SetContentTypeAndHeader updates the information about content type and
|
||||
// header of decrypted message. This should not trigger any IMAP update.
|
||||
// NOTE: Content type depends on details of decrypted message which we want to
|
||||
// cache.
|
||||
func (message *Message) SetContentTypeAndHeader(mimeType string, header mail.Header) error {
|
||||
message.msg.MIMEType = mimeType
|
||||
message.msg.Header = header
|
||||
txUpdate := func(tx *bolt.Tx) error {
|
||||
stored, err := message.store.txGetMessage(tx, message.msg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stored.MIMEType = mimeType
|
||||
stored.Header = header
|
||||
return message.store.txPutMessage(
|
||||
tx.Bucket(metadataBucket),
|
||||
stored,
|
||||
)
|
||||
}
|
||||
return message.store.db.Update(txUpdate)
|
||||
}
|
||||
193
internal/store/mocks/mocks.go
Normal file
193
internal/store/mocks/mocks.go
Normal file
@ -0,0 +1,193 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/internal/store (interfaces: PanicHandler,BridgeUser)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockPanicHandler is a mock of PanicHandler interface
|
||||
type MockPanicHandler struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPanicHandlerMockRecorder
|
||||
}
|
||||
|
||||
// MockPanicHandlerMockRecorder is the mock recorder for MockPanicHandler
|
||||
type MockPanicHandlerMockRecorder struct {
|
||||
mock *MockPanicHandler
|
||||
}
|
||||
|
||||
// NewMockPanicHandler creates a new mock instance
|
||||
func NewMockPanicHandler(ctrl *gomock.Controller) *MockPanicHandler {
|
||||
mock := &MockPanicHandler{ctrl: ctrl}
|
||||
mock.recorder = &MockPanicHandlerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockPanicHandler) EXPECT() *MockPanicHandlerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// HandlePanic mocks base method
|
||||
func (m *MockPanicHandler) HandlePanic() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "HandlePanic")
|
||||
}
|
||||
|
||||
// HandlePanic indicates an expected call of HandlePanic
|
||||
func (mr *MockPanicHandlerMockRecorder) HandlePanic() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlePanic", reflect.TypeOf((*MockPanicHandler)(nil).HandlePanic))
|
||||
}
|
||||
|
||||
// MockBridgeUser is a mock of BridgeUser interface
|
||||
type MockBridgeUser struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockBridgeUserMockRecorder
|
||||
}
|
||||
|
||||
// MockBridgeUserMockRecorder is the mock recorder for MockBridgeUser
|
||||
type MockBridgeUserMockRecorder struct {
|
||||
mock *MockBridgeUser
|
||||
}
|
||||
|
||||
// NewMockBridgeUser creates a new mock instance
|
||||
func NewMockBridgeUser(ctrl *gomock.Controller) *MockBridgeUser {
|
||||
mock := &MockBridgeUser{ctrl: ctrl}
|
||||
mock.recorder = &MockBridgeUserMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockBridgeUser) EXPECT() *MockBridgeUserMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CloseConnection mocks base method
|
||||
func (m *MockBridgeUser) CloseConnection(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "CloseConnection", arg0)
|
||||
}
|
||||
|
||||
// CloseConnection indicates an expected call of CloseConnection
|
||||
func (mr *MockBridgeUserMockRecorder) CloseConnection(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseConnection", reflect.TypeOf((*MockBridgeUser)(nil).CloseConnection), arg0)
|
||||
}
|
||||
|
||||
// GetAddressID mocks base method
|
||||
func (m *MockBridgeUser) GetAddressID(arg0 string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAddressID", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAddressID indicates an expected call of GetAddressID
|
||||
func (mr *MockBridgeUserMockRecorder) GetAddressID(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddressID", reflect.TypeOf((*MockBridgeUser)(nil).GetAddressID), arg0)
|
||||
}
|
||||
|
||||
// GetPrimaryAddress mocks base method
|
||||
func (m *MockBridgeUser) GetPrimaryAddress() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPrimaryAddress")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetPrimaryAddress indicates an expected call of GetPrimaryAddress
|
||||
func (mr *MockBridgeUserMockRecorder) GetPrimaryAddress() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryAddress", reflect.TypeOf((*MockBridgeUser)(nil).GetPrimaryAddress))
|
||||
}
|
||||
|
||||
// GetStoreAddresses mocks base method
|
||||
func (m *MockBridgeUser) GetStoreAddresses() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStoreAddresses")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetStoreAddresses indicates an expected call of GetStoreAddresses
|
||||
func (mr *MockBridgeUserMockRecorder) GetStoreAddresses() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStoreAddresses", reflect.TypeOf((*MockBridgeUser)(nil).GetStoreAddresses))
|
||||
}
|
||||
|
||||
// ID mocks base method
|
||||
func (m *MockBridgeUser) ID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ID indicates an expected call of ID
|
||||
func (mr *MockBridgeUserMockRecorder) ID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockBridgeUser)(nil).ID))
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode mocks base method
|
||||
func (m *MockBridgeUser) IsCombinedAddressMode() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsCombinedAddressMode")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsCombinedAddressMode indicates an expected call of IsCombinedAddressMode
|
||||
func (mr *MockBridgeUserMockRecorder) IsCombinedAddressMode() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCombinedAddressMode", reflect.TypeOf((*MockBridgeUser)(nil).IsCombinedAddressMode))
|
||||
}
|
||||
|
||||
// IsConnected mocks base method
|
||||
func (m *MockBridgeUser) IsConnected() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsConnected")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsConnected indicates an expected call of IsConnected
|
||||
func (mr *MockBridgeUserMockRecorder) IsConnected() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockBridgeUser)(nil).IsConnected))
|
||||
}
|
||||
|
||||
// Logout mocks base method
|
||||
func (m *MockBridgeUser) Logout() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Logout")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Logout indicates an expected call of Logout
|
||||
func (mr *MockBridgeUserMockRecorder) Logout() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logout", reflect.TypeOf((*MockBridgeUser)(nil).Logout))
|
||||
}
|
||||
|
||||
// UpdateUser mocks base method
|
||||
func (m *MockBridgeUser) UpdateUser() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUser")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateUser indicates an expected call of UpdateUser
|
||||
func (mr *MockBridgeUserMockRecorder) UpdateUser() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockBridgeUser)(nil).UpdateUser))
|
||||
}
|
||||
106
internal/store/mocks/utils_mocks.go
Normal file
106
internal/store/mocks/utils_mocks.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/ProtonMail/proton-bridge/pkg/listener (interfaces: Listener)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
)
|
||||
|
||||
// MockListener is a mock of Listener interface
|
||||
type MockListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockListenerMockRecorder
|
||||
}
|
||||
|
||||
// MockListenerMockRecorder is the mock recorder for MockListener
|
||||
type MockListenerMockRecorder struct {
|
||||
mock *MockListener
|
||||
}
|
||||
|
||||
// NewMockListener creates a new mock instance
|
||||
func NewMockListener(ctrl *gomock.Controller) *MockListener {
|
||||
mock := &MockListener{ctrl: ctrl}
|
||||
mock.recorder = &MockListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method
|
||||
func (m *MockListener) Add(arg0 string, arg1 chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add
|
||||
func (mr *MockListenerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockListener)(nil).Add), arg0, arg1)
|
||||
}
|
||||
|
||||
// Emit mocks base method
|
||||
func (m *MockListener) Emit(arg0, arg1 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Emit", arg0, arg1)
|
||||
}
|
||||
|
||||
// Emit indicates an expected call of Emit
|
||||
func (mr *MockListenerMockRecorder) Emit(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Emit", reflect.TypeOf((*MockListener)(nil).Emit), arg0, arg1)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockListener) Remove(arg0 string, arg1 chan<- string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Remove", arg0, arg1)
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockListenerMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockListener)(nil).Remove), arg0, arg1)
|
||||
}
|
||||
|
||||
// RetryEmit mocks base method
|
||||
func (m *MockListener) RetryEmit(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RetryEmit", arg0)
|
||||
}
|
||||
|
||||
// RetryEmit indicates an expected call of RetryEmit
|
||||
func (mr *MockListenerMockRecorder) RetryEmit(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetryEmit", reflect.TypeOf((*MockListener)(nil).RetryEmit), arg0)
|
||||
}
|
||||
|
||||
// SetBuffer mocks base method
|
||||
func (m *MockListener) SetBuffer(arg0 string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetBuffer", arg0)
|
||||
}
|
||||
|
||||
// SetBuffer indicates an expected call of SetBuffer
|
||||
func (mr *MockListenerMockRecorder) SetBuffer(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBuffer", reflect.TypeOf((*MockListener)(nil).SetBuffer), arg0)
|
||||
}
|
||||
|
||||
// SetLimit mocks base method
|
||||
func (m *MockListener) SetLimit(arg0 string, arg1 time.Duration) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetLimit", arg0, arg1)
|
||||
}
|
||||
|
||||
// SetLimit indicates an expected call of SetLimit
|
||||
func (mr *MockListenerMockRecorder) SetLimit(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockListener)(nil).SetLimit), arg0, arg1)
|
||||
}
|
||||
396
internal/store/store.go
Normal file
396
internal/store/store.go
Normal file
@ -0,0 +1,396 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
// Package store communicates with API and caches metadata in a local database.
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/listener"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
// PathDelimiter for IMAP
|
||||
PathDelimiter = "/"
|
||||
// UserLabelsMailboxName for IMAP
|
||||
UserLabelsMailboxName = "Labels"
|
||||
// UserLabelsPrefix contains name with delimiter for IMAP
|
||||
UserLabelsPrefix = UserLabelsMailboxName + PathDelimiter
|
||||
// UserFoldersMailboxName for IMAP
|
||||
UserFoldersMailboxName = "Folders"
|
||||
// UserFoldersPrefix contains name with delimiter for IMAP
|
||||
UserFoldersPrefix = UserFoldersMailboxName + PathDelimiter
|
||||
)
|
||||
|
||||
var (
|
||||
log = logrus.WithField("pkg", "store") //nolint[gochecknoglobals]
|
||||
|
||||
// Database structure:
|
||||
// * metadata
|
||||
// * {messageID} -> message data (subject, from, to, time, headers, body size, ...)
|
||||
// * counts
|
||||
// * {mailboxID} -> mailboxCounts: totalOnAPI, unreadOnAPI, labelName, labelColor, labelIsExclusive
|
||||
// * address_info
|
||||
// * {index} -> {address, addressID}
|
||||
// * address_mode
|
||||
// * mode -> string split or combined
|
||||
// * mailboxes_version
|
||||
// * version -> uint32 value
|
||||
// * sync_state
|
||||
// * sync_state -> string timestamp when it was last synced (when missing, sync should be ongoing)
|
||||
// * ids_ranges -> json array of groups with start and end message ID (when missing, there is no ongoing sync)
|
||||
// * ids_to_be_deleted -> json array of message IDs to be deleted after sync (when missing, there is no ongoing sync)
|
||||
// * mailboxes
|
||||
// * {addressID+mailboxID}
|
||||
// * imap_ids
|
||||
// * {imapUID} -> string messageID
|
||||
// * api_ids
|
||||
// * {messageID} -> uint32 imapUID
|
||||
metadataBucket = []byte("metadata") //nolint[gochecknoglobals]
|
||||
countsBucket = []byte("counts") //nolint[gochecknoglobals]
|
||||
addressInfoBucket = []byte("address_info") //nolint[gochecknoglobals]
|
||||
addressModeBucket = []byte("address_mode") //nolint[gochecknoglobals]
|
||||
syncStateBucket = []byte("sync_state") //nolint[gochecknoglobals]
|
||||
mailboxesBucket = []byte("mailboxes") //nolint[gochecknoglobals]
|
||||
imapIDsBucket = []byte("imap_ids") //nolint[gochecknoglobals]
|
||||
apiIDsBucket = []byte("api_ids") //nolint[gochecknoglobals]
|
||||
mboxVersionBucket = []byte("mailboxes_version") //nolint[gochecknoglobals]
|
||||
|
||||
// ErrNoSuchAPIID when mailbox does not have API ID.
|
||||
ErrNoSuchAPIID = errors.New("no such api id") //nolint[gochecknoglobals]
|
||||
// ErrNoSuchUID when mailbox does not have IMAP UID.
|
||||
ErrNoSuchUID = errors.New("no such uid") //nolint[gochecknoglobals]
|
||||
// ErrNoSuchSeqNum when mailbox does not have IMAP ID.
|
||||
ErrNoSuchSeqNum = errors.New("no such sequence number") //nolint[gochecknoglobals]
|
||||
)
|
||||
|
||||
// Store is local user storage, which handles the synchronization between IMAP and PM API.
|
||||
type Store struct {
|
||||
panicHandler PanicHandler
|
||||
eventLoop *eventLoop
|
||||
user BridgeUser
|
||||
api PMAPIProvider
|
||||
|
||||
log *logrus.Entry
|
||||
|
||||
cache *Cache
|
||||
filePath string
|
||||
db *bolt.DB
|
||||
lock *sync.RWMutex
|
||||
addresses map[string]*Address
|
||||
imapUpdates chan interface{}
|
||||
|
||||
isSyncRunning bool
|
||||
addressMode addressMode
|
||||
}
|
||||
|
||||
// New creates or opens a store for the given `user`.
|
||||
func New(
|
||||
panicHandler PanicHandler,
|
||||
user BridgeUser,
|
||||
api PMAPIProvider,
|
||||
events listener.Listener,
|
||||
path string,
|
||||
cache *Cache,
|
||||
) (store *Store, err error) {
|
||||
if user == nil || api == nil || events == nil || cache == nil {
|
||||
return nil, fmt.Errorf("missing parameters - user: %v, api: %v, events: %v, cache: %v", user, api, events, cache)
|
||||
}
|
||||
|
||||
l := log.WithField("user", user.ID())
|
||||
|
||||
var firstInit bool
|
||||
if _, existErr := os.Stat(path); os.IsNotExist(existErr) {
|
||||
l.Info("Creating new store database file with address mode from user's credentials store")
|
||||
firstInit = true
|
||||
} else {
|
||||
l.Info("Store database file already exists, using mode already set")
|
||||
firstInit = false
|
||||
}
|
||||
|
||||
bdb, err := openBoltDatabase(path)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "failed to open store database")
|
||||
return
|
||||
}
|
||||
|
||||
store = &Store{
|
||||
panicHandler: panicHandler,
|
||||
api: api,
|
||||
user: user,
|
||||
cache: cache,
|
||||
filePath: path,
|
||||
db: bdb,
|
||||
lock: &sync.RWMutex{},
|
||||
log: l,
|
||||
}
|
||||
|
||||
if err = store.init(firstInit); err != nil {
|
||||
l.WithError(err).Error("Could not initialise store, attempting to close")
|
||||
if storeCloseErr := store.Close(); storeCloseErr != nil {
|
||||
l.WithError(storeCloseErr).Warn("Could not close uninitialised store")
|
||||
}
|
||||
err = errors.Wrap(err, "failed to initialise store")
|
||||
return
|
||||
}
|
||||
|
||||
if user.IsConnected() {
|
||||
store.eventLoop = newEventLoop(cache, store, api, user, events)
|
||||
go func() {
|
||||
defer store.panicHandler.HandlePanic()
|
||||
store.eventLoop.start()
|
||||
}()
|
||||
}
|
||||
|
||||
return store, err
|
||||
}
|
||||
|
||||
func openBoltDatabase(filePath string) (db *bolt.DB, err error) {
|
||||
l := log.WithField("path", filePath)
|
||||
l.Debug("Opening bolt database")
|
||||
|
||||
if db, err = bolt.Open(filePath, 0600, &bolt.Options{Timeout: 1 * time.Second}); err != nil {
|
||||
l.WithError(err).Error("Could not open bolt database")
|
||||
return
|
||||
}
|
||||
|
||||
if val, set := os.LookupEnv("BRIDGESTRICTMODE"); set && val == "1" {
|
||||
db.StrictMode = true
|
||||
}
|
||||
|
||||
tx := func(tx *bolt.Tx) (err error) {
|
||||
if _, err = tx.CreateBucketIfNotExists(metadataBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(countsBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(addressInfoBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(addressModeBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(syncStateBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(mailboxesBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(mboxVersionBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = db.Update(tx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return db, err
|
||||
}
|
||||
|
||||
// init initialises the store for the given addresses.
|
||||
func (store *Store) init(firstInit bool) (err error) {
|
||||
if store.addresses != nil {
|
||||
store.log.Warn("Store was already initialised")
|
||||
return
|
||||
}
|
||||
|
||||
// If it's the first time we are creating the store, use the mode set in the
|
||||
// user's credentials, otherwise read it from the DB (if present).
|
||||
if firstInit {
|
||||
if store.user.IsCombinedAddressMode() {
|
||||
err = store.setAddressMode(combinedMode)
|
||||
} else {
|
||||
err = store.setAddressMode(splitMode)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "first init setting store address mode")
|
||||
}
|
||||
} else if store.addressMode, err = store.getAddressMode(); err != nil {
|
||||
store.log.WithError(err).Error("Store address mode is unknown, setting to combined mode")
|
||||
if err = store.setAddressMode(combinedMode); err != nil {
|
||||
return errors.Wrap(err, "setting store address mode")
|
||||
}
|
||||
}
|
||||
|
||||
store.log.WithField("mode", store.addressMode).Debug("Initialising store")
|
||||
|
||||
labels, err := store.initCounts()
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Could not initialise label counts")
|
||||
return
|
||||
}
|
||||
|
||||
if err = store.initAddresses(labels); err != nil {
|
||||
store.log.WithError(err).Error("Could not initialise store addresses")
|
||||
return
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// initCounts initialises the counts for each label. It tries to use the API first to fetch the labels but if
|
||||
// the API is unavailable for whatever reason it tries to fetch the labels locally.
|
||||
func (store *Store) initCounts() (labels []*pmapi.Label, err error) {
|
||||
if labels, err = store.api.ListLabels(); err != nil {
|
||||
store.log.WithError(err).Warn("Could not list API labels. Trying with local labels.")
|
||||
if labels, err = store.getLabelsFromLocalStorage(); err != nil {
|
||||
store.log.WithError(err).Error("Cannot list local labels")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// the labels listed by PMAPI don't include system folders so we need to add them.
|
||||
for _, counts := range getSystemFolders() {
|
||||
labels = append(labels, counts.getPMLabel())
|
||||
}
|
||||
|
||||
if err = store.createOrUpdateMailboxCountsBuckets(labels); err != nil {
|
||||
store.log.WithError(err).Error("Cannot create counts")
|
||||
return
|
||||
}
|
||||
|
||||
if countsErr := store.updateCountsFromServer(); countsErr != nil {
|
||||
store.log.WithError(countsErr).Warning("Continue without new counts from server")
|
||||
}
|
||||
}
|
||||
|
||||
sortByOrder(labels)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// initAddresses creates address objects in the store for each necessary address.
|
||||
// In combined mode this means just one mailbox for all addresses but in split mode this means one mailbox per address.
|
||||
func (store *Store) initAddresses(labels []*pmapi.Label) (err error) {
|
||||
store.addresses = make(map[string]*Address)
|
||||
|
||||
addrInfo, err := store.GetAddressInfo()
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Could not get addresses and address IDs")
|
||||
return
|
||||
}
|
||||
|
||||
// We need at least one address to continue.
|
||||
if len(addrInfo) < 1 {
|
||||
err = errors.New("no addresses to initialise")
|
||||
store.log.WithError(err).Warn("There are no addresses to initialise")
|
||||
return
|
||||
}
|
||||
|
||||
// If in combined mode, we only need the user's primary address.
|
||||
if store.addressMode == combinedMode {
|
||||
addrInfo = addrInfo[:1]
|
||||
}
|
||||
|
||||
for _, addr := range addrInfo {
|
||||
if err = store.addAddress(addr.Address, addr.AddressID, labels); err != nil {
|
||||
store.log.WithField("address", addr.Address).WithError(err).Error("Could not add address to store")
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// addAddress adds a new address to the store. If the address exists already it is overwritten.
|
||||
func (store *Store) addAddress(address, addressID string, labels []*pmapi.Label) (err error) {
|
||||
if _, ok := store.addresses[addressID]; ok {
|
||||
store.log.WithField("addressID", addressID).Debug("Overwriting store address")
|
||||
}
|
||||
|
||||
addr, err := newAddress(store, address, addressID, labels)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create store address object")
|
||||
}
|
||||
|
||||
store.addresses[addressID] = addr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Close stops the event loop and closes the database to free the file.
|
||||
func (store *Store) Close() error {
|
||||
store.lock.Lock()
|
||||
defer store.lock.Unlock()
|
||||
|
||||
return store.close()
|
||||
}
|
||||
|
||||
// CloseEventLoop stops the eventloop (if it is present).
|
||||
func (store *Store) CloseEventLoop() {
|
||||
if store.eventLoop != nil {
|
||||
store.eventLoop.stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) close() error {
|
||||
store.CloseEventLoop()
|
||||
return store.db.Close()
|
||||
}
|
||||
|
||||
// Remove closes and removes the database file and clears the cache file.
|
||||
func (store *Store) Remove() (err error) {
|
||||
store.lock.Lock()
|
||||
defer store.lock.Unlock()
|
||||
|
||||
store.log.Trace("Removing store")
|
||||
|
||||
var result *multierror.Error
|
||||
|
||||
if err = store.close(); err != nil {
|
||||
result = multierror.Append(result, errors.Wrap(err, "failed to close store"))
|
||||
}
|
||||
|
||||
if err = RemoveStore(store.cache, store.filePath, store.user.ID()); err != nil {
|
||||
result = multierror.Append(result, errors.Wrap(err, "failed to remove store"))
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// RemoveStore removes the database file and clears the cache file.
|
||||
func RemoveStore(cache *Cache, path, userID string) error {
|
||||
var result *multierror.Error
|
||||
|
||||
if err := cache.clearCacheUser(userID); err != nil {
|
||||
result = multierror.Append(result, errors.Wrap(err, "failed to clear event loop user cache"))
|
||||
}
|
||||
|
||||
// RemoveAll will not return an error if the path does not exist.
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
result = multierror.Append(result, errors.Wrap(err, "failed to remove database file"))
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
112
internal/store/store_address_mode.go
Normal file
112
internal/store/store_address_mode.go
Normal file
@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
type addressMode string
|
||||
|
||||
const (
|
||||
splitMode addressMode = "split"
|
||||
combinedMode addressMode = "combined"
|
||||
modeKey = "mode"
|
||||
)
|
||||
|
||||
// getAddressMode returns the current address mode (split or combined) of the store.
|
||||
// It first looks in the local cache but if that is not yet set, it loads it from the database.
|
||||
func (store *Store) getAddressMode() (mode addressMode, err error) {
|
||||
if store.addressMode != "" {
|
||||
mode = store.addressMode
|
||||
return
|
||||
}
|
||||
|
||||
tx := func(tx *bolt.Tx) (err error) {
|
||||
b := tx.Bucket(addressModeBucket)
|
||||
|
||||
dbMode := b.Get([]byte(modeKey))
|
||||
if dbMode == nil {
|
||||
return errors.New("address mode not set")
|
||||
}
|
||||
|
||||
mode = addressMode(dbMode)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = store.db.View(tx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsCombinedMode returns whether the store is set to combined mode.
|
||||
func (store *Store) IsCombinedMode() bool {
|
||||
return store.addressMode == combinedMode
|
||||
}
|
||||
|
||||
// UseCombinedMode sets whether the store should be set to combined mode.
|
||||
func (store *Store) UseCombinedMode(useCombined bool) (err error) {
|
||||
if useCombined {
|
||||
err = store.switchAddressMode(combinedMode)
|
||||
} else {
|
||||
err = store.switchAddressMode(splitMode)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// switchAddressMode sets the address mode to the given value and rebuilds the mailboxes.
|
||||
func (store *Store) switchAddressMode(mode addressMode) (err error) {
|
||||
if store.addressMode == mode {
|
||||
log.Debug("The store is using the correct address mode")
|
||||
return
|
||||
}
|
||||
|
||||
if err = store.setAddressMode(mode); err != nil {
|
||||
log.WithError(err).Error("Could not set store address mode")
|
||||
return
|
||||
}
|
||||
|
||||
if err = store.RebuildMailboxes(); err != nil {
|
||||
log.WithError(err).Error("Could not rebuild mailboxes after switching address mode")
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// setAddressMode sets the current address mode (split or combined) of the store.
|
||||
// It writes to database and updates the local value in the store object.
|
||||
func (store *Store) setAddressMode(mode addressMode) (err error) {
|
||||
store.log.WithField("mode", string(mode)).Info("Setting store address mode")
|
||||
|
||||
tx := func(tx *bolt.Tx) (err error) {
|
||||
b := tx.Bucket(addressModeBucket)
|
||||
return b.Put([]byte(modeKey), []byte(mode))
|
||||
}
|
||||
|
||||
if err = store.db.Update(tx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
store.addressMode = mode
|
||||
|
||||
return
|
||||
}
|
||||
68
internal/store/store_structure_version.go
Normal file
68
internal/store/store_structure_version.go
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import bolt "go.etcd.io/bbolt"
|
||||
|
||||
const (
|
||||
versionKey = "version"
|
||||
|
||||
// versionOffset makes it possible to force email client to reload all
|
||||
// mailboxes. If increased during application update it will trigger
|
||||
// the reload on client side without needing to sync DB or re-setup account.
|
||||
versionOffset = uint32(3)
|
||||
)
|
||||
|
||||
func (store *Store) getMailboxesVersion() uint32 {
|
||||
localVersion := store.readMailboxesVersion()
|
||||
// If a read error occurs it returns 0 which is an invalid version value.
|
||||
if localVersion == 0 {
|
||||
localVersion = 1
|
||||
_ = store.writeMailboxesVersion(localVersion)
|
||||
}
|
||||
|
||||
// versionOffset will make email clients reload if increased during bridge update.
|
||||
return localVersion + versionOffset
|
||||
}
|
||||
|
||||
func (store *Store) increaseMailboxesVersion() error {
|
||||
ver := store.readMailboxesVersion()
|
||||
// The version is zero if a read error occurred. Operation ++ will make it 1
|
||||
// which is default starting value.
|
||||
ver++
|
||||
return store.writeMailboxesVersion(ver)
|
||||
}
|
||||
|
||||
func (store *Store) readMailboxesVersion() (version uint32) {
|
||||
_ = store.db.View(func(tx *bolt.Tx) (err error) {
|
||||
b := tx.Bucket(mboxVersionBucket)
|
||||
verRaw := b.Get([]byte(versionKey))
|
||||
if verRaw != nil {
|
||||
version = btoi(verRaw)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (store *Store) writeMailboxesVersion(ver uint32) error {
|
||||
return store.db.Update(func(tx *bolt.Tx) (err error) {
|
||||
b := tx.Bucket(mboxVersionBucket)
|
||||
return b.Put([]byte(versionKey), itob(ver))
|
||||
})
|
||||
}
|
||||
129
internal/store/store_test.go
Normal file
129
internal/store/store_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
bridgemocks "github.com/ProtonMail/proton-bridge/internal/bridge/mocks"
|
||||
storeMocks "github.com/ProtonMail/proton-bridge/internal/store/mocks"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
addr1 = "niceaddress@pm.me"
|
||||
addrID1 = "niceaddressID"
|
||||
|
||||
addr2 = "jamesandmichalarecool@pm.me"
|
||||
addrID2 = "jamesandmichalarecool"
|
||||
)
|
||||
|
||||
type mocksForStore struct {
|
||||
tb testing.TB
|
||||
|
||||
ctrl *gomock.Controller
|
||||
events *storeMocks.MockListener
|
||||
api *bridgemocks.MockPMAPIProvider
|
||||
user *storeMocks.MockBridgeUser
|
||||
panicHandler *storeMocks.MockPanicHandler
|
||||
store *Store
|
||||
|
||||
tmpDir string
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
func initMocks(tb testing.TB) (*mocksForStore, func()) {
|
||||
ctrl := gomock.NewController(tb)
|
||||
mocks := &mocksForStore{
|
||||
tb: tb,
|
||||
ctrl: ctrl,
|
||||
events: storeMocks.NewMockListener(ctrl),
|
||||
api: bridgemocks.NewMockPMAPIProvider(ctrl),
|
||||
user: storeMocks.NewMockBridgeUser(ctrl),
|
||||
panicHandler: storeMocks.NewMockPanicHandler(ctrl),
|
||||
}
|
||||
|
||||
// Called during clean-up.
|
||||
mocks.panicHandler.EXPECT().HandlePanic().AnyTimes()
|
||||
|
||||
var err error
|
||||
mocks.tmpDir, err = ioutil.TempDir("", "store-test")
|
||||
require.NoError(tb, err)
|
||||
|
||||
cacheFile := filepath.Join(mocks.tmpDir, "cache.json")
|
||||
mocks.cache = NewCache(cacheFile)
|
||||
|
||||
return mocks, func() {
|
||||
if err := recover(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if mocks.store != nil {
|
||||
require.Nil(tb, mocks.store.Close())
|
||||
}
|
||||
ctrl.Finish()
|
||||
require.NoError(tb, os.RemoveAll(mocks.tmpDir))
|
||||
}
|
||||
}
|
||||
|
||||
func (mocks *mocksForStore) newStoreNoEvents(combinedMode bool) { //nolint[unparam]
|
||||
mocks.user.EXPECT().ID().Return("userID").AnyTimes()
|
||||
mocks.user.EXPECT().IsConnected().Return(true)
|
||||
mocks.user.EXPECT().IsCombinedAddressMode().Return(combinedMode)
|
||||
|
||||
mocks.api.EXPECT().Addresses().Return(pmapi.AddressList{
|
||||
{ID: addrID1, Email: addr1, Type: pmapi.OriginalAddress, Receive: pmapi.CanReceive},
|
||||
{ID: addrID2, Email: addr2, Type: pmapi.AliasAddress, Receive: pmapi.CanReceive},
|
||||
})
|
||||
mocks.api.EXPECT().ListLabels()
|
||||
mocks.api.EXPECT().CountMessages("")
|
||||
mocks.api.EXPECT().GetEvent(gomock.Any()).
|
||||
Return(&pmapi.Event{
|
||||
EventID: "latestEventID",
|
||||
}, nil).AnyTimes()
|
||||
|
||||
// We want to wait until first sync has finished.
|
||||
firstSyncWaiter := sync.WaitGroup{}
|
||||
firstSyncWaiter.Add(1)
|
||||
mocks.api.EXPECT().
|
||||
ListMessages(gomock.Any()).
|
||||
DoAndReturn(func(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
|
||||
firstSyncWaiter.Done()
|
||||
return []*pmapi.Message{}, 0, nil
|
||||
})
|
||||
|
||||
var err error
|
||||
mocks.store, err = New(
|
||||
mocks.panicHandler,
|
||||
mocks.user,
|
||||
mocks.api,
|
||||
mocks.events,
|
||||
filepath.Join(mocks.tmpDir, "mailbox-test.db"),
|
||||
mocks.cache,
|
||||
)
|
||||
require.NoError(mocks.tb, err)
|
||||
|
||||
// Wait for sync to finish.
|
||||
firstSyncWaiter.Wait()
|
||||
}
|
||||
145
internal/store/store_test_exports.go
Normal file
145
internal/store/store_test_exports.go
Normal file
@ -0,0 +1,145 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// TestSync triggers a sync of the store.
|
||||
func (store *Store) TestSync() {
|
||||
store.triggerSync()
|
||||
}
|
||||
|
||||
// TestPollNow triggers a loop of the event loop.
|
||||
func (store *Store) TestPollNow() {
|
||||
store.eventLoop.pollNow()
|
||||
}
|
||||
|
||||
// TestIsSyncRunning returns whether the sync is currently ongoing.
|
||||
func (store *Store) TestIsSyncRunning() bool {
|
||||
return store.isSyncRunning
|
||||
}
|
||||
|
||||
// TestGetEventLoop returns the store's event loop.
|
||||
func (store *Store) TestGetEventLoop() *eventLoop { //nolint[golint]
|
||||
return store.eventLoop
|
||||
}
|
||||
|
||||
// TestGetStoreFilePath returns the filepath of the store's database file.
|
||||
func (store *Store) TestGetStoreFilePath() string {
|
||||
return store.filePath
|
||||
}
|
||||
|
||||
// TestDumpDB will dump store database content.
|
||||
func (store *Store) TestDumpDB(tb assert.TestingT) {
|
||||
dumpCounts := true
|
||||
fmt.Printf(">>>>>>>> DUMP %s <<<<<\n\n", store.db.Path())
|
||||
|
||||
txMails := txDumpMailsFactory(tb)
|
||||
|
||||
txDump := func(tx *bolt.Tx) error {
|
||||
if dumpCounts {
|
||||
if err := txDumpCounts(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := txMails(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
assert.NoError(tb, store.db.View(txDump))
|
||||
}
|
||||
|
||||
func txDumpMailsFactory(tb assert.TestingT) func(tx *bolt.Tx) error {
|
||||
return func(tx *bolt.Tx) error {
|
||||
mailboxes := tx.Bucket(mailboxesBucket)
|
||||
metadata := tx.Bucket(metadataBucket)
|
||||
err := mailboxes.ForEach(func(mboxName, mboxData []byte) error {
|
||||
fmt.Println("mbox:", string(mboxName))
|
||||
b := mailboxes.Bucket(mboxName).Bucket(imapIDsBucket)
|
||||
c := b.Cursor()
|
||||
i := 0
|
||||
for imapID, apiID := c.First(); imapID != nil; imapID, apiID = c.Next() {
|
||||
i++
|
||||
fmt.Println(" ", i, "imap", btoi(imapID), "api", string(apiID))
|
||||
data := metadata.Get(apiID)
|
||||
if !assert.NotNil(tb, data) {
|
||||
continue
|
||||
}
|
||||
if !assert.NoError(tb, txMailMeta(data, i)) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
fmt.Println("total:", i)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func txDumpCounts(tx *bolt.Tx) error {
|
||||
counts := tx.Bucket(countsBucket)
|
||||
err := counts.ForEach(func(labelID, countsB []byte) error {
|
||||
defer fmt.Println()
|
||||
fmt.Printf("counts id: %q ", string(labelID))
|
||||
counts := &mailboxCounts{}
|
||||
if err := json.Unmarshal(countsB, counts); err != nil {
|
||||
fmt.Printf(" Error %v", err)
|
||||
return nil
|
||||
}
|
||||
fmt.Printf(" total :%d unread %d", counts.TotalOnAPI, counts.UnreadOnAPI)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func txMailMeta(data []byte, i int) error {
|
||||
fullMetaDump := false
|
||||
msg := &pmapi.Message{}
|
||||
if err := json.Unmarshal(data, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Body != "" {
|
||||
fmt.Printf(" %d body %s\n\n", i, msg.Body)
|
||||
panic("NONZERO BODY")
|
||||
}
|
||||
if i >= 10 {
|
||||
return nil
|
||||
}
|
||||
if fullMetaDump {
|
||||
fmt.Printf(" %d meta %s\n\n", i, string(data))
|
||||
} else {
|
||||
fmt.Println(
|
||||
" Subj", msg.Subject,
|
||||
"\n From", msg.Sender,
|
||||
"\n Time", msg.Time,
|
||||
"\n Labels", msg.LabelIDs,
|
||||
"\n Unread", msg.Unread,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
222
internal/store/sync.go
Normal file
222
internal/store/sync.go
Normal file
@ -0,0 +1,222 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
syncMinPagesPerWorker = 10
|
||||
syncMessagesMaxWorkers = 5
|
||||
maxFilterPageSize = 150
|
||||
)
|
||||
|
||||
type storeSynchronizer interface {
|
||||
getAllMessageIDs() ([]string, error)
|
||||
createOrUpdateMessagesEvent([]*pmapi.Message) error
|
||||
deleteMessagesEvent([]string) error
|
||||
saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string)
|
||||
}
|
||||
|
||||
type messageLister interface {
|
||||
ListMessages(*pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
}
|
||||
|
||||
func syncAllMail(panicHandler PanicHandler, store storeSynchronizer, api messageLister, syncState *syncState) error {
|
||||
labelID := pmapi.AllMailLabel
|
||||
|
||||
// When the full sync starts (i.e. is not already in progress), we need to load
|
||||
// - all message IDs in database, so we can see which messages we need to remove at the end of the sync
|
||||
// - ID ranges which indicate how to split work into multiple workers
|
||||
if !syncState.isIncomplete() {
|
||||
if err := syncState.loadMessageIDsToBeDeleted(); err != nil {
|
||||
return errors.Wrap(err, "failed to load message IDs")
|
||||
}
|
||||
|
||||
if err := findIDRanges(labelID, api, syncState); err != nil {
|
||||
return errors.Wrap(err, "failed to load IDs ranges")
|
||||
}
|
||||
syncState.save()
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
shouldStop := 0 // Using integer to have it atomic.
|
||||
var resultError error
|
||||
|
||||
for _, idRange := range syncState.idRanges {
|
||||
wg.Add(1)
|
||||
idRange := idRange // Bind for goroutine.
|
||||
go func() {
|
||||
defer panicHandler.HandlePanic()
|
||||
defer wg.Done()
|
||||
|
||||
err := syncBatch(labelID, store, api, syncState, idRange, &shouldStop)
|
||||
if err != nil {
|
||||
shouldStop = 1
|
||||
resultError = errors.Wrap(err, "failed to sync group")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if resultError == nil {
|
||||
if err := syncState.deleteMessagesToBeDeleted(); err != nil {
|
||||
return errors.Wrap(err, "failed to delete messages")
|
||||
}
|
||||
}
|
||||
|
||||
return resultError
|
||||
}
|
||||
|
||||
func findIDRanges(labelID string, api messageLister, syncState *syncState) error {
|
||||
_, count, err := getSplitIDAndCount(labelID, api, 0)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get first ID and count")
|
||||
}
|
||||
log.WithField("total", count).Debug("Finding ID ranges")
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
syncState.initIDRanges()
|
||||
|
||||
pages := int(math.Ceil(float64(count) / float64(maxFilterPageSize)))
|
||||
workers := (pages / syncMinPagesPerWorker) + 1
|
||||
if workers > syncMessagesMaxWorkers {
|
||||
workers = syncMessagesMaxWorkers
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
step := int(math.Round(float64(pages) / float64(workers)))
|
||||
// Increment steps in case there are more steps than max # of workers (due to rounding).
|
||||
if (step*syncMessagesMaxWorkers)+1 < pages {
|
||||
step++
|
||||
}
|
||||
|
||||
for page := step; page < pages; page += step {
|
||||
splitID, _, err := getSplitIDAndCount(labelID, api, page)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get IDs range")
|
||||
}
|
||||
// Some messages were probably deleted and so the page does not exist anymore.
|
||||
// Would be good to start this function again, but let's rather start the sync instead of
|
||||
// wasting time of many calls to API to find where to split workers.
|
||||
if splitID == "" {
|
||||
break
|
||||
}
|
||||
syncState.addIDRange(splitID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSplitIDAndCount(labelID string, api messageLister, page int) (string, int, error) {
|
||||
sort := "ID"
|
||||
desc := false
|
||||
filter := &pmapi.MessagesFilter{
|
||||
LabelID: labelID,
|
||||
Sort: sort,
|
||||
Desc: &desc,
|
||||
PageSize: maxFilterPageSize,
|
||||
Page: page,
|
||||
Limit: 1,
|
||||
}
|
||||
// If the page does not exist, an empty page instead of an error is returned.
|
||||
messages, total, err := api.ListMessages(filter)
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return "", 0, nil
|
||||
}
|
||||
return messages[0].ID, total, nil
|
||||
}
|
||||
|
||||
func syncBatch( //nolint[funlen]
|
||||
labelID string,
|
||||
store storeSynchronizer,
|
||||
api messageLister,
|
||||
syncState *syncState,
|
||||
idRange *syncIDRange,
|
||||
shouldStop *int,
|
||||
) error {
|
||||
log.WithField("start", idRange.StartID).WithField("stop", idRange.StopID).Info("Starting sync batch")
|
||||
for {
|
||||
if *shouldStop == 1 || idRange.isFinished() {
|
||||
break
|
||||
}
|
||||
|
||||
sort := "ID"
|
||||
desc := true
|
||||
filter := &pmapi.MessagesFilter{
|
||||
LabelID: labelID,
|
||||
Sort: sort,
|
||||
Desc: &desc,
|
||||
PageSize: maxFilterPageSize,
|
||||
Page: 0,
|
||||
|
||||
// Messages with BeginID and EndID are included. We will process
|
||||
// those messages twice, but that's OK.
|
||||
// When message is completely removed, it still works as expected.
|
||||
BeginID: idRange.StartID,
|
||||
EndID: idRange.StopID,
|
||||
}
|
||||
|
||||
log.WithField("begin", filter.BeginID).WithField("end", filter.EndID).Debug("Fetching page")
|
||||
|
||||
messages, _, err := api.ListMessages(filter)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to list messages")
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, m := range messages {
|
||||
syncState.doNotDeleteMessageID(m.ID)
|
||||
}
|
||||
syncState.save()
|
||||
|
||||
if err := store.createOrUpdateMessagesEvent(messages); err != nil {
|
||||
return errors.Wrap(err, "failed to create or update messages")
|
||||
}
|
||||
|
||||
pageLastMessageID := messages[len(messages)-1].ID
|
||||
if !desc {
|
||||
idRange.setStartID(pageLastMessageID)
|
||||
} else {
|
||||
idRange.setStopID(pageLastMessageID)
|
||||
}
|
||||
|
||||
if len(messages) < maxFilterPageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
217
internal/store/sync_state.go
Normal file
217
internal/store/sync_state.go
Normal file
@ -0,0 +1,217 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type syncState struct {
|
||||
lock *sync.RWMutex
|
||||
store storeSynchronizer
|
||||
|
||||
// finishTime is the time, when the sync was finished for the last time.
|
||||
// When it's zero, it was never finished or the sync is ongoing.
|
||||
finishTime int64
|
||||
|
||||
// idRanges are ID ranges which are used to split work in several workers.
|
||||
// On the beginning of the sync it will find split IDs which are used to
|
||||
// create this ranges. If we have 10000 messages and five workers, it will
|
||||
// find IDs around 2000, 4000, 6000 and 8000 and then first worker will
|
||||
// sync IDs 0-2000, second 2000-4000 and so on.
|
||||
idRanges []*syncIDRange
|
||||
|
||||
// idsToBeDeletedMap is map with keys as message IDs. On the beginning
|
||||
// of the sync, it will load all message IDs in database. During the sync,
|
||||
// it will delete all messages from the map which were sycned. The rest
|
||||
// at the end of the sync will be removed as those messages were not synced
|
||||
// again. We do that because we don't want to remove everything on the
|
||||
// beginning of the sync to keep client synced.
|
||||
idsToBeDeletedMap map[string]bool
|
||||
}
|
||||
|
||||
func newSyncState(store storeSynchronizer, finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) *syncState {
|
||||
idsToBeDeletedMap := map[string]bool{}
|
||||
for _, id := range idsToBeDeleted {
|
||||
idsToBeDeletedMap[id] = true
|
||||
}
|
||||
|
||||
syncState := &syncState{
|
||||
lock: &sync.RWMutex{},
|
||||
store: store,
|
||||
|
||||
finishTime: finishTime,
|
||||
idRanges: idRanges,
|
||||
idsToBeDeletedMap: idsToBeDeletedMap,
|
||||
}
|
||||
|
||||
for _, idRange := range idRanges {
|
||||
idRange.syncState = syncState
|
||||
}
|
||||
|
||||
return syncState
|
||||
}
|
||||
|
||||
func (s *syncState) save() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.store.saveSyncState(s.finishTime, s.idRanges, s.getIDsToBeDeleted())
|
||||
}
|
||||
|
||||
// isIncomplete returns whether the sync is in progress (no matter whether
|
||||
// the sync is running or just not finished by info from database).
|
||||
func (s *syncState) isIncomplete() bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
return s.finishTime == 0 && len(s.idRanges) != 0
|
||||
}
|
||||
|
||||
// isFinished returns whether the sync was finished.
|
||||
func (s *syncState) isFinished() bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
return s.finishTime != 0
|
||||
}
|
||||
|
||||
// clearFinishTime sets finish time to zero.
|
||||
func (s *syncState) clearFinishTime() {
|
||||
s.lock.Lock()
|
||||
defer s.save()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.finishTime = 0
|
||||
}
|
||||
|
||||
// setFinishTime sets finish time to current time.
|
||||
func (s *syncState) setFinishTime() {
|
||||
s.lock.Lock()
|
||||
defer s.save()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.finishTime = time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// initIDRanges inits the main full range. Then each range is added
|
||||
// by `addIDRange`.
|
||||
func (s *syncState) initIDRanges() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.idRanges = []*syncIDRange{{
|
||||
syncState: s,
|
||||
StartID: "",
|
||||
StopID: "",
|
||||
}}
|
||||
}
|
||||
|
||||
// addIDRange sets `splitID` as stopID for last range and adds new one
|
||||
// starting with `splitID`.
|
||||
func (s *syncState) addIDRange(splitID string) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
lastGroup := s.idRanges[len(s.idRanges)-1]
|
||||
lastGroup.StopID = splitID
|
||||
|
||||
s.idRanges = append(s.idRanges, &syncIDRange{
|
||||
syncState: s,
|
||||
StartID: splitID,
|
||||
StopID: "",
|
||||
})
|
||||
}
|
||||
|
||||
// loadMessageIDsToBeDeleted loads all message IDs from database
|
||||
// and by default all IDs are meant for deletion. During sync for
|
||||
// each ID `doNotDeleteMessageID` has to be called to remove that
|
||||
// message from being deleted by `deleteMessagesToBeDeleted`.
|
||||
func (s *syncState) loadMessageIDsToBeDeleted() error {
|
||||
idsToBeDeletedMap := make(map[string]bool)
|
||||
ids, err := s.store.getAllMessageIDs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, id := range ids {
|
||||
idsToBeDeletedMap[id] = true
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
defer s.save()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.idsToBeDeletedMap = idsToBeDeletedMap
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *syncState) doNotDeleteMessageID(id string) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
delete(s.idsToBeDeletedMap, id)
|
||||
}
|
||||
|
||||
func (s *syncState) deleteMessagesToBeDeleted() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
idsToBeDeleted := s.getIDsToBeDeleted()
|
||||
log.Infof("Deleting %v messages after sync", len(idsToBeDeleted))
|
||||
if err := s.store.deleteMessagesEvent(idsToBeDeleted); err != nil {
|
||||
return errors.Wrap(err, "failed to delete messages")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getIDsToBeDeleted is helper to convert internal map for easier
|
||||
// manipulation to array.
|
||||
func (s *syncState) getIDsToBeDeleted() []string {
|
||||
keys := []string{}
|
||||
for key := range s.idsToBeDeletedMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// syncIDRange holds range which IDs need to be synced.
|
||||
type syncIDRange struct {
|
||||
syncState *syncState
|
||||
StartID string
|
||||
StopID string
|
||||
}
|
||||
|
||||
func (r *syncIDRange) setStartID(startID string) {
|
||||
r.StartID = startID
|
||||
r.syncState.save()
|
||||
}
|
||||
|
||||
func (r *syncIDRange) setStopID(stopID string) {
|
||||
r.StopID = stopID
|
||||
r.syncState.save()
|
||||
}
|
||||
|
||||
// isFinished returns syncIDRange is finished when StartID and StopID
|
||||
// are the same. But it cannot be full range, full range cannot be
|
||||
// determined in other way than asking API.
|
||||
func (r *syncIDRange) isFinished() bool {
|
||||
return r.StartID == r.StopID && r.StartID != ""
|
||||
}
|
||||
85
internal/store/sync_state_test.go
Normal file
85
internal/store/sync_state_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSyncState_IDRanges(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{}
|
||||
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
|
||||
|
||||
syncState.initIDRanges()
|
||||
syncState.addIDRange("100")
|
||||
syncState.addIDRange("200")
|
||||
|
||||
r := syncState.idRanges
|
||||
assert.Equal(t, "", r[0].StartID)
|
||||
assert.Equal(t, "100", r[0].StopID)
|
||||
assert.Equal(t, "100", r[1].StartID)
|
||||
assert.Equal(t, "200", r[1].StopID)
|
||||
assert.Equal(t, "200", r[2].StartID)
|
||||
assert.Equal(t, "", r[2].StopID)
|
||||
}
|
||||
|
||||
func TestSyncState_IDRangesLoaded(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{}
|
||||
syncState := newSyncState(store, 0, []*syncIDRange{
|
||||
{StartID: "", StopID: "100"},
|
||||
{StartID: "100", StopID: ""},
|
||||
}, []string{})
|
||||
|
||||
r := syncState.idRanges
|
||||
assert.Equal(t, "", r[0].StartID)
|
||||
assert.Equal(t, "100", r[0].StopID)
|
||||
assert.Equal(t, "100", r[1].StartID)
|
||||
assert.Equal(t, "", r[1].StopID)
|
||||
}
|
||||
|
||||
func TestSyncState_IDsToBeDeleted(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{
|
||||
allMessageIDs: generateIDs(1, 9),
|
||||
}
|
||||
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
|
||||
|
||||
require.Nil(t, syncState.loadMessageIDsToBeDeleted())
|
||||
syncState.doNotDeleteMessageID("1")
|
||||
syncState.doNotDeleteMessageID("2")
|
||||
syncState.doNotDeleteMessageID("3")
|
||||
syncState.doNotDeleteMessageID("notthere")
|
||||
|
||||
idsToBeDeleted := syncState.getIDsToBeDeleted()
|
||||
sort.Strings(idsToBeDeleted)
|
||||
assert.Equal(t, generateIDs(4, 9), idsToBeDeleted)
|
||||
}
|
||||
|
||||
func TestSyncState_IDsToBeDeletedLoaded(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{
|
||||
allMessageIDs: generateIDs(1, 9),
|
||||
}
|
||||
syncState := newSyncState(store, 0, []*syncIDRange{}, generateIDs(4, 9))
|
||||
|
||||
idsToBeDeleted := syncState.getIDsToBeDeleted()
|
||||
sort.Strings(idsToBeDeleted)
|
||||
assert.Equal(t, generateIDs(4, 9), idsToBeDeleted)
|
||||
}
|
||||
509
internal/store/sync_test.go
Normal file
509
internal/store/sync_test.go
Normal file
@ -0,0 +1,509 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockLister struct {
|
||||
err error
|
||||
messageIDs []string
|
||||
}
|
||||
|
||||
func (m *mockLister) ListMessages(filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
|
||||
if m.err != nil {
|
||||
return nil, 0, m.err
|
||||
}
|
||||
skipByID := true
|
||||
skipByPaging := filter.PageSize * filter.Page
|
||||
for idx := 0; idx < len(m.messageIDs); idx++ {
|
||||
var messageID string
|
||||
if !*filter.Desc {
|
||||
messageID = m.messageIDs[idx]
|
||||
if filter.BeginID == "" || messageID == filter.BeginID {
|
||||
skipByID = false
|
||||
}
|
||||
} else {
|
||||
messageID = m.messageIDs[len(m.messageIDs)-1-idx]
|
||||
if filter.EndID == "" || messageID == filter.EndID {
|
||||
skipByID = false
|
||||
}
|
||||
}
|
||||
if skipByID {
|
||||
continue
|
||||
}
|
||||
skipByPaging--
|
||||
if skipByPaging > 0 {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, &pmapi.Message{
|
||||
ID: messageID,
|
||||
})
|
||||
if len(msgs) == filter.PageSize || len(msgs) == filter.Limit {
|
||||
break
|
||||
}
|
||||
if !*filter.Desc {
|
||||
if messageID == filter.EndID {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if messageID == filter.BeginID {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return msgs, len(m.messageIDs), nil
|
||||
}
|
||||
|
||||
type mockStoreSynchronizer struct {
|
||||
allMessageIDs []string
|
||||
errCreateOrUpdateMessagesEvent error
|
||||
createdMessageIDsByBatch [][]string
|
||||
}
|
||||
|
||||
func (m *mockStoreSynchronizer) getAllMessageIDs() ([]string, error) {
|
||||
return m.allMessageIDs, nil
|
||||
}
|
||||
|
||||
func (m *mockStoreSynchronizer) createOrUpdateMessagesEvent(messages []*pmapi.Message) error {
|
||||
if m.errCreateOrUpdateMessagesEvent != nil {
|
||||
return m.errCreateOrUpdateMessagesEvent
|
||||
}
|
||||
createdMessageIDs := []string{}
|
||||
for _, message := range messages {
|
||||
createdMessageIDs = append(createdMessageIDs, message.ID)
|
||||
}
|
||||
m.createdMessageIDsByBatch = append(m.createdMessageIDsByBatch, createdMessageIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStoreSynchronizer) deleteMessagesEvent([]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStoreSynchronizer) saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) {
|
||||
}
|
||||
|
||||
func newTestSyncState(store storeSynchronizer, splitIDs ...string) *syncState {
|
||||
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
|
||||
syncState.initIDRanges()
|
||||
for _, splitID := range splitIDs {
|
||||
syncState.addIDRange(splitID)
|
||||
}
|
||||
return syncState
|
||||
}
|
||||
|
||||
func generateIDs(start, stop int) []string {
|
||||
ids := []string{}
|
||||
for x := start; x <= stop; x++ {
|
||||
ids = append(ids, strconv.Itoa(x))
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func generateIDsR(start, stop int) []string {
|
||||
ids := []string{}
|
||||
for x := start; x >= stop; x-- {
|
||||
ids = append(ids, strconv.Itoa(x))
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func TestSyncAllMail(t *testing.T) { //nolint[funlen]
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
numberOfMessages := 10000
|
||||
|
||||
api := &mockLister{
|
||||
messageIDs: generateIDs(1, numberOfMessages),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
idRanges []*syncIDRange
|
||||
idsToBeDeleted []string
|
||||
wantUpdatedIDs []string
|
||||
wantNotUpdatedIDs []string
|
||||
}{
|
||||
{
|
||||
"full sync",
|
||||
[]*syncIDRange{},
|
||||
[]string{},
|
||||
api.messageIDs,
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"continue with interrupted sync",
|
||||
[]*syncIDRange{
|
||||
{StartID: "2000", StopID: "2100"},
|
||||
{StartID: "4000", StopID: "4200"},
|
||||
{StartID: "9500", StopID: ""},
|
||||
},
|
||||
mergeArrays(generateIDs(2000, 2100), generateIDs(4000, 4200), generateIDs(9500, 10010)),
|
||||
mergeArrays(generateIDs(2000, 2100), generateIDs(4000, 4200), generateIDs(9500, 10000)),
|
||||
mergeArrays(generateIDs(1, 1999), generateIDs(2101, 3999), generateIDs(4201, 9459)),
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{
|
||||
allMessageIDs: generateIDs(1, numberOfMessages+10),
|
||||
}
|
||||
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Check all messages were created or updated.
|
||||
updateMessageIDsMap := map[string]bool{}
|
||||
for _, messageIDs := range store.createdMessageIDsByBatch {
|
||||
for _, messageID := range messageIDs {
|
||||
updateMessageIDsMap[messageID] = true
|
||||
}
|
||||
}
|
||||
for _, messageID := range tc.wantUpdatedIDs {
|
||||
assert.True(t, updateMessageIDsMap[messageID], "Message %s was not created/updated, but should", messageID)
|
||||
}
|
||||
for _, messageID := range tc.wantNotUpdatedIDs {
|
||||
assert.False(t, updateMessageIDsMap[messageID], "Message %s was created/updated, but shouldn't", messageID)
|
||||
}
|
||||
|
||||
// Check all messages were deleted.
|
||||
idsToBeDeleted := syncState.getIDsToBeDeleted()
|
||||
sort.Strings(idsToBeDeleted)
|
||||
assert.Equal(t, generateIDs(numberOfMessages+1, numberOfMessages+10), idsToBeDeleted)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mergeArrays(arrays ...[]string) []string {
|
||||
result := []string{}
|
||||
for _, array := range arrays {
|
||||
result = append(result, array...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestSyncAllMail_FailedListing(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
numberOfMessages := 10000
|
||||
|
||||
store := &mockStoreSynchronizer{
|
||||
allMessageIDs: generateIDs(1, numberOfMessages+10),
|
||||
}
|
||||
api := &mockLister{
|
||||
err: errors.New("error"),
|
||||
messageIDs: generateIDs(1, numberOfMessages),
|
||||
}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.EqualError(t, err, "failed to sync group: failed to list messages: error")
|
||||
}
|
||||
|
||||
func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
numberOfMessages := 10000
|
||||
|
||||
store := &mockStoreSynchronizer{
|
||||
errCreateOrUpdateMessagesEvent: errors.New("error"),
|
||||
allMessageIDs: generateIDs(1, numberOfMessages+10),
|
||||
}
|
||||
api := &mockLister{
|
||||
messageIDs: generateIDs(1, numberOfMessages),
|
||||
}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := syncAllMail(m.panicHandler, store, api, syncState)
|
||||
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
|
||||
}
|
||||
|
||||
func TestFindIDRanges(t *testing.T) { //nolint[funlen]
|
||||
store := &mockStoreSynchronizer{}
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messageIDs []string
|
||||
wantBatches [][]string
|
||||
}{
|
||||
{
|
||||
"1k messages - 1 batch",
|
||||
generateIDs(1, 1000),
|
||||
[][]string{
|
||||
{"", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
"1k messages not starting at 1",
|
||||
generateIDs(1000, 2000),
|
||||
[][]string{
|
||||
{"", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
"2k messages - 2 batches",
|
||||
generateIDs(1, 2000),
|
||||
[][]string{
|
||||
{"", "1050"},
|
||||
{"1050", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
"4k messages - 3 batches",
|
||||
generateIDs(1, 4000),
|
||||
[][]string{
|
||||
{"", "1350"},
|
||||
{"1350", "2700"},
|
||||
{"2700", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
"5k messages - 4 batches",
|
||||
generateIDs(1, 5000),
|
||||
[][]string{
|
||||
{"", "1350"},
|
||||
{"1350", "2700"},
|
||||
{"2700", "4050"},
|
||||
{"4050", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
"10k messages - 5 batches",
|
||||
generateIDs(1, 10000),
|
||||
[][]string{
|
||||
{"", "2100"},
|
||||
{"2100", "4200"},
|
||||
{"4200", "6300"},
|
||||
{"6300", "8400"},
|
||||
{"8400", ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
"150k messages - 5 batches",
|
||||
generateIDs(1, 150000),
|
||||
[][]string{
|
||||
{"", "30000"},
|
||||
{"30000", "60000"},
|
||||
{"60000", "90000"},
|
||||
{"90000", "120000"},
|
||||
{"120000", ""},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
api := &mockLister{
|
||||
messageIDs: tc.messageIDs,
|
||||
}
|
||||
|
||||
err := findIDRanges(pmapi.AllMailLabel, api, syncState)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, len(tc.wantBatches), len(syncState.idRanges))
|
||||
for idx, idRange := range syncState.idRanges {
|
||||
want := tc.wantBatches[idx]
|
||||
assert.Equal(t, want[0], idRange.StartID, "Start ID for IDs range %d does not match", idx)
|
||||
assert.Equal(t, want[1], idRange.StopID, "Stop ID for IDs range %d does not match", idx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindIDRanges_FailedListing(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{}
|
||||
api := &mockLister{
|
||||
err: errors.New("error"),
|
||||
}
|
||||
|
||||
syncState := newTestSyncState(store)
|
||||
|
||||
err := findIDRanges(pmapi.AllMailLabel, api, syncState)
|
||||
require.EqualError(t, err, "failed to get first ID and count: failed to list messages: error")
|
||||
}
|
||||
|
||||
func TestGetSplitIDAndCount(t *testing.T) { //nolint[funlen]
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
messageIDs []string
|
||||
page int
|
||||
wantID string
|
||||
wantTotal int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
"1000 messages, first page",
|
||||
nil,
|
||||
generateIDs(1, 1000),
|
||||
0,
|
||||
"1",
|
||||
1000,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"1000 messages, 5th page",
|
||||
nil,
|
||||
generateIDs(1, 1000),
|
||||
4,
|
||||
"600",
|
||||
1000,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"one message, first page",
|
||||
nil,
|
||||
[]string{"1"},
|
||||
0,
|
||||
"1",
|
||||
1,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"no message, first page",
|
||||
nil,
|
||||
[]string{},
|
||||
0,
|
||||
"",
|
||||
0,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"listing error",
|
||||
errors.New("error"),
|
||||
generateIDs(1, 1000),
|
||||
0,
|
||||
"",
|
||||
0,
|
||||
"failed to list messages: error",
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
api := &mockLister{
|
||||
err: tc.err,
|
||||
messageIDs: tc.messageIDs,
|
||||
}
|
||||
|
||||
id, total, err := getSplitIDAndCount(pmapi.AllMailLabel, api, tc.page)
|
||||
|
||||
if tc.wantErr == "" {
|
||||
require.Nil(t, err)
|
||||
} else {
|
||||
require.EqualError(t, err, tc.wantErr)
|
||||
}
|
||||
assert.Equal(t, tc.wantID, id)
|
||||
assert.Equal(t, tc.wantTotal, total)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncBatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
batchIdx int
|
||||
wantCreatedMessageIDsByBatch [][]string
|
||||
}{
|
||||
{
|
||||
"first-batch",
|
||||
0,
|
||||
[][]string{generateIDsR(200, 51), generateIDsR(51, 1)},
|
||||
},
|
||||
{
|
||||
"second-batch",
|
||||
1,
|
||||
[][]string{generateIDsR(400, 251), generateIDsR(251, 200)},
|
||||
},
|
||||
{
|
||||
"third-batch",
|
||||
2,
|
||||
[][]string{generateIDsR(600, 451), generateIDsR(451, 400)},
|
||||
},
|
||||
{
|
||||
"fourth-batch",
|
||||
3,
|
||||
[][]string{generateIDsR(800, 651), generateIDsR(651, 600)},
|
||||
},
|
||||
{
|
||||
"fifth-batch",
|
||||
4,
|
||||
[][]string{generateIDsR(1000, 851), generateIDsR(851, 800)},
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{}
|
||||
api := &mockLister{
|
||||
messageIDs: generateIDs(1, 1000),
|
||||
}
|
||||
|
||||
err := testSyncBatch(t, store, api, tc.batchIdx, "200", "400", "600", "800")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, tc.wantCreatedMessageIDsByBatch, store.createdMessageIDsByBatch)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncBatch_FailedListing(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{}
|
||||
api := &mockLister{
|
||||
err: errors.New("error"),
|
||||
messageIDs: generateIDs(1, 1000),
|
||||
}
|
||||
|
||||
err := testSyncBatch(t, store, api, 0)
|
||||
require.EqualError(t, err, "failed to list messages: error")
|
||||
}
|
||||
|
||||
func TestSyncBatch_FailedCreateOrUpdateMessage(t *testing.T) {
|
||||
store := &mockStoreSynchronizer{
|
||||
errCreateOrUpdateMessagesEvent: errors.New("error"),
|
||||
}
|
||||
api := &mockLister{
|
||||
messageIDs: generateIDs(1, 1000),
|
||||
}
|
||||
|
||||
err := testSyncBatch(t, store, api, 0)
|
||||
require.EqualError(t, err, "failed to create or update messages: error")
|
||||
}
|
||||
|
||||
func testSyncBatch(t *testing.T, store storeSynchronizer, api messageLister, rangeIdx int, splitIDs ...string) error { //nolint[unparam]
|
||||
syncState := newTestSyncState(store, splitIDs...)
|
||||
idRange := syncState.idRanges[rangeIdx]
|
||||
shouldStop := 0
|
||||
return syncBatch(pmapi.AllMailLabel, store, api, syncState, idRange, &shouldStop)
|
||||
}
|
||||
69
internal/store/types.go
Normal file
69
internal/store/types.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
)
|
||||
|
||||
type PanicHandler interface {
|
||||
HandlePanic()
|
||||
}
|
||||
|
||||
// PMAPIProvider is subset of pmapi.Client for use by the Store.
|
||||
type PMAPIProvider interface {
|
||||
CurrentUser() (*pmapi.User, error)
|
||||
Addresses() pmapi.AddressList
|
||||
|
||||
GetEvent(eventID string) (*pmapi.Event, error)
|
||||
|
||||
CountMessages(addressID string) ([]*pmapi.MessagesCount, error)
|
||||
ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error)
|
||||
GetMessage(apiID string) (*pmapi.Message, error)
|
||||
Import([]*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error)
|
||||
DeleteMessages(apiIDs []string) error
|
||||
LabelMessages(apiIDs []string, labelID string) error
|
||||
UnlabelMessages(apiIDs []string, labelID string) error
|
||||
MarkMessagesRead(apiIDs []string) error
|
||||
MarkMessagesUnread(apiIDs []string) error
|
||||
|
||||
CreateDraft(m *pmapi.Message, parent string, action int) (created *pmapi.Message, err error)
|
||||
CreateAttachment(att *pmapi.Attachment, r io.Reader, sig io.Reader) (created *pmapi.Attachment, err error)
|
||||
SendMessage(messageID string, req *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error)
|
||||
|
||||
ListLabels() ([]*pmapi.Label, error)
|
||||
CreateLabel(label *pmapi.Label) (*pmapi.Label, error)
|
||||
UpdateLabel(label *pmapi.Label) (*pmapi.Label, error)
|
||||
DeleteLabel(labelID string) error
|
||||
EmptyFolder(labelID string, addressID string) error
|
||||
}
|
||||
|
||||
// BridgeUser is subset of bridge.User for use by the Store.
|
||||
type BridgeUser interface {
|
||||
ID() string
|
||||
GetAddressID(address string) (string, error)
|
||||
IsConnected() bool
|
||||
IsCombinedAddressMode() bool
|
||||
GetPrimaryAddress() string
|
||||
GetStoreAddresses() []string
|
||||
UpdateUser() error
|
||||
CloseConnection(string)
|
||||
Logout() error
|
||||
}
|
||||
64
internal/store/ulimit.go
Normal file
64
internal/store/ulimit.go
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func uLimit() int {
|
||||
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
|
||||
return 0
|
||||
}
|
||||
out, err := exec.Command("bash", "-c", "ulimit -n").Output()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return 0
|
||||
}
|
||||
outStr := strings.Trim(string(out), " \n")
|
||||
num, err := strconv.Atoi(outStr)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return 0
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func isFdCloseToULimit() bool {
|
||||
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
pid := fmt.Sprint(os.Getpid())
|
||||
out, err := exec.Command("lsof", "-p", pid).Output() //nolint[gosec]
|
||||
if err != nil {
|
||||
log.Warn("isFdCloseToULimit: ", err)
|
||||
return false
|
||||
}
|
||||
lines := strings.Split(string(out), "\n")
|
||||
|
||||
fd := len(lines) - 1
|
||||
ulimit := uLimit()
|
||||
log.Info("File descriptor check: num goroutines ", runtime.NumGoroutine(), " fd ", fd, " ulimit ", ulimit)
|
||||
return fd >= int(0.95*float64(ulimit))
|
||||
}
|
||||
41
internal/store/user.go
Normal file
41
internal/store/user.go
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
// UserID returns user ID.
|
||||
func (store *Store) UserID() string {
|
||||
return store.user.ID()
|
||||
}
|
||||
|
||||
// GetSpace returns used and total space in bytes.
|
||||
func (store *Store) GetSpace() (usedSpace, maxSpace uint, err error) {
|
||||
apiUser, err := store.api.CurrentUser()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return uint(apiUser.UsedSpace), uint(apiUser.MaxSpace), nil
|
||||
}
|
||||
|
||||
// GetMaxUpload returns max size of attachment in bytes.
|
||||
func (store *Store) GetMaxUpload() (uint, error) {
|
||||
apiUser, err := store.api.CurrentUser()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint(apiUser.MaxUpload), nil
|
||||
}
|
||||
216
internal/store/user_address.go
Normal file
216
internal/store/user_address.go
Normal file
@ -0,0 +1,216 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// GetAddress returns the store address by given ID.
|
||||
func (store *Store) GetAddress(addressID string) (*Address, error) {
|
||||
store.lock.RLock()
|
||||
defer store.lock.RUnlock()
|
||||
|
||||
storeAddress, ok := store.addresses[addressID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("addressID %v does not exist", addressID)
|
||||
}
|
||||
|
||||
return storeAddress, nil
|
||||
}
|
||||
|
||||
// RebuildMailboxes truncates all mailbox buckets and recreates them from the metadata bucket again.
|
||||
func (store *Store) RebuildMailboxes() (err error) {
|
||||
store.lock.Lock()
|
||||
defer store.lock.Unlock()
|
||||
|
||||
log.WithField("user", store.UserID()).Trace("Truncating mailboxes")
|
||||
|
||||
if err = store.truncateMailboxesBucket(); err != nil {
|
||||
log.WithError(err).Error("Could not truncate mailboxes bucket")
|
||||
return
|
||||
}
|
||||
|
||||
if err = store.truncateAddressInfoBucket(); err != nil {
|
||||
log.WithError(err).Error("Could not truncate address info bucket")
|
||||
return
|
||||
}
|
||||
|
||||
if err = store.init(false); err != nil {
|
||||
log.WithError(err).Error("Could not init store")
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.increaseMailboxesVersion(); err != nil {
|
||||
log.WithError(err).Error("Could not increase structure version")
|
||||
// Do not return here. The truncation was already done and mode
|
||||
// was changed in DB so we need to sync so that users start to see
|
||||
// messages and not block other operations.
|
||||
}
|
||||
|
||||
log.WithField("user", store.UserID()).Trace("Rebuilding mailboxes")
|
||||
store.triggerSync()
|
||||
return nil
|
||||
}
|
||||
|
||||
// createOrDeleteAddressesEvent creates address objects in the store for each necessary address
|
||||
// and deletes any address objects that shouldn't be there.
|
||||
// It doesn't do anything to addresses that are rightfully there.
|
||||
// It should only be called from the event loop.
|
||||
func (store *Store) createOrDeleteAddressesEvent() (err error) {
|
||||
labels, err := store.initCounts()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to initialise label counts")
|
||||
}
|
||||
|
||||
addrInfo, err := store.GetAddressInfo()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get addresses and address IDs")
|
||||
}
|
||||
|
||||
// We need at least one address to continue.
|
||||
if len(addrInfo) < 1 {
|
||||
return errors.New("no addresses to initialise")
|
||||
}
|
||||
|
||||
// If in combined mode, we only need the user's primary address.
|
||||
if store.addressMode == combinedMode {
|
||||
addrInfo = addrInfo[:1]
|
||||
}
|
||||
|
||||
// Go through all addresses that *should* be there.
|
||||
for _, addr := range addrInfo {
|
||||
if _, ok := store.addresses[addr.AddressID]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// This address is missing so we create it.
|
||||
if err = store.addAddress(addr.Address, addr.AddressID, labels); err != nil {
|
||||
return errors.Wrap(err, "failed to add address to store")
|
||||
}
|
||||
}
|
||||
|
||||
// Go through all addresses that *should not* be there.
|
||||
for _, addr := range store.addresses {
|
||||
belongs := false
|
||||
|
||||
for _, a := range addrInfo {
|
||||
if addr.addressID == a.AddressID {
|
||||
belongs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if belongs {
|
||||
continue
|
||||
}
|
||||
|
||||
delete(store.addresses, addr.addressID)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// truncateAddressInfoBucket removes the address info bucket.
|
||||
func (store *Store) truncateAddressInfoBucket() (err error) {
|
||||
log.Trace("Truncating address info bucket")
|
||||
|
||||
tx := func(tx *bolt.Tx) (err error) {
|
||||
if err = tx.DeleteBucket(addressInfoBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = tx.CreateBucketIfNotExists(addressInfoBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return store.db.Update(tx)
|
||||
}
|
||||
|
||||
// truncateMailboxesBucket removes the mailboxes bucket.
|
||||
func (store *Store) truncateMailboxesBucket() (err error) {
|
||||
log.Trace("Truncating mailboxes bucket")
|
||||
|
||||
store.addresses = nil
|
||||
|
||||
tx := func(tx *bolt.Tx) (err error) {
|
||||
mbs := tx.Bucket(mailboxesBucket)
|
||||
|
||||
return mbs.ForEach(func(addrIDMailbox, _ []byte) (err error) {
|
||||
addr := mbs.Bucket(addrIDMailbox)
|
||||
|
||||
if err = addr.DeleteBucket(imapIDsBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = addr.CreateBucketIfNotExists(imapIDsBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = addr.DeleteBucket(apiIDsBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = addr.CreateBucketIfNotExists(apiIDsBucket); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
return store.db.Update(tx)
|
||||
}
|
||||
|
||||
// initMailboxesBucket recreates the mailboxes bucket from the metadata bucket.
|
||||
func (store *Store) initMailboxesBucket() error { //nolint[unused]
|
||||
i := 0
|
||||
tx := func(tx *bolt.Tx) error {
|
||||
return tx.Bucket(metadataBucket).ForEach(func(k, v []byte) error {
|
||||
msg := &pmapi.Message{}
|
||||
|
||||
if err := json.Unmarshal(v, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, a := range store.addresses {
|
||||
if err := a.txCreateOrUpdateMessages(tx, []*pmapi.Message{msg}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
i++
|
||||
if i%100 == 0 {
|
||||
store.log.WithField("i", i).
|
||||
Trace("Init mboxes heartbeat")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return store.db.Update(tx)
|
||||
}
|
||||
158
internal/store/user_address_info.go
Normal file
158
internal/store/user_address_info.go
Normal file
@ -0,0 +1,158 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// AddressInfo is the format of the data held in the addresses bucket in the store.
|
||||
// It allows us to easily keep an address and its ID together and serialisation/deserialisation to []byte.
|
||||
type AddressInfo struct {
|
||||
Address, AddressID string
|
||||
}
|
||||
|
||||
// GetAddressID returns the ID of the given address.
|
||||
func (store *Store) GetAddressID(addr string) (id string, err error) {
|
||||
addrs, err := store.GetAddressInfo()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, addrInfo := range addrs {
|
||||
if strings.EqualFold(addrInfo.Address, addr) {
|
||||
id = addrInfo.AddressID
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = errors.New("no such address")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetAddressInfo returns information about all addresses owned by the user.
|
||||
// The first element is the user's primary address and the rest (if present) are aliases.
|
||||
// It tries to source the information from the store but if the store doesn't yet have it, it
|
||||
// fetches it from the API and caches it for later.
|
||||
func (store *Store) GetAddressInfo() (addrs []AddressInfo, err error) {
|
||||
if addrs, err = store.getAddressInfoFromStore(); err == nil && len(addrs) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Store does not have address info yet, need to build it first from API.
|
||||
addressList := store.api.Addresses()
|
||||
if addressList == nil {
|
||||
err = errors.New("addresses unavailable")
|
||||
store.log.WithError(err).Error("Could not get user addresses from API")
|
||||
return
|
||||
}
|
||||
|
||||
if err = store.createOrUpdateAddressInfo(addressList); err != nil {
|
||||
store.log.WithError(err).Warn("Could not update address IDs in store")
|
||||
return
|
||||
}
|
||||
|
||||
return store.getAddressInfoFromStore()
|
||||
}
|
||||
|
||||
// getAddressIDsByAddressFromStore returns a map from address to addressID for each address owned by the user.
|
||||
func (store *Store) getAddressInfoFromStore() (addrs []AddressInfo, err error) {
|
||||
store.log.Debug("Retrieving address info from store")
|
||||
|
||||
tx := func(tx *bolt.Tx) (err error) {
|
||||
c := tx.Bucket(addressInfoBucket).Cursor()
|
||||
for index, addrInfoBytes := c.First(); index != nil; index, addrInfoBytes = c.Next() {
|
||||
var addrInfo AddressInfo
|
||||
|
||||
if err = json.Unmarshal(addrInfoBytes, &addrInfo); err != nil {
|
||||
store.log.WithError(err).Error("Could not unmarshal address and addressID")
|
||||
return
|
||||
}
|
||||
|
||||
addrs = append(addrs, addrInfo)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = store.db.View(tx)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// createOrUpdateAddressInfo updates the store address/addressID bucket to match the given address list.
|
||||
// The address list supplied is assumed to contain active emails in any order.
|
||||
// It firstly (and stupidly) deletes the bucket of addresses and then fills it with up to date info.
|
||||
// This is because a user might delete an address and we don't want old addresses lying around (and finding the
|
||||
// specific ones to delete is likely not much more efficient than just rebuilding from scratch).
|
||||
func (store *Store) createOrUpdateAddressInfo(addressList pmapi.AddressList) (err error) {
|
||||
tx := func(tx *bolt.Tx) error {
|
||||
if err := tx.DeleteBucket(addressInfoBucket); err != nil {
|
||||
store.log.WithError(err).Error("Could not delete addressIDs bucket")
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.CreateBucketIfNotExists(addressInfoBucket); err != nil {
|
||||
store.log.WithError(err).Error("Could not recreate addressIDs bucket")
|
||||
return err
|
||||
}
|
||||
|
||||
addrsBucket := tx.Bucket(addressInfoBucket)
|
||||
|
||||
for index, address := range filterAddresses(addressList) {
|
||||
ib := itob(uint32(index))
|
||||
|
||||
info, err := json.Marshal(AddressInfo{
|
||||
Address: address.Email,
|
||||
AddressID: address.ID,
|
||||
})
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Could not marshal address and addressID")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := addrsBucket.Put(ib, info); err != nil {
|
||||
store.log.WithError(err).Error("Could not put address and addressID into store")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return store.db.Update(tx)
|
||||
}
|
||||
|
||||
// filterAddresses filters out inactive addresses and ensures the original address is listed first.
|
||||
func filterAddresses(addressList pmapi.AddressList) (filteredList pmapi.AddressList) {
|
||||
for _, address := range addressList {
|
||||
if address.Receive != pmapi.CanReceive {
|
||||
continue
|
||||
}
|
||||
|
||||
filteredList = append(filteredList, address)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
229
internal/store/user_mailbox.go
Normal file
229
internal/store/user_mailbox.go
Normal file
@ -0,0 +1,229 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// createMailbox creates the mailbox via the API.
|
||||
// The store mailbox is created later by processing an event.
|
||||
func (store *Store) createMailbox(name string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
log.WithField("name", name).Debug("Creating mailbox")
|
||||
|
||||
if store.hasMailbox(name) {
|
||||
return fmt.Errorf("mailbox %v already exists", name)
|
||||
}
|
||||
|
||||
color := store.leastUsedColor()
|
||||
|
||||
var exclusive int
|
||||
switch {
|
||||
case strings.HasPrefix(name, UserLabelsPrefix):
|
||||
name = strings.TrimPrefix(name, UserLabelsPrefix)
|
||||
exclusive = 0
|
||||
case strings.HasPrefix(name, UserFoldersPrefix):
|
||||
name = strings.TrimPrefix(name, UserFoldersPrefix)
|
||||
exclusive = 1
|
||||
default:
|
||||
// Ideally we would throw an error here, but then Outlook for
|
||||
// macOS keeps trying to make an IMAP Drafts folder and popping
|
||||
// up the error to the user.
|
||||
store.log.WithField("name", name).
|
||||
Warn("Ignoring creation of new mailbox in IMAP root")
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := store.api.CreateLabel(&pmapi.Label{
|
||||
Name: name,
|
||||
Color: color,
|
||||
Exclusive: exclusive,
|
||||
Type: pmapi.LabelTypeMailbox,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// allAddressesHaveMailbox returns whether each address has a mailbox with the given labelID.
|
||||
func (store *Store) allAddressesHaveMailbox(labelID string) bool {
|
||||
store.lock.RLock()
|
||||
defer store.lock.RUnlock()
|
||||
|
||||
for _, a := range store.addresses {
|
||||
addressHasMailbox := false
|
||||
for _, m := range a.mailboxes {
|
||||
if m.labelID == labelID {
|
||||
addressHasMailbox = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !addressHasMailbox {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// hasMailbox returns whether there is at least one address which has a mailbox with the given name.
|
||||
func (store *Store) hasMailbox(name string) bool {
|
||||
mailbox, _ := store.getMailbox(name)
|
||||
return mailbox != nil
|
||||
}
|
||||
|
||||
// getMailbox returns the first mailbox with the given name.
|
||||
func (store *Store) getMailbox(name string) (*Mailbox, error) {
|
||||
store.lock.RLock()
|
||||
defer store.lock.RUnlock()
|
||||
|
||||
for _, a := range store.addresses {
|
||||
for _, m := range a.mailboxes {
|
||||
if m.labelName == name {
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("mailbox %s does not exist", name)
|
||||
}
|
||||
|
||||
// leastUsedColor returns the least used color to be used for a newly created folder or label.
|
||||
func (store *Store) leastUsedColor() string {
|
||||
store.lock.RLock()
|
||||
defer store.lock.RUnlock()
|
||||
|
||||
usage := map[string]int{}
|
||||
for _, a := range store.addresses {
|
||||
for _, m := range a.mailboxes {
|
||||
if m.color != "" {
|
||||
usage[m.color]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
leastUsed := pmapi.LabelColors[0]
|
||||
for _, color := range pmapi.LabelColors {
|
||||
if usage[leastUsed] > usage[color] {
|
||||
leastUsed = color
|
||||
}
|
||||
}
|
||||
return leastUsed
|
||||
}
|
||||
|
||||
// updateMailbox updates the mailbox via the API.
|
||||
// The store mailbox is updated later by processing an event.
|
||||
func (store *Store) updateMailbox(labelID, newName, color string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
_, err := store.api.UpdateLabel(&pmapi.Label{
|
||||
ID: labelID,
|
||||
Name: newName,
|
||||
Color: color,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteMailbox deletes the mailbox via the API.
|
||||
// The store mailbox is deleted later by processing an event.
|
||||
func (store *Store) deleteMailbox(labelID, addressID string) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
if pmapi.IsSystemLabel(labelID) {
|
||||
var err error
|
||||
switch labelID {
|
||||
case pmapi.SpamLabel:
|
||||
err = store.api.EmptyFolder(pmapi.SpamLabel, addressID)
|
||||
case pmapi.TrashLabel:
|
||||
err = store.api.EmptyFolder(pmapi.TrashLabel, addressID)
|
||||
default:
|
||||
err = fmt.Errorf("cannot empty mailbox %v", labelID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return store.api.DeleteLabel(labelID)
|
||||
}
|
||||
|
||||
func (store *Store) createLabelsIfMissing(affectedLabelIDs map[string]bool) error {
|
||||
newLabelIDs := []string{}
|
||||
for labelID := range affectedLabelIDs {
|
||||
if pmapi.IsSystemLabel(labelID) || store.allAddressesHaveMailbox(labelID) {
|
||||
continue
|
||||
}
|
||||
newLabelIDs = append(newLabelIDs, labelID)
|
||||
}
|
||||
if len(newLabelIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
labels, err := store.api.ListLabels()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, newLabelID := range newLabelIDs {
|
||||
for _, label := range labels {
|
||||
if label.ID != newLabelID {
|
||||
continue
|
||||
}
|
||||
if err := store.createOrUpdateMailboxEvent(label); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createOrUpdateMailboxEvent creates or updates the mailbox in the store.
|
||||
// This is called from the event loop.
|
||||
func (store *Store) createOrUpdateMailboxEvent(label *pmapi.Label) error {
|
||||
store.lock.Lock()
|
||||
defer store.lock.Unlock()
|
||||
|
||||
if label.Type != pmapi.LabelTypeMailbox {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := store.createOrUpdateMailboxCountsBuckets([]*pmapi.Label{label}); err != nil {
|
||||
return errors.Wrap(err, "cannot update counts")
|
||||
}
|
||||
|
||||
for _, a := range store.addresses {
|
||||
if err := a.createOrUpdateMailboxEvent(label); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteMailboxEvent deletes the mailbox in the store.
|
||||
// This is called from the event loop.
|
||||
func (store *Store) deleteMailboxEvent(labelID string) error {
|
||||
store.lock.Lock()
|
||||
defer store.lock.Unlock()
|
||||
|
||||
_ = store.removeMailboxCount(labelID)
|
||||
|
||||
for _, a := range store.addresses {
|
||||
if err := a.deleteMailboxEvent(labelID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
329
internal/store/user_message.go
Normal file
329
internal/store/user_message.go
Normal file
@ -0,0 +1,329 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/mail"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
|
||||
pmcrypto "github.com/ProtonMail/gopenpgp/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
// CreateDraft creates draft with attachments.
|
||||
// If `attachedPublicKey` is passed, it's added to attachments.
|
||||
// Both draft and attachments are encrypted with passed `kr` key.
|
||||
func (store *Store) CreateDraft(
|
||||
kr *pmcrypto.KeyRing,
|
||||
message *pmapi.Message,
|
||||
attachmentReaders []io.Reader,
|
||||
attachedPublicKey,
|
||||
attachedPublicKeyName string,
|
||||
parentID string) (*pmapi.Message, []*pmapi.Attachment, error) {
|
||||
defer store.eventLoop.pollNow()
|
||||
|
||||
// Since this is a draft, we don't need to sign it.
|
||||
if err := message.Encrypt(kr, nil); err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to encrypt draft")
|
||||
}
|
||||
|
||||
attachments := message.Attachments
|
||||
message.Attachments = nil
|
||||
|
||||
draftAction := store.getDraftAction(message)
|
||||
draft, err := store.api.CreateDraft(message, parentID, draftAction)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create draft")
|
||||
}
|
||||
|
||||
if attachedPublicKey != "" {
|
||||
attachmentReaders = append(attachmentReaders, strings.NewReader(attachedPublicKey))
|
||||
publicKeyAttachment := &pmapi.Attachment{
|
||||
Name: attachedPublicKeyName + ".asc",
|
||||
MIMEType: "application/pgp-key",
|
||||
Header: textproto.MIMEHeader{},
|
||||
}
|
||||
attachments = append(attachments, publicKeyAttachment)
|
||||
}
|
||||
|
||||
for idx, attachment := range attachments {
|
||||
attachment.MessageID = draft.ID
|
||||
attachmentBody, _ := ioutil.ReadAll(attachmentReaders[idx])
|
||||
|
||||
createdAttachment, err := store.createAttachment(kr, attachment, attachmentBody)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "failed to create attachment for draft")
|
||||
}
|
||||
|
||||
attachments[idx] = createdAttachment
|
||||
}
|
||||
|
||||
return draft, attachments, nil
|
||||
}
|
||||
|
||||
func (store *Store) getDraftAction(message *pmapi.Message) int {
|
||||
// If not a reply, must be a forward.
|
||||
if len(message.Header["In-Reply-To"]) == 0 {
|
||||
return pmapi.DraftActionForward
|
||||
}
|
||||
return pmapi.DraftActionReply
|
||||
}
|
||||
|
||||
func (store *Store) createAttachment(kr *pmcrypto.KeyRing, attachment *pmapi.Attachment, attachmentBody []byte) (*pmapi.Attachment, error) {
|
||||
r := bytes.NewReader(attachmentBody)
|
||||
sigReader, err := attachment.DetachedSign(kr, r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to sign attachment")
|
||||
}
|
||||
|
||||
r = bytes.NewReader(attachmentBody)
|
||||
encReader, err := attachment.Encrypt(kr, r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to encrypt attachment")
|
||||
}
|
||||
|
||||
createdAttachment, err := store.api.CreateAttachment(attachment, encReader, sigReader)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create attachment")
|
||||
}
|
||||
|
||||
return createdAttachment, nil
|
||||
}
|
||||
|
||||
// SendMessage sends the message.
|
||||
func (store *Store) SendMessage(messageID string, req *pmapi.SendMessageReq) error {
|
||||
defer store.eventLoop.pollNow()
|
||||
_, _, err := store.api.SendMessage(messageID, req)
|
||||
return err
|
||||
}
|
||||
|
||||
// getAllMessageIDs returns all API IDs of messages in the local database.
|
||||
func (store *Store) getAllMessageIDs() (apiIDs []string, err error) {
|
||||
err = store.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(metadataBucket)
|
||||
return b.ForEach(func(k, v []byte) error {
|
||||
apiIDs = append(apiIDs, string(k))
|
||||
return nil
|
||||
})
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// getMessageFromDB returns pmapi struct of message by API ID.
|
||||
func (store *Store) getMessageFromDB(apiID string) (msg *pmapi.Message, err error) {
|
||||
err = store.db.View(func(tx *bolt.Tx) error {
|
||||
msg, err = store.txGetMessage(tx, apiID)
|
||||
return err
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// fetchMessage returns pmapi struct of message by API ID. If the requested
|
||||
// message is not in the database, it will try to fetch it from the server.
|
||||
// NOTE: Do not update the database here to prevent issues (extreme edge case).
|
||||
// The database will be updated by the event loop anyway.
|
||||
func (store *Store) fetchMessage(apiID string) (msg *pmapi.Message, err error) {
|
||||
if msg, err = store.api.GetMessage(apiID); err != nil {
|
||||
if err.Error() == "Message does not exist" {
|
||||
return nil, ErrNoSuchAPIID
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *Store) txGetMessage(tx *bolt.Tx, apiID string) (*pmapi.Message, error) {
|
||||
b := tx.Bucket(metadataBucket)
|
||||
|
||||
msgb := b.Get([]byte(apiID))
|
||||
if msgb == nil {
|
||||
return nil, ErrNoSuchAPIID
|
||||
}
|
||||
msg := &pmapi.Message{}
|
||||
if err := json.Unmarshal(msgb, msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (store *Store) txPutMessage(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message) error {
|
||||
b, err := json.Marshal(onlyMeta)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot marshall metadata")
|
||||
}
|
||||
err = metaBucket.Put([]byte(onlyMeta.ID), b)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot add to metadata bucket")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createOrUpdateMessageEvent is helper to create only one message with
|
||||
// createOrUpdateMessagesEvent.
|
||||
func (store *Store) createOrUpdateMessageEvent(msg *pmapi.Message) error {
|
||||
return store.createOrUpdateMessagesEvent([]*pmapi.Message{msg})
|
||||
}
|
||||
|
||||
// createOrUpdateMessagesEvent tries to create or update messages in database.
|
||||
// This function is optimised for insertion of many messages at once.
|
||||
// It calls createLabelsIfMissing if needed.
|
||||
func (store *Store) createOrUpdateMessagesEvent(msgs []*pmapi.Message) error { //nolint[funlen]
|
||||
store.log.WithField("msgs", msgs).Trace("Creating or updating messages in the store")
|
||||
|
||||
// Strip non meta first to reduce memory (no need to keep all old msg ID data during update).
|
||||
err := store.db.View(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket(metadataBucket)
|
||||
for _, msg := range msgs {
|
||||
clearNonMetadata(msg)
|
||||
txUpdateMetadaFromDB(b, msg, store.log)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedLabels := map[string]bool{}
|
||||
for _, m := range msgs {
|
||||
for _, l := range m.LabelIDs {
|
||||
affectedLabels[l] = true
|
||||
}
|
||||
}
|
||||
if err = store.createLabelsIfMissing(affectedLabels); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Updating metadata and mailboxes is not atomic, but this is OK.
|
||||
// The worst case scenario is we have metadata but not updated mailboxes
|
||||
// which is OK as without information in mailboxes IMAP we will never ask
|
||||
// for metadata. Also, when doing the operation again, it will simply
|
||||
// update the metadata.
|
||||
// The reason to split is efficiency--it's more memory efficient.
|
||||
|
||||
// Update metadata.
|
||||
err = store.db.Update(func(tx *bolt.Tx) error {
|
||||
metaBucket := tx.Bucket(metadataBucket)
|
||||
for _, msg := range msgs {
|
||||
err := store.txPutMessage(metaBucket, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update mailboxes.
|
||||
err = store.db.Update(func(tx *bolt.Tx) error {
|
||||
for _, a := range store.addresses {
|
||||
if err := a.txCreateOrUpdateMessages(tx, msgs); err != nil {
|
||||
store.log.WithError(err).Error("cannot update maiboxes")
|
||||
return errors.Wrap(err, "cannot add to mailboxes bucket")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearNonMetadata to not allow to store decrypted or encrypted data i.e. body
|
||||
// and attachments.
|
||||
func clearNonMetadata(onlyMeta *pmapi.Message) {
|
||||
onlyMeta.Body = ""
|
||||
onlyMeta.Attachments = nil
|
||||
}
|
||||
|
||||
// txUpdateMetadaFromDB changes the the onlyMeta data.
|
||||
// If there is stored message in metaBucket the size, header and MIMEType are
|
||||
// not changed if already set. To change these:
|
||||
// * size must be updated by Message.SetSize
|
||||
// * contentType and header must be updated by Message.SetContentTypeAndHeader
|
||||
func txUpdateMetadaFromDB(metaBucket *bolt.Bucket, onlyMeta *pmapi.Message, log *logrus.Entry) {
|
||||
// Size attribute on the server is counting encrypted data. We need to compute
|
||||
// "real" size of decrypted data. Negative values will be processed during fetch.
|
||||
onlyMeta.Size = -1
|
||||
|
||||
msgb := metaBucket.Get([]byte(onlyMeta.ID))
|
||||
if msgb == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// It is faster to unmarshal only the needed items.
|
||||
stored := &struct {
|
||||
Size int64
|
||||
Header string
|
||||
MIMEType string
|
||||
}{}
|
||||
if err := json.Unmarshal(msgb, stored); err != nil {
|
||||
log.WithError(err).
|
||||
Error("Fail to unmarshal from DB, metadata will be overwritten")
|
||||
return
|
||||
}
|
||||
|
||||
// Keep already calculated size and content type.
|
||||
onlyMeta.Size = stored.Size
|
||||
onlyMeta.MIMEType = stored.MIMEType
|
||||
if stored.Header != "" && stored.Header != "(No Header)" {
|
||||
tmpMsg, err := mail.ReadMessage(
|
||||
strings.NewReader(stored.Header + "\r\n\r\n"),
|
||||
)
|
||||
if err == nil {
|
||||
onlyMeta.Header = tmpMsg.Header
|
||||
} else {
|
||||
log.WithError(err).
|
||||
Error("Fail to parse, the header will be overwritten")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deleteMessageEvent is helper to delete only one message with deleteMessagesEvent.
|
||||
func (store *Store) deleteMessageEvent(apiID string) error {
|
||||
return store.deleteMessagesEvent([]string{apiID})
|
||||
}
|
||||
|
||||
// deleteMessagesEvent deletes the message from metadata and all mailbox buckets.
|
||||
func (store *Store) deleteMessagesEvent(apiIDs []string) error {
|
||||
return store.db.Update(func(tx *bolt.Tx) error {
|
||||
for _, apiID := range apiIDs {
|
||||
if err := tx.Bucket(metadataBucket).Delete([]byte(apiID)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, a := range store.addresses {
|
||||
if err := a.txDeleteMessage(tx, apiID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
156
internal/store/user_message_test.go
Normal file
156
internal/store/user_message_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
a "github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAllMessageIDs(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.ArchiveLabel})
|
||||
insertMessage(t, m, "msg3", "Test message 3", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg4", "Test message 4", addrID1, 0, []string{})
|
||||
|
||||
checkAllMessageIDs(t, m, []string{"msg1", "msg2", "msg3", "msg4"})
|
||||
}
|
||||
|
||||
func TestGetMessageFromDB(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
tests := []struct{ msgID, wantErr string }{
|
||||
{"msg1", ""},
|
||||
{"msg2", "no such api id"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.msgID, func(t *testing.T) {
|
||||
msg, err := m.store.getMessageFromDB(tc.msgID)
|
||||
if tc.wantErr != "" {
|
||||
require.EqualError(t, err, tc.wantErr)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, tc.msgID, msg.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateOrUpdateMessageMetadata(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
msg, err := m.store.getMessageFromDB("msg1")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Check non-meta and calculated data are cleared/empty.
|
||||
a.Equal(t, "", msg.Body)
|
||||
a.Equal(t, []*pmapi.Attachment(nil), msg.Attachments)
|
||||
a.Equal(t, int64(-1), msg.Size)
|
||||
a.Equal(t, "", msg.MIMEType)
|
||||
a.Equal(t, mail.Header(nil), msg.Header)
|
||||
|
||||
// Change the calculated data.
|
||||
wantSize := int64(42)
|
||||
wantMIMEType := "plain-text"
|
||||
wantHeader := mail.Header{
|
||||
"Key": []string{"value"},
|
||||
}
|
||||
|
||||
storeMsg, err := m.store.addresses[addrID1].mailboxes[pmapi.AllMailLabel].GetMessage("msg1")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, storeMsg.SetSize(wantSize))
|
||||
require.Nil(t, storeMsg.SetContentTypeAndHeader(wantMIMEType, wantHeader))
|
||||
|
||||
// Check calculated data.
|
||||
msg, err = m.store.getMessageFromDB("msg1")
|
||||
require.Nil(t, err)
|
||||
a.Equal(t, wantSize, msg.Size)
|
||||
a.Equal(t, wantMIMEType, msg.MIMEType)
|
||||
a.Equal(t, wantHeader, msg.Header)
|
||||
|
||||
// Check calculated data are not overridden by reinsert.
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
msg, err = m.store.getMessageFromDB("msg1")
|
||||
require.Nil(t, err)
|
||||
a.Equal(t, wantSize, msg.Size)
|
||||
a.Equal(t, wantMIMEType, msg.MIMEType)
|
||||
a.Equal(t, wantHeader, msg.Header)
|
||||
}
|
||||
|
||||
func TestDeleteMessage(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel})
|
||||
|
||||
require.Nil(t, m.store.deleteMessageEvent("msg1"))
|
||||
|
||||
checkAllMessageIDs(t, m, []string{"msg2"})
|
||||
checkMailboxMessageIDs(t, m, pmapi.AllMailLabel, []wantID{{"msg2", 2}})
|
||||
}
|
||||
|
||||
func insertMessage(t *testing.T, m *mocksForStore, id, subject, sender string, unread int, labelIDs []string) { //nolint[unparam]
|
||||
msg := getTestMessage(id, subject, sender, unread, labelIDs)
|
||||
require.Nil(t, m.store.createOrUpdateMessageEvent(msg))
|
||||
}
|
||||
|
||||
func getTestMessage(id, subject, sender string, unread int, labelIDs []string) *pmapi.Message {
|
||||
address := &mail.Address{Address: sender}
|
||||
return &pmapi.Message{
|
||||
ID: id,
|
||||
Subject: subject,
|
||||
Unread: unread,
|
||||
Sender: address,
|
||||
ToList: []*mail.Address{address},
|
||||
LabelIDs: labelIDs,
|
||||
Size: 12345,
|
||||
Body: "body of message",
|
||||
Attachments: []*pmapi.Attachment{{
|
||||
ID: "attachment1",
|
||||
MessageID: id,
|
||||
Name: "attachment",
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func checkAllMessageIDs(t *testing.T, m *mocksForStore, wantIDs []string) {
|
||||
allIds, allErr := m.store.getAllMessageIDs()
|
||||
require.Nil(t, allErr)
|
||||
require.Equal(t, wantIDs, allIds)
|
||||
}
|
||||
247
internal/store/user_sync.go
Normal file
247
internal/store/user_sync.go
Normal file
@ -0,0 +1,247 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const syncFinishTimeKey = "sync_state" // The original key was sync_state and we want to keep compatibility.
|
||||
const syncIDRangesKey = "id_ranges"
|
||||
const syncIDsToBeDeletedKey = "ids_to_be_deleted"
|
||||
|
||||
// updateCountsFromServer will download and set the counts.
|
||||
func (store *Store) updateCountsFromServer() error {
|
||||
counts, err := store.api.CountMessages("")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot update counts from server")
|
||||
}
|
||||
|
||||
return store.createOrUpdateOnAPICounts(counts)
|
||||
}
|
||||
|
||||
// isSynced checks whether DB counts are synced with provided counts from API.
|
||||
func (store *Store) isSynced(countsOnAPI []*pmapi.MessagesCount) (bool, error) {
|
||||
store.log.WithField("apiCounts", countsOnAPI).Debug("Checking whether store is synced")
|
||||
|
||||
// IMPORTANT: The countsOnAPI can contain duplicates due to event merge
|
||||
// (ie one label can be present multiple times). It is important to
|
||||
// process all counts before checking whether they are synced.
|
||||
if err := store.createOrUpdateOnAPICounts(countsOnAPI); err != nil {
|
||||
store.log.WithError(err).Error("Cannot update counts before check sync")
|
||||
return false, err
|
||||
}
|
||||
|
||||
allCounts, err := store.getOnAPICounts()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
store.lock.Lock()
|
||||
defer store.lock.Unlock()
|
||||
|
||||
countsAreOK := true
|
||||
for _, counts := range allCounts {
|
||||
total, unread := uint(0), uint(0)
|
||||
for _, address := range store.addresses {
|
||||
mbox, err := address.getMailboxByID(counts.LabelID)
|
||||
if err != nil {
|
||||
return false, errors.Wrapf(
|
||||
err,
|
||||
"cannot find mailbox for address %q",
|
||||
address.addressID,
|
||||
)
|
||||
}
|
||||
|
||||
mboxTot, mboxUnread, err := mbox.GetCounts()
|
||||
if err != nil {
|
||||
errW := errors.Wrap(err, "cannot count messages")
|
||||
store.log.
|
||||
WithError(errW).
|
||||
WithField("label", counts.LabelID).
|
||||
WithField("address", address.addressID).
|
||||
Error("IsSynced failed")
|
||||
return false, err
|
||||
}
|
||||
total += mboxTot
|
||||
unread += mboxUnread
|
||||
}
|
||||
|
||||
if total != counts.TotalOnAPI || unread != counts.UnreadOnAPI {
|
||||
store.log.WithFields(logrus.Fields{
|
||||
"label": counts.LabelID,
|
||||
"db-total": total,
|
||||
"db-unread": unread,
|
||||
"api-total": counts.TotalOnAPI,
|
||||
"api-unread": counts.UnreadOnAPI,
|
||||
}).Warning("counts differ")
|
||||
countsAreOK = false
|
||||
}
|
||||
}
|
||||
|
||||
return countsAreOK, nil
|
||||
}
|
||||
|
||||
// triggerSync starts a sync of complete user by syncing All Mail mailbox.
|
||||
// All Mail mailbox contains all messages, so we download all meta data needed
|
||||
// to generate any address/mailbox IMAP UIDs.
|
||||
// Sync state can be in three states:
|
||||
// * Nothing in database. For example when user logs in for the first time.
|
||||
// `triggerSync` will start full sync.
|
||||
// * Database has syncIDRangesKey and syncIDsToBeDeletedKey keys with data.
|
||||
// Sync is in progress or was interrupted. In later case when, `triggerSync`
|
||||
// will continue where it left off.
|
||||
// * Database has only syncStateKey with time when database was last synced.
|
||||
// `triggerSync` will reset it and start full sync again.
|
||||
func (store *Store) triggerSync() {
|
||||
syncState := store.loadSyncState()
|
||||
|
||||
// We first clear the last sync state in case this sync fails.
|
||||
syncState.clearFinishTime()
|
||||
|
||||
// We don't want sync to block.
|
||||
go func() {
|
||||
defer store.panicHandler.HandlePanic()
|
||||
|
||||
store.log.Debug("Store sync triggered")
|
||||
|
||||
store.lock.Lock()
|
||||
if store.isSyncRunning {
|
||||
store.lock.Unlock()
|
||||
store.log.Info("Store sync is already ongoing")
|
||||
return
|
||||
}
|
||||
store.isSyncRunning = true
|
||||
store.lock.Unlock()
|
||||
|
||||
defer func() {
|
||||
store.lock.Lock()
|
||||
store.isSyncRunning = false
|
||||
store.lock.Unlock()
|
||||
}()
|
||||
|
||||
store.log.WithField("isIncomplete", syncState.isIncomplete()).Info("Store sync started")
|
||||
|
||||
err := syncAllMail(store.panicHandler, store, store.api, syncState)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Store sync failed")
|
||||
return
|
||||
}
|
||||
|
||||
syncState.setFinishTime()
|
||||
}()
|
||||
}
|
||||
|
||||
// isSyncFinished returns whether the database has finished a sync.
|
||||
func (store *Store) isSyncFinished() (isSynced bool) {
|
||||
return store.loadSyncState().isFinished()
|
||||
}
|
||||
|
||||
// loadSyncState loads information about sync from database.
|
||||
// See `triggerSync` to learn more about possible states.
|
||||
func (store *Store) loadSyncState() *syncState {
|
||||
finishTime := int64(0)
|
||||
idRanges := []*syncIDRange{}
|
||||
idsToBeDeleted := []string{}
|
||||
|
||||
err := store.db.View(func(tx *bolt.Tx) (err error) {
|
||||
b := tx.Bucket(syncStateBucket)
|
||||
|
||||
finishTimeByte := b.Get([]byte(syncFinishTimeKey))
|
||||
if finishTimeByte != nil {
|
||||
finishTime, err = strconv.ParseInt(string(finishTimeByte), 10, 64)
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Failed to unmarshal sync IDs ranges")
|
||||
}
|
||||
}
|
||||
|
||||
idRangesData := b.Get([]byte(syncIDRangesKey))
|
||||
if idRangesData != nil {
|
||||
if err := json.Unmarshal(idRangesData, &idRanges); err != nil {
|
||||
store.log.WithError(err).Error("Failed to unmarshal sync IDs ranges")
|
||||
}
|
||||
}
|
||||
|
||||
idsToBeDeletedData := b.Get([]byte(syncIDsToBeDeletedKey))
|
||||
if idsToBeDeletedData != nil {
|
||||
if err := json.Unmarshal(idsToBeDeletedData, &idsToBeDeleted); err != nil {
|
||||
store.log.WithError(err).Error("Failed to unmarshal sync IDs to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Failed to load sync state")
|
||||
}
|
||||
|
||||
return newSyncState(store, finishTime, idRanges, idsToBeDeleted)
|
||||
}
|
||||
|
||||
// saveSyncState saves information about sync to database.
|
||||
// See `triggerSync` to learn more about possible states.
|
||||
func (store *Store) saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) {
|
||||
idRangesData, err := json.Marshal(idRanges)
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Failed to marshall sync IDs ranges")
|
||||
}
|
||||
|
||||
idsToBeDeletedData, err := json.Marshal(idsToBeDeleted)
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Failed to marshall sync IDs to be deleted")
|
||||
}
|
||||
|
||||
err = store.db.Update(func(tx *bolt.Tx) (err error) {
|
||||
b := tx.Bucket(syncStateBucket)
|
||||
if finishTime != 0 {
|
||||
curTime := []byte(fmt.Sprintf("%v", finishTime))
|
||||
if err := b.Put([]byte(syncFinishTimeKey), curTime); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Delete([]byte(syncIDRangesKey)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Delete([]byte(syncIDsToBeDeletedKey)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := b.Delete([]byte(syncFinishTimeKey)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Put([]byte(syncIDRangesKey), idRangesData); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.Put([]byte(syncIDsToBeDeletedKey), idsToBeDeletedData); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
store.log.WithError(err).Error("Failed to set sync state")
|
||||
}
|
||||
}
|
||||
90
internal/store/user_sync_test.go
Normal file
90
internal/store/user_sync_test.go
Normal file
@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2020 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoadSaveSyncState(t *testing.T) {
|
||||
m, clear := initMocks(t)
|
||||
defer clear()
|
||||
|
||||
m.newStoreNoEvents(true)
|
||||
insertMessage(t, m, "msg1", "Test message 1", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
insertMessage(t, m, "msg2", "Test message 2", addrID1, 0, []string{pmapi.AllMailLabel, pmapi.InboxLabel})
|
||||
|
||||
// Clear everything.
|
||||
|
||||
syncState := m.store.loadSyncState()
|
||||
syncState.clearFinishTime()
|
||||
|
||||
// Check everything is empty at the beginning.
|
||||
|
||||
syncState = m.store.loadSyncState()
|
||||
checkSyncStateAfterLoad(t, syncState, false, false, []string{})
|
||||
|
||||
// Save IDs ranges and check everything is also properly loaded.
|
||||
|
||||
syncState.initIDRanges()
|
||||
syncState.addIDRange("100")
|
||||
syncState.addIDRange("200")
|
||||
syncState.save()
|
||||
|
||||
syncState = m.store.loadSyncState()
|
||||
checkSyncStateAfterLoad(t, syncState, false, true, []string{})
|
||||
|
||||
// Save IDs to be deleted and check everything is properly loaded.
|
||||
|
||||
require.Nil(t, syncState.loadMessageIDsToBeDeleted())
|
||||
|
||||
syncState = m.store.loadSyncState()
|
||||
checkSyncStateAfterLoad(t, syncState, false, true, []string{"msg1", "msg2"})
|
||||
|
||||
// Set finish time and check everything is resetted to empty values.
|
||||
|
||||
syncState.setFinishTime()
|
||||
|
||||
syncState = m.store.loadSyncState()
|
||||
checkSyncStateAfterLoad(t, syncState, true, false, []string{})
|
||||
}
|
||||
|
||||
func checkSyncStateAfterLoad(t *testing.T, syncState *syncState, wantIsFinished bool, wantIDRanges bool, wantIDsToBeDeleted []string) {
|
||||
assert.Equal(t, wantIsFinished, syncState.isFinished())
|
||||
|
||||
if wantIDRanges {
|
||||
require.Equal(t, 3, len(syncState.idRanges))
|
||||
assert.Equal(t, "", syncState.idRanges[0].StartID)
|
||||
assert.Equal(t, "100", syncState.idRanges[0].StopID)
|
||||
assert.Equal(t, "100", syncState.idRanges[1].StartID)
|
||||
assert.Equal(t, "200", syncState.idRanges[1].StopID)
|
||||
assert.Equal(t, "200", syncState.idRanges[2].StartID)
|
||||
assert.Equal(t, "", syncState.idRanges[2].StopID)
|
||||
} else {
|
||||
assert.Empty(t, syncState.idRanges)
|
||||
}
|
||||
|
||||
idsToBeDeleted := syncState.getIDsToBeDeleted()
|
||||
sort.Strings(idsToBeDeleted)
|
||||
assert.Equal(t, wantIDsToBeDeleted, idsToBeDeleted)
|
||||
}
|
||||
Reference in New Issue
Block a user