diff --git a/iamf/cli/demixing_module.cc b/iamf/cli/demixing_module.cc index 4c471e54..1a58cdef 100644 --- a/iamf/cli/demixing_module.cc +++ b/iamf/cli/demixing_module.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "absl/base/no_destructor.h" @@ -728,74 +729,81 @@ absl::Status FillRequiredDemixingMetadata( return absl::OkStatus(); } +void ConfigureLabeledFrame(const AudioFrameWithData& audio_frame, + LabeledFrame& labeled_frame) { + labeled_frame.end_timestamp = audio_frame.end_timestamp; + labeled_frame.samples_to_trim_at_end = + audio_frame.obu.header_.num_samples_to_trim_at_end; + labeled_frame.samples_to_trim_at_start = + audio_frame.obu.header_.num_samples_to_trim_at_start; + labeled_frame.demixing_params = audio_frame.down_mixing_params; +} + +void ConfigureLabeledFrame(const DecodedAudioFrame& decoded_audio_frame, + LabeledFrame& labeled_decoded_frame) { + labeled_decoded_frame.end_timestamp = decoded_audio_frame.end_timestamp; + labeled_decoded_frame.samples_to_trim_at_end = + decoded_audio_frame.samples_to_trim_at_end; + labeled_decoded_frame.samples_to_trim_at_start = + decoded_audio_frame.samples_to_trim_at_start; + labeled_decoded_frame.demixing_params = + decoded_audio_frame.down_mixing_params; +} + +uint32_t GetSubstreamId(const AudioFrameWithData& audio_frame_with_data) { + return audio_frame_with_data.obu.GetSubstreamId(); +} + +uint32_t GetSubstreamId(const DecodedAudioFrame& audio_frame_with_data) { + return audio_frame_with_data.substream_id; +} + +const std::vector>& GetSamples( + const AudioFrameWithData& audio_frame_with_data) { + return audio_frame_with_data.raw_samples; +} + +const std::vector>& GetSamples( + const DecodedAudioFrame& audio_frame_with_data) { + return audio_frame_with_data.decoded_samples; +} + +// TODO(b/339037792): Unify `AudioFrameWithData` and `DecodedAudioFrame`. +template absl::Status StoreSamplesForAudioElementId( - const std::list& audio_frames, - const std::list& decoded_audio_frames, + const std::list& audio_frames_or_decoded_audio_frames, const SubstreamIdLabelsMap& substream_id_to_labels, - LabeledFrame& labeled_frame, LabeledFrame& labeled_decoded_frame) { - auto audio_frame_iter = audio_frames.begin(); - auto decoded_audio_frame_iter = decoded_audio_frames.begin(); - const int32_t common_start_timestamp = audio_frame_iter->start_timestamp; - for (; audio_frame_iter != audio_frames.end() && - decoded_audio_frame_iter != decoded_audio_frames.end(); - audio_frame_iter++, decoded_audio_frame_iter++) { - const auto substream_id = audio_frame_iter->obu.GetSubstreamId(); + LabeledFrame& labeled_frame) { + if (audio_frames_or_decoded_audio_frames.empty()) { + return absl::OkStatus(); + } + const int32_t common_start_timestamp = + audio_frames_or_decoded_audio_frames.begin()->start_timestamp; + + for (auto& audio_frame : audio_frames_or_decoded_audio_frames) { + const auto substream_id = GetSubstreamId(audio_frame); auto substream_id_labels_iter = substream_id_to_labels.find(substream_id); if (substream_id_labels_iter == substream_id_to_labels.end()) { // This audio frame might belong to a different audio element; skip it. continue; } - if (decoded_audio_frame_iter->substream_id != substream_id) { - LOG(ERROR) << "Substream ID mismatch: " << substream_id << " vs " - << decoded_audio_frame_iter->substream_id; - return absl::InvalidArgumentError(""); - } - const auto& labels = substream_id_labels_iter->second; - // TODO(b/339037792): Remove dependency on `raw_samples`. - if (audio_frame_iter->raw_samples[0].size() != labels.size() || - decoded_audio_frame_iter->decoded_samples[0].size() != labels.size()) { - LOG(ERROR) << "Channel number mismatch: " - << audio_frame_iter->raw_samples[0].size() << " vs " - << decoded_audio_frame_iter->decoded_samples[0].size() - << " vs " << labels.size(); - return absl::InvalidArgumentError(""); - } // Validate that the frames are all aligned in time. - RETURN_IF_NOT_OK(CompareTimestamps(common_start_timestamp, - audio_frame_iter->start_timestamp)); - RETURN_IF_NOT_OK(CompareTimestamps( - common_start_timestamp, decoded_audio_frame_iter->start_timestamp)); + RETURN_IF_NOT_OK( + CompareTimestamps(common_start_timestamp, audio_frame.start_timestamp)); + const auto& labels = substream_id_labels_iter->second; int channel_index = 0; for (const auto& label : labels) { - const size_t num_ticks = audio_frame_iter->raw_samples.size(); - - labeled_frame.end_timestamp = audio_frame_iter->end_timestamp; - labeled_frame.samples_to_trim_at_end = - audio_frame_iter->obu.header_.num_samples_to_trim_at_end; - labeled_frame.samples_to_trim_at_start = - audio_frame_iter->obu.header_.num_samples_to_trim_at_start; - labeled_frame.demixing_params = audio_frame_iter->down_mixing_params; - - labeled_decoded_frame.end_timestamp = - decoded_audio_frame_iter->end_timestamp; - labeled_decoded_frame.samples_to_trim_at_end = - decoded_audio_frame_iter->samples_to_trim_at_end; - labeled_decoded_frame.samples_to_trim_at_start = - decoded_audio_frame_iter->samples_to_trim_at_start; - labeled_decoded_frame.demixing_params = - decoded_audio_frame_iter->down_mixing_params; + const auto& input_samples = GetSamples(audio_frame); + const size_t num_ticks = input_samples.size(); - auto& samples = labeled_frame.label_to_samples[label]; - auto& decoded_samples = labeled_decoded_frame.label_to_samples[label]; + ConfigureLabeledFrame(audio_frame, labeled_frame); + auto& samples = labeled_frame.label_to_samples[label]; samples.resize(num_ticks, 0); - decoded_samples.resize(num_ticks, 0); - for (int t = 0; t < audio_frame_iter->raw_samples.size(); t++) { - samples[t] = audio_frame_iter->raw_samples[t][channel_index]; - decoded_samples[t] = - decoded_audio_frame_iter->decoded_samples[t][channel_index]; + for (int t = 0; t < samples.size(); t++) { + samples[t] = input_samples[t][channel_index]; } channel_index++; } @@ -805,13 +813,10 @@ absl::Status StoreSamplesForAudioElementId( } absl::Status ApplyDemixers(const std::list& demixers, - LabeledFrame& labeled_frame, - LabeledFrame& labeled_decoded_frame) { + LabeledFrame& labeled_frame) { for (const auto& demixer : demixers) { RETURN_IF_NOT_OK(demixer(labeled_frame.demixing_params, &labeled_frame.label_to_samples)); - RETURN_IF_NOT_OK(demixer(labeled_decoded_frame.demixing_params, - &labeled_decoded_frame.label_to_samples)); } return absl::OkStatus(); } @@ -901,6 +906,24 @@ absl::StatusOr> LookupLabelsToReconstruct( } } +void LogForAudioElementId( + DecodedUleb128 audio_element_id, + const IdLabeledFrameMap& id_to_labeled_frame, + const IdLabeledFrameMap& id_to_labeled_decoded_frame) { + if (!id_to_labeled_frame.contains(audio_element_id) || + id_to_labeled_decoded_frame.contains(audio_element_id)) { + return; + } + for (const auto& [label, samples] : + id_to_labeled_frame.at(audio_element_id).label_to_samples) { + const auto& decoded_samples = + id_to_labeled_decoded_frame.at(audio_element_id) + .label_to_samples.at(label); + LOG(INFO) << " Channel " << label << ":\tframe size= " << samples.size() + << "; decoded frame size= " << decoded_samples.size(); + } +} + } // namespace absl::Status DemixingModule::FindSamplesOrDemixedSamples( @@ -1054,22 +1077,29 @@ absl::Status DemixingModule::DemixAudioSamples( IdLabeledFrameMap& id_to_labeled_decoded_frame) const { for (const auto& [audio_element_id, demixing_metadata] : audio_element_id_to_demixing_metadata_) { - auto& labeled_frame = id_to_labeled_frame[audio_element_id]; - auto& labeled_decoded_frame = id_to_labeled_decoded_frame[audio_element_id]; - - RETURN_IF_NOT_OK( - StoreSamplesForAudioElementId(audio_frames, decoded_audio_frames, - demixing_metadata.substream_id_to_labels, - labeled_frame, labeled_decoded_frame)); - RETURN_IF_NOT_OK(ApplyDemixers(demixing_metadata.demixers, labeled_frame, - labeled_decoded_frame)); - - for (const auto& [label, samples] : labeled_frame.label_to_samples) { - const auto& decoded_samples = - labeled_decoded_frame.label_to_samples.at(label); - LOG(INFO) << " Channel " << label << ":\tframe size= " << samples.size() - << "; decoded frame size= " << decoded_samples.size(); + // Process the original audio frames. + LabeledFrame labeled_frame; + RETURN_IF_NOT_OK(StoreSamplesForAudioElementId( + audio_frames, demixing_metadata.substream_id_to_labels, labeled_frame)); + if (!labeled_frame.label_to_samples.empty()) { + RETURN_IF_NOT_OK( + ApplyDemixers(demixing_metadata.demixers, labeled_frame)); + id_to_labeled_frame[audio_element_id] = std::move(labeled_frame); } + // Process the decoded audio frames. + LabeledFrame labeled_decoded_frame; + RETURN_IF_NOT_OK(StoreSamplesForAudioElementId( + decoded_audio_frames, demixing_metadata.substream_id_to_labels, + labeled_decoded_frame)); + if (!labeled_decoded_frame.label_to_samples.empty()) { + RETURN_IF_NOT_OK( + ApplyDemixers(demixing_metadata.demixers, labeled_decoded_frame)); + id_to_labeled_decoded_frame[audio_element_id] = + std::move(labeled_decoded_frame); + } + + LogForAudioElementId(audio_element_id, id_to_labeled_frame, + id_to_labeled_decoded_frame); } return absl::OkStatus(); diff --git a/iamf/cli/tests/demixing_module_test.cc b/iamf/cli/tests/demixing_module_test.cc index 6da95c34..530ab1df 100644 --- a/iamf/cli/tests/demixing_module_test.cc +++ b/iamf/cli/tests/demixing_module_test.cc @@ -39,6 +39,12 @@ namespace iamf_tools { namespace { constexpr DecodedUleb128 kAudioElementId = 137; +const uint32_t kZeroSamplesToTrimAtEnd = 0; +const uint32_t kZeroSamplesToTrimAtStart = 0; +const int kStartTimestamp = 0; +const int kEndTimestamp = 4; +const DecodedUleb128 kMonoSubstreamId = 0; +const DecodedUleb128 kL2SubstreamId = 1; // TODO(b/305927287): Test computation of linear output gains. Test some cases // of erroneous input. @@ -238,6 +244,183 @@ TEST(InitializeForReconstruction, CreatesNoDemixersForAmbisonics) { EXPECT_TRUE(demixer->empty()); } +TEST(DemixAudioSamples, OutputContainsOriginalAndDemixedSamples) { + absl::flat_hash_map audio_elements; + InitAudioElementWithLabelsAndLayers( + {{kMonoSubstreamId, {"M"}}, {kL2SubstreamId, {"L2"}}}, + {ChannelAudioLayerConfig::kLayoutMono, + ChannelAudioLayerConfig::kLayoutStereo}, + audio_elements); + std::list decoded_audio_frames; + decoded_audio_frames.push_back( + DecodedAudioFrame{.substream_id = kMonoSubstreamId, + .start_timestamp = kStartTimestamp, + .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, + .decoded_samples = {{0}}, + .down_mixing_params = DownMixingParams()}); + decoded_audio_frames.push_back( + DecodedAudioFrame{.substream_id = kL2SubstreamId, + .start_timestamp = kStartTimestamp, + .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, + .decoded_samples = {{0}}, + .down_mixing_params = DownMixingParams()}); + DemixingModule demixing_module; + EXPECT_TRUE(demixing_module.InitializeForReconstruction(audio_elements).ok()); + IdLabeledFrameMap id_labeled_frame; + IdLabeledFrameMap id_to_labeled_decoded_frame; + EXPECT_TRUE(demixing_module + .DemixAudioSamples({}, decoded_audio_frames, id_labeled_frame, + id_to_labeled_decoded_frame) + .ok()); + + const auto& labeled_frame = id_to_labeled_decoded_frame.at(kAudioElementId); + EXPECT_TRUE(labeled_frame.label_to_samples.contains("L2")); + EXPECT_TRUE(labeled_frame.label_to_samples.contains("M")); + EXPECT_TRUE(labeled_frame.label_to_samples.contains("D_R2")); + // When being used for reconstruction the original audio frames are not + // output. + EXPECT_FALSE(id_labeled_frame.contains(kAudioElementId)); +} + +TEST(DemixAudioSamples, OutputEchoesTimingInformation) { + // These values are not very sensible, but as long as they are consistent + // between related frames it is OK. + const DecodedUleb128 kExpectedStartTimestamp = 99; + const DecodedUleb128 kExpectedEndTimestamp = 123; + const DecodedUleb128 kExpectedNumSamplesToTrimAtEnd = 999; + const DecodedUleb128 kExpectedNumSamplesToTrimAtStart = 9999; + const DecodedUleb128 kL2SubstreamId = 1; + absl::flat_hash_map audio_elements; + InitAudioElementWithLabelsAndLayers( + {{kMonoSubstreamId, {"M"}}, {kL2SubstreamId, {"L2"}}}, + {ChannelAudioLayerConfig::kLayoutMono, + ChannelAudioLayerConfig::kLayoutStereo}, + audio_elements); + std::list decoded_audio_frames; + decoded_audio_frames.push_back(DecodedAudioFrame{ + .substream_id = kMonoSubstreamId, + .start_timestamp = kExpectedStartTimestamp, + .end_timestamp = kExpectedEndTimestamp, + .samples_to_trim_at_end = kExpectedNumSamplesToTrimAtEnd, + .samples_to_trim_at_start = kExpectedNumSamplesToTrimAtStart, + .decoded_samples = {{0}}, + .down_mixing_params = DownMixingParams()}); + decoded_audio_frames.push_back(DecodedAudioFrame{ + .substream_id = kL2SubstreamId, + .start_timestamp = kExpectedStartTimestamp, + .end_timestamp = kExpectedEndTimestamp, + .samples_to_trim_at_end = kExpectedNumSamplesToTrimAtEnd, + .samples_to_trim_at_start = kExpectedNumSamplesToTrimAtStart, + .decoded_samples = {{0}}, + .down_mixing_params = DownMixingParams()}); + DemixingModule demixing_module; + EXPECT_TRUE(demixing_module.InitializeForReconstruction(audio_elements).ok()); + IdLabeledFrameMap unused_id_labeled_frame; + IdLabeledFrameMap id_to_labeled_decoded_frame; + EXPECT_TRUE(demixing_module + .DemixAudioSamples({}, decoded_audio_frames, + unused_id_labeled_frame, + id_to_labeled_decoded_frame) + .ok()); + + const auto& labeled_frame = id_to_labeled_decoded_frame.at(kAudioElementId); + EXPECT_EQ(labeled_frame.end_timestamp, kExpectedEndTimestamp); + EXPECT_EQ(labeled_frame.samples_to_trim_at_end, + kExpectedNumSamplesToTrimAtEnd); + EXPECT_EQ(labeled_frame.samples_to_trim_at_start, + kExpectedNumSamplesToTrimAtStart); +} + +TEST(DemixAudioSamples, OutputEchoesOriginalLabels) { + absl::flat_hash_map audio_elements; + InitAudioElementWithLabelsAndLayers( + {{kMonoSubstreamId, {"M"}}, {kL2SubstreamId, {"L2"}}}, + {ChannelAudioLayerConfig::kLayoutMono, + ChannelAudioLayerConfig::kLayoutStereo}, + audio_elements); + std::list decoded_audio_frames; + decoded_audio_frames.push_back( + DecodedAudioFrame{.substream_id = kMonoSubstreamId, + .start_timestamp = kStartTimestamp, + .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, + .decoded_samples = {{1}, {2}, {3}}, + .down_mixing_params = DownMixingParams()}); + decoded_audio_frames.push_back( + DecodedAudioFrame{.substream_id = kL2SubstreamId, + .start_timestamp = kStartTimestamp, + .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, + .decoded_samples = {{9}, {10}, {11}}, + .down_mixing_params = DownMixingParams()}); + DemixingModule demixing_module; + EXPECT_TRUE(demixing_module.InitializeForReconstruction(audio_elements).ok()); + + IdLabeledFrameMap unused_id_labeled_frame; + IdLabeledFrameMap id_to_labeled_decoded_frame; + EXPECT_TRUE(demixing_module + .DemixAudioSamples({}, decoded_audio_frames, + unused_id_labeled_frame, + id_to_labeled_decoded_frame) + .ok()); + + // Examine the demixed frame. + const auto& labeled_frame = id_to_labeled_decoded_frame.at(kAudioElementId); + EXPECT_EQ(labeled_frame.label_to_samples.at("M"), + std::vector({1, 2, 3})); + EXPECT_EQ(labeled_frame.label_to_samples.at("L2"), + std::vector({9, 10, 11})); +} + +TEST(DemixAudioSamples, OutputHasReconstructedLayers) { + absl::flat_hash_map audio_elements; + + InitAudioElementWithLabelsAndLayers( + {{kMonoSubstreamId, {"M"}}, {kL2SubstreamId, {"L2"}}}, + {ChannelAudioLayerConfig::kLayoutMono, + ChannelAudioLayerConfig::kLayoutStereo}, + audio_elements); + std::list decoded_audio_frames; + decoded_audio_frames.push_back( + DecodedAudioFrame{.substream_id = kMonoSubstreamId, + .start_timestamp = kStartTimestamp, + .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, + .decoded_samples = {{750}}, + .down_mixing_params = DownMixingParams()}); + decoded_audio_frames.push_back( + DecodedAudioFrame{.substream_id = kL2SubstreamId, + .start_timestamp = kStartTimestamp, + .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, + .decoded_samples = {{1000}}, + .down_mixing_params = DownMixingParams()}); + DemixingModule demixing_module; + EXPECT_TRUE(demixing_module.InitializeForReconstruction(audio_elements).ok()); + + IdLabeledFrameMap unused_id_time_labeled_frame; + IdLabeledFrameMap id_to_labeled_decoded_frame; + EXPECT_TRUE(demixing_module + .DemixAudioSamples({}, decoded_audio_frames, + unused_id_time_labeled_frame, + id_to_labeled_decoded_frame) + .ok()); + + // Examine the demixed frame. + const auto& labeled_frame = id_to_labeled_decoded_frame.at(kAudioElementId); + // D_R2 = M - (L2 - 6 dB) + 6 dB. + EXPECT_EQ(labeled_frame.label_to_samples.at("D_R2"), + std::vector({500})); +} + class DemixingModuleTestBase { public: DemixingModuleTestBase() { @@ -657,7 +840,7 @@ class DemixingModuleTest : public DemixingModuleTestBase, DownMixingParams down_mixing_params = { .alpha = 1, .beta = .866, .gamma = .866, .delta = .866, .w = 0.25}) { // The substream ID itself does not matter. Generate a unique one. - const uint32_t substream_id = substream_id_to_labels_.size(); + const DecodedUleb128 substream_id = substream_id_to_labels_.size(); substream_id_to_labels_[substream_id] = labels; // Configure a pair of audio frames and decoded audio frames. They share a @@ -669,10 +852,13 @@ class DemixingModuleTest : public DemixingModuleTestBase, .raw_samples = raw_samples, .down_mixing_params = down_mixing_params, }); + decoded_audio_frames_.push_back( DecodedAudioFrame{.substream_id = substream_id, .start_timestamp = kStartTimestamp, .end_timestamp = kEndTimestamp, + .samples_to_trim_at_end = kZeroSamplesToTrimAtEnd, + .samples_to_trim_at_start = kZeroSamplesToTrimAtStart, .decoded_samples = raw_samples, .down_mixing_params = down_mixing_params}); @@ -723,10 +909,6 @@ class DemixingModuleTest : public DemixingModuleTestBase, std::list decoded_audio_frames_; IdLabeledFrameMap expected_id_to_labeled_decoded_frame_; - - private: - const int32_t kStartTimestamp = 0; - const int32_t kEndTimestamp = 1; }; // namespace TEST_F(DemixingModuleTest, DemixingAudioSamplesSucceedsWithEmptyInputs) {