Skip to content

Commit

Permalink
Merge pull request #15 from SunbirdAI/asr-stt-inference
Browse files Browse the repository at this point in the history
Integrate ASR/STT inference
  • Loading branch information
PatrickCmd authored Apr 21, 2024
2 parents c678384 + 95f155c commit 1f9747e
Show file tree
Hide file tree
Showing 27 changed files with 288 additions and 193 deletions.
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[flake8]
ignore = D203, E402, F403, F405, W503, W605
exclude = .git,env,__pycache__,docs/source/conf.py,old,build,dist, *migrations*,env,venv,alembic
max-complexity = 10
max-line-length = 119
7 changes: 7 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[settings]
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88
skip=env,migrations,alembic,venv
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.PHONY: lint-apply lint-check

lint-check:
@echo "Checking for lint errors..."
flake8 .
black --check .
isort --check-only .

lint-apply:
@echo "apply linting ..."
black .
isort .
35 changes: 19 additions & 16 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
from fastapi import FastAPI, Request
import os
from functools import partial
from app.routers.tasks import router as tasks_router
from app.routers.auth import router as auth_router
from app.routers.frontend import router as frontend_router
from app.middleware.monitoring_middleware import log_request
from app.docs import description, tags_metadata
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pathlib import Path

import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from dotenv import load_dotenv
from pathlib import Path
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi_limiter import FastAPILimiter

from app.docs import description, tags_metadata
from app.middleware.monitoring_middleware import log_request
from app.routers.auth import router as auth_router
from app.routers.frontend import router as frontend_router
from app.routers.tasks import router as tasks_router

load_dotenv()


app = FastAPI(
title="Sunbird AI API",
description=description,
openapi_tags=tags_metadata
title="Sunbird AI API", description=description, openapi_tags=tags_metadata
)


@app.on_event("startup")
async def startup():
redis_instance = redis.from_url(os.getenv('REDIS_URL'), encoding="utf-8", decode_responses=True)
redis_instance = redis.from_url(
os.getenv("REDIS_URL"), encoding="utf-8", decode_responses=True
)
await FastAPILimiter.init(redis_instance)


static_files_directory = Path(__file__).parent.absolute() / "static"
app.mount("/static", StaticFiles(directory=static_files_directory), name="static")

Expand All @@ -41,7 +44,7 @@ async def startup():
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
allow_headers=["*"],
)


Expand Down
16 changes: 12 additions & 4 deletions app/crud/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
from contextlib import contextmanager

from sqlalchemy.orm import Session

from app.database.db import SessionLocal
from app.models import monitoring as models
from app.schemas import monitoring as schemas
from app.database.db import SessionLocal


@contextmanager
def auto_session():
sess = SessionLocal()
try:
yield sess
sess.commit()
except:
except Exception:
sess.rollback()
finally:
sess.close()


def create_endpoint_log(log: schemas.EndpointLog):
with auto_session() as sess:
db_log = models.EndpointLog(username=log.username, endpoint=log.endpoint, time_taken=log.time_taken)
db_log = models.EndpointLog(
username=log.username, endpoint=log.endpoint, time_taken=log.time_taken
)
sess.add(db_log)


def get_logs_by_username(db: Session, username: str):
return db.query(models.EndpointLog).filter(models.EndpointLog.username == username).all()
return (
db.query(models.EndpointLog)
.filter(models.EndpointLog.username == username)
.all()
)
10 changes: 8 additions & 2 deletions app/crud/users.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from sqlalchemy.orm import Session
from app.schemas import users as schema

from app.models import users as models
from app.schemas import users as schema


def create_user(db: Session, user: schema.UserInDB) -> schema.User:
db_user = models.User(email=user.email, username=user.username, organization=user.organization, hashed_password=user.hashed_password)
db_user = models.User(
email=user.email,
username=user.username,
organization=user.organization,
hashed_password=user.hashed_password,
)
db.add(db_user)
db.commit()
db.refresh(db_user)
Expand Down
7 changes: 3 additions & 4 deletions app/database/db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os

from sqlalchemy.orm import declarative_base
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker

load_dotenv()

engine = create_engine(os.getenv('DATABASE_URL'), pool_size=20, max_overflow=0)
engine = create_engine(os.getenv("DATABASE_URL"), pool_size=20, max_overflow=0)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()
1 change: 1 addition & 0 deletions app/deps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app.database.db import SessionLocal


def get_db():
db = SessionLocal()
try:
Expand Down
11 changes: 4 additions & 7 deletions app/docs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
description = """
Welcome to the Sunbird AI API documentation. The Sunbird AI API provides you access to Sunbird's language models. The currently supported models are:
Welcome to the Sunbird AI API documentation. The Sunbird AI API provides you access to Sunbird's language models. The currently supported models are: # noqa E501
- **Translation (English to Multiple)**: translate from English to Acholi, Ateso, Luganda, Lugbara and Runyankole.
- **Translation (Multiple to English)**: translate from the 5 local language above to English.
- **Speech To Text (Luganda)**: Convert Luganda speech audio to text.
Expand All @@ -25,12 +25,9 @@
"""

