Skip to content

Commit

Permalink
java : use tiny.en for tests (ggerganov#1484)
Browse files Browse the repository at this point in the history
* java : use tiny.en for tests

* java : try to fix full params struct
  • Loading branch information
ggerganov authored Nov 13, 2023
1 parent 3e5c7fe commit d423164
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ public void enableContext(boolean enable) {
no_context = enable ? CBool.FALSE : CBool.TRUE;
}

/** Generate timestamps or not? */
public CBool no_timestamps;

/** Flag to force single segment output (useful for streaming). (default = false) */
public CBool single_segment;

Expand Down Expand Up @@ -304,10 +307,16 @@ public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
}

/** Grammar stuff */
public Pointer grammar_rules;
public long n_grammar_rules;
public long i_start_rule;
public float grammar_penalty;

@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment",
"no_context", "single_segment", "no_timestamps",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
Expand All @@ -316,6 +325,7 @@ protected List<String> getFieldOrder() {
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data");
"logits_filter_callback", "logits_filter_callback_user_data",
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class WhisperCppTest {
static void init() throws FileNotFoundException {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
String modelName = "../../models/ggml-tiny.bin";
// String modelName = "../../models/ggml-tiny.en.bin";
//String modelName = "../../models/ggml-tiny.bin";
String modelName = "../../models/ggml-tiny.en.bin";
try {
whisper.initContext(modelName);
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
modelInitialised = true;
} catch (FileNotFoundException ex) {
System.out.println("Model " + modelName + " not found");
Expand Down Expand Up @@ -75,11 +75,11 @@ void testFullTranscribe() throws Exception {
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];

// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
// params.initial_prompt = "and so my fellow Americans um, like";
//params.initial_prompt = "and so my fellow Americans um, like";


try {
Expand Down Expand Up @@ -117,12 +117,11 @@ void testFullTranscribeWithTime() throws Exception {
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];

// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
// params.initial_prompt = "and so my fellow Americans um, like";

//params.initial_prompt = "and so my fellow Americans um, like";

try {
audioInputStream.read(b);
Expand Down

0 comments on commit d423164

Please sign in to comment.