From e1325c3a14065d7d1cf15d82a9acd8c789865814 Mon Sep 17 00:00:00 2001 From: seekwe Date: Sat, 7 Sep 2024 11:05:46 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=99=EF=B8=8F=20refactor:=20Refactoring=20?= =?UTF-8?q?handles=20static=20files?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- znet/router.go | 47 +++++++++++++++++++++++++++++++++-------------- znet/web_test.go | 17 ++++++++++++++--- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/znet/router.go b/znet/router.go index f94f85a..8a9148d 100644 --- a/znet/router.go +++ b/znet/router.go @@ -3,6 +3,8 @@ package znet import ( "errors" "fmt" + "io" + "io/fs" "net/http" "path" "regexp" @@ -63,6 +65,18 @@ func temporarilyTurnOffTheLog(e *Engine, msg string) func() { } } +func (c *Context) toHTTPError(err error) { + if errors.Is(err, fs.ErrNotExist) { + c.String(http.StatusNotFound, "404 page not found") + return + } + if errors.Is(err, fs.ErrPermission) { + c.String(http.StatusForbidden, "403 Forbidden") + return + } + c.String(http.StatusInternalServerError, "500 Internal Server Error") +} + func (e *Engine) StaticFS(relativePath string, fs http.FileSystem, moreHandler ...Handler) { var urlPattern string @@ -72,19 +86,23 @@ func (e *Engine) StaticFS(relativePath string, fs http.FileSystem, moreHandler . f = "%s %-40s" } log := temporarilyTurnOffTheLog(e, routeLog(e.Log, f, "FILE", ap)) - fileServer := http.StripPrefix(ap, http.FileServer(fs)) handler := func(c *Context) { - for key, value := range c.header { - for i := range value { - header := value[i] - if i == 0 { - c.Writer.Header().Set(key, header) - } else { - c.Writer.Header().Add(key, header) - } - } + p := strings.TrimPrefix(c.Request.URL.Path, relativePath) + f, err := fs.Open(p) + if err != nil { + c.toHTTPError(err) + return } - fileServer.ServeHTTP(c.Writer, c.Request) + + defer f.Close() + c.prevData.Content, err = io.ReadAll(f) + if err != nil { + c.toHTTPError(err) + return + } + + c.prevData.Type = zfile.GetMimeType(p, c.prevData.Content) + c.prevData.Code.Store(http.StatusOK) } if strings.HasSuffix(relativePath, "/") { urlPattern = path.Join(relativePath, "*") @@ -106,7 +124,7 @@ func (e *Engine) Static(relativePath, root string, moreHandler ...Handler) { e.StaticFS(relativePath, http.Dir(root), moreHandler...) } -func (e *Engine) StaticFile(relativePath, filepath string) { +func (e *Engine) StaticFile(relativePath, filepath string, moreHandler ...Handler) { handler := func(c *Context) { c.File(filepath) } @@ -116,8 +134,9 @@ func (e *Engine) StaticFile(relativePath, filepath string) { tip = routeLog(e.Log, "%s %-40s", "FILE", relativePath) } log := temporarilyTurnOffTheLog(e, tip) - e.GET(relativePath, handler) - e.HEAD(relativePath, handler) + e.GET(relativePath, handler, moreHandler...) + e.HEAD(relativePath, handler, moreHandler...) + e.OPTIONS(relativePath, handler, moreHandler...) log() } diff --git a/znet/web_test.go b/znet/web_test.go index a889f6e..4419897 100644 --- a/znet/web_test.go +++ b/znet/web_test.go @@ -180,7 +180,6 @@ func TestMoreMethod(t *testing.T) { g.TRACE("/", h("TRACE")) g.POST("/", h("POST")) g.PUT("/", h("PUT")) - for _, v := range []string{"CONNECT", "TRACE", "PUT", "DELETE", "POST", "OPTIONS"} { w = httptest.NewRecorder() req, _ = http.NewRequest(v, "/TestMore/", nil) @@ -658,8 +657,10 @@ func TestBind(t *testing.T) { } tt := zlsgo.NewTest(t) r := newServer() - w := newRequest(r, "POST", []string{"/TestBind", - `{"appid":"Aid","appids":[{"label":"isLabel","id":"333"}]}`, ContentTypeJSON}, "/TestBind", func(c *Context) { + w := newRequest(r, "POST", []string{ + "/TestBind", + `{"appid":"Aid","appids":[{"label":"isLabel","id":"333"}]}`, ContentTypeJSON, + }, "/TestBind", func(c *Context) { json, _ := c.GetJSONs() var appids []AppInfo json.Get("appids").ForEach(func(key, value *zjson.Res) bool { @@ -887,3 +888,13 @@ func TestMethodAndName(t *testing.T) { t.Log(r.GenerateURL(http.MethodPost, "non existent", nil)) } + +func TestStatic(t *testing.T) { + tt := zlsgo.NewTest(t) + r := New("Static") + r.StaticFile("/web.go", "../znet/web.go") + r.Static("/ss", "../") + w := request(r, "GET", "/ss/znet/web.go", nil) + tt.Equal(200, w.Code) + tt.EqualTrue(len(w.Body.String()) > 10*zfile.KB) +}