tags_metadata = [
{
"name": "AI Tasks",
"description": "Operations for AI inference."
},
{"name": "AI Tasks", "description": "Operations for AI inference."},
{
"name": "Authentication Endpoints",
"description": "Operations for Authentication, including Sign up and Login"
}
"description": "Operations for Authentication, including Sign up and Login",
},
]
19 changes: 6 additions & 13 deletions app/inference_services/base.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import json
import os

import requests
from dotenv import load_dotenv
import os

load_dotenv()

url = f"https://europe-west1-aiplatform.googleapis.com/v1/projects/{os.getenv('PROJECT_ID')}/locations/europe-west1/endpoints/{os.getenv('ENDPOINT_ID')}:rawPredict"
url = f"https://europe-west1-aiplatform.googleapis.com/v1/projects/{os.getenv('PROJECT_ID')}/locations/europe-west1/endpoints/{os.getenv('ENDPOINT_ID')}:rawPredict" # noqa E501


def inference_request(payload):
token = os.popen('gcloud auth print-access-token').read().strip()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
response = requests.request(
"POST",
url,
headers=headers,
data=json.dumps(payload)
)
token = os.popen("gcloud auth print-access-token").read().strip()
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))

return response
15 changes: 5 additions & 10 deletions app/inference_services/stt_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
from io import BytesIO

from app.inference_services.base import inference_request


Expand All @@ -11,21 +12,15 @@ def create_payload(audio_file):
audio_bytes = BytesIO(contents)
encoded_audio = base64.b64encode(audio_bytes.read())

utf_audio = encoded_audio.decode('utf-8')
utf_audio = encoded_audio.decode("utf-8")

payload = {
"instances": [
{
"audio": utf_audio,
"task": "asr"
}
]
}
payload = {"instances": [{"audio": utf_audio, "task": "asr"}]}

return payload


def transcribe(audio_file):
# TODO: Handle error cases
payload = create_payload(audio_file)
response = inference_request(payload).json()
return response['transcripts'][0]
return response["transcripts"][0]
23 changes: 15 additions & 8 deletions app/inference_services/translate_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@


def get_task(target_language):
return 'translate_from_english' if target_language != 'English' else 'translate_to_english'
return (
"translate_from_english"
if target_language != "English"
else "translate_to_english"
)


def create_payload(text, source_language=None, target_language=None):
Expand All @@ -14,45 +18,48 @@ def create_payload(text, source_language=None, target_language=None):
"sentence": text,
"task": task,
"target_language": target_language,
"source_language": source_language
"source_language": source_language,
}
]
}

return payload


def create_batch_payload(request: TranslationBatchRequest):
payload = {
"instances": [
{
"sentence": request.text,
"task": get_task(request.target_language),
"target_language": request.target_language.value,
"source_language": request.source_language.value
"source_language": request.source_language.value,
}
for request in request.requests
]
}

return payload


def translate(text, source_language=None, target_language=None):
payload = create_payload(text, source_language, target_language)
response = inference_request(payload).json()
# TODO: Handle error cases i.e if there's an error from the inference server.
if target_language == 'English':
if target_language == "English":
response = response["to_english_translations"][0]
else:
response = response["from_english_translations"][0]
return response


def translate_batch(request: TranslationBatchRequest):
payload = create_batch_payload(request)
response = inference_request(payload).json()
response_list = []
if 'to_english_translations' in response:
response_list = response['to_english_translations']
if 'from_english_translations' in response:
response_list.extend(response['from_english_translations'])
if "to_english_translations" in response:
response_list = response["to_english_translations"]
if "from_english_translations" in response:
response_list.extend(response["from_english_translations"])

return response_list
23 changes: 8 additions & 15 deletions app/inference_services/tts_inference.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
from app.inference_services.base import inference_request
from app.schemas.tasks import TTSRequest

from google.cloud import storage

import base64
import uuid
import os
import uuid

from dotenv import load_dotenv
from google.cloud import storage

from app.inference_services.base import inference_request
from app.schemas.tasks import TTSRequest

load_dotenv()


def create_payload(text):
payload = {
"instances": [
{
"sentence": text,
"task": "tts"
}
]
}
payload = {"instances": [{"sentence": text, "task": "tts"}]}

return payload

Expand All @@ -42,7 +35,7 @@ def tts(request: TTSRequest):

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account_key.json"
bucket_name = os.getenv("GOOGLE_CLOUD_BUCKET_NAME")
bucket_file = f"{str(uuid.uuid4())}.wav" # using a uuid for the audio file name
bucket_file = f"{str(uuid.uuid4())}.wav" # using a uuid for the audio file name
client = storage.Client()
bucket = client.bucket(bucket_name)
blob = bucket.blob(bucket_file)
Expand Down
Loading

0 comments on commit 1f9747e

Please sign in to comment.