diff --git a/.gitignore b/.gitignore
index 0b00543..d07c28c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,8 @@
.DS_Store
storage/
-erugo.db
+*.db
frontend/node_modules
frontend/dist
build/*
private
+*/*.db
diff --git a/frontend/src/api.js b/frontend/src/api.js
index 79e8ba7..71c9d00 100644
--- a/frontend/src/api.js
+++ b/frontend/src/api.js
@@ -169,8 +169,6 @@ export const createUser = async user => {
return data.data
}
-
-
export const updateUser = async user => {
const response = await fetchWithAuth(`${apiUrl}/api/users/${user.id}`, {
method: 'PUT',
@@ -274,7 +272,9 @@ export const saveLogo = async logoFile => {
return data.data
}
-export const createShare = async (files, name, description) => {
+// Share Methods
+export const createShare = async (files, name, description, uploadId) => {
+
const formData = new FormData()
files.forEach(file => {
formData.append('files', file)
@@ -282,7 +282,7 @@ export const createShare = async (files, name, description) => {
formData.append('name', name)
formData.append('description', description)
- const response = await fetchWithAuth(`${apiUrl}/api/shares`, {
+ const response = await fetchWithAuth(`${apiUrl}/api/shares?uploadId=${uploadId}`, {
method: 'POST',
body: formData
})
@@ -307,6 +307,7 @@ export const getShare = async id => {
return data.data.share
}
+//misc methods
export const getHealth = async () => {
const response = await fetch(`${apiUrl}/api/health`)
const data = await response.json()
diff --git a/frontend/src/components/uploader.vue b/frontend/src/components/uploader.vue
index fefb9f5..d835388 100644
--- a/frontend/src/components/uploader.vue
+++ b/frontend/src/components/uploader.vue
@@ -1,16 +1,19 @@
@@ -108,6 +148,17 @@
{{ niceFileSize(totalSize) }} / {{ niceFileSize(maxShareSize) }}
+
+
+
+
+ {{ Math.round(uploadProgress) }}%
+
{{ niceFileSize(uploadedBytes) }} / {{ niceFileSize(totalBytes) }}
+
+
+
@@ -165,3 +216,62 @@
+
+
diff --git a/frontend/src/utils.js b/frontend/src/utils.js
index cee8c5a..1063749 100644
--- a/frontend/src/utils.js
+++ b/frontend/src/utils.js
@@ -1,5 +1,13 @@
-
+const simpleUUID = () => {
+ //this isn't cryptographically secure, but it's good enough for our purposes
+ //our purposes being a simple unique string to track upload progress via SSE
+ return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function(c) {
+ const r = Math.random() * 16 | 0;
+ const v = c === 'x' ? r : (r & 0x3 | 0x8);
+ return v.toString(16);
+ });
+}
const niceFileSize = size => {
//return in most readable format
@@ -48,4 +56,4 @@ const getApiUrl = () => {
return url
}
-export { niceFileSize, niceFileType, niceExpirationDate, timeUntilExpiration, getApiUrl }
+export { niceFileSize, niceFileType, niceExpirationDate, timeUntilExpiration, getApiUrl, simpleUUID }
diff --git a/handlers/progress.go b/handlers/progress.go
new file mode 100644
index 0000000..d7d8858
--- /dev/null
+++ b/handlers/progress.go
@@ -0,0 +1,120 @@
+package handlers
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "github.com/DeanWard/erugo/progress"
+ "github.com/gorilla/mux"
+)
+
+func UploadProgressHandler() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ uploadID := mux.Vars(r)["uploadId"]
+ if uploadID == "" {
+ http.Error(w, "Upload ID required", http.StatusBadRequest)
+ return
+ }
+
+ // Set headers for SSE
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+
+ // Get the progress channel for this upload
+ tracker := progress.GetTracker()
+ progressChan, exists := tracker.GetUploadChannel(uploadID)
+ if !exists {
+ http.Error(w, "Upload not found", http.StatusNotFound)
+ return
+ }
+
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
+ return
+ }
+
+ // Use the request's context to detect client disconnect
+ ctx := r.Context()
+
+ for {
+ select {
+ case <-ctx.Done():
+ // Client disconnected
+ tracker.DeleteUpload(uploadID)
+ return
+
+ case progress, ok := <-progressChan:
+ if !ok {
+ // Channel closed - upload complete or failed
+ return
+ }
+
+ // Send progress update
+ data, _ := json.Marshal(progress)
+ fmt.Fprintf(w, "data: %s\n\n", data)
+ flusher.Flush()
+ }
+ }
+ })
+}
+
+type ProgressReader struct {
+ Reader io.Reader
+ Size int64 // Current file size
+ TotalFileSize int64 // Size of all files
+ bytesRead int64 // Current file bytes read
+ totalRead int64 // Total bytes read across all files
+ lastUpdate time.Time
+ uploadID string
+ tracker *progress.ProgressTracker
+}
+
+func NewProgressReader(reader io.Reader, size int64, totalSize int64, totalRead int64, uploadID string) *ProgressReader {
+ return &ProgressReader{
+ Reader: reader,
+ Size: size,
+ TotalFileSize: totalSize,
+ totalRead: totalRead,
+ uploadID: uploadID,
+ tracker: progress.GetTracker(),
+ lastUpdate: time.Now(),
+ }
+}
+
+func (pr *ProgressReader) Read(p []byte) (int, error) {
+ n, err := pr.Reader.Read(p)
+
+ pr.bytesRead += int64(n)
+ pr.totalRead += int64(n)
+
+ // Update progress every 500ms
+ if time.Since(pr.lastUpdate) > 500*time.Millisecond {
+ if progressChan, exists := pr.tracker.GetUploadChannel(pr.uploadID); exists {
+
+ totalProgress := float64(pr.totalRead) / float64(pr.TotalFileSize) * 100
+
+ select {
+ case progressChan <- progress.Progress{
+ BytesRead: pr.bytesRead,
+ TotalSize: pr.Size,
+ TotalBytesRead: pr.totalRead,
+ TotalFileSize: pr.TotalFileSize,
+ TotalProgress: totalProgress,
+ UploadID: pr.uploadID,
+ LastUpdate: time.Now(),
+ }:
+
+ default:
+ }
+ }
+ pr.lastUpdate = time.Now()
+ }
+
+ return n, err
+}
diff --git a/handlers/shares.go b/handlers/shares.go
index a707101..5bbd9f2 100644
--- a/handlers/shares.go
+++ b/handlers/shares.go
@@ -3,21 +3,21 @@ package handlers
import (
"database/sql"
"fmt"
+ "io"
"net/http"
"os"
"path/filepath"
- "strings"
"time"
"github.com/DeanWard/erugo/config"
"github.com/DeanWard/erugo/db"
"github.com/DeanWard/erugo/middleware"
"github.com/DeanWard/erugo/models"
+ "github.com/DeanWard/erugo/progress"
"github.com/DeanWard/erugo/responses"
"github.com/DeanWard/erugo/responses/file_response"
"github.com/DeanWard/erugo/store"
"github.com/DeanWard/erugo/utils"
- "github.com/go-playground/validator/v10"
"github.com/gorilla/mux"
"github.com/yelinaung/go-haikunator"
)
@@ -45,82 +45,6 @@ func DownloadShareHandler(database *sql.DB) http.Handler {
})
}
-func handleCreateShare(database *sql.DB, w http.ResponseWriter, r *http.Request) {
-
- utils.Log("Creating share", utils.ColorGreen)
-
- userID := r.Context().Value(middleware.ContextKey("userID")).(int)
- utils.Log(fmt.Sprintf("User ID: %d", userID), utils.ColorGreen)
-
- if err := r.ParseMultipartForm(10 << 20); err != nil {
- http.Error(w, "Failed to parse form", http.StatusBadRequest)
- return
- }
-
- validate := validator.New()
- files := r.MultipartForm.File["files"]
- req := models.CreateShareRequest{
- Files: files,
- }
-
- if err := validate.Struct(req); err != nil {
- payload := map[string]interface{}{
- "errors": extractValidationErrors(err),
- }
- responses.SendResponse(w, responses.StatusError, "Validation error", payload, http.StatusBadRequest)
- //log the error
- utils.Log(fmt.Sprintf("Validation error: %v", payload), utils.ColorRed)
- return
- }
-
- maxShareSize := config.GetMaxShareSize()
- utils.Log(fmt.Sprintf("Max share size: %d", maxShareSize), utils.ColorGreen)
- if maxShareSize == 0 {
- responses.SendResponse(w, responses.StatusError, "Max share size not set", nil, http.StatusInternalServerError)
- utils.Log("Max share size not set", utils.ColorRed)
- return
- }
-
- //check if the total size of the files in the request is greater than the max share size
- totalRequestSize := r.ContentLength
- if totalRequestSize > maxShareSize {
- responses.SendResponse(w, responses.StatusError, "Total size of files is greater than the max share size", nil, http.StatusBadRequest)
- utils.Log("Total size of files is greater than the max share size", utils.ColorRed)
- return
- }
-
- longId := generateUniqueLongId(database)
- folderPath := store.StoreUploadedFiles(config.AppConfig.BaseStoragePath, files, longId)
- if folderPath == "" {
- responses.SendResponse(w, responses.StatusError, "There was an error saving your files.", nil, http.StatusBadRequest)
- utils.Log("There was an error saving users files.", utils.ColorRed)
- return
- }
- totalSize := store.GetTotalSize(folderPath)
- fileNames := store.GetFilesInFolder(folderPath)
-
- expirationDate := time.Now().AddDate(0, 0, 7).Format(time.RFC3339)
- share := &models.Share{
- FilePath: folderPath,
- ExpirationDate: expirationDate,
- LongId: longId,
- NumFiles: len(files),
- TotalSize: totalSize,
- Files: fileNames,
- UserId: userID,
- }
- savedShare, err := db.ShareCreate(database, share)
- if err != nil {
- responses.SendResponse(w, responses.StatusError, "Failed to create share", nil, http.StatusInternalServerError)
- utils.Log("Failed to create share", utils.ColorRed)
- return
- }
- payload := map[string]interface{}{
- "share": savedShare.ToShareResponse(),
- }
- responses.SendResponse(w, responses.StatusSuccess, "Share created", payload, http.StatusOK)
-}
-
func handleGetShare(database *sql.DB, w http.ResponseWriter, longId string) {
share := db.ShareByLongId(database, longId)
if share == nil {
@@ -186,18 +110,6 @@ func handleDownloadShare(database *sql.DB, w http.ResponseWriter, r *http.Reques
file_response.New(downloadFilePath, fileName).Send(w, r)
}
-func extractValidationErrors(err error) map[string]string {
- errorMap := make(map[string]string)
-
- if validationErrors, ok := err.(validator.ValidationErrors); ok {
- for _, e := range validationErrors {
- errorMap[strings.ToLower(e.Field())] = e.Tag() // e.Tag() gives the failed rule (e.g., required, email)
- }
- }
-
- return errorMap
-}
-
func IsShareExpired(share *models.Share) bool {
expirationDate, err := time.Parse(time.RFC3339, share.ExpirationDate)
if err != nil {
@@ -206,3 +118,110 @@ func IsShareExpired(share *models.Share) bool {
}
return time.Now().After(expirationDate)
}
+
+func handleCreateShare(database *sql.DB, w http.ResponseWriter, r *http.Request) {
+ utils.Log("Creating share", utils.ColorGreen)
+
+ userID := r.Context().Value(middleware.ContextKey("userID")).(int)
+ uploadID := r.URL.Query().Get("uploadId")
+ if uploadID == "" {
+ responses.SendResponse(w, responses.StatusError, "Upload ID required", nil, http.StatusBadRequest)
+ return
+ }
+
+ reader, err := r.MultipartReader()
+ if err != nil {
+ responses.SendResponse(w, responses.StatusError, "Failed to create multipart reader", nil, http.StatusBadRequest)
+ return
+ }
+
+ // Initialize progress tracking
+ tracker := progress.GetTracker()
+ tracker.NewUpload(uploadID)
+ defer tracker.DeleteUpload(uploadID)
+
+ maxShareSize := config.GetMaxShareSize()
+ if maxShareSize == 0 {
+ responses.SendResponse(w, responses.StatusError, "Max share size not set", nil, http.StatusInternalServerError)
+ return
+ }
+
+ totalSize := r.ContentLength // Total size of the upload
+ utils.Log(fmt.Sprintf("Total size: %d", totalSize), utils.ColorGreen)
+ longId := generateUniqueLongId(database)
+ uploadDir := filepath.Join(config.AppConfig.BaseStoragePath, longId)
+ if err := os.MkdirAll(uploadDir, 0755); err != nil {
+ responses.SendResponse(w, responses.StatusError, "Failed to create upload directory", nil, http.StatusInternalServerError)
+ return
+ }
+
+ var fileNames []string
+ var actualTotalSize int64
+ var totalBytesRead int64
+
+ for {
+ part, err := reader.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ responses.SendResponse(w, responses.StatusError, "Error reading upload", nil, http.StatusBadRequest)
+ return
+ }
+ defer part.Close()
+
+ if part.FileName() == "" {
+ continue
+ }
+
+ destPath := filepath.Join(uploadDir, part.FileName())
+ dest, err := os.Create(destPath)
+ if err != nil {
+ responses.SendResponse(w, responses.StatusError, "Failed to create file", nil, http.StatusInternalServerError)
+ return
+ }
+ defer dest.Close()
+
+ // Create progress reader with both current file and total size
+ pr := NewProgressReader(part, 1, totalSize, totalBytesRead, uploadID) // per-file progress isn't working yet so just use 1 for now
+
+ written, err := io.Copy(dest, pr)
+ if err != nil {
+ responses.SendResponse(w, responses.StatusError, "Failed to save file", nil, http.StatusInternalServerError)
+ return
+ }
+
+ totalBytesRead += written
+ actualTotalSize += written
+ fileNames = append(fileNames, part.FileName())
+
+ if actualTotalSize > maxShareSize {
+ os.RemoveAll(uploadDir)
+ responses.SendResponse(w, responses.StatusError, "Total size of files exceeds maximum allowed", nil, http.StatusBadRequest)
+ return
+ }
+ }
+
+ expirationDate := time.Now().AddDate(0, 0, 7).Format(time.RFC3339)
+ share := &models.Share{
+ FilePath: uploadDir,
+ ExpirationDate: expirationDate,
+ LongId: longId,
+ NumFiles: len(fileNames),
+ TotalSize: actualTotalSize,
+ Files: fileNames,
+ UserId: userID,
+ }
+
+ savedShare, err := db.ShareCreate(database, share)
+ if err != nil {
+ os.RemoveAll(uploadDir)
+ responses.SendResponse(w, responses.StatusError, "Failed to create share", nil, http.StatusInternalServerError)
+ return
+ }
+
+ payload := map[string]interface{}{
+ "share": savedShare.ToShareResponse(),
+ }
+ responses.SendResponse(w, responses.StatusSuccess, "Share created", payload, http.StatusOK)
+}
diff --git a/progress/progress.go b/progress/progress.go
new file mode 100644
index 0000000..49ed664
--- /dev/null
+++ b/progress/progress.go
@@ -0,0 +1,81 @@
+package progress
+
+import (
+ "sync"
+ "time"
+)
+
+type Progress struct {
+ BytesRead int64 `json:"bytesRead"` // Current file bytes read
+ TotalSize int64 `json:"totalSize"` // Current file total size
+ Percentage float64 `json:"percentage"` // Current file percentage
+ TotalBytesRead int64 `json:"totalBytesRead"` // Total bytes read across all files
+ TotalFileSize int64 `json:"totalFileSize"` // Total size of all files
+ TotalProgress float64 `json:"totalProgress"` // Overall progress percentage
+ UploadID string `json:"uploadId"`
+ LastUpdate time.Time
+}
+
+type ProgressTracker struct {
+ mu sync.RWMutex
+ uploads map[string]chan Progress
+ cleanup map[string]chan struct{}
+}
+
+var tracker *ProgressTracker
+var once sync.Once
+
+func GetTracker() *ProgressTracker {
+ once.Do(func() {
+ tracker = &ProgressTracker{
+ uploads: make(map[string]chan Progress),
+ cleanup: make(map[string]chan struct{}),
+ }
+ })
+ return tracker
+}
+
+func (t *ProgressTracker) NewUpload(uploadID string) chan Progress {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ progressChan := make(chan Progress, 100)
+ cleanupChan := make(chan struct{})
+ t.uploads[uploadID] = progressChan
+ t.cleanup[uploadID] = cleanupChan
+
+ // Automatic cleanup after 1 hour
+ go func() {
+ select {
+ case <-cleanupChan:
+ // Normal cleanup
+ case <-time.After(1 * time.Hour):
+ // Timeout cleanup
+ }
+ t.DeleteUpload(uploadID)
+ }()
+
+ return progressChan
+}
+
+func (t *ProgressTracker) GetUploadChannel(uploadID string) (chan Progress, bool) {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ ch, exists := t.uploads[uploadID]
+ return ch, exists
+}
+
+func (t *ProgressTracker) DeleteUpload(uploadID string) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if cleanupChan, exists := t.cleanup[uploadID]; exists {
+ close(cleanupChan)
+ delete(t.cleanup, uploadID)
+ }
+
+ if progressChan, exists := t.uploads[uploadID]; exists {
+ close(progressChan)
+ delete(t.uploads, uploadID)
+ }
+}
diff --git a/routes/routes.go b/routes/routes.go
index 6d7623c..90dd136 100644
--- a/routes/routes.go
+++ b/routes/routes.go
@@ -52,6 +52,7 @@ func registerAuthRoutes(router *mux.Router, database *sql.DB) {
func registerShareRoutes(router *mux.Router, database *sql.DB) {
log.Println("registering share routes")
+
//POST /api/shares - create a new share
router.Handle("/api/shares",
middleware.JwtMiddleware(
@@ -59,6 +60,11 @@ func registerShareRoutes(router *mux.Router, database *sql.DB) {
),
).Methods("POST")
+ //GET /api/shares/progress/{uploadId} - get the progress of an upload
+ router.Handle("/api/shares/progress/{uploadId}",
+ handlers.UploadProgressHandler(),
+ ).Methods("GET")
+
//GET /api/shares/{longId} - get a share by longId
router.Handle("/api/shares/{longId}",
handlers.GetShareHandler(database),