From ea6d0683be9f5284b6e1ad78751020067cb48a4d Mon Sep 17 00:00:00 2001 From: Dean Ward Date: Wed, 12 Feb 2025 22:10:41 +0000 Subject: [PATCH] Added streaming file uploads to reduce memory usage and provide progress feedback on the front end --- .gitignore | 3 +- frontend/src/api.js | 9 +- frontend/src/components/uploader.vue | 120 +++++++++++++++- frontend/src/utils.js | 12 +- handlers/progress.go | 120 ++++++++++++++++ handlers/shares.go | 199 +++++++++++++++------------ progress/progress.go | 81 +++++++++++ routes/routes.go | 6 + 8 files changed, 448 insertions(+), 102 deletions(-) create mode 100644 handlers/progress.go create mode 100644 progress/progress.go diff --git a/.gitignore b/.gitignore index 0b00543..d07c28c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ .DS_Store storage/ -erugo.db +*.db frontend/node_modules frontend/dist build/* private +*/*.db diff --git a/frontend/src/api.js b/frontend/src/api.js index 79e8ba7..71c9d00 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -169,8 +169,6 @@ export const createUser = async user => { return data.data } - - export const updateUser = async user => { const response = await fetchWithAuth(`${apiUrl}/api/users/${user.id}`, { method: 'PUT', @@ -274,7 +272,9 @@ export const saveLogo = async logoFile => { return data.data } -export const createShare = async (files, name, description) => { +// Share Methods +export const createShare = async (files, name, description, uploadId) => { + const formData = new FormData() files.forEach(file => { formData.append('files', file) @@ -282,7 +282,7 @@ export const createShare = async (files, name, description) => { formData.append('name', name) formData.append('description', description) - const response = await fetchWithAuth(`${apiUrl}/api/shares`, { + const response = await fetchWithAuth(`${apiUrl}/api/shares?uploadId=${uploadId}`, { method: 'POST', body: formData }) @@ -307,6 +307,7 @@ export const getShare = async id => { return data.data.share } +//misc methods export const getHealth = async () => { const response = await fetch(`${apiUrl}/api/health`) const data = await response.json() diff --git a/frontend/src/components/uploader.vue b/frontend/src/components/uploader.vue index fefb9f5..d835388 100644 --- a/frontend/src/components/uploader.vue +++ b/frontend/src/components/uploader.vue @@ -1,16 +1,19 @@ @@ -108,6 +148,17 @@
{{ niceFileSize(totalSize) }} / {{ niceFileSize(maxShareSize) }}
+
+
+
+
+
+
+ {{ Math.round(uploadProgress) }}% +
{{ niceFileSize(uploadedBytes) }} / {{ niceFileSize(totalBytes) }}
+
+
+
@@ -165,3 +216,62 @@
+ + diff --git a/frontend/src/utils.js b/frontend/src/utils.js index cee8c5a..1063749 100644 --- a/frontend/src/utils.js +++ b/frontend/src/utils.js @@ -1,5 +1,13 @@ - +const simpleUUID = () => { + //this isn't cryptographically secure, but it's good enough for our purposes + //our purposes being a simple unique string to track upload progress via SSE + return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) { + const r = Math.random() * 16 | 0; + const v = c === 'x' ? r : (r & 0x3 | 0x8); + return v.toString(16); + }); +} const niceFileSize = size => { //return in most readable format @@ -48,4 +56,4 @@ const getApiUrl = () => { return url } -export { niceFileSize, niceFileType, niceExpirationDate, timeUntilExpiration, getApiUrl } +export { niceFileSize, niceFileType, niceExpirationDate, timeUntilExpiration, getApiUrl, simpleUUID } diff --git a/handlers/progress.go b/handlers/progress.go new file mode 100644 index 0000000..d7d8858 --- /dev/null +++ b/handlers/progress.go @@ -0,0 +1,120 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/DeanWard/erugo/progress" + "github.com/gorilla/mux" +) + +func UploadProgressHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uploadID := mux.Vars(r)["uploadId"] + if uploadID == "" { + http.Error(w, "Upload ID required", http.StatusBadRequest) + return + } + + // Set headers for SSE + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + // Get the progress channel for this upload + tracker := progress.GetTracker() + progressChan, exists := tracker.GetUploadChannel(uploadID) + if !exists { + http.Error(w, "Upload not found", http.StatusNotFound) + return + } + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + // Use the request's context to detect client disconnect + ctx := r.Context() + + for { + select { + case <-ctx.Done(): + // Client disconnected + tracker.DeleteUpload(uploadID) + return + + case progress, ok := <-progressChan: + if !ok { + // Channel closed - upload complete or failed + return + } + + // Send progress update + data, _ := json.Marshal(progress) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + } + }) +} + +type ProgressReader struct { + Reader io.Reader + Size int64 // Current file size + TotalFileSize int64 // Size of all files + bytesRead int64 // Current file bytes read + totalRead int64 // Total bytes read across all files + lastUpdate time.Time + uploadID string + tracker *progress.ProgressTracker +} + +func NewProgressReader(reader io.Reader, size int64, totalSize int64, totalRead int64, uploadID string) *ProgressReader { + return &ProgressReader{ + Reader: reader, + Size: size, + TotalFileSize: totalSize, + totalRead: totalRead, + uploadID: uploadID, + tracker: progress.GetTracker(), + lastUpdate: time.Now(), + } +} + +func (pr *ProgressReader) Read(p []byte) (int, error) { + n, err := pr.Reader.Read(p) + + pr.bytesRead += int64(n) + pr.totalRead += int64(n) + + // Update progress every 500ms + if time.Since(pr.lastUpdate) > 500*time.Millisecond { + if progressChan, exists := pr.tracker.GetUploadChannel(pr.uploadID); exists { + + totalProgress := float64(pr.totalRead) / float64(pr.TotalFileSize) * 100 + + select { + case progressChan <- progress.Progress{ + BytesRead: pr.bytesRead, + TotalSize: pr.Size, + TotalBytesRead: pr.totalRead, + TotalFileSize: pr.TotalFileSize, + TotalProgress: totalProgress, + UploadID: pr.uploadID, + LastUpdate: time.Now(), + }: + + default: + } + } + pr.lastUpdate = time.Now() + } + + return n, err +} diff --git a/handlers/shares.go b/handlers/shares.go index a707101..5bbd9f2 100644 --- a/handlers/shares.go +++ b/handlers/shares.go @@ -3,21 +3,21 @@ package handlers import ( "database/sql" "fmt" + "io" "net/http" "os" "path/filepath" - "strings" "time" "github.com/DeanWard/erugo/config" "github.com/DeanWard/erugo/db" "github.com/DeanWard/erugo/middleware" "github.com/DeanWard/erugo/models" + "github.com/DeanWard/erugo/progress" "github.com/DeanWard/erugo/responses" "github.com/DeanWard/erugo/responses/file_response" "github.com/DeanWard/erugo/store" "github.com/DeanWard/erugo/utils" - "github.com/go-playground/validator/v10" "github.com/gorilla/mux" "github.com/yelinaung/go-haikunator" ) @@ -45,82 +45,6 @@ func DownloadShareHandler(database *sql.DB) http.Handler { }) } -func handleCreateShare(database *sql.DB, w http.ResponseWriter, r *http.Request) { - - utils.Log("Creating share", utils.ColorGreen) - - userID := r.Context().Value(middleware.ContextKey("userID")).(int) - utils.Log(fmt.Sprintf("User ID: %d", userID), utils.ColorGreen) - - if err := r.ParseMultipartForm(10 << 20); err != nil { - http.Error(w, "Failed to parse form", http.StatusBadRequest) - return - } - - validate := validator.New() - files := r.MultipartForm.File["files"] - req := models.CreateShareRequest{ - Files: files, - } - - if err := validate.Struct(req); err != nil { - payload := map[string]interface{}{ - "errors": extractValidationErrors(err), - } - responses.SendResponse(w, responses.StatusError, "Validation error", payload, http.StatusBadRequest) - //log the error - utils.Log(fmt.Sprintf("Validation error: %v", payload), utils.ColorRed) - return - } - - maxShareSize := config.GetMaxShareSize() - utils.Log(fmt.Sprintf("Max share size: %d", maxShareSize), utils.ColorGreen) - if maxShareSize == 0 { - responses.SendResponse(w, responses.StatusError, "Max share size not set", nil, http.StatusInternalServerError) - utils.Log("Max share size not set", utils.ColorRed) - return - } - - //check if the total size of the files in the request is greater than the max share size - totalRequestSize := r.ContentLength - if totalRequestSize > maxShareSize { - responses.SendResponse(w, responses.StatusError, "Total size of files is greater than the max share size", nil, http.StatusBadRequest) - utils.Log("Total size of files is greater than the max share size", utils.ColorRed) - return - } - - longId := generateUniqueLongId(database) - folderPath := store.StoreUploadedFiles(config.AppConfig.BaseStoragePath, files, longId) - if folderPath == "" { - responses.SendResponse(w, responses.StatusError, "There was an error saving your files.", nil, http.StatusBadRequest) - utils.Log("There was an error saving users files.", utils.ColorRed) - return - } - totalSize := store.GetTotalSize(folderPath) - fileNames := store.GetFilesInFolder(folderPath) - - expirationDate := time.Now().AddDate(0, 0, 7).Format(time.RFC3339) - share := &models.Share{ - FilePath: folderPath, - ExpirationDate: expirationDate, - LongId: longId, - NumFiles: len(files), - TotalSize: totalSize, - Files: fileNames, - UserId: userID, - } - savedShare, err := db.ShareCreate(database, share) - if err != nil { - responses.SendResponse(w, responses.StatusError, "Failed to create share", nil, http.StatusInternalServerError) - utils.Log("Failed to create share", utils.ColorRed) - return - } - payload := map[string]interface{}{ - "share": savedShare.ToShareResponse(), - } - responses.SendResponse(w, responses.StatusSuccess, "Share created", payload, http.StatusOK) -} - func handleGetShare(database *sql.DB, w http.ResponseWriter, longId string) { share := db.ShareByLongId(database, longId) if share == nil { @@ -186,18 +110,6 @@ func handleDownloadShare(database *sql.DB, w http.ResponseWriter, r *http.Reques file_response.New(downloadFilePath, fileName).Send(w, r) } -func extractValidationErrors(err error) map[string]string { - errorMap := make(map[string]string) - - if validationErrors, ok := err.(validator.ValidationErrors); ok { - for _, e := range validationErrors { - errorMap[strings.ToLower(e.Field())] = e.Tag() // e.Tag() gives the failed rule (e.g., required, email) - } - } - - return errorMap -} - func IsShareExpired(share *models.Share) bool { expirationDate, err := time.Parse(time.RFC3339, share.ExpirationDate) if err != nil { @@ -206,3 +118,110 @@ func IsShareExpired(share *models.Share) bool { } return time.Now().After(expirationDate) } + +func handleCreateShare(database *sql.DB, w http.ResponseWriter, r *http.Request) { + utils.Log("Creating share", utils.ColorGreen) + + userID := r.Context().Value(middleware.ContextKey("userID")).(int) + uploadID := r.URL.Query().Get("uploadId") + if uploadID == "" { + responses.SendResponse(w, responses.StatusError, "Upload ID required", nil, http.StatusBadRequest) + return + } + + reader, err := r.MultipartReader() + if err != nil { + responses.SendResponse(w, responses.StatusError, "Failed to create multipart reader", nil, http.StatusBadRequest) + return + } + + // Initialize progress tracking + tracker := progress.GetTracker() + tracker.NewUpload(uploadID) + defer tracker.DeleteUpload(uploadID) + + maxShareSize := config.GetMaxShareSize() + if maxShareSize == 0 { + responses.SendResponse(w, responses.StatusError, "Max share size not set", nil, http.StatusInternalServerError) + return + } + + totalSize := r.ContentLength // Total size of the upload + utils.Log(fmt.Sprintf("Total size: %d", totalSize), utils.ColorGreen) + longId := generateUniqueLongId(database) + uploadDir := filepath.Join(config.AppConfig.BaseStoragePath, longId) + if err := os.MkdirAll(uploadDir, 0755); err != nil { + responses.SendResponse(w, responses.StatusError, "Failed to create upload directory", nil, http.StatusInternalServerError) + return + } + + var fileNames []string + var actualTotalSize int64 + var totalBytesRead int64 + + for { + part, err := reader.NextPart() + if err == io.EOF { + break + } + if err != nil { + responses.SendResponse(w, responses.StatusError, "Error reading upload", nil, http.StatusBadRequest) + return + } + defer part.Close() + + if part.FileName() == "" { + continue + } + + destPath := filepath.Join(uploadDir, part.FileName()) + dest, err := os.Create(destPath) + if err != nil { + responses.SendResponse(w, responses.StatusError, "Failed to create file", nil, http.StatusInternalServerError) + return + } + defer dest.Close() + + // Create progress reader with both current file and total size + pr := NewProgressReader(part, 1, totalSize, totalBytesRead, uploadID) // per-file progress isn't working yet so just use 1 for now + + written, err := io.Copy(dest, pr) + if err != nil { + responses.SendResponse(w, responses.StatusError, "Failed to save file", nil, http.StatusInternalServerError) + return + } + + totalBytesRead += written + actualTotalSize += written + fileNames = append(fileNames, part.FileName()) + + if actualTotalSize > maxShareSize { + os.RemoveAll(uploadDir) + responses.SendResponse(w, responses.StatusError, "Total size of files exceeds maximum allowed", nil, http.StatusBadRequest) + return + } + } + + expirationDate := time.Now().AddDate(0, 0, 7).Format(time.RFC3339) + share := &models.Share{ + FilePath: uploadDir, + ExpirationDate: expirationDate, + LongId: longId, + NumFiles: len(fileNames), + TotalSize: actualTotalSize, + Files: fileNames, + UserId: userID, + } + + savedShare, err := db.ShareCreate(database, share) + if err != nil { + os.RemoveAll(uploadDir) + responses.SendResponse(w, responses.StatusError, "Failed to create share", nil, http.StatusInternalServerError) + return + } + + payload := map[string]interface{}{ + "share": savedShare.ToShareResponse(), + } + responses.SendResponse(w, responses.StatusSuccess, "Share created", payload, http.StatusOK) +} diff --git a/progress/progress.go b/progress/progress.go new file mode 100644 index 0000000..49ed664 --- /dev/null +++ b/progress/progress.go @@ -0,0 +1,81 @@ +package progress + +import ( + "sync" + "time" +) + +type Progress struct { + BytesRead int64 `json:"bytesRead"` // Current file bytes read + TotalSize int64 `json:"totalSize"` // Current file total size + Percentage float64 `json:"percentage"` // Current file percentage + TotalBytesRead int64 `json:"totalBytesRead"` // Total bytes read across all files + TotalFileSize int64 `json:"totalFileSize"` // Total size of all files + TotalProgress float64 `json:"totalProgress"` // Overall progress percentage + UploadID string `json:"uploadId"` + LastUpdate time.Time +} + +type ProgressTracker struct { + mu sync.RWMutex + uploads map[string]chan Progress + cleanup map[string]chan struct{} +} + +var tracker *ProgressTracker +var once sync.Once + +func GetTracker() *ProgressTracker { + once.Do(func() { + tracker = &ProgressTracker{ + uploads: make(map[string]chan Progress), + cleanup: make(map[string]chan struct{}), + } + }) + return tracker +} + +func (t *ProgressTracker) NewUpload(uploadID string) chan Progress { + t.mu.Lock() + defer t.mu.Unlock() + + progressChan := make(chan Progress, 100) + cleanupChan := make(chan struct{}) + t.uploads[uploadID] = progressChan + t.cleanup[uploadID] = cleanupChan + + // Automatic cleanup after 1 hour + go func() { + select { + case <-cleanupChan: + // Normal cleanup + case <-time.After(1 * time.Hour): + // Timeout cleanup + } + t.DeleteUpload(uploadID) + }() + + return progressChan +} + +func (t *ProgressTracker) GetUploadChannel(uploadID string) (chan Progress, bool) { + t.mu.RLock() + defer t.mu.RUnlock() + ch, exists := t.uploads[uploadID] + return ch, exists +} + +func (t *ProgressTracker) DeleteUpload(uploadID string) { + t.mu.Lock() + defer t.mu.Unlock() + + if cleanupChan, exists := t.cleanup[uploadID]; exists { + close(cleanupChan) + delete(t.cleanup, uploadID) + } + + if progressChan, exists := t.uploads[uploadID]; exists { + close(progressChan) + delete(t.uploads, uploadID) + } +} diff --git a/routes/routes.go b/routes/routes.go index 6d7623c..90dd136 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -52,6 +52,7 @@ func registerAuthRoutes(router *mux.Router, database *sql.DB) { func registerShareRoutes(router *mux.Router, database *sql.DB) { log.Println("registering share routes") + //POST /api/shares - create a new share router.Handle("/api/shares", middleware.JwtMiddleware( @@ -59,6 +60,11 @@ func registerShareRoutes(router *mux.Router, database *sql.DB) { ), ).Methods("POST") + //GET /api/shares/progress/{uploadId} - get the progress of an upload + router.Handle("/api/shares/progress/{uploadId}", + handlers.UploadProgressHandler(), + ).Methods("GET") + //GET /api/shares/{longId} - get a share by longId router.Handle("/api/shares/{longId}", handlers.GetShareHandler(database),