Skip to content

Commit

Permalink
Desynchronize demixing for original and decoded frames.
Browse files Browse the repository at this point in the history
  - Remove restrictions that original and decoded samples align.
    - This is only relevant for users that have the original samples.
    - Restriction checking belongs with the user that cares about them being synchronized or are not really relevant when there already is `DecodedAudioFrame::down_mix_params``; such as the `ReconGainGenerator`.
    - This helps move towards an interface suitable for users which only care about audio element reconstruction.
  - Update procesisng functions to handle only one type of frames at a time
    - Update `StoreSamplesForAudioElementId` to be a template function which handles original or decoded audio frames.
    - Update `ApplyDemixers` to work only for one `TimeLabeledFrameMap` at a time.

PiperOrigin-RevId: 633259446
  • Loading branch information
jwcullen committed May 14, 2024
1 parent e931620 commit 8ff0ccc
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 77 deletions.
174 changes: 102 additions & 72 deletions iamf/cli/demixing_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstdint>
#include <list>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/no_destructor.h"
Expand Down Expand Up @@ -728,74 +729,81 @@ absl::Status FillRequiredDemixingMetadata(
return absl::OkStatus();
}

void ConfigureLabeledFrame(const AudioFrameWithData& audio_frame,
LabeledFrame& labeled_frame) {
labeled_frame.end_timestamp = audio_frame.end_timestamp;
labeled_frame.samples_to_trim_at_end =
audio_frame.obu.header_.num_samples_to_trim_at_end;
labeled_frame.samples_to_trim_at_start =
audio_frame.obu.header_.num_samples_to_trim_at_start;
labeled_frame.demixing_params = audio_frame.down_mixing_params;
}

void ConfigureLabeledFrame(const DecodedAudioFrame& decoded_audio_frame,
LabeledFrame& labeled_decoded_frame) {
labeled_decoded_frame.end_timestamp = decoded_audio_frame.end_timestamp;
labeled_decoded_frame.samples_to_trim_at_end =
decoded_audio_frame.samples_to_trim_at_end;
labeled_decoded_frame.samples_to_trim_at_start =
decoded_audio_frame.samples_to_trim_at_start;
labeled_decoded_frame.demixing_params =
decoded_audio_frame.down_mixing_params;
}

uint32_t GetSubstreamId(const AudioFrameWithData& audio_frame_with_data) {
return audio_frame_with_data.obu.GetSubstreamId();
}

uint32_t GetSubstreamId(const DecodedAudioFrame& audio_frame_with_data) {
return audio_frame_with_data.substream_id;
}

const std::vector<std::vector<int32_t>>& GetSamples(
const AudioFrameWithData& audio_frame_with_data) {
return audio_frame_with_data.raw_samples;
}

const std::vector<std::vector<int32_t>>& GetSamples(
const DecodedAudioFrame& audio_frame_with_data) {
return audio_frame_with_data.decoded_samples;
}

// TODO(b/339037792): Unify `AudioFrameWithData` and `DecodedAudioFrame`.
template <typename T>
absl::Status StoreSamplesForAudioElementId(
const std::list<AudioFrameWithData>& audio_frames,
const std::list<DecodedAudioFrame>& decoded_audio_frames,
const std::list<T>& audio_frames_or_decoded_audio_frames,
const SubstreamIdLabelsMap& substream_id_to_labels,
LabeledFrame& labeled_frame, LabeledFrame& labeled_decoded_frame) {
auto audio_frame_iter = audio_frames.begin();
auto decoded_audio_frame_iter = decoded_audio_frames.begin();
const int32_t common_start_timestamp = audio_frame_iter->start_timestamp;
for (; audio_frame_iter != audio_frames.end() &&
decoded_audio_frame_iter != decoded_audio_frames.end();
audio_frame_iter++, decoded_audio_frame_iter++) {
const auto substream_id = audio_frame_iter->obu.GetSubstreamId();
LabeledFrame& labeled_frame) {
if (audio_frames_or_decoded_audio_frames.empty()) {
return absl::OkStatus();
}
const int32_t common_start_timestamp =
audio_frames_or_decoded_audio_frames.begin()->start_timestamp;

for (auto& audio_frame : audio_frames_or_decoded_audio_frames) {
const auto substream_id = GetSubstreamId(audio_frame);
auto substream_id_labels_iter = substream_id_to_labels.find(substream_id);
if (substream_id_labels_iter == substream_id_to_labels.end()) {
// This audio frame might belong to a different audio element; skip it.
continue;
}
if (decoded_audio_frame_iter->substream_id != substream_id) {
LOG(ERROR) << "Substream ID mismatch: " << substream_id << " vs "
<< decoded_audio_frame_iter->substream_id;
return absl::InvalidArgumentError("");
}
const auto& labels = substream_id_labels_iter->second;
// TODO(b/339037792): Remove dependency on `raw_samples`.
if (audio_frame_iter->raw_samples[0].size() != labels.size() ||
decoded_audio_frame_iter->decoded_samples[0].size() != labels.size()) {
LOG(ERROR) << "Channel number mismatch: "
<< audio_frame_iter->raw_samples[0].size() << " vs "
<< decoded_audio_frame_iter->decoded_samples[0].size()
<< " vs " << labels.size();
return absl::InvalidArgumentError("");
}

// Validate that the frames are all aligned in time.
RETURN_IF_NOT_OK(CompareTimestamps(common_start_timestamp,
audio_frame_iter->start_timestamp));
RETURN_IF_NOT_OK(CompareTimestamps(
common_start_timestamp, decoded_audio_frame_iter->start_timestamp));
RETURN_IF_NOT_OK(
CompareTimestamps(common_start_timestamp, audio_frame.start_timestamp));

const auto& labels = substream_id_labels_iter->second;
int channel_index = 0;
for (const auto& label : labels) {
const size_t num_ticks = audio_frame_iter->raw_samples.size();

labeled_frame.end_timestamp = audio_frame_iter->end_timestamp;
labeled_frame.samples_to_trim_at_end =
audio_frame_iter->obu.header_.num_samples_to_trim_at_end;
labeled_frame.samples_to_trim_at_start =
audio_frame_iter->obu.header_.num_samples_to_trim_at_start;
labeled_frame.demixing_params = audio_frame_iter->down_mixing_params;

labeled_decoded_frame.end_timestamp =
decoded_audio_frame_iter->end_timestamp;
labeled_decoded_frame.samples_to_trim_at_end =
decoded_audio_frame_iter->samples_to_trim_at_end;
labeled_decoded_frame.samples_to_trim_at_start =
decoded_audio_frame_iter->samples_to_trim_at_start;
labeled_decoded_frame.demixing_params =
decoded_audio_frame_iter->down_mixing_params;
const auto& input_samples = GetSamples(audio_frame);
const size_t num_ticks = input_samples.size();

auto& samples = labeled_frame.label_to_samples[label];
auto& decoded_samples = labeled_decoded_frame.label_to_samples[label];
ConfigureLabeledFrame(audio_frame, labeled_frame);

auto& samples = labeled_frame.label_to_samples[label];
samples.resize(num_ticks, 0);
decoded_samples.resize(num_ticks, 0);
for (int t = 0; t < audio_frame_iter->raw_samples.size(); t++) {
samples[t] = audio_frame_iter->raw_samples[t][channel_index];
decoded_samples[t] =
decoded_audio_frame_iter->decoded_samples[t][channel_index];
for (int t = 0; t < samples.size(); t++) {
samples[t] = input_samples[t][channel_index];
}
channel_index++;
}
Expand All @@ -805,13 +813,10 @@ absl::Status StoreSamplesForAudioElementId(
}

absl::Status ApplyDemixers(const std::list<Demixer>& demixers,
LabeledFrame& labeled_frame,
LabeledFrame& labeled_decoded_frame) {
LabeledFrame& labeled_frame) {
for (const auto& demixer : demixers) {
RETURN_IF_NOT_OK(demixer(labeled_frame.demixing_params,
&labeled_frame.label_to_samples));
RETURN_IF_NOT_OK(demixer(labeled_decoded_frame.demixing_params,
&labeled_decoded_frame.label_to_samples));
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -901,6 +906,24 @@ absl::StatusOr<absl::flat_hash_set<std::string>> LookupLabelsToReconstruct(
}
}

void LogForAudioElementId(
DecodedUleb128 audio_element_id,
const IdLabeledFrameMap& id_to_labeled_frame,
const IdLabeledFrameMap& id_to_labeled_decoded_frame) {
if (!id_to_labeled_frame.contains(audio_element_id) ||
id_to_labeled_decoded_frame.contains(audio_element_id)) {
return;
}
for (const auto& [label, samples] :
id_to_labeled_frame.at(audio_element_id).label_to_samples) {
const auto& decoded_samples =
id_to_labeled_decoded_frame.at(audio_element_id)
.label_to_samples.at(label);
LOG(INFO) << " Channel " << label << ":\tframe size= " << samples.size()
<< "; decoded frame size= " << decoded_samples.size();
}
}

} // namespace

absl::Status DemixingModule::FindSamplesOrDemixedSamples(
Expand Down Expand Up @@ -1054,22 +1077,29 @@ absl::Status DemixingModule::DemixAudioSamples(
IdLabeledFrameMap& id_to_labeled_decoded_frame) const {
for (const auto& [audio_element_id, demixing_metadata] :
audio_element_id_to_demixing_metadata_) {
auto& labeled_frame = id_to_labeled_frame[audio_element_id];
auto& labeled_decoded_frame = id_to_labeled_decoded_frame[audio_element_id];

RETURN_IF_NOT_OK(
StoreSamplesForAudioElementId(audio_frames, decoded_audio_frames,
demixing_metadata.substream_id_to_labels,
labeled_frame, labeled_decoded_frame));
RETURN_IF_NOT_OK(ApplyDemixers(demixing_metadata.demixers, labeled_frame,
labeled_decoded_frame));

for (const auto& [label, samples] : labeled_frame.label_to_samples) {
const auto& decoded_samples =
labeled_decoded_frame.label_to_samples.at(label);
LOG(INFO) << " Channel " << label << ":\tframe size= " << samples.size()
<< "; decoded frame size= " << decoded_samples.size();
// Process the original audio frames.
LabeledFrame labeled_frame;
RETURN_IF_NOT_OK(StoreSamplesForAudioElementId(
audio_frames, demixing_metadata.substream_id_to_labels, labeled_frame));
if (!labeled_frame.label_to_samples.empty()) {
RETURN_IF_NOT_OK(
ApplyDemixers(demixing_metadata.demixers, labeled_frame));
id_to_labeled_frame[audio_element_id] = std::move(labeled_frame);
}
// Process the decoded audio frames.
LabeledFrame labeled_decoded_frame;
RETURN_IF_NOT_OK(StoreSamplesForAudioElementId(
decoded_audio_frames, demixing_metadata.substream_id_to_labels,
labeled_decoded_frame));
if (!labeled_decoded_frame.label_to_samples.empty()) {
RETURN_IF_NOT_OK(
ApplyDemixers(demixing_metadata.demixers, labeled_decoded_frame));
id_to_labeled_decoded_frame[audio_element_id] =
std::move(labeled_decoded_frame);
}

LogForAudioElementId(audio_element_id, id_to_labeled_frame,
id_to_labeled_decoded_frame);
}

return absl::OkStatus();
Expand Down
Loading

0 comments on commit 8ff0ccc

Please sign in to comment.