From cb7ae4d812d778e303fa8e02ee423695a87e2bfa Mon Sep 17 00:00:00 2001 From: Andrew Bernat Date: Sun, 15 Aug 2021 11:40:12 -0700 Subject: [PATCH] Refactor tlsobs main routine. Invoke `postScan` to post to scan endpoint instead of having the code inline. Added tests for postScan and building url to scan endpoint. --- README.md | 8 +++ tlsobs/main.go | 119 ++++++++++++++++++++++++++++---------------- tlsobs/scan_test.go | 38 ++++++++++++++ tlsobs/url_test.go | 51 +++++++++++++++++++ 4 files changed, 174 insertions(+), 42 deletions(-) create mode 100644 tlsobs/scan_test.go create mode 100644 tlsobs/url_test.go diff --git a/README.md b/README.md index 69c85bb6c..f7a75f492 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,14 @@ $ mkdir $GOPATH $ export PATH=$GOPATH/bin:$PATH ``` +Also required is bash >=v4. +```bash +$ /usr/bin/env bash +$ echo ${BASH_VERSINFO[0]} +4 +``` +Version 4 or more is ok. + Then get the binary: ```bash diff --git a/tlsobs/main.go b/tlsobs/main.go index 0b7e686c4..e5c7f48b5 100644 --- a/tlsobs/main.go +++ b/tlsobs/main.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "net/http" + "net/url" "os" "strings" "text/tabwriter" @@ -34,10 +35,6 @@ func usage() { os.Args[0], os.Args[0]) } -type scan struct { - ID int64 `json:"scan_id"` -} - var ( observatory = flag.String("observatory", "https://tls-observatory.services.mozilla.com", "URL of the observatory") scanid = flag.Int64("scanid", 0, "View results from a previous scan instead of starting a new one. eg `1234`") @@ -54,55 +51,26 @@ var exitCode int = 0 func main() { var ( err error - scan scan - rescanP string results database.Scan resp *http.Response body []byte - target string ) flag.Usage = func() { usage() flag.PrintDefaults() } flag.Parse() - if *scanid > 0 { - goto getresults - } - if len(flag.Args()) != 1 { - fmt.Println("error: must take only 1 non-flag argument as the target") - usage() - os.Exit(1) - } - target = strings.TrimPrefix(flag.Arg(0), "https://") - // also trim http:// prefix ( in case someone has a really wrong idea of what - // the observatory does...) - target = strings.TrimPrefix(target, "http://") - target = strings.TrimSuffix(target, "/") // trailing slash - - if *rescan { - rescanP = "&rescan=true" - } - resp, err = http.Post(*observatory+"/api/v1/scan?target="+target+rescanP, "application/json", nil) - if err != nil { - panic(err) - } - defer resp.Body.Close() - body, err = ioutil.ReadAll(resp.Body) - if err != nil { - panic(err) - } - if resp.StatusCode != http.StatusOK { - log.Fatalf("Scan failed. HTTP %d: %s", resp.StatusCode, body) - } - err = json.Unmarshal(body, &scan) - if err != nil { - log.Fatalf("Scan initiation failed: %s", body) + if *scanid == 0 { + exitOnInvalidArg() + targetURL := mustURL(buildScanURL(*observatory, flag.Arg(0), *rescan)) + *scanid, err = postScan(targetURL) + if err != nil { + panic(err) + } } - *scanid = scan.ID fmt.Printf("Scanning %s (id %d)\n", flag.Arg(0), *scanid) -getresults: + has_cert := false for { resp, err = http.Get(fmt.Sprintf("%s/api/v1/results?id=%d", *observatory, *scanid)) @@ -147,7 +115,7 @@ getresults: } fmt.Printf("\n") if !results.Has_tls { - fmt.Printf("%s does not support SSL/TLS\n", target) + fmt.Printf("%s does not support SSL/TLS\n", results.Target) exitCode = 5 } else { if *printRaw { @@ -160,6 +128,35 @@ getresults: os.Exit(exitCode) } +func postScan(scanURL *url.URL) (int64, error) { + type scan struct { + ID int64 `json:"scan_id"` + } + + resp, err := http.Post(scanURL.String(), "application/json", nil) + if err != nil { + log.Printf("unable to post to '%s': %v", scanURL, err) + return 0, err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("unable to read response from '%s': %v", scanURL, err) + return 0, err + } + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("scan failed with error code %d: %s", resp.StatusCode, body) + } + + var scanJSON scan + err = json.Unmarshal(body, &scanJSON) + if err != nil { + return 0, fmt.Errorf("unexpected scan response %s: %v", body, err) + + } + return scanJSON.ID, nil +} + func printCert(id int64) { var ( cert certificate.Certificate @@ -326,3 +323,41 @@ func getPaths(id int64) (paths certificate.Paths) { } return } + +func mustURL(u string) *url.URL { + parsed, err := url.Parse(u) + if err != nil { + log.Fatalf("url '%s' is invalid: %v", u, err) + } + return parsed +} + +func buildScanURL(observatoryURL string, targetURL string, rescan bool) string { + //goland:noinspection HttpUrlsUsage + normalizedTargetURL := string(trim(targetURL). + TrimPrefix("https://"). + TrimPrefix("http://"). //someone has a really wrong idea of what the observatory does... + TrimSuffix("/")) // trailing slash + rescanP := "" + if rescan { + rescanP = "&rescan=true" + } + return observatoryURL + "/api/v1/scan?target=" + normalizedTargetURL + rescanP +} + +func exitOnInvalidArg() { + if len(flag.Args()) != 1 { + fmt.Println("error: must take only 1 non-flag argument as the target") + usage() + os.Exit(1) + } +} + +type trim string + +func (t trim) TrimPrefix(prefix string) trim { + return trim(strings.TrimPrefix(string(t), prefix)) +} +func (t trim) TrimSuffix(suffix string) trim { + return trim(strings.TrimSuffix(string(t), suffix)) +} diff --git a/tlsobs/scan_test.go b/tlsobs/scan_test.go new file mode 100644 index 000000000..251ba491a --- /dev/null +++ b/tlsobs/scan_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestScanTooManyRequests(t *testing.T) { + t.Parallel() + server := testServer(http.StatusTooManyRequests, "try later") + defer server.Close() + _, err := postScan(mustURL(server.URL)) + if err.Error() != fmt.Sprintf("scan failed with error code 429: try later") { + t.Fatalf("server responded with too many requests and client did not handle it") + } +} + +func TestScanOK(t *testing.T) { + t.Parallel() + server := testServer(http.StatusOK, "{\"scan_id\": 3}") + defer server.Close() + scanID, err := postScan(mustURL(server.URL)) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if scanID != 3 { + t.Fatalf("unexpected scan id: %d [expected 3]", scanID) + } +} + +func testServer(statusCode int, body string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + _, _ = w.Write([]byte(body)) + })) +} diff --git a/tlsobs/url_test.go b/tlsobs/url_test.go new file mode 100644 index 000000000..e259c828f --- /dev/null +++ b/tlsobs/url_test.go @@ -0,0 +1,51 @@ +package main + +import ( + "fmt" + "testing" +) + +func TestBuildScanURL(t *testing.T) { + t.Parallel() + //goland:noinspection HttpUrlsUsage + testCases := []struct { + targetURL string + rescan bool + expectedURL string + }{ + {"target.com", false, "https://observatory.com/api/v1/scan?target=target.com"}, + {"http://target.com", false, "https://observatory.com/api/v1/scan?target=target.com"}, + {"https://target.com", false, "https://observatory.com/api/v1/scan?target=target.com"}, + {"https://target.com", true, "https://observatory.com/api/v1/scan?target=target.com&rescan=true"}, + {"https://target.com/", true, "https://observatory.com/api/v1/scan?target=target.com&rescan=true"}, + } + + for _, tc := range testCases { + name := fmt.Sprintf("target=%s rescan=%t", tc.targetURL, tc.rescan) + t.Run(name, func(t *testing.T) { + t.Parallel() + u := buildScanURL("https://observatory.com", tc.targetURL, tc.rescan) + if u != tc.expectedURL { + t.Fatalf("expected '%s' == '%s'", u, tc.expectedURL) + } + }) + } +} + +func TestMustURL(t *testing.T) { + t.Parallel() + testCases := []string{ + "https://observatory.com", + "observatory.com", + "foo", + } + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + t.Parallel() + u := mustURL(tc) + if u.String() != tc { + t.Fatalf("expected '%s'", tc) + } + }) + } +}