From b43192075140102d885884cf0814ba7639c77c6f Mon Sep 17 00:00:00 2001 From: Rumen Nikiforov Date: Sun, 22 Oct 2023 15:20:14 +0300 Subject: [PATCH] Reworked wireguard implementation Added multi-backend support Linux host only for now Mikrotik backend pending --- .env.dist | 4 + go.mod | 2 +- go.sum | 4 +- main.go | 34 +- pkg/api/config.go | 9 +- pkg/api/internal/model/adapter_foreign.go | 30 +- pkg/api/internal/model/adapter_peer.go | 4 +- .../internal/mutation/mutation_resolver.go | 24 +- pkg/api/internal/peer/peer_resolver.go | 14 +- pkg/api/internal/query/query_resolver.go | 10 +- pkg/api/internal/server/server_resolver.go | 12 +- pkg/api/router.go | 3 - pkg/config/config.go | 1 + pkg/manage/service.go | 326 ++++++-- pkg/wg/foreign_peer.go | 16 - pkg/wg/foreign_server.go | 11 - pkg/wg/helper.go | 7 - pkg/wg/interface_linux.go | 270 ------- pkg/wg/interface_other.go | 29 - pkg/wg/service.go | 754 ------------------ pkg/wireguard/backend/backend.go | 14 + pkg/wireguard/backend/configure_options.go | 20 + pkg/wireguard/backend/device.go | 6 + .../backend}/foreign_interface.go | 2 +- pkg/wireguard/backend/foreign_server.go | 11 + pkg/wireguard/backend/interface.go | 7 + pkg/wireguard/backend/interface_options.go | 19 + pkg/wireguard/backend/interface_stats.go | 27 + pkg/wireguard/backend/peer.go | 15 + pkg/wireguard/backend/peer_options.go | 23 + pkg/{wg => wireguard/backend}/peer_stats.go | 2 +- pkg/wireguard/backend/wireguard.go | 10 + pkg/wireguard/backend/wireguard_options.go | 25 + pkg/wireguard/linux/adapter.go | 107 +++ pkg/wireguard/linux/backend_linux.go | 525 ++++++++++++ pkg/wireguard/linux/backend_other.go | 13 + pkg/{wg => wireguard/linux}/wg_link.go | 4 +- pkg/wireguard/service.go | 60 ++ 38 files changed, 1243 insertions(+), 1211 deletions(-) delete mode 100644 pkg/wg/foreign_peer.go delete mode 100644 pkg/wg/foreign_server.go delete mode 100644 pkg/wg/helper.go delete mode 100644 pkg/wg/interface_linux.go delete mode 100644 pkg/wg/interface_other.go delete mode 100644 pkg/wg/service.go create mode 100644 pkg/wireguard/backend/backend.go create mode 100644 pkg/wireguard/backend/configure_options.go create mode 100644 pkg/wireguard/backend/device.go rename pkg/{wg => wireguard/backend}/foreign_interface.go (86%) create mode 100644 pkg/wireguard/backend/foreign_server.go create mode 100644 pkg/wireguard/backend/interface.go create mode 100644 pkg/wireguard/backend/interface_options.go create mode 100644 pkg/wireguard/backend/interface_stats.go create mode 100644 pkg/wireguard/backend/peer.go create mode 100644 pkg/wireguard/backend/peer_options.go rename pkg/{wg => wireguard/backend}/peer_stats.go (90%) create mode 100644 pkg/wireguard/backend/wireguard.go create mode 100644 pkg/wireguard/backend/wireguard_options.go create mode 100644 pkg/wireguard/linux/adapter.go create mode 100644 pkg/wireguard/linux/backend_linux.go create mode 100644 pkg/wireguard/linux/backend_other.go rename pkg/{wg => wireguard/linux}/wg_link.go (87%) create mode 100644 pkg/wireguard/service.go diff --git a/.env.dist b/.env.dist index 5c157d0..7ac5fa0 100644 --- a/.env.dist +++ b/.env.dist @@ -1,3 +1,7 @@ +# Wireguard interface backend +# Default: linux +WG_UI_BACKEND=linux + # The main database file path # Stores all users and wireguard servers configuration WG_UI_BOLT_DB_PATH=/var/lib/wg-ui/data.db diff --git a/go.mod b/go.mod index 16c90f2..666e972 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,7 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.14.0 // indirect - golang.zx2c4.com/wireguard v0.0.0-20231018191413-24ea13351eb7 // indirect + golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9eb58a6..c321e7b 100644 --- a/go.sum +++ b/go.sum @@ -102,8 +102,8 @@ golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc= golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= -golang.zx2c4.com/wireguard v0.0.0-20231018191413-24ea13351eb7 h1:1+bHXA5s3p7saWQTFJKtQF7WzoU0HEvIe5iUovtpzhU= -golang.zx2c4.com/wireguard v0.0.0-20231018191413-24ea13351eb7/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb h1:c5tyN8sSp8jSDxdCCDXVOpJwYXXhmTkNMt+g0zTSOic= +golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 8c71cd5..e3162de 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "strings" "syscall" "time" @@ -25,7 +26,9 @@ import ( "github.com/UnAfraid/wg-ui/pkg/server" "github.com/UnAfraid/wg-ui/pkg/subscription" "github.com/UnAfraid/wg-ui/pkg/user" - "github.com/UnAfraid/wg-ui/pkg/wg" + "github.com/UnAfraid/wg-ui/pkg/wireguard" + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" + "github.com/UnAfraid/wg-ui/pkg/wireguard/linux" ) const ( @@ -111,18 +114,36 @@ func main() { return } - wgService, err := wg.NewService(serverService, peerService) - if err != nil { + var wireguardBackend backend.Backend + + switch strings.ToLower(conf.Backend) { + case "linux": + wireguardBackend, err = linux.NewLinuxBackend() + if err != nil { + logrus. + WithError(err). + Fatal("failed to initialize linux backend for wireguard") + return + } + default: logrus. WithError(err). - Fatal("failed to initialize WireGuard service") + Fatal("unsupported wireguard backend") return } - defer wgService.Close() + + wireguardService := wireguard.NewService(wireguardBackend) + defer func() { + if err := wireguardService.Close(context.Background()); err != nil { + logrus. + WithError(err). + Error("failed to close wireguard service") + } + }() authService := auth.NewService(jwt.SigningMethodHS256, jwtSecretBytes, jwtSecretBytes, conf.JwtDuration) - manageService := manage.NewService(transactionScoper, userService, serverService, peerService, wgService) + manageService := manage.NewService(transactionScoper, userService, serverService, peerService, wireguardService) router := api.NewRouter( conf, @@ -130,7 +151,6 @@ func main() { userService, serverService, peerService, - wgService, manageService, ) diff --git a/pkg/api/config.go b/pkg/api/config.go index b856dd9..0b45101 100755 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -14,7 +14,6 @@ import ( "github.com/UnAfraid/wg-ui/pkg/peer" "github.com/UnAfraid/wg-ui/pkg/server" "github.com/UnAfraid/wg-ui/pkg/user" - "github.com/UnAfraid/wg-ui/pkg/wg" ) //go:generate go run github.com/99designs/gqlgen --config ../../gqlgen.yml generate @@ -23,16 +22,15 @@ func newConfig( userService user.Service, serverService server.Service, peerService peer.Service, - wgService wg.Service, manageService manage.Service, ) resolver.Config { return resolver.Config{ Resolvers: &resolverRoot{ queryResolver: query.NewQueryResolver( - wgService, peerService, serverService, userService, + manageService, ), mutationResolver: mutation.NewMutationResolver( authService, @@ -48,13 +46,10 @@ func newConfig( peerService, ), serverResolver: serverResolver.NewServerResolver( - serverService, peerService, - wgService, ), peerResolver: peerResolver.NewPeerResolver( - peerService, - wgService, + manageService, ), }, Directives: directive.NewDirectiveRoot(), diff --git a/pkg/api/internal/model/adapter_foreign.go b/pkg/api/internal/model/adapter_foreign.go index 7adfb39..3509647 100644 --- a/pkg/api/internal/model/adapter_foreign.go +++ b/pkg/api/internal/model/adapter_foreign.go @@ -1,11 +1,13 @@ package model import ( + "net" + "github.com/UnAfraid/wg-ui/pkg/internal/adapt" - "github.com/UnAfraid/wg-ui/pkg/wg" + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" ) -func ToForeignInterface(foreignInterface *wg.ForeignInterface) *ForeignInterface { +func ToForeignInterface(foreignInterface *backend.ForeignInterface) *ForeignInterface { if foreignInterface == nil { return nil } @@ -17,13 +19,13 @@ func ToForeignInterface(foreignInterface *wg.ForeignInterface) *ForeignInterface } } -func ToForeignServer(foreignServer *wg.ForeignServer) *ForeignServer { +func ToForeignServer(foreignServer *backend.ForeignServer) *ForeignServer { if foreignServer == nil { return nil } return &ForeignServer{ - ForeignInterface: ToForeignInterface(foreignServer.ForeignInterface), + ForeignInterface: ToForeignInterface(foreignServer.Interface), Name: foreignServer.Name, Type: foreignServer.Type, PublicKey: foreignServer.PublicKey, @@ -33,19 +35,21 @@ func ToForeignServer(foreignServer *wg.ForeignServer) *ForeignServer { } } -func ToForeignPeer(foreignPeer *wg.ForeignPeer) *ForeignPeer { +func ToForeignPeer(foreignPeer *backend.Peer) *ForeignPeer { if foreignPeer == nil { return nil } return &ForeignPeer{ - PublicKey: foreignPeer.PublicKey, - Endpoint: foreignPeer.Endpoint, - AllowedIps: foreignPeer.AllowedIPs, - PersistentKeepAliveInterval: int(foreignPeer.PersistentKeepaliveInterval), - LastHandshakeTime: adapt.ToPointer(foreignPeer.LastHandshakeTime), - ReceiveBytes: float64(foreignPeer.ReceiveBytes), - TransmitBytes: float64(foreignPeer.TransmitBytes), - ProtocolVersion: foreignPeer.ProtocolVersion, + PublicKey: foreignPeer.PublicKey, + Endpoint: adapt.ToPointerNilZero(foreignPeer.Endpoint), + AllowedIps: adapt.Array(foreignPeer.AllowedIPs, func(allowedIp net.IPNet) string { + return allowedIp.String() + }), + PersistentKeepAliveInterval: int(foreignPeer.PersistentKeepalive.Seconds()), + LastHandshakeTime: adapt.ToPointer(foreignPeer.Stats.LastHandshakeTime), + ReceiveBytes: float64(foreignPeer.Stats.ReceiveBytes), + TransmitBytes: float64(foreignPeer.Stats.TransmitBytes), + ProtocolVersion: foreignPeer.Stats.ProtocolVersion, } } diff --git a/pkg/api/internal/model/adapter_peer.go b/pkg/api/internal/model/adapter_peer.go index 9ca4d8c..e106906 100644 --- a/pkg/api/internal/model/adapter_peer.go +++ b/pkg/api/internal/model/adapter_peer.go @@ -3,7 +3,7 @@ package model import ( "github.com/UnAfraid/wg-ui/pkg/internal/adapt" "github.com/UnAfraid/wg-ui/pkg/peer" - "github.com/UnAfraid/wg-ui/pkg/wg" + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" ) func CreatePeerInputToCreateOptions(input CreatePeerInput) *peer.CreateOptions { @@ -140,7 +140,7 @@ func PeerHookInputToPeerHook(hook *PeerHookInput) *peer.Hook { } } -func ToPeerStats(stats *wg.PeerStats) *PeerStats { +func ToPeerStats(stats *backend.PeerStats) *PeerStats { if stats == nil { return nil } diff --git a/pkg/api/internal/mutation/mutation_resolver.go b/pkg/api/internal/mutation/mutation_resolver.go index 92eaaca..4e4ffdc 100755 --- a/pkg/api/internal/mutation/mutation_resolver.go +++ b/pkg/api/internal/mutation/mutation_resolver.go @@ -197,12 +197,22 @@ func (r *mutationResolver) DeleteServer(ctx context.Context, input model.DeleteS } func (r *mutationResolver) StartServer(ctx context.Context, input model.StartServerInput) (*model.StartServerPayload, error) { + user, err := model.ContextToUser(ctx) + if err != nil { + return nil, err + } + + userId, err := user.ID.String(model.IdKindUser) + if err != nil { + return nil, err + } + serverId, err := input.ID.String(model.IdKindServer) if err != nil { return nil, err } - srv, err := r.manageService.StartServer(ctx, serverId) + srv, err := r.manageService.StartServer(ctx, serverId, userId) if err != nil { return nil, err } @@ -214,12 +224,22 @@ func (r *mutationResolver) StartServer(ctx context.Context, input model.StartSer } func (r *mutationResolver) StopServer(ctx context.Context, input model.StopServerInput) (*model.StopServerPayload, error) { + user, err := model.ContextToUser(ctx) + if err != nil { + return nil, err + } + + userId, err := user.ID.String(model.IdKindUser) + if err != nil { + return nil, err + } + serverId, err := input.ID.String(model.IdKindServer) if err != nil { return nil, err } - srv, err := r.manageService.StopServer(ctx, serverId) + srv, err := r.manageService.StopServer(ctx, serverId, userId) if err != nil { return nil, err } diff --git a/pkg/api/internal/peer/peer_resolver.go b/pkg/api/internal/peer/peer_resolver.go index 1a1cbd5..4435314 100644 --- a/pkg/api/internal/peer/peer_resolver.go +++ b/pkg/api/internal/peer/peer_resolver.go @@ -6,22 +6,18 @@ import ( "github.com/UnAfraid/wg-ui/pkg/api/internal/handler" "github.com/UnAfraid/wg-ui/pkg/api/internal/model" "github.com/UnAfraid/wg-ui/pkg/api/internal/resolver" - "github.com/UnAfraid/wg-ui/pkg/peer" - "github.com/UnAfraid/wg-ui/pkg/wg" + "github.com/UnAfraid/wg-ui/pkg/manage" ) type peerResolver struct { - peerService peer.Service - wgService wg.Service + manageService manage.Service } func NewPeerResolver( - peerService peer.Service, - wgService wg.Service, + manageService manage.Service, ) resolver.PeerResolver { return &peerResolver{ - wgService: wgService, - peerService: peerService, + manageService: manageService, } } @@ -66,7 +62,7 @@ func (r *peerResolver) Stats(ctx context.Context, p *model.Peer) (*model.PeerSta return nil, nil } - stats, err := r.wgService.PeerStats(server.Name, p.PublicKey) + stats, err := r.manageService.PeerStats(nil, server.Name, p.PublicKey) if err != nil { return nil, err } diff --git a/pkg/api/internal/query/query_resolver.go b/pkg/api/internal/query/query_resolver.go index 8b015da..1fe6149 100755 --- a/pkg/api/internal/query/query_resolver.go +++ b/pkg/api/internal/query/query_resolver.go @@ -10,30 +10,30 @@ import ( "github.com/UnAfraid/wg-ui/pkg/api/internal/model" "github.com/UnAfraid/wg-ui/pkg/api/internal/resolver" "github.com/UnAfraid/wg-ui/pkg/internal/adapt" + "github.com/UnAfraid/wg-ui/pkg/manage" "github.com/UnAfraid/wg-ui/pkg/peer" "github.com/UnAfraid/wg-ui/pkg/server" "github.com/UnAfraid/wg-ui/pkg/user" - "github.com/UnAfraid/wg-ui/pkg/wg" ) type queryResolver struct { - wgService wg.Service peerService peer.Service serverService server.Service userService user.Service + manageService manage.Service } func NewQueryResolver( - wgService wg.Service, peerService peer.Service, serverService server.Service, userService user.Service, + manageService manage.Service, ) resolver.QueryResolver { return &queryResolver{ - wgService: wgService, peerService: peerService, serverService: serverService, userService: userService, + manageService: manageService, } } @@ -138,7 +138,7 @@ func (r *queryResolver) Peers(ctx context.Context, query *string) ([]*model.Peer } func (r *queryResolver) ForeignServers(ctx context.Context) ([]*model.ForeignServer, error) { - foreignServers, err := r.wgService.ForeignServers(ctx) + foreignServers, err := r.manageService.ForeignServers(ctx) if err != nil { return nil, err } diff --git a/pkg/api/internal/server/server_resolver.go b/pkg/api/internal/server/server_resolver.go index 4179516..066a86b 100644 --- a/pkg/api/internal/server/server_resolver.go +++ b/pkg/api/internal/server/server_resolver.go @@ -8,25 +8,17 @@ import ( "github.com/UnAfraid/wg-ui/pkg/api/internal/resolver" "github.com/UnAfraid/wg-ui/pkg/internal/adapt" "github.com/UnAfraid/wg-ui/pkg/peer" - "github.com/UnAfraid/wg-ui/pkg/server" - "github.com/UnAfraid/wg-ui/pkg/wg" ) type serverResolver struct { - serverService server.Service - peerService peer.Service - wgService wg.Service + peerService peer.Service } func NewServerResolver( - serverService server.Service, peerService peer.Service, - wgService wg.Service, ) resolver.ServerResolver { return &serverResolver{ - serverService: serverService, - peerService: peerService, - wgService: wgService, + peerService: peerService, } } diff --git a/pkg/api/router.go b/pkg/api/router.go index 1e89189..b56c23c 100755 --- a/pkg/api/router.go +++ b/pkg/api/router.go @@ -27,7 +27,6 @@ import ( "github.com/UnAfraid/wg-ui/pkg/peer" "github.com/UnAfraid/wg-ui/pkg/server" "github.com/UnAfraid/wg-ui/pkg/user" - "github.com/UnAfraid/wg-ui/pkg/wg" ) const ( @@ -42,7 +41,6 @@ func NewRouter( userService user.Service, serverService server.Service, peerService peer.Service, - wgService wg.Service, manageService manage.Service, ) http.Handler { corsMiddleware := cors.New(cors.Options{ @@ -58,7 +56,6 @@ func NewRouter( userService, serverService, peerService, - wgService, manageService, ) diff --git a/pkg/config/config.go b/pkg/config/config.go index 387c19a..30712d8 100755 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -9,6 +9,7 @@ import ( ) type Config struct { + Backend string `required:"true" default:"linux"` BoltDB *BoltDB `split_words:"true"` HttpServer *HttpServer `split_words:"true"` DebugServer *DebugServer `split_words:"true"` diff --git a/pkg/manage/service.go b/pkg/manage/service.go index 2860dbe..569f1e7 100644 --- a/pkg/manage/service.go +++ b/pkg/manage/service.go @@ -3,15 +3,20 @@ package manage import ( "context" "errors" + "fmt" + "net" + "strings" "time" "github.com/sirupsen/logrus" "github.com/UnAfraid/wg-ui/pkg/dbx" + "github.com/UnAfraid/wg-ui/pkg/internal/adapt" "github.com/UnAfraid/wg-ui/pkg/peer" "github.com/UnAfraid/wg-ui/pkg/server" "github.com/UnAfraid/wg-ui/pkg/user" - "github.com/UnAfraid/wg-ui/pkg/wg" + "github.com/UnAfraid/wg-ui/pkg/wireguard" + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" ) type Service interface { @@ -22,12 +27,14 @@ type Service interface { CreateServer(ctx context.Context, options *server.CreateOptions, userId string) (*server.Server, error) UpdateServer(ctx context.Context, serverId string, options *server.UpdateOptions, fieldMask *server.UpdateFieldMask, userId string) (*server.Server, error) DeleteServer(ctx context.Context, serverId string, userId string) (*server.Server, error) - StartServer(ctx context.Context, serverId string) (*server.Server, error) - StopServer(ctx context.Context, serverId string) (*server.Server, error) + StartServer(ctx context.Context, serverId string, userId string) (*server.Server, error) + StopServer(ctx context.Context, serverId string, userId string) (*server.Server, error) ImportForeignServer(ctx context.Context, name string, userId string) (*server.Server, error) CreatePeer(ctx context.Context, serverId string, options *peer.CreateOptions, userId string) (*peer.Peer, error) UpdatePeer(ctx context.Context, peerId string, options *peer.UpdateOptions, fieldMask *peer.UpdateFieldMask, userId string) (*peer.Peer, error) DeletePeer(ctx context.Context, peerId string, userId string) (*peer.Peer, error) + PeerStats(ctx context.Context, name string, peerPublicKey string) (*backend.PeerStats, error) + ForeignServers(ctx context.Context) ([]*backend.ForeignServer, error) } type service struct { @@ -35,7 +42,7 @@ type service struct { userService user.Service serverService server.Service peerService peer.Service - wgService wg.Service + wireguardService wireguard.Service } func NewService( @@ -43,14 +50,14 @@ func NewService( userService user.Service, serverService server.Service, peerService peer.Service, - wgService wg.Service, + wireguardService wireguard.Service, ) Service { s := &service{ transactionScoper: transactionScoper, userService: userService, serverService: serverService, peerService: peerService, - wgService: wgService, + wireguardService: wireguardService, } s.cleanup(context.Background()) @@ -104,29 +111,53 @@ func (s *service) DeleteUser(ctx context.Context, userId string) (*user.User, er } func (s *service) CreateServer(ctx context.Context, options *server.CreateOptions, userId string) (*server.Server, error) { - createdServer, err := s.serverService.CreateServer(ctx, options, userId) - if err != nil { - return nil, err - } + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*server.Server, error) { + createdServer, err := s.serverService.CreateServer(ctx, options, userId) + if err != nil { + return nil, err + } - if createdServer.Enabled { - return s.wgService.StartServer(ctx, createdServer.Id) - } + if createdServer.Enabled { + device, err := s.configureDevice(ctx, createdServer, nil) + if err != nil { + return nil, err + } + return s.updateServer(ctx, createdServer, device, userId) + } - return createdServer, nil + return createdServer, nil + }) } func (s *service) UpdateServer(ctx context.Context, serverId string, options *server.UpdateOptions, fieldMask *server.UpdateFieldMask, userId string) (*server.Server, error) { - updatedServer, err := s.serverService.UpdateServer(ctx, serverId, options, fieldMask, userId) - if err != nil { - return nil, err - } + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*server.Server, error) { + updatedServer, err := s.serverService.UpdateServer(ctx, serverId, options, fieldMask, userId) + if err != nil { + return nil, err + } - if !updatedServer.Enabled { - return s.wgService.StopServer(ctx, updatedServer.Id) - } + if !updatedServer.Enabled { + status, err := s.wireguardService.Status(ctx, updatedServer.Name) + if err != nil { + return nil, err + } - return updatedServer, nil + if status { + if err := s.wireguardService.Down(ctx, updatedServer.Name); err != nil { + return nil, err + } + updateOptions := server.UpdateOptions{ + Running: false, + } + updateFieldMask := server.UpdateFieldMask{ + Running: true, + } + return s.serverService.UpdateServer(ctx, updatedServer.Id, &updateOptions, &updateFieldMask, userId) + } + } + + return updatedServer, nil + }) } func (s *service) DeleteServer(ctx context.Context, serverId string, userId string) (*server.Server, error) { @@ -136,7 +167,7 @@ func (s *service) DeleteServer(ctx context.Context, serverId string, userId stri return nil, err } - if _, err = s.wgService.StopServer(ctx, svc.Id); err != nil { + if err = s.wireguardService.Down(ctx, svc.Name); err != nil { logrus. WithError(err). WithField("serverId", svc.Id). @@ -167,72 +198,247 @@ func (s *service) DeleteServer(ctx context.Context, serverId string, userId stri }) } -func (s *service) StartServer(ctx context.Context, serverId string) (*server.Server, error) { - return s.wgService.StartServer(ctx, serverId) +func (s *service) StartServer(ctx context.Context, serverId string, userId string) (*server.Server, error) { + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*server.Server, error) { + srv, err := s.findServer(ctx, serverId) + if err != nil { + return nil, err + } + + peers, err := s.peerService.FindPeers(ctx, &peer.FindOptions{ + ServerId: &srv.Id, + }) + if err != nil { + return nil, err + } + + device, err := s.configureDevice(ctx, srv, peers) + if err != nil { + return nil, err + } + return s.updateServer(ctx, srv, device, userId) + }) } -func (s *service) StopServer(ctx context.Context, serverId string) (*server.Server, error) { - return s.wgService.StopServer(ctx, serverId) +func (s *service) StopServer(ctx context.Context, serverId string, userId string) (*server.Server, error) { + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*server.Server, error) { + srv, err := s.findServer(ctx, serverId) + if err != nil { + return nil, err + } + + if err := s.wireguardService.Down(ctx, srv.Name); err != nil { + return nil, err + } + + updateOptions := server.UpdateOptions{ + Running: false, + } + updateFieldMask := server.UpdateFieldMask{ + Running: true, + } + return s.serverService.UpdateServer(ctx, srv.Id, &updateOptions, &updateFieldMask, userId) + }) } func (s *service) ImportForeignServer(ctx context.Context, name string, userId string) (*server.Server, error) { - return s.wgService.ImportForeignServer(ctx, name, userId) + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*server.Server, error) { + servers, err := s.serverService.FindServers(ctx, &server.FindOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to find servers: %w", err) + } + + knownInterfaces := adapt.Array(servers, func(server *server.Server) string { + return server.Name + }) + + foreignInterfaces, err := s.wireguardService.FindForeignServers(ctx, knownInterfaces) + if err != nil { + return nil, fmt.Errorf("failed to find foreign interfaces: %w", err) + } + + var foreignInterface *backend.ForeignInterface + for _, fn := range foreignInterfaces { + if strings.EqualFold(fn.Name, name) { + foreignInterface = fn.Interface + break + } + } + + if foreignInterface == nil { + return nil, fmt.Errorf("foreign interface: %s not found", name) + } + + device, err := s.wireguardService.Device(ctx, foreignInterface.Name) + if err != nil { + return nil, fmt.Errorf("failed to open interface: %s", foreignInterface.Name) + } + + var address string + if len(foreignInterface.Addresses) != 0 { + address = foreignInterface.Addresses[0] + } + + createServer, err := s.serverService.CreateServer(ctx, &server.CreateOptions{ + Name: foreignInterface.Name, + Description: "", + Enabled: true, + Running: true, + PublicKey: device.Wireguard.PublicKey, + PrivateKey: device.Wireguard.PrivateKey, + ListenPort: adapt.ToPointerNilZero(device.Wireguard.ListenPort), + FirewallMark: adapt.ToPointerNilZero(device.Wireguard.FirewallMark), + Address: address, + DNS: nil, + MTU: foreignInterface.Mtu, + }, userId) + if err != nil { + return nil, fmt.Errorf("failed to create server: %w", err) + } + + for i, p := range device.Wireguard.Peers { + _, err := s.peerService.CreatePeer(ctx, createServer.Id, &peer.CreateOptions{ + Name: fmt.Sprintf("Peer #%d", i+1), + Description: "", + PublicKey: p.PublicKey, + Endpoint: p.Endpoint, + AllowedIPs: adapt.Array(p.AllowedIPs, func(allowedIp net.IPNet) string { + return allowedIp.String() + }), + PresharedKey: p.PresharedKey, + PersistentKeepalive: int(p.PersistentKeepalive.Seconds()), + }, userId) + if err != nil { + return nil, fmt.Errorf("failed to create peer: %w", err) + } + } + + return createServer, nil + }) } func (s *service) CreatePeer(ctx context.Context, serverId string, options *peer.CreateOptions, userId string) (*peer.Peer, error) { - createdPeer, err := s.peerService.CreatePeer(ctx, serverId, options, userId) - if err != nil { - return nil, err - } + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*peer.Peer, error) { + createdPeer, err := s.peerService.CreatePeer(ctx, serverId, options, userId) + if err != nil { + return nil, err + } + return s.configurePeerDevice(ctx, createdPeer, userId) + }) +} - if err := s.wgService.AddPeer(ctx, createdPeer.Id); err != nil { - logrus. - WithError(err). - WithField("peerId", createdPeer.Id). - WithField("peerName", createdPeer.Name). - Warn("failed to add peer") +func (s *service) UpdatePeer(ctx context.Context, peerId string, options *peer.UpdateOptions, fieldMask *peer.UpdateFieldMask, userId string) (*peer.Peer, error) { + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*peer.Peer, error) { + updatedPeer, err := s.peerService.UpdatePeer(ctx, peerId, options, fieldMask, userId) + if err != nil { + return nil, err + } + return s.configurePeerDevice(ctx, updatedPeer, userId) + }) +} + +func (s *service) DeletePeer(ctx context.Context, peerId string, userId string) (*peer.Peer, error) { + return dbx.InTransactionScopeWithResult(ctx, s.transactionScoper, func(ctx context.Context) (*peer.Peer, error) { + deletedPeer, err := s.peerService.DeletePeer(ctx, peerId, userId) + if err != nil { + return nil, err + } + return s.configurePeerDevice(ctx, deletedPeer, userId) + }) +} + +func (s *service) PeerStats(ctx context.Context, name string, peerPublicKey string) (*backend.PeerStats, error) { + return s.wireguardService.PeerStats(ctx, name, peerPublicKey) +} + +func (s *service) ForeignServers(ctx context.Context) ([]*backend.ForeignServer, error) { + servers, err := s.serverService.FindServers(ctx, &server.FindOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to find servers: %w", err) } - return createdPeer, nil + knownInterfaces := adapt.Array(servers, func(server *server.Server) string { + return server.Name + }) + return s.wireguardService.FindForeignServers(ctx, knownInterfaces) } -func (s *service) UpdatePeer(ctx context.Context, peerId string, options *peer.UpdateOptions, fieldMask *peer.UpdateFieldMask, userId string) (*peer.Peer, error) { - updatedPeer, err := s.peerService.UpdatePeer(ctx, peerId, options, fieldMask, userId) +func (s *service) configurePeerDevice(ctx context.Context, p *peer.Peer, userId string) (*peer.Peer, error) { + srv, err := s.findServer(ctx, p.ServerId) if err != nil { return nil, err } - if err := s.wgService.UpdatePeer(ctx, peerId); err != nil { - logrus. - WithError(err). - WithField("peerId", updatedPeer.Id). - WithField("peerName", updatedPeer.Name). - Warn("failed to update peer") + status, err := s.wireguardService.Status(ctx, srv.Name) + if err != nil { + return nil, err } - return updatedPeer, nil -} + if !status { + return p, nil + } -func (s *service) DeletePeer(ctx context.Context, peerId string, userId string) (*peer.Peer, error) { - p, err := s.findPeer(ctx, peerId) + peers, err := s.peerService.FindPeers(ctx, &peer.FindOptions{ + ServerId: &srv.Id, + }) if err != nil { return nil, err } - if err := s.wgService.RemovePeer(ctx, peerId); err != nil { - logrus. - WithError(err). - WithField("peerId", p.Id). - WithField("peerName", p.Name). - Warn("failed to remove peer") + device, err := s.configureDevice(ctx, srv, peers) + if err != nil { + return nil, err } - deletedPeer, err := s.peerService.DeletePeer(ctx, peerId, userId) - if err != nil { + if _, err := s.updateServer(ctx, srv, device, userId); err != nil { return nil, err } - return deletedPeer, nil + return p, nil +} + +func (s *service) configureDevice(ctx context.Context, srv *server.Server, peers []*peer.Peer) (*backend.Device, error) { + return s.wireguardService.Up(ctx, backend.ConfigureOptions{ + InterfaceOptions: backend.InterfaceOptions{ + Name: srv.Name, + Address: srv.Address, + Mtu: srv.MTU, + }, + WireguardOptions: backend.WireguardOptions{ + PrivateKey: srv.PrivateKey, + ListenPort: srv.ListenPort, + FirewallMark: srv.FirewallMark, + Peers: adapt.Array(peers, func(peer *peer.Peer) *backend.PeerOptions { + return &backend.PeerOptions{ + PublicKey: peer.PublicKey, + Endpoint: peer.Endpoint, + AllowedIPs: peer.AllowedIPs, + PresharedKey: peer.PresharedKey, + PersistentKeepalive: peer.PersistentKeepalive, + } + }), + }, + }) +} + +func (s *service) updateServer(ctx context.Context, srv *server.Server, device *backend.Device, userId string) (*server.Server, error) { + updateOptions := server.UpdateOptions{ + Running: true, + PublicKey: device.Wireguard.PublicKey, + PrivateKey: device.Wireguard.PrivateKey, + ListenPort: adapt.ToPointerNilZero(device.Wireguard.ListenPort), + FirewallMark: adapt.ToPointerNilZero(device.Wireguard.FirewallMark), + MTU: device.Interface.Mtu, + } + updateFieldMask := server.UpdateFieldMask{ + Running: true, + PublicKey: strings.EqualFold(srv.PublicKey, device.Wireguard.PublicKey), + PrivateKey: strings.EqualFold(srv.PrivateKey, device.Wireguard.PrivateKey), + ListenPort: adapt.Dereference(srv.ListenPort) != device.Wireguard.ListenPort, + FirewallMark: adapt.Dereference(srv.FirewallMark) != device.Wireguard.FirewallMark, + MTU: srv.MTU != device.Interface.Mtu, + } + return s.serverService.UpdateServer(ctx, srv.Id, &updateOptions, &updateFieldMask, userId) } func (s *service) findUserById(ctx context.Context, userId string) (*user.User, error) { diff --git a/pkg/wg/foreign_peer.go b/pkg/wg/foreign_peer.go deleted file mode 100644 index 83582d7..0000000 --- a/pkg/wg/foreign_peer.go +++ /dev/null @@ -1,16 +0,0 @@ -package wg - -import ( - "time" -) - -type ForeignPeer struct { - PublicKey string - Endpoint *string - AllowedIPs []string - PersistentKeepaliveInterval float64 - LastHandshakeTime time.Time - ReceiveBytes int64 - TransmitBytes int64 - ProtocolVersion int -} diff --git a/pkg/wg/foreign_server.go b/pkg/wg/foreign_server.go deleted file mode 100644 index be0ef82..0000000 --- a/pkg/wg/foreign_server.go +++ /dev/null @@ -1,11 +0,0 @@ -package wg - -type ForeignServer struct { - ForeignInterface *ForeignInterface - Name string - Type string - PublicKey string - ListenPort int - FirewallMark int - Peers []*ForeignPeer -} diff --git a/pkg/wg/helper.go b/pkg/wg/helper.go deleted file mode 100644 index 44bf065..0000000 --- a/pkg/wg/helper.go +++ /dev/null @@ -1,7 +0,0 @@ -package wg - -func interfaceName(name string) []byte { - b := make([]byte, 16) - copy(b, name+"\x00") - return b -} diff --git a/pkg/wg/interface_linux.go b/pkg/wg/interface_linux.go deleted file mode 100644 index 5e8c363..0000000 --- a/pkg/wg/interface_linux.go +++ /dev/null @@ -1,270 +0,0 @@ -package wg - -import ( - "errors" - "fmt" - "net" - "os" - "slices" - "strings" - - "github.com/sirupsen/logrus" - "github.com/vishvananda/netlink" - - "github.com/UnAfraid/wg-ui/pkg/server" -) - -func configureInterface(name string, address string, mtu int) error { - attrs := netlink.NewLinkAttrs() - attrs.Name = name - - link := wgLink{ - attrs: &attrs, - } - - if err := netlink.LinkAdd(&link); err != nil { - if !os.IsExist(err) { - return fmt.Errorf("failed to add interface: %w", err) - } - } - - addressList, err := netlink.AddrList(&link, netFamilyAll) - if err != nil { - return fmt.Errorf("failed to get interface: %s address list: %w", name, err) - } - - serverAddress, err := netlink.ParseAddr(address) - if err != nil { - return fmt.Errorf("failed to parse client ip range: %w", err) - } - - needsAddress := true - for _, addr := range addressList { - if addr.Equal(*serverAddress) { - needsAddress = false - break - } - } - - if needsAddress { - if err = netlink.AddrAdd(&link, serverAddress); err != nil { - if !os.IsExist(err) { - return fmt.Errorf("failed to add address: %w", err) - } - } - } - - if mtu != attrs.MTU { - if err = netlink.LinkSetMTU(&link, mtu); err != nil { - return fmt.Errorf("failed to set server mtu: %w", err) - } - } - - if attrs.OperState != netlink.OperUp { - if err = netlink.LinkSetUp(&link); err != nil { - return fmt.Errorf("failed to set interface up: %w", err) - } - } - - return nil -} - -func configureRoutes(name string, allowedIPs []net.IPNet) error { - link, err := netlink.LinkByName(name) - if err != nil { - if os.IsNotExist(err) || errors.As(err, &netlink.LinkNotFoundError{}) { - return nil - } - return fmt.Errorf("failed to find link by name: %w", err) - } - - routes, err := netlink.RouteList(link, netFamilyAll) - if err != nil { - return fmt.Errorf("failed to get routes: %w", err) - } - - routesToAdd, routesToUpdate, routesToRemove := computeRoutes(link, routes, allowedIPs) - - for i, route := range routesToAdd { - if err = netlink.RouteAdd(routesToAdd[i]); err != nil { - return fmt.Errorf("failed to add route for %s - %w", route.Dst.String(), err) - } - - logrus. - WithField("name", link.Attrs().Name). - WithField("route", route.Dst.String()). - Debug("route added") - } - - for i, route := range routesToUpdate { - if err = netlink.RouteReplace(routesToAdd[i]); err != nil { - return fmt.Errorf("failed to replace route for %s - %w", route.Dst.String(), err) - } - - logrus. - WithField("name", link.Attrs().Name). - WithField("route", route.Dst.String()). - Debug("route replaced") - } - - for i, route := range routesToRemove { - if err = netlink.RouteDel(routesToAdd[i]); err != nil { - return fmt.Errorf("failed to delete route for %s - %w", route.Dst.String(), err) - } - - logrus. - WithField("name", link.Attrs().Name). - WithField("route", route.Dst.String()). - Debug("route deleted") - } - return nil -} - -func computeRoutes(link netlink.Link, existingRoutes []netlink.Route, allowedIPs []net.IPNet) ([]*netlink.Route, []*netlink.Route, []*netlink.Route) { - var routesToAdd []*netlink.Route - var routesToUpdate []*netlink.Route - var routesToRemove []*netlink.Route - for i, allowedIP := range allowedIPs { - var existingRoute *netlink.Route - for _, route := range existingRoutes { - if route.Dst != nil && route.Dst.IP.Equal(allowedIP.IP) && slices.Equal(route.Dst.Mask, allowedIP.Mask) { - existingRoute = &existingRoutes[i] - break - } - } - if existingRoute != nil { - var update bool - if existingRoute.Scope != netlink.SCOPE_LINK { - existingRoute.Scope = netlink.SCOPE_LINK - update = true - } - - if existingRoute.Protocol != netlink.RouteProtocol(3) { - existingRoute.Protocol = netlink.RouteProtocol(3) - update = true - } - - if existingRoute.Type != 1 { - existingRoute.Type = 1 - update = true - } - - if update { - routesToUpdate = append(routesToUpdate, existingRoute) - } - continue - } - - routesToAdd = append(routesToAdd, &netlink.Route{ - LinkIndex: link.Attrs().Index, - Scope: netlink.SCOPE_LINK, - Dst: &allowedIP, - Protocol: netlink.RouteProtocol(3), - Type: 1, - }) - } - - for i, existingRoute := range existingRoutes { - var exists bool - for _, allowedIP := range allowedIPs { - exists = existingRoute.Dst != nil && existingRoute.Dst.IP.Equal(allowedIP.IP) && slices.Equal(existingRoute.Dst.Mask, allowedIP.Mask) - if exists { - break - } - } - if !exists { - routesToRemove = append(routesToRemove, &existingRoutes[i]) - } - } - - return routesToAdd, routesToUpdate, routesToRemove -} - -func deleteInterface(name string) error { - link, err := netlink.LinkByName(name) - if err != nil { - if os.IsNotExist(err) || errors.As(err, &netlink.LinkNotFoundError{}) { - return nil - } - return fmt.Errorf("failed to find link by name: %w", err) - } - - if err := netlink.LinkDel(link); err != nil { - return fmt.Errorf("failed to delete interface down: %w", err) - } - return nil -} - -func interfaceStats(name string) (server.Stats, error) { - link, err := netlink.LinkByName(name) - if err != nil { - if os.IsNotExist(err) || errors.As(err, &netlink.LinkNotFoundError{}) { - return server.Stats{}, nil - } - return server.Stats{}, fmt.Errorf("failed to find link by name: %w", err) - } - - statistics := link.Attrs().Statistics - return server.Stats{ - RxPackets: statistics.RxPackets, - TxPackets: statistics.TxPackets, - RxBytes: statistics.RxBytes, - TxBytes: statistics.TxBytes, - RxErrors: statistics.RxErrors, - TxErrors: statistics.TxErrors, - RxDropped: statistics.RxDropped, - TxDropped: statistics.TxDropped, - Multicast: statistics.Multicast, - Collisions: statistics.Collisions, - RxLengthErrors: statistics.RxLengthErrors, - RxOverErrors: statistics.RxOverErrors, - RxCrcErrors: statistics.RxCrcErrors, - RxFrameErrors: statistics.RxFrameErrors, - RxFifoErrors: statistics.RxFifoErrors, - RxMissedErrors: statistics.RxMissedErrors, - TxAbortedErrors: statistics.TxAbortedErrors, - TxCarrierErrors: statistics.TxCarrierErrors, - TxFifoErrors: statistics.TxFifoErrors, - TxHeartbeatErrors: statistics.TxHeartbeatErrors, - TxWindowErrors: statistics.TxWindowErrors, - RxCompressed: statistics.RxCompressed, - TxCompressed: statistics.TxCompressed, - }, nil -} - -func findForeignInterfaces(knownInterfaces []string) (foreignInterfaces []ForeignInterface, err error) { - list, err := netlink.LinkList() - if err != nil { - return nil, err - } - - for _, link := range list { - if !strings.EqualFold(link.Type(), "wireguard") { - continue - } - - attrs := link.Attrs() - name := attrs.Name - if slices.Contains(knownInterfaces, name) { - continue - } - - addrList, err := netlink.AddrList(link, netFamilyAll) - if err != nil { - return nil, fmt.Errorf("failed to get address list for interface %s", name) - } - - var addresses []string - for _, addr := range addrList { - addresses = append(addresses, addr.IPNet.String()) - } - - foreignInterfaces = append(foreignInterfaces, ForeignInterface{ - Name: name, - Addresses: addresses, - Mtu: attrs.MTU, - State: attrs.OperState.String(), - }) - } - return foreignInterfaces, nil -} diff --git a/pkg/wg/interface_other.go b/pkg/wg/interface_other.go deleted file mode 100644 index b57603b..0000000 --- a/pkg/wg/interface_other.go +++ /dev/null @@ -1,29 +0,0 @@ -//go:build !linux - -package wg - -import ( - "net" - - "github.com/UnAfraid/wg-ui/pkg/server" -) - -func configureInterface(name string, address string, mtu int) error { - return nil -} - -func configureRoutes(name string, allowedIPs []net.IPNet) error { - return nil -} - -func deleteInterface(name string) error { - return nil -} - -func interfaceStats(name string) (server.Stats, error) { - return server.Stats{}, nil -} - -func findForeignInterfaces(knownInterfaces []string) ([]ForeignInterface, error) { - return nil, nil -} diff --git a/pkg/wg/service.go b/pkg/wg/service.go deleted file mode 100644 index e453cd2..0000000 --- a/pkg/wg/service.go +++ /dev/null @@ -1,754 +0,0 @@ -package wg - -import ( - "context" - "fmt" - "net" - "os" - "strings" - "time" - - "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/UnAfraid/wg-ui/pkg/internal/adapt" - "github.com/UnAfraid/wg-ui/pkg/peer" - "github.com/UnAfraid/wg-ui/pkg/server" -) - -const ( - netFamilyAll = 0 - netFamilyV4 = 2 - netFamilyV6 = 10 - - updateServersInterval = time.Minute - updateServerStatsInterval = 30 * time.Second -) - -type Service interface { - Close() error - ForeignServers(ctx context.Context) (foreignServers []*ForeignServer, err error) - ImportForeignServer(ctx context.Context, name string, userId string) (*server.Server, error) - StartServer(ctx context.Context, serverId string) (*server.Server, error) - StopServer(ctx context.Context, serverUd string) (*server.Server, error) - ConfigureWireGuard(name string, privateKey string, listenPort *int, firewallMark *int, peers []*peer.Peer) error - PeerStats(name string, peerPublicKey string) (*PeerStats, error) - AddPeer(ctx context.Context, peerId string) error - UpdatePeer(ctx context.Context, peerId string) error - RemovePeer(ctx context.Context, peerId string) error -} - -type service struct { - client *wgctrl.Client - updateStop func() - serverService server.Service - peerService peer.Service - stopChan chan struct{} - stoppedChan chan struct{} -} - -func NewService(serverService server.Service, peerService peer.Service) (Service, error) { - client, err := wgctrl.New() - if err != nil { - return nil, err - } - - s := &service{ - client: client, - serverService: serverService, - peerService: peerService, - stopChan: make(chan struct{}), - stoppedChan: make(chan struct{}), - } - - if err := s.init(); err != nil { - return nil, err - } - - go s.run() - - return s, nil -} - -func (s *service) init() error { - servers, err := s.serverService.FindServers(context.Background(), &server.FindOptions{}) - if err != nil { - return fmt.Errorf("failed to find servers: %w", err) - } - - for _, svc := range servers { - if !svc.Enabled { - continue - } - - if _, err := s.StartServer(context.Background(), svc.Id); err != nil { - logrus.WithError(err).Warn("failed to start server") - return nil - } - } - - return nil -} - -func (s *service) run() { - defer close(s.stoppedChan) - for { - select { - case <-s.stopChan: - return - case <-time.After(updateServersInterval): - s.updateServers() - case <-time.After(updateServerStatsInterval): - s.updateServerStats() - } - } -} - -func (s *service) updateServers() { - servers, err := s.serverService.FindServers(context.Background(), &server.FindOptions{}) - if err != nil { - logrus. - WithError(err). - Error("failed to find servers") - return - } - - for _, svc := range servers { - if !svc.Enabled { - continue - } - - wg, err := s.client.Device(svc.Name) - if err != nil { - if os.IsNotExist(err) { - if svc.Running { - updateOptions := &server.UpdateOptions{Running: false} - updateFieldMask := &server.UpdateFieldMask{Running: true} - if _, err = s.serverService.UpdateServer(context.Background(), svc.Id, updateOptions, updateFieldMask, ""); err != nil { - logrus. - WithError(err). - WithField("serverId", svc.Id). - WithField("serverName", svc.Name). - Warn("failed to update wireguard server") - } - } - return - } - - logrus. - WithError(err). - WithField("serverId", svc.Id). - WithField("serverName", svc.Name). - Error("failed to find open wireguard device") - return - } - - if adapt.Dereference(svc.ListenPort) != wg.ListenPort { - updateOptions := &server.UpdateOptions{ListenPort: &wg.ListenPort} - updateFieldMask := &server.UpdateFieldMask{ListenPort: true} - svc, err = s.serverService.UpdateServer(context.Background(), svc.Id, updateOptions, updateFieldMask, "") - if err != nil { - logrus. - WithError(err). - WithField("serverId", svc.Id). - WithField("serverName", svc.Name). - Error("failed to update wireguard server") - return - } - } - - for _, p := range wg.Peers { - existingPeer, err := s.peerService.FindPeer(context.Background(), &peer.FindOneOptions{ - ServerIdPublicKeyOption: &peer.ServerIdPublicKeyOption{ - ServerId: svc.Id, - PublicKey: p.PublicKey.String(), - }, - }) - if err != nil { - logrus. - WithError(err). - WithField("serverId", svc.Id). - WithField("serverName", svc.Name). - WithField("peerPublicKey", p.PublicKey.String()). - Warn("failed to to find peer") - continue - } - if existingPeer == nil { - continue - } - - if p.Endpoint == nil || p.Endpoint.String() == existingPeer.Endpoint { - continue - } - - options := &peer.UpdateOptions{Endpoint: p.Endpoint.String()} - fieldMask := &peer.UpdateFieldMask{Endpoint: true} - _, err = s.peerService.UpdatePeer(context.Background(), existingPeer.Id, options, fieldMask, "") - if err != nil { - logrus. - WithError(err). - WithField("serverId", svc.Id). - WithField("serverName", svc.Name). - WithField("peerId", existingPeer.Id). - WithField("peerPublicKey", p.PublicKey.String()). - Error("failed to to update peer") - return - } - } - } -} - -func (s *service) updateServerStats() { - servers, err := s.serverService.FindServers(context.Background(), &server.FindOptions{}) - if err != nil { - logrus. - WithError(err). - Error("failed to find servers") - return - } - - for _, svc := range servers { - if !svc.Enabled || !svc.Running { - continue - } - - newInterfaceStats, err := interfaceStats(svc.Name) - if err != nil { - logrus. - WithError(err). - WithField("name", svc.Name). - Warn("failed to get interface stats") - continue - } - - if newInterfaceStats != svc.Stats { - updateOptions := &server.UpdateOptions{Stats: newInterfaceStats} - updateFieldMask := &server.UpdateFieldMask{Stats: true} - _, err = s.serverService.UpdateServer(context.Background(), svc.Id, updateOptions, updateFieldMask, "") - if err != nil { - logrus. - WithError(err). - WithField("name", svc.Name). - Warn("failed update server stats") - continue - } - } - } -} - -func (s *service) Close() error { - close(s.stopChan) - <-s.stoppedChan - return s.client.Close() -} - -func (s *service) ForeignServers(ctx context.Context) (foreignServers []*ForeignServer, err error) { - servers, err := s.serverService.FindServers(ctx, &server.FindOptions{}) - if err != nil { - return nil, fmt.Errorf("failed to find servers: %w", err) - } - - knownInterfaces := adapt.Array(servers, func(server *server.Server) string { - return server.Name - }) - - foreignInterfaces, err := findForeignInterfaces(knownInterfaces) - if err != nil { - return nil, fmt.Errorf("failed to find foreign interfaces: %w", err) - } - - for i, foreignInterface := range foreignInterfaces { - device, err := s.client.Device(foreignInterface.Name) - if err != nil { - return nil, fmt.Errorf("failed to open wireguard interface: %s", foreignInterface.Name) - } - - foreignServers = append(foreignServers, &ForeignServer{ - ForeignInterface: &foreignInterfaces[i], - Name: device.Name, - Type: device.Type.String(), - PublicKey: device.PublicKey.String(), - ListenPort: device.ListenPort, - FirewallMark: device.FirewallMark, - Peers: adapt.Array(device.Peers, func(peer wgtypes.Peer) *ForeignPeer { - var endpoint *string - if peer.Endpoint != nil { - endpoint = adapt.ToPointer(peer.Endpoint.String()) - } - return &ForeignPeer{ - PublicKey: peer.PublicKey.String(), - Endpoint: endpoint, - PersistentKeepaliveInterval: peer.PersistentKeepaliveInterval.Seconds(), - LastHandshakeTime: peer.LastHandshakeTime, - ReceiveBytes: peer.ReceiveBytes, - TransmitBytes: peer.TransmitBytes, - AllowedIPs: adapt.Array(peer.AllowedIPs, func(allowedIp net.IPNet) string { - return allowedIp.String() - }), - ProtocolVersion: peer.ProtocolVersion, - } - }), - }) - } - return foreignServers, nil -} - -func (s *service) ImportForeignServer(ctx context.Context, name string, userId string) (*server.Server, error) { - servers, err := s.serverService.FindServers(ctx, &server.FindOptions{}) - if err != nil { - return nil, fmt.Errorf("failed to find servers: %w", err) - } - - knownInterfaces := adapt.Array(servers, func(server *server.Server) string { - return server.Name - }) - - foreignInterfaces, err := findForeignInterfaces(knownInterfaces) - if err != nil { - return nil, fmt.Errorf("failed to find foreign interfaces: %w", err) - } - - var foreignInterface *ForeignInterface - for _, fn := range foreignInterfaces { - if strings.EqualFold(fn.Name, name) { - foreignInterface = &fn - break - } - } - - if foreignInterface == nil { - return nil, fmt.Errorf("foreign interface: %s not found", name) - } - - device, err := s.client.Device(foreignInterface.Name) - if err != nil { - return nil, fmt.Errorf("failed to open interface: %s", foreignInterface.Name) - } - - var address string - if len(foreignInterface.Addresses) != 0 { - address = foreignInterface.Addresses[0] - } - - createServer, err := s.serverService.CreateServer(ctx, &server.CreateOptions{ - Name: foreignInterface.Name, - Description: "", - Enabled: true, - Running: true, - PublicKey: device.PublicKey.String(), - PrivateKey: device.PrivateKey.String(), - ListenPort: adapt.ToPointerNilZero(device.ListenPort), - FirewallMark: adapt.ToPointerNilZero(device.FirewallMark), - Address: address, - DNS: nil, - MTU: foreignInterface.Mtu, - }, userId) - if err != nil { - return nil, fmt.Errorf("failed to create server: %w", err) - } - - for i, p := range device.Peers { - var endpoint string - if p.Endpoint != nil { - endpoint = p.Endpoint.String() - } - - _, err := s.peerService.CreatePeer(ctx, createServer.Id, &peer.CreateOptions{ - Name: fmt.Sprintf("Peer #%d", i+1), - Description: "", - PublicKey: p.PublicKey.String(), - Endpoint: endpoint, - AllowedIPs: adapt.Array(p.AllowedIPs, func(allowedIp net.IPNet) string { - return allowedIp.String() - }), - PresharedKey: p.PresharedKey.String(), - PersistentKeepalive: int(p.PersistentKeepaliveInterval.Seconds()), - }, userId) - if err != nil { - return nil, fmt.Errorf("failed to create peer: %w", err) - } - } - - return createServer, nil -} - -func (s *service) StartServer(ctx context.Context, serverId string) (*server.Server, error) { - svc, err := s.findServer(ctx, serverId) - if err != nil { - return nil, err - } - - peers, err := s.peerService.FindPeers(ctx, &peer.FindOptions{ - ServerId: &svc.Id, - }) - if err != nil { - return nil, err - } - - logrus. - WithField("name", svc.Name). - Info("starting wireguard") - - if err := configureInterface(svc.Name, svc.Address, svc.MTU); err != nil { - return nil, fmt.Errorf("failed to configure interface: %w", err) - } - - if err := s.ConfigureWireGuard(svc.Name, svc.PrivateKey, svc.ListenPort, svc.FirewallMark, peers); err != nil { - return nil, fmt.Errorf("failed to configure wireguard: %w", err) - } - - currentDevice, err := s.client.Device(svc.Name) - if err != nil { - return nil, fmt.Errorf("failed to open wireguard device: %w", err) - } - - updateServerOptions := &server.UpdateOptions{ - ListenPort: ¤tDevice.ListenPort, - Running: true, - } - updateServerFieldMask := &server.UpdateFieldMask{ - ListenPort: true, - Running: true, - } - return s.serverService.UpdateServer(ctx, serverId, updateServerOptions, updateServerFieldMask, "") -} - -func (s *service) StopServer(ctx context.Context, serverId string) (*server.Server, error) { - svc, err := s.findServer(ctx, serverId) - if err != nil { - return nil, err - } - - logrus. - WithField("name", svc.Name). - Info("stopping wireguard") - - if err := deleteInterface(svc.Name); err != nil { - return nil, fmt.Errorf("failed to configure interface: %w", err) - } - - updateServerOptions := &server.UpdateOptions{ - Running: false, - } - updateServerFieldMask := &server.UpdateFieldMask{ - Running: true, - } - return s.serverService.UpdateServer(ctx, serverId, updateServerOptions, updateServerFieldMask, "") -} - -func (s *service) getAllowedIPs(name string) ([]net.IPNet, error) { - currentDevice, err := s.client.Device(name) - if err != nil { - return nil, fmt.Errorf("failed to open wireguard device: %w", err) - } - - var allowedIPs []net.IPNet - for _, p := range currentDevice.Peers { - allowedIPs = append(allowedIPs, p.AllowedIPs...) - } - return allowedIPs, nil -} - -func (s *service) ConfigureWireGuard(name string, privateKey string, listenPort *int, firewallMark *int, peers []*peer.Peer) error { - currentDevice, err := s.client.Device(name) - if err != nil { - return fmt.Errorf("failed to open wireguard device: %w", err) - } - - var actualPeers []wgtypes.PeerConfig - for _, p := range peers { - peerConfig, err := toPeerConfig(p) - if err != nil { - return err - } - actualPeers = append(actualPeers, peerConfig) - } - - var differentPeers []wgtypes.PeerConfig - for _, currentPeer := range currentDevice.Peers { - var found bool - for _, actualPeer := range actualPeers { - if currentPeer.PublicKey == actualPeer.PublicKey { - found = true - actualPeer.UpdateOnly = true - differentPeers = append(differentPeers, actualPeer) - break - } - } - if !found { - peerToRemove := wgtypes.PeerConfig{ - PublicKey: currentPeer.PublicKey, - Remove: true, - } - differentPeers = append(differentPeers, peerToRemove) - } - } - - for _, actualPeer := range actualPeers { - var found bool - for _, currentPeer := range currentDevice.Peers { - if actualPeer.PublicKey == currentPeer.PublicKey { - found = true - break - } - } - if !found { - differentPeers = append(differentPeers, actualPeer) - } - } - - return s.configureWireguard(name, privateKey, listenPort, firewallMark, differentPeers...) -} - -func (s *service) PeerStats(name string, peerPublicKey string) (*PeerStats, error) { - publicKey, err := wgtypes.ParseKey(peerPublicKey) - if err != nil { - return nil, fmt.Errorf("invalid peer: %s public key: %w", name, err) - } - - currentDevice, err := s.client.Device(name) - if err != nil { - return nil, fmt.Errorf("failed to open wireguard device: %w", err) - } - - for _, p := range currentDevice.Peers { - if p.PublicKey == publicKey { - return &PeerStats{ - LastHandshakeTime: p.LastHandshakeTime, - ReceiveBytes: p.ReceiveBytes, - TransmitBytes: p.TransmitBytes, - ProtocolVersion: p.ProtocolVersion, - }, nil - } - } - - return nil, nil -} - -func (s *service) AddPeer(ctx context.Context, peerId string) error { - p, err := s.findPeer(ctx, peerId) - if err != nil { - return err - } - - svc, err := s.findServer(ctx, p.ServerId) - if err != nil { - return err - } - - currentDevice, err := s.client.Device(svc.Name) - if err != nil { - return fmt.Errorf("failed to open wireguard device: %w", err) - } - - peerConfig, err := toPeerConfig(p) - if err != nil { - return err - } - - var currentPeer *wgtypes.Peer - for _, p := range currentDevice.Peers { - if p.PublicKey == peerConfig.PublicKey { - currentPeer = &p - break - } - } - - if currentPeer != nil { - peerConfig.UpdateOnly = true - if len(currentPeer.AllowedIPs) != len(peerConfig.AllowedIPs) { - peerConfig.ReplaceAllowedIPs = true - } else { - for i := 0; i < len(currentPeer.AllowedIPs); i++ { - if currentPeer.AllowedIPs[i].String() != peerConfig.AllowedIPs[i].String() { - peerConfig.ReplaceAllowedIPs = true - break - } - } - } - } - - return s.configureWireguard(svc.Name, svc.PrivateKey, svc.ListenPort, svc.FirewallMark, peerConfig) -} - -func (s *service) UpdatePeer(ctx context.Context, peerId string) error { - p, err := s.findPeer(ctx, peerId) - if err != nil { - return err - } - - svc, err := s.findServer(ctx, p.ServerId) - if err != nil { - return err - } - - currentDevice, err := s.client.Device(svc.Name) - if err != nil { - return fmt.Errorf("failed to open wireguard device: %w", err) - } - - peerConfig, err := toPeerConfig(p) - if err != nil { - return err - } - peerConfig.UpdateOnly = true - - var currentPeer *wgtypes.Peer - for _, p := range currentDevice.Peers { - if p.PublicKey == peerConfig.PublicKey { - currentPeer = &p - break - } - } - if currentPeer != nil { - peerConfig.UpdateOnly = true - if len(currentPeer.AllowedIPs) != len(peerConfig.AllowedIPs) { - peerConfig.ReplaceAllowedIPs = true - } else { - for i := 0; i < len(currentPeer.AllowedIPs); i++ { - if currentPeer.AllowedIPs[i].String() != peerConfig.AllowedIPs[i].String() { - peerConfig.ReplaceAllowedIPs = true - break - } - } - } - } - - return s.configureWireguard(svc.Name, svc.PrivateKey, svc.ListenPort, svc.FirewallMark, peerConfig) -} - -func (s *service) RemovePeer(ctx context.Context, peerId string) error { - p, err := s.findPeer(ctx, peerId) - if err != nil { - return err - } - - svc, err := s.findServer(ctx, p.ServerId) - if err != nil { - return err - } - - currentDevice, err := s.client.Device(svc.Name) - if err != nil { - return fmt.Errorf("failed to open wireguard device: %w", err) - } - - peerConfig, err := toPeerConfig(p) - if err != nil { - return err - } - - var currentPeer *wgtypes.Peer - for _, p := range currentDevice.Peers { - if p.PublicKey == peerConfig.PublicKey { - currentPeer = &p - break - } - } - if currentPeer != nil { - peerConfig.Remove = true - } - - return s.configureWireguard(svc.Name, svc.PrivateKey, svc.ListenPort, svc.FirewallMark, peerConfig) -} - -func (s *service) configureWireguard(name string, privateKey string, listenPort *int, firewallMark *int, peers ...wgtypes.PeerConfig) error { - key, err := wgtypes.ParseKey(privateKey) - if err != nil { - return fmt.Errorf("invalid server private key: %w", err) - } - - wgConfig := wgtypes.Config{ - PrivateKey: &key, - ListenPort: listenPort, - FirewallMark: firewallMark, - ReplacePeers: false, - Peers: peers, - } - - if err = s.client.ConfigureDevice(name, wgConfig); err != nil { - return fmt.Errorf("failed to configure device: %w", err) - } - - allowedIPs, err := s.getAllowedIPs(name) - if err != nil { - return fmt.Errorf("failed to get allowed ips: %w", err) - } - - if err := configureRoutes(name, allowedIPs); err != nil { - return fmt.Errorf("failed to configure routes: %w", err) - } - - return nil -} - -func toPeerConfig(peer *peer.Peer) (wgtypes.PeerConfig, error) { - publicKey, err := wgtypes.ParseKey(peer.PublicKey) - if err != nil { - return wgtypes.PeerConfig{}, fmt.Errorf("invalid peer: %s public key: %w", peer.Name, err) - } - - var presharedKey *wgtypes.Key - if peer.PresharedKey != "" { - key, err := wgtypes.ParseKey(peer.PresharedKey) - if err != nil { - return wgtypes.PeerConfig{}, fmt.Errorf("invalid peer: %s preshared key - %w", peer.Name, err) - } - presharedKey = &key - } - - allowedIPs := make([]net.IPNet, len(peer.AllowedIPs)) - for i, cidr := range peer.AllowedIPs { - _, ipNet, err := net.ParseCIDR(cidr) - if err != nil { - return wgtypes.PeerConfig{}, err - } - allowedIPs[i] = *ipNet - } - - var persistentKeepaliveInterval *time.Duration - if peer.PersistentKeepalive != 0 { - persistentKeepaliveInterval = adapt.ToPointer(time.Duration(peer.PersistentKeepalive) * time.Second) - } - - return wgtypes.PeerConfig{ - PublicKey: publicKey, - Remove: false, - UpdateOnly: false, - PresharedKey: presharedKey, - PersistentKeepaliveInterval: persistentKeepaliveInterval, - ReplaceAllowedIPs: false, - AllowedIPs: allowedIPs, - }, nil -} - -func (s *service) findServer(ctx context.Context, serverId string) (*server.Server, error) { - svc, err := s.serverService.FindServer(ctx, &server.FindOneOptions{ - IdOption: &server.IdOption{ - Id: serverId, - }, - }) - if err != nil { - return nil, err - } - if svc == nil { - return nil, server.ErrServerNotFound - } - return svc, nil -} - -func (s *service) findPeer(ctx context.Context, peerId string) (*peer.Peer, error) { - p, err := s.peerService.FindPeer(ctx, &peer.FindOneOptions{ - IdOption: &peer.IdOption{ - Id: peerId, - }, - }) - if err != nil { - return nil, err - } - if p == nil { - return nil, peer.ErrPeerNotFound - } - return p, nil -} diff --git a/pkg/wireguard/backend/backend.go b/pkg/wireguard/backend/backend.go new file mode 100644 index 0000000..2619991 --- /dev/null +++ b/pkg/wireguard/backend/backend.go @@ -0,0 +1,14 @@ +package backend + +import "context" + +type Backend interface { + Device(ctx context.Context, name string) (*Device, error) + Up(ctx context.Context, options ConfigureOptions) (*Device, error) + Down(ctx context.Context, name string) error + Status(ctx context.Context, name string) (bool, error) + Stats(ctx context.Context, name string) (*InterfaceStats, error) + PeerStats(ctx context.Context, name string, peerPublicKey string) (*PeerStats, error) + FindForeignServers(ctx context.Context, knownInterfaces []string) ([]*ForeignServer, error) + Close(ctx context.Context) error +} diff --git a/pkg/wireguard/backend/configure_options.go b/pkg/wireguard/backend/configure_options.go new file mode 100644 index 0000000..9137d06 --- /dev/null +++ b/pkg/wireguard/backend/configure_options.go @@ -0,0 +1,20 @@ +package backend + +import "fmt" + +type ConfigureOptions struct { + InterfaceOptions InterfaceOptions + WireguardOptions WireguardOptions +} + +func (o ConfigureOptions) Validate() error { + if err := o.InterfaceOptions.Validate(); err != nil { + return fmt.Errorf("interface options: %w", err) + } + + if err := o.WireguardOptions.Validate(); err != nil { + return fmt.Errorf("wireguard options: %w", err) + } + + return nil +} diff --git a/pkg/wireguard/backend/device.go b/pkg/wireguard/backend/device.go new file mode 100644 index 0000000..df31caf --- /dev/null +++ b/pkg/wireguard/backend/device.go @@ -0,0 +1,6 @@ +package backend + +type Device struct { + Interface Interface + Wireguard Wireguard +} diff --git a/pkg/wg/foreign_interface.go b/pkg/wireguard/backend/foreign_interface.go similarity index 86% rename from pkg/wg/foreign_interface.go rename to pkg/wireguard/backend/foreign_interface.go index f85b0c3..ccf54a4 100644 --- a/pkg/wg/foreign_interface.go +++ b/pkg/wireguard/backend/foreign_interface.go @@ -1,4 +1,4 @@ -package wg +package backend type ForeignInterface struct { Name string diff --git a/pkg/wireguard/backend/foreign_server.go b/pkg/wireguard/backend/foreign_server.go new file mode 100644 index 0000000..0ca84d8 --- /dev/null +++ b/pkg/wireguard/backend/foreign_server.go @@ -0,0 +1,11 @@ +package backend + +type ForeignServer struct { + Interface *ForeignInterface + Name string + Type string + PublicKey string + ListenPort int + FirewallMark int + Peers []*Peer +} diff --git a/pkg/wireguard/backend/interface.go b/pkg/wireguard/backend/interface.go new file mode 100644 index 0000000..a923115 --- /dev/null +++ b/pkg/wireguard/backend/interface.go @@ -0,0 +1,7 @@ +package backend + +type Interface struct { + Name string + Addresses []string + Mtu int +} diff --git a/pkg/wireguard/backend/interface_options.go b/pkg/wireguard/backend/interface_options.go new file mode 100644 index 0000000..38cda84 --- /dev/null +++ b/pkg/wireguard/backend/interface_options.go @@ -0,0 +1,19 @@ +package backend + +import "errors" + +type InterfaceOptions struct { + Name string + Address string + Mtu int +} + +func (o InterfaceOptions) Validate() error { + if len(o.Name) == 0 { + return errors.New("name is required") + } + if len(o.Address) == 0 { + return errors.New("address is required") + } + return nil +} diff --git a/pkg/wireguard/backend/interface_stats.go b/pkg/wireguard/backend/interface_stats.go new file mode 100644 index 0000000..df708eb --- /dev/null +++ b/pkg/wireguard/backend/interface_stats.go @@ -0,0 +1,27 @@ +package backend + +type InterfaceStats struct { + RxPackets uint64 + TxPackets uint64 + RxBytes uint64 + TxBytes uint64 + RxErrors uint64 + TxErrors uint64 + RxDropped uint64 + TxDropped uint64 + Multicast uint64 + Collisions uint64 + RxLengthErrors uint64 + RxOverErrors uint64 + RxCrcErrors uint64 + RxFrameErrors uint64 + RxFifoErrors uint64 + RxMissedErrors uint64 + TxAbortedErrors uint64 + TxCarrierErrors uint64 + TxFifoErrors uint64 + TxHeartbeatErrors uint64 + TxWindowErrors uint64 + RxCompressed uint64 + TxCompressed uint64 +} diff --git a/pkg/wireguard/backend/peer.go b/pkg/wireguard/backend/peer.go new file mode 100644 index 0000000..aa028fe --- /dev/null +++ b/pkg/wireguard/backend/peer.go @@ -0,0 +1,15 @@ +package backend + +import ( + "net" + "time" +) + +type Peer struct { + PublicKey string + Endpoint string + AllowedIPs []net.IPNet + PresharedKey string + PersistentKeepalive time.Duration + Stats PeerStats +} diff --git a/pkg/wireguard/backend/peer_options.go b/pkg/wireguard/backend/peer_options.go new file mode 100644 index 0000000..2e608f7 --- /dev/null +++ b/pkg/wireguard/backend/peer_options.go @@ -0,0 +1,23 @@ +package backend + +import "errors" + +type PeerOptions struct { + PublicKey string + Endpoint string + AllowedIPs []string + PresharedKey string + PersistentKeepalive int +} + +func (o *PeerOptions) Validate() error { + if len(o.PublicKey) == 0 { + return errors.New("public key is required") + } + + if len(o.AllowedIPs) == 0 { + return errors.New("allowed ips are required") + } + + return nil +} diff --git a/pkg/wg/peer_stats.go b/pkg/wireguard/backend/peer_stats.go similarity index 90% rename from pkg/wg/peer_stats.go rename to pkg/wireguard/backend/peer_stats.go index 18af28e..17cc0cc 100644 --- a/pkg/wg/peer_stats.go +++ b/pkg/wireguard/backend/peer_stats.go @@ -1,4 +1,4 @@ -package wg +package backend import ( "time" diff --git a/pkg/wireguard/backend/wireguard.go b/pkg/wireguard/backend/wireguard.go new file mode 100644 index 0000000..9104179 --- /dev/null +++ b/pkg/wireguard/backend/wireguard.go @@ -0,0 +1,10 @@ +package backend + +type Wireguard struct { + Name string + PublicKey string + PrivateKey string + ListenPort int + FirewallMark int + Peers []*Peer +} diff --git a/pkg/wireguard/backend/wireguard_options.go b/pkg/wireguard/backend/wireguard_options.go new file mode 100644 index 0000000..c9d140c --- /dev/null +++ b/pkg/wireguard/backend/wireguard_options.go @@ -0,0 +1,25 @@ +package backend + +import ( + "errors" + "fmt" +) + +type WireguardOptions struct { + PrivateKey string + ListenPort *int + FirewallMark *int + Peers []*PeerOptions +} + +func (o WireguardOptions) Validate() error { + if len(o.PrivateKey) == 0 { + return errors.New("private key required") + } + for _, peer := range o.Peers { + if err := peer.Validate(); err != nil { + return fmt.Errorf("peer: %w", err) + } + } + return nil +} diff --git a/pkg/wireguard/linux/adapter.go b/pkg/wireguard/linux/adapter.go new file mode 100644 index 0000000..6e8e78b --- /dev/null +++ b/pkg/wireguard/linux/adapter.go @@ -0,0 +1,107 @@ +//go:build linux + +package linux + +import ( + "fmt" + "net" + "time" + + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/UnAfraid/wg-ui/pkg/internal/adapt" + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" +) + +func wireguardPeerOptionsToPeerConfig(peer *backend.PeerOptions) (wgtypes.PeerConfig, error) { + publicKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return wgtypes.PeerConfig{}, fmt.Errorf("invalid peer: %s public key: %w", peer.PublicKey, err) + } + + var presharedKey *wgtypes.Key + if peer.PresharedKey != "" { + key, err := wgtypes.ParseKey(peer.PresharedKey) + if err != nil { + return wgtypes.PeerConfig{}, fmt.Errorf("invalid peer: %s preshared key - %w", peer.PublicKey, err) + } + presharedKey = &key + } + + allowedIPs := make([]net.IPNet, len(peer.AllowedIPs)) + for i, cidr := range peer.AllowedIPs { + _, ipNet, err := net.ParseCIDR(cidr) + if err != nil { + return wgtypes.PeerConfig{}, err + } + allowedIPs[i] = *ipNet + } + + var persistentKeepaliveInterval *time.Duration + if peer.PersistentKeepalive != 0 { + persistentKeepaliveInterval = adapt.ToPointer(time.Duration(peer.PersistentKeepalive) * time.Second) + } + + return wgtypes.PeerConfig{ + PublicKey: publicKey, + Remove: false, + UpdateOnly: false, + PresharedKey: presharedKey, + PersistentKeepaliveInterval: persistentKeepaliveInterval, + ReplaceAllowedIPs: false, + AllowedIPs: allowedIPs, + }, nil +} + +func linkStatisticsToBackendInterfaceStats(statistics *netlink.LinkStatistics) *backend.InterfaceStats { + if statistics == nil { + return nil + } + return &backend.InterfaceStats{ + RxPackets: statistics.RxPackets, + TxPackets: statistics.TxPackets, + RxBytes: statistics.RxBytes, + TxBytes: statistics.TxBytes, + RxErrors: statistics.RxErrors, + TxErrors: statistics.TxErrors, + RxDropped: statistics.RxDropped, + TxDropped: statistics.TxDropped, + Multicast: statistics.Multicast, + Collisions: statistics.Collisions, + RxLengthErrors: statistics.RxLengthErrors, + RxOverErrors: statistics.RxOverErrors, + RxCrcErrors: statistics.RxCrcErrors, + RxFrameErrors: statistics.RxFrameErrors, + RxFifoErrors: statistics.RxFifoErrors, + RxMissedErrors: statistics.RxMissedErrors, + TxAbortedErrors: statistics.TxAbortedErrors, + TxCarrierErrors: statistics.TxCarrierErrors, + TxFifoErrors: statistics.TxFifoErrors, + TxHeartbeatErrors: statistics.TxHeartbeatErrors, + TxWindowErrors: statistics.TxWindowErrors, + RxCompressed: statistics.RxCompressed, + TxCompressed: statistics.TxCompressed, + } +} + +func netlinkInterfaceToForeignInterface(link netlink.Link) (*backend.ForeignInterface, error) { + attrs := link.Attrs() + + addrList, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return nil, fmt.Errorf("failed to get address list for interface %s", attrs.Name) + } + + var addresses []string + for _, addr := range addrList { + addresses = append(addresses, addr.IPNet.String()) + } + + return &backend.ForeignInterface{ + Name: attrs.Name, + Addresses: addresses, + Mtu: attrs.MTU, + State: attrs.OperState.String(), + }, nil +} diff --git a/pkg/wireguard/linux/backend_linux.go b/pkg/wireguard/linux/backend_linux.go new file mode 100644 index 0000000..177c298 --- /dev/null +++ b/pkg/wireguard/linux/backend_linux.go @@ -0,0 +1,525 @@ +//go:build linux + +package linux + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "slices" + "strings" + + "github.com/sirupsen/logrus" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/UnAfraid/wg-ui/pkg/internal/adapt" + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" +) + +type linuxBackend struct { + client *wgctrl.Client +} + +func NewLinuxBackend() (backend.Backend, error) { + client, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("failed to initialize linux backend: %w", err) + } + + return &linuxBackend{ + client: client, + }, nil +} + +func (lb *linuxBackend) Device(_ context.Context, name string) (*backend.Device, error) { + device, err := lb.client.Device(name) + if err != nil { + return nil, fmt.Errorf("failed to find device: %s", err) + } + + return wgDeviceToBackendDevice(device, name) +} + +func (lb *linuxBackend) Up(_ context.Context, options backend.ConfigureOptions) (*backend.Device, error) { + if err := options.Validate(); err != nil { + return nil, err + } + + interfaceOptions := options.InterfaceOptions + if err := configureInterface(interfaceOptions.Name, interfaceOptions.Address, interfaceOptions.Mtu); err != nil { + return nil, fmt.Errorf("failed to configure interface: %s - %w", interfaceOptions.Name, err) + } + + wireguardOptions := options.WireguardOptions + if err := lb.configureWireguard(interfaceOptions.Name, wireguardOptions.PrivateKey, wireguardOptions.ListenPort, wireguardOptions.FirewallMark, wireguardOptions.Peers); err != nil { + return nil, fmt.Errorf("failed to configure wireguard: %s - %w", interfaceOptions.Name, err) + } + + device, err := lb.client.Device(interfaceOptions.Name) + if err != nil { + return nil, fmt.Errorf("failed to find device: %s", err) + } + + return wgDeviceToBackendDevice(device, interfaceOptions.Name) +} + +func wgDeviceToBackendDevice(device *wgtypes.Device, name string) (*backend.Device, error) { + link, err := findInterface(name) + if err != nil { + return nil, err + } + + addressList, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return nil, fmt.Errorf("failed to get interface: %s address list: %w", name, err) + } + + return &backend.Device{ + Interface: backend.Interface{ + Name: link.Attrs().Name, + Addresses: adapt.Array(addressList, func(addr netlink.Addr) string { + return addr.String() + }), + Mtu: link.Attrs().MTU, + }, + Wireguard: backend.Wireguard{ + Name: device.Name, + PublicKey: device.PublicKey.String(), + PrivateKey: device.PrivateKey.String(), + ListenPort: device.ListenPort, + FirewallMark: device.FirewallMark, + Peers: adapt.Array(device.Peers, func(peer wgtypes.Peer) *backend.Peer { + var endpoint string + if peer.Endpoint != nil { + endpoint = peer.Endpoint.String() + } + return &backend.Peer{ + PublicKey: peer.PublicKey.String(), + Endpoint: endpoint, + AllowedIPs: peer.AllowedIPs, + PresharedKey: peer.PresharedKey.String(), + PersistentKeepalive: peer.PersistentKeepaliveInterval, + Stats: backend.PeerStats{ + LastHandshakeTime: peer.LastHandshakeTime, + ReceiveBytes: peer.ReceiveBytes, + TransmitBytes: peer.TransmitBytes, + ProtocolVersion: peer.ProtocolVersion, + }, + } + }), + }, + }, nil +} + +func (lb *linuxBackend) Down(_ context.Context, name string) error { + return deleteInterface(name) +} + +func (lb *linuxBackend) Status(_ context.Context, name string) (bool, error) { + link, err := findInterface(name) + if err != nil { + return false, err + } + return link != nil, nil +} + +func (lb *linuxBackend) Stats(_ context.Context, name string) (*backend.InterfaceStats, error) { + return interfaceStats(name) +} + +func (lb *linuxBackend) PeerStats(_ context.Context, name string, peerPublicKey string) (*backend.PeerStats, error) { + currentDevice, err := lb.client.Device(name) + if err != nil { + return nil, fmt.Errorf("failed to open wireguard device: %w", err) + } + return peerStats(currentDevice, name, peerPublicKey) +} + +func (lb *linuxBackend) FindForeignServers(_ context.Context, knownInterfaces []string) ([]*backend.ForeignServer, error) { + return lb.findForeignServers(knownInterfaces) +} + +func (lb *linuxBackend) configureWireguard(name string, privateKey string, listenPort *int, firewallMark *int, peerOptions []*backend.PeerOptions) error { + device, err := lb.client.Device(name) + if err != nil { + return fmt.Errorf("failed to open wireguard device: %w", err) + } + + key, err := wgtypes.ParseKey(privateKey) + if err != nil { + return fmt.Errorf("invalid private key: %w", err) + } + + peers, err := computePeers(device, peerOptions) + if err != nil { + return fmt.Errorf("failed to compute peers: %w", err) + } + + return lb.applyDeviceConfiguration(device, name, &key, listenPort, firewallMark, peers) +} + +func (lb *linuxBackend) Close(_ context.Context) error { + return lb.client.Close() +} + +func computePeers(device *wgtypes.Device, peerOptions []*backend.PeerOptions) ([]wgtypes.PeerConfig, error) { + var actualPeers []wgtypes.PeerConfig + for _, p := range peerOptions { + peerConfig, err := wireguardPeerOptionsToPeerConfig(p) + if err != nil { + return nil, err + } + actualPeers = append(actualPeers, peerConfig) + } + + var peers []wgtypes.PeerConfig + for _, currentPeer := range device.Peers { + var found bool + for _, actualPeer := range actualPeers { + if currentPeer.PublicKey == actualPeer.PublicKey { + found = true + actualPeer.UpdateOnly = true + peers = append(peers, actualPeer) + break + } + } + if !found { + peerToRemove := wgtypes.PeerConfig{ + PublicKey: currentPeer.PublicKey, + Remove: true, + } + peers = append(peers, peerToRemove) + } + } + + for _, actualPeer := range actualPeers { + var found bool + for _, currentPeer := range device.Peers { + if actualPeer.PublicKey == currentPeer.PublicKey { + found = true + break + } + } + if !found { + peers = append(peers, actualPeer) + } + } + + return peers, nil +} + +func (lb *linuxBackend) applyDeviceConfiguration( + device *wgtypes.Device, + name string, + privateKey *wgtypes.Key, + listenPort *int, + firewallMark *int, + peers []wgtypes.PeerConfig, +) error { + wgConfig := wgtypes.Config{ + PrivateKey: privateKey, + ListenPort: listenPort, + FirewallMark: firewallMark, + ReplacePeers: false, + Peers: peers, + } + + if err := lb.client.ConfigureDevice(name, wgConfig); err != nil { + return fmt.Errorf("failed to configure device: %w", err) + } + + var allowedIPs []net.IPNet + for _, p := range device.Peers { + allowedIPs = append(allowedIPs, p.AllowedIPs...) + } + + if err := configureRoutes(name, allowedIPs); err != nil { + return fmt.Errorf("failed to configure routes: %w", err) + } + + return nil +} + +func findInterface(name string) (netlink.Link, error) { + link, err := netlink.LinkByName(name) + if err != nil { + if os.IsNotExist(err) || errors.As(err, &netlink.LinkNotFoundError{}) { + return nil, nil + } + return nil, fmt.Errorf("failed to find link by name: %w", err) + } + return link, nil +} + +func configureInterface(name string, address string, mtu int) error { + attrs := netlink.NewLinkAttrs() + attrs.Name = name + + link := &wgLink{ + attrs: &attrs, + } + + if err := netlink.LinkAdd(link); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("failed to add interface: %w", err) + } + } + + addressList, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get interface: %s address list: %w", name, err) + } + + serverAddress, err := netlink.ParseAddr(address) + if err != nil { + return fmt.Errorf("failed to parse interface address: %w", err) + } + + needsAddress := true + for _, addr := range addressList { + if addr.Equal(*serverAddress) { + needsAddress = false + break + } + } + + if needsAddress { + if err = netlink.AddrAdd(link, serverAddress); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("failed to add address: %w", err) + } + } + } + + if mtu != attrs.MTU { + if err = netlink.LinkSetMTU(link, mtu); err != nil { + return fmt.Errorf("failed to set server mtu: %w", err) + } + } + + if attrs.OperState != netlink.OperUp { + if err = netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to set interface up: %w", err) + } + } + + return nil +} + +func deleteInterface(name string) error { + link, err := findInterface(name) + if err != nil { + return err + } + if link == nil { + return nil + } + + if err := netlink.LinkDel(link); err != nil { + return fmt.Errorf("failed to delete interface down: %w", err) + } + return nil +} + +func interfaceStats(name string) (*backend.InterfaceStats, error) { + link, err := findInterface(name) + if err != nil { + return nil, err + } + if link == nil { + return nil, nil + } + return linkStatisticsToBackendInterfaceStats(link.Attrs().Statistics), nil +} + +func peerStats(device *wgtypes.Device, name string, peerPublicKey string) (*backend.PeerStats, error) { + publicKey, err := wgtypes.ParseKey(peerPublicKey) + if err != nil { + return nil, fmt.Errorf("invalid peer: %s public key: %w", name, err) + } + + for _, p := range device.Peers { + if p.PublicKey == publicKey { + return &backend.PeerStats{ + LastHandshakeTime: p.LastHandshakeTime, + ReceiveBytes: p.ReceiveBytes, + TransmitBytes: p.TransmitBytes, + ProtocolVersion: p.ProtocolVersion, + }, nil + } + } + + return nil, nil +} + +func configureRoutes(name string, allowedIPs []net.IPNet) error { + link, err := findInterface(name) + if err != nil { + return fmt.Errorf("failed to find link by name: %w", err) + } + if link == nil { + return fmt.Errorf("interface not found: %s", name) + } + + routes, err := netlink.RouteList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get routes: %w", err) + } + + routesToAdd, routesToUpdate, routesToRemove := computeRoutes(link, routes, allowedIPs) + + for i, route := range routesToAdd { + if err = netlink.RouteAdd(routesToAdd[i]); err != nil { + return fmt.Errorf("failed to add route for %s - %w", route.Dst.String(), err) + } + + logrus. + WithField("name", link.Attrs().Name). + WithField("route", route.Dst.String()). + Debug("route added") + } + + for i, route := range routesToUpdate { + if err = netlink.RouteReplace(routesToAdd[i]); err != nil { + return fmt.Errorf("failed to replace route for %s - %w", route.Dst.String(), err) + } + + logrus. + WithField("name", link.Attrs().Name). + WithField("route", route.Dst.String()). + Debug("route replaced") + } + + for i, route := range routesToRemove { + if err = netlink.RouteDel(routesToAdd[i]); err != nil { + return fmt.Errorf("failed to delete route for %s - %w", route.Dst.String(), err) + } + + logrus. + WithField("name", link.Attrs().Name). + WithField("route", route.Dst.String()). + Debug("route deleted") + } + return nil +} + +func computeRoutes(link netlink.Link, existingRoutes []netlink.Route, allowedIPs []net.IPNet) ([]*netlink.Route, []*netlink.Route, []*netlink.Route) { + var routesToAdd []*netlink.Route + var routesToUpdate []*netlink.Route + var routesToRemove []*netlink.Route + for i, allowedIP := range allowedIPs { + var existingRoute *netlink.Route + for _, route := range existingRoutes { + if route.Dst != nil && route.Dst.IP.Equal(allowedIP.IP) && slices.Equal(route.Dst.Mask, allowedIP.Mask) { + existingRoute = &existingRoutes[i] + break + } + } + if existingRoute != nil { + var update bool + if existingRoute.Scope != netlink.SCOPE_LINK { + existingRoute.Scope = netlink.SCOPE_LINK + update = true + } + + if existingRoute.Protocol != netlink.RouteProtocol(3) { + existingRoute.Protocol = netlink.RouteProtocol(3) + update = true + } + + if existingRoute.Type != 1 { + existingRoute.Type = 1 + update = true + } + + if update { + routesToUpdate = append(routesToUpdate, existingRoute) + } + continue + } + + routesToAdd = append(routesToAdd, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Scope: netlink.SCOPE_LINK, + Dst: &allowedIPs[i], + Protocol: netlink.RouteProtocol(3), + Type: 1, + }) + } + + for i, existingRoute := range existingRoutes { + var exists bool + for _, allowedIP := range allowedIPs { + exists = existingRoute.Dst != nil && existingRoute.Dst.IP.Equal(allowedIP.IP) && slices.Equal(existingRoute.Dst.Mask, allowedIP.Mask) + if exists { + break + } + } + if !exists { + routesToRemove = append(routesToRemove, &existingRoutes[i]) + } + } + + return routesToAdd, routesToUpdate, routesToRemove +} + +func (lb *linuxBackend) findForeignServers(knownInterfaces []string) ([]*backend.ForeignServer, error) { + list, err := netlink.LinkList() + if err != nil { + return nil, err + } + + var foreignServers []*backend.ForeignServer + for _, link := range list { + if !strings.EqualFold(link.Type(), "wireguard") { + continue + } + + if slices.Contains(knownInterfaces, link.Attrs().Name) { + continue + } + + foreignInterface, err := netlinkInterfaceToForeignInterface(link) + if err != nil { + return nil, err + } + + device, err := lb.client.Device(foreignInterface.Name) + if err != nil { + return nil, err + } + + foreignServers = append(foreignServers, &backend.ForeignServer{ + Interface: foreignInterface, + Name: device.Name, + Type: device.Type.String(), + PublicKey: device.PublicKey.String(), + ListenPort: device.ListenPort, + FirewallMark: device.FirewallMark, + Peers: adapt.Array(device.Peers, func(peer wgtypes.Peer) *backend.Peer { + var endpoint string + if peer.Endpoint != nil { + endpoint = peer.Endpoint.String() + } + return &backend.Peer{ + PublicKey: peer.PublicKey.String(), + Endpoint: endpoint, + AllowedIPs: peer.AllowedIPs, + PresharedKey: peer.PresharedKey.String(), + PersistentKeepalive: peer.PersistentKeepaliveInterval, + Stats: backend.PeerStats{ + LastHandshakeTime: peer.LastHandshakeTime, + ReceiveBytes: peer.ReceiveBytes, + TransmitBytes: peer.TransmitBytes, + ProtocolVersion: peer.ProtocolVersion, + }, + } + }), + }) + } + return foreignServers, nil +} diff --git a/pkg/wireguard/linux/backend_other.go b/pkg/wireguard/linux/backend_other.go new file mode 100644 index 0000000..efd767f --- /dev/null +++ b/pkg/wireguard/linux/backend_other.go @@ -0,0 +1,13 @@ +//go:build !linux + +package linux + +import ( + "errors" + + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" +) + +func NewLinuxBackend() (backend.Backend, error) { + return nil, errors.New("linux backend is only supported on linux") +} diff --git a/pkg/wg/wg_link.go b/pkg/wireguard/linux/wg_link.go similarity index 87% rename from pkg/wg/wg_link.go rename to pkg/wireguard/linux/wg_link.go index 3b64e70..3b27446 100644 --- a/pkg/wg/wg_link.go +++ b/pkg/wireguard/linux/wg_link.go @@ -1,4 +1,6 @@ -package wg +//go:build linux + +package linux import ( "github.com/vishvananda/netlink" diff --git a/pkg/wireguard/service.go b/pkg/wireguard/service.go new file mode 100644 index 0000000..8084ce7 --- /dev/null +++ b/pkg/wireguard/service.go @@ -0,0 +1,60 @@ +package wireguard + +import ( + "context" + + "github.com/UnAfraid/wg-ui/pkg/wireguard/backend" +) + +type Service interface { + Device(ctx context.Context, name string) (*backend.Device, error) + Up(ctx context.Context, options backend.ConfigureOptions) (*backend.Device, error) + Down(ctx context.Context, name string) error + Status(ctx context.Context, name string) (bool, error) + Stats(ctx context.Context, name string) (*backend.InterfaceStats, error) + PeerStats(ctx context.Context, name string, peerPublicKey string) (*backend.PeerStats, error) + FindForeignServers(_ context.Context, knownInterfaces []string) ([]*backend.ForeignServer, error) + Close(ctx context.Context) error +} + +type service struct { + backend backend.Backend +} + +func NewService(backend backend.Backend) Service { + return &service{ + backend: backend, + } +} + +func (s *service) Device(ctx context.Context, name string) (*backend.Device, error) { + return s.backend.Device(ctx, name) +} + +func (s *service) Up(ctx context.Context, options backend.ConfigureOptions) (*backend.Device, error) { + return s.backend.Up(ctx, options) +} + +func (s *service) Down(ctx context.Context, name string) error { + return s.backend.Down(ctx, name) +} + +func (s *service) Status(ctx context.Context, name string) (bool, error) { + return s.backend.Status(ctx, name) +} + +func (s *service) Stats(ctx context.Context, name string) (*backend.InterfaceStats, error) { + return s.backend.Stats(ctx, name) +} + +func (s *service) PeerStats(ctx context.Context, name string, peerPublicKey string) (*backend.PeerStats, error) { + return s.backend.PeerStats(ctx, name, peerPublicKey) +} + +func (s *service) FindForeignServers(ctx context.Context, knownInterfaces []string) ([]*backend.ForeignServer, error) { + return s.backend.FindForeignServers(ctx, knownInterfaces) +} + +func (s *service) Close(ctx context.Context) error { + return s.backend.Close(ctx) +}