Files
proton-bridge/internal/services/imapservice/sync_state_provider.go
Leander Beernaert efbe84964f feat(GODT-2829): Integrate new sync service
Update imap service to use the new sync service.

The new sync state is stored as simple file on disk to avoid contention
with concurrent vault writes.
2023-08-25 15:21:00 +02:00

254 lines
5.4 KiB
Go

// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"sync"
"github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice"
"github.com/bradenaw/juniper/xmaps"
)
type SyncState struct {
filePath string
status syncservice.Status
lock sync.Mutex
}
var ErrInvalidSyncFileVersion = errors.New("invalid sync file version")
const SyncFileVersion = 1
type syncStateFile struct {
Version int
Data string
}
type syncFileVersion1 struct {
Status syncservice.Status
}
func NewSyncState(filePath string) (*SyncState, error) {
s := &SyncState{filePath: filePath, status: syncservice.DefaultStatus()}
if err := s.loadUnsafe(); err != nil {
return nil, err
}
return s, nil
}
func (s *SyncState) AddFailedMessageID(_ context.Context, ids ...string) error {
s.lock.Lock()
defer s.lock.Unlock()
count := len(s.status.FailedMessages)
for _, id := range ids {
s.status.FailedMessages.Add(id)
}
// Only update if something change.
if count == len(s.status.FailedMessages) {
return nil
}
return s.storeUnsafe()
}
func (s *SyncState) RemFailedMessageID(_ context.Context, ids ...string) error {
s.lock.Lock()
defer s.lock.Unlock()
count := len(s.status.FailedMessages)
for _, id := range ids {
s.status.FailedMessages.Remove(id)
}
// Only update if something change.
if count == len(s.status.FailedMessages) {
return nil
}
return s.storeUnsafe()
}
func (s *SyncState) GetSyncStatus(_ context.Context) (syncservice.Status, error) {
s.lock.Lock()
defer s.lock.Unlock()
return s.status, nil
}
func (s *SyncState) ClearSyncStatus(_ context.Context) error {
s.lock.Lock()
defer s.lock.Unlock()
oldStatus := s.status
s.status = syncservice.DefaultStatus()
if err := s.storeUnsafe(); err != nil {
s.status = oldStatus
return err
}
return nil
}
func (s *SyncState) SetHasLabels(_ context.Context, b bool) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.HasLabels = b
return s.storeUnsafe()
}
func (s *SyncState) SetHasMessages(_ context.Context, b bool) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.HasMessages = b
return s.storeUnsafe()
}
func (s *SyncState) SetLastMessageID(_ context.Context, s2 string, i int64) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.LastSyncedMessageID = s2
s.status.NumSyncedMessages += i
return s.storeUnsafe()
}
func (s *SyncState) SetMessageCount(_ context.Context, i int64) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.TotalMessageCount = i
s.status.HasMessageCount = true
return s.storeUnsafe()
}
func (s *SyncState) storeUnsafe() error {
return storeImpl(&s.status, s.filePath)
}
func storeImpl(status *syncservice.Status, path string) error {
data, err := json.Marshal(syncFileVersion1{Status: *status})
if err != nil {
return fmt.Errorf("failed to marshal sync state data: %w", err)
}
syncFile := syncStateFile{
Version: SyncFileVersion,
Data: string(data),
}
syncFileData, err := json.Marshal(syncFile)
if err != nil {
return fmt.Errorf("failde to marshal sync state file: %w", err)
}
tmpFile := path + ".tmp"
if err := os.WriteFile(tmpFile, syncFileData, 0o600); err != nil {
return fmt.Errorf("failed to write sync state to tmp file: %w", err)
}
if err := os.Rename(tmpFile, path); err != nil {
return fmt.Errorf("failed to update sync state: %w", err)
}
return nil
}
func (s *SyncState) loadUnsafe() error {
data, err := os.ReadFile(s.filePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var syncFile syncStateFile
if err := json.Unmarshal(data, &syncFile); err != nil {
return fmt.Errorf("failed to unmarshal sync file: %w", err)
}
if syncFile.Version != SyncFileVersion {
return ErrInvalidSyncFileVersion
}
var v1 syncFileVersion1
if err := json.Unmarshal([]byte(syncFile.Data), &v1); err != nil {
return fmt.Errorf("failed to unmarshal sync data: %w", err)
}
s.status = v1.Status
return nil
}
func DeleteSyncState(configDir, userID string) error {
path := getSyncConfigPath(configDir, userID)
return os.Remove(path)
}
func MigrateVaultSettings(
configDir, userID string,
hasLabels, hasMessages bool,
failedMessageIDs []string,
) (bool, error) {
filePath := getSyncConfigPath(configDir, userID)
_, err := os.ReadFile(filePath) //nolint:gosec
if err == nil {
// File already exists, sync has been migrated.
return false, nil
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
// unexpected error occurred.
return false, err
}
status := syncservice.DefaultStatus()
status.HasLabels = hasLabels
status.HasMessages = hasMessages
status.HasMessageCount = hasMessages
status.FailedMessages = xmaps.SetFromSlice(failedMessageIDs)
return true, storeImpl(&status, filePath)
}