Skip to content

Commit

Permalink
Improvements to
Browse files Browse the repository at this point in the history
- real time voice conversations
- text processing for llm
- interrupt signal
  • Loading branch information
w4ffl35 committed Oct 12, 2024
1 parent 8899b28 commit 2a867ba
Show file tree
Hide file tree
Showing 16 changed files with 210 additions and 232 deletions.
1 change: 1 addition & 0 deletions src/airunner/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ class LLMActionType(Enum):
TOGGLE_TTS = "TOGGLE TEXT-TO-SPEECH: If the user requests that you turn on or off or toggle text-to-speech, choose this action."
PERFORM_RAG_SEARCH = "SEARCH: If the user requests that you search for information, choose this action."
SUMMARIZE = "SUMMARIZE"
DO_NOTHING = "DO NOTHING: If the user's request is unclear or you are unable to determine the user's intent, choose this action."



Expand Down
27 changes: 12 additions & 15 deletions src/airunner/handlers/llm/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core import SimpleKeywordTableIndex
from llama_index.core.indices.keyword_table import KeywordTableSimpleRetriever
from transformers import TextIteratorStreamer

from airunner.handlers.llm.huggingface_llm import HuggingFaceLLM
from airunner.handlers.llm.custom_embedding import CustomEmbedding
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, *args, **kwargs):
self.action = LLMActionType.CHAT
self.rendered_template = None
self.tokenizer = kwargs.pop("tokenizer", None)
self.streamer = kwargs.pop("streamer", None)
self.streamer = TextIteratorStreamer(self.tokenizer)
self.chat_template = kwargs.pop("chat_template", "")
self.is_mistral = kwargs.pop("is_mistral", True)
self.conversation_id = None
Expand All @@ -97,12 +98,11 @@ def __init__(self, *args, **kwargs):
@property
def available_actions(self):
return {
0: LLMActionType.QUIT_APPLICATION,
1: LLMActionType.TOGGLE_FULLSCREEN,
2: LLMActionType.TOGGLE_TTS,
3: LLMActionType.GENERATE_IMAGE,
4: LLMActionType.PERFORM_RAG_SEARCH,
5: LLMActionType.CHAT,
0: LLMActionType.TOGGLE_FULLSCREEN,
1: LLMActionType.TOGGLE_TTS,
2: LLMActionType.GENERATE_IMAGE,
3: LLMActionType.PERFORM_RAG_SEARCH,
4: LLMActionType.CHAT,
}

@property
Expand Down Expand Up @@ -163,7 +163,7 @@ def interrupt_process(self):

def do_interrupt_process(self):
interrupt = self.do_interrupt
self.do_interrupt = False
self.streamer = TextIteratorStreamer(self.tokenizer)
return interrupt

@property
Expand Down Expand Up @@ -303,9 +303,7 @@ def build_system_prompt(
self.names_prompt(use_names, botname, username),
self.mood(botname, bot_mood, use_mood),
system_instructions,
"------\n",
"Chat History:\n",
f"{self.username}: {self.prompt}\n",
self.history_prompt(),
]

