Skip to content

Commit

Permalink
Merge pull request #19 from SunbirdAI/update_inference_endpoints
Browse files Browse the repository at this point in the history
Update inference endpoints
  • Loading branch information
PatrickCmd authored May 22, 2024
2 parents d663c4c + 095a314 commit 27299ae
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
52 changes: 31 additions & 21 deletions app/routers/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import re
import shutil
Expand Down Expand Up @@ -36,11 +37,14 @@
router = APIRouter()

load_dotenv()
logging.basicConfig(level=logging.INFO)

PER_MINUTE_RATE_LIMIT = os.getenv("PER_MINUTE_RATE_LIMIT", 10)
RUNPOD_ENDPOINT_LANGUAGE_ID_ID = os.getenv("RUNPOD_ENDPOINT_LANGUAGE_ID_ID")
RUNPOD_ENDPOINT_ID = os.getenv("RUNPOD_ENDPOINT_ID")
# Set RunPod API Key
runpod.api_key = os.getenv("RUNPOD_API_KEY")


# Route for the Language identification endpoint
@router.post(
"/language_id",
Expand All @@ -51,38 +55,37 @@ async def language_id(
languageId_request: LanguageIdRequest, current_user=Depends(get_current_user)
):
"""
This endpoint identifies the language of a given text. It supports a limited
This endpoint identifies the language of a given text. It supports a limited
set of local languages including Acholi (ach), Ateso (teo), English (eng),
Luganda (lug), Lugbara (lgg), and Runyankole (nyn).
"""

# Define the endpoint ID for language identification
endpoint = runpod.Endpoint(RUNPOD_ENDPOINT_LANGUAGE_ID_ID)
endpoint = runpod.Endpoint(RUNPOD_ENDPOINT_ID)
request_response = {}

try:
# Run the language identification request asynchronously
run_request = endpoint.run_sync(
request_response = endpoint.run_sync(
{
"input": {
"task": "auto_detect_language",
"text": languageId_request.text,
}
},
timeout=60, # Timeout in seconds.
)

# Log the request for debugging purposes
print(run_request)
logging.info(f"Request response: {request_response}")

except TimeoutError:
# Handle timeout error and return a meaningful message to the user
print("Job timed out.")
logging.error("Job timed out.")
raise HTTPException(
status_code=408,
detail="The language identification job timed out. Please try again later.",
)

# Return the result of the language identification request
return run_request
return request_response


@router.post(
Expand All @@ -97,7 +100,7 @@ async def speech_to_text(
"""
Upload an audio file and get the transcription text of the audio
"""
endpoint = runpod.Endpoint(os.getenv("RUNPOD_ENDPOINT_ASR_STT_ID"))
endpoint = runpod.Endpoint(RUNPOD_ENDPOINT_ID)

filename = secure_filename(audio.filename)
file_path = os.path.join("/tmp", filename)
Expand All @@ -107,12 +110,14 @@ async def speech_to_text(
blob_name = upload_audio_file(file_path=file_path)
audio_file = blob_name
os.remove(file_path)
request_response = {}

start_time = time.time()
try:
run_request = endpoint.run_sync(
request_response = endpoint.run_sync(
{
"input": {
"task": "transcribe",
"target_lang": language,
"adapter": adapter,
"audio_file": audio_file,
Expand All @@ -121,14 +126,17 @@ async def speech_to_text(
timeout=600, # Timeout in seconds.
)
except TimeoutError:
print("Job timed out.")
logging.error("Job timed out.")

end_time = time.time()
logging.info(f"Response: {request_response}")

# Calculate the elapsed time
elapsed_time = end_time - start_time
print("Elapsed time:", elapsed_time, "seconds")
return STTTranscript(audio_transcription=run_request.get("audio_transcription"))
logging.info(f"Elapsed time: {elapsed_time} seconds")
return STTTranscript(
audio_transcription=request_response.get("audio_transcription")
)


# Route for the nllb translation endpoint
Expand All @@ -147,16 +155,15 @@ async def nllb_translate(
languages listed, the target can be any of the other languages.
"""
# URL for the endpoint
ENDPOINT_ID = os.getenv("RUNPOD_ENDPOINT_ID")
url = f"https://api.runpod.ai/v2/{ENDPOINT_ID}/runsync"
url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync"

# Authorization token
token = os.getenv("RUNPOD_API_KEY")

# Split text into chunks of 100 words each
text = translation_request.text
text_chunks = chunk_text(text, chunk_size=100)
print(f"text_chunks length: {len(text_chunks)}")
logging.info(f"text_chunks length: {len(text_chunks)}")

# Translated chunks will be stored here
translated_text_chunks = []
Expand All @@ -165,6 +172,7 @@ async def nllb_translate(
# Data to be sent in the request body
data = {
"input": {
"task": "translate",
"source_language": translation_request.source_language,
"target_language": translation_request.target_language,
"text": chunk.strip(), # Remove leading/trailing spaces
Expand All @@ -175,18 +183,20 @@ async def nllb_translate(
headers = {"Authorization": token, "Content-Type": "application/json"}

response = requests.post(url, headers=headers, json=data)
logging.info(f"response: {response.json()}")

if response.status_code == 200:
translated_chunk = response.json()["output"]["data"]["translated_text"]
translated_chunk = response.json()["output"]["translated_text"]
translated_text_chunks.append(translated_chunk)
else:
raise HTTPException(status_code=response.status_code, detail=response.text)

logging.info(f"translated_text_chunks: {translated_text_chunks}")
# Concatenate translated chunks
final_translated_text = " ".join(translated_text_chunks)
response = response.json()
response["output"]["data"]["text"] = text
response["output"]["data"]["translated_text"] = final_translated_text
response["output"]["text"] = text
response["output"]["translated_text"] = final_translated_text

return response

Expand Down
6 changes: 1 addition & 5 deletions app/schemas/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@ class NllbResponseOutputData(BaseModel):
translated_text: str


class NllbOutput(BaseModel):
data: NllbResponseOutputData


class NllbTranslationResponse(BaseModel):
delayTime: int
executionTime: int
id: str
output: NllbOutput
output: NllbResponseOutputData
status: str


Expand Down
6 changes: 3 additions & 3 deletions coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 27299ae

Please sign in to comment.