Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
add recover middleware
  • Loading branch information
vintikzzz committed Dec 2, 2024
1 parent d53c5f0 commit 8153ac4
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 75 deletions.
2 changes: 1 addition & 1 deletion server/services/touch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type TouchWriter struct {
h string
}

func NewTouchWrtier(w http.ResponseWriter, tm *TorrentMap, h string) *TouchWriter {
func NewTouchWriter(w http.ResponseWriter, tm *TorrentMap, h string) *TouchWriter {
return &TouchWriter{
ResponseWriter: w,
tm: tm,
Expand Down
22 changes: 21 additions & 1 deletion server/services/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/http"
"runtime/debug"

logrusmiddleware "github.com/bakins/logrus-middleware"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -45,6 +46,25 @@ func NewWeb(c *cli.Context, ws *WebSeeder) *Web {
}
}

// RecoverMiddleware is a middleware that recovers from panics and logs the error.
func RecoverMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// Log the error and stack trace
log.WithFields(log.Fields{
"error": fmt.Sprintf("%v", err),
"stack": string(debug.Stack()),
}).Error("Recovered from panic")

// Return 500 Internal Server Error
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}

func (s *Web) Serve() error {
addr := fmt.Sprintf("%s:%d", s.host, s.port)
ln, err := net.Listen("tcp", addr)
Expand All @@ -58,7 +78,7 @@ func (s *Web) Serve() error {
l := logrusmiddleware.Middleware{
Logger: logger,
}
mux.Handle("/", l.Handler(s.ws, ""))
mux.Handle("/", l.Handler(RecoverMiddleware(s.ws), ""))
log.Infof("serving Web at %v", fmt.Sprintf("%s:%d", s.host, s.port))
return http.Serve(s.ln, mux)

Expand Down
164 changes: 91 additions & 73 deletions server/services/web_seeder.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
Expand All @@ -20,8 +21,6 @@ var sha1R = regexp.MustCompile("^[0-9a-f]{5,40}$")

const (
SourceTorrentPath = "source.torrent"
//MaxReadahead = 250 * 1024 * 1024
//MinReadahead = 1024 * 1024
)

type WebSeeder struct {
Expand Down Expand Up @@ -99,96 +98,96 @@ func (s *WebSeeder) serveFile(w http.ResponseWriter, r *http.Request, h string,
if err != nil {
log.Error(err)
}
cp, err := s.fcm.Get(h, p)

_, download := r.URL.Query()["download"]

logWIthField := log.WithFields(log.Fields{
"hash": h,
"path": r.URL.Path,
"method": r.Method,
"remoteAddr": r.RemoteAddr,
"download": download,
"range": r.Header.Get("Range"),
})

w, reader, err := s.getReader(w, h, p)
if err != nil {
log.Error(err)
http.Error(w, "failed to get torrent", http.StatusInternalServerError)
http.Error(w, "failed to get reader", http.StatusInternalServerError)
return
}
if cp != "" {
if _, ok := r.URL.Query()["stats"]; ok {
http.NotFound(w, r)
return
}
if _, ok := r.URL.Query()["done"]; ok {
return
}
w.Header().Set("Etag", fmt.Sprintf("\"%x\"", sha1.Sum([]byte(h+p))))
http.ServeFile(w, r, cp)
return
}
if _, ok := r.URL.Query()["done"]; ok {
if reader == nil {
logWIthField.Info("file not found")
http.NotFound(w, r)
return
}
if _, ok := r.URL.Query()["stats"]; ok {
s.serveStats(w, r, h, p)
return
defer func(reader io.ReadSeekCloser) {
_ = reader.Close()
}(reader)

logWIthField.Info("serve file")
if download {
w.Header().Add("Content-Type", "application/octet-stream")
w.Header().Add("Content-Disposition", "attachment; filename=\""+filepath.Base(p)+"\"")
}
logWIthField := log.WithField("hash", h)
logWIthField = logWIthField.WithField("path", r.URL.Path)
logWIthField = logWIthField.WithField("method", r.Method)
logWIthField = logWIthField.WithField("remoteAddr", r.RemoteAddr)
found := false
download := true
keys, ok := r.URL.Query()["download"]
if !ok || len(keys[0]) < 1 {
download = false
if r.Header.Get("Origin") != "" {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", "*")
}
logWIthField = logWIthField.WithField("download", download)
w.Header().Set("Last-Modified", time.Unix(0, 0).Format(http.TimeFormat))
w.Header().Set("Etag", fmt.Sprintf("\"%x\"", sha1.Sum([]byte(h+p))))
http.ServeContent(w, r, p, time.Unix(0, 0), reader)
}

t, err := s.tm.Get(h)
func (s *WebSeeder) getReader(w http.ResponseWriter, h string, p string) (http.ResponseWriter, io.ReadSeekCloser, error) {
cp, err := s.fcm.Get(h, p)
if err != nil {
return nil, nil, err
}

if cp != "" {
return s.openCachedFile(w, cp)
}

return s.getTorrentReader(w, h, p)
}

func (s *WebSeeder) openCachedFile(w http.ResponseWriter, cp string) (http.ResponseWriter, io.ReadSeekCloser, error) {
file, err := os.Open(cp)
if err != nil {
log.Error(err)
http.Error(w, "failed to get torrent", http.StatusInternalServerError)
return
return w, nil, err
}
return w, file, nil
}

func (s *WebSeeder) getTorrentReader(w http.ResponseWriter, h string, p string) (http.ResponseWriter, io.ReadSeekCloser, error) {
t, err := s.tm.Get(h)
if err != nil {
return w, nil, err
}

for _, f := range t.Files() {
if f.Path() == p {
logWIthField.WithField("range", r.Header.Get("Range")).Info("serve file")
if download {
w.Header().Add("Content-Type", "application/octet-stream")
w.Header().Add("Content-Disposition", "attachment; filename=\""+filepath.Base(p)+"\"")
}
if r.Header.Get("Origin") != "" {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", "*")
}
var reader io.ReadSeeker
torReader := f.NewReader()
torReader.SetResponsive()
//torReader.SetReadaheadFunc(func(r torrent.ReadaheadContext) int64 {
// p := f.Length() / 100
// if p < MinReadahead {
// p = MinReadahead
// }
// ra := (r.CurrentPos - r.ContiguousReadStartPos) * 2
// if ra < p {
// return p
// }
// if ra > MaxReadahead {
// return MaxReadahead
// }
// return ra
//})
reader = torReader
w.Header().Set("Last-Modified", time.Unix(0, 0).Format(http.TimeFormat))
w.Header().Set("Etag", fmt.Sprintf("\"%x\"", sha1.Sum([]byte(t.InfoHash().String()+p))))
w = NewTouchWrtier(w, s.tm, h)
http.ServeContent(w, r, f.Path(), time.Unix(0, 0), reader)
found = true
return NewTouchWriter(w, s.tm, h), torReader, nil
}
}
if !found {
logWIthField.Info("file not found")

http.NotFound(w, r)
}
return w, nil, nil
}

func (s *WebSeeder) serveStats(w http.ResponseWriter, r *http.Request, h string, p string) {
err := s.st.Serve(w, r, h, p)
cp, err := s.fcm.Get(h, p)
if err != nil {
log.Error(err)
http.Error(w, "failed to get torrent", http.StatusInternalServerError)
return
}
if cp != "" {
http.NotFound(w, r)
return
}
err = s.st.Serve(w, r, h, p)
if err != nil {
log.Error(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -225,17 +224,36 @@ func (s *WebSeeder) renderIndex(w http.ResponseWriter, r *http.Request) {
}

func (s *WebSeeder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.getHash(r) == "" {
h := s.getHash(r)
if h == "" {
s.renderIndex(w, r)
} else {
p := r.URL.Path[1:]
p = strings.TrimPrefix(p, s.getHash(r)+"/")
p = strings.TrimPrefix(p, h+"/")
if p == "" {
s.renderTorrentIndex(w, r, s.getHash(r))
s.renderTorrentIndex(w, r, h)
} else if p == SourceTorrentPath {
s.renderTorrent(w, s.getHash(r))
} else if _, ok := r.URL.Query()["stats"]; ok {
s.serveStats(w, r, h, p)
} else if _, ok := r.URL.Query()["done"]; ok {
s.serveDone(w, r, h, p)
} else {
s.serveFile(w, r, s.getHash(r), p)
}
}
}

func (s *WebSeeder) serveDone(w http.ResponseWriter, r *http.Request, h string, p string) {
cp, err := s.fcm.Get(h, p)
if err != nil {
log.Error(err)
http.Error(w, "failed to get torrent", http.StatusInternalServerError)
return
}
if cp != "" {
return
} else {
http.NotFound(w, r)
}
}

0 comments on commit 8153ac4

Please sign in to comment.