Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Context #52

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Installing
----------

```
go get -u github.com/tidwall/redcon
go get -u github.com/tidwall/redcon/v2
```

Example
Expand Down Expand Up @@ -52,7 +52,7 @@ import (
"strings"
"sync"

"github.com/tidwall/redcon"
"github.com/tidwall/redcon/v2"
)

var addr = ":6380"
Expand Down
5 changes: 3 additions & 2 deletions example/clone.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package main

import (
"context"
"log"
"strings"
"sync"

"github.com/tidwall/redcon"
"github.com/tidwall/redcon/v2"
)

var addr = ":6380"
Expand All @@ -16,7 +17,7 @@ func main() {
var ps redcon.PubSub
go log.Printf("started server at %s", addr)

err := redcon.ListenAndServe(addr,
err := redcon.ListenAndServe(context.Background(), addr,
func(conn redcon.Conn, cmd redcon.Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
Expand Down
5 changes: 3 additions & 2 deletions example/mux/clone.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package main

import (
"context"
"log"

"github.com/tidwall/redcon"
"github.com/tidwall/redcon/v2"
)

var addr = ":6380"
Expand All @@ -21,7 +22,7 @@ func main() {
mux.HandleFunc("get", handler.get)
mux.HandleFunc("del", handler.delete)

err := redcon.ListenAndServe(addr,
err := redcon.ListenAndServe(context.Background(), addr,
mux.ServeRESP,
func(conn redcon.Conn) bool {
// use this function to accept or deny the connection.
Expand Down
2 changes: 1 addition & 1 deletion example/mux/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"log"
"sync"

"github.com/tidwall/redcon"
"github.com/tidwall/redcon/v2"
)

type Handler struct {
Expand Down
5 changes: 3 additions & 2 deletions example/tls/clone.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package main

import (
"context"
"crypto/tls"
"log"
"strings"
"sync"

"github.com/tidwall/redcon"
"github.com/tidwall/redcon/v2"
)

const serverKey = `-----BEGIN EC PARAMETERS-----
Expand Down Expand Up @@ -47,7 +48,7 @@ func main() {
var items = make(map[string][]byte)

go log.Printf("started server at %s", addr)
err = redcon.ListenAndServeTLS(addr,
err = redcon.ListenAndServeTLS(context.Background(), addr,
func(conn redcon.Conn, cmd redcon.Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module github.com/tidwall/redcon
module github.com/tidwall/redcon/v2

go 1.15

Expand Down
55 changes: 42 additions & 13 deletions redcon.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package redcon

import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
Expand All @@ -23,6 +24,7 @@ var (
errDetached = errors.New("detached")
errIncompleteCommand = errors.New("incomplete command")
errTooMuchData = errors.New("too much data")
errContextDone = errors.New("context done")
)

type errProtocol struct {
Expand Down Expand Up @@ -114,27 +116,31 @@ type Conn interface {
}

// NewServer returns a new Redcon server configured on "tcp" network net.
func NewServer(addr string,
func NewServer(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) *Server {
return NewServerNetwork("tcp", addr, handler, accept, closed)
return NewServerNetwork(ctx, "tcp", addr, handler, accept, closed)
}

// NewServerTLS returns a new Redcon TLS server configured on "tcp" network net.
func NewServerTLS(addr string,
func NewServerTLS(ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) *TLSServer {
return NewServerNetworkTLS("tcp", addr, handler, accept, closed, config)
return NewServerNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
}

// NewServerNetwork returns a new Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetwork(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
Expand All @@ -144,6 +150,7 @@ func NewServerNetwork(
panic("handler is nil")
}
s := &Server{
ctx: ctx,
net: net,
laddr: laddr,
handler: handler,
Expand All @@ -157,6 +164,7 @@ func NewServerNetwork(
// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func NewServerNetworkTLS(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
Expand All @@ -167,6 +175,7 @@ func NewServerNetworkTLS(
panic("handler is nil")
}
s := Server{
ctx: ctx,
net: net,
laddr: laddr,
handler: handler,
Expand Down Expand Up @@ -241,50 +250,58 @@ func Serve(ln net.Listener,
}

// ListenAndServe creates a new server and binds to addr configured on "tcp" network net.
func ListenAndServe(addr string,
func ListenAndServe(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return ListenAndServeNetwork("tcp", addr, handler, accept, closed)
return ListenAndServeNetwork(ctx, "tcp", addr, handler, accept, closed)
}

// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net.
func ListenAndServeTLS(addr string,
func ListenAndServeTLS(
ctx context.Context,
addr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) error {
return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config)
return ListenAndServeNetworkTLS(ctx, "tcp", addr, handler, accept, closed, config)
}

// ListenAndServeNetwork creates a new server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetwork(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
) error {
return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe()
return NewServerNetwork(ctx, net, laddr, handler, accept, closed).ListenAndServe()
}

// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
func ListenAndServeNetworkTLS(
ctx context.Context,
net, laddr string,
handler func(conn Conn, cmd Command),
accept func(conn Conn) bool,
closed func(conn Conn, err error),
config *tls.Config,
) error {
return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe()
return NewServerNetworkTLS(ctx, net, laddr, handler, accept, closed, config).ListenAndServe()
}

// ListenServeAndSignal serves incoming connections and passes nil or error
// when listening. signal can be nil.
func (s *Server) ListenServeAndSignal(signal chan error) error {
//var lc net.ListenConfig
//ln, err := lc.Listen(s.ctx, s.net, s.laddr)
ln, err := net.Listen(s.net, s.laddr)
if err != nil {
if signal != nil {
Expand Down Expand Up @@ -336,13 +353,26 @@ func serve(s *Server) error {
s.conns = nil
}()
}()

go func() {
select {
case <-s.ctx.Done():
s.Close()
}
}()

for {
lnconn, err := s.ln.Accept()
if err != nil {
s.mu.Lock()
done := s.done
s.mu.Unlock()
if done {
select {
case <-s.ctx.Done():
return errContextDone
default:
}
return nil
}
if s.AcceptError != nil {
Expand Down Expand Up @@ -547,6 +577,7 @@ type Command struct {

// Server defines a server for clients for managing client connections.
type Server struct {
ctx context.Context
mu sync.Mutex
net string
laddr string
Expand Down Expand Up @@ -1358,14 +1389,12 @@ func (ps *PubSub) unsubscribe(conn Conn, pattern, all bool, channel string) {
}
} else {
// unsubscribe single channel from (p)subscribe.
var entry *pubSubEntry
for ient := range sconn.entries {
if ient.pattern == pattern && ient.channel == channel {
removeEntry(entry)
removeEntry(ient)
break
}
}
removeEntry(entry)
}
sconn.dconn.Flush()
}
Expand Down
27 changes: 24 additions & 3 deletions redcon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package redcon
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -220,7 +221,8 @@ func TestServerUnix(t *testing.T) {
}

func testServerNetwork(t *testing.T, network, laddr string) {
s := NewServerNetwork(network, laddr,
ctx := context.Background()
s := NewServerNetwork(ctx, network, laddr,
func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
Expand Down Expand Up @@ -261,7 +263,7 @@ func testServerNetwork(t *testing.T, network, laddr string) {
}
go func() {
time.Sleep(time.Second / 4)
if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
if err := ListenAndServeNetwork(ctx, network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil {
panic("expected an error, should not be able to listen on the same port")
}
time.Sleep(time.Second / 4)
Expand Down Expand Up @@ -560,6 +562,7 @@ func TestParse(t *testing.T) {
func TestPubSub(t *testing.T) {
addr := ":12346"
done := make(chan bool)
ctx := context.Background()
go func() {
var ps PubSub
go func() {
Expand Down Expand Up @@ -593,7 +596,7 @@ func TestPubSub(t *testing.T) {
ps.Publish(channel, message)
}
}()
panic(ListenAndServe(addr, func(conn Conn, cmd Command) {
panic(ListenAndServe(ctx, addr, func(conn Conn, cmd Command) {
switch strings.ToLower(string(cmd.Args[0])) {
default:
conn.WriteError("ERR unknown command '" +
Expand Down Expand Up @@ -738,3 +741,21 @@ func TestPubSub(t *testing.T) {
// stop the timeout
final <- true
}

func TestContextDone(t *testing.T) {
var err error
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(1)
s := NewServerNetwork(ctx, "tcp", ":12345", func(conn Conn, cmd Command) {}, nil, nil)
go func() {
err = s.ListenAndServe()
wg.Done()
}()
time.Sleep(1 * time.Second)
cancel()
wg.Wait()
if err != errContextDone {
t.Fatalf("expected %v but found %v", errContextDone, err)
}
}