Skip to content

Commit

Permalink
Add must.Do utility function (#6955)
Browse files Browse the repository at this point in the history
This can take two values (typically the return values of a two-value
function) and panic if the error is non-nil, returning the interesting
value. This is particularly useful for cases where we statically know
the call will succeed.

Thanks to @mcpherrinm for the idea!
  • Loading branch information
jsha authored Jun 26, 2023
1 parent 6206992 commit 8dcbc4c
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 35 deletions.
7 changes: 2 additions & 5 deletions ca/ca_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/letsencrypt/boulder/linter"
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/metrics"
"github.com/letsencrypt/boulder/must"
"github.com/letsencrypt/boulder/policy"
sapb "github.com/letsencrypt/boulder/sa/proto"
"github.com/letsencrypt/boulder/test"
Expand Down Expand Up @@ -107,11 +108,7 @@ const caCertFile = "../test/test-ca.pem"
const caCertFile2 = "../test/test-ca2.pem"

func mustRead(path string) []byte {
b, err := os.ReadFile(path)
if err != nil {
panic(fmt.Sprintf("unable to read %#v: %s", path, err))
}
return b
return must.Do(os.ReadFile(path))
}

type testCtx struct {
Expand Down
15 changes: 15 additions & 0 deletions must/must.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package must

// Do panics if err is not nil, otherwise returns t.
// It is useful in wrapping a two-value function call
// where you know statically that the call will succeed.
//
// Example:
//
// url := must.Do(url.Parse("http://example.com"))
func Do[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}
13 changes: 13 additions & 0 deletions must/must_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package must

import (
"net/url"
"testing"
)

func TestDo(t *testing.T) {
url := Do(url.Parse("http://example.com"))
if url.Host != "example.com" {
t.Errorf("expected host to be example.com, got %s", url.Host)
}
}
11 changes: 3 additions & 8 deletions policy/pa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/letsencrypt/boulder/features"
"github.com/letsencrypt/boulder/identifier"
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/must"
"github.com/letsencrypt/boulder/test"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -394,19 +395,13 @@ func TestChallengesForWildcard(t *testing.T) {
Value: "*.zombo.com",
}

mustConstructPA := func(t *testing.T, enabledChallenges map[core.AcmeChallenge]bool) *AuthorityImpl {
pa, err := New(enabledChallenges, blog.NewMock())
test.AssertNotError(t, err, "Couldn't create policy implementation")
return pa
}

// First try to get a challenge for the wildcard ident without the
// DNS-01 challenge type enabled. This should produce an error
var enabledChallenges = map[core.AcmeChallenge]bool{
core.ChallengeTypeHTTP01: true,
core.ChallengeTypeDNS01: false,
}
pa := mustConstructPA(t, enabledChallenges)
pa := must.Do(New(enabledChallenges, blog.NewMock()))
_, err := pa.ChallengesFor(wildcardIdent)
test.AssertError(t, err, "ChallengesFor did not error for a wildcard ident "+
"when DNS-01 was disabled")
Expand All @@ -416,7 +411,7 @@ func TestChallengesForWildcard(t *testing.T) {
// Try again with DNS-01 enabled. It should not error and
// should return only one DNS-01 type challenge
enabledChallenges[core.ChallengeTypeDNS01] = true
pa = mustConstructPA(t, enabledChallenges)
pa = must.Do(New(enabledChallenges, blog.NewMock()))
challenges, err := pa.ChallengesFor(wildcardIdent)
test.AssertNotError(t, err, "ChallengesFor errored for a wildcard ident "+
"unexpectedly")
Expand Down
30 changes: 13 additions & 17 deletions va/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/letsencrypt/boulder/core"
berrors "github.com/letsencrypt/boulder/errors"
"github.com/letsencrypt/boulder/identifier"
"github.com/letsencrypt/boulder/must"
"github.com/letsencrypt/boulder/probs"
"github.com/letsencrypt/boulder/test"
"github.com/miekg/dns"
Expand Down Expand Up @@ -195,13 +196,8 @@ func TestHTTPValidationTarget(t *testing.T) {
}

func TestExtractRequestTarget(t *testing.T) {
mustURL := func(t *testing.T, rawURL string) *url.URL {
urlOb, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("Unable to parse raw URL %q: %v", rawURL, err)
return nil
}
return urlOb
mustURL := func(rawURL string) *url.URL {
return must.Do(url.Parse(rawURL))
}

testCases := []struct {
Expand All @@ -218,7 +214,7 @@ func TestExtractRequestTarget(t *testing.T) {
{
Name: "invalid protocol scheme",
Req: &http.Request{
URL: mustURL(t, "gopher://letsencrypt.org"),
URL: mustURL("gopher://letsencrypt.org"),
},
ExpectedError: fmt.Errorf("Invalid protocol scheme in redirect target. " +
`Only "http" and "https" protocol schemes are supported, ` +
Expand All @@ -227,68 +223,68 @@ func TestExtractRequestTarget(t *testing.T) {
{
Name: "invalid explicit port",
Req: &http.Request{
URL: mustURL(t, "https://weird.port.letsencrypt.org:9999"),
URL: mustURL("https://weird.port.letsencrypt.org:9999"),
},
ExpectedError: fmt.Errorf("Invalid port in redirect target. Only ports 80 " +
"and 443 are supported, not 9999"),
},
{
Name: "invalid empty hostname",
Req: &http.Request{
URL: mustURL(t, "https:///who/needs/a/hostname?not=me"),
URL: mustURL("https:///who/needs/a/hostname?not=me"),
},
ExpectedError: errors.New("Invalid empty hostname in redirect target"),
},
{
Name: "invalid .well-known hostname",
Req: &http.Request{
URL: mustURL(t, "https://my.webserver.is.misconfigured.well-known/acme-challenge/xxx"),
URL: mustURL("https://my.webserver.is.misconfigured.well-known/acme-challenge/xxx"),
},
ExpectedError: errors.New(`Invalid host in redirect target "my.webserver.is.misconfigured.well-known". Check webserver config for missing '/' in redirect target.`),
},
{
Name: "invalid non-iana hostname",
Req: &http.Request{
URL: mustURL(t, "https://my.tld.is.cpu/pretty/cool/right?yeah=Ithoughtsotoo"),
URL: mustURL("https://my.tld.is.cpu/pretty/cool/right?yeah=Ithoughtsotoo"),
},
ExpectedError: errors.New("Invalid hostname in redirect target, must end in IANA registered TLD"),
},
{
Name: "bare IP",
Req: &http.Request{
URL: mustURL(t, "https://10.10.10.10"),
URL: mustURL("https://10.10.10.10"),
},
ExpectedError: fmt.Errorf(`Invalid host in redirect target "10.10.10.10". ` +
"Only domain names are supported, not IP addresses"),
},
{
Name: "valid HTTP redirect, explicit port",
Req: &http.Request{
URL: mustURL(t, "http://cpu.letsencrypt.org:80"),
URL: mustURL("http://cpu.letsencrypt.org:80"),
},
ExpectedHost: "cpu.letsencrypt.org",
ExpectedPort: 80,
},
{
Name: "valid HTTP redirect, implicit port",
Req: &http.Request{
URL: mustURL(t, "http://cpu.letsencrypt.org"),
URL: mustURL("http://cpu.letsencrypt.org"),
},
ExpectedHost: "cpu.letsencrypt.org",
ExpectedPort: 80,
},
{
Name: "valid HTTPS redirect, explicit port",
Req: &http.Request{
URL: mustURL(t, "https://cpu.letsencrypt.org:443/hello.world"),
URL: mustURL("https://cpu.letsencrypt.org:443/hello.world"),
},
ExpectedHost: "cpu.letsencrypt.org",
ExpectedPort: 443,
},
{
Name: "valid HTTPS redirect, implicit port",
Req: &http.Request{
URL: mustURL(t, "https://cpu.letsencrypt.org/hello.world"),
URL: mustURL("https://cpu.letsencrypt.org/hello.world"),
},
ExpectedHost: "cpu.letsencrypt.org",
ExpectedPort: 443,
Expand Down
7 changes: 2 additions & 5 deletions wfe2/wfe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/metrics"
"github.com/letsencrypt/boulder/mocks"
"github.com/letsencrypt/boulder/must"
"github.com/letsencrypt/boulder/nonce"
noncepb "github.com/letsencrypt/boulder/nonce/proto"
"github.com/letsencrypt/boulder/probs"
Expand Down Expand Up @@ -400,11 +401,7 @@ func signAndPost(signer requestSigner, path, signedURL, payload string) *http.Re
}

func mustParseURL(s string) *url.URL {
if u, err := url.Parse(s); err != nil {
panic("Cannot parse URL " + s)
} else {
return u
}
return must.Do(url.Parse(s))
}

func sortHeader(s string) string {
Expand Down

0 comments on commit 8dcbc4c

Please sign in to comment.