Files
proton-bridge/internal/services/imapservice/sync_state_provider.go
Leander Beernaert 951c7c27fb fix(GODT-3003): Ensure IMAP State is reset after vault corruption
After we detect that the user has suffered the GODT-3003 bug due the
vault corruption not ensuring that a previous sync state would be
erased, we patch the gluon db directly and then reset the sync state.

After the account is added, the sync is automatically triggered and the
account state fixes itself.
2023-10-09 11:19:36 +01:00

258 lines
5.5 KiB
Go

// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package imapservice
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"sync"
"github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice"
"github.com/bradenaw/juniper/xmaps"
)
type SyncState struct {
filePath string
status syncservice.Status
lock sync.Mutex
}
var ErrInvalidSyncFileVersion = errors.New("invalid sync file version")
const SyncFileVersion = 1
type syncStateFile struct {
Version int
Data string
}
type syncFileVersion1 struct {
Status syncservice.Status
}
func NewSyncState(filePath string) (*SyncState, error) {
s := &SyncState{filePath: filePath, status: syncservice.DefaultStatus()}
if err := s.loadUnsafe(); err != nil {
return nil, err
}
return s, nil
}
func (s *SyncState) AddFailedMessageID(_ context.Context, ids ...string) error {
s.lock.Lock()
defer s.lock.Unlock()
count := len(s.status.FailedMessages)
for _, id := range ids {
s.status.FailedMessages.Add(id)
}
// Only update if something change.
if count == len(s.status.FailedMessages) {
return nil
}
return s.storeUnsafe()
}
func (s *SyncState) RemFailedMessageID(_ context.Context, ids ...string) error {
s.lock.Lock()
defer s.lock.Unlock()
count := len(s.status.FailedMessages)
for _, id := range ids {
s.status.FailedMessages.Remove(id)
}
// Only update if something change.
if count == len(s.status.FailedMessages) {
return nil
}
return s.storeUnsafe()
}
func (s *SyncState) GetSyncStatus(_ context.Context) (syncservice.Status, error) {
s.lock.Lock()
defer s.lock.Unlock()
return s.status, nil
}
func (s *SyncState) ClearSyncStatus(_ context.Context) error {
s.lock.Lock()
defer s.lock.Unlock()
oldStatus := s.status
s.status = syncservice.DefaultStatus()
if err := s.storeUnsafe(); err != nil {
s.status = oldStatus
return err
}
return nil
}
func (s *SyncState) SetHasLabels(_ context.Context, b bool) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.HasLabels = b
return s.storeUnsafe()
}
func (s *SyncState) SetHasMessages(_ context.Context, b bool) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.HasMessages = b
return s.storeUnsafe()
}
func (s *SyncState) SetLastMessageID(_ context.Context, s2 string, i int64) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.LastSyncedMessageID = s2
s.status.NumSyncedMessages += i
return s.storeUnsafe()
}
func (s *SyncState) SetMessageCount(_ context.Context, i int64) error {
s.lock.Lock()
defer s.lock.Unlock()
s.status.TotalMessageCount = i
s.status.HasMessageCount = true
return s.storeUnsafe()
}
func (s *SyncState) storeUnsafe() error {
return storeImpl(&s.status, s.filePath)
}
func storeImpl(status *syncservice.Status, path string) error {
data, err := json.Marshal(syncFileVersion1{Status: *status})
if err != nil {
return fmt.Errorf("failed to marshal sync state data: %w", err)
}
syncFile := syncStateFile{
Version: SyncFileVersion,
Data: string(data),
}
syncFileData, err := json.Marshal(syncFile)
if err != nil {
return fmt.Errorf("failde to marshal sync state file: %w", err)
}
tmpFile := path + ".tmp"
if err := os.WriteFile(tmpFile, syncFileData, 0o600); err != nil {
return fmt.Errorf("failed to write sync state to tmp file: %w", err)
}
if err := os.Rename(tmpFile, path); err != nil {
return fmt.Errorf("failed to update sync state: %w", err)
}
return nil
}
func (s *SyncState) loadUnsafe() error {
data, err := os.ReadFile(s.filePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
var syncFile syncStateFile
if err := json.Unmarshal(data, &syncFile); err != nil {
return fmt.Errorf("failed to unmarshal sync file: %w", err)
}
if syncFile.Version != SyncFileVersion {
return ErrInvalidSyncFileVersion
}
var v1 syncFileVersion1
if err := json.Unmarshal([]byte(syncFile.Data), &v1); err != nil {
return fmt.Errorf("failed to unmarshal sync data: %w", err)
}
s.status = v1.Status
return nil
}
func DeleteSyncState(configDir, userID string) error {
path := GetSyncConfigPath(configDir, userID)
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
return nil
}
func MigrateVaultSettings(
configDir, userID string,
hasLabels, hasMessages bool,
failedMessageIDs []string,
) (bool, error) {
filePath := GetSyncConfigPath(configDir, userID)
_, err := os.ReadFile(filePath) //nolint:gosec
if err == nil {
// File already exists, sync has been migrated.
return false, nil
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
// unexpected error occurred.
return false, err
}
status := syncservice.DefaultStatus()
status.HasLabels = hasLabels
status.HasMessages = hasMessages
status.HasMessageCount = hasMessages
status.FailedMessages = xmaps.SetFromSlice(failedMessageIDs)
return true, storeImpl(&status, filePath)
}