diff --git a/platform/dvr-local-disk.go b/platform/dvr-local-disk.go index 9c3a4791..cb1907e1 100644 --- a/platform/dvr-local-disk.go +++ b/platform/dvr-local-disk.go @@ -4,6 +4,7 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" @@ -535,21 +536,16 @@ func (v *RecordWorker) Handle(ctx context.Context, handler *http.ServeMux) error return nil } -func (v *RecordWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error { +func (v *RecordWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error { // Copy the ts file to temporary cache dir. tsid := uuid.NewString() tsfile := path.Join("record", fmt.Sprintf("%v.ts", tsid)) - // Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/ - // Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS. - if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil { - return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile) - } - - // Get the file size. - stats, err := os.Stat(msg.File) - if err != nil { - return errors.Wrapf(err, "stat file %v", msg.File) + if file, err := os.Create(tsfile); err != nil { + return errors.Wrapf(err, "create file %v error", tsfile) + } else { + defer file.Close() + io.Copy(file, bytes.NewReader(data)) } // Create a local ts file object. @@ -558,7 +554,7 @@ func (v *RecordWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) URL: msg.URL, SeqNo: msg.SeqNo, Duration: msg.Duration, - Size: uint64(stats.Size()), + Size: uint64(len(data)), File: tsfile, } diff --git a/platform/dvr-tencent-cos.go b/platform/dvr-tencent-cos.go index 210960c6..d4adad4d 100644 --- a/platform/dvr-tencent-cos.go +++ b/platform/dvr-tencent-cos.go @@ -4,13 +4,14 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "os" - "os/exec" "path" "strings" "sync" @@ -231,7 +232,7 @@ func (v *DvrWorker) Handle(ctx context.Context, handler *http.ServeMux) error { return nil } -func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error { +func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error { // Ignore for Tencent Cloud credentials not ready. if !v.ready() { return nil @@ -241,16 +242,11 @@ func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er tsid := uuid.NewString() tsfile := path.Join("dvr", fmt.Sprintf("%v.ts", tsid)) - // Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/ - // Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS. - if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil { - return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile) - } - - // Get the file size. - stats, err := os.Stat(msg.File) - if err != nil { - return errors.Wrapf(err, "stat file %v", msg.File) + if file, err := os.Create(tsfile); err != nil { + return errors.Wrapf(err, "create file %v error", tsfile) + } else { + defer file.Close() + io.Copy(file, bytes.NewReader(data)) } // Create a local ts file object. @@ -259,7 +255,7 @@ func (v *DvrWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er URL: msg.URL, SeqNo: msg.SeqNo, Duration: msg.Duration, - Size: uint64(stats.Size()), + Size: uint64(len(data)), File: tsfile, } diff --git a/platform/dvr-tencent-vod.go b/platform/dvr-tencent-vod.go index 6c261a1f..f26d6567 100644 --- a/platform/dvr-tencent-vod.go +++ b/platform/dvr-tencent-vod.go @@ -4,13 +4,14 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" + "io" "net/http" "net/url" "os" - "os/exec" "path" "strconv" "strings" @@ -321,7 +322,7 @@ func (v *VodWorker) Handle(ctx context.Context, handler *http.ServeMux) error { return nil } -func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error { +func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error { // Ignore for Tencent Cloud credentials not ready. if !v.ready() { return nil @@ -331,16 +332,11 @@ func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er tsid := uuid.NewString() tsfile := path.Join("vod", fmt.Sprintf("%v.ts", tsid)) - // Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/ - // Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS. - if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil { - return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile) - } - - // Get the file size. - stats, err := os.Stat(msg.File) - if err != nil { - return errors.Wrapf(err, "stat file %v", msg.File) + if file, err := os.Create(tsfile); err != nil { + return errors.Wrapf(err, "create file %v error", tsfile) + } else { + defer file.Close() + io.Copy(file, bytes.NewReader(data)) } // Create a local ts file object. @@ -349,7 +345,7 @@ func (v *VodWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) er URL: msg.URL, SeqNo: msg.SeqNo, Duration: msg.Duration, - Size: uint64(stats.Size()), + Size: uint64(len(data)), File: tsfile, } diff --git a/platform/ocr.go b/platform/ocr.go index 5c739198..1c61f535 100644 --- a/platform/ocr.go +++ b/platform/ocr.go @@ -4,6 +4,7 @@ package main import ( + "bytes" "context" "encoding/base64" "encoding/json" @@ -21,6 +22,7 @@ import ( "github.com/ossrs/go-oryx-lib/errors" ohttp "github.com/ossrs/go-oryx-lib/http" "github.com/ossrs/go-oryx-lib/logger" + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ "github.com/go-redis/redis/v8" "github.com/google/uuid" @@ -39,17 +41,12 @@ type OCRWorker struct { // The global OCR task, only support one OCR task. task *OCRTask - // Use async goroutine to process on_hls messages. - msgs chan *SrsOnHlsMessage - // Got message from SRS, a new TS segment file is generated. tsfiles chan *SrsOnHlsObject } func NewOCRWorker() *OCRWorker { v := &OCRWorker{ - // Message on_hls. - msgs: make(chan *SrsOnHlsMessage, 1024), // TS files. tsfiles: make(chan *SrsOnHlsObject, 1024), } @@ -547,16 +544,7 @@ func (v *OCRWorker) Enabled() bool { return v.task.enabled() } -func (v *OCRWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error { - select { - case <-ctx.Done(): - case v.msgs <- msg: - } - - return nil -} - -func (v *OCRWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage) error { +func (v *OCRWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error { // Ignore if not natch the task config. if !v.task.match(msg) { return nil @@ -566,16 +554,11 @@ func (v *OCRWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage tsid := fmt.Sprintf("%v-org-%v", msg.SeqNo, uuid.NewString()) tsfile := path.Join("ocr", fmt.Sprintf("%v.ts", tsid)) - // Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/ - // Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS. - if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil { - return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile) - } - - // Get the file size. - stats, err := os.Stat(msg.File) - if err != nil { - return errors.Wrapf(err, "stat file %v", msg.File) + if file, err := os.Create(tsfile); err != nil { + return errors.Wrapf(err, "create file %v error", tsfile) + } else { + defer file.Close() + io.Copy(file, bytes.NewReader(data)) } // Create a local ts file object. @@ -584,7 +567,7 @@ func (v *OCRWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage URL: msg.URL, SeqNo: msg.SeqNo, Duration: msg.Duration, - Size: uint64(stats.Size()), + Size: uint64(len(data)), File: tsfile, } @@ -659,22 +642,6 @@ func (v *OCRWorker) Start(ctx context.Context) error { } }() - // Consume all on_hls messages. - wg.Add(1) - go func() { - defer wg.Done() - - for ctx.Err() == nil { - select { - case <-ctx.Done(): - case msg := <-v.msgs: - if err := v.OnHlsTsMessageImpl(ctx, msg); err != nil { - logger.Wf(ctx, "ocr: handle on hls message %v err %+v", msg.String(), err) - } - } - } - }() - // Consume all ts files by task. wg.Add(1) go func() { diff --git a/platform/srs-hooks.go b/platform/srs-hooks.go index 7001b648..700c38d0 100644 --- a/platform/srs-hooks.go +++ b/platform/srs-hooks.go @@ -730,38 +730,50 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error { return errors.Errorf("invalid action=%v", msg.Action) } - if _, err := os.Stat(msg.File); err != nil { - logger.Tf(ctx, "invalid ts file %v", msg.File) + allowedDir, err := os.Getwd() + if err != nil { + return errors.Wrapf(err, "can not get current working directory") + } - if err := os.MkdirAll(filepath.Dir(msg.File), 0755); err != nil { - return errors.Wrapf(err, "failed to create ts file directory %v", filepath.Dir(msg.File)) - } + safePath := filepath.Join(allowedDir, filepath.Clean(msg.File)) + logger.Tf(ctx, "safePath is %v", safePath) + absPath, err := filepath.Abs(safePath) + if err != nil { + return errors.Wrapf(err, "can not get absolute path from %v", safePath) + } - if tsFile, err := os.Create(msg.File); err != nil { - return errors.Wrapf(err, "failed to create ts file %v", msg.File) - } else { - tsUrl := "http://" + os.Getenv("SRS_HOST") + ":" + os.Getenv("SRS_HTTP_STREAM_PORT") + "/" + msg.URL - logger.Tf(ctx, "download ts from %v", tsUrl) - client := http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - r.URL.Opaque = r.URL.Path - return nil - }, - } + if !filepath.HasPrefix(absPath, allowedDir) { + return errors.Errorf("Access denied, %v is outside allowed directory", absPath) + } - resp, err := client.Get(tsUrl) - if err != nil { - return errors.Wrapf(err, "http error to get url %v", tsUrl) - } - defer resp.Body.Close() + var data []byte + if _, err := os.Stat(safePath); err != nil { + logger.Tf(ctx, "invalid ts file %v", safePath) + tsUrl := "http://" + os.Getenv("SRS_HOST") + ":" + os.Getenv("SRS_HTTP_STREAM_PORT") + "/" + msg.URL + logger.Tf(ctx, "download ts from %v", tsUrl) + client := http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + r.URL.Opaque = r.URL.Path + return nil + }, + } - size, err := io.Copy(tsFile, resp.Body) - if err != nil { - return errors.Wrapf(err, "copy http resp to file %v", tsFile) - } - defer tsFile.Close() - logger.Tf(ctx, "Download ts file %s with size %d", tsUrl, size) + res, err := client.Get(tsUrl) + if err != nil { + return errors.Wrapf(err, "http error to get url %v", tsUrl) + } + defer res.Body.Close() + + if b, err := io.ReadAll(res.Body); err != nil { + return errors.Wrapf(err, "read http response error") + } else { + data = b } + logger.Tf(ctx, "Download ts file %s with size %d", tsUrl, len(data)) + } else if b, err := os.ReadFile(safePath); err != nil { + return errors.Wrapf(err, "read %v error", safePath) + } else { + data = b } logger.Tf(ctx, "on_hls ok, %v", string(b)) @@ -769,7 +781,7 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error { if recordAll, err := rdb.HGet(ctx, SRS_RECORD_PATTERNS, "all").Result(); err != nil && err != redis.Nil { return errors.Wrapf(err, "hget %v all", SRS_RECORD_PATTERNS) } else if recordAll == "true" { - if err = recordWorker.OnHlsTsMessage(ctx, &msg); err != nil { + if err = recordWorker.OnHlsTsMessage(ctx, &msg, data); err != nil { return errors.Wrapf(err, "feed %v", msg.String()) } logger.Tf(ctx, "record %v", msg.String()) @@ -779,7 +791,7 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error { if dvrAll, err := rdb.HGet(ctx, SRS_DVR_PATTERNS, "all").Result(); err != nil && err != redis.Nil { return errors.Wrapf(err, "hget %v all", SRS_DVR_PATTERNS) } else if dvrAll == "true" { - if err = dvrWorker.OnHlsTsMessage(ctx, &msg); err != nil { + if err = dvrWorker.OnHlsTsMessage(ctx, &msg, data); err != nil { return errors.Wrapf(err, "feed %v", msg.String()) } logger.Tf(ctx, "dvr %v", msg.String()) @@ -789,7 +801,7 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error { if vodAll, err := rdb.HGet(ctx, SRS_VOD_PATTERNS, "all").Result(); err != nil && err != redis.Nil { return errors.Wrapf(err, "hget %v all", SRS_VOD_PATTERNS) } else if vodAll == "true" { - if err = vodWorker.OnHlsTsMessage(ctx, &msg); err != nil { + if err = vodWorker.OnHlsTsMessage(ctx, &msg, data); err != nil { return errors.Wrapf(err, "feed %v", msg.String()) } logger.Tf(ctx, "vod %v", msg.String()) @@ -797,7 +809,7 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error { // Handle TS file by Transcript task if enabled. if transcriptWorker.Enabled() { - if err = transcriptWorker.OnHlsTsMessage(ctx, &msg); err != nil { + if err = transcriptWorker.OnHlsTsMessage(ctx, &msg, data); err != nil { return errors.Wrapf(err, "feed %v", msg.String()) } logger.Tf(ctx, "transcript %v", msg.String()) @@ -805,7 +817,7 @@ func handleOnHls(ctx context.Context, handler *http.ServeMux) error { // Handle TS file by OCR task if enabled. if ocrWorker.Enabled() { - if err = ocrWorker.OnHlsTsMessage(ctx, &msg); err != nil { + if err = ocrWorker.OnHlsTsMessage(ctx, &msg, data); err != nil { return errors.Wrapf(err, "feed %v", msg.String()) } logger.Tf(ctx, "ocr %v", msg.String()) diff --git a/platform/transcript.go b/platform/transcript.go index 51340915..17b5b2a4 100644 --- a/platform/transcript.go +++ b/platform/transcript.go @@ -4,6 +4,7 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" @@ -21,6 +22,7 @@ import ( "github.com/ossrs/go-oryx-lib/errors" ohttp "github.com/ossrs/go-oryx-lib/http" "github.com/ossrs/go-oryx-lib/logger" + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ "github.com/go-redis/redis/v8" "github.com/google/uuid" @@ -38,18 +40,12 @@ type TranscriptWorker struct { // The global transcript task, only support one transcript task. task *TranscriptTask - - // Use async goroutine to process on_hls messages. - msgs chan *SrsOnHlsMessage - // Got message from SRS, a new TS segment file is generated. tsfiles chan *SrsOnHlsObject } func NewTranscriptWorker() *TranscriptWorker { v := &TranscriptWorker{ - // Message on_hls. - msgs: make(chan *SrsOnHlsMessage, 1024), // TS files. tsfiles: make(chan *SrsOnHlsObject, 1024), } @@ -942,16 +938,7 @@ func (v *TranscriptWorker) Enabled() bool { return v.task.enabled() } -func (v *TranscriptWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage) error { - select { - case <-ctx.Done(): - case v.msgs <- msg: - } - - return nil -} - -func (v *TranscriptWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHlsMessage) error { +func (v *TranscriptWorker) OnHlsTsMessage(ctx context.Context, msg *SrsOnHlsMessage, data []byte) error { // Ignore if not natch the task config. if !v.task.match(msg) { return nil @@ -961,16 +948,11 @@ func (v *TranscriptWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHls tsid := fmt.Sprintf("%v-org-%v", msg.SeqNo, uuid.NewString()) tsfile := path.Join("transcript", fmt.Sprintf("%v.ts", tsid)) - // Always use execFile when params contains user inputs, see https://auth0.com/blog/preventing-command-injection-attacks-in-node-js-apps/ - // Note that should never use fs.copyFileSync(file, tsfile, fs.constants.COPYFILE_FICLONE_FORCE) which fails in macOS. - if err := exec.CommandContext(ctx, "cp", "-f", msg.File, tsfile).Run(); err != nil { - return errors.Wrapf(err, "copy file %v to %v", msg.File, tsfile) - } - - // Get the file size. - stats, err := os.Stat(msg.File) - if err != nil { - return errors.Wrapf(err, "stat file %v", msg.File) + if file, err := os.Create(tsfile); err != nil { + return errors.Wrapf(err, "create file %v error", tsfile) + } else { + defer file.Close() + io.Copy(file, bytes.NewReader(data)) } // Create a local ts file object. @@ -979,7 +961,7 @@ func (v *TranscriptWorker) OnHlsTsMessageImpl(ctx context.Context, msg *SrsOnHls URL: msg.URL, SeqNo: msg.SeqNo, Duration: msg.Duration, - Size: uint64(stats.Size()), + Size: uint64(len(data)), File: tsfile, } @@ -1054,22 +1036,6 @@ func (v *TranscriptWorker) Start(ctx context.Context) error { } }() - // Consume all on_hls messages. - wg.Add(1) - go func() { - defer wg.Done() - - for ctx.Err() == nil { - select { - case <-ctx.Done(): - case msg := <-v.msgs: - if err := v.OnHlsTsMessageImpl(ctx, msg); err != nil { - logger.Wf(ctx, "transcript: handle on hls message %v err %+v", msg.String(), err) - } - } - } - }() - // Consume all ts files by task. wg.Add(1) go func() {