diff --git a/http-proxy/main.go b/http-proxy/main.go index 22e4dcef..4697deb1 100644 --- a/http-proxy/main.go +++ b/http-proxy/main.go @@ -186,9 +186,10 @@ var ( algenevaAddr = flag.String("algeneva-addr", "", "Address at which to listen for algenAddr connections.") - waterAddr = flag.String("water-addr", "", "Address at which to listen for WATER connections.") - waterWASM = flag.String("water-wasm", "", "Base64 encoded WASM for WATER") - waterTransport = flag.String("water-transport", "", "WATER based transport name") + waterAddr = flag.String("water-addr", "", "Address at which to listen for WATER connections.") + waterWASM = flag.String("water-wasm", "", "Base64 encoded WASM for WATER") + waterWASMAvailableAt = flag.String("water-wasm-available-at", "", "URLs where the WATER WASM is available") + waterTransport = flag.String("water-transport", "", "WATER based transport name") track = flag.String("track", "", "The track this proxy is running on") ) @@ -473,6 +474,7 @@ func main() { AlgenevaAddr: *algenevaAddr, WaterAddr: *waterAddr, WaterWASM: *waterWASM, + WaterWASMAvailableAt: *waterWASMAvailableAt, WaterTransport: *waterTransport, } if *maxmindLicenseKey != "" { diff --git a/http_proxy.go b/http_proxy.go index c6ce452f..47aa49a3 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -1,8 +1,10 @@ package proxy import ( + "bytes" "context" "crypto/tls" + "encoding/base64" "encoding/json" "fmt" "net" @@ -185,9 +187,11 @@ type Proxy struct { AlgenevaAddr string - WaterAddr string - WaterWASM string - WaterTransport string + // deprecated: use WaterWASMAvailableAt + WaterWASM string + WaterWASMAvailableAt string + WaterTransport string + WaterAddr string throttleConfig throttle.Config instrument instrument.Instrument @@ -988,8 +992,34 @@ func (p *Proxy) listenAlgeneva(baseListen func(string) (net.Listener, error)) li // Currently water doesn't support customized TCP connections and we need to listen and receive requests directly from the WATER listener func (p *Proxy) listenWATER(addr string) (net.Listener, error) { ctx := context.Background() - waterListener, err := water.NewWATERListener(ctx, p.WaterTransport, addr, p.WaterWASM) + var wasm []byte + if p.WaterWASM != "" { + var err error + wasm, err = base64.StdEncoding.DecodeString(p.WaterWASM) + if err != nil { + log.Errorf("failed to decode WASM base64: %v", err) + return nil, err + } + } + + if p.WaterWASMAvailableAt != "" { + wasmBuffer := new(bytes.Buffer) + d, err := water.NewWASMDownloader(strings.Split(p.WaterWASMAvailableAt, ","), &http.Client{Timeout: 1 * time.Minute}) + if err != nil { + return nil, log.Errorf("failed to create wasm downloader: %w", err) + } + + err = d.DownloadWASM(ctx, wasmBuffer) + if err != nil { + return nil, log.Errorf("unable to download water wasm: %w", err) + } + wasm = wasmBuffer.Bytes() + } + + // currently the WATER listener doesn't accept a multiplexed connections, so we need to listen and accept connections directly from the listener + waterListener, err := water.NewWATERListener(ctx, nil, p.WaterTransport, addr, wasm) if err != nil { + log.Errorf("failed to starte WATER listener: %w", err) return nil, err } diff --git a/water/downloader.go b/water/downloader.go new file mode 100644 index 00000000..151ed7eb --- /dev/null +++ b/water/downloader.go @@ -0,0 +1,75 @@ +package water + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "strings" +) + +//go:generate mockgen -package=water -destination=mocks_test.go . WASMDownloader + +type WASMDownloader interface { + DownloadWASM(context.Context, io.Writer) error +} + +type downloader struct { + urls []string + httpClient *http.Client + httpDownloader WASMDownloader +} + +// NewWASMDownloader creates a new WASMDownloader instance. +func NewWASMDownloader(urls []string, client *http.Client) (WASMDownloader, error) { + if len(urls) == 0 { + return nil, log.Error("WASM downloader requires URLs to download but received empty list") + } + return &downloader{ + urls: urls, + httpClient: client, + }, nil +} + +// DownloadWASM downloads the WASM file from the given URLs, verifies the hash +// sum and writes the file to the given writer. +func (d *downloader) DownloadWASM(ctx context.Context, w io.Writer) error { + joinedErrs := errors.New("failed to download WASM from all URLs") + for _, url := range d.urls { + if strings.HasPrefix(url, "magnet:?") { + // Skip magnet links for now + joinedErrs = errors.Join(joinedErrs, errors.New("magnet links are not supported")) + continue + } + tempBuffer := &bytes.Buffer{} + err := d.downloadWASM(ctx, tempBuffer, url) + if err != nil { + joinedErrs = errors.Join(joinedErrs, err) + continue + } + + _, err = tempBuffer.WriteTo(w) + if err != nil { + joinedErrs = errors.Join(joinedErrs, err) + continue + } + + return nil + } + return joinedErrs +} + +// downloadWASM checks what kind of URL was given and downloads the WASM file +// from the URL. It can be a HTTPS URL or a magnet link. +func (d *downloader) downloadWASM(ctx context.Context, w io.Writer, url string) error { + switch { + case strings.HasPrefix(url, "http://"), strings.HasPrefix(url, "https://"): + if d.httpDownloader == nil { + d.httpDownloader = NewHTTPSDownloader(d.httpClient, url) + } + return d.httpDownloader.DownloadWASM(ctx, w) + default: + return log.Errorf("unsupported protocol: %s", url) + } +} diff --git a/water/downloader_test.go b/water/downloader_test.go new file mode 100644 index 00000000..94e1519a --- /dev/null +++ b/water/downloader_test.go @@ -0,0 +1,143 @@ +package water + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" +) + +func TestNewWASMDownloader(t *testing.T) { + var tests = []struct { + name string + givenURLs []string + givenHTTPClient *http.Client + assert func(*testing.T, WASMDownloader, error) + }{ + { + name: "it should return an error when providing an empty list of URLs", + assert: func(t *testing.T, d WASMDownloader, err error) { + assert.Error(t, err) + assert.Nil(t, d) + }, + }, + { + name: "it should successfully return a wasm downloader", + givenURLs: []string{"http://example.com"}, + givenHTTPClient: http.DefaultClient, + assert: func(t *testing.T, wDownloader WASMDownloader, err error) { + assert.NoError(t, err) + d := wDownloader.(*downloader) + assert.Equal(t, []string{"http://example.com"}, d.urls) + assert.Equal(t, http.DefaultClient, d.httpClient) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d, err := NewWASMDownloader(tt.givenURLs, tt.givenHTTPClient) + tt.assert(t, d, err) + }) + } +} + +func TestDownloadWASM(t *testing.T) { + ctx := context.Background() + + contentMessage := "content" + var tests = []struct { + name string + givenHTTPClient *http.Client + givenURLs []string + givenWriter io.Writer + setupHTTPDownloader func(ctrl *gomock.Controller) WASMDownloader + assert func(*testing.T, io.Reader, error) + }{ + { + name: "it should return an error telling magnet links are not supported", + givenURLs: []string{"magnet:?"}, + assert: func(t *testing.T, r io.Reader, err error) { + b, berr := io.ReadAll(r) + require.NoError(t, berr) + assert.Empty(t, b) + assert.Error(t, err) + assert.ErrorContains(t, err, "magnet links are not supported") + }, + }, + { + name: "it should return an unupported protocol error when we provide an URL with not implemented downloader", + givenURLs: []string{ + "udp://example.com", + }, + assert: func(t *testing.T, r io.Reader, err error) { + b, berr := io.ReadAll(r) + require.NoError(t, berr) + assert.Empty(t, b) + assert.Error(t, err) + assert.ErrorContains(t, err, "unsupported protocol") + }, + }, + { + name: "it should return an error with the HTTP error", + givenURLs: []string{ + "http://example.com", + }, + setupHTTPDownloader: func(ctrl *gomock.Controller) WASMDownloader { + httpDownloader := NewMockWASMDownloader(ctrl) + httpDownloader.EXPECT().DownloadWASM(ctx, gomock.Any()).Return(assert.AnError) + return httpDownloader + }, + assert: func(t *testing.T, r io.Reader, err error) { + b, berr := io.ReadAll(r) + require.NoError(t, berr) + assert.Empty(t, b) + assert.Error(t, err) + assert.ErrorContains(t, err, assert.AnError.Error()) + assert.ErrorContains(t, err, "failed to download WASM from all URLs") + }, + }, + { + name: "it should return an io.Reader with the expected content", + givenURLs: []string{ + "http://example.com", + }, + setupHTTPDownloader: func(ctrl *gomock.Controller) WASMDownloader { + httpDownloader := NewMockWASMDownloader(ctrl) + httpDownloader.EXPECT().DownloadWASM(ctx, gomock.Any()).DoAndReturn( + func(ctx context.Context, w io.Writer) error { + _, err := w.Write([]byte(contentMessage)) + return err + }) + return httpDownloader + }, + assert: func(t *testing.T, r io.Reader, err error) { + b, berr := io.ReadAll(r) + require.NoError(t, berr) + assert.NoError(t, err) + assert.Equal(t, contentMessage, string(b)) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var httpDownloader WASMDownloader + if tt.setupHTTPDownloader != nil { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + httpDownloader = tt.setupHTTPDownloader(ctrl) + } + + b := &bytes.Buffer{} + wDownloader, err := NewWASMDownloader(tt.givenURLs, tt.givenHTTPClient) + require.NoError(t, err) + wDownloader.(*downloader).httpDownloader = httpDownloader + err = wDownloader.DownloadWASM(ctx, b) + tt.assert(t, b, err) + }) + } +} diff --git a/water/https_downloader.go b/water/https_downloader.go new file mode 100644 index 00000000..63e8fdf4 --- /dev/null +++ b/water/https_downloader.go @@ -0,0 +1,44 @@ +package water + +import ( + "context" + "fmt" + "io" + "net/http" +) + +type httpsDownloader struct { + cli *http.Client + url string +} + +func NewHTTPSDownloader(client *http.Client, url string) WASMDownloader { + return &httpsDownloader{cli: client, url: url} +} + +func (d *httpsDownloader) DownloadWASM(ctx context.Context, w io.Writer) error { + if d.cli == nil { + d.cli = http.DefaultClient + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, http.NoBody) + if err != nil { + return fmt.Errorf("failed to create a new HTTP request: %w", err) + } + resp, err := d.cli.Do(req) + if err != nil { + return fmt.Errorf("failed to send a HTTP request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download WASM file: %s", resp.Status) + } + + _, err = io.Copy(w, resp.Body) + if err != nil { + return fmt.Errorf("failed to write the WASM file: %w", err) + } + + return nil +} diff --git a/water/https_downloader_test.go b/water/https_downloader_test.go new file mode 100644 index 00000000..68bfb631 --- /dev/null +++ b/water/https_downloader_test.go @@ -0,0 +1,90 @@ +package water + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type roundTripFunc struct { + f func(req *http.Request) (*http.Response, error) +} + +func (f *roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f.f(req) +} + +func TestHTTPSDownloadWASM(t *testing.T) { + ctx := context.Background() + var tests = []struct { + name string + givenHTTPClient *http.Client + givenURL string + assert func(*testing.T, io.Reader, error) + }{ + { + name: "sending request successfully", + givenHTTPClient: &http.Client{ + Transport: &roundTripFunc{ + f: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString("wasm")), + }, nil + }, + }, + }, + givenURL: "https://example.com/wasm.wasm", + assert: func(t *testing.T, r io.Reader, err error) { + assert.NoError(t, err) + b, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, "wasm", string(b)) + }, + }, + { + name: "when receiving an error from the HTTP client, it should return an error", + givenHTTPClient: &http.Client{ + Transport: &roundTripFunc{ + f: func(req *http.Request) (*http.Response, error) { + return nil, assert.AnError + }, + }, + }, + givenURL: "https://example.com/wasm.wasm", + assert: func(t *testing.T, r io.Reader, err error) { + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to send a HTTP request") + }, + }, + { + name: "when the HTTP status code is not 200, it should return an error", + givenHTTPClient: &http.Client{ + Transport: &roundTripFunc{ + f: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + }, nil + }, + }, + }, + givenURL: "https://example.com/wasm.wasm", + assert: func(t *testing.T, r io.Reader, err error) { + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to download WASM file") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := new(bytes.Buffer) + err := NewHTTPSDownloader(tt.givenHTTPClient, tt.givenURL).DownloadWASM(ctx, b) + tt.assert(t, b, err) + }) + } +} diff --git a/water/listener.go b/water/listener.go index 3d6b86f7..71aa04ab 100644 --- a/water/listener.go +++ b/water/listener.go @@ -2,7 +2,6 @@ package water import ( "context" - "encoding/base64" "log/slog" "net" @@ -15,17 +14,14 @@ var log = golog.LoggerFor("water") // NewWATERListener creates a WATER listener // Currently water doesn't support customized TCP connections and we need to listen and receive requests directly from the WATER listener -func NewWATERListener(ctx context.Context, transport, address, wasm string) (net.Listener, error) { - decodedWASM, err := base64.StdEncoding.DecodeString(wasm) - if err != nil { - log.Errorf("failed to decode WASM base64: %v", err) - return nil, err +func NewWATERListener(ctx context.Context, baseListener net.Listener, transport, address string, wasm []byte) (net.Listener, error) { + cfg := &water.Config{ + TransportModuleBin: wasm, + OverrideLogger: slog.New(newLogHandler(log, transport)), } - cfg := &water.Config{ - TransportModuleBin: decodedWASM, - //NetworkListener: baseListener, - OverrideLogger: slog.New(newLogHandler(log, transport)), + if baseListener != nil { + cfg.NetworkListener = baseListener } waterListener, err := cfg.ListenContext(ctx, "tcp", address) diff --git a/water/listener_test.go b/water/listener_test.go index 26a997e1..e13544ee 100644 --- a/water/listener_test.go +++ b/water/listener_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "embed" - "encoding/base64" "io" "net" "testing" @@ -25,15 +24,13 @@ func TestWATERListener(t *testing.T) { wasm, err := io.ReadAll(f) require.Nil(t, err) - b64WASM := base64.StdEncoding.EncodeToString(wasm) - ctx := context.Background() cfg := &water.Config{ TransportModuleBin: wasm, } - ll, err := NewWATERListener(ctx, "reverse_v0", "127.0.0.1:3000", b64WASM) + ll, err := NewWATERListener(ctx, nil, "reverse_v0", "127.0.0.1:3000", wasm) require.Nil(t, err) messageRequest := "hello" diff --git a/water/mocks_test.go b/water/mocks_test.go new file mode 100644 index 00000000..80635f90 --- /dev/null +++ b/water/mocks_test.go @@ -0,0 +1,55 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/getlantern/http-proxy-lantern/v2/water (interfaces: WASMDownloader) +// +// Generated by this command: +// +// mockgen -package=water -destination=mocks_test.go . WASMDownloader +// + +// Package water is a generated GoMock package. +package water + +import ( + context "context" + io "io" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockWASMDownloader is a mock of WASMDownloader interface. +type MockWASMDownloader struct { + ctrl *gomock.Controller + recorder *MockWASMDownloaderMockRecorder +} + +// MockWASMDownloaderMockRecorder is the mock recorder for MockWASMDownloader. +type MockWASMDownloaderMockRecorder struct { + mock *MockWASMDownloader +} + +// NewMockWASMDownloader creates a new mock instance. +func NewMockWASMDownloader(ctrl *gomock.Controller) *MockWASMDownloader { + mock := &MockWASMDownloader{ctrl: ctrl} + mock.recorder = &MockWASMDownloaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWASMDownloader) EXPECT() *MockWASMDownloaderMockRecorder { + return m.recorder +} + +// DownloadWASM mocks base method. +func (m *MockWASMDownloader) DownloadWASM(arg0 context.Context, arg1 io.Writer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DownloadWASM", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DownloadWASM indicates an expected call of DownloadWASM. +func (mr *MockWASMDownloaderMockRecorder) DownloadWASM(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadWASM", reflect.TypeOf((*MockWASMDownloader)(nil).DownloadWASM), arg0, arg1) +}