Skip to content

Commit

Permalink
Derive the sample rte exclusively from TTSEngine
Browse files Browse the repository at this point in the history
Don't use sample rate heuristics sprinkled all over the code, but always
use the sample rate that is derived from the current voice's native
sample rate.

Also try to simplify any related code dealing with sample rate.

Moreover, improve network voice error handling.

Signed-off-by: Daniel Schnell <[email protected]>
  • Loading branch information
lumpidu committed Feb 2, 2024
1 parent 39b9cd5 commit 0c54f82
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 97 deletions.
18 changes: 11 additions & 7 deletions app/src/main/java/com/grammatek/simaromur/AppRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public boolean isCurrentVoice(Voice voice) {
*/
public void downloadVoiceAsync(Voice voice, DownloadVoiceManager.Observer finishedObserver) {
// when the download is successful, the voice is updated in the database. This happens
// asynchonously.
// asynchronously.
mDVM.downloadVoiceAsync(voice, finishedObserver, mVoiceDao);
}

Expand Down Expand Up @@ -466,10 +466,10 @@ public void startNetworkTTS(Voice voice, CacheItem item, TTSRequest ttsRequest,
Log.v(LOG_TAG, "startNetworkTTS: " + item.getUuid());
// map given voice to voiceId
if (voice != null) {
final TTSObserver ttsObserver = new TTSObserver(pitch, speed, AudioManager.SAMPLE_RATE_WAV);
final TTSObserver ttsObserver = new TTSObserver(pitch, speed, mNetworkSpeakController.getNativeSampleRate());
if (playIfAudioCacheHit(voice.internalName, voice.version, item, ttsObserver, ttsRequest)) return;

final String SampleRate = "" + AudioManager.SAMPLE_RATE_WAV;
final String SampleRate = "" + mNetworkSpeakController.getNativeSampleRate();
final String normalized = item.getUtterance().getNormalized();
if (normalized.trim().isEmpty()) {
Log.w(LOG_TAG, "startNetworkTTS: given text is whitespace only ?!");
Expand Down Expand Up @@ -529,8 +529,7 @@ public TTSEngineController.SpeakTask startDeviceSpeak(Voice voice, CacheItem ite
e.printStackTrace();
return null;
}
return mTTSEngineController.StartSpeak(item, speed, pitch,
mTTSEngineController.getEngine().GetNativeSampleRate(), observer, getCurrentTTsRequest());
return mTTSEngineController.StartSpeak(item, speed, pitch, observer, getCurrentTTsRequest());
}

/**
Expand Down Expand Up @@ -587,6 +586,10 @@ public String getVersionOfVoice(String internalVoiceName) {
return null;
}

public int getVoiceNativeSampleRate() {
return mTTSEngineController.getEngine().GetNativeSampleRate();
}

/**
* Find if we have the specified language available.
* Use our DB model to query availability of voices
Expand Down Expand Up @@ -795,6 +798,7 @@ public void showTtsBackendWarningDialog(Context context) {
*/
public void speakAssetFile(SynthesisCallback callback, String assetFilename) {
Log.v(LOG_TAG, "playAssetFile: " + assetFilename);
final int SAMPLE_RATE_ASSETS = 22050;
try {
InputStream inputStream = App.getContext().getAssets().open(assetFilename);
int size = inputStream.available();
Expand All @@ -803,7 +807,7 @@ public void speakAssetFile(SynthesisCallback callback, String assetFilename) {
Log.w(LOG_TAG, "playAssetFile: not enough bytes ?");
}
// don't provide rawText: there are no speech marks to update
callback.start(AudioManager.SAMPLE_RATE_WAV, AudioFormat.ENCODING_PCM_16BIT,
callback.start(SAMPLE_RATE_ASSETS, AudioFormat.ENCODING_PCM_16BIT,
AudioManager.N_CHANNELS);
feedBytesToSynthesisCallback(callback, buffer, "");
callback.done();
Expand Down Expand Up @@ -850,7 +854,7 @@ public static void feedBytesToSynthesisCallback(SynthesisCallback callback, byte
final int bytesConsumed = Math.min(maxBytes, bytesLeft);
if (callback.hasStarted()) {
// this feeds audio data to the callback, which will then be consumed by the TTS
// client. In case the current utterance is stopped(), all remaining audio data is
// client. In case the current utterance is stopped, all remaining audio data is
// consumed and discarded and afterwards TTSService.onStopped() is executed.
int cbStatus = callback.audioAvailable(buffer, offset, bytesConsumed);
switch(cbStatus) {
Expand Down
28 changes: 13 additions & 15 deletions app/src/main/java/com/grammatek/simaromur/TTSService.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;

Expand Down Expand Up @@ -109,7 +108,7 @@ protected int onLoadLanguage(String language, String country, String variant) {
* waiting for a TTSProcessingResult inside onSynthesizeText(). If afterwards the audio processing is
* finished, the processing result is received and discarded, because the current utterance is
* already finished and has changed.
*
* <p>
* Note: mandatory, don't synchronize this method !
*/
@Override
Expand Down Expand Up @@ -156,7 +155,7 @@ protected void onSynthesizeText(SynthesisRequest request,
loadedVoiceName = mRepository.getLoadedVoiceName();
} else {
Log.w(LOG_TAG, "onSynthesizeText: couldn't load voice ("+voiceNameToLoad+")");
callback.start(AudioManager.SAMPLE_RATE_WAV, AudioFormat.ENCODING_PCM_16BIT,
callback.start(mRepository.getVoiceNativeSampleRate(), AudioFormat.ENCODING_PCM_16BIT,
AudioManager.N_CHANNELS);
callback.error(TextToSpeech.ERROR_SERVICE);
if (callback.hasStarted() && ! callback.hasFinished()) {
Expand Down Expand Up @@ -215,12 +214,12 @@ protected void onSynthesizeText(SynthesisRequest request,
Log.v(LOG_TAG, "onSynthesizeText: finished (" + item.getUuid() + ")");
return;
}
startSynthesisCallback(callback, AudioManager.SAMPLE_RATE_WAV, true);
startSynthesisCallback(callback, mRepository.getVoiceNativeSampleRate(), true);
setSpeechMarksToBeginning(callback);
mRepository.startNetworkTTS(voice, item, ttsRequest, speechrate / 100.0f, pitch / 100.0f);
break;
case com.grammatek.simaromur.db.Voice.TYPE_ONNX:
startSynthesisCallback(callback, AudioManager.SAMPLE_RATE_ONNX, false);
startSynthesisCallback(callback, mRepository.getVoiceNativeSampleRate(), false);
setSpeechMarksToBeginning(callback);
mRepository.startDeviceTTS(voice, item, ttsRequest, speechrate / 100.0f, pitch / 100.0f);
break;
Expand Down Expand Up @@ -251,7 +250,7 @@ private void handleProcessingResult(SynthesisCallback callback, CacheItem item,
// todo: we need to handle timeout errors here, e.g. processing
// timeouts, some error, e.g. network timeouts are already taken care of
TTSProcessingResult elem = mRepository.dequeueTTSProcessingResult();
float rtf = estimateRTF(startTime, System.currentTimeMillis(), item, elem);
float rtf = estimateRTF(startTime, System.currentTimeMillis(), elem);
Log.v(LOG_TAG, "estimateRTF: rtf=" + rtf);
if (rtf > 500.0f && !isCached) {
Log.w(LOG_TAG, "handleProcessingResult: rtf > 500.0f, something went wrong for the estimation");
Expand Down Expand Up @@ -320,11 +319,10 @@ private void handleProcessingResult(SynthesisCallback callback, CacheItem item,
*
* @param startTimeMillis time when the processing started
* @param stopTimeMillis time when the processing stopped
* @param item cache item
* @param elem processing result
* @return the real time factor
*/
private float estimateRTF(long startTimeMillis, long stopTimeMillis, CacheItem item, TTSProcessingResult elem) {
private float estimateRTF(long startTimeMillis, long stopTimeMillis, TTSProcessingResult elem) {
String uuid = elem.getTTSRequest().getCacheItemUuid();
Log.v(LOG_TAG, "estimateRTF for: " + uuid);

Expand All @@ -336,7 +334,7 @@ private float estimateRTF(long startTimeMillis, long stopTimeMillis, CacheItem i
// assume currently slowest used sample rate, i.e. 16kHz and 16 bit with 1 channel
// TODO: we should use the real sample rate here, but this needs to be passed via the
// TTSProcessingResult
final int sampleRate = AudioManager.SAMPLE_RATE_WAV;
final int sampleRate = mRepository.getVoiceNativeSampleRate();
final int bytesPerSample = 2;
final int channels = 1;

Expand Down Expand Up @@ -445,7 +443,7 @@ private boolean testForAndHandleNetworkVoiceIssues(SynthesisCallback callback,

/**
* Signal TTS client a TTS error with given error code.
*
* <p>
* The sequence for signalling an error seems to be important: callback.start(),
* callback.error(), callback.done(). Any callback.audioAvailable() call after a callback.error()
* is ignored.
Expand All @@ -455,7 +453,7 @@ private boolean testForAndHandleNetworkVoiceIssues(SynthesisCallback callback,
*/
private void signalTtsError(SynthesisCallback callback, int errorCode) {
Log.w(LOG_TAG, "signalTtsError(): errorCode = " + errorCode);
callback.start(AudioManager.SAMPLE_RATE_WAV, AudioFormat.ENCODING_PCM_16BIT,
callback.start(mRepository.getVoiceNativeSampleRate(), AudioFormat.ENCODING_PCM_16BIT,
AudioManager.N_CHANNELS);
callback.error(errorCode);
callback.done();
Expand All @@ -467,12 +465,12 @@ private void signalTtsError(SynthesisCallback callback, int errorCode) {
*
* @param callback TTS callback provided in the onSynthesizeText() callback
*/
private static void playSilence(SynthesisCallback callback) {
private void playSilence(SynthesisCallback callback) {
Log.v(LOG_TAG, "playSilence() ...");
callback.start(AudioManager.SAMPLE_RATE_WAV, AudioFormat.ENCODING_PCM_16BIT,
AudioManager.N_CHANNELS);
int sampleRate = mRepository.getVoiceNativeSampleRate();
callback.start(sampleRate, AudioFormat.ENCODING_PCM_16BIT, AudioManager.N_CHANNELS);
setSpeechMarksToBeginning(callback);
byte[] silenceData = AudioManager.generatePcmSilence(0.25f);
byte[] silenceData = AudioManager.generatePcmSilence(0.25f, sampleRate);
callback.audioAvailable(silenceData, 0, silenceData.length);
if (! callback.hasFinished() && callback.hasStarted()) {
callback.done();
Expand Down
28 changes: 4 additions & 24 deletions app/src/main/java/com/grammatek/simaromur/audio/AudioManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ public class AudioManager {
private final static String LOG_TAG = "Simaromur_" + AudioManager.class.getSimpleName();

// Some constants used throughout audio conversion
public static final int SAMPLE_RATE_WAV = 16000;
//public static final int SAMPLE_RATE_WAV = 16000;
public static final int SAMPLE_RATE_MP3 = 22050;
public static final int SAMPLE_RATE_ONNX = 16000;
//public static final int SAMPLE_RATE_ONNX = 16000;
//public static final int SAMPLE_RATE_ONNX = 22050;
public static final int N_CHANNELS = 1;

Expand Down Expand Up @@ -79,26 +79,6 @@ static public byte[] applyPitchAndSpeed(final byte[] monoPcmData, int sampleRate
return outputConversionStream.toByteArray();
}

/**
* Either apply pitch and speed to ttsData, resulting in a potentially differently sized output
* buffer, or simply copy ttsData to the new output buffer, if no changes of speed or pitch
* are requested.
* Return the newly created output buffer.
*
* @param monoPcmData byte array of MONO PCM data to be used as input data. 22050 Hz sample rate
* is expected
* @param pitch pitch to be applied. 1.0f means no pitch change, values > 1.0 mean higher
* pitch, values < 1.0 mean lower pitch than in given pcmData
* @param speed speed to be applied. 1.0f means no speed change, values > 1.0 mean higher
* speed, values < 1.0 mean lower speed than in given pcmData. This parameter
* produces either more data for values >1.0, less data for values < 1.0, or
* no data change for a value of 1.0
* @return new byte array with converted PCM data
*/
static public byte[] applyPitchAndSpeed(final byte[] monoPcmData, float pitch, float speed) {
return applyPitchAndSpeed(monoPcmData, SAMPLE_RATE_WAV, pitch, speed);
}

/**
* Converts given float values to 16bits PCM. No resampling or interpolation is done.
* Floats are rounded to the nearest integer.
Expand Down Expand Up @@ -284,10 +264,10 @@ static public byte[] pcmFloatTo16BitPCMWithDither(float[] pcmFloats, float norma
return outBuf;
}

static public byte[] generatePcmSilence(float duration) {
static public byte[] generatePcmSilence(float duration, int sampleRate) {
final int nChannels = 1;
final int nBits = 16;
final int nSamples = (int) (duration * SAMPLE_RATE_WAV);
final int nSamples = (int) (duration * sampleRate);
final int nBytes = nSamples * nChannels * nBits / 8;
return new byte[nBytes];
}
Expand Down
Loading

0 comments on commit 0c54f82

Please sign in to comment.