forked from Silverfish/proton-bridge
GODT-35: Finish all details and make tests pass
This commit is contained in:
149
pkg/pmapi/dialer_pinning_test.go
Normal file
149
pkg/pmapi/dialer_pinning_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
// Copyright (c) 2021 Proton Technologies AG
|
||||
//
|
||||
// This file is part of ProtonMail Bridge.
|
||||
//
|
||||
// ProtonMail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// ProtonMail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with ProtonMail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package pmapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
a "github.com/stretchr/testify/assert"
|
||||
r "github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTLSPinValid(t *testing.T) {
|
||||
called, _, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
checkTLSIssueHandler(t, 0, called)
|
||||
}
|
||||
|
||||
func TestTLSPinBackup(t *testing.T) {
|
||||
called, dialer, cm := createClientWithPinningDialer(getRootURL())
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
dialer.pinChecker.trustedPins[1] = dialer.pinChecker.trustedPins[0]
|
||||
dialer.pinChecker.trustedPins[0] = ""
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
checkTLSIssueHandler(t, 0, called)
|
||||
}
|
||||
|
||||
func TestTLSPinInvalid(t *testing.T) {
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSONResponsefromFile(t, w, "/auth/info/post_response.json", 0)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
called, _, cm := createClientWithPinningDialer(ts.URL)
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
checkTLSIssueHandler(t, 1, called)
|
||||
}
|
||||
|
||||
func TestTLSPinNoMatch(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
called, dialer, cm := createClientWithPinningDialer(getRootURL())
|
||||
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
for i := 0; i < len(dialer.pinChecker.trustedPins); i++ {
|
||||
dialer.pinChecker.trustedPins[i] = "testing"
|
||||
}
|
||||
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
_, _, _ = cm.NewClientWithLogin(context.Background(), "username", "pass") //nolint
|
||||
|
||||
// Check that it will be reported only once per session, but notified every time.
|
||||
r.Equal(t, 1, len(dialer.reporter.sentReports))
|
||||
checkTLSIssueHandler(t, 2, called)
|
||||
}
|
||||
|
||||
func TestTLSSignedCertWrongPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _ := createClientWithPinningDialer("")
|
||||
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
|
||||
r.Error(t, err, "expected dial to fail because of wrong public key")
|
||||
}
|
||||
|
||||
func TestTLSSignedCertTrustedPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _ := createClientWithPinningDialer("")
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="W8/42Z0ffufwnHIOSndT+eVzBJSC0E8uTIC8O6mEliQ="`)
|
||||
_, err := dialer.DialTLS("tcp", "rsa4096.badssl.com:443")
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known and cert is signed by CA")
|
||||
}
|
||||
|
||||
func TestTLSSelfSignedCertTrustedPublicKey(t *testing.T) {
|
||||
skipIfProxyIsSet(t)
|
||||
|
||||
_, dialer, _ := createClientWithPinningDialer("")
|
||||
copyTrustedPins(dialer.pinChecker)
|
||||
dialer.pinChecker.trustedPins = append(dialer.pinChecker.trustedPins, `pin-sha256="9SLklscvzMYj8f+52lp5ze/hY0CFHyLSPQzSpYYIBm8="`)
|
||||
_, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
|
||||
r.NoError(t, err, "expected dial to succeed because public key is known despite cert being self-signed")
|
||||
}
|
||||
|
||||
func createClientWithPinningDialer(hostURL string) (*int, *PinningTLSDialer, *manager) {
|
||||
called := 0
|
||||
|
||||
cfg := Config{
|
||||
AppVersion: "Bridge_1.2.4-test",
|
||||
HostURL: hostURL,
|
||||
TLSIssueHandler: func() { called++ },
|
||||
}
|
||||
|
||||
dialer := NewPinningTLSDialer(cfg, NewBasicTLSDialer(cfg))
|
||||
|
||||
cm := newManager(cfg)
|
||||
cm.SetTransport(CreateTransportWithDialer(dialer))
|
||||
|
||||
return &called, dialer, cm
|
||||
}
|
||||
|
||||
func copyTrustedPins(pinChecker *pinChecker) {
|
||||
copiedPins := make([]string, len(pinChecker.trustedPins))
|
||||
copy(copiedPins, pinChecker.trustedPins)
|
||||
pinChecker.trustedPins = copiedPins
|
||||
}
|
||||
|
||||
func checkTLSIssueHandler(t *testing.T, wantCalledAtLeast int, called *int) {
|
||||
// TLSIssueHandler is called in goroutine se we need to wait a bit to be sure it was called.
|
||||
a.Eventually(
|
||||
t,
|
||||
func() bool {
|
||||
if wantCalledAtLeast == 0 {
|
||||
return *called == 0
|
||||
}
|
||||
// Dialer can do more attempts resulting in more calls.
|
||||
return *called >= wantCalledAtLeast
|
||||
},
|
||||
time.Second,
|
||||
10*time.Millisecond,
|
||||
)
|
||||
// Repeated again so it generates nice message.
|
||||
if wantCalledAtLeast == 0 {
|
||||
r.Equal(t, 0, *called)
|
||||
} else {
|
||||
r.GreaterOrEqual(t, *called, wantCalledAtLeast)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user