From 79ee710064466ae38d28c292e3a3a1c8c1497b67 Mon Sep 17 00:00:00 2001 From: Davy Peter Braun <543614+dheavy@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:16:32 +0200 Subject: [PATCH] Re-lint after rebase --- ROADMAP.md | 5 +- TASKS.md | 4 +- docs/client/setup.mdx | 6 +- docs/server/setup.mdx | 66 +++++----- docs/style.css | 2 +- hardware/light/README.md | 4 +- .../hardware/devices/jetson-nano/README.md | 2 +- software/.cursorignore | 1 - software/pyproject.toml | 2 +- software/source/clients/esp32/README.md | 5 +- .../clients/esp32/src/client/client.ino | 44 +++---- .../clients/esp32/src/client/platformio.ini | 4 +- software/source/clients/linux/device.py | 2 + software/source/clients/mac/device.py | 2 + software/source/clients/rpi/device.py | 4 +- software/source/clients/windows/device.py | 2 + software/source/server/llm.py | 5 +- .../source/server/services/llm/litellm/llm.py | 4 - .../server/services/llm/llamaedge/llm.py | 55 +++++--- .../stt/local-whisper/whisper-rust/.gitignore | 2 +- .../stt/local-whisper/whisper-rust/Cargo.toml | 2 +- .../local-whisper/whisper-rust/src/main.rs | 4 +- .../whisper-rust/src/transcribe.rs | 2 +- .../source/server/services/stt/openai/stt.py | 66 ++++++---- .../source/server/services/tts/piper/tts.py | 118 +++++++++++++----- .../system_messages/BaseSystemMessage.py | 6 +- .../system_messages/TeachModeSystemMessage.py | 6 +- software/source/server/tunnel.py | 35 ++++-- software/source/server/utils/bytes_to_wav.py | 27 ++-- software/source/server/utils/kernel.py | 52 +++++--- software/source/server/utils/logs.py | 5 +- software/source/utils/accumulator.py | 18 +-- software/source/utils/print_markdown.py | 3 +- software/start.py | 95 +++++++++----- 34 files changed, 418 insertions(+), 242 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index cf4183db..58938b33 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,8 +1,8 @@ # Roadmap -Our goal is to power a billion devices with the 01OS over the next 10 years. The Cambrian explosion of AI devices. +Our goal is to power a billion devices with the 01OS over the next 10 years. The Cambrian explosion of AI devices. -We can do that with your help. Help extend the 01OS to run on new hardware, to connect with new peripherals like GPS and cameras, and add new locally running language models to unlock use-cases for this technology that no-one has even imagined yet. +We can do that with your help. Help extend the 01OS to run on new hardware, to connect with new peripherals like GPS and cameras, and add new locally running language models to unlock use-cases for this technology that no-one has even imagined yet. In the coming months, we're going to release: @@ -10,4 +10,3 @@ In the coming months, we're going to release: - [ ] An open-source language model for computer control - [ ] A react-native app for your phone - [ ] A hand-held device that runs fully offline. - diff --git a/TASKS.md b/TASKS.md index 8197dd65..2477fb2f 100644 --- a/TASKS.md +++ b/TASKS.md @@ -36,7 +36,7 @@ - [ ] Sends to describe API - [ ] prints and returns description - [ ] Llamafile for phi-2 + moondream - - [ ] test on rPi + Jetson (+android mini phone?) + - [ ] test on rPi + Jetson (+android mini phone?) **OS** @@ -66,7 +66,7 @@ **Hardware** -- [ ] (Hardware and software) Get the 01OS working on the **Jetson** or Pi. Pick one to move forward with. +- [ ] (Hardware and software) Get the 01OS working on the **Jetson** or Pi. Pick one to move forward with. - [ ] Connect the Seeed Sense (ESP32 with Wifi, Bluetooth and a mic) to a small DAC + amplifier + speaker. - [ ] Connect the Seeed Sense to a battery. - [ ] Configure the ESP32 to be a wireless mic + speaker for the Jetson or Pi. diff --git a/docs/client/setup.mdx b/docs/client/setup.mdx index d9ed422e..df1a6285 100644 --- a/docs/client/setup.mdx +++ b/docs/client/setup.mdx @@ -34,9 +34,9 @@ poetry run 01 --client ### Flags -- `--client` +- `--client` Run client. -- `--client-type TEXT` - Specify the client type. +- `--client-type TEXT` + Specify the client type. Default: `auto`. diff --git a/docs/server/setup.mdx b/docs/server/setup.mdx index e9e284a0..f500687a 100644 --- a/docs/server/setup.mdx +++ b/docs/server/setup.mdx @@ -44,73 +44,73 @@ For more information, please read about speec ## CLI Flags -- `--server` +- `--server` Run server. -- `--server-host TEXT` - Specify the server host where the server will deploy. +- `--server-host TEXT` + Specify the server host where the server will deploy. Default: `0.0.0.0`. -- `--server-port INTEGER` - Specify the server port where the server will deploy. +- `--server-port INTEGER` + Specify the server port where the server will deploy. Default: `10001`. -- `--tunnel-service TEXT` - Specify the tunnel service. +- `--tunnel-service TEXT` + Specify the tunnel service. Default: `ngrok`. -- `--expose` +- `--expose` Expose server to internet. -- `--server-url TEXT` - Specify the server URL that the client should expect. - Defaults to server-host and server-port. +- `--server-url TEXT` + Specify the server URL that the client should expect. + Defaults to server-host and server-port. Default: `None`. -- `--llm-service TEXT` - Specify the LLM service. +- `--llm-service TEXT` + Specify the LLM service. Default: `litellm`. -- `--model TEXT` - Specify the model. +- `--model TEXT` + Specify the model. Default: `gpt-4`. -- `--llm-supports-vision` +- `--llm-supports-vision` Specify if the LLM service supports vision. -- `--llm-supports-functions` +- `--llm-supports-functions` Specify if the LLM service supports functions. -- `--context-window INTEGER` - Specify the context window size. +- `--context-window INTEGER` + Specify the context window size. Default: `2048`. -- `--max-tokens INTEGER` - Specify the maximum number of tokens. +- `--max-tokens INTEGER` + Specify the maximum number of tokens. Default: `4096`. -- `--temperature FLOAT` - Specify the temperature for generation. +- `--temperature FLOAT` + Specify the temperature for generation. Default: `0.8`. -- `--tts-service TEXT` - Specify the TTS service. +- `--tts-service TEXT` + Specify the TTS service. Default: `openai`. -- `--stt-service TEXT` - Specify the STT service. +- `--stt-service TEXT` + Specify the STT service. Default: `openai`. -- `--local` +- `--local` Use recommended local services for LLM, STT, and TTS. -- `--install-completion [bash|zsh|fish|powershell|pwsh]` - Install completion for the specified shell. +- `--install-completion [bash|zsh|fish|powershell|pwsh]` + Install completion for the specified shell. Default: `None`. -- `--show-completion [bash|zsh|fish|powershell|pwsh]` - Show completion for the specified shell, to copy it or customize the installation. +- `--show-completion [bash|zsh|fish|powershell|pwsh]` + Show completion for the specified shell, to copy it or customize the installation. Default: `None`. -- `--help` +- `--help` Show this message and exit. diff --git a/docs/style.css b/docs/style.css index 392cac0a..52a1d79a 100644 --- a/docs/style.css +++ b/docs/style.css @@ -29,4 +29,4 @@ .body { font-weight: normal; -} \ No newline at end of file +} diff --git a/hardware/light/README.md b/hardware/light/README.md index cd6fcfc2..9ec534ae 100644 --- a/hardware/light/README.md +++ b/hardware/light/README.md @@ -22,13 +22,13 @@ Please install first [PlatformIO](http://platformio.org/) open source ecosystem ```bash cd software/source/clients/esp32/src/client/ -``` +``` And build and upload the firmware with a simple command: ```bash pio run --target upload -``` +``` ## Wifi diff --git a/project_management/hardware/devices/jetson-nano/README.md b/project_management/hardware/devices/jetson-nano/README.md index 600bda41..08a7c02f 100644 --- a/project_management/hardware/devices/jetson-nano/README.md +++ b/project_management/hardware/devices/jetson-nano/README.md @@ -19,4 +19,4 @@ ![](mac-share-internet-v2.png) - d. Now the Jetson should have connectivity! \ No newline at end of file + d. Now the Jetson should have connectivity! diff --git a/software/.cursorignore b/software/.cursorignore index b494a46c..7a81b426 100644 --- a/software/.cursorignore +++ b/software/.cursorignore @@ -1,4 +1,3 @@ _archive __pycache__ .idea - diff --git a/software/pyproject.toml b/software/pyproject.toml index 6d331eae..8b9a5341 100644 --- a/software/pyproject.toml +++ b/software/pyproject.toml @@ -54,4 +54,4 @@ target-version = ['py311'] [tool.isort] profile = "black" multi_line_output = 3 -include_trailing_comma = true \ No newline at end of file +include_trailing_comma = true diff --git a/software/source/clients/esp32/README.md b/software/source/clients/esp32/README.md index 3a80f429..48b6a3a5 100644 --- a/software/source/clients/esp32/README.md +++ b/software/source/clients/esp32/README.md @@ -19,11 +19,10 @@ Please install first [PlatformIO](http://platformio.org/) open source ecosystem ```bash cd client/ -``` +``` And build and upload the firmware with a simple command: ```bash pio run --target upload -``` - +``` diff --git a/software/source/clients/esp32/src/client/client.ino b/software/source/clients/esp32/src/client/client.ino index 77bf8b61..76ba0565 100644 --- a/software/source/clients/esp32/src/client/client.ino +++ b/software/source/clients/esp32/src/client/client.ino @@ -78,11 +78,11 @@ const char post_connected_html[] PROGMEM = R"=====( 01OS Setup @@ -144,7 +144,7 @@ const char post_connected_html[] PROGMEM = R"=====(

