diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..1a813c1 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] + ignore = D203, E402, F403, F405, W503, W605 + exclude = .git,env,__pycache__,docs/source/conf.py,old,build,dist, *migrations*,env,venv,alembic + max-complexity = 10 + max-line-length = 119 \ No newline at end of file diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..742448f --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,7 @@ +[settings] +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88 +skip=env,migrations,alembic,venv \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1bd5d5f --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +.PHONY: lint-apply lint-check + +lint-check: + @echo "Checking for lint errors..." + flake8 . + black --check . + isort --check-only . + +lint-apply: + @echo "apply linting ..." + black . + isort . diff --git a/app/api.py b/app/api.py index 99de695..ea6314d 100644 --- a/app/api.py +++ b/app/api.py @@ -1,33 +1,36 @@ -from fastapi import FastAPI, Request +import os from functools import partial -from app.routers.tasks import router as tasks_router -from app.routers.auth import router as auth_router -from app.routers.frontend import router as frontend_router -from app.middleware.monitoring_middleware import log_request -from app.docs import description, tags_metadata -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles +from pathlib import Path import redis.asyncio as redis -from fastapi_limiter import FastAPILimiter from dotenv import load_dotenv -from pathlib import Path -import os +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from fastapi_limiter import FastAPILimiter + +from app.docs import description, tags_metadata +from app.middleware.monitoring_middleware import log_request +from app.routers.auth import router as auth_router +from app.routers.frontend import router as frontend_router +from app.routers.tasks import router as tasks_router load_dotenv() app = FastAPI( - title="Sunbird AI API", - description=description, - openapi_tags=tags_metadata + title="Sunbird AI API", description=description, openapi_tags=tags_metadata ) + @app.on_event("startup") async def startup(): - redis_instance = redis.from_url(os.getenv('REDIS_URL'), encoding="utf-8", decode_responses=True) + redis_instance = redis.from_url( + os.getenv("REDIS_URL"), encoding="utf-8", decode_responses=True + ) await FastAPILimiter.init(redis_instance) + static_files_directory = Path(__file__).parent.absolute() / "static" app.mount("/static", StaticFiles(directory=static_files_directory), name="static") @@ -41,7 +44,7 @@ async def startup(): allow_origins=origins, allow_credentials=True, allow_methods=["*"], - allow_headers=["*"] + allow_headers=["*"], ) diff --git a/app/crud/monitoring.py b/app/crud/monitoring.py index d272d7e..7d53de6 100644 --- a/app/crud/monitoring.py +++ b/app/crud/monitoring.py @@ -1,9 +1,11 @@ from contextlib import contextmanager from sqlalchemy.orm import Session + +from app.database.db import SessionLocal from app.models import monitoring as models from app.schemas import monitoring as schemas -from app.database.db import SessionLocal + @contextmanager def auto_session(): @@ -11,7 +13,7 @@ def auto_session(): try: yield sess sess.commit() - except: + except Exception: sess.rollback() finally: sess.close() @@ -19,9 +21,15 @@ def auto_session(): def create_endpoint_log(log: schemas.EndpointLog): with auto_session() as sess: - db_log = models.EndpointLog(username=log.username, endpoint=log.endpoint, time_taken=log.time_taken) + db_log = models.EndpointLog( + username=log.username, endpoint=log.endpoint, time_taken=log.time_taken + ) sess.add(db_log) def get_logs_by_username(db: Session, username: str): - return db.query(models.EndpointLog).filter(models.EndpointLog.username == username).all() + return ( + db.query(models.EndpointLog) + .filter(models.EndpointLog.username == username) + .all() + ) diff --git a/app/crud/users.py b/app/crud/users.py index a396064..6a00bf6 100644 --- a/app/crud/users.py +++ b/app/crud/users.py @@ -1,10 +1,16 @@ from sqlalchemy.orm import Session -from app.schemas import users as schema + from app.models import users as models +from app.schemas import users as schema def create_user(db: Session, user: schema.UserInDB) -> schema.User: - db_user = models.User(email=user.email, username=user.username, organization=user.organization, hashed_password=user.hashed_password) + db_user = models.User( + email=user.email, + username=user.username, + organization=user.organization, + hashed_password=user.hashed_password, + ) db.add(db_user) db.commit() db.refresh(db_user) diff --git a/app/database/db.py b/app/database/db.py index e5f4d4b..e9edaf4 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -1,13 +1,12 @@ import os -from sqlalchemy.orm import declarative_base -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.orm import declarative_base, sessionmaker load_dotenv() -engine = create_engine(os.getenv('DATABASE_URL'), pool_size=20, max_overflow=0) +engine = create_engine(os.getenv("DATABASE_URL"), pool_size=20, max_overflow=0) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() diff --git a/app/deps.py b/app/deps.py index c79a895..577aa0c 100644 --- a/app/deps.py +++ b/app/deps.py @@ -1,5 +1,6 @@ from app.database.db import SessionLocal + def get_db(): db = SessionLocal() try: diff --git a/app/docs.py b/app/docs.py index d2a4f16..d5abf7c 100644 --- a/app/docs.py +++ b/app/docs.py @@ -1,5 +1,5 @@ description = """ -Welcome to the Sunbird AI API documentation. The Sunbird AI API provides you access to Sunbird's language models. The currently supported models are: +Welcome to the Sunbird AI API documentation. The Sunbird AI API provides you access to Sunbird's language models. The currently supported models are: # noqa E501 - **Translation (English to Multiple)**: translate from English to Acholi, Ateso, Luganda, Lugbara and Runyankole. - **Translation (Multiple to English)**: translate from the 5 local language above to English. - **Speech To Text (Luganda)**: Convert Luganda speech audio to text. @@ -25,12 +25,9 @@ """ tags_metadata = [ - { - "name": "AI Tasks", - "description": "Operations for AI inference." - }, + {"name": "AI Tasks", "description": "Operations for AI inference."}, { "name": "Authentication Endpoints", - "description": "Operations for Authentication, including Sign up and Login" - } + "description": "Operations for Authentication, including Sign up and Login", + }, ] diff --git a/app/inference_services/base.py b/app/inference_services/base.py index c531a75..c82bf9d 100644 --- a/app/inference_services/base.py +++ b/app/inference_services/base.py @@ -1,24 +1,17 @@ import json +import os + import requests from dotenv import load_dotenv -import os load_dotenv() -url = f"https://europe-west1-aiplatform.googleapis.com/v1/projects/{os.getenv('PROJECT_ID')}/locations/europe-west1/endpoints/{os.getenv('ENDPOINT_ID')}:rawPredict" +url = f"https://europe-west1-aiplatform.googleapis.com/v1/projects/{os.getenv('PROJECT_ID')}/locations/europe-west1/endpoints/{os.getenv('ENDPOINT_ID')}:rawPredict" # noqa E501 def inference_request(payload): - token = os.popen('gcloud auth print-access-token').read().strip() - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json" - } - response = requests.request( - "POST", - url, - headers=headers, - data=json.dumps(payload) - ) + token = os.popen("gcloud auth print-access-token").read().strip() + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + response = requests.request("POST", url, headers=headers, data=json.dumps(payload)) return response diff --git a/app/inference_services/stt_inference.py b/app/inference_services/stt_inference.py index 2389258..bae41f4 100644 --- a/app/inference_services/stt_inference.py +++ b/app/inference_services/stt_inference.py @@ -1,5 +1,6 @@ import base64 from io import BytesIO + from app.inference_services.base import inference_request @@ -11,21 +12,15 @@ def create_payload(audio_file): audio_bytes = BytesIO(contents) encoded_audio = base64.b64encode(audio_bytes.read()) - utf_audio = encoded_audio.decode('utf-8') + utf_audio = encoded_audio.decode("utf-8") - payload = { - "instances": [ - { - "audio": utf_audio, - "task": "asr" - } - ] - } + payload = {"instances": [{"audio": utf_audio, "task": "asr"}]} return payload + def transcribe(audio_file): # TODO: Handle error cases payload = create_payload(audio_file) response = inference_request(payload).json() - return response['transcripts'][0] + return response["transcripts"][0] diff --git a/app/inference_services/translate_inference.py b/app/inference_services/translate_inference.py index ebc065b..544e964 100644 --- a/app/inference_services/translate_inference.py +++ b/app/inference_services/translate_inference.py @@ -3,7 +3,11 @@ def get_task(target_language): - return 'translate_from_english' if target_language != 'English' else 'translate_to_english' + return ( + "translate_from_english" + if target_language != "English" + else "translate_to_english" + ) def create_payload(text, source_language=None, target_language=None): @@ -14,13 +18,14 @@ def create_payload(text, source_language=None, target_language=None): "sentence": text, "task": task, "target_language": target_language, - "source_language": source_language + "source_language": source_language, } ] } return payload + def create_batch_payload(request: TranslationBatchRequest): payload = { "instances": [ @@ -28,7 +33,7 @@ def create_batch_payload(request: TranslationBatchRequest): "sentence": request.text, "task": get_task(request.target_language), "target_language": request.target_language.value, - "source_language": request.source_language.value + "source_language": request.source_language.value, } for request in request.requests ] @@ -36,23 +41,25 @@ def create_batch_payload(request: TranslationBatchRequest): return payload + def translate(text, source_language=None, target_language=None): payload = create_payload(text, source_language, target_language) response = inference_request(payload).json() # TODO: Handle error cases i.e if there's an error from the inference server. - if target_language == 'English': + if target_language == "English": response = response["to_english_translations"][0] else: response = response["from_english_translations"][0] return response + def translate_batch(request: TranslationBatchRequest): payload = create_batch_payload(request) response = inference_request(payload).json() response_list = [] - if 'to_english_translations' in response: - response_list = response['to_english_translations'] - if 'from_english_translations' in response: - response_list.extend(response['from_english_translations']) + if "to_english_translations" in response: + response_list = response["to_english_translations"] + if "from_english_translations" in response: + response_list.extend(response["from_english_translations"]) return response_list diff --git a/app/inference_services/tts_inference.py b/app/inference_services/tts_inference.py index 4f0c85b..c810a87 100644 --- a/app/inference_services/tts_inference.py +++ b/app/inference_services/tts_inference.py @@ -1,25 +1,18 @@ -from app.inference_services.base import inference_request -from app.schemas.tasks import TTSRequest - -from google.cloud import storage - import base64 -import uuid import os +import uuid + from dotenv import load_dotenv +from google.cloud import storage + +from app.inference_services.base import inference_request +from app.schemas.tasks import TTSRequest load_dotenv() def create_payload(text): - payload = { - "instances": [ - { - "sentence": text, - "task": "tts" - } - ] - } + payload = {"instances": [{"sentence": text, "task": "tts"}]} return payload @@ -42,7 +35,7 @@ def tts(request: TTSRequest): os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account_key.json" bucket_name = os.getenv("GOOGLE_CLOUD_BUCKET_NAME") - bucket_file = f"{str(uuid.uuid4())}.wav" # using a uuid for the audio file name + bucket_file = f"{str(uuid.uuid4())}.wav" # using a uuid for the audio file name client = storage.Client() bucket = client.bucket(bucket_name) blob = bucket.blob(bucket_file) diff --git a/app/middleware/monitoring_middleware.py b/app/middleware/monitoring_middleware.py index 2b7855a..aac06ac 100644 --- a/app/middleware/monitoring_middleware.py +++ b/app/middleware/monitoring_middleware.py @@ -2,20 +2,23 @@ from fastapi import Request from fastapi.exceptions import HTTPException -from app.schemas.monitoring import EndpointLog -from app.crud.monitoring import create_endpoint_log from jose import jwt -from app.utils.auth_utils import SECRET_KEY, ALGORITHM + +from app.crud.monitoring import create_endpoint_log +from app.schemas.monitoring import EndpointLog +from app.utils.auth_utils import ALGORITHM, SECRET_KEY async def log_request(request: Request, call_next): - if request.url.path.startswith('/tasks'): + if request.url.path.startswith("/tasks"): try: - header = request.headers['Authorization'] - bearer, _, token = header.partition(' ') + header = request.headers["Authorization"] + bearer, _, token = header.partition(" ") # token = request.headers['Authorization'].replace('Bearer ', '') - # TODO: Find another way of getting the current user. This is inefficient as it makes 2 similar database calls which causes problems with the DB pool size + # TODO: Find another way of getting the current user. + # This is inefficient as it makes 2 similar database calls which causes + # problems with the DB pool size # user = get_current_user(token, db_session) # TODO: This is a hacky workaround for the hackathon to prevent multiple DB calls. username = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]).get("sub") @@ -28,7 +31,7 @@ async def log_request(request: Request, call_next): username=username, endpoint=request.url.path, organization=organization, - time_taken=(end - start) + time_taken=(end - start), ) create_endpoint_log(endpoint_log) except HTTPException: diff --git a/app/models/monitoring.py b/app/models/monitoring.py index 347e019..368cafd 100644 --- a/app/models/monitoring.py +++ b/app/models/monitoring.py @@ -1,5 +1,6 @@ -from sqlalchemy import Column, String, Float, Integer, DateTime +from sqlalchemy import Column, DateTime, Float, Integer, String from sqlalchemy.sql import func + from app.database.db import Base diff --git a/app/models/users.py b/app/models/users.py index a59cb94..5e648cf 100644 --- a/app/models/users.py +++ b/app/models/users.py @@ -1,4 +1,5 @@ -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, Integer, String + from app.database.db import Base diff --git a/app/routers/auth.py b/app/routers/auth.py index 13f5dc7..5e14f55 100644 --- a/app/routers/auth.py +++ b/app/routers/auth.py @@ -1,19 +1,20 @@ from datetime import timedelta -from fastapi import APIRouter, HTTPException, status, Depends +from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from app.schemas.users import UserCreate, User, UserInDB, Token, TokenData -from app.crud.users import create_user, get_user_by_username, get_user_by_email -from app.deps import get_db +from jose import JWTError from sqlalchemy.orm import Session + +from app.crud.users import create_user, get_user_by_email, get_user_by_username +from app.deps import get_db +from app.schemas.users import Token, TokenData, User, UserCreate, UserInDB from app.utils.auth_utils import ( - authenticate_user, - get_password_hash, ACCESS_TOKEN_EXPIRE_MINUTES, + authenticate_user, create_access_token, + get_password_hash, get_username_from_token, ) -from jose import JWTError router = APIRouter() diff --git a/app/routers/frontend.py b/app/routers/frontend.py index e19c615..bcbee57 100644 --- a/app/routers/frontend.py +++ b/app/routers/frontend.py @@ -1,21 +1,22 @@ import json +from datetime import timedelta -from fastapi import APIRouter, Request, Form, Depends, responses, status +from fastapi import APIRouter, Depends, Form, Request, responses, status from fastapi.templating import Jinja2Templates +from pydantic.error_wrappers import ValidationError +from sqlalchemy.orm import Session + +from app.crud.users import create_user, get_user_by_email, get_user_by_username from app.deps import get_db +from app.schemas.users import User, UserCreate, UserInDB from app.utils.auth_utils import ( + ACCESS_TOKEN_EXPIRE_MINUTES, + OAuth2PasswordBearerWithCookie, authenticate_user, create_access_token, - ACCESS_TOKEN_EXPIRE_MINUTES, get_password_hash, get_username_from_token, - OAuth2PasswordBearerWithCookie, ) -from sqlalchemy.orm import Session -from datetime import timedelta -from app.schemas.users import UserCreate, UserInDB, User -from app.crud.users import create_user, get_user_by_username, get_user_by_email -from pydantic.error_wrappers import ValidationError from app.utils.monitoring_utils import aggregate_usage_for_user router = APIRouter() @@ -36,7 +37,7 @@ async def login(request: Request): # type: ignore @router.post("/login") -async def login( +async def login( # noqa F811 request: Request, username: str = Form(...), password: str = Form(...), @@ -70,7 +71,7 @@ async def signup(request: Request): @router.post("/register") -async def signup( +async def signup( # noqa F811 request: Request, email: str = Form(...), username: str = Form(...), diff --git a/app/routers/tasks.py b/app/routers/tasks.py index 0952dd9..d3b5b22 100644 --- a/app/routers/tasks.py +++ b/app/routers/tasks.py @@ -1,49 +1,40 @@ import os -import io import re -import requests +import shutil +import time +import requests +import runpod from dotenv import load_dotenv +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi_limiter.depends import RateLimiter +from twilio.rest import Client +from werkzeug.utils import secure_filename -from fastapi import ( - APIRouter, - HTTPException, - status, - File, - UploadFile, - Form, - Depends, - Header, -) +from app.inference_services.translate_inference import translate, translate_batch +from app.inference_services.tts_inference import tts +from app.routers.auth import get_current_user from app.schemas.tasks import ( + ChatRequest, + ChatResponse, + NllbLanguage, + NllbTranslationRequest, + NllbTranslationResponse, STTTranscript, - TranslationRequest, - TranslationResponse, TranslationBatchRequest, TranslationBatchResponse, - NllbTranslationRequest, - NllbTranslationResponse, + TranslationRequest, + TranslationResponse, TTSRequest, TTSResponse, - ChatRequest, - ChatResponse, - Language, ) -from typing import Annotated - -from app.inference_services.stt_inference import transcribe -from app.inference_services.translate_inference import translate, translate_batch -from app.inference_services.tts_inference import tts -from app.routers.auth import get_current_user -from pydub import AudioSegment -from fastapi_limiter.depends import RateLimiter -from twilio.rest import Client - +from app.utils.upload_audio_file_gcp import upload_audio_file router = APIRouter() load_dotenv() PER_MINUTE_RATE_LIMIT = os.getenv("PER_MINUTE_RATE_LIMIT", 10) +runpod.api_key = os.getenv("RUNPOD_API_KEY") @router.post( @@ -51,27 +42,47 @@ ) async def speech_to_text( audio: UploadFile(...) = File(...), - language: Language = Form("Luganda"), - return_confidences: bool = Form(False), + language: NllbLanguage = Form("lug"), + adapter: NllbLanguage = Form("lug"), current_user=Depends(get_current_user), -) -> STTTranscript: # TODO: Make language an enum +) -> STTTranscript: """ - We currently only support Luganda. + Upload an audio file and get the transcription text of the audio """ - if not audio.content_type.startswith("audio"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid file type uploaded. Please upload a valid audio file", + endpoint = runpod.Endpoint(os.getenv("RUNPOD_ENDPOINT_ASR_STT_ID")) + + filename = secure_filename(audio.filename) + file_path = os.path.join("/tmp", filename) + with open(file_path, "wb") as buffer: + shutil.copyfileobj(audio.file, buffer) + + blob_name = upload_audio_file(file_path=file_path) + audio_file = blob_name + os.remove(file_path) + + start_time = time.time() + try: + run_request = endpoint.run_sync( + { + "input": { + "target_lang": language, + "adapter": adapter, + "audio_file": audio_file, + } + }, + timeout=600, # Timeout in seconds. ) - if audio.content_type != "audio/wave": - # try to convert to wave, if it fails return an error. - buf = io.BytesIO() - audio_file = audio.file - audio = AudioSegment.from_file(audio_file) - audio = audio.export(buf, format="wav") + except TimeoutError: + print("Job timed out.") + + end_time = time.time() - response = transcribe(audio) - return STTTranscript(text=response) + # Calculate the elapsed time + elapsed_time = end_time - start_time + print("Elapsed time:", elapsed_time, "seconds") + return STTTranscript( + audio_transcription=run_request.get("audio_transcription") + ) # Route for the nllb translation endpoint @@ -84,8 +95,9 @@ async def nllb_translate( translation_request: NllbTranslationRequest, current_user=Depends(get_current_user) ): """ - Source and Target Language can be one of: ach(Acholi), teo(Ateso), eng(English), lug(Luganda), lgg(Lugbara), or nyn(Runyankole). - We currently only support English to Local languages and Local to English languages, so when the source language is one of the + Source and Target Language can be one of: ach(Acholi), teo(Ateso), eng(English), + lug(Luganda), lgg(Lugbara), or nyn(Runyankole).We currently only support English to Local + languages and Local to English languages, so when the source language is one of the languages listed, the target can be any of the other languages. """ # URL for the endpoint @@ -192,7 +204,7 @@ async def chat(chat_request: ChatRequest, current_user=Depends(get_current_user) to_number = chat_request.to_number client = Client(account_sid, auth_token) - message = client.messages.create( + _ = client.messages.create( from_=f"whatsapp:{from_number}", body=response, to=f"whatsapp:{to_number}" ) diff --git a/app/schemas/tasks.py b/app/schemas/tasks.py index b02c9a9..6c3cffa 100644 --- a/app/schemas/tasks.py +++ b/app/schemas/tasks.py @@ -1,15 +1,11 @@ +from enum import Enum from typing import List + from pydantic import BaseModel, Field -from enum import Enum class STTTranscript(BaseModel): - text: str - confidences: List[int] | None = None - - -from typing import Optional -from pydantic import BaseModel + audio_transcription: str class NllbResponseOutputData(BaseModel): diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index 8c6b412..630f343 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -2,10 +2,10 @@ import pytest from fastapi.testclient import TestClient -from app.api import app - from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker + +from app.api import app from app.database.db import Base from app.deps import get_db @@ -17,12 +17,14 @@ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + @pytest.fixture() def test_db(): Base.metadata.create_all(bind=engine) yield Base.metadata.drop_all(bind=engine) + def override_get_db(): db = TestingSessionLocal() try: @@ -30,20 +32,22 @@ def override_get_db(): finally: db.close() + app.dependency_overrides[get_db] = override_get_db client = TestClient(app) + def test_register(test_db): user_data = { - 'username': 'test_user', - 'email': 'test_user@email.com', - 'password': 'test_password', - 'organization': 'test_organization', - 'account_type': 'Free' + "username": "test_user", + "email": "test_user@email.com", + "password": "test_password", + "organization": "test_organization", + "account_type": "Free", } - response = client.post('/auth/register', content=json.dumps(user_data)) + response = client.post("/auth/register", content=json.dumps(user_data)) assert response.status_code == 201 - expected_dict = {key: user_data[key] for key in user_data if key != 'password'} - expected_dict['id'] = 1 + expected_dict = {key: user_data[key] for key in user_data if key != "password"} + expected_dict["id"] = 1 assert response.json() == expected_dict diff --git a/app/utils/auth_utils.py b/app/utils/auth_utils.py index 3d886a2..3d42939 100644 --- a/app/utils/auth_utils.py +++ b/app/utils/auth_utils.py @@ -1,20 +1,21 @@ import os -from datetime import timedelta, datetime -from typing import Optional, Dict +from datetime import datetime, timedelta +from typing import Dict, Optional -from app.crud.users import get_user_by_username -from sqlalchemy.orm import Session from dotenv import load_dotenv -from passlib.context import CryptContext -from jose import jwt +from fastapi import HTTPException, Request, status +from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel from fastapi.security import OAuth2 from fastapi.security.utils import get_authorization_scheme_param -from fastapi import Request, HTTPException, status -from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel +from jose import jwt +from passlib.context import CryptContext +from sqlalchemy.orm import Session + +from app.crud.users import get_user_by_username load_dotenv() -SECRET_KEY = os.getenv('SECRET_KEY') +SECRET_KEY = os.getenv("SECRET_KEY") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 52560000 @@ -57,11 +58,11 @@ def get_username_from_token(token: str) -> str: class OAuth2PasswordBearerWithCookie(OAuth2): def __init__( - self, - tokenUrl: str, - scheme_name: Optional[str] = None, - scopes: Optional[Dict[str, str]] = None, - auto_error: bool = True + self, + tokenUrl: str, + scheme_name: Optional[str] = None, + scopes: Optional[Dict[str, str]] = None, + auto_error: bool = True, ): if not scopes: scopes = {} @@ -77,7 +78,7 @@ async def __call__(self, request: Request) -> Optional[str]: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", - headers={"WWW-Authenticate": "Bearer"} + headers={"WWW-Authenticate": "Bearer"}, ) else: return None diff --git a/app/utils/upload_audio_file_gcp.py b/app/utils/upload_audio_file_gcp.py new file mode 100644 index 0000000..f389be0 --- /dev/null +++ b/app/utils/upload_audio_file_gcp.py @@ -0,0 +1,25 @@ +import os + +from dotenv import load_dotenv +from google.cloud import storage + +load_dotenv() + + +def upload_audio_file(file_path): + try: + # Initialize a client and get the bucket + storage_client = storage.Client() + bucket_name = os.getenv("AUDIO_CONTENT_BUCKET_NAME") + bucket = storage_client.bucket(bucket_name) + + blob_name = os.path.basename(file_path) + + # Upload the file to the bucket + blob = bucket.blob(blob_name) + blob.upload_from_filename(file_path) + + return blob_name + except Exception as e: + print(f"An error occurred: {e}") + return None diff --git a/coverage.svg b/coverage.svg index 2fad913..dd6df12 100644 --- a/coverage.svg +++ b/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 62% - 62% + 61% + 61% diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..648cf4b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[tool.black] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | env + |venv + | _build + | buck-out + | build + | dist + | migrations + |alembic +)/ +''' \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..614c34d --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +black==22.3.0 +isort==5.12.0 +flake8==6.1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bbfb3df..5b0cb84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp==3.9.0 +aiohttp aiohttp-retry==2.8.3 aiosignal==1.3.1 alembic==1.10.4 @@ -10,7 +10,7 @@ cachetools==5.3.1 certifi==2022.12.7 cffi==1.15.1 charset-normalizer==3.1.0 -click==8.1.3 +click coverage==7.2.7 cryptography==40.0.2 dnspython==2.3.0 @@ -19,13 +19,13 @@ email-validator==2.0.0.post2 fastapi==0.94.1 fastapi-limiter==0.1.5 frozenlist==1.4.0 -google-api-core==2.11.1 +google-api-core google-api-python-client==2.95.0 google-auth==2.22.0 google-auth-httplib2==0.1.0 google-cloud==0.34.0 google-cloud-core==2.3.3 -google-cloud-storage==2.10.0 +google-cloud-storage google-crc32c==1.5.0 google-resumable-media==2.5.0 googleapis-common-protos==1.60.0 @@ -57,7 +57,8 @@ python-dotenv==1.0.0 python-jose==3.3.0 python-multipart==0.0.6 redis==4.6.0 -requests==2.28.2 +requests +runpod==1.6.2 rsa==4.9 six==1.16.0 sniffio==1.3.0 @@ -68,4 +69,5 @@ typing_extensions==4.5.0 uritemplate==4.1.1 urllib3==1.26.15 uvicorn==0.21.1 +Werkzeug yarl==1.9.3