Skip to content

Commit

Permalink
[torchcodec] stream_index should not be optional for getting frame by…
Browse files Browse the repository at this point in the history
… index

Differential Revision: D58505608
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 18, 2024
1 parent de240a4 commit ca4d83e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
14 changes: 6 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("get_next_frame(Tensor(a!) decoder) -> Tensor");
m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor");
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int? stream_index=None) -> Tensor");
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int? stream_index=None) -> Tensor");
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
}

Expand Down Expand Up @@ -131,22 +131,20 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) {
at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
std::optional<int64_t> stream_index) {
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto result =
videoDecoder->getFrameAtIndex(stream_index.value_or(-1), frame_index);
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return result.frame;
}

at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
std::optional<int64_t> stream_index) {
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndexes(
stream_index.value_or(-1), frameIndicesVec);
auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec);
return result.frames;
}

Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_frame_at_pts_abstract(decoder: torch.Tensor, seconds: float) -> torch.Te

@register_fake("torchcodec_ns::get_frame_at_index")
def get_frame_at_index_abstract(
decoder: torch.Tensor, *, frame_index: int, stream_index: Optional[int] = None
decoder: torch.Tensor, *, frame_index: int, stream_index: int
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(3)]
return torch.empty(image_size)
Expand All @@ -125,7 +125,7 @@ def get_frames_at_indices_abstract(
decoder: torch.Tensor,
*,
frame_indices: List[int],
stream_index: Optional[int] = None,
stream_index: int,
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
Expand Down

0 comments on commit ca4d83e

Please sign in to comment.