diff --git a/x/go.mod b/x/go.mod index 640b9c73..e9c14043 100644 --- a/x/go.mod +++ b/x/go.mod @@ -6,7 +6,7 @@ require ( github.com/Jigsaw-Code/outline-sdk v0.0.16 // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules - github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240522172529-8fcc4b9a51cf + github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.1.0 diff --git a/x/go.sum b/x/go.sum index fe9203a3..a8dcfab0 100644 --- a/x/go.sum +++ b/x/go.sum @@ -18,8 +18,8 @@ github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464 h1:VmnMMMheFX github.com/Psiphon-Labs/goptlib v0.0.0-20200406165125-c0e32a7a3464/go.mod h1:Pe5BqN2DdIdChorAXl6bDaQd/wghpCleJfid2NoSli0= github.com/Psiphon-Labs/psiphon-tls v0.0.0-20240424193802-52b2602ec60c h1:+SEszyxW7yu+smufzSlAszj/WmOYJ054DJjb5jllulc= github.com/Psiphon-Labs/psiphon-tls v0.0.0-20240424193802-52b2602ec60c/go.mod h1:AaKKoshr8RI1LZTheeNDtNuZ39qNVPWVK4uir2c2XIs= -github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240522172529-8fcc4b9a51cf h1:qXrGUIY9MMXIWqOmWv84qjVa8XLLjcOb+S5TEpZjFpA= -github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240522172529-8fcc4b9a51cf/go.mod h1:Z5txHi6IF67uDg206QnSxkgE1I3FJUDDJ3n0pa+bKRs= +github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 h1:YhpvDo++9Q3FiBuaAUhrFEzEWC6es3zFohjofEwO6xg= +github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647/go.mod h1:Z5txHi6IF67uDg206QnSxkgE1I3FJUDDJ3n0pa+bKRs= github.com/Psiphon-Labs/quic-go v0.0.0-20240424181006-45545f5e1536 h1:pM5ex1QufkHV8lDR6Tc1Crk1bW5lYZjrFIJGZNBWE9k= github.com/Psiphon-Labs/quic-go v0.0.0-20240424181006-45545f5e1536/go.mod h1:2MTiPsgoOqWs3Bo6Xr3ElMBX6zzfjd3YkDFpQJLwHdQ= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= diff --git a/x/psiphon/psiphon.go b/x/psiphon/psiphon.go index 31064699..ad35c7f9 100644 --- a/x/psiphon/psiphon.go +++ b/x/psiphon/psiphon.go @@ -18,25 +18,23 @@ import ( "context" "encoding/json" "errors" - "fmt" - "io" "net" + "runtime" + "strings" "sync" + "unicode" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Psiphon-Labs/psiphon-tunnel-core/ClientLibrary/clientlib" - psi "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon" ) // The single [Dialer] we can have. -var singletonDialer = Dialer{ - setNoticeWriter: psi.SetNoticeWriter, -} +var singletonDialer Dialer var ( errNotStartedDial = errors.New("dialer has not been started yet") errNotStartedStop = errors.New("tried to stop dialer that is not running") - errTunnelTimeout = errors.New("tunnel establishment timed out") + errAlreadyStarted = errors.New("dialer has already started") ) // DialerConfig specifies the parameters for [Dialer]. @@ -62,11 +60,14 @@ type Dialer struct { // Controls the Dialer state and Psiphon's global state. mu sync.Mutex // Used by DialStream. - controller *psi.Controller + tunnel psiphonTunnel // Used by Stop. stop func() - // Allows for overriding the global notice writer for testing. - setNoticeWriter func(io.Writer) +} + +type psiphonTunnel interface { + Dial(remoteAddr string) (net.Conn, error) + Stop() } var _ transport.StreamDialer = (*Dialer)(nil) @@ -76,153 +77,102 @@ var _ transport.StreamDialer = (*Dialer)(nil) // you will need to add it independently. func (d *Dialer) DialStream(unusedContext context.Context, addr string) (transport.StreamConn, error) { d.mu.Lock() - controller := d.controller + tunnel := d.tunnel d.mu.Unlock() - if controller == nil { + if tunnel == nil { return nil, errNotStartedDial } - netConn, err := controller.Dial(addr, nil) + netConn, err := tunnel.Dial(addr) if err != nil { return nil, err } return streamConn{netConn}, nil } -func newPsiphonConfig(config *DialerConfig) (*psi.Config, error) { +func getClientPlatform() string { + clientPlatformAllowChars := func(r rune) bool { + return !unicode.IsSpace(r) && r != '_' + } + goos := strings.Join(strings.FieldsFunc(runtime.GOOS, clientPlatformAllowChars), "-") + goarch := strings.Join(strings.FieldsFunc(runtime.GOARCH, clientPlatformAllowChars), "-") + return "outline-sdk_" + goos + "_" + goarch +} + +// Allows for overriding in tests. +var startTunnel func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) = psiphonStartTunnel + +func psiphonStartTunnel(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { if config == nil { return nil, errors.New("config must not be nil") } - // Validate keys. We parse as a map first because we need to check for the existence - // of certain keys. - var configMap map[string]interface{} - if err := json.Unmarshal(config.ProviderConfig, &configMap); err != nil { - return nil, fmt.Errorf("failed to parse config: %w", err) - } - for key, value := range configMap { - switch key { - case "DisableLocalHTTPProxy", "DisableLocalSocksProxy": - b, ok := value.(bool) - if !ok { - return nil, fmt.Errorf("field %v must be a boolean", key) - } - if b != true { - return nil, fmt.Errorf("field %v must be true if set", key) - } - case "DataRootDirectory": - return nil, errors.New("field DataRootDirectory must not be set in the provider config. Specify it in the DialerConfig instead.") - } - } - // Parse provider config. - pConfig, err := psi.LoadConfig(config.ProviderConfig) - if err != nil { - return nil, fmt.Errorf("config load failed: %w", err) + // Note that these parameters override anything in the provider config. + clientPlatform := getClientPlatform() + trueValue := true + params := clientlib.Parameters{ + DataRootDirectory: &config.DataRootDirectory, + ClientPlatform: &clientPlatform, + // Disable Psiphon's local proxy servers, which we don't use. + DisableLocalSocksProxy: &trueValue, + DisableLocalHTTPProxy: &trueValue, } - // Force some Psiphon config defaults for the Outline SDK case. - pConfig.DisableLocalHTTPProxy = true - pConfig.DisableLocalSocksProxy = true - pConfig.DataRootDirectory = config.DataRootDirectory - - return pConfig, nil + return clientlib.StartTunnel(ctx, config.ProviderConfig, "", params, nil, nil) } // Start configures and runs the Dialer. It must be called before you can use the Dialer. It returns when the tunnel is ready. func (d *Dialer) Start(ctx context.Context, config *DialerConfig) error { - pConfig, err := newPsiphonConfig(config) - if err != nil { - return err - } - - // Will receive a value if an error occurs during the connection sequence. - // It will be closed on succesful connection. - errCh := make(chan error) - - // Start returns either when a tunnel is ready, or an error happens, whichever comes first. - // When emitting the errors, we use a select statement to ensure the channel is being listened - // on, to avoid a deadlock after the initial error. + resultCh := make(chan error) go func() { - onTunnel := func() { - select { - case errCh <- nil: - default: - } - } - err := d.runController(ctx, pConfig, onTunnel) - select { - case errCh <- err: - default: - } - }() + d.mu.Lock() + defer d.mu.Unlock() - // Wait for an active tunnel or error - return <-errCh -} + if d.stop != nil { + resultCh <- errAlreadyStarted + return + } -func (d *Dialer) runController(ctx context.Context, pConfig *psi.Config, onTunnel func()) error { - d.mu.Lock() - defer d.mu.Unlock() - if d.stop != nil { - return errors.New("tried to start dialer that is alread running") - } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(context.Canceled) - controllerDone := make(chan struct{}) - defer close(controllerDone) - d.stop = func() { - // Tell controller to stop. - cancel(context.Canceled) - // Wait for controller to return. - <-controllerDone - } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + tunnelDone := make(chan struct{}) + defer close(tunnelDone) + d.stop = func() { + // Tell start to stop. + cancel() + // Wait for tunnel to be done. + <-tunnelDone + } + defer func() { + // Cleanup. + d.stop = nil + }() - // Set up NoticeWriter to receive events. - d.setNoticeWriter(psi.NewNoticeReceiver( - func(notice []byte) { - var event clientlib.NoticeEvent - err := json.Unmarshal(notice, &event) - if err != nil { - // This is unexpected and probably indicates something fatal has occurred. - // We'll interpret it as a connection error and abort. - cancel(fmt.Errorf("failed to unmarshal notice JSON: %w", err)) - return - } - switch event.Type { - case "EstablishTunnelTimeout": - cancel(errTunnelTimeout) - case "Tunnels": - count := event.Data["count"].(float64) - if count > 0 { - onTunnel() - } - } - })) - defer psi.SetNoticeWriter(io.Discard) + d.mu.Unlock() - err := pConfig.Commit(true) - if err != nil { - return fmt.Errorf("failed to commit config: %w", err) - } + tunnel, err := startTunnel(ctx, config) - err = psi.OpenDataStore(&psi.Config{DataRootDirectory: pConfig.DataRootDirectory}) - if err != nil { - return fmt.Errorf("failed to open data store: %w", err) - } - defer psi.CloseDataStore() + d.mu.Lock() - controller, err := psi.NewController(pConfig) - if err != nil { - return fmt.Errorf("failed to create Controller: %w", err) - } - d.controller = controller - d.mu.Unlock() - - controller.Run(ctx) - - d.mu.Lock() - d.controller = nil - d.stop = nil - return context.Cause(ctx) + if ctx.Err() != nil { + err = context.Cause(ctx) + } + if err != nil { + resultCh <- err + return + } + d.tunnel = tunnel + defer func() { + d.tunnel = nil + tunnel.Stop() + }() + resultCh <- nil + + d.mu.Unlock() + // wait for Stop + <-ctx.Done() + d.mu.Lock() + }() + return <-resultCh } // Stop stops the Dialer background processes, releasing resources and allowing it to be reconfigured. @@ -230,8 +180,8 @@ func (d *Dialer) runController(ctx context.Context, pConfig *psi.Config, onTunne func (d *Dialer) Stop() error { d.mu.Lock() stop := d.stop - d.stop = nil d.mu.Unlock() + if stop == nil { return errNotStartedStop } diff --git a/x/psiphon/psiphon_test.go b/x/psiphon/psiphon_test.go index 7f8bbd55..98b1bf05 100644 --- a/x/psiphon/psiphon_test.go +++ b/x/psiphon/psiphon_test.go @@ -19,13 +19,13 @@ package psiphon import ( "context" "encoding/json" - "io" + "errors" + "net" "os" "testing" "time" - psi "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon" - + "github.com/Psiphon-Labs/psiphon-tunnel-core/ClientLibrary/clientlib" "github.com/stretchr/testify/require" ) @@ -41,76 +41,82 @@ func newTestConfig(tb testing.TB) (*DialerConfig, func()) { }, func() { os.RemoveAll(tempDir) } } -func TestNewPsiphonConfig_ParseCorrectly(t *testing.T) { - config, err := newPsiphonConfig(&DialerConfig{ - ProviderConfig: json.RawMessage(`{ - "PropagationChannelId": "ID1", - "SponsorId": "ID2" - }`), - }) - require.NoError(t, err) - require.Equal(t, "ID1", config.PropagationChannelId) - require.Equal(t, "ID2", config.SponsorId) -} +func TestDialer_Start_Successful(t *testing.T) { + dialer := GetSingletonDialer() + startTunnel = func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { + return &clientlib.PsiphonTunnel{}, nil + } + defer func() { + startTunnel = psiphonStartTunnel + }() -func TestNewPsiphonConfig_AcceptOkOptions(t *testing.T) { - _, err := newPsiphonConfig(&DialerConfig{ - ProviderConfig: json.RawMessage(`{ - "DisableLocalHTTPProxy": true, - "DisableLocalSocksProxy": true - }`)}) - require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + require.NoError(t, dialer.Start(ctx, nil)) + require.NotNil(t, dialer.tunnel) + require.ErrorIs(t, dialer.Start(ctx, nil), errAlreadyStarted) + require.NoError(t, dialer.Stop()) + require.Nil(t, dialer.tunnel) + require.ErrorIs(t, dialer.Stop(), errNotStartedStop) + require.NoError(t, dialer.Start(ctx, nil)) + require.NoError(t, dialer.Stop()) } -func TestNewPsiphonConfig_RejectBadOptions(t *testing.T) { - _, err := newPsiphonConfig(&DialerConfig{ - ProviderConfig: json.RawMessage(`{"DisableLocalHTTPProxy": false}`)}) - require.Error(t, err) +func TestDialer_StopOnStart(t *testing.T) { + dialer := GetSingletonDialer() + startCalled := make(chan struct{}) + startTunnel = func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { + startCalled <- struct{}{} + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + } + } + defer func() { + startTunnel = psiphonStartTunnel + }() - _, err = newPsiphonConfig(&DialerConfig{ - ProviderConfig: json.RawMessage(`{"DisableLocalSocksProxy": false}`)}) - require.Error(t, err) - require.Error(t, err) + resultCh := make(chan error) + go func() { + resultCh <- dialer.Start(context.Background(), nil) + }() + <-startCalled + require.NoError(t, dialer.Stop()) + require.Error(t, <-resultCh) } -func TestDialer_StartSuccessful(t *testing.T) { - // Create minimal config. - cfg, delete := newTestConfig(t) - defer delete() - - // Intercept notice writer. +func TestDialer_StartOnStart(t *testing.T) { dialer := GetSingletonDialer() - wCh := make(chan io.Writer) - dialer.setNoticeWriter = func(w io.Writer) { - wCh <- w + startCalled := make(chan struct{}) + startTunnel = func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { + startCalled <- struct{}{} + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + } } defer func() { - dialer.setNoticeWriter = psi.SetNoticeWriter + startTunnel = psiphonStartTunnel }() - errCh := make(chan error) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + resultCh := make(chan error) go func() { - errCh <- dialer.Start(ctx, cfg) + resultCh <- dialer.Start(context.Background(), nil) }() - - // We use a select because the error may happen before the notice writer is set. - select { - case w := <-wCh: - // Notify fake tunnel establishment once we have the notice writer. - psi.SetNoticeWriter(w) - psi.NoticeTunnels(1) - case err := <-errCh: - t.Fatalf("Got error from Start: %v", err) + <-startCalled + startTunnel = func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { + return nil, errors.New("failed to start") } - - err := <-errCh - require.NoError(t, err) + require.ErrorIs(t, dialer.Start(context.Background(), nil), errAlreadyStarted) require.NoError(t, dialer.Stop()) + require.Error(t, <-resultCh) +} + +func TestDialer_Start_NilConfig(t *testing.T) { + require.Error(t, GetSingletonDialer().Start(context.Background(), nil)) } -func TestDialerStart_Cancelled(t *testing.T) { +func TestDialer_Start_Cancelled(t *testing.T) { cfg, delete := newTestConfig(t) defer delete() errCh := make(chan error) @@ -123,7 +129,7 @@ func TestDialerStart_Cancelled(t *testing.T) { require.ErrorIs(t, err, context.Canceled) } -func TestDialerStart_Timeout(t *testing.T) { +func TestDialer_Start_Timeout(t *testing.T) { cfg, delete := newTestConfig(t) defer delete() errCh := make(chan error) @@ -136,12 +142,78 @@ func TestDialerStart_Timeout(t *testing.T) { require.ErrorIs(t, err, context.DeadlineExceeded) } -func TestDialerDialStream_NotStarted(t *testing.T) { - _, err := GetSingletonDialer().DialStream(context.Background(), "") +type errorTunnel struct { + err error + stopped bool +} + +func (t *errorTunnel) Dial(addr string) (net.Conn, error) { + return nil, t.err +} + +func (t *errorTunnel) Stop() { + t.stopped = true +} + +func TestDialer_DialStream(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dialer := GetSingletonDialer() + + // Dial before Start. + _, err := dialer.DialStream(ctx, "") require.ErrorIs(t, err, errNotStartedDial) + + var tunnel errorTunnel + startTunnel = func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { + tunnel.stopped = false + return &tunnel, nil + } + defer func() { + startTunnel = psiphonStartTunnel + }() + // Make sure it works on restarts. + for i := 0; i < 2; i++ { + // Dial after Start. + require.NoError(t, dialer.Start(ctx, nil)) + require.False(t, tunnel.stopped) + conn, err := dialer.DialStream(ctx, "") + require.NoError(t, err) + require.NoError(t, conn.CloseRead()) + require.NoError(t, conn.CloseWrite()) + + // Dial after Stop. + require.NoError(t, dialer.Stop()) + require.True(t, tunnel.stopped) + _, err = dialer.DialStream(nil, "") + require.ErrorIs(t, err, errNotStartedDial) + } +} + +func TestDialer_DialStream_Error(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dialer := GetSingletonDialer() + tunnel := errorTunnel{ + err: errors.New("failed to dial"), + } + startTunnel = func(ctx context.Context, config *DialerConfig) (psiphonTunnel, error) { + tunnel.stopped = false + return &tunnel, nil + } + defer func() { + startTunnel = psiphonStartTunnel + }() + require.NoError(t, dialer.Start(ctx, nil)) + require.False(t, tunnel.stopped) + _, err := dialer.DialStream(ctx, "") + require.Equal(t, tunnel.err, err) + require.NoError(t, dialer.Stop()) } -func TestDialerStop_NotStarted(t *testing.T) { +func TestDialer_Stop_NotStarted(t *testing.T) { err := GetSingletonDialer().Stop() require.ErrorIs(t, err, errNotStartedStop) }