Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exposed TLS SNI extension #40

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package openssl

// #cgo pkg-config: libssl
// #cgo pkg-config: libssl libcrypto
// #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN
// #cgo darwin CFLAGS: -Wno-deprecated-declarations
import "C"
54 changes: 54 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build cgo
// +build cgo

package openssl
Expand All @@ -31,6 +32,9 @@ package openssl
// const char * SSL_get_cipher_name_not_a_macro(const SSL *ssl) {
// return SSL_get_cipher_name(ssl);
// }
// int SSL_version_not_a_macro(const SSL *ssl) {
// return SSL_version(ssl);
// }
import "C"

import (
Expand Down Expand Up @@ -476,6 +480,43 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return 0, err
}

func (c *Conn) peek(b []byte) (int, func() error) {
if len(b) == 0 {
return 0, nil
}
c.mtx.Lock()
defer c.mtx.Unlock()
if c.is_shutdown {
return 0, func() error { return io.EOF }
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
rv, errno := C.SSL_peek(c.ssl, unsafe.Pointer(&b[0]), C.int(len(b)))
if rv > 0 {
return int(rv), nil
}
return 0, c.getErrorHandler(rv, errno)
}

func (c *Conn) Peek(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
err = tryAgain
for err == tryAgain {
n, errcb := c.peek(b)
err = c.handleError(errcb)
if err == nil {
go c.flushOutputBuffer()
return n, nil
}
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
}
return 0, err
}

func (c *Conn) write(b []byte) (int, func() error) {
if len(b) == 0 {
return 0, nil
Expand Down Expand Up @@ -548,6 +589,11 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

func (c *Conn) SetCtx(ctx *Ctx) {
c.ctx = ctx
C.SSL_set_SSL_CTX(c.ssl, ctx.ctx)
}

func (c *Conn) UnderlyingConn() net.Conn {
return c.conn
}
Expand All @@ -566,3 +612,11 @@ func (c *Conn) SetTlsExtHostName(name string) error {
func (c *Conn) VerifyResult() VerifyResult {
return VerifyResult(C.SSL_get_verify_result(c.ssl))
}

func (c *Conn) GetServerName() string {
return C.GoString(C.SSL_get_servername(c.ssl, C.TLSEXT_NAMETYPE_host_name))
}

func (c *Conn) Version() int {
return int(C.SSL_version_not_a_macro(c.ssl))
}
78 changes: 73 additions & 5 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,40 @@ static long SSL_CTX_set_tmp_ecdh_not_a_macro(SSL_CTX* ctx, EC_KEY *key) {
return SSL_CTX_set_tmp_ecdh(ctx, key);
}

static long SSL_CTX_set_tlsext_servername_callback_not_a_macro(SSL_CTX* ctx, void (*fp)()) {
return SSL_CTX_set_tlsext_servername_callback(ctx, fp);
}

typedef struct TlsServernameData {
void *go_ctx;
SSL_CTX *ctx;
void *arg;
} TlsServernameData;

static TlsServernameData* new_TlsServernameData() {
return calloc(1, sizeof(TlsServernameData));
}

//UNUSED: openssl doesn't have a way to unset SNI callback or arg. So we just leak whatever
//the function above allocates
//static void del_TlsServernameData(TlsServernameData *tsd) {
// free(tds);
//}

extern int callServerNameCb(SSL* ssl, int ad, void* arg);

static int call_go_servername(SSL* ssl, int ad, void* arg) {
return callServerNameCb(ssl, ad, arg);
}

static int servername_gateway(TlsServernameData* cw) {
SSL_CTX* ctx = cw->ctx;
//TODO: figure out what to do with return codes. The first isn't 0
SSL_CTX_set_tlsext_servername_callback(ctx, call_go_servername);
SSL_CTX_set_tlsext_servername_arg(ctx, cw);
return 0;
}

#ifndef SSL_MODE_RELEASE_BUFFERS
#define SSL_MODE_RELEASE_BUFFERS 0
#endif
Expand Down Expand Up @@ -117,11 +151,13 @@ var (
)

type Ctx struct {
ctx *C.SSL_CTX
cert *Certificate
chain []*Certificate
key PrivateKey
verify_cb VerifyCallback
ctx *C.SSL_CTX
cert *Certificate
chain []*Certificate
key PrivateKey
verify_cb VerifyCallback
servername_cb ServerNameCallback
ted *C.TlsServernameData
}

//export get_ssl_ctx_idx
Expand Down Expand Up @@ -605,3 +641,35 @@ func (c *Ctx) SessSetCacheSize(t int) int {
func (c *Ctx) SessGetCacheSize() int {
return int(C.SSL_CTX_sess_get_cache_size_not_a_macro(c.ctx))
}

// Set SSL_CTX_set_tlsext_servername_callback
// https://www.openssl.org/docs/manmaster/ssl/???
type ServerNameCallback func(ssl Conn, ad int, arg unsafe.Pointer) int

//export callServerNameCb
func callServerNameCb(ssl *C.SSL, ad C.int, arg unsafe.Pointer) C.int {
var ted *C.TlsServernameData = (*C.TlsServernameData)(arg)
goCtx := (*Ctx)(ted.go_ctx)

//setup a dummy Conn so we can associate a SSL_CTX from user callback
conn := Conn{
ssl: ssl,
ctx: goCtx,
}
ret := goCtx.servername_cb(conn, int(ad), ted.arg)
return C.int(ret)
}

func (c *Ctx) SetTlsExtServerNameCallback(cb func(ssl Conn, ad int, arg unsafe.Pointer) int,
arg unsafe.Pointer) int {
c.servername_cb = cb
ted := C.new_TlsServernameData()
if ted == nil {
return 1
}
ted.go_ctx = unsafe.Pointer(c)
ted.ctx = c.ctx
ted.arg = arg
c.ted = ted
return int(C.servername_gateway(c.ted))
}
120 changes: 120 additions & 0 deletions tls_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (C) 2014 Space Monkey, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package openssl

import (
"bytes"
"io"
"sync"
"testing"
"unsafe"
)

var gFoundServerName bool = false
var gServerName string
var gCallbackData string = "some callback data"

func passThroughServername() func(ssl Conn, ad int, arg unsafe.Pointer) int {
x := func(ssl Conn, ad int, arg unsafe.Pointer) int {
cbData := (*string)(arg)
if *cbData != gCallbackData { //we should getthe callback data we set on the CTX
return 1
}
name := ssl.GetServerName()
if name == gServerName {
gFoundServerName = true
//here we'd normally do soemthing like get a CTX for the specific server name and
//set it on the conn.
} else {
gFoundServerName = false
}
return 0
}
return x
}

func TestTLSExtSNI(t *testing.T) {
//setup SNI On the CTX
server_conn, client_conn := NetPipe(t)
defer server_conn.Close()
defer client_conn.Close()

server, client := OpenSSLConstructor(t, server_conn, client_conn)
cconn := client.(*Conn)
sconn := server.(*Conn)
ctx := (*sconn).ctx
//setup SNI On the CTX
rc := ctx.SetTlsExtServerNameCallback(passThroughServername(), unsafe.Pointer(&gCallbackData))
if rc != 0 {
t.Fatal("Expected 0 from ctx.SetTlsExtServerNameCallback, but got %d", rc)
}
data := "first test string\n"
host := "test-host"

var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
gServerName = host
err := cconn.SetTlsExtHostName(host)
if err != nil {
t.Fatal(err)
}

err = client.Handshake()
if err != nil {
t.Fatal(err)
}

_, err = io.Copy(client, bytes.NewReader([]byte(data)))
if err != nil {
t.Fatal(err)
}

err = client.Close()
if err != nil {
t.Fatal(err)
}
}()
go func() {
defer wg.Done()

err := server.Handshake()
if err != nil {
t.Fatal(err)
}

buf := bytes.NewBuffer(make([]byte, 0, len(data)))
_, err = io.CopyN(buf, server, int64(len(data)))
if err != nil {
t.Fatal(err)
}
if string(buf.Bytes()) != data {
t.Fatal("mismatched data")
}

err = server.Close()
if err != nil {
t.Fatal(err)
}
}()
wg.Wait()
if gFoundServerName == false {
t.Fatal("Expected gFoundServerName to be set to true")
}
if gServerName != host {
t.Fatal("Expected gServerName to be '%s', but it was '%s'", host, gServerName)
}
}