Skip to content

Commit

Permalink
s/rbsp::RbspBitReader<'a>/rbsp::BitReader<R>/
Browse files Browse the repository at this point in the history
For dholroyd#4: make the bit reader type take a BufRead
rather than a slice so we don't have to keep a buffered copy of the
RBSP. While I'm at it, reduce "stuttering" by taking the module name out
of the struct name.

This is intrusive but mechanical.
  • Loading branch information
scottlamb committed Jun 11, 2021
1 parent fd0367f commit 86ac91d
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 116 deletions.
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/fuzz_target_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl h264_reader::nal::NalHandler for SliceFuzz {
decode.push(ctx, &current_slice.buf[..]);
decode.end(ctx);
let capture = decode.into_handler();
let mut r = rbsp::RbspBitReader::new(&capture.buf[1..]);
let mut r = rbsp::BitReader::new(&capture.buf[1..]);
match nal::slice::SliceHeader::read(ctx, &mut r, current_slice.header) {
Ok((header, sps, pps)) => {
println!("{:#?}", header);
Expand Down
29 changes: 15 additions & 14 deletions src/nal/pps.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use super::NalHandler;
use super::NalHeader;
use super::sps;
use std::io::BufRead;
use std::marker;
use crate::{rbsp, Context};
use crate::rbsp::RbspBitReader;
use crate::rbsp::BitReader;
use log::*;

#[derive(Debug)]
pub enum PpsError {
RbspReaderError(rbsp::RbspBitReaderError),
RbspReaderError(rbsp::BitReaderError),
InvalidSliceGroupMapType(u32),
InvalidSliceGroupChangeType(u32),
UnknownSeqParamSetId(ParamSetId),
Expand All @@ -17,8 +18,8 @@ pub enum PpsError {
ScalingMatrix(sps::ScalingMatrixError),
}

impl From<rbsp::RbspBitReaderError> for PpsError {
fn from(e: rbsp::RbspBitReaderError) -> Self {
impl From<rbsp::BitReaderError> for PpsError {
fn from(e: rbsp::BitReaderError) -> Self {
PpsError::RbspReaderError(e)
}
}
Expand Down Expand Up @@ -46,7 +47,7 @@ pub struct SliceRect {
bottom_right: u32,
}
impl SliceRect {
fn read(r: &mut RbspBitReader<'_>) -> Result<SliceRect,PpsError> {
fn read<R: BufRead>(r: &mut BitReader<R>) -> Result<SliceRect,PpsError> {
Ok(SliceRect {
top_left: r.read_ue_named("top_left")?,
bottom_right: r.read_ue_named("bottom_right")?,
Expand Down Expand Up @@ -77,7 +78,7 @@ pub enum SliceGroup {
},
}
impl SliceGroup {
fn read(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result<SliceGroup,PpsError> {
fn read<R: BufRead>(r: &mut BitReader<R>, num_slice_groups_minus1: u32) -> Result<SliceGroup,PpsError> {
let slice_group_map_type = r.read_ue_named("slice_group_map_type")?;
match slice_group_map_type {
0 => Ok(SliceGroup::Interleaved {
Expand All @@ -103,23 +104,23 @@ impl SliceGroup {
}
}

fn read_run_lengths(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result<Vec<u32>,PpsError> {
fn read_run_lengths<R: BufRead>(r: &mut BitReader<R>, num_slice_groups_minus1: u32) -> Result<Vec<u32>,PpsError> {
let mut run_length_minus1 = Vec::with_capacity(num_slice_groups_minus1 as usize + 1);
for _ in 0..num_slice_groups_minus1+1 {
run_length_minus1.push(r.read_ue_named("run_length_minus1")?);
}
Ok(run_length_minus1)
}

fn read_rectangles(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result<Vec<SliceRect>,PpsError> {
fn read_rectangles<R: BufRead>(r: &mut BitReader<R>, num_slice_groups_minus1: u32) -> Result<Vec<SliceRect>,PpsError> {
let mut run_length_minus1 = Vec::with_capacity(num_slice_groups_minus1 as usize + 1);
for _ in 0..num_slice_groups_minus1+1 {
run_length_minus1.push(SliceRect::read(r)?);
}
Ok(run_length_minus1)
}

fn read_group_ids(r: &mut RbspBitReader<'_>, num_slice_groups_minus1: u32) -> Result<Vec<u32>,PpsError> {
fn read_group_ids<R: BufRead>(r: &mut BitReader<R>, num_slice_groups_minus1: u32) -> Result<Vec<u32>,PpsError> {
let pic_size_in_map_units_minus1 = r.read_ue_named("pic_size_in_map_units_minus1")?;
// TODO: avoid any panics due to failed conversions
let size = ((1f64+f64::from(pic_size_in_map_units_minus1)).log2()) as u8;
Expand All @@ -136,7 +137,7 @@ struct PicScalingMatrix {
// TODO
}
impl PicScalingMatrix {
fn read(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet, transform_8x8_mode_flag: bool) -> Result<Option<PicScalingMatrix>,PpsError> {
fn read<R: BufRead>(r: &mut BitReader<R>, sps: &sps::SeqParameterSet, transform_8x8_mode_flag: bool) -> Result<Option<PicScalingMatrix>,PpsError> {
let pic_scaling_matrix_present_flag = r.read_bool()?;
Ok(if pic_scaling_matrix_present_flag {
let mut scaling_list4x4 = vec!();
Expand Down Expand Up @@ -171,8 +172,8 @@ pub struct PicParameterSetExtra {
second_chroma_qp_index_offset: i32,
}
impl PicParameterSetExtra {
fn read(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result<Option<PicParameterSetExtra>,PpsError> {
Ok(if r.has_more_rbsp_data() {
fn read<R: BufRead>(r: &mut BitReader<R>, sps: &sps::SeqParameterSet) -> Result<Option<PicParameterSetExtra>,PpsError> {
Ok(if r.has_more_rbsp_data("transform_8x8_mode_flag")? {
let transform_8x8_mode_flag = r.read_bool()?;
Some(PicParameterSetExtra {
transform_8x8_mode_flag,
Expand Down Expand Up @@ -226,7 +227,7 @@ pub struct PicParameterSet {
}
impl PicParameterSet {
pub fn from_bytes<Ctx>(ctx: &Context<Ctx>, buf: &[u8]) -> Result<PicParameterSet, PpsError> {
let mut r = RbspBitReader::new(buf);
let mut r = BitReader::new(buf);
let pic_parameter_set_id = ParamSetId::from_u32(r.read_ue_named("pic_parameter_set_id")?)
.map_err(PpsError::BadPicParamSetId)?;
let seq_parameter_set_id = ParamSetId::from_u32(r.read_ue_named("seq_parameter_set_id")?)
Expand All @@ -253,7 +254,7 @@ impl PicParameterSet {
})
}

fn read_slice_groups(r: &mut RbspBitReader<'_>) -> Result<Option<SliceGroup>,PpsError> {
fn read_slice_groups<R: BufRead>(r: &mut BitReader<R>) -> Result<Option<SliceGroup>,PpsError> {
let num_slice_groups_minus1 = r.read_ue_named("num_slice_groups_minus1")?;
Ok(if num_slice_groups_minus1 > 0 {
Some(SliceGroup::read(r, num_slice_groups_minus1)?)
Expand Down
15 changes: 8 additions & 7 deletions src/nal/sei/buffering_period.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use super::SeiCompletePayloadReader;
use std::io::BufRead;
use std::marker;
use crate::nal::pps;
use crate::rbsp::RbspBitReader;
use crate::rbsp::BitReader;
use crate::Context;
use crate::nal::sei::HeaderType;
use crate::rbsp::RbspBitReaderError;
use crate::rbsp::BitReaderError;
use log::*;

#[derive(Debug)]
enum BufferingPeriodError {
ReaderError(RbspBitReaderError),
ReaderError(BitReaderError),
UndefinedSeqParamSetId(pps::ParamSetId),
InvalidSeqParamSetId(pps::ParamSetIdError),
}
impl From<RbspBitReaderError> for BufferingPeriodError {
fn from(e: RbspBitReaderError) -> Self {
impl From<BitReaderError> for BufferingPeriodError {
fn from(e: BitReaderError) -> Self {
BufferingPeriodError::ReaderError(e)
}
}
Expand All @@ -30,7 +31,7 @@ struct InitialCpbRemoval {
initial_cpb_removal_delay_offset: u32,
}

fn read_cpb_removal_delay_list(r: &mut RbspBitReader<'_>, count: usize, length: u8) -> Result<Vec<InitialCpbRemoval>,RbspBitReaderError> {
fn read_cpb_removal_delay_list<R: BufRead>(r: &mut BitReader<R>, count: usize, length: u8) -> Result<Vec<InitialCpbRemoval>,BitReaderError> {
let mut res = vec!();
for _ in 0..count {
res.push(InitialCpbRemoval {
Expand All @@ -48,7 +49,7 @@ struct BufferingPeriod {
}
impl BufferingPeriod {
fn read<Ctx>(ctx: &Context<Ctx>, buf: &[u8]) -> Result<BufferingPeriod,BufferingPeriodError> {
let mut r = RbspBitReader::new(buf);
let mut r = BitReader::new(buf);
let seq_parameter_set_id = pps::ParamSetId::from_u32(r.read_ue_named("seq_parameter_set_id")?)?;
match ctx.sps_by_id(seq_parameter_set_id) {
None => Err(BufferingPeriodError::UndefinedSeqParamSetId(seq_parameter_set_id)),
Expand Down
22 changes: 12 additions & 10 deletions src/nal/sei/pic_timing.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::io::BufRead;

use crate::nal::sei::SeiCompletePayloadReader;
use crate::Context;
use crate::nal::sei::HeaderType;
use crate::nal::pps::ParamSetId;
use crate::rbsp::RbspBitReader;
use crate::rbsp::BitReader;
use crate::nal::sps;
use crate::rbsp::RbspBitReaderError;
use crate::rbsp::BitReaderError;
use log::*;

// FIXME: SPS selection
Expand All @@ -14,12 +16,12 @@ use log::*;

#[derive(Debug)]
pub enum PicTimingError {
RbspError(RbspBitReaderError),
RbspError(BitReaderError),
UndefinedSeqParamSetId(ParamSetId),
InvalidPicStructId(u8),
}
impl From<RbspBitReaderError> for PicTimingError {
fn from(e: RbspBitReaderError) -> Self {
impl From<BitReaderError> for PicTimingError {
fn from(e: BitReaderError) -> Self {
PicTimingError::RbspError(e)
}
}
Expand Down Expand Up @@ -175,7 +177,7 @@ pub struct ClockTimestamp {
pub time_offset: Option<i32>,
}
impl ClockTimestamp {
fn read(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result<ClockTimestamp, PicTimingError> {
fn read<R: BufRead>(r: &mut BitReader<R>, sps: &sps::SeqParameterSet) -> Result<ClockTimestamp, PicTimingError> {
let ct_type = CtType::from_id(r.read_u8(2)?);
let nuit_field_based_flag = r.read_bool_named("nuit_field_based_flag")?;
let counting_type = CountingType::from_id(r.read_u8(5)?);
Expand Down Expand Up @@ -247,7 +249,7 @@ pub struct PicTiming {
}
impl PicTiming {
pub fn read<Ctx>(ctx: &mut Context<Ctx>, buf: &[u8]) -> Result<PicTiming, PicTimingError> {
let mut r = RbspBitReader::new(buf);
let mut r = BitReader::new(buf);
let seq_parameter_set_id = ParamSetId::from_u32(0).unwrap();
match ctx.sps_by_id(seq_parameter_set_id) {
None => Err(PicTimingError::UndefinedSeqParamSetId(seq_parameter_set_id)),
Expand All @@ -260,7 +262,7 @@ impl PicTiming {
}
}

fn read_delays(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result<Option<Delays>,PicTimingError> {
fn read_delays<R: BufRead>(r: &mut BitReader<R>, sps: &sps::SeqParameterSet) -> Result<Option<Delays>,PicTimingError> {
Ok(if let Some(ref vui_params) = sps.vui_parameters {
if let Some(ref hrd) = vui_params.nal_hrd_parameters.as_ref().or_else(|| vui_params.nal_hrd_parameters.as_ref() ) {
Some(Delays {
Expand All @@ -275,7 +277,7 @@ impl PicTiming {
})
}

fn read_pic_struct(r: &mut RbspBitReader<'_>, sps: &sps::SeqParameterSet) -> Result<Option<PicStruct>,PicTimingError> {
fn read_pic_struct<R: BufRead>(r: &mut BitReader<R>, sps: &sps::SeqParameterSet) -> Result<Option<PicStruct>,PicTimingError> {
Ok(if let Some(ref vui_params) = sps.vui_parameters {
if vui_params.pic_struct_present_flag {
let pic_struct = PicStructType::from_id(r.read_u8(4)?)?;
Expand All @@ -293,7 +295,7 @@ impl PicTiming {
})
}

fn read_clock_timestamps(r: &mut RbspBitReader<'_>, pic_struct: &PicStructType, sps: &sps::SeqParameterSet) -> Result<Vec<Option<ClockTimestamp>>,PicTimingError> {
fn read_clock_timestamps<R: BufRead>(r: &mut BitReader<R>, pic_struct: &PicStructType, sps: &sps::SeqParameterSet) -> Result<Vec<Option<ClockTimestamp>>,PicTimingError> {
let mut res = Vec::new();
for _ in 0..pic_struct.num_clock_timestamps() {
res.push(if r.read_bool_named("clock_timestamp_flag")? {
Expand Down
23 changes: 12 additions & 11 deletions src/nal/slice/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

use crate::Context;
use crate::rbsp::RbspBitReader;
use crate::rbsp::RbspBitReaderError;
use crate::rbsp::BitReader;
use crate::rbsp::BitReaderError;
use crate::nal::pps::{ParamSetId, PicParameterSet};
use crate::nal::pps;
use crate::nal::sps;
use std::io::BufRead;
use std::marker;
use crate::nal::sps::SeqParameterSet;
use crate::nal::NalHeader;
Expand Down Expand Up @@ -56,7 +57,7 @@ impl SliceType {

#[derive(Debug)]
pub enum SliceHeaderError {
RbspError(RbspBitReaderError),
RbspError(BitReaderError),
InvalidSliceType(u32),
InvalidSeqParamSetId(pps::ParamSetIdError),
UndefinedPicParamSetId(pps::ParamSetId),
Expand All @@ -72,8 +73,8 @@ pub enum SliceHeaderError {
/// The header contained syntax elements that the parser isn't able to handle yet
UnsupportedSyntax(&'static str),
}
impl From<RbspBitReaderError> for SliceHeaderError {
fn from(e: RbspBitReaderError) -> Self {
impl From<BitReaderError> for SliceHeaderError {
fn from(e: BitReaderError) -> Self {
SliceHeaderError::RbspError(e)
}
}
Expand Down Expand Up @@ -163,7 +164,7 @@ enum RefPicListModifications {
},
}
impl RefPicListModifications {
fn read(slice_family: &SliceFamily, r: &mut RbspBitReader<'_>) -> Result<RefPicListModifications, SliceHeaderError> {
fn read<R: BufRead>(slice_family: &SliceFamily, r: &mut BitReader<R>) -> Result<RefPicListModifications, SliceHeaderError> {
Ok(match slice_family {
SliceFamily::I | SliceFamily::SI => RefPicListModifications::I,
SliceFamily::B => RefPicListModifications::B {
Expand All @@ -176,7 +177,7 @@ impl RefPicListModifications {
})
}

fn read_list(r: &mut RbspBitReader<'_>) -> Result<Vec<ModificationOfPicNums>, SliceHeaderError> {
fn read_list<R: BufRead>(r: &mut BitReader<R>) -> Result<Vec<ModificationOfPicNums>, SliceHeaderError> {
let mut result = vec![];
// either ref_pic_list_modification_flag_l0 or ref_pic_list_modification_flag_l1 depending
// on call-site,
Expand Down Expand Up @@ -209,7 +210,7 @@ struct PredWeightTable {
chroma_weights: Vec<Vec<PredWeight>>,
}
impl PredWeightTable {
fn read(r: &mut RbspBitReader<'_>, slice_type: &SliceType, pps: &pps::PicParameterSet, sps: &sps::SeqParameterSet, num_ref_active: &Option<NumRefIdxActive>) -> Result<PredWeightTable, SliceHeaderError> {
fn read<R: BufRead>(r: &mut BitReader<R>, slice_type: &SliceType, pps: &pps::PicParameterSet, sps: &sps::SeqParameterSet, num_ref_active: &Option<NumRefIdxActive>) -> Result<PredWeightTable, SliceHeaderError> {
let chroma_array_type = if sps.chroma_info.separate_colour_plane_flag {
// TODO: "Otherwise (separate_colour_plane_flag is equal to 1), ChromaArrayType is
// set equal to 0." ...does this mean ChromaFormat::Monochrome then?
Expand Down Expand Up @@ -295,7 +296,7 @@ enum DecRefPicMarking {
Adaptive(Vec<MemoryManagementControlOperation>),
}
impl DecRefPicMarking {
fn read(r: &mut RbspBitReader<'_>, header: NalHeader) -> Result<DecRefPicMarking, SliceHeaderError> {
fn read<R: BufRead>(r: &mut BitReader<R>, header: NalHeader) -> Result<DecRefPicMarking, SliceHeaderError> {
Ok(if header.nal_unit_type() == crate::nal::UnitType::SliceLayerWithoutPartitioningIdr {
DecRefPicMarking::Idr {
no_output_of_prior_pics_flag: r.read_bool_named("no_output_of_prior_pics_flag")?,
Expand Down Expand Up @@ -363,7 +364,7 @@ pub struct SliceHeader {
disable_deblocking_filter_idc: u8,
}
impl SliceHeader {
pub fn read<'a, Ctx>(ctx: &'a mut Context<Ctx>, r: &mut RbspBitReader<'_>, header: NalHeader) -> Result<(SliceHeader, &'a SeqParameterSet, &'a PicParameterSet), SliceHeaderError> {
pub fn read<'a, Ctx, R: BufRead>(ctx: &'a mut Context<Ctx>, r: &mut BitReader<R>, header: NalHeader) -> Result<(SliceHeader, &'a SeqParameterSet, &'a PicParameterSet), SliceHeaderError> {
let first_mb_in_slice = r.read_ue_named("first_mb_in_slice")?;
let slice_type = SliceType::from_id(r.read_ue_named("slice_type")?)?;
let pic_parameter_set_id = ParamSetId::from_u32(r.read_ue_named("pic_parameter_set_id")?)?;
Expand Down Expand Up @@ -543,7 +544,7 @@ impl<Ctx> super::NalHandler for SliceLayerWithoutPartitioningRbsp<Ctx> {
match self.state {
ParseState::Unstarted => panic!("start() not yet called"),
ParseState::Start(header) => {
let mut r = RbspBitReader::new(buf);
let mut r = BitReader::new(buf);
match SliceHeader::read(ctx, &mut r, header) {
Ok(header) => info!("TODO: expose to caller: {:#?}", header),
Err(e) => error!("slice_header() error: SliceHeaderError::{:?}", e),
Expand Down
Loading

0 comments on commit 86ac91d

Please sign in to comment.