GODT-1779: Remove go-imap

This commit is contained in:
James Houlahan
2022-08-26 17:00:21 +02:00
parent 3b0bc1ca15
commit 39433fe707
593 changed files with 12725 additions and 91626 deletions

View File

@ -92,8 +92,7 @@ func TestRemoveWithExceptions(t *testing.T) {
}
func newTestDir(t *testing.T, subdirs ...string) string {
dir, err := os.MkdirTemp("", "test-files-dir")
require.NoError(t, err)
dir := t.TempDir()
for _, target := range subdirs {
require.NoError(t, os.MkdirAll(filepath.Join(dir, target), 0o700))

View File

@ -23,7 +23,6 @@ import (
"fmt"
"sync"
"github.com/ProtonMail/proton-bridge/v2/internal/config/settings"
"github.com/docker/docker-credential-helpers/credentials"
)
@ -48,20 +47,19 @@ var (
)
// NewKeychain creates a new native keychain.
func NewKeychain(s *settings.Settings, keychainName string) (*Keychain, error) {
func NewKeychain(preferred, keychainName string) (*Keychain, error) {
// There must be at least one keychain helper available.
if len(Helpers) < 1 {
return nil, ErrNoKeychain
}
// If the preferred keychain is unsupported, fallback to the default one.
// NOTE: Maybe we want to error out here and show something in the GUI instead?
if _, ok := Helpers[s.Get(settings.PreferredKeychainKey)]; !ok {
s.Set(settings.PreferredKeychainKey, defaultHelper)
if _, ok := Helpers[preferred]; !ok {
preferred = defaultHelper
}
// Load the user's preferred keychain helper.
helperConstructor, ok := Helpers[s.Get(settings.PreferredKeychainKey)]
helperConstructor, ok := Helpers[preferred]
if !ok {
return nil, ErrNoKeychain
}

View File

@ -1,214 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package listener
import (
"sync"
"time"
"github.com/sirupsen/logrus"
)
var log = logrus.WithField("pkg", "bridgeUtils/listener") //nolint:gochecknoglobals
// Listener has a list of channels watching for updates.
type Listener interface {
SetLimit(eventName string, limit time.Duration)
ProvideChannel(eventName string) <-chan string
Add(eventName string, channel chan<- string)
Remove(eventName string, channel chan<- string)
Emit(eventName string, data string)
SetBuffer(eventName string)
RetryEmit(eventName string)
Book(eventName string)
}
type listener struct {
channels map[string][]chan<- string
limits map[string]time.Duration
lastEmits map[string]map[string]time.Time
buffered map[string][]string
lock *sync.RWMutex
}
// New returns a new Listener which initially has no topics.
func New() Listener {
return &listener{
channels: nil,
limits: make(map[string]time.Duration),
lastEmits: make(map[string]map[string]time.Time),
buffered: make(map[string][]string),
lock: &sync.RWMutex{},
}
}
// Book wil create the list of channels for specific eventName. This should be
// used when there is not always listening channel available and it should not
// be logged when no channel is awaiting an emitted event.
func (l *listener) Book(eventName string) {
if l.channels == nil {
l.channels = make(map[string][]chan<- string)
}
if _, ok := l.channels[eventName]; !ok {
l.channels[eventName] = []chan<- string{}
}
log.WithField("name", eventName).Debug("Channel booked")
}
// SetLimit sets the limit for the `eventName`. When the same event (name and data)
// is emitted within last time duration (`limit`), event is dropped. Zero limit clears
// the limit for the specific `eventName`.
func (l *listener) SetLimit(eventName string, limit time.Duration) {
l.lock.Lock()
defer l.lock.Unlock()
if limit == 0 {
delete(l.limits, eventName)
return
}
l.limits[eventName] = limit
}
// ProvideChannel creates new channel, adds it to listener and sends to it
// bufferent events.
func (l *listener) ProvideChannel(eventName string) <-chan string {
ch := make(chan string)
l.Add(eventName, ch)
l.RetryEmit(eventName)
return ch
}
// Add adds an event listener.
func (l *listener) Add(eventName string, channel chan<- string) {
l.lock.Lock()
defer l.lock.Unlock()
if l.channels == nil {
l.channels = make(map[string][]chan<- string)
}
log := log.WithField("name", eventName).WithField("i", len(l.channels[eventName]))
l.channels[eventName] = append(l.channels[eventName], channel)
log.Debug("Added event listener")
}
// Remove removes an event listener.
func (l *listener) Remove(eventName string, channel chan<- string) {
l.lock.Lock()
defer l.lock.Unlock()
if _, ok := l.channels[eventName]; ok {
for i := range l.channels[eventName] {
if l.channels[eventName][i] == channel {
l.channels[eventName] = append(l.channels[eventName][:i], l.channels[eventName][i+1:]...)
break
}
}
}
}
// Emit emits an event in parallel to all listeners (channels).
func (l *listener) Emit(eventName string, data string) {
l.lock.Lock()
defer l.lock.Unlock()
l.emit(eventName, data, false)
}
func (l *listener) emit(eventName, data string, isReEmit bool) {
if !l.shouldEmit(eventName, data) {
log.Warn("Emit of ", eventName, " with data ", data, " skipped")
return
}
if _, ok := l.channels[eventName]; ok {
for i, handler := range l.channels[eventName] {
go func(handler chan<- string, i int) {
log := log.WithField("name", eventName).WithField("i", i).WithField("data", data)
log.Debug("Send event")
handler <- data
log.Debug("Event sent")
}(handler, i)
}
} else if !isReEmit {
if bufferedData, ok := l.buffered[eventName]; ok {
l.buffered[eventName] = append(bufferedData, data)
log.Debugf("Buffering event %s data %s", eventName, data)
} else {
log.Warnf("No channel is listening to %s data %s", eventName, data)
}
}
}
func (l *listener) shouldEmit(eventName, data string) bool {
if _, ok := l.limits[eventName]; !ok {
return true
}
l.clearLastEmits()
if eventLastEmits, ok := l.lastEmits[eventName]; ok {
if _, ok := eventLastEmits[data]; ok {
return false
}
} else {
l.lastEmits[eventName] = make(map[string]time.Time)
}
l.lastEmits[eventName][data] = time.Now()
return true
}
func (l *listener) clearLastEmits() {
for eventName, lastEmits := range l.lastEmits {
limit, ok := l.limits[eventName]
if !ok { // Limits were disabled.
delete(l.lastEmits, eventName)
continue
}
for key, lastEmit := range lastEmits {
if time.Since(lastEmit) > limit {
delete(lastEmits, key)
}
}
}
}
func (l *listener) SetBuffer(eventName string) {
l.lock.Lock()
defer l.lock.Unlock()
if _, ok := l.buffered[eventName]; !ok {
l.buffered[eventName] = []string{}
}
}
func (l *listener) RetryEmit(eventName string) {
l.lock.Lock()
defer l.lock.Unlock()
if _, ok := l.channels[eventName]; !ok || len(l.channels[eventName]) == 0 {
return
}
if bufferedData, ok := l.buffered[eventName]; ok {
for _, data := range bufferedData {
l.emit(eventName, data, true)
}
l.buffered[eventName] = []string{}
}
}

View File

@ -1,177 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package listener
import (
"fmt"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)
const minEventReceiveTime = 100 * time.Millisecond
func Example() {
eventListener := New()
ch := make(chan string)
eventListener.Add("eventname", ch)
for eventdata := range ch {
fmt.Println(eventdata + " world")
}
eventListener.Emit("eventname", "hello")
}
func TestAddAndEmitSameEvent(t *testing.T) {
listener, channel := newListener()
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
}
func TestAddAndEmitDifferentEvent(t *testing.T) {
listener, channel := newListener()
listener.Emit("other", "hello!")
checkChannelNotEmitted(t, channel)
}
func TestAddAndRemove(t *testing.T) {
listener := New()
channel := make(chan string)
listener.Add("event", channel)
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
listener.Remove("event", channel)
listener.Emit("event", "hello!")
checkChannelNotEmitted(t, channel)
}
func TestNoLimit(t *testing.T) {
listener, channel := newListener()
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
}
func TestLimit(t *testing.T) {
listener, channel := newListener()
listener.SetLimit("event", 1*time.Second)
channel2 := make(chan string)
listener.Add("event", channel2)
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
checkChannelEmitted(t, channel2, "hello!")
listener.Emit("event", "hello!")
checkChannelNotEmitted(t, channel)
checkChannelNotEmitted(t, channel2)
time.Sleep(1 * time.Second)
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
checkChannelEmitted(t, channel2, "hello!")
}
func TestLimitDifferentData(t *testing.T) {
listener, channel := newListener()
listener.SetLimit("event", 1*time.Second)
listener.Emit("event", "hello!")
checkChannelEmitted(t, channel, "hello!")
listener.Emit("event", "hello?")
checkChannelEmitted(t, channel, "hello?")
}
func TestReEmit(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
listener := New()
listener.Emit("event", "hello?")
listener.SetBuffer("event")
listener.SetBuffer("other")
listener.Emit("event", "hello1")
listener.Emit("event", "hello2")
listener.Emit("other", "hello!")
listener.Emit("event", "hello3")
listener.Emit("other", "hello!")
eventCH := make(chan string, 3)
listener.Add("event", eventCH)
otherCH := make(chan string)
listener.Add("other", otherCH)
listener.RetryEmit("event")
listener.RetryEmit("other")
time.Sleep(time.Millisecond)
receivedEvents := map[string]int{}
for i := 0; i < 5; i++ {
select {
case res := <-eventCH:
receivedEvents[res]++
case res := <-otherCH:
receivedEvents[res+":other"]++
case <-time.After(minEventReceiveTime):
t.Fatalf("Channel not emitted %d times", i+1)
}
}
expectedEvents := map[string]int{"hello1": 1, "hello2": 1, "hello3": 1, "hello!:other": 2}
require.Equal(t, expectedEvents, receivedEvents)
}
func newListener() (Listener, chan string) {
listener := New()
channel := make(chan string)
listener.Add("event", channel)
return listener, channel
}
func checkChannelEmitted(t testing.TB, channel chan string, expectedData string) {
select {
case res := <-channel:
require.Equal(t, expectedData, res)
case <-time.After(minEventReceiveTime):
t.Fatalf("Channel not emitted with expected data: %s", expectedData)
}
}
func checkChannelNotEmitted(t testing.TB, channel chan string) {
select {
case res := <-channel:
t.Fatalf("Channel emitted with a unexpected response: %s", res)
case <-time.After(minEventReceiveTime):
}
}

View File

@ -1,130 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bufio"
"bytes"
"io"
)
type boundaryReader struct {
reader *bufio.Reader
closed, first bool
skipped int
nl []byte // "\r\n" or "\n" (set after seeing first boundary line)
nlDashBoundary []byte // nl + "--boundary"
dashBoundaryDash []byte // "--boundary--"
dashBoundary []byte // "--boundary"
}
func newBoundaryReader(r *bufio.Reader, boundary string) (br *boundaryReader, err error) {
b := []byte("\r\n--" + boundary + "--")
br = &boundaryReader{
reader: r,
closed: false,
first: true,
nl: b[:2],
nlDashBoundary: b[:len(b)-2],
dashBoundaryDash: b[2:],
dashBoundary: b[2 : len(b)-2],
}
err = br.writeNextPartTo(nil)
return
}
// writeNextPartTo will copy the the bytes of next part and write them to
// writer. Will return EOF if the underlying reader is empty.
func (br *boundaryReader) writeNextPartTo(part io.Writer) (err error) {
if br.closed {
return io.EOF
}
var line, slice []byte
br.skipped = 0
for {
slice, err = br.reader.ReadSlice('\n')
line = append(line, slice...)
if err == bufio.ErrBufferFull {
continue
}
br.skipped += len(line)
if err == io.EOF && br.isFinalBoundary(line) {
err = nil
br.closed = true
return
}
if err != nil {
return
}
if br.isBoundaryDelimiterLine(line) {
br.first = false
return
}
if br.isFinalBoundary(line) {
br.closed = true
return
}
if part != nil {
if _, err = part.Write(line); err != nil {
return
}
}
line = []byte{}
}
}
func (br *boundaryReader) isFinalBoundary(line []byte) bool {
if !bytes.HasPrefix(line, br.dashBoundaryDash) {
return false
}
rest := line[len(br.dashBoundaryDash):]
rest = skipLWSPChar(rest)
return len(rest) == 0 || bytes.Equal(rest, br.nl)
}
func (br *boundaryReader) isBoundaryDelimiterLine(line []byte) (ret bool) {
if !bytes.HasPrefix(line, br.dashBoundary) {
return false
}
rest := line[len(br.dashBoundary):]
rest = skipLWSPChar(rest)
if br.first && len(rest) == 1 && rest[0] == '\n' {
br.nl = br.nl[1:]
br.nlDashBoundary = br.nlDashBoundary[1:]
}
return bytes.Equal(rest, br.nl)
}
func skipLWSPChar(b []byte) []byte {
for len(b) > 0 && (b[0] == ' ' || b[0] == '\t') {
b = b[1:]
}
return b
}

View File

@ -18,14 +18,22 @@
package message
import (
"context"
"io"
"sync"
"bytes"
"encoding/base64"
"mime"
"net/mail"
"strings"
"time"
"unicode/utf8"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-rfc5322"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/ProtonMail/proton-bridge/v2/pkg/pool"
"github.com/ProtonMail/proton-bridge/v2/pkg/algo"
"github.com/emersion/go-message"
"github.com/emersion/go-message/textproto"
"github.com/pkg/errors"
"gitlab.protontech.ch/go/liteapi"
)
var (
@ -33,199 +41,526 @@ var (
ErrNoSuchKeyRing = errors.New("the keyring to decrypt this message could not be found")
)
const (
BackgroundPriority = 1 << iota
ForegroundPriority
)
// InternalIDDomain is used as a placeholder for reference/message ID headers to improve compatibility with various clients.
const InternalIDDomain = `protonmail.internalid`
type Builder struct {
pool *pool.Pool
jobs map[string]*Job
lock sync.Mutex
}
func BuildRFC822(kr *crypto.KeyRing, msg liteapi.Message, attData map[string][]byte, opts JobOptions) ([]byte, error) {
switch {
case len(msg.Attachments) > 0:
return buildMultipartRFC822(kr, msg, attData, opts)
type Fetcher interface {
GetMessage(context.Context, string) (*pmapi.Message, error)
GetAttachment(context.Context, string) (io.ReadCloser, error)
KeyRingForAddressID(string) (*crypto.KeyRing, error)
}
case msg.MIMEType == "multipart/mixed":
return buildPGPRFC822(kr, msg, opts)
// NewBuilder creates a new builder which manages the given number of fetch/attach/build workers.
// - fetchWorkers: the number of workers which fetch messages from API
// - attachWorkers: the number of workers which fetch attachments from API.
//
// The returned builder is ready to handle jobs -- see (*Builder).NewJob for more information.
//
// Call (*Builder).Done to shut down the builder and stop all workers.
func NewBuilder(fetchWorkers, attachmentWorkers int) *Builder {
attachmentPool := pool.New(attachmentWorkers, newAttacherWorkFunc())
fetcherPool := pool.New(fetchWorkers, newFetcherWorkFunc(attachmentPool))
return &Builder{
pool: fetcherPool,
jobs: make(map[string]*Job),
default:
return buildSimpleRFC822(kr, msg, opts)
}
}
func (builder *Builder) NewJob(ctx context.Context, fetcher Fetcher, messageID string, prio int) (*Job, pool.DoneFunc) {
return builder.NewJobWithOptions(ctx, fetcher, messageID, JobOptions{}, prio)
}
func (builder *Builder) NewJobWithOptions(ctx context.Context, fetcher Fetcher, messageID string, opts JobOptions, prio int) (*Job, pool.DoneFunc) {
builder.lock.Lock()
defer builder.lock.Unlock()
if job, ok := builder.jobs[messageID]; ok {
if job.GetPriority() < prio {
job.SetPriority(prio)
func buildSimpleRFC822(kr *crypto.KeyRing, msg liteapi.Message, opts JobOptions) ([]byte, error) {
dec, err := msg.Decrypt(kr)
if err != nil {
if !opts.IgnoreDecryptionErrors {
return nil, errors.Wrap(ErrDecryptionFailed, err.Error())
}
return job, job.done
return buildMultipartRFC822(kr, msg, nil, opts)
}
job, done := builder.pool.NewJob(
&fetchReq{
ctx: ctx,
fetcher: fetcher,
messageID: messageID,
options: opts,
},
prio,
)
hdr := getTextPartHeader(getMessageHeader(msg, opts), dec, msg.MIMEType)
buildDone := func() {
builder.lock.Lock()
defer builder.lock.Unlock()
buf := new(bytes.Buffer)
// Remove the job from the builder.
delete(builder.jobs, messageID)
// And mark it as done.
done()
}
buildJob := &Job{
Job: job,
done: buildDone,
}
builder.jobs[messageID] = buildJob
return buildJob, buildDone
}
func (builder *Builder) Done() {
// NOTE(GODT-1158): Stop worker pool.
}
type fetchReq struct {
ctx context.Context
fetcher Fetcher
messageID string
options JobOptions
}
type attachReq struct {
ctx context.Context
fetcher Fetcher
message *pmapi.Message
}
type Job struct {
*pool.Job
done pool.DoneFunc
}
func (job *Job) GetResult() ([]byte, error) {
res, err := job.Job.GetResult()
w, err := message.CreateWriter(buf, hdr)
if err != nil {
return nil, err
}
return res.([]byte), nil //nolint:forcetypeassert
}
// NOTE: This is not used because it is actually not doing what was expected: It
// downloads all the attachments which belongs to one message sequentially
// within one goroutine. We should have one job per one attachment. This doesn't look
// like a bottle neck right now.
func newAttacherWorkFunc() pool.WorkFunc {
return func(payload interface{}, prio int) (interface{}, error) {
req, ok := payload.(*attachReq)
if !ok {
panic("bad payload type")
}
res := make(map[string][]byte)
for _, att := range req.message.Attachments {
rc, err := req.fetcher.GetAttachment(req.ctx, att.ID)
if err != nil {
return nil, err
}
b, err := io.ReadAll(rc)
if err != nil {
return nil, err
}
if err := rc.Close(); err != nil {
return nil, err
}
res[att.ID] = b
}
return res, nil
if _, err := w.Write(dec); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func newFetcherWorkFunc(attachmentPool *pool.Pool) pool.WorkFunc {
return func(payload interface{}, prio int) (interface{}, error) {
req, ok := payload.(*fetchReq)
if !ok {
panic("bad payload type")
}
func buildMultipartRFC822(
kr *crypto.KeyRing,
msg liteapi.Message,
attData map[string][]byte,
opts JobOptions,
) ([]byte, error) {
boundary := newBoundary(msg.ID)
msg, err := req.fetcher.GetMessage(req.ctx, req.messageID)
if err != nil {
hdr := getMessageHeader(msg, opts)
hdr.SetContentType("multipart/mixed", map[string]string{"boundary": boundary.gen()})
buf := new(bytes.Buffer)
w, err := message.CreateWriter(buf, hdr)
if err != nil {
return nil, err
}
var (
inlineAtts []liteapi.Attachment
inlineData [][]byte
attachAtts []liteapi.Attachment
attachData [][]byte
)
for _, att := range msg.Attachments {
if att.Disposition == liteapi.InlineDisposition {
inlineAtts = append(inlineAtts, att)
inlineData = append(inlineData, attData[att.ID])
} else {
attachAtts = append(attachAtts, att)
attachData = append(attachData, attData[att.ID])
}
}
if len(inlineAtts) > 0 {
if err := writeRelatedParts(w, kr, boundary, msg, inlineAtts, inlineData, opts); err != nil {
return nil, err
}
} else if err := writeTextPart(w, kr, msg, opts); err != nil {
return nil, err
}
attData := make(map[string][]byte)
for i, att := range attachAtts {
if err := writeAttachmentPart(w, kr, att, attachData[i], opts); err != nil {
return nil, err
}
}
for _, att := range msg.Attachments {
// NOTE: Potential place for optimization:
// Use attachmentPool to download each attachment in
// separate parallel job. It is not straightforward
// because we need to make sure we call attachment-job-done
// function in case of any error or after we collect all
// attachment bytes asynchronously.
rc, err := req.fetcher.GetAttachment(req.ctx, att.ID)
if err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
b, err := io.ReadAll(rc)
if err != nil {
_ = rc.Close()
return nil, err
}
return buf.Bytes(), nil
}
if err := rc.Close(); err != nil {
return nil, err
}
attData[att.ID] = b
func writeTextPart(
w *message.Writer,
kr *crypto.KeyRing,
msg liteapi.Message,
opts JobOptions,
) error {
dec, err := msg.Decrypt(kr)
if err != nil {
if !opts.IgnoreDecryptionErrors {
return errors.Wrap(ErrDecryptionFailed, err.Error())
}
kr, err := req.fetcher.KeyRingForAddressID(msg.AddressID)
if err != nil {
return nil, ErrNoSuchKeyRing
return writeCustomTextPart(w, msg, err)
}
return writePart(w, getTextPartHeader(message.Header{}, dec, msg.MIMEType), dec)
}
func writeAttachmentPart(
w *message.Writer,
kr *crypto.KeyRing,
att liteapi.Attachment,
attData []byte,
opts JobOptions,
) error {
kps, err := base64.StdEncoding.DecodeString(att.KeyPackets)
if err != nil {
return err
}
msg := crypto.NewPGPSplitMessage(kps, attData).GetPGPMessage()
dec, err := kr.Decrypt(msg, nil, crypto.GetUnixTime())
if err != nil {
if !opts.IgnoreDecryptionErrors {
return errors.Wrap(ErrDecryptionFailed, err.Error())
}
return buildRFC822(kr, msg, attData, req.options)
log.
WithField("attID", att.ID).
WithError(err).
Warn("Attachment decryption failed")
return writeCustomAttachmentPart(w, att, msg, err)
}
return writePart(w, getAttachmentPartHeader(att), dec.GetBinary())
}
func writeRelatedParts(
w *message.Writer,
kr *crypto.KeyRing,
boundary *boundary,
msg liteapi.Message,
atts []liteapi.Attachment,
attData [][]byte,
opts JobOptions,
) error {
hdr := message.Header{}
hdr.SetContentType("multipart/related", map[string]string{"boundary": boundary.gen()})
return createPart(w, hdr, func(rel *message.Writer) error {
if err := writeTextPart(rel, kr, msg, opts); err != nil {
return err
}
for i, att := range atts {
if err := writeAttachmentPart(rel, kr, att, attData[i], opts); err != nil {
return err
}
}
return nil
})
}
func buildPGPRFC822(kr *crypto.KeyRing, msg liteapi.Message, opts JobOptions) ([]byte, error) {
dec, err := msg.Decrypt(kr)
if err != nil {
if !opts.IgnoreDecryptionErrors {
return nil, errors.Wrap(ErrDecryptionFailed, err.Error())
}
return buildPGPMIMEFallbackRFC822(msg, opts)
}
hdr := getMessageHeader(msg, opts)
sigs, err := msg.ExtractSignatures(kr)
if err != nil {
log.WithError(err).WithField("id", msg.ID).Warn("Extract signature failed")
}
if len(sigs) > 0 {
return writeMultipartSignedRFC822(hdr, dec, sigs[0])
}
return writeMultipartEncryptedRFC822(hdr, dec)
}
func buildPGPMIMEFallbackRFC822(msg liteapi.Message, opts JobOptions) ([]byte, error) {
hdr := getMessageHeader(msg, opts)
hdr.SetContentType("multipart/encrypted", map[string]string{
"boundary": newBoundary(msg.ID).gen(),
"protocol": "application/pgp-encrypted",
})
buf := new(bytes.Buffer)
w, err := message.CreateWriter(buf, hdr)
if err != nil {
return nil, err
}
var encHdr message.Header
encHdr.SetContentType("application/pgp-encrypted", nil)
encHdr.Set("Content-Description", "PGP/MIME version identification")
if err := writePart(w, encHdr, []byte("Version: 1")); err != nil {
return nil, err
}
var dataHdr message.Header
dataHdr.SetContentType("application/octet-stream", map[string]string{"name": "encrypted.asc"})
dataHdr.SetContentDisposition("inline", map[string]string{"filename": "encrypted.asc"})
dataHdr.Set("Content-Description", "OpenPGP encrypted message")
if err := writePart(w, dataHdr, []byte(msg.Body)); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeMultipartSignedRFC822(header message.Header, body []byte, sig liteapi.Signature) ([]byte, error) { //nolint:funlen
buf := new(bytes.Buffer)
boundary := newBoundary("").gen()
header.SetContentType("multipart/signed", map[string]string{
"micalg": sig.Hash,
"protocol": "application/pgp-signature",
"boundary": boundary,
})
if err := textproto.WriteHeader(buf, header.Header); err != nil {
return nil, err
}
mw := textproto.NewMultipartWriter(buf)
if err := mw.SetBoundary(boundary); err != nil {
return nil, err
}
bodyHeader, bodyData, err := readHeaderBody(body)
if err != nil {
return nil, err
}
bodyPart, err := mw.CreatePart(*bodyHeader)
if err != nil {
return nil, err
}
if _, err := bodyPart.Write(bodyData); err != nil {
return nil, err
}
var sigHeader message.Header
sigHeader.SetContentType("application/pgp-signature", map[string]string{"name": "OpenPGP_signature.asc"})
sigHeader.SetContentDisposition("attachment", map[string]string{"filename": "OpenPGP_signature"})
sigHeader.Set("Content-Description", "OpenPGP digital signature")
sigPart, err := mw.CreatePart(sigHeader.Header)
if err != nil {
return nil, err
}
sigData, err := sig.Data.GetArmored()
if err != nil {
return nil, err
}
if _, err := sigPart.Write([]byte(sigData)); err != nil {
return nil, err
}
if err := mw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeMultipartEncryptedRFC822(header message.Header, body []byte) ([]byte, error) {
buf := new(bytes.Buffer)
bodyHeader, bodyData, err := readHeaderBody(body)
if err != nil {
return nil, err
}
// If parsed header is empty then either it is malformed or it is missing.
// Anyway message could not be considered multipart/mixed anymore since there will be no boundary.
if bodyHeader.Len() == 0 {
header.Del("Content-Type")
}
entFields := bodyHeader.Fields()
for entFields.Next() {
header.Set(entFields.Key(), entFields.Value())
}
if err := textproto.WriteHeader(buf, header.Header); err != nil {
return nil, err
}
if _, err := buf.Write(bodyData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func getMessageHeader(msg liteapi.Message, opts JobOptions) message.Header { //nolint:funlen
hdr := toMessageHeader(msg.ParsedHeaders)
// SetText will RFC2047-encode.
if msg.Subject != "" {
hdr.SetText("Subject", msg.Subject)
}
// mail.Address.String() will RFC2047-encode if necessary.
if msg.Sender != nil {
hdr.Set("From", msg.Sender.String())
}
if len(msg.ReplyTos) > 0 {
hdr.Set("Reply-To", toAddressList(msg.ReplyTos))
}
if len(msg.ToList) > 0 {
hdr.Set("To", toAddressList(msg.ToList))
}
if len(msg.CCList) > 0 {
hdr.Set("Cc", toAddressList(msg.CCList))
}
if len(msg.BCCList) > 0 {
hdr.Set("Bcc", toAddressList(msg.BCCList))
}
setMessageIDIfNeeded(msg, &hdr)
// Sanitize the date; it needs to have a valid unix timestamp.
if opts.SanitizeDate {
if date, err := rfc5322.ParseDateTime(hdr.Get("Date")); err != nil || date.Before(time.Unix(0, 0)) {
msgDate := SanitizeMessageDate(msg.Time)
hdr.Set("Date", msgDate.In(time.UTC).Format(time.RFC1123Z))
// We clobbered the date so we save it under X-Original-Date.
hdr.Set("X-Original-Date", date.In(time.UTC).Format(time.RFC1123Z))
}
}
// Set our internal ID if requested.
// This is important for us to detect whether APPENDed things are actually "move like outlook".
if opts.AddInternalID {
hdr.Set("X-Pm-Internal-Id", msg.ID)
}
// Set our external ID if requested.
// This was useful during debugging of applemail recovered messages; doesn't help with any behaviour.
if opts.AddExternalID {
hdr.Set("X-Pm-External-Id", "<"+msg.ExternalID+">")
}
// Set our server date if requested.
// Can be useful to see how long it took for a message to arrive.
if opts.AddMessageDate {
hdr.Set("X-Pm-Date", time.Unix(msg.Time, 0).In(time.UTC).Format(time.RFC1123Z))
}
// Include the message ID in the references (supposedly this somehow improves outlook support...).
if opts.AddMessageIDReference {
if references := hdr.Get("References"); !strings.Contains(references, msg.ID) {
hdr.Set("References", references+" <"+msg.ID+"@"+InternalIDDomain+">")
}
}
return hdr
}
// SanitizeMessageDate will return time from msgTime timestamp. If timestamp is
// not after epoch the RFC822 publish day will be used. No message should
// realistically be older than RFC822 itself.
func SanitizeMessageDate(msgTime int64) time.Time {
if msgTime := time.Unix(msgTime, 0); msgTime.After(time.Unix(0, 0)) {
return msgTime
}
return time.Date(1982, 8, 13, 0, 0, 0, 0, time.UTC)
}
// setMessageIDIfNeeded sets Message-Id from ExternalID or ID if it's not
// already set.
func setMessageIDIfNeeded(msg liteapi.Message, hdr *message.Header) {
if hdr.Get("Message-Id") == "" {
if msg.ExternalID != "" {
hdr.Set("Message-Id", "<"+msg.ExternalID+">")
} else {
hdr.Set("Message-Id", "<"+msg.ID+"@"+InternalIDDomain+">")
}
}
}
func getTextPartHeader(hdr message.Header, body []byte, mimeType rfc822.MIMEType) message.Header {
params := make(map[string]string)
if utf8.Valid(body) {
params["charset"] = "utf-8"
}
hdr.SetContentType(string(mimeType), params)
// Use quoted-printable for all text/... parts
hdr.Set("Content-Transfer-Encoding", "quoted-printable")
return hdr
}
func getAttachmentPartHeader(att liteapi.Attachment) message.Header {
hdr := toMessageHeader(liteapi.Headers(att.Headers))
// All attachments have a content type.
hdr.SetContentType(string(att.MIMEType), map[string]string{"name": mime.QEncoding.Encode("utf-8", att.Name)})
// All attachments have a content disposition.
hdr.SetContentDisposition(string(att.Disposition), map[string]string{"filename": mime.QEncoding.Encode("utf-8", att.Name)})
// Use base64 for all attachments except embedded RFC822 messages.
if att.MIMEType != rfc822.MessageRFC822 {
hdr.Set("Content-Transfer-Encoding", "base64")
} else {
hdr.Del("Content-Transfer-Encoding")
}
return hdr
}
func toMessageHeader(hdr liteapi.Headers) message.Header {
var res message.Header
for key, val := range hdr {
for _, val := range val {
// Using AddRaw instead of Add to save key-value pair as byte buffer within Header.
// This buffer is used latter on in message writer to construct message and avoid crash
// when key length is more than 76 characters long.
res.AddRaw([]byte(key + ": " + val + "\r\n"))
}
}
return res
}
func toAddressList(addrs []*mail.Address) string {
res := make([]string, len(addrs))
for i, addr := range addrs {
res[i] = addr.String()
}
return strings.Join(res, ", ")
}
func createPart(w *message.Writer, hdr message.Header, fn func(*message.Writer) error) error {
part, err := w.CreatePart(hdr)
if err != nil {
return err
}
if err := fn(part); err != nil {
return err
}
return part.Close()
}
func writePart(w *message.Writer, hdr message.Header, body []byte) error {
return createPart(w, hdr, func(part *message.Writer) error {
if _, err := part.Write(body); err != nil {
return errors.Wrap(err, "failed to write part body")
}
return nil
})
}
type boundary struct {
val string
}
func newBoundary(seed string) *boundary {
return &boundary{val: seed}
}
func (bw *boundary) gen() string {
bw.val = algo.HashHexSHA256(bw.val)
return bw.val
}

View File

@ -1,35 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"github.com/ProtonMail/proton-bridge/v2/pkg/algo"
)
type boundary struct {
val string
}
func newBoundary(seed string) *boundary {
return &boundary{val: seed}
}
func (bw *boundary) gen() string {
bw.val = algo.HashHexSHA256(bw.val)
return bw.val
}

View File

@ -23,14 +23,14 @@ import (
"github.com/ProtonMail/gopenpgp/v2/constants"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/emersion/go-message"
"gitlab.protontech.ch/go/liteapi"
)
// writeCustomTextPart writes an armored-PGP text part for a message body that couldn't be decrypted.
func writeCustomTextPart(
w *message.Writer,
msg *pmapi.Message,
msg liteapi.Message,
decError error,
) error {
enc, err := crypto.NewPGPMessageFromArmored(msg.Body)
@ -48,7 +48,7 @@ func writeCustomTextPart(
var hdr message.Header
hdr.SetContentType(msg.MIMEType, nil)
hdr.SetContentType(string(msg.MIMEType), nil)
part, err := w.CreatePart(hdr)
if err != nil {
@ -65,7 +65,7 @@ func writeCustomTextPart(
// writeCustomAttachmentPart writes an armored-PGP data part for an attachment that couldn't be decrypted.
func writeCustomAttachmentPart(
w *message.Writer,
att *pmapi.Attachment,
att liteapi.Attachment,
msg *crypto.PGPMessage,
decError error,
) error {
@ -82,7 +82,7 @@ func writeCustomAttachmentPart(
var hdr message.Header
hdr.SetContentType("application/octet-stream", map[string]string{"name": filename})
hdr.SetContentDisposition(att.Disposition, map[string]string{"filename": filename})
hdr.SetContentDisposition(string(att.Disposition), map[string]string{"filename": filename})
part, err := w.CreatePart(hdr)
if err != nil {

View File

@ -1,168 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bytes"
"encoding/base64"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/emersion/go-message"
"github.com/emersion/go-textwrapper"
)
// BuildEncrypted is used for importing encrypted message.
func BuildEncrypted(m *pmapi.Message, readers []io.Reader, kr *crypto.KeyRing) ([]byte, error) { //nolint:funlen
b := &bytes.Buffer{}
boundary := newBoundary(m.ID).gen()
// Overwrite content for main header for import.
// Even if message has just simple body we should upload as multipart/mixed.
// Each part has encrypted body and header reflects the original header.
mainHeader := convertGoMessageToTextprotoHeader(getMessageHeader(m, JobOptions{}))
mainHeader.Set("Content-Type", "multipart/mixed; boundary="+boundary)
mainHeader.Del("Content-Disposition")
mainHeader.Del("Content-Transfer-Encoding")
if err := WriteHeader(b, mainHeader); err != nil {
return nil, err
}
mw := multipart.NewWriter(b)
if err := mw.SetBoundary(boundary); err != nil {
return nil, err
}
// Write the body part.
bodyHeader := make(textproto.MIMEHeader)
bodyHeader.Set("Content-Type", m.MIMEType+"; charset=utf-8")
bodyHeader.Set("Content-Disposition", pmapi.DispositionInline)
bodyHeader.Set("Content-Transfer-Encoding", "7bit")
p, err := mw.CreatePart(bodyHeader)
if err != nil {
return nil, err
}
// First, encrypt the message body.
if err := m.Encrypt(kr, kr); err != nil {
return nil, err
}
if _, err := io.WriteString(p, m.Body); err != nil {
return nil, err
}
// Write the attachments parts.
for i := 0; i < len(m.Attachments); i++ {
att := m.Attachments[i]
r := readers[i]
h := getAttachmentHeader(att, false)
p, err := mw.CreatePart(h)
if err != nil {
return nil, err
}
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
// Create encrypted writer.
pgpMessage, err := kr.Encrypt(crypto.NewPlainMessage(data), nil)
if err != nil {
return nil, err
}
ww := textwrapper.NewRFC822(p)
bw := base64.NewEncoder(base64.StdEncoding, ww)
if _, err := bw.Write(pgpMessage.GetBinary()); err != nil {
return nil, err
}
if err := bw.Close(); err != nil {
return nil, err
}
}
if err := mw.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}
func convertGoMessageToTextprotoHeader(h message.Header) textproto.MIMEHeader {
out := make(textproto.MIMEHeader)
hf := h.Fields()
for hf.Next() {
// go-message fields are in the reverse order.
// textproto.MIMEHeader is not ordered except for the values of
// the same key which are ordered
key := textproto.CanonicalMIMEHeaderKey(hf.Key())
out[key] = append([]string{hf.Value()}, out[key]...)
}
return out
}
func getAttachmentHeader(att *pmapi.Attachment, buildForIMAP bool) textproto.MIMEHeader {
mediaType := att.MIMEType
if mediaType == "application/pgp-encrypted" {
mediaType = "application/octet-stream"
}
transferEncoding := "base64"
if mediaType == rfc822Message && buildForIMAP {
transferEncoding = "8bit"
}
encodedName := pmmime.EncodeHeader(att.Name)
disposition := "attachment" //nolint:goconst
if strings.Contains(att.Header.Get("Content-Disposition"), pmapi.DispositionInline) {
disposition = pmapi.DispositionInline
}
h := make(textproto.MIMEHeader)
h.Set("Content-Type", mime.FormatMediaType(mediaType, map[string]string{"name": encodedName}))
if transferEncoding != "" {
h.Set("Content-Transfer-Encoding", transferEncoding)
}
h.Set("Content-Disposition", mime.FormatMediaType(disposition, map[string]string{"filename": encodedName}))
// Forward some original header lines.
forward := []string{"Content-Id", "Content-Description", "Content-Location"}
for _, k := range forward {
v := att.Header.Get(k)
if v != "" {
h.Set(k, v)
}
}
return h
}
func WriteHeader(w io.Writer, h textproto.MIMEHeader) (err error) {
if err = http.Header(h).Write(w); err != nil {
return
}
_, err = io.WriteString(w, "\r\n")
return
}

View File

@ -21,46 +21,24 @@ import (
"bufio"
"bytes"
"encoding/base64"
"io"
"strings"
"testing"
"time"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/pkg/message/mocks"
"github.com/ProtonMail/proton-bridge/v2/pkg/message/parser"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/text/encoding/htmlindex"
)
func newTestFetcher(
m *gomock.Controller,
kr *crypto.KeyRing,
msg *pmapi.Message,
attData ...[]byte,
) Fetcher {
f := mocks.NewMockFetcher(m)
f.EXPECT().GetMessage(gomock.Any(), msg.ID).Return(msg, nil)
for i, att := range msg.Attachments {
f.EXPECT().GetAttachment(gomock.Any(), att.ID).Return(newTestReadCloser(attData[i]), nil)
}
f.EXPECT().KeyRingForAddressID(msg.AddressID).Return(kr, nil)
return f
}
func newTestMessage(
t *testing.T,
kr *crypto.KeyRing,
messageID, addressID, mimeType, body string, //nolint:unparam
date time.Time,
) *pmapi.Message {
) liteapi.Message {
enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(body), kr)
require.NoError(t, err)
@ -70,57 +48,47 @@ func newTestMessage(
return newRawTestMessage(messageID, addressID, mimeType, arm, date)
}
func newRawTestMessage(messageID, addressID, mimeType, body string, date time.Time) *pmapi.Message {
return &pmapi.Message{
ID: messageID,
AddressID: addressID,
MIMEType: mimeType,
Header: map[string][]string{
func newRawTestMessage(messageID, addressID, mimeType, body string, date time.Time) liteapi.Message {
return liteapi.Message{
MessageMetadata: liteapi.MessageMetadata{
ID: messageID,
AddressID: addressID,
Time: date.Unix(),
},
ParsedHeaders: liteapi.Headers{
"Content-Type": {mimeType},
"Date": {date.In(time.UTC).Format(time.RFC1123Z)},
},
Body: body,
Time: date.Unix(),
MIMEType: rfc822.MIMEType(mimeType),
Body: body,
}
}
func addTestAttachment(
t *testing.T,
kr *crypto.KeyRing,
msg *pmapi.Message,
msg *liteapi.Message,
attachmentID, name, mimeType, disposition, data string,
) []byte {
enc, err := kr.EncryptAttachment(crypto.NewPlainMessageFromString(data), attachmentID+".bin")
require.NoError(t, err)
msg.Attachments = append(msg.Attachments, &pmapi.Attachment{
msg.Attachments = append(msg.Attachments, liteapi.Attachment{
ID: attachmentID,
Name: name,
MIMEType: mimeType,
Header: map[string][]string{
MIMEType: rfc822.MIMEType(mimeType),
Headers: liteapi.Headers{
"Content-Type": {mimeType},
"Content-Disposition": {disposition},
"Content-Transfer-Encoding": {"base64"},
},
Disposition: disposition,
Disposition: liteapi.Disposition(disposition),
KeyPackets: base64.StdEncoding.EncodeToString(enc.GetBinaryKeyPacket()),
})
return enc.GetBinaryDataPacket()
}
type testReadCloser struct {
io.Reader
}
func newTestReadCloser(b []byte) *testReadCloser {
return &testReadCloser{Reader: bytes.NewReader(b)}
}
func (testReadCloser) Close() error {
return nil
}
type testSection struct {
t *testing.T
part *parser.Part
@ -130,21 +98,18 @@ type testSection struct {
// NOTE: Each section is parsed individually --> cleaner test code but slower... improve this one day?
func section(t *testing.T, b []byte, section ...int) *testSection {
p, err := parser.New(bytes.NewReader(b))
assert.NoError(t, err)
require.NoError(t, err)
part, err := p.Section(section)
require.NoError(t, err)
bs, err := NewBodyStructure(bytes.NewReader(b))
require.NoError(t, err)
raw, err := bs.GetSection(bytes.NewReader(b), section)
s, err := rfc822.Parse(b).Part(section...)
require.NoError(t, err)
return &testSection{
t: t,
part: part,
raw: raw,
raw: s.Literal(),
}
}
@ -249,7 +214,7 @@ type isMatcher struct {
}
func (matcher isMatcher) match(t *testing.T, have string) {
assert.Equal(t, matcher.want, have)
require.Equal(t, matcher.want, have)
}
func is(want string) isMatcher {
@ -265,7 +230,7 @@ type isNotMatcher struct {
}
func (matcher isNotMatcher) match(t *testing.T, have string) {
assert.NotEqual(t, matcher.notWant, have)
require.NotEqual(t, matcher.notWant, have)
}
func isNot(notWant string) isNotMatcher {
@ -277,7 +242,7 @@ type containsMatcher struct {
}
func (matcher containsMatcher) match(t *testing.T, have string) {
assert.Contains(t, have, matcher.contains)
require.Contains(t, have, matcher.contains)
}
func contains(contains string) containsMatcher {
@ -296,7 +261,7 @@ func (matcher decryptsToMatcher) match(t *testing.T, have string) {
dec, err := matcher.kr.Decrypt(haveMsg, nil, crypto.GetUnixTime())
require.NoError(t, err)
assert.Equal(t, matcher.want, string(dec.GetBinary()))
require.Equal(t, matcher.want, string(dec.GetBinary()))
}
func decryptsTo(kr *crypto.KeyRing, want string) decryptsToMatcher {
@ -315,7 +280,7 @@ func (matcher decodesToMatcher) match(t *testing.T, have string) {
dec, err := enc.NewDecoder().String(have)
require.NoError(t, err)
assert.Equal(t, matcher.want, dec)
require.Equal(t, matcher.want, dec)
}
func decodesTo(charset string, want string) decodesToMatcher {
@ -328,8 +293,8 @@ type verifiesAgainstMatcher struct {
}
func (matcher verifiesAgainstMatcher) match(t *testing.T, have string) {
assert.NoError(t, matcher.kr.VerifyDetached(
crypto.NewPlainMessage(bytes.TrimSuffix([]byte(have), []byte("\r\n"))),
require.NoError(t, matcher.kr.VerifyDetached(
crypto.NewPlainMessage([]byte(have)),
matcher.sig,
crypto.GetUnixTime()),
)
@ -347,7 +312,7 @@ func (matcher maxLineLengthMatcher) match(t *testing.T, have string) {
scanner := bufio.NewScanner(strings.NewReader(have))
for scanner.Scan() {
assert.Less(t, len(scanner.Text()), matcher.wantMax)
require.Less(t, len(scanner.Text()), matcher.wantMax)
}
}

View File

@ -1,544 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bytes"
"encoding/base64"
"mime"
"net/mail"
"strings"
"time"
"unicode/utf8"
"github.com/ProtonMail/go-rfc5322"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/emersion/go-message"
"github.com/emersion/go-message/textproto"
"github.com/pkg/errors"
)
func buildRFC822(kr *crypto.KeyRing, msg *pmapi.Message, attData map[string][]byte, opts JobOptions) ([]byte, error) {
switch {
case len(msg.Attachments) > 0:
return buildMultipartRFC822(kr, msg, attData, opts)
case msg.MIMEType == "multipart/mixed":
return buildPGPRFC822(kr, msg, opts)
default:
return buildSimpleRFC822(kr, msg, opts)
}
}
func buildSimpleRFC822(kr *crypto.KeyRing, msg *pmapi.Message, opts JobOptions) ([]byte, error) {
dec, err := msg.Decrypt(kr)
if err != nil {
if !opts.IgnoreDecryptionErrors {
return nil, errors.Wrap(ErrDecryptionFailed, err.Error())
}
return buildMultipartRFC822(kr, msg, nil, opts)
}
hdr := getTextPartHeader(getMessageHeader(msg, opts), dec, msg.MIMEType)
buf := new(bytes.Buffer)
w, err := message.CreateWriter(buf, hdr)
if err != nil {
return nil, err
}
if _, err := w.Write(dec); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func buildMultipartRFC822(
kr *crypto.KeyRing,
msg *pmapi.Message,
attData map[string][]byte,
opts JobOptions,
) ([]byte, error) {
boundary := newBoundary(msg.ID)
hdr := getMessageHeader(msg, opts)
hdr.SetContentType("multipart/mixed", map[string]string{"boundary": boundary.gen()})
buf := new(bytes.Buffer)
w, err := message.CreateWriter(buf, hdr)
if err != nil {
return nil, err
}
var (
inlineAtts []*pmapi.Attachment
inlineData [][]byte
attachAtts []*pmapi.Attachment
attachData [][]byte
)
for _, att := range msg.Attachments {
if att.Disposition == pmapi.DispositionInline {
inlineAtts = append(inlineAtts, att)
inlineData = append(inlineData, attData[att.ID])
} else {
attachAtts = append(attachAtts, att)
attachData = append(attachData, attData[att.ID])
}
}
if len(inlineAtts) > 0 {
if err := writeRelatedParts(w, kr, boundary, msg, inlineAtts, inlineData, opts); err != nil {
return nil, err
}
} else if err := writeTextPart(w, kr, msg, opts); err != nil {
return nil, err
}
for i, att := range attachAtts {
if err := writeAttachmentPart(w, kr, att, attachData[i], opts); err != nil {
return nil, err
}
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeTextPart(
w *message.Writer,
kr *crypto.KeyRing,
msg *pmapi.Message,
opts JobOptions,
) error {
dec, err := msg.Decrypt(kr)
if err != nil {
if !opts.IgnoreDecryptionErrors {
return errors.Wrap(ErrDecryptionFailed, err.Error())
}
return writeCustomTextPart(w, msg, err)
}
return writePart(w, getTextPartHeader(message.Header{}, dec, msg.MIMEType), dec)
}
func writeAttachmentPart(
w *message.Writer,
kr *crypto.KeyRing,
att *pmapi.Attachment,
attData []byte,
opts JobOptions,
) error {
kps, err := base64.StdEncoding.DecodeString(att.KeyPackets)
if err != nil {
return err
}
msg := crypto.NewPGPSplitMessage(kps, attData).GetPGPMessage()
dec, err := kr.Decrypt(msg, nil, crypto.GetUnixTime())
if err != nil {
if !opts.IgnoreDecryptionErrors {
return errors.Wrap(ErrDecryptionFailed, err.Error())
}
log.
WithField("attID", att.ID).
WithField("msgID", att.MessageID).
WithError(err).
Warn("Attachment decryption failed")
return writeCustomAttachmentPart(w, att, msg, err)
}
return writePart(w, getAttachmentPartHeader(att), dec.GetBinary())
}
func writeRelatedParts(
w *message.Writer,
kr *crypto.KeyRing,
boundary *boundary,
msg *pmapi.Message,
atts []*pmapi.Attachment,
attData [][]byte,
opts JobOptions,
) error {
hdr := message.Header{}
hdr.SetContentType("multipart/related", map[string]string{"boundary": boundary.gen()})
return createPart(w, hdr, func(rel *message.Writer) error {
if err := writeTextPart(rel, kr, msg, opts); err != nil {
return err
}
for i, att := range atts {
if err := writeAttachmentPart(rel, kr, att, attData[i], opts); err != nil {
return err
}
}
return nil
})
}
func buildPGPRFC822(kr *crypto.KeyRing, msg *pmapi.Message, opts JobOptions) ([]byte, error) {
dec, err := msg.Decrypt(kr)
if err != nil {
if !opts.IgnoreDecryptionErrors {
return nil, errors.Wrap(ErrDecryptionFailed, err.Error())
}
return buildPGPMIMEFallbackRFC822(msg, opts)
}
hdr := getMessageHeader(msg, opts)
sigs, err := msg.ExtractSignatures(kr)
if err != nil {
log.WithError(err).WithField("id", msg.ID).Warn("Extract signature failed")
}
if len(sigs) > 0 {
return writeMultipartSignedRFC822(hdr, dec, sigs[0])
}
return writeMultipartEncryptedRFC822(hdr, dec)
}
func buildPGPMIMEFallbackRFC822(msg *pmapi.Message, opts JobOptions) ([]byte, error) {
hdr := getMessageHeader(msg, opts)
hdr.SetContentType("multipart/encrypted", map[string]string{
"boundary": newBoundary(msg.ID).gen(),
"protocol": "application/pgp-encrypted",
})
buf := new(bytes.Buffer)
w, err := message.CreateWriter(buf, hdr)
if err != nil {
return nil, err
}
var encHdr message.Header
encHdr.SetContentType("application/pgp-encrypted", nil)
encHdr.Set("Content-Description", "PGP/MIME version identification")
if err := writePart(w, encHdr, []byte("Version: 1")); err != nil {
return nil, err
}
var dataHdr message.Header
dataHdr.SetContentType("application/octet-stream", map[string]string{"name": "encrypted.asc"})
dataHdr.SetContentDisposition("inline", map[string]string{"filename": "encrypted.asc"})
dataHdr.Set("Content-Description", "OpenPGP encrypted message")
if err := writePart(w, dataHdr, []byte(msg.Body)); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeMultipartSignedRFC822(header message.Header, body []byte, sig pmapi.Signature) ([]byte, error) { //nolint:funlen
buf := new(bytes.Buffer)
boundary := newBoundary("").gen()
header.SetContentType("multipart/signed", map[string]string{
"micalg": sig.Hash,
"protocol": "application/pgp-signature",
"boundary": boundary,
})
if err := textproto.WriteHeader(buf, header.Header); err != nil {
return nil, err
}
mw := textproto.NewMultipartWriter(buf)
if err := mw.SetBoundary(boundary); err != nil {
return nil, err
}
bodyHeader, bodyData, err := readHeaderBody(body)
if err != nil {
return nil, err
}
bodyPart, err := mw.CreatePart(*bodyHeader)
if err != nil {
return nil, err
}
if _, err := bodyPart.Write(bodyData); err != nil {
return nil, err
}
var sigHeader message.Header
sigHeader.SetContentType("application/pgp-signature", map[string]string{"name": "OpenPGP_signature.asc"})
sigHeader.SetContentDisposition("attachment", map[string]string{"filename": "OpenPGP_signature"})
sigHeader.Set("Content-Description", "OpenPGP digital signature")
sigPart, err := mw.CreatePart(sigHeader.Header)
if err != nil {
return nil, err
}
sigData, err := crypto.NewPGPSignature(sig.Data).GetArmored()
if err != nil {
return nil, err
}
if _, err := sigPart.Write([]byte(sigData)); err != nil {
return nil, err
}
if err := mw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeMultipartEncryptedRFC822(header message.Header, body []byte) ([]byte, error) {
buf := new(bytes.Buffer)
bodyHeader, bodyData, err := readHeaderBody(body)
if err != nil {
return nil, err
}
// If parsed header is empty then either it is malformed or it is missing.
// Anyway message could not be considered multipart/mixed anymore since there will be no boundary.
if bodyHeader.Len() == 0 {
header.Del("Content-Type")
}
entFields := bodyHeader.Fields()
for entFields.Next() {
header.Set(entFields.Key(), entFields.Value())
}
if err := textproto.WriteHeader(buf, header.Header); err != nil {
return nil, err
}
if _, err := buf.Write(bodyData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func getMessageHeader(msg *pmapi.Message, opts JobOptions) message.Header { //nolint:funlen
hdr := toMessageHeader(msg.Header)
// SetText will RFC2047-encode.
if msg.Subject != "" {
hdr.SetText("Subject", msg.Subject)
}
// mail.Address.String() will RFC2047-encode if necessary.
if msg.Sender != nil {
hdr.Set("From", msg.Sender.String())
}
if len(msg.ReplyTos) > 0 {
hdr.Set("Reply-To", toAddressList(msg.ReplyTos))
}
if len(msg.ToList) > 0 {
hdr.Set("To", toAddressList(msg.ToList))
}
if len(msg.CCList) > 0 {
hdr.Set("Cc", toAddressList(msg.CCList))
}
if len(msg.BCCList) > 0 {
hdr.Set("Bcc", toAddressList(msg.BCCList))
}
setMessageIDIfNeeded(msg, &hdr)
// Sanitize the date; it needs to have a valid unix timestamp.
if opts.SanitizeDate {
if date, err := rfc5322.ParseDateTime(hdr.Get("Date")); err != nil || date.Before(time.Unix(0, 0)) {
msgDate := SanitizeMessageDate(msg.Time)
hdr.Set("Date", msgDate.In(time.UTC).Format(time.RFC1123Z))
// We clobbered the date so we save it under X-Original-Date.
hdr.Set("X-Original-Date", date.In(time.UTC).Format(time.RFC1123Z))
}
}
// Set our internal ID if requested.
// This is important for us to detect whether APPENDed things are actually "move like outlook".
if opts.AddInternalID {
hdr.Set("X-Pm-Internal-Id", msg.ID)
}
// Set our external ID if requested.
// This was useful during debugging of applemail recovered messages; doesn't help with any behaviour.
if opts.AddExternalID {
hdr.Set("X-Pm-External-Id", "<"+msg.ExternalID+">")
}
// Set our server date if requested.
// Can be useful to see how long it took for a message to arrive.
if opts.AddMessageDate {
hdr.Set("X-Pm-Date", time.Unix(msg.Time, 0).In(time.UTC).Format(time.RFC1123Z))
}
// Include the message ID in the references (supposedly this somehow improves outlook support...).
if opts.AddMessageIDReference {
if references := hdr.Get("References"); !strings.Contains(references, msg.ID) {
hdr.Set("References", references+" <"+msg.ID+"@"+pmapi.InternalIDDomain+">")
}
}
return hdr
}
// SanitizeMessageDate will return time from msgTime timestamp. If timestamp is
// not after epoch the RFC822 publish day will be used. No message should
// realistically be older than RFC822 itself.
func SanitizeMessageDate(msgTime int64) time.Time {
if msgTime := time.Unix(msgTime, 0); msgTime.After(time.Unix(0, 0)) {
return msgTime
}
return time.Date(1982, 8, 13, 0, 0, 0, 0, time.UTC)
}
// setMessageIDIfNeeded sets Message-Id from ExternalID or ID if it's not
// already set.
func setMessageIDIfNeeded(msg *pmapi.Message, hdr *message.Header) {
if hdr.Get("Message-Id") == "" {
if msg.ExternalID != "" {
hdr.Set("Message-Id", "<"+msg.ExternalID+">")
} else {
hdr.Set("Message-Id", "<"+msg.ID+"@"+pmapi.InternalIDDomain+">")
}
}
}
func getTextPartHeader(hdr message.Header, body []byte, mimeType string) message.Header {
params := make(map[string]string)
if utf8.Valid(body) {
params["charset"] = "utf-8"
}
hdr.SetContentType(mimeType, params)
// Use quoted-printable for all text/... parts
hdr.Set("Content-Transfer-Encoding", "quoted-printable")
return hdr
}
func getAttachmentPartHeader(att *pmapi.Attachment) message.Header {
hdr := toMessageHeader(mail.Header(att.Header))
// All attachments have a content type.
hdr.SetContentType(att.MIMEType, map[string]string{"name": mime.QEncoding.Encode("utf-8", att.Name)})
// All attachments have a content disposition.
hdr.SetContentDisposition(att.Disposition, map[string]string{"filename": mime.QEncoding.Encode("utf-8", att.Name)})
// Use base64 for all attachments except embedded RFC822 messages.
if att.MIMEType != rfc822Message {
hdr.Set("Content-Transfer-Encoding", "base64")
} else {
hdr.Del("Content-Transfer-Encoding")
}
return hdr
}
func toMessageHeader(hdr mail.Header) message.Header {
var res message.Header
for key, val := range hdr {
for _, val := range val {
// Using AddRaw instead of Add to save key-value pair as byte buffer within Header.
// This buffer is used latter on in message writer to construct message and avoid crash
// when key length is more than 76 characters long.
res.AddRaw([]byte(key + ": " + val + "\r\n"))
}
}
return res
}
func toAddressList(addrs []*mail.Address) string {
res := make([]string, len(addrs))
for i, addr := range addrs {
res[i] = addr.String()
}
return strings.Join(res, ", ")
}
func createPart(w *message.Writer, hdr message.Header, fn func(*message.Writer) error) error {
part, err := w.CreatePart(hdr)
if err != nil {
return err
}
if err := fn(part); err != nil {
return err
}
return part.Close()
}
func writePart(w *message.Writer, hdr message.Header, body []byte) error {
return createPart(w, hdr, func(part *message.Writer) error {
if _, err := part.Write(body); err != nil {
return errors.Wrap(err, "failed to write part body")
}
return nil
})
}

File diff suppressed because it is too large Load Diff

View File

@ -1,245 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bytes"
"encoding/base64"
"io"
"mime"
"mime/quotedprintable"
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
"github.com/emersion/go-message/textproto"
"github.com/pkg/errors"
)
func EncryptRFC822(kr *crypto.KeyRing, r io.Reader) ([]byte, error) {
b, err := io.ReadAll(r)
if err != nil {
return nil, err
}
header, body, err := readHeaderBody(b)
if err != nil {
return nil, err
}
buf := new(bytes.Buffer)
result, err := writeEncryptedPart(kr, header, bytes.NewReader(body))
if err != nil {
return nil, err
}
if err := textproto.WriteHeader(buf, *header); err != nil {
return nil, err
}
if _, err := result.WriteTo(buf); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeEncryptedPart(kr *crypto.KeyRing, header *textproto.Header, r io.Reader) (io.WriterTo, error) {
decoder := getTransferDecoder(r, header.Get("Content-Transfer-Encoding"))
encoded := new(bytes.Buffer)
contentType, contentParams, err := parseContentType(header.Get("Content-Type"))
// Ignoring invalid media parameter makes it work for invalid tutanota RFC2047-encoded attachment filenames since we often only really need the content type and not the optional media parameters.
if err != nil && !errors.Is(err, mime.ErrInvalidMediaParameter) {
return nil, err
}
switch {
case contentType == "", strings.HasPrefix(contentType, "text/"), strings.HasPrefix(contentType, "message/"):
header.Del("Content-Transfer-Encoding")
if charset, ok := contentParams["charset"]; ok {
if reader, err := pmmime.CharsetReader(charset, decoder); err == nil {
decoder = reader
// We can decode the charset to utf-8 so let's set that as the content type charset parameter.
contentParams["charset"] = "utf-8"
header.Set("Content-Type", mime.FormatMediaType(contentType, contentParams))
}
}
if err := encode(&writeCloser{encoded}, func(w io.Writer) error {
return writeEncryptedTextPart(w, decoder, kr)
}); err != nil {
return nil, err
}
case contentType == "multipart/encrypted":
if _, err := encoded.ReadFrom(decoder); err != nil {
return nil, err
}
case strings.HasPrefix(contentType, "multipart/"):
if err := encode(&writeCloser{encoded}, func(w io.Writer) error {
return writeEncryptedMultiPart(kr, w, header, decoder)
}); err != nil {
return nil, err
}
default:
header.Set("Content-Transfer-Encoding", "base64")
if err := encode(base64.NewEncoder(base64.StdEncoding, encoded), func(w io.Writer) error {
return writeEncryptedAttachmentPart(w, decoder, kr)
}); err != nil {
return nil, err
}
}
return encoded, nil
}
func writeEncryptedTextPart(w io.Writer, r io.Reader, kr *crypto.KeyRing) error {
dec, err := io.ReadAll(r)
if err != nil {
return err
}
var arm string
if msg, err := crypto.NewPGPMessageFromArmored(string(dec)); err != nil {
enc, err := kr.Encrypt(crypto.NewPlainMessage(dec), kr)
if err != nil {
return err
}
if arm, err = enc.GetArmored(); err != nil {
return err
}
} else if arm, err = msg.GetArmored(); err != nil {
return err
}
if _, err := io.WriteString(w, arm); err != nil {
return err
}
return nil
}
func writeEncryptedAttachmentPart(w io.Writer, r io.Reader, kr *crypto.KeyRing) error {
dec, err := io.ReadAll(r)
if err != nil {
return err
}
enc, err := kr.Encrypt(crypto.NewPlainMessage(dec), kr)
if err != nil {
return err
}
if _, err := w.Write(enc.GetBinary()); err != nil {
return err
}
return nil
}
func writeEncryptedMultiPart(kr *crypto.KeyRing, w io.Writer, header *textproto.Header, r io.Reader) error {
_, contentParams, err := parseContentType(header.Get("Content-Type"))
if err != nil {
return err
}
scanner, err := newPartScanner(r, contentParams["boundary"])
if err != nil {
return err
}
parts, err := scanner.scanAll()
if err != nil {
return err
}
writer := newPartWriter(w, contentParams["boundary"])
for _, part := range parts {
header, body, err := readHeaderBody(part.b)
if err != nil {
return err
}
result, err := writeEncryptedPart(kr, header, bytes.NewReader(body))
if err != nil {
return err
}
if err := writer.createPart(func(w io.Writer) error {
if err := textproto.WriteHeader(w, *header); err != nil {
return err
}
if _, err := result.WriteTo(w); err != nil {
return err
}
return nil
}); err != nil {
return err
}
}
return writer.done()
}
func getTransferDecoder(r io.Reader, encoding string) io.Reader {
switch strings.ToLower(encoding) {
case "base64":
return base64.NewDecoder(base64.StdEncoding, r)
case "quoted-printable":
return quotedprintable.NewReader(r)
default:
return r
}
}
func encode(wc io.WriteCloser, fn func(io.Writer) error) error {
if err := fn(wc); err != nil {
return err
}
return wc.Close()
}
type writeCloser struct {
io.Writer
}
func (writeCloser) Close() error { return nil }
func parseContentType(val string) (string, map[string]string, error) {
if val == "" {
val = "text/plain"
}
return pmmime.ParseMediaType(val)
}

View File

@ -1,101 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bytes"
"os"
"testing"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require"
)
func TestEncryptRFC822(t *testing.T) {
literal, err := os.ReadFile("testdata/text_plain_latin1.eml")
require.NoError(t, err)
key, err := crypto.GenerateKey("name", "email", "rsa", 2048)
require.NoError(t, err)
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
enc, err := EncryptRFC822(kr, bytes.NewReader(literal))
require.NoError(t, err)
section(t, enc).
expectContentType(is(`text/plain`)).
expectContentTypeParam(`charset`, is(`utf-8`)).
expectBody(decryptsTo(kr, `ééééééé`))
}
func TestEncryptRFC822Multipart(t *testing.T) {
literal, err := os.ReadFile("testdata/multipart_alternative_nested.eml")
require.NoError(t, err)
key, err := crypto.GenerateKey("name", "email", "rsa", 2048)
require.NoError(t, err)
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
enc, err := EncryptRFC822(kr, bytes.NewReader(literal))
require.NoError(t, err)
section(t, enc).
expectContentType(is(`multipart/alternative`))
section(t, enc, 1).
expectContentType(is(`multipart/alternative`))
section(t, enc, 1, 1).
expectContentType(is(`text/plain`)).
expectBody(decryptsTo(kr, "*multipart 1.1*\n\n"))
section(t, enc, 1, 2).
expectContentType(is(`text/html`)).
expectBody(decryptsTo(kr, `<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8">
</head>
<body>
<b>multipart 1.2</b>
</body>
</html>
`))
section(t, enc, 2).
expectContentType(is(`multipart/alternative`))
section(t, enc, 2, 1).
expectContentType(is(`text/plain`)).
expectBody(decryptsTo(kr, "*multipart 2.1*\n\n"))
section(t, enc, 2, 2).
expectContentType(is(`text/html`)).
expectBody(decryptsTo(kr, `<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8">
</head>
<body>
<b>multipart 2.2</b>
</body>
</html>
`))
}

View File

@ -1,67 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"net/mail"
"net/textproto"
"strings"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/emersion/go-imap"
)
// GetEnvelope will prepare envelope from pmapi message and cached header.
func GetEnvelope(msg *pmapi.Message, header textproto.MIMEHeader) *imap.Envelope {
hdr := toMessageHeader(mail.Header(header))
setMessageIDIfNeeded(msg, &hdr)
return &imap.Envelope{
Date: SanitizeMessageDate(msg.Time),
Subject: msg.Subject,
From: getAddresses([]*mail.Address{msg.Sender}),
Sender: getAddresses([]*mail.Address{msg.Sender}),
ReplyTo: getAddresses(msg.ReplyTos),
To: getAddresses(msg.ToList),
Cc: getAddresses(msg.CCList),
Bcc: getAddresses(msg.BCCList),
InReplyTo: hdr.Get("In-Reply-To"),
MessageId: hdr.Get("Message-Id"),
}
}
func getAddresses(addrs []*mail.Address) (imapAddrs []*imap.Address) {
for _, a := range addrs {
if a == nil {
continue
}
parts := strings.SplitN(a.Address, "@", 2)
if len(parts) != 2 {
continue
}
imapAddrs = append(imapAddrs, &imap.Address{
PersonalName: a.Name,
MailboxName: parts[0],
HostName: parts[1],
})
}
return
}

View File

@ -1,61 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/emersion/go-imap"
)
// Various client specific flags.
const (
AppleMailJunkFlag = "$Junk"
ThunderbirdJunkFlag = "Junk"
ThunderbirdNonJunkFlag = "NonJunk"
)
// GetFlags returns imap flags from pmapi message attributes.
func GetFlags(m *pmapi.Message) (flags []string) {
if !m.Unread {
flags = append(flags, imap.SeenFlag)
}
if !m.Has(pmapi.FlagSent) && !m.Has(pmapi.FlagReceived) {
flags = append(flags, imap.DraftFlag)
}
if m.Has(pmapi.FlagReplied) || m.Has(pmapi.FlagRepliedAll) {
flags = append(flags, imap.AnsweredFlag)
}
hasSpam := false
for _, l := range m.LabelIDs {
if l == pmapi.StarredLabel {
flags = append(flags, imap.FlaggedFlag)
}
if l == pmapi.SpamLabel {
flags = append(flags, AppleMailJunkFlag, ThunderbirdJunkFlag)
hasSpam = true
}
}
if !hasSpam {
flags = append(flags, ThunderbirdNonJunkFlag)
}
return
}

View File

@ -1,27 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"github.com/ProtonMail/go-rfc5322"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
)
func init() { //nolint:gochecknoinits
rfc5322.CharsetReader = pmmime.CharsetReader
}

View File

@ -23,8 +23,4 @@ import (
"github.com/sirupsen/logrus"
)
const (
rfc822Message = "message/rfc822"
)
var log = logrus.WithField("pkg", "pkg/message") //nolint:gochecknoglobals

View File

@ -1,83 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v2/pkg/message (interfaces: Fetcher)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
io "io"
reflect "reflect"
crypto "github.com/ProtonMail/gopenpgp/v2/crypto"
pmapi "github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
)
// MockFetcher is a mock of Fetcher interface.
type MockFetcher struct {
ctrl *gomock.Controller
recorder *MockFetcherMockRecorder
}
// MockFetcherMockRecorder is the mock recorder for MockFetcher.
type MockFetcherMockRecorder struct {
mock *MockFetcher
}
// NewMockFetcher creates a new mock instance.
func NewMockFetcher(ctrl *gomock.Controller) *MockFetcher {
mock := &MockFetcher{ctrl: ctrl}
mock.recorder = &MockFetcherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFetcher) EXPECT() *MockFetcherMockRecorder {
return m.recorder
}
// GetAttachment mocks base method.
func (m *MockFetcher) GetAttachment(arg0 context.Context, arg1 string) (io.ReadCloser, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
ret0, _ := ret[0].(io.ReadCloser)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAttachment indicates an expected call of GetAttachment.
func (mr *MockFetcherMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockFetcher)(nil).GetAttachment), arg0, arg1)
}
// GetMessage mocks base method.
func (m *MockFetcher) GetMessage(arg0 context.Context, arg1 string) (*pmapi.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessage indicates an expected call of GetMessage.
func (mr *MockFetcherMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockFetcher)(nil).GetMessage), arg0, arg1)
}
// KeyRingForAddressID mocks base method.
func (m *MockFetcher) KeyRingForAddressID(arg0 string) (*crypto.KeyRing, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyRingForAddressID", arg0)
ret0, _ := ret[0].(*crypto.KeyRing)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// KeyRingForAddressID indicates an expected call of KeyRingForAddressID.
func (mr *MockFetcherMockRecorder) KeyRingForAddressID(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyRingForAddressID", reflect.TypeOf((*MockFetcher)(nil).KeyRingForAddressID), arg0)
}

View File

@ -23,98 +23,146 @@ import (
"io"
"mime"
"net/mail"
"net/textproto"
"regexp"
"strings"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-rfc5322"
"github.com/ProtonMail/proton-bridge/v2/pkg/message/parser"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-message"
"github.com/jaytaylor/html2text"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
)
// Parse parses RAW message.
func Parse(r io.Reader) (m *pmapi.Message, mimeBody, plainBody string, attReaders []io.Reader, err error) {
defer func() {
r := recover()
if r == nil {
return
}
type MIMEBody string
err = fmt.Errorf("panic while parsing message: %v", r)
type Body string
type Message struct {
Header mail.Header
MIMEBody MIMEBody
RichBody Body
PlainBody Body
Time int64
ExternalID string
Subject string
Sender *mail.Address
ToList []*mail.Address
CCList []*mail.Address
BCCList []*mail.Address
ReplyTos []*mail.Address
MIMEType rfc822.MIMEType
Attachments []Attachment
}
func (m *Message) Recipients() []string {
var recipients []string
for _, addresses := range [][]*mail.Address{m.ToList, m.CCList, m.BCCList} {
recipients = append(recipients, xslices.Map(addresses, func(address *mail.Address) string {
return address.Address
})...)
}
return recipients
}
type Attachment struct {
Header mail.Header
Name string
ContentID string
MIMEType string
Disposition string
Data []byte
}
// Parse parses an RFC822 message.
func Parse(r io.Reader) (m Message, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic while parsing message: %v", r)
}
}()
p, err := parser.New(r)
if err != nil {
return nil, "", "", nil, errors.Wrap(err, "failed to create new parser")
return Message{}, errors.Wrap(err, "failed to create new parser")
}
m, plainBody, attReaders, err = ParserWithParser(p)
if err != nil {
return nil, "", "", nil, errors.Wrap(err, "failed to parse the message")
}
mimeBody, err = BuildMIMEBody(p)
if err != nil {
return nil, "", "", nil, errors.Wrap(err, "failed to build mime body")
}
return m, mimeBody, plainBody, attReaders, nil
return parse(p)
}
// ParserWithParser parses message from Parser without building MIME body.
func ParserWithParser(p *parser.Parser) (m *pmapi.Message, plainBody string, attReaders []io.Reader, err error) {
logrus.Trace("Parsing message")
// Parse parses an RFC822 message using an existing parser.
func ParseWithParser(p *parser.Parser) (m Message, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic while parsing message: %v", r)
}
}()
if err = convertEncodedTransferEncoding(p); err != nil {
err = errors.Wrap(err, "failed to convert encoded transfer encodings")
return
}
if err = convertForeignEncodings(p); err != nil {
err = errors.Wrap(err, "failed to convert foreign encodings")
return
}
m = pmapi.NewMessage()
if err = parseMessageHeader(m, p.Root().Header); err != nil {
err = errors.Wrap(err, "failed to parse message header")
return
}
if m.Attachments, attReaders, err = collectAttachments(p); err != nil {
err = errors.Wrap(err, "failed to collect attachments")
return
}
if m.Body, plainBody, err = buildBodies(p); err != nil {
err = errors.Wrap(err, "failed to build bodies")
return
}
if m.MIMEType, err = determineMIMEType(p); err != nil {
err = errors.Wrap(err, "failed to determine mime type")
return
}
return m, plainBody, attReaders, nil
return parse(p)
}
// BuildMIMEBody builds mime body from the parser returned by NewParser.
func BuildMIMEBody(p *parser.Parser) (mimeBody string, err error) {
mimeBodyBuffer := new(bytes.Buffer)
if err = p.NewWriter().Write(mimeBodyBuffer); err != nil {
err = errors.Wrap(err, "failed to write out mime message")
return
func parse(p *parser.Parser) (Message, error) {
if err := convertEncodedTransferEncoding(p); err != nil {
return Message{}, errors.Wrap(err, "failed to convert encoded transfer encoding")
}
return mimeBodyBuffer.String(), nil
if err := convertForeignEncodings(p); err != nil {
return Message{}, errors.Wrap(err, "failed to convert foreign encodings")
}
m, err := parseMessageHeader(p.Root().Header)
if err != nil {
return Message{}, errors.Wrap(err, "failed to parse message header")
}
atts, err := collectAttachments(p)
if err != nil {
return Message{}, errors.Wrap(err, "failed to collect attachments")
}
m.Attachments = atts
richBody, plainBody, err := buildBodies(p)
if err != nil {
return Message{}, errors.Wrap(err, "failed to build bodies")
}
mimeBody, err := buildMIMEBody(p)
if err != nil {
return Message{}, errors.Wrap(err, "failed to build mime body")
}
m.RichBody = Body(richBody)
m.PlainBody = Body(plainBody)
m.MIMEBody = MIMEBody(mimeBody)
mimeType, err := determineMIMEType(p)
if err != nil {
return Message{}, errors.Wrap(err, "failed to get mime type")
}
m.MIMEType = rfc822.MIMEType(mimeType)
return m, nil
}
// buildMIMEBody builds mime body from the parser returned by NewParser.
func buildMIMEBody(p *parser.Parser) (mimeBody string, err error) {
buf := new(bytes.Buffer)
if err := p.NewWriter().Write(buf); err != nil {
return "", fmt.Errorf("failed to write message: %w", err)
}
return buf.String(), nil
}
// convertEncodedTransferEncoding decodes any RFC2047-encoded content transfer encodings.
@ -158,33 +206,30 @@ func convertForeignEncodings(p *parser.Parser) error {
Walk()
}
func collectAttachments(p *parser.Parser) ([]*pmapi.Attachment, []io.Reader, error) {
func collectAttachments(p *parser.Parser) ([]Attachment, error) {
var (
atts []*pmapi.Attachment
data []io.Reader
atts []Attachment
err error
)
w := p.NewWalker().
RegisterContentDispositionHandler("attachment", func(p *parser.Part) error {
att, err := parseAttachment(p.Header)
att, err := parseAttachment(p.Header, p.Body)
if err != nil {
return err
}
atts = append(atts, att)
data = append(data, bytes.NewReader(p.Body))
return nil
}).
RegisterContentTypeHandler("text/calendar", func(p *parser.Part) error {
att, err := parseAttachment(p.Header)
att, err := parseAttachment(p.Header, p.Body)
if err != nil {
return err
}
atts = append(atts, att)
data = append(data, bytes.NewReader(p.Body))
return nil
}).
@ -196,22 +241,21 @@ func collectAttachments(p *parser.Parser) ([]*pmapi.Attachment, []io.Reader, err
return nil
}
att, err := parseAttachment(p.Header)
att, err := parseAttachment(p.Header, p.Body)
if err != nil {
return err
}
atts = append(atts, att)
data = append(data, bytes.NewReader(p.Body))
return nil
})
if err = w.Walk(); err != nil {
return nil, nil, err
return nil, err
}
return atts, data, nil
return atts, nil
}
// buildBodies collects all text/html and text/plain parts and returns two bodies,
@ -400,24 +444,14 @@ func getPlainBody(part *parser.Part) []byte {
}
}
func AttachPublicKey(p *parser.Parser, key, keyName string) {
h := message.Header{}
func parseMessageHeader(h message.Header) (Message, error) { //nolint:funlen
var m Message
h.Set("Content-Type", fmt.Sprintf(`application/pgp-keys; name="%v.asc"; filename="%v.asc"`, keyName, keyName))
h.Set("Content-Disposition", fmt.Sprintf(`attachment; name="%v.asc"; filename="%v.asc"`, keyName, keyName))
h.Set("Content-Transfer-Encoding", "base64")
p.Root().AddChild(&parser.Part{
Header: h,
Body: []byte(key),
})
}
func parseMessageHeader(m *pmapi.Message, h message.Header) error { //nolint:funlen
mimeHeader, err := toMailHeader(h)
if err != nil {
return err
return Message{}, err
}
m.Header = mimeHeader
fields := h.Fields()
@ -428,7 +462,7 @@ func parseMessageHeader(m *pmapi.Message, h message.Header) error { //nolint:fun
s, err := fields.Text()
if err != nil {
if s, err = pmmime.DecodeHeader(fields.Value()); err != nil {
return errors.Wrap(err, "failed to parse subject")
return Message{}, errors.Wrap(err, "failed to parse subject")
}
}
@ -437,7 +471,7 @@ func parseMessageHeader(m *pmapi.Message, h message.Header) error { //nolint:fun
case "from":
sender, err := rfc5322.ParseAddressList(fields.Value())
if err != nil {
return errors.Wrap(err, "failed to parse from")
return Message{}, errors.Wrap(err, "failed to parse from")
}
if len(sender) > 0 {
m.Sender = sender[0]
@ -446,35 +480,35 @@ func parseMessageHeader(m *pmapi.Message, h message.Header) error { //nolint:fun
case "to":
toList, err := rfc5322.ParseAddressList(fields.Value())
if err != nil {
return errors.Wrap(err, "failed to parse to")
return Message{}, errors.Wrap(err, "failed to parse to")
}
m.ToList = toList
case "reply-to":
replyTos, err := rfc5322.ParseAddressList(fields.Value())
if err != nil {
return errors.Wrap(err, "failed to parse reply-to")
return Message{}, errors.Wrap(err, "failed to parse reply-to")
}
m.ReplyTos = replyTos
case "cc":
ccList, err := rfc5322.ParseAddressList(fields.Value())
if err != nil {
return errors.Wrap(err, "failed to parse cc")
return Message{}, errors.Wrap(err, "failed to parse cc")
}
m.CCList = ccList
case "bcc":
bccList, err := rfc5322.ParseAddressList(fields.Value())
if err != nil {
return errors.Wrap(err, "failed to parse bcc")
return Message{}, errors.Wrap(err, "failed to parse bcc")
}
m.BCCList = bccList
case "date":
date, err := rfc5322.ParseDateTime(fields.Value())
if err != nil {
return errors.Wrap(err, "failed to parse date")
return Message{}, errors.Wrap(err, "failed to parse date")
}
m.Time = date.Unix()
@ -483,48 +517,47 @@ func parseMessageHeader(m *pmapi.Message, h message.Header) error { //nolint:fun
}
}
return nil
return m, nil
}
func parseAttachment(h message.Header) (*pmapi.Attachment, error) {
att := &pmapi.Attachment{}
func parseAttachment(h message.Header, body []byte) (Attachment, error) {
att := Attachment{
Data: body,
}
mimeHeader, err := toMIMEHeader(h)
mimeHeader, err := toMailHeader(h)
if err != nil {
return nil, err
return Attachment{}, err
}
att.Header = mimeHeader
mimeType, mimeTypeParams, err := h.ContentType()
if err != nil {
return nil, err
return Attachment{}, err
}
att.MIMEType = mimeType
// Prefer attachment name from filename param in content disposition.
// If not available, try to get it from name param in content type.
// Otherwise fallback to attachment.bin.
_, dispParams, dispErr := h.ContentDisposition()
if dispErr != nil {
ext, err := mime.ExtensionsByType(att.MIMEType)
if err != nil {
return nil, err
}
if disp, dispParams, err := h.ContentDisposition(); err == nil {
att.Disposition = disp
if len(ext) > 0 {
att.Name = "attachment" + ext[0]
if filename, ok := dispParams["filename"]; ok {
att.Name = filename
}
} else {
att.Name = dispParams["filename"]
}
if att.Name == "" {
att.Name = mimeTypeParams["name"]
}
if att.Name == "" && mimeType == rfc822Message {
att.Name = "message.eml"
}
if att.Name == "" {
att.Name = "attachment.bin"
if filename, ok := mimeTypeParams["name"]; ok {
att.Name = filename
} else if mimeType == string(rfc822.MessageRFC822) {
att.Name = "message.eml"
} else if ext, err := mime.ExtensionsByType(att.MIMEType); err == nil && len(ext) > 0 {
att.Name = "attachment" + ext[0]
} else {
att.Name = "attachment.bin"
}
}
// Only set ContentID if it should be inline;
@ -534,9 +567,12 @@ func parseAttachment(h message.Header) (*pmapi.Attachment, error) {
// (This is necessary because some clients don't set Content-Disposition at all,
// so we need to rely on other information to deduce if it's inline or attachment.)
if h.Has("Content-Disposition") {
if disp, _, err := h.ContentDisposition(); err != nil {
return nil, err
} else if disp == pmapi.DispositionInline {
disp, _, err := h.ContentDisposition()
if err != nil {
return Attachment{}, err
}
if disp == string(liteapi.InlineDisposition) {
att.ContentID = strings.Trim(h.Get("Content-Id"), " <>")
}
} else if h.Has("Content-Id") {
@ -559,19 +595,6 @@ func toMailHeader(h message.Header) (mail.Header, error) {
return mimeHeader, nil
}
func toMIMEHeader(h message.Header) (textproto.MIMEHeader, error) {
mimeHeader := make(textproto.MIMEHeader)
if err := forEachDecodedHeaderField(h, func(key, val string) error {
mimeHeader[key] = []string{val}
return nil
}); err != nil {
return nil, err
}
return mimeHeader, nil
}
func forEachDecodedHeaderField(h message.Header, fn func(string, string) error) error {
fields := h.Fields()

View File

@ -18,6 +18,7 @@
package parser
import (
"fmt"
"io"
"github.com/emersion/go-message"
@ -67,6 +68,19 @@ func (p *Parser) Root() *Part {
return p.root
}
func (p *Parser) AttachPublicKey(key, keyName string) {
h := message.Header{}
h.Set("Content-Type", fmt.Sprintf(`application/pgp-keys; name="%v.asc"; filename="%v.asc"`, keyName, keyName))
h.Set("Content-Disposition", fmt.Sprintf(`attachment; name="%v.asc"; filename="%v.asc"`, keyName, keyName))
h.Set("Content-Transfer-Encoding", "base64")
p.Root().AddChild(&Part{
Header: h,
Body: []byte(key),
})
}
// Section returns the message part referred to by the given section. A section
// is zero or more integers. For example, section 1.2.3 will return the third
// part of the second part of the first part of the message.

View File

@ -18,6 +18,7 @@
package message
import (
"bytes"
"image/png"
"io"
"os"
@ -33,129 +34,129 @@ import (
func TestParseLongHeaderLine(t *testing.T) {
f := getFileReader("long_header_line.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseLongHeaderLineMultiline(t *testing.T) {
f := getFileReader("long_header_line_multiline.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlain(t *testing.T) {
f := getFileReader("text_plain.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainUTF8(t *testing.T) {
f := getFileReader("text_plain_utf8.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainLatin1(t *testing.T) {
f := getFileReader("text_plain_latin1.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "ééééééé", m.Body)
assert.Equal(t, "ééééééé", plainBody)
assert.Equal(t, "ééééééé", string(m.RichBody))
assert.Equal(t, "ééééééé", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainUTF8Subject(t *testing.T) {
f := getFileReader("text_plain_utf8_subject.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, `汉字汉字汉`, m.Subject)
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainLatin2Subject(t *testing.T) {
f := getFileReader("text_plain_latin2_subject.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, `If you can read this you understand the example.`, m.Subject)
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainUnknownCharsetIsActuallyLatin1(t *testing.T) {
f := getFileReader("text_plain_unknown_latin1.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "ééééééé", m.Body)
assert.Equal(t, "ééééééé", plainBody)
assert.Equal(t, "ééééééé", string(m.RichBody))
assert.Equal(t, "ééééééé", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainUnknownCharsetIsActuallyLatin2(t *testing.T) {
f := getFileReader("text_plain_unknown_latin2.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
@ -167,97 +168,97 @@ func TestParseTextPlainUnknownCharsetIsActuallyLatin2(t *testing.T) {
expect, _ := charmap.ISO8859_1.NewDecoder().Bytes(latin2)
assert.NotEqual(t, []byte("řšřšřš"), expect)
assert.Equal(t, string(expect), m.Body)
assert.Equal(t, string(expect), plainBody)
assert.Equal(t, string(expect), string(m.RichBody))
assert.Equal(t, string(expect), string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainAlready7Bit(t *testing.T) {
f := getFileReader("text_plain_7bit.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainWithOctetAttachment(t *testing.T) {
f := getFileReader("text_plain_octet_attachment.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
require.Len(t, attReaders, 1)
assert.Equal(t, readerToString(attReaders[0]), "if you are reading this, hi!")
require.Len(t, m.Attachments, 1)
assert.Equal(t, string(m.Attachments[0].Data), "if you are reading this, hi!")
}
func TestParseTextPlainWithOctetAttachmentGoodFilename(t *testing.T) {
f := getFileReader("text_plain_octet_attachment_good_2231_filename.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 1)
assert.Equal(t, readerToString(attReaders[0]), "if you are reading this, hi!")
assert.Len(t, m.Attachments, 1)
assert.Equal(t, string(m.Attachments[0].Data), "if you are reading this, hi!")
assert.Equal(t, "😁😂.txt", m.Attachments[0].Name)
}
func TestParseTextPlainWithRFC822Attachment(t *testing.T) {
f := getFileReader("text_plain_rfc822_attachment.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 1)
assert.Len(t, m.Attachments, 1)
assert.Equal(t, "message.eml", m.Attachments[0].Name)
}
func TestParseTextPlainWithOctetAttachmentBadFilename(t *testing.T) {
f := getFileReader("text_plain_octet_attachment_bad_2231_filename.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 1)
assert.Equal(t, readerToString(attReaders[0]), "if you are reading this, hi!")
assert.Len(t, m.Attachments, 1)
assert.Equal(t, string(m.Attachments[0].Data), "if you are reading this, hi!")
assert.Equal(t, "attachment.bin", m.Attachments[0].Name)
}
func TestParseTextPlainWithOctetAttachmentNameInContentType(t *testing.T) {
f := getFileReader("text_plain_octet_attachment_name_in_contenttype.eml")
m, _, _, _, err := Parse(f) //nolint:dogsled
m, err := Parse(f) //nolint:dogsled
require.NoError(t, err)
assert.Equal(t, "attachment-contenttype.txt", m.Attachments[0].Name)
@ -266,7 +267,7 @@ func TestParseTextPlainWithOctetAttachmentNameInContentType(t *testing.T) {
func TestParseTextPlainWithOctetAttachmentNameConflict(t *testing.T) {
f := getFileReader("text_plain_octet_attachment_name_conflict.eml")
m, _, _, _, err := Parse(f) //nolint:dogsled
m, err := Parse(f) //nolint:dogsled
require.NoError(t, err)
assert.Equal(t, "attachment-disposition.txt", m.Attachments[0].Name)
@ -275,49 +276,49 @@ func TestParseTextPlainWithOctetAttachmentNameConflict(t *testing.T) {
func TestParseTextPlainWithPlainAttachment(t *testing.T) {
f := getFileReader("text_plain_plain_attachment.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
require.Len(t, attReaders, 1)
assert.Equal(t, readerToString(attReaders[0]), "attachment")
require.Len(t, m.Attachments, 1)
assert.Equal(t, string(m.Attachments[0].Data), "attachment")
}
func TestParseTextPlainEmptyAddresses(t *testing.T) {
f := getFileReader("text_plain_empty_addresses.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextPlainWithImageInline(t *testing.T) {
f := getFileReader("text_plain_image_inline.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
// The inline image is an 8x8 mic-dropping gopher.
require.Len(t, attReaders, 1)
img, err := png.DecodeConfig(attReaders[0])
require.Len(t, m.Attachments, 1)
img, err := png.DecodeConfig(bytes.NewReader(m.Attachments[0].Data))
require.NoError(t, err)
assert.Equal(t, 8, img.Width)
assert.Equal(t, 8, img.Height)
@ -326,111 +327,111 @@ func TestParseTextPlainWithImageInline(t *testing.T) {
func TestParseTextPlainWithDuplicateCharset(t *testing.T) {
f := getFileReader("text_plain_duplicate_charset.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseWithMultipleTextParts(t *testing.T) {
f := getFileReader("multiple_text_parts.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body\nsome other part of the message", m.Body)
assert.Equal(t, "body\nsome other part of the message", plainBody)
assert.Equal(t, "body\nsome other part of the message", string(m.RichBody))
assert.Equal(t, "body\nsome other part of the message", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextHTML(t *testing.T) {
f := getFileReader("text_html.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> without attachment</body></html>", m.Body)
assert.Equal(t, "This is body of *HTML mail* without attachment", plainBody)
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> without attachment</body></html>", string(m.RichBody))
assert.Equal(t, "This is body of *HTML mail* without attachment", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextHTMLAlready7Bit(t *testing.T) {
f := getFileReader("text_html_7bit.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
assert.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> without attachment</body></html>", m.Body)
assert.Equal(t, "This is body of *HTML mail* without attachment", plainBody)
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> without attachment</body></html>", string(m.RichBody))
assert.Equal(t, "This is body of *HTML mail* without attachment", string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseTextHTMLWithOctetAttachment(t *testing.T) {
f := getFileReader("text_html_octet_attachment.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> with attachment</body></html>", m.Body)
assert.Equal(t, "This is body of *HTML mail* with attachment", plainBody)
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> with attachment</body></html>", string(m.RichBody))
assert.Equal(t, "This is body of *HTML mail* with attachment", string(m.PlainBody))
require.Len(t, attReaders, 1)
assert.Equal(t, readerToString(attReaders[0]), "if you are reading this, hi!")
require.Len(t, m.Attachments, 1)
assert.Equal(t, string(m.Attachments[0].Data), "if you are reading this, hi!")
}
func TestParseTextHTMLWithPlainAttachment(t *testing.T) {
f := getFileReader("text_html_plain_attachment.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
// BAD: plainBody should not be empty!
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> with attachment</body></html>", m.Body)
assert.Equal(t, "This is body of *HTML mail* with attachment", plainBody)
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> with attachment</body></html>", string(m.RichBody))
assert.Equal(t, "This is body of *HTML mail* with attachment", string(m.PlainBody))
require.Len(t, attReaders, 1)
assert.Equal(t, readerToString(attReaders[0]), "attachment")
require.Len(t, m.Attachments, 1)
assert.Equal(t, string(m.Attachments[0].Data), "attachment")
}
func TestParseTextHTMLWithImageInline(t *testing.T) {
f := getFileReader("text_html_image_inline.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
assert.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> with attachment</body></html>", m.Body)
assert.Equal(t, "This is body of *HTML mail* with attachment", plainBody)
assert.Equal(t, "<html><head></head><body>This is body of <b>HTML mail</b> with attachment</body></html>", string(m.RichBody))
assert.Equal(t, "This is body of *HTML mail* with attachment", string(m.PlainBody))
// The inline image is an 8x8 mic-dropping gopher.
require.Len(t, attReaders, 1)
img, err := png.DecodeConfig(attReaders[0])
require.Len(t, m.Attachments, 1)
img, err := png.DecodeConfig(bytes.NewReader(m.Attachments[0].Data))
require.NoError(t, err)
assert.Equal(t, 8, img.Width)
assert.Equal(t, 8, img.Height)
@ -441,40 +442,42 @@ func TestParseWithAttachedPublicKey(t *testing.T) {
p, err := parser.New(f)
require.NoError(t, err)
m, plainBody, attReaders, err := ParserWithParser(p)
AttachPublicKey(p, "publickey", "publickeyname")
m, err := ParseWithParser(p)
require.NoError(t, err)
p.AttachPublicKey("publickey", "publickeyname")
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, "body", m.Body)
assert.Equal(t, "body", plainBody)
assert.Equal(t, "body", string(m.RichBody))
assert.Equal(t, "body", string(m.PlainBody))
// The pubkey should not be collected as an attachment.
// We upload the pubkey when creating the draft.
require.Len(t, attReaders, 0)
require.Len(t, m.Attachments, 0)
}
func TestParseTextHTMLWithEmbeddedForeignEncoding(t *testing.T) {
f := getFileReader("text_html_embedded_foreign_encoding.eml")
m, _, plainBody, attReaders, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@pm.me>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@pm.me>`, m.ToList[0].String())
assert.Equal(t, `<html><head><meta charset="UTF-8"/></head><body>latin2 řšřš</body></html>`, m.Body)
assert.Equal(t, `latin2 řšřš`, plainBody)
assert.Equal(t, `<html><head><meta charset="UTF-8"/></head><body>latin2 řšřš</body></html>`, string(m.RichBody))
assert.Equal(t, `latin2 řšřš`, string(m.PlainBody))
assert.Len(t, attReaders, 0)
assert.Len(t, m.Attachments, 0)
}
func TestParseMultipartAlternative(t *testing.T) {
f := getFileReader("multipart_alternative.eml")
m, _, plainBody, _, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"schizofrenic" <schizofrenic@pm.me>`, m.Sender.String())
@ -487,15 +490,15 @@ func TestParseMultipartAlternative(t *testing.T) {
<b>aoeuaoeu</b>
</body></html>`, m.Body)
</body></html>`, string(m.RichBody))
assert.Equal(t, "*aoeuaoeu*\n\n", plainBody)
assert.Equal(t, "*aoeuaoeu*\n\n", string(m.PlainBody))
}
func TestParseMultipartAlternativeNested(t *testing.T) {
f := getFileReader("multipart_alternative_nested.eml")
m, _, plainBody, _, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"schizofrenic" <schizofrenic@pm.me>`, m.Sender.String())
@ -508,15 +511,15 @@ func TestParseMultipartAlternativeNested(t *testing.T) {
<b>multipart 2.2</b>
</body></html>`, m.Body)
</body></html>`, string(m.RichBody))
assert.Equal(t, "*multipart 2.1*\n\n", plainBody)
assert.Equal(t, "*multipart 2.1*\n\n", string(m.PlainBody))
}
func TestParseMultipartAlternativeLatin1(t *testing.T) {
f := getFileReader("multipart_alternative_latin1.eml")
m, _, plainBody, _, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"schizofrenic" <schizofrenic@pm.me>`, m.Sender.String())
@ -529,52 +532,52 @@ func TestParseMultipartAlternativeLatin1(t *testing.T) {
<b>aoeuaoeu</b>
</body></html>`, m.Body)
</body></html>`, string(m.RichBody))
assert.Equal(t, "*aoeuaoeu*\n\n", plainBody)
assert.Equal(t, "*aoeuaoeu*\n\n", string(m.PlainBody))
}
func TestParseWithTrailingEndOfMailIndicator(t *testing.T) {
f := getFileReader("text_html_trailing_end_of_mail.eml")
m, _, plainBody, _, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@sender.com>`, m.Sender.String())
assert.Equal(t, `"Receiver" <receiver@receiver.com>`, m.ToList[0].String())
assert.Equal(t, "<!DOCTYPE html><html><head></head><body>boo!</body></html>", m.Body)
assert.Equal(t, "boo!", plainBody)
assert.Equal(t, "<!DOCTYPE html><html><head></head><body>boo!</body></html>", string(m.RichBody))
assert.Equal(t, "boo!", string(m.PlainBody))
}
func TestParseEncodedContentType(t *testing.T) {
f := getFileReader("rfc2047-content-transfer-encoding.eml")
m, _, plainBody, _, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@sender.com>`, m.Sender.String())
assert.Equal(t, `<user@somewhere.org>`, m.ToList[0].String())
assert.Equal(t, "bodybodybody\n", plainBody)
assert.Equal(t, "bodybodybody\n", string(m.PlainBody))
}
func TestParseNonEncodedContentType(t *testing.T) {
f := getFileReader("non-encoded-content-transfer-encoding.eml")
m, _, plainBody, _, err := Parse(f)
m, err := Parse(f)
require.NoError(t, err)
assert.Equal(t, `"Sender" <sender@sender.com>`, m.Sender.String())
assert.Equal(t, `<user@somewhere.org>`, m.ToList[0].String())
assert.Equal(t, "bodybodybody\n", plainBody)
assert.Equal(t, "bodybodybody\n", string(m.PlainBody))
}
func TestParseEncodedContentTypeBad(t *testing.T) {
f := getFileReader("rfc2047-content-transfer-encoding-bad.eml")
_, _, _, _, err := Parse(f) //nolint:dogsled
_, err := Parse(f) //nolint:dogsled
require.Error(t, err)
}
@ -587,7 +590,7 @@ func (panicReader) Read(p []byte) (int, error) {
func TestParsePanic(t *testing.T) {
var err error
require.NotPanics(t, func() {
_, _, _, _, err = Parse(&panicReader{})
_, err = Parse(&panicReader{})
})
require.Error(t, err)
}
@ -600,12 +603,3 @@ func getFileReader(filename string) io.Reader {
return f
}
func readerToString(r io.Reader) string {
b, err := io.ReadAll(r)
if err != nil {
panic(err)
}
return string(b)
}

View File

@ -1,96 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bufio"
"bytes"
"errors"
"io"
)
type partScanner struct {
r *bufio.Reader
boundary string
progress int
}
type part struct {
b []byte
offset int
}
func newPartScanner(r io.Reader, boundary string) (*partScanner, error) {
scanner := &partScanner{r: bufio.NewReader(r), boundary: boundary}
if _, _, err := scanner.readToBoundary(); err != nil {
return nil, err
}
return scanner, nil
}
func (s *partScanner) scanAll() ([]part, error) {
var parts []part
for {
offset := s.progress
b, more, err := s.readToBoundary()
if err != nil {
return nil, err
}
if !more {
return parts, nil
}
parts = append(parts, part{b: b, offset: offset})
}
}
func (s *partScanner) readToBoundary() ([]byte, bool, error) {
var res []byte
for {
line, err := s.r.ReadBytes('\n')
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, false, err
}
if len(line) == 0 {
return nil, false, nil
}
}
s.progress += len(line)
switch {
case bytes.HasPrefix(bytes.TrimSpace(line), []byte("--"+s.boundary)):
return bytes.TrimSuffix(bytes.TrimSuffix(res, []byte("\n")), []byte("\r")), true, nil
case bytes.HasSuffix(bytes.TrimSpace(line), []byte(s.boundary+"--")):
return bytes.TrimSuffix(bytes.TrimSuffix(res, []byte("\n")), []byte("\r")), false, nil
default:
res = append(res, line...)
}
}
}

View File

@ -1,136 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScanner(t *testing.T) {
const literal = `this part of the text should be ignored
--longrandomstring
body1
--longrandomstring
body2
--longrandomstring--
`
scanner, err := newPartScanner(strings.NewReader(literal), "longrandomstring")
require.NoError(t, err)
parts, err := scanner.scanAll()
require.NoError(t, err)
assert.Equal(t, "\nbody1\n", string(parts[0].b))
assert.Equal(t, "\nbody2\n", string(parts[1].b))
assert.Equal(t, "\nbody1\n", literal[parts[0].offset:parts[0].offset+len(parts[0].b)])
assert.Equal(t, "\nbody2\n", literal[parts[1].offset:parts[1].offset+len(parts[1].b)])
}
func TestScannerNested(t *testing.T) {
const literal = `This is the preamble. It is to be ignored, though it
is a handy place for mail composers to include an
explanatory note to non-MIME compliant readers.
--simple boundary
Content-type: multipart/mixed; boundary="nested boundary"
This is the preamble. It is to be ignored, though it
is a handy place for mail composers to include an
explanatory note to non-MIME compliant readers.
--nested boundary
Content-type: text/plain; charset=us-ascii
This part does not end with a linebreak.
--nested boundary
Content-type: text/plain; charset=us-ascii
This part does end with a linebreak.
--nested boundary--
--simple boundary
Content-type: text/plain; charset=us-ascii
This part does end with a linebreak.
--simple boundary--
This is the epilogue. It is also to be ignored.
`
scanner, err := newPartScanner(strings.NewReader(literal), "simple boundary")
require.NoError(t, err)
parts, err := scanner.scanAll()
require.NoError(t, err)
assert.Equal(t, `Content-type: multipart/mixed; boundary="nested boundary"
This is the preamble. It is to be ignored, though it
is a handy place for mail composers to include an
explanatory note to non-MIME compliant readers.
--nested boundary
Content-type: text/plain; charset=us-ascii
This part does not end with a linebreak.
--nested boundary
Content-type: text/plain; charset=us-ascii
This part does end with a linebreak.
--nested boundary--`, string(parts[0].b))
assert.Equal(t, `Content-type: text/plain; charset=us-ascii
This part does end with a linebreak.
`, string(parts[1].b))
}
func TestScannerNoFinalLinebreak(t *testing.T) {
const literal = `--nested boundary
Content-type: text/plain; charset=us-ascii
This part does not end with a linebreak.
--nested boundary
Content-type: text/plain; charset=us-ascii
This part does end with a linebreak.
--nested boundary--`
scanner, err := newPartScanner(strings.NewReader(literal), "nested boundary")
require.NoError(t, err)
parts, err := scanner.scanAll()
require.NoError(t, err)
assert.Equal(t, `Content-type: text/plain; charset=us-ascii
This part does not end with a linebreak.`, string(parts[0].b))
assert.Equal(t, `Content-type: text/plain; charset=us-ascii
This part does end with a linebreak.
`, string(parts[1].b))
}

View File

@ -1,395 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bufio"
"bytes"
"io"
"net/textproto"
"strconv"
"strings"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
"github.com/emersion/go-imap"
"github.com/pkg/errors"
"github.com/vmihailenco/msgpack/v5"
)
// BodyStructure is used to parse an email into MIME sections and then generate
// body structure for IMAP server.
type BodyStructure map[string]*SectionInfo
// SectionInfo is used to hold data about parts of each section.
type SectionInfo struct {
Header []byte
Start, BSize, Size, Lines int
reader io.Reader
isHeaderReadFinished bool
}
// Read will also count the final size of section.
func (si *SectionInfo) Read(p []byte) (n int, err error) {
n, err = si.reader.Read(p)
si.Size += n
si.Lines += bytes.Count(p, []byte("\n"))
si.readHeader(p)
return
}
// readHeader appends read data to Header until empty line is found.
func (si *SectionInfo) readHeader(p []byte) {
if si.isHeaderReadFinished {
return
}
si.Header = append(si.Header, p...)
if i := bytes.Index(si.Header, []byte("\n\r\n")); i > 0 {
si.Header = si.Header[:i+3]
si.isHeaderReadFinished = true
return
}
// textproto works also with simple line ending so we should be liberal
// as well.
if i := bytes.Index(si.Header, []byte("\n\n")); i > 0 {
si.Header = si.Header[:i+2]
si.isHeaderReadFinished = true
}
}
// GetMIMEHeader parses bytes and return MIME header.
func (si *SectionInfo) GetMIMEHeader() (textproto.MIMEHeader, error) {
return textproto.NewReader(bufio.NewReader(bytes.NewReader(si.Header))).ReadMIMEHeader()
}
func NewBodyStructure(reader io.Reader) (structure *BodyStructure, err error) {
structure = &BodyStructure{}
err = structure.Parse(reader)
return
}
// DeserializeBodyStructure will create new structure from msgpack bytes.
func DeserializeBodyStructure(raw []byte) (*BodyStructure, error) {
bs := &BodyStructure{}
err := msgpack.Unmarshal(raw, bs)
if err != nil {
return nil, errors.Wrap(err, "cannot deserialize bodystructure")
}
return bs, err
}
// Serialize will write msgpack bytes.
func (bs *BodyStructure) Serialize() ([]byte, error) {
data, err := msgpack.Marshal(bs)
if err != nil {
return nil, errors.Wrap(err, "cannot serialize bodystructure")
}
return data, nil
}
// Parse will read the mail and create all body structures.
func (bs *BodyStructure) Parse(r io.Reader) error {
return bs.parseAllChildSections(r, []int{}, 0)
}
func (bs *BodyStructure) parseAllChildSections(r io.Reader, currentPath []int, start int) (err error) { //nolint:funlen
info := &SectionInfo{
Start: start,
Size: 0,
BSize: 0,
Lines: 0,
reader: r,
}
bufInfo := bufio.NewReader(info)
tp := textproto.NewReader(bufInfo)
tpHeader, err := tp.ReadMIMEHeader()
if err != nil {
return
}
bodyInfo := &SectionInfo{reader: tp.R}
bodyReader := bufio.NewReader(bodyInfo)
mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
// If multipart, call getAllParts, else read to count lines.
if (strings.HasPrefix(mediaType, "multipart/") || mediaType == rfc822Message) && params["boundary"] != "" {
nextPath := getChildPath(currentPath)
var br *boundaryReader
br, err = newBoundaryReader(bodyReader, params["boundary"])
// New reader seeks first boundary.
if err != nil {
// Return also EOF.
return
}
for err == nil {
start += br.skipped
part := &bytes.Buffer{}
err = br.writeNextPartTo(part)
if err != nil {
break
}
err = bs.parseAllChildSections(part, nextPath, start)
part.Reset()
nextPath[len(nextPath)-1]++
}
br.reader = nil
if err == io.EOF {
err = nil
}
if err != nil {
return
}
} else {
// Count length.
_, _ = bodyReader.WriteTo(io.Discard)
}
// Clear all buffers.
bodyReader = nil //nolint:wastedassign // just to be sure we clear garbage collector
bodyInfo.reader = nil
tp.R = nil
tp = nil //nolint:wastedassign // just to be sure we clear garbage collector
bufInfo = nil //nolint:ineffassign,wastedassign // just to be sure we clear garbage collector
info.reader = nil
// Store boundaries.
info.BSize = bodyInfo.Size
path := stringPathFromInts(currentPath)
(*bs)[path] = info
// Fix start of subsections.
newPath := getChildPath(currentPath)
shift := info.Size - info.BSize
subInfo, err := bs.getInfo(newPath)
// If it has subparts.
for err == nil {
subInfo.Start += shift
// Level down.
subInfo, err = bs.getInfo(append(newPath, 1))
if err == nil {
newPath = append(newPath, 1)
continue
}
// Next.
newPath[len(newPath)-1]++
subInfo, err = bs.getInfo(newPath)
if err == nil {
continue
}
// Level up.
for {
newPath = newPath[:len(newPath)-1]
if len(newPath) > 0 {
newPath[len(newPath)-1]++
subInfo, err = bs.getInfo(newPath)
if err != nil {
err = nil
continue
}
}
break
}
// The end.
if len(newPath) == 0 {
break
}
}
return nil
}
// getChildPath will return the first child path of parent path.
// NOTE: Return value can be used to iterate over parts so it is necessary to
// copy parrent values in order to not rewrite values in parent.
func getChildPath(parent []int) []int {
// append alloc inline is the fasted way to copy
return append(append(make([]int, 0, len(parent)+1), parent...), 1)
}
func stringPathFromInts(ints []int) (ret string) {
for i, n := range ints {
if i != 0 {
ret += "."
}
ret += strconv.Itoa(n)
}
return
}
func (bs *BodyStructure) hasInfo(sectionPath []int) bool {
_, err := bs.getInfo(sectionPath)
return err == nil
}
func (bs *BodyStructure) getInfoCheckSection(sectionPath []int) (sectionInfo *SectionInfo, err error) {
if len(*bs) == 1 && len(sectionPath) == 1 && sectionPath[0] == 1 {
sectionPath = []int{}
}
return bs.getInfo(sectionPath)
}
func (bs *BodyStructure) getInfo(sectionPath []int) (sectionInfo *SectionInfo, err error) {
path := stringPathFromInts(sectionPath)
sectionInfo, ok := (*bs)[path]
if !ok {
err = errors.New("wrong section " + path)
}
return
}
// GetSection returns bytes of section including MIME header.
func (bs *BodyStructure) GetSection(wholeMail io.ReadSeeker, sectionPath []int) (section []byte, err error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
return
}
return goToOffsetAndReadNBytes(wholeMail, info.Start, info.Size)
}
// GetSectionContent returns bytes of section content (excluding MIME header).
func (bs *BodyStructure) GetSectionContent(wholeMail io.ReadSeeker, sectionPath []int) (section []byte, err error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
return
}
return goToOffsetAndReadNBytes(wholeMail, info.Start+info.Size-info.BSize, info.BSize)
}
// GetMailHeader returns the main header of mail.
func (bs *BodyStructure) GetMailHeader() (header textproto.MIMEHeader, err error) {
return bs.GetSectionHeader([]int{})
}
// GetMailHeaderBytes returns the bytes with main mail header.
// Warning: It can contain extra lines.
func (bs *BodyStructure) GetMailHeaderBytes() (header []byte, err error) {
return bs.GetSectionHeaderBytes([]int{})
}
func goToOffsetAndReadNBytes(wholeMail io.ReadSeeker, offset, length int) ([]byte, error) {
if length == 0 {
return []byte{}, nil
}
if length < 0 {
return nil, errors.New("requested negative length")
}
if offset > 0 {
if _, err := wholeMail.Seek(int64(offset), io.SeekStart); err != nil {
return nil, err
}
}
out := make([]byte, length)
_, err := wholeMail.Read(out)
return out, err
}
// GetSectionHeader returns the mime header of specified section.
func (bs *BodyStructure) GetSectionHeader(sectionPath []int) (textproto.MIMEHeader, error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
return nil, err
}
return info.GetMIMEHeader()
}
// GetSectionHeaderBytes returns raw header bytes of specified section.
func (bs *BodyStructure) GetSectionHeaderBytes(sectionPath []int) ([]byte, error) {
info, err := bs.getInfoCheckSection(sectionPath)
if err != nil {
return nil, err
}
return info.Header, nil
}
// IMAPBodyStructure will prepare imap bodystructure recurently for given part.
// Use empty path to create whole email structure.
func (bs *BodyStructure) IMAPBodyStructure(currentPart []int) (imapBS *imap.BodyStructure, err error) {
var info *SectionInfo
if info, err = bs.getInfo(currentPart); err != nil {
return
}
tpHeader, err := info.GetMIMEHeader()
if err != nil {
return
}
mediaType, params, _ := pmmime.ParseMediaType(tpHeader.Get("Content-Type"))
mediaTypeSep := strings.Split(mediaType, "/")
// If it is empty or missing it will not crash.
mediaTypeSep = append(mediaTypeSep, "")
imapBS = &imap.BodyStructure{
MIMEType: mediaTypeSep[0],
MIMESubType: mediaTypeSep[1],
Params: params,
Size: uint32(info.BSize),
Lines: uint32(info.Lines),
}
if val := tpHeader.Get("Content-ID"); val != "" {
imapBS.Id = val
}
if val := tpHeader.Get("Content-Transfer-Encoding"); val != "" {
imapBS.Encoding = val
}
if val := tpHeader.Get("Content-Description"); val != "" {
imapBS.Description = val
}
if val := tpHeader.Get("Content-Disposition"); val != "" {
imapBS.Disposition = val
}
nextPart := append(currentPart, 1) //nolint:gocritic
for {
if !bs.hasInfo(nextPart) {
break
}
var subStruct *imap.BodyStructure
subStruct, err = bs.IMAPBodyStructure(nextPart)
if err != nil {
return
}
if imapBS.Parts == nil {
imapBS.Parts = []*imap.BodyStructure{}
}
imapBS.Parts = append(imapBS.Parts, subStruct)
nextPart[len(nextPart)-1]++
}
return imapBS, nil
}

View File

@ -1,599 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"bytes"
"fmt"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
var enableDebug = false //nolint:global
func debug(msg string, v ...interface{}) {
if !enableDebug {
return
}
_, file, line, _ := runtime.Caller(1)
fmt.Printf("%s:%d: \033[2;33m"+msg+"\033[0;39m\n", append([]interface{}{filepath.Base(file), line}, v...)...)
}
func TestParseBodyStructure(t *testing.T) {
expectedStructure := map[string]string{
"": "multipart/mixed; boundary=\"0000MAIN\"",
"1": "text/plain",
"2": "application/octet-stream",
"3": "message/rfc822; boundary=\"0003MSG\"",
"3.1": "text/plain",
"3.2": "application/octet-stream",
"4": "multipart/mixed; boundary=\"0004ATTACH\"",
"4.1": "image/gif",
"4.2": "message/rfc822; boundary=\"0042MSG\"",
"4.2.1": "text/plain",
"4.2.2": "multipart/alternative; boundary=\"0422ALTER\"",
"4.2.2.1": "text/plain",
"4.2.2.2": "text/html",
}
mailReader := strings.NewReader(sampleMail)
bs, err := NewBodyStructure(mailReader)
require.NoError(t, err)
paths := []string{}
for path := range *bs {
paths = append(paths, path)
}
sort.Strings(paths)
debug("%10s: %-50s %5s %5s %5s %5s", "section", "type", "start", "size", "bsize", "lines")
for _, path := range paths {
sec := (*bs)[path]
header, err := sec.GetMIMEHeader()
require.NoError(t, err)
contentType := header.Get("Content-Type")
debug("%10s: %-50s %5d %5d %5d %5d", path, contentType, sec.Start, sec.Size, sec.BSize, sec.Lines)
require.Equal(t, expectedStructure[path], contentType)
}
require.True(t, len(*bs) == len(expectedStructure), "Wrong number of sections expected %d but have %d", len(expectedStructure), len(*bs))
}
func TestParseBodyStructurePGP(t *testing.T) {
expectedStructure := map[string]string{
"": "multipart/signed; micalg=pgp-sha256; protocol=\"application/pgp-signature\"; boundary=\"MHEDFShwcX18dyE3X7RXujo5fjpgdjHNM\"",
"1": "multipart/mixed; boundary=\"FBBl2LNv76z8UkvHhSkT9vLwVwxqV8378\"; protected-headers=\"v1\"",
"1.1": "multipart/mixed; boundary=\"------------F97C8ED4878E94675762AE43\"",
"1.1.1": "multipart/alternative; boundary=\"------------041318B15DD3FA540FED32C6\"",
"1.1.1.1": "text/plain; charset=utf-8; format=flowed",
"1.1.1.2": "text/html; charset=utf-8",
"1.1.2": "application/pdf; name=\"minimal.pdf\"",
"1.1.3": "application/pgp-keys; name=\"OpenPGP_0x161C0875822359F7.asc\"",
"2": "application/pgp-signature; name=\"OpenPGP_signature.asc\"",
}
b, err := os.ReadFile("testdata/enc-body-structure.eml")
require.NoError(t, err)
bs, err := NewBodyStructure(bytes.NewReader(b))
require.NoError(t, err)
haveStructure := map[string]string{}
for path := range *bs {
header, err := (*bs)[path].GetMIMEHeader()
require.NoError(t, err)
haveStructure[path] = header.Get("Content-Type")
}
require.Equal(t, expectedStructure, haveStructure)
}
func TestGetSection(t *testing.T) {
structReader := strings.NewReader(sampleMail)
bs, err := NewBodyStructure(structReader)
require.NoError(t, err)
// Bad paths
wantPaths := [][]int{{0}, {-1}, {3, 2, 3}}
for _, wantPath := range wantPaths {
_, err = bs.getInfo(wantPath)
require.Error(t, err, "path %v", wantPath)
}
// Whole section.
for _, try := range testPaths {
mailReader := strings.NewReader(sampleMail)
info, err := bs.getInfo(try.path)
require.NoError(t, err)
section, err := bs.GetSection(mailReader, try.path)
require.NoError(t, err)
debug("section %v: %d %d\n___\n%s\n‾‾‾\n", try.path, info.Start, info.Size, string(section))
require.True(t, string(section) == try.expectedSection, "not same as expected:\n___\n%s\n‾‾‾", try.expectedSection)
}
// Body content.
for _, try := range testPaths {
mailReader := strings.NewReader(sampleMail)
info, err := bs.getInfo(try.path)
require.NoError(t, err)
section, err := bs.GetSectionContent(mailReader, try.path)
require.NoError(t, err)
debug("content %v: %d %d\n___\n%s\n‾‾‾\n", try.path, info.Start+info.Size-info.BSize, info.BSize, string(section))
require.True(t, string(section) == try.expectedBody, "not same as expected:\n___\n%s\n‾‾‾", try.expectedBody)
}
}
func TestGetSecionNoMIMEParts(t *testing.T) {
wantBody := "This is just a simple mail with no multipart structure.\n"
wantHeader := `Subject: Sample mail
From: John Doe <jdoe@machine.example>
To: Mary Smith <mary@example.net>
Date: Fri, 21 Nov 1997 09:55:06 -0600
Content-Type: plain/text
`
wantMail := wantHeader + wantBody
r := require.New(t)
bs, err := NewBodyStructure(strings.NewReader(wantMail))
r.NoError(err)
// Bad parts
wantPaths := [][]int{{0}, {2}, {1, 2, 3}}
for _, wantPath := range wantPaths {
_, err = bs.getInfoCheckSection(wantPath)
r.Error(err, "path %v: %d %d\n__\n%s\n", wantPath)
}
debug := func(wantPath []int, info *SectionInfo, section []byte) string {
if info == nil {
info = &SectionInfo{}
}
return fmt.Sprintf("path %v %q: %d %d\n___\n%s\n‾‾‾\n",
wantPath, stringPathFromInts(wantPath), info.Start, info.Size,
string(section),
)
}
// Ok Parts
wantPaths = [][]int{{}, {1}}
for _, p := range wantPaths {
wantPath := append([]int{}, p...)
info, err := bs.getInfoCheckSection(wantPath)
r.NoError(err, debug(wantPath, info, []byte{}))
section, err := bs.GetSection(strings.NewReader(wantMail), wantPath)
r.NoError(err, debug(wantPath, info, section))
r.Equal(wantMail, string(section), debug(wantPath, info, section))
haveBody, err := bs.GetSectionContent(strings.NewReader(wantMail), wantPath)
r.NoError(err, debug(wantPath, info, haveBody))
r.Equal(wantBody, string(haveBody), debug(wantPath, info, haveBody))
haveHeader, err := bs.GetSectionHeaderBytes(wantPath)
r.NoError(err, debug(wantPath, info, haveHeader))
r.Equal(wantHeader, string(haveHeader), debug(wantPath, info, haveHeader))
}
}
func TestGetMainHeaderBytes(t *testing.T) {
wantHeader := []byte(`Subject: Sample mail
From: John Doe <jdoe@machine.example>
To: Mary Smith <mary@example.net>
Date: Fri, 21 Nov 1997 09:55:06 -0600
Content-Type: multipart/mixed; boundary="0000MAIN"
`)
structReader := strings.NewReader(sampleMail)
bs, err := NewBodyStructure(structReader)
require.NoError(t, err)
haveHeader, err := bs.GetMailHeaderBytes()
require.NoError(t, err)
require.Equal(t, wantHeader, haveHeader)
}
/* Structure example:
HEADER ([RFC-2822] header of the message)
TEXT ([RFC-2822] text body of the message) MULTIPART/MIXED
1 TEXT/PLAIN
2 APPLICATION/OCTET-STREAM
3 MESSAGE/RFC822
3.HEADER ([RFC-2822] header of the message)
3.TEXT ([RFC-2822] text body of the message) MULTIPART/MIXED
3.1 TEXT/PLAIN
3.2 APPLICATION/OCTET-STREAM
4 MULTIPART/MIXED
4.1 IMAGE/GIF
4.1.MIME ([MIME-IMB] header for the IMAGE/GIF)
4.2 MESSAGE/RFC822
4.2.HEADER ([RFC-2822] header of the message)
4.2.TEXT ([RFC-2822] text body of the message) MULTIPART/MIXED
4.2.1 TEXT/PLAIN
4.2.2 MULTIPART/ALTERNATIVE
4.2.2.1 TEXT/PLAIN
4.2.2.2 TEXT/RICHTEXT
*/
var sampleMail = `Subject: Sample mail
From: John Doe <jdoe@machine.example>
To: Mary Smith <mary@example.net>
Date: Fri, 21 Nov 1997 09:55:06 -0600
Content-Type: multipart/mixed; boundary="0000MAIN"
main summary
--0000MAIN
Content-Type: text/plain
1. main message
--0000MAIN
Content-Type: application/octet-stream
Content-Disposition: inline; filename="main_signature.sig"
Content-Transfer-Encoding: base64
2/MainOctetStream
--0000MAIN
Subject: Inside mail 3
From: Mary Smith <mary@example.net>
To: John Doe <jdoe@machine.example>
Date: Fri, 20 Nov 1997 09:55:06 -0600
Content-Type: message/rfc822; boundary="0003MSG"
3. message summary
--0003MSG
Content-Type: text/plain
3.1 message text
--0003MSG
Content-Type: application/octet-stream
Content-Disposition: attachment; filename="msg_3_signature.sig"
Content-Transfer-Encoding: base64
3/2/MessageOctestStream/==
--0003MSG--
--0000MAIN
Content-Type: multipart/mixed; boundary="0004ATTACH"
4 attach summary
--0004ATTACH
Content-Type: image/gif
Content-Disposition: attachment; filename="att4.1_gif.sig"
Content-Transfer-Encoding: base64
4/1/Gif=
--0004ATTACH
Subject: Inside mail 4.2
From: Mary Smith <mary@example.net>
To: John Doe <jdoe@machine.example>
Date: Fri, 10 Nov 1997 09:55:06 -0600
Content-Type: message/rfc822; boundary="0042MSG"
4.2 message summary
--0042MSG
Content-Type: text/plain
4.2.1 message text
--0042MSG
Content-Type: multipart/alternative; boundary="0422ALTER"
4.2.2 alternative summary
--0422ALTER
Content-Type: text/plain
4.2.2.1 plain text
--0422ALTER
Content-Type: text/html
<h1>4.2.2.2 html text</h1>
--0422ALTER--
--0042MSG--
--0004ATTACH--
--0000MAIN--
`
var testPaths = []struct {
path []int
expectedSection, expectedBody string
}{
{
[]int{},
sampleMail,
`main summary
--0000MAIN
Content-Type: text/plain
1. main message
--0000MAIN
Content-Type: application/octet-stream
Content-Disposition: inline; filename="main_signature.sig"
Content-Transfer-Encoding: base64
2/MainOctetStream
--0000MAIN
Subject: Inside mail 3
From: Mary Smith <mary@example.net>
To: John Doe <jdoe@machine.example>
Date: Fri, 20 Nov 1997 09:55:06 -0600
Content-Type: message/rfc822; boundary="0003MSG"
3. message summary
--0003MSG
Content-Type: text/plain
3.1 message text
--0003MSG
Content-Type: application/octet-stream
Content-Disposition: attachment; filename="msg_3_signature.sig"
Content-Transfer-Encoding: base64
3/2/MessageOctestStream/==
--0003MSG--
--0000MAIN
Content-Type: multipart/mixed; boundary="0004ATTACH"
4 attach summary
--0004ATTACH
Content-Type: image/gif
Content-Disposition: attachment; filename="att4.1_gif.sig"
Content-Transfer-Encoding: base64
4/1/Gif=
--0004ATTACH
Subject: Inside mail 4.2
From: Mary Smith <mary@example.net>
To: John Doe <jdoe@machine.example>
Date: Fri, 10 Nov 1997 09:55:06 -0600
Content-Type: message/rfc822; boundary="0042MSG"
4.2 message summary
--0042MSG
Content-Type: text/plain
4.2.1 message text
--0042MSG
Content-Type: multipart/alternative; boundary="0422ALTER"
4.2.2 alternative summary
--0422ALTER
Content-Type: text/plain
4.2.2.1 plain text
--0422ALTER
Content-Type: text/html
<h1>4.2.2.2 html text</h1>
--0422ALTER--
--0042MSG--
--0004ATTACH--
--0000MAIN--
`,
},
{
[]int{1},
`Content-Type: text/plain
1. main message
`,
`1. main message
`,
},
{
[]int{3},
`Subject: Inside mail 3
From: Mary Smith <mary@example.net>
To: John Doe <jdoe@machine.example>
Date: Fri, 20 Nov 1997 09:55:06 -0600
Content-Type: message/rfc822; boundary="0003MSG"
3. message summary
--0003MSG
Content-Type: text/plain
3.1 message text
--0003MSG
Content-Type: application/octet-stream
Content-Disposition: attachment; filename="msg_3_signature.sig"
Content-Transfer-Encoding: base64
3/2/MessageOctestStream/==
--0003MSG--
`,
`3. message summary
--0003MSG
Content-Type: text/plain
3.1 message text
--0003MSG
Content-Type: application/octet-stream
Content-Disposition: attachment; filename="msg_3_signature.sig"
Content-Transfer-Encoding: base64
3/2/MessageOctestStream/==
--0003MSG--
`,
},
{
[]int{3, 1},
`Content-Type: text/plain
3.1 message text
`,
`3.1 message text
`,
},
{
[]int{3, 2},
`Content-Type: application/octet-stream
Content-Disposition: attachment; filename="msg_3_signature.sig"
Content-Transfer-Encoding: base64
3/2/MessageOctestStream/==
`,
`3/2/MessageOctestStream/==
`,
},
{
[]int{4, 2, 2, 1},
`Content-Type: text/plain
4.2.2.1 plain text
`,
`4.2.2.1 plain text
`,
},
{
[]int{4, 2, 2, 2},
`Content-Type: text/html
<h1>4.2.2.2 html text</h1>
`,
`<h1>4.2.2.2 html text</h1>
`,
},
}
func TestBodyStructureSerialize(t *testing.T) {
r := require.New(t)
want := &BodyStructure{
"1": {
Header: []byte("Content: type"),
Start: 1,
Size: 2,
BSize: 3,
Lines: 4,
},
"1.1.1": {
Header: []byte("X-Pm-Key: id"),
Start: 11,
Size: 12,
BSize: 13,
Lines: 14,
reader: bytes.NewBuffer([]byte("this should not be serialized")),
},
}
raw, err := want.Serialize()
r.NoError(err)
have, err := DeserializeBodyStructure(raw)
r.NoError(err)
// Before compare remove reader (should not be serialized)
(*want)["1.1.1"].reader = nil
r.Equal(want, have)
}
func TestSectionInfoReadHeader(t *testing.T) {
r := require.New(t)
testData := []struct {
wantHeader, mail string
}{
{
"key1: val1\nkey2: val2\n\n",
"key1: val1\nkey2: val2\n\nbody is here\n\nand it is not confused",
},
{
"key1:\n val1\n\n",
"key1:\n val1\n\nbody is here",
},
{
"key1: val1\r\nkey2: val2\r\n\r\n",
"key1: val1\r\nkey2: val2\r\n\r\nbody is here\r\n\r\nand it is not confused",
},
}
for _, td := range testData {
bs, err := NewBodyStructure(strings.NewReader(td.mail))
r.NoError(err, "case %q", td.mail)
haveHeader, err := bs.GetMailHeaderBytes()
r.NoError(err, "case %q", td.mail)
r.Equal(td.wantHeader, string(haveHeader), "case %q", td.mail)
}
}

View File

@ -1,48 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package message
import (
"fmt"
"io"
)
type partWriter struct {
w io.Writer
boundary string
}
func newPartWriter(w io.Writer, boundary string) *partWriter {
return &partWriter{w: w, boundary: boundary}
}
func (w *partWriter) createPart(fn func(io.Writer) error) error {
if _, err := fmt.Fprintf(w.w, "\r\n--%v\r\n", w.boundary); err != nil {
return err
}
return fn(w.w)
}
func (w *partWriter) done() error {
if _, err := fmt.Fprintf(w.w, "\r\n--%v--\r\n", w.boundary); err != nil {
return err
}
return nil
}

View File

@ -1,136 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package parallel
import (
"sync"
"time"
)
// parallelJob is to be used for passing items between input, worker and
// collector. `idx` is there to know the original order.
type parallelJob struct {
idx int
value interface{}
}
// RunParallel starts `workers` number of workers and feeds them with `input` data.
// Each worker calls `process`. Processed data is collected in the same order as
// the input and is passed in order to the `collect` callback. If an error
// occurs, the execution is stopped and the error returned.
// runParallel blocks until everything is done.
func RunParallel( //nolint:funlen
workers int,
input []interface{},
process func(interface{}) (interface{}, error),
collect func(int, interface{}) error,
) (resultError error) {
wgProcess := &sync.WaitGroup{}
wgCollect := &sync.WaitGroup{}
// Optimise by not executing the code at all if there is no input
// or run less workers than requested if there are few inputs.
inputLen := len(input)
if inputLen == 0 {
return nil
}
if inputLen < workers {
workers = inputLen
}
inputChan := make(chan *parallelJob)
outputChan := make(chan *parallelJob)
orderedCollectLock := &sync.Mutex{}
orderedCollect := make(map[int]interface{})
// Feed input channel used by workers with input data with index for ordering.
go func() {
defer close(inputChan)
for idx, item := range input {
if resultError != nil {
break
}
inputChan <- &parallelJob{idx, item}
}
}()
// Start workers and process all the inputs.
wgProcess.Add(workers)
for i := 0; i < workers; i++ {
go func() {
defer wgProcess.Done()
for item := range inputChan {
if output, err := process(item.value); err != nil {
resultError = err
break
} else {
outputChan <- &parallelJob{item.idx, output}
}
}
}()
}
// Collect data into map with the original position in the array.
wgCollect.Add(1)
go func() {
defer wgCollect.Done()
for output := range outputChan {
orderedCollectLock.Lock()
orderedCollect[output.idx] = output.value
orderedCollectLock.Unlock()
}
}()
// Collect data in the same order as in the input array.
wgCollect.Add(1)
go func() {
defer wgCollect.Done()
idx := 0
for {
if idx >= inputLen || resultError != nil {
break
}
orderedCollectLock.Lock()
value, ok := orderedCollect[idx]
if ok {
if err := collect(idx, value); err != nil {
resultError = err
}
delete(orderedCollect, idx)
idx++
}
orderedCollectLock.Unlock()
if !ok {
time.Sleep(10 * time.Millisecond)
}
}
}()
// When input channel is closed, all workers will finish. We need to wait
// for all of them and close the output channel only once.
wgProcess.Wait()
close(outputChan)
// When workers are done, the last job is to finish collecting data. First
// collector is finished when output channel is closed and the second one
// when all items are passed to `collect` in the order or after an error.
wgCollect.Wait()
return resultError
}

View File

@ -1,136 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package parallel
import (
"errors"
"fmt"
"math"
"runtime"
"testing"
"time"
r "github.com/stretchr/testify/require"
)
//nolint:gochecknoglobals
var (
testInput = []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
wantOutput = []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
testProcessSleep = 100 // ms
runParallelTimeOverhead = 150 // ms
windowsCIExtra = 500 // ms - estimated experimentally
)
func TestParallel(t *testing.T) {
workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
for _, workers := range workersTests {
workers := workers
t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) {
collected := make([]int, 0)
collect := func(idx int, value interface{}) error {
collected = append(collected, value.(int)) //nolint:forcetypeassert
return nil
}
tstart := time.Now()
err := RunParallel(workers, testInput, processSleep, collect)
duration := time.Since(tstart)
r.Nil(t, err)
r.Equal(t, wantOutput, collected) // Check the order is always kept.
wantMinDuration := int(math.Ceil(float64(len(testInput))/float64(workers))) * testProcessSleep
wantMaxDuration := wantMinDuration + runParallelTimeOverhead
if runtime.GOOS == "windows" {
wantMaxDuration += windowsCIExtra
}
r.True(t, duration.Nanoseconds() > int64(wantMinDuration*1000000), "Duration too short: %v (expected: %v)", duration, wantMinDuration)
r.True(t, duration.Nanoseconds() < int64(wantMaxDuration*1000000), "Duration too long: %v (expected: %v)", duration, wantMaxDuration)
})
}
}
func TestParallelEmptyInput(t *testing.T) {
workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
for _, workers := range workersTests {
workers := workers
t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) {
err := RunParallel(workers, []interface{}{}, processSleep, collectNil)
r.Nil(t, err)
})
}
}
func TestParallelErrorInProcess(t *testing.T) {
workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
for _, workers := range workersTests {
workers := workers
t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) {
var lastCollected int
process := func(value interface{}) (interface{}, error) {
time.Sleep(10 * time.Millisecond)
if value.(int) == 5 { //nolint:forcetypeassert
return nil, errors.New("Error")
}
return value, nil
}
collect := func(idx int, value interface{}) error {
lastCollected = value.(int) //nolint:forcetypeassert
return nil
}
err := RunParallel(workers, testInput, process, collect)
r.EqualError(t, err, "Error")
time.Sleep(10 * time.Millisecond)
r.True(t, lastCollected < 5, "Last collected cannot be higher that 5, got: %d", lastCollected)
})
}
}
func TestParallelErrorInCollect(t *testing.T) {
workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
for _, workers := range workersTests {
workers := workers
t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) {
collect := func(idx int, value interface{}) error {
if value.(int) == 5 { //nolint:forcetypeassert
return errors.New("Error")
}
return nil
}
err := RunParallel(workers, testInput, processSleep, collect)
r.EqualError(t, err, "Error")
})
}
}
func processSleep(value interface{}) (interface{}, error) {
time.Sleep(time.Duration(testProcessSleep) * time.Millisecond)
return value.(int), nil //nolint:forcetypeassert
}
func collectNil(idx int, value interface{}) error {
return nil
}

View File

@ -1,125 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pchan
import (
"sort"
"sync"
)
type PChan struct {
lock sync.Mutex
items []*Item
ready, done chan struct{}
once sync.Once
}
type Item struct {
ch *PChan
val interface{}
prio int
done sync.WaitGroup
}
func (item *Item) Wait() {
item.done.Wait()
}
func (item *Item) GetPriority() int {
item.ch.lock.Lock()
defer item.ch.lock.Unlock()
return item.prio
}
func (item *Item) SetPriority(priority int) {
item.ch.lock.Lock()
defer item.ch.lock.Unlock()
item.prio = priority
sort.Slice(item.ch.items, func(i, j int) bool {
return item.ch.items[i].prio < item.ch.items[j].prio
})
}
func New() *PChan {
return &PChan{
ready: make(chan struct{}),
done: make(chan struct{}),
}
}
func (ch *PChan) Push(val interface{}, prio int) *Item {
defer ch.notify()
return ch.push(val, prio)
}
func (ch *PChan) Pop() (interface{}, int, bool) {
select {
case <-ch.ready:
val, prio := ch.pop()
return val, prio, true
case <-ch.done:
return nil, 0, false
}
}
func (ch *PChan) Close() {
ch.once.Do(func() { close(ch.done) })
}
func (ch *PChan) push(val interface{}, prio int) *Item {
ch.lock.Lock()
defer ch.lock.Unlock()
item := &Item{
ch: ch,
val: val,
prio: prio,
}
item.done.Add(1)
ch.items = append(ch.items, item)
return item
}
func (ch *PChan) pop() (interface{}, int) {
ch.lock.Lock()
defer ch.lock.Unlock()
sort.Slice(ch.items, func(i, j int) bool {
return ch.items[i].prio < ch.items[j].prio
})
var item *Item
item, ch.items = ch.items[len(ch.items)-1], ch.items[:len(ch.items)-1]
defer item.done.Done()
return item.val, item.prio
}
func (ch *PChan) notify() {
go func() { ch.ready <- struct{}{} }()
}

View File

@ -1,123 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pchan
import (
"sort"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPChanConcurrentPush(t *testing.T) {
ch := New()
var wg sync.WaitGroup
// We are going to test with 5 additional goroutines.
wg.Add(5)
// Start 5 concurrent pushes.
go func() { defer wg.Done(); ch.Push(1, 1) }()
go func() { defer wg.Done(); ch.Push(2, 2) }()
go func() { defer wg.Done(); ch.Push(3, 3) }()
go func() { defer wg.Done(); ch.Push(4, 4) }()
go func() { defer wg.Done(); ch.Push(5, 5) }()
// Wait for the items to be pushed.
wg.Wait()
// All 5 should now be ready for popping.
require.Len(t, ch.items, 5)
// They should be popped in priority order.
assert.Equal(t, 5, getValue(t, ch))
assert.Equal(t, 4, getValue(t, ch))
assert.Equal(t, 3, getValue(t, ch))
assert.Equal(t, 2, getValue(t, ch))
assert.Equal(t, 1, getValue(t, ch))
}
func TestPChanConcurrentPop(t *testing.T) {
ch := New()
var wg sync.WaitGroup
// We are going to test with 5 additional goroutines.
wg.Add(5)
// Make a list to store the results in.
var res list
// Start 5 concurrent pops; these consume any items pushed.
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
go func() { defer wg.Done(); res.append(getValue(t, ch)) }()
// Push and block; items should be popped immediately by the waiting goroutines.
ch.Push(1, 1).Wait()
ch.Push(2, 2).Wait()
ch.Push(3, 3).Wait()
ch.Push(4, 4).Wait()
ch.Push(5, 5).Wait()
// Wait for all items to be popped then close the result channel.
wg.Wait()
assert.True(t, sort.IntsAreSorted(res.items))
}
func TestPChanClose(t *testing.T) {
ch := New()
go ch.Push(1, 1)
valOpen, _, okOpen := ch.Pop()
assert.True(t, okOpen)
assert.Equal(t, 1, valOpen)
ch.Close()
valClose, _, okClose := ch.Pop()
assert.False(t, okClose)
assert.Nil(t, valClose)
}
type list struct {
items []int
mut sync.Mutex
}
func (l *list) append(val int) {
l.mut.Lock()
defer l.mut.Unlock()
l.items = append(l.items, val)
}
func getValue(t *testing.T, ch *PChan) int {
val, _, ok := ch.Pop()
assert.True(t, ok)
return val.(int) //nolint:forcetypeassert
}

View File

@ -1,217 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
"github.com/pkg/errors"
)
// Address statuses.
const (
DisabledAddress = iota
EnabledAddress
)
// Address HasKeys values.
const (
MissingKeys = iota
KeysPresent
)
// Address types.
const (
_ = iota // Skip first.
OriginalAddress
AliasAddress
CustomAddress
PremiumAddress
)
// Address Send values.
const (
NoSendAddress = iota
MainSendAddress
SecondarySendAddress
)
// Address represents a user's address.
type Address struct {
ID string
DomainID string
Email string
Send int
Receive Boolean
Status int
Order int `json:",omitempty"`
Type int
DisplayName string
Signature string
MemberID string `json:",omitempty"`
MemberName string `json:",omitempty"`
HasKeys int
Keys PMKeys
}
// AddressList is a list of addresses.
type AddressList []*Address
// ByID returns an address by id. Returns nil if no address is found.
func (l AddressList) ByID(id string) *Address {
for _, addr := range l {
if addr.ID == id {
return addr
}
}
return nil
}
// AllEmails returns all emails.
func (l AddressList) AllEmails() (addresses []string) {
for _, a := range l {
addresses = append(addresses, a.Email)
}
return
}
// ActiveEmails returns only active emails.
func (l AddressList) ActiveEmails() (addresses []string) {
for _, a := range l {
if a.Receive {
addresses = append(addresses, a.Email)
}
}
return
}
// Main gets the main address.
func (l AddressList) Main() *Address {
for _, addr := range l {
if addr.Order == 1 {
return addr
}
}
return nil
}
// ByEmail gets an address by email. Returns nil if no address is found.
func (l AddressList) ByEmail(email string) *Address {
email = SanitizeEmail(email)
for _, addr := range l {
if strings.EqualFold(addr.Email, email) {
return addr
}
}
return nil
}
func SanitizeEmail(email string) string {
splitAt := strings.Split(email, "@")
if len(splitAt) != 2 {
return email
}
splitPlus := strings.Split(splitAt[0], "+")
email = splitPlus[0] + "@" + splitAt[1]
return email
}
func ConstructAddress(headerEmail string, addressEmail string) string {
splitAtHeader := strings.Split(headerEmail, "@")
if len(splitAtHeader) != 2 {
return addressEmail
}
splitPlus := strings.Split(splitAtHeader[0], "+")
if len(splitPlus) != 2 {
return addressEmail
}
splitAtAddress := strings.Split(addressEmail, "@")
if len(splitAtAddress) != 2 {
return addressEmail
}
return splitAtAddress[0] + "+" + splitPlus[1] + "@" + splitAtAddress[1]
}
// GetAddresses requests all of current user addresses (without pagination).
func (c *client) GetAddresses(ctx context.Context) (addresses AddressList, err error) {
var res struct {
Addresses []*Address
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/addresses")
}); err != nil {
return nil, err
}
return res.Addresses, nil
}
func (c *client) ReorderAddresses(ctx context.Context, addressIDs []string) error {
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&struct {
AddressIDs []string
}{
AddressIDs: addressIDs,
}).Put("/addresses/order")
}); err != nil {
return err
}
_, err := c.UpdateUser(ctx)
return err
}
// Addresses returns the addresses stored in the client object itself rather than fetching from the API.
func (c *client) Addresses() AddressList {
return c.addresses
}
// unlockAddresses unlocks all keys for all addresses of current user.
func (c *client) unlockAddress(passphrase []byte, address *Address) error {
if address == nil {
return errors.New("address data is missing")
}
if address.HasKeys == MissingKeys {
return nil
}
kr, err := address.Keys.UnlockAll(passphrase, c.userKeyRing)
if err != nil {
return errors.Wrap(err, "cannot unlock address keys for "+address.ID)
}
c.addrKeyRing[address.ID] = kr
return nil
}
func (c *client) KeyRingForAddressID(addrID string) (*crypto.KeyRing, error) {
if kr, ok := c.addrKeyRing[addrID]; ok {
return kr, nil
}
return nil, errors.New("no keyring available")
}

View File

@ -1,76 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"net/http"
"testing"
r "github.com/stretchr/testify/require"
)
var testAddressList = AddressList{
&Address{
ID: "1",
Email: "root@nsa.gov",
Send: SecondarySendAddress,
Status: EnabledAddress,
Order: 2,
},
&Address{
ID: "2",
Email: "root@gchq.gov.uk",
Send: MainSendAddress,
Status: EnabledAddress,
Order: 1,
},
&Address{
ID: "3",
Email: "root@protonmail.com",
Send: NoSendAddress,
Status: DisabledAddress,
Order: 3,
},
}
func routeGetAddresses(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(tb, checkMethodAndPath(req, "GET", "/addresses"))
r.NoError(tb, isAuthReq(req, testUID, testAccessToken))
return "addresses/get_response.json"
}
func TestAddressList(t *testing.T) {
input := "1"
addr := testAddressList.ByID(input)
r.Equal(t, testAddressList[0], addr)
input = "42"
addr = testAddressList.ByID(input)
r.Nil(t, addr)
input = "root@protonmail.com"
addr = testAddressList.ByEmail(input)
r.Equal(t, testAddressList[2], addr)
input = "idontexist@protonmail.com"
addr = testAddressList.ByEmail(input)
r.Nil(t, addr)
addr = testAddressList.Main()
r.Equal(t, testAddressList[1], addr)
}

View File

@ -1,176 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/textproto"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
type header textproto.MIMEHeader
type rawHeader map[string]json.RawMessage
func (h *header) UnmarshalJSON(b []byte) error {
if *h == nil {
*h = make(header)
}
raw := make(rawHeader)
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
for k, v := range raw {
// Most headers are string because they have only one value.
var s string
if err := json.Unmarshal(v, &s); err == nil {
textproto.MIMEHeader(*h).Set(k, s)
continue
}
// If it's not a string, it must be an array of strings.
var a []string
if err := json.Unmarshal(v, &a); err != nil {
return fmt.Errorf("pmapi: attachment header field is neither a string nor an array of strings: %v", err)
}
for _, vv := range a {
textproto.MIMEHeader(*h).Add(k, vv)
}
}
return nil
}
const (
DispositionInline = "inline"
DispositionAttachment = "attachment"
)
// Attachment represents a message attachment.
type Attachment struct {
ID string `json:",omitempty"`
MessageID string `json:",omitempty"` // msg v3 ???
Name string `json:",omitempty"`
Size int64 `json:",omitempty"`
MIMEType string `json:",omitempty"`
ContentID string `json:",omitempty"`
Disposition string
KeyPackets string `json:",omitempty"`
Signature string `json:",omitempty"`
Header textproto.MIMEHeader `json:"-"`
}
// Define a new type to prevent MarshalJSON/UnmarshalJSON infinite loops.
type attachment Attachment
type rawAttachment struct {
attachment
Header header `json:"Headers,omitempty"`
}
func (a *Attachment) MarshalJSON() ([]byte, error) {
var raw rawAttachment
raw.attachment = attachment(*a)
if a.Header != nil {
raw.Header = header(a.Header)
}
return json.Marshal(&raw)
}
func (a *Attachment) UnmarshalJSON(b []byte) error {
var raw rawAttachment
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
*a = Attachment(raw.attachment)
if raw.Header != nil {
a.Header = textproto.MIMEHeader(raw.Header)
}
return nil
}
// Decrypt decrypts this attachment's data from r using the keys from kr.
func (a *Attachment) Decrypt(r io.Reader, kr *crypto.KeyRing) (decrypted io.Reader, err error) {
keyPackets, err := base64.StdEncoding.DecodeString(a.KeyPackets)
if err != nil {
return
}
return decryptAttachment(kr, keyPackets, r)
}
// Encrypt encrypts an attachment.
func (a *Attachment) Encrypt(kr *crypto.KeyRing, att io.Reader) (encrypted io.Reader, err error) {
return encryptAttachment(kr, att, a.Name)
}
func (a *Attachment) DetachedSign(kr *crypto.KeyRing, att io.Reader) (signed io.Reader, err error) {
return signAttachment(kr, att)
}
// CreateAttachment uploads an attachment. It must be already encrypted and contain a MessageID.
//
// The returned created attachment contains the new attachment ID and its size.
func (c *client) CreateAttachment(ctx context.Context, att *Attachment, attData io.Reader, sigData io.Reader) (*Attachment, error) {
var res struct {
Attachment *Attachment
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).
SetMultipartFormData(map[string]string{
"Filename": att.Name,
"MessageID": att.MessageID,
"MIMEType": att.MIMEType,
"ContentID": att.ContentID,
}).
SetMultipartField("DataPacket", "DataPacket.pgp", "application/octet-stream", attData).
SetMultipartField("Signature", "Signature.pgp", "application/octet-stream", sigData).
Post("/mail/v4/attachments")
}); err != nil {
return nil, err
}
return res.Attachment, nil
}
// GetAttachment gets an attachment's content. The returned data is encrypted.
func (c *client) GetAttachment(ctx context.Context, attachmentID string) (att io.ReadCloser, err error) {
res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetDoNotParseResponse(true).Get("/mail/v4/attachments/" + attachmentID)
})
if err != nil {
return nil, err
}
return res.RawBody(), nil
}

View File

@ -1,201 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"testing"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
"github.com/stretchr/testify/require"
)
const testAttachmentCleartext = `cc,
dille.
`
// Attachment cleartext encrypted with testPrivateKeyRing.
const testKeyPacket = `wcBMA0fcZ7XLgmf2AQf/cHhfDRM9zlIuBi+h2W6DKjbbyIHMkgF6ER3JEvn/tSruUH8KTGt0N7Z+a80FFMCuXn1Y1I/nW7MVrNhGuJZAF4OymD8ugvuoAMIQX0eCYEpPXzRIWJBZg82AuowmFMsv8Dgvq4bTZq4cttI3CZcxKUNXuAearmNpmgplUKWj5USmRXK4iGB3VFGjidXkxbElrP4fD5A/rfEZ5aJgCsegqcXxX3MEjWXi9pFzgd/9phOvl1ZFm9U9hNoVAW3QsgmVeihnKaDZUyf2Qsigij21QKAUxw9U3y89eTUIqZAcmIgqeDujA3RWBgJwjtY/lOyhEmkf3AWKzehvf1xtJmCWDg==`
const testDataPacket = `0ksB6S4f4l8C1NB8yzmd/jNi0xqEZsyTDLdTP+N4Qxh3NZjla+yGRvC9rGmoUL7XVyowsG/GKTf2LXF/5E5FkX/3WMYwIv1n11ExyAE=`
var testAttachment = &Attachment{
ID: "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw==",
Name: "croutonmail.txt",
Size: 77,
MIMEType: "text/plain",
KeyPackets: testKeyPacket,
Header: textproto.MIMEHeader{
"Content-Description": {"You'll never believe what's in this text file"},
"X-Mailer": {"Microsoft Outlook 15.0", "Microsoft Live Mail 42.0"},
},
MessageID: "h3CD-DT7rLoAw1vmpcajvIPAl-wwDfXR2MHtWID3wuQURDBKTiGUAwd6E2WBbS44QQKeXImW-axm6X0hAfcVCA==",
}
// Part of GET /mail/messages/{id} response from server.
const testAttachmentJSON = `{
"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw==",
"Name": "croutonmail.txt",
"Size": 77,
"MIMEType": "text/plain",
"KeyPackets": "` + testKeyPacket + `",
"Headers": {
"content-description": "You'll never believe what's in this text file",
"x-mailer": [
"Microsoft Outlook 15.0",
"Microsoft Live Mail 42.0"
]
}
}
`
// POST /mail/attachment/ response from server.
const testCreatedAttachmentBody = `{
"Code": 1000,
"Attachment": {"ID": "y6uKIlc2HdoHPAwPSrvf7dXoZNMYvBgxshYUN67cY5DJjL2O8NYewuvGHcYvCfd8LpEoAI_GdymO0Jr0mHlsEw=="}
}`
func TestAttachment_UnmarshalJSON(t *testing.T) {
r := require.New(t)
att := new(Attachment)
err := json.Unmarshal([]byte(testAttachmentJSON), att)
r.NoError(err)
att.MessageID = testAttachment.MessageID // This isn't in the server response
r.Equal(testAttachment, att)
}
func TestClient_CreateAttachment(t *testing.T) {
r := require.New(t)
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(checkMethodAndPath(req, "POST", "/mail/v4/attachments"))
contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
r.NoError(err)
r.Equal("multipart/form-data", contentType)
mr := multipart.NewReader(req.Body, params["boundary"])
form, err := mr.ReadForm(10 * 1024)
r.NoError(err)
defer r.NoError(form.RemoveAll())
r.Equal(testAttachment.Name, form.Value["Filename"][0])
r.Equal(testAttachment.MessageID, form.Value["MessageID"][0])
r.Equal(testAttachment.MIMEType, form.Value["MIMEType"][0])
dataFile, err := form.File["DataPacket"][0].Open()
r.NoError(err)
defer r.NoError(dataFile.Close())
b, err := io.ReadAll(dataFile)
r.NoError(err)
r.Equal(testAttachmentCleartext, string(b))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreatedAttachmentBody)
}))
defer s.Close()
reader := strings.NewReader(testAttachmentCleartext) // In reality, this thing is encrypted
created, err := c.CreateAttachment(context.Background(), testAttachment, reader, strings.NewReader(""))
r.NoError(err)
r.Equal(testAttachment.ID, created.ID)
}
func TestClient_GetAttachment(t *testing.T) {
r := require.New(t)
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(checkMethodAndPath(req, "GET", "/mail/v4/attachments/"+testAttachment.ID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testAttachmentCleartext)
}))
defer s.Close()
att, err := c.GetAttachment(context.Background(), testAttachment.ID)
r.NoError(err)
defer att.Close() //nolint:errcheck
// In reality, r contains encrypted data
b, err := io.ReadAll(att)
r.NoError(err)
r.Equal(testAttachmentCleartext, string(b))
}
func TestAttachmentDecrypt(t *testing.T) {
r := require.New(t)
rawKeyPacket, err := base64.StdEncoding.DecodeString(testKeyPacket)
r.NoError(err)
rawDataPacket, err := base64.StdEncoding.DecodeString(testDataPacket)
r.NoError(err)
decryptAndCheck(r, bytes.NewBuffer(append(rawKeyPacket, rawDataPacket...)))
}
func TestAttachmentEncrypt(t *testing.T) {
r := require.New(t)
encryptedReader, err := testAttachment.Encrypt(
testPublicKeyRing,
bytes.NewBufferString(testAttachmentCleartext),
)
r.NoError(err)
// The result is always different due to session key. The best way is to
// test result of encryption by decrypting again acn coparet to cleartext.
decryptAndCheck(r, encryptedReader)
}
func decryptAndCheck(r *require.Assertions, data io.Reader) {
// First separate KeyPacket from encrypted data. In our case keypacket
// has 271 bytes.
raw, err := io.ReadAll(data)
r.NoError(err)
rawKeyPacket := raw[:271]
rawDataPacket := raw[271:]
// KeyPacket is retrieve by get GET /mail/messages/{id}
haveAttachment := &Attachment{
KeyPackets: base64.StdEncoding.EncodeToString(rawKeyPacket),
}
// DataPacket is received from GET /mail/attachments/{id}
decryptedReader, err := haveAttachment.Decrypt(bytes.NewBuffer(rawDataPacket), testPrivateKeyRing)
r.NoError(err)
b, err := io.ReadAll(decryptedReader)
r.NoError(err)
r.Equal(testAttachmentCleartext, string(b))
}

View File

@ -1,229 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"io"
"net/http"
"time"
"github.com/go-resty/resty/v2"
)
type AuthModulus struct {
Modulus string
ModulusID string
}
type GetAuthInfoReq struct {
Username string
}
type AuthInfo struct {
Version int
Modulus string
ServerEphemeral string
Salt string
SRPSession string
}
type TwoFAInfo struct {
Enabled TwoFAStatus
}
func (twoFAInfo TwoFAInfo) hasTwoFactor() bool {
return twoFAInfo.Enabled > TwoFADisabled
}
type TwoFAStatus int
const (
TwoFADisabled TwoFAStatus = iota
TOTPEnabled
U2FEnabled
TOTPAndU2FEnabled
)
type PasswordMode int
const (
OnePasswordMode PasswordMode = iota + 1
TwoPasswordMode
)
type AuthReq struct {
Username string
ClientProof string
ClientEphemeral string
SRPSession string
}
type AuthRefresh struct {
UID string
AccessToken string
RefreshToken string
ExpiresIn int64
Scopes []string
}
type Auth struct {
AuthRefresh
UserID string
ServerProof string
PasswordMode PasswordMode
TwoFA *TwoFAInfo `json:"2FA,omitempty"`
}
func (a Auth) HasTwoFactor() bool {
if a.TwoFA == nil {
return false
}
return a.TwoFA.hasTwoFactor()
}
func (a Auth) HasMailboxPassword() bool {
return a.PasswordMode == TwoPasswordMode
}
type auth2FAReq struct {
TwoFactorCode string
}
type authRefreshReq struct {
UID string
RefreshToken string
ResponseType string
GrantType string
RedirectURI string
State string
}
func (c *client) Auth2FA(ctx context.Context, twoFactorCode string) error {
// 2FA is called during login procedure during which refresh token should
// be valid, therefore, no refresh is needed if there is an error.
ctx = ContextWithoutAuthRefresh(ctx)
if res, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(auth2FAReq{TwoFactorCode: twoFactorCode}).Post("/auth/2fa")
}); err != nil {
if res != nil {
switch res.StatusCode() {
case http.StatusUnauthorized:
return ErrBad2FACode
case http.StatusUnprocessableEntity:
return ErrBad2FACodeTryAgain
}
}
return err
}
return nil
}
func (c *client) AuthDelete(ctx context.Context) error {
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/auth")
}); err != nil {
return err
}
c.uid, c.acc, c.ref, c.exp = "", "", "", time.Time{}
c.sendAuthRefresh(nil)
return nil
}
func (c *client) AuthSalt(ctx context.Context) (string, error) {
salts, err := c.GetKeySalts(ctx)
if err != nil {
return "", err
}
if _, err := c.CurrentUser(ctx); err != nil {
return "", err
}
for _, s := range salts {
if s.ID == c.user.Keys[0].ID {
return s.KeySalt, nil
}
}
return "", errors.New("no matching salt found")
}
func (c *client) AddAuthRefreshHandler(handler AuthRefreshHandler) {
c.authHandlers = append(c.authHandlers, handler)
}
func (c *client) authRefresh(ctx context.Context) error {
c.authLocker.Lock()
defer c.authLocker.Unlock()
if c.ref == "" {
return ErrUnauthorized
}
auth, err := c.manager.authRefresh(ctx, c.uid, c.ref)
if err != nil {
if IsFailedAuth(err) {
c.sendAuthRefresh(nil)
}
return err
}
c.acc = auth.AccessToken
c.ref = auth.RefreshToken
c.exp = expiresIn(auth.ExpiresIn)
c.sendAuthRefresh(auth)
return nil
}
func (c *client) sendAuthRefresh(auth *AuthRefresh) {
for _, handler := range c.authHandlers {
go handler(auth)
}
if auth == nil {
c.authHandlers = []AuthRefreshHandler{}
}
}
func randomString(length int) string {
noise := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, noise); err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(noise)[:length]
}
func (c *client) GetCurrentAuth() *Auth {
return &Auth{
UserID: c.user.ID,
AuthRefresh: AuthRefresh{
UID: c.uid,
RefreshToken: c.ref,
},
}
}

View File

@ -1,122 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/stretchr/testify/require"
)
type testRefreshResponse struct {
Code int
AccessToken string
ExpiresIn int
TokenType string
Scope string
Scopes []string
UID string
RefreshToken string
LocalID int
r *require.Assertions
}
var tokenID = 0
func newTestRefreshToken(r *require.Assertions) testRefreshResponse {
tokenID++
scopes := []string{
"full",
"self",
"parent",
"user",
"loggedin",
"paid",
"nondelinquent",
"mail",
"verified",
}
return testRefreshResponse{
Code: 1000,
AccessToken: fmt.Sprintf("acc%d", tokenID),
ExpiresIn: 3600,
TokenType: "Bearer",
Scope: strings.Join(scopes, " "),
Scopes: scopes,
UID: fmt.Sprintf("uid%d", tokenID),
RefreshToken: fmt.Sprintf("ref%d", tokenID),
r: r,
}
}
func (r *testRefreshResponse) isCorrectRefreshToken(body io.ReadCloser) int {
request := authRefreshReq{}
err := json.NewDecoder(body).Decode(&request)
r.r.NoError(body.Close())
r.r.NoError(err)
if r.UID != request.UID {
return http.StatusUnprocessableEntity
}
if r.RefreshToken != request.RefreshToken {
return http.StatusBadRequest
}
return http.StatusOK
}
func (r *testRefreshResponse) handleAuthRefresh(response http.ResponseWriter, request *http.Request) {
if code := r.isCorrectRefreshToken(request.Body); code != http.StatusOK {
response.WriteHeader(code)
return
}
tokenID++
r.AccessToken = fmt.Sprintf("acc%d", tokenID)
r.RefreshToken = fmt.Sprintf("ref%d", tokenID)
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(http.StatusOK)
r.r.NoError(json.NewEncoder(response).Encode(r))
}
func (r *testRefreshResponse) wantAuthRefresh() AuthRefresh {
return AuthRefresh{
UID: r.UID,
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
ExpiresIn: int64(r.ExpiresIn),
Scopes: r.Scopes,
}
}
func (r *testRefreshResponse) isAuthorized(header http.Header) bool {
return header.Get("x-pm-uid") == r.UID && header.Get("Authorization") == "Bearer "+r.AccessToken
}
func (r *testRefreshResponse) handleAuthCheckOnly(response http.ResponseWriter, request *http.Request) {
if r.isAuthorized(request.Header) {
response.WriteHeader(http.StatusOK)
} else {
response.WriteHeader(http.StatusUnauthorized)
}
}

View File

@ -1,268 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestAutomaticAuthRefresh(t *testing.T) {
r := require.New(t)
mux := http.NewServeMux()
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testAcc := currentTokens.AccessToken
testRef := currentTokens.RefreshToken
currentTokens.ExpiresIn = 100
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
var gotAuthRefresh *AuthRefresh
c := New(Config{HostURL: ts.URL}).
NewClient(testUID, testAcc, testRef, time.Now().Add(-time.Second))
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// Make a request with an access token that already expired one second ago.
_, err := c.GetAddresses(context.Background())
r.NoError(err)
wantAuthRefresh := currentTokens.wantAuthRefresh()
// The auth callback should have been called.
r.NotNil(gotAuthRefresh)
r.Equal(wantAuthRefresh, *gotAuthRefresh)
cl := c.(*client) //nolint:forcetypeassert // we want to panic here
r.Equal(wantAuthRefresh.AccessToken, cl.acc)
r.Equal(wantAuthRefresh.RefreshToken, cl.ref)
r.WithinDuration(expiresIn(100), cl.exp, time.Second)
}
func Test401AuthRefresh(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testRef := currentTokens.RefreshToken
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
var gotAuthRefresh *AuthRefresh
// Create a new client.
m := New(Config{HostURL: ts.URL})
c := m.NewClient(testUID, "oldAccToken", testRef, time.Now().Add(time.Hour))
// Register an auth handler.
c.AddAuthRefreshHandler(func(auth *AuthRefresh) { gotAuthRefresh = auth })
// The first request will fail with 401, triggering a refresh and retry.
_, err := c.GetAddresses(context.Background())
r.NoError(err)
// The auth callback should have been called.
r.NotNil(gotAuthRefresh)
r.Equal(currentTokens.wantAuthRefresh(), *gotAuthRefresh)
}
func Test401RevokedAuth(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient("badUID", "badAcc", "badRef", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background())
r.True(IsFailedAuth(err))
}
func Test401OldRefreshToken(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient(currentTokens.UID, "oldAcc", "oldRef", time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh.
// The retry will also fail with 401, returning an error.
_, err := c.GetAddresses(context.Background())
r.True(IsFailedAuth(err))
}
func Test401NoAccessToken(t *testing.T) {
r := require.New(t)
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testRef := currentTokens.RefreshToken
mux := http.NewServeMux()
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/addresses", currentTokens.handleAuthCheckOnly)
ts := httptest.NewServer(mux)
c := New(Config{HostURL: ts.URL}).
NewClient(testUID, "", testRef, time.Now().Add(time.Hour))
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
_, err := c.GetAddresses(context.Background())
r.NoError(err)
}
func Test401ExpiredAuthUpdateUser(t *testing.T) {
r := require.New(t)
mux := http.NewServeMux()
currentTokens := newTestRefreshToken(r)
testUID := currentTokens.UID
testRef := currentTokens.RefreshToken
mux.HandleFunc("/auth/refresh", currentTokens.handleAuthRefresh)
mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) {
if !currentTokens.isAuthorized(r.Header) {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
respObj := struct {
Code int
User *User
}{
Code: 1000,
User: &User{
ID: "MJLke8kWh1BBvG95JBIrZvzpgsZ94hNNgjNHVyhXMiv4g9cn6SgvqiIFR5cigpml2LD_iUk_3DkV29oojTt3eA==",
Name: "jason",
UsedSpace: &usedSpace,
},
}
if err := json.NewEncoder(w).Encode(respObj); err != nil {
panic(err)
}
})
mux.HandleFunc("/addresses", func(w http.ResponseWriter, r *http.Request) {
if !currentTokens.isAuthorized(r.Header) {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
respObj := struct {
Code int
Addresses []*Address
}{
Code: 1000,
Addresses: []*Address{},
}
if err := json.NewEncoder(w).Encode(respObj); err != nil {
panic(err)
}
})
ts := httptest.NewServer(mux)
m := New(Config{HostURL: ts.URL})
c, _, err := m.NewClientWithRefresh(context.Background(), testUID, testRef)
r.NoError(err)
// The request will fail with 401, triggering a refresh. After the refresh it should succeed.
_, err = c.UpdateUser(context.Background())
r.NoError(err)
}
func TestAuth2FA(t *testing.T) {
r := require.New(t)
twoFACode := "code"
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
var twoFAreq auth2FAReq
r.NoError(json.NewDecoder(req.Body).Decode(&twoFAreq))
r.Equal(twoFAreq.TwoFactorCode, twoFACode)
return "/auth/2fa/post_response.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), twoFACode)
r.NoError(err)
}
func TestAuth2FA_Fail(t *testing.T) {
r := require.New(t)
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_401_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(ErrBad2FACode, err)
}
func TestAuth2FA_Retry(t *testing.T) {
r := require.New(t)
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(checkMethodAndPath(req, "POST", "/auth/2fa"))
return "/auth/2fa/post_422_bad_password.json"
},
)
defer finish()
err := c.Auth2FA(context.Background(), "code")
r.Equal(ErrBad2FACodeTryAgain, err)
}

View File

@ -1,41 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import "encoding/json"
type Boolean bool
func (boolean *Boolean) UnmarshalJSON(b []byte) error {
var value int
err := json.Unmarshal(b, &value)
if err != nil {
return err
}
*boolean = Boolean(value == 1)
return nil
}
func (boolean Boolean) MarshalJSON() ([]byte, error) {
var value int
if boolean {
value = 1
}
return json.Marshal(value)
}

View File

@ -1,120 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"net/http"
"sync"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
// client is a client of the protonmail API. It implements the Client interface.
type client struct {
manager clientManager
uid, acc, ref string
authHandlers []AuthRefreshHandler
authLocker sync.RWMutex
user *User
addresses AddressList
userKeyRing *crypto.KeyRing
addrKeyRing map[string]*crypto.KeyRing
keyRingLock sync.Locker
exp time.Time
}
func newClient(manager clientManager, uid string) *client {
return &client{
manager: manager,
uid: uid,
addrKeyRing: make(map[string]*crypto.KeyRing),
keyRingLock: &sync.RWMutex{},
}
}
func (c *client) withAuth(acc, ref string, exp time.Time) *client {
c.acc = acc
c.ref = ref
c.exp = exp
return c
}
func (c *client) r(ctx context.Context) (*resty.Request, error) {
r := c.manager.r(ctx)
if c.uid != "" {
r.SetHeader("x-pm-uid", c.uid)
}
if time.Now().After(c.exp) {
if err := c.authRefresh(ctx); err != nil {
return nil, err
}
}
c.authLocker.RLock()
defer c.authLocker.RUnlock()
if c.acc != "" {
r.SetAuthToken(c.acc)
}
return r, nil
}
// do executes fn and may repeat execution in case of retry after "401 Unauthorized" error.
// Note: fn may be called more than once.
func (c *client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) {
r, err := c.r(ctx)
if err != nil {
return nil, err
}
res, err := wrapNoConnection(fn(r))
if err != nil {
if res.StatusCode() != http.StatusUnauthorized {
// Return also response so caller has more options to decide what to do.
return res, err
}
if !isAuthRefreshDisabled(ctx) {
if err := c.authRefresh(ctx); err != nil {
return nil, err
}
// We need to reconstruct request since access token is changed with authRefresh.
r, err := c.r(ctx)
if err != nil {
return nil, err
}
return wrapNoConnection(fn(r))
}
return res, err
}
return res, nil
}

View File

@ -1,92 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
)
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
// If the keyrings are already present, they are not recreated.
func (c *client) Unlock(ctx context.Context, passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
return c.unlock(ctx, passphrase)
}
// unlock unlocks the user's keys but without locking the keyring lock first.
// Should only be used internally by methods that first lock the lock.
func (c *client) unlock(ctx context.Context, passphrase []byte) error {
if _, err := c.CurrentUser(ctx); err != nil {
return err
}
if c.userKeyRing == nil {
if err := c.unlockUser(passphrase); err != nil {
return ErrUnlockFailed{err}
}
}
for _, address := range c.addresses {
if c.addrKeyRing[address.ID] == nil {
if err := c.unlockAddress(passphrase, address); err != nil {
return ErrUnlockFailed{err}
}
}
}
return nil
}
func (c *client) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
c.clearKeys()
return c.unlock(ctx, passphrase)
}
func (c *client) clearKeys() {
if c.userKeyRing != nil {
c.userKeyRing.ClearPrivateParams()
c.userKeyRing = nil
}
for id, kr := range c.addrKeyRing {
if kr != nil {
kr.ClearPrivateParams()
}
delete(c.addrKeyRing, id)
}
}
func (c *client) IsUnlocked() bool {
if c.userKeyRing == nil {
return false
}
for _, address := range c.addresses {
if address.HasKeys != MissingKeys && c.addrKeyRing[address.ID] == nil {
return false
}
}
return true
}

View File

@ -1,93 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"io"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
// Client defines the interface of a PMAPI client.
type Client interface {
Auth2FA(context.Context, string) error
AuthSalt(ctx context.Context) (string, error)
AuthDelete(context.Context) error
AddAuthRefreshHandler(AuthRefreshHandler)
GetUser(ctx context.Context) (*User, error)
CurrentUser(ctx context.Context) (*User, error)
UpdateUser(ctx context.Context) (*User, error)
Unlock(ctx context.Context, passphrase []byte) (err error)
ReloadKeys(ctx context.Context, passphrase []byte) (err error)
IsUnlocked() bool
Addresses() AddressList
GetAddresses(context.Context) (addresses AddressList, err error)
ReorderAddresses(ctx context.Context, addressIDs []string) error
GetEvent(ctx context.Context, eventID string) (*Event, error)
SendMessage(context.Context, string, *SendMessageReq) (sent, parent *Message, err error)
CreateDraft(ctx context.Context, m *Message, parent string, action int) (created *Message, err error)
Import(context.Context, ImportMsgReqs) ([]*ImportMsgRes, error)
CountMessages(ctx context.Context, addressID string) ([]*MessagesCount, error)
ListMessages(ctx context.Context, filter *MessagesFilter) ([]*Message, int, error)
GetMessage(ctx context.Context, apiID string) (*Message, error)
DeleteMessages(ctx context.Context, apiIDs []string) error
LabelMessages(ctx context.Context, apiIDs []string, labelID string) error
UnlabelMessages(ctx context.Context, apiIDs []string, labelID string) error
MarkMessagesRead(ctx context.Context, apiIDs []string) error
MarkMessagesUnread(ctx context.Context, apiIDs []string) error
ListLabels(ctx context.Context) ([]*Label, error)
CreateLabel(ctx context.Context, label *Label) (*Label, error)
UpdateLabel(ctx context.Context, label *Label) (*Label, error)
DeleteLabel(ctx context.Context, labelID string) error
EmptyFolder(ctx context.Context, labelID string, addressID string) error
// /core/V4/labels routes
ListLabelsOnly(ctx context.Context) ([]*Label, error)
ListFoldersOnly(ctx context.Context) ([]*Label, error)
CreateLabelV4(ctx context.Context, label *Label) (*Label, error)
UpdateLabelV4(ctx context.Context, label *Label) (*Label, error)
DeleteLabelV4(ctx context.Context, labelID string) error
GetMailSettings(ctx context.Context) (MailSettings, error)
GetContactEmailByEmail(context.Context, string, int, int) ([]ContactEmail, error)
GetContactByID(context.Context, string) (Contact, error)
DecryptAndVerifyCards([]Card) ([]Card, error)
GetAttachment(ctx context.Context, id string) (att io.ReadCloser, err error)
CreateAttachment(ctx context.Context, att *Attachment, r io.Reader, sig io.Reader) (created *Attachment, err error)
GetUserKeyRing() (*crypto.KeyRing, error)
KeyRingForAddressID(string) (kr *crypto.KeyRing, err error)
GetPublicKeysForEmail(context.Context, string) ([]PublicKey, bool, error)
}
type AuthRefreshHandler func(*AuthRefresh)
type clientManager interface {
r(context.Context) *resty.Request
authRefresh(context.Context, string, string) (*AuthRefresh, error)
setSentryUserID(userID string)
}

View File

@ -1,74 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"runtime"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
type Config struct {
// HostURL is the base URL of API.
HostURL string
// AppVersion sets version to headers of each request.
AppVersion string
// UserAgent sets user agent to headers of each request.
// Used only if GetUserAgent is not set.
UserAgent string
// GetUserAgent is dynamic version of UserAgent.
// Overrides UserAgent.
GetUserAgent func() string
// UpgradeApplicationHandler is used to notify when there is a force upgrade.
UpgradeApplicationHandler func()
// TLSIssueHandler is used to notify when there is a TLS issue.
TLSIssueHandler func()
}
func NewConfig(appVersionName, appVersion string) Config {
return Config{
HostURL: getRootURL(),
AppVersion: getAPIOS() + cases.Title(language.Und).String(appVersionName) + "_" + appVersion,
}
}
func (c *Config) getUserAgent() string {
if c.GetUserAgent == nil {
return c.UserAgent
}
return c.GetUserAgent()
}
// getAPIOS returns actual operating system.
func getAPIOS() string {
switch os := runtime.GOOS; os {
case "darwin": // nolint: const
return "macOS"
case "linux":
return "Linux"
case "windows":
return "Windows"
}
return "Linux"
}

View File

@ -1,36 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build !build_qa
// +build !build_qa
package pmapi
import (
"net/http"
)
func getRootURL() string {
return "https://api.protonmail.ch"
}
func newProxyDialerAndTransport(cfg Config) (*ProxyTLSDialer, http.RoundTripper) {
basicDialer := NewBasicTLSDialer(cfg)
pinningDialer := NewPinningTLSDialer(cfg, basicDialer)
proxyDialer := NewProxyTLSDialer(cfg, pinningDialer)
return proxyDialer, CreateTransportWithDialer(proxyDialer)
}

View File

@ -1,49 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
//go:build build_qa
// +build build_qa
package pmapi
import (
"crypto/tls"
"net/http"
"os"
"strings"
)
func getRootURL() string {
// This config allows to dynamically change ROOT URL.
url := os.Getenv("PMAPI_ROOT_URL")
if strings.HasPrefix(url, "http") {
return url
}
if url != "" {
return "https://" + url
}
return "https://api.protonmail.ch"
}
func newProxyDialerAndTransport(cfg Config) (*ProxyTLSDialer, http.RoundTripper) {
transport := CreateTransportWithDialer(NewBasicTLSDialer(cfg))
// TLS certificate of testing environment might be self-signed.
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return nil, transport
}

View File

@ -1,130 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"errors"
"strconv"
"github.com/go-resty/resty/v2"
)
type Card struct {
Type int
Data string
Signature string
}
const (
CardEncrypted = 1
CardSigned = 2
)
type Contact struct {
ID string
Name string
UID string
Size int64
CreateTime int64
ModifyTime int64
LabelIDs []string
ContactEmails []ContactEmail
Cards []Card
}
type ContactEmail struct {
ID string
Name string
Email string
Type []string
Defaults int
Order int
ContactID string
LabelIDs []string
}
var errVerificationFailed = errors.New("signature verification failed")
// ================= Public utility functions ======================
func (c *client) DecryptAndVerifyCards(cards []Card) ([]Card, error) {
for i := range cards {
card := &cards[i]
if isEncryptedCardType(card.Type) {
signedCard, err := c.decrypt(card.Data)
if err != nil {
return nil, err
}
card.Data = string(signedCard)
}
if isSignedCardType(card.Type) {
err := c.verify(card.Data, card.Signature)
if err != nil {
return cards, errVerificationFailed
}
}
}
return cards, nil
}
// GetContactByID gets contact details specified by contact ID.
func (c *client) GetContactByID(ctx context.Context, contactID string) (contactDetail Contact, err error) {
var res struct {
Contact Contact
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/contacts/v4/" + contactID)
}); err != nil {
return Contact{}, err
}
return res.Contact, nil
}
// GetContactEmailByEmail gets all emails from all contacts matching a specified email string.
func (c *client) GetContactEmailByEmail(ctx context.Context, email string, page int, pageSize int) (contactEmails []ContactEmail, err error) {
var res struct {
ContactEmails []ContactEmail
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
r = r.SetQueryParams(map[string]string{
"Email": email,
"Page": strconv.Itoa(page),
})
if pageSize != 0 {
r.SetQueryParam("PageSize", strconv.Itoa(pageSize))
}
return r.SetResult(&res).Get("/contacts/v4/emails")
}); err != nil {
return nil, err
}
return res.ContactEmails, nil
}
func isSignedCardType(cardType int) bool {
return (cardType & CardSigned) == CardSigned
}
func isEncryptedCardType(cardType int) bool {
return (cardType & CardEncrypted) == CardEncrypted
}

View File

@ -1,219 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"fmt"
"net/http"
"reflect"
"testing"
r "github.com/stretchr/testify/require"
)
var (
CleartextCard = 0
EncryptedCard = 1
SignedCard = 2
EncryptedSignedCard = 3
)
var testGetContactByIDResponseBody = `{
"Code": 1000,
"Contact": {
"ID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
"Name": "Alice",
"UID": "proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b",
"Size": 243,
"CreateTime": 1517395498,
"ModifyTime": 1517395498,
"Cards": [
{
"Type": 3,
"Data": "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQf/RnOQRpo8DVJHQSJRgckEaUQvdMcADiM4L23diyiS\nQfclby/Ve2WInmvZc2RJ3rWENfeqyDZE6krQT642pKiW09GOIyVIjl+hje9y\nE4HBX0AIAWv7QhhKX6UZcM5dYSFbV3j3QxQB8A4Thng2G6ltotMTlbtcHbhu\n96Lt6ngA1tngXLSF5seyflnoiSQ5gLi2qVzrd95dIP6D4Ottcp929/4hDGmq\nPyxw9dColx6gVd1bmIDSI6ewkET4Grmo6QYqjSvjqLOf0PqHKzqypSFLkI5l\nmmnWKYTQCgl9wX+hq6Qz5E+m/BtbkdeX0YxYUss2e+oSAzJmnfdETErG9U5z\n3NJqAc3sgdwDzfWHBzogAxAbDHiqrF6zMlR5SFvZ6nRU7M2DTOE5dJhf+zOp\n1WSKn5LR46LGyt0m5wJPDjaGyQdPffAO4EULvwhGENe10UxRjY1qcUmjYOtS\nunl/vh3afI9PC1jj+HHJD2VgCA==\n=UpcY\n-----END PGP MESSAGE-----\n",
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4pCRDMO9BwcW4mpAAA6h0H/2+97koXzly5pu9hpbaW\n75d1Q976RjMr5DjAx6tKFtSzznel8YfWgvA6OQmMGdPY8ae7/+3mwCJZYWy/\nXVvUfCSflmYpSIKGfP+Vm1XezWY1W84DGhiFj5n8sdaWisv3bpFwFf1YR3Ae\noBoZ4ufNzaQALRqGPMgXETtXZCtzuL/+0vGSKj5SLECiRcSE4jCPEVRy2bcl\nWJyB9r4VmcjF042OMHxphXoYmTEWvgigyaQFHNORu5cK9EHfHpCG6IcjGbdx\n+9Px5YnDY1ix+YpBKePGSTlLE0u6ow0VTUrdvNjl7IUBaRcfJcIIdgCBOTMw\n1uQ/yeyP46V5AFXFnIKeZeQ=\n=FlOf\n-----END PGP SIGNATURE-----\n"
},
{
"Type": 2,
"Data": "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Alice\nitem1.EMAIL:alice@protonmail.com\nUID:proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b\nEND:VCARD",
"Signature": "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4qCRDMO9BwcW4mpAAA3jUIAJ88mIyO8Yj0+evSFXnK\nNxNdjNgn7t1leY0BWlh1nkK76XrZEPipdw2QU8cOcZzn1Wby2SGfZVkwoPc4\nzAhPT4WKbkFVqXhDry5399kLwGYJCxdEcw/oPyYj+YgpQKMxhTrQq21tbEwr\n7JDRBXgi3Cckh/XsteFHOIiAVnM7BV6zFudipnYxa4uNF0Bf4VbUZx1Mm0Wb\nMJaGsO5reqQUQzDPO5TdSAZ8qGSdjVv7RESgUu5DckcDSsnB987Zbh9uFc22\nfPYmb6zA0cEZh3dAjpDPT7cg8hlvfYBb+kP3sLFyLiIkdEG8Pcagjf0k+l76\nr1IsPlYBx2LJmsJf+WDNlj8=\n=Xn+3\n-----END PGP SIGNATURE-----\n"
}
],
"ContactEmails": [
{
"ID": "4m2sBxLq4McqD0D330Kuy5xG-yyDNXyLEjG5_RYcjy9X-3qHGNP07DNOWLY40TYtUAQr4fAVp8zOcZ_z2o6H-A==",
"Name": "Alice",
"Email": "alice@protonmail.com",
"Type": [],
"Defaults": 1,
"Order": 1,
"ContactID": "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
"LabelIDs": []
}
],
"LabelIDs": []
}
}`
var testGetContactEmailByEmailResponseBody = `{
"Code": 1000,
"ContactEmails": [
{
"ID": "aefew4323jFv0BhSMw==",
"Name": "ProtonMail Features",
"Email": "features@protonmail.black",
"Type": [
"work"
],
"Defaults": 1,
"Order": 1,
"ContactID": "a29olIjFv0rnXxBhSMw==",
"LabelIDs": [
"I6hgx3Ol-d3HYa3E394T_ACXDmTaBub14w=="
],
"CanonicalEmail": "features@protonmail.black",
"LastUsedTime": 1612546350
}
],
"Total": 2
}`
var testGetContactByID = Contact{
ID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
Name: "Alice",
UID: "proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b",
Size: 243,
CreateTime: 1517395498,
ModifyTime: 1517395498,
Cards: []Card{
{
Type: 3,
Data: "-----BEGIN PGP MESSAGE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwcBMA1vYAFKnBP8gAQf/RnOQRpo8DVJHQSJRgckEaUQvdMcADiM4L23diyiS\nQfclby/Ve2WInmvZc2RJ3rWENfeqyDZE6krQT642pKiW09GOIyVIjl+hje9y\nE4HBX0AIAWv7QhhKX6UZcM5dYSFbV3j3QxQB8A4Thng2G6ltotMTlbtcHbhu\n96Lt6ngA1tngXLSF5seyflnoiSQ5gLi2qVzrd95dIP6D4Ottcp929/4hDGmq\nPyxw9dColx6gVd1bmIDSI6ewkET4Grmo6QYqjSvjqLOf0PqHKzqypSFLkI5l\nmmnWKYTQCgl9wX+hq6Qz5E+m/BtbkdeX0YxYUss2e+oSAzJmnfdETErG9U5z\n3NJqAc3sgdwDzfWHBzogAxAbDHiqrF6zMlR5SFvZ6nRU7M2DTOE5dJhf+zOp\n1WSKn5LR46LGyt0m5wJPDjaGyQdPffAO4EULvwhGENe10UxRjY1qcUmjYOtS\nunl/vh3afI9PC1jj+HHJD2VgCA==\n=UpcY\n-----END PGP MESSAGE-----\n",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4pCRDMO9BwcW4mpAAA6h0H/2+97koXzly5pu9hpbaW\n75d1Q976RjMr5DjAx6tKFtSzznel8YfWgvA6OQmMGdPY8ae7/+3mwCJZYWy/\nXVvUfCSflmYpSIKGfP+Vm1XezWY1W84DGhiFj5n8sdaWisv3bpFwFf1YR3Ae\noBoZ4ufNzaQALRqGPMgXETtXZCtzuL/+0vGSKj5SLECiRcSE4jCPEVRy2bcl\nWJyB9r4VmcjF042OMHxphXoYmTEWvgigyaQFHNORu5cK9EHfHpCG6IcjGbdx\n+9Px5YnDY1ix+YpBKePGSTlLE0u6ow0VTUrdvNjl7IUBaRcfJcIIdgCBOTMw\n1uQ/yeyP46V5AFXFnIKeZeQ=\n=FlOf\n-----END PGP SIGNATURE-----\n",
},
{
Type: 2,
Data: "BEGIN:VCARD\nVERSION:4.0\nFN;TYPE=fn:Alice\nitem1.EMAIL:alice@protonmail.com\nUID:proton-web-98c8de5e-4536-140b-9ab0-bd8ab6a2050b\nEND:VCARD",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: ProtonMail\nComment: https://protonmail.com\n\nwsBcBAEBCAAQBQJacZ4qCRDMO9BwcW4mpAAA3jUIAJ88mIyO8Yj0+evSFXnK\nNxNdjNgn7t1leY0BWlh1nkK76XrZEPipdw2QU8cOcZzn1Wby2SGfZVkwoPc4\nzAhPT4WKbkFVqXhDry5399kLwGYJCxdEcw/oPyYj+YgpQKMxhTrQq21tbEwr\n7JDRBXgi3Cckh/XsteFHOIiAVnM7BV6zFudipnYxa4uNF0Bf4VbUZx1Mm0Wb\nMJaGsO5reqQUQzDPO5TdSAZ8qGSdjVv7RESgUu5DckcDSsnB987Zbh9uFc22\nfPYmb6zA0cEZh3dAjpDPT7cg8hlvfYBb+kP3sLFyLiIkdEG8Pcagjf0k+l76\nr1IsPlYBx2LJmsJf+WDNlj8=\n=Xn+3\n-----END PGP SIGNATURE-----\n",
},
},
ContactEmails: []ContactEmail{
{
ID: "4m2sBxLq4McqD0D330Kuy5xG-yyDNXyLEjG5_RYcjy9X-3qHGNP07DNOWLY40TYtUAQr4fAVp8zOcZ_z2o6H-A==",
Name: "Alice",
Email: "alice@protonmail.com",
Type: []string{},
Defaults: 1,
Order: 1,
ContactID: "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==",
LabelIDs: []string{},
},
},
LabelIDs: []string{},
}
var testGetContactEmailByEmail = []ContactEmail{
{
ID: "aefew4323jFv0BhSMw==",
Name: "ProtonMail Features",
Email: "features@protonmail.black",
Type: []string{"work"},
Defaults: 1,
Order: 1,
ContactID: "a29olIjFv0rnXxBhSMw==",
LabelIDs: []string{"I6hgx3Ol-d3HYa3E394T_ACXDmTaBub14w=="},
},
}
func TestContact_GetContactById(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/contacts/v4/s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg=="))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testGetContactByIDResponseBody)
}))
defer s.Close()
contact, err := c.GetContactByID(context.Background(), "s_SN9y1q0jczjYCH4zhvfOdHv1QNovKhnJ9bpDcTE0u7WCr2Z-NV9uubHXvOuRozW-HRVam6bQupVYRMC3BCqg==")
r.NoError(t, err)
if !reflect.DeepEqual(contact, testGetContactByID) {
t.Fatalf("Invalid got contact: expected %+v, got %+v", testGetContactByID, contact)
}
}
func TestContact_GetContactEmailByEmail(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/contacts/v4/emails?Email=someone%40pm.me&Page=1&PageSize=10"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testGetContactEmailByEmailResponseBody)
}))
defer s.Close()
contact, err := c.GetContactEmailByEmail(context.Background(), "someone@pm.me", 1, 10)
r.NoError(t, err)
if !reflect.DeepEqual(contact, testGetContactEmailByEmail) {
t.Fatalf("Invalid got contact: expected %+v, got %+v", testGetContactByID, contact)
}
}
func TestContact_isSignedCardType(t *testing.T) {
if !isSignedCardType(SignedCard) || !isSignedCardType(EncryptedSignedCard) {
t.Fatal("isSignedCardType shouldn't return false for signed card types")
}
if isSignedCardType(CleartextCard) || isSignedCardType(EncryptedCard) {
t.Fatal("isSignedCardType shouldn't return true for non-signed card types")
}
}
func TestContact_isEncryptedCardType(t *testing.T) {
if !isEncryptedCardType(EncryptedCard) || !isEncryptedCardType(EncryptedSignedCard) {
t.Fatal("isEncryptedCardType shouldn't return false for encrypted card types")
}
if isEncryptedCardType(CleartextCard) || isEncryptedCardType(SignedCard) {
t.Fatal("isEncryptedCardType shouldn't return true for non-encrypted card types")
}
}
var testCardsEncrypted = []Card{
{
Type: EncryptedSignedCard,
Data: "-----BEGIN PGP MESSAGE-----\nVersion: GopenPGP 0.0.1 (ddacebe0)\nComment: https://gopenpgp.org\n\nwcBMA0fcZ7XLgmf2AQf/fLKA6ZCkDxumpDoUoFQfO86B9LFuqGEJq+voP12C6UXo\nfB2nTy/K4+VosLKYOkU9sW1PZOCL+i00z+zkqUZ6jchbZBpzwy/UCTmpPRw5zrmr\nW6bZCwwgqJSGVWrvcrDA3bW9cn/HHqQqU6jNeXIF+IuhTscRAJVGehJZYWjr1lgB\nToJhg4+//Bgp/Fxzz8Fej/fsokgOlRJ8xcZKYx0rKL/+Il0u2jnd08kJTegpaY+6\nBlsYBzfYq25WkS02iy02wHbt6XD7AxFDi4WDjsM8bryLSm/KNWrejqfDYb/tMAKa\nKNJqK39/EUewzp1gHEXiGmdDEIFTKCHTDTPV84mwf9I1Ae4yoLs+ilYE6sSk7DCh\nPSWjDC8lpKzmw93slsejTG93HJKQPcZ0rLBpv6qPZX6widNYjDE=\n=QFxr\n-----END PGP MESSAGE-----",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: GopenPGP 0.0.1 (ddacebe0)\nComment: https://gopenpgp.org\n\nwsBcBAABCgAQBQJdZQ1kCRA+tiWe3yHfJAAA9nMH/0X7pS8TGt6Ox0BewRh0vjfQ\n9LPLwbOiHdj97LNqutZcLlDTfm9SPH82221ZpVILWhB0u2kFeNUGihVbjAqJGYJn\nEk2TELLwn8csYRy9r5JkyUirqrvh7jgl4vs1yt8O/3Yb4ARudOoZr8Yrb4+NVNe0\nCcwQJnH/fJPtF1hbarKwtKtCo3IFwTis4pc8qWJRpBH61z1mO0Yr/LIh85QndhnF\nnZ/3MkWOY0kp2gl4ptqtNUw7z+JJ4LLVdT3ycdVK7GVTZmIG90y5KKxwJvrwbS7/\n8rmPGPQ5diLEMrzuKC2plXT6Pdy0ShtZxie2C3JY86e7ol7xvl0pNqxzOrj424w=\n=AOTG\n-----END PGP SIGNATURE-----",
},
}
var testCardsCleartext = []Card{
{
Type: EncryptedSignedCard,
Data: "data",
Signature: "-----BEGIN PGP SIGNATURE-----\nVersion: GopenPGP 0.0.1 (ddacebe0)\nComment: https://gopenpgp.org\n\nwsBcBAABCgAQBQJdZQ1kCRA+tiWe3yHfJAAA9nMH/0X7pS8TGt6Ox0BewRh0vjfQ\n9LPLwbOiHdj97LNqutZcLlDTfm9SPH82221ZpVILWhB0u2kFeNUGihVbjAqJGYJn\nEk2TELLwn8csYRy9r5JkyUirqrvh7jgl4vs1yt8O/3Yb4ARudOoZr8Yrb4+NVNe0\nCcwQJnH/fJPtF1hbarKwtKtCo3IFwTis4pc8qWJRpBH61z1mO0Yr/LIh85QndhnF\nnZ/3MkWOY0kp2gl4ptqtNUw7z+JJ4LLVdT3ycdVK7GVTZmIG90y5KKxwJvrwbS7/\n8rmPGPQ5diLEMrzuKC2plXT6Pdy0ShtZxie2C3JY86e7ol7xvl0pNqxzOrj424w=\n=AOTG\n-----END PGP SIGNATURE-----",
},
}
func TestClient_Decrypt(t *testing.T) {
c := newClient(newManager(Config{}), "")
c.userKeyRing = testPrivateKeyRing
cardCleartext, err := c.DecryptAndVerifyCards(testCardsEncrypted)
r.Nil(t, err)
r.Equal(t, testCardsCleartext[0].Data, cardCleartext[0].Data)
}

View File

@ -1,54 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
)
type pmapiContextKey string
const (
retryContextKey = pmapiContextKey("retry")
retryDisabled = "disabled"
authRefreshContextKey = pmapiContextKey("authRefresh")
authRefreshDisabled = "disabled"
)
func ContextWithoutRetry(parent context.Context) context.Context {
return context.WithValue(parent, retryContextKey, retryDisabled)
}
func isRetryDisabled(ctx context.Context) bool {
if v := ctx.Value(retryContextKey); v != nil {
return v == retryDisabled
}
return false
}
func ContextWithoutAuthRefresh(parent context.Context) context.Context {
return context.WithValue(parent, authRefreshContextKey, authRefreshDisabled)
}
func isAuthRefreshDisabled(ctx context.Context) bool {
if v := ctx.Value(authRefreshContextKey); v != nil {
return v == authRefreshDisabled
}
return false
}

View File

@ -1,31 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import "github.com/ProtonMail/gopenpgp/v2/crypto"
var testIdentity = &crypto.Identity{
Name: "UserID",
Email: "",
}
const (
testUID = "729ad6012421d67ad26950dc898bebe3a6e3caa2" //nolint:gosec
testAccessToken = "de0423049b44243afeec7d9c1d99be7b46da1e8a" //nolint:gosec
testRefreshToken = "a49b98256745bb497bec20e9b55f5de16f01fb52" //nolint:gosec
)

View File

@ -1,78 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"crypto/tls"
"net"
"net/http"
"time"
)
type TLSDialer interface {
DialTLS(network, address string) (conn net.Conn, err error)
}
// CreateTransportWithDialer creates an http.Transport that uses the given dialer to make TLS connections.
func CreateTransportWithDialer(dialer TLSDialer) *http.Transport {
return &http.Transport{
DialTLS: dialer.DialTLS,
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 5 * time.Minute,
ExpectContinueTimeout: 500 * time.Millisecond,
// GODT-126: this was initially 10s but logs from users showed a significant number
// were hitting this timeout, possibly due to flaky wifi taking >10s to reconnect.
// Bumping to 30s for now to avoid this problem.
ResponseHeaderTimeout: 30 * time.Second,
// If we allow up to 30 seconds for response headers, it is reasonable to allow up
// to 30 seconds for the TLS handshake to take place.
TLSHandshakeTimeout: 30 * time.Second,
}
}
// BasicTLSDialer implements TLSDialer.
type BasicTLSDialer struct {
cfg Config
}
// NewBasicTLSDialer returns a new BasicTLSDialer.
func NewBasicTLSDialer(cfg Config) *BasicTLSDialer {
return &BasicTLSDialer{
cfg: cfg,
}
}
// DialTLS returns a connection to the given address using the given network.
func (d *BasicTLSDialer) DialTLS(network, address string) (conn net.Conn, err error) {
dialer := &net.Dialer{Timeout: 30 * time.Second} // Alternative Routes spec says this should be a 30s timeout.
var tlsConfig *tls.Config
// If we are not dialing the standard API then we should skip cert verification checks.
if address != d.cfg.HostURL {
tlsConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec
}
return tls.DialWithDialer(dialer, network, address, tlsConfig)
}

View File

@ -1,116 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"crypto/tls"
"net"
"github.com/sirupsen/logrus"
)
// TrustedAPIPins contains trusted public keys of the protonmail API and proxies.
// NOTE: the proxy pins are the same for all proxy servers, guaranteed by infra team ;).
var TrustedAPIPins = []string{ //nolint:gochecknoglobals
// api.protonmail.ch
`pin-sha256="drtmcR2kFkM8qJClsuWgUzxgBkePfRCkRpqUesyDmeE="`, // current
`pin-sha256="YRGlaY0jyJ4Jw2/4M8FIftwbDIQfh8Sdro96CeEel54="`, // hot backup
`pin-sha256="AfMENBVvOS8MnISprtvyPsjKlPooqh8nMB/pvCrpJpw="`, // cold backup
// protonmail.com
// \todo remove when sure no one is using it.
`pin-sha256="8joiNBdqaYiQpKskgtkJsqRxF7zN0C0aqfi8DacknnI="`, // current
`pin-sha256="JMI8yrbc6jB1FYGyyWRLFTmDNgIszrNEMGlgy972e7w="`, // hot backup
`pin-sha256="Iu44zU84EOCZ9vx/vz67/MRVrxF1IO4i4NIa8ETwiIY="`, // cold backup
// proton.me
`pin-sha256="CT56BhOTmj5ZIPgb/xD5mH8rY3BLo/MlhP7oPyJUEDo="`, // current
`pin-sha256="35Dx28/uzN3LeltkCBQ8RHK0tlNSa2kCpCRGNp34Gxc="`, // hot backup
`pin-sha256="qYIukVc63DEITct8sFT7ebIq5qsWmuscaIKeJx+5J5A="`, // col backup
// proxies
`pin-sha256="EU6TS9MO0L/GsDHvVc9D5fChYLNy5JdGYpJw0ccgetM="`, // main
`pin-sha256="iKPIHPnDNqdkvOnTClQ8zQAIKG0XavaPkcEo0LBAABA="`, // backup 1
`pin-sha256="MSlVrBCdL0hKyczvgYVSRNm88RicyY04Q2y5qrBt0xA="`, // backup 2
`pin-sha256="C2UxW0T1Ckl9s+8cXfjXxlEqwAfPM4HiW2y3UdtBeCw="`, // backup 3
}
// TLSReportURI is the address where TLS reports should be sent.
const TLSReportURI = "https://reports.protonmail.ch/reports/tls"
// PinningTLSDialer wraps a TLSDialer to check fingerprints after connecting and
// to report errors if the fingerprint check fails.
type PinningTLSDialer struct {
dialer TLSDialer
// pinChecker is used to check TLS keys of connections.
pinChecker *pinChecker
reporter *tlsReporter
// tlsIssueNotifier is used to notify something when there is a TLS issue.
tlsIssueNotifier func()
// A logger for logging messages.
log logrus.FieldLogger
}
// NewPinningTLSDialer constructs a new dialer which only returns tcp connections to servers
// which present known certificates.
// If enabled, it reports any invalid certificates it finds.
func NewPinningTLSDialer(cfg Config, dialer TLSDialer) *PinningTLSDialer {
return &PinningTLSDialer{
dialer: dialer,
pinChecker: newPinChecker(TrustedAPIPins),
reporter: newTLSReporter(cfg, TrustedAPIPins),
tlsIssueNotifier: cfg.TLSIssueHandler,
log: logrus.WithField("pkg", "pmapi/tls-pinning"),
}
}
// DialTLS dials the given network/address, returning an error if the certificates don't match the trusted pins.
func (p *PinningTLSDialer) DialTLS(network, address string) (net.Conn, error) {
conn, err := p.dialer.DialTLS(network, address)
if err != nil {
return nil, err
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
if err := p.pinChecker.checkCertificate(conn); err != nil {
if p.tlsIssueNotifier != nil {
go p.tlsIssueNotifier()
}
if tlsConn, ok := conn.(*tls.Conn); ok && p.reporter != nil {
p.reporter.reportCertIssue(
TLSReportURI,
host,
port,
tlsConn.ConnectionState(),
)
}
return nil, err
}
return conn, nil
}

View File

@ -1,67 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"github.com/ProtonMail/proton-bridge/v2/pkg/algo"
)
// ErrTLSMismatch indicates that no TLS fingerprint match could be found.
var ErrTLSMismatch = errors.New("no TLS fingerprint match found")
type pinChecker struct {
trustedPins []string
}
func newPinChecker(trustedPins []string) *pinChecker {
return &pinChecker{
trustedPins: trustedPins,
}
}
// checkCertificate returns whether the connection presents a known TLS certificate.
func (p *pinChecker) checkCertificate(conn net.Conn) error {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return errors.New("connection is not a TLS connection")
}
connState := tlsConn.ConnectionState()
for _, peerCert := range connState.PeerCertificates {
fingerprint := certFingerprint(peerCert)
for _, pin := range p.trustedPins {
if pin == fingerprint {
return nil
}
}
}
return ErrTLSMismatch
}
func certFingerprint(cert *x509.Certificate) string {
return fmt.Sprintf(`pin-sha256=%q`, algo.HashBase64SHA256(string(cert.RawSubjectPublicKeyInfo)))
}

View File

@ -1,144 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"encoding/json"
"io"
"net/http"
"strconv"
"time"
"github.com/sirupsen/logrus"
)
// tlsReport is inspired by https://tools.ietf.org/html/rfc7469#section-3.
// When a TLS key mismatch is detected, a tlsReport is posted to TLSReportURI.
type tlsReport struct {
// DateTime of observed pin validation in time.RFC3339 format.
DateTime string `json:"date-time"`
// Hostname to which the UA made original request that failed pin validation.
Hostname string `json:"hostname"`
// Port to which the UA made original request that failed pin validation.
Port int `json:"port"`
// EffectiveExpirationDate for noted pins in time.RFC3339 format.
EffectiveExpirationDate string `json:"effective-expiration-date"`
// IncludeSubdomains indicates whether or not the UA has noted the
// includeSubDomains directive for the Known Pinned Host.
IncludeSubdomains bool `json:"include-subdomains"`
// NotedHostname indicates the hostname that the UA noted when it noted
// the Known Pinned Host. This field allows operators to understand why
// Pin Validation was performed for, e.g., foo.example.com when the
// noted Known Pinned Host was example.com with includeSubDomains set.
NotedHostname string `json:"noted-hostname"`
// ServedCertificateChain is the certificate chain, as served by
// the Known Pinned Host during TLS session setup. It is provided as an
// array of strings; each string pem1, ... pemN is the Privacy-Enhanced
// Mail (PEM) representation of each X.509 certificate as described in
// [RFC7468].
ServedCertificateChain []string `json:"served-certificate-chain"`
// ValidatedCertificateChain is the certificate chain, as
// constructed by the UA during certificate chain verification. (This
// may differ from the served-certificate-chain.) It is provided as an
// array of strings; each string pem1, ... pemN is the PEM
// representation of each X.509 certificate as described in [RFC7468].
// UAs that build certificate chains in more than one way during the
// validation process SHOULD send the last chain built. In this way,
// they can avoid keeping too much state during the validation process.
ValidatedCertificateChain []string `json:"validated-certificate-chain"`
// The known-pins are the Pins that the UA has noted for the Known
// Pinned Host. They are provided as an array of strings with the
// syntax: known-pin = token "=" quoted-string
// e.g.:
// ```
// "known-pins": [
// 'pin-sha256="d6qzRu9zOECb90Uez27xWltNsj0e1Md7GkYYkVoZWmM="',
// "pin-sha256=\"E9CZ9INDbd+2eRQozYqqbQ2yXLVKB9+xcprMF+44U1g=\""
// ]
// ```
KnownPins []string `json:"known-pins"`
// AppVersion is used to set `x-pm-appversion` json format from datatheorem/TrustKit.
AppVersion string `json:"app-version"`
}
// newTLSReport constructs a new tlsReport configured with the given app version and known pinned public keys.
// Temporal things (current date/time) are not set yet -- they are set when sendReport is called.
func newTLSReport(host, port, server string, certChain, knownPins []string, appVersion string) (report tlsReport) {
// If we can't parse the port for whatever reason, it doesn't really matter; we should report anyway.
intPort, _ := strconv.Atoi(port)
report = tlsReport{
Hostname: host,
Port: intPort,
NotedHostname: server,
ServedCertificateChain: certChain,
KnownPins: knownPins,
AppVersion: appVersion,
}
return
}
// sendReport posts the given TLS report to the standard TLS Report URI.
func (r tlsReport) sendReport(cfg Config, uri string) {
now := time.Now()
r.DateTime = now.Format(time.RFC3339)
r.EffectiveExpirationDate = now.Add(365 * 24 * 60 * 60 * time.Second).Format(time.RFC3339)
b, err := json.Marshal(r)
if err != nil {
logrus.WithError(err).Error("Failed to marshal TLS report")
return
}
req, err := http.NewRequest("POST", uri, bytes.NewReader(b))
if err != nil {
logrus.WithError(err).Error("Failed to create http request")
return
}
req.Header.Add("Content-Type", "application/json")
req.Header.Set("User-Agent", cfg.getUserAgent())
req.Header.Set("x-pm-appversion", r.AppVersion)
logrus.WithField("request", req).Warn("Reporting TLS mismatch")
res, err := (&http.Client{Transport: CreateTransportWithDialer(NewBasicTLSDialer(cfg))}).Do(req)
if err != nil {
logrus.WithError(err).Error("Failed to report TLS mismatch")
return
}
logrus.WithField("response", res).Error("Reported TLS mismatch")
if res.StatusCode != http.StatusOK {
logrus.WithField("status", http.StatusOK).Error("StatusCode was not OK")
}
_, _ = io.ReadAll(res.Body)
_ = res.Body.Close()
}

View File

@ -1,107 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"time"
"github.com/google/go-cmp/cmp"
"github.com/sirupsen/logrus"
)
type sentReport struct {
r tlsReport
t time.Time
}
type tlsReporter struct {
cfg Config
trustedPins []string
sentReports []sentReport
}
func newTLSReporter(cfg Config, trustedPins []string) *tlsReporter {
return &tlsReporter{
cfg: cfg,
trustedPins: trustedPins,
}
}
// reportCertIssue reports a TLS key mismatch.
func (r *tlsReporter) reportCertIssue(remoteURI, host, port string, connState tls.ConnectionState) {
var certChain []string
if len(connState.VerifiedChains) > 0 {
certChain = marshalCert7468(connState.VerifiedChains[len(connState.VerifiedChains)-1])
} else {
certChain = marshalCert7468(connState.PeerCertificates)
}
report := newTLSReport(host, port, connState.ServerName, certChain, r.trustedPins, r.cfg.AppVersion)
if !r.hasRecentlySentReport(report) {
r.recordReport(report)
go report.sendReport(r.cfg, remoteURI)
}
}
// hasRecentlySentReport returns whether the report was already sent within the last 24 hours.
func (r *tlsReporter) hasRecentlySentReport(report tlsReport) bool {
var validReports []sentReport
for _, r := range r.sentReports {
if time.Since(r.t) < 24*time.Hour {
validReports = append(validReports, r)
}
}
r.sentReports = validReports
for _, r := range r.sentReports {
if cmp.Equal(report, r.r) {
return true
}
}
return false
}
// recordReport records the given report and the current time so we can check whether we recently sent this report.
func (r *tlsReporter) recordReport(report tlsReport) {
r.sentReports = append(r.sentReports, sentReport{r: report, t: time.Now()})
}
func marshalCert7468(certs []*x509.Certificate) (pemCerts []string) {
var buffer bytes.Buffer
for _, cert := range certs {
if err := pem.Encode(&buffer, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}); err != nil {
logrus.WithField("pkg", "pmapi/tls-pinning").WithError(err).Error("Failed to encode TLS certificate")
}
pemCerts = append(pemCerts, buffer.String())
buffer.Reset()
}
return pemCerts
}

View File

@ -1,62 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTLSReporter_DoubleReport(t *testing.T) {
reportCounter := 0
reportServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reportCounter++
}))
cfg := Config{
AppVersion: "3",
UserAgent: "useragent",
}
r := newTLSReporter(cfg, TrustedAPIPins)
// Report the same issue many times.
for i := 0; i < 10; i++ {
r.reportCertIssue(reportServer.URL, "myhost", "443", tls.ConnectionState{})
}
// We should only report once.
assert.Eventually(t, func() bool {
return reportCounter == 1
}, time.Second, time.Millisecond)
// If we then report something else many times.
for i := 0; i < 10; i++ {
r.reportCertIssue(reportServer.URL, "anotherhost", "443", tls.ConnectionState{})
}
// We should get a second report.
assert.Eventually(t, func() bool {
return reportCounter == 2
}, time.Second, time.Millisecond)
}

View File

@ -1,149 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
a "github.com/stretchr/testify/assert"
r "github.com/stretchr/testify/require"
)
func TestTLSPinValid(t *testing.T) {
called, _, cm := createClientWithPinningDialer(getRootURL())
_, _ = cm.getAuthInfo(context.Background(), GetAuthInfoReq{Username: "username"})
checkTLSIssueHandler(t, 0, called)
}
func TestTLSPinBackup(t *testing.T) {
called, dialer, cm := createClientWithPinningDialer(getRootURL())
copyTrustedPins(dialer.pinChecker)
dialer.pinChecker.trustedPins[1] = dialer.pinChecker.trustedPins[0]
dialer.pinChecker.trustedPins[0] = ""
_, _ = cm.getAuthInfo(context.Background(), GetAuthInfoReq{Username: "username"})
checkTLSIssueHandler(t, 0, called)
}
func TestTLSPinInvalid(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
}))
defer ts.Close()
called, _, cm := createClientWithPinningDialer(ts.URL)
_, _ = cm.getAuthInfo(context.Background(), GetAuthInfoReq{Username: "username"})
checkTLSIssueHandler(t, 1, called)
}
func TestTLSPinNoMatch(t *testing.T) {
skipIfProxyIsSet(t)
called, dialer, cm := createClientWithPinningDialer(getRootURL())
copyTrustedPins(dialer.pinChecker)
for i := 0; i < len(dialer.pinChecker.trustedPins); i++ {
dialer.pinChecker.trustedPins[i] = "testing"
}
_, _ = cm.getAuthInfo(context.Background(), GetAuthInfoReq{Username: "username"})
_, _ = cm.getAuthInfo(context.Background(), GetAuthInfoReq{Username: "username"})
// Check that it will be reported only once per session, but notified every time.
r.Equal(t, 1, len(dialer.reporter.sentReports))
checkTLSIssueHandler(t, 2, called)
}
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _ := createClientWithPinningDialer("")
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
r.Error(t, err, "expected dial to fail because of wrong public key")
}
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _ := createClientWithPinningDialer("")
copyTrustedPins(dialer.pinChecker)
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="LwnIKjNLV3z243ap8y0yXNPghsqE76J08Eq3COvUt2E="`)
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
}
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
skipIfProxyIsSet(t)
_, dialer, _ := createClientWithPinningDialer("")
copyTrustedPins(dialer.pinChecker)
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
_, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
}
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *manager) {
called := 0
cfg := Config{
AppVersion: "Bridge_1.2.4-test",
HostURL: hostURL,
TLSIssueHandler: func() { called++ },
}
dialer := NewPinningTLSDialer(cfg, NewBasicTLSDialer(cfg))
cm := newManager(cfg)
cm.SetTransport(CreateTransportWithDialer(dialer))
return &called, dialer, cm
}
func copyTrustedPins(pinChecker *pinChecker) {
copiedPins := make([]string, len(pinChecker.trustedPins))
copy(copiedPins, pinChecker.trustedPins)
pinChecker.trustedPins = copiedPins
}
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
a.Eventually(
t,
func() bool {
if wantCalledAtLeast == 0 {
return *called == 0
}
// Dialer can do more attempts resulting in more calls.
return *called >= wantCalledAtLeast
},
time.Second,
10*time.Millisecond,
)
// Repeated again so it generates nice message.
if wantCalledAtLeast == 0 {
r.Equal(t, 0, *called)
} else {
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
}
}

View File

@ -1,149 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"net"
"net/url"
"sync"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// ProxyTLSDialer wraps a TLSDialer to switch to a proxy if the initial dial fails.
type ProxyTLSDialer struct {
dialer TLSDialer
locker sync.RWMutex
directAddress string
proxyAddress string
allowProxy bool
proxyProvider *proxyProvider
proxyUseDuration time.Duration
}
// NewProxyTLSDialer constructs a dialer which provides a proxy-managing layer on top of an underlying dialer.
func NewProxyTLSDialer(cfg Config, dialer TLSDialer) *ProxyTLSDialer {
return &ProxyTLSDialer{
dialer: dialer,
locker: sync.RWMutex{},
directAddress: formatAsAddress(cfg.HostURL),
proxyAddress: formatAsAddress(cfg.HostURL),
proxyProvider: newProxyProvider(cfg, dohProviders, proxyQuery),
proxyUseDuration: proxyUseDuration,
}
}
// formatAsAddress returns URL as `host:port` for easy comparison in DialTLS.
func formatAsAddress(rawURL string) string {
url, err := url.Parse(rawURL)
if err != nil {
// This means wrong configuration.
// Developer should get feedback right away.
panic(err)
}
host := url.Host
if host == "" {
host = url.Path
}
port := "443"
if url.Scheme == "http" {
port = "80"
}
return net.JoinHostPort(host, port)
}
// DialTLS dials the given network/address. If it fails, it retries using a proxy.
func (d *ProxyTLSDialer) DialTLS(network, address string) (net.Conn, error) {
if address == d.directAddress {
address = d.proxyAddress
}
conn, err := d.dialer.DialTLS(network, address)
if err == nil || !d.allowProxy {
return conn, err
}
err = d.switchToReachableServer()
if err != nil {
return nil, err
}
return d.dialer.DialTLS(network, d.proxyAddress)
}
// switchToReachableServer switches to using a reachable server (either proxy or standard API).
func (d *ProxyTLSDialer) switchToReachableServer() error {
d.locker.Lock()
defer d.locker.Unlock()
logrus.Info("Attempting to switch to a proxy")
proxy, err := d.proxyProvider.findReachableServer()
if err != nil {
return errors.Wrap(err, "failed to find a usable proxy")
}
proxyAddress := formatAsAddress(proxy)
// If the chosen proxy is the standard API, we want to use it but still show the troubleshooting screen.
if proxyAddress == d.directAddress {
logrus.Info("The standard API is reachable again; connection drop was only intermittent")
d.proxyAddress = proxyAddress
return ErrNoConnection
}
logrus.WithField("proxy", proxyAddress).Info("Switching to a proxy")
// If the host is currently the rootURL, it's the first time we are enabling a proxy.
// This means we want to disable it again in 24 hours.
if d.proxyAddress == d.directAddress {
go func() {
<-time.After(d.proxyUseDuration)
d.locker.Lock()
defer d.locker.Unlock()
d.proxyAddress = d.directAddress
}()
}
d.proxyAddress = proxyAddress
return nil
}
// AllowProxy allows the dialer to switch to a proxy if need be.
func (d *ProxyTLSDialer) AllowProxy() {
d.locker.Lock()
defer d.locker.Unlock()
d.allowProxy = true
}
// DisallowProxy prevents the dialer from switching to a proxy if need be.
func (d *ProxyTLSDialer) DisallowProxy() {
d.locker.Lock()
defer d.locker.Unlock()
d.allowProxy = false
d.proxyAddress = d.directAddress
}

View File

@ -1,254 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/base64"
"strings"
"sync"
"time"
"github.com/go-resty/resty/v2"
"github.com/miekg/dns"
"github.com/pkg/errors"
)
const (
proxyUseDuration = 24 * time.Hour
proxyLookupWait = 5 * time.Second
proxyCacheRefreshTimeout = 20 * time.Second
proxyDoHTimeout = 20 * time.Second
proxyCanReachTimeout = 20 * time.Second
proxyQuery = "dMFYGSLTQOJXXI33ONVQWS3BOMNUA.protonpro.xyz"
Quad9Provider = "https://dns11.quad9.net/dns-query"
Quad9PortProvider = "https://dns11.quad9.net:5053/dns-query"
GoogleProvider = "https://dns.google/dns-query"
)
var dohProviders = []string{ //nolint:gochecknoglobals
Quad9Provider,
Quad9PortProvider,
GoogleProvider,
}
// proxyProvider manages known proxies.
type proxyProvider struct {
cfg Config
// dohLookup is used to look up the given query at the given DoH provider, returning the TXT records>
dohLookup func(ctx context.Context, query, provider string) (urls []string, err error)
providers []string // List of known doh providers.
query string // The query string used to find proxies.
proxyCache []string // All known proxies, cached in case DoH providers are unreachable.
cacheRefreshTimeout time.Duration
dohTimeout time.Duration
canReachTimeout time.Duration
lastLookup time.Time // The time at which we last attempted to find a proxy.
}
// newProxyProvider creates a new proxyProvider that queries the given DoH providers
// to retrieve DNS records for the given query string.
func newProxyProvider(cfg Config, providers []string, query string) (p *proxyProvider) { //nolint:unparam
p = &proxyProvider{
cfg: cfg,
providers: providers,
query: query,
cacheRefreshTimeout: proxyCacheRefreshTimeout,
dohTimeout: proxyDoHTimeout,
canReachTimeout: proxyCanReachTimeout,
}
// Use the default DNS lookup method; this can be overridden if necessary.
p.dohLookup = p.defaultDoHLookup
return
}
// findReachableServer returns a working API server (either proxy or standard API).
func (p *proxyProvider) findReachableServer() (proxy string, err error) {
log.Debug("Trying to find a reachable server")
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon")
}
p.lastLookup = time.Now()
// We use a waitgroup to wait for both
// a) the check whether the API is reachable, and
// b) the DoH queries.
// This is because the Alternative Routes v2 spec says:
// Call the GET /test/ping route on normal API domain (same time as DoH requests and wait until all have finished)
var wg sync.WaitGroup
var apiReachable bool
wg.Add(2)
go func() {
defer wg.Done()
apiReachable = p.canReach(p.cfg.HostURL)
}()
go func() {
defer wg.Done()
err = p.refreshProxyCache()
}()
wg.Wait()
if apiReachable {
proxy = p.cfg.HostURL
return
}
if err != nil {
return
}
for _, url := range p.proxyCache {
if p.canReach(url) {
proxy = url
return
}
}
return "", errors.New("no reachable server could be found")
}
// refreshProxyCache loads the latest proxies from the known providers.
// If the process takes longer than proxyCacheRefreshTimeout, an error is returned.
func (p *proxyProvider) refreshProxyCache() error {
log.Info("Refreshing proxy cache")
ctx, cancel := context.WithTimeout(context.Background(), p.cacheRefreshTimeout)
defer cancel()
resultChan := make(chan []string)
go func() {
for _, provider := range p.providers {
if proxies, err := p.dohLookup(ctx, p.query, provider); err == nil {
resultChan <- proxies
return
}
}
// If no dohLoopkup worked, cancel right after it's done to not
// block refreshing for the whole cacheRefreshTimeout.
cancel()
}()
select {
case result := <-resultChan:
p.proxyCache = result
return nil
case <-ctx.Done():
return errors.New("timed out while refreshing proxy cache")
}
}
// canReach returns whether we can reach the given url.
func (p *proxyProvider) canReach(url string) bool {
log.WithField("url", url).Debug("Trying to ping proxy")
if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "http://") {
url = "https://" + url
}
dialer := NewPinningTLSDialer(p.cfg, NewBasicTLSDialer(p.cfg))
pinger := resty.New().
SetBaseURL(url).
SetTimeout(p.canReachTimeout).
SetTransport(CreateTransportWithDialer(dialer))
if _, err := pinger.R().Get("/tests/ping"); err != nil {
log.WithField("proxy", url).WithError(err).Warn("Failed to ping proxy")
return false
}
return true
}
// defaultDoHLookup is the default implementation of the proxy manager's DoH lookup.
// It looks up DNS TXT records for the given query URL using the given DoH provider.
// It returns a list of all found TXT records.
// If the whole process takes more than proxyDoHTimeout then an error is returned.
func (p *proxyProvider) defaultDoHLookup(ctx context.Context, query, dohProvider string) (data []string, err error) {
ctx, cancel := context.WithTimeout(ctx, p.dohTimeout)
defer cancel()
dataChan, errChan := make(chan []string), make(chan error)
go func() {
// Build new DNS request in RFC1035 format.
dnsRequest := new(dns.Msg).SetQuestion(dns.Fqdn(query), dns.TypeTXT)
// Pack the DNS request message into wire format.
rawRequest, err := dnsRequest.Pack()
if err != nil {
errChan <- errors.Wrap(err, "failed to pack DNS request")
return
}
// Encode wire-format DNS request message as base64url (RFC4648) without padding chars.
encodedRequest := base64.RawURLEncoding.EncodeToString(rawRequest)
// Make DoH request to the given DoH provider.
rawResponse, err := resty.New().R().SetContext(ctx).SetQueryParam("dns", encodedRequest).Get(dohProvider)
if err != nil {
errChan <- errors.Wrap(err, "failed to make DoH request")
return
}
// Unpack the DNS response.
dnsResponse := new(dns.Msg)
if err = dnsResponse.Unpack(rawResponse.Body()); err != nil {
errChan <- errors.Wrap(err, "failed to unpack DNS response")
return
}
// Pick out the TXT answers.
for _, answer := range dnsResponse.Answer {
if t, ok := answer.(*dns.TXT); ok {
data = append(data, t.Txt...)
}
}
dataChan <- data
}()
select {
case data = <-dataChan:
log.WithField("data", data).Info("Received TXT records")
return
case err = <-errChan:
log.WithField("provider", dohProvider).WithError(err).Error("Failed to query DNS records")
return
case <-ctx.Done():
log.WithField("provider", dohProvider).Error("Timed out querying DNS records")
return []string{}, errors.New("timed out querying DNS records")
}
}

View File

@ -1,192 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"net/http"
"testing"
"time"
r "github.com/stretchr/testify/require"
"golang.org/x/net/http/httpproxy"
)
func TestProxyProvider_FindProxy(t *testing.T) {
proxy := getTrustedServer()
defer closeServer(proxy)
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, proxy.URL, url)
}
func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
reachableProxy := getTrustedServer()
defer closeServer(reachableProxy)
// We actually close the unreachable proxy straight away rather than deferring the closure.
unreachableProxy := getTrustedServer()
closeServer(unreachableProxy)
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{reachableProxy.URL, unreachableProxy.URL}, nil
}
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, reachableProxy.URL, url)
}
func TestProxyProvider_FindProxy_ChooseTrustedProxy(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
untrustedProxy := getUntrustedServer()
defer closeServer(untrustedProxy)
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy.URL, trustedProxy.URL}, nil
}
url, err := p.findReachableServer()
r.NoError(t, err)
r.Equal(t, trustedProxy.URL, url)
}
func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
unreachableProxy1 := getTrustedServer()
closeServer(unreachableProxy1)
unreachableProxy2 := getTrustedServer()
closeServer(unreachableProxy2)
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{unreachableProxy1.URL, unreachableProxy2.URL}, nil
}
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_FailIfNoneTrusted(t *testing.T) {
untrustedProxy1 := getUntrustedServer()
defer closeServer(untrustedProxy1)
untrustedProxy2 := getUntrustedServer()
defer closeServer(untrustedProxy2)
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) {
return []string{untrustedProxy1.URL, untrustedProxy2.URL}, nil
}
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_RefreshCacheTimeout(t *testing.T) {
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.cacheRefreshTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
// We should fail to refresh the proxy cache because the doh provider
// takes 2 seconds to respond but we timeout after just 1 second.
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_FindProxy_CanReachTimeout(t *testing.T) {
slowProxy := getTrustedServerWithHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
time.Sleep(2 * time.Second)
}))
defer closeServer(slowProxy)
p := newProxyProvider(Config{HostURL: ""}, []string{"not used"}, "not used")
p.canReachTimeout = 1 * time.Second
p.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
// We should fail to reach the returned proxy because it takes 2 seconds
// to reach it and we only allow 1.
_, err := p.findReachableServer()
r.Error(t, err)
}
func TestProxyProvider_DoHLookup_Quad9(t *testing.T) {
p := newProxyProvider(Config{}, []string{Quad9Provider, GoogleProvider}, proxyQuery)
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9Provider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
// DISABLEDTestProxyProvider_DoHLookup_Quad9Port cannot run on CI due to custom
// port filter. Basic functionality should be covered by other tests. Keeping
// code here to be able to run it locally if needed.
func DISABLEDTestProxyProviderDoHLookupQuad9Port(t *testing.T) {
p := newProxyProvider(Config{}, []string{Quad9PortProvider, GoogleProvider}, proxyQuery)
records, err := p.dohLookup(context.Background(), proxyQuery, Quad9PortProvider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
func TestProxyProvider_DoHLookup_Google(t *testing.T) {
p := newProxyProvider(Config{}, []string{Quad9Provider, GoogleProvider}, proxyQuery)
records, err := p.dohLookup(context.Background(), proxyQuery, GoogleProvider)
r.NoError(t, err)
r.NotEmpty(t, records)
}
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
skipIfProxyIsSet(t)
p := newProxyProvider(Config{}, []string{Quad9Provider, GoogleProvider}, proxyQuery)
url, err := p.findReachableServer()
r.NoError(t, err)
r.NotEmpty(t, url)
}
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
skipIfProxyIsSet(t)
p := newProxyProvider(Config{}, []string{"https://unreachable", Quad9Provider, GoogleProvider}, proxyQuery)
url, err := p.findReachableServer()
r.NoError(t, err)
r.NotEmpty(t, url)
}
// skipIfProxyIsSet skips the tests if HTTPS proxy is set.
// Should be used for tests depending on proper certificate checks which
// is not possible under our CI setup.
func skipIfProxyIsSet(t *testing.T) {
if httpproxy.FromEnvironment().HTTPSProxy != "" {
t.SkipNow()
}
}

View File

@ -1,268 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// getTrustedServer returns a server and sets its public key as one of the pinned ones.
func getTrustedServer() *httptest.Server {
return getTrustedServerWithHandler(
http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
// Do nothing.
}),
)
}
func getTrustedServerWithHandler(handler http.HandlerFunc) *httptest.Server {
proxy := httptest.NewTLSServer(handler)
pin := certFingerprint(proxy.Certificate())
TrustedAPIPins = append(TrustedAPIPins, pin)
return proxy
}
const servercrt = `
-----BEGIN CERTIFICATE-----
MIIE5TCCA82gAwIBAgIJAKsmhcMFGfGcMA0GCSqGSIb3DQEBCwUAMIGsMQswCQYD
VQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzARBgNVBAcMClJhbmRvbUNp
dHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEfMB0GA1UECwwWUmFuZG9t
T3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYRaGVsbG9AZXhhbXBsZS5j
b20xEjAQBgNVBAMMCTEyNy4wLjAuMTAeFw0yMDA0MjQxMzI3MzdaFw0yMTA5MDYx
MzI3MzdaMIGsMQswCQYDVQQGEwJVUzEUMBIGA1UECAwLUmFuZG9tU3RhdGUxEzAR
BgNVBAcMClJhbmRvbUNpdHkxGzAZBgNVBAoMElJhbmRvbU9yZ2FuaXphdGlvbjEf
MB0GA1UECwwWUmFuZG9tT3JnYW5pemF0aW9uVW5pdDEgMB4GCSqGSIb3DQEJARYR
aGVsbG9AZXhhbXBsZS5jb20xEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZI
hvcNAQEBBQADggEPADCCAQoCggEBANAnYyqhosWwNzGjBwSwmDUINOaPs4TSTgKt
r6CE01atxAWzWUCyYqnQ4fPe5q2tx5t/VrmnTNpzycammKJszGLlmj9DFxSiYVw2
pTTK3DBWFkfTwxq98mM7wMnCWy1T2L2pmuYjnd7Pa6pQa9OHYoJwRzlIl2Q3YVdM
GIBDbkW728A1dcelkIdFpv3r3ayTZv01vU8JMXd4PLHwXU0x0hHlH52+kx+9Ndru
rdqqV6LqVfNlSR1jFZkwLBBqvh3XrJRD9Q01EAX6m+ufZ0yq8mK9ifMRtwQet10c
kKMnx63MwvxDFmqrBj4HMtIRUpK+LBDs1ke7DvS0eLqaojWl28ECAwEAAaOCAQYw
ggECMIHLBgNVHSMEgcMwgcChgbKkga8wgawxCzAJBgNVBAYTAlVTMRQwEgYDVQQI
DAtSYW5kb21TdGF0ZTETMBEGA1UEBwwKUmFuZG9tQ2l0eTEbMBkGA1UECgwSUmFu
ZG9tT3JnYW5pemF0aW9uMR8wHQYDVQQLDBZSYW5kb21Pcmdhbml6YXRpb25Vbml0
MSAwHgYJKoZIhvcNAQkBFhFoZWxsb0BleGFtcGxlLmNvbTESMBAGA1UEAwwJMTI3
LjAuMC4xggkAvCxbs152YckwCQYDVR0TBAIwADALBgNVHQ8EBAMCBPAwGgYDVR0R
BBMwEYIJMTI3LjAuMC4xhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAC7ZycZMZ5
L+cjIpwSj0cemLkVD+kcFUCkI7ket5gbX1PmavmnpuFl9Sru0eJ5wyJ+97MQElPA
CNFgXoX7DbJWkcd/LSksvZoJnpc1sTqFKMWFmOUxmUD62lCacuhqE27ZTThQ/53P
3doLa74rKzUqlPI8OL4R34FY2deL7t5l2KSnpf7CKNeF5bkinAsn6NBqyZs2KPmg
yT1/POdlRewzGSqBTMdktNQ4vKSfdFjcfVeo8PSHBgbGXZ5KoHZ6R6DNJehEh27l
z3OteROLGoii+w3OllLq6JATif2MDIbH0s/KjGjbXSSGbM/rZu5eBZm5/vksGAzc
u53wgIhCJGuX
-----END CERTIFICATE-----
`
const serverkey = `
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDQJ2MqoaLFsDcx
owcEsJg1CDTmj7OE0k4Cra+ghNNWrcQFs1lAsmKp0OHz3uatrcebf1a5p0zac8nG
ppiibMxi5Zo/QxcUomFcNqU0ytwwVhZH08MavfJjO8DJwlstU9i9qZrmI53ez2uq
UGvTh2KCcEc5SJdkN2FXTBiAQ25Fu9vANXXHpZCHRab9692sk2b9Nb1PCTF3eDyx
8F1NMdIR5R+dvpMfvTXa7q3aqlei6lXzZUkdYxWZMCwQar4d16yUQ/UNNRAF+pvr
n2dMqvJivYnzEbcEHrddHJCjJ8etzML8QxZqqwY+BzLSEVKSviwQ7NZHuw70tHi6
mqI1pdvBAgMBAAECggEAOqqPOYm63arPs462QK0hCPlaJ41i1FGNqRWYxU4KXoi1
EcI9qo1cX24+8MPnEhZDhuD56XNsprkxqmpz5Htzk4AQ3DmlfKxTcnD4WQu/yWPJ
/c6CU7wrX6qMqJC9r+XM1Y/C15A8Q3sEZkkqSsECk67fdBawjI9LQRZyZVwb7U0F
qtvbKM7VQA6hrgdSmXWJ+spp5yymVFF22Ssz31SSbCI93bnp3mukRCKWdRmA9pmT
VXa0HzJ5p70WC+Se9nA/1riWGKt4HCmjVeEtZuiwaUTlXDSeYpu2e4QrX1OnUXBu
Z7yfviTqA8o7KfiA6urumFbAMJcibxkWJoWacc5tTQKBgQD39ZdtNz8B6XJy7f5h
bo9Ag9OrkVX+HITQyWKpcCDba9SuIX3/F++2AK4oeJ3aHKMJWiP19hQvGS1xE67X
TKejOsQxORn6nAYQpFd3AOBOtKAC+VQITBqlfq2ukGmvcQ1O31hMOFbZagFA5cpU
LYb9VVDsZzhM7CccIn/EGEZjgwKBgQDW51rUA2S9naV/iEGhw1tuhoQ5OADD/n8f
pPIkbGxmACDaX/7jt+UwlDU0EsI+aBlJUDqGiEZ5z3UPmaSJUdfRCeJEdKIe1GLm
nqF3sF6Aq+S/79v/wKYn+MHcoiWog5n3McLzZ3+0rwrhMREjE2eWPwVHz/jJIFP3
Pp3+UZVsawKBgB4Az5PdjXgzwS968L7lW9wYl3I5Iciftsp0s8WA1dj3EUMItnA5
ez3wkyI+hgswT+H/0D4gyoxwZXk7Qnq2wcoUgEzcdfJHEszMtfCmYH3liT8S4EIo
w0inLWjj/IXIDi4vBEYkww2HsCMkKvlIkP7yZdpVGxDjuk/DNOaLcWj1AoGAXuyK
PiPRl7/Onmp9MwqrlEJunSeTjv8W/89H9ba+mr9rw4mreMJ9xdtxNLMkgZRRtwRt
FYeUObHdLyradp1kCr2m6D3sblm55cwj3k5VL9i9jdpQ/sMFoZpLZz1oDOs0Uu/0
ALeyvQikcZvOygOEOeVUW8gNSCmzbP6HoxI+QkkCgYBCI6oL4GPcPPqzd+2djbOD
z3rVUyHzYc1KUcBixK/uaRQKM886k4CL8/GvbHHI/yoZ7xWJGnBi59DtpqnGTZJ2
FDJwYIlQKhZmsyVcZu/4smsaejGnHn/liksVlgesSwCtOrsd2AC8fBXSyrTWJx8o
vwRMog6lPhlRhHh/FZ43Cg==
-----END PRIVATE KEY-----
`
// getUntrustedServer returns a server but it doesn't add its public key to the list of pinned ones.
func getUntrustedServer() *httptest.Server {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
cert, err := tls.X509KeyPair([]byte(servercrt), []byte(serverkey))
if err != nil {
panic(err)
}
server.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
server.StartTLS()
return server
}
// closeServer closes the given server. If it is a trusted server, its cert is removed from the trusted public keys.
func closeServer(server *httptest.Server) {
pin := certFingerprint(server.Certificate())
for i := range TrustedAPIPins {
if TrustedAPIPins[i] == pin {
TrustedAPIPins = append(TrustedAPIPins[:i], TrustedAPIPins[i:]...)
break
}
}
server.Close()
}
func TestProxyDialer_UseProxy(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
cfg := Config{HostURL: ""}
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
}
func TestProxyDialer_UseProxy_MultipleTimes(t *testing.T) {
proxy1 := getTrustedServer()
defer closeServer(proxy1)
proxy2 := getTrustedServer()
defer closeServer(proxy2)
proxy3 := getTrustedServer()
defer closeServer(proxy3)
cfg := Config{HostURL: ""}
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy3.URL), d.proxyAddress)
}
func TestProxyDialer_UseProxy_RevertAfterTime(t *testing.T) {
trustedProxy := getTrustedServer()
defer closeServer(trustedProxy)
cfg := Config{HostURL: ""}
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
d.proxyUseDuration = time.Second
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
time.Sleep(2 * time.Second)
require.Equal(t, ":443", d.proxyAddress)
}
func TestProxyDialer_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachable(t *testing.T) {
trustedProxy := getTrustedServer()
cfg := Config{HostURL: ""}
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{trustedProxy.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(trustedProxy.URL), d.proxyAddress)
// Simulate that the proxy stops working and that the standard api is reachable again.
closeServer(trustedProxy)
d.directAddress = formatAsAddress(getRootURL())
d.proxyProvider.cfg.HostURL = getRootURL()
time.Sleep(proxyLookupWait)
// We should now find the original API URL if it is working again.
// The error should be ErrAPINotReachable because the connection dropped intermittently but
// the original API is now reachable (see Alternative-Routing-v2 spec for details).
err = d.switchToReachableServer()
require.Error(t, err)
require.Equal(t, formatAsAddress(getRootURL()), d.proxyAddress)
}
func TestProxyDialer_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBlocked(t *testing.T) {
// proxy1 is closed later in this test so we don't defer it here.
proxy1 := getTrustedServer()
proxy2 := getTrustedServer()
defer closeServer(proxy2)
cfg := Config{HostURL: ""}
d := NewProxyTLSDialer(cfg, NewBasicTLSDialer(cfg))
d.proxyProvider.dohLookup = func(ctx context.Context, q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
err := d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy1.URL), d.proxyAddress)
// Have to wait so as to not get rejected.
time.Sleep(proxyLookupWait)
// The proxy stops working and the protonmail API is still blocked.
closeServer(proxy1)
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
err = d.switchToReachableServer()
require.NoError(t, err)
require.Equal(t, formatAsAddress(proxy2.URL), d.proxyAddress)
}
func TestFormatAsAddress(t *testing.T) {
r := require.New(t)
testData := map[string]string{
"sub.domain.tld": "sub.domain.tld:443",
"http://sub.domain.tld": "sub.domain.tld:80",
"https://sub.domain.tld": "sub.domain.tld:443",
"ftp://sub.domain.tld": "sub.domain.tld:443",
"//sub.domain.tld": "sub.domain.tld:443",
}
for rawURL, wantURL := range testData {
r.Equal(wantURL, formatAsAddress(rawURL))
}
}

View File

@ -1,88 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import "errors"
var (
ErrNoConnection = errors.New("no internet connection")
ErrUnauthorized = errors.New("API client is unauthorized")
ErrUpgradeApplication = errors.New("application upgrade required")
ErrBad2FACode = errors.New("incorrect 2FA code")
ErrBad2FACodeTryAgain = errors.New("incorrect 2FA code: please try again")
ErrPaidPlanRequired = errors.New("paid subscription plan is required")
ErrPasswordWrong = errors.New("wrong password")
)
// ErrUnprocessableEntity ...
type ErrUnprocessableEntity struct {
OriginalError error
}
func IsUnprocessableEntity(err error) bool {
_, ok := err.(ErrUnprocessableEntity)
return ok
}
func (err ErrUnprocessableEntity) Error() string {
return err.OriginalError.Error()
}
// ErrBadRequest ...
type ErrBadRequest struct {
OriginalError error
}
func IsBadRequest(err error) bool {
_, ok := err.(ErrBadRequest)
return ok
}
func (err ErrBadRequest) Error() string {
return err.OriginalError.Error()
}
// ErrAuthFailed ...
type ErrAuthFailed struct {
OriginalError error
}
func IsFailedAuth(err error) bool {
_, ok := err.(ErrAuthFailed)
return ok
}
func (err ErrAuthFailed) Error() string {
return err.OriginalError.Error()
}
// ErrUnlockFailed ...
type ErrUnlockFailed struct {
OriginalError error
}
func IsFailedUnlock(err error) bool {
_, ok := err.(ErrUnlockFailed)
return ok
}
func (err ErrUnlockFailed) Error() string {
return err.OriginalError.Error()
}

View File

@ -1,247 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/json"
"net/mail"
"github.com/go-resty/resty/v2"
)
// Event represents changes since the last check.
type Event struct {
// The current event ID.
EventID string
// If set to one, all cached data must be fetched again.
Refresh int
// If set to one, fetch more events.
More Boolean
// Changes applied to messages.
Messages []*EventMessage
// Counts of messages per labels.
MessageCounts []*MessagesCount
// Changes applied to labels.
Labels []*EventLabel
// Current user status.
User *User
// Changes to addresses.
Addresses []*EventAddress
// Messages to show to the user.
Notices []string
// Update of used user space
UsedSpace *int64
}
// EventAction is the action that created a change.
type EventAction int
const (
EventDelete EventAction = iota // EventDelete Item has been deleted.
EventCreate // EventCreate Item has been created.
EventUpdate // EventUpdate Item has been updated.
EventUpdateFlags // EventUpdateFlags For messages: flags have been updated.
)
// Flags for event refresh.
const (
EventRefreshMail = 1
EventRefreshContact = 2
EventRefreshAll = 255
)
// maxNumberOfMergedEvents limits how many events are merged into one. It means
// when GetEvent is called and event returns there is more events, it will
// automatically fetch next one and merge it up to this number of events.
const maxNumberOfMergedEvents = 50
// EventItem is an item that has changed.
type EventItem struct {
ID string
Action EventAction
}
// EventMessage is a message that has changed.
type EventMessage struct {
EventItem
// If the message has been created, the new message.
Created *Message `json:"-"`
// If the message has been updated, the updated fields.
Updated *EventMessageUpdated `json:"-"`
}
// eventMessage defines a new type to prevent MarshalJSON/UnmarshalJSON infinite loops.
type eventMessage EventMessage
type rawEventMessage struct {
eventMessage
// This will be parsed depending on the action.
Message json.RawMessage `json:",omitempty"`
}
func (em *EventMessage) UnmarshalJSON(b []byte) (err error) {
var raw rawEventMessage
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
*em = EventMessage(raw.eventMessage)
switch em.Action {
case EventCreate:
em.Created = &Message{ID: raw.ID}
return json.Unmarshal(raw.Message, em.Created)
case EventUpdate, EventUpdateFlags:
em.Updated = &EventMessageUpdated{ID: raw.ID}
return json.Unmarshal(raw.Message, em.Updated)
case EventDelete:
return nil
}
return nil
}
func (em *EventMessage) MarshalJSON() ([]byte, error) {
var raw rawEventMessage
raw.eventMessage = eventMessage(*em)
var err error
switch em.Action {
case EventCreate:
raw.Message, err = json.Marshal(em.Created)
case EventUpdate, EventUpdateFlags:
raw.Message, err = json.Marshal(em.Updated)
case EventDelete:
}
if err != nil {
return nil, err
}
return json.Marshal(raw)
}
// EventMessageUpdated contains changed fields for an updated message.
type EventMessageUpdated struct {
ID string
Subject *string
Unread *Boolean
Flags *int64
Sender *mail.Address
ToList *[]*mail.Address
CCList *[]*mail.Address
BCCList *[]*mail.Address
Time int64
// Fields only present for EventUpdateFlags.
LabelIDs []string
LabelIDsAdded []string
LabelIDsRemoved []string
}
// EventLabel is a label that has changed.
type EventLabel struct {
EventItem
Label *Label
}
// EventAddress is an address that has changed.
type EventAddress struct {
EventItem
Address *Address
}
// GetEvent returns a summary of events that occurred since last. To get the latest event,
// provide an empty last value. The latest event is always empty.
func (c *client) GetEvent(ctx context.Context, eventID string) (*Event, error) {
return c.getEvent(ctx, eventID, 1)
}
func (c *client) getEvent(ctx context.Context, eventID string, numberOfMergedEvents int) (*Event, error) {
if eventID == "" {
eventID = "latest"
}
var event *Event
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&event).Get("/events/" + eventID)
}); err != nil {
return nil, err
}
// API notifies about used space two ways:
// - by `event.User.UsedSpace`
// - by `event.UsedSpace`
//
// Because event merging is implemented for User object we copy the
// value from event.UsedSpace to event.User.UsedSpace and continue with
// user.
if event.UsedSpace != nil {
if event.User == nil {
event.User = &User{UsedSpace: event.UsedSpace}
} else {
event.User.UsedSpace = event.UsedSpace
}
}
if event.More && numberOfMergedEvents < maxNumberOfMergedEvents {
nextEvent, err := c.getEvent(ctx, event.EventID, numberOfMergedEvents+1)
if err != nil {
return nil, err
}
event = mergeEvents(event, nextEvent)
}
return event, nil
}
// mergeEvents combines an old events and a new events object.
// This is not as simple as just blindly joining the two because some things should only be taken from the new events.
func mergeEvents(eventsOld *Event, eventsNew *Event) (mergedEvents *Event) {
return &Event{
EventID: eventsNew.EventID,
Refresh: eventsOld.Refresh | eventsNew.Refresh,
More: eventsNew.More,
Messages: append(eventsOld.Messages, eventsNew.Messages...),
MessageCounts: append(eventsOld.MessageCounts, eventsNew.MessageCounts...),
Labels: append(eventsOld.Labels, eventsNew.Labels...),
User: mergeUserEvents(eventsOld.User, eventsNew.User),
Addresses: append(eventsOld.Addresses, eventsNew.Addresses...),
Notices: append(eventsOld.Notices, eventsNew.Notices...),
}
}
func mergeUserEvents(userOld, userNew *User) *User {
if userNew == nil {
return userOld
}
if userOld != nil {
if userNew.MaxSpace == nil {
userNew.MaxSpace = userOld.MaxSpace
}
if userNew.UsedSpace == nil {
userNew.UsedSpace = userOld.UsedSpace
}
}
return userNew
}

View File

@ -1,538 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"fmt"
"net/http"
"net/mail"
"regexp"
"strconv"
"strings"
"testing"
r "github.com/stretchr/testify/require"
)
func TestClient_GetEvent(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/events/latest"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testEventBody)
}))
defer s.Close()
event, err := c.GetEvent(context.Background(), "")
r.NoError(t, err)
r.Equal(t, testEvent, event)
}
func TestClient_GetEvent_withID(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/events/"+testEvent.EventID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testEventBody)
}))
defer s.Close()
event, err := c.GetEvent(context.Background(), testEvent.EventID)
r.NoError(t, err)
r.Equal(t, testEvent, event)
}
// We first call GetEvent with id of eventID1, which returns More=1 so we fetch with id eventID2.
func TestClient_GetEvent_mergeEvents(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch req.URL.RequestURI() {
case "/events/eventID1":
r.NoError(t, checkMethodAndPath(req, "GET", "/events/eventID1"))
fmt.Fprint(w, testEventBodyMore1)
case "/events/eventID2":
r.NoError(t, checkMethodAndPath(req, "GET", "/events/eventID2"))
fmt.Fprint(w, testEventBodyMore2)
default:
t.Fail()
}
}))
defer s.Close()
event, err := c.GetEvent(context.Background(), "eventID1")
r.NoError(t, err)
r.Equal(t, testEventMerged, event)
}
func TestClient_GetEvent_mergeMaxNumberOfEvents(t *testing.T) {
numberOfCalls := 0
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
numberOfCalls++
re := regexp.MustCompile(`/eventID([0-9]+)`)
eventIDString := re.FindStringSubmatch(req.URL.RequestURI())[1]
eventID, err := strconv.Atoi(eventIDString)
r.NoError(t, err)
if numberOfCalls > maxNumberOfMergedEvents*2 {
r.Fail(t, "Too many calls!")
}
body := strings.ReplaceAll(testEventBodyMore1, "eventID2", "eventID"+strconv.Itoa(eventID+1))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, body)
}))
defer s.Close()
event, err := c.GetEvent(context.Background(), "eventID1")
r.NoError(t, err)
r.Equal(t, maxNumberOfMergedEvents, numberOfCalls)
r.True(t, bool(event.More))
}
var (
testEventMessageUpdateUnread = Boolean(false)
testEvent = &Event{
EventID: "eventID1",
Refresh: 0,
Messages: []*EventMessage{
{
EventItem: EventItem{ID: "hdI7aIgUO1hFplCIcJHB0jShRVsAzS0AB75wGCaiNVeIHXLmaUnt4eJ8l7c7L6uk4g0ZdXhGWG5gfh6HHgAZnw==", Action: EventCreate},
Created: &Message{
ID: "hdI7aIgUO1hFplCIcJHB0jShRVsAzS0AB75wGCaiNVeIHXLmaUnt4eJ8l7c7L6uk4g0ZdXhGWG5gfh6HHgAZnw==",
Header: make(mail.Header),
Subject: "Hey there",
},
},
{
EventItem: EventItem{ID: "bSFLAimPSfGz2Kj0aV3l3AyXsof_Vf7sfrrMJ8ifgGJe-f2NG2eLaEGXLytjMhq9wnLMtkoZpO2uBXM4nOVa5g==", Action: EventUpdateFlags},
Updated: &EventMessageUpdated{
ID: "bSFLAimPSfGz2Kj0aV3l3AyXsof_Vf7sfrrMJ8ifgGJe-f2NG2eLaEGXLytjMhq9wnLMtkoZpO2uBXM4nOVa5g==",
Unread: &testEventMessageUpdateUnread,
Time: 1472391377,
LabelIDsAdded: []string{ArchiveLabel},
LabelIDsRemoved: []string{InboxLabel},
},
},
{
EventItem: EventItem{ID: "XRBMBYnSkaEJWtqFACp2kjlNc-7GjzX3SnPcOtWK4PyLG11Nhsg0uxPYjTXoClQfB-EHVDl9gE3w2PVuj93jBg==", Action: EventDelete},
},
},
MessageCounts: []*MessagesCount{
{
LabelID: "0",
Total: 19,
Unread: 2,
},
{
LabelID: "6",
Total: 1,
Unread: 0,
},
},
Notices: []string{"Server will be down in 2min because of a NSA attack"},
}
testEventMerged = &Event{
EventID: "eventID3",
Refresh: 1,
Messages: []*EventMessage{
{
EventItem: EventItem{ID: "msgID1", Action: EventCreate},
Created: &Message{
ID: "id",
Header: make(mail.Header),
Subject: "Hey there",
},
},
{
EventItem: EventItem{ID: "msgID2", Action: EventCreate},
Created: &Message{
ID: "id",
Header: make(mail.Header),
Subject: "Hey there again",
},
},
},
MessageCounts: []*MessagesCount{
{
LabelID: "label1",
Total: 19,
Unread: 2,
},
{
LabelID: "label2",
Total: 1,
Unread: 0,
},
{
LabelID: "label2",
Total: 2,
Unread: 1,
},
{
LabelID: "label3",
Total: 1,
Unread: 0,
},
},
Notices: []string{"Server will be down in 2min because of a NSA attack", "Just kidding lol"},
Labels: []*EventLabel{
{
EventItem: EventItem{
ID: "labelID1",
Action: 1,
},
Label: &Label{
ID: "id",
Name: "Event Label 1",
},
},
{
EventItem: EventItem{
ID: "labelID2",
Action: 1,
},
Label: &Label{
ID: "id",
Name: "Event Label 2",
},
},
},
User: &User{
ID: "userID1",
Name: "user",
UsedSpace: &usedSpace,
MaxSpace: &maxSpace,
},
Addresses: []*EventAddress{
{
EventItem: EventItem{
ID: "addressID1",
Action: 2,
},
Address: &Address{
ID: "id",
DisplayName: "address 1",
},
},
{
EventItem: EventItem{
ID: "addressID2",
Action: 2,
},
Address: &Address{
ID: "id",
DisplayName: "address 2",
},
},
},
}
)
const (
testEventBody = `{
"EventID": "eventID1",
"Refresh": 0,
"Messages": [
{
"ID": "hdI7aIgUO1hFplCIcJHB0jShRVsAzS0AB75wGCaiNVeIHXLmaUnt4eJ8l7c7L6uk4g0ZdXhGWG5gfh6HHgAZnw==",
"Action": 1,
"Message": {
"ID": "hdI7aIgUO1hFplCIcJHB0jShRVsAzS0AB75wGCaiNVeIHXLmaUnt4eJ8l7c7L6uk4g0ZdXhGWG5gfh6HHgAZnw==",
"Subject": "Hey there"
}
},
{
"ID": "bSFLAimPSfGz2Kj0aV3l3AyXsof_Vf7sfrrMJ8ifgGJe-f2NG2eLaEGXLytjMhq9wnLMtkoZpO2uBXM4nOVa5g==",
"Action": 3,
"Message": {
"ConversationID": "2oX3EILYRuZ0IRBVlzMg1oV5eazQL67sFIHlcR8bjickPn7K4id4sJZuAB6n0pdtI3hRIVsjCpgWfRm8c_x3IQ==",
"Unread": 0,
"Time": 1472391377,
"Location": 6,
"LabelIDsAdded": [
"6"
],
"LabelIDsRemoved": [
"0"
]
}
},
{
"ID": "XRBMBYnSkaEJWtqFACp2kjlNc-7GjzX3SnPcOtWK4PyLG11Nhsg0uxPYjTXoClQfB-EHVDl9gE3w2PVuj93jBg==",
"Action": 0
}
],
"Conversations": [
{
"ID": "2oX3EILYRuZ0IRBVlzMg1oV5eazQL67sFIHlcR8bjickPn7K4id4sJZuAB6n0pdtI3hRIVsjCpgWfRm8c_x3IQ==",
"Action": 1,
"Conversation": {
"ID": "2oX3EILYRuZ0IRBVlzMg1oV5eazQL67sFIHlcR8bjickPn7K4id4sJZuAB6n0pdtI3hRIVsjCpgWfRm8c_x3IQ==",
"Order": 1616,
"Subject": "Hey there",
"Senders": [
{
"Address": "apple@protonmail.com",
"Name": "apple@protonmail.com"
}
],
"Recipients": [
{
"Address": "apple@protonmail.com",
"Name": "apple@protonmail.com"
}
],
"NumMessages": 1,
"NumUnread": 1,
"NumAttachments": 0,
"ExpirationTime": 0,
"TotalSize": 636,
"AddressID": "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==",
"LabelIDs": [
"0"
],
"Labels": [
{
"Count": 1,
"NumMessages": 1,
"NumUnread": 1,
"ID": "0"
}
]
}
}
],
"Total": {
"Locations": [
{
"Location": 0,
"Count": 19
},
{
"Location": 1,
"Count": 16
},
{
"Location": 2,
"Count": 16
},
{
"Location": 3,
"Count": 17
},
{
"Location": 6,
"Count": 1
}
],
"Labels": [
{
"LabelID": "LLz8ysmVxwr4dF6mWpClePT0SpSWOEvzTdq17RydSl4ndMckvY1K63HeXDzn03BJQwKYvgf-eWT8Qfd9WVuIEQ==",
"Count": 2
},
{
"LabelID": "BvbqbySUPo9uWW_eR8tLA13NUsQMz3P4Zhw4UnpvrKqURnrHlE6L2Au0nplHfHlVXFgGz4L4hJ9-BYllOL-L5g==",
"Count": 2
}
],
"Starred": 3
},
"Unread": {
"Locations": [
{
"Location": 0,
"Count": 2
},
{
"Location": 1,
"Count": 0
},
{
"Location": 2,
"Count": 0
},
{
"Location": 3,
"Count": 0
},
{
"Location": 6,
"Count": 0
}
],
"Labels": [
{
"LabelID": "LLz8ysmVxwr4dF6mWpClePT0SpSWOEvzTdq17RydSl4ndMckvY1K63HeXDzn03BJQwKYvgf-eWT8Qfd9WVuIEQ==",
"Count": 0
},
{
"LabelID": "BvbqbySUPo9uWW_eR8tLA13NUsQMz3P4Zhw4UnpvrKqURnrHlE6L2Au0nplHfHlVXFgGz4L4hJ9-BYllOL-L5g==",
"Count": 0
}
],
"Starred": 0
},
"MessageCounts": [
{
"LabelID": "0",
"Total": 19,
"Unread": 2
},
{
"LabelID": "6",
"Total": 1,
"Unread": 0
}
],
"ConversationCounts": [
{
"LabelID": "0",
"Total": 19,
"Unread": 2
},
{
"LabelID": "6",
"Total": 1,
"Unread": 0
}
],
"Notices": ["Server will be down in 2min because of a NSA attack"],
"Code": 1000
}
`
testEventBodyMore1 = `{
"EventID": "eventID2",
"More": 1,
"Refresh": 1,
"Messages": [
{
"ID": "msgID1",
"Action": 1,
"Message": {
"ID": "id",
"Subject": "Hey there"
}
}
],
"MessageCounts": [
{
"LabelID": "label1",
"Total": 19,
"Unread": 2
},
{
"LabelID": "label2",
"Total": 1,
"Unread": 0
}
],
"Labels": [
{
"ID":"labelID1",
"Action":1,
"Label":{
"ID":"id",
"Name":"Event Label 1"
}
}
],
"User": {
"ID": "userID1",
"Name": "user",
"UsedSpace": 444,
"MaxSpace": 12345678
},
"Addresses": [
{
"ID": "addressID1",
"Action": 2,
"Address": {
"ID": "id",
"DisplayName": "address 1"
}
}
],
"UsedSpace": 12345,
"Notices": ["Server will be down in 2min because of a NSA attack"]
}
`
testEventBodyMore2 = `{
"EventID": "eventID3",
"Refresh": 0,
"Messages": [
{
"ID": "msgID2",
"Action": 1,
"Message": {
"ID": "id",
"Subject": "Hey there again"
}
}
],
"MessageCounts": [
{
"LabelID": "label2",
"Total": 2,
"Unread": 1
},
{
"LabelID": "label3",
"Total": 1,
"Unread": 0
}
],
"Labels": [
{
"ID":"labelID2",
"Action":1,
"Label":{
"ID":"id",
"Name":"Event Label 2"
}
}
],
"User": {
"ID": "userID1",
"Name": "user",
"UsedSpace": 23456
},
"Addresses": [
{
"ID": "addressID2",
"Action": 2,
"Address": {
"ID": "id",
"DisplayName": "address 2"
}
}
],
"Notices": ["Just kidding lol"]
}
`
)

View File

@ -1,153 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"strconv"
"github.com/go-resty/resty/v2"
)
const (
MaxImportMessageRequestLength = 10
MaxImportMessageRequestSize = 25 * 1024 * 1024 // MaxImportMessageRequestSize 25 MB total limit
)
type ImportMsgReq struct {
Metadata *ImportMetadata // Metadata about the message to import.
Message []byte // The raw RFC822 message.
}
type ImportMsgReqs []*ImportMsgReq
func (reqs ImportMsgReqs) buildMultipartFormData() ([]*resty.MultipartField, error) {
metadata := make(map[string]*ImportMetadata, len(reqs))
fields := make([]*resty.MultipartField, 0, len(reqs))
for i, req := range reqs {
name := strconv.Itoa(i)
metadata[name] = req.Metadata
fields = append(fields, &resty.MultipartField{
Param: name,
FileName: name + ".eml",
ContentType: "message/rfc822",
Reader: bytes.NewReader(req.Message),
})
}
b, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
fields = append(fields, &resty.MultipartField{
Param: "Metadata",
ContentType: "application/json",
Reader: bytes.NewReader(b),
})
return fields, nil
}
type ImportMetadata struct {
AddressID string
Unread Boolean // 0: read, 1: unread.
IsReplied Boolean // 1 if the message has been replied.
IsRepliedAll Boolean // 1 if the message has been replied to all.
IsForwarded Boolean // 1 if the message has been forwarded.
Time int64 // The time when the message was received as a Unix time.
Flags int64 // The type of the imported message.
LabelIDs []string // The labels to apply to the imported message. Must contain at least one system label.
}
type ImportMsgRes struct {
// The error encountered while importing the message, if any.
Error error
// The newly created message ID.
MessageID string
}
// Import imports messages to the user's account.
func (c *client) Import(ctx context.Context, reqs ImportMsgReqs) ([]*ImportMsgRes, error) {
if len(reqs) == 0 {
return nil, errors.New("missing import requests")
}
if len(reqs) > MaxImportMessageRequestLength {
log.
WithField("count", len(reqs)).
Warn("Importing too many messages at once.")
return nil, errors.New("request is too long")
}
remainingSize := MaxImportMessageRequestSize
for _, req := range reqs {
remainingSize -= len(req.Message)
if remainingSize < 0 {
log.
WithField("count", len(reqs)).
WithField("size", MaxImportMessageRequestLength-remainingSize).
Warn("Importing too big message(s)")
return nil, errors.New("request size is too big")
}
}
fields, err := reqs.buildMultipartFormData()
if err != nil {
return nil, err
}
var res struct {
Responses []struct {
Name string
Response struct {
Error
MessageID string
}
}
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetMultipartFields(fields...).SetResult(&res).Post("/mail/v4/messages/import")
}); err != nil {
return nil, err
}
resps := make([]*ImportMsgRes, 0, len(res.Responses))
for _, resp := range res.Responses {
var err error
if resp.Response.Code != 1000 {
err = resp.Response.Error
}
resps = append(resps, &ImportMsgRes{
Error: err,
MessageID: resp.Response.MessageID,
})
}
return resps, nil
}

View File

@ -1,146 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"mime/multipart"
"net/http"
"testing"
pmmime "github.com/ProtonMail/proton-bridge/v2/pkg/mime"
r "github.com/stretchr/testify/require"
)
var testImportReqs = []*ImportMsgReq{
{
Metadata: &ImportMetadata{
AddressID: "QMJs2dzTx7uqpH5PNgIzjULywU4gO9uMBhEMVFOAVJOoUml54gC0CCHtW9qYwzH-zYbZwMv3MFYncPjW1Usq7Q==",
Unread: Boolean(false),
Flags: FlagReceived | FlagImported,
LabelIDs: []string{ArchiveLabel},
},
Message: []byte("Hello World!"),
},
}
const testImportBody = `{
"Code": 1001,
"Responses": [{
"Name": "0",
"Response": {"Code": 1000, "MessageID": "UKjSNz95KubYjrYmfbv1mbIfGxzY6D64mmHmVpWhkeEau-u0PIS4ru5IFMHgX6WjKpWYKCht3oiOtL5-wZChNg=="}
}]
}`
var testImportRes = &ImportMsgRes{
Error: nil,
MessageID: "UKjSNz95KubYjrYmfbv1mbIfGxzY6D64mmHmVpWhkeEau-u0PIS4ru5IFMHgX6WjKpWYKCht3oiOtL5-wZChNg==",
}
func TestClient_Import(t *testing.T) { //nolint:funlen
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "POST", "/mail/v4/messages/import"))
contentType, params, err := pmmime.ParseMediaType(req.Header.Get("Content-Type"))
r.NoError(t, err)
r.Equal(t, "multipart/form-data", contentType)
mr := multipart.NewReader(req.Body, params["boundary"])
// First part is message body.
p, err := mr.NextPart()
r.NoError(t, err)
contentDisp, params, err := pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
r.NoError(t, err)
r.Equal(t, "form-data", contentDisp)
r.Equal(t, "0", params["name"])
b, err := io.ReadAll(p)
r.NoError(t, err)
r.Equal(t, string(testImportReqs[0].Message), string(b))
// Second part is metadata.
p, err = mr.NextPart()
r.NoError(t, err)
contentDisp, params, err = pmmime.ParseMediaType(p.Header.Get("Content-Disposition"))
r.NoError(t, err)
r.Equal(t, "form-data", contentDisp)
r.Equal(t, "Metadata", params["name"])
metadata := map[string]*ImportMetadata{}
err = json.NewDecoder(p).Decode(&metadata)
r.NoError(t, err)
r.Equal(t, 1, len(metadata))
importReq := metadata["0"]
r.NotNil(t, req)
expected := *testImportReqs[0].Metadata
r.Equal(t, &expected, importReq)
// No more parts.
_, err = mr.NextPart()
r.EqualError(t, err, io.EOF.Error())
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testImportBody)
}))
defer s.Close()
imported, err := c.Import(context.Background(), testImportReqs)
r.NoError(t, err)
r.Equal(t, 1, len(imported))
r.Equal(t, testImportRes, imported[0])
}
func TestClientImportBigSize(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.FailNow(t, "request is not dropped")
}))
defer s.Close()
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const size = MaxImportMessageRequestSize + 1
msg := make([]byte, size)
for i := 0; i < size; i++ {
msg[i] = letterBytes[rand.Intn(len(letterBytes))]
}
importRequest := []*ImportMsgReq{
{
Metadata: &ImportMetadata{
AddressID: "addressID",
Unread: Boolean(false),
Flags: FlagReceived | FlagImported,
LabelIDs: []string{ArchiveLabel},
},
Message: msg,
},
}
_, err := c.Import(context.Background(), importRequest)
r.EqualError(t, err, "request size is too big")
}

View File

@ -1,78 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"github.com/go-resty/resty/v2"
)
// Key flags.
const (
UseToVerifyFlag = 1 << iota
UseToEncryptFlag
)
type PublicKey struct {
Flags int
PublicKey string
}
type RecipientType int
const (
RecipientTypeInternal RecipientType = iota + 1
RecipientTypeExternal
)
// GetPublicKeysForEmail returns all sending public keys for the given email address.
func (c *client) GetPublicKeysForEmail(ctx context.Context, email string) (keys []PublicKey, internal bool, err error) {
var res struct {
Keys []PublicKey
RecipientType RecipientType
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).SetQueryParam("Email", email).Get("/keys")
}); err != nil {
return nil, false, err
}
return res.Keys, res.RecipientType == RecipientTypeInternal, nil
}
// KeySalt contains id and salt for key.
type KeySalt struct {
ID, KeySalt string
}
// GetKeySalts sends request to get list of key salts (n.b. locked route).
func (c *client) GetKeySalts(ctx context.Context) (keySalts []KeySalt, err error) {
var res struct {
KeySalts []KeySalt
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/keys/salts")
}); err != nil {
return nil, err
}
return res.KeySalts, nil
}

View File

@ -1,335 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
type PMKey struct {
ID string
Version int
Flags int
Fingerprint string
PrivateKey *crypto.Key
Primary int
Token string
Active Boolean
Signature string
}
type clearable []byte
func (c *clearable) UnmarshalJSON(b []byte) error {
b = bytes.Trim(b, "\"")
b = bytes.ReplaceAll(b, []byte("\\n"), []byte("\n"))
b = bytes.ReplaceAll(b, []byte("\\r"), []byte("\r"))
*c = b
return nil
}
func (c *clearable) clear() {
for i := range *c {
(*c)[i] = 0
}
}
func (key *PMKey) UnmarshalJSON(b []byte) (err error) {
type _PMKey PMKey
rawKey := struct {
_PMKey
PrivateKey clearable
}{}
defer rawKey.PrivateKey.clear()
if err = json.Unmarshal(b, &rawKey); err != nil {
return
}
*key = PMKey(rawKey._PMKey)
if key.PrivateKey, err = crypto.NewKeyFromArmoredReader(bytes.NewReader(rawKey.PrivateKey)); err != nil {
return errors.Wrap(err, "failed to create crypto key from armored private key")
}
return
}
func (key PMKey) getPassphraseFromToken(kr *crypto.KeyRing) (passphrase []byte, err error) {
if kr == nil {
return nil, errors.New("no user key was provided")
}
msg, err := crypto.NewPGPMessageFromArmored(key.Token)
if err != nil {
return
}
sig, err := crypto.NewPGPSignatureFromArmored(key.Signature)
if err != nil {
return
}
token, err := kr.Decrypt(msg, nil, 0)
if err != nil {
return
}
if err = kr.VerifyDetached(token, sig, 0); err != nil {
return
}
return token.GetBinary(), nil
}
func (key PMKey) unlock(passphrase []byte) (unlockedKey *crypto.Key, err error) {
if unlockedKey, err = key.PrivateKey.Unlock(passphrase); err != nil {
return
}
ok, err := unlockedKey.Check()
if err != nil {
return
}
if !ok {
err = errors.New("private and public keys do not match")
return
}
return
}
type PMKeys []PMKey
// UnlockAll goes through each key and unlocks it, returning a keyring containing all unlocked keys,
// or an error if no keys could be unlocked.
// The passphrase is used to unlock the key unless the key's token and signature are both non-nil,
// in which case the given userkey is used to deduce the passphrase.
func (keys *PMKeys) UnlockAll(passphrase []byte, userKey *crypto.KeyRing) (kr *crypto.KeyRing, err error) {
if kr, err = crypto.NewKeyRing(nil); err != nil {
return
}
for _, key := range *keys {
if !key.Active {
logrus.WithField("fingerprint", key.Fingerprint).Warn("Skipping inactive key")
continue
}
var secret []byte
if key.Token == "" || key.Signature == "" {
secret = passphrase
} else if secret, err = key.getPassphraseFromToken(userKey); err != nil {
return
}
k, unlockErr := key.unlock(secret)
if unlockErr != nil {
logrus.WithError(unlockErr).WithField("fingerprint", key.Fingerprint).Warn("Failed to unlock key")
continue
}
if addKeyErr := kr.AddKey(k); addKeyErr != nil {
logrus.WithError(addKeyErr).Warn("Failed to add key to keyring")
continue
}
}
if kr.CountEntities() == 0 {
err = errors.New("no keys could be unlocked")
return
}
return kr, err
}
// ErrNoKeyringAvailable represents an error caused by a keyring being nil or having no entities.
var ErrNoKeyringAvailable = errors.New("no keyring available")
func encrypt(encrypter *crypto.KeyRing, plain string, signer *crypto.KeyRing) (armored string, err error) {
if encrypter == nil {
return "", ErrNoKeyringAvailable
}
firstKey, err := encrypter.FirstKey()
if err != nil {
return "", err
}
plainMessage := crypto.NewPlainMessageFromString(plain)
// We use only primary key to encrypt the message. Our keyring contains all keys (primary, old and deacivated ones).
pgpMessage, err := firstKey.Encrypt(plainMessage, signer)
if err != nil {
return
}
return pgpMessage.GetArmored()
}
func (c *client) decrypt(armored string) (plain []byte, err error) {
return decrypt(c.userKeyRing, armored)
}
func decrypt(decrypter *crypto.KeyRing, armored string) (plainBody []byte, err error) {
if decrypter == nil {
return nil, ErrNoKeyringAvailable
}
pgpMessage, err := crypto.NewPGPMessageFromArmored(armored)
if err != nil {
return
}
plainMessage, err := decrypter.Decrypt(pgpMessage, nil, 0)
if err != nil {
return
}
return plainMessage.GetBinary(), nil
}
func (c *client) verify(plain, amroredSignature string) (err error) {
plainMessage := crypto.NewPlainMessageFromString(plain)
pgpSignature, err := crypto.NewPGPSignatureFromArmored(amroredSignature)
if err != nil {
return
}
verifyTime := int64(0) // By default it will use current timestamp.
return c.userKeyRing.VerifyDetached(plainMessage, pgpSignature, verifyTime)
}
func encryptAttachment(kr *crypto.KeyRing, data io.Reader, filename string) (encrypted io.Reader, err error) {
if kr == nil {
return nil, ErrNoKeyringAvailable
}
firstKey, err := kr.FirstKey()
if err != nil {
return nil, err
}
dataBytes, err := io.ReadAll(data)
if err != nil {
return
}
plainMessage := crypto.NewPlainMessage(dataBytes)
// We use only primary key to encrypt the message. Our keyring contains all keys (primary, old and deacivated ones).
pgpSplitMessage, err := firstKey.EncryptAttachment(plainMessage, filename)
if err != nil {
return
}
packets := append(pgpSplitMessage.KeyPacket, pgpSplitMessage.DataPacket...) //nolint:gocritic
return bytes.NewReader(packets), nil
}
func decryptAttachment(kr *crypto.KeyRing, keyPackets []byte, data io.Reader) (decrypted io.Reader, err error) {
if kr == nil {
return nil, ErrNoKeyringAvailable
}
dataBytes, err := io.ReadAll(data)
if err != nil {
return
}
pgpSplitMessage := crypto.NewPGPSplitMessage(keyPackets, dataBytes)
plainMessage, err := kr.DecryptAttachment(pgpSplitMessage)
if err != nil {
return
}
return plainMessage.NewReader(), nil
}
func signAttachment(encrypter *crypto.KeyRing, data io.Reader) (signature io.Reader, err error) {
if encrypter == nil {
return nil, ErrNoKeyringAvailable
}
dataBytes, err := io.ReadAll(data)
if err != nil {
return
}
plainMessage := crypto.NewPlainMessage(dataBytes)
sig, err := encrypter.SignDetached(plainMessage)
if err != nil {
return
}
return bytes.NewReader(sig.GetBinary()), nil
}
func encryptAndEncodeSessionKeys(
pubkey *crypto.KeyRing,
bodyKey *crypto.SessionKey,
attkeys map[string]*crypto.SessionKey,
) (bodyPacket string, attachmentPackets map[string]string, err error) {
// Encrypt message body keys.
packetBytes, err := pubkey.EncryptSessionKey(bodyKey)
if err != nil {
return
}
bodyPacket = base64.StdEncoding.EncodeToString(packetBytes)
// Encrypt attachment keys.
attachmentPackets = make(map[string]string)
for id, attkey := range attkeys {
var packets []byte
if packets, err = pubkey.EncryptSessionKey(attkey); err != nil {
return
}
attachmentPackets[id] = base64.StdEncoding.EncodeToString(packets)
}
return
}
func encryptSymmDecryptKey(
kr *crypto.KeyRing,
textToEncrypt string,
) (decryptedKey *crypto.SessionKey, symEncryptedData []byte, err error) {
// We use only primary key to encrypt the message. Our keyring contains all keys (primary, old and deacivated ones).
firstKey, err := kr.FirstKey()
if err != nil {
return
}
pgpMessage, err := firstKey.Encrypt(crypto.NewPlainMessageFromString(textToEncrypt), kr)
if err != nil {
return
}
pgpSplitMessage, err := pgpMessage.SplitMessage()
if err != nil {
return
}
decryptedKey, err = kr.DecryptSessionKey(pgpSplitMessage.GetBinaryKeyPacket())
if err != nil {
return
}
symEncryptedData = pgpSplitMessage.GetBinaryDataPacket()
return
}

View File

@ -1,115 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"encoding/json"
"testing"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require"
)
func loadPMKeys(jsonKeys string) (keys *PMKeys) {
_ = json.Unmarshal([]byte(jsonKeys), &keys)
return
}
func TestPMKeys_GetKeyRingAndUnlock(t *testing.T) {
r := require.New(t)
addrKeysWithTokens := loadPMKeys(readTestFile("keyring_addressKeysWithTokens_JSON", false))
addrKeysWithoutTokens := loadPMKeys(readTestFile("keyring_addressKeysWithoutTokens_JSON", false))
addrKeysPrimaryHasToken := loadPMKeys(readTestFile("keyring_addressKeysPrimaryHasToken_JSON", false))
addrKeysSecondaryHasToken := loadPMKeys(readTestFile("keyring_addressKeysSecondaryHasToken_JSON", false))
key, err := crypto.NewKeyFromArmored(readTestFile("keyring_userKey", false))
if err != nil {
panic(err)
}
userKey, err := crypto.NewKeyRing(key)
r.NoError(err, "Expected not to receive an error unlocking user key")
type args struct {
userKeyring *crypto.KeyRing
passphrase []byte
}
tests := []struct {
name string
keys *PMKeys
args args
}{
{
name: "AddressKeys locked with tokens",
keys: addrKeysWithTokens,
args: args{userKeyring: userKey, passphrase: []byte("testpassphrase")},
},
{
name: "AddressKeys locked with passphrase, not tokens",
keys: addrKeysWithoutTokens,
args: args{userKeyring: userKey, passphrase: []byte("testpassphrase")},
},
{
name: "AddressKeys, primary locked with token, secondary with passphrase",
keys: addrKeysPrimaryHasToken,
args: args{userKeyring: userKey, passphrase: []byte("testpassphrase")},
},
{
name: "AddressKeys, primary locked with passphrase, secondary with token",
keys: addrKeysSecondaryHasToken,
args: args{userKeyring: userKey, passphrase: []byte("testpassphrase")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
kr, err := tt.keys.UnlockAll(tt.args.passphrase, tt.args.userKeyring) //nolint:scopelint
r.NoError(err)
// assert at least one key has been decrypted
atLeastOneDecrypted := false
for _, k := range kr.GetKeys() { //nolint:scopelint
ok, err := k.IsUnlocked()
if err != nil {
panic(err)
}
if ok {
atLeastOneDecrypted = true
break
}
}
r.True(atLeastOneDecrypted)
})
}
}
func TestGopenpgpEncryptAttachment(t *testing.T) {
r := require.New(t)
wantMessage := crypto.NewPlainMessage([]byte(testAttachmentCleartext))
pgpSplitMessage, err := testPublicKeyRing.EncryptAttachment(wantMessage, "")
r.NoError(err)
haveMessage, err := testPrivateKeyRing.DecryptAttachment(pgpSplitMessage)
r.NoError(err)
r.Equal(wantMessage.Data, haveMessage.Data)
}

View File

@ -1,187 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"errors"
"strconv"
"github.com/go-resty/resty/v2"
)
// System labels.
const (
InboxLabel = "0"
AllDraftsLabel = "1"
AllSentLabel = "2"
TrashLabel = "3"
SpamLabel = "4"
AllMailLabel = "5"
ArchiveLabel = "6"
SentLabel = "7"
DraftLabel = "8"
StarredLabel = "10"
LabelTypeMailBox = 1
LabelTypeContactGroup = 2
)
// IsSystemLabel checks if a label is a pre-defined system label.
func IsSystemLabel(label string) bool {
switch label {
case InboxLabel, DraftLabel, SentLabel, TrashLabel, SpamLabel, ArchiveLabel, StarredLabel, AllMailLabel, AllSentLabel, AllDraftsLabel:
return true
}
return false
}
// LabelColors provides the RGB values of the available label colors.
var LabelColors = []string{ //nolint:gochecknoglobals
"#7272a7",
"#cf5858",
"#c26cc7",
"#7569d1",
"#69a9d1",
"#5ec7b7",
"#72bb75",
"#c3d261",
"#e6c04c",
"#e6984c",
"#8989ac",
"#cf7e7e",
"#c793ca",
"#9b94d1",
"#a8c4d5",
"#97c9c1",
"#9db99f",
"#c6cd97",
"#e7d292",
"#dfb286",
}
// Label for message.
type Label struct { //nolint:maligned
ID string
Name string
Path string
Color string
Order int `json:",omitempty"`
Display int // Not used for now, leave it empty.
Exclusive Boolean
Type int
Notify Boolean
}
func (c *client) ListLabels(ctx context.Context) (labels []*Label, err error) {
return c.listLabelType(ctx, LabelTypeMailBox)
}
func (c *client) ListContactGroups(ctx context.Context) (labels []*Label, err error) {
return c.listLabelType(ctx, LabelTypeContactGroup)
}
// listLabelType lists all labels created by the user.
func (c *client) listLabelType(ctx context.Context, labelType int) (labels []*Label, err error) {
var res struct {
Labels []*Label
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParam("Type", strconv.Itoa(labelType)).SetResult(&res).Get("/labels")
}); err != nil {
return nil, err
}
return res.Labels, nil
}
type LabelReq struct {
*Label
}
// CreateLabel creates a new label.
func (c *client) CreateLabel(ctx context.Context, label *Label) (created *Label, err error) {
if label.Name == "" {
return nil, errors.New("name is required")
}
var res struct {
Label *Label
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&LabelReq{
Label: label,
}).SetResult(&res).Post("/labels")
}); err != nil {
return nil, err
}
return res.Label, nil
}
// UpdateLabel updates a label.
func (c *client) UpdateLabel(ctx context.Context, label *Label) (updated *Label, err error) {
if label.Name == "" {
return nil, errors.New("name is required")
}
var res struct {
Label *Label
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&LabelReq{
Label: label,
}).SetResult(&res).Put("/labels/" + label.ID)
}); err != nil {
return nil, err
}
return res.Label, nil
}
// DeleteLabel deletes a label.
func (c *client) DeleteLabel(ctx context.Context, labelID string) error {
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/labels/" + labelID)
}); err != nil {
return err
}
return nil
}
// LeastUsedColor is intended to return color for creating a new inbox or label.
func LeastUsedColor(colors []string) (color string) {
color = LabelColors[0]
frequency := map[string]int{}
for _, c := range colors {
frequency[c]++
}
for _, c := range LabelColors {
if frequency[color] > frequency[c] {
color = c
}
}
return
}

View File

@ -1,204 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"testing"
r "github.com/stretchr/testify/require"
)
const testLabelsBody = `{
"Labels": [
{
"ID": "LLz8ysmVxwr4dF6mWpClePT0SpSWOEvzTdq17RydSl4ndMckvY1K63HeXDzn03BJQwKYvgf-eWT8Qfd9WVuIEQ==",
"Name": "CroutonMail is awesome :)",
"Color": "#7272a7",
"Display": 0,
"Order": 1,
"Type": 1
},
{
"ID": "BvbqbySUPo9uWW_eR8tLA13NUsQMz3P4Zhw4UnpvrKqURnrHlE6L2Au0nplHfHlVXFgGz4L4hJ9-BYllOL-L5g==",
"Name": "Royal sausage",
"Color": "#cf5858",
"Display": 1,
"Order": 2,
"Type": 1
}
],
"Code": 1000
}
`
var testLabels = []*Label{
{ID: "LLz8ysmVxwr4dF6mWpClePT0SpSWOEvzTdq17RydSl4ndMckvY1K63HeXDzn03BJQwKYvgf-eWT8Qfd9WVuIEQ==", Name: "CroutonMail is awesome :)", Color: "#7272a7", Order: 1, Display: 0, Type: LabelTypeMailBox},
{ID: "BvbqbySUPo9uWW_eR8tLA13NUsQMz3P4Zhw4UnpvrKqURnrHlE6L2Au0nplHfHlVXFgGz4L4hJ9-BYllOL-L5g==", Name: "Royal sausage", Color: "#cf5858", Order: 2, Display: 1, Type: LabelTypeMailBox},
}
var testLabelReq = LabelReq{&Label{
Name: "sava",
Color: "#c26cc7",
Display: 1,
}}
const testCreateLabelBody = `{
"Label": {
"ID": "otkpEZzG--8dMXvwyLXLQWB72hhBhNGzINjH14rUDfywvOyeN01cDxDrS3Koifxf6asA7Xcwtldm0r_MCmWiAQ==",
"Name": "sava",
"Color": "#c26cc7",
"Display": 1,
"Order": 3,
"Type": 1
},
"Code": 1000
}
`
var testLabelCreated = &Label{
ID: "otkpEZzG--8dMXvwyLXLQWB72hhBhNGzINjH14rUDfywvOyeN01cDxDrS3Koifxf6asA7Xcwtldm0r_MCmWiAQ==",
Name: "sava",
Color: "#c26cc7",
Order: 3,
Display: 1,
Type: LabelTypeMailBox,
}
const testDeleteLabelBody = `{
"Code": 1000
}
`
func TestClient_ListLabels(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/labels?Type=1"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testLabelsBody)
}))
defer s.Close()
labels, err := c.ListLabels(context.Background())
r.NoError(t, err)
r.Equal(t, testLabels, labels)
}
func TestClient_CreateLabel(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "POST", "/labels"))
body := &bytes.Buffer{}
_, err := body.ReadFrom(req.Body)
r.NoError(t, err)
if bytes.Contains(body.Bytes(), []byte("Order")) {
t.Fatal("Body contains `Order`: ", body.String())
}
var labelReq LabelReq
err = json.NewDecoder(body).Decode(&labelReq)
r.NoError(t, err)
r.Equal(t, testLabelReq.Label, labelReq.Label)
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateLabelBody)
}))
defer s.Close()
created, err := c.CreateLabel(context.Background(), testLabelReq.Label)
r.NoError(t, err)
if !reflect.DeepEqual(created, testLabelCreated) {
t.Fatalf("Invalid created label: expected %+v, got %+v", testLabelCreated, created)
}
}
func TestClient_CreateEmptyLabel(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
r.Fail(t, "API should not be called")
}))
defer s.Close()
_, err := c.CreateLabel(context.Background(), &Label{})
r.EqualError(t, err, "name is required")
}
func TestClient_UpdateLabel(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "PUT", "/labels/"+testLabelCreated.ID))
var labelReq LabelReq
err := json.NewDecoder(req.Body).Decode(&labelReq)
r.NoError(t, err)
r.Equal(t, testLabelCreated, labelReq.Label)
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateLabelBody)
}))
defer s.Close()
updated, err := c.UpdateLabel(context.Background(), testLabelCreated)
r.NoError(t, err)
if !reflect.DeepEqual(updated, testLabelCreated) {
t.Fatalf("Invalid updated label: expected %+v, got %+v", testLabelCreated, updated)
}
}
func TestClient_UpdateLabelToEmptyName(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
r.Fail(t, "API should not be called")
}))
defer s.Close()
_, err := c.UpdateLabel(context.Background(), &Label{ID: "label"})
r.EqualError(t, err, "name is required")
}
func TestClient_DeleteLabel(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "DELETE", "/labels/"+testLabelCreated.ID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testDeleteLabelBody)
}))
defer s.Close()
err := c.DeleteLabel(context.Background(), testLabelCreated.ID)
r.NoError(t, err)
}
func TestLeastUsedColor(t *testing.T) {
// No colors at all, should use first available color
colors := []string{}
r.Equal(t, "#7272a7", LeastUsedColor(colors))
// All colors have same frequency, should use first available color
colors = []string{"#7272a7", "#cf5858", "#c26cc7", "#7569d1", "#69a9d1", "#5ec7b7", "#72bb75", "#c3d261", "#e6c04c", "#e6984c", "#8989ac", "#cf7e7e", "#c793ca", "#9b94d1", "#a8c4d5", "#97c9c1", "#9db99f", "#c6cd97", "#e7d292", "#dfb286"}
r.Equal(t, "#7272a7", LeastUsedColor(colors))
// First three colors already used, but others wasn't. Should use first non-used one.
colors = []string{"#7272a7", "#cf5858", "#c26cc7"}
r.Equal(t, "#7569d1", LeastUsedColor(colors))
}

View File

@ -1,110 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"errors"
"strconv"
"github.com/go-resty/resty/v2"
)
type LabelTypeV4 int
const (
LabelTypeV4Label = 1
LabelTypeV4ContactGroup = 2
LabelTypeV4Folder = 3
)
func (c *client) ListLabelsOnly(ctx context.Context) (labels []*Label, err error) {
return c.listLabelTypeV4(ctx, LabelTypeV4Label)
}
func (c *client) ListFoldersOnly(ctx context.Context) (labels []*Label, err error) {
return c.listLabelTypeV4(ctx, LabelTypeV4Folder)
}
// listLabelType lists all labels created by the user.
func (c *client) listLabelTypeV4(ctx context.Context, labelType LabelTypeV4) (labels []*Label, err error) {
var res struct {
Labels []*Label
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParam("Type", strconv.Itoa(int(labelType))).SetResult(&res).Get("/core/v4/labels")
}); err != nil {
return nil, err
}
return res.Labels, nil
}
// CreateLabel creates a new label.
func (c *client) CreateLabelV4(ctx context.Context, label *Label) (created *Label, err error) {
if label.Name == "" {
return nil, errors.New("name is required")
}
var res struct {
Label *Label
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&LabelReq{
Label: label,
}).SetResult(&res).Post("/core/v4/labels")
}); err != nil {
return nil, err
}
return res.Label, nil
}
// UpdateLabel updates a label.
func (c *client) UpdateLabelV4(ctx context.Context, label *Label) (updated *Label, err error) {
if label.Name == "" {
return nil, errors.New("name is required")
}
var res struct {
Label *Label
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&LabelReq{
Label: label,
}).SetResult(&res).Put("/core/v4/labels/" + label.ID)
}); err != nil {
return nil, err
}
return res.Label, nil
}
// DeleteLabel deletes a label.
func (c *client) DeleteLabelV4(ctx context.Context, labelID string) error {
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/core/v4/labels/" + labelID)
}); err != nil {
return err
}
return nil
}

View File

@ -1,170 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"testing"
r "github.com/stretchr/testify/require"
)
const testFoldersBody = `{
"Labels": [
{
"ID": "LLz8ysmVxwr4dF6mWpClePT0SpSWOEvzTdq17RydSl4ndMckvY1K63HeXDzn03BJQwKYvgf-eWT8Qfd9WVuIEQ==",
"Name": "CroutonMail is awesome :)",
"Color": "#7272a7",
"Display": 0,
"Order": 1,
"Type": 3
},
{
"ID": "BvbqbySUPo9uWW_eR8tLA13NUsQMz3P4Zhw4UnpvrKqURnrHlE6L2Au0nplHfHlVXFgGz4L4hJ9-BYllOL-L5g==",
"Name": "Royal sausage",
"Color": "#cf5858",
"Display": 1,
"Order": 2,
"Type": 3
}
],
"Code": 1000
}
`
var testFolders = []*Label{
{ID: "LLz8ysmVxwr4dF6mWpClePT0SpSWOEvzTdq17RydSl4ndMckvY1K63HeXDzn03BJQwKYvgf-eWT8Qfd9WVuIEQ==", Name: "CroutonMail is awesome :)", Color: "#7272a7", Order: 1, Display: 0, Type: LabelTypeV4Folder},
{ID: "BvbqbySUPo9uWW_eR8tLA13NUsQMz3P4Zhw4UnpvrKqURnrHlE6L2Au0nplHfHlVXFgGz4L4hJ9-BYllOL-L5g==", Name: "Royal sausage", Color: "#cf5858", Order: 2, Display: 1, Type: LabelTypeV4Folder},
}
func TestClient_ListLabelsOnly(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/core/v4/labels?Type=1"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testLabelsBody)
}))
defer s.Close()
labels, err := c.ListLabelsOnly(context.Background())
r.NoError(t, err)
r.Equal(t, testLabels, labels)
}
func TestClient_ListFoldersOnly(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/core/v4/labels?Type=3"))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testFoldersBody)
}))
defer s.Close()
folders, err := c.ListFoldersOnly(context.Background())
r.NoError(t, err)
r.Equal(t, testFolders, folders)
}
func TestClient_CreateLabelV4(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "POST", "/core/v4/labels"))
body := &bytes.Buffer{}
_, err := body.ReadFrom(req.Body)
r.NoError(t, err)
if bytes.Contains(body.Bytes(), []byte("Order")) {
t.Fatal("Body contains `Order`: ", body.String())
}
var labelReq LabelReq
err = json.NewDecoder(body).Decode(&labelReq)
r.NoError(t, err)
r.Equal(t, testLabelReq.Label, labelReq.Label)
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateLabelBody)
}))
defer s.Close()
created, err := c.CreateLabelV4(context.Background(), testLabelReq.Label)
r.NoError(t, err)
if !reflect.DeepEqual(created, testLabelCreated) {
t.Fatalf("Invalid created label: expected %+v, got %+v", testLabelCreated, created)
}
}
func TestClient_CreateEmptyLabelV4(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
r.Fail(t, "API should not be called")
}))
defer s.Close()
_, err := c.CreateLabelV4(context.Background(), &Label{})
r.EqualError(t, err, "name is required")
}
func TestClient_UpdateLabelV4(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "PUT", "/core/v4/labels/"+testLabelCreated.ID))
var labelReq LabelReq
err := json.NewDecoder(req.Body).Decode(&labelReq)
r.NoError(t, err)
r.Equal(t, testLabelCreated, labelReq.Label)
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testCreateLabelBody)
}))
defer s.Close()
updated, err := c.UpdateLabelV4(context.Background(), testLabelCreated)
r.NoError(t, err)
if !reflect.DeepEqual(updated, testLabelCreated) {
t.Fatalf("Invalid updated label: expected %+v, got %+v", testLabelCreated, updated)
}
}
func TestClient_UpdateLabelToEmptyNameV4(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
r.Fail(t, "API should not be called")
}))
defer s.Close()
_, err := c.UpdateLabelV4(context.Background(), &Label{ID: "label"})
r.EqualError(t, err, "name is required")
}
func TestClient_DeleteLabelV4(t *testing.T) {
s, c := newTestClient(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "DELETE", "/core/v4/labels/"+testLabelCreated.ID))
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testDeleteLabelBody)
}))
defer s.Close()
err := c.DeleteLabelV4(context.Background(), testLabelCreated.ID)
r.NoError(t, err)
}

View File

@ -1,173 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"net/http"
"sync"
"time"
"github.com/getsentry/sentry-go"
"github.com/go-resty/resty/v2"
)
type manager struct {
cfg Config
rc *resty.Client
isDown bool
locker sync.Locker
refreshingAuth sync.Locker
connectionObservers []ConnectionObserver
proxyDialer *ProxyTLSDialer
pingMutex *sync.RWMutex
isPinging bool
setSentryUserIDOnce sync.Once
}
func New(cfg Config) Manager {
return newManager(cfg)
}
func newManager(cfg Config) *manager {
m := &manager{
cfg: cfg,
rc: resty.New().EnableTrace(),
locker: &sync.Mutex{},
refreshingAuth: &sync.Mutex{},
pingMutex: &sync.RWMutex{},
isPinging: false,
setSentryUserIDOnce: sync.Once{},
}
proxyDialer, transport := newProxyDialerAndTransport(cfg)
m.proxyDialer = proxyDialer
m.rc.SetTransport(transport)
m.rc.SetBaseURL(cfg.HostURL)
m.rc.OnBeforeRequest(m.setHeaderValues)
// Any HTTP status code higher than 399 with JSON inside (and proper header)
// is converted to Error. `catchAPIError` then processes API custom errors
// wrapped in JSON. If error is returned, `handleRequestFailure` is called,
// otherwise `handleRequestSuccess` is called.
m.rc.SetError(&Error{})
m.rc.OnAfterResponse(logConnReuse)
m.rc.OnAfterResponse(updateTime)
m.rc.OnAfterResponse(m.catchAPIError)
m.rc.OnAfterResponse(m.handleRequestSuccess)
m.rc.OnError(m.handleRequestFailure)
// Configure retry mechanism.
//
// SetRetryCount(5): The most probable value of Retry-After from our
// API is 1s (max 10s). Retrying up to 5 times will on average cause a
// delay of 40s.
//
// NOTE: Increasing to values larger than 10 causing significant delay.
// The resty is increasing the delay between retries up to 1 minute
// (SetRetryMaxWaitTime) so for 10 retries the cumulative delay can be
// up to 5min.
m.rc.SetRetryCount(3)
m.rc.SetRetryMaxWaitTime(time.Minute)
m.rc.SetRetryAfter(catchRetryAfter)
m.rc.AddRetryCondition(m.shouldRetry)
return m
}
func (m *manager) SetTransport(transport http.RoundTripper) {
m.rc.SetTransport(transport)
m.proxyDialer = nil
}
func (m *manager) SetCookieJar(jar http.CookieJar) {
m.rc.SetCookieJar(jar)
}
func (m *manager) SetRetryCount(count int) {
m.rc.SetRetryCount(count)
}
func (m *manager) AddConnectionObserver(observer ConnectionObserver) {
m.connectionObservers = append(m.connectionObservers, observer)
}
func (m *manager) setHeaderValues(_ *resty.Client, req *resty.Request) error {
req.SetHeaders(map[string]string{
"x-pm-appversion": m.cfg.AppVersion,
"User-Agent": m.cfg.getUserAgent(),
})
return nil
}
func (m *manager) r(ctx context.Context) *resty.Request {
return m.rc.R().SetContext(ctx)
}
func (m *manager) handleRequestSuccess(_ *resty.Client, res *resty.Response) error {
m.locker.Lock()
defer m.locker.Unlock()
if !m.isDown {
return nil
}
// We successfully got a response; connection must be up.
m.isDown = false
for _, observer := range m.connectionObservers {
observer.OnUp()
}
return nil
}
func (m *manager) handleRequestFailure(req *resty.Request, err error) {
m.locker.Lock()
defer m.locker.Unlock()
if m.isDown {
return
}
if res, ok := err.(*resty.ResponseError); ok && res.Response.RawResponse != nil {
return
}
// We didn't get any response; connection must be down.
m.isDown = true
for _, observer := range m.connectionObservers {
observer.OnDown()
}
go m.pingUntilSuccess()
}
func (m *manager) setSentryUserID(userID string) {
m.setSentryUserIDOnce.Do(func() {
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetTag("UserID", userID)
})
})
}

View File

@ -1,138 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/base64"
"time"
"github.com/ProtonMail/go-srp"
)
func (m *manager) NewClient(uid, acc, ref string, exp time.Time) Client {
log.Trace("New client")
return newClient(m, uid).withAuth(acc, ref, exp)
}
func (m *manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (Client, *AuthRefresh, error) {
log.Trace("New client with refresh")
c := newClient(m, uid)
auth, err := m.authRefresh(ctx, uid, ref)
if err != nil {
return nil, nil, err
}
return c.withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
}
func (m *manager) NewClientWithLogin(ctx context.Context, username string, password []byte) (Client, *Auth, error) {
log.Trace("New client with login")
info, err := m.getAuthInfo(ctx, GetAuthInfoReq{Username: username})
if err != nil {
return nil, nil, err
}
srpAuth, err := srp.NewAuth(info.Version, username, password, info.Salt, info.Modulus, info.ServerEphemeral)
if err != nil {
return nil, nil, err
}
proofs, err := srpAuth.GenerateProofs(2048)
if err != nil {
return nil, nil, err
}
// Do not retry requests after this point. The ephemeral from auth info
// won't be valid any more
ctx = ContextWithoutRetry(ctx)
auth, err := m.auth(ctx, AuthReq{
Username: username,
ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof),
ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
SRPSession: info.SRPSession,
})
if err != nil {
return nil, nil, err
}
return newClient(m, auth.UID).withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
}
func (m *manager) getAuthInfo(ctx context.Context, req GetAuthInfoReq) (*AuthInfo, error) {
var res struct {
*AuthInfo
}
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/info"))
if err != nil {
return nil, err
}
return res.AuthInfo, nil
}
func (m *manager) auth(ctx context.Context, req AuthReq) (*Auth, error) {
var res struct {
*Auth
}
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth"))
if err != nil {
return nil, err
}
return res.Auth, nil
}
func (m *manager) authRefresh(ctx context.Context, uid, ref string) (*AuthRefresh, error) {
m.refreshingAuth.Lock()
defer m.refreshingAuth.Unlock()
req := authRefreshReq{
UID: uid,
RefreshToken: ref,
ResponseType: "token",
GrantType: "refresh_token",
RedirectURI: "https://protonmail.ch",
State: randomString(32),
}
var res struct {
*AuthRefresh
}
_, err := wrapNoConnection(m.r(ctx).SetBody(req).SetResult(&res).Post("/auth/refresh"))
if err != nil {
if IsBadRequest(err) || IsUnprocessableEntity(err) {
err = ErrAuthFailed{err}
}
return nil, err
}
return res.AuthRefresh, nil
}
func expiresIn(seconds int64) time.Time {
return time.Now().Add(time.Duration(seconds) * time.Second)
}

View File

@ -1,69 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"io"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"golang.org/x/net/context"
)
// DownloadAndVerify downloads a file and its signature from the given locations `file` and `sig`.
// The file and its signature are verified using the given keyring `kr`.
// If the file is verified successfully, it can be read from the returned reader.
// TLS fingerprinting is used to verify that connections are only made to known servers.
func (m *manager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) {
fb, err := m.fetchFile(url)
if err != nil {
return nil, err
}
sb, err := m.fetchFile(sig)
if err != nil {
return nil, err
}
if err := kr.VerifyDetached(
crypto.NewPlainMessage(fb),
crypto.NewPGPSignature(sb),
crypto.GetUnixTime(),
); err != nil {
return nil, err
}
return fb, nil
}
func (m *manager) fetchFile(url string) ([]byte, error) {
res, err := m.r(ContextWithoutRetry(context.Background())).SetDoNotParseResponse(true).Get(url)
if err != nil {
return nil, err
}
b, err := io.ReadAll(res.RawBody())
if err != nil {
return nil, err
}
if err := res.RawBody().Close(); err != nil {
return nil, err
}
return b, nil
}

View File

@ -1,71 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
)
// restyLogger decreases debug level to trace level so resty logs
// are not logged as debug but trace instead. Resty logging is too
// verbose which we don't want to have in debug level.
type restyLogger struct {
logrus *logrus.Entry
}
func (l *restyLogger) Errorf(format string, v ...interface{}) {
l.logrus.Errorf(format, v...)
}
func (l *restyLogger) Warnf(format string, v ...interface{}) {
l.logrus.Warnf(format, v...)
}
func (l *restyLogger) Debugf(format string, v ...interface{}) {
l.logrus.Tracef(format, v...)
}
func (m *manager) SetLogging(logger *logrus.Entry, verbose bool) {
if verbose {
m.rc.SetLogger(&restyLogger{logrus: logger})
m.rc.SetDebug(true)
return
}
m.rc.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error {
logger.Infof("Requesting %s %s", req.Method, req.URL)
return nil
})
m.rc.OnAfterResponse(func(_ *resty.Client, res *resty.Response) error {
log := logger.WithFields(logrus.Fields{
"error": res.Error(),
"status": res.StatusCode(),
"duration": res.Time(),
})
if res.Request == nil {
log.Warn("Requested unknown request")
return nil
}
log.Debugf("Requested %s %s", res.Request.Method, res.Request.URL)
return nil
})
m.rc.OnError(func(req *resty.Request, err error) {
logger.WithError(err).Warnf("Failed request %s %s", req.Method, req.URL)
})
}

View File

@ -1,34 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
)
func (m *manager) SendSimpleMetric(ctx context.Context, category, action, label string) error {
r := m.r(ctx).SetQueryParams(map[string]string{
"Category": category,
"Action": action,
"Label": label,
})
if _, err := wrapNoConnection(r.Get("/metrics")); err != nil {
return err
}
return nil
}

View File

@ -1,48 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
r "github.com/stretchr/testify/require"
)
const testSendSimpleMetricsBody = `{
"Code": 1000
}
`
func TestClient_SendSimpleMetric(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "GET", "/metrics?Action=some_action&Category=some_category&Label=some_label"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, testSendSimpleMetricsBody)
}))
defer s.Close()
m := newManager(newTestConfig(s.URL))
err := m.SendSimpleMetric(context.Background(), "some_category", "some_action", "some_label")
r.NoError(t, err)
}

View File

@ -1,85 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"time"
"github.com/sirupsen/logrus"
)
// retryConnectionSleeps defines a smooth cool down in seconds.
var retryConnectionSleeps = []int{2, 5, 10, 30, 60} //nolint:gochecknoglobals
func (m *manager) pingUntilSuccess() {
if m.isPingOngoing() {
logrus.Debug("Ping already ongoing")
return
}
m.pingingStarted()
defer m.pingingStopped()
attempt := 0
for {
ctx := ContextWithoutRetry(context.Background())
err := m.testPing(ctx)
if err == nil {
return
}
waitTime := getRetryConnectionSleep(attempt)
attempt++
logrus.WithError(err).WithField("attempt", attempt).WithField("wait", waitTime).Debug("Connection (still) not available")
time.Sleep(waitTime)
}
}
func (m *manager) isPingOngoing() bool {
m.pingMutex.RLock()
defer m.pingMutex.RUnlock()
return m.isPinging
}
func (m *manager) pingingStarted() {
m.pingMutex.Lock()
defer m.pingMutex.Unlock()
m.isPinging = true
}
func (m *manager) pingingStopped() {
m.pingMutex.Lock()
defer m.pingMutex.Unlock()
m.isPinging = false
}
func getRetryConnectionSleep(idx int) time.Duration {
if idx >= len(retryConnectionSleeps) {
idx = len(retryConnectionSleeps) - 1
}
sec := retryConnectionSleeps[idx]
return time.Duration(sec) * time.Second
}
func (m *manager) testPing(ctx context.Context) error {
if _, err := m.r(ctx).Get("/tests/ping"); err != nil {
return err
}
return nil
}

View File

@ -1,32 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
// AllowProxy allows the client manager to switch clients over to a proxy if need be.
func (m *manager) AllowProxy() {
if m.proxyDialer != nil {
m.proxyDialer.AllowProxy()
}
}
// DisallowProxy prevents the client manager from switching clients over to a proxy if need be.
func (m *manager) DisallowProxy() {
if m.proxyDialer != nil {
m.proxyDialer.DisallowProxy()
}
}

View File

@ -1,49 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
)
// Report sends request as json or multipart (if has attachment).
func (m *manager) ReportBug(ctx context.Context, rep ReportBugReq) error {
if rep.ClientType == 0 {
rep.ClientType = EmailClientType
}
if rep.Client == "" {
rep.Client = m.cfg.GetUserAgent()
}
if rep.ClientVersion == "" {
rep.ClientVersion = m.cfg.AppVersion
}
r := m.r(ctx).SetMultipartFormData(rep.GetMultipartFormData())
for _, att := range rep.Attachments {
r = r.SetMultipartField(att.name, att.name, att.mime, att.body)
}
if _, err := wrapNoConnection(r.Post("/reports/bug")); err != nil {
return err
}
return nil
}

View File

@ -1,86 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
r "github.com/stretchr/testify/require"
)
var testBugReportReq = ReportBugReq{
OS: "Mac OSX",
OSVersion: "10.11.6",
Browser: "AppleMail",
Client: "demoapp",
ClientVersion: "GoPMAPI_1.0.14",
ClientType: 1,
Title: "Big Bug",
Description: "Cannot fetch new messages",
Username: "Apple",
Email: "apple@gmail.com",
}
const testBugsBody = `{
"Code": 1000
}
`
func TestClient_BugReportWithAttachment(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
r.NoError(t, checkMethodAndPath(req, "POST", "/reports/bug"))
r.NoError(t, req.ParseMultipartForm(10*1024))
for field, expected := range map[string]string{
"OS": testBugReportReq.OS,
"OSVersion": testBugReportReq.OSVersion,
"Client": testBugReportReq.Client,
"ClientVersion": testBugReportReq.ClientVersion,
"ClientType": fmt.Sprintf("%d", testBugReportReq.ClientType),
"Title": testBugReportReq.Title,
"Description": testBugReportReq.Description,
"Username": testBugReportReq.Username,
"Email": testBugReportReq.Email,
} {
r.Equal(t, expected, req.PostFormValue(field))
}
attReader, err := req.MultipartForm.File["log"][0].Open()
r.NoError(t, err)
_, err = io.ReadAll(attReader)
r.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
fmt.Fprint(w, testBugsBody)
}))
defer s.Close()
cm := newManager(newTestConfig(s.URL))
rep := testBugReportReq
rep.AddAttachment("log", "last.log", strings.NewReader(testAttachmentJSON))
err := cm.ReportBug(context.Background(), rep)
r.NoError(t, err)
}

View File

@ -1,83 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"fmt"
"io"
)
// ClientType is required by API.
const (
EmailClientType = iota + 1
VPNClientType
)
type reportAtt struct {
name, mime string
body io.Reader
}
// ReportBugReq stores data for report.
type ReportBugReq struct {
OS string `json:",omitempty"`
OSVersion string `json:",omitempty"`
Browser string `json:",omitempty"`
BrowserVersion string `json:",omitempty"`
BrowserExtensions string `json:",omitempty"`
Resolution string `json:",omitempty"`
DisplayMode string `json:",omitempty"`
Client string `json:",omitempty"`
ClientVersion string `json:",omitempty"`
ClientType int `json:",omitempty"`
Title string `json:",omitempty"`
Description string `json:",omitempty"`
Username string `json:",omitempty"`
Email string `json:",omitempty"`
Country string `json:",omitempty"`
ISP string `json:",omitempty"`
Debug string `json:",omitempty"`
Attachments []reportAtt `json:",omitempty"`
}
// AddAttachment to report.
func (rep *ReportBugReq) AddAttachment(name, mime string, r io.Reader) {
rep.Attachments = append(rep.Attachments, reportAtt{name: name, mime: mime, body: r})
}
func (rep *ReportBugReq) GetMultipartFormData() map[string]string {
return map[string]string{
"OS": rep.OS,
"OSVersion": rep.OSVersion,
"Browser": rep.Browser,
"BrowserVersion": rep.BrowserVersion,
"BrowserExtensions": rep.BrowserExtensions,
"Resolution": rep.Resolution,
"DisplayMode": rep.DisplayMode,
"Client": rep.Client,
"ClientVersion": rep.ClientVersion,
"ClientType": fmt.Sprintf("%d", rep.ClientType),
"Title": rep.Title,
"Description": rep.Description,
"Username": rep.Username,
"Email": rep.Email,
"Country": rep.Country,
"ISP": rep.ISP,
"Debug": rep.Debug,
}
}

View File

@ -1,237 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
r "github.com/stretchr/testify/require"
)
const testForceUpgradeBody = `{
"Code":5003,
"Error":"Upgrade!"
}`
const testTooManyAPIRequests = `{
"Code":85131,
"Error":"Too many recent API requests"
}`
func TestHandleTooManyRequests(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
if numCalls < 5 {
w.WriteHeader(http.StatusTooManyRequests)
w.Header().Set("content-type", "application/json;charset=utf-8")
fmt.Fprint(w, testTooManyAPIRequests)
} else {
w.WriteHeader(http.StatusOK)
}
}))
m := New(Config{HostURL: ts.URL})
m.SetRetryCount(5)
// The call should succeed because the 5th retry should succeed (429s are retried).
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.NoError(t, err)
// The server should be called 5 times.
// The first four calls should return 429 and the last call should return 200.
r.Equal(t, 5, numCalls)
}
func TestHandleUnprocessableEntity(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
w.WriteHeader(http.StatusUnprocessableEntity)
}))
m := New(Config{HostURL: ts.URL})
m.SetRetryCount(5)
// The call should fail because the first call should fail (422s are not retried).
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.EqualError(t, err, "422 Unprocessable Entity")
// The server should be called 1 time.
// The first call should return 422.
r.Equal(t, 1, numCalls)
}
func TestHandleDialFailure(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
w.WriteHeader(http.StatusOK)
}))
// The failingRoundTripper will fail the first 5 times it is used.
m := New(Config{HostURL: ts.URL})
m.SetTransport(newFailingRoundTripper(5))
m.SetRetryCount(5)
// The call should succeed because the last retry should succeed (dial errors are retried).
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.NoError(t, err)
// The server should be called 1 time.
// The first 4 attempts don't reach the server.
r.Equal(t, 1, numCalls)
}
func TestHandleTooManyDialFailures(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
w.WriteHeader(http.StatusOK)
}))
// The failingRoundTripper will fail the first 10 times it is used.
// This is more than the number of retries we permit.
// Thus, dials will fail.
m := New(Config{HostURL: ts.URL})
m.SetTransport(newFailingRoundTripper(10))
m.SetRetryCount(5)
// The call should fail because every dial will fail and we'll run out of retries.
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.EqualError(t, err, "no internet connection")
// The server should never be called.
r.Equal(t, 0, numCalls)
}
func TestRetriesWithContextTimeout(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
if numCalls < 5 {
w.WriteHeader(http.StatusTooManyRequests)
} else {
w.WriteHeader(http.StatusOK)
}
}))
// Theoretically, this should succeed; on the fifth retry, we'll get StatusOK.
m := New(Config{HostURL: ts.URL})
m.SetRetryCount(5)
// However, that will take ~0.5s, and we only allow 10ms in the context.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
// Thus, it will fail.
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(ctx)
r.EqualError(t, err, context.DeadlineExceeded.Error())
}
func TestObserveConnectionStatus(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
var onDown, onUp bool
m := New(Config{HostURL: ts.URL})
m.SetTransport(newFailingRoundTripper(10))
m.SetRetryCount(5)
m.AddConnectionObserver(NewConnectionObserver(func() { onDown = true }, func() { onUp = true }))
// The call should fail because every dial will fail and we'll run out of retries.
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.Error(t, err)
r.False(t, onUp)
r.True(t, onDown)
onDown, onUp = false, false
// The call should succeed because the last dial attempt will succeed.
_, err = m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.NoError(t, err)
r.True(t, onUp)
r.False(t, onDown)
}
func TestReturnErrNoConnection(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// We will fail more times than we retry, so requests should fail with ErrNoConnection.
m := New(Config{HostURL: ts.URL})
m.SetTransport(newFailingRoundTripper(10))
m.SetRetryCount(5)
// The call should fail because every dial will fail and we'll run out of retries.
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.EqualError(t, err, "no internet connection")
}
func TestReturnErrUpgradeApplication(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusUnprocessableEntity)
fmt.Fprint(w, testForceUpgradeBody)
}))
m := New(Config{HostURL: ts.URL})
// The call should fail because every call return force upgrade error.
_, err := m.NewClient("", "", "", time.Now().Add(time.Hour)).GetAddresses(context.Background())
r.EqualError(t, err, ErrUpgradeApplication.Error())
}
type failingRoundTripper struct {
http.RoundTripper
fails, calls int
}
func newFailingRoundTripper(fails int) http.RoundTripper {
return &failingRoundTripper{
RoundTripper: http.DefaultTransport,
fails: fails,
}
}
func (rt *failingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.calls++
if rt.calls < rt.fails {
return nil, errors.New("simulating network error")
}
return rt.RoundTripper.RoundTrip(req)
}

View File

@ -1,46 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"net/http"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/sirupsen/logrus"
)
type Manager interface {
NewClient(string, string, string, time.Time) Client
NewClientWithRefresh(context.Context, string, string) (Client, *AuthRefresh, error)
NewClientWithLogin(context.Context, string, []byte) (Client, *Auth, error)
DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error)
ReportBug(context.Context, ReportBugReq) error
SendSimpleMetric(context.Context, string, string, string) error
SetLogging(logger *logrus.Entry, verbose bool)
SetTransport(http.RoundTripper)
SetCookieJar(http.CookieJar)
SetRetryCount(int)
AddConnectionObserver(ConnectionObserver)
AllowProxy()
DisallowProxy()
}

View File

@ -1,363 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"encoding/base64"
"errors"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
// Draft actions.
const (
DraftActionReply = 0
DraftActionReplyAll = 1
DraftActionForward = 2
)
// PackageFlag for send message package types.
type PackageFlag int
func (p *PackageFlag) Has(flag PackageFlag) bool { return iHasFlag(int(*p), int(flag)) }
func (p *PackageFlag) HasAtLeastOne(flag PackageFlag) bool {
return iHasAtLeastOneFlag(int(*p), int(flag))
}
func (p *PackageFlag) Is(flag PackageFlag) bool { return iIsFlag(int(*p), int(flag)) }
func (p *PackageFlag) HasNo(flag PackageFlag) bool { return iHasNoneOfFlag(int(*p), int(flag)) }
// Send message package types.
const (
InternalPackage = PackageFlag(1)
EncryptedOutsidePackage = PackageFlag(2)
ClearPackage = PackageFlag(4)
PGPInlinePackage = PackageFlag(8)
PGPMIMEPackage = PackageFlag(16)
ClearMIMEPackage = PackageFlag(32)
)
// SignatureFlag for send signature types.
type SignatureFlag int
func (p *SignatureFlag) Is(flag SignatureFlag) bool { return iIsFlag(int(*p), int(flag)) }
func (p *SignatureFlag) Has(flag SignatureFlag) bool { return iHasFlag(int(*p), int(flag)) }
func (p *SignatureFlag) HasNo(flag SignatureFlag) bool { return iHasNoneOfFlag(int(*p), int(flag)) }
// Send signature types.
const (
SignatureNone = SignatureFlag(0)
SignatureDetached = SignatureFlag(1)
SignatureAttachedArmored = SignatureFlag(2)
)
// DraftReq defines paylod for creating drafts.
type DraftReq struct {
Message *Message
ParentID string `json:",omitempty"`
Action int
AttachmentKeyPackets []string
}
func (c *client) CreateDraft(ctx context.Context, m *Message, parent string, action int) (created *Message, err error) {
var res struct {
Message *Message
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(&DraftReq{
Message: m,
ParentID: parent,
Action: action,
AttachmentKeyPackets: []string{},
}).SetResult(&res).Post("/mail/v4/messages")
}); err != nil {
return nil, err
}
return res.Message, nil
}
type AlgoKey struct {
Key string
Algorithm string
}
type MessageAddress struct {
Type PackageFlag
EncryptedBodyKeyPacket string `json:"BodyKeyPacket,omitempty"` // base64-encoded key packet.
Signature SignatureFlag
EncryptedAttachmentKeyPackets map[string]string `json:"AttachmentKeyPackets,omitempty"`
}
type MessagePackage struct {
Addresses map[string]*MessageAddress
Type PackageFlag
MIMEType string
EncryptedBody string `json:"Body"` // base64-encoded encrypted data packet.
DecryptedBodyKey *AlgoKey `json:"BodyKey,omitempty"` // base64-encoded session key (only if cleartext recipients).
DecryptedAttachmentKeys map[string]AlgoKey `json:"AttachmentKeys,omitempty"` // Only include if cleartext & attachments.
}
func newMessagePackage(
send sendData,
attKeys map[string]AlgoKey,
) (pkg *MessagePackage) {
pkg = &MessagePackage{
EncryptedBody: base64.StdEncoding.EncodeToString(send.ciphertext),
Addresses: send.addressMap,
MIMEType: send.contentType,
Type: send.sharedScheme,
}
if send.sharedScheme.HasAtLeastOne(ClearPackage | ClearMIMEPackage) {
pkg.DecryptedBodyKey = &AlgoKey{
Key: send.decryptedBodyKey.GetBase64Key(),
Algorithm: send.decryptedBodyKey.Algo,
}
}
if len(attKeys) != 0 && send.sharedScheme.Has(ClearPackage) {
pkg.DecryptedAttachmentKeys = attKeys
}
return pkg
}
type sendData struct {
decryptedBodyKey *crypto.SessionKey // body session key
addressMap map[string]*MessageAddress
sharedScheme PackageFlag
ciphertext []byte
cleartext string
contentType string
}
type SendMessageReq struct {
ExpirationTime int64 `json:",omitempty"`
// AutoSaveContacts int `json:",omitempty"`
// Data for encrypted recipients.
Packages []*MessagePackage `json:",omitempty"`
mime, plain, rich sendData
attKeys map[string]*crypto.SessionKey
kr *crypto.KeyRing
}
func NewSendMessageReq(
kr *crypto.KeyRing,
mimeBody, plainBody, richBody string,
attKeys map[string]*crypto.SessionKey,
) *SendMessageReq {
req := &SendMessageReq{}
req.mime.addressMap = make(map[string]*MessageAddress)
req.plain.addressMap = make(map[string]*MessageAddress)
req.rich.addressMap = make(map[string]*MessageAddress)
req.mime.cleartext = mimeBody
req.plain.cleartext = plainBody
req.rich.cleartext = richBody
req.attKeys = attKeys
req.kr = kr
return req
}
var (
errUnknownContentType = errors.New("unknown content type")
errMultipartInNonMIME = errors.New("multipart mixed not allowed in this scheme")
errAttSignNotSupported = errors.New("attached signature not supported")
errEncryptMustSign = errors.New("encrypted package must be signed")
errEncryptedOutsideNotSupported = errors.New("encrypted outside is not supported")
errWrongSendScheme = errors.New("wrong send scheme")
errInternalMustEncrypt = errors.New("internal package must be encrypted")
errInlineMustBePlain = errors.New("PGP Inline package must be plain text")
errMissingPubkey = errors.New("cannot encrypt body key packet: missing pubkey")
errClearSignMustNotBeHTML = errors.New("clear signed packet must be multipart or plain")
errMIMEMustBeMultipart = errors.New("MIME packet must be multipart")
errClearMIMEMustSign = errors.New("clear MIME must be signed")
errClearSignMustNotBePGPInline = errors.New("clear sign must not be PGP inline")
)
func (req *SendMessageReq) AddRecipient(
email string, sendScheme PackageFlag,
pubkey *crypto.KeyRing, signature SignatureFlag,
contentType string, doEncrypt bool,
) (err error) {
if signature.Has(SignatureAttachedArmored) {
return errAttSignNotSupported
}
if doEncrypt && signature.HasNo(SignatureDetached) {
return errEncryptMustSign
}
switch sendScheme {
case PGPMIMEPackage, ClearMIMEPackage:
if contentType != ContentTypeMultipartMixed {
return errMIMEMustBeMultipart
}
return req.addMIMERecipient(email, sendScheme, pubkey, signature)
case InternalPackage, ClearPackage, PGPInlinePackage:
if contentType == ContentTypeMultipartMixed {
return errMultipartInNonMIME
}
return req.addNonMIMERecipient(email, sendScheme, pubkey, signature, contentType, doEncrypt)
case EncryptedOutsidePackage:
return errEncryptedOutsideNotSupported
default:
return errWrongSendScheme
}
}
func (req *SendMessageReq) addNonMIMERecipient(
email string, sendScheme PackageFlag,
pubkey *crypto.KeyRing, signature SignatureFlag,
contentType string, doEncrypt bool,
) (err error) {
if signature.Is(SignatureDetached) && !doEncrypt {
if sendScheme.Is(PGPInlinePackage) {
return errClearSignMustNotBePGPInline
}
if sendScheme.Is(ClearPackage) && contentType == ContentTypeHTML {
return errClearSignMustNotBeHTML
}
}
var send *sendData
switch contentType {
case ContentTypePlainText:
send = &req.plain
send.contentType = ContentTypePlainText
case ContentTypeHTML, "":
send = &req.rich
send.contentType = ContentTypeHTML
case ContentTypeMultipartMixed:
return errMultipartInNonMIME
default:
return errUnknownContentType
}
if send.decryptedBodyKey == nil {
if send.decryptedBodyKey, send.ciphertext, err = encryptSymmDecryptKey(req.kr, send.cleartext); err != nil {
return err
}
}
newAddress := &MessageAddress{Type: sendScheme, Signature: signature}
if sendScheme.Is(PGPInlinePackage) && contentType == ContentTypeHTML {
return errInlineMustBePlain
}
if sendScheme.Is(InternalPackage) && !doEncrypt {
return errInternalMustEncrypt
}
if doEncrypt && pubkey == nil {
return errMissingPubkey
}
if doEncrypt {
newAddress.EncryptedBodyKeyPacket, newAddress.EncryptedAttachmentKeyPackets, err = encryptAndEncodeSessionKeys(pubkey, send.decryptedBodyKey, req.attKeys)
if err != nil {
return err
}
}
send.addressMap[email] = newAddress
send.sharedScheme |= sendScheme
return nil
}
func (req *SendMessageReq) addMIMERecipient(
email string, sendScheme PackageFlag,
pubkey *crypto.KeyRing, signature SignatureFlag,
) (err error) {
if sendScheme.Is(ClearMIMEPackage) && signature.HasNo(SignatureDetached) {
return errClearMIMEMustSign
}
req.mime.contentType = ContentTypeMultipartMixed
if req.mime.decryptedBodyKey == nil {
if req.mime.decryptedBodyKey, req.mime.ciphertext, err = encryptSymmDecryptKey(req.kr, req.mime.cleartext); err != nil {
return err
}
}
if sendScheme.Is(PGPMIMEPackage) {
if pubkey == nil {
return errMissingPubkey
}
// Attachment keys are not needed because attachments are part
// of MIME body and therefore attachments are encrypted with
// body session key.
mimeBodyPacket, _, err := encryptAndEncodeSessionKeys(pubkey, req.mime.decryptedBodyKey, map[string]*crypto.SessionKey{})
if err != nil {
return err
}
req.mime.addressMap[email] = &MessageAddress{Type: sendScheme, EncryptedBodyKeyPacket: mimeBodyPacket, Signature: signature}
} else {
req.mime.addressMap[email] = &MessageAddress{Type: sendScheme, Signature: signature}
}
req.mime.sharedScheme |= sendScheme
return nil
}
func (req *SendMessageReq) PreparePackages() {
attkeysEncoded := make(map[string]AlgoKey)
for attID, attkey := range req.attKeys {
attkeysEncoded[attID] = AlgoKey{
Key: attkey.GetBase64Key(),
Algorithm: attkey.Algo,
}
}
for _, send := range []sendData{req.mime, req.plain, req.rich} {
if len(send.addressMap) == 0 {
continue
}
req.Packages = append(req.Packages, newMessagePackage(send, attkeysEncoded))
}
}
func (c *client) SendMessage(ctx context.Context, draftID string, req *SendMessageReq) (*Message, *Message, error) {
if draftID == "" {
return nil, nil, errors.New("pmapi: cannot send message with an empty draftID")
}
if req.Packages == nil {
req.Packages = []*MessagePackage{}
}
var res struct {
Sent *Message
Parent *Message
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages/" + draftID)
}); err != nil {
return nil, nil, err
}
return res.Sent, res.Parent, nil
}

View File

@ -1,632 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"encoding/base64"
"encoding/json"
"testing"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require"
)
type recipient struct {
email string
sendScheme PackageFlag
pubkey *crypto.KeyRing
signature SignatureFlag
contentType string
doEncrypt bool
wantError error
}
type testData struct {
emails []string
recipients []recipient
wantPackages []*MessagePackage
allRecipients map[string]recipient
allAddresses map[string]*MessageAddress
attKeys map[string]*crypto.SessionKey
mimeBody, plainBody, richBody string
}
func (td *testData) addRecipients(t testing.TB) {
for _, email := range td.emails {
rcp, ok := td.allRecipients[email]
require.True(t, ok, "missing recipient %s", email)
rcp.email = email
td.recipients = append(td.recipients, rcp)
}
}
func (td *testData) addAddresses(t testing.TB) {
for i, wantPackage := range td.wantPackages {
for email := range wantPackage.Addresses {
address, ok := td.allAddresses[email]
require.True(t, ok, "missing address %s", email)
td.wantPackages[i].Addresses[email] = address
}
}
}
func (td *testData) prepareAndCheck(t *testing.T) {
r := require.New(t)
matchPresence := func(want string) require.ValueAssertionFunc {
if len(want) == 0 {
return require.Empty
}
return require.NotEmpty
}
have := NewSendMessageReq(testPrivateKeyRing, td.mimeBody, td.plainBody, td.richBody, td.attKeys)
for _, rec := range td.recipients {
err := have.AddRecipient(rec.email, rec.sendScheme, rec.pubkey, rec.signature, rec.contentType, rec.doEncrypt)
if rec.wantError == nil {
r.NoError(err, "email %s", rec.email)
} else {
r.EqualError(err, rec.wantError.Error(), "email %s", rec.email)
}
}
have.PreparePackages()
r.Equal(len(td.wantPackages), len(have.Packages))
for i, wantPackage := range td.wantPackages {
havePackage := have.Packages[i]
r.Equal(wantPackage.MIMEType, havePackage.MIMEType, "pkg %d", i)
r.Equal(wantPackage.Type, havePackage.Type, "pkg %d", i)
r.Equal(len(wantPackage.Addresses), len(havePackage.Addresses), "pkg %d", i)
for email, wantAddress := range wantPackage.Addresses {
haveAddress, ok := havePackage.Addresses[email]
r.True(ok, "pkg %d email %s", i, email)
r.Equal(wantAddress.Type, haveAddress.Type, "pkg %d email %s", i, email)
matchPresence(wantAddress.EncryptedBodyKeyPacket)(t, haveAddress.EncryptedBodyKeyPacket, "pkg %d email %s", i, email)
r.Equal(wantAddress.Signature, haveAddress.Signature, "pkg %d email %s", i, email)
if len(td.attKeys) == 0 {
r.Len(haveAddress.EncryptedAttachmentKeyPackets, 0)
} else {
r.Equal(
len(wantAddress.EncryptedAttachmentKeyPackets),
len(haveAddress.EncryptedAttachmentKeyPackets),
"pkg %d email %s", i, email,
)
for attID, wantAttKey := range wantAddress.EncryptedAttachmentKeyPackets {
haveAttKey, ok := haveAddress.EncryptedAttachmentKeyPackets[attID]
r.True(ok, "pkg %d email %s att %s", i, email, attID)
matchPresence(wantAttKey)(t, haveAttKey, "pkg %d email %s att %s", i, email, attID)
}
}
}
matchPresence(wantPackage.EncryptedBody)(t, havePackage.EncryptedBody, "pkg %d", i)
wantBodyKey := wantPackage.DecryptedBodyKey
haveBodyKey := havePackage.DecryptedBodyKey
if wantBodyKey == nil {
r.Nil(haveBodyKey, "pkg %d: expected empty body key but got %v", i, haveBodyKey)
} else {
r.NotNil(haveBodyKey, "pkg %d: expected body key but got nil", i)
r.NotEmpty(haveBodyKey.Algorithm, "pkg %d", i)
r.NotEmpty(haveBodyKey.Key, "pkg %d", i)
}
if len(td.attKeys) == 0 {
r.Len(havePackage.DecryptedAttachmentKeys, 0)
} else {
r.Equal(
len(wantPackage.DecryptedAttachmentKeys),
len(havePackage.DecryptedAttachmentKeys),
"pkg %d", i,
)
for attID, wantAttKey := range wantPackage.DecryptedAttachmentKeys {
haveAttKey, ok := havePackage.DecryptedAttachmentKeys[attID]
r.True(ok, "pkg %d att %s", i, attID)
matchPresence(wantAttKey.Key)(t, haveAttKey.Key, "pkg %d att %s", i, attID)
matchPresence(wantAttKey.Algorithm)(t, haveAttKey.Algorithm, "pkg %d att %s", i, attID)
}
}
}
haveBytes, err := json.Marshal(have)
r.NoError(err)
haveString := string(haveBytes)
// Added `:` to avoid false-fail if the whole output results to empty object.
r.NotContains(haveString, ":\"\"", "found empty string: %s", haveString)
r.NotContains(haveString, ":[]", "found empty list: %s", haveString)
r.NotContains(haveString, ":{}", "found empty object: %s", haveString)
r.NotContains(haveString, ":null", "found null: %s", haveString)
}
func TestSendReq(t *testing.T) {
attKeyB64 := "EvjO/2RIJNn6HdoU6ACqFdZglzJhpjQ/PpjsvL3mB5Q="
token, err := base64.StdEncoding.DecodeString(attKeyB64)
require.NoError(t, err)
attKey := crypto.NewSessionKeyFromToken(token, "aes256")
attKeyPackets := map[string]string{"attID": "not-empty"}
attAlgoKeys := map[string]AlgoKey{"attID": {"not-empty", "not-empty"}}
allRecipients := map[string]recipient{
// Internal OK
"none@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, "", true, nil},
"html@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, ContentTypeHTML, true, nil},
"plain@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, ContentTypePlainText, true, nil},
// Internal bad
"wrongtype@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, "application/rfc822", true, errUnknownContentType},
"multipart@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, ContentTypeMultipartMixed, true, errMultipartInNonMIME},
"noencrypt@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, ContentTypeHTML, false, errInternalMustEncrypt},
"no-pubkey@pm.me": {"", InternalPackage, nil, SignatureDetached, ContentTypeHTML, true, errMissingPubkey},
"nosigning@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureNone, ContentTypeHTML, true, errEncryptMustSign},
// testing combination
"internal1@pm.me": {"", InternalPackage, testPublicKeyRing, SignatureDetached, ContentTypePlainText, true, nil},
// Clear OK
"html@email.com": {"", ClearPackage, nil, SignatureNone, ContentTypeHTML, false, nil},
"none@email.com": {"", ClearPackage, nil, SignatureNone, "", false, nil},
"plain@email.com": {"", ClearPackage, nil, SignatureNone, ContentTypePlainText, false, nil},
"plain-sign@email.com": {"", ClearPackage, nil, SignatureDetached, ContentTypePlainText, false, nil},
"mime-sign@email.com": {"", ClearMIMEPackage, nil, SignatureDetached, ContentTypeMultipartMixed, false, nil},
// Clear bad
"mime@email.com": {"", ClearMIMEPackage, nil, SignatureNone, ContentTypeMultipartMixed, false, errClearMIMEMustSign},
"clear-plain-sign@email.com": {"", PGPInlinePackage, nil, SignatureDetached, ContentTypePlainText, false, errClearSignMustNotBePGPInline},
"html-sign@email.com": {"", ClearPackage, nil, SignatureDetached, ContentTypeHTML, false, errClearSignMustNotBeHTML},
"mime-plain@email.com": {"", ClearMIMEPackage, nil, SignatureDetached, ContentTypePlainText, false, errMIMEMustBeMultipart},
"mime-html@email.com": {"", ClearMIMEPackage, nil, SignatureDetached, ContentTypeHTML, false, errMIMEMustBeMultipart},
// External Encryption OK
"mime@gpg.com": {"", PGPMIMEPackage, testPublicKeyRing, SignatureDetached, ContentTypeMultipartMixed, true, nil},
"plain@gpg.com": {"", PGPInlinePackage, testPublicKeyRing, SignatureDetached, ContentTypePlainText, true, nil},
// External Encryption bad
"eo@gpg.com": {"", EncryptedOutsidePackage, testPublicKeyRing, SignatureDetached, ContentTypeHTML, true, errEncryptedOutsideNotSupported},
"inline-html@gpg.com": {"", PGPInlinePackage, testPublicKeyRing, SignatureDetached, ContentTypeHTML, true, errInlineMustBePlain},
"inline-mixed@gpg.com": {"", PGPInlinePackage, testPublicKeyRing, SignatureDetached, ContentTypeMultipartMixed, true, errMultipartInNonMIME},
"mime-plain@gpg.com": {"", PGPMIMEPackage, nil, SignatureDetached, ContentTypePlainText, true, errMIMEMustBeMultipart},
"mime-html@sgpg.com": {"", PGPMIMEPackage, nil, SignatureDetached, ContentTypeHTML, true, errMIMEMustBeMultipart},
"no-pubkey@gpg.com": {"", PGPMIMEPackage, nil, SignatureDetached, ContentTypeMultipartMixed, true, errMissingPubkey},
"not-signed@gpg.com": {"", PGPMIMEPackage, testPublicKeyRing, SignatureNone, ContentTypeMultipartMixed, true, errEncryptMustSign},
}
allAddresses := map[string]*MessageAddress{
"none@pm.me": {
Type: InternalPackage,
Signature: SignatureDetached,
EncryptedBodyKeyPacket: "not-empty",
EncryptedAttachmentKeyPackets: attKeyPackets,
},
"plain@pm.me": {
Type: InternalPackage,
Signature: SignatureDetached,
EncryptedBodyKeyPacket: "not-empty",
EncryptedAttachmentKeyPackets: attKeyPackets,
},
"html@pm.me": {
Type: InternalPackage,
Signature: SignatureDetached,
EncryptedBodyKeyPacket: "not-empty",
EncryptedAttachmentKeyPackets: attKeyPackets,
},
"internal1@pm.me": {
Type: InternalPackage,
Signature: SignatureDetached,
EncryptedBodyKeyPacket: "not-empty",
EncryptedAttachmentKeyPackets: attKeyPackets,
},
"html@email.com": {
Type: ClearPackage,
Signature: SignatureNone,
},
"none@email.com": {
Type: ClearPackage,
Signature: SignatureNone,
},
"plain@email.com": {
Type: ClearPackage,
Signature: SignatureNone,
},
"plain-sign@email.com": {
Type: ClearPackage,
Signature: SignatureDetached,
},
"mime-sign@email.com": {
Type: ClearMIMEPackage,
Signature: SignatureDetached,
},
"mime@gpg.com": {
Type: PGPMIMEPackage,
Signature: SignatureDetached,
EncryptedBodyKeyPacket: "non-empty",
},
"plain@gpg.com": {
Type: PGPInlinePackage,
Signature: SignatureDetached,
EncryptedBodyKeyPacket: "non-empty",
EncryptedAttachmentKeyPackets: attKeyPackets,
},
}
// NOTE naming
// Single: there should be one package
// Multiple: there should be more than one package
// Internal: there should be internal package
// Clear: there should be non-encrypted package
// Encrypted: there should be encrypted package
// NotAllowed: combination of inputs which are not allowed
newTests := map[string]testData{
"Nothing": { // expect no crash
emails: []string{},
wantPackages: []*MessagePackage{},
},
"Fails": {
emails: []string{
"wrongtype@pm.me",
"multipart@pm.me",
"noencrypt@pm.me",
"no-pubkey@pm.me",
"nosigning@pm.me",
"html-sign@email.com",
"mime-plain@email.com",
"mime-html@email.com",
"mime@email.com",
"clear-plain-sign@email.com",
"eo@gpg.com",
"inline-html@gpg.com",
"inline-mixed@gpg.com",
"mime-plain@gpg.com",
"mime-html@sgpg.com",
"no-pubkey@gpg.com",
"not-signed@gpg.com",
},
},
// one scheme in one package
"SingleInternalHTML": {
emails: []string{"none@pm.me", "html@pm.me"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"none@pm.me": nil,
"html@pm.me": nil,
},
Type: InternalPackage,
MIMEType: ContentTypeHTML,
EncryptedBody: "non-empty",
},
},
},
"SingleInternalPlain": {
emails: []string{"plain@pm.me"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"plain@pm.me": nil,
},
Type: InternalPackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
},
},
},
"SingleClearHTML": {
emails: []string{"none@email.com", "html@email.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"html@email.com": nil,
"none@email.com": nil,
},
Type: ClearPackage,
MIMEType: ContentTypeHTML,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
},
},
"SingleClearPlain": {
emails: []string{"plain@email.com", "plain-sign@email.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"plain@email.com": nil,
"plain-sign@email.com": nil,
},
Type: ClearPackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
},
},
"SingleClearMIME": {
emails: []string{"mime-sign@email.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"mime-sign@email.com": nil,
},
Type: ClearMIMEPackage,
MIMEType: ContentTypeMultipartMixed,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
},
},
},
"SingleEncyptedPlain": {
emails: []string{"plain@gpg.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"plain@gpg.com": nil,
},
Type: PGPInlinePackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
},
},
},
"SingleEncyptedMIME": {
emails: []string{"mime@gpg.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"mime@gpg.com": nil,
},
Type: PGPMIMEPackage,
MIMEType: ContentTypeMultipartMixed,
EncryptedBody: "non-empty",
},
},
},
// two schemes combined to one package
"SingleClearInternalPlain": {
emails: []string{"plain@email.com", "plain-sign@email.com", "plain@pm.me"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"plain@pm.me": nil,
"plain@email.com": nil,
"plain-sign@email.com": nil,
},
Type: InternalPackage | ClearPackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
},
},
"SingleClearInternalHTML": {
emails: []string{"none@email.com", "html@email.com", "html@pm.me", "none@pm.me"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"none@pm.me": nil,
"html@pm.me": nil,
"html@email.com": nil,
"none@email.com": nil,
},
Type: InternalPackage | ClearPackage,
MIMEType: ContentTypeHTML,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
},
},
"SingleEncryptedInternalPlain": {
emails: []string{"plain@gpg.com", "plain@pm.me"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"plain@pm.me": nil,
"plain@gpg.com": nil,
},
Type: InternalPackage | PGPInlinePackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
},
},
},
"SingleEncryptedClearMIME": {
emails: []string{"mime@gpg.com", "mime-sign@email.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"mime@gpg.com": nil,
"mime-sign@email.com": nil,
},
Type: ClearMIMEPackage | PGPMIMEPackage,
MIMEType: ContentTypeMultipartMixed,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
},
},
},
// one scheme separated to multiple packages
"MultipleInternal": {
emails: []string{"none@pm.me", "html@pm.me", "plain@pm.me"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"plain@pm.me": nil,
},
Type: InternalPackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
},
{
Addresses: map[string]*MessageAddress{
"none@pm.me": nil,
"html@pm.me": nil,
},
Type: InternalPackage,
MIMEType: ContentTypeHTML,
EncryptedBody: "non-empty",
},
},
},
"MultipleClear": {
emails: []string{
"none@email.com", "html@email.com",
"plain@email.com", "plain-sign@email.com",
"mime-sign@email.com",
},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"mime-sign@email.com": nil,
},
Type: ClearMIMEPackage,
MIMEType: ContentTypeMultipartMixed,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
},
{
Addresses: map[string]*MessageAddress{
"plain@email.com": nil,
"plain-sign@email.com": nil,
},
Type: ClearPackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
{
Addresses: map[string]*MessageAddress{
"html@email.com": nil,
"none@email.com": nil,
},
Type: ClearPackage,
MIMEType: ContentTypeHTML,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
},
},
"MultipleEncrypted": {
emails: []string{"plain@gpg.com", "mime@gpg.com"},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"mime@gpg.com": nil,
},
Type: PGPMIMEPackage,
MIMEType: ContentTypeMultipartMixed,
EncryptedBody: "non-empty",
},
{
Addresses: map[string]*MessageAddress{
"plain@gpg.com": nil,
},
Type: PGPInlinePackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
},
},
},
"MultipleComboAll": {
emails: []string{
"none@pm.me",
"plain@pm.me",
"html@pm.me",
"none@email.com",
"html@email.com",
"plain@email.com",
"plain-sign@email.com",
"mime-sign@email.com",
"mime@gpg.com",
"plain@gpg.com",
},
wantPackages: []*MessagePackage{
{
Addresses: map[string]*MessageAddress{
"mime@gpg.com": nil,
"mime-sign@email.com": nil,
},
Type: ClearMIMEPackage | PGPMIMEPackage,
MIMEType: ContentTypeMultipartMixed,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
},
{
Addresses: map[string]*MessageAddress{
"plain@gpg.com": nil,
"plain@email.com": nil,
"plain-sign@email.com": nil,
"plain@pm.me": nil,
},
Type: InternalPackage | ClearPackage | PGPInlinePackage,
MIMEType: ContentTypePlainText,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
{
Addresses: map[string]*MessageAddress{
"none@pm.me": nil,
"html@pm.me": nil,
"none@email.com": nil,
"html@email.com": nil,
},
Type: InternalPackage | ClearPackage,
MIMEType: ContentTypeHTML,
EncryptedBody: "non-empty",
DecryptedBodyKey: &AlgoKey{"non-empty", "non-empty"},
DecryptedAttachmentKeys: attAlgoKeys,
},
},
},
}
for name, test := range newTests {
test.mimeBody = "Mime body"
test.plainBody = "Plain body"
test.richBody = "HTML body"
test.allRecipients = allRecipients
test.allAddresses = allAddresses
test.addRecipients(t)
test.addAddresses(t)
t.Run("NoAtt"+name, test.prepareAndCheck)
test.attKeys = map[string]*crypto.SessionKey{"attID": attKey}
t.Run("Att"+name, test.prepareAndCheck)
}
}

View File

@ -1,739 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/mail"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/ProtonMail/go-crypto/openpgp"
"github.com/ProtonMail/go-crypto/openpgp/armor"
"github.com/ProtonMail/go-crypto/openpgp/packet"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
"github.com/sirupsen/logrus"
)
// Header types.
const (
MessageHeader = "-----BEGIN PGP MESSAGE-----"
MessageTail = "-----END PGP MESSAGE-----"
MessageHeaderLegacy = "---BEGIN ENCRYPTED MESSAGE---"
MessageTailLegacy = "---END ENCRYPTED MESSAGE---"
RandomKeyHeader = "---BEGIN ENCRYPTED RANDOM KEY---"
RandomKeyTail = "---END ENCRYPTED RANDOM KEY---"
)
// Sort types.
const (
SortByTo = "To"
SortByFrom = "From"
SortBySubject = "Subject"
SortBySize = "Size"
SortByTime = "Time"
SortByID = "ID"
SortDesc = true
SortAsc = false
)
// Message actions.
const (
ActionReply = 0
ActionReplyAll = 1
ActionForward = 2
)
// Message flag definitions.
const (
FlagReceived = int64(1)
FlagSent = int64(2)
FlagInternal = int64(4)
FlagE2E = int64(8)
FlagAuto = int64(16)
FlagReplied = int64(32)
FlagRepliedAll = int64(64)
FlagForwarded = int64(128)
FlagAutoreplied = int64(256)
FlagImported = int64(512)
FlagOpened = int64(1024)
FlagReceiptSent = int64(2048)
)
// Draft flags.
const (
FlagReceiptRequest = 1 << 16
FlagPublicKey = 1 << 17
FlagSign = 1 << 18
)
// Spam flags.
const (
FlagSpfFail = 1 << 24
FlagDkimFail = 1 << 25
FlagDmarcFail = 1 << 26
FlagHamManual = 1 << 27
FlagSpamAuto = 1 << 28
FlagSpamManual = 1 << 29
FlagPhishingAuto = 1 << 30
FlagPhishingManual = 1 << 31
)
// Message flag masks.
const (
FlagMaskGeneral = 4095
FlagMaskDraft = FlagReceiptRequest * 7
FlagMaskSpam = FlagSpfFail * 255
FlagMask = FlagMaskGeneral | FlagMaskDraft | FlagMaskSpam
)
// INTERNAL, AUTO are immutable. E2E is immutable except for drafts on send.
const (
FlagMaskAdd = 4067 + (16777216 * 168)
)
// Content types.
const (
ContentTypeMultipartMixed = "multipart/mixed"
ContentTypeMultipartEncrypted = "multipart/encrypted"
ContentTypePlainText = "text/plain"
ContentTypeHTML = "text/html"
)
// LabelsOperation is the operation to apply to labels.
type LabelsOperation int
const (
KeepLabels LabelsOperation = iota // KeepLabels Do nothing.
ReplaceLabels // ReplaceLabels Replace current labels with new ones.
AddLabels // AddLabels Add new labels to current ones.
RemoveLabels // RemoveLabels Remove specified labels from current ones.
)
// Due to API limitations, we shouldn't make requests with more than 100 message IDs at a time.
const messageIDPageSize = 100
// ConversationIDDomain is used as a placeholder for conversation reference headers to improve compatibility with various clients.
const ConversationIDDomain = `protonmail.conversationid`
// InternalIDDomain is used as a placeholder for reference/message ID headers to improve compatibility with various clients.
const InternalIDDomain = `protonmail.internalid`
// RxInternalReferenceFormat is compiled regexp which describes the match for
// a message ID used in reference headers.
var RxInternalReferenceFormat = regexp.MustCompile(`(?U)<(.+)@` + regexp.QuoteMeta(InternalIDDomain) + `>`) //nolint:gochecknoglobals
// Message structure.
type Message struct {
ID string `json:",omitempty"`
Order int64 `json:",omitempty"`
ConversationID string `json:",omitempty"` // only filter
Subject string
Unread Boolean
Flags int64
Sender *mail.Address
ReplyTo *mail.Address `json:",omitempty"`
ReplyTos []*mail.Address `json:",omitempty"`
ToList []*mail.Address
CCList []*mail.Address
BCCList []*mail.Address
Time int64 // Unix time
NumAttachments int
ExpirationTime int64 // Unix time
SpamScore int
AddressID string
Body string `json:",omitempty"`
Attachments []*Attachment
LabelIDs []string
ExternalID string
Header mail.Header
MIMEType string
}
// NewMessage initializes a new message.
func NewMessage() *Message {
return &Message{
ToList: []*mail.Address{},
CCList: []*mail.Address{},
BCCList: []*mail.Address{},
Attachments: []*Attachment{},
LabelIDs: []string{},
}
}
// Define a new type to prevent MarshalJSON/UnmarshalJSON infinite loops.
type message Message
type rawMessage struct {
message
Header string `json:",omitempty"`
}
func (m *Message) MarshalJSON() ([]byte, error) {
var raw rawMessage
raw.message = message(*m)
b := &bytes.Buffer{}
_ = http.Header(m.Header).Write(b)
raw.Header = b.String()
return json.Marshal(&raw)
}
func (m *Message) UnmarshalJSON(b []byte) error {
var raw rawMessage
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
*m = Message(raw.message)
if raw.Header != "" && raw.Header != "(No Header)" {
msg, err := mail.ReadMessage(strings.NewReader(raw.Header + "\r\n\r\n"))
if err != nil {
logrus.WithField("rawHeader", raw.Header).Trace("Failed to parse header")
return fmt.Errorf("failed to parse header of message %v: %v", m.ID, err.Error())
}
m.Header = msg.Header
} else {
m.Header = make(mail.Header)
}
return nil
}
// IsDraft returns whether the message should be considered to be a draft.
// A draft is complicated. It might have pmapi.DraftLabel but it might not.
// The real API definition of IsDraft is that it is neither sent nor received -- we should use that here.
func (m *Message) IsDraft() bool {
return (m.Flags & (FlagReceived | FlagSent)) == 0
}
// HasLabelID returns whether the message has the `labelID`.
func (m *Message) HasLabelID(labelID string) bool {
for _, l := range m.LabelIDs {
if l == labelID {
return true
}
}
return false
}
func (m *Message) IsEncrypted() bool {
return strings.HasPrefix(m.Header.Get("Content-Type"), "multipart/encrypted") || m.IsBodyEncrypted()
}
func (m *Message) IsBodyEncrypted() bool {
trimmedBody := strings.TrimSpace(m.Body)
return strings.HasPrefix(trimmedBody, MessageHeader) &&
strings.HasSuffix(trimmedBody, MessageTail)
}
func (m *Message) IsLegacyMessage() bool {
return strings.Contains(m.Body, RandomKeyHeader) &&
strings.Contains(m.Body, RandomKeyTail) &&
strings.Contains(m.Body, MessageHeaderLegacy) &&
strings.Contains(m.Body, MessageTailLegacy) &&
strings.Contains(m.Body, MessageHeader) &&
strings.Contains(m.Body, MessageTail)
}
func (m *Message) Decrypt(kr *crypto.KeyRing) ([]byte, error) {
if m.IsLegacyMessage() {
return m.decryptLegacy(kr)
}
if !m.IsBodyEncrypted() {
return []byte(m.Body), nil
}
armored := strings.TrimSpace(m.Body)
body, err := decrypt(kr, armored)
if err != nil {
return nil, err
}
return body, nil
}
type Signature struct {
Hash string
Data []byte
}
func (m *Message) ExtractSignatures(kr *crypto.KeyRing) ([]Signature, error) {
var entities openpgp.EntityList
for _, key := range kr.GetKeys() {
entities = append(entities, key.GetEntity())
}
p, err := armor.Decode(strings.NewReader(m.Body))
if err != nil {
return nil, err
}
msg, err := openpgp.ReadMessage(p.Body, entities, nil, nil)
if err != nil {
return nil, err
}
if _, err := io.ReadAll(msg.UnverifiedBody); err != nil {
return nil, err
}
if !msg.IsSigned {
return nil, nil
}
signatures := make([]Signature, 0, len(msg.UnverifiedSignatures))
for _, signature := range msg.UnverifiedSignatures {
buf := new(bytes.Buffer)
if err := signature.Serialize(buf); err != nil {
return nil, err
}
signatures = append(signatures, Signature{
Hash: signature.Hash.String(),
Data: buf.Bytes(),
})
}
return signatures, nil
}
func (m *Message) decryptLegacy(kr *crypto.KeyRing) (dec []byte, err error) {
randomKeyStart := strings.Index(m.Body, RandomKeyHeader) + len(RandomKeyHeader)
randomKeyEnd := strings.Index(m.Body, RandomKeyTail)
randomKey := m.Body[randomKeyStart:randomKeyEnd]
signedKey, err := decrypt(kr, strings.TrimSpace(randomKey))
if err != nil {
return
}
bytesKey, err := decodeBase64UTF8(string(signedKey))
if err != nil {
return
}
messageStart := strings.Index(m.Body, MessageHeaderLegacy) + len(MessageHeaderLegacy)
messageEnd := strings.Index(m.Body, MessageTailLegacy)
message := m.Body[messageStart:messageEnd]
bytesMessage, err := decodeBase64UTF8(message)
if err != nil {
return
}
block, err := aes.NewCipher(bytesKey)
if err != nil {
return
}
prefix := make([]byte, block.BlockSize()+2)
bytesMessageReader := bytes.NewReader(bytesMessage)
_, err = io.ReadFull(bytesMessageReader, prefix)
if err != nil {
return
}
s := packet.NewOCFBDecrypter(block, prefix, packet.OCFBResync)
if s == nil {
err = errors.New("pmapi: incorrect key for legacy decryption")
return
}
reader := cipher.StreamReader{S: s, R: bytesMessageReader}
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(reader)
plaintextBytes := buf.Bytes()
plaintext := ""
for i := 0; i < len(plaintextBytes); i++ {
plaintext += string(plaintextBytes[i])
}
bytesPlaintext, err := decodeBase64UTF8(plaintext)
if err != nil {
return
}
return bytesPlaintext, nil
}
func decodeBase64UTF8(input string) (output []byte, err error) {
input = strings.TrimSpace(input)
decodedMessage, err := base64.StdEncoding.DecodeString(input)
if err != nil {
return
}
utf8DecodedMessage := []rune(string(decodedMessage))
output = make([]byte, len(utf8DecodedMessage))
for i := 0; i < len(utf8DecodedMessage); i++ {
output[i] = byte(int(utf8DecodedMessage[i]))
}
return
}
func (m *Message) Encrypt(encrypter, signer *crypto.KeyRing) (err error) {
if m.IsBodyEncrypted() {
err = errors.New("pmapi: trying to encrypt an already encrypted message")
return
}
m.Body, err = encrypt(encrypter, m.Body, signer)
return
}
func (m *Message) Has(flag int64) bool {
return (m.Flags & flag) == flag
}
func (m *Message) Recipients() []*mail.Address {
var recipients []*mail.Address
recipients = append(recipients, m.ToList...)
recipients = append(recipients, m.CCList...)
recipients = append(recipients, m.BCCList...)
return recipients
}
// MessagesCount contains message counts for one label.
type MessagesCount struct {
LabelID string
Total int
Unread int
}
// MessagesFilter contains fields to filter messages.
type MessagesFilter struct {
Page int
PageSize int
Limit int
LabelID string
Sort string // Time by default (Time, To, From, Subject, Size).
Desc *bool
Begin int64 // Unix time.
End int64 // Unix time.
BeginID string
EndID string
Keyword string
To string
From string
Subject string
ConversationID string
AddressID string
ID []string
Attachments *bool
Unread *bool
ExternalID string // MIME Message-Id (only valid for messages).
AutoWildcard *bool
}
func (filter *MessagesFilter) urlValues() url.Values { //nolint:funlen
v := url.Values{}
if filter.Page != 0 {
v.Set("Page", strconv.Itoa(filter.Page))
}
if filter.PageSize != 0 {
v.Set("PageSize", strconv.Itoa(filter.PageSize))
}
if filter.Limit != 0 {
v.Set("Limit", strconv.Itoa(filter.Limit))
}
if filter.LabelID != "" {
v.Set("LabelID", filter.LabelID)
}
if filter.Sort != "" {
v.Set("Sort", filter.Sort)
}
if filter.Desc != nil {
if *filter.Desc {
v.Set("Desc", "1")
} else {
v.Set("Desc", "0")
}
}
if filter.Begin != 0 {
v.Set("Begin", strconv.Itoa(int(filter.Begin)))
}
if filter.End != 0 {
v.Set("End", strconv.Itoa(int(filter.End)))
}
if filter.BeginID != "" {
v.Set("BeginID", filter.BeginID)
}
if filter.EndID != "" {
v.Set("EndID", filter.EndID)
}
if filter.Keyword != "" {
v.Set("Keyword", filter.Keyword)
}
if filter.To != "" {
v.Set("To", filter.To)
}
if filter.From != "" {
v.Set("From", filter.From)
}
if filter.Subject != "" {
v.Set("Subject", filter.Subject)
}
if filter.ConversationID != "" {
v.Set("ConversationID", filter.ConversationID)
}
if filter.AddressID != "" {
v.Set("AddressID", filter.AddressID)
}
if len(filter.ID) > 0 {
for _, id := range filter.ID {
v.Add("ID[]", id)
}
}
if filter.Attachments != nil {
if *filter.Attachments {
v.Set("Attachments", "1")
} else {
v.Set("Attachments", "0")
}
}
if filter.Unread != nil {
if *filter.Unread {
v.Set("Unread", "1")
} else {
v.Set("Unread", "0")
}
}
if filter.ExternalID != "" {
v.Set("ExternalID", filter.ExternalID)
}
if filter.AutoWildcard != nil {
if *filter.AutoWildcard {
v.Set("AutoWildcard", "1")
} else {
v.Set("AutoWildcard", "0")
}
}
return v
}
// ListMessages gets message metadata.
func (c *client) ListMessages(ctx context.Context, filter *MessagesFilter) ([]*Message, int, error) {
var res struct {
Messages []*Message
Total int
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParamsFromValues(filter.urlValues()).
SetResult(&res).
Get("/mail/v4/messages")
}); err != nil {
return nil, 0, err
}
return res.Messages, res.Total, nil
}
// CountMessages counts messages by label.
func (c *client) CountMessages(ctx context.Context, addressID string) (counts []*MessagesCount, err error) {
var res struct {
Counts []*MessagesCount
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if addressID != "" {
r = r.SetQueryParam("AddressID", addressID)
}
return r.SetResult(&res).Get("/mail/v4/messages/count")
}); err != nil {
return nil, err
}
return res.Counts, nil
}
// GetMessage retrieves a message.
func (c *client) GetMessage(ctx context.Context, messageID string) (msg *Message, err error) {
var res struct {
Message *Message
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/mail/v4/messages/" + messageID)
}); err != nil {
return nil, err
}
return res.Message, nil
}
type MessagesActionReq struct {
IDs []string
}
func (c *client) MarkMessagesRead(ctx context.Context, messageIDs []string) error {
return doPaged(messageIDs, defaultPageSize, func(messageIDs []string) (err error) {
req := MessagesActionReq{IDs: messageIDs}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/read")
}); err != nil {
return err
}
return nil
})
}
func (c *client) MarkMessagesUnread(ctx context.Context, messageIDs []string) error {
return doPaged(messageIDs, defaultPageSize, func(messageIDs []string) (err error) {
req := MessagesActionReq{IDs: messageIDs}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/unread")
}); err != nil {
return err
}
return nil
})
}
func (c *client) DeleteMessages(ctx context.Context, messageIDs []string) error {
return doPaged(messageIDs, defaultPageSize, func(messageIDs []string) (err error) {
req := MessagesActionReq{IDs: messageIDs}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/delete")
}); err != nil {
return err
}
return nil
})
}
func (c *client) UndeleteMessages(ctx context.Context, messageIDs []string) error {
return doPaged(messageIDs, defaultPageSize, func(messageIDs []string) (err error) {
req := MessagesActionReq{IDs: messageIDs}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/undelete")
}); err != nil {
return err
}
return nil
})
}
type LabelMessagesReq struct {
LabelID string
IDs []string
}
// LabelMessages labels the given message IDs with the given label.
// The requests are performed paged; this can eventually be done in parallel.
func (c *client) LabelMessages(ctx context.Context, messageIDs []string, labelID string) error {
return doPaged(messageIDs, defaultPageSize, func(messageIDs []string) (err error) {
req := LabelMessagesReq{
LabelID: labelID,
IDs: messageIDs,
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/label")
}); err != nil {
return err
}
return nil
})
}
// UnlabelMessages removes the given label from the given message IDs.
// The requests are performed paged; this can eventually be done in parallel.
func (c *client) UnlabelMessages(ctx context.Context, messageIDs []string, labelID string) error {
return doPaged(messageIDs, defaultPageSize, func(messageIDs []string) (err error) {
req := LabelMessagesReq{
LabelID: labelID,
IDs: messageIDs,
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/unlabel")
}); err != nil {
return err
}
return nil
})
}
func (c *client) EmptyFolder(ctx context.Context, labelID, addressID string) error {
if labelID == "" {
return errors.New("labelID parameter is empty string")
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if addressID != "" {
r.SetQueryParam("AddressID", addressID)
}
return r.SetQueryParam("LabelID", labelID).Delete("/mail/v4/messages/empty")
}); err != nil {
return err
}
return nil
}
// ComputeMessageFlagsByLabels returns flags based on labels.
func ComputeMessageFlagsByLabels(labels []string) (flag int64) {
for _, labelID := range labels {
switch labelID {
case SentLabel:
flag = (flag | FlagSent)
case ArchiveLabel, InboxLabel:
flag = (flag | FlagReceived)
}
}
// NOTE: if the labels are custom only
if flag == 0 {
flag = FlagReceived
}
return flag
}

View File

@ -1,257 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require"
)
const (
testMessageCleartext = `<div>jeej saas<br></div><div><br></div><div class="protonmail_signature_block"><div>Sent from <a href="https://protonmail.ch">ProtonMail</a>, encrypted email based in Switzerland.<br></div><div><br></div></div>`
testMessageCleartextLegacy = `<div>flkasjfkjasdklfjasd<br></div><div>fasd<br></div><div>jfasjdfjasd<br></div><div>fj<br></div><div>asdfj<br></div><div>sadjf<br></div><div>sadjf<br></div><div>asjdf<br></div><div>jasd<br></div><div>fj<br></div><div>asdjf<br></div><div>asdjfsad<br></div><div>fasdlkfjasdjfkljsadfljsdfjsdljflkdsjfkljsdlkfjsdlk<br></div><div>jasfd<br></div><div>jsd<br></div><div>jf<br></div><div>sdjfjsdf<br></div><div><br></div><div>djfskjsladf<br></div><div>asd<br></div><div>fja<br></div><div>sdjfajsf<br></div><div>jas<br></div><div>fas<br></div><div>fj<br></div><div>afj<br></div><div>ajf<br></div><div>af<br></div><div>asdfasdfasd<br></div><div>Sent from <a href="https://protonmail.ch">ProtonMail</a>, encrypted email based in Switzerland.<br></div><div>dshfljsadfasdf<br></div><div>as<br></div><div>df<br></div><div>asd<br></div><div>fasd<br></div><div>f<br></div><div>asd<br></div><div>fasdflasdklfjsadlkjf</div><div>asd<br></div><div>fasdlkfjasdlkfjklasdjflkasjdflaslkfasdfjlasjflkasflksdjflkjasdf<br></div><div>asdflkasdjflajsfljaslkflasf<br></div><div>asdfkas<br></div><div>dfjas<br></div><div>djf<br></div><div>asjf<br></div><div>asj<br></div><div>faj<br></div><div>f<br></div><div>afj<br></div><div>sdjaf<br></div><div>jas<br></div><div>sdfj<br></div><div>ajf<br></div><div>aj<br></div><div>ajsdafafdaaf<br></div><div>a<br></div><div>f<br></div><div>lasl;ga<br></div><div>sags<br></div><div>ad<br></div><div>gags<br></div><div>g<br></div><div>ga<br></div><div>a<br></div><div>gg<br></div><div>a<br></div><div>ag<br></div><div>ag<br></div><div>agga.g.ga,ag.ag./ga<br></div><div><br></div><div>dsga<br></div><div>sg<br></div><div><br></div><div>gasga\g\g\g\g\g\n\y\t\r\\r\r\\n\n\n\<br></div><div><br></div><div><br></div><div>sd<br></div><div>asdf<br></div><div>asdf<br></div><div>dsa<br></div><div>fasd<br></div><div>f</div>`
)
const testMessageEncrypted = `-----BEGIN PGP MESSAGE-----
Version: OpenPGP.js v1.2.0
Comment: http://openpgpjs.org
wcBMA0fcZ7XLgmf2AQf+JPulpEOWwmY/Sfze8rBpYvrO2cebSSkjCgapFfXG
CI4PA+rb+WGkn9uBJf3FgEEg76c2ZqGh9zXTyrdHyFLm8ekarvxzgLpvcei/
p18IzcxsWnaM+1uknL4bKUtK3298gIl6xrfc4eVEA8tqUPUkSLSGk7uggjhj
zEYR4zIgMa0c6sMVcZ1Idvy9gGsTIvvcZJ4h1lKVUl8gba+qr1D76RaAf5xS
SBT74q9HhgfEMZwk6hXAp4MYY5h+lIsuhFu5kQ9fhZKU0PWS7ljddv854ZxS
9gHKPBerv4NBjkkCLp9xa2QNjDnu1fNlzlJpfCavp6wDdC83GiT61VRHPE4s
J9LASwFwgOrPmB8Mi867AQM0dddbj4Qe5ghlUcF1XnybkwfHqvQA1QT50d5n
ddFyxwIjvI/Nsn8MTCSnmrWCrjQ7v8JC73NyGxO5k6ZlUnc6BQVie78QJo5a
ftzl5b6nwlCYuXI8R6N/t5MXzrC5GwR8nvjH6kgbUVTLL1hO2Sbgyq5bBKLW
jjylTsZDHUGi4OX7q7eet5/RhKusWdvR0cHEaZAVD6BhTNN0mFBJ5bM1SINI
9gxJVqKJe7j4nJP4PGZBJrokZihhiBS/WEbJdvS54frYajGKjMavB3VhFP6k
qi5aiqGJKOJOV/G8yIwtdtxac3UL34eWo69U39Zx2mNfSXCzSjuafCr1nmAS
4g==
=Uw3B
-----END PGP MESSAGE-----
`
const testMessageEncryptedLegacy = `---BEGIN ENCRYPTED MESSAGE---esK5w7TCgVnDj8KQHBvDvhJObcOvw6/Cv2/CjMOpw5UES8KQwq/CiMOpI3MrexLDimzDmsKqVmwQw7vDkcKlRgXCosOpwoJgV8KEBCslSGbDtsOlw5gow7NxG8OSw6JNPlYuwrHCg8K5w6vDi8Kww5V5wo/Dl8KgwpnCi8Kww7nChMKdw5FHwoxmCGbCm8O6wpDDmRVEWsO7wqnCtVnDlMKORDbDnjbCqcOnNMKEwoPClFlaw6k1w5TDpcOGJsOUw5Unw5fCrcK3XnLCoRBBwo/DpsKAJiTDrUHDuGEQXz/DjMOhTCN7esO5ZjVIQSoFZMOyF8Kgw6nChcKmw6fCtcOBcW7Ck8KJwpTDnCzCnz3DjFY7wp5jUsOhw7XDosKQNsOUBmLDksKzPcO4fE/Dmw1GecKew4/CmcOJTFXDsB5uMcOFd1vDmX9ow4bDpCPDoU3Drw8oScKOXznDisKfYF3DvMKoEy0DDmzDhlHDjwIyC8OzRS/CnEZ4woM9w5cnw51fw6MZMAzDk8O3CDXDoyHDvzlFwqDCg8KsTnAiaMOsIyfCmUEaw6nChMK5TMOxG8KEHUNIwo1seMOXw5HDhyVawrzCr8KmFWHDpMO3asKpwrQbbMOlwoMew4t1Jz51wp9Jw6kGWcOzc8KgwpLCpsOHOMOgYB3DiMOxLcOQB8K7AcOyWF3CmnwfK8Kxw6XDm2TCiT/CnVTCg8Omw7Ngwp3CuUAHw6/CjRLDgcKsU8O/w6gXJ0cIw6pZMcOxEWETwpd4w58Mwr5SBMKORQjCi3FYcULDgx09w5M7SH7DrMKrw4gnXMKjwqUrBMOLwqQyF0nDhcKuwqTDqsO2w7LCnGjCvkbDgDgcw54xAkEiQMKUFlzDkMOew73CmkU4wrnCjw3DvsKaW8K0InA+w4sPSXfDuhbClMKgUcKeCMORw5ZYJcKnNEzDoMOhw7MYCX4DwqIQwoHCvsOaB1UAI8KVw6LCvcOTw53CuSgow4kZdHw5aRkYw7ZyV8OsP0LCh8KnwpIuw4p1NisoEcKcwrjDhcOtMzdvw5rDmsK3IAdAw7M4J8K+w6zCmR3CuMKUw4lqw6osPMObw53Dg8K3wqLCrsKZwr8mPcK4w4QWw5LCnwZeH1bDgwwiXcKbUhHDk1DDk0MLwoDDqMKXw5skNsKAAcOFw77Di8KNGCBzP8OcwrI5wodQQwQyw5V0wrInwrPDt8O+T8KbNsKVw7Mzw7HCsMOjwpcewoPCuMOUEsOow6QZVDjDpgbDlMOBGDXCtMOmw6jDuMKfw4nDlWTDq8Kqd0TDvwPCpSzDlA4JO3EHwrlBWcK5w7DCscOwCMK2wpsvwrYNIcOgBBXChMK0w6nCosKWEVd+w7cEal5hIcO4SWrCu0TDrW5Yw4XCmBgCwpc7YVwIwqPCi8OlGDzDmyJ/woHCscOtw4zDuC7CpUXCrDAJwp7Cj8KxPX3CrhDCvVB2w7PCosKbw7F+V11hY8Omwq1eQcO8w4wcRMKBJ2LDgW/DomXDhwkgAlxmQcKew6HDq8Ouw6ASeG/DlcKgUcKmLMOowpQWNcKJJcKDa3XDksK/woHCo3d6wrHDpMOqwqs/UUXCjUpnwrHCmsOyJx4bwoHChAnDi0TCpjLDrBvCvEghw5VtfhPCk8K5KsKIw75FCsOyDsKtV17CicOjwqAnF8OHHC0qMsOEwrgEwr13c8KZw4fDn8KXw73CksKAw4QTGRgIG8KMMXwpwrRBT2DDq8K3AsOQXl/DqMKYMivClsKiXcOhGkvDmsK9w77Cmmpvwrhsd8Kaw7bDgQ/DuCU2CyTDtjnCgn/DiMOtSyPDnsOfVTstccO6EVXDrj03MUHDvDDCgsO7BFQFEX3DszIyw7Rsw7pNwpjCs8OCLR9UbsOlw5USw73DiWJqVXTCl2tFw7FaAcKaw7l5a3Mvw5TCpMKCwpbDi3fCi8KHwrfDugUZwo5hw7fChsKDw5ZhPjA7w7HDjcO9wrrCjUbDoy4JXA1JICRDw49UNsOYOsK9FGE5wqhAw67DumnDqW0cwqbCu8OedEbDqcOfw50MVH8twpVLH8O3LsKvacKJw75xTMKkOcOJw4/DvsOYwqRwZcOnwqfCm2XCnRJFwqEgX8KLPsKfwpQWw6nChm82w6hME10KTRhGw5LCj1stPiXClsO8w7rCocOLw6lFw7tAZ8K0O3wswpZ4wqvCmMOFwpzDhMKVRRQjw53CikECPMOKZcOOwoAKcMK7WMO3K8Okw4bCjgrCisKLRsKewqzDvmtnw584wrtiw6RFVsKPecOpIhx7TsKzw4TCisKyw6nCqcK+w6fChsKxw5kWSsOgfD7CkRfCncKGKMOubsKoBA9Fe2YHwrx4aQNSG8Kpw5zDrMO1FMOPZcKSIVnDrHxOBsKyBcKmYwQMOl7CiRvCnDNVw7NaesOoPR3CrnQEwr9Xw600BSFYECnDgi1OFS7DoFYJw4M6wrzCog09WFPCmiHDogjDpQFjdsKKIsOWFsKXd0TDjXU3CsONRX3DssOrw4HDmX0Mw7rDiENvwpPCghsXacK2w6XCkMOICcKVw4nCkMO8RcOUw4zCn1VJw752RAUawqhdw5dEwqbDh0wAMH/DlTrChC/DosOoGsOPw5nClTcyw5XDlsKhNsKAcBINwpxUAi8Rw5Jvwpckwq4uBy0nw51dP2UGbidATX1FLMKFw5zDsQxewp3DlMKwwo3CrhBPJGR7cVHCnTUnwrDDksO0AcO5T3jCm245OnUVUT8WD1HDhTnCqnbCt8OjMDvCsAzCjsKSwoDDlDhtw7cFwpsDaS7CvVLDu0zDnlvDlMOEwrnCgVzCgcOZN8Oxwp0LSMKswq/DrMK9fcKTL1zDgcOvwofCtWAoL0IKR8OWwqpPw6QfVsKcwqxTXGEPKCFydX4Mw5jDmcOEWlPCgMKDPcOJw7HDgcOMahzCjMO7HyPDo8K3Y8OswqPDgSQ+w6wfw67Cr8O/w61oMsO+woTDrnECI2TDuMK5wrzDusOHw5/CosKFwrciQF3Csj5aw7DDpMKwZMK3Z8KlRBIcLcKvM2/CtBk8JMKWwqVyw6RNwoUhwoDCsXbCrD04wpQ4F8KOcMKIw7PDtMKqZRTCjsKSOMOKCMKYQ8OhwqZ1dGrChcKXLSnDiT7CrEjCihckNcOXw63CkUYpT8KTwq7CgMKiw7PCqmBzwq/Crz50XcKEGlLCrUBjw6ASVsObD8K9wpZ6eBHCi2FTMVcDSzvDgwtxw5ZJHlF5woDDtsKTwovChMOyYMKOSCt7w7hGDDsFaMOewrrCjRbDrGPDg2rCpsO3wo8IEMO9wqjCrG0mRXHDocKJwqQYdsKOw7UUwqIUwq/CqUlKW8ObwpcZGizCpgd4dAZBXMOYw5s5w6HDvkEgw6sbRxAwwoBSOyXCjDPDpsKlwrPCrl/DqsOswoJJDWzDp8Ocw5nDrE5FWm3DncKVwpnCqMKiwoDDmMONQcOEwpwRwonCsh0Tw7FCw6Nfw7U7wp7DnMKnfMOHCMOnw4TClcOVwrzCiiddUj3CmsOgwqvDhxfDjsOMWcKDZnvDocObw77Do1rDgMKHVsKCLcOXRMOHD0RNwpEdwozCrBnDqBYWwojCiVzCjTTCqcO5wqgAwqhhw7tnw5ZuOcOYNGTDiR1GAEzDuE0PeErDnlQlfsOjw6UGWUUNw6TCmgx8NMKzDMKgL8O3esKDwprDoTl8wrbDvVDCvU4Iw5sAwr/DugcoR8KMw4hNeMKSw7Jmw4rDjG8NbcO8w7jCs8OvfFXCoBBNfcOqNsK0EQLCncKPw53DrsOiwolvwqjCr8OZDsORw47DiyA+VcOMSg5wworDgGx0w7sgKMOyDMOyZRkgw43CqUHDicKfwpDCo8OII8KvKsOxDcKoFsOaw7HCgXTDssK7B8KIwoNcw4zCu8KBw4vCvFjDkWLDl8OyB8O/w4oYw5DCslzDk2kDw7jDgcOJw4jComXDkwdfw61xw53Cv8KPf11iwq0kKsKDw7nCmiVNF0NqLMKvwqvDjhQ3ZXbDomvDs8OKQQ7CocOnwr1Fw7xZRMK6w41cw5DDgzzCthIoAMOBQcOPbcOPVx/Cm8OYw7pHwo/CvCxhCcKVw7vChShnw6rClUQ7w6dbZMOrw4hpw7lZXMOxw5pnUXHDiMOLDxrDiA/DtMKqw6zDjXRJwp07BsKEwoTClBHCritDYXgzT3RWDcOlw4lfw4Vbw7fCj8K0w4AnwqjCrxPDpCVXF8KbY8OMPwQvwqdaw6E8w4AHPcKbNGl8wpQMX2PDp0pJfcOyGsOUXkNww5jCg8Obwo7DryjCisKeYiQ/XUzDvRvDncOtCMKJwqxHw6LDh8KwwrV7LGPCkcKOIXbCv8KHwpnDi1keQkLDssOSw7XCk8K+w7YdSMKAQmbDo8KPw7xywpnCsgANNTJYScKkNAvDo8KZw6Ayw6tmC8KaTsKEbcOZTx3DilrDtUjDi8OWV8K/wrocwpNKLlYbbcOmPcKPwrvCsTpLey5Xw58XJBPCo8KEPWJrwqZJX1fCncKDw4AZw4hWw5pTw7pidlzDtMO6w7t9DcK+R8KefMOfETvCskgjOgHCqcK7UgHCgsOfwrt8bcKQw5FeZcOiw4Faw7hRTjDDocOuEMOoEm04NQTCrCjDvMOaNDV6V8OHc8OTdMOndCh7HMOqw7HDnlzCl3MqwpjDiiDDtcKmCknCuBcQwobDvcOUN2LDmsOeHMOmPMKeH0nCt0nDgsO8w73CkRDDmMOuacO9w5J1KsKswqY7UMKyHHzDjMOjw5QOSWUhw4jCpMKJw4DCtcKNdcKPLcOFJsOqQ14=---END ENCRYPTED MESSAGE---||---BEGIN ENCRYPTED RANDOM KEY--------BEGIN PGP MESSAGE-----
Version: OpenPGP.js v0.9.0
Comment: http://openpgpjs.org
wcBMA2tjJVxNCRhtAQf/YzkQoUqaqa3NJ/c1apIF/dsl7yJ4GdVrC3/w7lxE
2CO5ioQD4s6QMWP2Y9dOdVl2INwz8eXOds9NS+1nMs4SoMbrpJnAjx8Cthti
1Z/8eWMU023LYahds8BYM0T435K/2tTB5GTA4uTl2y8Xzz2PbptQ4PrUDaII
+egeQQyPA0yuoRDwpaeTiaBYOSa06YYuK5Agr0buQAxRIMCxI2o+fucjoabv
FsQHKGu20U5GlJroSIyIVVkaH3evhNti/AnYX1HuokcGEQNsF5vo4SjWcH23
2P86EIV+w5lUWC1FN9vZCyvbvyuqLHQMtqKVn4GBOkIc3bYQ0jru3a0FG4Cx
bNJ0ASps2+p3Vxe0d+so2iFV92ByQ+0skyCUwCNUlwOV5V5f2fy1ImXk4mXI
cO/bcbqRxx3pG9gkPIh43FoQktTT+tsJ5vS53qfaLGdhCYfkrWjsKu+2P9Xg
+Cr8clh6NTblhfkoAS1gzjA3XgsgEFrtP+OGqwg=
=c5WU
-----END PGP MESSAGE-----
---END ENCRYPTED RANDOM KEY---
`
const testMessageSigned = `-----BEGIN PGP MESSAGE-----
Version: OpenPGP.js v4.5.3
Comment: https://openpgpjs.org
wcBMA0fcZ7XLgmf2AQgAgnHOlcAwVu2AnVfi2fIQHSkTQ0OFnZMJMRR3MJ1q
HtUW8jkSLcurL0Sn/tBFLIIR4YT2tQMzV7cvZzZyBEuZM4OYnDp8xSmoszPh
Gc/nvYG0A0pmKAQkL27v05Dul8oUWA0APT51urghH2Pzm7NdOMtTKIE4LQjS
mBfQ6Cf14uKV0xGS9v2dSFjFxxXEEpMQ+k60NCKRYClN2LVVxf3OKXbuugds
m2GUGn3CuFsiabosIUv4EcdE3aD9HbNo+PIWLJWRJIYJSc5+FWcbwXuIIFgC
XX1s7OV53ceZJnhjCmDE0N2ZOLLAYWED2zRvUa+CAqG+hZgc/3Ia+UmJUVuZ
BNLAugFuRsOVgh3olUIz0vazHhyGG0XIsNqmRm0U9SIfhWkPPHBmU6Xht6Qw
EvLbBfKTYHxX01yQUNgIv4S/TULeQuUjZQfsNYNXXGepS+jiCoIdEgUwpvre
OMFGsypwQXVCFYO/GQdYanMQRTckEexyBY4hGYVrevDM1yG/zGJIdbfI2L+1
1cz76jI8PtzL+S0zcVkevLcjjsHm2Je959uSida9jara7Bymr0y56UdoXoWX
4vZ0kQNo58eEEV0zg7dit4lDvwcuSZMW6K//xNtRQ4QX7/EDtlcYqBJXPwJY
eQSBVeYbeUbZ+PHJdu5gbI85BJNE2dKcS1bdOhEU2lPLYpvmMpPdot9TwnJb
dN3l8yDyhScGvTIZqlxhU7HCM9VHAS0bDqCUoO8EruztUSgjMI+gKC9+xdVU
yrkF7K23UNLWflROMv4cp0LDRB57619Y2w5lY/MG5bS0jSfMWBwnJG2AF28c
2tYKnHw6rpZXvXnlDmEDT8suTzuTGA==
=Sir8
-----END PGP MESSAGE-----
`
const testMessageSigner = `-----BEGIN PGP PUBLIC KEY BLOCK-----
Version: OpenPGP.js v0.7.1
Comment: http://openpgpjs.org
xsBNBFSI0BMBB/9td6B5RDzVSFTlFzYOS4JxIb5agtNW1rbA4FeLoC47bGLR
8E42IA6aKcO4H0vOZ1lFms0URiKk1DjCMXn3AUErbxqiV5IATRZLwliH6vwy
PI6j5rtGF8dyxYfwmLtoUNkDcPdcFEb4NCdowsN7e8tKU0bcpouZcQhAqawC
9nEdaG/gS5w+2k4hZX2lOKS1EF5SvP48UadlspEK2PLAIp5wB9XsFS9ey2wu
elzkSfDh7KUAlteqFGSMqIgYH62/gaKm+TcckfZeyiMHWFw6sfrcFQ3QOZPq
ahWt0Rn9XM5xBAxx5vW0oceuQ1vpvdfFlM5ix4gn/9w6MhmStaCee8/fABEB
AAHNBlVzZXJJRMLAcgQQAQgAJgUCVIjQHQYLCQgHAwIJEASDR1Fk7GNTBBUI
AgoDFgIBAhsDAh4BAADmhAf/Yt0mCfWqQ25NNGUN14pKKgnPm68zwj1SmMGa
pU7+7ItRpoFNaDwV5QYiQSLC1SvSb1ZeKoY928GPKfqYyJlBpTPL9zC1OHQj
9+2yYauHjYW9JWQM7hst2S2LBcdiQPOs3ybWPaO9yaccV4thxKOCPvyClaS5
b9T4Iv9GEVZQIUvArkwI8hyzIi6skRgxflGheq1O+S1W4Gzt2VtYvo8g8r6W
GzAGMw2nrs2h0+vUr+dLDgIbFCTc5QU99d5jE/e5Hw8iqBxv9tqB1hVATf8T
wC8aU5MTtxtabOiBgG0PsBs6oIwjFqEjpOIza2/AflPZfo7stp6IiwbwvTHo
1NlHoM7ATQRUiNAdAQf/eOLJYxX4lUQUzrNQgASDNE8gJPj7ywcGzySyqr0Y
5rbG57EjtKMIgZrpzJRpSCuRbBjfsltqJ5Q9TBAbPO+oR3rue0LqPKMnmr/q
KsHswBJRfsb/dbktUNmv/f7R9IVyOuvyP6RgdGeloxdGNeWiZSA6AZYI+WGc
xaOvVDPz8thtnML4G4MUhXxxNZ7JzQ0Lfz6mN8CCkblIP5xpcJsyRU7lUsGD
EJGZX0JH/I8bRVN1Xu08uFinIkZyiXRJ5ZGgF3Dns6VbIWmbttY54tBELtk+
5g9pNSl9qiYwiCdwuZrA//NmD3xlZIN8sG4eM7ZUibZ23vEq+bUt1++6Mpba
GQARAQABwsBfBBgBCAATBQJUiNAfCRAEg0dRZOxjUwIbDAAAlpMH/085qZdO
mGRAlbvViUNhF2rtHvCletC48WHGO1ueSh9VTxalkP21YAYLJ4JgJzArJ7tH
lEeiKiHm8YU9KhLe11Yv/o3AiKIAQjJiQluvk+mWdMcddB4fBjL6ttMTRAXe
gHnjtMoamHbSZdeUTUadv05Fl6ivWtpXlODG4V02YvDiGBUbDosdGXEqDtpT
g6MYlj3QMvUiUNQvt7YGMJS8A9iQ9qBNzErgRW8L6CON2RmpQ/wgwP5nwUHz
JjY51d82Vj8bZeI8LdsX41SPoUhyC7kmNYpw9ZRy7NlrCt8dBIOB4/BKEJ2G
ClW54lp9eeOfYTsdTSbn9VaSO0E6m2/Q4Tk=
=WFtr
-----END PGP PUBLIC KEY BLOCK-----`
func TestMessage_IsBodyEncrypted(t *testing.T) {
r := require.New(t)
msg := &Message{Body: testMessageEncrypted}
r.True(msg.IsBodyEncrypted(), "the body should be encrypted")
msg.Body = testMessageCleartext
r.True(!msg.IsBodyEncrypted(), "the body should not be encrypted")
}
func TestMessage_Decrypt(t *testing.T) {
r := require.New(t)
msg := &Message{Body: testMessageEncrypted}
dec, err := msg.Decrypt(testPrivateKeyRing)
r.NoError(err)
r.Equal(testMessageCleartext, string(dec))
}
func TestMessage_Decrypt_Legacy(t *testing.T) {
r := require.New(t)
testPrivateKeyLegacy := readTestFile("testPrivateKeyLegacy", false)
key, err := crypto.NewKeyFromArmored(testPrivateKeyLegacy)
r.NoError(err)
unlockedKey, err := key.Unlock([]byte(testMailboxPasswordLegacy))
r.NoError(err)
testPrivateKeyRingLegacy, err := crypto.NewKeyRing(unlockedKey)
r.NoError(err)
msg := &Message{Body: testMessageEncryptedLegacy}
dec, err := msg.Decrypt(testPrivateKeyRingLegacy)
r.NoError(err)
r.Equal(testMessageCleartextLegacy, string(dec))
}
func TestMessage_Decrypt_signed(t *testing.T) {
r := require.New(t)
msg := &Message{Body: testMessageSigned}
dec, err := msg.Decrypt(testPrivateKeyRing)
r.NoError(err)
r.Equal(testMessageCleartext, string(dec))
}
func TestMessage_Encrypt(t *testing.T) {
r := require.New(t)
key, err := crypto.NewKeyFromArmored(testMessageSigner)
r.NoError(err)
signer, err := crypto.NewKeyRing(key)
r.NoError(err)
msg := &Message{Body: testMessageCleartext}
r.NoError(msg.Encrypt(testPrivateKeyRing, testPrivateKeyRing))
dec, err := msg.Decrypt(testPrivateKeyRing)
r.NoError(err)
r.Equal(testMessageCleartext, string(dec))
r.Equal(testIdentity, signer.GetIdentities()[0])
}
func routeLabelMessages(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
require.NoError(tb, checkMethodAndPath(req, "PUT", "/mail/v4/messages/label"))
return "messages/label/put_response.json"
}
func TestMessage_LabelMessages_NoPaging(t *testing.T) {
r := require.New(t)
// This should be only enough IDs to produce one page.
testIDs := []string{}
for i := 0; i < messageIDPageSize-1; i++ {
testIDs = append(testIDs, fmt.Sprintf("%v", i))
}
// There should be enough IDs to produce just one page so the endpoint should be called once.
finish, c := newTestClientCallbacks(t,
routeLabelMessages,
)
defer finish()
r.NoError(c.LabelMessages(context.Background(), testIDs, "mylabel"))
}
func TestMessage_LabelMessages_Paging(t *testing.T) {
r := require.New(t)
// This should be enough IDs to produce three pages.
testIDs := []string{}
for i := 0; i < 3*messageIDPageSize; i++ {
testIDs = append(testIDs, fmt.Sprintf("%v", i))
}
// There should be enough IDs to produce three pages so the endpoint should be called three times.
finish, c := newTestClientCallbacks(t,
routeLabelMessages,
routeLabelMessages,
routeLabelMessages,
)
defer finish()
r.NoError(c.LabelMessages(context.Background(), testIDs, "mylabel"))
}
// TestClient_GetMessage might look like no actual functionality is tested
// here. But there was case when API was responding with bad payload and it was
// useful to have this to quickly test it.
func TestClient_GetMessage(t *testing.T) {
r := require.New(t)
testID := "AeUizgtA3H44qRgcr-HdBApwLiUhlQg5kB81mg_QalWotmQJIHep9OScWIo7Wu9pnYxM4RqQxJnr3BE4kh4y_Q=="
finish, c := newTestClientCallbacks(t,
func(tb testing.TB, w http.ResponseWriter, req *http.Request) string {
r.NoError(checkMethodAndPath(req, "GET", "/mail/v4/messages/"+testID))
return "/messages/get_response.json"
},
)
defer finish()
msg, err := c.GetMessage(context.Background(), testID)
r.NoError(err)
r.Equal(testID, msg.ID)
}

View File

@ -1,883 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v2/pkg/pmapi (interfaces: Client,Manager)
// Package mocks is a generated GoMock package.
package mocks
import (
context "context"
io "io"
http "net/http"
reflect "reflect"
time "time"
crypto "github.com/ProtonMail/gopenpgp/v2/crypto"
pmapi "github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
gomock "github.com/golang/mock/gomock"
logrus "github.com/sirupsen/logrus"
)
// MockClient is a mock of Client interface.
type MockClient struct {
ctrl *gomock.Controller
recorder *MockClientMockRecorder
}
// MockClientMockRecorder is the mock recorder for MockClient.
type MockClientMockRecorder struct {
mock *MockClient
}
// NewMockClient creates a new mock instance.
func NewMockClient(ctrl *gomock.Controller) *MockClient {
mock := &MockClient{ctrl: ctrl}
mock.recorder = &MockClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockClient) EXPECT() *MockClientMockRecorder {
return m.recorder
}
// AddAuthRefreshHandler mocks base method.
func (m *MockClient) AddAuthRefreshHandler(arg0 pmapi.AuthRefreshHandler) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddAuthRefreshHandler", arg0)
}
// AddAuthRefreshHandler indicates an expected call of AddAuthRefreshHandler.
func (mr *MockClientMockRecorder) AddAuthRefreshHandler(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddAuthRefreshHandler", reflect.TypeOf((*MockClient)(nil).AddAuthRefreshHandler), arg0)
}
// Addresses mocks base method.
func (m *MockClient) Addresses() pmapi.AddressList {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Addresses")
ret0, _ := ret[0].(pmapi.AddressList)
return ret0
}
// Addresses indicates an expected call of Addresses.
func (mr *MockClientMockRecorder) Addresses() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addresses", reflect.TypeOf((*MockClient)(nil).Addresses))
}
// Auth2FA mocks base method.
func (m *MockClient) Auth2FA(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Auth2FA", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Auth2FA indicates an expected call of Auth2FA.
func (mr *MockClientMockRecorder) Auth2FA(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Auth2FA", reflect.TypeOf((*MockClient)(nil).Auth2FA), arg0, arg1)
}
// AuthDelete mocks base method.
func (m *MockClient) AuthDelete(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthDelete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AuthDelete indicates an expected call of AuthDelete.
func (mr *MockClientMockRecorder) AuthDelete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthDelete", reflect.TypeOf((*MockClient)(nil).AuthDelete), arg0)
}
// AuthSalt mocks base method.
func (m *MockClient) AuthSalt(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthSalt", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthSalt indicates an expected call of AuthSalt.
func (mr *MockClientMockRecorder) AuthSalt(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthSalt", reflect.TypeOf((*MockClient)(nil).AuthSalt), arg0)
}
// CountMessages mocks base method.
func (m *MockClient) CountMessages(arg0 context.Context, arg1 string) ([]*pmapi.MessagesCount, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountMessages", arg0, arg1)
ret0, _ := ret[0].([]*pmapi.MessagesCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountMessages indicates an expected call of CountMessages.
func (mr *MockClientMockRecorder) CountMessages(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountMessages", reflect.TypeOf((*MockClient)(nil).CountMessages), arg0, arg1)
}
// CreateAttachment mocks base method.
func (m *MockClient) CreateAttachment(arg0 context.Context, arg1 *pmapi.Attachment, arg2, arg3 io.Reader) (*pmapi.Attachment, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateAttachment", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*pmapi.Attachment)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateAttachment indicates an expected call of CreateAttachment.
func (mr *MockClientMockRecorder) CreateAttachment(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAttachment", reflect.TypeOf((*MockClient)(nil).CreateAttachment), arg0, arg1, arg2, arg3)
}
// CreateDraft mocks base method.
func (m *MockClient) CreateDraft(arg0 context.Context, arg1 *pmapi.Message, arg2 string, arg3 int) (*pmapi.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateDraft", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*pmapi.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateDraft indicates an expected call of CreateDraft.
func (mr *MockClientMockRecorder) CreateDraft(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDraft", reflect.TypeOf((*MockClient)(nil).CreateDraft), arg0, arg1, arg2, arg3)
}
// CreateLabel mocks base method.
func (m *MockClient) CreateLabel(arg0 context.Context, arg1 *pmapi.Label) (*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateLabel", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateLabel indicates an expected call of CreateLabel.
func (mr *MockClientMockRecorder) CreateLabel(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLabel", reflect.TypeOf((*MockClient)(nil).CreateLabel), arg0, arg1)
}
// CreateLabelV4 mocks base method.
func (m *MockClient) CreateLabelV4(arg0 context.Context, arg1 *pmapi.Label) (*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateLabelV4", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateLabelV4 indicates an expected call of CreateLabelV4.
func (mr *MockClientMockRecorder) CreateLabelV4(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLabelV4", reflect.TypeOf((*MockClient)(nil).CreateLabelV4), arg0, arg1)
}
// CurrentUser mocks base method.
func (m *MockClient) CurrentUser(arg0 context.Context) (*pmapi.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CurrentUser", arg0)
ret0, _ := ret[0].(*pmapi.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CurrentUser indicates an expected call of CurrentUser.
func (mr *MockClientMockRecorder) CurrentUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CurrentUser", reflect.TypeOf((*MockClient)(nil).CurrentUser), arg0)
}
// DecryptAndVerifyCards mocks base method.
func (m *MockClient) DecryptAndVerifyCards(arg0 []pmapi.Card) ([]pmapi.Card, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DecryptAndVerifyCards", arg0)
ret0, _ := ret[0].([]pmapi.Card)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DecryptAndVerifyCards indicates an expected call of DecryptAndVerifyCards.
func (mr *MockClientMockRecorder) DecryptAndVerifyCards(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptAndVerifyCards", reflect.TypeOf((*MockClient)(nil).DecryptAndVerifyCards), arg0)
}
// DeleteLabel mocks base method.
func (m *MockClient) DeleteLabel(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteLabel", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteLabel indicates an expected call of DeleteLabel.
func (mr *MockClientMockRecorder) DeleteLabel(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLabel", reflect.TypeOf((*MockClient)(nil).DeleteLabel), arg0, arg1)
}
// DeleteLabelV4 mocks base method.
func (m *MockClient) DeleteLabelV4(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteLabelV4", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteLabelV4 indicates an expected call of DeleteLabelV4.
func (mr *MockClientMockRecorder) DeleteLabelV4(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLabelV4", reflect.TypeOf((*MockClient)(nil).DeleteLabelV4), arg0, arg1)
}
// DeleteMessages mocks base method.
func (m *MockClient) DeleteMessages(arg0 context.Context, arg1 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteMessages", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteMessages indicates an expected call of DeleteMessages.
func (mr *MockClientMockRecorder) DeleteMessages(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMessages", reflect.TypeOf((*MockClient)(nil).DeleteMessages), arg0, arg1)
}
// EmptyFolder mocks base method.
func (m *MockClient) EmptyFolder(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EmptyFolder", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// EmptyFolder indicates an expected call of EmptyFolder.
func (mr *MockClientMockRecorder) EmptyFolder(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmptyFolder", reflect.TypeOf((*MockClient)(nil).EmptyFolder), arg0, arg1, arg2)
}
// GetAddresses mocks base method.
func (m *MockClient) GetAddresses(arg0 context.Context) (pmapi.AddressList, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAddresses", arg0)
ret0, _ := ret[0].(pmapi.AddressList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAddresses indicates an expected call of GetAddresses.
func (mr *MockClientMockRecorder) GetAddresses(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddresses", reflect.TypeOf((*MockClient)(nil).GetAddresses), arg0)
}
// GetAttachment mocks base method.
func (m *MockClient) GetAttachment(arg0 context.Context, arg1 string) (io.ReadCloser, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAttachment", arg0, arg1)
ret0, _ := ret[0].(io.ReadCloser)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAttachment indicates an expected call of GetAttachment.
func (mr *MockClientMockRecorder) GetAttachment(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttachment", reflect.TypeOf((*MockClient)(nil).GetAttachment), arg0, arg1)
}
// GetContactByID mocks base method.
func (m *MockClient) GetContactByID(arg0 context.Context, arg1 string) (pmapi.Contact, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetContactByID", arg0, arg1)
ret0, _ := ret[0].(pmapi.Contact)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetContactByID indicates an expected call of GetContactByID.
func (mr *MockClientMockRecorder) GetContactByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetContactByID", reflect.TypeOf((*MockClient)(nil).GetContactByID), arg0, arg1)
}
// GetContactEmailByEmail mocks base method.
func (m *MockClient) GetContactEmailByEmail(arg0 context.Context, arg1 string, arg2, arg3 int) ([]pmapi.ContactEmail, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetContactEmailByEmail", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]pmapi.ContactEmail)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetContactEmailByEmail indicates an expected call of GetContactEmailByEmail.
func (mr *MockClientMockRecorder) GetContactEmailByEmail(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetContactEmailByEmail", reflect.TypeOf((*MockClient)(nil).GetContactEmailByEmail), arg0, arg1, arg2, arg3)
}
// GetEvent mocks base method.
func (m *MockClient) GetEvent(arg0 context.Context, arg1 string) (*pmapi.Event, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEvent", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Event)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEvent indicates an expected call of GetEvent.
func (mr *MockClientMockRecorder) GetEvent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEvent", reflect.TypeOf((*MockClient)(nil).GetEvent), arg0, arg1)
}
// GetMailSettings mocks base method.
func (m *MockClient) GetMailSettings(arg0 context.Context) (pmapi.MailSettings, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMailSettings", arg0)
ret0, _ := ret[0].(pmapi.MailSettings)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMailSettings indicates an expected call of GetMailSettings.
func (mr *MockClientMockRecorder) GetMailSettings(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMailSettings", reflect.TypeOf((*MockClient)(nil).GetMailSettings), arg0)
}
// GetMessage mocks base method.
func (m *MockClient) GetMessage(arg0 context.Context, arg1 string) (*pmapi.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessage", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Message)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessage indicates an expected call of GetMessage.
func (mr *MockClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockClient)(nil).GetMessage), arg0, arg1)
}
// GetPublicKeysForEmail mocks base method.
func (m *MockClient) GetPublicKeysForEmail(arg0 context.Context, arg1 string) ([]pmapi.PublicKey, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPublicKeysForEmail", arg0, arg1)
ret0, _ := ret[0].([]pmapi.PublicKey)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetPublicKeysForEmail indicates an expected call of GetPublicKeysForEmail.
func (mr *MockClientMockRecorder) GetPublicKeysForEmail(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicKeysForEmail", reflect.TypeOf((*MockClient)(nil).GetPublicKeysForEmail), arg0, arg1)
}
// GetUser mocks base method.
func (m *MockClient) GetUser(arg0 context.Context) (*pmapi.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUser", arg0)
ret0, _ := ret[0].(*pmapi.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUser indicates an expected call of GetUser.
func (mr *MockClientMockRecorder) GetUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockClient)(nil).GetUser), arg0)
}
// GetUserKeyRing mocks base method.
func (m *MockClient) GetUserKeyRing() (*crypto.KeyRing, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserKeyRing")
ret0, _ := ret[0].(*crypto.KeyRing)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserKeyRing indicates an expected call of GetUserKeyRing.
func (mr *MockClientMockRecorder) GetUserKeyRing() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserKeyRing", reflect.TypeOf((*MockClient)(nil).GetUserKeyRing))
}
// Import mocks base method.
func (m *MockClient) Import(arg0 context.Context, arg1 pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Import", arg0, arg1)
ret0, _ := ret[0].([]*pmapi.ImportMsgRes)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Import indicates an expected call of Import.
func (mr *MockClientMockRecorder) Import(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Import", reflect.TypeOf((*MockClient)(nil).Import), arg0, arg1)
}
// IsUnlocked mocks base method.
func (m *MockClient) IsUnlocked() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsUnlocked")
ret0, _ := ret[0].(bool)
return ret0
}
// IsUnlocked indicates an expected call of IsUnlocked.
func (mr *MockClientMockRecorder) IsUnlocked() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsUnlocked", reflect.TypeOf((*MockClient)(nil).IsUnlocked))
}
// KeyRingForAddressID mocks base method.
func (m *MockClient) KeyRingForAddressID(arg0 string) (*crypto.KeyRing, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyRingForAddressID", arg0)
ret0, _ := ret[0].(*crypto.KeyRing)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// KeyRingForAddressID indicates an expected call of KeyRingForAddressID.
func (mr *MockClientMockRecorder) KeyRingForAddressID(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyRingForAddressID", reflect.TypeOf((*MockClient)(nil).KeyRingForAddressID), arg0)
}
// LabelMessages mocks base method.
func (m *MockClient) LabelMessages(arg0 context.Context, arg1 []string, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LabelMessages", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// LabelMessages indicates an expected call of LabelMessages.
func (mr *MockClientMockRecorder) LabelMessages(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LabelMessages", reflect.TypeOf((*MockClient)(nil).LabelMessages), arg0, arg1, arg2)
}
// ListLabels mocks base method.
func (m *MockClient) ListLabels(arg0 context.Context) ([]*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListLabels", arg0)
ret0, _ := ret[0].([]*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListLabels indicates an expected call of ListLabels.
func (mr *MockClientMockRecorder) ListLabels(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListLabels", reflect.TypeOf((*MockClient)(nil).ListLabels), arg0)
}
// ListLabelsOnly mocks base method.
func (m *MockClient) ListLabelsOnly(arg0 context.Context) ([]*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListLabelsOnly", arg0)
ret0, _ := ret[0].([]*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListLabelsOnly indicates an expected call of ListLabelsOnly.
func (mr *MockClientMockRecorder) ListLabelsOnly(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListLabelsOnly", reflect.TypeOf((*MockClient)(nil).ListLabelsOnly), arg0)
}
// ListFoldersOnly mocks base method.
func (m *MockClient) ListFoldersOnly(arg0 context.Context) ([]*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListFoldersOnly", arg0)
ret0, _ := ret[0].([]*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListFoldersOnly indicates an expected call of ListFoldersOnly.
func (mr *MockClientMockRecorder) ListFoldersOnly(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListFoldersOnly", reflect.TypeOf((*MockClient)(nil).ListFoldersOnly), arg0)
}
// ListMessages mocks base method.
func (m *MockClient) ListMessages(arg0 context.Context, arg1 *pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListMessages", arg0, arg1)
ret0, _ := ret[0].([]*pmapi.Message)
ret1, _ := ret[1].(int)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// ListMessages indicates an expected call of ListMessages.
func (mr *MockClientMockRecorder) ListMessages(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListMessages", reflect.TypeOf((*MockClient)(nil).ListMessages), arg0, arg1)
}
// MarkMessagesRead mocks base method.
func (m *MockClient) MarkMessagesRead(arg0 context.Context, arg1 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkMessagesRead", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// MarkMessagesRead indicates an expected call of MarkMessagesRead.
func (mr *MockClientMockRecorder) MarkMessagesRead(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkMessagesRead", reflect.TypeOf((*MockClient)(nil).MarkMessagesRead), arg0, arg1)
}
// MarkMessagesUnread mocks base method.
func (m *MockClient) MarkMessagesUnread(arg0 context.Context, arg1 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkMessagesUnread", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// MarkMessagesUnread indicates an expected call of MarkMessagesUnread.
func (mr *MockClientMockRecorder) MarkMessagesUnread(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkMessagesUnread", reflect.TypeOf((*MockClient)(nil).MarkMessagesUnread), arg0, arg1)
}
// ReloadKeys mocks base method.
func (m *MockClient) ReloadKeys(arg0 context.Context, arg1 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadKeys", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadKeys indicates an expected call of ReloadKeys.
func (mr *MockClientMockRecorder) ReloadKeys(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadKeys", reflect.TypeOf((*MockClient)(nil).ReloadKeys), arg0, arg1)
}
// ReorderAddresses mocks base method.
func (m *MockClient) ReorderAddresses(arg0 context.Context, arg1 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReorderAddresses", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ReorderAddresses indicates an expected call of ReorderAddresses.
func (mr *MockClientMockRecorder) ReorderAddresses(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReorderAddresses", reflect.TypeOf((*MockClient)(nil).ReorderAddresses), arg0, arg1)
}
// SendMessage mocks base method.
func (m *MockClient) SendMessage(arg0 context.Context, arg1 string, arg2 *pmapi.SendMessageReq) (*pmapi.Message, *pmapi.Message, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendMessage", arg0, arg1, arg2)
ret0, _ := ret[0].(*pmapi.Message)
ret1, _ := ret[1].(*pmapi.Message)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// SendMessage indicates an expected call of SendMessage.
func (mr *MockClientMockRecorder) SendMessage(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockClient)(nil).SendMessage), arg0, arg1, arg2)
}
// UnlabelMessages mocks base method.
func (m *MockClient) UnlabelMessages(arg0 context.Context, arg1 []string, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnlabelMessages", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// UnlabelMessages indicates an expected call of UnlabelMessages.
func (mr *MockClientMockRecorder) UnlabelMessages(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnlabelMessages", reflect.TypeOf((*MockClient)(nil).UnlabelMessages), arg0, arg1, arg2)
}
// Unlock mocks base method.
func (m *MockClient) Unlock(arg0 context.Context, arg1 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unlock", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Unlock indicates an expected call of Unlock.
func (mr *MockClientMockRecorder) Unlock(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockClient)(nil).Unlock), arg0, arg1)
}
// UpdateLabel mocks base method.
func (m *MockClient) UpdateLabel(arg0 context.Context, arg1 *pmapi.Label) (*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateLabel", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateLabel indicates an expected call of UpdateLabel.
func (mr *MockClientMockRecorder) UpdateLabel(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLabel", reflect.TypeOf((*MockClient)(nil).UpdateLabel), arg0, arg1)
}
// UpdateLabelV4 mocks base method.
func (m *MockClient) UpdateLabelV4(arg0 context.Context, arg1 *pmapi.Label) (*pmapi.Label, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateLabelV4", arg0, arg1)
ret0, _ := ret[0].(*pmapi.Label)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateLabelV4 indicates an expected call of UpdateLabelV4.
func (mr *MockClientMockRecorder) UpdateLabelV4(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLabelUpdateLabelV4", reflect.TypeOf((*MockClient)(nil).UpdateLabelV4), arg0, arg1)
}
// UpdateUser mocks base method.
func (m *MockClient) UpdateUser(arg0 context.Context) (*pmapi.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUser", arg0)
ret0, _ := ret[0].(*pmapi.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUser indicates an expected call of UpdateUser.
func (mr *MockClientMockRecorder) UpdateUser(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUser", reflect.TypeOf((*MockClient)(nil).UpdateUser), arg0)
}
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// AddConnectionObserver mocks base method.
func (m *MockManager) AddConnectionObserver(arg0 pmapi.ConnectionObserver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddConnectionObserver", arg0)
}
// AddConnectionObserver indicates an expected call of AddConnectionObserver.
func (mr *MockManagerMockRecorder) AddConnectionObserver(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConnectionObserver", reflect.TypeOf((*MockManager)(nil).AddConnectionObserver), arg0)
}
// AllowProxy mocks base method.
func (m *MockManager) AllowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AllowProxy")
}
// AllowProxy indicates an expected call of AllowProxy.
func (mr *MockManagerMockRecorder) AllowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AllowProxy", reflect.TypeOf((*MockManager)(nil).AllowProxy))
}
// DisallowProxy mocks base method.
func (m *MockManager) DisallowProxy() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisallowProxy")
}
// DisallowProxy indicates an expected call of DisallowProxy.
func (mr *MockManagerMockRecorder) DisallowProxy() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisallowProxy", reflect.TypeOf((*MockManager)(nil).DisallowProxy))
}
// DownloadAndVerify mocks base method.
func (m *MockManager) DownloadAndVerify(arg0 *crypto.KeyRing, arg1, arg2 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DownloadAndVerify", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DownloadAndVerify indicates an expected call of DownloadAndVerify.
func (mr *MockManagerMockRecorder) DownloadAndVerify(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadAndVerify", reflect.TypeOf((*MockManager)(nil).DownloadAndVerify), arg0, arg1, arg2)
}
// NewClient mocks base method.
func (m *MockManager) NewClient(arg0, arg1, arg2 string, arg3 time.Time) pmapi.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewClient", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(pmapi.Client)
return ret0
}
// NewClient indicates an expected call of NewClient.
func (mr *MockManagerMockRecorder) NewClient(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewClient", reflect.TypeOf((*MockManager)(nil).NewClient), arg0, arg1, arg2, arg3)
}
// NewClientWithLogin mocks base method.
func (m *MockManager) NewClientWithLogin(arg0 context.Context, arg1 string, arg2 []byte) (pmapi.Client, *pmapi.Auth, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewClientWithLogin", arg0, arg1, arg2)
ret0, _ := ret[0].(pmapi.Client)
ret1, _ := ret[1].(*pmapi.Auth)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// NewClientWithLogin indicates an expected call of NewClientWithLogin.
func (mr *MockManagerMockRecorder) NewClientWithLogin(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewClientWithLogin", reflect.TypeOf((*MockManager)(nil).NewClientWithLogin), arg0, arg1, arg2)
}
// NewClientWithRefresh mocks base method.
func (m *MockManager) NewClientWithRefresh(arg0 context.Context, arg1, arg2 string) (pmapi.Client, *pmapi.AuthRefresh, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewClientWithRefresh", arg0, arg1, arg2)
ret0, _ := ret[0].(pmapi.Client)
ret1, _ := ret[1].(*pmapi.AuthRefresh)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// NewClientWithRefresh indicates an expected call of NewClientWithRefresh.
func (mr *MockManagerMockRecorder) NewClientWithRefresh(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewClientWithRefresh", reflect.TypeOf((*MockManager)(nil).NewClientWithRefresh), arg0, arg1, arg2)
}
// ReportBug mocks base method.
func (m *MockManager) ReportBug(arg0 context.Context, arg1 pmapi.ReportBugReq) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReportBug", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// ReportBug indicates an expected call of ReportBug.
func (mr *MockManagerMockRecorder) ReportBug(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportBug", reflect.TypeOf((*MockManager)(nil).ReportBug), arg0, arg1)
}
// SendSimpleMetric mocks base method.
func (m *MockManager) SendSimpleMetric(arg0 context.Context, arg1, arg2, arg3 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendSimpleMetric", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// SendSimpleMetric indicates an expected call of SendSimpleMetric.
func (mr *MockManagerMockRecorder) SendSimpleMetric(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSimpleMetric", reflect.TypeOf((*MockManager)(nil).SendSimpleMetric), arg0, arg1, arg2, arg3)
}
// SetCookieJar mocks base method.
func (m *MockManager) SetCookieJar(arg0 http.CookieJar) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetCookieJar", arg0)
}
// SetCookieJar indicates an expected call of SetCookieJar.
func (mr *MockManagerMockRecorder) SetCookieJar(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCookieJar", reflect.TypeOf((*MockManager)(nil).SetCookieJar), arg0)
}
// SetLogging mocks base method.
func (m *MockManager) SetLogging(arg0 *logrus.Entry, arg1 bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLogging", arg0, arg1)
}
// SetLogging indicates an expected call of SetLogging.
func (mr *MockManagerMockRecorder) SetLogging(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLogging", reflect.TypeOf((*MockManager)(nil).SetLogging), arg0, arg1)
}
// SetRetryCount mocks base method.
func (m *MockManager) SetRetryCount(arg0 int) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetRetryCount", arg0)
}
// SetRetryCount indicates an expected call of SetRetryCount.
func (mr *MockManagerMockRecorder) SetRetryCount(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRetryCount", reflect.TypeOf((*MockManager)(nil).SetRetryCount), arg0)
}
// SetTransport mocks base method.
func (m *MockManager) SetTransport(arg0 http.RoundTripper) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTransport", arg0)
}
// SetTransport indicates an expected call of SetTransport.
func (mr *MockManagerMockRecorder) SetTransport(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTransport", reflect.TypeOf((*MockManager)(nil).SetTransport), arg0)
}

View File

@ -1,40 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
type ConnectionObserver interface {
OnDown()
OnUp()
}
type observer struct {
onDown, onUp func()
}
// NewConnectionObserver is a helper function to create a new connection observer from two callbacks.
// It doesn't need to be used; anything which implements the ConnectionObserver interface can be an observer.
func NewConnectionObserver(onDown, onUp func()) ConnectionObserver {
return &observer{
onDown: onDown,
onUp: onUp,
}
}
func (o observer) OnDown() { o.onDown() }
func (o observer) OnUp() { o.onUp() }

View File

@ -1,32 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
const defaultPageSize = 100
func doPaged(elements []string, pageSize int, fn func([]string) error) error { //nolint:unparam
for len(elements) > pageSize {
if err := fn(elements[:pageSize]); err != nil {
return err
}
elements = elements[pageSize:]
}
return fn(elements)
}

View File

@ -1,44 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"encoding/base64"
"github.com/ProtonMail/go-srp"
"github.com/pkg/errors"
)
// HashMailboxPassword expectects 128bit long salt encoded by standard base64.
func HashMailboxPassword(password []byte, salt string) ([]byte, error) {
if salt == "" {
return password, nil
}
decodedSalt, err := base64.StdEncoding.DecodeString(salt)
if err != nil {
return nil, errors.Wrap(err, "failed to decode salt")
}
hash, err := srp.MailboxPassword(password, decodedSalt)
if err != nil {
return nil, errors.Wrap(err, "failed to hash password")
}
return hash[len(hash)-31:], nil
}

View File

@ -1,44 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMailboxPassword(t *testing.T) {
// wantHash was generated with passprase and salt defined below. It
// should not change when changing implementation of the function.
wantHash := []byte("B5nwpsJQSTJ16ldr64Vdq6oeCCn32Fi")
// Valid salt is 128bit long (16bytes)
// $echo aaaabbbbccccdddd | base64
salt := "YWFhYWJiYmJjY2NjZGRkZAo="
passphrase := []byte("random")
r := require.New(t)
_, err := HashMailboxPassword(passphrase, "badsalt")
r.Error(err)
haveHash, err := HashMailboxPassword(passphrase, salt)
r.NoError(err)
r.Equal(wantHash, haveHash)
}

View File

@ -1,24 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"github.com/sirupsen/logrus"
)
var log = logrus.WithField("pkg", "pmapi") //nolint:gochecknoglobals

View File

@ -1,76 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"os"
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
const (
testMailboxPassword = "apple"
testMailboxPasswordLegacy = "123"
)
var (
testPrivateKeyRing *crypto.KeyRing
testPublicKeyRing *crypto.KeyRing
)
func init() {
testPrivateKey := readTestFile("testPrivateKey", false)
testPublicKey := readTestFile("testPublicKey", false)
var err error
privKey, err := crypto.NewKeyFromArmored(testPrivateKey)
if err != nil {
panic(err)
}
privKeyUnlocked, err := privKey.Unlock([]byte(testMailboxPassword))
if err != nil {
panic(err)
}
pubKey, err := crypto.NewKeyFromArmored(testPublicKey)
if err != nil {
panic(err)
}
if testPrivateKeyRing, err = crypto.NewKeyRing(privKeyUnlocked); err != nil {
panic(err)
}
if testPublicKeyRing, err = crypto.NewKeyRing(pubKey); err != nil {
panic(err)
}
}
func readTestFile(name string, trimNewlines bool) string { //nolint:unparam
data, err := os.ReadFile("testdata/" + name)
if err != nil {
panic(err)
}
if trimNewlines {
return strings.TrimRight(string(data), "\n")
}
return string(data)
}

View File

@ -1,181 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"math/rand"
"net/http"
"strconv"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
errCodeUpgradeApplication = 5003
errCodePasswordWrong = 8002
errCodeAuthPaidPlanRequired = 10004
)
type Error struct {
Code int
Message string `json:"Error"`
}
func (err Error) Error() string {
return err.Message
}
func (m *manager) catchAPIError(_ *resty.Client, res *resty.Response) error {
if !res.IsError() {
return nil
}
if res.StatusCode() == http.StatusUnauthorized {
return ErrUnauthorized
}
var err error
if apiErr, ok := res.Error().(*Error); ok {
switch {
case apiErr.Code == errCodeUpgradeApplication:
if m.cfg.UpgradeApplicationHandler != nil {
m.cfg.UpgradeApplicationHandler()
}
return ErrUpgradeApplication
case apiErr.Code == errCodePasswordWrong:
return ErrPasswordWrong
case apiErr.Code == errCodeAuthPaidPlanRequired:
return ErrPaidPlanRequired
default:
err = apiErr
}
} else {
err = errors.New(res.Status())
}
switch res.StatusCode() {
case http.StatusUnprocessableEntity:
err = ErrUnprocessableEntity{err}
case http.StatusBadRequest:
err = ErrBadRequest{err}
}
return err
}
func updateTime(_ *resty.Client, res *resty.Response) error {
if date, err := time.Parse(time.RFC1123, res.Header().Get("Date")); err != nil {
log.WithError(err).Warning("Cannot parse header date")
} else {
crypto.UpdateTime(date.Unix())
}
return nil
}
func logConnReuse(_ *resty.Client, res *resty.Response) error {
if !res.Request.TraceInfo().IsConnReused {
logrus.WithField("host", res.Request.URL).Trace("Connection was NOT reused")
}
return nil
}
func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) {
if res.StatusCode() == http.StatusTooManyRequests {
if after := res.Header().Get("Retry-After"); after != "" {
l := log.
WithField("statusCode", res.StatusCode()).
WithField("url", res.Request.URL).
WithField("verb", res.Request.Method)
seconds, err := strconv.Atoi(after)
if err != nil {
l.WithError(err).Warning("Cannot convert Retry-After to number")
seconds = 10
}
// To avoid spikes when all clients retry at the same time, we add some random wait.
seconds += rand.Intn(10) //nolint:gosec // It is OK to use weak random number generator here.
l = l.WithField("seconds", seconds).WithField("start", time.Now().Unix())
// Maximum retry time in client is is one minute. But
// here wait times can be longer e.g. high API load
l.Warn("Retrying after induced by http code. Waiting now...")
time.Sleep(time.Duration(seconds) * time.Second)
l.Warn("Wait done")
return 0, nil
}
}
// 0 and no error means default behaviour which is exponential backoff with jitter.
return 0, nil
}
func (m *manager) shouldRetry(res *resty.Response, err error) bool {
if isRetryDisabled(res.Request.Context()) {
return false
}
if isTooManyRequest(res) {
return true
}
if isNoResponse(res, err) {
// Even if the context of request allows to retry we should check
// whether the server is reachable or not. In some cases the we can
// keep retrying but also report that connection is lost.
go m.pingUntilSuccess()
return true
}
return false
}
func isTooManyRequest(res *resty.Response) bool {
return res.StatusCode() == http.StatusTooManyRequests
}
func isNoResponse(res *resty.Response, err error) bool {
// Do not retry TLS failures
if errors.Is(err, ErrTLSMismatch) {
return false
}
return res.RawResponse == nil && err != nil
}
func wrapNoConnection(res *resty.Response, err error) (*resty.Response, error) {
if err, ok := err.(*resty.ResponseError); ok {
return res, err
}
if errors.Is(err, context.Canceled) {
return res, err
}
if res.RawResponse != nil {
return res, err
}
// Log useful information and return back nicer and clear error message.
logrus.WithError(err).WithField("url", res.Request.URL).Warn("No internet connection")
return res, ErrNoConnection
}

View File

@ -1,138 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"testing"
"time"
"github.com/hashicorp/go-multierror"
r "github.com/stretchr/testify/require"
)
var (
colRed = "\033[1;31m"
colNon = "\033[0;39m"
reHTTPCode = regexp.MustCompile(`(HTTP|get|post|put|delete)_(\d{3}).*.json`)
)
func newTestConfig(url string) Config {
return Config{
HostURL: url,
AppVersion: "GoPMAPI_1.0.14",
}
}
// newTestClient is old function and should be replaced everywhere by newTestClientCallbacks.
func newTestClient(h http.Handler) (*httptest.Server, Client) {
s := httptest.NewServer(h)
return s, newManager(newTestConfig(s.URL)).NewClient(testUID, testAccessToken, testRefreshToken, time.Now().Add(time.Hour))
}
func newTestClientCallbacks(tb testing.TB, callbacks ...func(testing.TB, http.ResponseWriter, *http.Request) string) (func(), Client) {
reqNum := 0
_, file, line, _ := runtime.Caller(1)
file = filepath.Base(file)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqNum++
if reqNum > len(callbacks) {
fmt.Printf(
"%s:%d: %sServer was requested %d times which is more requests than expected %d times%s\n\n",
file, line, colRed, reqNum, len(callbacks), colNon,
)
tb.FailNow()
}
response := callbacks[reqNum-1](tb, w, r)
if response != "" {
writeJSONResponsefromFile(tb, w, response, reqNum-1)
}
}))
finish := func() {
server.CloseClientConnections() // Closing without waiting for finishing requests.
if reqNum != len(callbacks) {
fmt.Printf(
"%s:%d: %sServer was requested %d times but expected to be %d times%s\n\n",
file, line, colRed, reqNum, len(callbacks), colNon,
)
tb.Error("server failed")
}
}
return finish, newManager(newTestConfig(server.URL)).NewClient(testUID, testAccessToken, testRefreshToken, time.Now().Add(time.Hour))
}
func checkMethodAndPath(r *http.Request, method, path string) error {
var result *multierror.Error
if err := checkHeader(r.Header, "x-pm-appversion", "GoPMAPI_1.0.14"); err != nil {
result = multierror.Append(result, err)
}
if r.Method != method {
err := fmt.Errorf("Invalid request method expected %v, got %v", method, r.Method)
result = multierror.Append(result, err)
}
if r.URL.RequestURI() != path {
err := fmt.Errorf("Invalid request path expected %v, got %v", path, r.URL.RequestURI())
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
func writeJSONResponsefromFile(tb testing.TB, w http.ResponseWriter, response string, reqNum int) {
if match := reHTTPCode.FindAllSubmatch([]byte(response), -1); len(match) != 0 {
httpCode, err := strconv.Atoi(string(match[0][len(match[0])-1]))
r.NoError(tb, err)
w.WriteHeader(httpCode)
}
f, err := os.Open("./testdata/routes/" + response)
r.NoError(tb, err)
w.Header().Set("content-type", "application/json;charset=utf-8")
w.Header().Set("x-test-pmapi-response", fmt.Sprintf("%s:%d", tb.Name(), reqNum))
_, err = io.Copy(w, f)
r.NoError(tb, err)
}
func checkHeader(h http.Header, field, exp string) error {
val := h.Get(field)
if val != exp {
msg := "wrong field %s expected %q but have %q"
return fmt.Errorf(msg, field, exp, val)
}
return nil
}
func isAuthReq(r *http.Request, uid, token string) error { //nolint:unparam always retrieves testUID
if err := checkHeader(r.Header, "x-pm-uid", uid); err != nil {
return err
}
if err := checkHeader(r.Header, "authorization", "Bearer "+token); err != nil { //nolint:revive can return the error right away but this is easier to read
return err
}
return nil
}

View File

@ -1,73 +0,0 @@
// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package pmapi
import (
"context"
"github.com/go-resty/resty/v2"
)
type MailSettings struct {
DisplayName string
Signature string `json:",omitempty"`
Theme string `json:",omitempty"`
AutoSaveContacts int
AutoWildcardSearch int
ComposerMode int
MessageButtons int
ShowImages int
ShowMoved int
ViewMode int
ViewLayout int
SwipeLeft int
SwipeRight int
AlsoArchive int
Hotkeys int
PMSignature int
ImageProxy int
TLS int
RightToLeft int
AttachPublicKey int
Sign int
PGPScheme PackageFlag
PromptPin int
Autocrypt int
NumMessagePerPage int
DraftMIMEType string
ReceiveMIMEType string
ShowMIMEType string
// Undocumented -- there's only `null` in example:
// AutoResponder string
}
// GetMailSettings gets contact details specified by contact ID.
func (c *client) GetMailSettings(ctx context.Context) (settings MailSettings, err error) {
var res struct {
MailSettings MailSettings
}
if _, err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/mail/v4/settings")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}

Some files were not shown because too many files have changed in this diff Show More