Skip to content

Commit

Permalink
Merge pull request google-ai-edge#5673 from priankakariatyml:ios-audi…
Browse files Browse the repository at this point in the history
…o-embedder-updated-tests

PiperOrigin-RevId: 684941232
  • Loading branch information
copybara-github committed Oct 11, 2024
2 parents ac0542f + e4c34c8 commit e5067b2
Show file tree
Hide file tree
Showing 9 changed files with 567 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ NS_SWIFT_NAME(AudioEmbedder)
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(audioBlock:timestampInMilliseconds:));

/**
* Closes and cleans up the MediaPipe audio embedder.
*
* For audio embedders initialized with `.audioStream` mode, ensure that this method is called
* after all audio blocks in an audio stream are sent for inference using
* `embedAsync(audioBlock:timestampInMilliseconds:)`. Otherwise, the audio embedder will not
* process the last audio block (of type `AudioData`) in the stream if its `bufferLength` is shorter
* than the model's input length. Once an audio embedder is closed, you cannot send any inference
* requests to it. You must create a new instance of `AudioEmbedder` to send any pending requests.
* Ensure that you are ready to dispose off the audio embedder before this method is invoked.
*
* @return Returns successfully if the task was closed. Otherwise, throws an error
* indicating the reason for failure.
*/
- (BOOL)closeWithError:(NSError **)error;

- (instancetype)init NS_UNAVAILABLE;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ - (instancetype)initWithOptions:(MPPAudioEmbedderOptions *)options error:(NSErro
kTimestampedEmbeddingsOutStreamName]
]
taskOptions:options
enableFlowLimiting:options.runningMode == MPPAudioRunningModeAudioStream
enableFlowLimiting:NO
error:error];

if (!taskInfo) {
Expand Down Expand Up @@ -154,6 +154,10 @@ + (MPPAudioRecord *)createAudioRecordWithChannelCount:(NSUInteger)channelCount
error:error];
}

- (BOOL)closeWithError:(NSError **)error {
return [_audioTaskRunner closeWithError:error];
}

#pragma mark - Private

