refactor: GetBridgeAuthChannel --> GetAuthUpdateChannel

This commit is contained in:
James Houlahan
2020-04-08 10:29:01 +02:00
parent 042c340881
commit b01be382fc
10 changed files with 47 additions and 46 deletions

View File

@ -25,7 +25,7 @@ import (
"time" "time"
"github.com/ProtonMail/proton-bridge/internal/events" "github.com/ProtonMail/proton-bridge/internal/events"
m "github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/metrics"
"github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/preferences"
"github.com/ProtonMail/proton-bridge/internal/store" "github.com/ProtonMail/proton-bridge/internal/store"
"github.com/ProtonMail/proton-bridge/pkg/config" "github.com/ProtonMail/proton-bridge/pkg/config"
@ -117,7 +117,7 @@ func New(
} }
if pref.GetBool(preferences.FirstStartKey) { if pref.GetBool(preferences.FirstStartKey) {
b.SendMetric(m.New(m.Setup, m.FirstStart, m.Label(version))) b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(version)))
} }
go b.heartbeat() go b.heartbeat()
@ -134,7 +134,7 @@ func (b *Bridge) heartbeat() {
} }
nextTime := time.Unix(next, 0) nextTime := time.Unix(next, 0)
if time.Now().After(nextTime) { if time.Now().After(nextTime) {
b.SendMetric(m.New(m.Heartbeat, m.Daily, m.NoLabel)) b.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel))
nextTime = nextTime.Add(24 * time.Hour) nextTime = nextTime.Add(24 * time.Hour)
b.pref.Set(preferences.NextHeartbeatKey, strconv.FormatInt(nextTime.Unix(), 10)) b.pref.Set(preferences.NextHeartbeatKey, strconv.FormatInt(nextTime.Unix(), 10))
} }
@ -180,7 +180,7 @@ func (b *Bridge) watchBridgeOutdated() {
// watchAPIAuths receives auths from the client manager and sends them to the appropriate user. // watchAPIAuths receives auths from the client manager and sends them to the appropriate user.
func (b *Bridge) watchAPIAuths() { func (b *Bridge) watchAPIAuths() {
for auth := range b.clientManager.GetBridgeAuthChannel() { for auth := range b.clientManager.GetAuthUpdateChannel() {
logrus.Debug("Bridge received auth from ClientManager") logrus.Debug("Bridge received auth from ClientManager")
user, ok := b.hasUser(auth.UserID) user, ok := b.hasUser(auth.UserID)
@ -296,6 +296,9 @@ func (b *Bridge) connectExistingUser(user *User, auth *pmapi.Auth, hashedPasswor
// addNewUser adds a new bridge user to the bridge. // addNewUser adds a new bridge user to the bridge.
func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) { func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) {
b.lock.Lock()
defer b.lock.Unlock()
client := b.clientManager.GetClient(user.ID) client := b.clientManager.GetClient(user.ID)
if auth, err = client.AuthRefresh(auth.GenToken()); err != nil { if auth, err = client.AuthRefresh(auth.GenToken()); err != nil {
@ -325,7 +328,7 @@ func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword s
return errors.Wrap(err, "failed to initialise user") return errors.Wrap(err, "failed to initialise user")
} }
b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel)) b.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel))
return return
} }
@ -459,7 +462,7 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address,
} }
// SendMetric sends a metric. We don't want to return any errors, only log them. // SendMetric sends a metric. We don't want to return any errors, only log them.
func (b *Bridge) SendMetric(m m.Metric) { func (b *Bridge) SendMetric(m metrics.Metric) {
c := b.clientManager.GetClient("metric_reporter") c := b.clientManager.GetClient("metric_reporter")
defer c.Logout() defer c.Logout()

View File

@ -214,7 +214,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge {
m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes()
m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes()
m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any())
m.clientManager.EXPECT().GetBridgeAuthChannel().Return(make(chan *pmapi.ClientAuth)) m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan *pmapi.ClientAuth))
bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore)

View File

@ -15,7 +15,7 @@
// You should have received a copy of the GNU General Public License // You should have received a copy of the GNU General Public License
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>. // along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
// Code generated by ./credits.sh at Thu Apr 16 13:43:04 CEST 2020. DO NOT EDIT. // Code generated by ./credits.sh at Fri Apr 17 13:33:28 CEST 2020. DO NOT EDIT.
package bridge package bridge

View File

@ -264,18 +264,18 @@ func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient))
} }
// GetBridgeAuthChannel mocks base method // GetAuthUpdateChannel mocks base method
func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth { func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBridgeAuthChannel") ret := m.ctrl.Call(m, "GetAuthUpdateChannel")
ret0, _ := ret[0].(chan pmapi.ClientAuth) ret0, _ := ret[0].(chan pmapi.ClientAuth)
return ret0 return ret0
} }
// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel // GetAuthUpdateChannel indicates an expected call of GetBridgeAuthChannel
func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call { func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel))
} }
// GetClient mocks base method // GetClient mocks base method

