feat: report corrupt update files

This commit is contained in:
James Houlahan
2021-01-15 10:31:19 +01:00
parent b9ee4a152a
commit adcf0827ee
4 changed files with 50 additions and 16 deletions

View File

@ -18,6 +18,7 @@
package main package main
import ( import (
"fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@ -43,9 +44,9 @@ var (
) )
func main() { // nolint[funlen] func main() { // nolint[funlen]
sentryReporter := sentry.NewReporter(appName, constants.Version) reporter := sentry.NewReporter(appName, constants.Version)
crashHandler := crash.NewHandler(sentryReporter.Report) crashHandler := crash.NewHandler(reporter.ReportException)
defer crashHandler.HandlePanic() defer crashHandler.HandlePanic()
locationsProvider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, ConfigName)) locationsProvider, err := locations.NewDefaultProvider(filepath.Join(constants.VendorName, ConfigName))
@ -84,7 +85,7 @@ func main() { // nolint[funlen]
versioner := versioner.New(updatesPath) versioner := versioner.New(updatesPath)
exe, err := getPathToExecutable(ExeName, versioner, kr) exe, err := getPathToExecutable(ExeName, versioner, kr, reporter)
if err != nil { if err != nil {
if exe, err = getFallbackExecutable(ExeName, versioner); err != nil { if exe, err = getFallbackExecutable(ExeName, versioner); err != nil {
logrus.WithError(err).Fatal("Failed to find any launchable executable") logrus.WithError(err).Fatal("Failed to find any launchable executable")
@ -140,7 +141,12 @@ func appendLauncherPath(path string, args []string) []string {
return res return res
} }
func getPathToExecutable(name string, versioner *versioner.Versioner, kr *crypto.KeyRing) (string, error) { func getPathToExecutable(
name string,
versioner *versioner.Versioner,
kr *crypto.KeyRing,
reporter *sentry.Reporter,
) (string, error) {
versions, err := versioner.ListVersions() versions, err := versioner.ListVersions()
if err != nil { if err != nil {
return "", errors.Wrap(err, "failed to list available versions") return "", errors.Wrap(err, "failed to list available versions")
@ -152,6 +158,10 @@ func getPathToExecutable(name string, versioner *versioner.Versioner, kr *crypto
if err := version.VerifyFiles(kr); err != nil { if err := version.VerifyFiles(kr); err != nil {
vlog.WithError(err).Error("Files failed verification and will be removed") vlog.WithError(err).Error("Files failed verification and will be removed")
if err := reporter.ReportMessage(fmt.Sprintf("version %v failed verification: %v", version, err)); err != nil {
vlog.WithError(err).Error("Failed to report corrupt update files")
}
if err := version.Remove(); err != nil { if err := version.Remove(); err != nil {
vlog.WithError(err).Error("Failed to remove files") vlog.WithError(err).Error("Failed to remove files")
} }

View File

@ -93,7 +93,7 @@ func New( // nolint[funlen]
sentryReporter := sentry.NewReporter(appName, constants.Version) sentryReporter := sentry.NewReporter(appName, constants.Version)
crashHandler := crash.NewHandler( crashHandler := crash.NewHandler(
sentryReporter.Report, sentryReporter.ReportException,
crash.ShowErrorNotification(appName), crash.ShowErrorNotification(appName),
) )
defer crashHandler.HandlePanic() defer crashHandler.HandlePanic()

View File

@ -19,7 +19,8 @@ package versioner
import ( import (
"bytes" "bytes"
"errors" "encoding/base64"
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -50,6 +51,10 @@ func (v Versions) Swap(i, j int) {
v[i], v[j] = v[j], v[i] v[i], v[j] = v[j], v[i]
} }
func (v *Version) String() string {
return fmt.Sprintf("%v", v.version)
}
// VerifyFiles verifies all files in the version directory. // VerifyFiles verifies all files in the version directory.
func (v *Version) VerifyFiles(kr *crypto.KeyRing) error { func (v *Version) VerifyFiles(kr *crypto.KeyRing) error {
fileBytes, err := ioutil.ReadFile(filepath.Join(v.path, sumFile)) // nolint[gosec] fileBytes, err := ioutil.ReadFile(filepath.Join(v.path, sumFile)) // nolint[gosec]
@ -76,7 +81,11 @@ func (v *Version) VerifyFiles(kr *crypto.KeyRing) error {
} }
if !bytes.Equal(sum, fileBytes) { if !bytes.Equal(sum, fileBytes) {
return errors.New("sum mismatch") return fmt.Errorf(
"sum mismatch: %v should be %v",
base64.RawStdEncoding.EncodeToString(sum),
base64.RawStdEncoding.EncodeToString(fileBytes),
)
} }
return nil return nil

View File

@ -67,8 +67,30 @@ func (r *Reporter) SetUserAgentProvider(uap userAgentProvider) {
r.uap = uap r.uap = uap
} }
func (r *Reporter) ReportException(i interface{}) error {
err := fmt.Errorf("recover: %v", i)
return r.scopedReport(func() {
if eventID := sentry.CaptureException(err); eventID != nil {
logrus.WithError(err).
WithField("reportID", *eventID).
Warn("Captured exception")
}
})
}
func (r *Reporter) ReportMessage(msg string) error {
return r.scopedReport(func() {
if eventID := sentry.CaptureMessage(msg); eventID != nil {
logrus.WithField("message", msg).
WithField("reportID", *eventID).
Warn("Captured message")
}
})
}
// Report reports a sentry crash with stacktrace from all goroutines. // Report reports a sentry crash with stacktrace from all goroutines.
func (r *Reporter) Report(i interface{}) (err error) { func (r *Reporter) scopedReport(doReport func()) error {
SkipDuringUnwind() SkipDuringUnwind()
if os.Getenv("PROTONMAIL_ENV") == "dev" { if os.Getenv("PROTONMAIL_ENV") == "dev" {
@ -83,8 +105,6 @@ func (r *Reporter) Report(i interface{}) (err error) {
userAgent = runtime.GOOS userAgent = runtime.GOOS
} }
reportErr := fmt.Errorf("recover: %v", i)
tags := map[string]string{ tags := map[string]string{
"OS": runtime.GOOS, "OS": runtime.GOOS,
"Client": r.appName, "Client": r.appName,
@ -93,21 +113,16 @@ func (r *Reporter) Report(i interface{}) (err error) {
"UserID": "", "UserID": "",
} }
var reportID string
sentry.WithScope(func(scope *sentry.Scope) { sentry.WithScope(func(scope *sentry.Scope) {
SkipDuringUnwind() SkipDuringUnwind()
scope.SetTags(tags) scope.SetTags(tags)
if eventID := sentry.CaptureException(reportErr); eventID != nil { doReport()
reportID = string(*eventID)
}
}) })
if !sentry.Flush(time.Second * 10) { if !sentry.Flush(time.Second * 10) {
return errors.New("failed to report sentry error") return errors.New("failed to report sentry error")
} }
logrus.WithField("error", reportErr).WithField("id", reportID).Warn("Sentry error reported")
return nil return nil
} }