Skip to content

Commit

Permalink
fix: fix implement
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Oct 21, 2024
1 parent 4d1324f commit 8a945e5
Show file tree
Hide file tree
Showing 16 changed files with 1,407 additions and 83 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ lib/
# React Native Codegen
ios/generated
android/generated

# SPM
.spm.pods/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ target "YourApp" do
# Add these lines
spm_pkg "bark",
:url => "https://github.com/PABannier/bark.cpp.git",
:branch => "1.0.0",
:branch => "main",
:products => ["bark"]

# spm_pkg should be before use_native_modules!
Expand Down
1 change: 1 addition & 0 deletions android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ FetchContent_MakeAvailable(BARK_CPP)

add_library(bark-rn SHARED
../cpp/utils.cpp
../cpp/dr_wav.h
cpp-adapter.cpp
)

Expand Down
119 changes: 85 additions & 34 deletions android/cpp-adapter.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
#include "utils.h"
#include "bark.h"
#include <jni.h>
#include <thread>
#include <tuple>
#include <type_traits>

template <typename T>
T get_map_value(JNIEnv *env, jobject params, const char *key) {
jclass map_class = env->FindClass("java/util/Map");
jmethodID get_method = env->GetMethodID(
map_class, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
return env->CallObjectMethod(params, get_method, env->NewStringUTF(key));
jobject value = env->CallObjectMethod(params, get_method, env->NewStringUTF(key));
if constexpr (std::is_same_v<T, jfloat>) {
jclass float_class = env->FindClass("java/lang/Float");
return env->CallFloatMethod(value, env->GetMethodID(float_class, "floatValue", "()F"));
} else if constexpr (std::is_same_v<T, jint>) {
jclass int_class = env->FindClass("java/lang/Integer");
return env->CallIntMethod(value, env->GetMethodID(int_class, "intValue", "()I"));
} else {
throw std::runtime_error("Unsupported type");
}
}

bool has_map_key(JNIEnv *env, jobject params, const char *key) {
Expand All @@ -18,37 +30,78 @@ bool has_map_key(JNIEnv *env, jobject params, const char *key) {
env->NewStringUTF(key));
}

#define RESOLVE_PARAM(key, cpp_type, java_type) \
if (has_map_key(env, jParams, #key)) { \
params.key = get_map_value<java_type>(env, jParams, #key); \
}

extern "C" JNIEXPORT jlong JNICALL Java_com_barkrn_BarkContext_nativeInitContext(
JNIEnv *env, jclass type, jstring jPath, jobject jParams) {
auto params = bark_context_default_params();
RESOLVE_PARAM(verbosity, bark_verbosity_level, jint);
RESOLVE_PARAM(temp, float, jfloat);
RESOLVE_PARAM(fine_temp, float, jfloat);
RESOLVE_PARAM(min_eos_p, float, jfloat);
RESOLVE_PARAM(sliding_window_size, int, jint);
RESOLVE_PARAM(max_coarse_history, int, jint);
RESOLVE_PARAM(sample_rate, int, jint);
RESOLVE_PARAM(target_bandwidth, int, jint);
RESOLVE_PARAM(cls_token_id, int, jint);
RESOLVE_PARAM(sep_token_id, int, jint);
RESOLVE_PARAM(n_steps_text_encoder, int, jint);
RESOLVE_PARAM(text_pad_token, int, jint);
RESOLVE_PARAM(text_encoding_offset, int, jint);
RESOLVE_PARAM(semantic_rate_hz, float, jfloat);
RESOLVE_PARAM(semantic_pad_token, int, jint);
RESOLVE_PARAM(semantic_vocab_size, int, jint);
RESOLVE_PARAM(semantic_infer_token, int, jint);
RESOLVE_PARAM(coarse_rate_hz, float, jfloat);
RESOLVE_PARAM(coarse_infer_token, int, jint);
RESOLVE_PARAM(coarse_semantic_pad_token, int, jint);
RESOLVE_PARAM(n_coarse_codebooks, int, jint);
RESOLVE_PARAM(n_fine_codebooks, int, jint);
RESOLVE_PARAM(codebook_size, int, jint);
if (has_map_key(env, jParams, "verbosity")) {
params.verbosity = static_cast<bark_verbosity_level>(get_map_value<jint>(env, jParams, "verbosity"));
}
if (has_map_key(env, jParams, "temp")) {
params.temp = get_map_value<jfloat>(env, jParams, "temp");
}
if (has_map_key(env, jParams, "fine_temp")) {
params.fine_temp = get_map_value<jfloat>(env, jParams, "fine_temp");
}
if (has_map_key(env, jParams, "min_eos_p")) {
params.min_eos_p = get_map_value<jfloat>(env, jParams, "min_eos_p");
}
if (has_map_key(env, jParams, "sliding_window_size")) {
params.sliding_window_size = get_map_value<jint>(env, jParams, "sliding_window_size");
}
if (has_map_key(env, jParams, "max_coarse_history")) {
params.max_coarse_history = get_map_value<jint>(env, jParams, "max_coarse_history");
}
if (has_map_key(env, jParams, "sample_rate")) {
params.sample_rate = get_map_value<jint>(env, jParams, "sample_rate");
}
if (has_map_key(env, jParams, "target_bandwidth")) {
params.target_bandwidth = get_map_value<jint>(env, jParams, "target_bandwidth");
}
if (has_map_key(env, jParams, "cls_token_id")) {
params.cls_token_id = get_map_value<jint>(env, jParams, "cls_token_id");
}
if (has_map_key(env, jParams, "sep_token_id")) {
params.sep_token_id = get_map_value<jint>(env, jParams, "sep_token_id");
}
if (has_map_key(env, jParams, "n_steps_text_encoder")) {
params.n_steps_text_encoder = get_map_value<jint>(env, jParams, "n_steps_text_encoder");
}
if (has_map_key(env, jParams, "text_pad_token")) {
params.text_pad_token = get_map_value<jint>(env, jParams, "text_pad_token");
}
if (has_map_key(env, jParams, "text_encoding_offset")) {
params.text_encoding_offset = get_map_value<jint>(env, jParams, "text_encoding_offset");
}
if (has_map_key(env, jParams, "semantic_rate_hz")) {
params.semantic_rate_hz = get_map_value<jfloat>(env, jParams, "semantic_rate_hz");
}
if (has_map_key(env, jParams, "semantic_pad_token")) {
params.semantic_pad_token = get_map_value<jint>(env, jParams, "semantic_pad_token");
}
if (has_map_key(env, jParams, "semantic_vocab_size")) {
params.semantic_vocab_size = get_map_value<jint>(env, jParams, "semantic_vocab_size");
}
if (has_map_key(env, jParams, "semantic_infer_token")) {
params.semantic_infer_token = get_map_value<jint>(env, jParams, "semantic_infer_token");
}
if (has_map_key(env, jParams, "coarse_rate_hz")) {
params.coarse_rate_hz = get_map_value<jfloat>(env, jParams, "coarse_rate_hz");
}
if (has_map_key(env, jParams, "coarse_infer_token")) {
params.coarse_infer_token = get_map_value<jint>(env, jParams, "coarse_infer_token");
}
if (has_map_key(env, jParams, "coarse_semantic_pad_token")) {
params.coarse_semantic_pad_token = get_map_value<jint>(env, jParams, "coarse_semantic_pad_token");
}
if (has_map_key(env, jParams, "n_coarse_codebooks")) {
params.n_coarse_codebooks = get_map_value<jint>(env, jParams, "n_coarse_codebooks");
}
if (has_map_key(env, jParams, "n_fine_codebooks")) {
params.n_fine_codebooks = get_map_value<jint>(env, jParams, "n_fine_codebooks");
}
if (has_map_key(env, jParams, "codebook_size")) {
params.codebook_size = get_map_value<jint>(env, jParams, "codebook_size");
}
int seed = 0;
if (has_map_key(env, jParams, "seed")) {
seed = get_map_value<jint>(env, jParams, "seed");
Expand All @@ -61,25 +114,23 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_barkrn_BarkContext_nativeInitContext

extern "C" JNIEXPORT jobject JNICALL Java_com_barkrn_BarkContext_nativeGenerate(
JNIEnv *env, jclass type, jlong jCtx, jstring jText, jstring jOutPath,
jint jThreads) {
jint jThreads, jint sample_rate) {
auto context = reinterpret_cast<bark_context *>(jCtx);
int threads = jThreads;
if (threads < 0) {
threads = std::thread::hardware_concurrency() << 1;
}
if (threads <= 0) {
} else if (threads == 0) {
threads = 1;
}
auto text = env->GetStringUTFChars(jText, nullptr);
auto success = bark_generate_audio(context, text, threads);
env->ReleaseStringUTFChars(jText, text);
const float *audio_data = bark_get_audio_data(context);
const int audio_samples = bark_get_audio_data_size(context);
const auto sample_rate = context->params.sample_rate;
if (success) {
auto dest_path = env->GetStringUTFChars(jOutPath, nullptr);
std::vector<float> audio_data_vec(audio_data, audio_data + audio_samples);
pcmToWav(audio_data_vec, sample_rate, dest_path);
barkrn::pcmToWav(audio_data_vec, sample_rate, dest_path);
env->ReleaseStringUTFChars(jOutPath, dest_path);
}
const auto load_time = bark_get_load_time(context);
Expand Down
12 changes: 10 additions & 2 deletions android/src/main/java/com/barkrn/BarkContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package com.barkrn

class BarkContext {
private var context: Long = 0L
protected var sample_rate: Int = 24000
protected var n_threads: Int = -1

class BarkResult(success: Boolean, load_time: Int, eval_time: Int) {
val success: Boolean = success
Expand All @@ -10,18 +12,24 @@ class BarkContext {
}

external fun nativeInitContext(model_path: String, params: Map<String, Any>): Long
external fun nativeGenerate(context: Long, text: String, out_path: String, threads: Int): BarkResult
external fun nativeGenerate(context: Long, text: String, out_path: String, threads: Int, sample_rate: Int): BarkResult
external fun nativeReleaseContext(context: Long)

constructor(model_path: String, params: Map<String, Any>) {
context = nativeInitContext(model_path, params)
if (params.containsKey("sample_rate")) {
sample_rate = params["sample_rate"] as Int
}
if (params.containsKey("n_threads")) {
n_threads = params["n_threads"] as Int
}
}

fun generate(text: String, out_path: String, threads: Int = 1): BarkResult {
if (context == 0L) {
throw IllegalStateException("Context not initialized")
}
return nativeGenerate(context, text, out_path, threads)
return nativeGenerate(context, text, out_path, n_threads, sample_rate)
}

fun release() {
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/java/com/barkrn/BarkRnModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class BarkRnModule internal constructor(context: ReactApplicationContext) :
}

@ReactMethod
override fun generate(id: Int, text: String, audio_path: String, threads: Int, promise: Promise) {
override fun generate(id: Int, text: String, audio_path: String, promise: Promise) {
contexts[id]?.let { context ->
val result = context.generate(text, audio_path, threads)
val result = context.generate(text, audio_path)
val resultMap = Arguments.createMap()
resultMap.putBoolean("success", result.success)
resultMap.putInt("load_time", result.load_time)
Expand Down
2 changes: 1 addition & 1 deletion android/src/oldarch/BarkRnSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract class BarkRnSpec internal constructor(context: ReactApplicationContext)
ReactContextBaseJavaModule(context) {

abstract fun init_context(model_path: String, params: ReadableMap, promise: Promise)
abstract fun generate(id: Int, text: String, out_path: String, threads: Int, promise: Promise)
abstract fun generate(id: Int, text: String, out_path: String, promise: Promise)
abstract fun release_context(id: Int, promise: Promise)
abstract fun release_all_contexts(promise: Promise)
}
12 changes: 5 additions & 7 deletions bark-rn.podspec
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ Pod::Spec.new do |s|
end
end

if const_defined?(:ReactNativePodsUtils) && ReactNativePodsUtils.respond_to?(:spm_dependency)
ReactNativePodsUtils.spm_dependency(s,
if defined?(:spm_dependency)
spm_dependency(s,
url: 'https://github.com/PABannier/bark.cpp.git',
requirement: {kind: 'upToNextMajorVersion', minimumVersion: '1.0.0'},
products: spm_products
requirement: {kind: 'branch', branch: 'main'},
products: ['bark']
)
elsif s.respond_to?(:spm_dependency)
for product in spm_products
s.spm_dependency "bark"
end
s.spm_dependency "bark/bark"
else
raise "Please upgrade React Native to >=0.75.0 or install `cocoapods-spm` plugin"
end
Expand Down
1 change: 1 addition & 0 deletions cpp/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#define DR_WAV_IMPLEMENTATION
#include "utils.h"
#include "dr_wav.h"

Expand Down
2 changes: 1 addition & 1 deletion example/Gemfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
source 'https://rubygems.org'

# You may use http://rbenv.org/ or https://rvm.io/ to install and use this version
ruby ">= 2.6.10"
ruby ">= 3.0"

# Exclude problematic versions of cocoapods and activesupport that causes build failures.
gem 'cocoapods', '>= 1.13', '!= 1.15.0', '!= 1.15.1'
Expand Down
9 changes: 4 additions & 5 deletions example/ios/Podfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@ require Pod::Executable.execute_command('node', ['-p',
{paths: [process.argv[1]]},
)', __dir__]).strip

platform :ios, min_ios_version_supported
platform :ios, "14.0"
prepare_react_native_project!

plugin "cocoapods-spm"

linkage = ENV['USE_FRAMEWORKS']
if linkage != nil
Pod::UI.puts "Configuring Pod with #{linkage}ally linked Frameworks".green
use_frameworks! :linkage => linkage.to_sym
end

target 'BarkRnExample' do
spm_pkg "bark",
:url => "https://github.com/PABannier/bark.cpp.git",
:branch => "1.0.0",
:products => ["bark"]
spm_pkg "bark", :url => "https://github.com/PABannier/bark.cpp.git", :branch => "main"

config = use_native_modules!

Expand Down
11 changes: 11 additions & 0 deletions ios/BarkContext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifdef __cplusplus
#import "bark.h"
#endif

@interface BarkContext : NSObject

+ (BarkContext *)initWithModelPath:(NSString *)model_path params:(NSDictionary *)ns_params;
- (void)dealloc;
- (NSDictionary *)generate:(NSString *)text out_path:(NSString *)out_path threads:(NSInteger)threads sample_rate:(NSInteger)sample_rate;

@end
64 changes: 64 additions & 0 deletions ios/BarkContext.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#import "BarkContext.h"
#import "Convert.h"

#include "utils.h"
#include <vector>
#include <thread>

@interface BarkContext ()

@property (nonatomic, assign) bark_context *context;
@property (nonatomic, assign) int sample_rate;
@property (nonatomic, assign) int n_threads;

@end

@implementation BarkContext

+ (BarkContext *)initWithModelPath:(NSString *)model_path params:(NSDictionary *)ns_params {
self = [super init];
if (self) {
int seed = 0;
if (ns_params && ns_params[@"seed"]) seed = [ns_params[@"seed"] intValue];
self.sample_rate = 24000;
if (ns_params && ns_params[@"sample_rate"]) self.sample_rate = [ns_params[@"sample_rate"] intValue];
self.n_threads = -1;
if (ns_params && ns_params[@"n_threads"]) self.n_threads = [ns_params[@"n_threads"] intValue];
if (self.n_threads < 0) self.n_threads = std::thread::hardware_concurrency() << 1;
if (self.n_threads == 0) self.n_threads = 1;
bark_context_params params = [Convert convert_params:ns_params];
try {
self.context = bark_load_model(model_path, params, seed);
} catch (const std::exception &e) {
@throw [NSException exceptionWithName:@"BarkContext" reason:[NSString stringWithUTF8String:e.what()] userInfo:nil];
}
}
return self;
}

- (void)dealloc {
bark_free(self.context);
self.context = NULL;
[super dealloc];
}

- (NSDictionary *)generate:(NSString *)text out_path:(NSString *)out_path {
try {
bool success = bark_generate_audio(self.context, text, self.n_threads);
} catch (const std::exception &e) {
@throw [NSException exceptionWithName:@"BarkContext" reason:[NSString stringWithUTF8String:e.what()] userInfo:nil];
}
if (success) {
int audio_samples = bark_get_audio_data_size(self.context);
const float *audio_data = bark_get_audio_data(self.context);
std::vector<float> audio_data_vec(audio_data, audio_data + audio_samples);
barkrn::pcmToWav(audio_data_vec, self.sample_rate, out_path);
}
return @{
@"success": @(success),
@"load_time": @(bark_get_load_time(self.context)),
@"eval_time": @(bark_get_eval_time(self.context))
};
}

@end
Loading

0 comments on commit 8a945e5

Please sign in to comment.