mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
feat: add reloadkeys method
This commit is contained in:
@ -413,8 +413,8 @@ func (u *User) UpdateUser() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
if err = u.client().ReloadKeys([]byte(u.creds.MailboxPassword)); err != nil {
|
||||
return errors.Wrap(err, "failed to reload keys")
|
||||
}
|
||||
|
||||
emails := u.client().Addresses().ActiveEmails()
|
||||
|
||||
@ -39,7 +39,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
|
||||
|
||||
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
|
||||
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
|
||||
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
|
||||
|
||||
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),
|
||||
|
||||
@ -197,25 +197,23 @@ func (c *client) Addresses() AddressList {
|
||||
}
|
||||
|
||||
// unlockAddresses unlocks all keys for all addresses of current user.
|
||||
func (c *client) unlockAddresses(passphrase []byte) (err error) {
|
||||
for _, a := range c.addresses {
|
||||
if a.HasKeys == MissingKeys {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.addrKeyRing[a.ID] != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var kr *crypto.KeyRing
|
||||
|
||||
if kr, err = a.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.addrKeyRing[a.ID] = kr
|
||||
func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) {
|
||||
if address == nil {
|
||||
return errors.New("address data is missing")
|
||||
}
|
||||
|
||||
if address.HasKeys == MissingKeys {
|
||||
return
|
||||
}
|
||||
|
||||
var kr *crypto.KeyRing
|
||||
|
||||
if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.addrKeyRing[address.ID] = kr
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -459,16 +459,5 @@ func (c *client) ClearData() {
|
||||
c.accessToken = ""
|
||||
c.addresses = nil
|
||||
c.user = nil
|
||||
|
||||
if c.userKeyRing != nil {
|
||||
c.userKeyRing.ClearPrivateParams()
|
||||
c.userKeyRing = nil
|
||||
}
|
||||
|
||||
for addrID, addr := range c.addrKeyRing {
|
||||
if addr != nil {
|
||||
addr.ClearPrivateParams()
|
||||
delete(c.addrKeyRing, addrID)
|
||||
}
|
||||
}
|
||||
c.clearKeys()
|
||||
}
|
||||
|
||||
@ -148,35 +148,56 @@ func (c *client) IsUnlocked() bool {
|
||||
return c.userKeyRing != nil
|
||||
}
|
||||
|
||||
// Unlock unlocks all the user and address keys using the given passphrase.
|
||||
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
|
||||
// If the keyrings are already present, they are not recreated.
|
||||
func (c *client) Unlock(passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
// If the user already has a keyring, we already unlocked, so no need to try again.
|
||||
if c.userKeyRing != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = c.CurrentUser(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.user == nil || c.addresses == nil {
|
||||
return errors.New("user data is not loaded")
|
||||
if c.userKeyRing == nil {
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
}
|
||||
|
||||
if err = c.unlockUser(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock user")
|
||||
}
|
||||
|
||||
if err = c.unlockAddresses(passphrase); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock addresses")
|
||||
for _, address := range c.addresses {
|
||||
if c.addrKeyRing[address.ID] == nil {
|
||||
if err = c.unlockAddress(passphrase, address); err != nil {
|
||||
return errors.Wrap(err, "failed to unlock address")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *client) ReloadKeys(passphrase []byte) (err error) {
|
||||
c.keyRingLock.Lock()
|
||||
defer c.keyRingLock.Unlock()
|
||||
|
||||
c.clearKeys()
|
||||
|
||||
return c.Unlock(passphrase)
|
||||
}
|
||||
|
||||
func (c *client) clearKeys() {
|
||||
if c.userKeyRing != nil {
|
||||
c.userKeyRing.ClearPrivateParams()
|
||||
c.userKeyRing = nil
|
||||
}
|
||||
|
||||
for id, kr := range c.addrKeyRing {
|
||||
if kr != nil {
|
||||
kr.ClearPrivateParams()
|
||||
}
|
||||
delete(c.addrKeyRing, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Do makes an API request. It does not check for HTTP status code errors.
|
||||
func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
|
||||
// Copy the request body in case we need to retry it.
|
||||
|
||||
@ -38,6 +38,7 @@ type Client interface {
|
||||
CurrentUser() (*User, error)
|
||||
UpdateUser() (*User, error)
|
||||
Unlock(passphrase []byte) (err error)
|
||||
ReloadKeys(passphrase []byte) (err error)
|
||||
IsUnlocked() bool
|
||||
|
||||
GetAddresses() (addresses AddressList, err error)
|
||||
|
||||
@ -125,7 +125,7 @@ func (key PMKey) unlock(passphrase []byte) (unlockedKey *crypto.Key, err error)
|
||||
type PMKeys []PMKey
|
||||
|
||||
// UnlockAll goes through each key and unlocks it, returning a keyring containing all unlocked keys,
|
||||
// or an error if at least one could not be unlocked.
|
||||
// or an error if no keys could be unlocked.
|
||||
// The passphrase is used to unlock the key unless the key's token and signature are both non-nil,
|
||||
// in which case the given userkey is used to deduce the passphrase.
|
||||
func (keys *PMKeys) UnlockAll(passphrase []byte, userKey *crypto.KeyRing) (kr *crypto.KeyRing, err error) {
|
||||
|
||||
@ -561,6 +561,20 @@ func (mr *MockClientMockRecorder) MarkMessagesUnread(arg0 interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkMessagesUnread", reflect.TypeOf((*MockClient)(nil).MarkMessagesUnread), arg0)
|
||||
}
|
||||
|
||||
// ReloadKeys mocks base method
|
||||
func (m *MockClient) ReloadKeys(arg0 []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ReloadKeys", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ReloadKeys indicates an expected call of ReloadKeys
|
||||
func (mr *MockClientMockRecorder) ReloadKeys(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadKeys", reflect.TypeOf((*MockClient)(nil).ReloadKeys), arg0)
|
||||
}
|
||||
|
||||
// ReorderAddresses mocks base method
|
||||
func (m *MockClient) ReorderAddresses(arg0 []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -90,8 +90,8 @@ type UserRes struct {
|
||||
|
||||
// unlockUser unlocks all the client's user keys using the given passphrase.
|
||||
func (c *client) unlockUser(passphrase []byte) (err error) {
|
||||
if c.userKeyRing != nil {
|
||||
return
|
||||
if c.user == nil {
|
||||
return errors.New("user data is not loaded")
|
||||
}
|
||||
|
||||
if c.userKeyRing, err = c.user.Keys.UnlockAll(passphrase, nil); err != nil {
|
||||
|
||||
@ -63,6 +63,14 @@ func (api *FakePMAPI) Unlock(passphrase []byte) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) ReloadKeys(passphrase []byte) (err error) {
|
||||
if _, err = api.UpdateUser(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return api.Unlock(passphrase)
|
||||
}
|
||||
|
||||
func (api *FakePMAPI) CurrentUser() (*pmapi.User, error) {
|
||||
return api.UpdateUser()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user