diff --git a/tests/test_subsamplers.py b/tests/test_subsamplers.py index 7a5e706b..be6045ac 100644 --- a/tests/test_subsamplers.py +++ b/tests/test_subsamplers.py @@ -11,6 +11,7 @@ ClippingSubsampler, _get_seconds, _split_time_frame, + Streams, FFProbeSubsampler, ResolutionSubsampler, FrameSubsampler, @@ -45,8 +46,8 @@ def test_clipping_subsampler(clips): min_length = 5.0 if clips == MULTI else 2.0 max_length = 999999.0 if clips == MULTI else 3.0 subsampler = ClippingSubsampler( - 3, - {"video": "mp4", "audio": "mp3"}, + oom_clip_count=3, + encode_formats={"video": "mp4", "audio": "mp3"}, min_length=min_length, max_length=max_length, max_length_strategy="all", @@ -58,7 +59,7 @@ def test_clipping_subsampler(clips): "clips": clips, } - streams = {"video": [video_bytes], "audio": [audio_bytes]} + streams: Streams = {"video": [video_bytes], "audio": [audio_bytes]} stream_fragments, meta_fragments, error_message = subsampler(streams, metadata) video_fragments = stream_fragments["video"] audio_fragments = stream_fragments["audio"] @@ -84,7 +85,7 @@ def test_clipping_subsampler(clips): s_target, e_target = clips[key_ind] s_target, e_target = _get_seconds(s_target), _get_seconds(e_target) expected_clips = _split_time_frame(s_target, e_target, min_length, max_length) - assert (_get_seconds(s), _get_seconds(e)) in expected_clips + assert [_get_seconds(s), _get_seconds(e)] in expected_clips assert _get_seconds(e) - _get_seconds(s) >= min_length s_s, e_s = _get_seconds(s), _get_seconds(e) @@ -92,7 +93,6 @@ def test_clipping_subsampler(clips): video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] frag_len = float(video_stream["duration"]) - # currently some segments can be pretty innacurate assert abs(frag_len - (e_s - s_s)) < 5.0 diff --git a/video2dataset/subsamplers/__init__.py b/video2dataset/subsamplers/__init__.py index 5d4741f8..90e4cd58 100644 --- a/video2dataset/subsamplers/__init__.py +++ b/video2dataset/subsamplers/__init__.py @@ -3,7 +3,7 @@ """ from .audio_rate_subsampler import AudioRateSubsampler -from .clipping_subsampler import ClippingSubsampler, _get_seconds, _split_time_frame +from .clipping_subsampler import ClippingSubsampler, _get_seconds, _split_time_frame, Streams from .frame_subsampler import FrameSubsampler from .ffprobe_subsampler import FFProbeSubsampler from .noop_subsampler import NoOpSubsampler diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 2d9e4592..73eae18f 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -1,18 +1,33 @@ """ clipping subsampler turns full videos into clips of videos according to clip_col """ -import os +from collections.abc import Iterable +from typing import Any, Union, List, Tuple, Dict, TypedDict, Literal, cast import copy -import glob import ffmpeg +import glob +import os import tempfile -from collections.abc import Iterable import datetime from .subsampler import Subsampler -def _get_seconds(t): +ClipSpan = List[float] # [start, end] + + +class EncodeFormats(TypedDict): + video: str + audio: str + + +class Streams(TypedDict): + video: List[bytes] + audio: List[bytes] + + +def _get_seconds(t: Union[str, float]) -> float: + """Converts time to seconds""" if not isinstance(t, str): return float(t) # already seconds time_format = "%H:%M:%S.%f" # TODO: maybe parameterize this? @@ -20,7 +35,8 @@ def _get_seconds(t): return t_obj.second + t_obj.microsecond / 1e6 + t_obj.minute * 60 + t_obj.hour * 3600 -def _get_strtime(t_sec): +def _get_strtime(t_sec: float) -> str: + """Converts time to string""" hour = int(t_sec // 3600) minute = int((t_sec // 60) % 60) second = int(t_sec % 60) @@ -29,36 +45,102 @@ def _get_strtime(t_sec): return f"{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}" -def _split_time_frame(s, e, min_length, max_length): +def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> List[ClipSpan]: """Filters out cuts by min and max length""" time_d = e - s - time_frames = [ - (s + i * max_length, min(s + (i + 1) * max_length, e)) - for i in range(int(time_d // max_length) + (1 if time_d % max_length > 0 else 0)) - ] - if len(time_frames) == 0: - return [] - last_time_d = time_frames[-1][1] - time_frames[-1][0] - time_frames = time_frames if last_time_d >= min_length else time_frames[:-1] - return time_frames - - -def _adjust_ranges_to_keyframes(ranges, keyframes): - """Translates ranges into keyframe vocab""" - adjusted_ranges = [] - for start, end in ranges: + n_full_clips = int(time_d // max_length) + clip_spans = [[s + i * max_length, s + (i + 1) * max_length] for i in range(n_full_clips)] + ( + [[s + (n_full_clips) * max_length, e]] if time_d % max_length > min_length else [] + ) + return clip_spans + + +def _adjust_clip_spans_to_keyframes(clip_spans: List[ClipSpan], keyframes: List[float]) -> List[ClipSpan]: + """Translates clip_spans into keyframe vocab""" + adjusted_clip_spans = [] + for start, end in clip_spans: keyframes_in_range = [k for k in keyframes if start <= k <= end] if keyframes_in_range: adjusted_start = min(keyframes_in_range) adjusted_end = max(keyframes_in_range) if adjusted_start != adjusted_end: - adjusted_ranges.append((adjusted_start, adjusted_end)) - return adjusted_ranges - - -def _extract_subtitles(clip_span, meta_clip): + adjusted_clip_spans.append([adjusted_start, adjusted_end]) + return adjusted_clip_spans + + +def _adjust_clip_spans( + clip_spans: List[ClipSpan], + keyframe_timestamps: Union[List[float], None], + min_length: float, + max_length: float, + max_length_strategy: str, +) -> List[ClipSpan]: + """Adjusts cut times around keyframes, filtering by min and max length""" + if not isinstance(clip_spans[0], Iterable): # make sure clip_spans looks like [[start, end]] and not [start, end] + clip_spans = cast(List[ClipSpan], [clip_spans]) + clip_spans = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_spans] + + if keyframe_timestamps: + clip_spans = _adjust_clip_spans_to_keyframes(clip_spans, keyframe_timestamps) + + filtered_clip_spans = [] + for s, e in clip_spans: + max_len_clip_spans = _split_time_frame(s, e, min_length, max_length) + if max_length_strategy == "first": + max_len_clip_spans = max_len_clip_spans[:1] + filtered_clip_spans += max_len_clip_spans + return filtered_clip_spans + + +def _collate_clip_spans(clip_spans: List[ClipSpan]) -> Tuple[str, List[int]]: + """Collates clip spans into a single string for ffmpeg and a list of clip idxs""" + clip_times = [] + clip_idxs = [] + e_prev = 0.0 + clip_idx = 0 + + for s, e in clip_spans: + if s == e_prev: # clip starts where last one left off + clip_times += [e] + clip_idxs.append(clip_idx) + clip_idx += 1 + else: # next clip skips over some time + clip_times += [s, e] + clip_idxs.append(clip_idx + 1) + clip_idx += 2 + e_prev = e + + clip_times_str = ",".join([str(time) for time in clip_times]) + return clip_times_str, clip_idxs + + +def _process_stream( + tmpdir: Any, # BytesPath + stream_bytes: bytes, + encode_format: str, + ffmpeg_kwargs: dict, +) -> List[str]: + """Processes a stream into clips using ffmpeg""" + # TODO: we need to put the extension into the metadata + # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn + with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: + f.write(stream_bytes) + try: + ( + ffmpeg.input(f"{tmpdir}/input.{encode_format}") + .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) + .run(capture_stdout=True, quiet=True) + ) + except Exception as err: # pylint: disable=broad-except + raise err + stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") + stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) + return stream_clips + + +def _extract_subtitles(clip_span: ClipSpan, meta_clip: dict) -> List[dict]: """Extracts subtitles and groups them by language""" - clip_subtitles = [] + clip_subtitles: List[dict] = [] s_c, e_c = _get_seconds(clip_span[0]), _get_seconds(clip_span[1]) for lang_id, (lang, subtitles) in enumerate(meta_clip["yt_meta_dict"]["subtitles"].items()): idx = 0 @@ -75,10 +157,100 @@ def _extract_subtitles(clip_span, meta_clip): clip_subtitles.append(temp_line) elif s > e_c: break - return clip_subtitles +def _get_clip_metadata( + clip_spans: List[ClipSpan], + clip_idxs: List[int], + metadata: dict, + oom_clip_count: int, + strtime_formatting: bool, +) -> List[dict]: + """Gets metadata for each clip""" + metadata_clips = [] + for clip_id, (clip_span, _) in enumerate(zip(clip_spans, clip_idxs)): + clip_key = "{clip_id:0{oom_clip_count}d}".format( # pylint: disable=consider-using-f-string + clip_id=clip_id, oom_clip_count=oom_clip_count + ) + meta_clip = copy.deepcopy(metadata) + # set the timeframe of this clip + if strtime_formatting: + # Keep clip_spans in the original format to be compatible with the data schema. + meta_clip["clips"] = [(_get_strtime(clip_span[0]), _get_strtime(clip_span[1]))] + else: + meta_clip["clips"] = [clip_span] + meta_clip["key"] = f"{meta_clip['key']}_{clip_key}" + + yt_md_dict = meta_clip.get("yt_meta_dict", {}) + if (yt_md_dict is not None) and (yt_md_dict.get("subtitles", None) is not None): + meta_clip["clip_subtitles"] = _extract_subtitles(clip_span, meta_clip) + metadata_clips.append(meta_clip) + + # remove redundant metadata from clips after the first + for m_clips in metadata_clips[1:]: + m_clips["yt_meta_dict"] = {} + + return metadata_clips + + +def _get_clips( + streams: Streams, + encode_formats: EncodeFormats, + precision: str, + clip_spans: List[ClipSpan], + metadata: dict, + oom_clip_count: int, + strtime_formatting: bool, +) -> Tuple[Dict[str, List[bytes]], List[dict]]: + """Gets clips from streams""" + clip_times, clip_idxs = _collate_clip_spans(clip_spans) + + ffmpeg_kwargs = { + "map": 0, + "f": "segment", + "segment_times": clip_times, + "reset_timestamps": 1, + } + if precision == "exact": + ffmpeg_kwargs["force_key_frames"] = clip_times + else: + ffmpeg_kwargs["c"] = "copy" + + clips: Dict[str, List[bytes]] = {} + for k in streams.keys(): + k = cast(Literal["audio", "video"], k) + with tempfile.TemporaryDirectory() as tmpdir: + stream_bytes = streams[k][0] # pre-broadcast so only one + if stream_bytes is None: + continue + try: + stream_clips = _process_stream( + tmpdir=tmpdir, + stream_bytes=stream_bytes, + encode_format=encode_formats[k], + ffmpeg_kwargs=ffmpeg_kwargs, + ) + except Exception as err: # pylint: disable=broad-except + raise err + + clips[k] = [] + for clip_idx in clip_idxs: + with open(stream_clips[clip_idx], "rb") as vid_f: + clip_bytes = vid_f.read() + clips[k].append(clip_bytes) + + clip_metadata = _get_clip_metadata( + clip_spans=clip_spans, + clip_idxs=clip_idxs, + metadata=metadata, + oom_clip_count=oom_clip_count, + strtime_formatting=strtime_formatting, + ) + + return clips, clip_metadata + + class ClippingSubsampler(Subsampler): """ Cuts videos up into segments according to the 'clips' metadata @@ -108,146 +280,51 @@ class ClippingSubsampler(Subsampler): def __init__( self, - oom_clip_count, - encode_formats, - min_length=0.0, - max_length=999999.0, - max_length_strategy="all", - precision="low", + oom_clip_count: int, + encode_formats: EncodeFormats, + min_length: float = 0.0, + max_length: float = 999999.0, + max_length_strategy: Literal["all", "first"] = "all", + precision: Literal["low", "keyframe_adjusted", "exact"] = "low", ): + assert max_length_strategy in ["all", "first"] + assert precision in ["exact", "low", "keyframe_adjusted"] self.oom_clip_count = oom_clip_count self.encode_formats = encode_formats self.min_length = min_length - self.max_length, self.max_length_strategy = max_length, max_length_strategy - assert precision in ["exact", "low", "keyframe_adjusted"] + self.max_length = max_length + self.max_length_strategy = max_length_strategy self.precision = precision - def __call__(self, streams, metadata): - clips = metadata.pop("clips") - - if not isinstance(clips[0], Iterable): # make sure clips looks like [[start, end]] and not [start, end] - clips = [clips] - - is_strtime = isinstance(clips[0][0], str) - - if self.precision == "keyframe_adjusted": - # TODO: make it so if not present, get it yourself - keyframe_timestamps = metadata["video_metadata"].pop("keyframe_timestamps") - s_clips = [[_get_seconds(s), _get_seconds(e)] for (s, e) in clips] - clips = _adjust_ranges_to_keyframes(s_clips, keyframe_timestamps) - - filtered_clips = [] - for s, e in clips: - max_len_clips = _split_time_frame(_get_seconds(s), _get_seconds(e), self.min_length, self.max_length) - - if self.max_length_strategy == "first": - max_len_clips = max_len_clips[:1] - - filtered_clips += max_len_clips - clips = filtered_clips - - if len(clips) == 0: - # return an error - return {}, [], f"Video had no clips longer than {self.min_length}" - - start_0 = _get_seconds(clips[0][0]) == 0.0 - - ind = 1 + int(not start_0) - s_p, e_p = clips[0] - s_p, e_p = _get_seconds(s_p), _get_seconds(e_p) - splits = (not start_0) * [s_p] + [e_p] - # list of indicies of clips to take, used to discard non-contiguous sections - take_inds = [int(not start_0)] - - # TODO: make nicer - for s, e in clips[1:]: - s, e = _get_seconds(s), _get_seconds(e) - - if s == e_p: # situations like [0, 1], [1, 2], [2, 3] -> 1, 2 - splits += [e] - take_inds.append(ind) - ind += 1 - else: - splits += [s, e] - take_inds.append(ind + 1) - ind += 2 - e_p = e - - segment_times = ",".join([str(spl) for spl in splits]) - streams_clips = {} - - for k in streams.keys(): - stream_bytes = streams[k][0] # pre-broadcast so only one - if stream_bytes is None: - continue - encode_format = self.encode_formats[k] - - with tempfile.TemporaryDirectory() as tmpdir: - # TODO: we need to put the extension into the metadata - # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn - with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: - f.write(stream_bytes) - try: - kwargs = { - "map": 0, - "f": "segment", - "segment_times": segment_times, - "reset_timestamps": 1, - } - - # Precision things, tradeoff for speed - if self.precision != "exact": - kwargs["c"] = "copy" - else: - kwargs["force_key_frames"] = segment_times - - _ = ( - ffmpeg.input(f"{tmpdir}/input.{encode_format}") - .output(f"{tmpdir}/clip_%d.{encode_format}", **kwargs) - .run(capture_stdout=True, quiet=True) - ) - - except Exception as err: # pylint: disable=broad-except - return {}, [], str(err) - - stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") - stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) - - correct_clips = [] - for clip_id, (clip, ind) in enumerate(zip(clips, take_inds)): - if ind < len(stream_clips): - correct_clips.append((clip_id, clip, stream_clips[ind])) - # clips_lost = len(take_inds) - len(correct_clips) # TODO report this somehow - - stream_clips, metadata_clips = [], [] - for clip_id, clip_span, clip_pth in correct_clips: - with open(clip_pth, "rb") as vid_f: - clip_bytes = vid_f.read() - stream_clips.append(clip_bytes) - - clip_key = "{clip_id:0{oom_clip_count}d}".format( # pylint: disable=consider-using-f-string - clip_id=clip_id, oom_clip_count=self.oom_clip_count - ) - meta_clip = copy.deepcopy(metadata) - # set the timeframe of this clip - if is_strtime: - # Keep clips in the original format to be compatible with the data schema. - meta_clip["clips"] = [(_get_strtime(clip_span[0]), _get_strtime(clip_span[1]))] - else: - meta_clip["clips"] = [clip_span] - meta_clip["key"] = f"{meta_clip['key']}_{clip_key}" - - yt_md_dict = meta_clip.get("yt_meta_dict", {}) - if (yt_md_dict is not None) and (yt_md_dict.get("subtitles", None) is not None): - # full video subtitles might still be useful for context - meta_clip["clip_subtitles"] = _extract_subtitles(clip_span, meta_clip) - - metadata_clips.append(meta_clip) - - streams_clips[k] = stream_clips - - # remove redundant metadata from clips after the first - for m_clips in metadata_clips[1:]: - m_clips["yt_meta_dict"] = {} - - return streams_clips, metadata_clips, None + def __call__(self, streams: Streams, metadata: dict): + strtime_formatting = isinstance(metadata["clips"][0][0], str) + + clip_spans = _adjust_clip_spans( + clip_spans=metadata.pop("clips"), + keyframe_timestamps=( + # TODO: make it so if keyframe timestamps not present, get it yourself + metadata["video_metadata"].pop("keyframe_timestamps") + if self.precision == "keyframe_adjusted" + else None + ), + min_length=self.min_length, + max_length=self.max_length, + max_length_strategy=self.max_length_strategy, + ) + if len(clip_spans) == 0: + return {}, [], f"Video had no clip_spans longer than {self.min_length}" + + try: + clips, clip_metadata = _get_clips( + streams=streams, + encode_formats=self.encode_formats, + precision=self.precision, + clip_spans=clip_spans, + metadata=metadata, + oom_clip_count=self.oom_clip_count, + strtime_formatting=strtime_formatting, + ) + except Exception as err: # pylint: disable=broad-except + return {}, [], str(err) + + return clips, clip_metadata, None