Skip to content

Commit

Permalink
Update with x/net
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna committed Jan 29, 2025
1 parent 5f6d095 commit cc5b0f7
Showing 1 changed file with 49 additions and 8 deletions.
57 changes: 49 additions & 8 deletions x/examples/ws2endpoint/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ package main

import (
"context"
"errors"
"flag"
"io"
"log"
"log/slog"
"net"
"net/http"
"os"
Expand All @@ -41,6 +43,43 @@ func (c *natConn) Write(b []byte) (int, error) {
return c.Conn.Write(b)
}

func websocketToConn(targetConn io.Writer, clientConn *websocket.Conn) {
var buf []byte
for {
err := websocket.Message.Receive(clientConn, &buf)
if err != nil {
if !errors.Is(err, io.EOF) {
slog.Warn("failed to read from client", "error", err)
}
break
}
_, err = targetConn.Write(buf)
if err != nil {
slog.Warn("failed to write to target", "error", err)
break
}
}
}

func connToWebsocket(clientConn *websocket.Conn, targetConn io.Reader) {
// TODO: use a buffer pool
buf := make([]byte, 64*1024)
for {
n, err := targetConn.Read(buf)
if err != nil {
if !errors.Is(err, io.EOF) {
slog.Warn("failed to read from target", "error", err)
}
break
}
err = websocket.Message.Send(clientConn, buf[:n])
if err != nil {
slog.Warn("failed to write to client", "error", err)
break
}
}
}

func main() {
listenFlag := flag.String("listen", "localhost:8080", "Local proxy address to listen on")
transportFlag := flag.String("transport", "", "Transport config")
Expand Down Expand Up @@ -71,19 +110,20 @@ func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Got stream request: %v\n", r)
handler := func(wsConn *websocket.Conn) {
defer wsConn.Close()
targetConn, err := endpoint.ConnectStream(r.Context())
if err != nil {
log.Printf("Failed to upgrade: %v\n", err)
w.WriteHeader(http.StatusBadGateway)
return
}
defer targetConn.Close()
// Relay from client to target.
go func() {
io.Copy(targetConn, wsConn)
targetConn.CloseWrite()
defer targetConn.CloseWrite()
websocketToConn(targetConn, wsConn)
}()
io.Copy(wsConn, targetConn)
wsConn.Close()
connToWebsocket(wsConn, targetConn)
}
websocket.Server{Handler: handler}.ServeHTTP(w, r)
})
Expand All @@ -98,6 +138,7 @@ func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Got packet request: %v\n", r)
handler := func(wsConn *websocket.Conn) {
defer wsConn.Close()
targetConn, err := endpoint.ConnectPacket(r.Context())
if err != nil {
log.Printf("Failed to upgrade: %v\n", err)
Expand All @@ -107,12 +148,12 @@ func main() {
// Expire connetion after 5 minutes of idle time, as per
// https://datatracker.ietf.org/doc/html/rfc4787#section-4.3
targetConn = &natConn{targetConn, 5 * time.Minute}
defer targetConn.Close()
// Relay from client to target.
go func() {
io.Copy(targetConn, wsConn)
targetConn.Close()
websocketToConn(targetConn, wsConn)
}()
io.Copy(wsConn, targetConn)
wsConn.Close()
connToWebsocket(wsConn, targetConn)
}
websocket.Server{Handler: handler}.ServeHTTP(w, r)
})
Expand Down

0 comments on commit cc5b0f7

Please sign in to comment.