Files
proton-bridge/internal/user/sync_downloader_test.go
2023-07-31 15:21:53 +02:00

473 lines
14 KiB
Go

// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"context"
"fmt"
"io"
"strings"
"testing"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/user/mocks"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestSyncDownloader_Parallel_429(t *testing.T) {
// Check 429 is correctly caught and download state recorded correctly
// Message 1: All ok
// Message 2: Message failed
// Message 3: One attachment failed.
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
panicHandler := &async.NoopPanicHandler{}
ctx := context.Background()
requests := downloadRequest{
ids: []string{"Msg1", "Msg2", "Msg3"},
expectedSize: 0,
err: nil,
}
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg1")).Times(1).Return(proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MsgID1",
NumAttachments: 1,
},
Attachments: []proton.Attachment{
{
ID: "Attachment1_1",
},
},
}, nil)
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg2")).Times(1).Return(proton.Message{}, &proton.APIError{Status: 429})
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "MsgID3",
NumAttachments: 2,
},
Attachments: []proton.Attachment{
{
ID: "Attachment3_1",
},
{
ID: "Attachment3_2",
},
},
}, nil)
const attachmentData = "attachment data"
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment1_1"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
_, err := r.ReadFrom(strings.NewReader(attachmentData))
return err
})
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_1"), gomock.Any()).Times(1).Return(&proton.APIError{Status: 429})
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("Attachment3_2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
_, err := r.ReadFrom(strings.NewReader(attachmentData))
return err
})
cache := newSyncDownloadCache()
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
defer attachmentDownloader.close()
result, err := downloadMessagesParallel(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
require.NoError(t, err)
require.Equal(t, 3, len(result))
// Check message 1
require.Equal(t, result[0].State, downloadStateFinished)
require.Equal(t, result[0].Message.ID, "MsgID1")
require.NotEmpty(t, result[0].Message.AttData)
require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
require.NotNil(t, result[0].Message.AttData[0])
require.Nil(t, result[0].err)
// Check message 2
require.Equal(t, result[1].State, downloadStateZero)
require.Empty(t, result[1].Message.ID)
require.NotNil(t, result[1].err)
require.Equal(t, result[2].State, downloadStateHasMessage)
require.Equal(t, result[2].Message.ID, "MsgID3")
require.Equal(t, 2, len(result[2].Message.AttData))
require.NotNil(t, result[2].err)
require.Nil(t, result[2].Message.AttData[0])
require.NotEqual(t, attachmentData, result[2].Message.AttData[1])
require.NotNil(t, result[2].err)
_, ok := cache.GetMessage("MsgID1")
require.True(t, ok)
_, ok = cache.GetMessage("MsgID3")
require.True(t, ok)
att, ok := cache.GetAttachment("Attachment1_1")
require.True(t, ok)
require.Equal(t, attachmentData, string(att))
}
func TestSyncDownloader_Stage2_Everything200(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
downloadResult := []downloadResult{
{
ID: "Msg1",
State: downloadStateFinished,
},
{
ID: "Msg2",
State: downloadStateFinished,
},
}
result, err := downloadMessagesSequential(ctx, downloadResult, messageDownloader, cache, &noCooldown{})
require.NoError(t, err)
require.Equal(t, 2, len(result))
}
func TestSyncDownloader_Stage2_Not429(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
msgErr := fmt.Errorf("something not 429")
downloadResult := []downloadResult{
{
ID: "Msg1",
State: downloadStateFinished,
},
{
ID: "Msg2",
State: downloadStateHasMessage,
err: msgErr,
},
{
ID: "Msg3",
State: downloadStateFinished,
},
}
_, err := downloadMessagesSequential(ctx, downloadResult, messageDownloader, cache, &noCooldown{})
require.Error(t, err)
require.Equal(t, msgErr, err)
}
func TestSyncDownloader_Stage2_API500(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
msgErr := &proton.APIError{Status: 500}
downloadResult := []downloadResult{
{
ID: "Msg2",
State: downloadStateHasMessage,
err: msgErr,
},
{
ID: "Msg3",
State: downloadStateFinished,
},
}
_, err := downloadMessagesSequential(ctx, downloadResult, messageDownloader, cache, &noCooldown{})
require.Error(t, err)
require.Equal(t, msgErr, err)
}
func TestSyncDownloader_Stage2_Some429(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
const attachmentData1 = "attachment data 1"
const attachmentData2 = "attachment data 2"
const attachmentData3 = "attachment data 3"
const attachmentData4 = "attachment data 4"
err429 := &proton.APIError{Status: 429}
downloadResult := []downloadResult{
{
// Full message , but missing 1 of 2 attachments
ID: "Msg1",
Message: proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "Msg1",
NumAttachments: 2,
},
Attachments: []proton.Attachment{
{
ID: "A3",
},
{
ID: "A4",
},
},
},
AttData: [][]byte{
nil,
[]byte(attachmentData4),
},
},
State: downloadStateHasMessage,
err: err429,
},
{
// Full message, but missing all attachments
ID: "Msg2",
Message: proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "Msg2",
NumAttachments: 2,
},
Attachments: []proton.Attachment{
{
ID: "A1",
},
{
ID: "A2",
},
},
},
AttData: nil,
},
State: downloadStateHasMessage,
err: err429,
},
{
// Missing everything
ID: "Msg3",
State: downloadStateZero,
Message: proton.FullMessage{
Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}},
},
err: err429,
},
}
{
// Simulate 2 failures for message 3 body.
firstCall := messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(2).Return(proton.Message{}, err429)
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).After(firstCall).Times(1).Return(proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "Msg3",
},
}, nil)
}
{
// Simulate failures for message 2 attachments.
firstCall := messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).Times(2).Return(err429)
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A1"), gomock.Any()).After(firstCall).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
_, err := r.ReadFrom(strings.NewReader(attachmentData1))
return err
})
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A2"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
_, err := r.ReadFrom(strings.NewReader(attachmentData2))
return err
})
}
{
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).DoAndReturn(func(_ context.Context, _ string, r io.ReaderFrom) error {
_, err := r.ReadFrom(strings.NewReader(attachmentData3))
return err
})
}
messages, err := downloadMessagesSequential(ctx, downloadResult, messageDownloader, cache, &noCooldown{})
require.NoError(t, err)
require.Equal(t, 3, len(messages))
require.Equal(t, messages[0].Message.ID, "Msg1")
require.Equal(t, messages[1].Message.ID, "Msg2")
require.Equal(t, messages[2].Message.ID, "Msg3")
// check attachments
require.Equal(t, attachmentData3, string(messages[0].AttData[0]))
require.Equal(t, attachmentData4, string(messages[0].AttData[1]))
require.Equal(t, attachmentData1, string(messages[1].AttData[0]))
require.Equal(t, attachmentData2, string(messages[1].AttData[1]))
require.Empty(t, messages[2].AttData)
_, ok := cache.GetMessage("Msg3")
require.True(t, ok)
att3, ok := cache.GetAttachment("A3")
require.True(t, ok)
require.Equal(t, attachmentData3, string(att3))
att1, ok := cache.GetAttachment("A1")
require.True(t, ok)
require.Equal(t, attachmentData1, string(att1))
att2, ok := cache.GetAttachment("A2")
require.True(t, ok)
require.Equal(t, attachmentData2, string(att2))
}
func TestSyncDownloader_Stage2_ErrorOnNon429MessageDownload(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
err429 := &proton.APIError{Status: 429}
err500 := &proton.APIError{Status: 500}
downloadResult := []downloadResult{
{
// Missing everything
ID: "Msg3",
State: downloadStateZero,
Message: proton.FullMessage{
Message: proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3"}},
},
err: err429,
},
{
// Full message , but missing 1 of 2 attachments
ID: "Msg1",
Message: proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "Msg1",
NumAttachments: 2,
},
Attachments: []proton.Attachment{
{
ID: "A3",
},
{
ID: "A4",
},
},
},
},
State: downloadStateHasMessage,
err: err429,
},
}
{
// Simulate 2 failures for message 3 body,
messageDownloader.EXPECT().GetMessage(gomock.Any(), gomock.Eq("Msg3")).Times(1).Return(proton.Message{}, err500)
}
messages, err := downloadMessagesSequential(ctx, downloadResult, messageDownloader, cache, &noCooldown{})
require.Error(t, err)
require.Empty(t, 0, messages)
}
func TestSyncDownloader_Stage2_ErrorOnNon429AttachmentDownload(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
ctx := context.Background()
cache := newSyncDownloadCache()
err429 := &proton.APIError{Status: 429}
err500 := &proton.APIError{Status: 500}
downloadResult := []downloadResult{
{
// Full message , but missing 1 of 2 attachments
ID: "Msg1",
Message: proton.FullMessage{
Message: proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "Msg1",
NumAttachments: 2,
},
Attachments: []proton.Attachment{
{
ID: "A3",
},
{
ID: "A4",
},
},
},
},
State: downloadStateHasMessage,
err: err429,
},
}
// 429 for first attachment
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A3"), gomock.Any()).Times(1).Return(err429)
// 500 for second attachment
messageDownloader.EXPECT().GetAttachmentInto(gomock.Any(), gomock.Eq("A4"), gomock.Any()).Times(1).Return(err500)
messages, err := downloadMessagesSequential(ctx, downloadResult, messageDownloader, cache, &noCooldown{})
require.Error(t, err)
require.Empty(t, 0, messages)
}
func TestSyncDownloader_Parallel_DoNotDownloadIfAlreadyInCache(t *testing.T) {
mockCtrl := gomock.NewController(t)
messageDownloader := mocks.NewMockMessageDownloader(mockCtrl)
panicHandler := &async.NoopPanicHandler{}
ctx := context.Background()
requests := downloadRequest{
ids: []string{"Msg1", "Msg3"},
expectedSize: 0,
err: nil,
}
cache := newSyncDownloadCache()
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, messageDownloader, cache, 1)
defer attachmentDownloader.close()
const attachmentData = "attachment data"
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg1", NumAttachments: 1}, Attachments: []proton.Attachment{{ID: "A1"}}})
cache.StoreMessage(proton.Message{MessageMetadata: proton.MessageMetadata{ID: "Msg3", NumAttachments: 2}, Attachments: []proton.Attachment{{ID: "A2"}}})
cache.StoreAttachment("A1", []byte(attachmentData))
cache.StoreAttachment("A2", []byte(attachmentData))
result, err := downloadMessagesParallel(ctx, panicHandler, requests, messageDownloader, attachmentDownloader, cache, 1)
require.NoError(t, err)
require.Equal(t, 2, len(result))
require.Equal(t, result[0].State, downloadStateFinished)
require.Equal(t, result[0].Message.ID, "Msg1")
require.NotEmpty(t, result[0].Message.AttData)
require.NotEqual(t, attachmentData, result[0].Message.AttData[0])
require.NotNil(t, result[0].Message.AttData[0])
require.Nil(t, result[0].err)
require.Equal(t, result[1].State, downloadStateFinished)
require.Equal(t, result[1].Message.ID, "Msg3")
require.NotEmpty(t, result[1].Message.AttData)
require.NotEqual(t, attachmentData, result[1].Message.AttData[0])
require.NotNil(t, result[1].Message.AttData[0])
require.Nil(t, result[1].err)
}