Skip to content

Commit

Permalink
Add wush cp for single file transfers without rsync
Browse files Browse the repository at this point in the history
  • Loading branch information
coadler committed Aug 26, 2024
1 parent 32f2838 commit a0d2cc9
Show file tree
Hide file tree
Showing 9 changed files with 393 additions and 13 deletions.
194 changes: 194 additions & 0 deletions cmd/wush/cp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package main

import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httputil"
"net/netip"
"os"
"path/filepath"

"github.com/charmbracelet/huh"
"github.com/coder/serpent"
"github.com/coder/wush/cliui"
"github.com/coder/wush/overlay"
"github.com/coder/wush/tsserver"
"github.com/schollz/progressbar/v3"
"tailscale.com/net/netns"
)

func cpCmd() *serpent.Command {
var (
authID string
waitP2P bool
stunAddrOverride string
stunAddrOverrideIP netip.Addr
)
return &serpent.Command{
Use: "cp <file>",
Short: "Transfer files.",
Long: "Transfer files to a " + cliui.Code("wush") + " peer. ",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
logF := func(str string, args ...any) {
fmt.Fprintf(inv.Stderr, str+"\n", args...)
}

if authID == "" {
err := huh.NewInput().
Title("Enter your Auth ID:").
Value(&authID).
Run()
if err != nil {
return fmt.Errorf("get auth id: %w", err)
}
}

dm, err := tsserver.DERPMapTailscale(ctx)
if err != nil {
return err
}

if stunAddrOverride != "" {
stunAddrOverrideIP, err = netip.ParseAddr(stunAddrOverride)
if err != nil {
return fmt.Errorf("parse stun addr override: %w", err)
}
}

send := overlay.NewSendOverlay(logger, dm)
send.STUNIPOverride = stunAddrOverrideIP

err = send.Auth.Parse(authID)
if err != nil {
return fmt.Errorf("parse auth key: %w", err)
}

logF("Auth information:")
stunStr := send.Auth.ReceiverStunAddr.String()
if !send.Auth.ReceiverStunAddr.IsValid() {
stunStr = "Disabled"
}
logF("\t> Server overlay STUN address: %s", cliui.Code(stunStr))
derpStr := "Disabled"
if send.Auth.ReceiverDERPRegionID > 0 {
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
}
logF("\t> Server overlay DERP home: %s", cliui.Code(derpStr))
logF("\t> Server overlay public key: %s", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
logF("\t> Server overlay auth key: %s", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))

s, err := tsserver.NewServer(ctx, logger, send)
if err != nil {
return err
}

if send.Auth.ReceiverDERPRegionID != 0 {
go send.ListenOverlayDERP(ctx)
} else if send.Auth.ReceiverStunAddr.IsValid() {
go send.ListenOverlaySTUN(ctx)
} else {
return errors.New("auth key provided neither DERP nor STUN")
}

go s.ListenAndServe(ctx)
netns.SetDialerOverride(s.Dialer())
ts, err := newTSNet("send")
if err != nil {
return err
}
ts.Logf = func(string, ...any) {}
ts.UserLogf = func(string, ...any) {}

logF("Bringing Wireguard up..")
ts.Up(ctx)
logF("Wireguard is ready!")

lc, err := ts.LocalClient()
if err != nil {
return err
}

ip, err := waitUntilHasPeerHasIP(ctx, logF, lc)
if err != nil {
return err
}

if waitP2P {
err := waitUntilHasP2P(ctx, logF, lc)
if err != nil {
return err
}
}

fiPath := inv.Args[0]
fiName := filepath.Base(inv.Args[0])

fi, err := os.Open(fiPath)
if err != nil {
return err
}
defer fi.Close()

fiStat, err := fi.Stat()
if err != nil {
return err
}

bar := progressbar.DefaultBytes(
fiStat.Size(),
fmt.Sprintf("Uploading %q", fiPath),
)
barReader := progressbar.NewReader(fi, bar)

hc := ts.HTTPClient()
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s:4444/%s", ip.String(), fiName), &barReader)
if err != nil {
return err
}
req.ContentLength = fiStat.Size()

res, err := hc.Do(req)
if err != nil {
return err
}
defer res.Body.Close()

out, err := httputil.DumpResponse(res, true)
if err != nil {
return err
}
bar.Close()
fmt.Println(string(out))

return nil
},
Options: []serpent.Option{
{
Flag: "auth-id",
Env: "WUSH_AUTH_ID",
Description: "The auth id returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
Default: "",
Value: serpent.StringOf(&authID),
},
{
Flag: "stun-ip-override",
Default: "",
Value: serpent.StringOf(&stunAddrOverride),
},
{
Flag: "wait-p2p",
Description: "Waits for the connection to be p2p.",
Default: "false",
Value: serpent.BoolOf(&waitP2P),
},
},
}
}
1 change: 1 addition & 0 deletions cmd/wush/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func main() {
sshCmd(),
receiveCmd(),
rsyncCmd(),
cpCmd(),
},
Options: []serpent.Option{
{
Expand Down
46 changes: 44 additions & 2 deletions cmd/wush/receive.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
"golang.org/x/xerrors"
"tailscale.com/ipn/store"
Expand All @@ -31,6 +34,7 @@ func receiveCmd() *serpent.Command {
return &serpent.Command{
Use: "receive",
Aliases: []string{"host"},
Short: "Run the wush server.",
Long: "Runs the wush server. Allows other wush CLIs to connect to this computer.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
Expand Down Expand Up @@ -95,19 +99,57 @@ func receiveCmd() *serpent.Command {
return err
}

ls, err := ts.Listen("tcp", ":3")
sshListener, err := ts.Listen("tcp", ":3")
if err != nil {
return err
}

go func() {
fmt.Println(cliui.Timestamp(time.Now()), "SSH server listening")
err := sshSrv.Serve(ls)
err := sshSrv.Serve(sshListener)
if err != nil {
logger.Info("ssh server exited", "err", err)
}
}()

cpListener, err := ts.Listen("tcp", ":4444")
if err != nil {
return err
}

go http.Serve(cpListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
return
}

fiName := strings.TrimPrefix(r.URL.Path, "/")
defer r.Body.Close()

fi, err := os.OpenFile(fiName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

bar := progressbar.DefaultBytes(
r.ContentLength,
fmt.Sprintf("Downloading %q", fiName),
)
_, err = io.Copy(io.MultiWriter(fi, bar), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
fi.Close()
bar.Close()

w.WriteHeader(http.StatusOK)
w.Write([]byte(fmt.Sprintf("File %q written", fiName)))
fmt.Printf("Received file %s from %s\n", fiName, r.RemoteAddr)
}))

ctx, ctxCancel := inv.SignalNotifyContext(ctx, os.Interrupt)
defer ctxCancel()

Expand Down
17 changes: 10 additions & 7 deletions cmd/wush/rsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ func rsyncCmd() *serpent.Command {
sshStdio bool
)
return &serpent.Command{
Use: "rsync [flags] -- [rsync args]",
Use: "rsync [flags] -- [rsync args]",
Short: "Transfer files over rsync.",
Long: "Runs rsync to transfer files to a " + cliui.Code("wush") + " peer. " +
"Use " + cliui.Code("wush receive") + " on the computer you would like to connect to." +
"\n\n" +
Expand Down Expand Up @@ -81,11 +82,13 @@ func rsyncCmd() *serpent.Command {

progPath := os.Args[0]
args := []string{
"-e", fmt.Sprintf("%s --auth-id %s --stdio --", progPath, send.Auth.AuthKey()),
"-c",
fmt.Sprintf(`rsync -e "%s ssh --auth-key %s --stdio --" %s`,
progPath, send.Auth.AuthKey(), strings.Join(inv.Args, " "),
),
}
args = append(args, inv.Args...)
fmt.Println("Running: rsync", strings.Join(args, " "))
cmd := exec.CommandContext(ctx, "rsync", args...)
cmd := exec.CommandContext(ctx, "sh", args...)
cmd.Stdin = inv.Stdin
cmd.Stdout = inv.Stdout
cmd.Stderr = inv.Stderr
Expand All @@ -94,9 +97,9 @@ func rsyncCmd() *serpent.Command {
},
Options: []serpent.Option{
{
Flag: "auth-id",
Env: "WUSH_AUTH_ID",
Description: "The auth id returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
Flag: "auth-key",
Env: "WUSH_AUTH_KEY",
Description: "The auth key returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
Default: "",
Value: serpent.StringOf(&authID),
},
Expand Down
6 changes: 3 additions & 3 deletions cmd/wush/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func sshCmd() *serpent.Command {
return &serpent.Command{
Use: "ssh",
Aliases: []string{},
Short: "Open a shell.",
Long: "Opens an SSH connection to a " + cliui.Code("wush") + " peer. " +
"Use " + cliui.Code("wush receive") + " on the computer you would like to connect to.",
Handler: func(inv *serpent.Invocation) error {

ctx := inv.Context()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
logF := func(str string, args ...any) {
Expand Down Expand Up @@ -135,8 +135,8 @@ func sshCmd() *serpent.Command {
},
Options: []serpent.Option{
{
Flag: "auth",
Env: "WUSH_AUTH",
Flag: "auth-key",
Env: "WUSH_AUTH_KEY",
Description: "The auth key returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
Default: "",
Value: serpent.StringOf(&authKey),
Expand Down
2 changes: 1 addition & 1 deletion cmd/wush/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
func versionCmd() *serpent.Command {
cmd := &serpent.Command{
Use: "version",
Short: "Show wush version",
Short: "Show wush version.",
Handler: func(inv *serpent.Invocation) error {
bi := getBuildInfo()
fmt.Printf("Wush %s-%s %s\n", bi.version, bi.commitHash[:7], bi.commitTime.Format(time.RFC1123))
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ require (
github.com/pion/stun/v3 v3.0.0
github.com/prometheus/client_golang v1.19.1
github.com/puzpuzpuz/xsync/v3 v3.2.0
github.com/schollz/progressbar/v3 v3.14.6
github.com/spf13/afero v1.11.0
github.com/valyala/fasthttp v1.55.0
go4.org/mem v0.0.0-20220726221520-4f986261bf13
Expand Down Expand Up @@ -139,6 +140,7 @@ require (
github.com/mdlayher/sdnotify v1.0.0 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
github.com/miekg/dns v1.1.58 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/go-ps v1.0.0 // indirect
github.com/mitchellh/go-testing-interface v1.14.1 // indirect
Expand Down
Loading

0 comments on commit a0d2cc9

Please sign in to comment.