Skip to content

Commit

Permalink
Make config public
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Oct 26, 2023
1 parent b9e9e41 commit 900c30b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions transport/tls/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ func (d *StreamDialer) Dial(ctx context.Context, remoteAddr string) (transport.S
return conn, nil
}

// clientConfig encodes the parameters for a TLS client connection.
type clientConfig struct {
// ClientConfig encodes the parameters for a TLS client connection.
type ClientConfig struct {
ServerName string
CertificateName string
NextProtos []string
SessionCache tls.ClientSessionCache
}

// ToStdConfig creates a [tls.Config] based on the configured parameters.
func (cfg *clientConfig) ToStdConfig() *tls.Config {
func (cfg *ClientConfig) ToStdConfig() *tls.Config {
return &tls.Config{
ServerName: cfg.ServerName,
NextProtos: cfg.NextProtos,
Expand All @@ -106,7 +106,7 @@ func (cfg *clientConfig) ToStdConfig() *tls.Config {
}

// ClientOption allows configuring the parameters to be used for a client TLS connection.
type ClientOption func(host string, port int, config *clientConfig)
type ClientOption func(host string, port int, config *ClientConfig)

// WrapConn wraps a [transport.StreamConn] in a TLS connection.
func WrapConn(ctx context.Context, conn transport.StreamConn, remoteAdr string, options ...ClientOption) (transport.StreamConn, error) {
Expand All @@ -118,7 +118,7 @@ func WrapConn(ctx context.Context, conn transport.StreamConn, remoteAdr string,
if err != nil {
return nil, fmt.Errorf("could not resolve port: %w", err)
}
cfg := clientConfig{ServerName: host, CertificateName: host}
cfg := ClientConfig{ServerName: host, CertificateName: host}
for _, option := range options {
option(host, port, &cfg)
}
Expand All @@ -134,30 +134,30 @@ func WrapConn(ctx context.Context, conn transport.StreamConn, remoteAdr string,
// If absent, defaults to the dialed hostname.
// Note that this only changes what is sent in the SNI, not what host is used for certificate verification.
func WithSNI(hostName string) ClientOption {
return func(_ string, _ int, config *clientConfig) {
return func(_ string, _ int, config *ClientConfig) {
config.ServerName = hostName
}
}

// WithALPN sets the protocol name list for [Application-Layer Protocol Negotiation](https://datatracker.ietf.org/doc/html/rfc7301) (ALPN).
// The list of protocol IDs can be found in [IANA's registry](https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids).
func WithALPN(protocolNameList []string) ClientOption {
return func(_ string, _ int, config *clientConfig) {
return func(_ string, _ int, config *ClientConfig) {
config.NextProtos = protocolNameList
}
}

// WithSessionCache sets the [tls.ClientSessionCache] to enable session resumption of TLS connections.
func WithSessionCache(sessionCache tls.ClientSessionCache) ClientOption {
return func(_ string, _ int, config *clientConfig) {
return func(_ string, _ int, config *ClientConfig) {
config.SessionCache = sessionCache
}
}

// WithCertificateName sets the hostname to be used for the certificate cerification.
// If absent, defaults to the dialed hostname.
func WithCertificateName(hostname string) ClientOption {
return func(_ string, _ int, config *clientConfig) {
return func(_ string, _ int, config *ClientConfig) {
config.CertificateName = hostname
}
}
4 changes: 2 additions & 2 deletions transport/tls/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ func TestAllCustom(t *testing.T) {
}

func TestWithSNI(t *testing.T) {
var cfg clientConfig
var cfg ClientConfig
WithSNI("example.com")("", 0, &cfg)
require.Equal(t, "example.com", cfg.ServerName)
}

func TestWithALPN(t *testing.T) {
var cfg clientConfig
var cfg ClientConfig
WithALPN([]string{"h2", "http/1.1"})("", 0, &cfg)
require.Equal(t, []string{"h2", "http/1.1"}, cfg.NextProtos)
}

0 comments on commit 900c30b

Please sign in to comment.