From 3c29ae5c7f18e5e133fc63be0eb2f9f95a7e169c Mon Sep 17 00:00:00 2001 From: Austin DeNoble Date: Thu, 1 Aug 2024 18:25:48 -0400 Subject: [PATCH] update host normalization to be a bit smarter about evaluating the incoming uri, update unit tests --- pinecone/index_connection.go | 26 +++++++++++++++++--------- pinecone/index_connection_test.go | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/pinecone/index_connection.go b/pinecone/index_connection.go index cef58b9..32351cc 100644 --- a/pinecone/index_connection.go +++ b/pinecone/index_connection.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "log" + "net/url" "strings" "github.com/pinecone-io/go-pinecone/internal/gen/data" @@ -35,16 +36,20 @@ type newIndexParameters struct { } func newIndexConnection(in newIndexParameters, dialOpts ...grpc.DialOption) (*IndexConnection, error) { - config := &tls.Config{} target := normalizeHost(in.host) // configure default gRPC DialOptions grpcOptions := []grpc.DialOption{ - grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithAuthority(target), grpc.WithUserAgent(useragent.BuildUserAgentGRPC(in.sourceTag)), } + // if the target includes an http:// address, don't include TLS + if strings.HasPrefix(target, "http://") { + config := &tls.Config{} + grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(credentials.NewTLS(config))) + } + // if we have user-provided dialOpts, append them to the defaults here dialOpts = append(grpcOptions, dialOpts...) @@ -1094,18 +1099,21 @@ func sparseValToGrpc(sv *SparseValues) *data.SparseValues { } func normalizeHost(host string) string { - hasPort := strings.Contains(host, ":") - - // remove https:// from the host - host = strings.TrimPrefix(host, "https://") + parsedHost, err := url.Parse(host) + if err != nil { + log.Default().Printf("Failed to parse host %s: %v", host, err) + return host + } - // if plaintext without a port, strip http:// as well - if !hasPort { + // if https:// or http:// without a port, strip the scheme + if parsedHost.Scheme == "https" { + host = strings.TrimPrefix(host, "https://") + } else if parsedHost.Scheme == "http" && parsedHost.Port() == "" { host = strings.TrimPrefix(host, "http://") } // if a port was provided leave it, otherwise we append :443 - if !hasPort { + if parsedHost.Port() == "" { host = host + ":443" } diff --git a/pinecone/index_connection_test.go b/pinecone/index_connection_test.go index 0bc0b94..8f4c825 100644 --- a/pinecone/index_connection_test.go +++ b/pinecone/index_connection_test.go @@ -991,7 +991,7 @@ func TestToUsageUnit(t *testing.T) { } } -func TestNormalizeHost(t *testing.T) { +func TestNormalizeHostUnit(t *testing.T) { tests := []struct { name string host string