Skip to content

Commit

Permalink
Use a factory pattern in WavSampleProvider.
Browse files Browse the repository at this point in the history
  - A factory pattern is easier to use correctly than a construct / initialize pattern. Especially when construct and initialize were always coupled. Prevents limbo states.
  - New behavior:
    - Forbid duplicate audio element IDs in `Create`, previously the constructor would overwrite its cache and only keep the metadata for the last ID when there were duplicates (no test coverage). Now that the factory can detect it, it makes sense to entirely prevent it.
    - Fix crash if the underlying codec config pointers were `nullptr`
  - Internally cache `channel_ids` in lieu of all `AudioFrameMetadata`. Previously all metadatas needed to be cached between construction and init. With the factory pattern it seems cleaner to just cache what is needed.

PiperOrigin-RevId: 686987337
  • Loading branch information
jwcullen committed Oct 18, 2024
1 parent f449665 commit b650f58
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 171 deletions.
2 changes: 2 additions & 0 deletions iamf/cli/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,12 @@ cc_library(
"//iamf/cli/proto:audio_frame_cc_proto",
"//iamf/common:macros",
"//iamf/common:obu_util",
"//iamf/obu:codec_config",
"//iamf/obu:types",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
],
Expand Down
11 changes: 7 additions & 4 deletions iamf/cli/encoder_main_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,12 @@ absl::Status GenerateObus(
ia_sequence_header_obu, codec_config_obus, audio_elements,
mix_presentation_obus));

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
RETURN_IF_NOT_OK(
wav_sample_provider.Initialize(input_wav_directory, audio_elements));
auto wav_sample_provider =
WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
input_wav_directory, audio_elements);
if (!wav_sample_provider.ok()) {
return wav_sample_provider.status();
}

// Parameter blocks.
TimeParameterBlockMetadataMap time_parameter_block_metadata;
Expand All @@ -197,7 +200,7 @@ absl::Status GenerateObus(
absl::flat_hash_map<DecodedUleb128, LabelSamplesMap> id_to_labeled_samples;
bool no_more_real_samples = false;
RETURN_IF_NOT_OK(CollectLabeledSamplesForAudioElements(
audio_elements, wav_sample_provider, id_to_labeled_samples,
audio_elements, *wav_sample_provider, id_to_labeled_samples,
no_more_real_samples));

for (const auto& [audio_element_id, labeled_samples] :
Expand Down
2 changes: 2 additions & 0 deletions iamf/cli/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ cc_test(
"//iamf/cli:channel_label",
"//iamf/cli:demixing_module",
"//iamf/cli:wav_sample_provider",
"//iamf/cli/proto:audio_element_cc_proto",
"//iamf/cli/proto:audio_frame_cc_proto",
"//iamf/cli/proto:user_metadata_cc_proto",
"//iamf/obu:codec_config",
"//iamf/obu:types",
Expand Down
159 changes: 92 additions & 67 deletions iamf/cli/tests/wav_sample_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "iamf/cli/audio_element_with_data.h"
#include "iamf/cli/channel_label.h"
#include "iamf/cli/demixing_module.h"
#include "iamf/cli/proto/audio_element.pb.h"
#include "iamf/cli/proto/audio_frame.pb.h"
#include "iamf/cli/proto/user_metadata.pb.h"
#include "iamf/cli/tests/cli_test_utils.h"
#include "iamf/obu/codec_config.h"
Expand All @@ -43,21 +45,27 @@ static constexpr DecodedUleb128 kAudioElementId = 300;
static constexpr DecodedUleb128 kCodecConfigId = 200;
static constexpr uint32_t kSampleRate = 48000;

void InitializeTestData(
const uint32_t sample_rate,
iamf_tools_cli_proto::UserMetadata& user_metadata,
absl::flat_hash_map<DecodedUleb128, AudioElementWithData>& audio_elements) {
void FillStereoDataForAudioElementId(
uint32_t audio_element_id,
iamf_tools_cli_proto::AudioFrameObuMetadata& audio_frame_metadata) {
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
R"pb(
wav_filename: "stereo_8_samples_48khz_s16le.wav"
samples_to_trim_at_end: 0
samples_to_trim_at_start: 0
audio_element_id: 300
channel_ids: [ 0, 1 ]
channel_labels: [ "L2", "R2" ]
)pb",
user_metadata.add_audio_frame_metadata()));
&audio_frame_metadata));
audio_frame_metadata.set_audio_element_id(audio_element_id);
}

