Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow missing trailing slashes in chainUploadLocation #970

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion signer/contentsignaturepki/contentsignature.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"hash"
"io"
"math/big"
"net/http"
"time"

"github.com/mozilla-services/autograph/database"
Expand Down Expand Up @@ -187,7 +188,7 @@ func (s *ContentSigner) initEE(conf signer.Configuration) error {
default:
return fmt.Errorf("contentsignaturepki %q: failed to find suitable end-entity: %w", s.ID, err)
}
_, _, err = GetX5U(buildHTTPClient(), s.X5U)
_, _, err = GetX5U(http.DefaultClient, s.X5U)
if err != nil {
return fmt.Errorf("contentsignaturepki %q: failed to verify x5u: %w", s.ID, err)
}
Expand Down
126 changes: 65 additions & 61 deletions signer/contentsignaturepki/contentsignature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package contentsignaturepki

import (
"crypto/ecdsa"
"fmt"
"net/http"
"strings"
"testing"

Expand All @@ -18,73 +20,75 @@ import (
func TestSign(t *testing.T) {
input := []byte("foobarbaz1234abcd")
for i, testcase := range PASSINGTESTCASES {
// initialize a signer
s, err := New(testcase.cfg)
if err != nil {
t.Fatalf("testcase %d signer initialization failed with: %v", i, err)
}
if s.Type != testcase.cfg.Type {
t.Fatalf("testcase %d signer type %q does not match configuration %q", i, s.Type, testcase.cfg.Type)
}
if s.ID != testcase.cfg.ID {
t.Fatalf("testcase %d signer id %q does not match configuration %q", i, s.ID, testcase.cfg.ID)
}
if s.PrivateKey != testcase.cfg.PrivateKey {
t.Fatalf("testcase %d signer private key %q does not match configuration %q", i, s.PrivateKey, testcase.cfg.PrivateKey)
}
if s.Mode != testcase.cfg.Mode {
t.Fatalf("testcase %d signer curve %q does not match expected %q", i, s.Mode, testcase.cfg.Mode)
}
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
// initialize a signer
s, err := New(testcase.cfg)
if err != nil {
t.Fatalf("testcase %d signer initialization failed with: %v", i, err)
}
if s.Type != testcase.cfg.Type {
t.Fatalf("testcase %d signer type %q does not match configuration %q", i, s.Type, testcase.cfg.Type)
}
if s.ID != testcase.cfg.ID {
t.Fatalf("testcase %d signer id %q does not match configuration %q", i, s.ID, testcase.cfg.ID)
}
if s.PrivateKey != testcase.cfg.PrivateKey {
t.Fatalf("testcase %d signer private key %q does not match configuration %q", i, s.PrivateKey, testcase.cfg.PrivateKey)
}
if s.Mode != testcase.cfg.Mode {
t.Fatalf("testcase %d signer curve %q does not match expected %q", i, s.Mode, testcase.cfg.Mode)
}

// sign input data
sig, err := s.SignData(input, nil)
if err != nil {
t.Fatalf("testcase %d failed to sign data: %v", i, err)
}
// convert signature to string format
sigstr, err := sig.Marshal()
if err != nil {
t.Fatalf("testcase %d failed to marshal signature: %v", i, err)
}
// sign input data
sig, err := s.SignData(input, nil)
if err != nil {
t.Fatalf("testcase %d failed to sign data: %v", i, err)
}
// convert signature to string format
sigstr, err := sig.Marshal()
if err != nil {
t.Fatalf("testcase %d failed to marshal signature: %v", i, err)
}

// convert string format back to signature
cs, err := verifier.Unmarshal(sigstr)
if err != nil {
t.Fatalf("testcase %d failed to unmarshal signature: %v", i, err)
}
// convert string format back to signature
cs, err := verifier.Unmarshal(sigstr)
if err != nil {
t.Fatalf("testcase %d failed to unmarshal signature: %v", i, err)
}

// make sure we still have the same string representation
sigstr2, err := cs.Marshal()
if err != nil {
t.Fatalf("testcase %d failed to re-marshal signature: %v", i, err)
}
if sigstr != sigstr2 {
t.Fatalf("testcase %d marshalling signature changed its format.\nexpected\t%q\nreceived\t%q",
i, sigstr, sigstr2)
}
// make sure we still have the same string representation
sigstr2, err := cs.Marshal()
if err != nil {
t.Fatalf("testcase %d failed to re-marshal signature: %v", i, err)
}
if sigstr != sigstr2 {
t.Fatalf("testcase %d marshalling signature changed its format.\nexpected\t%q\nreceived\t%q",
i, sigstr, sigstr2)
}

if cs.Len != getSignatureLen(s.Mode) {
t.Fatalf("testcase %d expected signature len of %d, got %d",
i, getSignatureLen(s.Mode), cs.Len)
}
if cs.Mode != s.Mode {
t.Fatalf("testcase %d expected curve name %q, got %q", i, s.Mode, cs.Mode)
}
if cs.Len != getSignatureLen(s.Mode) {
t.Fatalf("testcase %d expected signature len of %d, got %d",
i, getSignatureLen(s.Mode), cs.Len)
}
if cs.Mode != s.Mode {
t.Fatalf("testcase %d expected curve name %q, got %q", i, s.Mode, cs.Mode)
}

// verify the signature using the public key of the end entity
_, certs, err := GetX5U(buildHTTPClient(), s.X5U)
if err != nil {
t.Fatalf("testcase %d failed to get X5U %q: %v", i, s.X5U, err)
}
leaf := certs[0]
key := leaf.PublicKey.(*ecdsa.PublicKey)
if !sig.(*verifier.ContentSignature).VerifyData([]byte(input), key) {
t.Fatalf("testcase %d failed to verify signature", i)
}
// verify the signature using the public key of the end entity
_, certs, err := GetX5U(http.DefaultClient, s.X5U)
if err != nil {
t.Fatalf("testcase %d failed to get X5U %q: %v", i, s.X5U, err)
}
leaf := certs[0]
key := leaf.PublicKey.(*ecdsa.PublicKey)
if !sig.(*verifier.ContentSignature).VerifyData([]byte(input), key) {
t.Fatalf("testcase %d failed to verify signature", i)
}

if leaf.Subject.CommonName != testcase.expectedCommonName {
t.Errorf("testcase %d expected common name %#v, got %#v", i, testcase.expectedCommonName, leaf.Subject.CommonName)
}
if leaf.Subject.CommonName != testcase.expectedCommonName {
t.Errorf("testcase %d expected common name %#v, got %#v", i, testcase.expectedCommonName, leaf.Subject.CommonName)
}
})
}
}

