From 81dcc8ea65283b7a0d1fda50234baa50074e43a1 Mon Sep 17 00:00:00 2001 From: akadan47 Date: Mon, 4 Nov 2024 18:29:51 +0600 Subject: [PATCH] Add content types for HLS playlist & segments. Add HLS path parsing for security (countering path traversal attacks). --- protocol/hls/src/server.rs | 204 +++++++++++++++++++++++++++---------- 1 file changed, 152 insertions(+), 52 deletions(-) diff --git a/protocol/hls/src/server.rs b/protocol/hls/src/server.rs index 3bf10af..0f11d8e 100644 --- a/protocol/hls/src/server.rs +++ b/protocol/hls/src/server.rs @@ -14,81 +14,135 @@ use { type GenericError = Box; type Result = std::result::Result; + static NOTFOUND: &[u8] = b"Not Found"; static UNAUTHORIZED: &[u8] = b"Unauthorized"; -async fn handle_connection(State(auth): State>, req: Request) -> Response { - let path = req.uri().path(); +#[derive(Debug)] +enum HlsFileType { + Playlist, + Segment, +} - let query_string: Option = req.uri().query().map(|s| s.to_string()); - let mut file_path: String = String::from(""); - - if path.ends_with(".m3u8") { - //http://127.0.0.1/app_name/stream_name/stream_name.m3u8 - let m3u8_index = path.find(".m3u8").unwrap(); - - if m3u8_index > 0 { - let (left, _) = path.split_at(m3u8_index); - let rv: Vec<_> = left.split('/').collect(); - - let app_name = String::from(rv[1]); - let stream_name = String::from(rv[2]); - - if let Some(auth_val) = auth { - if auth_val - .authenticate( - &stream_name, - &query_string.map(SecretCarrier::Query), - true, - ) - .is_err() - { - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(UNAUTHORIZED.into()) - .unwrap(); - } - } - - file_path = format!("./{app_name}/{stream_name}/{stream_name}.m3u8"); +impl HlsFileType { + const CONTENT_TYPE_PLAYLIST: &'static str = "application/vnd.apple.mpegurl"; + const CONTENT_TYPE_SEGMENT: &'static str = "video/mp2t"; + + fn content_type(&self) -> &str { + match self { + Self::Playlist => Self::CONTENT_TYPE_PLAYLIST, + Self::Segment => Self::CONTENT_TYPE_SEGMENT, } - } else if path.ends_with(".ts") { - //http://127.0.0.1/app_name/stream_name/ts_name.m3u8 - let ts_index = path.find(".ts").unwrap(); + } +} - if ts_index > 0 { - let (left, _) = path.split_at(ts_index); +#[derive(Debug)] +struct HlsPath { + app_name: String, + stream_name: String, + file_name: String, + file_type: HlsFileType, +} - let rv: Vec<_> = left.split('/').collect(); +impl HlsPath { + const M3U8_EXT: &'static str = "m3u8"; + const TS_EXT: &'static str = "ts"; - let app_name = String::from(rv[1]); - let stream_name = String::from(rv[2]); - let ts_name = String::from(rv[3]); + fn parse(path: &str) -> Option { + if path.is_empty() || path.contains("..") { + return None; + } + + let mut parts = path[1..].split('/'); + let app_name = parts.next()?; + let stream_name = parts.next()?; + let file_part = parts.next()?; + if parts.next().is_some() { + return None; + } - file_path = format!("./{app_name}/{stream_name}/{ts_name}.ts"); + let (file_name, ext) = file_part.rsplit_once('.')?; + if file_name.is_empty() { + return None; } + + let file_type = match ext { + Self::M3U8_EXT => HlsFileType::Playlist, + Self::TS_EXT => HlsFileType::Segment, + _ => return None, + }; + + Some(Self { + app_name: app_name.into(), + stream_name: stream_name.into(), + file_name: file_name.into(), + file_type, + }) + } + + fn to_file_path(&self) -> String { + let ext = match self.file_type { + HlsFileType::Playlist => Self::M3U8_EXT, + HlsFileType::Segment => Self::TS_EXT, + }; + format!( + "./{}/{}/{}.{}", + self.app_name, self.stream_name, self.file_name, ext + ) } - simple_file_send(file_path.as_str()).await } -/// HTTP status code 404 -fn not_found() -> Response { +fn response_unauthorized() -> Response { + Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(UNAUTHORIZED.into()) + .unwrap() +} + +fn response_not_found() -> Response { Response::builder() .status(StatusCode::NOT_FOUND) .body(NOTFOUND.into()) .unwrap() } -async fn simple_file_send(filename: &str) -> Response { - // Serve a file by asynchronously reading it by chunks using tokio-util crate. +async fn response_file(hls_path: &HlsPath) -> Response { + let file_path = hls_path.to_file_path(); + + if let Ok(file) = File::open(&file_path).await { + let builder = Response::builder().header("Content-Type", hls_path.file_type.content_type()); - if let Ok(file) = File::open(filename).await { + // Serve a file by asynchronously reading it by chunks using tokio-util crate. let stream = FramedRead::new(file, BytesCodec::new()); - let body = Body::from_stream(stream); - return Response::new(body); + return builder.body(Body::from_stream(stream)).unwrap(); + } + + response_not_found() +} + +async fn handle_connection(State(auth): State>, req: Request) -> Response { + let path = req.uri().path(); + let query_string = req.uri().query().map(|s| s.to_string()); + + let hls_path = match HlsPath::parse(path) { + Some(p) => p, + None => return response_not_found(), + }; + + if let (Some(auth_val), HlsFileType::Playlist) = (auth.as_ref(), &hls_path.file_type) { + if auth_val + .authenticate( + &hls_path.stream_name, + &query_string.map(SecretCarrier::Query), + true, + ) + .is_err() + { + return response_unauthorized(); + } } - not_found() + response_file(&hls_path).await } pub async fn run(port: usize, auth: Option) -> Result<()> { @@ -105,3 +159,49 @@ pub async fn run(port: usize, auth: Option) -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::{HlsFileType, HlsPath}; + + #[test] + fn test_hls_path_parse() { + // Playlist + let playlist = HlsPath::parse("/live/stream/stream.m3u8").unwrap(); + assert_eq!(playlist.app_name, "live"); + assert_eq!(playlist.stream_name, "stream"); + assert_eq!(playlist.file_name, "stream"); + assert!(matches!(playlist.file_type, HlsFileType::Playlist)); + assert_eq!(playlist.to_file_path(), "./live/stream/stream.m3u8"); + assert_eq!( + playlist.file_type.content_type(), + "application/vnd.apple.mpegurl" + ); + + // Segment + let segment = HlsPath::parse("/live/stream/123.ts").unwrap(); + assert_eq!(segment.app_name, "live"); + assert_eq!(segment.stream_name, "stream"); + assert_eq!(segment.file_name, "123"); + assert!(matches!(segment.file_type, HlsFileType::Segment)); + assert_eq!(segment.to_file_path(), "./live/stream/123.ts"); + assert_eq!(segment.file_type.content_type(), "video/mp2t"); + + // Negative + assert!(HlsPath::parse("").is_none()); + assert!(HlsPath::parse("/invalid").is_none()); + assert!(HlsPath::parse("/too/many/parts/of/path.m3u8").is_none()); + assert!(HlsPath::parse("/live/stream/invalid.mp4").is_none()); + assert!(HlsPath::parse("/live/stream/../../etc/passwd").is_none()); + assert!(HlsPath::parse("/live/stream/...").is_none()); + assert!(HlsPath::parse("/live/stream.m3u8").is_none()); + assert!(HlsPath::parse("/live/stream.ts").is_none()); + assert!(HlsPath::parse("/live/stream/").is_none()); + assert!(HlsPath::parse("/live/stream.m3u8").is_none()); + assert!(HlsPath::parse("/live/stream.ts").is_none()); + assert!(HlsPath::parse("/live/stream/file.").is_none()); + assert!(HlsPath::parse("/live/stream/.m3u8").is_none()); + assert!(HlsPath::parse("/live/stream/file.M3U8").is_none()); + assert!(HlsPath::parse("/live/stream/file.TS").is_none()); + } +}