Skip to content

Commit

Permalink
Handle MPD download errors and provide same status code on response
Browse files Browse the repository at this point in the history
  • Loading branch information
mhdzumair committed Nov 16, 2024
1 parent 035ae1a commit bcf22d2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
15 changes: 9 additions & 6 deletions mediaflow_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,15 @@ async def get_playlist(
Returns:
Response: The HTTP response with the HLS playlist.
"""
mpd_dict = await get_cached_mpd(
playlist_params.destination,
headers=proxy_headers.request,
parse_drm=not playlist_params.key_id and not playlist_params.key,
parse_segment_profile_id=playlist_params.profile_id,
)
try:
mpd_dict = await get_cached_mpd(
playlist_params.destination,
headers=proxy_headers.request,
parse_drm=not playlist_params.key_id and not playlist_params.key,
parse_segment_profile_id=playlist_params.profile_id,
)
except DownloadError as e:
raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
return await process_playlist(request, mpd_dict, playlist_params.profile_id, proxy_headers)


Expand Down
13 changes: 8 additions & 5 deletions mediaflow_proxy/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import ValidationError

from mediaflow_proxy.speedtest.models import SpeedTestTask
from mediaflow_proxy.utils.http_utils import download_file_with_retry
from mediaflow_proxy.utils.http_utils import download_file_with_retry, DownloadError
from mediaflow_proxy.utils.mpd_utils import parse_mpd, parse_mpd_dict

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -344,7 +344,7 @@ async def get_cached_mpd(
headers: dict,
parse_drm: bool,
parse_segment_profile_id: str | None = None,
) -> Optional[dict]:
) -> dict:
"""Get MPD from cache or download and parse it."""
# Try cache first
cached_data = await MPD_CACHE.get(mpd_url)
Expand All @@ -364,9 +364,12 @@ async def get_cached_mpd(
# Cache the original MPD dict
await MPD_CACHE.set(mpd_url, json.dumps(mpd_dict).encode())
return parsed_dict
except Exception as e:
logger.error(f"Error processing MPD: {e}")
return None
except DownloadError as error:
logger.error(f"Error downloading MPD: {error}")
raise error
except Exception as error:
logger.exception(f"Error processing MPD: {e}")
raise error


async def get_cached_speedtest(task_id: str) -> Optional[SpeedTestTask]:
Expand Down

0 comments on commit bcf22d2

Please sign in to comment.