elif action is LLMActionType.SUMMARIZE:
Expand Down Expand Up @@ -502,10 +500,9 @@ def run(
self.create_conversation()

# Add the user's message to history
if action in (
LLMActionType.CHAT,
LLMActionType.PERFORM_RAG_SEARCH,
LLMActionType.GENERATE_IMAGE,
if action not in (
LLMActionType.APPLICATION_COMMAND,
LLMActionType.UPDATE_MOOD
):
self.add_message_to_history(self.prompt, LLMChatRole.HUMAN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def clear_history(self):
"""
Public method to clear the chat agent history
"""
if not self._chat_agent:
return
self.logger.debug("Clearing chat history")
self._chat_agent.clear_history()

Expand Down Expand Up @@ -301,7 +303,6 @@ def _load_agent(self):
self._chat_agent = BaseAgent(
model=self._model,
tokenizer=self._tokenizer,
streamer=self._streamer,
chat_template=self.chat_template,
is_mistral=self.is_mistral,
)
Expand Down Expand Up @@ -378,8 +379,7 @@ def _load_model_local(self):

def _do_generate(self, prompt: str, action: LLMActionType):
self.logger.debug("Generating response")
model_path = self.model_path
if self._current_model_path != model_path:
if self._current_model_path != self.model_path:
self.unload()
self.load()
if action is LLMActionType.CHAT and self.chatbot.use_mood:
Expand Down
146 changes: 68 additions & 78 deletions src/airunner/handlers/stt/whisper_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor

from airunner.handlers.base_handler import BaseHandler
from airunner.enums import SignalCode, ModelType, ModelStatus, LLMChatRole
from airunner.enums import SignalCode, ModelType, ModelStatus
from airunner.exceptions import NaNException
from airunner.utils.clear_memory import clear_memory

Expand All @@ -27,6 +27,10 @@ def __init__(self, *args, **kwargs):
self._feature_extractor = None
self._fs = 16000

@property
def dtype(self):
return torch.bfloat16

@property
def stt_is_loading(self):
return self.model_status is ModelStatus.LOADING
Expand Down Expand Up @@ -57,15 +61,15 @@ def process_audio(self, audio_data):
# Convert the byte string to a float32 array
inputs = np.frombuffer(item, dtype=np.int16)
inputs = inputs.astype(np.float32) / 32767.0
transcription = None
try:
transcription = self._process_inputs(inputs)
except Exception as e:
self.logger.error(f"Failed to process inputs {e}")
self.logger.error(e)
try:
self._process_human_speech(transcription)
except ValueError as e:
self.logger.error(f"Failed to process audio {e}")

if transcription:
self._send_transcription(transcription)

def load(self):
if self.stt_is_loading or self.stt_is_loaded:
Expand Down Expand Up @@ -99,17 +103,18 @@ def unload(self):

def _load_model(self):
self.logger.debug(f"Loading model from {self.model_path}")
device = self.device
try:
self._model = WhisperForConditionalGeneration.from_pretrained(
self.model_path,
local_files_only=True,
torch_dtype=torch.bfloat16,
device_map=self.device,
use_safetensors=True
torch_dtype=self.dtype,
device_map=device,
use_safetensors=True,
force_download=False
)
except Exception as e:
self.logger.error(f"Failed to load model")
self.logger.error(e)
self.logger.error(f"Failed to load model: {e}")
return None

def _load_processor(self):
Expand All @@ -119,12 +124,11 @@ def _load_processor(self):
self._processor = WhisperProcessor.from_pretrained(
model_path,
local_files_only=True,
torch_dtype=torch.bfloat16,
torch_dtype=self.dtype,
device_map=self.device
)
except Exception as e:
self.logger.error(f"Failed to load processor")
self.logger.error(e)
self.logger.error(f"Failed to load processor: {e}")
return None

def _load_feature_extractor(self):
Expand All @@ -134,7 +138,7 @@ def _load_feature_extractor(self):
self._feature_extractor = WhisperFeatureExtractor.from_pretrained(
model_path,
local_files_only=True,
torch_dtype=torch.bfloat16,
torch_dtype=self.dtype,
device_map=self.device
)
except Exception as e:
Expand All @@ -157,59 +161,34 @@ def _unload_feature_extractor(self):
self._feature_extractor = None
clear_memory(self.device)

def _process_inputs(
self,
inputs: np.ndarray,
role: LLMChatRole = LLMChatRole.HUMAN,
) -> str:
inputs = torch.from_numpy(inputs)
def _process_inputs(self, inputs: np.ndarray) -> str:
if not self._feature_extractor:
return ""
inputs = torch.from_numpy(inputs).to(torch.float32).to(self.device)

if torch.isnan(inputs).any():
raise NaNException

# Move inputs to CPU and ensure they are in float32 before passing to _feature_extractor
inputs = inputs.cpu().to(torch.float32)
inputs = self._feature_extractor(inputs, sampling_rate=self._fs, return_tensors="pt")
if torch.isnan(inputs.input_features).any():
raise NaNException

inputs["input_features"] = inputs["input_features"].to(torch.bfloat16)
if torch.isnan(inputs.input_features).any():
raise NaNException

inputs = inputs.to(self._model.device)
inputs["input_features"] = inputs["input_features"].to(self.dtype).to(self.device)
if torch.isnan(inputs.input_features).any():
raise NaNException

transcription = self._run(inputs, role)
transcription = self._run(inputs)
if transcription is None or 'nan' in transcription:
raise NaNException

return transcription

def _process_human_speech(self, transcription: str = None):
"""
Process the human speech.
This method is called when the model has processed the human speech
and the transcription is ready to be added to the chat history.
This should only be used for human speech.
:param transcription:
:return:
"""
if transcription == "":
raise ValueError("Transcription is empty")
self.logger.debug("Processing human speech")
data = {
"message": transcription,
"role": LLMChatRole.HUMAN
}
self.emit_signal(
SignalCode.ADD_CHATBOT_MESSAGE_SIGNAL,
data
)

def _run(
self,
inputs,
role: LLMChatRole = LLMChatRole.HUMAN,
inputs
) -> str:
"""
Run the model on the given inputs.
Expand All @@ -231,48 +210,59 @@ def _run(
if torch.isnan(input_features).any():
raise NaNException

generated_ids = self._model.generate(
input_features=input_features,
# generation_config=None,
# logits_processor=None,
# stopping_criteria=None,
# prefix_allowed_tokens_fn=None,
# synced_gpus=True,
# return_timestamps=None,
# task="transcribe",
# language="en",
# is_multilingual=True,
# prompt_ids=None,
# prompt_condition_type=None,
# condition_on_prev_tokens=None,
temperature=0.8,
# compression_ratio_threshold=None,
# logprob_threshold=None,
# no_speech_threshold=None,
# num_segment_frames=None,
# attention_mask=None,
# time_precision=0.02,
# return_token_timestamps=None,
# return_segments=False,
# return_dict_in_generate=None,
)
try:
generated_ids = self._model.generate(
input_features=input_features,
# generation_config=None,
# logits_processor=None,
# stopping_criteria=None,
# prefix_allowed_tokens_fn=None,
# synced_gpus=True,
# return_timestamps=None,
# task="transcribe",
# language="en",
is_multilingual=False,
# prompt_ids=None,
# prompt_condition_type=None,
# condition_on_prev_tokens=None,
temperature=0.8,
compression_ratio_threshold=1.35,
logprob_threshold=-1.0,
no_speech_threshold=0.2,
# num_segment_frames=None,
# attention_mask=None,
time_precision=0.02,
# return_token_timestamps=None,
# return_segments=False,
# return_dict_in_generate=None,
)
except RuntimeError as e:
generated_ids = None
self.logger.error(f"Error in model generation: {e}")

if generated_ids is None:
return ""

if torch.isnan(generated_ids).any():
raise NaNException

transcription = self.process_transcription(generated_ids)
if len(transcription) == 0 or len(transcription.split(" ")) == 1:
return ""

# Emit the transcription so that other handlers can use it
return transcription

def _send_transcription(self, transcription: str):
"""
Emit the transcription so that other handlers can use it
"""
self.emit_signal(SignalCode.AUDIO_PROCESSOR_RESPONSE_SIGNAL, {
"transcription": transcription,
"role": role
"transcription": transcription
})

return transcription

def process_transcription(self, generated_ids) -> str:
# Decode the generated ids
generated_ids = generated_ids.to("cpu").to(torch.float32)
transcription = self._processor.batch_decode(
generated_ids,
skip_special_tokens=True
Expand Down
9 changes: 4 additions & 5 deletions src/airunner/handlers/tts/speecht5_tts_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,24 +328,23 @@ def interrupt_process_signal(self):
def _prepare_text(self, text) -> str:
text = self._replace_unspeakable_characters(text)
text = self._strip_emoji_characters(text)
text = self._roman_to_int(text)
# the following function is currently disabled because we must first find a
# reliable way to handle the word "I" and distinguish it from the Roman numeral "I"
# text = self._roman_to_int(text)
text = self._replace_numbers_with_words(text)
text = re.sub(r"\s+", " ", text) # Remove extra spaces
text = text.strip()
return text

@staticmethod
def _replace_unspeakable_characters(text) -> str:
# strip things like ellipsis, etc
text = text.replace("...", " ")
text = text.replace("…", " ")
text = text.replace("’", "")
text = text.replace("’", "'")
text = text.replace("“", "")
text = text.replace("”", "")
text = text.replace("‘", "")
text = text.replace("–", "")
text = text.replace("—", "")
text = text.replace("'", "")
text = text.replace('"', "")
text = text.replace("-", "")
text = text.replace("-", "")
Expand Down
2 changes: 1 addition & 1 deletion src/airunner/tests/test_speecht5_tts_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_roman_to_int(self):
"M": "1000",
"MMXXI": "2021",
"This is a IV test": "This is a 4 test",
"A test with no roman numerals": "A test with no roman numerals"
"A test with no roman numerals": "A test with no roman numerals",
}

for roman, expected in test_cases.items():
Expand Down
Loading

0 comments on commit 2a867ba

Please sign in to comment.