GODT-35: New pmapi client and manager using resty

This commit is contained in:
James Houlahan
2021-02-22 18:23:51 +01:00
committed by Jakub
parent 1d538e8540
commit 2284e9ede1
163 changed files with 3333 additions and 8124 deletions

View File

@ -19,6 +19,7 @@ package fakeapi
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
@ -53,7 +54,7 @@ func newTestAttachment(iAtt int, msgID string) *pmapi.Attachment {
}
}
func (api *FakePMAPI) GetAttachment(attachmentID string) (io.ReadCloser, error) {
func (api *FakePMAPI) GetAttachment(_ context.Context, attachmentID string) (io.ReadCloser, error) {
if err := api.checkAndRecordCall(GET, "/mail/v4/attachments/"+attachmentID, nil); err != nil {
return nil, err
}
@ -65,7 +66,7 @@ func (api *FakePMAPI) GetAttachment(attachmentID string) (io.ReadCloser, error)
return ioutil.NopCloser(r), nil
}
func (api *FakePMAPI) CreateAttachment(attachment *pmapi.Attachment, data io.Reader, signature io.Reader) (*pmapi.Attachment, error) {
func (api *FakePMAPI) CreateAttachment(_ context.Context, attachment *pmapi.Attachment, data io.Reader, signature io.Reader) (*pmapi.Attachment, error) {
if err := api.checkAndRecordCall(POST, "/mail/v4/attachments", nil); err != nil {
return nil, err
}
@ -76,7 +77,3 @@ func (api *FakePMAPI) CreateAttachment(attachment *pmapi.Attachment, data io.Rea
attachment.KeyPackets = base64.StdEncoding.EncodeToString(bytes)
return attachment, nil
}
func (api *FakePMAPI) DeleteAttachment(attID string) error {
return api.checkAndRecordCall(DELETE, "/mail/v4/attachments/"+attID, nil)
}

View File

@ -18,76 +18,23 @@
package fakeapi
import (
"strings"
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func (api *FakePMAPI) SetAuths(auths chan<- *pmapi.Auth) {
api.auths = auths
}
func (api *FakePMAPI) AuthInfo(username string) (*pmapi.AuthInfo, error) {
if err := api.checkInternetAndRecordCall(POST, "/auth/info", &pmapi.AuthInfoReq{
Username: username,
}); err != nil {
return nil, err
}
authInfo := &pmapi.AuthInfo{}
user, ok := api.controller.usersByUsername[username]
if !ok {
// If username is wrong, API server will return empty but
// positive response
return authInfo, nil
}
authInfo.TwoFA = user.get2FAInfo()
return authInfo, nil
}
func (api *FakePMAPI) Auth(username, password string, authInfo *pmapi.AuthInfo) (*pmapi.Auth, error) {
if err := api.checkInternetAndRecordCall(POST, "/auth", &pmapi.AuthReq{
Username: username,
}); err != nil {
return nil, err
}
session, err := api.controller.createSessionIfAuthorized(username, password)
if err != nil {
return nil, err
}
api.setUID(session.uid)
if err := api.setUser(username); err != nil {
return nil, err
}
user := api.controller.usersByUsername[username]
auth := &pmapi.Auth{
TwoFA: user.get2FAInfo(),
RefreshToken: session.refreshToken,
ExpiresIn: 86400, // seconds
}
auth.DANGEROUSLYSetUID(session.uid)
api.sendAuth(auth)
return auth, nil
}
func (api *FakePMAPI) Auth2FA(twoFactorCode string, auth *pmapi.Auth) error {
if err := api.checkInternetAndRecordCall(POST, "/auth/2fa", &pmapi.Auth2FAReq{
TwoFactorCode: twoFactorCode,
}); err != nil {
func (api *FakePMAPI) Auth2FA(_ context.Context, req pmapi.Auth2FAReq) error {
if err := api.checkAndRecordCall(POST, "/auth/2fa", req); err != nil {
return err
}
if api.uid == "" {
return pmapi.ErrInvalidToken
return pmapi.ErrUnauthorized
}
session, ok := api.controller.sessionsByUID[api.uid]
if !ok {
return pmapi.ErrInvalidToken
return pmapi.ErrUnauthorized
}
session.hasFullScope = true
@ -95,92 +42,24 @@ func (api *FakePMAPI) Auth2FA(twoFactorCode string, auth *pmapi.Auth) error {
return nil
}
func (api *FakePMAPI) AuthRefresh(token string) (*pmapi.Auth, error) {
if api.lastToken == "" {
api.lastToken = token
}
split := strings.Split(token, ":")
if len(split) != 2 {
return nil, pmapi.ErrInvalidToken
}
if err := api.checkInternetAndRecordCall(POST, "/auth/refresh", &pmapi.AuthRefreshReq{
ResponseType: "token",
GrantType: "refresh_token",
UID: split[0],
RefreshToken: split[1],
RedirectURI: "https://protonmail.ch",
State: "random_string",
}); err != nil {
return nil, err
}
session, ok := api.controller.sessionsByUID[split[0]]
if !ok || session.refreshToken != split[1] {
api.log.WithField("token", token).
WithField("session", session).
Warn("Refresh token failed")
// The API server will respond normal error not 401 (check api)
// i.e. should not use `sendAuth(nil)`
api.setUID("")
return nil, pmapi.ErrInvalidToken
}
api.setUID(split[0])
if err := api.setUser(session.username); err != nil {
return nil, err
}
api.controller.refreshTheTokensForSession(session)
api.lastToken = split[0] + ":" + session.refreshToken
auth := &pmapi.Auth{
RefreshToken: session.refreshToken,
ExpiresIn: 86400,
}
auth.DANGEROUSLYSetUID(session.uid)
api.sendAuth(auth)
return auth, nil
}
func (api *FakePMAPI) AuthSalt() (string, error) {
if err := api.checkInternetAndRecordCall(GET, "/keys/salts", nil); err != nil {
func (api *FakePMAPI) AuthSalt(_ context.Context) (string, error) {
if err := api.checkAndRecordCall(GET, "/keys/salts", nil); err != nil {
return "", err
}
return "", nil
}
func (api *FakePMAPI) Logout() {
api.controller.clientManager.LogoutClient(api.userID)
func (api *FakePMAPI) AddAuthHandler(handler pmapi.AuthHandler) {
api.authHandlers = append(api.authHandlers, handler)
}
func (api *FakePMAPI) IsConnected() bool {
return api.uid != "" && api.lastToken != ""
}
func (api *FakePMAPI) DeleteAuth() error {
func (api *FakePMAPI) AuthDelete(_ context.Context) error {
if err := api.checkAndRecordCall(DELETE, "/auth", nil); err != nil {
return err
}
api.controller.deleteSession(api.uid)
return nil
}
func (api *FakePMAPI) ClearData() {
if api.userKeyRing != nil {
api.userKeyRing.ClearPrivateParams()
api.userKeyRing = nil
}
for addrID, addr := range api.addrKeyRing {
if addr != nil {
addr.ClearPrivateParams()
delete(api.addrKeyRing, addrID)
}
}
api.unsetUser()
}

View File

@ -18,6 +18,7 @@
package fakeapi
import (
"context"
"fmt"
"net/url"
"strconv"
@ -29,7 +30,7 @@ func (api *FakePMAPI) DecryptAndVerifyCards(cards []pmapi.Card) ([]pmapi.Card, e
return cards, nil
}
func (api *FakePMAPI) GetContactEmailByEmail(email string, page int, pageSize int) ([]pmapi.ContactEmail, error) {
func (api *FakePMAPI) GetContactEmailByEmail(_ context.Context, email string, page int, pageSize int) ([]pmapi.ContactEmail, error) {
v := url.Values{}
v.Set("Page", strconv.Itoa(page))
if pageSize > 0 {
@ -42,7 +43,7 @@ func (api *FakePMAPI) GetContactEmailByEmail(email string, page int, pageSize in
return []pmapi.ContactEmail{}, nil
}
func (api *FakePMAPI) GetContactByID(contactID string) (pmapi.Contact, error) {
func (api *FakePMAPI) GetContactByID(_ context.Context, contactID string) (pmapi.Contact, error) {
if err := api.checkAndRecordCall(GET, "/contacts/"+contactID, nil); err != nil {
return pmapi.Contact{}, err
}

View File

@ -32,7 +32,7 @@ type Controller struct {
labelIDGenerator idGenerator
messageIDGenerator idGenerator
tokenGenerator idGenerator
clientManager *pmapi.ClientManager
clientManager pmapi.Manager
// State controlled by test.
noInternetConnection bool
@ -46,7 +46,7 @@ type Controller struct {
log *logrus.Entry
}
func NewController(cm *pmapi.ClientManager) *Controller {
func NewController() (*Controller, pmapi.Manager) {
controller := &Controller{
lock: &sync.RWMutex{},
fakeAPIs: []*FakePMAPI{},
@ -54,7 +54,6 @@ func NewController(cm *pmapi.ClientManager) *Controller {
labelIDGenerator: 100, // We cannot use system label IDs.
messageIDGenerator: 0,
tokenGenerator: 1000, // No specific reason; 1000 simply feels right.
clientManager: cm,
noInternetConnection: false,
usersByUsername: map[string]*fakeUser{},
@ -67,11 +66,11 @@ func NewController(cm *pmapi.ClientManager) *Controller {
log: logrus.WithField("pkg", "fakeapi-controller"),
}
cm.SetClientConstructor(func(userID string) pmapi.Client {
fakeAPI := New(controller, userID)
controller.fakeAPIs = append(controller.fakeAPIs, fakeAPI)
return fakeAPI
})
cm := &fakePMAPIManager{
controller: controller,
}
return controller
controller.clientManager = cm
return controller, cm
}

View File

@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/nsf/jsondiff"
)
@ -39,23 +40,31 @@ type fakeCall struct {
request []byte
}
func (ctl *Controller) recordCall(method method, path string, req interface{}) {
func (ctl *Controller) recordCall(method method, path string, req interface{}) error {
ctl.lock.Lock()
defer ctl.lock.Unlock()
request := []byte{}
var request []byte
if req != nil {
var err error
request, err = json.Marshal(req)
if err != nil {
panic(err)
if request, err = json.Marshal(req); err != nil {
return err
}
}
ctl.calls = append(ctl.calls, &fakeCall{
method: method,
path: path,
request: request,
})
if ctl.noInternetConnection {
return pmapi.ErrNoConnection
}
return nil
}
func (ctl *Controller) PrintCalls() {

View File

@ -18,6 +18,7 @@
package fakeapi
import (
"context"
"errors"
"fmt"
"strings"
@ -51,7 +52,7 @@ func (ctl *Controller) ReorderAddresses(user *pmapi.User, addressIDs []string) e
return errors.New("no such user")
}
return api.ReorderAddresses(addressIDs)
return api.ReorderAddresses(context.TODO(), addressIDs)
}
func (ctl *Controller) AddUser(user *pmapi.User, addresses *pmapi.AddressList, password string, twoFAEnabled bool) error {

View File

@ -19,16 +19,36 @@ package fakeapi
import (
"errors"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
type fakeSession struct {
username string
uid, refreshToken string
hasFullScope bool
username string
uid, acc, ref string
hasFullScope bool
}
var errWrongNameOrPassword = errors.New("Incorrect login credentials. Please try again") //nolint[stylecheck]
func (ctl *Controller) checkAccessToken(uid, acc string) bool {
session, ok := ctl.sessionsByUID[uid]
if !ok {
return false
}
return session.uid == uid && session.acc == acc
}
func (ctl *Controller) checkScope(uid string) bool {
session, ok := ctl.sessionsByUID[uid]
if !ok {
return false
}
return session.hasFullScope
}
func (ctl *Controller) createSessionIfAuthorized(username, password string) (*fakeSession, error) {
// get user
user, ok := ctl.usersByUsername[username]
@ -40,16 +60,32 @@ func (ctl *Controller) createSessionIfAuthorized(username, password string) (*fa
session := &fakeSession{
username: username,
uid: ctl.tokenGenerator.next("uid"),
acc: ctl.tokenGenerator.next("acc"),
ref: ctl.tokenGenerator.next("ref"),
hasFullScope: !user.has2FA,
}
ctl.refreshTheTokensForSession(session)
ctl.sessionsByUID[session.uid] = session
return session, nil
}
func (ctl *Controller) refreshTheTokensForSession(session *fakeSession) {
session.refreshToken = ctl.tokenGenerator.next("refresh")
func (ctl *Controller) refreshSessionIfAuthorized(uid, ref string) (*fakeSession, error) {
session, ok := ctl.sessionsByUID[uid]
if !ok {
return nil, pmapi.ErrUnauthorized
}
if ref != session.ref {
return nil, pmapi.ErrUnauthorized
}
session.ref = ctl.tokenGenerator.next("ref")
session.acc = ctl.tokenGenerator.next("acc")
ctl.sessionsByUID[session.uid] = session
return session, nil
}
func (ctl *Controller) deleteSession(uid string) {

View File

@ -24,14 +24,3 @@ type fakeUser struct {
password string
has2FA bool
}
func (fu *fakeUser) get2FAInfo() *pmapi.TwoFactorInfo {
twoFAEnabled := 0
if fu.has2FA {
twoFAEnabled = 1
}
return &pmapi.TwoFactorInfo{
Enabled: twoFAEnabled,
TOTP: 0,
}
}

View File

@ -17,9 +17,13 @@
package fakeapi
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
import (
"context"
func (api *FakePMAPI) CountMessages(addressID string) ([]*pmapi.MessagesCount, error) {
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func (api *FakePMAPI) CountMessages(_ context.Context, addressID string) ([]*pmapi.MessagesCount, error) {
if err := api.checkAndRecordCall(GET, "/mail/v4/messages/count?AddressID="+addressID, nil); err != nil {
return nil, err
}
@ -43,10 +47,16 @@ func (api *FakePMAPI) getCounts(addressID string) []*pmapi.MessagesCount {
counts.Unread++
}
} else {
var unread int
if message.Unread == pmapi.True {
unread = 1
}
allCounts[labelID] = &pmapi.MessagesCount{
LabelID: labelID,
Total: 1,
Unread: message.Unread,
Unread: unread,
}
}
}

View File

@ -18,10 +18,12 @@
package fakeapi
import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func (api *FakePMAPI) GetEvent(eventID string) (*pmapi.Event, error) {
func (api *FakePMAPI) GetEvent(_ context.Context, eventID string) (*pmapi.Event, error) {
if err := api.checkAndRecordCall(GET, "/events/"+eventID, nil); err != nil {
return nil, err
}

View File

@ -34,28 +34,64 @@ type FakePMAPI struct {
controller *Controller
eventIDGenerator idGenerator
auths chan<- *pmapi.Auth
user *pmapi.User
userKeyRing *crypto.KeyRing
addresses *pmapi.AddressList
addrKeyRing map[string]*crypto.KeyRing
labels []*pmapi.Label
messages []*pmapi.Message
events []*pmapi.Event
authHandlers []pmapi.AuthHandler
user *pmapi.User
userKeyRing *crypto.KeyRing
addresses *pmapi.AddressList
addrKeyRing map[string]*crypto.KeyRing
labels []*pmapi.Label
messages []*pmapi.Message
events []*pmapi.Event
// uid represents the API UID. It is the unique session ID.
uid, lastToken string
uid string
acc string // FIXME(conman): Check this is correct!
ref string // FIXME(conman): Check this is correct!
log *logrus.Entry
}
func New(controller *Controller, userID string) *FakePMAPI {
fakePMAPI := &FakePMAPI{
func newFakePMAPI(controller *Controller, userID, uid, acc, ref string) *FakePMAPI {
return &FakePMAPI{
controller: controller,
log: logrus.WithField("pkg", "fakeapi"),
log: logrus.WithField("pkg", "fakeapi").WithField("uid", uid),
uid: uid,
acc: acc, // FIXME(conman): This should be checked!
ref: ref, // FIXME(conman): This should be checked!
userID: userID,
addrKeyRing: make(map[string]*crypto.KeyRing),
}
}
func NewFakePMAPI(controller *Controller, username, userID, uid, acc, ref string) (*FakePMAPI, error) {
user, ok := controller.usersByUsername[username]
if !ok {
return nil, fmt.Errorf("user %s does not exist", username)
}
addresses, ok := controller.addressesByUsername[username]
if !ok {
addresses = &pmapi.AddressList{}
}
labels, ok := controller.labelsByUsername[username]
if !ok {
labels = []*pmapi.Label{}
}
messages, ok := controller.messagesByUsername[username]
if !ok {
messages = []*pmapi.Message{}
}
fakePMAPI := newFakePMAPI(controller, userID, uid, acc, ref)
fakePMAPI.log = fakePMAPI.log.WithField("username", username)
fakePMAPI.username = username
fakePMAPI.user = user.user
fakePMAPI.addresses = addresses
fakePMAPI.labels = labels
fakePMAPI.messages = messages
fakePMAPI.addEvent(&pmapi.Event{
EventID: fakePMAPI.eventIDGenerator.last("event"),
@ -63,7 +99,7 @@ func New(controller *Controller, userID string) *FakePMAPI {
More: 0,
})
return fakePMAPI
return fakePMAPI, nil
}
func (api *FakePMAPI) CloseConnections() {
@ -74,54 +110,24 @@ func (api *FakePMAPI) checkAndRecordCall(method method, path string, request int
api.controller.locker.Lock()
defer api.controller.locker.Unlock()
if err := api.checkInternetAndRecordCall(method, path, request); err != nil {
api.log.WithField(string(method), path).Trace("CALL")
if err := api.controller.recordCall(method, path, request); err != nil {
return err
}
// Try re-auth
if api.uid == "" && api.lastToken != "" {
api.log.WithField("lastToken", api.lastToken).Warn("Handling unauthorized status")
if _, err := api.AuthRefresh(api.lastToken); err != nil {
return err
}
// FIXME(conman): This needs to match conman behaviour. Should try auth refresh somehow.
if !api.controller.checkAccessToken(api.uid, api.acc) {
return pmapi.ErrUnauthorized
}
// Check client is authenticated. There is difference between
// * invalid token
// * and missing token
// but API treats it the same
if api.uid == "" {
return pmapi.ErrInvalidToken
}
// Any route (except Auth and AuthRefresh) can end with wrong
// token and it should be translated into logout
session, ok := api.controller.sessionsByUID[api.uid]
if !ok {
api.setUID("") // all consecutive requests will not send auth nil
api.sendAuth(nil)
return pmapi.ErrInvalidToken
} else if !session.hasFullScope {
// This is exact error string from the server (at least from documentation).
if path != "/auth/2fa" && !api.controller.checkScope(api.uid) {
return errors.New("Access token does not have sufficient scope") //nolint[stylecheck]
}
return nil
}
func (api *FakePMAPI) checkInternetAndRecordCall(method method, path string, request interface{}) error {
api.log.WithField(string(method), path).Trace("CALL")
api.controller.recordCall(method, path, request)
if api.controller.noInternetConnection {
return pmapi.ErrAPINotReachable
}
return nil
}
func (api *FakePMAPI) sendAuth(auth *pmapi.Auth) {
api.controller.clientManager.HandleAuth(pmapi.ClientAuth{UserID: api.userID, Auth: auth})
}
func (api *FakePMAPI) setUser(username string) error {
api.username = username
api.log = api.log.WithField("username", username)
@ -153,14 +159,9 @@ func (api *FakePMAPI) setUser(username string) error {
return nil
}
func (api *FakePMAPI) setUID(uid string) {
api.uid = uid
api.log = api.log.WithField("uid", api.uid)
api.log.Info("UID updated")
}
func (api *FakePMAPI) unsetUser() {
api.setUID("")
api.uid = ""
api.acc = "" // FIXME(conman): This should be checked!
api.user = nil
api.labels = nil
api.messages = nil

View File

@ -17,7 +17,11 @@
package fakeapi
import "github.com/ProtonMail/proton-bridge/pkg/pmapi"
import (
"context"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
// publicKey is used from pmapi unit tests.
// For now we need just some key, no need to have some specific one.
@ -55,7 +59,7 @@ a+hqY4Jr/a7ui40S+7xYRHKL/7ZAS4/grWllhU3dbNrwSzrOKwrA/U0/9t73
-----END PGP PUBLIC KEY BLOCK-----
`
func (api *FakePMAPI) GetPublicKeysForEmail(email string) (keys []pmapi.PublicKey, internal bool, err error) {
func (api *FakePMAPI) GetPublicKeysForEmail(_ context.Context, email string) (keys []pmapi.PublicKey, internal bool, err error) {
if err := api.checkAndRecordCall(GET, "/keys?Email="+email, nil); err != nil {
return nil, false, err
}

View File

@ -18,6 +18,7 @@
package fakeapi
import (
"context"
"fmt"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
@ -32,14 +33,14 @@ func (api *FakePMAPI) isLabelFolder(labelID string) bool {
return labelID == pmapi.InboxLabel || labelID == pmapi.ArchiveLabel || labelID == pmapi.SentLabel
}
func (api *FakePMAPI) ListLabels() ([]*pmapi.Label, error) {
func (api *FakePMAPI) ListLabels(context.Context) ([]*pmapi.Label, error) {
if err := api.checkAndRecordCall(GET, "/labels/1", nil); err != nil {
return nil, err
}
return api.labels, nil
}
func (api *FakePMAPI) CreateLabel(label *pmapi.Label) (*pmapi.Label, error) {
func (api *FakePMAPI) CreateLabel(_ context.Context, label *pmapi.Label) (*pmapi.Label, error) {
if err := api.checkAndRecordCall(POST, "/labels", &pmapi.LabelReq{Label: label}); err != nil {
return nil, err
}
@ -61,7 +62,7 @@ func (api *FakePMAPI) CreateLabel(label *pmapi.Label) (*pmapi.Label, error) {
return label, nil
}
func (api *FakePMAPI) UpdateLabel(label *pmapi.Label) (*pmapi.Label, error) {
func (api *FakePMAPI) UpdateLabel(_ context.Context, label *pmapi.Label) (*pmapi.Label, error) {
if err := api.checkAndRecordCall(PUT, "/labels", &pmapi.LabelReq{Label: label}); err != nil {
return nil, err
}
@ -81,7 +82,7 @@ func (api *FakePMAPI) UpdateLabel(label *pmapi.Label) (*pmapi.Label, error) {
return nil, fmt.Errorf("label %s does not exist", label.ID)
}
func (api *FakePMAPI) DeleteLabel(labelID string) error {
func (api *FakePMAPI) DeleteLabel(_ context.Context, labelID string) error {
if err := api.checkAndRecordCall(DELETE, "/labels/"+labelID, nil); err != nil {
return err
}

164
test/fakeapi/manager.go Normal file
View File

@ -0,0 +1,164 @@
package fakeapi
import (
"context"
"net/http"
"net/url"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/go-resty/resty/v2"
)
type fakePMAPIManager struct {
controller *Controller
}
func (m *fakePMAPIManager) NewClient(uid string, acc string, ref string, _ time.Time) pmapi.Client {
session, ok := m.controller.sessionsByUID[uid]
if !ok {
return newFakePMAPI(m.controller, "", "", "", "")
}
user, ok := m.controller.usersByUsername[session.username]
if !ok {
return newFakePMAPI(m.controller, "", "", "", "")
}
client, err := NewFakePMAPI(m.controller, session.username, user.user.ID, session.uid, session.acc, session.ref)
if err != nil {
return newFakePMAPI(m.controller, "", "", "", "")
}
m.controller.fakeAPIs = append(m.controller.fakeAPIs, client)
return client
}
func (m *fakePMAPIManager) NewClientWithRefresh(_ context.Context, uid, ref string) (pmapi.Client, *pmapi.Auth, error) {
if err := m.controller.recordCall(POST, "/auth/refresh", &pmapi.AuthRefreshReq{
UID: uid,
RefreshToken: ref,
ResponseType: "token",
GrantType: "refresh_token",
RedirectURI: "https://protonmail.ch",
State: "random_string",
}); err != nil {
return nil, nil, err
}
session, err := m.controller.refreshSessionIfAuthorized(uid, ref)
if err != nil {
return nil, nil, pmapi.ErrUnauthorized
}
user, ok := m.controller.usersByUsername[session.username]
if !ok {
return nil, nil, errWrongNameOrPassword
}
client, err := NewFakePMAPI(m.controller, session.username, user.user.ID, session.uid, session.acc, session.ref)
if err != nil {
return nil, nil, err
}
m.controller.fakeAPIs = append(m.controller.fakeAPIs, client)
auth := &pmapi.Auth{
UID: session.uid,
AccessToken: session.acc,
RefreshToken: session.ref,
ExpiresIn: 86400, // seconds,
}
if user.has2FA {
auth.TwoFA = pmapi.TwoFAInfo{
Enabled: pmapi.TOTPEnabled,
}
}
return client, auth, nil
}
func (m *fakePMAPIManager) NewClientWithLogin(_ context.Context, username string, password string) (pmapi.Client, *pmapi.Auth, error) {
if err := m.controller.recordCall(POST, "/auth/info", &pmapi.GetAuthInfoReq{Username: username}); err != nil {
return nil, nil, err
}
// If username is wrong, API server will return empty but positive response.
// However, we will fail to create a client, so we return error here.
user, ok := m.controller.usersByUsername[username]
if !ok {
return nil, nil, errWrongNameOrPassword
}
if err := m.controller.recordCall(POST, "/auth", &pmapi.AuthReq{Username: username}); err != nil {
return nil, nil, err
}
session, err := m.controller.createSessionIfAuthorized(username, password)
if err != nil {
return nil, nil, err
}
client, err := NewFakePMAPI(m.controller, username, user.user.ID, session.uid, session.acc, session.ref)
if err != nil {
return nil, nil, err
}
m.controller.fakeAPIs = append(m.controller.fakeAPIs, client)
auth := &pmapi.Auth{
UID: session.uid,
AccessToken: session.acc,
RefreshToken: session.ref,
ExpiresIn: 86400, // seconds,
}
if user.has2FA {
auth.TwoFA = pmapi.TwoFAInfo{
Enabled: pmapi.TOTPEnabled,
}
}
return client, auth, nil
}
func (*fakePMAPIManager) DownloadAndVerify(kr *crypto.KeyRing, url, sig string) ([]byte, error) {
panic("TODO")
}
func (*fakePMAPIManager) ReportBug(context.Context, pmapi.ReportBugReq) error {
panic("TODO")
}
func (m *fakePMAPIManager) SendSimpleMetric(_ context.Context, cat string, act string, lab string) error {
v := url.Values{}
v.Set("Category", cat)
v.Set("Action", act)
v.Set("Label", lab)
return m.controller.recordCall(GET, "/metrics?"+v.Encode(), nil)
}
func (*fakePMAPIManager) SetLogger(resty.Logger) {
panic("TODO")
}
func (*fakePMAPIManager) SetTransport(http.RoundTripper) {
panic("TODO")
}
func (*fakePMAPIManager) SetCookieJar(http.CookieJar) {
panic("TODO")
}
func (*fakePMAPIManager) SetRetryCount(int) {
panic("TODO")
}
func (*fakePMAPIManager) AddConnectionObserver(pmapi.ConnectionObserver) {
panic("TODO")
}

View File

@ -19,6 +19,7 @@ package fakeapi
import (
"bytes"
"context"
"fmt"
"time"
@ -29,7 +30,7 @@ import (
var errWasNotUpdated = errors.New("message was not updated")
func (api *FakePMAPI) GetMessage(apiID string) (*pmapi.Message, error) {
func (api *FakePMAPI) GetMessage(_ context.Context, apiID string) (*pmapi.Message, error) {
if err := api.checkAndRecordCall(GET, "/mail/v4/messages/"+apiID, nil); err != nil {
return nil, err
}
@ -49,7 +50,7 @@ func (api *FakePMAPI) GetMessage(apiID string) (*pmapi.Message, error) {
// * ID
// * Attachments
// * AutoWildcard
func (api *FakePMAPI) ListMessages(filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
func (api *FakePMAPI) ListMessages(_ context.Context, filter *pmapi.MessagesFilter) ([]*pmapi.Message, int, error) {
if err := api.checkAndRecordCall(GET, "/mail/v4/messages", filter); err != nil {
return nil, 0, err
}
@ -131,10 +132,14 @@ func isMessageMatchingFilter(filter *pmapi.MessagesFilter, message *pmapi.Messag
return false
}
if filter.Unread != nil {
wantUnread := 0
var wantUnread pmapi.Boolean
if *filter.Unread {
wantUnread = 1
wantUnread = pmapi.True
} else {
wantUnread = pmapi.False
}
if message.Unread != wantUnread {
return false
}
@ -150,7 +155,7 @@ func copyFilteredMessage(message *pmapi.Message) *pmapi.Message {
return filteredMessage
}
func (api *FakePMAPI) CreateDraft(message *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
func (api *FakePMAPI) CreateDraft(ctx context.Context, message *pmapi.Message, parentID string, action int) (*pmapi.Message, error) {
if err := api.checkAndRecordCall(POST, "/mail/v4/messages", &pmapi.DraftReq{
Message: message,
ParentID: parentID,
@ -160,7 +165,7 @@ func (api *FakePMAPI) CreateDraft(message *pmapi.Message, parentID string, actio
return nil, err
}
if parentID != "" {
if _, err := api.GetMessage(parentID); err != nil {
if _, err := api.GetMessage(ctx, parentID); err != nil {
return nil, err
}
}
@ -174,11 +179,11 @@ func (api *FakePMAPI) CreateDraft(message *pmapi.Message, parentID string, actio
return message, nil
}
func (api *FakePMAPI) SendMessage(messageID string, sendMessageRequest *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error) {
func (api *FakePMAPI) SendMessage(ctx context.Context, messageID string, sendMessageRequest *pmapi.SendMessageReq) (sent, parent *pmapi.Message, err error) {
if err := api.checkAndRecordCall(POST, "/mail/v4/messages/"+messageID, sendMessageRequest); err != nil {
return nil, nil, err
}
message, err := api.GetMessage(messageID)
message, err := api.GetMessage(ctx, messageID)
if err != nil {
return nil, nil, errors.Wrap(err, "draft does not exist")
}
@ -188,7 +193,7 @@ func (api *FakePMAPI) SendMessage(messageID string, sendMessageRequest *pmapi.Se
return message, nil, nil
}
func (api *FakePMAPI) Import(importMessageRequests []*pmapi.ImportMsgReq) ([]*pmapi.ImportMsgRes, error) {
func (api *FakePMAPI) Import(_ context.Context, importMessageRequests pmapi.ImportMsgReqs) ([]*pmapi.ImportMsgRes, error) {
if err := api.checkAndRecordCall(POST, "/import", importMessageRequests); err != nil {
return nil, err
}
@ -211,7 +216,7 @@ func (api *FakePMAPI) Import(importMessageRequests []*pmapi.ImportMsgReq) ([]*pm
}
func (api *FakePMAPI) generateMessageFromImportRequest(msgReq *pmapi.ImportMsgReq) (*pmapi.Message, error) {
m, _, _, _, err := message.Parse(bytes.NewReader(msgReq.Body)) // nolint[dogsled]
m, _, _, _, err := message.Parse(bytes.NewReader(msgReq.Message)) // nolint[dogsled]
if err != nil {
return nil, err
}
@ -230,16 +235,16 @@ func (api *FakePMAPI) generateMessageFromImportRequest(msgReq *pmapi.ImportMsgRe
return &pmapi.Message{
ID: messageID,
ExternalID: m.ExternalID,
AddressID: msgReq.AddressID,
AddressID: msgReq.Metadata.AddressID,
Sender: m.Sender,
ToList: m.ToList,
Subject: m.Subject,
Unread: msgReq.Unread,
Unread: msgReq.Metadata.Unread,
LabelIDs: api.generateLabelIDsFromImportRequest(msgReq),
Body: m.Body,
Header: m.Header,
Flags: msgReq.Flags,
Time: msgReq.Time,
Flags: msgReq.Metadata.Flags,
Time: msgReq.Metadata.Time,
}, nil
}
@ -248,17 +253,17 @@ func (api *FakePMAPI) generateMessageFromImportRequest(msgReq *pmapi.ImportMsgRe
func (api *FakePMAPI) generateLabelIDsFromImportRequest(msgReq *pmapi.ImportMsgReq) []string {
isInSentOrInbox := false
labelIDs := []string{pmapi.AllMailLabel}
for _, labelID := range msgReq.LabelIDs {
for _, labelID := range msgReq.Metadata.LabelIDs {
if labelID == pmapi.InboxLabel || labelID == pmapi.SentLabel {
isInSentOrInbox = true
} else {
labelIDs = append(labelIDs, labelID)
}
}
if isInSentOrInbox && (msgReq.Flags&pmapi.FlagSent) != 0 {
if isInSentOrInbox && (msgReq.Metadata.Flags&pmapi.FlagSent) != 0 {
labelIDs = append(labelIDs, pmapi.SentLabel)
}
if isInSentOrInbox && (msgReq.Flags&pmapi.FlagReceived) != 0 {
if isInSentOrInbox && (msgReq.Metadata.Flags&pmapi.FlagReceived) != 0 {
labelIDs = append(labelIDs, pmapi.InboxLabel)
}
return labelIDs
@ -287,7 +292,7 @@ func (api *FakePMAPI) addMessage(message *pmapi.Message) {
api.addEventMessage(pmapi.EventCreate, message)
}
func (api *FakePMAPI) DeleteMessages(apiIDs []string) error {
func (api *FakePMAPI) DeleteMessages(_ context.Context, apiIDs []string) error {
err := api.deleteMessages(PUT, "/mail/v4/messages/delete", &pmapi.MessagesActionReq{
IDs: apiIDs,
}, func(message *pmapi.Message) bool {
@ -304,7 +309,7 @@ func (api *FakePMAPI) DeleteMessages(apiIDs []string) error {
return nil
}
func (api *FakePMAPI) EmptyFolder(labelID string, addressID string) error {
func (api *FakePMAPI) EmptyFolder(_ context.Context, labelID string, addressID string) error {
err := api.deleteMessages(DELETE, "/mail/v4/messages/empty?LabelID="+labelID+"&AddressID="+addressID, nil, func(message *pmapi.Message) bool {
return hasItem(message.LabelIDs, labelID) && message.AddressID == addressID
})
@ -340,7 +345,7 @@ func (api *FakePMAPI) deleteMessages(method method, path string, request interfa
return nil
}
func (api *FakePMAPI) LabelMessages(apiIDs []string, labelID string) error {
func (api *FakePMAPI) LabelMessages(_ context.Context, apiIDs []string, labelID string) error {
return api.updateMessages(PUT, "/mail/v4/messages/label", &pmapi.LabelMessagesReq{
IDs: apiIDs,
LabelID: labelID,
@ -366,7 +371,7 @@ func (api *FakePMAPI) LabelMessages(apiIDs []string, labelID string) error {
})
}
func (api *FakePMAPI) UnlabelMessages(apiIDs []string, labelID string) error {
func (api *FakePMAPI) UnlabelMessages(_ context.Context, apiIDs []string, labelID string) error {
return api.updateMessages(PUT, "/mail/v4/messages/unlabel", &pmapi.LabelMessagesReq{
IDs: apiIDs,
LabelID: labelID,
@ -384,7 +389,7 @@ func (api *FakePMAPI) UnlabelMessages(apiIDs []string, labelID string) error {
})
}
func (api *FakePMAPI) MarkMessagesRead(apiIDs []string) error {
func (api *FakePMAPI) MarkMessagesRead(_ context.Context, apiIDs []string) error {
return api.updateMessages(PUT, "/mail/v4/messages/read", &pmapi.MessagesActionReq{
IDs: apiIDs,
}, apiIDs, func(message *pmapi.Message) error {
@ -396,7 +401,7 @@ func (api *FakePMAPI) MarkMessagesRead(apiIDs []string) error {
})
}
func (api *FakePMAPI) MarkMessagesUnread(apiIDs []string) error {
func (api *FakePMAPI) MarkMessagesUnread(_ context.Context, apiIDs []string) error {
err := api.updateMessages(PUT, "/mail/v4/messages/unread", &pmapi.MessagesActionReq{
IDs: apiIDs,
}, apiIDs, func(message *pmapi.Message) error {

View File

@ -1,40 +0,0 @@
// Copyright (c) 2021 Proton Technologies AG
//
// This file is part of ProtonMail Bridge.Bridge.
//
// ProtonMail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// ProtonMail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
package fakeapi
import (
"net/url"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func (api *FakePMAPI) Report(report pmapi.ReportReq) error {
return api.checkInternetAndRecordCall(POST, "/reports/bug", report)
}
func (api *FakePMAPI) SendSimpleMetric(category, action, label string) error {
v := url.Values{}
v.Set("Category", category)
v.Set("Action", action)
v.Set("Label", label)
return api.checkInternetAndRecordCall(GET, "/metrics?"+v.Encode(), nil)
}
func (api *FakePMAPI) ReportSentryCrash(err error) error {
return nil
}

View File

@ -18,11 +18,13 @@
package fakeapi
import (
"context"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/pkg/pmapi"
)
func (api *FakePMAPI) GetMailSettings() (pmapi.MailSettings, error) {
func (api *FakePMAPI) GetMailSettings(context.Context) (pmapi.MailSettings, error) {
if err := api.checkAndRecordCall(GET, "/mail/v4/settings", nil); err != nil {
return pmapi.MailSettings{}, err
}
@ -33,7 +35,7 @@ func (api *FakePMAPI) IsUnlocked() bool {
return api.userKeyRing != nil
}
func (api *FakePMAPI) Unlock(passphrase []byte) (err error) {
func (api *FakePMAPI) Unlock(_ context.Context, passphrase []byte) (err error) {
if api.userKeyRing != nil {
return
}
@ -63,19 +65,19 @@ func (api *FakePMAPI) Unlock(passphrase []byte) (err error) {
return nil
}
func (api *FakePMAPI) ReloadKeys(passphrase []byte) (err error) {
if _, err = api.UpdateUser(); err != nil {
func (api *FakePMAPI) ReloadKeys(ctx context.Context, passphrase []byte) (err error) {
if _, err = api.UpdateUser(ctx); err != nil {
return
}
return api.Unlock(passphrase)
return api.Unlock(ctx, passphrase)
}
func (api *FakePMAPI) CurrentUser() (*pmapi.User, error) {
return api.UpdateUser()
func (api *FakePMAPI) CurrentUser(ctx context.Context) (*pmapi.User, error) {
return api.UpdateUser(ctx)
}
func (api *FakePMAPI) UpdateUser() (*pmapi.User, error) {
func (api *FakePMAPI) UpdateUser(context.Context) (*pmapi.User, error) {
if err := api.checkAndRecordCall(GET, "/users", nil); err != nil {
return nil, err
}
@ -83,14 +85,14 @@ func (api *FakePMAPI) UpdateUser() (*pmapi.User, error) {
return api.user, nil
}
func (api *FakePMAPI) GetAddresses() (pmapi.AddressList, error) {
func (api *FakePMAPI) GetAddresses(context.Context) (pmapi.AddressList, error) {
if err := api.checkAndRecordCall(GET, "/addresses", nil); err != nil {
return nil, err
}
return *api.addresses, nil
}
func (api *FakePMAPI) ReorderAddresses(addressIDs []string) error {
func (api *FakePMAPI) ReorderAddresses(_ context.Context, addressIDs []string) error {
if err := api.checkAndRecordCall(PUT, "/addresses/order", nil); err != nil {
return err
}