diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index ac95a385..fe29ca4d 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -25,7 +25,7 @@ import ( "time" "github.com/ProtonMail/proton-bridge/internal/events" - m "github.com/ProtonMail/proton-bridge/internal/metrics" + "github.com/ProtonMail/proton-bridge/internal/metrics" "github.com/ProtonMail/proton-bridge/internal/preferences" "github.com/ProtonMail/proton-bridge/internal/store" "github.com/ProtonMail/proton-bridge/pkg/config" @@ -117,7 +117,7 @@ func New( } if pref.GetBool(preferences.FirstStartKey) { - b.SendMetric(m.New(m.Setup, m.FirstStart, m.Label(version))) + b.SendMetric(metrics.New(metrics.Setup, metrics.FirstStart, metrics.Label(version))) } go b.heartbeat() @@ -134,7 +134,7 @@ func (b *Bridge) heartbeat() { } nextTime := time.Unix(next, 0) if time.Now().After(nextTime) { - b.SendMetric(m.New(m.Heartbeat, m.Daily, m.NoLabel)) + b.SendMetric(metrics.New(metrics.Heartbeat, metrics.Daily, metrics.NoLabel)) nextTime = nextTime.Add(24 * time.Hour) b.pref.Set(preferences.NextHeartbeatKey, strconv.FormatInt(nextTime.Unix(), 10)) } @@ -180,7 +180,7 @@ func (b *Bridge) watchBridgeOutdated() { // watchAPIAuths receives auths from the client manager and sends them to the appropriate user. func (b *Bridge) watchAPIAuths() { - for auth := range b.clientManager.GetBridgeAuthChannel() { + for auth := range b.clientManager.GetAuthUpdateChannel() { logrus.Debug("Bridge received auth from ClientManager") user, ok := b.hasUser(auth.UserID) @@ -296,6 +296,9 @@ func (b *Bridge) connectExistingUser(user *User, auth *pmapi.Auth, hashedPasswor // addNewUser adds a new bridge user to the bridge. func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword string) (err error) { + b.lock.Lock() + defer b.lock.Unlock() + client := b.clientManager.GetClient(user.ID) if auth, err = client.AuthRefresh(auth.GenToken()); err != nil { @@ -325,7 +328,7 @@ func (b *Bridge) addNewUser(user *pmapi.User, auth *pmapi.Auth, hashedPassword s return errors.Wrap(err, "failed to initialise user") } - b.SendMetric(m.New(m.Setup, m.NewUser, m.NoLabel)) + b.SendMetric(metrics.New(metrics.Setup, metrics.NewUser, metrics.NoLabel)) return } @@ -459,7 +462,7 @@ func (b *Bridge) ReportBug(osType, osVersion, description, accountName, address, } // SendMetric sends a metric. We don't want to return any errors, only log them. -func (b *Bridge) SendMetric(m m.Metric) { +func (b *Bridge) SendMetric(m metrics.Metric) { c := b.clientManager.GetClient("metric_reporter") defer c.Logout() diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index 40c9185d..f8dd2820 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -214,7 +214,7 @@ func testNewBridge(t *testing.T, m mocks) *Bridge { m.config.EXPECT().GetDBDir().Return("/tmp").AnyTimes() m.config.EXPECT().GetIMAPCachePath().Return(cacheFile.Name()).AnyTimes() m.eventListener.EXPECT().Add(events.UpgradeApplicationEvent, gomock.Any()) - m.clientManager.EXPECT().GetBridgeAuthChannel().Return(make(chan *pmapi.ClientAuth)) + m.clientManager.EXPECT().GetAuthUpdateChannel().Return(make(chan *pmapi.ClientAuth)) bridge := New(m.config, m.prefProvider, m.PanicHandler, m.eventListener, "ver", m.clientManager, m.credentialsStore) diff --git a/internal/bridge/credits.go b/internal/bridge/credits.go index c92ba9b7..88539ca7 100644 --- a/internal/bridge/credits.go +++ b/internal/bridge/credits.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU General Public License // along with ProtonMail Bridge. If not, see . -// Code generated by ./credits.sh at Thu Apr 16 13:43:04 CEST 2020. DO NOT EDIT. +// Code generated by ./credits.sh at Fri Apr 17 13:33:28 CEST 2020. DO NOT EDIT. package bridge diff --git a/internal/bridge/mocks/mocks.go b/internal/bridge/mocks/mocks.go index bd4e9cda..0463001e 100644 --- a/internal/bridge/mocks/mocks.go +++ b/internal/bridge/mocks/mocks.go @@ -264,18 +264,18 @@ func (mr *MockClientManagerMockRecorder) GetAnonymousClient() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnonymousClient", reflect.TypeOf((*MockClientManager)(nil).GetAnonymousClient)) } -// GetBridgeAuthChannel mocks base method -func (m *MockClientManager) GetBridgeAuthChannel() chan pmapi.ClientAuth { +// GetAuthUpdateChannel mocks base method +func (m *MockClientManager) GetAuthUpdateChannel() chan pmapi.ClientAuth { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetBridgeAuthChannel") + ret := m.ctrl.Call(m, "GetAuthUpdateChannel") ret0, _ := ret[0].(chan pmapi.ClientAuth) return ret0 } -// GetBridgeAuthChannel indicates an expected call of GetBridgeAuthChannel -func (mr *MockClientManagerMockRecorder) GetBridgeAuthChannel() *gomock.Call { +// GetAuthUpdateChannel indicates an expected call of GetBridgeAuthChannel +func (mr *MockClientManagerMockRecorder) GetAuthUpdateChannel() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBridgeAuthChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthUpdateChannel", reflect.TypeOf((*MockClientManager)(nil).GetBridgeAuthChannel)) } // GetClient mocks base method diff --git a/internal/bridge/types.go b/internal/bridge/types.go index ea630ee4..c3b1012b 100644 --- a/internal/bridge/types.go +++ b/internal/bridge/types.go @@ -57,5 +57,5 @@ type ClientManager interface { GetAnonymousClient() pmapi.Client AllowProxy() DisallowProxy() - GetBridgeAuthChannel() chan pmapi.ClientAuth + GetAuthUpdateChannel() chan pmapi.ClientAuth } diff --git a/pkg/pmapi/clientmanager.go b/pkg/pmapi/clientmanager.go index 821dde51..1177c75f 100644 --- a/pkg/pmapi/clientmanager.go +++ b/pkg/pmapi/clientmanager.go @@ -105,7 +105,7 @@ func (cm *ClientManager) SetRoundTripper(rt http.RoundTripper) { cm.roundTripper = rt } -// GetRoundTripper sets the roundtripper used by clients created by this client manager. +// GetRoundTripper gets the roundtripper used by clients created by this client manager. func (cm *ClientManager) GetRoundTripper() (rt http.RoundTripper) { return cm.roundTripper } @@ -212,15 +212,14 @@ func (cm *ClientManager) IsProxyEnabled() bool { return cm.host != RootURL } -// SwitchToProxy returns a usable proxy server. -// TODO: Perhaps the name could be better -- we aren't only switching to a proxy but also to the standard API. -func (cm *ClientManager) SwitchToProxy() (proxy string, err error) { +// switchToReachableServer switches to using a reachable server (either proxy or standard API). +func (cm *ClientManager) switchToReachableServer() (proxy string, err error) { cm.hostLocker.Lock() defer cm.hostLocker.Unlock() logrus.Info("Attempting to switch to a proxy") - if proxy, err = cm.proxyProvider.findProxy(); err != nil { + if proxy, err = cm.proxyProvider.findReachableServer(); err != nil { err = errors.Wrap(err, "failed to find a usable proxy") return } @@ -254,8 +253,8 @@ func (cm *ClientManager) GetToken(userID string) string { return cm.tokens[userID] } -// GetBridgeAuthChannel returns a channel on which client auths can be received. -func (cm *ClientManager) GetBridgeAuthChannel() chan ClientAuth { +// GetAuthUpdateChannel returns a channel on which client auths can be received. +func (cm *ClientManager) GetAuthUpdateChannel() chan ClientAuth { return cm.bridgeAuths } @@ -346,6 +345,6 @@ func (cm *ClientManager) watchTokenExpiration(userID string) { cm.clients[userID].AuthRefresh(cm.tokens[userID]) case <-expiration.cancel: - logrus.WithField("userID", userID).Info("Auth was refreshed before it expired") + logrus.WithField("userID", userID).Debug("Auth was refreshed before it expired") } } diff --git a/pkg/pmapi/dialer_with_proxy.go b/pkg/pmapi/dialer_with_proxy.go index 9a3c7eee..5adc34fa 100644 --- a/pkg/pmapi/dialer_with_proxy.go +++ b/pkg/pmapi/dialer_with_proxy.go @@ -303,7 +303,7 @@ func (p *DialerWithPinning) dialWithProxyFallback(network, address string) (conn } // Switch to a proxy and retry the dial. - proxy, err := p.cm.SwitchToProxy() + proxy, err := p.cm.switchToReachableServer() if err != nil { return } diff --git a/pkg/pmapi/proxy.go b/pkg/pmapi/proxy.go index 2af15940..e2628711 100644 --- a/pkg/pmapi/proxy.go +++ b/pkg/pmapi/proxy.go @@ -72,10 +72,9 @@ func newProxyProvider(providers []string, query string) (p *proxyProvider) { // return } -// findProxy returns a new working proxy domain. This includes the standard API. +// findReachableServer returns a working API server (either proxy or standard API). // It returns an error if the process takes longer than ProxySearchTime. -// TODO: Perhaps the name can be better -- we might also return the standard API. -func (p *proxyProvider) findProxy() (proxy string, err error) { +func (p *proxyProvider) findReachableServer() (proxy string, err error) { if time.Now().Before(p.lastLookup.Add(proxyLookupWait)) { return "", errors.New("not looking for a proxy, too soon") } diff --git a/pkg/pmapi/proxy_test.go b/pkg/pmapi/proxy_test.go index 19a89c4d..aa678549 100644 --- a/pkg/pmapi/proxy_test.go +++ b/pkg/pmapi/proxy_test.go @@ -42,7 +42,7 @@ func TestProxyProvider_FindProxy(t *testing.T) { p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } - url, err := p.findProxy() + url, err := p.findReachableServer() require.NoError(t, err) require.Equal(t, proxy.URL, url) } @@ -61,7 +61,7 @@ func TestProxyProvider_FindProxy_ChooseReachableProxy(t *testing.T) { p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, goodProxy.URL}, nil } - url, err := p.findProxy() + url, err := p.findReachableServer() require.NoError(t, err) require.Equal(t, goodProxy.URL, url) } @@ -80,7 +80,7 @@ func TestProxyProvider_FindProxy_FailIfNoneReachable(t *testing.T) { p := newProxyProvider([]string{"not used"}, "not used") p.dohLookup = func(q, p string) ([]string, error) { return []string{badProxy.URL, anotherBadProxy.URL}, nil } - _, err := p.findProxy() + _, err := p.findReachableServer() require.Error(t, err) } @@ -95,8 +95,8 @@ func TestProxyProvider_FindProxy_LookupTimeout(t *testing.T) { p.lookupTimeout = time.Second p.dohLookup = func(q, p string) ([]string, error) { time.Sleep(2 * time.Second); return nil, nil } - // The findProxy should fail because lookup takes 2 seconds but we only allow 1 second. - _, err := p.findProxy() + // The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second. + _, err := p.findReachableServer() require.Error(t, err) } @@ -113,8 +113,8 @@ func TestProxyProvider_FindProxy_FindTimeout(t *testing.T) { p.findTimeout = time.Second p.dohLookup = func(q, p string) ([]string, error) { return []string{slowProxy.URL}, nil } - // The findProxy should fail because lookup takes 2 seconds but we only allow 1 second. - _, err := p.findProxy() + // The findReachableServer should fail because lookup takes 2 seconds but we only allow 1 second. + _, err := p.findReachableServer() require.Error(t, err) } @@ -131,7 +131,7 @@ func TestProxyProvider_UseProxy(t *testing.T) { cm.proxyProvider = p p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } - url, err := cm.SwitchToProxy() + url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, cm.GetHost()) @@ -154,7 +154,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { cm.proxyProvider = p p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL}, nil } - url, err := cm.SwitchToProxy() + url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy1.URL, url) require.Equal(t, proxy1.URL, cm.GetHost()) @@ -163,7 +163,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { time.Sleep(proxyLookupWait) p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy2.URL}, nil } - url, err = cm.SwitchToProxy() + url, err = cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy2.URL, url) require.Equal(t, proxy2.URL, cm.GetHost()) @@ -172,7 +172,7 @@ func TestProxyProvider_UseProxy_MultipleTimes(t *testing.T) { time.Sleep(proxyLookupWait) p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy3.URL}, nil } - url, err = cm.SwitchToProxy() + url, err = cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy3.URL, url) require.Equal(t, proxy3.URL, cm.GetHost()) @@ -192,7 +192,7 @@ func TestProxyProvider_UseProxy_RevertAfterTime(t *testing.T) { cm.proxyUseDuration = time.Second p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } - url, err := cm.SwitchToProxy() + url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, cm.GetHost()) @@ -214,7 +214,7 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab cm.proxyProvider = p p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy.URL}, nil } - url, err := cm.SwitchToProxy() + url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy.URL, url) require.Equal(t, proxy.URL, cm.GetHost()) @@ -225,7 +225,7 @@ func TestProxyProvider_UseProxy_RevertIfProxyStopsWorkingAndOriginalAPIIsReachab time.Sleep(proxyLookupWait) // We should now find the original API URL if it is working again. - url, err = cm.SwitchToProxy() + url, err = cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, RootURL, url) require.Equal(t, RootURL, cm.GetHost()) @@ -247,7 +247,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl // Find a proxy. p.dohLookup = func(q, p string) ([]string, error) { return []string{proxy1.URL, proxy2.URL}, nil } - url, err := cm.SwitchToProxy() + url, err := cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy1.URL, url) require.Equal(t, proxy1.URL, cm.GetHost()) @@ -259,7 +259,7 @@ func TestProxyProvider_UseProxy_FindSecondAlternativeIfFirstFailsAndAPIIsStillBl proxy1.Close() // Should switch to the second proxy because both the first proxy and the protonmail API are blocked. - url, err = cm.SwitchToProxy() + url, err = cm.switchToReachableServer() require.NoError(t, err) require.Equal(t, proxy2.URL, url) require.Equal(t, proxy2.URL, cm.GetHost()) @@ -284,7 +284,7 @@ func TestProxyProvider_DoHLookup_Google(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { p := newProxyProvider([]string{TestQuad9Provider, TestGoogleProvider}, TestDoHQuery) - url, err := p.findProxy() + url, err := p.findReachableServer() require.NoError(t, err) require.NotEmpty(t, url) } @@ -292,7 +292,7 @@ func TestProxyProvider_DoHLookup_FindProxy(t *testing.T) { func TestProxyProvider_DoHLookup_FindProxyFirstProviderUnreachable(t *testing.T) { p := newProxyProvider([]string{"https://unreachable", TestGoogleProvider}, TestDoHQuery) - url, err := p.findProxy() + url, err := p.findReachableServer() require.NoError(t, err) require.NotEmpty(t, url) } diff --git a/test/liveapi/users.go b/test/liveapi/users.go index 4238a091..da1c045b 100644 --- a/test/liveapi/users.go +++ b/test/liveapi/users.go @@ -18,7 +18,7 @@ package liveapi import ( - "github.com/ProtonMail/bridge/pkg/pmapi" + "github.com/ProtonMail/proton-bridge/pkg/pmapi" "github.com/cucumber/godog" "github.com/pkg/errors" )