View File

@ -57,5 +57,5 @@ type ClientManager interface {
GetAnonymousClient() pmapi.Client GetAnonymousClient() pmapi.Client
AllowProxy() AllowProxy()
DisallowProxy() DisallowProxy()
GetBridgeAuthChannel() chan pmapi.ClientAuth GetAuthUpdateChannel() chan pmapi.ClientAuth
} }

View File

@ -105,7 +105,7 @@ func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) {
cm.roundTripper = rt cm.roundTripper = rt
} }
// GetRoundTripper sets the roundtripper used by clients created by this client manager. // GetRoundTripper gets the roundtripper used by clients created by this client manager.
func (cm *ClientManager) GetRoundTripper() (rt http.RoundTripper) { func (cm *ClientManager) GetRoundTripper() (rt http.RoundTripper) {
return cm.roundTripper return cm.roundTripper
} }
@ -212,15 +212,14 @@ func (cm *ClientManager) IsProxyEnabled() bool {
return cm.host != RootURL return cm.host != RootURL
} }
// SwitchToProxy returns a usable proxy server. // switchToReachableServer switches to using a reachable server (either proxy or standard API).
// TODO: Perhaps the name could be better -- we aren't only switching to a proxy but also to the standard API. func (cm *ClientManager) switchToReachableServer() (proxy string, err error) {
func (cm *ClientManager) SwitchToProxy() (proxy string, err error) {
cm.hostLocker.Lock() cm.hostLocker.Lock()
defer cm.hostLocker.Unlock() defer cm.hostLocker.Unlock()
logrus.Info("Attempting to switch to a proxy") logrus.Info("Attempting to switch to a proxy")
if proxy, err = cm.proxyProvider.findProxy(); err != nil { if proxy, err = cm.proxyProvider.findReachableServer(); err != nil {
err = errors.Wrap(err, "failed to find a usable proxy") err = errors.Wrap(err, "failed to find a usable proxy")
return return
} }
@ -254,8 +253,8 @@ func (cm *ClientManager) GetToken(userID string) string {
return cm.tokens[userID] return cm.tokens[userID]
} }
// GetBridgeAuthChannel returns a channel on which client auths can be received. // GetAuthUpdateChannel returns a channel on which client auths can be received.
func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth { func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth {
return cm.bridgeAuths return cm.bridgeAuths
} }
@ -346,6 +345,6 @@ func (cm *ClientManager) watchTokenExpiration(userID string) {
cm.clients[userID].AuthRefresh(cm.tokens[userID]) cm.clients[userID].AuthRefresh(cm.tokens[userID])
case <-expiration.cancel: case <-expiration.cancel:
logrus.WithField("userID", userID).Info("Auth was refreshed before it expired") logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired")
} }
} }

View File

@ -303,7 +303,7 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn
} }
// Switch to a proxy and retry the dial. // Switch to a proxy and retry the dial.
proxy, err := p.cm.SwitchToProxy() proxy, err := p.cm.switchToReachableServer()
if err != nil { if err != nil {
return return
} }

View File

@ -72,10 +72,9 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { //
return return
} }
// findProxy returns a new working proxy domain. This includes the standard API. // findReachableServer returns a working API server (either proxy or standard API).
// It returns an error if the process takes longer than ProxySearchTime. // It returns an error if the process takes longer than ProxySearchTime.
// TODO: Perhaps the name can be better -- we might also return the standard API. func (p *proxyProvider) findReachableServer() (proxy string, err error) {
func (p *proxyProvider) findProxy() (proxy string, err error) {
if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) {
return "", errors.New("not looking for a proxy, too soon") return "", errors.New("not looking for a proxy, too soon")
} }

View File

