feat: add reloadkeys method

This commit is contained in:
James Houlahan
2020-06-16 11:28:11 +02:00
parent f3e6af5571
commit 9241a9bdbf
10 changed files with 80 additions and 49 deletions

View File

@ -413,8 +413,8 @@ func (u *User) UpdateUser() error {
return err
}
if err = u.client().Unlock([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to unlock user")
if err = u.client().ReloadKeys([]byte(u.creds.MailboxPassword)); err != nil {
return errors.Wrap(err, "failed to reload keys")
}
emails := u.client().Addresses().ActiveEmails()

View File

@ -39,7 +39,7 @@ func TestUpdateUser(t *testing.T) {
m.pmapiClient.EXPECT().Unlock([]byte("pass")).Return(nil),
m.pmapiClient.EXPECT().UpdateUser().Return(nil, nil),
m.pmapiClient.EXPECT().Unlock([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().ReloadKeys([]byte(testCredentials.MailboxPassword)).Return(nil),
m.pmapiClient.EXPECT().Addresses().Return([]*pmapi.Address{testPMAPIAddress}),
m.credentialsStore.EXPECT().UpdateEmails("user", []string{testPMAPIAddress.Email}),

View File

@ -197,25 +197,23 @@ func (c *client) Addresses() AddressList {
}
// unlockAddresses unlocks all keys for all addresses of current user.
func (c *client) unlockAddresses(passphrase []byte) (err error) {
for _, a := range c.addresses {
if a.HasKeys == MissingKeys {
continue
}
if c.addrKeyRing[a.ID] != nil {
continue
}
var kr *crypto.KeyRing
if kr, err = a.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
return
}
c.addrKeyRing[a.ID] = kr
func (c *client) unlockAddress(passphrase []byte, address *Address) (err error) {
if address == nil {
return errors.New("address data is missing")
}
if address.HasKeys == MissingKeys {
return
}
var kr *crypto.KeyRing
if kr, err = address.Keys.UnlockAll(passphrase, c.userKeyRing); err != nil {
return
}
c.addrKeyRing[address.ID] = kr
return
}

View File

@ -459,16 +459,5 @@ func (c *client) ClearData() {
c.accessToken = ""
c.addresses = nil
c.user = nil
if c.userKeyRing != nil {
c.userKeyRing.ClearPrivateParams()
c.userKeyRing = nil
}
for addrID, addr := range c.addrKeyRing {
if addr != nil {
addr.ClearPrivateParams()
delete(c.addrKeyRing, addrID)
}
}
c.clearKeys()
}

View File

@ -148,35 +148,56 @@ func (c *client) IsUnlocked() bool {
return c.userKeyRing != nil
}
// Unlock unlocks all the user and address keys using the given passphrase.
// Unlock unlocks all the user and address keys using the given passphrase, creating user and address keyrings.
// If the keyrings are already present, they are not recreated.
func (c *client) Unlock(passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
// If the user already has a keyring, we already unlocked, so no need to try again.
if c.userKeyRing != nil {
return
}
if _, err = c.CurrentUser(); err != nil {
return
}
if c.user == nil || c.addresses == nil {
return errors.New("user data is not loaded")
if c.userKeyRing == nil {
if err = c.unlockUser(passphrase); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
}
if err = c.unlockUser(passphrase); err != nil {
return errors.Wrap(err, "failed to unlock user")
}
if err = c.unlockAddresses(passphrase); err != nil {
return errors.Wrap(err, "failed to unlock addresses")
for _, address := range c.addresses {
if c.addrKeyRing[address.ID] == nil {
if err = c.unlockAddress(passphrase, address); err != nil {
return errors.Wrap(err, "failed to unlock address")
}
}
}
return
}
func (c *client) ReloadKeys(passphrase []byte) (err error) {
c.keyRingLock.Lock()
defer c.keyRingLock.Unlock()
c.clearKeys()
return c.Unlock(passphrase)
}
func (c *client) clearKeys() {
if c.userKeyRing != nil {
c.userKeyRing.ClearPrivateParams()
c.userKeyRing = nil
}
for id, kr := range c.addrKeyRing {
if kr != nil {
kr.ClearPrivateParams()
}
delete(c.addrKeyRing, id)
}
}
// Do makes an API request. It does not check for HTTP status code errors.
func (c *client) Do(req *http.Request, retryUnauthorized bool) (res *http.Response, err error) {
// Copy the request body in case we need to retry it.

View File

@ -38,6 +38,7 @@ type Client interface {
CurrentUser() (*User, error)
UpdateUser() (*User, error)
Unlock(passphrase []byte) (err error)
ReloadKeys(passphrase []byte) (err error)
IsUnlocked() bool
GetAddresses() (addresses AddressList, err error)

View File

@ -125,7 +125,7 @@ func (key PMKey) unlock(passphrase []byte) (unlockedKey *crypto.Key, err error)
type PMKeys []PMKey
// UnlockAll goes through each key and unlocks it, returning a keyring containing all unlocked keys,
// or an error if at least one could not be unlocked.
// or an error if no keys could be unlocked.
// The passphrase is used to unlock the key unless the key's token and signature are both non-nil,
// in which case the given userkey is used to deduce the passphrase.
func (keys *PMKeys) UnlockAll(passphrase []byte, userKey *crypto.KeyRing) (kr *crypto.KeyRing, err error) {

View File

@ -561,6 +561,20 @@ func (mr *MockClientMockRecorder) MarkMessagesUnread(arg0 interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkMessagesUnread", reflect.TypeOf((*MockClient)(nil).MarkMessagesUnread), arg0)
}
// ReloadKeys mocks base method
func (m *MockClient) ReloadKeys(arg0 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadKeys", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadKeys indicates an expected call of ReloadKeys
func (mr *MockClientMockRecorder) ReloadKeys(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadKeys", reflect.TypeOf((*MockClient)(nil).ReloadKeys), arg0)
}
// ReorderAddresses mocks base method
func (m *MockClient) ReorderAddresses(arg0 []string) error {
m.ctrl.T.Helper()

View File

@ -90,8 +90,8 @@ type UserRes struct {
// unlockUser unlocks all the client's user keys using the given passphrase.
func (c *client) unlockUser(passphrase []byte) (err error) {
if c.userKeyRing != nil {
return
if c.user == nil {
return errors.New("user data is not loaded")
}
if c.userKeyRing, err = c.user.Keys.UnlockAll(passphrase, nil); err != nil {

View File

@ -63,6 +63,14 @@ func (api *FakePMAPI) Unlock(passphrase []byte) (err error) {
return nil
}
func (api *FakePMAPI) ReloadKeys(passphrase []byte) (err error) {
if _, err = api.UpdateUser(); err != nil {
return
}
return api.Unlock(passphrase)
}
func (api *FakePMAPI) CurrentUser() (*pmapi.User, error) {
return api.UpdateUser()
}