From 1bc195ddf9380c07930d24e375821112f9fe9103 Mon Sep 17 00:00:00 2001 From: Johnny Zheng Date: Sun, 3 Nov 2024 01:49:01 -0700 Subject: [PATCH] updated main --- FastAPI/api/main.py | 100 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/FastAPI/api/main.py b/FastAPI/api/main.py index 8f3ae1f4..a410546b 100644 --- a/FastAPI/api/main.py +++ b/FastAPI/api/main.py @@ -1,3 +1,4 @@ +import os from fastapi import FastAPI, Depends, HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session @@ -5,12 +6,104 @@ from database import SessionLocal, engine import models from fastapi.middleware.cors import CORSMiddleware +import google.generativeai as genai +from google.api_core.exceptions import (ResourceExhausted, FailedPrecondition, + InvalidArgument, ServiceUnavailable, + InternalServerError) +# from gemini_module import call_gemini # Import your function + +genai.configure(api_key="AIzaSyCN7ZZCZbaXxErK88XbrRz2aDF1knWcXHA") # Create database tables models.Base.metadata.create_all(bind=engine) app = FastAPI() +#functions +def call_gemini(productName: str, productDescription: str, productAdditional: str = "No Additional Information to Provide"): + """After thread is created, call Gemini, get a response, and begin executing commands + + Args: + recorder: the recorder object must be created in the main process, so it is passed in to this function as an argument and used + """ + global retries, mainChat, main_model + + # if function has been recursively called 3 times (in 3 attempts to retry a prompt), break out of loop + if retries == 3: + retries = 0 + print('An unexpected error occurred on Google\'s side. Wait a bit and retry your request. If the issue persists after retrying, please report it using the Send feedback button in Google AI Studio.') + return + + model_prompt = f""" + You are trying to sell a {productName}. Given the following information, identify at LEAST three potential communities to target for advertisement: + + ### Product Name: + {productName} + + ### Product Description: + {productDescription} + + ### Additional Information (Optional) + {productAdditional} + + ### Instructions: + Step 1. Given the information above, identify at LEAST three very specific potential communities to target for advertisement. Be AS SPECIFIC AS POSSIBLE when naming communities to target. For each community, include a short justification for why this specific community. + Step 2. You are an AI capable of generating text and images. Underneath this, for EACH of your targeted communities, identify the optimal means of advertisement for each (either text or image, you may choose to do both) + Step 3. + a. If you create a textual advertisement, generate a full length text to be able to be copy pasted straight into the community as an advertisement. Omit any unnecessary information. + b. If you create a visual advertisement (image), generate a prompt to be used for an AI image generator. You may be as specific as necessary in your vision of the optimal visual advertisement. + + ### Formatting Instructions: + - Follow the below sample output as closely as possible. + - After outputting step 1, print "BREAK HERE" for us to do some preprocessing formatting. + - After the break, list out + + + #### Sample Output: + 1. Billboards in Los Angeles\n + + + 2. Reddit Community (r/furniture)\n + + + 3. Youtube Ad (Home Improvement Channels)\n + + + ... + (include as many samples as you deem necessary, but the bare minimum is 3) + + BREAK + + 1. [sample description here. Be as detailed as necessary in your prompting to the image generation model] + + 2. [full description that can be straight copy pasted into the community of choice] + + 3. [sample description here. Be as detailed as necessary in your prompting to the image generation model]\n [full description that can be straight copy pasted into the community of choice] + """ + + try: + startTime = time.time() + response = mainChat.send_message(model_prompt) + print(response.text) + endTime = time.time() + + # time logging + print(f'Gemini took approximately {abs(round((startTime - endTime), 2))} seconds to respond') + + # return repsonse from model + # return {"text" : response.text, "name" : productName, "description" : productDescription, "additional" : productAdditional} + return response.text + + except ResourceExhausted as resource_error: + print(f'You have exceeded the API call rate. Please wait a minute before trying again... \nError message from Google:\n{resource_error}') + except InternalServerError as internal_error: + retries += 1 + print(f'An expected error occured on Google\'s side. Retrying after 1 second cooldown... Attempt {retries}/3') + print(f'Error message from Google:\n{internal_error}') + except Exception as e: + print(f'Unknown error encountered. \nError message from Google:\n{e}') +# + # Allow CORS from the frontend URL app.add_middleware( CORSMiddleware, @@ -57,7 +150,7 @@ def get_output_data_by_name(name: str, db: Session) -> Optional[models.OutputDat @app.post("/submit_output/") async def submit_output(data: OutputDataRequest, db: Session = Depends(get_db)): # Check if a record with the same name already exists - existing_data = get_output_data_by_name(data.name, db) + existing_data = db.query(models.OutputData).filter(models.OutputData.name == data.name).first() if existing_data: # Update existing record @@ -68,10 +161,13 @@ async def submit_output(data: OutputDataRequest, db: Session = Depends(get_db)): db.refresh(existing_data) return {"message": "Data updated successfully", "data": existing_data} else: + # Call Gemini to get model response based on form data + gemini_response = call_gemini(data.name, data.description, data.extra or "No Additional Information to Provide") + # Insert new record if no existing record is found new_data = models.OutputData( name=data.name, - description=data.description, + description=gemini_response, # Storing the Gemini response in the description extra=data.extra, image=data.image, )