diff --git a/iamf/cli/BUILD b/iamf/cli/BUILD index 678fde8..f2ad02f 100644 --- a/iamf/cli/BUILD +++ b/iamf/cli/BUILD @@ -447,10 +447,12 @@ cc_library( "//iamf/cli/proto:audio_frame_cc_proto", "//iamf/common:macros", "//iamf/common:obu_util", + "//iamf/obu:codec_config", "//iamf/obu:types", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], diff --git a/iamf/cli/encoder_main_lib.cc b/iamf/cli/encoder_main_lib.cc index f1ca845..6d0cad9 100644 --- a/iamf/cli/encoder_main_lib.cc +++ b/iamf/cli/encoder_main_lib.cc @@ -169,9 +169,12 @@ absl::Status GenerateObus( ia_sequence_header_obu, codec_config_obus, audio_elements, mix_presentation_obus)); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - RETURN_IF_NOT_OK( - wav_sample_provider.Initialize(input_wav_directory, audio_elements)); + auto wav_sample_provider = + WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + input_wav_directory, audio_elements); + if (!wav_sample_provider.ok()) { + return wav_sample_provider.status(); + } // Parameter blocks. TimeParameterBlockMetadataMap time_parameter_block_metadata; @@ -197,7 +200,7 @@ absl::Status GenerateObus( absl::flat_hash_map id_to_labeled_samples; bool no_more_real_samples = false; RETURN_IF_NOT_OK(CollectLabeledSamplesForAudioElements( - audio_elements, wav_sample_provider, id_to_labeled_samples, + audio_elements, *wav_sample_provider, id_to_labeled_samples, no_more_real_samples)); for (const auto& [audio_element_id, labeled_samples] : diff --git a/iamf/cli/tests/BUILD b/iamf/cli/tests/BUILD index dfe7dbf..01809ae 100644 --- a/iamf/cli/tests/BUILD +++ b/iamf/cli/tests/BUILD @@ -429,6 +429,8 @@ cc_test( "//iamf/cli:channel_label", "//iamf/cli:demixing_module", "//iamf/cli:wav_sample_provider", + "//iamf/cli/proto:audio_element_cc_proto", + "//iamf/cli/proto:audio_frame_cc_proto", "//iamf/cli/proto:user_metadata_cc_proto", "//iamf/obu:codec_config", "//iamf/obu:types", diff --git a/iamf/cli/tests/wav_sample_provider_test.cc b/iamf/cli/tests/wav_sample_provider_test.cc index 59fccf2..73afe3b 100644 --- a/iamf/cli/tests/wav_sample_provider_test.cc +++ b/iamf/cli/tests/wav_sample_provider_test.cc @@ -25,6 +25,8 @@ #include "iamf/cli/audio_element_with_data.h" #include "iamf/cli/channel_label.h" #include "iamf/cli/demixing_module.h" +#include "iamf/cli/proto/audio_element.pb.h" +#include "iamf/cli/proto/audio_frame.pb.h" #include "iamf/cli/proto/user_metadata.pb.h" #include "iamf/cli/tests/cli_test_utils.h" #include "iamf/obu/codec_config.h" @@ -43,21 +45,27 @@ static constexpr DecodedUleb128 kAudioElementId = 300; static constexpr DecodedUleb128 kCodecConfigId = 200; static constexpr uint32_t kSampleRate = 48000; -void InitializeTestData( - const uint32_t sample_rate, - iamf_tools_cli_proto::UserMetadata& user_metadata, - absl::flat_hash_map& audio_elements) { +void FillStereoDataForAudioElementId( + uint32_t audio_element_id, + iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata) { ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( wav_filename: "stereo_8_samples_48khz_s16le.wav" samples_to_trim_at_end: 0 samples_to_trim_at_start: 0 - audio_element_id: 300 channel_ids: [ 0, 1 ] channel_labels: [ "L2", "R2" ] )pb", - user_metadata.add_audio_frame_metadata())); + &audio_frame_metadata)); + audio_frame_metadata.set_audio_element_id(audio_element_id); +} +void InitializeTestData( + const uint32_t sample_rate, + iamf_tools_cli_proto::UserMetadata& user_metadata, + absl::flat_hash_map& audio_elements) { + FillStereoDataForAudioElementId(kAudioElementId, + *user_metadata.add_audio_frame_metadata()); static absl::flat_hash_map codec_config_obus; codec_config_obus.clear(); AddLpcmCodecConfigWithIdAndSampleRate(kCodecConfigId, sample_rate, @@ -75,30 +83,68 @@ std::string GetInputWavDir() { return input_wav_dir; } -TEST(Initialize, SucceedsForStereoInput) { +TEST(Create, SucceedsForStereoInput) { iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements), + EXPECT_THAT(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements), IsOk()); } -TEST(Initialize, FailsForUnknownLabels) { +TEST(Create, FailsWhenUserMetadataContainsDuplicateAudioElementIds) { + iamf_tools_cli_proto::UserMetadata user_metadata; + absl::flat_hash_map audio_elements; + InitializeTestData(kSampleRate, user_metadata, audio_elements); + FillStereoDataForAudioElementId(kAudioElementId, + *user_metadata.add_audio_frame_metadata()); + + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); +} + +TEST(Create, FailsWhenMatchingAudioElementObuIsMissing) { + iamf_tools_cli_proto::UserMetadata user_metadata; + const absl::flat_hash_map + kNoAudioElements = {}; + FillStereoDataForAudioElementId(kAudioElementId, + *user_metadata.add_audio_frame_metadata()); + + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), kNoAudioElements) + .ok()); +} + +TEST(Create, FailsWhenCodecConfigIsMissing) { + iamf_tools_cli_proto::UserMetadata user_metadata; + absl::flat_hash_map audio_elements; + InitializeTestData(kSampleRate, user_metadata, audio_elements); + + // Corrupt the audio element by clearing the codec config pointer. + ASSERT_TRUE(audio_elements.contains(kAudioElementId)); + audio_elements.at(kAudioElementId).codec_config = nullptr; + + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); +} + +TEST(Create, FailsForUnknownLabels) { constexpr absl::string_view kUnknownLabel = "unknown_label"; iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); user_metadata.mutable_audio_frame_metadata(0)->set_channel_labels( 0, kUnknownLabel); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_FALSE( - wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok()); + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); } -TEST(Initialize, SucceedsForDuplicateChannelIds) { +TEST(Create, SucceedsForDuplicateChannelIds) { constexpr uint32_t kDuplicateChannelId = 0; iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; @@ -107,13 +153,13 @@ TEST(Initialize, SucceedsForDuplicateChannelIds) { 0, kDuplicateChannelId); user_metadata.mutable_audio_frame_metadata(0)->set_channel_ids( 1, kDuplicateChannelId); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements), + EXPECT_THAT(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements), IsOk()); } -TEST(Initialize, FailsForDuplicateChannelLabels) { +TEST(Create, FailsForDuplicateChannelLabels) { constexpr absl::string_view kDuplicateLabel = "L2"; iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; @@ -122,13 +168,13 @@ TEST(Initialize, FailsForDuplicateChannelLabels) { 0, kDuplicateLabel); user_metadata.mutable_audio_frame_metadata(0)->set_channel_labels( 1, kDuplicateLabel); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_FALSE( - wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok()); + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); } -TEST(Initialize, FailsForChannelIdTooLarge) { +TEST(Create, FailsForChannelIdTooLarge) { iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); @@ -138,27 +184,27 @@ TEST(Initialize, FailsForChannelIdTooLarge) { constexpr uint32_t kChannelIdTooLargeForStereoWavFile = 2; user_metadata.mutable_audio_frame_metadata(0)->mutable_channel_ids()->Set( kFirstChannelIndex, kChannelIdTooLargeForStereoWavFile); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_FALSE( - wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok()); + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); } -TEST(Initialize, FailsForDifferentSizedChannelIdsAndChannelLabels) { +TEST(Create, FailsForDifferentSizedChannelIdsAndChannelLabels) { iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); // Add one extra channel label, which does not have a corresponding - // channel ID, causing the `Initialize()` to fail. + // channel ID, causing the `Create()` to fail. user_metadata.mutable_audio_frame_metadata(0)->add_channel_labels("C"); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_FALSE( - wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok()); + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); } -TEST(Initialize, FailsForBitDepthLowerThanFile) { +TEST(Create, FailsForBitDepthLowerThanFile) { iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); @@ -167,13 +213,12 @@ TEST(Initialize, FailsForBitDepthLowerThanFile) { // The `Initialize()` would refuse to lower the bit depth and fail. user_metadata.mutable_audio_frame_metadata(0)->set_wav_filename( "stereo_8_samples_48khz_s24le.wav"); - - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_FALSE( - wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok()); + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); } -TEST(Initialize, FailsForMismatchingSampleRates) { +TEST(Create, FailsForMismatchingSampleRates) { iamf_tools_cli_proto::UserMetadata user_metadata; absl::flat_hash_map audio_elements; @@ -182,9 +227,9 @@ TEST(Initialize, FailsForMismatchingSampleRates) { const uint32_t kWrongSampleRate = 16000; InitializeTestData(kWrongSampleRate, user_metadata, audio_elements); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - EXPECT_FALSE( - wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok()); + EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(), + GetInputWavDir(), audio_elements) + .ok()); } TEST(WavSampleProviderTest, ReadFrameSucceeds) { @@ -192,14 +237,14 @@ TEST(WavSampleProviderTest, ReadFrameSucceeds) { absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - ASSERT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements), - IsOk()); + auto wav_sample_provider = WavSampleProvider::Create( + user_metadata.audio_frame_metadata(), GetInputWavDir(), audio_elements); + ASSERT_THAT(wav_sample_provider, IsOk()); LabelSamplesMap labeled_samples; bool finished_reading = false; - EXPECT_THAT(wav_sample_provider.ReadFrames(kAudioElementId, labeled_samples, - finished_reading), + EXPECT_THAT(wav_sample_provider->ReadFrames(kAudioElementId, labeled_samples, + finished_reading), IsOk()); EXPECT_TRUE(finished_reading); @@ -218,9 +263,9 @@ TEST(WavSampleProviderTest, ReadFrameFailsWithWrongAudioElementId) { absl::flat_hash_map audio_elements; InitializeTestData(kSampleRate, user_metadata, audio_elements); - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - ASSERT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements), - IsOk()); + auto wav_sample_provider = WavSampleProvider::Create( + user_metadata.audio_frame_metadata(), GetInputWavDir(), audio_elements); + ASSERT_THAT(wav_sample_provider, IsOk()); // Try to read frames using a wrong Audio Element ID. const auto kWrongAudioElementId = kAudioElementId + 99; @@ -228,27 +273,7 @@ TEST(WavSampleProviderTest, ReadFrameFailsWithWrongAudioElementId) { bool finished_reading = false; EXPECT_FALSE( wav_sample_provider - .ReadFrames(kWrongAudioElementId, labeled_samples, finished_reading) - .ok()); -} - -TEST(WavSampleProviderTest, ReadFrameFailsWithoutCallingInitialize) { - iamf_tools_cli_proto::UserMetadata user_metadata; - absl::flat_hash_map audio_elements; - InitializeTestData(kSampleRate, user_metadata, audio_elements); - - WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata()); - - // Miss the following call to `Initialize()`: - // ASSERT_THAT( - // wav_sample_provider.Initialize(GetInputWavDir(), audio_elements), - // IsOk()); - - LabelSamplesMap labeled_samples; - bool finished_reading = false; - EXPECT_FALSE( - wav_sample_provider - .ReadFrames(kAudioElementId, labeled_samples, finished_reading) + ->ReadFrames(kWrongAudioElementId, labeled_samples, finished_reading) .ok()); } diff --git a/iamf/cli/wav_sample_provider.cc b/iamf/cli/wav_sample_provider.cc index 2f3e403..33b7f10 100644 --- a/iamf/cli/wav_sample_provider.cc +++ b/iamf/cli/wav_sample_provider.cc @@ -23,104 +23,185 @@ #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "iamf/cli/audio_element_with_data.h" #include "iamf/cli/channel_label.h" #include "iamf/cli/demixing_module.h" #include "iamf/cli/wav_reader.h" #include "iamf/common/macros.h" #include "iamf/common/obu_util.h" +#include "iamf/obu/codec_config.h" #include "iamf/obu/types.h" +#include "src/google/protobuf/repeated_ptr_field.h" namespace iamf_tools { -absl::Status WavSampleProvider::Initialize( - const std::string& input_wav_directory, - const absl::flat_hash_map& - audio_elements) { - for (const auto& [audio_element_id, audio_frame_metadata] : - audio_frame_metadata_) { - if (audio_frame_metadata.channel_ids_size() != - audio_frame_metadata.channel_labels_size()) { - return absl::InvalidArgumentError( - absl::StrCat("#channel IDs and #channel labels differ: (", - audio_frame_metadata.channel_ids_size(), " vs ", - audio_frame_metadata.channel_labels_size(), ")")); - } - if (!ValidateUnique(audio_frame_metadata.channel_ids().begin(), - audio_frame_metadata.channel_ids().end(), "channel ids") - .ok()) { - // OK. The user is claiming some channel IDs are shared between labels. - // This is strange, but permitted. - LOG(WARNING) << "Usually channel labels should be unique. Did you use " - "the same channel ID for different channels?"; - } +namespace { + +absl::Status FillChannelIdsAndLabels( + const iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata, + std::vector& channel_ids, + std::vector& labels) { + if (audio_frame_metadata.channel_ids_size() != + audio_frame_metadata.channel_labels_size()) { + return absl::InvalidArgumentError( + absl::StrCat("#channel IDs and #channel labels differ: (", + audio_frame_metadata.channel_ids_size(), " vs ", + audio_frame_metadata.channel_labels_size(), ")")); + } - // Precompute the `ChannelLabel::Label` for each channel label string. - RETURN_IF_NOT_OK(ChannelLabel::ConvertAndFillLabels( - audio_frame_metadata.channel_labels(), - audio_element_id_to_labels_[audio_element_id])); + // Precompute the channel IDs for the audio element. + channel_ids.reserve(audio_frame_metadata.channel_ids_size()); + for (const uint32_t channel_id : audio_frame_metadata.channel_ids()) { + channel_ids.push_back(channel_id); + } + if (!ValidateUnique(channel_ids.begin(), channel_ids.end(), "channel ids") + .ok()) { + // OK. The user is claiming some channel IDs are shared between labels. + // This is strange, but permitted. + LOG(WARNING) << "Usually channel labels should be unique. Did you use " + "the same channel ID for different channels?"; + } - const auto audio_element_iter = audio_elements.find(audio_element_id); - if (audio_element_iter == audio_elements.end()) { + // Precompute the `ChannelLabel::Label` for each channel label string. + RETURN_IF_NOT_OK(ChannelLabel::ConvertAndFillLabels( + audio_frame_metadata.channel_labels(), labels)); + + return absl::OkStatus(); +} +absl::Status ValidateWavReaderIsConsistentWithData( + absl::string_view wav_filename_for_debugging, const WavReader& wav_reader, + const CodecConfigObu& codec_config, + const std::vector& channel_ids) { + const std::string pretty_print_wav_filename = + absl::StrCat("WAV (", wav_filename_for_debugging, ")"); + const int encoder_input_pcm_bit_depth = + static_cast(codec_config.GetBitDepthToMeasureLoudness()); + if (wav_reader.bit_depth() > encoder_input_pcm_bit_depth) { + return absl::InvalidArgumentError(absl::StrCat( + "Refusing to lower bit-depth of ", pretty_print_wav_filename, + " with bit_depth= ", wav_reader.bit_depth(), + " to bit_depth=", encoder_input_pcm_bit_depth)); + } + + const uint32_t encoder_input_sample_rate = codec_config.GetInputSampleRate(); + if (wav_reader.sample_rate_hz() != encoder_input_sample_rate) { + return absl::InvalidArgumentError(absl::StrCat( + pretty_print_wav_filename, "has a sample rate of ", + wav_reader.sample_rate_hz(), " Hz. Expected a sample rate of ", + encoder_input_sample_rate, + " Hz based on the Codec Config OBU. Consider using a third party " + "resampler on the WAV file, or picking Codec Config OBU settings to " + "match the WAV file before trying again.")); + } + + const uint32_t decoder_output_sample_rate = + codec_config.GetOutputSampleRate(); + if (encoder_input_sample_rate != decoder_output_sample_rate) { + return absl::InvalidArgumentError(absl::StrCat( + "Input and output sample rates differ: (", encoder_input_sample_rate, + " vs ", decoder_output_sample_rate, ")")); + } + + // To prevent indexing out of bounds after the `WavSampleProvider` is + // created, we ensure all user-specified channel IDs are in range of the + // number of channels in the input file. + for (const uint32_t channel_id : channel_ids) { + if (channel_id >= wav_reader.num_channels()) { return absl::InvalidArgumentError( - absl::StrCat("No Audio Element found for ID= ", audio_element_id)); - } - const auto& codec_config = *audio_element_iter->second.codec_config; - const auto& wav_filename = std::filesystem::path(input_wav_directory) / - audio_frame_metadata.wav_filename(); - - auto wav_reader = WavReader::CreateFromFile( - wav_filename.string(), - static_cast(codec_config.GetNumSamplesPerFrame())); - if (!wav_reader.ok()) { - return wav_reader.status(); + absl::StrCat(pretty_print_wav_filename, + " has num_channels= ", wav_reader.num_channels(), + ". channel_id= ", channel_id, " is out of bounds.")); } + } - const int encoder_input_pcm_bit_depth = - static_cast(codec_config.GetBitDepthToMeasureLoudness()); - if (wav_reader->bit_depth() > encoder_input_pcm_bit_depth) { - return absl::InvalidArgumentError(absl::StrCat( - "Refusing to lower bit-depth of WAV (", wav_filename.string(), - ") with bit_depth= ", wav_reader->bit_depth(), - " to bit_depth=", encoder_input_pcm_bit_depth)); - } + return absl::OkStatus(); +} - const uint32_t encoder_input_sample_rate = - codec_config.GetInputSampleRate(); - if (wav_reader->sample_rate_hz() != encoder_input_sample_rate) { - return absl::InvalidArgumentError(absl::StrCat( - "WAV (", wav_filename.string(), ") has a sample rate of ", - wav_reader->sample_rate_hz(), " Hz. Expected a sample rate of ", - encoder_input_sample_rate, - " Hz based on the Codec Config OBU. Consider using a third party " - "resampler on the WAV file, or picking Codec Config OBU settings to " - "match the WAV file before trying again.")); - } +// Fills in `channel_ids`, `labels`, and creates a `WavReader` from the input +// metadata and other input data. +absl::Status InitializeForAudioElement( + uint32_t audio_element_id, + const iamf_tools_cli_proto::AudioFrameObuMetadata audio_frame_metadata, + const std::string& wav_filename, const CodecConfigObu& codec_config, + std::vector& channel_ids, + std::vector& labels, + absl::flat_hash_map& + audio_element_id_to_wav_reader) { + RETURN_IF_NOT_OK( + FillChannelIdsAndLabels(audio_frame_metadata, channel_ids, labels)); + + auto wav_reader = WavReader::CreateFromFile( + wav_filename, static_cast(codec_config.GetNumSamplesPerFrame())); + if (!wav_reader.ok()) { + return wav_reader.status(); + } + RETURN_IF_NOT_OK(ValidateWavReaderIsConsistentWithData( + wav_filename, *wav_reader, codec_config, channel_ids)); + + audio_element_id_to_wav_reader.emplace(audio_element_id, + std::move(*wav_reader)); + + return absl::OkStatus(); +} + +} // namespace - const uint32_t decoder_output_sample_rate = - codec_config.GetOutputSampleRate(); - if (encoder_input_sample_rate != decoder_output_sample_rate) { - return absl::InvalidArgumentError(absl::StrCat( - "Input and output sample rates differ: (", encoder_input_sample_rate, - " vs ", decoder_output_sample_rate, ")")); +absl::StatusOr WavSampleProvider::Create( + const ::google::protobuf::RepeatedPtrField< + iamf_tools_cli_proto::AudioFrameObuMetadata>& audio_frame_metadata, + absl::string_view input_wav_directory, + const absl::flat_hash_map& + audio_elements) { + // Precompute, validate, and cache data for each audio element. + absl::flat_hash_map wav_readers; + absl::flat_hash_map> + audio_element_id_to_channel_ids; + absl::flat_hash_map> + audio_element_id_to_labels; + + const std::filesystem::path input_wav_directory_path(input_wav_directory); + for (const auto& audio_frame_obu_metadata : audio_frame_metadata) { + const uint32_t audio_element_id = + audio_frame_obu_metadata.audio_element_id(); + const auto& wav_filename = + input_wav_directory_path / + std::filesystem::path(audio_frame_obu_metadata.wav_filename()); + + // Retrieve the Codec Config OBU for the audio element. + auto audio_element_iter = audio_elements.find(audio_element_id); + if (audio_element_iter == audio_elements.end()) { + return absl::InvalidArgumentError( + absl::StrCat("No Audio Element found for ID= ", + audio_frame_obu_metadata.audio_element_id())); + } + const CodecConfigObu* codec_config = + audio_element_iter->second.codec_config; + if (codec_config == nullptr) { + return absl::InvalidArgumentError( + absl::StrCat("No Codec Config found for Audio Element ID= ", + audio_frame_obu_metadata.audio_element_id())); } - // To prevent indexing out of bounds after the `WavSampleProvider` is - // created, we ensure all user-specified channel IDs are in range of the - // number of channels in the input file. - for (const uint32_t channel_id : audio_frame_metadata.channel_ids()) { - if (channel_id >= wav_reader->num_channels()) { - return absl::InvalidArgumentError( - absl::StrCat("WAV (", wav_filename.string(), - ") has num_channels= ", wav_reader->num_channels(), - ". channel_id= ", channel_id, " is out of bounds.")); - } + auto [channel_ids_iter, inserted] = audio_element_id_to_channel_ids.emplace( + audio_element_id, std::vector()); + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("List of AudioFrameObuMetadatahas contains duplicate " + "Audio Element ID= ", + audio_element_id)); } + // Internals add to the maps in parallel; if one had an empty slot, then + // the others will have an empty slot. - wav_readers_.emplace(audio_element_id, std::move(*wav_reader)); + RETURN_IF_NOT_OK(InitializeForAudioElement( + audio_element_id, audio_frame_obu_metadata, wav_filename.string(), + *codec_config, channel_ids_iter->second, + audio_element_id_to_labels[audio_element_id], wav_readers)); } - - return absl::OkStatus(); + return WavSampleProvider(std::move(wav_readers), + std::move(audio_element_id_to_channel_ids), + std::move(audio_element_id_to_labels)); } absl::Status WavSampleProvider::ReadFrames( @@ -136,11 +217,11 @@ absl::Status WavSampleProvider::ReadFrames( LOG_FIRST_N(INFO, 1) << samples_read << " samples read"; // Note if the WAV reader is found for the Audio Element ID, then it's - // guaranteed to have a corresponding audio frame metadata (otherwise the - // `Initialize()` would have failed). - const auto& audio_frame_metadata = audio_frame_metadata_.at(audio_element_id); + // guaranteed to have the other corresponding metadata (otherwise the + // `Create()` would have failed). const size_t num_time_ticks = samples_read / wav_reader.num_channels(); - const auto& channel_ids = audio_frame_metadata.channel_ids(); + const auto& channel_ids = + audio_element_id_to_channel_ids_.at(audio_element_id); const auto& channel_labels = audio_element_id_to_labels_.at(audio_element_id); labeled_samples.clear(); for (int c = 0; c < channel_labels.size(); ++c) { @@ -156,4 +237,15 @@ absl::Status WavSampleProvider::ReadFrames( return absl::OkStatus(); } +WavSampleProvider::WavSampleProvider( + absl::flat_hash_map&& wav_readers, + absl::flat_hash_map>&& + audio_element_id_to_channel_ids, + absl::flat_hash_map>&& + audio_element_id_to_labels) + : wav_readers_(std::move(wav_readers)), + audio_element_id_to_channel_ids_( + std::move(audio_element_id_to_channel_ids)), + audio_element_id_to_labels_(std::move(audio_element_id_to_labels)) {}; + } // namespace iamf_tools diff --git a/iamf/cli/wav_sample_provider.h b/iamf/cli/wav_sample_provider.h index 7c6821e..98ca284 100644 --- a/iamf/cli/wav_sample_provider.h +++ b/iamf/cli/wav_sample_provider.h @@ -13,11 +13,15 @@ #ifndef CLI_WAV_SAMPLE_PROVIDER_H_ #define CLI_WAV_SAMPLE_PROVIDER_H_ +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "iamf/cli/audio_element_with_data.h" #include "iamf/cli/channel_label.h" #include "iamf/cli/demixing_module.h" @@ -30,27 +34,17 @@ namespace iamf_tools { class WavSampleProvider { public: - /*!\brief Constructor. + /*!\brief Factory function. * * \param audio_frame_metadata Input audio frame metadata. - */ - WavSampleProvider( - const ::google::protobuf::RepeatedPtrField< - iamf_tools_cli_proto::AudioFrameObuMetadata>& audio_frame_metadata) { - for (const auto& audio_frame_obu_metadata : audio_frame_metadata) { - audio_frame_metadata_[audio_frame_obu_metadata.audio_element_id()] = - audio_frame_obu_metadata; - } - } - - /*!\brief Initializes WAV readers that provide samples for the audio frames. - * * \param input_wav_directory Directory containing the input WAV files. * \param audio_elements Input Audio Element OBUs with data. * \return `absl::OkStatus()` on success. A specific status on failure. */ - absl::Status Initialize( - const std::string& input_wav_directory, + static absl::StatusOr Create( + const ::google::protobuf::RepeatedPtrField< + iamf_tools_cli_proto::AudioFrameObuMetadata>& audio_frame_metadata, + absl::string_view input_wav_directory, const absl::flat_hash_map& audio_elements); @@ -67,16 +61,32 @@ class WavSampleProvider { bool& finished_reading); private: + /*!\brief Constructor. + * + * Used only by factory function. Moves from all input arguments. + * + * \param wav_readers Mapping from Audio Element ID to `WavReader`. + * \param audio_element_id_to_channel_ids Mapping from Audio Element ID to + * channel IDs. + * \param audio_element_id_to_labels Mapping from Audio Element ID to channel + * labels. + */ + WavSampleProvider( + absl::flat_hash_map&& wav_readers, + absl::flat_hash_map>&& + audio_element_id_to_channel_ids, + absl::flat_hash_map>&& + audio_element_id_to_labels); + // Mapping from Audio Element ID to `WavReader`. absl::flat_hash_map wav_readers_; - // Mapping from Audio Element ID to audio frame metadata. - absl::flat_hash_map - audio_frame_metadata_; + // Mapping from Audio Element ID to channel IDs. + const absl::flat_hash_map> + audio_element_id_to_channel_ids_; // Mapping from Audio Element ID to channel labels. - absl::flat_hash_map> + const absl::flat_hash_map> audio_element_id_to_labels_; };