diff --git a/server/services/touch_writer.go b/server/services/touch_writer.go index 2714e32..0d97894 100644 --- a/server/services/touch_writer.go +++ b/server/services/touch_writer.go @@ -13,7 +13,7 @@ type TouchWriter struct { h string } -func NewTouchWrtier(w http.ResponseWriter, tm *TorrentMap, h string) *TouchWriter { +func NewTouchWriter(w http.ResponseWriter, tm *TorrentMap, h string) *TouchWriter { return &TouchWriter{ ResponseWriter: w, tm: tm, diff --git a/server/services/web.go b/server/services/web.go index 9ef9626..2bb4305 100644 --- a/server/services/web.go +++ b/server/services/web.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/http" + "runtime/debug" logrusmiddleware "github.com/bakins/logrus-middleware" log "github.com/sirupsen/logrus" @@ -45,6 +46,25 @@ func NewWeb(c *cli.Context, ws *WebSeeder) *Web { } } +// RecoverMiddleware is a middleware that recovers from panics and logs the error. +func RecoverMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + // Log the error and stack trace + log.WithFields(log.Fields{ + "error": fmt.Sprintf("%v", err), + "stack": string(debug.Stack()), + }).Error("Recovered from panic") + + // Return 500 Internal Server Error + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} + func (s *Web) Serve() error { addr := fmt.Sprintf("%s:%d", s.host, s.port) ln, err := net.Listen("tcp", addr) @@ -58,7 +78,7 @@ func (s *Web) Serve() error { l := logrusmiddleware.Middleware{ Logger: logger, } - mux.Handle("/", l.Handler(s.ws, "")) + mux.Handle("/", l.Handler(RecoverMiddleware(s.ws), "")) log.Infof("serving Web at %v", fmt.Sprintf("%s:%d", s.host, s.port)) return http.Serve(s.ln, mux) diff --git a/server/services/web_seeder.go b/server/services/web_seeder.go index 29c810f..460b7e7 100644 --- a/server/services/web_seeder.go +++ b/server/services/web_seeder.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/url" + "os" "path/filepath" "regexp" "strings" @@ -20,8 +21,6 @@ var sha1R = regexp.MustCompile("^[0-9a-f]{5,40}$") const ( SourceTorrentPath = "source.torrent" - //MaxReadahead = 250 * 1024 * 1024 - //MinReadahead = 1024 * 1024 ) type WebSeeder struct { @@ -99,96 +98,96 @@ func (s *WebSeeder) serveFile(w http.ResponseWriter, r *http.Request, h string, if err != nil { log.Error(err) } - cp, err := s.fcm.Get(h, p) + + _, download := r.URL.Query()["download"] + + logWIthField := log.WithFields(log.Fields{ + "hash": h, + "path": r.URL.Path, + "method": r.Method, + "remoteAddr": r.RemoteAddr, + "download": download, + "range": r.Header.Get("Range"), + }) + + w, reader, err := s.getReader(w, h, p) if err != nil { log.Error(err) - http.Error(w, "failed to get torrent", http.StatusInternalServerError) + http.Error(w, "failed to get reader", http.StatusInternalServerError) return } - if cp != "" { - if _, ok := r.URL.Query()["stats"]; ok { - http.NotFound(w, r) - return - } - if _, ok := r.URL.Query()["done"]; ok { - return - } - w.Header().Set("Etag", fmt.Sprintf("\"%x\"", sha1.Sum([]byte(h+p)))) - http.ServeFile(w, r, cp) - return - } - if _, ok := r.URL.Query()["done"]; ok { + if reader == nil { + logWIthField.Info("file not found") http.NotFound(w, r) return } - if _, ok := r.URL.Query()["stats"]; ok { - s.serveStats(w, r, h, p) - return + defer func(reader io.ReadSeekCloser) { + _ = reader.Close() + }(reader) + + logWIthField.Info("serve file") + if download { + w.Header().Add("Content-Type", "application/octet-stream") + w.Header().Add("Content-Disposition", "attachment; filename=\""+filepath.Base(p)+"\"") } - logWIthField := log.WithField("hash", h) - logWIthField = logWIthField.WithField("path", r.URL.Path) - logWIthField = logWIthField.WithField("method", r.Method) - logWIthField = logWIthField.WithField("remoteAddr", r.RemoteAddr) - found := false - download := true - keys, ok := r.URL.Query()["download"] - if !ok || len(keys[0]) < 1 { - download = false + if r.Header.Get("Origin") != "" { + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Origin", "*") } - logWIthField = logWIthField.WithField("download", download) + w.Header().Set("Last-Modified", time.Unix(0, 0).Format(http.TimeFormat)) + w.Header().Set("Etag", fmt.Sprintf("\"%x\"", sha1.Sum([]byte(h+p)))) + http.ServeContent(w, r, p, time.Unix(0, 0), reader) +} - t, err := s.tm.Get(h) +func (s *WebSeeder) getReader(w http.ResponseWriter, h string, p string) (http.ResponseWriter, io.ReadSeekCloser, error) { + cp, err := s.fcm.Get(h, p) + if err != nil { + return nil, nil, err + } + + if cp != "" { + return s.openCachedFile(w, cp) + } + + return s.getTorrentReader(w, h, p) +} +func (s *WebSeeder) openCachedFile(w http.ResponseWriter, cp string) (http.ResponseWriter, io.ReadSeekCloser, error) { + file, err := os.Open(cp) if err != nil { - log.Error(err) - http.Error(w, "failed to get torrent", http.StatusInternalServerError) - return + return w, nil, err } + return w, file, nil +} + +func (s *WebSeeder) getTorrentReader(w http.ResponseWriter, h string, p string) (http.ResponseWriter, io.ReadSeekCloser, error) { + t, err := s.tm.Get(h) + if err != nil { + return w, nil, err + } + for _, f := range t.Files() { if f.Path() == p { - logWIthField.WithField("range", r.Header.Get("Range")).Info("serve file") - if download { - w.Header().Add("Content-Type", "application/octet-stream") - w.Header().Add("Content-Disposition", "attachment; filename=\""+filepath.Base(p)+"\"") - } - if r.Header.Get("Origin") != "" { - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Origin", "*") - } - var reader io.ReadSeeker torReader := f.NewReader() torReader.SetResponsive() - //torReader.SetReadaheadFunc(func(r torrent.ReadaheadContext) int64 { - // p := f.Length() / 100 - // if p < MinReadahead { - // p = MinReadahead - // } - // ra := (r.CurrentPos - r.ContiguousReadStartPos) * 2 - // if ra < p { - // return p - // } - // if ra > MaxReadahead { - // return MaxReadahead - // } - // return ra - //}) - reader = torReader - w.Header().Set("Last-Modified", time.Unix(0, 0).Format(http.TimeFormat)) - w.Header().Set("Etag", fmt.Sprintf("\"%x\"", sha1.Sum([]byte(t.InfoHash().String()+p)))) - w = NewTouchWrtier(w, s.tm, h) - http.ServeContent(w, r, f.Path(), time.Unix(0, 0), reader) - found = true + return NewTouchWriter(w, s.tm, h), torReader, nil } } - if !found { - logWIthField.Info("file not found") - - http.NotFound(w, r) - } + return w, nil, nil } func (s *WebSeeder) serveStats(w http.ResponseWriter, r *http.Request, h string, p string) { - err := s.st.Serve(w, r, h, p) + cp, err := s.fcm.Get(h, p) + if err != nil { + log.Error(err) + http.Error(w, "failed to get torrent", http.StatusInternalServerError) + return + } + if cp != "" { + http.NotFound(w, r) + return + } + err = s.st.Serve(w, r, h, p) if err != nil { log.Error(err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -225,17 +224,36 @@ func (s *WebSeeder) renderIndex(w http.ResponseWriter, r *http.Request) { } func (s *WebSeeder) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if s.getHash(r) == "" { + h := s.getHash(r) + if h == "" { s.renderIndex(w, r) } else { p := r.URL.Path[1:] - p = strings.TrimPrefix(p, s.getHash(r)+"/") + p = strings.TrimPrefix(p, h+"/") if p == "" { - s.renderTorrentIndex(w, r, s.getHash(r)) + s.renderTorrentIndex(w, r, h) } else if p == SourceTorrentPath { s.renderTorrent(w, s.getHash(r)) + } else if _, ok := r.URL.Query()["stats"]; ok { + s.serveStats(w, r, h, p) + } else if _, ok := r.URL.Query()["done"]; ok { + s.serveDone(w, r, h, p) } else { s.serveFile(w, r, s.getHash(r), p) } } } + +func (s *WebSeeder) serveDone(w http.ResponseWriter, r *http.Request, h string, p string) { + cp, err := s.fcm.Get(h, p) + if err != nil { + log.Error(err) + http.Error(w, "failed to get torrent", http.StatusInternalServerError) + return + } + if cp != "" { + return + } else { + http.NotFound(w, r) + } +}