diff --git a/cmd/dashboard/controller/common_page.go b/cmd/dashboard/controller/common_page.go index 8182d50d40..24e1c77d2f 100644 --- a/cmd/dashboard/controller/common_page.go +++ b/cmd/dashboard/controller/common_page.go @@ -260,8 +260,8 @@ func (cp *commonPage) home(c *gin.Context) { } var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, + ReadBufferSize: 10240, + WriteBufferSize: 10240, } type Data struct { @@ -305,8 +305,8 @@ func (cp *commonPage) ws(c *gin.Context) { } func (cp *commonPage) terminal(c *gin.Context) { - terminalID := c.Param("id") - if _, err := rpc.NezhaHandlerSingleton.GetStream(terminalID); err != nil { + streamId := c.Param("id") + if _, err := rpc.NezhaHandlerSingleton.GetStream(streamId); err != nil { mygin.ShowErrorPage(c, mygin.ErrInfo{ Code: http.StatusForbidden, Title: "无权访问", @@ -316,7 +316,7 @@ func (cp *commonPage) terminal(c *gin.Context) { }, true) return } - defer rpc.NezhaHandlerSingleton.CloseStream(terminalID) + defer rpc.NezhaHandlerSingleton.CloseStream(streamId) wsConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { @@ -344,11 +344,11 @@ func (cp *commonPage) terminal(c *gin.Context) { } }() - if err = rpc.NezhaHandlerSingleton.UserConnected(terminalID, conn); err != nil { + if err = rpc.NezhaHandlerSingleton.UserConnected(streamId, conn); err != nil { return } - rpc.NezhaHandlerSingleton.StartStream(terminalID, time.Second*10) + rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) } type createTerminalRequest struct { @@ -380,7 +380,7 @@ func (cp *commonPage) createTerminal(c *gin.Context) { return } - id, err := uuid.GenerateUUID() + streamId, err := uuid.GenerateUUID() if err != nil { mygin.ShowErrorPage(c, mygin.ErrInfo{ Code: http.StatusInternalServerError, @@ -394,7 +394,7 @@ func (cp *commonPage) createTerminal(c *gin.Context) { return } - rpc.NezhaHandlerSingleton.CreateStream(id) + rpc.NezhaHandlerSingleton.CreateStream(streamId) singleton.ServerLock.RLock() server := singleton.ServerList[createTerminalReq.ID] @@ -411,7 +411,7 @@ func (cp *commonPage) createTerminal(c *gin.Context) { } terminalData, _ := utils.Json.Marshal(&model.TerminalTask{ - StreamID: id, + StreamID: streamId, }) if err := server.TaskStream.Send(&proto.Task{ Type: model.TaskTypeTerminalGRPC, @@ -428,7 +428,7 @@ func (cp *commonPage) createTerminal(c *gin.Context) { } c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/terminal", mygin.CommonEnvironment(c, gin.H{ - "SessionID": id, + "SessionID": streamId, "ServerName": server.Name, })) } diff --git a/cmd/dashboard/controller/controller.go b/cmd/dashboard/controller/controller.go index 8ede497682..927976952f 100644 --- a/cmd/dashboard/controller/controller.go +++ b/cmd/dashboard/controller/controller.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "fmt" "html/template" "io/fs" @@ -14,16 +15,26 @@ import ( "code.cloudfoundry.org/bytefmt" "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" + "github.com/hashicorp/go-uuid" "github.com/nicksnyder/go-i18n/v2/i18n" + "github.com/naiba/nezha/model" "github.com/naiba/nezha/pkg/mygin" + "github.com/naiba/nezha/pkg/utils" + "github.com/naiba/nezha/proto" "github.com/naiba/nezha/resource" + "github.com/naiba/nezha/service/rpc" "github.com/naiba/nezha/service/singleton" ) func ServeWeb(port uint) *http.Server { gin.SetMode(gin.ReleaseMode) r := gin.Default() + if singleton.Conf.Debug { + gin.SetMode(gin.DebugMode) + pprof.Register(r) + } + r.Use(natGateway) tmpl := template.New("").Funcs(funcMap) var err error tmpl, err = tmpl.ParseFS(resource.TemplateFS, "template/**/*.html") @@ -32,10 +43,6 @@ func ServeWeb(port uint) *http.Server { } tmpl = loadThirdPartyTemplates(tmpl) r.SetHTMLTemplate(tmpl) - if singleton.Conf.Debug { - gin.SetMode(gin.DebugMode) - pprof.Register(r) - } r.Use(mygin.RecordPath) staticFs, err := fs.Sub(resource.StaticFS, "static") if err != nil { @@ -44,7 +51,6 @@ func ServeWeb(port uint) *http.Server { r.StaticFS("/static", http.FS(staticFs)) r.Static("/static-custom", "resource/static/custom") routers(r) - page404 := func(c *gin.Context) { mygin.ShowErrorPage(c, mygin.ErrInfo{ Code: http.StatusNotFound, @@ -238,3 +244,64 @@ var funcMap = template.FuncMap{ return singleton.StatusCodeToString(singleton.GetStatusCode(val)) }, } + +func natGateway(c *gin.Context) { + natConfig := singleton.GetNATConfigByDomain(c.Request.Host) + if natConfig == nil { + return + } + + singleton.ServerLock.RLock() + server := singleton.ServerList[natConfig.ServerID] + singleton.ServerLock.RUnlock() + if server == nil || server.TaskStream == nil { + c.Writer.WriteString("server not found or not connected") + c.Abort() + return + } + + streamId, err := uuid.GenerateUUID() + if err != nil { + c.Writer.WriteString(fmt.Sprintf("stream id error: %v", err)) + c.Abort() + return + } + + rpc.NezhaHandlerSingleton.CreateStream(streamId) + defer rpc.NezhaHandlerSingleton.CloseStream(streamId) + + taskData, err := json.Marshal(model.TaskNAT{ + StreamID: streamId, + Host: natConfig.Host, + }) + if err != nil { + c.Writer.WriteString(fmt.Sprintf("task data error: %v", err)) + c.Abort() + return + } + + if err := server.TaskStream.Send(&proto.Task{ + Type: model.TaskTypeNAT, + Data: string(taskData), + }); err != nil { + c.Writer.WriteString(fmt.Sprintf("send task error: %v", err)) + c.Abort() + return + } + + w, err := utils.NewRequestWrapper(c.Request, c.Writer) + if err != nil { + c.Writer.WriteString(fmt.Sprintf("request wrapper error: %v", err)) + c.Abort() + return + } + + if err := rpc.NezhaHandlerSingleton.UserConnected(streamId, w); err != nil { + c.Writer.WriteString(fmt.Sprintf("user connected error: %v", err)) + c.Abort() + return + } + + rpc.NezhaHandlerSingleton.StartStream(streamId, time.Second*10) + c.Abort() +} diff --git a/cmd/dashboard/controller/member_api.go b/cmd/dashboard/controller/member_api.go index 72f679a908..3f6ef5e5d1 100644 --- a/cmd/dashboard/controller/member_api.go +++ b/cmd/dashboard/controller/member_api.go @@ -45,6 +45,7 @@ func (ma *memberAPI) serve() { mr.POST("/batch-update-server-group", ma.batchUpdateServerGroup) mr.POST("/batch-delete-server", ma.batchDeleteServer) mr.POST("/notification", ma.addOrEditNotification) + mr.POST("/nat", ma.addOrEditNAT) mr.POST("/alert-rule", ma.addOrEditAlertRule) mr.POST("/setting", ma.updateSetting) mr.DELETE("/:model/:id", ma.delete) @@ -209,6 +210,11 @@ func (ma *memberAPI) delete(c *gin.Context) { if err == nil { singleton.OnDeleteNotification(id) } + case "nat": + err = singleton.DB.Unscoped().Delete(&model.NAT{}, "id = ?", id).Error + if err == nil { + singleton.OnNATUpdate() + } case "monitor": err = singleton.DB.Unscoped().Delete(&model.Monitor{}, "id = ?", id).Error if err == nil { @@ -733,6 +739,45 @@ func (ma *memberAPI) addOrEditNotification(c *gin.Context) { }) } +type natForm struct { + ID uint64 + Name string + ServerID uint64 + Host string + Domain string +} + +func (ma *memberAPI) addOrEditNAT(c *gin.Context) { + var nf natForm + var n model.NAT + err := c.ShouldBindJSON(&nf) + if err == nil { + n.Name = nf.Name + n.ID = nf.ID + n.Domain = nf.Domain + n.Host = nf.Host + n.ServerID = nf.ServerID + } + if err == nil { + if n.ID == 0 { + err = singleton.DB.Create(&n).Error + } else { + err = singleton.DB.Save(&n).Error + } + } + if err != nil { + c.JSON(http.StatusOK, model.Response{ + Code: http.StatusBadRequest, + Message: fmt.Sprintf("请求错误:%s", err), + }) + return + } + singleton.OnNATUpdate() + c.JSON(http.StatusOK, model.Response{ + Code: http.StatusOK, + }) +} + type alertRuleForm struct { ID uint64 Name string diff --git a/cmd/dashboard/controller/member_page.go b/cmd/dashboard/controller/member_page.go index d6e3b5c1c9..1982144c70 100644 --- a/cmd/dashboard/controller/member_page.go +++ b/cmd/dashboard/controller/member_page.go @@ -27,6 +27,7 @@ func (mp *memberPage) serve() { mr.GET("/monitor", mp.monitor) mr.GET("/cron", mp.cron) mr.GET("/notification", mp.notification) + mr.GET("/nat", mp.nat) mr.GET("/setting", mp.setting) mr.GET("/api", mp.api) } @@ -77,6 +78,15 @@ func (mp *memberPage) notification(c *gin.Context) { })) } +func (mp *memberPage) nat(c *gin.Context) { + var data []model.NAT + singleton.DB.Find(&data) + c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/nat", mygin.CommonEnvironment(c, gin.H{ + "Title": singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: "NAT"}), + "NAT": data, + })) +} + func (mp *memberPage) setting(c *gin.Context) { c.HTML(http.StatusOK, "dashboard-"+singleton.Conf.Site.DashboardTheme+"/setting", mygin.CommonEnvironment(c, gin.H{ "Title": singleton.Localizer.MustLocalize(&i18n.LocalizeConfig{MessageID: "Settings"}), diff --git a/model/monitor.go b/model/monitor.go index b858593c0e..459a0363f6 100644 --- a/model/monitor.go +++ b/model/monitor.go @@ -21,12 +21,18 @@ const ( TaskTypeUpgrade TaskTypeKeepalive TaskTypeTerminalGRPC + TaskTypeNAT ) type TerminalTask struct { StreamID string } +type TaskNAT struct { + StreamID string + Host string +} + const ( MonitorCoverAll = iota MonitorCoverIgnoreAll diff --git a/model/nat.go b/model/nat.go new file mode 100644 index 0000000000..83ac5fac34 --- /dev/null +++ b/model/nat.go @@ -0,0 +1,9 @@ +package model + +type NAT struct { + Common + Name string + ServerID uint64 + Host string + Domain string `gorm:"unique"` +} diff --git a/pkg/mygin/mygin.go b/pkg/mygin/mygin.go index 604ffd3afc..45df0b12d5 100644 --- a/pkg/mygin/mygin.go +++ b/pkg/mygin/mygin.go @@ -16,6 +16,7 @@ var adminPage = map[string]bool{ "/monitor": true, "/setting": true, "/notification": true, + "/nat": true, "/cron": true, "/api": true, } diff --git a/pkg/utils/request_wrapper.go b/pkg/utils/request_wrapper.go new file mode 100644 index 0000000000..9699129a0e --- /dev/null +++ b/pkg/utils/request_wrapper.go @@ -0,0 +1,56 @@ +package utils + +import ( + "bytes" + "io" + "net" + "net/http" + + "github.com/gin-gonic/gin" +) + +var _ io.ReadWriteCloser = &RequestWrapper{} + +type RequestWrapper struct { + req *http.Request + reader *bytes.Buffer + writer net.Conn +} + +func NewRequestWrapper(req *http.Request, writer gin.ResponseWriter) (*RequestWrapper, error) { + conn, _, err := writer.Hijack() + if err != nil { + return nil, err + } + buf := bytes.NewBuffer(nil) + if err = req.Write(buf); err != nil { + return nil, err + } + return &RequestWrapper{ + req: req, + reader: buf, + writer: conn, + }, nil +} + +func (rw *RequestWrapper) Read(p []byte) (int, error) { + count, err := rw.reader.Read(p) + if err == nil { + return count, nil + } + if err != io.EOF { + return count, err + } + // request 数据读完之后等待客户端断开连接或 grpc 超时 + return rw.writer.Read(p) +} + +func (rw *RequestWrapper) Write(p []byte) (int, error) { + return rw.writer.Write(p) +} + +func (rw *RequestWrapper) Close() error { + rw.req.Body.Close() + rw.writer.Close() + return nil +} diff --git a/pkg/websocketx/safe_conn.go b/pkg/websocketx/safe_conn.go index 3b45c759da..1d29019af6 100644 --- a/pkg/websocketx/safe_conn.go +++ b/pkg/websocketx/safe_conn.go @@ -1,11 +1,14 @@ package websocketx import ( + "io" "sync" "github.com/gorilla/websocket" ) +var _ io.ReadWriteCloser = &Conn{} + type Conn struct { *websocket.Conn writeLock *sync.Mutex diff --git a/resource/l10n/en-US.toml b/resource/l10n/en-US.toml index 4ecdf4a9dc..ae48ea324d 100644 --- a/resource/l10n/en-US.toml +++ b/resource/l10n/en-US.toml @@ -648,3 +648,6 @@ other = "Disable Switch Template in Frontend" [ServersOnWorldMap] other = "Servers On World Map" + +[NAT] +other = "NAT" \ No newline at end of file diff --git a/resource/l10n/es-ES.toml b/resource/l10n/es-ES.toml index b57d8c1cc9..59a11b7f46 100644 --- a/resource/l10n/es-ES.toml +++ b/resource/l10n/es-ES.toml @@ -647,4 +647,7 @@ other = "Temperatura" other = "Deshabilitar Cambio de Plantilla en Frontend" [ServersOnWorldMap] -other = "Servidores en el mapa mundial" \ No newline at end of file +other = "Servidores en el mapa mundial" + +[NAT] +other = "NAT" \ No newline at end of file diff --git a/resource/l10n/zh-CN.toml b/resource/l10n/zh-CN.toml index 14fa21cce8..d7c582928d 100644 --- a/resource/l10n/zh-CN.toml +++ b/resource/l10n/zh-CN.toml @@ -648,3 +648,6 @@ other = "禁止前台切换模板" [ServersOnWorldMap] other = "服务器世界分布图" + +[NAT] +other = "内网穿透" \ No newline at end of file diff --git a/resource/l10n/zh-TW.toml b/resource/l10n/zh-TW.toml index 1ce927d8f1..e158346729 100644 --- a/resource/l10n/zh-TW.toml +++ b/resource/l10n/zh-TW.toml @@ -648,3 +648,6 @@ other = "禁止前台切換主題" [ServersOnWorldMap] other = "伺服器世界分布圖" + +[NAT] +other = "NAT" \ No newline at end of file diff --git a/resource/static/main.js b/resource/static/main.js index 9b8f500a25..e5e91a16b3 100644 --- a/resource/static/main.js +++ b/resource/static/main.js @@ -91,6 +91,7 @@ function showFormModal(modelSelector, formID, URL, getData) { item.name.endsWith("_id") || item.name === "id" || item.name === "ID" || + item.name === "ServerID" || item.name === "RequestType" || item.name === "RequestMethod" || item.name === "TriggerMode" || @@ -255,6 +256,28 @@ function addOrEditNotification(notification) { ); } +function addOrEditNAT(nat) { + const modal = $(".nat.modal"); + modal.children(".header").text((nat ? LANG.Edit : LANG.Add)); + modal + .find(".nezha-primary-btn.button") + .html( + nat + ? LANG.Edit + '' + : LANG.Add + '' + ); + modal.find("input[name=ID]").val(nat ? nat.ID : null); + modal.find("input[name=ServerID]").val(nat ? nat.ServerID : null); + modal.find("input[name=Name]").val(nat ? nat.Name : null); + modal.find("input[name=Host]").val(nat ? nat.Host : null); + modal.find("input[name=Domain]").val(nat ? nat.Domain : null); + showFormModal( + ".nat.modal", + "#natForm", + "/api/nat" + ); +} + function connectToServer(id) { post('/terminal', { Host: window.location.host, Protocol: window.location.protocol, ID: id }) } diff --git a/resource/template/common/footer.html b/resource/template/common/footer.html index 8b17c69a51..015e6ae8df 100644 --- a/resource/template/common/footer.html +++ b/resource/template/common/footer.html @@ -10,7 +10,7 @@ - + +{{end}} \ No newline at end of file diff --git a/service/rpc/io_stream.go b/service/rpc/io_stream.go index 74a7c126be..0a56d8b356 100644 --- a/service/rpc/io_stream.go +++ b/service/rpc/io_stream.go @@ -136,6 +136,5 @@ LOOP: }() <-endCh - return err } diff --git a/service/singleton/api.go b/service/singleton/api.go index 0e40eed66e..d338832bfd 100644 --- a/service/singleton/api.go +++ b/service/singleton/api.go @@ -15,8 +15,6 @@ var ( ServerAPI = &ServerAPIService{} MonitorAPI = &MonitorAPIService{} - - once = &sync.Once{} ) type ServerAPIService struct{} @@ -78,7 +76,7 @@ func InitAPI() { UserIDToApiTokenList = make(map[uint64][]string) } -func LoadAPI() { +func loadAPI() { InitAPI() var tokenList []*model.ApiToken DB.Find(&tokenList) diff --git a/service/singleton/crontask.go b/service/singleton/crontask.go index 6a1257ef95..8ab22890c0 100644 --- a/service/singleton/crontask.go +++ b/service/singleton/crontask.go @@ -24,8 +24,8 @@ func InitCronTask() { Crons = make(map[uint64]*model.Cron) } -// LoadCronTasks 加载计划任务 -func LoadCronTasks() { +// loadCronTasks 加载计划任务 +func loadCronTasks() { InitCronTask() var crons []model.Cron DB.Find(&crons) diff --git a/service/singleton/nat.go b/service/singleton/nat.go new file mode 100644 index 0000000000..5a2b7cb241 --- /dev/null +++ b/service/singleton/nat.go @@ -0,0 +1,31 @@ +package singleton + +import ( + "sync" + + "github.com/naiba/nezha/model" +) + +var natCache = make(map[string]*model.NAT) +var natCacheRwLock = new(sync.RWMutex) + +func initNAT() { + OnNATUpdate() +} + +func OnNATUpdate() { + natCacheRwLock.Lock() + defer natCacheRwLock.Unlock() + var nats []*model.NAT + DB.Find(&nats) + natCache = make(map[string]*model.NAT) + for i := 0; i < len(nats); i++ { + natCache[nats[i].Domain] = nats[i] + } +} + +func GetNATConfigByDomain(domain string) *model.NAT { + natCacheRwLock.RLock() + defer natCacheRwLock.RUnlock() + return natCache[domain] +} diff --git a/service/singleton/notification.go b/service/singleton/notification.go index afc14cdaa1..a14cfa66e1 100644 --- a/service/singleton/notification.go +++ b/service/singleton/notification.go @@ -24,8 +24,8 @@ func InitNotification() { NotificationIDToTag = make(map[uint64]string) } -// LoadNotifications 从 DB 初始化通知方式相关参数 -func LoadNotifications() { +// loadNotifications 从 DB 初始化通知方式相关参数 +func loadNotifications() { InitNotification() notificationsLock.Lock() defer notificationsLock.Unlock() diff --git a/service/singleton/server.go b/service/singleton/server.go index 554f5ff751..b1209fdead 100644 --- a/service/singleton/server.go +++ b/service/singleton/server.go @@ -25,8 +25,8 @@ func InitServer() { ServerTagToIDList = make(map[string][]uint64) } -// LoadServers 加载服务器列表并根据ID排序 -func LoadServers() { +// loadServers 加载服务器列表并根据ID排序 +func loadServers() { InitServer() var servers []model.Server DB.Find(&servers) diff --git a/service/singleton/singleton.go b/service/singleton/singleton.go index 4cca923c02..9bde10e7f0 100644 --- a/service/singleton/singleton.go +++ b/service/singleton/singleton.go @@ -34,10 +34,11 @@ func InitTimezoneAndCache() { // LoadSingleton 加载子服务并执行 func LoadSingleton() { - LoadNotifications() // 加载通知服务 - LoadServers() // 加载服务器列表 - LoadCronTasks() // 加载定时任务 - LoadAPI() + loadNotifications() // 加载通知服务 + loadServers() // 加载服务器列表 + loadCronTasks() // 加载定时任务 + loadAPI() + initNAT() } // InitConfigFromPath 从给出的文件路径中加载配置 @@ -47,11 +48,11 @@ func InitConfigFromPath(path string) { if err != nil { panic(err) } - ValidateConfig() + validateConfig() } -// ValidateConfig 验证配置文件有效性 -func ValidateConfig() { +// validateConfig 验证配置文件有效性 +func validateConfig() { var err error if Conf.DDNS.Provider == "" { err = ValidateDDNSProvidersFromProfiles() @@ -82,7 +83,8 @@ func InitDBFromPath(path string) { } err = DB.AutoMigrate(model.Server{}, model.User{}, model.Notification{}, model.AlertRule{}, model.Monitor{}, - model.MonitorHistory{}, model.Cron{}, model.Transfer{}, model.ApiToken{}) + model.MonitorHistory{}, model.Cron{}, model.Transfer{}, + model.ApiToken{}, model.NAT{}) if err != nil { panic(err) }