Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support to fetch WATER WASM files #628

Merged
merged 16 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions http-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,11 @@ var (

algenevaAddr = flag.String("algeneva-addr", "", "Address at which to listen for algenAddr connections.")

waterAddr = flag.String("water-addr", "", "Address at which to listen for WATER connections.")
waterWASM = flag.String("water-wasm", "", "Base64 encoded WASM for WATER")
waterTransport = flag.String("water-transport", "", "WATER based transport name")
waterAddr = flag.String("water-addr", "", "Address at which to listen for WATER connections.")
waterWASM = flag.String("water-wasm", "", "Base64 encoded WASM for WATER")
waterWASMAvailableAt = flag.String("water-wasm-available-at", "", "URLs where the WATER WASM is available")
waterWASMHashsum = flag.String("water-wasm-hashsum", "", "Hashsum of the WATER WASM")
waterTransport = flag.String("water-transport", "", "WATER based transport name")

track = flag.String("track", "", "The track this proxy is running on")
)
Expand Down Expand Up @@ -473,6 +475,8 @@ func main() {
AlgenevaAddr: *algenevaAddr,
WaterAddr: *waterAddr,
WaterWASM: *waterWASM,
WaterWASMAvailableAt: *waterWASMAvailableAt,
WaterWASMHashsum: *waterWASMHashsum,
WaterTransport: *waterTransport,
}
if *maxmindLicenseKey != "" {
Expand Down
37 changes: 33 additions & 4 deletions http_proxy.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package proxy

import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -185,9 +187,12 @@ type Proxy struct {

AlgenevaAddr string

WaterAddr string
WaterWASM string
WaterTransport string
// deprecated: use WaterWASMAvailableAt
WaterWASM string
WaterWASMAvailableAt string
WaterWASMHashsum string
WendelHime marked this conversation as resolved.
Show resolved Hide resolved
WaterTransport string
WaterAddr string

throttleConfig throttle.Config
instrument instrument.Instrument
Expand Down Expand Up @@ -988,7 +993,31 @@ func (p *Proxy) listenAlgeneva(baseListen func(string) (net.Listener, error)) li
// Currently water doesn't support customized TCP connections and we need to listen and receive requests directly from the WATER listener
func (p *Proxy) listenWATER(addr string) (net.Listener, error) {
ctx := context.Background()
waterListener, err := water.NewWATERListener(ctx, p.WaterTransport, addr, p.WaterWASM)
var wasm []byte
if p.WaterWASM != "" {
var err error
wasm, err = base64.StdEncoding.DecodeString(p.WaterWASM)
if err != nil {
log.Errorf("failed to decode WASM base64: %v", err)
return nil, err
}
}

if p.WaterWASMAvailableAt != "" {
wasmBuffer := new(bytes.Buffer)
err := water.NewWASMDownloader(
water.WithURLs(strings.Split(p.WaterWASMAvailableAt, ",")),
water.WithExpectedHashsum(p.WaterWASMHashsum),
water.WithHTTPClient(&http.Client{Timeout: 1 * time.Minute}),
).DownloadWASM(ctx, wasmBuffer)
if err != nil {
return nil, fmt.Errorf("unable to download water wasm: %v", err)
WendelHime marked this conversation as resolved.
Show resolved Hide resolved
}
wasm = wasmBuffer.Bytes()
}

// currently the WATER listener doesn't accept a multiplexed connections, so we need to listen and accept connections directly from the listener
waterListener, err := water.NewWATERListener(ctx, nil, p.WaterTransport, addr, wasm)
if err != nil {
return nil, err
}
Expand Down
118 changes: 118 additions & 0 deletions water/downloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package water

import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"strings"
)

//go:generate mockgen -destination=mock_downloader.go -package=water . WASMDownloader

type WASMDownloader interface {
DownloadWASM(context.Context, io.Writer) error
}

type downloader struct {
urls []string
httpClient *http.Client
expectedHashSum string
httpDownloader WASMDownloader
}

type DownloaderOption func(*downloader)

func WithURLs(urls []string) DownloaderOption {
return func(d *downloader) {
d.urls = urls
}
}

func WithHTTPClient(httpClient *http.Client) DownloaderOption {
return func(d *downloader) {
d.httpClient = httpClient
}
}

func WithExpectedHashsum(hashsum string) DownloaderOption {
WendelHime marked this conversation as resolved.
Show resolved Hide resolved
return func(d *downloader) {
d.expectedHashSum = hashsum
}
}

func WithHTTPDownloader(httpDownloader WASMDownloader) DownloaderOption {
return func(d *downloader) {
d.httpDownloader = httpDownloader
}
}

// NewWASMDownloader creates a new WASMDownloader instance.
func NewWASMDownloader(withOpts ...DownloaderOption) WASMDownloader {
downloader := new(downloader)
for _, opt := range withOpts {
opt(downloader)
}
return downloader
}
WendelHime marked this conversation as resolved.
Show resolved Hide resolved

// DownloadWASM downloads the WASM file from the given URLs, verifies the hash
// sum and writes the file to the given writer.
func (d *downloader) DownloadWASM(ctx context.Context, w io.Writer) error {
joinedErrs := errors.New("failed to download WASM from all URLs")
for _, url := range d.urls {
WendelHime marked this conversation as resolved.
Show resolved Hide resolved
if strings.HasPrefix(url, "magnet:?") {
// Skip magnet links for now
joinedErrs = errors.Join(joinedErrs, errors.New("magnet links are not supported"))
continue
}
tempBuffer := &bytes.Buffer{}
err := d.downloadWASM(ctx, tempBuffer, url)
if err != nil {
joinedErrs = errors.Join(joinedErrs, err)
continue
}

err = d.verifyHashSum(tempBuffer.Bytes())
if err != nil {
joinedErrs = errors.Join(joinedErrs, err)
continue
}

_, err = tempBuffer.WriteTo(w)
if err != nil {
joinedErrs = errors.Join(joinedErrs, err)
continue
}

return nil
}
return joinedErrs
}

// downloadWASM checks what kind of URL was given and downloads the WASM file
// from the URL. It can be a HTTPS URL or a magnet link.
func (d *downloader) downloadWASM(ctx context.Context, w io.Writer, url string) error {
switch {
case strings.HasPrefix(url, "http://"), strings.HasPrefix(url, "https://"):
if d.httpDownloader == nil {
d.httpDownloader = NewHTTPSDownloader(d.httpClient, url)
}
return d.httpDownloader.DownloadWASM(ctx, w)
default:
return fmt.Errorf("unsupported protocol: %s", url)
}
}

var ErrFailedToVerifyHashSum = errors.New("failed to verify hash sum")

func (d *downloader) verifyHashSum(data []byte) error {
sha256Hashsum := sha256.Sum256(data)
if d.expectedHashSum == "" || d.expectedHashSum != fmt.Sprintf("%x", sha256Hashsum[:]) {
return ErrFailedToVerifyHashSum
}
return nil
}
Loading
Loading