- (void)processAudioStreamResult:(absl::StatusOr<PacketMap>)audioStreamResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,27 @@ + (MPPAudioEmbedderResult *)audioEmbedderResultWithEmbeddingResultPacket:(const
NSInteger timestampInMilliseconds =
(NSInteger)(packet.Timestamp().Value() / kMicrosecondsPerMillisecond);

if (!packet.ValidateAsType<std::vector<EmbeddingResultProto>>().ok()) {
// MPPAudioEmbedderResult's timestamp is populated from timestamp `EmbeddingResultProto`'s
// timestamp_ms(). It is 0 since the packet can't be validated as a `EmbeddingResultProto`.
std::vector<EmbeddingResultProto> cppEmbeddingResults;
if (packet.ValidateAsType<EmbeddingResultProto>().ok()) {
// If `runningMode = .audioStream`, only a single `EmbeddingResult` will be returned in the
// result packet.
cppEmbeddingResults.emplace_back(packet.Get<EmbeddingResultProto>());
} else if (packet.ValidateAsType<std::vector<EmbeddingResultProto>>().ok()) {
// If `runningMode = .audioStream`, a vector of timestamped `EmbeddingResult`s will be
// returned in the result packet.
cppEmbeddingResults = packet.Get<std::vector<EmbeddingResultProto>>();
} else {
// If packet does not contain protobuf of a type expected by the audio embedder.
return [[MPPAudioEmbedderResult alloc] initWithEmbeddingResults:nil
timestampInMilliseconds:timestampInMilliseconds];
}

std::vector<EmbeddingResultProto> cppEmbeddingResultProtos =
packet.Get<std::vector<EmbeddingResultProto>>();

NSMutableArray<MPPEmbeddingResult *> *embeddingResults =
[NSMutableArray arrayWithCapacity:cppEmbeddingResultProtos.size()];
[NSMutableArray arrayWithCapacity:cppEmbeddingResults.size()];

for (const auto &cppEmbeddingResultProto : cppEmbeddingResultProtos) {
for (const auto &cppEmbeddingResult : cppEmbeddingResults) {
MPPEmbeddingResult *embeddingResult =
[MPPEmbeddingResult embeddingResultWithProto:cppEmbeddingResultProto];
[MPPEmbeddingResult embeddingResultWithProto:cppEmbeddingResult];
[embeddingResults addObject:embeddingResult];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ - (void)testClassifyFailsWithCallingWrongApiInAudioClipsMode {
MPPAudioClassifier *audioClassifier =
[MPPAudioClassifierTests audioClassifierWithOptions:options];

MPPAudioData *audioClip =
[MPPAudioClassifierTests audioDataFromAudioFileWithInfo:kSpeech16KHzMonoFileInfo];
MPPAudioData *audioClip = [[MPPAudioData alloc] initWithFileInfo:kSpeech16KHzMonoFileInfo];
NSError *error;
XCTAssertFalse([audioClassifier classifyAsyncAudioBlock:audioClip
timestampInMilliseconds:0
Expand All @@ -380,8 +379,7 @@ - (void)testClassifyFailsWithCallingWrongApiInAudioStreamMode {
MPPAudioClassifier *audioClassifier =
[MPPAudioClassifierTests audioClassifierWithOptions:options];

MPPAudioData *audioClip =
[MPPAudioClassifierTests audioDataFromAudioFileWithInfo:kSpeech16KHzMonoFileInfo];
MPPAudioData *audioClip = [[MPPAudioData alloc] initWithFileInfo:kSpeech16KHzMonoFileInfo];

NSError *error;
XCTAssertFalse([audioClassifier classifyAudioClip:audioClip error:&error]);
Expand Down Expand Up @@ -562,38 +560,6 @@ - (void)classifyUsingYamnetAsyncAudioFileWithInfo:(MPPFileInfo *)audioFileInfo

#pragma mark Audio Data Initializers

+ (MPPAudioData *)audioDataFromAudioFileWithInfo:(MPPFileInfo *)fileInfo {
// Load the samples from the audio file in `Float32` interleaved format to
// an `AVAudioPCMBuffer`.
AVAudioPCMBuffer *buffer =
[AVAudioPCMBuffer interleavedFloat32BufferFromAudioFileWithInfo:fileInfo];

// Create a float buffer from the `floatChannelData` of `AVAudioPCMBuffer`. This float buffer will
// be used to load the audio data.
MPPFloatBuffer *bufferData = [[MPPFloatBuffer alloc] initWithData:buffer.floatChannelData[0]
length:buffer.frameLength];

MPPAudioData *audioData = [[MPPAudioData alloc] initWithChannelCount:buffer.format.channelCount
sampleRate:buffer.format.sampleRate
sampleCount:buffer.frameLength];

// Load all the samples in the audio file to the newly created audio data.
[audioData loadBuffer:bufferData offset:0 length:bufferData.length error:nil];
return audioData;
}

+ (MPPAudioData *)audioDataWithChannelCount:(NSUInteger)channelCount
sampleRate:(double)sampleRate
sampleCount:(NSUInteger)sampleCount {
MPPAudioDataFormat *audioDataFormat =
[[MPPAudioDataFormat alloc] initWithChannelCount:channelCount sampleRate:sampleRate];

MPPAudioData *audioData = [[MPPAudioData alloc] initWithFormat:audioDataFormat
sampleCount:sampleCount];

return audioData;
}

+ (NSArray<MPPTimestampedAudioData *> *)streamedAudioDataListforYamnet {
NSArray<MPPTimestampedAudioData *> *streamedAudioDataList =
[AVAudioFile streamedAudioBlocksFromAudioFileWithInfo:kSpeech16KHzMonoFileInfo
Expand All @@ -609,7 +575,7 @@ + (MPPAudioData *)audioDataWithChannelCount:(NSUInteger)channelCount

- (MPPAudioClassifier *)audioClassifierInStreamModeWithModelFileInfo:(MPPFileInfo *)fileInfo {
MPPAudioClassifierOptions *options =
[MPPAudioClassifierTests audioClassifierOptionsWithModelFileInfo:kYamnetModelFileInfo];
[MPPAudioClassifierTests audioClassifierOptionsWithModelFileInfo:fileInfo];
options.runningMode = MPPAudioRunningModeAudioStream;
options.audioClassifierStreamDelegate = self;

Expand Down Expand Up @@ -701,7 +667,7 @@ + (void)assertResultsOfClassifyAudioClipWithFileInfo:(MPPFileInfo *)fileInfo

+ (MPPAudioClassifierResult *)classifyAudioClipWithFileInfo:(MPPFileInfo *)fileInfo
usingAudioClassifier:(MPPAudioClassifier *)audioClassifier {
MPPAudioData *audioData = [MPPAudioClassifierTests audioDataFromAudioFileWithInfo:fileInfo];
MPPAudioData *audioData = [[MPPAudioData alloc] initWithFileInfo:fileInfo];
MPPAudioClassifierResult *result = [audioClassifier classifyAudioClip:audioData error:nil];
XCTAssertNotNil(result);

Expand Down
63 changes: 63 additions & 0 deletions mediapipe/tasks/ios/test/audio/audio_embedder/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
load(
"@build_bazel_rules_apple//apple:ios.bzl",
"ios_unit_test",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
load(
"//mediapipe/framework/tool:ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)

package(default_visibility = ["//visibility:public"])

licenses(["notice"])

# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]

# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]

objc_library(
name = "MPPAudioEmbedderObjcTestLibrary",
testonly = 1,
srcs = ["MPPAudioEmbedderTests.mm"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
data = [
"//mediapipe/tasks/testdata/audio:test_audio_clips",
"//mediapipe/tasks/testdata/audio:test_models",
],
deps = [
"//mediapipe/tasks/ios/audio/audio_embedder:MPPAudioEmbedder",
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/test/audio/core/utils:AVAudioFileTestUtils",
"//mediapipe/tasks/ios/test/audio/core/utils:AVAudioPCMBufferTestUtils",
"//mediapipe/tasks/ios/test/audio/core/utils:MPPAudioDataTestUtils",
"//mediapipe/tasks/ios/test/utils:MPPFileInfo",
"//third_party/apple_frameworks:XCTest",
],
)

ios_unit_test(
name = "MPPAudioEmbedderObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPAudioEmbedderObjcTestLibrary",
],
)
Loading

0 comments on commit e5067b2

Please sign in to comment.