From bb3af0827224e927b9c594b59e52aa47fac1a25a Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 15 Feb 2024 17:26:24 -0500 Subject: [PATCH 01/25] Create dns.NewStreamDialer --- dns/stream_dialer.go | 67 +++++++++++++++++++++++++++++++++++++++ dns/stream_dialer_test.go | 59 ++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 dns/stream_dialer.go create mode 100644 dns/stream_dialer_test.go diff --git a/dns/stream_dialer.go b/dns/stream_dialer.go new file mode 100644 index 00000000..33bc45ab --- /dev/null +++ b/dns/stream_dialer.go @@ -0,0 +1,67 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dns + +import ( + "context" + "fmt" + "net/netip" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "golang.org/x/net/dns/dnsmessage" +) + +func resolveIP(ctx context.Context, resolver Resolver, rrType dnsmessage.Type, hostname string) ([]netip.Addr, error) { + ips := []netip.Addr{} + q, err := NewQuestion(hostname, rrType) + if err != nil { + return nil, err + } + response, err := resolver.Query(ctx, *q) + if err != nil { + return nil, err + } + if response.RCode != dnsmessage.RCodeSuccess { + return nil, fmt.Errorf("got %v (%d)", response.RCode.String(), response.RCode) + } + for _, answer := range response.Answers { + if answer.Header.Type != rrType { + continue + } + if rr, ok := answer.Body.(*dnsmessage.AResource); ok { + ips = append(ips, netip.AddrFrom4(rr.A)) + } + if rr, ok := answer.Body.(*dnsmessage.AAAAResource); ok { + ips = append(ips, netip.AddrFrom16(rr.AAAA)) + } + } + return ips, nil +} + +// NewStreamDialer creates a [transport.StreamDialer] that uses Happy Eyeballs v2 to establish a connection. +// It uses resolver to map host names to IP addresses, and the given dialer to attempt connections. +func NewStreamDialer(resolver Resolver, dialer transport.StreamDialer) transport.StreamDialer { + return &transport.HappyEyeballsStreamDialer{ + Dialer: dialer, + Resolve: transport.NewParallelHappyEyeballsResolveFunc( + func(ctx context.Context, hostname string) ([]netip.Addr, error) { + return resolveIP(ctx, resolver, dnsmessage.TypeAAAA, hostname) + }, + func(ctx context.Context, hostname string) ([]netip.Addr, error) { + return resolveIP(ctx, resolver, dnsmessage.TypeA, hostname) + }, + ), + } +} diff --git a/dns/stream_dialer_test.go b/dns/stream_dialer_test.go new file mode 100644 index 00000000..82c5bfbf --- /dev/null +++ b/dns/stream_dialer_test.go @@ -0,0 +1,59 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dns + +import ( + "context" + "errors" + "net/netip" + "testing" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +func TestNewStreamDialer(t *testing.T) { + resolver := FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { + resp := new(dnsmessage.Message) + resp.Header.Response = true + resp.Questions = []dnsmessage.Question{q} + answerRR := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{Name: q.Name, Type: q.Type, Class: q.Class, TTL: 0}, + } + switch q.Type { + case dnsmessage.TypeA: + answerRR.Body = &dnsmessage.AResource{A: netip.MustParseAddr("127.0.0.1").As4()} + case dnsmessage.TypeAAAA: + answerRR.Body = &dnsmessage.AAAAResource{AAAA: netip.MustParseAddr("::1").As16()} + default: + return nil, errors.New("bad query type: " + q.Type.String()) + } + resp.Answers = []dnsmessage.Resource{answerRR} + resp.Authorities = []dnsmessage.Resource{} + resp.Additionals = []dnsmessage.Resource{} + return resp, nil + }) + addrs := []string{} + baseDialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + addrs = append(addrs, addr) + return nil, errors.New("not implemented") + }) + dialer := NewStreamDialer(resolver, baseDialer) + conn, err := dialer.DialStream(context.Background(), "localhost:8080") + require.Error(t, err) + require.Nil(t, conn) + require.Equal(t, []string{"[::1]:8080", "127.0.0.1:8080"}, addrs) +} From fb07d9004ca11c9497937726d1ad516860e3f682 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 15 Feb 2024 19:39:30 -0500 Subject: [PATCH 02/25] Create Smart dialer --- x/examples/smart-proxy/config.json | 115 +++++ x/examples/smart-proxy/config_broken.json | 19 + x/examples/smart-proxy/main.go | 145 ++++++ x/go.mod | 2 +- x/go.sum | 4 +- x/smart/cache.go | 117 +++++ x/smart/cname.go | 26 + x/smart/cname_unix.go | 72 +++ x/smart/stream_dialer.go | 594 ++++++++++++++++++++++ 9 files changed, 1091 insertions(+), 3 deletions(-) create mode 100644 x/examples/smart-proxy/config.json create mode 100644 x/examples/smart-proxy/config_broken.json create mode 100644 x/examples/smart-proxy/main.go create mode 100644 x/smart/cache.go create mode 100644 x/smart/cname.go create mode 100644 x/smart/cname_unix.go create mode 100644 x/smart/stream_dialer.go diff --git a/x/examples/smart-proxy/config.json b/x/examples/smart-proxy/config.json new file mode 100644 index 00000000..48535d38 --- /dev/null +++ b/x/examples/smart-proxy/config.json @@ -0,0 +1,115 @@ +{ + "dns": [ + {"udp": {"address": "8.8.8.8"}}, + + {"https": {"name": "2620:fe::fe"}, "//": "Quad9"}, + {"https": {"name": "9.9.9.9"}}, + {"https": {"name": "149.112.112.112"}}, + + {"https": {"name": "2001:4860:4860::8888"}, "//": "Google"}, + {"https": {"name": "8.8.8.8"}}, + {"https": {"name": "2001:4860:4860::8844"}}, + {"https": {"name": "8.8.4.4"}}, + + {"https": {"name": "2606:4700:4700::1111"}, "//": "Cloudflare"}, + {"https": {"name": "1.1.1.1"}}, + {"https": {"name": "2606:4700:4700::1001"}}, + {"https": {"name": "1.0.0.1"}}, + {"https": {"name": "cloudflare-dns.com.", "address": "cloudflare.net."}}, + + {"https": {"name": "2620:119:35::35"}, "//": "OpenDNS"}, + {"https": {"name": "208.67.220.220"}}, + {"https": {"name": "2620:119:53::53"}}, + {"https": {"name": "208.67.222.222"}}, + + {"https": {"name": "2001:67c:930::1"}, "//": "Wikimedia DNS"}, + {"https": {"name": "185.71.138.138"}}, + + {"https": {"name": "doh.dns.sb", "address": "cloudflare.net:443"}, "//": "DNS.SB"}, + + + {"tls": {"name": "2620:fe::fe"}, "//": "Quad9"}, + {"tls": {"name": "9.9.9.9"}}, + {"tls": {"name": "149.112.112.112"}}, + + {"tls": {"name": "2001:4860:4860::8888"}, "//": "Google"}, + {"tls": {"name": "8.8.8.8"}}, + {"tls": {"name": "2001:4860:4860::8844"}}, + {"tls": {"name": "8.8.4.4"}}, + + {"tls": {"name": "2606:4700:4700::1111"}, "//": "Cloudflare"}, + {"tls": {"name": "1.1.1.1"}}, + {"tls": {"name": "2606:4700:4700::1001"}}, + {"tls": {"name": "1.0.0.1"}}, + + {"tls": {"name": "2620:119:35::35"}, "//": "OpenDNS"}, + {"tls": {"name": "208.67.220.220"}}, + {"tls": {"name": "2620:119:53::53"}}, + {"tls": {"name": "208.67.222.222"}}, + + {"tls": {"name": "2001:67c:930::1"}, "//": "Wikimedia DNS"}, + {"tls": {"name": "185.71.138.138"}}, + + + {"tcp": {"address": "2620:fe::fe"}, "//": "Quad9"}, + {"tcp": {"address": "9.9.9.9"}}, + {"tcp": {"address": "149.112.112.112"}}, + {"tcp": {"address": "[2620:fe::fe]:9953"}}, + {"tcp": {"address": "9.9.9.9:9953"}}, + {"tcp": {"address": "149.112.112.112:9953"}}, + + {"tcp": {"address": "2001:4860:4860::8888"}, "//": "Google"}, + {"tcp": {"address": "8.8.8.8"}}, + {"tcp": {"address": "2001:4860:4860::8844"}}, + {"tcp": {"address": "8.8.4.4"}}, + + {"tcp": {"address": "2606:4700:4700::1111"}, "//": "Cloudflare"}, + {"tcp": {"address": "1.1.1.1"}}, + {"tcp": {"address": "2606:4700:4700::1001"}}, + {"tcp": {"address": "1.0.0.1"}}, + + {"tcp": {"address": "2620:119:35::35"}, "//": "OpenDNS"}, + {"tcp": {"address": "208.67.220.220"}}, + {"tcp": {"address": "2620:119:53::53"}}, + {"tcp": {"address": "208.67.222.222"}}, + {"tcp": {"address": "[2620:119:35::35]:443"}}, + {"tcp": {"address": "208.67.220.220:443"}}, + {"tcp": {"address": "[2620:119:35::35]:5353"}}, + {"tcp": {"address": "208.67.220.220:5353"}}, + + + {"udp": {"address": "2620:fe::fe"}, "//": "Quad9"}, + {"udp": {"address": "9.9.9.9"}}, + {"udp": {"address": "149.112.112.112"}}, + {"udp": {"address": "[2620:fe::fe]:9953"}}, + {"udp": {"address": "9.9.9.9:9953"}}, + {"udp": {"address": "149.112.112.112:9953"}}, + + {"udp": {"address": "2001:4860:4860::8888"}, "//": "Google"}, + {"udp": {"address": "8.8.8.8"}}, + {"udp": {"address": "2001:4860:4860::8844"}}, + {"udp": {"address": "8.8.4.4"}}, + + {"udp": {"address": "2606:4700:4700::1111"}, "//": "Cloudflare"}, + {"udp": {"address": "1.1.1.1"}}, + {"udp": {"address": "2606:4700:4700::1001"}}, + {"udp": {"address": "1.0.0.1"}}, + + {"udp": {"address": "2620:119:35::35"}, "//": "OpenDNS"}, + {"udp": {"address": "208.67.220.220"}}, + {"udp": {"address": "2620:119:53::53"}}, + {"udp": {"address": "208.67.222.222"}}, + {"udp": {"address": "[2620:119:35::35]:443"}}, + {"udp": {"address": "208.67.220.220:443"}}, + {"udp": {"address": "[2620:119:35::35]:5353"}}, + {"udp": {"address": "208.67.220.220:5353"}} + ], + + "tls": [ + "", + "split:1", + "split:2", + "split:5", + "tlsfrag:1" + ] +} diff --git a/x/examples/smart-proxy/config_broken.json b/x/examples/smart-proxy/config_broken.json new file mode 100644 index 00000000..296d572d --- /dev/null +++ b/x/examples/smart-proxy/config_broken.json @@ -0,0 +1,19 @@ +{ + "dns": [ + {"udp": {"address": "china.cn"}}, + {"udp": {"address": "ns1.tic.ir"}}, + {"tcp": {"address": "ns1.tic.ir"}}, + {"udp": {"address": "tmcell.tm"}}, + {"udp": {"address": "dns1.transtelecom.net."}}, + {"tls": {"name": "captive-portal.badssl.com", "address": "captive-portal.badssl.com:443"}}, + {"https": {"name": "mitm-software.badssl.com"}} + ], + + "tls": [ + "", + "split:1", + "split:2", + "split:5", + "tlsfrag:1" + ] +} diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go new file mode 100644 index 00000000..6db27d1c --- /dev/null +++ b/x/examples/smart-proxy/main.go @@ -0,0 +1,145 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "os/signal" + "strings" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/config" + "github.com/Jigsaw-Code/outline-sdk/x/httpproxy" + "github.com/Jigsaw-Code/outline-sdk/x/smart" +) + +var debugLog log.Logger = *log.New(io.Discard, "", 0) + +type stringArrayFlagValue []string + +func (v *stringArrayFlagValue) String() string { + return fmt.Sprint(*v) +} + +func (v *stringArrayFlagValue) Set(value string) error { + *v = append(*v, value) + return nil +} + +func main() { + verboseFlag := flag.Bool("v", false, "Enable debug output") + addrFlag := flag.String("localAddr", "localhost:1080", "Local proxy address") + configFlag := flag.String("config", "config.json", "Address of the config file") + transportFlag := flag.String("transport", "", "The base transport for the connections") + var domainsFlag stringArrayFlagValue + flag.Var(&domainsFlag, "domain", "The test domains to find strategies.") + + flag.Parse() + if *verboseFlag { + debugLog = *log.New(os.Stderr, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) + } + + if len(domainsFlag) == 0 { + log.Fatal("Must specify flag --domain") + } + + if *configFlag == "" { + log.Fatal("Must specify flag --config") + } + + finderConfig, err := os.ReadFile(*configFlag) + if err != nil { + log.Fatalf("Could not read config: %v", err) + } + + packetDialer, err := config.NewPacketDialer(*transportFlag) + if err != nil { + log.Fatalf("Could not create packet dialer: %v", err) + } + streamDialer, err := config.NewStreamDialer(*transportFlag) + if err != nil { + log.Fatalf("Could not create stream dialer: %v", err) + } + if strings.HasPrefix(*transportFlag, "ss:") { + innerDialer := streamDialer + // Hack to disable IPv6 with Shadowsocks, since it doesn't communicate connection success. + streamDialer = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip != nil && ip.To4() == nil { + return nil, fmt.Errorf("IPv6 not supported") + } + return innerDialer.DialStream(ctx, addr) + }) + } + finder := smart.StrategyFinder{ + LogWriter: debugLog.Writer(), + TestTimeout: 5 * time.Second, + StreamDialer: streamDialer, + PacketDialer: packetDialer, + } + + fmt.Println("Finding strategy") + dialer, err := finder.NewDialer(context.Background(), domainsFlag, finderConfig) + if err != nil { + log.Fatalf("Failed to find dialer: %v", err) + } + logDialer := transport.FuncStreamDialer(func(ctx context.Context, address string) (transport.StreamConn, error) { + conn, err := dialer.DialStream(ctx, address) + if err != nil { + debugLog.Printf("Failed to dial %v: %v\n", address, err) + } + return conn, err + }) + + listener, err := net.Listen("tcp", *addrFlag) + if err != nil { + log.Fatalf("Could not listen on address %v: %v", *addrFlag, err) + } + defer listener.Close() + fmt.Printf("Proxy listening on %v\n", listener.Addr().String()) + + server := http.Server{ + Handler: httpproxy.NewProxyHandler(logDialer), + ErrorLog: &debugLog, + } + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Fatalf("Error running web server: %v", err) + } + }() + + // Wait for interrupt signal to stop the proxy. + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + <-sig + fmt.Print("Shutting down") + // Gracefully shut down the server, with a 5s timeout. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("Failed to shutdown gracefully: %v", err) + } +} diff --git a/x/go.mod b/x/go.mod index e21d972f..3127652e 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.20 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212550-6cd87709dc1e + github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240215222624-bb3af0827224 github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b github.com/stretchr/testify v1.8.2 github.com/vishvananda/netlink v1.1.0 diff --git a/x/go.sum b/x/go.sum index 0806aa8a..af7d0359 100644 --- a/x/go.sum +++ b/x/go.sum @@ -1,5 +1,5 @@ -github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212550-6cd87709dc1e h1:56ZI48e68EYYb3m2slu3YJ6C+gWqh8v9bIWk+Bl9dfY= -github.com/Jigsaw-Code/outline-sdk v0.0.12-0.20240117212550-6cd87709dc1e/go.mod h1:9cEaF6sWWMzY8orcUI9pV5D0oFp2FZArTSyJiYtMQQs= +github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240215222624-bb3af0827224 h1:LUueXcQtgO2T7rsQalP3lUr3AdW/ddIjL4AdcBH5G9Q= +github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240215222624-bb3af0827224/go.mod h1:9cEaF6sWWMzY8orcUI9pV5D0oFp2FZArTSyJiYtMQQs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/x/smart/cache.go b/x/smart/cache.go new file mode 100644 index 00000000..4cd0b3f6 --- /dev/null +++ b/x/smart/cache.go @@ -0,0 +1,117 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package smart + +import ( + "context" + "strings" + "time" + + "github.com/Jigsaw-Code/outline-sdk/dns" + "golang.org/x/net/dns/dnsmessage" +) + +// canonicalName returns the domain name in canonical form. A name in canonical +// form is lowercase and fully qualified. Only US-ASCII letters are affected. See +// Section 6.2 in RFC 4034. +func canonicalName(s string) string { + return strings.Map(func(r rune) rune { + if r >= 'A' && r <= 'Z' { + r += 'a' - 'A' + } + return r + }, s) +} + +type cacheEntry struct { + key string + msg *dnsmessage.Message + expire time.Time +} + +// cacheResolver is a very simple caching [RoundTripper]. +// It doesn't use the response TTL and doesn't cache empty answers. +// It also doesn't dedup duplicate in-flight requests. +type cacheResolver struct { + resolver dns.Resolver + cache []cacheEntry +} + +var _ dns.Resolver = (*cacheResolver)(nil) + +func newCacheResolver(resolver dns.Resolver, numEntries int) dns.Resolver { + return &cacheResolver{resolver: resolver, cache: make([]cacheEntry, numEntries)} +} + +func (r *cacheResolver) removeExpired() { + now := time.Now() + last := 0 + for _, entry := range r.cache { + if entry.expire.After(now) { + r.cache[last] = entry + last++ + } + } + r.cache = r.cache[:last] +} + +func (r *cacheResolver) moveToFront(index int) { + entry := r.cache[index] + copy(r.cache[1:], r.cache[:index]) + r.cache[0] = entry +} + +func makeCacheKey(q dnsmessage.Question) string { + domainKey := canonicalName(q.Name.String()) + return strings.Join([]string{domainKey, q.Type.String(), q.Class.String()}, "|") +} + +func (r *cacheResolver) searchCache(key string) *dnsmessage.Message { + for ei, entry := range r.cache { + if entry.key == key { + r.moveToFront(ei) + // TODO: update TTLs + // TODO: make names match + return entry.msg + } + } + return nil +} + +func (r *cacheResolver) addToCache(key string, msg *dnsmessage.Message) { + newSize := len(r.cache) + 1 + if newSize > cap(r.cache) { + newSize = cap(r.cache) + } + r.cache = r.cache[:newSize] + copy(r.cache[1:], r.cache[:newSize-1]) + // TODO: copy and normalize names + r.cache[0] = cacheEntry{key: key, msg: msg, expire: time.Now().Add(60 * time.Second)} +} + +func (r *cacheResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { + r.removeExpired() + cacheKey := makeCacheKey(q) + if msg := r.searchCache(cacheKey); msg != nil { + return msg, nil + } + msg, err := r.resolver.Query(ctx, q) + if err != nil { + // TODO: cache NXDOMAIN. See https://datatracker.ietf.org/doc/html/rfc2308. + return nil, err + } + r.addToCache(cacheKey, msg) + return msg, nil +} diff --git a/x/smart/cname.go b/x/smart/cname.go new file mode 100644 index 00000000..a1eec5ab --- /dev/null +++ b/x/smart/cname.go @@ -0,0 +1,26 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !unix + +package smart + +import ( + "context" + "net" +) + +func lookupCNAME(ctx context.Context, domain string) (string, error) { + return net.DefaultResolver.LookupCNAME(ctx, domain) +} diff --git a/x/smart/cname_unix.go b/x/smart/cname_unix.go new file mode 100644 index 00000000..521ffd0a --- /dev/null +++ b/x/smart/cname_unix.go @@ -0,0 +1,72 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build unix + +package smart + +/* +#include +#include +#include +#include +*/ +import "C" + +import ( + "context" + "fmt" + "unsafe" +) + +func lookupCNAME(ctx context.Context, domain string) (string, error) { + type result struct { + cname string + err error + } + + results := make(chan result) + go func() { + cname, err := lookupCNAMEBlocking(domain) + results <- result{cname, err} + }() + + select { + case r := <-results: + return r.cname, r.err + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func lookupCNAMEBlocking(host string) (string, error) { + var hints C.struct_addrinfo + var result *C.struct_addrinfo + + chost := C.CString(host) + defer C.free(unsafe.Pointer(chost)) + + hints.ai_family = C.AF_UNSPEC + hints.ai_flags = C.AI_CANONNAME + + res := C.getaddrinfo(chost, nil, &hints, &result) + if res != 0 { + return "", fmt.Errorf("getaddrinfo error: %s", C.GoString(C.gai_strerror(res))) + } + defer C.freeaddrinfo(result) + + // Extract canonical name + cname := C.GoString(result.ai_canonname) + return cname, nil +} diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go new file mode 100644 index 00000000..d72b14ed --- /dev/null +++ b/x/smart/stream_dialer.go @@ -0,0 +1,594 @@ +// Copyright 2023 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package smart + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net" + "net/url" + "sync" + "time" + "unicode" + + "github.com/Jigsaw-Code/outline-sdk/dns" + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/config" + "golang.org/x/net/dns/dnsmessage" +) + +// To test one strategy: +// go run ./x/examples/smart-proxy -v -localAddr=localhost:1080 --transport="" --domain www.rferl.org --config=<(echo '{"dns": [{"https": {"name": "doh.sb"}}]}') + +// mixCase randomizes the case of the domain letters. +func mixCase(domain string) string { + var mixed []rune + for _, r := range domain { + if rand.Intn(2) == 0 { + mixed = append(mixed, unicode.ToLower(r)) + } else { + mixed = append(mixed, unicode.ToUpper(r)) + } + } + return string(mixed) +} + +func getARootNameserver() (string, error) { + nsList, err := net.LookupNS(".") + if err != nil { + return "", fmt.Errorf("could not get list of root nameservers: %v", err) + } + if len(nsList) == 0 { + return "", fmt.Errorf("empty list of root nameservers") + } + return nsList[0].Host, nil +} + +func fingerprint(pd transport.PacketDialer, sd transport.StreamDialer, testDomain string) { + rootNS, err := getARootNameserver() + if err != nil { + log.Fatalf("Failed to find root nameserver: %v", err) + } + + allNSIPs, err := net.LookupIP(rootNS) + if err != nil { + log.Fatalf("Failed to resolve root nameserver: %v", err) + } + ips := []net.IP{} + for _, ip := range allNSIPs { + if ip.To4() != nil { + ips = append(ips, ip) + break + } + } + for _, ip := range allNSIPs { + if ip.To16() != nil { + ips = append(ips, ip) + break + } + } + + q, err := dns.NewQuestion(testDomain, dnsmessage.TypeA) + if err != nil { + log.Fatalf("failed to parse domain name: %v", err) + } + for _, rootNSIP := range ips { + resolvedNS := net.JoinHostPort(rootNSIP.String(), "53") + for _, proto := range []string{"udp", "tcp"} { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var resolver dns.Resolver + switch proto { + case "tcp": + resolver = dns.NewTCPResolver(sd, resolvedNS) + default: + resolver = dns.NewUDPResolver(pd, resolvedNS) + } + + response, err := resolver.Query(ctx, *q) + fmt.Printf("%v:%v", proto, resolvedNS) + if err != nil { + fmt.Printf("; status=error: %v\n", err) + continue + } + if len(response.Answers) > 0 { + fmt.Printf("; status=unexpected answer (injected): %v ⚠️\n", response.Answers) + // TODO: use RCODE, CNAME and IPs as blocking fingerprint. + continue + } + if response.RCode != dnsmessage.RCodeSuccess { + fmt.Printf("; status=unexpected rcode (injected): %v ⚠️\n", response.Answers) + // TODO: use RCODE, CNAME and IPs as blocking fingerprint. + continue + } + fmt.Print("; status=ok (no injection) ✓\n") + } + } +} + +func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain string) ([]net.IP, error) { + requestDomain := mixCase(testDomain) + _, err := lookupCNAME(ctx, requestDomain) + if err != nil { + return nil, fmt.Errorf("could not get cname: %w", err) + } + ips, err := resolver.LookupIP(ctx, "ip", requestDomain) + if err != nil { + return nil, fmt.Errorf("failed to lookup IPs: %w", err) + } + if len(ips) == 0 { + return nil, fmt.Errorf("no ip answer") + } + for _, ip := range ips { + if ip.IsLoopback() { + return nil, fmt.Errorf("localhost ip: %v", ip) // -1 + } + if ip.IsPrivate() { + return nil, fmt.Errorf("private ip: %v", ip) // -1 + } + if ip.IsUnspecified() { + return nil, fmt.Errorf("zero ip: %v", ip) // -1 + } + // TODO: consider validating the IPs: fingerprint, hardcoded ground truth, trusted response, TLS connection. + } + return ips, nil +} + +func evaluateAddressResponse(response dnsmessage.Message, requestDomain string) ([]net.IP, error) { + if response.RCode != dnsmessage.RCodeSuccess { + return nil, fmt.Errorf("rcode is not success: %v", response.RCode) + } + var ips []net.IP + if len(response.Answers) == 0 { + return ips, errors.New("no answers") // -1 + } + for _, answer := range response.Answers { + if answer.Header.Type != dnsmessage.TypeA && answer.Header.Type != dnsmessage.TypeAAAA { + continue + } + var ip net.IP + switch rr := answer.Body.(type) { + case *dnsmessage.AResource: + ip = net.IP(rr.A[:]) + case *dnsmessage.AAAAResource: + ip = net.IP(rr.AAAA[:]) + default: + continue + } + if ip.IsLoopback() { + return nil, fmt.Errorf("localhost ip: %v", ip) // -1 + } + if ip.IsPrivate() { + return nil, fmt.Errorf("private ip: %v", ip) // -1 + } + if ip.IsUnspecified() { + return nil, fmt.Errorf("zero ip: %v", ip) // -1 + } + ips = append(ips, ip) + } + if len(ips) == 0 { + return ips, fmt.Errorf("no ip answer: %v", response.Answers) // -1 + } + // All popular recursive resolvers we tested maintain the domain case of the request. + // Note that this is not the case of authoritative resolvers. Some of them will return + // a fully normalized domain name, or normalize part of it. + if response.Answers[0].Header.Name.String() != requestDomain { + return ips, fmt.Errorf("domain mismatch: got %v, expected %v", response.Answers[0].Header.Name, requestDomain) // -0.5 or +0.5 if match + } + return ips, nil +} + +func evaluateCNAMEResponse(response dnsmessage.Message, requestDomain string) error { + if response.RCode != dnsmessage.RCodeSuccess { + return fmt.Errorf("rcode is not success: %v", response.RCode) + } + if len(response.Answers) == 0 { + var numSOA int + for _, answer := range response.Authorities { + if _, ok := answer.Body.(*dnsmessage.SOAResource); ok { + numSOA++ + } + } + if numSOA != 1 { + return fmt.Errorf("SOA records is %v, expected 1", numSOA) + } + return nil + } + var cname string + for _, answer := range response.Answers { + if answer.Header.Type != dnsmessage.TypeCNAME { + return fmt.Errorf("bad answer type: %v", answer.Header.Type) + } + if rr, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { + if cname != "" { + return fmt.Errorf("found too many CNAMEs: %v %v", cname, rr.CNAME) + } + cname = rr.CNAME.String() + } + } + if cname == "" { + return fmt.Errorf("no CNAME in answers") + } + return nil +} + +type StrategyFinder struct { + TestTimeout time.Duration + LogWriter io.Writer + StreamDialer transport.StreamDialer + PacketDialer transport.PacketDialer + logMu sync.Mutex +} + +func (f *StrategyFinder) log(format string, a ...any) { + if f.LogWriter != nil { + f.logMu.Lock() + defer f.logMu.Unlock() + fmt.Fprintf(f.LogWriter, format, a...) + } +} + +func (f *StrategyFinder) testDNSResolver(baseCtx context.Context, resolver dns.Resolver, testDomain string) ([]net.IP, error) { + // We special case the system resolver, since we can't get a dns.RoundTripper. + if resolver == nil { + ctx, cancel := context.WithTimeout(baseCtx, f.TestTimeout) + defer cancel() + return evaluateNetResolver(ctx, new(net.Resolver), testDomain) + } + + requestDomain := mixCase(testDomain) + + q, err := dns.NewQuestion(requestDomain, dnsmessage.TypeA) + if err != nil { + return nil, fmt.Errorf("failed to create question: %v", err) + } + ctxA, cancelA := context.WithTimeout(baseCtx, f.TestTimeout) + defer cancelA() + response, err := resolver.Query(ctxA, *q) + if err != nil { + return nil, fmt.Errorf("request for A query failed: %w", err) + } + ips, err := evaluateAddressResponse(*response, requestDomain) + if err != nil { + return ips, fmt.Errorf("failed A test: %w", err) + } + // TODO(fortuna): Consider testing whether we can establish a TCP connection to ip:443. + + q, err = dns.NewQuestion(requestDomain, dnsmessage.TypeCNAME) + if err != nil { + return nil, fmt.Errorf("failed to create question: %v", err) + } + ctxCNAME, cancelCNAME := context.WithTimeout(baseCtx, f.TestTimeout) + defer cancelCNAME() + response, err = resolver.Query(ctxCNAME, *q) + if err != nil { + return nil, fmt.Errorf("request for CNAME query failed: %w", err) + } + err = evaluateCNAMEResponse(*response, requestDomain) + if err != nil { + return nil, fmt.Errorf("failed CNAME test: %w", err) + } + return ips, nil +} + +type httpsEntryJSON struct { + Name string `json:"name,omitempty"` + Address string `json:"address,omitempty"` +} + +type tlsEntryJSON struct { + Name string `json:"name,omitempty"` + Address string `json:"address,omitempty"` +} + +type udpEntryJSON struct { + Address string `json:"address,omitempty"` +} + +type tcpEntryJSON struct { + Address string `json:"address,omitempty"` +} + +type dnsEntryJSON struct { + System *struct{} `json:"system,omitempty"` + HTTPS *httpsEntryJSON `json:"https,omitempty"` + TLS *tlsEntryJSON `json:"tls,omitempty"` + UDP *udpEntryJSON `json:"udp,omitempty"` + TCP *tcpEntryJSON `json:"tcp,omitempty"` +} + +type configJSON struct { + DNS []dnsEntryJSON `json:"dns,omitempty"` + TLS []string `json:"tls,omitempty"` +} + +func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolver, error) { + if entry.System != nil { + return nil, nil + } else if cfg := entry.HTTPS; cfg != nil { + if cfg.Name == "" { + return nil, fmt.Errorf("https entry has empty server name") + } + serverAddr := cfg.Address + if serverAddr == "" { + serverAddr = cfg.Name + } + _, port, err := net.SplitHostPort(serverAddr) + if err != nil { + serverAddr = net.JoinHostPort(serverAddr, "443") + port = "443" + } + dohURL := url.URL{Scheme: "https", Host: net.JoinHostPort(cfg.Name, port), Path: "/dns-query"} + return dns.NewHTTPSResolver(f.StreamDialer, serverAddr, dohURL.String()), nil + } else if cfg := entry.TLS; cfg != nil { + if cfg.Name == "" { + return nil, fmt.Errorf("tls entry has empty server name") + } + serverAddr := cfg.Address + if serverAddr == "" { + serverAddr = cfg.Name + } + _, _, err := net.SplitHostPort(serverAddr) + if err != nil { + serverAddr = net.JoinHostPort(serverAddr, "853") + } + return dns.NewTLSResolver(f.StreamDialer, serverAddr, cfg.Name), nil + } else if cfg := entry.TCP; cfg != nil { + if cfg.Address == "" { + return nil, fmt.Errorf("tcp entry has empty server address") + } + host, port, err := net.SplitHostPort(cfg.Address) + if err != nil { + host = cfg.Address + port = "53" + } + serverAddr := net.JoinHostPort(host, port) + return dns.NewTCPResolver(f.StreamDialer, serverAddr), nil + } else if cfg := entry.UDP; cfg != nil { + if cfg.Address == "" { + return nil, fmt.Errorf("udp entry has empty server address") + } + host, port, err := net.SplitHostPort(cfg.Address) + if err != nil { + host = cfg.Address + port = "53" + } + serverAddr := net.JoinHostPort(host, port) + return dns.NewUDPResolver(f.PacketDialer, serverAddr), nil + } else { + return nil, errors.New("invalid DNS entry") + } +} + +type resolverEntry struct { + ID string + Resolver dns.Resolver +} + +func (f *StrategyFinder) dnsConfigToRoundTrippers(dnsConfig []dnsEntryJSON) ([]resolverEntry, error) { + if len(dnsConfig) == 0 { + return nil, errors.New("no DNS config entry") + } + rts := make([]resolverEntry, 0, len(dnsConfig)) + for ei, entry := range dnsConfig { + idBytes, err := json.Marshal(entry) + if err != nil { + return nil, fmt.Errorf("cannot serialize entry %v: %w", ei, err) + } + id := string(idBytes) + resolver, err := f.newDNSResolverFromEntry(entry) + if err != nil { + return nil, fmt.Errorf("failed to process entry %v: %w", ei, err) + } + rts = append(rts, resolverEntry{ID: id, Resolver: resolver}) + } + return rts, nil +} + +// Returns a [context.Context] that is already done. +func newDoneContext() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx +} + +func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { + resolvers, err := f.dnsConfigToRoundTrippers(dnsConfig) + if err != nil { + return nil, err + } + type testResult struct { + ID string + Resolver dns.Resolver + Err error + } + // Communicates the result of each test. + resultChan := make(chan testResult) + // Indicates to tests that the search is done, so they don't get stuck writing to the results channel that will no longer be read. + searchCtx, searchDone := context.WithCancel(context.Background()) + defer searchDone() + // Used to space out each test. The initial value is done because there's no wait needed. + waitCtx := newDoneContext() + // Next entry to start testing. + nextResolver := 0 + // How many test entries are not done. + resolversToTest := len(resolvers) + for resolversToTest > 0 { + if nextResolver == len(resolvers) { + // No more tests to start. Make sure the select doesn't trigger on waitCtx. + waitCtx = searchCtx + } + select { + case <-waitCtx.Done(): + // Start a new test. + entry := resolvers[nextResolver] + nextResolver++ + var waitDone context.CancelFunc + waitCtx, waitDone = context.WithTimeout(searchCtx, 250*time.Millisecond) + go func(entry resolverEntry, testDone context.CancelFunc) { + defer testDone() + for _, testDomain := range testDomains { + select { + case <-searchCtx.Done(): + return + default: + } + f.log("🏃 run dns: %v (domain: %v)\n", entry.ID, testDomain) + startTime := time.Now() + ips, err := f.testDNSResolver(searchCtx, entry.Resolver, testDomain) + duration := time.Since(startTime) + status := "ok ✅" + if err != nil { + status = fmt.Sprintf("%v ❌", err) + } + f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", entry.ID, testDomain, duration, ips, status) + if err != nil { + select { + case <-searchCtx.Done(): + return + case resultChan <- testResult{ID: entry.ID, Resolver: entry.Resolver, Err: err}: + return + } + } + } + select { + case <-searchCtx.Done(): + case resultChan <- testResult{ID: entry.ID, Resolver: entry.Resolver, Err: nil}: + } + }(entry, waitDone) + + case result := <-resultChan: + resolversToTest-- + // Process the result of a test. + if result.Err != nil { + continue + } + f.log("✅ selected resolver %v\n", result.ID) + // Tested all domains on this resolver. Return + if result.Resolver != nil { + return result.Resolver, nil + } else { + return nil, nil + } + } + } + return nil, errors.New("could not find working resolver") +} + +func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.StreamDialer, tlsConfig []string) (transport.StreamDialer, error) { + if len(tlsConfig) == 0 { + return nil, errors.New("config for TLS is empty. Please specify at least one transport") + } + for _, transportCfg := range tlsConfig { + for di, testDomain := range testDomains { + testAddr := net.JoinHostPort(testDomain, "443") + f.log(" tls=%v (domain: %v)", transportCfg, testDomain) + + tlsDialer, err := config.WrapStreamDialer(baseDialer, transportCfg) + if err != nil { + f.log("; wrap_error=%v ❌\n", err) + break + } + ctx, cancel := context.WithTimeout(context.Background(), f.TestTimeout) + defer cancel() + testConn, err := tlsDialer.DialStream(ctx, testAddr) + if err != nil { + f.log("; dial_error=%v ❌\n", err) + break + } + tlsConn := tls.Client(testConn, &tls.Config{ServerName: testDomain}) + err = tlsConn.HandshakeContext(ctx) + tlsConn.Close() + if err != nil { + f.log("; handshake=%v ❌\n", err) + break + } + f.log("; status=ok ✅\n") + if di+1 < len(testDomains) { + // More domains to test + continue + } + return transport.FuncStreamDialer(func(ctx context.Context, raddr string) (transport.StreamConn, error) { + _, portStr, err := net.SplitHostPort(raddr) + if err != nil { + return nil, fmt.Errorf("failed to parse address: %w", err) + } + portNum, err := net.DefaultResolver.LookupPort(ctx, "tcp", portStr) + if err != nil { + return nil, fmt.Errorf("could not resolve port: %w", err) + } + selectedDialer := baseDialer + if portNum == 443 || portNum == 853 { + selectedDialer = tlsDialer + } + return selectedDialer.DialStream(ctx, raddr) + }), nil + } + } + return nil, errors.New("could not find TLS strategy") +} + +// makeFullyQualified makes the domain fully-qualified, ending on a dot ("."). +// This is useful in domain resolution to avoid ambiguity with local domains +// and domain search. +func makeFullyQualified(domain string) string { + if len(domain) > 0 && domain[len(domain)-1] == '.' { + return domain + } + return domain + "." +} + +// NewDialer uses the config in configBytes to search for a strategy that unblocks all of the testDomains, returning a dialer with the found strategy. +// It returns an error if no strategy was found that unblocks the testDomains. +// The testDomains must be domains with a TLS service running on port 443. +func (f *StrategyFinder) NewDialer(ctx context.Context, testDomains []string, configBytes []byte) (transport.StreamDialer, error) { + var parsedConfig configJSON + err := json.Unmarshal(configBytes, &parsedConfig) + if err != nil { + return nil, fmt.Errorf("failed to parse config: %v", err) + } + + // Make domain fully-qualified to prevent confusing domain search. + testDomains = append(make([]string, 0, len(testDomains)), testDomains...) + for di, domain := range testDomains { + testDomains[di] = makeFullyQualified(domain) + } + + dnsRT, err := f.findDNS(testDomains, parsedConfig.DNS) + if err != nil { + return nil, err + } + var dnsDialer transport.StreamDialer + if dnsRT == nil { + if _, ok := f.StreamDialer.(*transport.TCPDialer); !ok { + return nil, fmt.Errorf("cannot use system resolver with base dialer of type %T", f.StreamDialer) + } + dnsDialer = f.StreamDialer + } else { + dnsRT = newCacheResolver(dnsRT, 100) + dnsDialer = dns.NewStreamDialer(dnsRT, f.StreamDialer) + } + + if len(parsedConfig.TLS) == 0 { + return dnsDialer, nil + } + return f.findTLS(testDomains, dnsDialer, parsedConfig.TLS) +} From 5fb6dece360fbc09e8cb6675ba825919853ab946 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 15:00:38 -0500 Subject: [PATCH 03/25] Fix config --- x/examples/smart-proxy/config.json | 2 -- 1 file changed, 2 deletions(-) diff --git a/x/examples/smart-proxy/config.json b/x/examples/smart-proxy/config.json index 48535d38..082c3911 100644 --- a/x/examples/smart-proxy/config.json +++ b/x/examples/smart-proxy/config.json @@ -1,7 +1,5 @@ { "dns": [ - {"udp": {"address": "8.8.8.8"}}, - {"https": {"name": "2620:fe::fe"}, "//": "Quad9"}, {"https": {"name": "9.9.9.9"}}, {"https": {"name": "149.112.112.112"}}, From 5dbaaf6da1596bfe13d6e89b39243f605c70c72b Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 16:57:01 -0500 Subject: [PATCH 04/25] Clean up --- x/go.mod | 2 +- x/go.sum | 4 +- x/smart/cache.go | 27 +++-- x/smart/stream_dialer.go | 215 +++++++++++++++------------------------ 4 files changed, 105 insertions(+), 143 deletions(-) diff --git a/x/go.mod b/x/go.mod index 3127652e..0bba0e1f 100644 --- a/x/go.mod +++ b/x/go.mod @@ -3,7 +3,7 @@ module github.com/Jigsaw-Code/outline-sdk/x go 1.20 require ( - github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240215222624-bb3af0827224 + github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240216220040-f741c57bf854 github.com/songgao/water v0.0.0-20190725173103-fd331bda3f4b github.com/stretchr/testify v1.8.2 github.com/vishvananda/netlink v1.1.0 diff --git a/x/go.sum b/x/go.sum index af7d0359..c0f2c440 100644 --- a/x/go.sum +++ b/x/go.sum @@ -1,5 +1,5 @@ -github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240215222624-bb3af0827224 h1:LUueXcQtgO2T7rsQalP3lUr3AdW/ddIjL4AdcBH5G9Q= -github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240215222624-bb3af0827224/go.mod h1:9cEaF6sWWMzY8orcUI9pV5D0oFp2FZArTSyJiYtMQQs= +github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240216220040-f741c57bf854 h1:SXp/tNjb70hpjF/MXAuLDkgCttlRA9qxLR7FCosGydg= +github.com/Jigsaw-Code/outline-sdk v0.0.14-0.20240216220040-f741c57bf854/go.mod h1:9cEaF6sWWMzY8orcUI9pV5D0oFp2FZArTSyJiYtMQQs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/x/smart/cache.go b/x/smart/cache.go index 4cd0b3f6..1bc218d4 100644 --- a/x/smart/cache.go +++ b/x/smart/cache.go @@ -17,6 +17,7 @@ package smart import ( "context" "strings" + "sync" "time" "github.com/Jigsaw-Code/outline-sdk/dns" @@ -41,12 +42,13 @@ type cacheEntry struct { expire time.Time } -// cacheResolver is a very simple caching [RoundTripper]. +// cacheResolver is a very simple caching [dns.Resolver]. // It doesn't use the response TTL and doesn't cache empty answers. // It also doesn't dedup duplicate in-flight requests. type cacheResolver struct { resolver dns.Resolver cache []cacheEntry + mux sync.Mutex } var _ dns.Resolver = (*cacheResolver)(nil) @@ -55,9 +57,11 @@ func newCacheResolver(resolver dns.Resolver, numEntries int) dns.Resolver { return &cacheResolver{resolver: resolver, cache: make([]cacheEntry, numEntries)} } -func (r *cacheResolver) removeExpired() { +func (r *cacheResolver) RemoveExpired() { now := time.Now() last := 0 + r.mux.Lock() + defer r.mux.Unlock() for _, entry := range r.cache { if entry.expire.After(now) { r.cache[last] = entry @@ -78,7 +82,9 @@ func makeCacheKey(q dnsmessage.Question) string { return strings.Join([]string{domainKey, q.Type.String(), q.Class.String()}, "|") } -func (r *cacheResolver) searchCache(key string) *dnsmessage.Message { +func (r *cacheResolver) SearchCache(key string) *dnsmessage.Message { + r.mux.Lock() + defer r.mux.Unlock() for ei, entry := range r.cache { if entry.key == key { r.moveToFront(ei) @@ -90,7 +96,9 @@ func (r *cacheResolver) searchCache(key string) *dnsmessage.Message { return nil } -func (r *cacheResolver) addToCache(key string, msg *dnsmessage.Message) { +func (r *cacheResolver) AddToCache(key string, msg *dnsmessage.Message) { + r.mux.Lock() + defer r.mux.Unlock() newSize := len(r.cache) + 1 if newSize > cap(r.cache) { newSize = cap(r.cache) @@ -101,17 +109,20 @@ func (r *cacheResolver) addToCache(key string, msg *dnsmessage.Message) { r.cache[0] = cacheEntry{key: key, msg: msg, expire: time.Now().Add(60 * time.Second)} } +// Query implements [dns.Resolver]. func (r *cacheResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { - r.removeExpired() + r.RemoveExpired() cacheKey := makeCacheKey(q) - if msg := r.searchCache(cacheKey); msg != nil { + if msg := r.SearchCache(cacheKey); msg != nil { return msg, nil } msg, err := r.resolver.Query(ctx, q) if err != nil { - // TODO: cache NXDOMAIN. See https://datatracker.ietf.org/doc/html/rfc2308. + // TODO: cache server failures. See https://datatracker.ietf.org/doc/html/rfc2308. return nil, err } - r.addToCache(cacheKey, msg) + if msg.RCode == dnsmessage.RCodeSuccess || msg.RCode == dnsmessage.RCodeNameError { + r.AddToCache(cacheKey, msg) + } return msg, nil } diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index d72b14ed..881a26eb 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package smart provides utilities to dynamically find serverless strategies for circumvention. package smart import ( @@ -21,7 +22,6 @@ import ( "errors" "fmt" "io" - "log" "math/rand" "net" "net/url" @@ -51,79 +51,6 @@ func mixCase(domain string) string { return string(mixed) } -func getARootNameserver() (string, error) { - nsList, err := net.LookupNS(".") - if err != nil { - return "", fmt.Errorf("could not get list of root nameservers: %v", err) - } - if len(nsList) == 0 { - return "", fmt.Errorf("empty list of root nameservers") - } - return nsList[0].Host, nil -} - -func fingerprint(pd transport.PacketDialer, sd transport.StreamDialer, testDomain string) { - rootNS, err := getARootNameserver() - if err != nil { - log.Fatalf("Failed to find root nameserver: %v", err) - } - - allNSIPs, err := net.LookupIP(rootNS) - if err != nil { - log.Fatalf("Failed to resolve root nameserver: %v", err) - } - ips := []net.IP{} - for _, ip := range allNSIPs { - if ip.To4() != nil { - ips = append(ips, ip) - break - } - } - for _, ip := range allNSIPs { - if ip.To16() != nil { - ips = append(ips, ip) - break - } - } - - q, err := dns.NewQuestion(testDomain, dnsmessage.TypeA) - if err != nil { - log.Fatalf("failed to parse domain name: %v", err) - } - for _, rootNSIP := range ips { - resolvedNS := net.JoinHostPort(rootNSIP.String(), "53") - for _, proto := range []string{"udp", "tcp"} { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - var resolver dns.Resolver - switch proto { - case "tcp": - resolver = dns.NewTCPResolver(sd, resolvedNS) - default: - resolver = dns.NewUDPResolver(pd, resolvedNS) - } - - response, err := resolver.Query(ctx, *q) - fmt.Printf("%v:%v", proto, resolvedNS) - if err != nil { - fmt.Printf("; status=error: %v\n", err) - continue - } - if len(response.Answers) > 0 { - fmt.Printf("; status=unexpected answer (injected): %v ⚠️\n", response.Answers) - // TODO: use RCODE, CNAME and IPs as blocking fingerprint. - continue - } - if response.RCode != dnsmessage.RCodeSuccess { - fmt.Printf("; status=unexpected rcode (injected): %v ⚠️\n", response.Answers) - // TODO: use RCODE, CNAME and IPs as blocking fingerprint. - continue - } - fmt.Print("; status=ok (no injection) ✓\n") - } - } -} - func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain string) ([]net.IP, error) { requestDomain := mixCase(testDomain) _, err := lookupCNAME(ctx, requestDomain) @@ -152,15 +79,9 @@ func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain return ips, nil } -func evaluateAddressResponse(response dnsmessage.Message, requestDomain string) ([]net.IP, error) { - if response.RCode != dnsmessage.RCodeSuccess { - return nil, fmt.Errorf("rcode is not success: %v", response.RCode) - } +func getIPs(answers []dnsmessage.Resource) []net.IP { var ips []net.IP - if len(response.Answers) == 0 { - return ips, errors.New("no answers") // -1 - } - for _, answer := range response.Answers { + for _, answer := range answers { if answer.Header.Type != dnsmessage.TypeA && answer.Header.Type != dnsmessage.TypeAAAA { continue } @@ -173,25 +94,38 @@ func evaluateAddressResponse(response dnsmessage.Message, requestDomain string) default: continue } + ips = append(ips, ip) + } + return ips +} + +func evaluateAddressResponse(response dnsmessage.Message, requestDomain string) ([]net.IP, error) { + if response.RCode != dnsmessage.RCodeSuccess { + return nil, fmt.Errorf("rcode is not success: %v", response.RCode) + } + if len(response.Answers) == 0 { + return nil, errors.New("no answers") + } + ips := getIPs(response.Answers) + if len(ips) == 0 { + return ips, fmt.Errorf("no ip answer: %v", response.Answers) + } + for _, ip := range ips { if ip.IsLoopback() { - return nil, fmt.Errorf("localhost ip: %v", ip) // -1 + return nil, fmt.Errorf("localhost ip: %v", ip) } if ip.IsPrivate() { - return nil, fmt.Errorf("private ip: %v", ip) // -1 + return nil, fmt.Errorf("private ip: %v", ip) } if ip.IsUnspecified() { - return nil, fmt.Errorf("zero ip: %v", ip) // -1 + return nil, fmt.Errorf("zero ip: %v", ip) } - ips = append(ips, ip) - } - if len(ips) == 0 { - return ips, fmt.Errorf("no ip answer: %v", response.Answers) // -1 } // All popular recursive resolvers we tested maintain the domain case of the request. // Note that this is not the case of authoritative resolvers. Some of them will return // a fully normalized domain name, or normalize part of it. if response.Answers[0].Header.Name.String() != requestDomain { - return ips, fmt.Errorf("domain mismatch: got %v, expected %v", response.Answers[0].Header.Name, requestDomain) // -0.5 or +0.5 if match + return ips, fmt.Errorf("domain mismatch: got %v, expected %v", response.Answers[0].Header.Name, requestDomain) } return ips, nil } @@ -246,9 +180,9 @@ func (f *StrategyFinder) log(format string, a ...any) { } } -func (f *StrategyFinder) testDNSResolver(baseCtx context.Context, resolver dns.Resolver, testDomain string) ([]net.IP, error) { +func (f *StrategyFinder) testDNSResolver(baseCtx context.Context, resolver *smartResolver, testDomain string) ([]net.IP, error) { // We special case the system resolver, since we can't get a dns.RoundTripper. - if resolver == nil { + if resolver.Resolver == nil { ctx, cancel := context.WithTimeout(baseCtx, f.TestTimeout) defer cancel() return evaluateNetResolver(ctx, new(net.Resolver), testDomain) @@ -266,12 +200,22 @@ func (f *StrategyFinder) testDNSResolver(baseCtx context.Context, resolver dns.R if err != nil { return nil, fmt.Errorf("request for A query failed: %w", err) } + + if resolver.Secure { + // For secure DNS, we just need to check if we can communicate with it. + // No need to analyze content, since it is protected by TLS. + return getIPs(response.Answers), nil + } + ips, err := evaluateAddressResponse(*response, requestDomain) if err != nil { return ips, fmt.Errorf("failed A test: %w", err) } + // TODO(fortuna): Consider testing whether we can establish a TCP connection to ip:443. + // Run CNAME test, which helps in case the resolver returns a public IP, as is the + // case in China. q, err = dns.NewQuestion(requestDomain, dnsmessage.TypeCNAME) if err != nil { return nil, fmt.Errorf("failed to create question: %v", err) @@ -290,20 +234,26 @@ func (f *StrategyFinder) testDNSResolver(baseCtx context.Context, resolver dns.R } type httpsEntryJSON struct { - Name string `json:"name,omitempty"` + // Domain name of the host. + Name string `json:"name,omitempty"` + // Host:port. Defaults to Name:443. Address string `json:"address,omitempty"` } type tlsEntryJSON struct { - Name string `json:"name,omitempty"` + // Domain name of the host. + Name string `json:"name,omitempty"` + // Host:port. Defaults to Name:853. Address string `json:"address,omitempty"` } type udpEntryJSON struct { + // Host:port. Address string `json:"address,omitempty"` } type tcpEntryJSON struct { + // Host:port. Address string `json:"address,omitempty"` } @@ -320,12 +270,14 @@ type configJSON struct { TLS []string `json:"tls,omitempty"` } -func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolver, error) { +// newDNSResolverFromEntry creates a [dns.Resolver] based on the config, returning the resolver +// a boolean indicating whether the resolver is secure (TLS, HTTPS) and a possible error. +func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolver, bool, error) { if entry.System != nil { - return nil, nil + return nil, false, nil } else if cfg := entry.HTTPS; cfg != nil { if cfg.Name == "" { - return nil, fmt.Errorf("https entry has empty server name") + return nil, true, fmt.Errorf("https entry has empty server name") } serverAddr := cfg.Address if serverAddr == "" { @@ -337,10 +289,10 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv port = "443" } dohURL := url.URL{Scheme: "https", Host: net.JoinHostPort(cfg.Name, port), Path: "/dns-query"} - return dns.NewHTTPSResolver(f.StreamDialer, serverAddr, dohURL.String()), nil + return dns.NewHTTPSResolver(f.StreamDialer, serverAddr, dohURL.String()), true, nil } else if cfg := entry.TLS; cfg != nil { if cfg.Name == "" { - return nil, fmt.Errorf("tls entry has empty server name") + return nil, true, fmt.Errorf("tls entry has empty server name") } serverAddr := cfg.Address if serverAddr == "" { @@ -350,10 +302,10 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv if err != nil { serverAddr = net.JoinHostPort(serverAddr, "853") } - return dns.NewTLSResolver(f.StreamDialer, serverAddr, cfg.Name), nil + return dns.NewTLSResolver(f.StreamDialer, serverAddr, cfg.Name), true, nil } else if cfg := entry.TCP; cfg != nil { if cfg.Address == "" { - return nil, fmt.Errorf("tcp entry has empty server address") + return nil, false, fmt.Errorf("tcp entry has empty server address") } host, port, err := net.SplitHostPort(cfg.Address) if err != nil { @@ -361,10 +313,10 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv port = "53" } serverAddr := net.JoinHostPort(host, port) - return dns.NewTCPResolver(f.StreamDialer, serverAddr), nil + return dns.NewTCPResolver(f.StreamDialer, serverAddr), false, nil } else if cfg := entry.UDP; cfg != nil { if cfg.Address == "" { - return nil, fmt.Errorf("udp entry has empty server address") + return nil, false, fmt.Errorf("udp entry has empty server address") } host, port, err := net.SplitHostPort(cfg.Address) if err != nil { @@ -372,33 +324,34 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv port = "53" } serverAddr := net.JoinHostPort(host, port) - return dns.NewUDPResolver(f.PacketDialer, serverAddr), nil + return dns.NewUDPResolver(f.PacketDialer, serverAddr), false, nil } else { - return nil, errors.New("invalid DNS entry") + return nil, false, errors.New("invalid DNS entry") } } -type resolverEntry struct { - ID string - Resolver dns.Resolver +type smartResolver struct { + dns.Resolver + ID string + Secure bool } -func (f *StrategyFinder) dnsConfigToRoundTrippers(dnsConfig []dnsEntryJSON) ([]resolverEntry, error) { +func (f *StrategyFinder) dnsConfigToResolver(dnsConfig []dnsEntryJSON) ([]*smartResolver, error) { if len(dnsConfig) == 0 { return nil, errors.New("no DNS config entry") } - rts := make([]resolverEntry, 0, len(dnsConfig)) + rts := make([]*smartResolver, 0, len(dnsConfig)) for ei, entry := range dnsConfig { idBytes, err := json.Marshal(entry) if err != nil { return nil, fmt.Errorf("cannot serialize entry %v: %w", ei, err) } id := string(idBytes) - resolver, err := f.newDNSResolverFromEntry(entry) + resolver, isSecure, err := f.newDNSResolverFromEntry(entry) if err != nil { return nil, fmt.Errorf("failed to process entry %v: %w", ei, err) } - rts = append(rts, resolverEntry{ID: id, Resolver: resolver}) + rts = append(rts, &smartResolver{Resolver: resolver, ID: id, Secure: isSecure}) } return rts, nil } @@ -411,13 +364,12 @@ func newDoneContext() context.Context { } func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { - resolvers, err := f.dnsConfigToRoundTrippers(dnsConfig) + resolvers, err := f.dnsConfigToResolver(dnsConfig) if err != nil { return nil, err } type testResult struct { - ID string - Resolver dns.Resolver + Resolver *smartResolver Err error } // Communicates the result of each test. @@ -443,7 +395,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) nextResolver++ var waitDone context.CancelFunc waitCtx, waitDone = context.WithTimeout(searchCtx, 250*time.Millisecond) - go func(entry resolverEntry, testDone context.CancelFunc) { + go func(resolver *smartResolver, testDone context.CancelFunc) { defer testDone() for _, testDomain := range testDomains { select { @@ -451,27 +403,27 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) return default: } - f.log("🏃 run dns: %v (domain: %v)\n", entry.ID, testDomain) + f.log("🏃 run dns: %v (domain: %v)\n", resolver.ID, testDomain) startTime := time.Now() - ips, err := f.testDNSResolver(searchCtx, entry.Resolver, testDomain) + ips, err := f.testDNSResolver(searchCtx, resolver, testDomain) duration := time.Since(startTime) status := "ok ✅" if err != nil { status = fmt.Sprintf("%v ❌", err) } - f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", entry.ID, testDomain, duration, ips, status) + f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) if err != nil { select { case <-searchCtx.Done(): return - case resultChan <- testResult{ID: entry.ID, Resolver: entry.Resolver, Err: err}: + case resultChan <- testResult{Resolver: resolver, Err: err}: return } } } select { case <-searchCtx.Done(): - case resultChan <- testResult{ID: entry.ID, Resolver: entry.Resolver, Err: nil}: + case resultChan <- testResult{Resolver: resolver, Err: nil}: } }(entry, waitDone) @@ -481,13 +433,9 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if result.Err != nil { continue } - f.log("✅ selected resolver %v\n", result.ID) - // Tested all domains on this resolver. Return - if result.Resolver != nil { - return result.Resolver, nil - } else { - return nil, nil - } + f.log("✅ selected resolver %v\n", result.Resolver.ID) + // Tested all domains on this resolver. Unwrap and return. + return result.Resolver.Resolver, nil } } return nil, errors.New("could not find working resolver") @@ -572,19 +520,22 @@ func (f *StrategyFinder) NewDialer(ctx context.Context, testDomains []string, co testDomains[di] = makeFullyQualified(domain) } - dnsRT, err := f.findDNS(testDomains, parsedConfig.DNS) + resolver, err := f.findDNS(testDomains, parsedConfig.DNS) if err != nil { return nil, err } var dnsDialer transport.StreamDialer - if dnsRT == nil { + if resolver == nil { if _, ok := f.StreamDialer.(*transport.TCPDialer); !ok { return nil, fmt.Errorf("cannot use system resolver with base dialer of type %T", f.StreamDialer) } dnsDialer = f.StreamDialer } else { - dnsRT = newCacheResolver(dnsRT, 100) - dnsDialer = dns.NewStreamDialer(dnsRT, f.StreamDialer) + resolver = newCacheResolver(resolver, 100) + dnsDialer, err = dns.NewStreamDialer(resolver, f.StreamDialer) + if err != nil { + return nil, fmt.Errorf("dns.NewStreamDialer failed: %w", err) + } } if len(parsedConfig.TLS) == 0 { From 83162e977e5caea6e2d688e5663ff48216fd2e43 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 17:49:23 -0500 Subject: [PATCH 05/25] Slight delay --- transport/happyeyeballs_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/transport/happyeyeballs_test.go b/transport/happyeyeballs_test.go index b36711b2..4029123e 100644 --- a/transport/happyeyeballs_test.go +++ b/transport/happyeyeballs_test.go @@ -315,6 +315,7 @@ func ExampleNewParallelHappyEyeballsResolveFunc() { dialer := HappyEyeballsStreamDialer{ Dialer: FuncStreamDialer(func(ctx context.Context, addr string) (StreamConn, error) { ips = append(ips, netip.MustParseAddrPort(addr).Addr()) + time.Sleep(1 * time.Millisecond) return nil, errors.New("not implemented") }), Resolve: NewParallelHappyEyeballsResolveFunc( From d4227d5242222dd0b8d5772b6e7b600aafa144f2 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 18:13:25 -0500 Subject: [PATCH 06/25] Simplify --- x/smart/stream_dialer.go | 58 +++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 881a26eb..b0840bfb 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -356,11 +356,11 @@ func (f *StrategyFinder) dnsConfigToResolver(dnsConfig []dnsEntryJSON) ([]*smart return rts, nil } -// Returns a [context.Context] that is already done. -func newDoneContext() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return ctx +// Returns a read channel that is already closed. +func newClosedChanel() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch } func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { @@ -373,28 +373,28 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) Err error } // Communicates the result of each test. - resultChan := make(chan testResult) + resultChan := make(chan testResult, len(resolvers)) // Indicates to tests that the search is done, so they don't get stuck writing to the results channel that will no longer be read. searchCtx, searchDone := context.WithCancel(context.Background()) defer searchDone() // Used to space out each test. The initial value is done because there's no wait needed. - waitCtx := newDoneContext() - // Next entry to start testing. + waitCh := newClosedChanel() nextResolver := 0 - // How many test entries are not done. - resolversToTest := len(resolvers) - for resolversToTest > 0 { - if nextResolver == len(resolvers) { - // No more tests to start. Make sure the select doesn't trigger on waitCtx. - waitCtx = searchCtx - } + for resolversToTest := len(resolvers); resolversToTest > 0; { select { - case <-waitCtx.Done(): - // Start a new test. - entry := resolvers[nextResolver] + // Ready to start testing another resolver. + case <-waitCh: + resolver := resolvers[nextResolver] nextResolver++ - var waitDone context.CancelFunc - waitCtx, waitDone = context.WithTimeout(searchCtx, 250*time.Millisecond) + + waitCtx, waitDone := context.WithTimeout(searchCtx, 250*time.Millisecond) + if nextResolver == len(resolvers) { + // Done with resolvers. No longer trigger on waitCh. + waitCh = nil + } else { + waitCh = waitCtx.Done() + } + go func(resolver *smartResolver, testDone context.CancelFunc) { defer testDone() for _, testDomain := range testDomains { @@ -403,29 +403,25 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) return default: } + f.log("🏃 run dns: %v (domain: %v)\n", resolver.ID, testDomain) startTime := time.Now() ips, err := f.testDNSResolver(searchCtx, resolver, testDomain) duration := time.Since(startTime) + status := "ok ✅" if err != nil { status = fmt.Sprintf("%v ❌", err) } f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) + if err != nil { - select { - case <-searchCtx.Done(): - return - case resultChan <- testResult{Resolver: resolver, Err: err}: - return - } + resultChan <- testResult{Resolver: resolver, Err: err} + return } } - select { - case <-searchCtx.Done(): - case resultChan <- testResult{Resolver: resolver, Err: nil}: - } - }(entry, waitDone) + resultChan <- testResult{Resolver: resolver, Err: nil} + }(resolver, waitDone) case result := <-resultChan: resolversToTest-- From 97a82de956007b3f58ee1bd1f089f59eb34bfd1f Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 18:36:28 -0500 Subject: [PATCH 07/25] Create raceTests --- x/smart/stream_dialer.go | 113 +++++++++++++++++++++------------------ 1 file changed, 62 insertions(+), 51 deletions(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index b0840bfb..b861ddf2 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -363,78 +363,89 @@ func newClosedChanel() <-chan struct{} { return ch } -func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { - resolvers, err := f.dnsConfigToResolver(dnsConfig) - if err != nil { - return nil, err - } +// raceTests races will call the test function on each entry until it finds an entry for which the test returns nil. +// That entry is returned. A test is only started after the previous test finished or maxWait is done, whichever +// happens first. That way you bound the wait for a test, and they may overlap. +func raceTests[E any](ctx context.Context, maxWait time.Duration, entries []*E, test func(entry *E) error) (*E, error) { type testResult struct { - Resolver *smartResolver - Err error + Entry *E + Err error } // Communicates the result of each test. - resultChan := make(chan testResult, len(resolvers)) - // Indicates to tests that the search is done, so they don't get stuck writing to the results channel that will no longer be read. - searchCtx, searchDone := context.WithCancel(context.Background()) - defer searchDone() - // Used to space out each test. The initial value is done because there's no wait needed. + resultChan := make(chan testResult, len(entries)) waitCh := newClosedChanel() - nextResolver := 0 - for resolversToTest := len(resolvers); resolversToTest > 0; { + + next := 0 + for toTest := len(entries); toTest > 0; { select { // Ready to start testing another resolver. case <-waitCh: - resolver := resolvers[nextResolver] - nextResolver++ + entry := entries[next] + next++ - waitCtx, waitDone := context.WithTimeout(searchCtx, 250*time.Millisecond) - if nextResolver == len(resolvers) { - // Done with resolvers. No longer trigger on waitCh. + waitCtx, waitDone := context.WithTimeout(ctx, 250*time.Millisecond) + if next == len(entries) { + // Done with entries. No longer trigger on waitCh. waitCh = nil } else { waitCh = waitCtx.Done() } - go func(resolver *smartResolver, testDone context.CancelFunc) { + go func(entry *E, testDone context.CancelFunc) { defer testDone() - for _, testDomain := range testDomains { - select { - case <-searchCtx.Done(): - return - default: - } - - f.log("🏃 run dns: %v (domain: %v)\n", resolver.ID, testDomain) - startTime := time.Now() - ips, err := f.testDNSResolver(searchCtx, resolver, testDomain) - duration := time.Since(startTime) - - status := "ok ✅" - if err != nil { - status = fmt.Sprintf("%v ❌", err) - } - f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) - - if err != nil { - resultChan <- testResult{Resolver: resolver, Err: err} - return - } - } - resultChan <- testResult{Resolver: resolver, Err: nil} - }(resolver, waitDone) + err := test(entry) + resultChan <- testResult{Entry: entry, Err: err} + }(entry, waitDone) case result := <-resultChan: - resolversToTest-- - // Process the result of a test. + toTest-- if result.Err != nil { continue } - f.log("✅ selected resolver %v\n", result.Resolver.ID) - // Tested all domains on this resolver. Unwrap and return. - return result.Resolver.Resolver, nil + return result.Entry, nil } } - return nil, errors.New("could not find working resolver") + return nil, errors.New("all tests failed") +} + +func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { + resolvers, err := f.dnsConfigToResolver(dnsConfig) + if err != nil { + return nil, err + } + + ctx, searchDone := context.WithCancel(context.Background()) + defer searchDone() + resolver, err := raceTests[smartResolver](ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) error { + for _, testDomain := range testDomains { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + f.log("🏃 run dns: %v (domain: %v)\n", resolver.ID, testDomain) + startTime := time.Now() + ips, err := f.testDNSResolver(ctx, resolver, testDomain) + duration := time.Since(startTime) + + status := "ok ✅" + if err != nil { + status = fmt.Sprintf("%v ❌", err) + } + f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) + + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("could not find working resolver: %w", err) + } + f.log("✅ selected resolver %v\n", resolver.ID) + return resolver.Resolver, nil } func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.StreamDialer, tlsConfig []string) (transport.StreamDialer, error) { From 4e93fe59f6ccd75da838d3b42b8da654541b3cdf Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 19:11:22 -0500 Subject: [PATCH 08/25] Race TLS --- x/smart/stream_dialer.go | 111 +++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index b861ddf2..447203b8 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -366,10 +366,10 @@ func newClosedChanel() <-chan struct{} { // raceTests races will call the test function on each entry until it finds an entry for which the test returns nil. // That entry is returned. A test is only started after the previous test finished or maxWait is done, whichever // happens first. That way you bound the wait for a test, and they may overlap. -func raceTests[E any](ctx context.Context, maxWait time.Duration, entries []*E, test func(entry *E) error) (*E, error) { +func raceTests[E any, R any](ctx context.Context, maxWait time.Duration, entries []E, test func(entry E) (R, error)) (R, error) { type testResult struct { - Entry *E - Err error + Result R + Err error } // Communicates the result of each test. resultChan := make(chan testResult, len(entries)) @@ -378,6 +378,11 @@ func raceTests[E any](ctx context.Context, maxWait time.Duration, entries []*E, next := 0 for toTest := len(entries); toTest > 0; { select { + // Search cancelled, quit. + case <-ctx.Done(): + var empty R + return empty, ctx.Err() + // Ready to start testing another resolver. case <-waitCh: entry := entries[next] @@ -391,21 +396,23 @@ func raceTests[E any](ctx context.Context, maxWait time.Duration, entries []*E, waitCh = waitCtx.Done() } - go func(entry *E, testDone context.CancelFunc) { + go func(entry E, testDone context.CancelFunc) { defer testDone() - err := test(entry) - resultChan <- testResult{Entry: entry, Err: err} + result, err := test(entry) + resultChan <- testResult{Result: result, Err: err} }(entry, waitDone) + // Got a test result. case result := <-resultChan: toTest-- if result.Err != nil { continue } - return result.Entry, nil + return result.Result, nil } } - return nil, errors.New("all tests failed") + var empty R + return empty, errors.New("all tests failed") } func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { @@ -416,15 +423,15 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) ctx, searchDone := context.WithCancel(context.Background()) defer searchDone() - resolver, err := raceTests[smartResolver](ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) error { + resolver, err := raceTests[*smartResolver](ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) (*smartResolver, error) { for _, testDomain := range testDomains { select { case <-ctx.Done(): - return ctx.Err() + return nil, ctx.Err() default: } - f.log("🏃 run dns: %v (domain: %v)\n", resolver.ID, testDomain) + f.log("🏃 run DNS: %v (domain: %v)\n", resolver.ID, testDomain) startTime := time.Now() ips, err := f.testDNSResolver(ctx, resolver, testDomain) duration := time.Since(startTime) @@ -433,18 +440,18 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if err != nil { status = fmt.Sprintf("%v ❌", err) } - f.log("🏁 got dns: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) + f.log("🏁 got DNS: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) if err != nil { - return err + return nil, err } } - return nil + return resolver, nil }) if err != nil { return nil, fmt.Errorf("could not find working resolver: %w", err) } - f.log("✅ selected resolver %v\n", resolver.ID) + f.log("🏆 selected DNS resolver %v\n", resolver.ID) return resolver.Resolver, nil } @@ -452,53 +459,57 @@ func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.Stre if len(tlsConfig) == 0 { return nil, errors.New("config for TLS is empty. Please specify at least one transport") } - for _, transportCfg := range tlsConfig { - for di, testDomain := range testDomains { + + ctx, searchDone := context.WithCancel(context.Background()) + defer searchDone() + tlsDialer, err := raceTests(ctx, 250*time.Millisecond, tlsConfig, func(transportCfg string) (transport.StreamDialer, error) { + tlsDialer, err := config.WrapStreamDialer(baseDialer, transportCfg) + if err != nil { + return nil, fmt.Errorf("WrapStreamDialer failed: %w", err) + } + for _, testDomain := range testDomains { + startTime := time.Now() + testAddr := net.JoinHostPort(testDomain, "443") - f.log(" tls=%v (domain: %v)", transportCfg, testDomain) + f.log("🏃 run TLS: '%v' (domain: %v)\n", transportCfg, testDomain) - tlsDialer, err := config.WrapStreamDialer(baseDialer, transportCfg) - if err != nil { - f.log("; wrap_error=%v ❌\n", err) - break - } ctx, cancel := context.WithTimeout(context.Background(), f.TestTimeout) defer cancel() testConn, err := tlsDialer.DialStream(ctx, testAddr) if err != nil { - f.log("; dial_error=%v ❌\n", err) - break + f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, dial_error=%v ❌\n", transportCfg, testDomain, time.Since(startTime), err) + return nil, err } tlsConn := tls.Client(testConn, &tls.Config{ServerName: testDomain}) err = tlsConn.HandshakeContext(ctx) tlsConn.Close() if err != nil { - f.log("; handshake=%v ❌\n", err) - break + f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, handshake=%v ❌\n", transportCfg, testDomain, time.Since(startTime), err) + return nil, err } - f.log("; status=ok ✅\n") - if di+1 < len(testDomains) { - // More domains to test - continue - } - return transport.FuncStreamDialer(func(ctx context.Context, raddr string) (transport.StreamConn, error) { - _, portStr, err := net.SplitHostPort(raddr) - if err != nil { - return nil, fmt.Errorf("failed to parse address: %w", err) - } - portNum, err := net.DefaultResolver.LookupPort(ctx, "tcp", portStr) - if err != nil { - return nil, fmt.Errorf("could not resolve port: %w", err) - } - selectedDialer := baseDialer - if portNum == 443 || portNum == 853 { - selectedDialer = tlsDialer - } - return selectedDialer.DialStream(ctx, raddr) - }), nil - } - } - return nil, errors.New("could not find TLS strategy") + f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, status=ok ✅\n", transportCfg, testDomain, time.Since(startTime)) + } + f.log("🏆 selected TLS strategy '%v'\n", transportCfg) + return tlsDialer, nil + }) + if err != nil { + return nil, fmt.Errorf("could not find TLS strategy: %w", err) + } + return transport.FuncStreamDialer(func(ctx context.Context, raddr string) (transport.StreamConn, error) { + _, portStr, err := net.SplitHostPort(raddr) + if err != nil { + return nil, fmt.Errorf("failed to parse address: %w", err) + } + portNum, err := net.DefaultResolver.LookupPort(ctx, "tcp", portStr) + if err != nil { + return nil, fmt.Errorf("could not resolve port: %w", err) + } + selectedDialer := baseDialer + if portNum == 443 || portNum == 853 { + selectedDialer = tlsDialer + } + return selectedDialer.DialStream(ctx, raddr) + }), nil } // makeFullyQualified makes the domain fully-qualified, ending on a dot ("."). From cf6f50f075ab6ea78535d0e4a0c51e30137fa565 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 19:39:25 -0500 Subject: [PATCH 09/25] Improve output --- x/examples/smart-proxy/main.go | 2 + x/smart/dns.go | 217 +++++++++++++++++++++++++++ x/smart/doc.go | 18 +++ x/smart/racer.go | 80 ++++++++++ x/smart/stream_dialer.go | 262 +-------------------------------- 5 files changed, 323 insertions(+), 256 deletions(-) create mode 100644 x/smart/dns.go create mode 100644 x/smart/doc.go create mode 100644 x/smart/racer.go diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go index 6db27d1c..b9581a4f 100644 --- a/x/examples/smart-proxy/main.go +++ b/x/examples/smart-proxy/main.go @@ -102,10 +102,12 @@ func main() { } fmt.Println("Finding strategy") + startTime := time.Now() dialer, err := finder.NewDialer(context.Background(), domainsFlag, finderConfig) if err != nil { log.Fatalf("Failed to find dialer: %v", err) } + fmt.Printf("Found strategy in %0.2fs\n", time.Since(startTime).Seconds()) logDialer := transport.FuncStreamDialer(func(ctx context.Context, address string) (transport.StreamConn, error) { conn, err := dialer.DialStream(ctx, address) if err != nil { diff --git a/x/smart/dns.go b/x/smart/dns.go new file mode 100644 index 00000000..985ee2dc --- /dev/null +++ b/x/smart/dns.go @@ -0,0 +1,217 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package smart + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "time" + "unicode" + + "github.com/Jigsaw-Code/outline-sdk/dns" + "golang.org/x/net/dns/dnsmessage" +) + +// makeFullyQualified makes the domain fully-qualified, ending on a dot ("."). +// This is useful in domain resolution to avoid ambiguity with local domains +// and domain search. +func makeFullyQualified(domain string) string { + if len(domain) > 0 && domain[len(domain)-1] == '.' { + return domain + } + return domain + "." +} + +// mixCase randomizes the case of the domain letters. +func mixCase(domain string) string { + var mixed []rune + for _, r := range domain { + if rand.Intn(2) == 0 { + mixed = append(mixed, unicode.ToLower(r)) + } else { + mixed = append(mixed, unicode.ToUpper(r)) + } + } + return string(mixed) +} + +func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain string) ([]net.IP, error) { + requestDomain := mixCase(testDomain) + _, err := lookupCNAME(ctx, requestDomain) + if err != nil { + return nil, fmt.Errorf("could not get cname: %w", err) + } + ips, err := resolver.LookupIP(ctx, "ip", requestDomain) + if err != nil { + return nil, fmt.Errorf("failed to lookup IPs: %w", err) + } + if len(ips) == 0 { + return nil, fmt.Errorf("no ip answer") + } + for _, ip := range ips { + if ip.IsLoopback() { + return nil, fmt.Errorf("localhost ip: %v", ip) // -1 + } + if ip.IsPrivate() { + return nil, fmt.Errorf("private ip: %v", ip) // -1 + } + if ip.IsUnspecified() { + return nil, fmt.Errorf("zero ip: %v", ip) // -1 + } + // TODO: consider validating the IPs: fingerprint, hardcoded ground truth, trusted response, TLS connection. + } + return ips, nil +} + +func getIPs(answers []dnsmessage.Resource) []net.IP { + var ips []net.IP + for _, answer := range answers { + if answer.Header.Type != dnsmessage.TypeA && answer.Header.Type != dnsmessage.TypeAAAA { + continue + } + var ip net.IP + switch rr := answer.Body.(type) { + case *dnsmessage.AResource: + ip = net.IP(rr.A[:]) + case *dnsmessage.AAAAResource: + ip = net.IP(rr.AAAA[:]) + default: + continue + } + ips = append(ips, ip) + } + return ips +} + +func evaluateAddressResponse(response dnsmessage.Message, requestDomain string) ([]net.IP, error) { + if response.RCode != dnsmessage.RCodeSuccess { + return nil, fmt.Errorf("rcode is not success: %v", response.RCode) + } + if len(response.Answers) == 0 { + return nil, errors.New("no answers") + } + ips := getIPs(response.Answers) + if len(ips) == 0 { + return ips, fmt.Errorf("no ip answer: %v", response.Answers) + } + for _, ip := range ips { + if ip.IsLoopback() { + return nil, fmt.Errorf("localhost ip: %v", ip) + } + if ip.IsPrivate() { + return nil, fmt.Errorf("private ip: %v", ip) + } + if ip.IsUnspecified() { + return nil, fmt.Errorf("zero ip: %v", ip) + } + } + // All popular recursive resolvers we tested maintain the domain case of the request. + // Note that this is not the case of authoritative resolvers. Some of them will return + // a fully normalized domain name, or normalize part of it. + if response.Answers[0].Header.Name.String() != requestDomain { + return ips, fmt.Errorf("domain mismatch: got %v, expected %v", response.Answers[0].Header.Name, requestDomain) + } + return ips, nil +} + +func evaluateCNAMEResponse(response dnsmessage.Message, requestDomain string) error { + if response.RCode != dnsmessage.RCodeSuccess { + return fmt.Errorf("rcode is not success: %v", response.RCode) + } + if len(response.Answers) == 0 { + var numSOA int + for _, answer := range response.Authorities { + if _, ok := answer.Body.(*dnsmessage.SOAResource); ok { + numSOA++ + } + } + if numSOA != 1 { + return fmt.Errorf("SOA records is %v, expected 1", numSOA) + } + return nil + } + var cname string + for _, answer := range response.Answers { + if answer.Header.Type != dnsmessage.TypeCNAME { + return fmt.Errorf("bad answer type: %v", answer.Header.Type) + } + if rr, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { + if cname != "" { + return fmt.Errorf("found too many CNAMEs: %v %v", cname, rr.CNAME) + } + cname = rr.CNAME.String() + } + } + if cname == "" { + return fmt.Errorf("no CNAME in answers") + } + return nil +} + +func testDNSResolver(baseCtx context.Context, oneTestTimeout time.Duration, resolver *smartResolver, testDomain string) ([]net.IP, error) { + // We special case the system resolver, since we can't get a dns.RoundTripper. + if resolver.Resolver == nil { + ctx, cancel := context.WithTimeout(baseCtx, oneTestTimeout) + defer cancel() + return evaluateNetResolver(ctx, new(net.Resolver), testDomain) + } + + requestDomain := mixCase(testDomain) + + q, err := dns.NewQuestion(requestDomain, dnsmessage.TypeA) + if err != nil { + return nil, fmt.Errorf("failed to create question: %v", err) + } + ctxA, cancelA := context.WithTimeout(baseCtx, oneTestTimeout) + defer cancelA() + response, err := resolver.Query(ctxA, *q) + if err != nil { + return nil, fmt.Errorf("request for A query failed: %w", err) + } + + if resolver.Secure { + // For secure DNS, we just need to check if we can communicate with it. + // No need to analyze content, since it is protected by TLS. + return getIPs(response.Answers), nil + } + + ips, err := evaluateAddressResponse(*response, requestDomain) + if err != nil { + return ips, fmt.Errorf("failed A test: %w", err) + } + + // TODO(fortuna): Consider testing whether we can establish a TCP connection to ip:443. + + // Run CNAME test, which helps in case the resolver returns a public IP, as is the + // case in China. + q, err = dns.NewQuestion(requestDomain, dnsmessage.TypeCNAME) + if err != nil { + return nil, fmt.Errorf("failed to create question: %v", err) + } + ctxCNAME, cancelCNAME := context.WithTimeout(baseCtx, oneTestTimeout) + defer cancelCNAME() + response, err = resolver.Query(ctxCNAME, *q) + if err != nil { + return nil, fmt.Errorf("request for CNAME query failed: %w", err) + } + err = evaluateCNAMEResponse(*response, requestDomain) + if err != nil { + return nil, fmt.Errorf("failed CNAME test: %w", err) + } + return ips, nil +} diff --git a/x/smart/doc.go b/x/smart/doc.go new file mode 100644 index 00000000..6ca22dd0 --- /dev/null +++ b/x/smart/doc.go @@ -0,0 +1,18 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Package smart provides utilities to dynamically find serverless strategies for circumvention. +*/ +package smart diff --git a/x/smart/racer.go b/x/smart/racer.go new file mode 100644 index 00000000..c66037d9 --- /dev/null +++ b/x/smart/racer.go @@ -0,0 +1,80 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package smart + +import ( + "context" + "errors" + "time" +) + +// Returns a read channel that is already closed. +func newClosedChanel() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} + +// raceTests races will call the test function on each entry until it finds an entry for which the test returns nil. +// That entry is returned. A test is only started after the previous test finished or maxWait is done, whichever +// happens first. That way you bound the wait for a test, and they may overlap. +func raceTests[E any, R any](ctx context.Context, maxWait time.Duration, entries []E, test func(entry E) (R, error)) (R, error) { + type testResult struct { + Result R + Err error + } + // Communicates the result of each test. + resultChan := make(chan testResult, len(entries)) + waitCh := newClosedChanel() + + next := 0 + for toTest := len(entries); toTest > 0; { + select { + // Search cancelled, quit. + case <-ctx.Done(): + var empty R + return empty, ctx.Err() + + // Ready to start testing another resolver. + case <-waitCh: + entry := entries[next] + next++ + + waitCtx, waitDone := context.WithTimeout(ctx, 250*time.Millisecond) + if next == len(entries) { + // Done with entries. No longer trigger on waitCh. + waitCh = nil + } else { + waitCh = waitCtx.Done() + } + + go func(entry E, testDone context.CancelFunc) { + defer testDone() + result, err := test(entry) + resultChan <- testResult{Result: result, Err: err} + }(entry, waitDone) + + // Got a test result. + case result := <-resultChan: + toTest-- + if result.Err != nil { + continue + } + return result.Result, nil + } + } + var empty R + return empty, errors.New("all tests failed") +} diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 447203b8..9a6bb0d2 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package smart provides utilities to dynamically find serverless strategies for circumvention. package smart import ( @@ -22,148 +21,19 @@ import ( "errors" "fmt" "io" - "math/rand" "net" "net/url" "sync" "time" - "unicode" "github.com/Jigsaw-Code/outline-sdk/dns" "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/x/config" - "golang.org/x/net/dns/dnsmessage" ) // To test one strategy: // go run ./x/examples/smart-proxy -v -localAddr=localhost:1080 --transport="" --domain www.rferl.org --config=<(echo '{"dns": [{"https": {"name": "doh.sb"}}]}') -// mixCase randomizes the case of the domain letters. -func mixCase(domain string) string { - var mixed []rune - for _, r := range domain { - if rand.Intn(2) == 0 { - mixed = append(mixed, unicode.ToLower(r)) - } else { - mixed = append(mixed, unicode.ToUpper(r)) - } - } - return string(mixed) -} - -func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain string) ([]net.IP, error) { - requestDomain := mixCase(testDomain) - _, err := lookupCNAME(ctx, requestDomain) - if err != nil { - return nil, fmt.Errorf("could not get cname: %w", err) - } - ips, err := resolver.LookupIP(ctx, "ip", requestDomain) - if err != nil { - return nil, fmt.Errorf("failed to lookup IPs: %w", err) - } - if len(ips) == 0 { - return nil, fmt.Errorf("no ip answer") - } - for _, ip := range ips { - if ip.IsLoopback() { - return nil, fmt.Errorf("localhost ip: %v", ip) // -1 - } - if ip.IsPrivate() { - return nil, fmt.Errorf("private ip: %v", ip) // -1 - } - if ip.IsUnspecified() { - return nil, fmt.Errorf("zero ip: %v", ip) // -1 - } - // TODO: consider validating the IPs: fingerprint, hardcoded ground truth, trusted response, TLS connection. - } - return ips, nil -} - -func getIPs(answers []dnsmessage.Resource) []net.IP { - var ips []net.IP - for _, answer := range answers { - if answer.Header.Type != dnsmessage.TypeA && answer.Header.Type != dnsmessage.TypeAAAA { - continue - } - var ip net.IP - switch rr := answer.Body.(type) { - case *dnsmessage.AResource: - ip = net.IP(rr.A[:]) - case *dnsmessage.AAAAResource: - ip = net.IP(rr.AAAA[:]) - default: - continue - } - ips = append(ips, ip) - } - return ips -} - -func evaluateAddressResponse(response dnsmessage.Message, requestDomain string) ([]net.IP, error) { - if response.RCode != dnsmessage.RCodeSuccess { - return nil, fmt.Errorf("rcode is not success: %v", response.RCode) - } - if len(response.Answers) == 0 { - return nil, errors.New("no answers") - } - ips := getIPs(response.Answers) - if len(ips) == 0 { - return ips, fmt.Errorf("no ip answer: %v", response.Answers) - } - for _, ip := range ips { - if ip.IsLoopback() { - return nil, fmt.Errorf("localhost ip: %v", ip) - } - if ip.IsPrivate() { - return nil, fmt.Errorf("private ip: %v", ip) - } - if ip.IsUnspecified() { - return nil, fmt.Errorf("zero ip: %v", ip) - } - } - // All popular recursive resolvers we tested maintain the domain case of the request. - // Note that this is not the case of authoritative resolvers. Some of them will return - // a fully normalized domain name, or normalize part of it. - if response.Answers[0].Header.Name.String() != requestDomain { - return ips, fmt.Errorf("domain mismatch: got %v, expected %v", response.Answers[0].Header.Name, requestDomain) - } - return ips, nil -} - -func evaluateCNAMEResponse(response dnsmessage.Message, requestDomain string) error { - if response.RCode != dnsmessage.RCodeSuccess { - return fmt.Errorf("rcode is not success: %v", response.RCode) - } - if len(response.Answers) == 0 { - var numSOA int - for _, answer := range response.Authorities { - if _, ok := answer.Body.(*dnsmessage.SOAResource); ok { - numSOA++ - } - } - if numSOA != 1 { - return fmt.Errorf("SOA records is %v, expected 1", numSOA) - } - return nil - } - var cname string - for _, answer := range response.Answers { - if answer.Header.Type != dnsmessage.TypeCNAME { - return fmt.Errorf("bad answer type: %v", answer.Header.Type) - } - if rr, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { - if cname != "" { - return fmt.Errorf("found too many CNAMEs: %v %v", cname, rr.CNAME) - } - cname = rr.CNAME.String() - } - } - if cname == "" { - return fmt.Errorf("no CNAME in answers") - } - return nil -} - type StrategyFinder struct { TestTimeout time.Duration LogWriter io.Writer @@ -180,59 +50,6 @@ func (f *StrategyFinder) log(format string, a ...any) { } } -func (f *StrategyFinder) testDNSResolver(baseCtx context.Context, resolver *smartResolver, testDomain string) ([]net.IP, error) { - // We special case the system resolver, since we can't get a dns.RoundTripper. - if resolver.Resolver == nil { - ctx, cancel := context.WithTimeout(baseCtx, f.TestTimeout) - defer cancel() - return evaluateNetResolver(ctx, new(net.Resolver), testDomain) - } - - requestDomain := mixCase(testDomain) - - q, err := dns.NewQuestion(requestDomain, dnsmessage.TypeA) - if err != nil { - return nil, fmt.Errorf("failed to create question: %v", err) - } - ctxA, cancelA := context.WithTimeout(baseCtx, f.TestTimeout) - defer cancelA() - response, err := resolver.Query(ctxA, *q) - if err != nil { - return nil, fmt.Errorf("request for A query failed: %w", err) - } - - if resolver.Secure { - // For secure DNS, we just need to check if we can communicate with it. - // No need to analyze content, since it is protected by TLS. - return getIPs(response.Answers), nil - } - - ips, err := evaluateAddressResponse(*response, requestDomain) - if err != nil { - return ips, fmt.Errorf("failed A test: %w", err) - } - - // TODO(fortuna): Consider testing whether we can establish a TCP connection to ip:443. - - // Run CNAME test, which helps in case the resolver returns a public IP, as is the - // case in China. - q, err = dns.NewQuestion(requestDomain, dnsmessage.TypeCNAME) - if err != nil { - return nil, fmt.Errorf("failed to create question: %v", err) - } - ctxCNAME, cancelCNAME := context.WithTimeout(baseCtx, f.TestTimeout) - defer cancelCNAME() - response, err = resolver.Query(ctxCNAME, *q) - if err != nil { - return nil, fmt.Errorf("request for CNAME query failed: %w", err) - } - err = evaluateCNAMEResponse(*response, requestDomain) - if err != nil { - return nil, fmt.Errorf("failed CNAME test: %w", err) - } - return ips, nil -} - type httpsEntryJSON struct { // Domain name of the host. Name string `json:"name,omitempty"` @@ -356,65 +173,6 @@ func (f *StrategyFinder) dnsConfigToResolver(dnsConfig []dnsEntryJSON) ([]*smart return rts, nil } -// Returns a read channel that is already closed. -func newClosedChanel() <-chan struct{} { - ch := make(chan struct{}) - close(ch) - return ch -} - -// raceTests races will call the test function on each entry until it finds an entry for which the test returns nil. -// That entry is returned. A test is only started after the previous test finished or maxWait is done, whichever -// happens first. That way you bound the wait for a test, and they may overlap. -func raceTests[E any, R any](ctx context.Context, maxWait time.Duration, entries []E, test func(entry E) (R, error)) (R, error) { - type testResult struct { - Result R - Err error - } - // Communicates the result of each test. - resultChan := make(chan testResult, len(entries)) - waitCh := newClosedChanel() - - next := 0 - for toTest := len(entries); toTest > 0; { - select { - // Search cancelled, quit. - case <-ctx.Done(): - var empty R - return empty, ctx.Err() - - // Ready to start testing another resolver. - case <-waitCh: - entry := entries[next] - next++ - - waitCtx, waitDone := context.WithTimeout(ctx, 250*time.Millisecond) - if next == len(entries) { - // Done with entries. No longer trigger on waitCh. - waitCh = nil - } else { - waitCh = waitCtx.Done() - } - - go func(entry E, testDone context.CancelFunc) { - defer testDone() - result, err := test(entry) - resultChan <- testResult{Result: result, Err: err} - }(entry, waitDone) - - // Got a test result. - case result := <-resultChan: - toTest-- - if result.Err != nil { - continue - } - return result.Result, nil - } - } - var empty R - return empty, errors.New("all tests failed") -} - func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { resolvers, err := f.dnsConfigToResolver(dnsConfig) if err != nil { @@ -423,6 +181,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) ctx, searchDone := context.WithCancel(context.Background()) defer searchDone() + raceStart := time.Now() resolver, err := raceTests[*smartResolver](ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) (*smartResolver, error) { for _, testDomain := range testDomains { select { @@ -433,7 +192,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) f.log("🏃 run DNS: %v (domain: %v)\n", resolver.ID, testDomain) startTime := time.Now() - ips, err := f.testDNSResolver(ctx, resolver, testDomain) + ips, err := testDNSResolver(ctx, f.TestTimeout, resolver, testDomain) duration := time.Since(startTime) status := "ok ✅" @@ -451,7 +210,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if err != nil { return nil, fmt.Errorf("could not find working resolver: %w", err) } - f.log("🏆 selected DNS resolver %v\n", resolver.ID) + f.log("🏆 selected DNS resolver %v in %0.2f\n", resolver.ID, time.Since(raceStart).Seconds()) return resolver.Resolver, nil } @@ -462,6 +221,7 @@ func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.Stre ctx, searchDone := context.WithCancel(context.Background()) defer searchDone() + raceStart := time.Now() tlsDialer, err := raceTests(ctx, 250*time.Millisecond, tlsConfig, func(transportCfg string) (transport.StreamDialer, error) { tlsDialer, err := config.WrapStreamDialer(baseDialer, transportCfg) if err != nil { @@ -489,7 +249,7 @@ func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.Stre } f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, status=ok ✅\n", transportCfg, testDomain, time.Since(startTime)) } - f.log("🏆 selected TLS strategy '%v'\n", transportCfg) + f.log("🏆 selected TLS strategy '%v' in %0.2fs\n", transportCfg, time.Since(raceStart).Seconds()) return tlsDialer, nil }) if err != nil { @@ -512,17 +272,7 @@ func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.Stre }), nil } -// makeFullyQualified makes the domain fully-qualified, ending on a dot ("."). -// This is useful in domain resolution to avoid ambiguity with local domains -// and domain search. -func makeFullyQualified(domain string) string { - if len(domain) > 0 && domain[len(domain)-1] == '.' { - return domain - } - return domain + "." -} - -// NewDialer uses the config in configBytes to search for a strategy that unblocks all of the testDomains, returning a dialer with the found strategy. +// NewDialer uses the config in configBytes to search for a strategy that unblocks DNS and TLS for all of the testDomains, returning a dialer with the found strategy. // It returns an error if no strategy was found that unblocks the testDomains. // The testDomains must be domains with a TLS service running on port 443. func (f *StrategyFinder) NewDialer(ctx context.Context, testDomains []string, configBytes []byte) (transport.StreamDialer, error) { From 2ba3ba8f1109babce5dc2cf8253de240f2b6d0a9 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 19:46:15 -0500 Subject: [PATCH 10/25] Add comment --- transport/happyeyeballs_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/transport/happyeyeballs_test.go b/transport/happyeyeballs_test.go index 4029123e..5eb5502b 100644 --- a/transport/happyeyeballs_test.go +++ b/transport/happyeyeballs_test.go @@ -315,6 +315,7 @@ func ExampleNewParallelHappyEyeballsResolveFunc() { dialer := HappyEyeballsStreamDialer{ Dialer: FuncStreamDialer(func(ctx context.Context, addr string) (StreamConn, error) { ips = append(ips, netip.MustParseAddrPort(addr).Addr()) + // Add a slight delay to simulate a more real life ordering. time.Sleep(1 * time.Millisecond) return nil, errors.New("not implemented") }), From 80cca1fd74b7df46b2ed7a1b4bb4e357d92fdfe1 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 19:55:25 -0500 Subject: [PATCH 11/25] Fix unit --- x/smart/stream_dialer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 9a6bb0d2..3893c62d 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -210,7 +210,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if err != nil { return nil, fmt.Errorf("could not find working resolver: %w", err) } - f.log("🏆 selected DNS resolver %v in %0.2f\n", resolver.ID, time.Since(raceStart).Seconds()) + f.log("🏆 selected DNS resolver %v in %0.2fs\n", resolver.ID, time.Since(raceStart).Seconds()) return resolver.Resolver, nil } From 5b0966e3322ad874707afae1480268caeb56ea95 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 19:59:47 -0500 Subject: [PATCH 12/25] Cleaner log --- x/smart/stream_dialer.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 3893c62d..01d45dd0 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -199,7 +199,12 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if err != nil { status = fmt.Sprintf("%v ❌", err) } - f.log("🏁 got DNS: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) + select { + case <-ctx.Done(): + default: + // Only output log if the search is not done yet. + f.log("🏁 got DNS: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) + } if err != nil { return nil, err From 0a2c16d1f8f721444181ec44fb43ab31cc578fd2 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 20:13:47 -0500 Subject: [PATCH 13/25] Remove split 5 --- x/examples/smart-proxy/config.json | 1 - 1 file changed, 1 deletion(-) diff --git a/x/examples/smart-proxy/config.json b/x/examples/smart-proxy/config.json index 082c3911..8fa3a168 100644 --- a/x/examples/smart-proxy/config.json +++ b/x/examples/smart-proxy/config.json @@ -107,7 +107,6 @@ "", "split:1", "split:2", - "split:5", "tlsfrag:1" ] } From a4a873bd241e2cc03b4b4bbda7b90e099bc83471 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 20:15:32 -0500 Subject: [PATCH 14/25] Add line --- x/smart/stream_dialer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 01d45dd0..1fd335b3 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -215,7 +215,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if err != nil { return nil, fmt.Errorf("could not find working resolver: %w", err) } - f.log("🏆 selected DNS resolver %v in %0.2fs\n", resolver.ID, time.Since(raceStart).Seconds()) + f.log("🏆 selected DNS resolver %v in %0.2fs\n\n", resolver.ID, time.Since(raceStart).Seconds()) return resolver.Resolver, nil } @@ -254,7 +254,7 @@ func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.Stre } f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, status=ok ✅\n", transportCfg, testDomain, time.Since(startTime)) } - f.log("🏆 selected TLS strategy '%v' in %0.2fs\n", transportCfg, time.Since(raceStart).Seconds()) + f.log("🏆 selected TLS strategy '%v' in %0.2fs\n\n", transportCfg, time.Since(raceStart).Seconds()) return tlsDialer, nil }) if err != nil { From 1ab30d7f4d38d3d021480d20107da45e8f811a9c Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 23:20:16 -0500 Subject: [PATCH 15/25] Update mobileproxy --- x/mobileproxy/README.md | 228 +++++++++++++++++++++++++++++++++-- x/mobileproxy/mobileproxy.go | 107 ++++++++++++++-- 2 files changed, 318 insertions(+), 17 deletions(-) diff --git a/x/mobileproxy/README.md b/x/mobileproxy/README.md index 2065f38a..56955e18 100644 --- a/x/mobileproxy/README.md +++ b/x/mobileproxy/README.md @@ -47,6 +47,14 @@ The header file below is an example of the Objective-C interface that Go Mobile @class MobileproxyProxy; +@class MobileproxyStreamDialer; +@class MobileproxyStringList; +@protocol MobileproxyLogWriter; +@class MobileproxyLogWriter; + +@protocol MobileproxyLogWriter +- (BOOL)writeString:(NSString* _Nullable)s n:(long* _Nullable)n error:(NSError* _Nullable* _Nullable)error; +@end /** * Proxy enables you to get the actual address bound by the server and stop the service when no longer needed. @@ -77,10 +85,83 @@ The function takes a timeoutSeconds number instead of a [time.Duration] so it's @end /** - * RunProxy runs a local web proxy that listens on localAddress, and uses the transportConfig to -create a [transport.StreamDialer] that is used to connect to the requested destination. + * StreamDialer encapsulates the logic to create stream connections (like TCP). + */ +@interface MobileproxyStreamDialer : NSObject { +} +@property(strong, readonly) _Nonnull id _ref; + +- (nonnull instancetype)initWithRef:(_Nonnull id)ref; +/** + * NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. +The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. + */ +- (nullable instancetype)initFromConfig:(NSString* _Nullable)transportConfig; +// skipped field StreamDialer.StreamDialer with unsupported type: github.com/Jigsaw-Code/outline-sdk/transport.StreamDialer + +// skipped method StreamDialer.DialStream with unsupported parameter or return types + +@end + +/** + * StringList allows us to pass a list of strings to the Go Mobile functions, since Go Mobiule doesn't +support slices as parameters. + */ +@interface MobileproxyStringList : NSObject { +} +@property(strong, readonly) _Nonnull id _ref; + +- (nonnull instancetype)initWithRef:(_Nonnull id)ref; +- (nonnull instancetype)init; +/** + * Append adds the string value to the end of the list. + */ +- (void)append:(NSString* _Nullable)value; +@end + +/** + * NewListFromLines creates a StringList by splitting the input string on new lines. + */ +FOUNDATION_EXPORT MobileproxyStringList* _Nullable MobileproxyNewListFromLines(NSString* _Nullable lines); + +/** + * NewSmartStreamDialer automatically selects a DNS and TLS strategy to use, and return a [StreamDialer] +that will use the selected strategy. +It uses testDomain to find a strategy that works when accessing those domains. +The strategies to search are given in the searchConfig. An example can be found in +https://github.com/Jigsaw-Code/outline-sdk/x/examples/smart-proxy/config.json + */ +FOUNDATION_EXPORT MobileproxyStreamDialer* _Nullable MobileproxyNewSmartStreamDialer(MobileproxyStringList* _Nullable testDomains, NSString* _Nullable searchConfig, id _Nullable logWriter, NSError* _Nullable* _Nullable error); + +/** + * NewStderrLogWriter creates a [LogWriter] that writes to the standard error output. + */ +FOUNDATION_EXPORT id _Nullable MobileproxyNewStderrLogWriter(void); + +/** + * NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. +The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. + */ +FOUNDATION_EXPORT MobileproxyStreamDialer* _Nullable MobileproxyNewStreamDialerFromConfig(NSString* _Nullable transportConfig, NSError* _Nullable* _Nullable error); + +/** + * RunProxy runs a local web proxy that listens on localAddress, and handles proxy requests by +establishing connections to requested destination using the [StreamDialer]. + */ +FOUNDATION_EXPORT MobileproxyProxy* _Nullable MobileproxyRunProxy(NSString* _Nullable localAddress, MobileproxyStreamDialer* _Nullable dialer, NSError* _Nullable* _Nullable error); + +@class MobileproxyLogWriter; + +/** + * LogWriter is used as a sink for logging. */ -FOUNDATION_EXPORT MobileproxyProxy* _Nullable MobileproxyRunProxy(NSString* _Nullable localAddress, NSString* _Nullable transportConfig, NSError* _Nullable* _Nullable error); +@interface MobileproxyLogWriter : NSObject { +} +@property(strong, readonly) _Nonnull id _ref; + +- (nonnull instancetype)initWithRef:(_Nonnull id)ref; +- (BOOL)writeString:(NSString* _Nullable)s n:(long* _Nullable)n error:(NSError* _Nullable* _Nullable)error; +@end #endif ``` @@ -94,7 +175,28 @@ The files below are examples of the Java interface that Go Mobile generates. > **Warning**: this example may diverge from what is actually generated by the current code. Use the coed you generate instead. -`mobileproxy.java`: +`LogWriter.java`: + +```java +// Code generated by gobind. DO NOT EDIT. + +// Java class mobileproxy.LogWriter is a proxy for talking to a Go program. +// +// autogenerated by gobind -lang=java github.com/Jigsaw-Code/outline-sdk/x/mobileproxy +package mobileproxy; + +import go.Seq; + +/** + * LogWriter is used as a sink for logging. + */ +public interface LogWriter { + public long writeString(String s) throws Exception; + +} +``` + +`Mobileproxy.java`: ```java // Code generated by gobind. DO NOT EDIT. @@ -119,13 +221,46 @@ public abstract class Mobileproxy { private static native void _init(); + private static final class proxyLogWriter implements Seq.Proxy, LogWriter { + private final int refnum; + + @Override public final int incRefnum() { + Seq.incGoRef(refnum, this); + return refnum; + } + + proxyLogWriter(int refnum) { this.refnum = refnum; Seq.trackGoRef(refnum, this); } + + public native long writeString(String s) throws Exception; + } /** - * RunProxy runs a local web proxy that listens on localAddress, and uses the transportConfig to - create a [transport.StreamDialer] that is used to connect to the requested destination. + * NewListFromLines creates a StringList by splitting the input string on new lines. */ - public static native Proxy runProxy(String localAddress, String transportConfig) throws Exception; + public static native StringList newListFromLines(String lines); + /** + * NewSmartStreamDialer automatically selects a DNS and TLS strategy to use, and return a [StreamDialer] + that will use the selected strategy. + It uses testDomain to find a strategy that works when accessing those domains. + The strategies to search are given in the searchConfig. An example can be found in + https://github.com/Jigsaw-Code/outline-sdk/x/examples/smart-proxy/config.json + */ + public static native StreamDialer newSmartStreamDialer(StringList testDomains, String searchConfig, LogWriter logWriter) throws Exception; + /** + * NewStderrLogWriter creates a [LogWriter] that writes to the standard error output. + */ + public static native LogWriter newStderrLogWriter(); + /** + * NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. + The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. + */ + public static native StreamDialer newStreamDialerFromConfig(String transportConfig) throws Exception; + /** + * RunProxy runs a local web proxy that listens on localAddress, and handles proxy requests by + establishing connections to requested destination using the [StreamDialer]. + */ + public static native Proxy runProxy(String localAddress, StreamDialer dialer) throws Exception; } ``` @@ -197,13 +332,71 @@ public final class Proxy implements Seq.Proxy { } ``` +`StringList.java`: + +```java +// Code generated by gobind. DO NOT EDIT. + +// Java class mobileproxy.StringList is a proxy for talking to a Go program. +// +// autogenerated by gobind -lang=java github.com/Jigsaw-Code/outline-sdk/x/mobileproxy +package mobileproxy; + +import go.Seq; + +/** + * StringList allows us to pass a list of strings to the Go Mobile functions, since Go Mobiule doesn't +support slices as parameters. + */ +public final class StringList implements Seq.Proxy { + static { Mobileproxy.touch(); } + + private final int refnum; + + @Override public final int incRefnum() { + Seq.incGoRef(refnum, this); + return refnum; + } + + StringList(int refnum) { this.refnum = refnum; Seq.trackGoRef(refnum, this); } + + public StringList() { this.refnum = __New(); Seq.trackGoRef(refnum, this); } + + private static native int __New(); + + /** + * Append adds the string value to the end of the list. + */ + public native void append(String value); + @Override public boolean equals(Object o) { + if (o == null || !(o instanceof StringList)) { + return false; + } + StringList that = (StringList)o; + return true; + } + + @Override public int hashCode() { + return java.util.Arrays.hashCode(new Object[] {}); + } + + @Override public String toString() { + StringBuilder b = new StringBuilder(); + b.append("StringList").append("{"); + return b.append("}").toString(); + } +} +``` + + + ## Add the library to your mobile project To add the library to your mobile project, see Go Mobile's [Building and deploying to iOS](https://github.com/golang/go/wiki/Mobile#building-and-deploying-to-ios-1) and [Building and deploying to Android](https://github.com/golang/go/wiki/Mobile#building-and-deploying-to-android-1). -## Use the library +## Using the basic local proxy forwarder You need to call the `RunProxy` function passing the local address to use, and the transport configuration. @@ -217,6 +410,23 @@ val proxy = mobileproxy.runProxy("localhost:0", "split:3") proxy.stop() ``` +## Using the smart local proxy forwarder ("Smart Proxy") + +The Smart Proxy can automatically try multiple strategies to unblock access to the test domains you specify. +You need to specify a strategy config in JSON format ([example](../examples/smart-proxy/config.json)). + +On Android, the Kotlin code would look like this: +```kotlin +// Use port zero to let the system pick an open port for you. +val testDomains = mobileproxy.newListFromLines("www.youtube.com\ni.ytimg.com") +val strategiesConfig = "..." // Config JSON. +val proxy = mobileproxy.runSmartProxy("localhost:0", testDomains, strategies) +// Configure your networking library using proxy.host() and proxy.port() or proxy.address(). +// ... +// Stop running the proxy. +proxy.stop() +``` + ## Configure your HTTP client or networking library You need to configure your networking library to use the local proxy. How you do it depends on the networking library you are using. @@ -265,7 +475,7 @@ We are working on instructions on how use the local proxy in a Webview. On Android, you will likely have to implement [WebViewClient.shouldInterceptRequest](https://developer.android.com/reference/android/webkit/WebViewClient#shouldInterceptRequest(android.webkit.WebView,%20android.webkit.WebResourceRequest)) to fulfill requests using an HTTP client that uses the local proxy. -On iOS, you will have to use [NWParameters.PrivacyContext.proxyConfigurations](https://developer.apple.com/documentation/network/nwparameters/privacycontext/4156642-proxyconfigurations). It is iOS 17.0+ and MacOS 14.0+ only. As a fallback you can force encrypted DNS in iOS 14+ via [NWParameters.PrivacyContext.requireEncryptedNameResolution(_:fallbackResolver:)](https://developer.apple.com/documentation/network/nwparameters/privacycontext/3548851-requireencryptednameresolution). +On iOS, we are still looking for ideas. There's [WKWebViewConfiguration.setURLSchemeHandler](https://developer.apple.com/documentation/webkit/wkwebviewconfiguration/2875766-seturlschemehandler), but the documentation says it can't be used to intercept HTTPS. If you know how to use a proxy with the WKWebView, please let us know! ## Clean up diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index 3f89648a..cfd3df89 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -21,14 +21,19 @@ package mobileproxy import ( "context" "fmt" + "io" "log" "net" "net/http" + "os" "strconv" + "strings" "time" + "github.com/Jigsaw-Code/outline-sdk/transport" "github.com/Jigsaw-Code/outline-sdk/x/config" "github.com/Jigsaw-Code/outline-sdk/x/httpproxy" + "github.com/Jigsaw-Code/outline-sdk/x/smart" ) // Proxy enables you to get the actual address bound by the server and stop the service when no longer needed. @@ -64,14 +69,9 @@ func (p *Proxy) Stop(timeoutSeconds int) { } } -// RunProxy runs a local web proxy that listens on localAddress, and uses the transportConfig to -// create a [transport.StreamDialer] that is used to connect to the requested destination. -func RunProxy(localAddress string, transportConfig string) (*Proxy, error) { - dialer, err := config.NewStreamDialer(transportConfig) - if err != nil { - return nil, fmt.Errorf("could not create dialer: %w", err) - } - +// RunProxy runs a local web proxy that listens on localAddress, and handles proxy requests by +// establishing connections to requested destination using the [StreamDialer]. +func RunProxy(localAddress string, dialer *StreamDialer) (*Proxy, error) { listener, err := net.Listen("tcp", localAddress) if err != nil { return nil, fmt.Errorf("could not listen on address %v: %v", localAddress, err) @@ -90,3 +90,94 @@ func RunProxy(localAddress string, transportConfig string) (*Proxy, error) { } return &Proxy{host: host, port: port, server: server}, nil } + +// StreamDialer encapsulates the logic to create stream connections (like TCP). +type StreamDialer struct { + transport.StreamDialer +} + +// NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. +// The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. +func NewStreamDialerFromConfig(transportConfig string) (*StreamDialer, error) { + dialer, err := config.NewStreamDialer(transportConfig) + if err != nil { + return nil, err + } + return &StreamDialer{dialer}, nil +} + +// LogWriter is used as a sink for logging. +type LogWriter io.StringWriter + +// Adaptor to convert an [io.StringWriter] to a [io.Writer]. +type stringToBytesWriter struct { + w io.Writer +} + +// WriteString implements [io.StringWriter]. +func (w *stringToBytesWriter) WriteString(logText string) (int, error) { + return io.WriteString(w.w, logText) +} + +// NewStderrLogWriter creates a [LogWriter] that writes to the standard error output. +func NewStderrLogWriter() LogWriter { + return &stringToBytesWriter{os.Stderr} +} + +// Adaptor to convert an [io.Writer] to a [io.StringWriter]. +type bytestoStringWriter struct { + sw io.StringWriter +} + +// Write implements [io.Writer]. +func (w *bytestoStringWriter) Write(b []byte) (int, error) { + return w.sw.WriteString(string(b)) +} + +func toWriter(logWriter LogWriter) io.Writer { + if logWriter == nil { + return nil + } + if w, ok := logWriter.(io.Writer); ok { + return w + } + return &bytestoStringWriter{logWriter} +} + +// NewSmartStreamDialer automatically selects a DNS and TLS strategy to use, and return a [StreamDialer] +// that will use the selected strategy. +// It uses testDomain to find a strategy that works when accessing those domains. +// The strategies to search are given in the searchConfig. An example can be found in +// https://github.com/Jigsaw-Code/outline-sdk/x/examples/smart-proxy/config.json +func NewSmartStreamDialer(testDomains *StringList, searchConfig string, logWriter LogWriter) (*StreamDialer, error) { + logBytesWriter := toWriter(logWriter) + // TODO: inject the base dialer for tests. + finder := smart.StrategyFinder{ + LogWriter: logBytesWriter, + TestTimeout: 5 * time.Second, + StreamDialer: &transport.TCPDialer{}, + PacketDialer: &transport.UDPDialer{}, + } + testDomainsSlice := append(make([]string, 0, len(testDomains.list)), testDomains.list...) + dialer, err := finder.NewDialer(context.Background(), testDomainsSlice, []byte(searchConfig)) + if err != nil { + return nil, fmt.Errorf("failed to find dialer: %v", err) + } + return &StreamDialer{dialer}, nil +} + +// StringList allows us to pass a list of strings to the Go Mobile functions, since Go Mobiule doesn't +// support slices as parameters. +type StringList struct { + list []string +} + +// Append adds the string value to the end of the list. +func (l *StringList) Append(value string) { + l.list = append(l.list, value) +} + +// NewListFromLines creates a StringList by splitting the input string on new lines. +func NewListFromLines(lines string) *StringList { + return &StringList{list: strings.Split(lines, "\n")} +} From 7bce93d387dca6bdbbcd3c036a0751a4e76bd0dc Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 23:33:22 -0500 Subject: [PATCH 16/25] Updates --- x/mobileproxy/.gitignore | 1 + x/mobileproxy/README.md | 74 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 x/mobileproxy/.gitignore diff --git a/x/mobileproxy/.gitignore b/x/mobileproxy/.gitignore new file mode 100644 index 00000000..e2e7327c --- /dev/null +++ b/x/mobileproxy/.gitignore @@ -0,0 +1 @@ +/out diff --git a/x/mobileproxy/README.md b/x/mobileproxy/README.md index 56955e18..2bacdd8c 100644 --- a/x/mobileproxy/README.md +++ b/x/mobileproxy/README.md @@ -195,6 +195,70 @@ public interface LogWriter { } ``` +`StreamDialer.java`: + +```java +// Code generated by gobind. DO NOT EDIT. + +// Java class mobileproxy.StreamDialer is a proxy for talking to a Go program. +// +// autogenerated by gobind -lang=java github.com/Jigsaw-Code/outline-sdk/x/mobileproxy +package mobileproxy; + +import go.Seq; + +/** + * StreamDialer encapsulates the logic to create stream connections (like TCP). + */ +public final class StreamDialer implements Seq.Proxy { + static { Mobileproxy.touch(); } + + private final int refnum; + + @Override public final int incRefnum() { + Seq.incGoRef(refnum, this); + return refnum; + } + + /** + * NewStreamDialerFromConfig creates a [StreamDialer] based on the given config. + The config format is specified in https://pkg.go.dev/github.com/Jigsaw-Code/outline-sdk/x/config#hdr-Config_Format. + */ + public StreamDialer(String transportConfig) { + this.refnum = __NewStreamDialerFromConfig(transportConfig); + Seq.trackGoRef(refnum, this); + } + + private static native int __NewStreamDialerFromConfig(String transportConfig); + + StreamDialer(int refnum) { this.refnum = refnum; Seq.trackGoRef(refnum, this); } + + // skipped field StreamDialer.StreamDialer with unsupported type: github.com/Jigsaw-Code/outline-sdk/transport.StreamDialer + + // skipped method StreamDialer.DialStream with unsupported parameter or return types + + @Override public boolean equals(Object o) { + if (o == null || !(o instanceof StreamDialer)) { + return false; + } + StreamDialer that = (StreamDialer)o; + // skipped field StreamDialer.StreamDialer with unsupported type: github.com/Jigsaw-Code/outline-sdk/transport.StreamDialer + + return true; + } + + @Override public int hashCode() { + return java.util.Arrays.hashCode(new Object[] {}); + } + + @Override public String toString() { + StringBuilder b = new StringBuilder(); + b.append("StreamDialer").append("{"); + return b.append("}").toString(); + } +} +``` + `Mobileproxy.java`: @@ -403,7 +467,9 @@ You need to call the `RunProxy` function passing the local address to use, and t On Android, you can have the following Kotlin code: ```kotlin // Use port zero to let the system pick an open port for you. -val proxy = mobileproxy.runProxy("localhost:0", "split:3") +val dialer = StreamDialer("split:3") + +val proxy = Mobileproxy.runProxy("localhost:0", dialer) // Configure your networking library using proxy.host() and proxy.port() or proxy.address(). // ... // Stop running the proxy. @@ -418,9 +484,11 @@ You need to specify a strategy config in JSON format ([example](../examples/smar On Android, the Kotlin code would look like this: ```kotlin // Use port zero to let the system pick an open port for you. -val testDomains = mobileproxy.newListFromLines("www.youtube.com\ni.ytimg.com") +val testDomains = Mobileproxy.newListFromLines("www.youtube.com\ni.ytimg.com") val strategiesConfig = "..." // Config JSON. -val proxy = mobileproxy.runSmartProxy("localhost:0", testDomains, strategies) +val dialer = Mobileproxy.newSmartStreamDialer(testDomains, strategies, Mobileproxy.newStderrLogWriter()) + +val proxy = Mobileproxy.runProxy("localhost:0", dialer) // Configure your networking library using proxy.host() and proxy.port() or proxy.address(). // ... // Stop running the proxy. From 19391f6b07be6bcf8f40ea5d0a97e335297bccd3 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 20 Feb 2024 23:37:57 -0500 Subject: [PATCH 17/25] Fix fetch-proxy --- x/examples/fetch-proxy/main.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/x/examples/fetch-proxy/main.go b/x/examples/fetch-proxy/main.go index e9b8c5c5..d452caf9 100644 --- a/x/examples/fetch-proxy/main.go +++ b/x/examples/fetch-proxy/main.go @@ -34,9 +34,13 @@ func main() { log.Fatal("Need to pass the URL to fetch in the command-line") } - proxy, err := mobileproxy.RunProxy("localhost:0", *transportFlag) + dialer, err := mobileproxy.NewStreamDialerFromConfig(*transportFlag) if err != nil { - log.Fatalf("Cmobileproxy start proxy: %v", err) + log.Fatalf("NewStreamDialerFromConfig failed: %v", err) + } + proxy, err := mobileproxy.RunProxy("localhost:0", dialer) + if err != nil { + log.Fatalf("RunProxy failed: %v", err) } httpClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(&url.URL{Scheme: "http", Host: proxy.Address()})}} From d658b329272e90f65f7223111b701397419d76d8 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 21 Feb 2024 13:23:47 -0500 Subject: [PATCH 18/25] Address review comments --- x/examples/smart-proxy/main.go | 20 ++++++++++++++++---- x/mobileproxy/mobileproxy.go | 4 ++-- x/smart/cname_unix.go | 4 ++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go index b9581a4f..c529ce6c 100644 --- a/x/examples/smart-proxy/main.go +++ b/x/examples/smart-proxy/main.go @@ -24,7 +24,6 @@ import ( "net/http" "os" "os/signal" - "strings" "time" "github.com/Jigsaw-Code/outline-sdk/transport" @@ -46,6 +45,19 @@ func (v *stringArrayFlagValue) Set(value string) error { return nil } +func supportsHappyEyeballs(dialer transport.StreamDialer) bool { + // Some proxy protocols, most notably Shadowsocks, can't communicate connection success. + // Our shadowsocks.StreamDialer will return a connection successfully as long as it can + // connect to the proxy server, regardless of whether it can connect to the target. + // This breaks HappyEyeballs. + conn, err := dialer.DialStream(context.Background(), "invalid:0") + if conn != nil { + conn.Close() + } + // If the dialer returns success on an invalid address, it doesn't support Happy Eyeballs. + return err != nil +} + func main() { verboseFlag := flag.Bool("v", false, "Enable debug output") addrFlag := flag.String("localAddr", "localhost:1080", "Local proxy address") @@ -80,9 +92,9 @@ func main() { if err != nil { log.Fatalf("Could not create stream dialer: %v", err) } - if strings.HasPrefix(*transportFlag, "ss:") { + if !supportsHappyEyeballs(streamDialer) { innerDialer := streamDialer - // Hack to disable IPv6 with Shadowsocks, since it doesn't communicate connection success. + // Disable IPv6 if the dialer doesn't support HappyEyballs. streamDialer = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { @@ -137,7 +149,7 @@ func main() { sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt) <-sig - fmt.Print("Shutting down") + fmt.Println("Shutting down") // Gracefully shut down the server, with a 5s timeout. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index cfd3df89..cf9dfc76 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -161,12 +161,12 @@ func NewSmartStreamDialer(testDomains *StringList, searchConfig string, logWrite testDomainsSlice := append(make([]string, 0, len(testDomains.list)), testDomains.list...) dialer, err := finder.NewDialer(context.Background(), testDomainsSlice, []byte(searchConfig)) if err != nil { - return nil, fmt.Errorf("failed to find dialer: %v", err) + return nil, fmt.Errorf("failed to find dialer: %w", err) } return &StreamDialer{dialer}, nil } -// StringList allows us to pass a list of strings to the Go Mobile functions, since Go Mobiule doesn't +// StringList allows us to pass a list of strings to the Go Mobile functions, since Go Mobile doesn't // support slices as parameters. type StringList struct { list []string diff --git a/x/smart/cname_unix.go b/x/smart/cname_unix.go index 521ffd0a..0685b2ef 100644 --- a/x/smart/cname_unix.go +++ b/x/smart/cname_unix.go @@ -30,6 +30,10 @@ import ( "unsafe" ) +// lookupCNAME provides functionality equivalent to net.DefaultResolver.LookupCNAME. However, +// net.DefaultResolver.LookupCNAME uses libresolv on unix, and, on Android and iOS, it tries +// to connect to [::1]:53 (probably from /etc/resolv.conf) and the connection is refused. +// Instead, we use getaddrinfo to get the canonical name. func lookupCNAME(ctx context.Context, domain string) (string, error) { type result struct { cname string From 4ec7c5b399c5a1c52561853a31cabccb0be2b145 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 21 Feb 2024 13:37:41 -0500 Subject: [PATCH 19/25] Emit warning --- x/examples/smart-proxy/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go index c529ce6c..4676b282 100644 --- a/x/examples/smart-proxy/main.go +++ b/x/examples/smart-proxy/main.go @@ -93,6 +93,7 @@ func main() { log.Fatalf("Could not create stream dialer: %v", err) } if !supportsHappyEyeballs(streamDialer) { + fmt.Println("⚠️ Warning: base transport is not compatible with Happy Eyeballs. Disabling IPv6.") innerDialer := streamDialer // Disable IPv6 if the dialer doesn't support HappyEyballs. streamDialer = transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { From 771dc1cbb689079e8977ddc22be54dbfe344c4f6 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 21 Feb 2024 16:18:55 -0500 Subject: [PATCH 20/25] Remove type --- x/smart/stream_dialer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 1fd335b3..a42b805e 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -182,7 +182,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) ctx, searchDone := context.WithCancel(context.Background()) defer searchDone() raceStart := time.Now() - resolver, err := raceTests[*smartResolver](ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) (*smartResolver, error) { + resolver, err := raceTests(ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) (*smartResolver, error) { for _, testDomain := range testDomains { select { case <-ctx.Done(): From d7e5a94d47a05b05c7c31d9dec8b4dd2240ba953 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 22 Feb 2024 09:44:27 -0600 Subject: [PATCH 21/25] Apply suggestions from code review Co-authored-by: J. Yi <93548144+jyyi1@users.noreply.github.com> --- x/mobileproxy/README.md | 2 +- x/mobileproxy/mobileproxy.go | 4 ++-- x/smart/dns.go | 6 +++--- x/smart/racer.go | 2 +- x/smart/stream_dialer.go | 10 +++++----- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/x/mobileproxy/README.md b/x/mobileproxy/README.md index 2bacdd8c..16ae5418 100644 --- a/x/mobileproxy/README.md +++ b/x/mobileproxy/README.md @@ -486,7 +486,7 @@ On Android, the Kotlin code would look like this: // Use port zero to let the system pick an open port for you. val testDomains = Mobileproxy.newListFromLines("www.youtube.com\ni.ytimg.com") val strategiesConfig = "..." // Config JSON. -val dialer = Mobileproxy.newSmartStreamDialer(testDomains, strategies, Mobileproxy.newStderrLogWriter()) +val dialer = Mobileproxy.newSmartStreamDialer(testDomains, strategiesConfig, Mobileproxy.newStderrLogWriter()) val proxy = Mobileproxy.runProxy("localhost:0", dialer) // Configure your networking library using proxy.host() and proxy.port() or proxy.address(). diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index cf9dfc76..0d44adcf 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -144,9 +144,9 @@ func toWriter(logWriter LogWriter) io.Writer { return &bytestoStringWriter{logWriter} } -// NewSmartStreamDialer automatically selects a DNS and TLS strategy to use, and return a [StreamDialer] +// NewSmartStreamDialer automatically selects a DNS and TLS strategy to use, and returns a [StreamDialer] // that will use the selected strategy. -// It uses testDomain to find a strategy that works when accessing those domains. +// It uses testDomains to find a strategy that works when accessing those domains. // The strategies to search are given in the searchConfig. An example can be found in // https://github.com/Jigsaw-Code/outline-sdk/x/examples/smart-proxy/config.json func NewSmartStreamDialer(testDomains *StringList, searchConfig string, logWriter LogWriter) (*StreamDialer, error) { diff --git a/x/smart/dns.go b/x/smart/dns.go index 985ee2dc..183afe50 100644 --- a/x/smart/dns.go +++ b/x/smart/dns.go @@ -61,7 +61,7 @@ func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain return nil, fmt.Errorf("failed to lookup IPs: %w", err) } if len(ips) == 0 { - return nil, fmt.Errorf("no ip answer") + return nil, errors.New("no ip answer") } for _, ip := range ips { if ip.IsLoopback() { @@ -158,7 +158,7 @@ func evaluateCNAMEResponse(response dnsmessage.Message, requestDomain string) er } } if cname == "" { - return fmt.Errorf("no CNAME in answers") + return errors.New("no CNAME in answers") } return nil } @@ -201,7 +201,7 @@ func testDNSResolver(baseCtx context.Context, oneTestTimeout time.Duration, reso // case in China. q, err = dns.NewQuestion(requestDomain, dnsmessage.TypeCNAME) if err != nil { - return nil, fmt.Errorf("failed to create question: %v", err) + return nil, fmt.Errorf("failed to create question: %w", err) } ctxCNAME, cancelCNAME := context.WithTimeout(baseCtx, oneTestTimeout) defer cancelCNAME() diff --git a/x/smart/racer.go b/x/smart/racer.go index c66037d9..d2a05912 100644 --- a/x/smart/racer.go +++ b/x/smart/racer.go @@ -27,7 +27,7 @@ func newClosedChanel() <-chan struct{} { return ch } -// raceTests races will call the test function on each entry until it finds an entry for which the test returns nil. +// raceTests will call the test function on each entry until it finds an entry for which the test returns nil error. // That entry is returned. A test is only started after the previous test finished or maxWait is done, whichever // happens first. That way you bound the wait for a test, and they may overlap. func raceTests[E any, R any](ctx context.Context, maxWait time.Duration, entries []E, test func(entry E) (R, error)) (R, error) { diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index a42b805e..5b6fc650 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -87,14 +87,14 @@ type configJSON struct { TLS []string `json:"tls,omitempty"` } -// newDNSResolverFromEntry creates a [dns.Resolver] based on the config, returning the resolver +// newDNSResolverFromEntry creates a [dns.Resolver] based on the config, returning the resolver and // a boolean indicating whether the resolver is secure (TLS, HTTPS) and a possible error. func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolver, bool, error) { if entry.System != nil { return nil, false, nil } else if cfg := entry.HTTPS; cfg != nil { if cfg.Name == "" { - return nil, true, fmt.Errorf("https entry has empty server name") + return nil, true, errors.New("https entry has empty server name") } serverAddr := cfg.Address if serverAddr == "" { @@ -109,7 +109,7 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv return dns.NewHTTPSResolver(f.StreamDialer, serverAddr, dohURL.String()), true, nil } else if cfg := entry.TLS; cfg != nil { if cfg.Name == "" { - return nil, true, fmt.Errorf("tls entry has empty server name") + return nil, true, errors.New("tls entry has empty server name") } serverAddr := cfg.Address if serverAddr == "" { @@ -122,7 +122,7 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv return dns.NewTLSResolver(f.StreamDialer, serverAddr, cfg.Name), true, nil } else if cfg := entry.TCP; cfg != nil { if cfg.Address == "" { - return nil, false, fmt.Errorf("tcp entry has empty server address") + return nil, false, errors.New("tcp entry has empty server address") } host, port, err := net.SplitHostPort(cfg.Address) if err != nil { @@ -133,7 +133,7 @@ func (f *StrategyFinder) newDNSResolverFromEntry(entry dnsEntryJSON) (dns.Resolv return dns.NewTCPResolver(f.StreamDialer, serverAddr), false, nil } else if cfg := entry.UDP; cfg != nil { if cfg.Address == "" { - return nil, false, fmt.Errorf("udp entry has empty server address") + return nil, false, errors.New("udp entry has empty server address") } host, port, err := net.SplitHostPort(cfg.Address) if err != nil { From ea13524f64be3bfdf2879e0f558ebf928f5b9ea0 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 22 Feb 2024 10:00:15 -0600 Subject: [PATCH 22/25] Update x/smart/dns.go Co-authored-by: J. Yi <93548144+jyyi1@users.noreply.github.com> --- x/smart/dns.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/smart/dns.go b/x/smart/dns.go index 183afe50..9814c5d5 100644 --- a/x/smart/dns.go +++ b/x/smart/dns.go @@ -175,7 +175,7 @@ func testDNSResolver(baseCtx context.Context, oneTestTimeout time.Duration, reso q, err := dns.NewQuestion(requestDomain, dnsmessage.TypeA) if err != nil { - return nil, fmt.Errorf("failed to create question: %v", err) + return nil, fmt.Errorf("failed to create question: %w", err) } ctxA, cancelA := context.WithTimeout(baseCtx, oneTestTimeout) defer cancelA() From 59612c27e96fee7c753a38cc9eb34eb0a15ea740 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 22 Feb 2024 11:10:30 -0500 Subject: [PATCH 23/25] Address review comments --- x/mobileproxy/mobileproxy.go | 3 +-- x/smart/cache.go | 22 +++++++++++----------- x/smart/dns.go | 8 ++++---- x/smart/racer.go | 1 + x/smart/stream_dialer.go | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index 0d44adcf..217a47c2 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -158,8 +158,7 @@ func NewSmartStreamDialer(testDomains *StringList, searchConfig string, logWrite StreamDialer: &transport.TCPDialer{}, PacketDialer: &transport.UDPDialer{}, } - testDomainsSlice := append(make([]string, 0, len(testDomains.list)), testDomains.list...) - dialer, err := finder.NewDialer(context.Background(), testDomainsSlice, []byte(searchConfig)) + dialer, err := finder.NewDialer(context.Background(), testDomains.list, []byte(searchConfig)) if err != nil { return nil, fmt.Errorf("failed to find dialer: %w", err) } diff --git a/x/smart/cache.go b/x/smart/cache.go index 1bc218d4..99ed55e9 100644 --- a/x/smart/cache.go +++ b/x/smart/cache.go @@ -42,22 +42,22 @@ type cacheEntry struct { expire time.Time } -// cacheResolver is a very simple caching [dns.Resolver]. -// It doesn't use the response TTL and doesn't cache empty answers. +// simpleLRUCacheResolver is a very simple caching [dns.Resolver]. +// It doesn't use the response TTL. // It also doesn't dedup duplicate in-flight requests. -type cacheResolver struct { +type simpleLRUCacheResolver struct { resolver dns.Resolver cache []cacheEntry mux sync.Mutex } -var _ dns.Resolver = (*cacheResolver)(nil) +var _ dns.Resolver = (*simpleLRUCacheResolver)(nil) -func newCacheResolver(resolver dns.Resolver, numEntries int) dns.Resolver { - return &cacheResolver{resolver: resolver, cache: make([]cacheEntry, numEntries)} +func newSimpleLRUCacheResolver(resolver dns.Resolver, numEntries int) dns.Resolver { + return &simpleLRUCacheResolver{resolver: resolver, cache: make([]cacheEntry, numEntries)} } -func (r *cacheResolver) RemoveExpired() { +func (r *simpleLRUCacheResolver) RemoveExpired() { now := time.Now() last := 0 r.mux.Lock() @@ -71,7 +71,7 @@ func (r *cacheResolver) RemoveExpired() { r.cache = r.cache[:last] } -func (r *cacheResolver) moveToFront(index int) { +func (r *simpleLRUCacheResolver) moveToFront(index int) { entry := r.cache[index] copy(r.cache[1:], r.cache[:index]) r.cache[0] = entry @@ -82,7 +82,7 @@ func makeCacheKey(q dnsmessage.Question) string { return strings.Join([]string{domainKey, q.Type.String(), q.Class.String()}, "|") } -func (r *cacheResolver) SearchCache(key string) *dnsmessage.Message { +func (r *simpleLRUCacheResolver) SearchCache(key string) *dnsmessage.Message { r.mux.Lock() defer r.mux.Unlock() for ei, entry := range r.cache { @@ -96,7 +96,7 @@ func (r *cacheResolver) SearchCache(key string) *dnsmessage.Message { return nil } -func (r *cacheResolver) AddToCache(key string, msg *dnsmessage.Message) { +func (r *simpleLRUCacheResolver) AddToCache(key string, msg *dnsmessage.Message) { r.mux.Lock() defer r.mux.Unlock() newSize := len(r.cache) + 1 @@ -110,7 +110,7 @@ func (r *cacheResolver) AddToCache(key string, msg *dnsmessage.Message) { } // Query implements [dns.Resolver]. -func (r *cacheResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { +func (r *simpleLRUCacheResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) { r.RemoveExpired() cacheKey := makeCacheKey(q) if msg := r.SearchCache(cacheKey); msg != nil { diff --git a/x/smart/dns.go b/x/smart/dns.go index 9814c5d5..2e496055 100644 --- a/x/smart/dns.go +++ b/x/smart/dns.go @@ -65,15 +65,15 @@ func evaluateNetResolver(ctx context.Context, resolver *net.Resolver, testDomain } for _, ip := range ips { if ip.IsLoopback() { - return nil, fmt.Errorf("localhost ip: %v", ip) // -1 + return nil, fmt.Errorf("localhost ip: %v", ip) } if ip.IsPrivate() { - return nil, fmt.Errorf("private ip: %v", ip) // -1 + return nil, fmt.Errorf("private ip: %v", ip) } if ip.IsUnspecified() { - return nil, fmt.Errorf("zero ip: %v", ip) // -1 + return nil, fmt.Errorf("zero ip: %v", ip) } - // TODO: consider validating the IPs: fingerprint, hardcoded ground truth, trusted response, TLS connection. + // TODO: consider validating the IPs: fingerprint, TCP connection, hardcoded ground truth, trusted response, TLS connection. } return ips, nil } diff --git a/x/smart/racer.go b/x/smart/racer.go index d2a05912..70d4826d 100644 --- a/x/smart/racer.go +++ b/x/smart/racer.go @@ -30,6 +30,7 @@ func newClosedChanel() <-chan struct{} { // raceTests will call the test function on each entry until it finds an entry for which the test returns nil error. // That entry is returned. A test is only started after the previous test finished or maxWait is done, whichever // happens first. That way you bound the wait for a test, and they may overlap. +// The test function should make use of the context to stop doing work when the race is done and it is no longer needed. func raceTests[E any, R any](ctx context.Context, maxWait time.Duration, entries []E, test func(entry E) (R, error)) (R, error) { type testResult struct { Result R diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 5b6fc650..2d750cdb 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -304,7 +304,7 @@ func (f *StrategyFinder) NewDialer(ctx context.Context, testDomains []string, co } dnsDialer = f.StreamDialer } else { - resolver = newCacheResolver(resolver, 100) + resolver = newSimpleLRUCacheResolver(resolver, 100) dnsDialer, err = dns.NewStreamDialer(resolver, f.StreamDialer) if err != nil { return nil, fmt.Errorf("dns.NewStreamDialer failed: %w", err) From 783f0beccec48c137813d2263d89dd12ebd3986a Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 22 Feb 2024 11:49:53 -0500 Subject: [PATCH 24/25] Add timeout to HappyEyeballs test --- x/examples/smart-proxy/main.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x/examples/smart-proxy/main.go b/x/examples/smart-proxy/main.go index 4676b282..61c368bc 100644 --- a/x/examples/smart-proxy/main.go +++ b/x/examples/smart-proxy/main.go @@ -50,7 +50,9 @@ func supportsHappyEyeballs(dialer transport.StreamDialer) bool { // Our shadowsocks.StreamDialer will return a connection successfully as long as it can // connect to the proxy server, regardless of whether it can connect to the target. // This breaks HappyEyeballs. - conn, err := dialer.DialStream(context.Background(), "invalid:0") + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + conn, err := dialer.DialStream(ctx, "invalid:0") + cancel() if conn != nil { conn.Close() } From 453488151fc5efbffd0721c5fced0e7ad456a5a2 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Thu, 22 Feb 2024 11:50:12 -0500 Subject: [PATCH 25/25] Add context --- x/smart/stream_dialer.go | 55 +++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/x/smart/stream_dialer.go b/x/smart/stream_dialer.go index 2d750cdb..c34b61da 100644 --- a/x/smart/stream_dialer.go +++ b/x/smart/stream_dialer.go @@ -50,6 +50,18 @@ func (f *StrategyFinder) log(format string, a ...any) { } } +// Only log if context is not done +func (f *StrategyFinder) logCtx(ctx context.Context, format string, a ...any) { + if ctx != nil { + select { + case <-ctx.Done(): + return + default: + } + } + f.log(format, a...) +} + type httpsEntryJSON struct { // Domain name of the host. Name string `json:"name,omitempty"` @@ -173,13 +185,13 @@ func (f *StrategyFinder) dnsConfigToResolver(dnsConfig []dnsEntryJSON) ([]*smart return rts, nil } -func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { +func (f *StrategyFinder) findDNS(ctx context.Context, testDomains []string, dnsConfig []dnsEntryJSON) (dns.Resolver, error) { resolvers, err := f.dnsConfigToResolver(dnsConfig) if err != nil { return nil, err } - ctx, searchDone := context.WithCancel(context.Background()) + ctx, searchDone := context.WithCancel(ctx) defer searchDone() raceStart := time.Now() resolver, err := raceTests(ctx, 250*time.Millisecond, resolvers, func(resolver *smartResolver) (*smartResolver, error) { @@ -190,7 +202,7 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) default: } - f.log("🏃 run DNS: %v (domain: %v)\n", resolver.ID, testDomain) + f.logCtx(ctx, "🏃 run DNS: %v (domain: %v)\n", resolver.ID, testDomain) startTime := time.Now() ips, err := testDNSResolver(ctx, f.TestTimeout, resolver, testDomain) duration := time.Since(startTime) @@ -199,12 +211,8 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) if err != nil { status = fmt.Sprintf("%v ❌", err) } - select { - case <-ctx.Done(): - default: - // Only output log if the search is not done yet. - f.log("🏁 got DNS: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) - } + // Only output log if the search is not done yet. + f.logCtx(ctx, "🏁 got DNS: %v (domain: %v), duration=%v, ips=%v, status=%v\n", resolver.ID, testDomain, duration, ips, status) if err != nil { return nil, err @@ -219,15 +227,19 @@ func (f *StrategyFinder) findDNS(testDomains []string, dnsConfig []dnsEntryJSON) return resolver.Resolver, nil } -func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.StreamDialer, tlsConfig []string) (transport.StreamDialer, error) { +func (f *StrategyFinder) findTLS(ctx context.Context, testDomains []string, baseDialer transport.StreamDialer, tlsConfig []string) (transport.StreamDialer, error) { if len(tlsConfig) == 0 { return nil, errors.New("config for TLS is empty. Please specify at least one transport") } - ctx, searchDone := context.WithCancel(context.Background()) + ctx, searchDone := context.WithCancel(ctx) defer searchDone() raceStart := time.Now() - tlsDialer, err := raceTests(ctx, 250*time.Millisecond, tlsConfig, func(transportCfg string) (transport.StreamDialer, error) { + type SearchResult struct { + Dialer transport.StreamDialer + Config string + } + result, err := raceTests(ctx, 250*time.Millisecond, tlsConfig, func(transportCfg string) (*SearchResult, error) { tlsDialer, err := config.WrapStreamDialer(baseDialer, transportCfg) if err != nil { return nil, fmt.Errorf("WrapStreamDialer failed: %w", err) @@ -236,30 +248,31 @@ func (f *StrategyFinder) findTLS(testDomains []string, baseDialer transport.Stre startTime := time.Now() testAddr := net.JoinHostPort(testDomain, "443") - f.log("🏃 run TLS: '%v' (domain: %v)\n", transportCfg, testDomain) + f.logCtx(ctx, "🏃 run TLS: '%v' (domain: %v)\n", transportCfg, testDomain) - ctx, cancel := context.WithTimeout(context.Background(), f.TestTimeout) + ctx, cancel := context.WithTimeout(ctx, f.TestTimeout) defer cancel() testConn, err := tlsDialer.DialStream(ctx, testAddr) if err != nil { - f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, dial_error=%v ❌\n", transportCfg, testDomain, time.Since(startTime), err) + f.logCtx(ctx, "🏁 got TLS: '%v' (domain: %v), duration=%v, dial_error=%v ❌\n", transportCfg, testDomain, time.Since(startTime), err) return nil, err } tlsConn := tls.Client(testConn, &tls.Config{ServerName: testDomain}) err = tlsConn.HandshakeContext(ctx) tlsConn.Close() if err != nil { - f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, handshake=%v ❌\n", transportCfg, testDomain, time.Since(startTime), err) + f.logCtx(ctx, "🏁 got TLS: '%v' (domain: %v), duration=%v, handshake=%v ❌\n", transportCfg, testDomain, time.Since(startTime), err) return nil, err } - f.log("🏁 got TLS: '%v' (domain: %v), duration=%v, status=ok ✅\n", transportCfg, testDomain, time.Since(startTime)) + f.logCtx(ctx, "🏁 got TLS: '%v' (domain: %v), duration=%v, status=ok ✅\n", transportCfg, testDomain, time.Since(startTime)) } - f.log("🏆 selected TLS strategy '%v' in %0.2fs\n\n", transportCfg, time.Since(raceStart).Seconds()) - return tlsDialer, nil + return &SearchResult{tlsDialer, transportCfg}, nil }) if err != nil { return nil, fmt.Errorf("could not find TLS strategy: %w", err) } + f.log("🏆 selected TLS strategy '%v' in %0.2fs\n\n", result.Config, time.Since(raceStart).Seconds()) + tlsDialer := result.Dialer return transport.FuncStreamDialer(func(ctx context.Context, raddr string) (transport.StreamConn, error) { _, portStr, err := net.SplitHostPort(raddr) if err != nil { @@ -293,7 +306,7 @@ func (f *StrategyFinder) NewDialer(ctx context.Context, testDomains []string, co testDomains[di] = makeFullyQualified(domain) } - resolver, err := f.findDNS(testDomains, parsedConfig.DNS) + resolver, err := f.findDNS(ctx, testDomains, parsedConfig.DNS) if err != nil { return nil, err } @@ -314,5 +327,5 @@ func (f *StrategyFinder) NewDialer(ctx context.Context, testDomains []string, co if len(parsedConfig.TLS) == 0 { return dnsDialer, nil } - return f.findTLS(testDomains, dnsDialer, parsedConfig.TLS) + return f.findTLS(ctx, testDomains, dnsDialer, parsedConfig.TLS) }