From f2fb94f65cb59449002e0b30d117376c4966f3e9 Mon Sep 17 00:00:00 2001 From: Scott Lamb Date: Fri, 11 Jun 2021 13:13:01 -0700 Subject: [PATCH] s/rbsp::RbspBitReader<'a>/rbsp::BitReader/ For dholroyd/h264-reader#4: make the bit reader type take a BufRead rather than a slice so we don't have to keep a buffered copy of the RBSP. While I'm at it, reduce "stuttering" by taking the module name out of the struct name. This is intrusive but mechanical. --- fuzz/fuzz_targets/fuzz_target_1.rs | 2 +- src/nal/pps.rs | 29 ++++++------ src/nal/sei/buffering_period.rs | 15 +++--- src/nal/sei/pic_timing.rs | 22 +++++---- src/nal/slice/mod.rs | 23 +++++----- src/nal/sps.rs | 69 ++++++++++++++-------------- src/rbsp.rs | 74 ++++++++++++++---------------- 7 files changed, 118 insertions(+), 116 deletions(-) diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs index daf5f15..ba85779 100644 --- a/fuzz/fuzz_targets/fuzz_target_1.rs +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -113,7 +113,7 @@ impl h264_reader::nal::NalHandler for SliceFuzz { decode.push(ctx, ¤t_slice.buf[..]); decode.end(ctx); let capture = decode.into_handler(); - let mut r = rbsp::RbspBitReader::new(&capture.buf[1..]); + let mut r = rbsp::BitReader::new(&capture.buf[1..]); match nal::slice::SliceHeader::read(ctx, &mut r, current_slice.header) { Ok((header, sps, pps)) => { println!("{:#?}", header); diff --git a/src/nal/pps.rs b/src/nal/pps.rs index 06e11c2..4a39748 100644 --- a/src/nal/pps.rs +++ b/src/nal/pps.rs @@ -1,14 +1,15 @@ use super::NalHandler; use super::NalHeader; use super::sps; +use std::io::BufRead; use std::marker; use crate::{rbsp, Context}; -use crate::rbsp::RbspBitReader; +use crate::rbsp::BitReader; use log::*; #[derive(Debug)] pub enum PpsError { - RbspReaderError(rbsp::RbspBitReaderError), + RbspReaderError(rbsp::BitReaderError), InvalidSliceGroupMapType(u32), InvalidSliceGroupChangeType(u32), UnknownSeqParamSetId(ParamSetId), @@ -17,8 +18,8 @@ pub enum PpsError { ScalingMatrix(sps::ScalingMatrixError), } -impl From for PpsError { - fn from(e: rbsp::RbspBitReaderError) -> Self { +impl From for PpsError { + fn from(e: rbsp::BitReaderError) -> Self { PpsError::RbspReaderError(e) } } @@ -46,7 +47,7 @@ pub struct SliceRect { bottom_right: u32, } impl SliceRect { - fn read(r: &mut RbspBitReader<'_>) -> Result { + fn read(r: &mut BitReader) -> Result { Ok(SliceRect { top_left: r.read_ue_named("top_left")?, bottom_right: r.read_ue_named("bottom_right")?, @@ -77,7 +78,7 @@ pub enum SliceGroup { }, } impl SliceGroup { - fn read(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result { + fn read(r: &mut BitReader, num_slice_groups_minus1: u32) -> Result { let slice_group_map_type = r.read_ue_named("slice_group_map_type")?; match slice_group_map_type { 0 => Ok(SliceGroup::Interleaved { @@ -103,7 +104,7 @@ impl SliceGroup { } } - fn read_run_lengths(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result,PpsError> { + fn read_run_lengths(r: &mut BitReader, num_slice_groups_minus1: u32) -> Result,PpsError> { let mut run_length_minus1 = Vec::with_capacity(num_slice_groups_minus1 as usize + 1); for _ in 0..num_slice_groups_minus1+1 { run_length_minus1.push(r.read_ue_named("run_length_minus1")?); @@ -111,7 +112,7 @@ impl SliceGroup { Ok(run_length_minus1) } - fn read_rectangles(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result,PpsError> { + fn read_rectangles(r: &mut BitReader, num_slice_groups_minus1: u32) -> Result,PpsError> { let mut run_length_minus1 = Vec::with_capacity(num_slice_groups_minus1 as usize + 1); for _ in 0..num_slice_groups_minus1+1 { run_length_minus1.push(SliceRect::read(r)?); @@ -119,7 +120,7 @@ impl SliceGroup { Ok(run_length_minus1) } - fn read_group_ids(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result,PpsError> { + fn read_group_ids(r: &mut BitReader, num_slice_groups_minus1: u32) -> Result,PpsError> { let pic_size_in_map_units_minus1 = r.read_ue_named("pic_size_in_map_units_minus1")?; // TODO: avoid any panics due to failed conversions let size = ((1f64+f64::from(pic_size_in_map_units_minus1)).log2()) as u8; @@ -136,7 +137,7 @@ struct PicScalingMatrix { // TODO } impl PicScalingMatrix { - fn read(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet, transform_8x8_mode_flag: bool) -> Result,PpsError> { + fn read(r: &mut BitReader, sps: &sps::SeqParameterSet, transform_8x8_mode_flag: bool) -> Result,PpsError> { let pic_scaling_matrix_present_flag = r.read_bool()?; Ok(if pic_scaling_matrix_present_flag { let mut scaling_list4x4 = vec!(); @@ -171,8 +172,8 @@ pub struct PicParameterSetExtra { second_chroma_qp_index_offset: i32, } impl PicParameterSetExtra { - fn read(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result,PpsError> { - Ok(if r.has_more_rbsp_data() { + fn read(r: &mut BitReader, sps: &sps::SeqParameterSet) -> Result,PpsError> { + Ok(if r.has_more_rbsp_data("transform_8x8_mode_flag")? { let transform_8x8_mode_flag = r.read_bool()?; Some(PicParameterSetExtra { transform_8x8_mode_flag, @@ -226,7 +227,7 @@ pub struct PicParameterSet { } impl PicParameterSet { pub fn from_bytes(ctx: &Context, buf: &[u8]) -> Result { - let mut r = RbspBitReader::new(buf); + let mut r = BitReader::new(buf); let pic_parameter_set_id = ParamSetId::from_u32(r.read_ue_named("pic_parameter_set_id")?) .map_err(PpsError::BadPicParamSetId)?; let seq_parameter_set_id = ParamSetId::from_u32(r.read_ue_named("seq_parameter_set_id")?) @@ -253,7 +254,7 @@ impl PicParameterSet { }) } - fn read_slice_groups(r: &mut RbspBitReader<'_>) -> Result,PpsError> { + fn read_slice_groups(r: &mut BitReader) -> Result,PpsError> { let num_slice_groups_minus1 = r.read_ue_named("num_slice_groups_minus1")?; Ok(if num_slice_groups_minus1 > 0 { Some(SliceGroup::read(r, num_slice_groups_minus1)?) diff --git a/src/nal/sei/buffering_period.rs b/src/nal/sei/buffering_period.rs index 15374f5..8f88c6f 100644 --- a/src/nal/sei/buffering_period.rs +++ b/src/nal/sei/buffering_period.rs @@ -1,20 +1,21 @@ use super::SeiCompletePayloadReader; +use std::io::BufRead; use std::marker; use crate::nal::pps; -use crate::rbsp::RbspBitReader; +use crate::rbsp::BitReader; use crate::Context; use crate::nal::sei::HeaderType; -use crate::rbsp::RbspBitReaderError; +use crate::rbsp::BitReaderError; use log::*; #[derive(Debug)] enum BufferingPeriodError { - ReaderError(RbspBitReaderError), + ReaderError(BitReaderError), UndefinedSeqParamSetId(pps::ParamSetId), InvalidSeqParamSetId(pps::ParamSetIdError), } -impl From for BufferingPeriodError { - fn from(e: RbspBitReaderError) -> Self { +impl From for BufferingPeriodError { + fn from(e: BitReaderError) -> Self { BufferingPeriodError::ReaderError(e) } } @@ -30,7 +31,7 @@ struct InitialCpbRemoval { initial_cpb_removal_delay_offset: u32, } -fn read_cpb_removal_delay_list(r: &mut RbspBitReader<'_>, count: usize, length: u8) -> Result,RbspBitReaderError> { +fn read_cpb_removal_delay_list(r: &mut BitReader, count: usize, length: u8) -> Result,BitReaderError> { let mut res = vec!(); for _ in 0..count { res.push(InitialCpbRemoval { @@ -48,7 +49,7 @@ struct BufferingPeriod { } impl BufferingPeriod { fn read(ctx: &Context, buf: &[u8]) -> Result { - let mut r = RbspBitReader::new(buf); + let mut r = BitReader::new(buf); let seq_parameter_set_id = pps::ParamSetId::from_u32(r.read_ue_named("seq_parameter_set_id")?)?; match ctx.sps_by_id(seq_parameter_set_id) { None => Err(BufferingPeriodError::UndefinedSeqParamSetId(seq_parameter_set_id)), diff --git a/src/nal/sei/pic_timing.rs b/src/nal/sei/pic_timing.rs index bc1d72b..ffc99c4 100644 --- a/src/nal/sei/pic_timing.rs +++ b/src/nal/sei/pic_timing.rs @@ -1,10 +1,12 @@ +use std::io::BufRead; + use crate::nal::sei::SeiCompletePayloadReader; use crate::Context; use crate::nal::sei::HeaderType; use crate::nal::pps::ParamSetId; -use crate::rbsp::RbspBitReader; +use crate::rbsp::BitReader; use crate::nal::sps; -use crate::rbsp::RbspBitReaderError; +use crate::rbsp::BitReaderError; use log::*; // FIXME: SPS selection @@ -14,12 +16,12 @@ use log::*; #[derive(Debug)] pub enum PicTimingError { - RbspError(RbspBitReaderError), + RbspError(BitReaderError), UndefinedSeqParamSetId(ParamSetId), InvalidPicStructId(u8), } -impl From for PicTimingError { - fn from(e: RbspBitReaderError) -> Self { +impl From for PicTimingError { + fn from(e: BitReaderError) -> Self { PicTimingError::RbspError(e) } } @@ -175,7 +177,7 @@ pub struct ClockTimestamp { pub time_offset: Option, } impl ClockTimestamp { - fn read(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result { + fn read(r: &mut BitReader, sps: &sps::SeqParameterSet) -> Result { let ct_type = CtType::from_id(r.read_u8(2)?); let nuit_field_based_flag = r.read_bool_named("nuit_field_based_flag")?; let counting_type = CountingType::from_id(r.read_u8(5)?); @@ -247,7 +249,7 @@ pub struct PicTiming { } impl PicTiming { pub fn read(ctx: &mut Context, buf: &[u8]) -> Result { - let mut r = RbspBitReader::new(buf); + let mut r = BitReader::new(buf); let seq_parameter_set_id = ParamSetId::from_u32(0).unwrap(); match ctx.sps_by_id(seq_parameter_set_id) { None => Err(PicTimingError::UndefinedSeqParamSetId(seq_parameter_set_id)), @@ -260,7 +262,7 @@ impl PicTiming { } } - fn read_delays(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result,PicTimingError> { + fn read_delays(r: &mut BitReader, sps: &sps::SeqParameterSet) -> Result,PicTimingError> { Ok(if let Some(ref vui_params) = sps.vui_parameters { if let Some(ref hrd) = vui_params.nal_hrd_parameters.as_ref().or_else(|| vui_params.nal_hrd_parameters.as_ref() ) { Some(Delays { @@ -275,7 +277,7 @@ impl PicTiming { }) } - fn read_pic_struct(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result,PicTimingError> { + fn read_pic_struct(r: &mut BitReader, sps: &sps::SeqParameterSet) -> Result,PicTimingError> { Ok(if let Some(ref vui_params) = sps.vui_parameters { if vui_params.pic_struct_present_flag { let pic_struct = PicStructType::from_id(r.read_u8(4)?)?; @@ -293,7 +295,7 @@ impl PicTiming { }) } - fn read_clock_timestamps(r: &mut RbspBitReader<'_>, pic_struct: &PicStructType, sps: &sps::SeqParameterSet) -> Result>,PicTimingError> { + fn read_clock_timestamps(r: &mut BitReader, pic_struct: &PicStructType, sps: &sps::SeqParameterSet) -> Result>,PicTimingError> { let mut res = Vec::new(); for _ in 0..pic_struct.num_clock_timestamps() { res.push(if r.read_bool_named("clock_timestamp_flag")? { diff --git a/src/nal/slice/mod.rs b/src/nal/slice/mod.rs index 478091e..de2369f 100644 --- a/src/nal/slice/mod.rs +++ b/src/nal/slice/mod.rs @@ -1,10 +1,11 @@ use crate::Context; -use crate::rbsp::RbspBitReader; -use crate::rbsp::RbspBitReaderError; +use crate::rbsp::BitReader; +use crate::rbsp::BitReaderError; use crate::nal::pps::{ParamSetId, PicParameterSet}; use crate::nal::pps; use crate::nal::sps; +use std::io::BufRead; use std::marker; use crate::nal::sps::SeqParameterSet; use crate::nal::NalHeader; @@ -56,7 +57,7 @@ impl SliceType { #[derive(Debug)] pub enum SliceHeaderError { - RbspError(RbspBitReaderError), + RbspError(BitReaderError), InvalidSliceType(u32), InvalidSeqParamSetId(pps::ParamSetIdError), UndefinedPicParamSetId(pps::ParamSetId), @@ -72,8 +73,8 @@ pub enum SliceHeaderError { /// The header contained syntax elements that the parser isn't able to handle yet UnsupportedSyntax(&'static str), } -impl From for SliceHeaderError { - fn from(e: RbspBitReaderError) -> Self { +impl From for SliceHeaderError { + fn from(e: BitReaderError) -> Self { SliceHeaderError::RbspError(e) } } @@ -163,7 +164,7 @@ enum RefPicListModifications { }, } impl RefPicListModifications { - fn read(slice_family: &SliceFamily, r: &mut RbspBitReader<'_>) -> Result { + fn read(slice_family: &SliceFamily, r: &mut BitReader) -> Result { Ok(match slice_family { SliceFamily::I | SliceFamily::SI => RefPicListModifications::I, SliceFamily::B => RefPicListModifications::B { @@ -176,7 +177,7 @@ impl RefPicListModifications { }) } - fn read_list(r: &mut RbspBitReader<'_>) -> Result, SliceHeaderError> { + fn read_list(r: &mut BitReader) -> Result, SliceHeaderError> { let mut result = vec![]; // either ref_pic_list_modification_flag_l0 or ref_pic_list_modification_flag_l1 depending // on call-site, @@ -209,7 +210,7 @@ struct PredWeightTable { chroma_weights: Vec>, } impl PredWeightTable { - fn read(r: &mut RbspBitReader<'_>, slice_type: &SliceType, pps: &pps::PicParameterSet, sps: &sps::SeqParameterSet, num_ref_active: &Option) -> Result { + fn read(r: &mut BitReader, slice_type: &SliceType, pps: &pps::PicParameterSet, sps: &sps::SeqParameterSet, num_ref_active: &Option) -> Result { let chroma_array_type = if sps.chroma_info.separate_colour_plane_flag { // TODO: "Otherwise (separate_colour_plane_flag is equal to 1), ChromaArrayType is // set equal to 0." ...does this mean ChromaFormat::Monochrome then? @@ -295,7 +296,7 @@ enum DecRefPicMarking { Adaptive(Vec), } impl DecRefPicMarking { - fn read(r: &mut RbspBitReader<'_>, header: NalHeader) -> Result { + fn read(r: &mut BitReader, header: NalHeader) -> Result { Ok(if header.nal_unit_type() == crate::nal::UnitType::SliceLayerWithoutPartitioningIdr { DecRefPicMarking::Idr { no_output_of_prior_pics_flag: r.read_bool_named("no_output_of_prior_pics_flag")?, @@ -363,7 +364,7 @@ pub struct SliceHeader { disable_deblocking_filter_idc: u8, } impl SliceHeader { - pub fn read<'a, Ctx>(ctx: &'a mut Context, r: &mut RbspBitReader<'_>, header: NalHeader) -> Result<(SliceHeader, &'a SeqParameterSet, &'a PicParameterSet), SliceHeaderError> { + pub fn read<'a, Ctx, R: BufRead>(ctx: &'a mut Context, r: &mut BitReader, header: NalHeader) -> Result<(SliceHeader, &'a SeqParameterSet, &'a PicParameterSet), SliceHeaderError> { let first_mb_in_slice = r.read_ue_named("first_mb_in_slice")?; let slice_type = SliceType::from_id(r.read_ue_named("slice_type")?)?; let pic_parameter_set_id = ParamSetId::from_u32(r.read_ue_named("pic_parameter_set_id")?)?; @@ -543,7 +544,7 @@ impl super::NalHandler for SliceLayerWithoutPartitioningRbsp { match self.state { ParseState::Unstarted => panic!("start() not yet called"), ParseState::Start(header) => { - let mut r = RbspBitReader::new(buf); + let mut r = BitReader::new(buf); match SliceHeader::read(ctx, &mut r, header) { Ok(header) => info!("TODO: expose to caller: {:#?}", header), Err(e) => error!("slice_header() error: SliceHeaderError::{:?}", e), diff --git a/src/nal/sps.rs b/src/nal/sps.rs index 7cab8ff..fc92a54 100644 --- a/src/nal/sps.rs +++ b/src/nal/sps.rs @@ -1,9 +1,10 @@ -use crate::rbsp::RbspBitReader; +use crate::rbsp::BitReader; use super::NalHandler; use super::NalHeader; use crate::Context; -use crate::rbsp::RbspBitReaderError; +use crate::rbsp::BitReaderError; +use std::io::BufRead; use std::{marker, fmt}; use crate::nal::pps::ParamSetId; use crate::nal::pps::ParamSetIdError; @@ -13,7 +14,7 @@ use std::fmt::Debug; pub enum SpsError { /// Signals that bit_depth_luma_minus8 was greater than the max value, 6 BitDepthOutOfRange(u32), - RbspReaderError(RbspBitReaderError), + RbspReaderError(BitReaderError), PicOrderCnt(PicOrderCntError), ScalingMatrix(ScalingMatrixError), /// log2_max_frame_num_minus4 must be between 0 and 12 @@ -27,8 +28,8 @@ pub enum SpsError { CpbCountOutOfRange(u32), } -impl From for SpsError { - fn from(e: RbspBitReaderError) -> Self { +impl From for SpsError { + fn from(e: BitReaderError) -> Self { SpsError::RbspReaderError(e) } } @@ -279,7 +280,7 @@ pub struct ScalingList { // TODO } impl ScalingList { - pub fn read(r: &mut RbspBitReader<'_>, size: u8) -> Result { + pub fn read(r: &mut BitReader, size: u8) -> Result { let mut scaling_list = vec!(); let mut last_scale = 8; let mut next_scale = 8; @@ -303,13 +304,13 @@ impl ScalingList { #[derive(Debug)] pub enum ScalingMatrixError { - ReaderError(RbspBitReaderError), + ReaderError(BitReaderError), /// The `delta_scale` field must be between -128 and 127 inclusive. DeltaScaleOutOfRange(i32), } -impl From for ScalingMatrixError { - fn from(e: RbspBitReaderError) -> Self { +impl From for ScalingMatrixError { + fn from(e: BitReaderError) -> Self { ScalingMatrixError::ReaderError(e) } } @@ -324,7 +325,7 @@ impl Default for SeqScalingMatrix { } } impl SeqScalingMatrix { - fn read(r: &mut RbspBitReader<'_>, chroma_format_idc: u32) -> Result { + fn read(r: &mut BitReader, chroma_format_idc: u32) -> Result { let mut scaling_list4x4 = vec!(); let mut scaling_list8x8 = vec!(); @@ -353,7 +354,7 @@ pub struct ChromaInfo { pub scaling_matrix: SeqScalingMatrix, } impl ChromaInfo { - pub fn read(r: &mut RbspBitReader<'_>, profile_idc: ProfileIdc) -> Result { + pub fn read(r: &mut BitReader, profile_idc: ProfileIdc) -> Result { if profile_idc.has_chroma_info() { let chroma_format_idc = r.read_ue_named("chroma_format_idc")?; Ok(ChromaInfo { @@ -375,7 +376,7 @@ impl ChromaInfo { }) } } - fn read_bit_depth_minus8(r: &mut RbspBitReader<'_>) -> Result { + fn read_bit_depth_minus8(r: &mut BitReader) -> Result { let value = r.read_ue_named("read_bit_depth_minus8")?; if value > 6 { Err(SpsError::BitDepthOutOfRange(value)) @@ -383,7 +384,7 @@ impl ChromaInfo { Ok(value as u8) } } - fn read_scaling_matrix(r: &mut RbspBitReader<'_>, chroma_format_idc: u32) -> Result { + fn read_scaling_matrix(r: &mut BitReader, chroma_format_idc: u32) -> Result { let scaling_matrix_present_flag = r.read_bool()?; if scaling_matrix_present_flag { SeqScalingMatrix::read(r, chroma_format_idc).map_err(SpsError::ScalingMatrix) @@ -396,15 +397,15 @@ impl ChromaInfo { #[derive(Debug)] pub enum PicOrderCntError { InvalidPicOrderCountType(u32), - ReaderError(RbspBitReaderError), + ReaderError(BitReaderError), /// log2_max_pic_order_cnt_lsb_minus4 must be between 0 and 12 Log2MaxPicOrderCntLsbMinus4OutOfRange(u32), /// num_ref_frames_in_pic_order_cnt_cycle must be between 0 and 255 NumRefFramesInPicOrderCntCycleOutOfRange(u32), } -impl From for PicOrderCntError { - fn from(e: RbspBitReaderError) -> Self { +impl From for PicOrderCntError { + fn from(e: BitReaderError) -> Self { PicOrderCntError::ReaderError(e) } } @@ -423,7 +424,7 @@ pub enum PicOrderCntType { TypeTwo } impl PicOrderCntType { - fn read(r: &mut RbspBitReader<'_>) -> Result { + fn read(r: &mut BitReader) -> Result { let pic_order_cnt_type = r.read_ue_named("pic_order_cnt_type")?; match pic_order_cnt_type { 0 => { @@ -448,7 +449,7 @@ impl PicOrderCntType { } } - fn read_log2_max_pic_order_cnt_lsb_minus4(r: &mut RbspBitReader<'_>) -> Result { + fn read_log2_max_pic_order_cnt_lsb_minus4(r: &mut BitReader) -> Result { let val = r.read_ue_named("log2_max_pic_order_cnt_lsb_minus4")?; if val > 12 { Err(PicOrderCntError::Log2MaxPicOrderCntLsbMinus4OutOfRange(val)) @@ -457,7 +458,7 @@ impl PicOrderCntType { } } - fn read_offsets_for_ref_frame(r: &mut RbspBitReader<'_>) -> Result, PicOrderCntError> { + fn read_offsets_for_ref_frame(r: &mut BitReader) -> Result, PicOrderCntError> { let num_ref_frames_in_pic_order_cnt_cycle = r.read_ue_named("num_ref_frames_in_pic_order_cnt_cycle")?; if num_ref_frames_in_pic_order_cnt_cycle > 255 { return Err(PicOrderCntError::NumRefFramesInPicOrderCntCycleOutOfRange(num_ref_frames_in_pic_order_cnt_cycle)); @@ -478,7 +479,7 @@ pub enum FrameMbsFlags { } } impl FrameMbsFlags { - fn read(r: &mut RbspBitReader<'_>) -> Result { + fn read(r: &mut BitReader) -> Result { let frame_mbs_only_flag = r.read_bool()?; if frame_mbs_only_flag { Ok(FrameMbsFlags::Frames) @@ -498,7 +499,7 @@ pub struct FrameCropping { pub bottom_offset: u32, } impl FrameCropping { - fn read(r: &mut RbspBitReader<'_>) -> Result, RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result, BitReaderError> { let frame_cropping_flag = r.read_bool()?; Ok(if frame_cropping_flag { Some(FrameCropping { @@ -537,7 +538,7 @@ pub enum AspectRatioInfo { } impl AspectRatioInfo { - fn read(r: &mut RbspBitReader<'_>) -> Result, RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result, BitReaderError> { let aspect_ratio_info_present_flag = r.read_bool()?; Ok(if aspect_ratio_info_present_flag { let aspect_ratio_idc = r.read_u8(8)?; @@ -609,7 +610,7 @@ pub enum OverscanAppropriate { Inappropriate, } impl OverscanAppropriate { - fn read(r: &mut RbspBitReader<'_>) -> Result { + fn read(r: &mut BitReader) -> Result { let overscan_info_present_flag = r.read_bool()?; Ok(if overscan_info_present_flag { let overscan_appropriate_flag = r.read_bool()?; @@ -656,7 +657,7 @@ pub struct ColourDescription { matrix_coefficients: u8, } impl ColourDescription { - fn read(r: &mut RbspBitReader<'_>) -> Result, RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result, BitReaderError> { let colour_description_present_flag = r.read_bool()?; Ok(if colour_description_present_flag { Some(ColourDescription { @@ -677,7 +678,7 @@ pub struct VideoSignalType { colour_description: Option, } impl VideoSignalType { - fn read(r: &mut RbspBitReader<'_>) -> Result, RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result, BitReaderError> { let video_signal_type_present_flag = r.read_bool()?; Ok(if video_signal_type_present_flag { Some(VideoSignalType { @@ -697,7 +698,7 @@ pub struct ChromaLocInfo { chroma_sample_loc_type_bottom_field: u32, } impl ChromaLocInfo { - fn read(r: &mut RbspBitReader<'_>) -> Result, RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result, BitReaderError> { let chroma_loc_info_present_flag = r.read_bool()?; Ok(if chroma_loc_info_present_flag { Some(ChromaLocInfo { @@ -717,7 +718,7 @@ pub struct TimingInfo { pub fixed_frame_rate_flag: bool, } impl TimingInfo { - fn read(r: &mut RbspBitReader<'_>) -> Result, RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result, BitReaderError> { let timing_info_present_flag = r.read_bool()?; Ok(if timing_info_present_flag { Some(TimingInfo { @@ -738,7 +739,7 @@ pub struct CpbSpec { cbr_flag: bool, } impl CpbSpec { - fn read(r: &mut RbspBitReader<'_>) -> Result { + fn read(r: &mut BitReader) -> Result { Ok(CpbSpec { bit_rate_value_minus1: r.read_ue_named("bit_rate_value_minus1")?, cpb_size_value_minus1: r.read_ue_named("cpb_size_value_minus1")?, @@ -759,7 +760,7 @@ pub struct HrdParameters { pub time_offset_length: u8, } impl HrdParameters { - fn read(r: &mut RbspBitReader<'_>, hrd_parameters_present: &mut bool) -> Result, SpsError> { + fn read(r: &mut BitReader, hrd_parameters_present: &mut bool) -> Result, SpsError> { let hrd_parameters_present_flag = r.read_bool_named("hrd_parameters_present_flag")?; *hrd_parameters_present |= hrd_parameters_present_flag; Ok(if hrd_parameters_present_flag { @@ -781,7 +782,7 @@ impl HrdParameters { None }) } - fn read_cpb_specs(r: &mut RbspBitReader<'_>, cpb_cnt: u32) -> Result,RbspBitReaderError> { + fn read_cpb_specs(r: &mut BitReader, cpb_cnt: u32) -> Result,BitReaderError> { let mut cpb_specs = Vec::with_capacity(cpb_cnt as usize); for _ in 0..cpb_cnt { cpb_specs.push(CpbSpec::read(r)?); @@ -801,7 +802,7 @@ pub struct BitstreamRestrictions { max_dec_frame_buffering: u32, } impl BitstreamRestrictions { - fn read(r: &mut RbspBitReader<'_>) -> Result,RbspBitReaderError> { + fn read(r: &mut BitReader) -> Result,BitReaderError> { let bitstream_restriction_flag = r.read_bool()?; Ok(if bitstream_restriction_flag { Some(BitstreamRestrictions { @@ -833,7 +834,7 @@ pub struct VuiParameters { pub bitstream_restrictions: Option, } impl VuiParameters { - fn read(r: &mut RbspBitReader<'_>) -> Result, SpsError> { + fn read(r: &mut BitReader) -> Result, SpsError> { let vui_parameters_present_flag = r.read_bool()?; Ok(if vui_parameters_present_flag { let mut hrd_parameters_present = false; @@ -875,7 +876,7 @@ pub struct SeqParameterSet { } impl SeqParameterSet { pub fn from_bytes(buf: &[u8]) -> Result { - let mut r = RbspBitReader::new(buf); + let mut r = BitReader::new(buf); let profile_idc = r.read_u8(8)?.into(); let sps = SeqParameterSet { profile_idc, @@ -897,7 +898,7 @@ impl SeqParameterSet { Ok(sps) } - fn read_log2_max_frame_num_minus4(r: &mut RbspBitReader<'_>) -> Result { + fn read_log2_max_frame_num_minus4(r: &mut BitReader) -> Result { let val = r.read_ue_named("log2_max_frame_num_minus4")?; if val > 12 { Err(SpsError::Log2MaxFrameNumMinus4OutOfRange(val)) diff --git a/src/rbsp.rs b/src/rbsp.rs index 6cb4c79..f0c1c5c 100644 --- a/src/rbsp.rs +++ b/src/rbsp.rs @@ -185,14 +185,8 @@ pub fn decode_nal<'a>(nal_unit: &'a [u8]) -> Cow<'a, [u8]> { decoder.into_handler().data } -impl From for RbspBitReaderError { - fn from(e: std::io::Error) -> Self { - RbspBitReaderError::ReaderError(e) - } -} - #[derive(Debug)] -pub enum RbspBitReaderError { +pub enum BitReaderError { ReaderError(std::io::Error), ReaderErrorFor(&'static str, std::io::Error), @@ -200,17 +194,17 @@ pub enum RbspBitReaderError { ExpGolombTooLarge(&'static str), } -pub struct RbspBitReader<'buf> { - reader: bitstream_io::read::BitReader, bitstream_io::BigEndian>, +/// Reads H.264 bitstream syntax elements from an RBSP representation (no NAL +/// header byte or emulation prevention three bytes). +pub struct BitReader { + reader: bitstream_io::read::BitReader, } -impl<'buf> RbspBitReader<'buf> { - pub fn new(buf: &'buf [u8]) -> Self { - RbspBitReader { - reader: bitstream_io::read::BitReader::new(std::io::Cursor::new(buf)), - } +impl BitReader { + pub fn new(inner: R) -> Self { + Self { reader: bitstream_io::read::BitReader::new(inner) } } - pub fn read_ue_named(&mut self, name: &'static str) -> Result { + pub fn read_ue_named(&mut self, name: &'static str) -> Result { let count = count_zero_bits(&mut self.reader, name)?; if count > 0 { let val = self.read_u32(count)?; @@ -220,37 +214,39 @@ impl<'buf> RbspBitReader<'buf> { } } - pub fn read_se_named(&mut self, name: &'static str) -> Result { + pub fn read_se_named(&mut self, name: &'static str) -> Result { Ok(Self::golomb_to_signed(self.read_ue_named(name)?)) } - pub fn read_bool(&mut self) -> Result { - self.reader.read_bit().map_err( |e| RbspBitReaderError::ReaderError(e) ) + pub fn read_bool(&mut self) -> Result { + self.reader.read_bit().map_err( |e| BitReaderError::ReaderError(e) ) } - pub fn read_bool_named(&mut self, name: &'static str) -> Result { - self.reader.read_bit().map_err( |e| RbspBitReaderError::ReaderErrorFor(name, e) ) + pub fn read_bool_named(&mut self, name: &'static str) -> Result { + self.reader.read_bit().map_err( |e| BitReaderError::ReaderErrorFor(name, e) ) } - pub fn read_u8(&mut self, bit_count: u32) -> Result { - self.reader.read(u32::from(bit_count)).map_err( |e| RbspBitReaderError::ReaderError(e) ) + pub fn read_u8(&mut self, bit_count: u32) -> Result { + self.reader.read(u32::from(bit_count)).map_err(BitReaderError::ReaderError) } - pub fn read_u16(&mut self, bit_count: u8) -> Result { - self.reader.read(u32::from(bit_count)).map_err( |e| RbspBitReaderError::ReaderError(e) ) + pub fn read_u16(&mut self, bit_count: u8) -> Result { + self.reader.read(u32::from(bit_count)).map_err(BitReaderError::ReaderError) } - pub fn read_u32(&mut self, bit_count: u8) -> Result { - self.reader.read(u32::from(bit_count)).map_err( |e| RbspBitReaderError::ReaderError(e) ) + pub fn read_u32(&mut self, bit_count: u8) -> Result { + self.reader.read(u32::from(bit_count)).map_err(BitReaderError::ReaderError) } - pub fn read_i32(&mut self, bit_count: u8) -> Result { - self.reader.read(u32::from(bit_count)).map_err( |e| RbspBitReaderError::ReaderError(e) ) + pub fn read_i32(&mut self, bit_count: u8) -> Result { + self.reader.read(u32::from(bit_count)).map_err(BitReaderError::ReaderError) } - pub fn has_more_rbsp_data(&mut self) -> bool { + pub fn has_more_rbsp_data(&mut self, name: &'static str) -> Result { // BitReader returns its reader iff at an aligned position. - self.reader.reader().map(|r| (r.position() as usize) < r.get_ref().len()).unwrap_or(true) + self.reader.reader().map(|r| { + r.fill_buf().map(|b| !b.is_empty()).map_err(|e| BitReaderError::ReaderErrorFor(name, e)) + }).unwrap_or(Ok(true)) } fn golomb_to_signed(val: u32) -> i32 { @@ -258,12 +254,12 @@ impl<'buf> RbspBitReader<'buf> { ((val >> 1) as i32 + (val & 0x1) as i32) * sign } } -fn count_zero_bits(r: &mut R, name: &'static str) -> Result { +fn count_zero_bits(r: &mut R, name: &'static str) -> Result { let mut count = 0; - while !r.read_bit()? { + while !r.read_bit().map_err(|e| BitReaderError::ReaderErrorFor(name, e))? { count += 1; if count > 31 { - return Err(RbspBitReaderError::ExpGolombTooLarge(name)); + return Err(BitReaderError::ExpGolombTooLarge(name)); } } Ok(count) @@ -365,15 +361,15 @@ mod tests { #[test] fn bitreader_has_more_data() { - let mut reader = RbspBitReader::new(&[0x12, 0x34]); - assert!(reader.has_more_rbsp_data()); + let mut reader = BitReader::new(&[0x12, 0x34][..]); + assert!(reader.has_more_rbsp_data("1").unwrap()); assert_eq!(reader.read_u8(4).unwrap(), 0x1); - assert!(reader.has_more_rbsp_data()); // unaligned, backing reader not at EOF + assert!(reader.has_more_rbsp_data("2").unwrap()); // unaligned, backing reader not at EOF assert_eq!(reader.read_u8(4).unwrap(), 0x2); - assert!(reader.has_more_rbsp_data()); // aligned, backing reader not at EOF + assert!(reader.has_more_rbsp_data("3").unwrap()); // aligned, backing reader not at EOF assert_eq!(reader.read_u8(4).unwrap(), 0x3); - assert!(reader.has_more_rbsp_data()); // unaligned, backing reader at EOF + assert!(reader.has_more_rbsp_data("4").unwrap()); // unaligned, backing reader at EOF assert_eq!(reader.read_u8(4).unwrap(), 0x4); - assert!(!reader.has_more_rbsp_data()); // aligned, backing reader at EOF + assert!(!reader.has_more_rbsp_data("eof").unwrap()); // aligned, backing reader at EOF } }