Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tlsobs main routine. #428

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ $ mkdir $GOPATH
$ export PATH=$GOPATH/bin:$PATH
```

Also required is bash >=v4.
```bash
$ /usr/bin/env bash
$ echo ${BASH_VERSINFO[0]}
4
```
Version 4 or more is ok.

Then get the binary:

```bash
Expand Down
119 changes: 77 additions & 42 deletions tlsobs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"strings"
"text/tabwriter"
Expand Down Expand Up @@ -34,10 +35,6 @@ func usage() {
os.Args[0], os.Args[0])
}

type scan struct {
ID int64 `json:"scan_id"`
}

var (
observatory = flag.String("observatory", "https://tls-observatory.services.mozilla.com", "URL of the observatory")
scanid = flag.Int64("scanid", 0, "View results from a previous scan instead of starting a new one. eg `1234`")
Expand All @@ -54,55 +51,26 @@ var exitCode int = 0
func main() {
var (
err error
scan scan
rescanP string
results database.Scan
resp *http.Response
body []byte
target string
)
flag.Usage = func() {
usage()
flag.PrintDefaults()
}
flag.Parse()
if *scanid > 0 {
goto getresults
}
if len(flag.Args()) != 1 {
fmt.Println("error: must take only 1 non-flag argument as the target")
usage()
os.Exit(1)
}

target = strings.TrimPrefix(flag.Arg(0), "https://")
// also trim http:// prefix ( in case someone has a really wrong idea of what
// the observatory does...)
target = strings.TrimPrefix(target, "http://")
target = strings.TrimSuffix(target, "/") // trailing slash

if *rescan {
rescanP = "&rescan=true"
}
resp, err = http.Post(*observatory+"/api/v1/scan?target="+target+rescanP, "application/json", nil)
if err != nil {
panic(err)
}
defer resp.Body.Close()
body, err = ioutil.ReadAll(resp.Body)
if err != nil {
panic(err)
}
if resp.StatusCode != http.StatusOK {
log.Fatalf("Scan failed. HTTP %d: %s", resp.StatusCode, body)
}
err = json.Unmarshal(body, &scan)
if err != nil {
log.Fatalf("Scan initiation failed: %s", body)
if *scanid == 0 {
exitOnInvalidArg()
targetURL := mustURL(buildScanURL(*observatory, flag.Arg(0), *rescan))
*scanid, err = postScan(targetURL)
if err != nil {
panic(err)
}
}
*scanid = scan.ID
fmt.Printf("Scanning %s (id %d)\n", flag.Arg(0), *scanid)
getresults:

has_cert := false
for {
resp, err = http.Get(fmt.Sprintf("%s/api/v1/results?id=%d", *observatory, *scanid))
Expand Down Expand Up @@ -147,7 +115,7 @@ getresults:
}
fmt.Printf("\n")
if !results.Has_tls {
fmt.Printf("%s does not support SSL/TLS\n", target)
fmt.Printf("%s does not support SSL/TLS\n", results.Target)
exitCode = 5
} else {
if *printRaw {
Expand All @@ -160,6 +128,35 @@ getresults:
os.Exit(exitCode)
}

func postScan(scanURL *url.URL) (int64, error) {
type scan struct {
ID int64 `json:"scan_id"`
}

resp, err := http.Post(scanURL.String(), "application/json", nil)
if err != nil {
log.Printf("unable to post to '%s': %v", scanURL, err)
return 0, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Printf("unable to read response from '%s': %v", scanURL, err)
return 0, err
}
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("scan failed with error code %d: %s", resp.StatusCode, body)
}

var scanJSON scan
err = json.Unmarshal(body, &scanJSON)
if err != nil {
return 0, fmt.Errorf("unexpected scan response %s: %v", body, err)

}
return scanJSON.ID, nil
}

func printCert(id int64) {
var (
cert certificate.Certificate
Expand Down Expand Up @@ -326,3 +323,41 @@ func getPaths(id int64) (paths certificate.Paths) {
}
return
}

func mustURL(u string) *url.URL {
parsed, err := url.Parse(u)
if err != nil {
log.Fatalf("url '%s' is invalid: %v", u, err)
}
return parsed
}

func buildScanURL(observatoryURL string, targetURL string, rescan bool) string {
//goland:noinspection HttpUrlsUsage
normalizedTargetURL := string(trim(targetURL).
TrimPrefix("https://").
TrimPrefix("http://"). //someone has a really wrong idea of what the observatory does...
TrimSuffix("/")) // trailing slash
rescanP := ""
if rescan {
rescanP = "&rescan=true"
}
return observatoryURL + "/api/v1/scan?target=" + normalizedTargetURL + rescanP
}

func exitOnInvalidArg() {
if len(flag.Args()) != 1 {
fmt.Println("error: must take only 1 non-flag argument as the target")
usage()
os.Exit(1)
}
}

type trim string

func (t trim) TrimPrefix(prefix string) trim {
return trim(strings.TrimPrefix(string(t), prefix))
}
func (t trim) TrimSuffix(suffix string) trim {
return trim(strings.TrimSuffix(string(t), suffix))
}
38 changes: 38 additions & 0 deletions tlsobs/scan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package main

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)

func TestScanTooManyRequests(t *testing.T) {
t.Parallel()
server := testServer(http.StatusTooManyRequests, "try later")
defer server.Close()
_, err := postScan(mustURL(server.URL))
if err.Error() != fmt.Sprintf("scan failed with error code 429: try later") {
t.Fatalf("server responded with too many requests and client did not handle it")
}
}

func TestScanOK(t *testing.T) {
t.Parallel()
server := testServer(http.StatusOK, "{\"scan_id\": 3}")
defer server.Close()
scanID, err := postScan(mustURL(server.URL))
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if scanID != 3 {
t.Fatalf("unexpected scan id: %d [expected 3]", scanID)
}
}

func testServer(statusCode int, body string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
_, _ = w.Write([]byte(body))
}))
}
51 changes: 51 additions & 0 deletions tlsobs/url_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package main

import (
"fmt"
"testing"
)

func TestBuildScanURL(t *testing.T) {
t.Parallel()
//goland:noinspection HttpUrlsUsage
testCases := []struct {
targetURL string
rescan bool
expectedURL string
}{
{"target.com", false, "https://observatory.com/api/v1/scan?target=target.com"},
{"http://target.com", false, "https://observatory.com/api/v1/scan?target=target.com"},
{"https://target.com", false, "https://observatory.com/api/v1/scan?target=target.com"},
{"https://target.com", true, "https://observatory.com/api/v1/scan?target=target.com&rescan=true"},
{"https://target.com/", true, "https://observatory.com/api/v1/scan?target=target.com&rescan=true"},
}

for _, tc := range testCases {
name := fmt.Sprintf("target=%s rescan=%t", tc.targetURL, tc.rescan)
t.Run(name, func(t *testing.T) {
t.Parallel()
u := buildScanURL("https://observatory.com", tc.targetURL, tc.rescan)
if u != tc.expectedURL {
t.Fatalf("expected '%s' == '%s'", u, tc.expectedURL)
}
})
}
}

func TestMustURL(t *testing.T) {
t.Parallel()
testCases := []string{
"https://observatory.com",
"observatory.com",
"foo",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
t.Parallel()
u := mustURL(tc)
if u.String() != tc {
t.Fatalf("expected '%s'", tc)
}
})
}
}