Skip to content

Commit

Permalink
Fixes the demixing module to be truly iterative
Browse files Browse the repository at this point in the history
- Previously it demixes all the frames, including the current and past ones.
- Removes the time dimension of the demixing module -- it should always work
  on frames in the same temporal unit.

PiperOrigin-RevId: 633204740
  • Loading branch information
yero authored and jwcullen committed May 14, 2024
1 parent 1dea906 commit df7b8f7
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 91 deletions.
1 change: 1 addition & 0 deletions iamf/cli/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ cc_library(
":audio_element_with_data",
":audio_frame_decoder",
":audio_frame_with_data",
":cli_util",
"//iamf/cli/proto:audio_frame_cc_proto",
"//iamf/cli/proto:user_metadata_cc_proto",
"//iamf/common:macros",
Expand Down
83 changes: 33 additions & 50 deletions iamf/cli/demixing_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "iamf/cli/audio_element_with_data.h"
#include "iamf/cli/audio_frame_decoder.h"
#include "iamf/cli/audio_frame_with_data.h"
#include "iamf/cli/cli_util.h"
#include "iamf/cli/proto/audio_frame.pb.h"
#include "iamf/cli/proto/user_metadata.pb.h"
#include "iamf/common/macros.h"
Expand Down Expand Up @@ -731,10 +732,10 @@ absl::Status StoreSamplesForAudioElementId(
const std::list<AudioFrameWithData>& audio_frames,
const std::list<DecodedAudioFrame>& decoded_audio_frames,
const SubstreamIdLabelsMap& substream_id_to_labels,
TimeLabeledFrameMap& time_to_labeled_frame,
TimeLabeledFrameMap& time_to_labeled_decoded_frame) {
LabeledFrame& labeled_frame, LabeledFrame& labeled_decoded_frame) {
auto audio_frame_iter = audio_frames.begin();
auto decoded_audio_frame_iter = decoded_audio_frames.begin();
const int32_t common_start_timestamp = audio_frame_iter->start_timestamp;
for (; audio_frame_iter != audio_frames.end() &&
decoded_audio_frame_iter != decoded_audio_frames.end();
audio_frame_iter++, decoded_audio_frame_iter++) {
Expand All @@ -760,26 +761,23 @@ absl::Status StoreSamplesForAudioElementId(
return absl::InvalidArgumentError("");
}

// Validate that the frames are all aligned in time.
RETURN_IF_NOT_OK(CompareTimestamps(common_start_timestamp,
audio_frame_iter->start_timestamp));
RETURN_IF_NOT_OK(CompareTimestamps(
common_start_timestamp, decoded_audio_frame_iter->start_timestamp));

int channel_index = 0;
for (const auto& label : labels) {
const size_t num_ticks = audio_frame_iter->raw_samples.size();
const int32_t start_timestamp = audio_frame_iter->start_timestamp;
if (decoded_audio_frame_iter->start_timestamp != start_timestamp) {
LOG(ERROR) << "Start timestamp mismatch: " << start_timestamp << " vs "
<< decoded_audio_frame_iter->start_timestamp;
return absl::InvalidArgumentError("");
}

auto& labeled_frame = time_to_labeled_frame[start_timestamp];
labeled_frame.end_timestamp = audio_frame_iter->end_timestamp;
labeled_frame.samples_to_trim_at_end =
audio_frame_iter->obu.header_.num_samples_to_trim_at_end;
labeled_frame.samples_to_trim_at_start =
audio_frame_iter->obu.header_.num_samples_to_trim_at_start;
labeled_frame.demixing_params = audio_frame_iter->down_mixing_params;

auto& labeled_decoded_frame =
time_to_labeled_decoded_frame[start_timestamp];
labeled_decoded_frame.end_timestamp =
decoded_audio_frame_iter->end_timestamp;
labeled_decoded_frame.samples_to_trim_at_end =
Expand All @@ -805,18 +803,14 @@ absl::Status StoreSamplesForAudioElementId(
}

absl::Status ApplyDemixers(const std::list<Demixer>& demixers,
TimeLabeledFrameMap* time_to_labeled_frame,
TimeLabeledFrameMap* time_to_labeled_decoded_frame) {
for (auto& [time, labeled_frame] : *time_to_labeled_frame) {
auto& labeled_decoded_frame = time_to_labeled_decoded_frame->at(time);
for (const auto& demixer : demixers) {
RETURN_IF_NOT_OK(demixer(labeled_frame.demixing_params,
&labeled_frame.label_to_samples));
RETURN_IF_NOT_OK(demixer(labeled_frame.demixing_params,
&labeled_decoded_frame.label_to_samples));
}
LabeledFrame& labeled_frame,
LabeledFrame& labeled_decoded_frame) {
for (const auto& demixer : demixers) {
RETURN_IF_NOT_OK(demixer(labeled_frame.demixing_params,
&labeled_frame.label_to_samples));
RETURN_IF_NOT_OK(demixer(labeled_frame.demixing_params,
&labeled_decoded_frame.label_to_samples));
}

return absl::OkStatus();
}

Expand Down Expand Up @@ -1054,36 +1048,25 @@ absl::Status DemixingModule::DownMixSamplesToSubstreams(
absl::Status DemixingModule::DemixAudioSamples(
const std::list<AudioFrameWithData>& audio_frames,
const std::list<DecodedAudioFrame>& decoded_audio_frames,
IdTimeLabeledFrameMap& id_to_time_to_labeled_frame,
IdTimeLabeledFrameMap& id_to_time_to_labeled_decoded_frame) const {
IdLabeledFrameMap& id_to_labeled_frame,
IdLabeledFrameMap& id_to_labeled_decoded_frame) const {
for (const auto& [audio_element_id, demixing_metadata] :
audio_element_id_to_demixing_metadata_) {
auto& time_to_labeled_frame = id_to_time_to_labeled_frame[audio_element_id];
auto& time_to_labeled_decoded_frame =
id_to_time_to_labeled_decoded_frame[audio_element_id];

RETURN_IF_NOT_OK(StoreSamplesForAudioElementId(
audio_frames, decoded_audio_frames,
demixing_metadata.substream_id_to_labels, time_to_labeled_frame,
time_to_labeled_decoded_frame));
RETURN_IF_NOT_OK(ApplyDemixers(demixing_metadata.demixers,
&time_to_labeled_frame,
&time_to_labeled_decoded_frame));

LOG(INFO) << "Demixing Audio Element ID= " << audio_element_id;
LOG(INFO) << " Samples has " << time_to_labeled_frame.size() << " frames";
LOG(INFO) << " Decoded Samples has "
<< time_to_labeled_decoded_frame.size() << " frames";
if (!time_to_labeled_frame.empty() &&
!time_to_labeled_decoded_frame.empty()) {
for (const auto& [label, samples] :
time_to_labeled_frame.begin()->second.label_to_samples) {
const auto& decoded_samples = time_to_labeled_decoded_frame.begin()
->second.label_to_samples[label];
LOG(INFO) << " Channel " << label
<< ":\tframe size= " << samples.size()
<< "; decoded frame size= " << decoded_samples.size();
}
auto& labeled_frame = id_to_labeled_frame[audio_element_id];
auto& labeled_decoded_frame = id_to_labeled_decoded_frame[audio_element_id];

RETURN_IF_NOT_OK(
StoreSamplesForAudioElementId(audio_frames, decoded_audio_frames,
demixing_metadata.substream_id_to_labels,
labeled_frame, labeled_decoded_frame));
RETURN_IF_NOT_OK(ApplyDemixers(demixing_metadata.demixers, labeled_frame,
labeled_decoded_frame));

for (const auto& [label, samples] : labeled_frame.label_to_samples) {
const auto& decoded_samples =
labeled_decoded_frame.label_to_samples.at(label);
LOG(INFO) << " Channel " << label << ":\tframe size= " << samples.size()
<< "; decoded frame size= " << decoded_samples.size();
}
}

Expand Down
17 changes: 10 additions & 7 deletions iamf/cli/demixing_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ struct LabeledFrame {
DownMixingParams demixing_params;
};

// Mapping from starting timestamp to a `LabeledFrame`.
// Mapping from audio element ids to `LabeledFrame`s.
typedef absl::flat_hash_map<DecodedUleb128, LabeledFrame> IdLabeledFrameMap;

// Mapping from starting timestamps to `LabeledFrame`s.
typedef absl::btree_map<int32_t, LabeledFrame> TimeLabeledFrameMap;

// Mapping from audio element id to a `TimeLabeledFrameMap`.
// Mapping from audio element ids to `TimeLabeledFrameMap`s.
typedef absl::flat_hash_map<DecodedUleb128, TimeLabeledFrameMap>
IdTimeLabeledFrameMap;

Expand Down Expand Up @@ -162,16 +165,16 @@ class DemixingModule {
*
* \param audio_frames Audio Frames.
* \param decoded_audio_frames Decoded Audio Frames.
* \param id_to_time_to_labeled_frame Output data structure for samples.
* \param id_to_time_to_labeled_decoded_frame Output data structure for
* decoded samples.
* \param id_to_labeled_frame Output data structure for samples.
* \param id_to_labeled_decoded_frame Output data structure for decoded
* samples.
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
absl::Status DemixAudioSamples(
const std::list<AudioFrameWithData>& audio_frames,
const std::list<DecodedAudioFrame>& decoded_audio_frames,
IdTimeLabeledFrameMap& id_to_time_to_labeled_frame,
IdTimeLabeledFrameMap& id_to_time_to_labeled_decoded_frame) const;
IdLabeledFrameMap& id_to_labeled_frame,
IdLabeledFrameMap& id_to_labeled_decoded_frame) const;

/*\!brief Gets the down-mixers associated with an Audio Element ID.
*
Expand Down
25 changes: 19 additions & 6 deletions iamf/cli/encoder_main_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,32 @@ absl::Status GenerateObus(
while (audio_frame_generator.GeneratingFrames()) {
std::list<AudioFrameWithData> temp_audio_frames;
RETURN_IF_NOT_OK(audio_frame_generator.OutputFrames(temp_audio_frames));

if (temp_audio_frames.empty()) {
absl::SleepFor(absl::Milliseconds(50));
} else {
// Decode the audio frames. They are required to determine the demixed
// frames.
std::list<DecodedAudioFrame> decoded_audio_frames;
RETURN_IF_NOT_OK(
audio_frame_decoder.Decode(temp_audio_frames, decoded_audio_frames));
std::list<DecodedAudioFrame> temp_decoded_audio_frames;
RETURN_IF_NOT_OK(audio_frame_decoder.Decode(temp_audio_frames,
temp_decoded_audio_frames));

// Demix the audio frames.
IdLabeledFrameMap id_to_labeled_frame;
IdLabeledFrameMap id_to_labeled_decoded_frame;
RETURN_IF_NOT_OK(demixing_module.DemixAudioSamples(
temp_audio_frames, decoded_audio_frames, id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame));
temp_audio_frames, temp_decoded_audio_frames, id_to_labeled_frame,
id_to_labeled_decoded_frame));

// Collect and organize in time.
const auto start_timestamp = temp_audio_frames.front().start_timestamp;
for (const auto& [id, labeled_frame] : id_to_labeled_frame) {
id_to_time_to_labeled_frame[id][start_timestamp] = labeled_frame;
}
for (const auto& [id, labeled_decoded_frame] :
id_to_labeled_decoded_frame) {
id_to_time_to_labeled_decoded_frame[id][start_timestamp] =
labeled_decoded_frame;
}

audio_frames.splice(audio_frames.end(), temp_audio_frames);
}
Expand Down
43 changes: 15 additions & 28 deletions iamf/cli/tests/demixing_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,7 @@ class DemixingModuleTest : public DemixingModuleTestBase,
.decoded_samples = raw_samples});

auto& expected_label_to_samples =
expected_id_to_time_to_labeled_decoded_frame_[kAudioElementId]
[kStartTimestamp]
.label_to_samples;
expected_id_to_labeled_decoded_frame_[kAudioElementId].label_to_samples;
// `raw_samples` is arranged in (time, channel axes). Arrange the samples
// associated with each channel by time. The demixing process never changes
// data for the input labels.
Expand All @@ -698,42 +696,32 @@ class DemixingModuleTest : public DemixingModuleTestBase,
const std::string& label, std::vector<int32_t> expected_demixed_samples) {
// Configure the expected demixed channels. Typically the input `label`
// should have a "D_" prefix.
expected_id_to_time_to_labeled_decoded_frame_[kAudioElementId]
[kStartTimestamp]
.label_to_samples[label] =
expected_demixed_samples;
expected_id_to_labeled_decoded_frame_[kAudioElementId]
.label_to_samples[label] = expected_demixed_samples;
}

void TestDemixing(int expected_number_of_down_mixers) {
IdTimeLabeledFrameMap unused_id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame;
IdLabeledFrameMap unused_id_to_labeled_frame, id_to_labeled_decoded_frame;

TestCreateDemixingModule(expected_number_of_down_mixers);

EXPECT_TRUE(demixing_module_
.DemixAudioSamples(audio_frames_, decoded_audio_frames_,
unused_id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame)
unused_id_to_labeled_frame,
id_to_labeled_decoded_frame)
.ok());

// Check that the demixed samples have the correct values.
EXPECT_EQ(
id_to_time_to_labeled_decoded_frame[kAudioElementId].size(),
expected_id_to_time_to_labeled_decoded_frame_[kAudioElementId].size());
for (const auto& [time, labeled_frame] :
id_to_time_to_labeled_decoded_frame[kAudioElementId]) {
EXPECT_EQ(
labeled_frame.label_to_samples,
expected_id_to_time_to_labeled_decoded_frame_[kAudioElementId][time]
.label_to_samples);
}
EXPECT_EQ(id_to_labeled_decoded_frame[kAudioElementId].label_to_samples,
expected_id_to_labeled_decoded_frame_[kAudioElementId]
.label_to_samples);
}

protected:
std::list<AudioFrameWithData> audio_frames_;
std::list<DecodedAudioFrame> decoded_audio_frames_;

IdTimeLabeledFrameMap expected_id_to_time_to_labeled_decoded_frame_;
IdLabeledFrameMap expected_id_to_labeled_decoded_frame_;

private:
const int32_t kStartTimestamp = 0;
Expand All @@ -751,18 +739,17 @@ TEST_F(DemixingModuleTest, DemixingAudioSamplesSucceedsWithEmptyInputs) {
.ok());

// Call `DemixAudioSamples()`.
IdTimeLabeledFrameMap id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame;
IdLabeledFrameMap id_to_labeled_frame, id_to_labeled_decoded_frame;
EXPECT_TRUE(demixing_module_
.DemixAudioSamples(
/*audio_frames=*/{},
/*decoded_audio_frames=*/{}, id_to_time_to_labeled_frame,
id_to_time_to_labeled_decoded_frame)
/*decoded_audio_frames=*/{}, id_to_labeled_frame,
id_to_labeled_decoded_frame)
.ok());

// Expect empty outputs.
EXPECT_TRUE(id_to_time_to_labeled_frame.empty());
EXPECT_TRUE(id_to_time_to_labeled_decoded_frame.empty());
EXPECT_TRUE(id_to_labeled_frame.empty());
EXPECT_TRUE(id_to_labeled_decoded_frame.empty());
}

TEST_F(DemixingModuleTest, AmbisonicsHasNoDemixers) {
Expand Down

0 comments on commit df7b8f7

Please sign in to comment.