@ -42,7 +42,7 @@ func TestProxyProvider_FindProxy(t *testing.T) {
p := newProxyProvider([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := p.findProxy() url, err := p.findReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, url)
} }
@ -61,7 +61,7 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) {
p := newProxyProvider([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil }
url, err := p.findProxy() url, err := p.findReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, goodProxy.URL, url) require.Equal(t, goodProxy.URL, url)
} }
@ -80,7 +80,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) {
p := newProxyProvider([]string{"not used"}, "not used") p := newProxyProvider([]string{"not used"}, "not used")
p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil }
_, err := p.findProxy() _, err := p.findReachableServer()
require.Error(t, err) require.Error(t, err)
} }
@ -95,8 +95,8 @@ func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) {
p.lookupTimeout = time.Second p.lookupTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil }
// The findProxy should fail because lookup takes 2 seconds but we only allow 1 second. // The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second.
_, err := p.findProxy() _, err := p.findReachableServer()
require.Error(t, err) require.Error(t, err)
} }
@ -113,8 +113,8 @@ func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) {
p.findTimeout = time.Second p.findTimeout = time.Second
p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil }
// The findProxy should fail because lookup takes 2 seconds but we only allow 1 second. // The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second.
_, err := p.findProxy() _, err := p.findReachableServer()
require.Error(t, err) require.Error(t, err)
} }
@ -131,7 +131,7 @@ func TestProxyProvider_UseProxy(t *testing.T) {
cm.proxyProvider = p cm.proxyProvider = p
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := cm.SwitchToProxy() url, err := cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, url)
require.Equal(t, proxy.URL, cm.GetHost()) require.Equal(t, proxy.URL, cm.GetHost())
@ -154,7 +154,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
cm.proxyProvider = p cm.proxyProvider = p
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil }
url, err := cm.SwitchToProxy() url, err := cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy1.URL, url) require.Equal(t, proxy1.URL, url)
require.Equal(t, proxy1.URL, cm.GetHost()) require.Equal(t, proxy1.URL, cm.GetHost())
@ -163,7 +163,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
time.Sleep(proxyLookupWait) time.Sleep(proxyLookupWait)
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil }
url, err = cm.SwitchToProxy() url, err = cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy2.URL, url) require.Equal(t, proxy2.URL, url)
require.Equal(t, proxy2.URL, cm.GetHost()) require.Equal(t, proxy2.URL, cm.GetHost())
@ -172,7 +172,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) {
time.Sleep(proxyLookupWait) time.Sleep(proxyLookupWait)
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil }
url, err = cm.SwitchToProxy() url, err = cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy3.URL, url) require.Equal(t, proxy3.URL, url)
require.Equal(t, proxy3.URL, cm.GetHost()) require.Equal(t, proxy3.URL, cm.GetHost())
@ -192,7 +192,7 @@ func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) {
cm.proxyUseDuration = time.Second cm.proxyUseDuration = time.Second
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := cm.SwitchToProxy() url, err := cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, url)
require.Equal(t, proxy.URL, cm.GetHost()) require.Equal(t, proxy.URL, cm.GetHost())
@ -214,7 +214,7 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab
cm.proxyProvider = p cm.proxyProvider = p
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil }
url, err := cm.SwitchToProxy() url, err := cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, url)
require.Equal(t, proxy.URL, cm.GetHost()) require.Equal(t, proxy.URL, cm.GetHost())
@ -225,7 +225,7 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab
time.Sleep(proxyLookupWait) time.Sleep(proxyLookupWait)
// We should now find the original API URL if it is working again. // We should now find the original API URL if it is working again.
url, err = cm.SwitchToProxy() url, err = cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, RootURL, url) require.Equal(t, RootURL, url)
require.Equal(t, RootURL, cm.GetHost()) require.Equal(t, RootURL, cm.GetHost())
@ -247,7 +247,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl
// Find a proxy. // Find a proxy.
p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil }
url, err := cm.SwitchToProxy() url, err := cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy1.URL, url) require.Equal(t, proxy1.URL, url)
require.Equal(t, proxy1.URL, cm.GetHost()) require.Equal(t, proxy1.URL, cm.GetHost())
@ -259,7 +259,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl
proxy1.Close() proxy1.Close()
// Should switch to the second proxy because both the first proxy and the protonmail API are blocked. // Should switch to the second proxy because both the first proxy and the protonmail API are blocked.
url, err = cm.SwitchToProxy() url, err = cm.switchToReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proxy2.URL, url) require.Equal(t, proxy2.URL, url)
require.Equal(t, proxy2.URL, cm.GetHost()) require.Equal(t, proxy2.URL, cm.GetHost())
@ -284,7 +284,7 @@ func TestProxyProvider_DoHLookup_Google(t *testing.T) {
func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery)
url, err := p.findProxy() url, err := p.findReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, url) require.NotEmpty(t, url)
} }
@ -292,7 +292,7 @@ func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) {
func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) {
p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery) p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery)
url, err := p.findProxy() url, err := p.findReachableServer()
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, url) require.NotEmpty(t, url)
} }

View File

@ -18,7 +18,7 @@
package liveapi package liveapi
import ( import (
"github.com/ProtonMail/bridge/pkg/pmapi" "github.com/ProtonMail/proton-bridge/pkg/pmapi"
"github.com/cucumber/godog" "github.com/cucumber/godog"
"github.com/pkg/errors" "github.com/pkg/errors"
) )