Expand Down
72 changes: 39 additions & 33 deletions signer/contentsignaturepki/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -53,9 +55,12 @@ func (s *ContentSigner) upload(data, name string) error {
}

func uploadToS3(client S3UploadAPI, data, name string, target *url.URL) error {
// aws-sdk-go-v2 now includes leading slashes in the key name, where v1 did
// not. So, to keep this code compatible, we have to trim it.
keyName := strings.TrimPrefix(path.Join(target.Path, name), "/")
_, err := client.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(target.Host),
Key: aws.String(target.Path + name),
Key: aws.String(keyName),
ACL: types.ObjectCannedACLPublicRead,
Body: strings.NewReader(data),
ContentType: aws.String("binary/octet-stream"),
Expand All @@ -78,56 +83,57 @@ func writeLocalFile(data, name string, target *url.URL) error {
return err
}
}
// write the file into the target dir
return os.WriteFile(target.Path+name, []byte(data), 0755)
}

// buildHTTPClient returns the default HTTP.Client for fetching X5Us
func buildHTTPClient() *http.Client {
return &http.Client{}
return os.WriteFile(filepath.Join(target.Path, name), []byte(data), 0755)
}

// GetX5U retrieves a chain file of certs from upload location, parses
// and verifies it, then returns a byte slice of the response body and
// a slice of parsed certificates.
func GetX5U(client *http.Client, x5u string) (body []byte, certs []*x509.Certificate, err error) {
func GetX5U(client *http.Client, x5u string) ([]byte, []*x509.Certificate, error) {
parsedURL, err := url.Parse(x5u)
if err != nil {
err = fmt.Errorf("failed to parse chain upload location: %w", err)
return
}
if parsedURL.Scheme == "file" {
t := &http.Transport{}
t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
client.Transport = t
}
resp, err := client.Get(x5u)
if err != nil {
err = fmt.Errorf("failed to retrieve x5u: %w", err)
return
return nil, nil, fmt.Errorf("failed to parse chain upload location: %w", err)

}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("failed to retrieve x5u from %s: %s", x5u, resp.Status)
return
var bodyReader io.ReadCloser
switch parsedURL.Scheme {
case "https":
resp, err := client.Get(x5u)
if err != nil {
return nil, nil, fmt.Errorf("failed to retrieve x5u from %#v: %w", x5u, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("failed to retrieve x5u from %#v: %s", x5u, resp.Status)
}
bodyReader = resp.Body

case "file":
bodyReader, err = os.Open(parsedURL.Path)
if err != nil {
return nil, nil, fmt.Errorf("failed to open x5u file:// at %#v: %w", x5u, err)
}
defer bodyReader.Close()
default:
return nil, nil, fmt.Errorf("unsupported x5u scheme: %#v", parsedURL.Scheme)
}
body, err = io.ReadAll(resp.Body)

body, err := io.ReadAll(bodyReader)
if err != nil {
err = fmt.Errorf("failed to parse x5u body: %w", err)
return
return nil, nil, fmt.Errorf("failed to parse x5u body from %#v: %w", x5u, err)
}
certs, err = csigverifier.ParseChain(body)
certs, err := csigverifier.ParseChain(body)
if err != nil {
err = fmt.Errorf("failed to parse x5u : %w", err)
return

return nil, nil, fmt.Errorf("failed to parse x5u : %w", err)
}
rootHash := sha2Fingerprint(certs[2])
err = csigverifier.VerifyChain(rootHash, certs, time.Now())
if err != nil {
err = fmt.Errorf("failed to verify certificate chain: %w", err)
return
return nil, nil, fmt.Errorf("failed to verify certificate chain: %w", err)
}
return
return body, certs, nil
}

func sha2Fingerprint(cert *x509.Certificate) string {
Expand Down
70 changes: 53 additions & 17 deletions signer/contentsignaturepki/upload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,78 @@ func (m mockUploadAPI) Upload(ctx context.Context, input *s3.PutObjectInput, opt

func TestUploadToS3(t *testing.T) {
cases := []struct {
client func(t *testing.T) S3UploadAPI
data string
name string
target string
expectErr bool
testName string
client func(t *testing.T) S3UploadAPI
data string
name string
chainUploadLocation string
expectErr bool
}{
{
testName: "successful_upload",
client: func(t *testing.T) S3UploadAPI {
return mockUploadAPI(func(ctx context.Context, input *s3.PutObjectInput, opts ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
t.Helper()
expectedBucket := "foo.bar"
if *input.Bucket != expectedBucket {
t.Errorf("bucket: want %#v, got %#v", expectedBucket, *input.Bucket)
}
if *input.Key != "somestuff/successful_chain" {
t.Errorf("key: want \"somestuff/successful_chain\", got %#v", *input.Key)
}
return &manager.UploadOutput{}, nil
})
},
data: "foo",
name: "successful_upload",
target: "https://foo.bar",
expectErr: false,
data: "foo",
name: "successful_chain",
chainUploadLocation: "s3://foo.bar/somestuff/",
expectErr: false,
},
{
testName: "successful_upload_with_missing_slash",
client: func(t *testing.T) S3UploadAPI {
return mockUploadAPI(func(ctx context.Context, input *s3.PutObjectInput, opts ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
t.Helper()
expectedBucket := "foo.bar"
if *input.Bucket != expectedBucket {
t.Errorf("bucket: want %#v, got %#v", expectedBucket, *input.Bucket)
}
expectedKey := "somestuff/successful_chain"
if *input.Key != expectedKey {
t.Errorf("key: want %#v, got %#v", expectedKey, *input.Key)
}
return &manager.UploadOutput{}, nil
})
},
data: "foo",
name: "successful_chain",
chainUploadLocation: "s3://foo.bar/somestuff",
expectErr: false,
},
{
testName: "failed_upload",
client: func(t *testing.T) S3UploadAPI {
return mockUploadAPI(func(ctx context.Context, input *s3.PutObjectInput, opts ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
expectedBucket := "foo.quux"
if *input.Bucket != expectedBucket {
t.Errorf("bucket: want %#v, got %#v", expectedBucket, *input.Bucket)
}
expectedKey := "something/will_fail_chain"
if *input.Key != expectedKey {
t.Errorf("key: want %#v, got %#v", expectedKey, *input.Key)
}
return nil, errors.New("upload failed")
})
},
data: "foo",
name: "failed_upload",
target: "https://foo.bar",
expectErr: true,
data: "foo",
name: "will_fail_chain",
chainUploadLocation: "s3://foo.quux/something/",
expectErr: true,
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
t.Run(tt.testName, func(t *testing.T) {
t.Parallel()
url, err := url.Parse(tt.target)
url, err := url.Parse(tt.chainUploadLocation)
if err != nil {
t.Fatalf("error parsing test url: %v", err)
}
Expand Down
Loading
Loading