void InitializeTestData(
const uint32_t sample_rate,
iamf_tools_cli_proto::UserMetadata& user_metadata,
absl::flat_hash_map<DecodedUleb128, AudioElementWithData>& audio_elements) {
FillStereoDataForAudioElementId(kAudioElementId,
*user_metadata.add_audio_frame_metadata());
static absl::flat_hash_map<uint32_t, CodecConfigObu> codec_config_obus;
codec_config_obus.clear();
AddLpcmCodecConfigWithIdAndSampleRate(kCodecConfigId, sample_rate,
Expand All @@ -75,30 +83,68 @@ std::string GetInputWavDir() {
return input_wav_dir;
}

TEST(Initialize, SucceedsForStereoInput) {
TEST(Create, SucceedsForStereoInput) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
EXPECT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements),
EXPECT_THAT(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements),
IsOk());
}

TEST(Initialize, FailsForUnknownLabels) {
TEST(Create, FailsWhenUserMetadataContainsDuplicateAudioElementIds) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);
FillStereoDataForAudioElementId(kAudioElementId,
*user_metadata.add_audio_frame_metadata());

EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Create, FailsWhenMatchingAudioElementObuIsMissing) {
iamf_tools_cli_proto::UserMetadata user_metadata;
const absl::flat_hash_map<DecodedUleb128, AudioElementWithData>
kNoAudioElements = {};
FillStereoDataForAudioElementId(kAudioElementId,
*user_metadata.add_audio_frame_metadata());

EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), kNoAudioElements)
.ok());
}

TEST(Create, FailsWhenCodecConfigIsMissing) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);

// Corrupt the audio element by clearing the codec config pointer.
ASSERT_TRUE(audio_elements.contains(kAudioElementId));
audio_elements.at(kAudioElementId).codec_config = nullptr;

EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Create, FailsForUnknownLabels) {
constexpr absl::string_view kUnknownLabel = "unknown_label";
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);
user_metadata.mutable_audio_frame_metadata(0)->set_channel_labels(
0, kUnknownLabel);
WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());

EXPECT_FALSE(
wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok());
EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Initialize, SucceedsForDuplicateChannelIds) {
TEST(Create, SucceedsForDuplicateChannelIds) {
constexpr uint32_t kDuplicateChannelId = 0;
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
Expand All @@ -107,13 +153,13 @@ TEST(Initialize, SucceedsForDuplicateChannelIds) {
0, kDuplicateChannelId);
user_metadata.mutable_audio_frame_metadata(0)->set_channel_ids(
1, kDuplicateChannelId);
WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());

EXPECT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements),
EXPECT_THAT(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements),
IsOk());
}

TEST(Initialize, FailsForDuplicateChannelLabels) {
TEST(Create, FailsForDuplicateChannelLabels) {
constexpr absl::string_view kDuplicateLabel = "L2";
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
Expand All @@ -122,13 +168,13 @@ TEST(Initialize, FailsForDuplicateChannelLabels) {
0, kDuplicateLabel);
user_metadata.mutable_audio_frame_metadata(0)->set_channel_labels(
1, kDuplicateLabel);
WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());

EXPECT_FALSE(
wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok());
EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Initialize, FailsForChannelIdTooLarge) {
TEST(Create, FailsForChannelIdTooLarge) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);
Expand All @@ -138,27 +184,27 @@ TEST(Initialize, FailsForChannelIdTooLarge) {
constexpr uint32_t kChannelIdTooLargeForStereoWavFile = 2;
user_metadata.mutable_audio_frame_metadata(0)->mutable_channel_ids()->Set(
kFirstChannelIndex, kChannelIdTooLargeForStereoWavFile);
WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());

