-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_test.go
141 lines (122 loc) · 3.33 KB
/
main_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package main
import (
"crypto"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/PuerkitoBio/goquery"
"github.com/audunmo/sri-fixer/pkg/extractor"
"github.com/audunmo/sri-fixer/pkg/hash"
"github.com/audunmo/sri-fixer/pkg/injector"
scriptfetcher "github.com/audunmo/sri-fixer/pkg/script_fetcher"
)
var (
path1 = "/scripts/1.js"
path2 = "/scripts/2.js"
path3 = "/scripts/3.js"
path4 = "/scripts/4.js"
path5 = "/scripts/5.js"
)
func createTestServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(http.CanonicalHeaderKey("content-type"), "application/json")
if r.URL.Path == path1 {
w.Write([]byte("console.log(1)"))
}
if r.URL.Path == path2 {
w.Write([]byte("console.log(2)"))
}
if r.URL.Path == path3 {
w.Write([]byte("console.log(3)"))
}
if r.URL.Path == path4 {
w.Write([]byte("console.log(4)"))
}
if r.URL.Path == path5 {
w.Write([]byte("console.log(5)"))
}
}))
return ts
}
// TestExtractHashAndInject runs the primary test. It will perform the following flow:
// 1. Read all script srcs from the markup
// 2. Download all the remote scripts
// 3. Hash all the scripts
// 4. Inject integrity hashes into the markup
func TestExtractHashAndInject(t *testing.T) {
ts := createTestServer()
url1 := ts.URL + path1
url2 := ts.URL + path2
url3 := ts.URL + path3
url4 := ts.URL + path4
url5 := ts.URL + path5
markup := fmt.Sprintf(`
<html>
<head>
<script src="%v"></script>
<script src="%v"></script>
<script src="%v"></script>
</head>
<body>
<script src="%v"></script>
<script src="%v"></script>
</body>
</html>
`, url1, url2, url3, url4, url5)
urls, err := extractor.ExtractURLS(strings.NewReader(markup))
if err != nil {
t.Fatal(err)
}
f := scriptfetcher.New([]string{})
html := markup
integrities := map[string]string{}
for _, u := range urls.Scripts {
script, err := f.Fetch(u)
if err != nil {
t.Fatal(err)
}
h := hash.Hash([]byte(script), []crypto.Hash{crypto.SHA256, crypto.SHA384, crypto.SHA512})
integrity := fmt.Sprintf("%v %v %v", h[crypto.SHA256], h[crypto.SHA384], h[crypto.SHA512])
integrities[u] = integrity
html, err = injector.Inject(html, u, integrity, "script")
if err != nil {
t.Fatal(err)
}
}
for _, u := range urls.Links {
script, err := f.Fetch(u)
if err != nil {
t.Fatal(err)
}
h := hash.Hash([]byte(script), []crypto.Hash{crypto.SHA256, crypto.SHA384, crypto.SHA512})
integrity := fmt.Sprintf("%v %v %v", h[crypto.SHA256], h[crypto.SHA384], h[crypto.SHA512])
integrities[u] = integrity
html, err = injector.Inject(html, u, integrity, "link")
if err != nil {
t.Fatal(err)
}
}
newDoc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
if err != nil {
t.Fatal(err)
}
foundAndVerified := map[string]bool{}
for url := range integrities {
foundAndVerified[url] = false
}
newDoc.Find("script").Each(func(n int, s *goquery.Selection) {
src, _ := s.Attr("src")
expectedHash := integrities[src]
actualHash, _ := s.Attr("integrity")
if expectedHash == actualHash {
foundAndVerified[src] = true
}
})
for url, verified := range foundAndVerified {
if !verified {
t.Fatalf("was unable to find or verify the integrity hash for script %v", url)
}
}
}