Skip to content

Commit

Permalink
Implement basic Whisper transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
AutonomicPerfectionist committed Sep 10, 2023
1 parent fe942bc commit f0f4ba6
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 26 deletions.
180 changes: 177 additions & 3 deletions src/main/java/org/myrobotlab/service/Whisper.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
package org.myrobotlab.service;

import org.myrobotlab.framework.Service;
import org.myrobotlab.service.config.ServiceConfig;
import io.github.givimad.whisperjni.WhisperContext;
import io.github.givimad.whisperjni.WhisperFullParams;
import io.github.givimad.whisperjni.WhisperJNI;
import org.myrobotlab.framework.Platform;
import org.myrobotlab.service.abstracts.AbstractSpeechRecognizer;
import org.myrobotlab.service.config.LlamaConfig;
import org.myrobotlab.service.config.WhisperConfig;
import org.myrobotlab.service.data.Locale;

import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.Line;
import javax.sound.sampled.LineUnavailableException;
import javax.sound.sampled.Mixer;
import javax.sound.sampled.TargetDataLine;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Path;
import java.util.Map;

public class Whisper extends AbstractSpeechRecognizer<WhisperConfig> {
private transient WhisperJNI whisper;

private transient WhisperContext ctx;

private transient WhisperFullParams params;

private transient Thread listeningThread = new Thread();


public class Whisper extends Service<ServiceConfig> {
/**
* Constructor of service, reservedkey typically is a services name and inId
* will be its process id
Expand All @@ -14,4 +48,144 @@ public class Whisper extends Service<ServiceConfig> {
public Whisper(String reservedKey, String inId) {
super(reservedKey, inId);
}

public void loadModel(String modelPath) {
try {
whisper = new WhisperJNI();
WhisperJNI.loadLibrary();
ctx = whisper.init(Path.of(modelPath));
} catch (IOException e) {
throw new RuntimeException(e);
}

params = new WhisperFullParams();
params.nThreads = Platform.getLocalInstance().getNumPhysicalProcessors();
params.printRealtime = true;
params.printProgress = true;

}

public String findModelPath(String modelName) {
// First, we loop over all user-defined
// model directories
for (String dir : config.modelPaths) {
File path = new File(dir + fs + modelName);
if (path.exists()) {
return path.getAbsolutePath();
}
}

// Now, we check our data directory for any downloaded models
File path = new File(getDataDir() + fs + modelName);
if (path.exists()) {
return path.getAbsolutePath();
} else if (config.modelUrls.containsKey(modelName)) {
// Model was not in data but we do have a URL for it
try (FileOutputStream fileOutputStream = new FileOutputStream(path)) {
ReadableByteChannel readableByteChannel = Channels.newChannel(new URL(config.modelUrls.get(modelName)).openStream());
FileChannel fileChannel = fileOutputStream.getChannel();
info("Downloading model %s to path %s from URL %s", modelName, path, config.modelUrls.get(modelName));
fileChannel.transferFrom(readableByteChannel, 0, Long.MAX_VALUE);
} catch (IOException e) {
throw new RuntimeException(e);
}
return path.getAbsolutePath();
}
// Cannot find the model anywhere
error("Could not locate model {}, add its URL to download it or add a directory where it is located", modelName);
return null;
}

@Override
public void startListening() {

listeningThread = new Thread(() -> {
AudioFormat format = new AudioFormat(16000.0f, 16, 1, true, false);
TargetDataLine microphone = null;

Mixer.Info[] mixerInfos = AudioSystem.getMixerInfo();
for (Mixer.Info info: mixerInfos){
Mixer m = AudioSystem.getMixer(info);
Line.Info[] lineInfos = m.getTargetLineInfo();
for (Line.Info lineInfo:lineInfos){
System.out.println (info.getName()+"---"+lineInfo);
// Hard-code for my mic right now
if (info.getName().contains("U0x46d0x825")) {
try {
microphone = (TargetDataLine) m.getLine(lineInfo);
microphone.open(format);
System.out.println("Sample rate: " + format.getSampleRate());
} catch (LineUnavailableException e) {
throw new RuntimeException(e);
}
}

}

}

int numBytesRead;

microphone.start();
while(config.listening) {
int CHUNK_SIZE = (int)((format.getFrameSize() * format.getFrameRate())) * 5;
ByteBuffer captureBuffer = ByteBuffer.allocate(CHUNK_SIZE);
captureBuffer.order(ByteOrder.LITTLE_ENDIAN);
numBytesRead = microphone.read(captureBuffer.array(), 0, CHUNK_SIZE);
System.out.println("Num bytes read=" + numBytesRead);
ShortBuffer shortBuffer = captureBuffer.asShortBuffer();
// transform the samples to f32 samples
float[] samples = new float[captureBuffer.capacity() / 2];
int index = 0;
shortBuffer.position(0);
while (shortBuffer.hasRemaining()) {
samples[index++] = Float.max(-1f, Float.min(((float) shortBuffer.get()) / (float) Short.MAX_VALUE, 1f));
}
int result = whisper.full(ctx, params, samples, samples.length);
if(result != 0) {
throw new RuntimeException("Transcription failed with code " + result);
}
int numSegments = whisper.fullNSegments(ctx);
System.out.println("Inference done, numSegments=" + numSegments);
for (int i = 0; i < numSegments; i++) {
System.out.println(whisper.fullGetSegmentText(ctx, i));
invoke("publishRecognized", whisper.fullGetSegmentText(ctx, i));
}

}
microphone.close();
});
super.startListening();

listeningThread.start();
}

@Override
public WhisperConfig apply(WhisperConfig c) {
super.apply(c);

if (config.selectedModel != null && !config.selectedModel.isEmpty()) {
String modelPath = findModelPath(config.selectedModel);
if (modelPath != null) {
loadModel(modelPath);
} else {
error("Could not find selected model {}", config.selectedModel);
}
}

return config;
}

/**
* locales this service supports - implementation can simply get
* runtime.getLocales() if acceptable or create their own locales
*
* @return map of string to locale
*/
@Override
public Map<String, Locale> getLocales() {
return null;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public void clearLock() {
*/
@Override
public String getWakeWord() {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
return c.wakeWord;
}

Expand All @@ -177,17 +177,16 @@ public String getWakeWord() {
*/
@Override
public boolean isListening() {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
return c.listening;
}

@Override
@Deprecated /* use publishListening(boolean event) */
public void listeningEvent(Boolean event) {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.listening = event;
broadcastState();
return;
}

@Override
Expand All @@ -213,12 +212,12 @@ public void onEndSpeaking(String utterance) {
// affect "recognizing"
// FIXME - add a deta time after ...

SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;

if (c.afterSpeakingPauseMs > 0) {
// remove previous one shot - because we are "sliding" the window of
// stopping the publishing of recognized words
addTaskOneShot(c.afterSpeakingPauseMs, "setSpeaking", new Object[] { false });
addTaskOneShot(c.afterSpeakingPauseMs, "setSpeaking", false);
log.warn("isSpeaking = false will occur in {} ms", c.afterSpeakingPauseMs);
} else {
setSpeaking(false, null);
Expand All @@ -233,17 +232,16 @@ public void onAudioStart(AudioData data) {
purgeTask("setSpeaking");
// isSpeaking = true;
setSpeaking(true, data.getFileName());
return;
}

@Override
public void onAudioEnd(AudioData data) {
log.info("sound stopped {}", data);
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
if (c.afterSpeakingPauseMs > 0) {
// remove previous one shot - because we are "sliding" the window of
// stopping the publishing of recognized words
addTaskOneShot(c.afterSpeakingPauseMs, "setSpeaking", new Object[] { false });
addTaskOneShot(c.afterSpeakingPauseMs, "setSpeaking", false);
log.warn("isSpeaking = false will occur in {} ms", c.afterSpeakingPauseMs);
} else {
setSpeaking(false, null);
Expand All @@ -264,7 +262,7 @@ public boolean setSpeaking(boolean b, String utterance) {

ListeningEvent event = new ListeningEvent();

SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
event.isRecording = c.recording;
event.isListening = c.listening;
event.isAwake = isAwake;
Expand All @@ -289,7 +287,6 @@ public void onStartSpeaking(String utterance) {
purgeTask("setSpeaking");
// isSpeaking = true;
setSpeaking(true, utterance);
return;
}

@Override
Expand All @@ -304,11 +301,10 @@ public void pauseListening() {
public ListeningEvent[] processResults(ListeningEvent[] results) {
// at the moment its simply invoking other methods, but if a new speech
// recognizer is created - it might need more processing
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;


for (int i = 0; i < results.length; ++i) {
ListeningEvent event = results[i];
for (ListeningEvent event : results) {
event.isRecording = c.recording;
event.isListening = c.listening;
event.isAwake = isAwake;
Expand Down Expand Up @@ -366,7 +362,7 @@ public void setAwake(boolean b) {
}

public void setAwake(boolean b, String text) {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;

if (!b && isSpeaking) {
log.info("bot is speaking - bot doesn't get tired when talking about self sliding idle timeout");
Expand Down Expand Up @@ -463,7 +459,7 @@ public void setLowerCase(boolean b) {
*/
@Override
public void setWakeWord(String word) {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;

if (word == null || word.trim().length() == 0) {
word = null;
Expand All @@ -487,7 +483,7 @@ public void setWakeWord(String word) {
*
*/
public void setWakeWordTimeout(Integer wakeWordTimeoutSeconds) {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.wakeWordIdleTimeoutSeconds = wakeWordTimeoutSeconds;
broadcastState();
}
Expand All @@ -496,7 +492,7 @@ public void setWakeWordTimeout(Integer wakeWordTimeoutSeconds) {
@Override
public void startListening() {
log.debug("Start listening event seen.");
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.listening = true;
c.recording = true;
broadcastState();
Expand All @@ -518,7 +514,7 @@ public void setAutoListen(Boolean value) {
*/
@Override
public void startRecording() {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.recording = true;
broadcastState();
}
Expand All @@ -531,7 +527,7 @@ public void startRecording() {
@Override
public void stopListening() {
log.debug("stopListening()");
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.listening = false;
broadcastState();
}
Expand All @@ -542,7 +538,7 @@ public void stopListening() {

@Override
public void stopRecording() {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.recording = false;
broadcastState();
}
Expand All @@ -555,13 +551,13 @@ public void stopService() {
}

public long setAfterSpeakingPause(long ms) {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
c.afterSpeakingPauseMs = ms;
return c.afterSpeakingPauseMs;
}

public long getAfterSpeakingPause() {
SpeechRecognizerConfig c = (SpeechRecognizerConfig)config;
SpeechRecognizerConfig c = config;
return c.afterSpeakingPauseMs;
}

Expand Down
23 changes: 23 additions & 0 deletions src/main/java/org/myrobotlab/service/config/WhisperConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.myrobotlab.service.config;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class WhisperConfig extends SpeechRecognizerConfig {
public String selectedModel = "ggml-tiny.en.bin";

public List<String> modelPaths = new ArrayList<>(List.of(

));

public Map<String, String> modelUrls = new HashMap<>(Map.of(
"ggml-tiny.bin", "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin",
"ggml-small.bin", "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin",
"ggml-tiny.en.bin", "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin",
"ggml-small.en.bin", "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin",
"ggml-medium-q5_0.bin", "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium-q5_0.bin",
"ggml-medium.en-q5_0.bin", "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q5_0.bin"
));
}

0 comments on commit f0f4ba6

Please sign in to comment.