EXPECT_FALSE(
wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok());
EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Initialize, FailsForDifferentSizedChannelIdsAndChannelLabels) {
TEST(Create, FailsForDifferentSizedChannelIdsAndChannelLabels) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);

// Add one extra channel label, which does not have a corresponding
// channel ID, causing the `Initialize()` to fail.
// channel ID, causing the `Create()` to fail.
user_metadata.mutable_audio_frame_metadata(0)->add_channel_labels("C");

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
EXPECT_FALSE(
wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok());
EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Initialize, FailsForBitDepthLowerThanFile) {
TEST(Create, FailsForBitDepthLowerThanFile) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);
Expand All @@ -167,13 +213,12 @@ TEST(Initialize, FailsForBitDepthLowerThanFile) {
// The `Initialize()` would refuse to lower the bit depth and fail.
user_metadata.mutable_audio_frame_metadata(0)->set_wav_filename(
"stereo_8_samples_48khz_s24le.wav");

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
EXPECT_FALSE(
wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok());
EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(Initialize, FailsForMismatchingSampleRates) {
TEST(Create, FailsForMismatchingSampleRates) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;

Expand All @@ -182,24 +227,24 @@ TEST(Initialize, FailsForMismatchingSampleRates) {
const uint32_t kWrongSampleRate = 16000;
InitializeTestData(kWrongSampleRate, user_metadata, audio_elements);

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
EXPECT_FALSE(
wav_sample_provider.Initialize(GetInputWavDir(), audio_elements).ok());
EXPECT_FALSE(WavSampleProvider::Create(user_metadata.audio_frame_metadata(),
GetInputWavDir(), audio_elements)
.ok());
}

TEST(WavSampleProviderTest, ReadFrameSucceeds) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
ASSERT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements),
IsOk());
auto wav_sample_provider = WavSampleProvider::Create(
user_metadata.audio_frame_metadata(), GetInputWavDir(), audio_elements);
ASSERT_THAT(wav_sample_provider, IsOk());

LabelSamplesMap labeled_samples;
bool finished_reading = false;
EXPECT_THAT(wav_sample_provider.ReadFrames(kAudioElementId, labeled_samples,
finished_reading),
EXPECT_THAT(wav_sample_provider->ReadFrames(kAudioElementId, labeled_samples,
finished_reading),
IsOk());
EXPECT_TRUE(finished_reading);

Expand All @@ -218,37 +263,17 @@ TEST(WavSampleProviderTest, ReadFrameFailsWithWrongAudioElementId) {
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());
ASSERT_THAT(wav_sample_provider.Initialize(GetInputWavDir(), audio_elements),
IsOk());
auto wav_sample_provider = WavSampleProvider::Create(
user_metadata.audio_frame_metadata(), GetInputWavDir(), audio_elements);
ASSERT_THAT(wav_sample_provider, IsOk());

// Try to read frames using a wrong Audio Element ID.
const auto kWrongAudioElementId = kAudioElementId + 99;
LabelSamplesMap labeled_samples;
bool finished_reading = false;
EXPECT_FALSE(
wav_sample_provider
.ReadFrames(kWrongAudioElementId, labeled_samples, finished_reading)
.ok());
}

TEST(WavSampleProviderTest, ReadFrameFailsWithoutCallingInitialize) {
iamf_tools_cli_proto::UserMetadata user_metadata;
absl::flat_hash_map<DecodedUleb128, AudioElementWithData> audio_elements;
InitializeTestData(kSampleRate, user_metadata, audio_elements);

WavSampleProvider wav_sample_provider(user_metadata.audio_frame_metadata());

// Miss the following call to `Initialize()`:
// ASSERT_THAT(
// wav_sample_provider.Initialize(GetInputWavDir(), audio_elements),
// IsOk());

LabelSamplesMap labeled_samples;
bool finished_reading = false;
EXPECT_FALSE(
wav_sample_provider
.ReadFrames(kAudioElementId, labeled_samples, finished_reading)
->ReadFrames(kWrongAudioElementId, labeled_samples, finished_reading)
.ok());
}

Expand Down
Loading

0 comments on commit b650f58

Please sign in to comment.