- +

@@ -270,7 +270,7 @@ bool connectTo01OS(String server_address) portStr = server_address.substring(colonIndex + 1); } else { domain = server_address; - portStr = ""; + portStr = ""; } WiFiClient c; @@ -281,7 +281,7 @@ bool connectTo01OS(String server_address) port = portStr.toInt(); } - HttpClient http(c, domain.c_str(), port); + HttpClient http(c, domain.c_str(), port); Serial.println("Connecting to 01OS at " + domain + ":" + port + "/ping"); if (domain.indexOf("ngrok") != -1) { @@ -363,7 +363,7 @@ bool connectTo01OS(String server_address) Serial.print("Connection failed: "); Serial.println(err); } - + return connectionSuccess; } @@ -436,7 +436,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) { String ssid; String password; - + // Check if SSID parameter exists and assign it if(request->hasParam("ssid", true)) { ssid = request->getParam("ssid", true)->value(); @@ -446,7 +446,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) Serial.println("OTHER SSID SELECTED: " + ssid); } } - + // Check if Password parameter exists and assign it if(request->hasParam("password", true)) { password = request->getParam("password", true)->value(); @@ -458,7 +458,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) if(request->hasParam("password", true) && request->hasParam("ssid", true)) { connectToWifi(ssid, password); } - + // Redirect user or send a response back if (WiFi.status() == WL_CONNECTED) { @@ -466,7 +466,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) AsyncWebServerResponse *response = request->beginResponse(200, "text/html", htmlContent); response->addHeader("Cache-Control", "public,max-age=31536000"); // save this file to cache for 1 year (unless you refresh) request->send(response); - Serial.println("Served Post connection HTML Page"); + Serial.println("Served Post connection HTML Page"); } else { request->send(200, "text/plain", "Failed to connect to " + ssid); } }); @@ -474,7 +474,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) server.on("/submit_01os", HTTP_POST, [](AsyncWebServerRequest *request) { String server_address; - + // Check if SSID parameter exists and assign it if(request->hasParam("server_address", true)) { server_address = request->getParam("server_address", true)->value(); @@ -490,7 +490,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) { AsyncWebServerResponse *response = request->beginResponse(200, "text/html", successHtml); response->addHeader("Cache-Control", "no-cache, no-store, must-revalidate"); // Prevent caching of this page - request->send(response); + request->send(response); Serial.println(" "); Serial.println("Connected to 01 websocket!"); Serial.println(" "); @@ -502,7 +502,7 @@ void setUpWebserver(AsyncWebServer &server, const IPAddress &localIP) String htmlContent = String(post_connected_html); // Load your HTML template // Inject the error message htmlContent.replace("

", "

Error connecting, please try again.

"); - + AsyncWebServerResponse *response = request->beginResponse(200, "text/html", htmlContent); response->addHeader("Cache-Control", "no-cache, no-store, must-revalidate"); // Prevent caching of this page request->send(response); @@ -622,7 +622,7 @@ void InitI2SSpeakerOrMic(int mode) #if ESP_IDF_VERSION > ESP_IDF_VERSION_VAL(4, 1, 0) .communication_format = I2S_COMM_FORMAT_STAND_I2S, // Set the format of the communication. -#else +#else .communication_format = I2S_COMM_FORMAT_I2S, #endif .intr_alloc_flags = ESP_INTR_FLAG_LEVEL1, @@ -779,17 +779,17 @@ void setup() { Serial.begin(115200); // Initialize serial communication at 115200 baud rate. // Attempt to reconnect to WiFi using stored credentials. // Check if WiFi is connected but the server URL isn't stored - + Serial.setTxBufferSize(1024); // Set the transmit buffer size for the Serial object. WiFi.mode(WIFI_AP_STA); // Set WiFi mode to both AP and STA. - + // delay(100); // Short delay to ensure mode change takes effect // WiFi.softAPConfig(localIP, gatewayIP, subnetMask); // WiFi.softAP(ssid, password); startSoftAccessPoint(ssid, password, localIP, gatewayIP); setUpDNSServer(dnsServer, localIP); - + setUpWebserver(server, localIP); tryReconnectWiFi(); // Print a welcome message to the Serial port. @@ -823,7 +823,7 @@ void loop() if ((millis() - last_dns_ms) > DNS_INTERVAL) { last_dns_ms = millis(); // seems to help with stability, if you are doing other things in the loop this may not be needed dnsServer.processNextRequest(); // I call this atleast every 10ms in my other projects (can be higher but I haven't tested it for stability) - } + } // Check WiFi connection status if (WiFi.status() == WL_CONNECTED && !hasSetupWebsocket) @@ -865,4 +865,4 @@ void loop() M5.update(); webSocket.loop(); } -} \ No newline at end of file +} diff --git a/software/source/clients/esp32/src/client/platformio.ini b/software/source/clients/esp32/src/client/platformio.ini index 6061e13a..d1011e28 100644 --- a/software/source/clients/esp32/src/client/platformio.ini +++ b/software/source/clients/esp32/src/client/platformio.ini @@ -10,7 +10,7 @@ platform = espressif32 framework = arduino monitor_speed = 115200 upload_speed = 1500000 -monitor_filters = +monitor_filters = esp32_exception_decoder time build_flags = @@ -23,7 +23,7 @@ board = esp32dev [env:m5echo] extends = esp32common -lib_deps = +lib_deps = m5stack/M5Atom @ ^0.1.2 links2004/WebSockets @ ^2.4.1 ;esphome/ESPAsyncWebServer-esphome @ ^3.1.0 diff --git a/software/source/clients/linux/device.py b/software/source/clients/linux/device.py index a9a79c02..0fa0fed2 100644 --- a/software/source/clients/linux/device.py +++ b/software/source/clients/linux/device.py @@ -2,9 +2,11 @@ device = Device() + def main(server_url): device.server_url = server_url device.start() + if __name__ == "__main__": main() diff --git a/software/source/clients/mac/device.py b/software/source/clients/mac/device.py index a9a79c02..0fa0fed2 100644 --- a/software/source/clients/mac/device.py +++ b/software/source/clients/mac/device.py @@ -2,9 +2,11 @@ device = Device() + def main(server_url): device.server_url = server_url device.start() + if __name__ == "__main__": main() diff --git a/software/source/clients/rpi/device.py b/software/source/clients/rpi/device.py index 279822f9..fe0250bd 100644 --- a/software/source/clients/rpi/device.py +++ b/software/source/clients/rpi/device.py @@ -2,8 +2,10 @@ device = Device() + def main(): device.start() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/software/source/clients/windows/device.py b/software/source/clients/windows/device.py index a9a79c02..0fa0fed2 100644 --- a/software/source/clients/windows/device.py +++ b/software/source/clients/windows/device.py @@ -2,9 +2,11 @@ device = Device() + def main(server_url): device.server_url = server_url device.start() + if __name__ == "__main__": main() diff --git a/software/source/server/llm.py b/software/source/server/llm.py index ba761a30..430e58ad 100644 --- a/software/source/server/llm.py +++ b/software/source/server/llm.py @@ -1,4 +1,5 @@ from dotenv import load_dotenv + load_dotenv() # take environment variables from .env. import os @@ -8,7 +9,7 @@ ### LLM SETUP # Define the path to a llamafile -llamafile_path = Path(__file__).parent / 'model.llamafile' +llamafile_path = Path(__file__).parent / "model.llamafile" # Check if the new llamafile exists, if not download it if not os.path.exists(llamafile_path): @@ -25,4 +26,4 @@ subprocess.run(["chmod", "+x", llamafile_path], check=True) # Run the new llamafile -subprocess.run([str(llamafile_path)], check=True) \ No newline at end of file +subprocess.run([str(llamafile_path)], check=True) diff --git a/software/source/server/services/llm/litellm/llm.py b/software/source/server/services/llm/litellm/llm.py index 906308bf..f4093e4a 100644 --- a/software/source/server/services/llm/litellm/llm.py +++ b/software/source/server/services/llm/litellm/llm.py @@ -1,6 +1,5 @@ class Llm: def __init__(self, config): - # Litellm is used by OI by default, so we just modify OI interpreter = config["interpreter"] @@ -10,6 +9,3 @@ def __init__(self, config): setattr(interpreter, key.replace("-", "_"), value) self.llm = interpreter.llm.completions - - - diff --git a/software/source/server/services/llm/llamaedge/llm.py b/software/source/server/services/llm/llamaedge/llm.py index fa77abf6..7894f544 100644 --- a/software/source/server/services/llm/llamaedge/llm.py +++ b/software/source/server/services/llm/llamaedge/llm.py @@ -3,29 +3,54 @@ import requests import json + class Llm: def __init__(self, config): self.install(config["service_directory"]) def install(self, service_directory): LLM_FOLDER_PATH = service_directory - self.llm_directory = os.path.join(LLM_FOLDER_PATH, 'llm') - if not os.path.isdir(self.llm_directory): # Check if the LLM directory exists + self.llm_directory = os.path.join(LLM_FOLDER_PATH, "llm") + if not os.path.isdir(self.llm_directory): # Check if the LLM directory exists os.makedirs(LLM_FOLDER_PATH, exist_ok=True) # Install WasmEdge - subprocess.run(['curl', '-sSf', 'https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh', '|', 'bash', '-s', '--', '--plugin', 'wasi_nn-ggml']) + subprocess.run( + [ + "curl", + "-sSf", + "https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh", + "|", + "bash", + "-s", + "--", + "--plugin", + "wasi_nn-ggml", + ] + ) # Download the Qwen1.5-0.5B-Chat model GGUF file MODEL_URL = "https://huggingface.co/second-state/Qwen1.5-0.5B-Chat-GGUF/resolve/main/Qwen1.5-0.5B-Chat-Q5_K_M.gguf" - subprocess.run(['curl', '-LO', MODEL_URL], cwd=self.llm_directory) - + subprocess.run(["curl", "-LO", MODEL_URL], cwd=self.llm_directory) + # Download the llama-api-server.wasm app APP_URL = "https://github.com/LlamaEdge/LlamaEdge/releases/latest/download/llama-api-server.wasm" - subprocess.run(['curl', '-LO', APP_URL], cwd=self.llm_directory) + subprocess.run(["curl", "-LO", APP_URL], cwd=self.llm_directory) # Run the API server - subprocess.run(['wasmedge', '--dir', '.:.', '--nn-preload', 'default:GGML:AUTO:Qwen1.5-0.5B-Chat-Q5_K_M.gguf', 'llama-api-server.wasm', '-p', 'llama-2-chat'], cwd=self.llm_directory) + subprocess.run( + [ + "wasmedge", + "--dir", + ".:.", + "--nn-preload", + "default:GGML:AUTO:Qwen1.5-0.5B-Chat-Q5_K_M.gguf", + "llama-api-server.wasm", + "-p", + "llama-2-chat", + ], + cwd=self.llm_directory, + ) print("LLM setup completed.") else: @@ -33,17 +58,11 @@ def install(self, service_directory): def llm(self, messages): url = "http://localhost:8080/v1/chat/completions" - headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json' - } - data = { - "messages": messages, - "model": "llama-2-chat" - } - with requests.post(url, headers=headers, data=json.dumps(data), stream=True) as response: + headers = {"accept": "application/json", "Content-Type": "application/json"} + data = {"messages": messages, "model": "llama-2-chat"} + with requests.post( + url, headers=headers, data=json.dumps(data), stream=True + ) as response: for line in response.iter_lines(): if line: yield json.loads(line) - - diff --git a/software/source/server/services/stt/local-whisper/whisper-rust/.gitignore b/software/source/server/services/stt/local-whisper/whisper-rust/.gitignore index 71ab9a43..73fab072 100644 --- a/software/source/server/services/stt/local-whisper/whisper-rust/.gitignore +++ b/software/source/server/services/stt/local-whisper/whisper-rust/.gitignore @@ -7,4 +7,4 @@ target/ **/*.rs.bk # MSVC Windows builds of rustc generate these, which store debugging information -*.pdb \ No newline at end of file +*.pdb diff --git a/software/source/server/services/stt/local-whisper/whisper-rust/Cargo.toml b/software/source/server/services/stt/local-whisper/whisper-rust/Cargo.toml index f1726929..c3daf687 100644 --- a/software/source/server/services/stt/local-whisper/whisper-rust/Cargo.toml +++ b/software/source/server/services/stt/local-whisper/whisper-rust/Cargo.toml @@ -11,4 +11,4 @@ clap = { version = "4.4.18", features = ["derive"] } cpal = "0.15.2" hound = "3.5.1" whisper-rs = "0.10.0" -whisper-rs-sys = "0.8.0" \ No newline at end of file +whisper-rs-sys = "0.8.0" diff --git a/software/source/server/services/stt/local-whisper/whisper-rust/src/main.rs b/software/source/server/services/stt/local-whisper/whisper-rust/src/main.rs index 0688c89e..52965388 100644 --- a/software/source/server/services/stt/local-whisper/whisper-rust/src/main.rs +++ b/software/source/server/services/stt/local-whisper/whisper-rust/src/main.rs @@ -10,7 +10,7 @@ struct Args { /// This is the model for Whisper STT #[arg(short, long, value_parser, required = true)] model_path: PathBuf, - + /// This is the wav audio file that will be converted from speech to text #[arg(short, long, value_parser, required = true)] file_path: Option, @@ -31,4 +31,4 @@ fn main() { Ok(transcription) => print!("{}", transcription), Err(e) => panic!("Error: {}", e), } -} \ No newline at end of file +} diff --git a/software/source/server/services/stt/local-whisper/whisper-rust/src/transcribe.rs b/software/source/server/services/stt/local-whisper/whisper-rust/src/transcribe.rs index 35970cc9..99e1a527 100644 --- a/software/source/server/services/stt/local-whisper/whisper-rust/src/transcribe.rs +++ b/software/source/server/services/stt/local-whisper/whisper-rust/src/transcribe.rs @@ -61,4 +61,4 @@ pub fn transcribe(model_path: &PathBuf, file_path: &PathBuf) -> Result str: if mime_type == "audio/x-wav" or mime_type == "audio/wav": return "wav" @@ -29,30 +29,37 @@ def convert_mime_type_to_format(mime_type: str) -> str: return mime_type + @contextlib.contextmanager def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str: temp_dir = tempfile.gettempdir() # Create a temporary file with the appropriate extension input_ext = convert_mime_type_to_format(mime_type) - input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}") - with open(input_path, 'wb') as f: + input_path = os.path.join( + temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}" + ) + with open(input_path, "wb") as f: f.write(audio) # Check if the input file exists assert os.path.exists(input_path), f"Input file does not exist: {input_path}" # Export to wav - output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav") + output_path = os.path.join( + temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav" + ) if mime_type == "audio/raw": ffmpeg.input( input_path, - f='s16le', - ar='16000', + f="s16le", + ar="16000", ac=1, - ).output(output_path, loglevel='panic').run() + ).output(output_path, loglevel="panic").run() else: - ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k', loglevel='panic').run() + ffmpeg.input(input_path).output( + output_path, acodec="pcm_s16le", ac=1, ar="16k", loglevel="panic" + ).run() try: yield output_path @@ -60,39 +67,49 @@ def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str: os.remove(input_path) os.remove(output_path) + def run_command(command): - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) return result.stdout, result.stderr -def get_transcription_file(wav_file_path: str): - local_path = os.path.join(os.path.dirname(__file__), 'local_service') - whisper_rust_path = os.path.join(os.path.dirname(__file__), 'whisper-rust', 'target', 'release') - model_name = os.getenv('WHISPER_MODEL_NAME', 'ggml-tiny.en.bin') - output, error = run_command([ - os.path.join(whisper_rust_path, 'whisper-rust'), - '--model-path', os.path.join(local_path, model_name), - '--file-path', wav_file_path - ]) +def get_transcription_file(wav_file_path: str): + local_path = os.path.join(os.path.dirname(__file__), "local_service") + whisper_rust_path = os.path.join( + os.path.dirname(__file__), "whisper-rust", "target", "release" + ) + model_name = os.getenv("WHISPER_MODEL_NAME", "ggml-tiny.en.bin") + + output, error = run_command( + [ + os.path.join(whisper_rust_path, "whisper-rust"), + "--model-path", + os.path.join(local_path, model_name), + "--file-path", + wav_file_path, + ] + ) return output + def get_transcription_bytes(audio_bytes: bytearray, mime_type): with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path: return get_transcription_file(wav_file_path) + def stt_bytes(audio_bytes: bytearray, mime_type="audio/wav"): with export_audio_to_wav_ffmpeg(audio_bytes, mime_type) as wav_file_path: return stt_wav(wav_file_path) -def stt_wav(wav_file_path: str): +def stt_wav(wav_file_path: str): audio_file = open(wav_file_path, "rb") try: transcript = client.audio.transcriptions.create( - model="whisper-1", - file=audio_file, - response_format="text" + model="whisper-1", file=audio_file, response_format="text" ) except openai.BadRequestError as e: print(f"openai.BadRequestError: {e}") @@ -100,10 +117,13 @@ def stt_wav(wav_file_path: str): return transcript + def stt(input_data, mime_type="audio/wav"): if isinstance(input_data, str): return stt_wav(input_data) elif isinstance(input_data, bytearray): return stt_bytes(input_data, mime_type) else: - raise ValueError("Input data should be either a path to a wav file (str) or audio bytes (bytearray)") \ No newline at end of file + raise ValueError( + "Input data should be either a path to a wav file (str) or audio bytes (bytearray)" + ) diff --git a/software/source/server/services/tts/piper/tts.py b/software/source/server/services/tts/piper/tts.py index 46d23dc8..8daa1584 100644 --- a/software/source/server/services/tts/piper/tts.py +++ b/software/source/server/services/tts/piper/tts.py @@ -13,26 +13,40 @@ def __init__(self, config): self.install(config["service_directory"]) def tts(self, text): - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: output_file = temp_file.name piper_dir = self.piper_directory - subprocess.run([ - os.path.join(piper_dir, 'piper'), - '--model', os.path.join(piper_dir, os.getenv('PIPER_VOICE_NAME', 'en_US-lessac-medium.onnx')), - '--output_file', output_file - ], input=text, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + subprocess.run( + [ + os.path.join(piper_dir, "piper"), + "--model", + os.path.join( + piper_dir, + os.getenv("PIPER_VOICE_NAME", "en_US-lessac-medium.onnx"), + ), + "--output_file", + output_file, + ], + input=text, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) # TODO: hack to format audio correctly for device outfile = tempfile.gettempdir() + "/" + "raw.dat" - ffmpeg.input(temp_file.name).output(outfile, f="s16le", ar="16000", ac="1", loglevel='panic').run() + ffmpeg.input(temp_file.name).output( + outfile, f="s16le", ar="16000", ac="1", loglevel="panic" + ).run() return outfile def install(self, service_directory): PIPER_FOLDER_PATH = service_directory - self.piper_directory = os.path.join(PIPER_FOLDER_PATH, 'piper') - if not os.path.isdir(self.piper_directory): # Check if the Piper directory exists + self.piper_directory = os.path.join(PIPER_FOLDER_PATH, "piper") + if not os.path.isdir( + self.piper_directory + ): # Check if the Piper directory exists os.makedirs(PIPER_FOLDER_PATH, exist_ok=True) # Determine OS and architecture @@ -60,52 +74,92 @@ def install(self, service_directory): asset_url = f"{PIPER_URL}{PIPER_ASSETNAME}" if OS == "windows": - asset_url = asset_url.replace(".tar.gz", ".zip") # Download and extract Piper - urllib.request.urlretrieve(asset_url, os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME)) + urllib.request.urlretrieve( + asset_url, os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME) + ) # Extract the downloaded file if OS == "windows": import zipfile - with zipfile.ZipFile(os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), 'r') as zip_ref: + + with zipfile.ZipFile( + os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), "r" + ) as zip_ref: zip_ref.extractall(path=PIPER_FOLDER_PATH) else: - with tarfile.open(os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), 'r:gz') as tar: + with tarfile.open( + os.path.join(PIPER_FOLDER_PATH, PIPER_ASSETNAME), "r:gz" + ) as tar: tar.extractall(path=PIPER_FOLDER_PATH) - PIPER_VOICE_URL = os.getenv('PIPER_VOICE_URL', - 'https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/lessac/medium/') - PIPER_VOICE_NAME = os.getenv('PIPER_VOICE_NAME', 'en_US-lessac-medium.onnx') + PIPER_VOICE_URL = os.getenv( + "PIPER_VOICE_URL", + "https://huggingface.co/rhasspy/piper-voices/resolve/main/en/en_US/lessac/medium/", + ) + PIPER_VOICE_NAME = os.getenv("PIPER_VOICE_NAME", "en_US-lessac-medium.onnx") # Download voice model and its json file - urllib.request.urlretrieve(f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}", - os.path.join(self.piper_directory, PIPER_VOICE_NAME)) - urllib.request.urlretrieve(f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}.json", - os.path.join(self.piper_directory, f"{PIPER_VOICE_NAME}.json")) + urllib.request.urlretrieve( + f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}", + os.path.join(self.piper_directory, PIPER_VOICE_NAME), + ) + urllib.request.urlretrieve( + f"{PIPER_VOICE_URL}{PIPER_VOICE_NAME}.json", + os.path.join(self.piper_directory, f"{PIPER_VOICE_NAME}.json"), + ) # Additional setup for macOS if OS == "macos": if ARCH == "x64": - subprocess.run(['softwareupdate', '--install-rosetta', '--agree-to-license']) + subprocess.run( + ["softwareupdate", "--install-rosetta", "--agree-to-license"] + ) PIPER_PHONEMIZE_ASSETNAME = f"piper-phonemize_{OS}_{ARCH}.tar.gz" PIPER_PHONEMIZE_URL = "https://github.com/rhasspy/piper-phonemize/releases/latest/download/" - urllib.request.urlretrieve(f"{PIPER_PHONEMIZE_URL}{PIPER_PHONEMIZE_ASSETNAME}", - os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME)) - - with tarfile.open(os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME), 'r:gz') as tar: + urllib.request.urlretrieve( + f"{PIPER_PHONEMIZE_URL}{PIPER_PHONEMIZE_ASSETNAME}", + os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME), + ) + + with tarfile.open( + os.path.join(self.piper_directory, PIPER_PHONEMIZE_ASSETNAME), + "r:gz", + ) as tar: tar.extractall(path=self.piper_directory) PIPER_DIR = self.piper_directory - subprocess.run(['install_name_tool', '-change', '@rpath/libespeak-ng.1.dylib', - f"{PIPER_DIR}/piper-phonemize/lib/libespeak-ng.1.dylib", f"{PIPER_DIR}/piper"]) - subprocess.run(['install_name_tool', '-change', '@rpath/libonnxruntime.1.14.1.dylib', - f"{PIPER_DIR}/piper-phonemize/lib/libonnxruntime.1.14.1.dylib", f"{PIPER_DIR}/piper"]) - subprocess.run(['install_name_tool', '-change', '@rpath/libpiper_phonemize.1.dylib', - f"{PIPER_DIR}/piper-phonemize/lib/libpiper_phonemize.1.dylib", f"{PIPER_DIR}/piper"]) + subprocess.run( + [ + "install_name_tool", + "-change", + "@rpath/libespeak-ng.1.dylib", + f"{PIPER_DIR}/piper-phonemize/lib/libespeak-ng.1.dylib", + f"{PIPER_DIR}/piper", + ] + ) + subprocess.run( + [ + "install_name_tool", + "-change", + "@rpath/libonnxruntime.1.14.1.dylib", + f"{PIPER_DIR}/piper-phonemize/lib/libonnxruntime.1.14.1.dylib", + f"{PIPER_DIR}/piper", + ] + ) + subprocess.run( + [ + "install_name_tool", + "-change", + "@rpath/libpiper_phonemize.1.dylib", + f"{PIPER_DIR}/piper-phonemize/lib/libpiper_phonemize.1.dylib", + f"{PIPER_DIR}/piper", + ] + ) print("Piper setup completed.") else: - print("Piper already set up. Skipping download.") \ No newline at end of file + print("Piper already set up. Skipping download.") diff --git a/software/source/server/system_messages/BaseSystemMessage.py b/software/source/server/system_messages/BaseSystemMessage.py index 20429f3d..7fdaefce 100644 --- a/software/source/server/system_messages/BaseSystemMessage.py +++ b/software/source/server/system_messages/BaseSystemMessage.py @@ -36,7 +36,7 @@ The user's current task is: {{ tasks[0] if tasks else "No current tasks." }} -{{ +{{ if len(tasks) > 1: print("The next task is: ", tasks[1]) }} @@ -91,7 +91,7 @@ The user's current task is: {{ tasks[0] if tasks else "No current tasks." }} -{{ +{{ if len(tasks) > 1: print("The next task is: ", tasks[1]) }} @@ -184,7 +184,7 @@ finally: sys.stdout = original_stdout sys.stderr = original_stderr - + }} # SKILLS diff --git a/software/source/server/system_messages/TeachModeSystemMessage.py b/software/source/server/system_messages/TeachModeSystemMessage.py index d88708c6..4c8ec09f 100644 --- a/software/source/server/system_messages/TeachModeSystemMessage.py +++ b/software/source/server/system_messages/TeachModeSystemMessage.py @@ -96,7 +96,7 @@ finally: sys.stdout = original_stdout sys.stderr = original_stderr - + }} # SKILLS LIBRARY @@ -131,4 +131,6 @@ Remember: You can run Python code outside a function only to run a Python function; all other code must go in a in Python function if you first write a Python function. ALL imports must go inside the function. -""".strip().replace("OI_SKILLS_DIR", os.path.abspath(os.path.join(os.path.dirname(__file__), "skills"))) \ No newline at end of file +""".strip().replace( + "OI_SKILLS_DIR", os.path.abspath(os.path.join(os.path.dirname(__file__), "skills")) +) diff --git a/software/source/server/tunnel.py b/software/source/server/tunnel.py index 6d6acb01..809db081 100644 --- a/software/source/server/tunnel.py +++ b/software/source/server/tunnel.py @@ -1,12 +1,14 @@ import subprocess import re -import shutil import pyqrcode import time from ..utils.print_markdown import print_markdown -def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10001, qr=False): - print_markdown(f"Exposing server to the internet...") + +def create_tunnel( + tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False +): + print_markdown("Exposing server to the internet...") server_url = "" if tunnel_method == "bore": @@ -35,9 +37,11 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10 if not line: break if "listening at bore.pub:" in line: - remote_port = re.search('bore.pub:([0-9]*)', line).group(1) + remote_port = re.search("bore.pub:([0-9]*)", line).group(1) server_url = f"bore.pub:{remote_port}" - print_markdown(f"Your server is being hosted at the following URL: bore.pub:{remote_port}") + print_markdown( + f"Your server is being hosted at the following URL: bore.pub:{remote_port}" + ) break elif tunnel_method == "localtunnel": @@ -69,9 +73,11 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10 match = url_pattern.search(line) if match: found_url = True - remote_url = match.group(0).replace('your url is: ', '') + remote_url = match.group(0).replace("your url is: ", "") server_url = remote_url - print(f"\nYour server is being hosted at the following URL: {remote_url}") + print( + f"\nYour server is being hosted at the following URL: {remote_url}" + ) break # Exit the loop once the URL is found if not found_url: @@ -93,7 +99,11 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10 # If ngrok is installed, start it on the specified port # process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE) - process = subprocess.Popen(f'ngrok http {server_port} --scheme http,https --domain=marten-advanced-dragon.ngrok-free.app --log=stdout', shell=True, stdout=subprocess.PIPE) + process = subprocess.Popen( + f"ngrok http {server_port} --scheme http,https --domain=marten-advanced-dragon.ngrok-free.app --log=stdout", + shell=True, + stdout=subprocess.PIPE, + ) # Initially, no URL is found found_url = False @@ -110,15 +120,18 @@ def create_tunnel(tunnel_method='ngrok', server_host='localhost', server_port=10 found_url = True remote_url = match.group(0) server_url = remote_url - print(f"\nYour server is being hosted at the following URL: {remote_url}") + print( + f"\nYour server is being hosted at the following URL: {remote_url}" + ) break # Exit the loop once the URL is found if not found_url: - print("Failed to extract the ngrok tunnel URL. Please check ngrok's output for details.") + print( + "Failed to extract the ngrok tunnel URL. Please check ngrok's output for details." + ) if server_url and qr: text = pyqrcode.create(remote_url) print(text.terminal(quiet_zone=1)) return server_url - diff --git a/software/source/server/utils/bytes_to_wav.py b/software/source/server/utils/bytes_to_wav.py index d40ae150..a1892576 100644 --- a/software/source/server/utils/bytes_to_wav.py +++ b/software/source/server/utils/bytes_to_wav.py @@ -5,6 +5,7 @@ import ffmpeg import subprocess + def convert_mime_type_to_format(mime_type: str) -> str: if mime_type == "audio/x-wav" or mime_type == "audio/wav": return "wav" @@ -15,39 +16,49 @@ def convert_mime_type_to_format(mime_type: str) -> str: return mime_type + @contextlib.contextmanager def export_audio_to_wav_ffmpeg(audio: bytearray, mime_type: str) -> str: temp_dir = tempfile.gettempdir() # Create a temporary file with the appropriate extension input_ext = convert_mime_type_to_format(mime_type) - input_path = os.path.join(temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}") - with open(input_path, 'wb') as f: + input_path = os.path.join( + temp_dir, f"input_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.{input_ext}" + ) + with open(input_path, "wb") as f: f.write(audio) # Check if the input file exists assert os.path.exists(input_path), f"Input file does not exist: {input_path}" # Export to wav - output_path = os.path.join(temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav") + output_path = os.path.join( + temp_dir, f"output_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.wav" + ) print(mime_type, input_path, output_path) if mime_type == "audio/raw": ffmpeg.input( input_path, - f='s16le', - ar='16000', + f="s16le", + ar="16000", ac=1, - ).output(output_path, loglevel='panic').run() + ).output(output_path, loglevel="panic").run() else: - ffmpeg.input(input_path).output(output_path, acodec='pcm_s16le', ac=1, ar='16k', loglevel='panic').run() + ffmpeg.input(input_path).output( + output_path, acodec="pcm_s16le", ac=1, ar="16k", loglevel="panic" + ).run() try: yield output_path finally: os.remove(input_path) + def run_command(command): - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) return result.stdout, result.stderr diff --git a/software/source/server/utils/kernel.py b/software/source/server/utils/kernel.py index 4b800af9..fcca1076 100644 --- a/software/source/server/utils/kernel.py +++ b/software/source/server/utils/kernel.py @@ -1,4 +1,5 @@ from dotenv import load_dotenv + load_dotenv() # take environment variables from .env. import asyncio @@ -7,42 +8,49 @@ from .logs import setup_logging from .logs import logger + setup_logging() + def get_kernel_messages(): """ Is this the way to do this? """ current_platform = platform.system() - + if current_platform == "Darwin": - process = subprocess.Popen(['syslog'], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + process = subprocess.Popen( + ["syslog"], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) output, _ = process.communicate() - return output.decode('utf-8') + return output.decode("utf-8") elif current_platform == "Linux": - with open('/var/log/dmesg', 'r') as file: + with open("/var/log/dmesg", "r") as file: return file.read() else: logger.info("Unsupported platform.") + def custom_filter(message): # Check for {TO_INTERPRETER{ message here }TO_INTERPRETER} pattern - if '{TO_INTERPRETER{' in message and '}TO_INTERPRETER}' in message: - start = message.find('{TO_INTERPRETER{') + len('{TO_INTERPRETER{') - end = message.find('}TO_INTERPRETER}', start) + if "{TO_INTERPRETER{" in message and "}TO_INTERPRETER}" in message: + start = message.find("{TO_INTERPRETER{") + len("{TO_INTERPRETER{") + end = message.find("}TO_INTERPRETER}", start) return message[start:end] # Check for USB mention # elif 'USB' in message: # return message # # Check for network related keywords # elif any(keyword in message for keyword in ['network', 'IP', 'internet', 'LAN', 'WAN', 'router', 'switch']) and "networkStatusForFlags" not in message: - + # return message else: return None - + + last_messages = "" + def check_filtered_kernel(): messages = get_kernel_messages() if messages is None: @@ -51,12 +59,12 @@ def check_filtered_kernel(): global last_messages messages.replace(last_messages, "") messages = messages.split("\n") - + filtered_messages = [] for message in messages: if custom_filter(message): filtered_messages.append(message) - + return "\n".join(filtered_messages) @@ -66,11 +74,25 @@ async def put_kernel_messages_into_queue(queue): if text: if isinstance(queue, asyncio.Queue): await queue.put({"role": "computer", "type": "console", "start": True}) - await queue.put({"role": "computer", "type": "console", "format": "output", "content": text}) + await queue.put( + { + "role": "computer", + "type": "console", + "format": "output", + "content": text, + } + ) await queue.put({"role": "computer", "type": "console", "end": True}) else: queue.put({"role": "computer", "type": "console", "start": True}) - queue.put({"role": "computer", "type": "console", "format": "output", "content": text}) + queue.put( + { + "role": "computer", + "type": "console", + "format": "output", + "content": text, + } + ) queue.put({"role": "computer", "type": "console", "end": True}) - - await asyncio.sleep(5) \ No newline at end of file + + await asyncio.sleep(5) diff --git a/software/source/server/utils/logs.py b/software/source/server/utils/logs.py index 5aca8bb6..7b071a63 100644 --- a/software/source/server/utils/logs.py +++ b/software/source/server/utils/logs.py @@ -1,4 +1,5 @@ from dotenv import load_dotenv + load_dotenv() # take environment variables from .env. import os @@ -9,9 +10,7 @@ def _basic_config() -> None: - logging.basicConfig( - format="%(message)s" - ) + logging.basicConfig(format="%(message)s") def setup_logging() -> None: diff --git a/software/source/utils/accumulator.py b/software/source/utils/accumulator.py index edecda16..37912b5c 100644 --- a/software/source/utils/accumulator.py +++ b/software/source/utils/accumulator.py @@ -1,12 +1,11 @@ class Accumulator: def __init__(self): - self.template = {"role": None, "type": None, "format": None, "content": None} + self.template = {"role": None, "type": None, "format": None, "content": None} self.message = self.template def accumulate(self, chunk): - #print(str(chunk)[:100]) + # print(str(chunk)[:100]) if type(chunk) == dict: - if "format" in chunk and chunk["format"] == "active_line": # We don't do anything with these return None @@ -17,15 +16,20 @@ def accumulate(self, chunk): return None if "content" in chunk: - - if any(self.message[key] != chunk[key] for key in self.message if key != "content"): + if any( + self.message[key] != chunk[key] + for key in self.message + if key != "content" + ): self.message = chunk if "content" not in self.message: self.message["content"] = chunk["content"] else: if type(chunk["content"]) == dict: # dict concatenation cannot happen, so we see if chunk is a dict - self.message["content"]["content"] += chunk["content"]["content"] + self.message["content"]["content"] += chunk["content"][ + "content" + ] else: self.message["content"] += chunk["content"] return None @@ -41,5 +45,3 @@ def accumulate(self, chunk): self.message["content"] = b"" self.message["content"] += chunk return None - - \ No newline at end of file diff --git a/software/source/utils/print_markdown.py b/software/source/utils/print_markdown.py index 9fbbda80..f4eff474 100644 --- a/software/source/utils/print_markdown.py +++ b/software/source/utils/print_markdown.py @@ -1,9 +1,10 @@ from rich.console import Console from rich.markdown import Markdown + def print_markdown(markdown_text): console = Console() md = Markdown(markdown_text) print("") console.print(md) - print("") \ No newline at end of file + print("") diff --git a/software/start.py b/software/start.py index d521ad01..4f3377f9 100644 --- a/software/start.py +++ b/software/start.py @@ -15,35 +15,64 @@ @app.command() def run( - server: bool = typer.Option(False, "--server", help="Run server"), - server_host: str = typer.Option("0.0.0.0", "--server-host", help="Specify the server host where the server will deploy"), - server_port: int = typer.Option(10001, "--server-port", help="Specify the server port where the server will deploy"), - - tunnel_service: str = typer.Option("ngrok", "--tunnel-service", help="Specify the tunnel service"), - expose: bool = typer.Option(False, "--expose", help="Expose server to internet"), - - client: bool = typer.Option(False, "--client", help="Run client"), - server_url: str = typer.Option(None, "--server-url", help="Specify the server URL that the client should expect. Defaults to server-host and server-port"), - client_type: str = typer.Option("auto", "--client-type", help="Specify the client type"), - - llm_service: str = typer.Option("litellm", "--llm-service", help="Specify the LLM service"), - - model: str = typer.Option("gpt-4", "--model", help="Specify the model"), - llm_supports_vision: bool = typer.Option(False, "--llm-supports-vision", help="Specify if the LLM service supports vision"), - llm_supports_functions: bool = typer.Option(False, "--llm-supports-functions", help="Specify if the LLM service supports functions"), - context_window: int = typer.Option(2048, "--context-window", help="Specify the context window size"), - max_tokens: int = typer.Option(4096, "--max-tokens", help="Specify the maximum number of tokens"), - temperature: float = typer.Option(0.8, "--temperature", help="Specify the temperature for generation"), - - tts_service: str = typer.Option("openai", "--tts-service", help="Specify the TTS service"), - - stt_service: str = typer.Option("openai", "--stt-service", help="Specify the STT service"), - - local: bool = typer.Option(False, "--local", help="Use recommended local services for LLM, STT, and TTS"), - - qr: bool = typer.Option(False, "--qr", help="Print the QR code for the server URL") - ): - + server: bool = typer.Option(False, "--server", help="Run server"), + server_host: str = typer.Option( + "0.0.0.0", + "--server-host", + help="Specify the server host where the server will deploy", + ), + server_port: int = typer.Option( + 10001, + "--server-port", + help="Specify the server port where the server will deploy", + ), + tunnel_service: str = typer.Option( + "ngrok", "--tunnel-service", help="Specify the tunnel service" + ), + expose: bool = typer.Option(False, "--expose", help="Expose server to internet"), + client: bool = typer.Option(False, "--client", help="Run client"), + server_url: str = typer.Option( + None, + "--server-url", + help="Specify the server URL that the client should expect. Defaults to server-host and server-port", + ), + client_type: str = typer.Option( + "auto", "--client-type", help="Specify the client type" + ), + llm_service: str = typer.Option( + "litellm", "--llm-service", help="Specify the LLM service" + ), + model: str = typer.Option("gpt-4", "--model", help="Specify the model"), + llm_supports_vision: bool = typer.Option( + False, + "--llm-supports-vision", + help="Specify if the LLM service supports vision", + ), + llm_supports_functions: bool = typer.Option( + False, + "--llm-supports-functions", + help="Specify if the LLM service supports functions", + ), + context_window: int = typer.Option( + 2048, "--context-window", help="Specify the context window size" + ), + max_tokens: int = typer.Option( + 4096, "--max-tokens", help="Specify the maximum number of tokens" + ), + temperature: float = typer.Option( + 0.8, "--temperature", help="Specify the temperature for generation" + ), + tts_service: str = typer.Option( + "openai", "--tts-service", help="Specify the TTS service" + ), + stt_service: str = typer.Option( + "openai", "--stt-service", help="Specify the STT service" + ), + local: bool = typer.Option( + False, "--local", help="Use recommended local services for LLM, STT, and TTS" + ), + qr: bool = typer.Option(False, "--qr", help="Print the QR code for the server URL"), +): _run( server=server, server_host=server_host, @@ -63,7 +92,7 @@ def run( tts_service=tts_service, stt_service=stt_service, local=local, - qr=qr + qr=qr, ) @@ -86,7 +115,7 @@ def _run( tts_service: str = "openai", stt_service: str = "openai", local: bool = False, - qr: bool = False + qr: bool = False, ): if local: tts_service = "piper" @@ -130,7 +159,9 @@ def handle_exit(signum, frame): server_thread.start() if expose: - tunnel_thread = threading.Thread(target=create_tunnel, args=[tunnel_service, server_host, server_port, qr]) + tunnel_thread = threading.Thread( + target=create_tunnel, args=[tunnel_service, server_host, server_port, qr] + ) tunnel_thread.start() if client: