diff --git a/slowfast/datasets/decoder.py b/slowfast/datasets/decoder.py index 727e60bf5..9e3122108 100644 --- a/slowfast/datasets/decoder.py +++ b/slowfast/datasets/decoder.py @@ -85,6 +85,7 @@ def get_multiple_start_end_idx( num_clips_uniform, min_delta=0, max_delta=math.inf, + use_offset=False ): """ Sample a clip of size clip_size from a video of size video_size and @@ -114,20 +115,28 @@ def sample_clips( min_delta=0, max_delta=math.inf, num_retries=100, + use_offset=False ): se_inds = np.empty((0, 2)) dt = np.empty((0)) for clip_size in clip_sizes: for i_try in range(num_retries): - clip_size = int(clip_size) + # clip_size = int(clip_size) max_start = max(video_size - clip_size, 0) if clip_idx == -1: # Random temporal sampling. start_idx = random.uniform(0, max_start) - else: - # Uniformly sample the clip with the given index. - start_idx = max_start * clip_idx / num_clips_uniform - end_idx = start_idx + clip_size # - 1 + else: # Uniformly sample the clip with the given index. + if use_offset: + if num_clips_uniform == 1: + # Take the center clip if num_clips is 1. + start_idx = math.floor(max_start / 2) + else: + start_idx = clip_idx * math.floor(max_start / (num_clips_uniform - 1)) + else: + start_idx = max_start * clip_idx / num_clips_uniform + + end_idx = start_idx + clip_size - 1 se_inds_new = np.append(se_inds, [[start_idx, end_idx]], axis=0) if se_inds.shape[0] < 1: @@ -156,6 +165,7 @@ def sample_clips( min_delta, max_delta, 100, + use_offset, ) success = not (any(dt < min_delta) or any(dt > max_delta)) if success or clip_idx != -1: @@ -295,9 +305,7 @@ def torchvision_decode( clip_sizes = [ np.maximum( 1.0, - np.ceil( - sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps - ), + sampling_rate[i] * num_frames[i] / target_fps * fps ) for i in range(len(sampling_rate)) ] @@ -308,6 +316,7 @@ def torchvision_decode( num_clips_uniform, min_delta=min_delta, max_delta=max_delta, + use_offset=use_offset, ) frames_out = [None] * len(num_frames) for k in range(len(num_frames)): @@ -374,6 +383,10 @@ def pyav_decode( num_clips_uniform=10, target_fps=30, use_offset=False, + modalities=("visual",), + max_spatial_scale=0, + min_delta=-math.inf, + max_delta=math.inf, ): """ Convert the video from its original fps to the target_fps. If the video @@ -411,38 +424,69 @@ def pyav_decode( # If failed to fetch the decoding information, decode the entire video. decode_all_video = True video_start_pts, video_end_pts = 0, math.inf + start_end_delta_time = None + + frames = None + if container.streams.video: + video_frames, max_pts = pyav_decode_stream( + container, + video_start_pts, + video_end_pts, + container.streams.video[0], + {"video": 0}, + ) + container.close() + + frames = [frame.to_rgb().to_ndarray() for frame in video_frames] + frames = torch.as_tensor(np.stack(frames)) + frames_out = [frames] + else: # Perform selective decoding. decode_all_video = False - clip_size = np.maximum( - 1.0, np.ceil(sampling_rate * (num_frames - 1) / target_fps * fps) - ) - start_idx, end_idx, fraction = get_start_end_idx( + clip_sizes = [ + np.maximum( + 1.0, + np.ceil( + sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps + ), + ) + for i in range(len(sampling_rate)) + ] + start_end_delta_time = get_multiple_start_end_idx( frames_length, - clip_size, + clip_sizes, clip_idx, num_clips_uniform, - use_offset=use_offset, - ) - timebase = duration / frames_length - video_start_pts = int(start_idx * timebase) - video_end_pts = int(end_idx * timebase) - - frames = None - # If video stream was found, fetch video frames from the video. - if container.streams.video: - video_frames, max_pts = pyav_decode_stream( - container, - video_start_pts, - video_end_pts, - container.streams.video[0], - {"video": 0}, + min_delta=min_delta, + max_delta=max_delta, ) + frames_out = [None] * len(num_frames) + for k in range(len(num_frames)): + start_idx = start_end_delta_time[k, 0] + end_idx = start_end_delta_time[k, 1] + timebase = duration / frames_length + video_start_pts = int(start_idx * timebase) + video_end_pts = int(end_idx * timebase) + + frames = None + # If video stream was found, fetch video frames from the video. + if container.streams.video: + video_frames, max_pts = pyav_decode_stream( + container, + video_start_pts, + video_end_pts, + container.streams.video[0], + {"video": 0}, + ) + + frames = [frame.to_rgb().to_ndarray() for frame in video_frames] + frames = torch.as_tensor(np.stack(frames)) + + frames_out[k] = frames container.close() - frames = [frame.to_rgb().to_ndarray() for frame in video_frames] - frames = torch.as_tensor(np.stack(frames)) - return frames, fps, decode_all_video + return frames_out, fps, decode_all_video, start_end_delta_time def decode( @@ -504,7 +548,7 @@ def decode( if backend == "pyav": assert min_delta == -math.inf and max_delta == math.inf, \ "delta sampling not supported in pyav" - frames_decoded, fps, decode_all_video = pyav_decode( + frames_decoded, fps, decode_all_video, start_end_delta_time = pyav_decode( container, sampling_rate, num_frames, @@ -512,6 +556,10 @@ def decode( num_clips_uniform, target_fps, use_offset=use_offset, + modalities=("visual",), + max_spatial_scale=max_spatial_scale, + min_delta=min_delta, + max_delta=max_delta, ) elif backend == "torchvision": ( @@ -551,7 +599,7 @@ def decode( clip_sizes = [ np.maximum( 1.0, - np.ceil(sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps), + sampling_rate[i] * num_frames[i] / target_fps * fps ) for i in range(len(sampling_rate)) ] @@ -565,6 +613,7 @@ def decode( num_clips_uniform if decode_all_video else 1, min_delta=min_delta, max_delta=max_delta, + use_offset=use_offset, ) frames_out, start_inds, time_diff_aug = (