Skip to content

Commit

Permalink
Merge pull request facebookresearch#541
Browse files Browse the repository at this point in the history
  • Loading branch information
Yi Li committed Jul 23, 2022
2 parents 67c61a9 + e0d247a commit b6a45f1
Showing 1 changed file with 71 additions and 33 deletions.
104 changes: 71 additions & 33 deletions slowfast/datasets/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ def pyav_decode(
num_clips_uniform=10,
target_fps=30,
use_offset=False,
modalities=("visual",),
max_spatial_scale=0,
min_delta=-math.inf,
max_delta=math.inf,
):
"""
Convert the video from its original fps to the target_fps. If the video
Expand Down Expand Up @@ -419,38 +423,69 @@ def pyav_decode(
# If failed to fetch the decoding information, decode the entire video.
decode_all_video = True
video_start_pts, video_end_pts = 0, math.inf
start_end_delta_time = None

frames = None
if container.streams.video:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
container.close()

frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))
frames_out = [frames]

else:
# Perform selective decoding.
decode_all_video = False
clip_size = np.maximum(
1.0, np.ceil(sampling_rate * (num_frames - 1) / target_fps * fps)
)
start_idx, end_idx, fraction = get_start_end_idx(
clip_sizes = [
np.maximum(
1.0,
np.ceil(
sampling_rate[i] * (num_frames[i] - 1) / target_fps * fps
),
)
for i in range(len(sampling_rate))
]
start_end_delta_time = get_multiple_start_end_idx(
frames_length,
clip_size,
clip_sizes,
clip_idx,
num_clips_uniform,
use_offset=use_offset,
)
timebase = duration / frames_length
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)

frames = None
# If video stream was found, fetch video frames from the video.
if container.streams.video:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
min_delta=min_delta,
max_delta=max_delta,
)
frames_out = [None] * len(num_frames)
for k in range(len(num_frames)):
start_idx = start_end_delta_time[k, 0]
end_idx = start_end_delta_time[k, 1]
timebase = duration / frames_length
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)

frames = None
# If video stream was found, fetch video frames from the video.
if container.streams.video:
video_frames, max_pts = pyav_decode_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)

frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))

frames_out[k] = frames
container.close()

frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
frames = torch.as_tensor(np.stack(frames))
return frames, fps, decode_all_video
return frames_out, fps, decode_all_video, start_end_delta_time


def decode(
Expand Down Expand Up @@ -510,17 +545,20 @@ def decode(
) # clips come temporally ordered from decoder
try:
if backend == "pyav":
assert (
min_delta == -math.inf and max_delta == math.inf
), "delta sampling not supported in pyav"
frames_decoded, fps, decode_all_video = pyav_decode(
assert min_delta == -math.inf and max_delta == math.inf, \
"delta sampling not supported in pyav"
frames_decoded, fps, decode_all_video, start_end_delta_time = pyav_decode(
container,
sampling_rate,
num_frames,
clip_idx,
num_clips_uniform,
target_fps,
use_offset=use_offset,
modalities=("visual",),
max_spatial_scale=max_spatial_scale,
min_delta=min_delta,
max_delta=max_delta,
)
elif backend == "torchvision":
(
Expand Down Expand Up @@ -558,12 +596,12 @@ def decode(
frames_decoded = [frames_decoded]
num_decoded = len(frames_decoded)
clip_sizes = [
np.maximum(
1.0,
sampling_rate[i] * num_frames[i] / target_fps * fps
)
for i in range(len(sampling_rate))
]
np.maximum(
1.0,
sampling_rate[i] * num_frames[i] / target_fps * fps
)
for i in range(len(sampling_rate))
]

if decode_all_video: # full video was decoded (not trimmed yet)
assert num_decoded == 1 and start_end_delta_time is None
Expand Down

0 comments on commit b6a45f1

Please sign in to comment.