Skip to content

Commit

Permalink
Fix everything
Browse files Browse the repository at this point in the history
  • Loading branch information
graceshawyan committed Nov 3, 2024
1 parent 8daf183 commit 7765f47
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 32 deletions.
1 change: 1 addition & 0 deletions FastAPI/api/database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# database.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
Expand Down
78 changes: 78 additions & 0 deletions FastAPI/api/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# gemini.py
import os
import time
import google.generativeai as genai
from google.api_core.exceptions import (
ResourceExhausted, InternalServerError
)

# Configure the API key from environment variable
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# AI Model Configuration
generation_config = {
"temperature": 1.0,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 2000,
"response_mime_type": "text/plain",
}

main_model = genai.GenerativeModel(
model_name="gemini-1.5-flash-002",
generation_config=generation_config,
system_instruction="""
You are the world's best salesperson and a business analyst. Your job is to identify potential communities to target for advertisement.
"""
)

mainChat = main_model.start_chat(history=[], enable_automatic_function_calling=False)
retries = 0

def call_gemini(productName: str, productDescription: str, productAdditional: str = "No Additional Information to Provide"):
global retries, mainChat

if retries >= 3:
retries = 0
return "An unexpected error occurred. Please try again later."

model_prompt = f"""
You are trying to sell a {productName}. Given the following information, identify at least three potential communities to target for advertisement:
### Product Name:
{productName}
### Product Description:
{productDescription}
### Additional Information (Optional):
{productAdditional}
### Instructions:
Step 1. Identify at least three very specific potential communities to target for advertisement. Be as specific as possible when naming communities to target. For each community, include a short justification for why this specific community.
Step 2. For each of your targeted communities, identify the optimal means of advertisement (either text or image, or both).
Step 3.
a. If creating a textual advertisement, generate a full-length text that can be directly used.
b. If creating a visual advertisement (image), generate a prompt for an AI image generator.
### Formatting Instructions:
- Follow the sample output as closely as possible.
- After outputting step 1, print "BREAK HERE" to separate sections.
"""

try:
startTime = time.time()
response = mainChat.send_message(model_prompt)
endTime = time.time()

print(f'Gemini responded in {round(endTime - startTime, 2)} seconds')
return response.text

except (ResourceExhausted, InternalServerError) as e:
retries += 1
print(f'Error encountered: {e}. Retrying... ({retries}/3)')
time.sleep(1)
return call_gemini(productName, productDescription, productAdditional)
except Exception as e:
print(f'Unknown error: {e}')
return "An unexpected error occurred while contacting Gemini."
42 changes: 21 additions & 21 deletions FastAPI/api/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# main.py
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import Optional
import os

from database import SessionLocal, engine
import models
from fastapi.middleware.cors import CORSMiddleware
from gemini import call_gemini # Import the function from gemini.py

# Create database tables
models.Base.metadata.create_all(bind=engine)

app = FastAPI()

# Allow CORS from the frontend URL
# Allow CORS from all origins (adjust as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"], # Frontend URL
allow_origins=["*"], # For development; specify origins in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand All @@ -24,8 +28,8 @@
class OutputDataRequest(BaseModel):
name: str
description: str
extra: str = None
image: str = None
extra: Optional[str] = None
image: Optional[str] = None

# Dependency to get a database session
def get_db():
Expand All @@ -41,44 +45,40 @@ def read_root():

# Helper function to get output data by name
def get_output_data_by_name(name: str, db: Session) -> Optional[models.OutputData]:
"""
Retrieve output data from the database by name.
Args:
name (str): The name of the output data to retrieve.
db (Session): The database session.
Returns:
Optional[models.OutputData]: The retrieved output data or None if not found.
"""
return db.query(models.OutputData).filter(models.OutputData.name == name).first()

# POST endpoint to submit or update data
@app.post("/submit_output/")
async def submit_output(data: OutputDataRequest, db: Session = Depends(get_db)):
# Call Gemini to get model response based on form data
gemini_response = call_gemini(data.name, data.description, data.extra or "No Additional Information to Provide")

if not gemini_response:
raise HTTPException(status_code=500, detail="Failed to get response from Gemini.")

# Check if a record with the same name already exists
existing_data = get_output_data_by_name(data.name, db)

if existing_data:
# Update existing record
existing_data.description = data.description
existing_data.description = gemini_response # Store Gemini response
existing_data.extra = data.extra
existing_data.image = data.image
db.commit()
db.refresh(existing_data)
return {"message": "Data updated successfully", "data": existing_data}
return {"message": "Data updated successfully", "data": existing_data, "gemini_response": gemini_response}
else:
# Insert new record if no existing record is found
# Insert new record
new_data = models.OutputData(
name=data.name,
description=data.description,
description=gemini_response, # Store Gemini response
extra=data.extra,
image=data.image,
)
db.add(new_data)
db.commit()
db.refresh(new_data)
return {"message": "Data stored successfully", "data": new_data}
return {"message": "Data stored successfully", "data": new_data, "gemini_response": gemini_response}

# GET endpoint to retrieve data by name
@app.get("/output_data/{name}")
Expand Down
5 changes: 3 additions & 2 deletions FastAPI/api/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# models.py
from sqlalchemy import Column, Integer, String
from database import Base

class OutputData(Base):
__tablename__ = 'output_data'

id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, index=True) # Keeps name unique to avoid duplicates
name = Column(String, unique=True, index=True)
description = Column(String)
extra = Column(String, nullable=True)
image = Column(String, nullable=True) # Stores image URL or base64 string
image = Column(String, nullable=True)
Binary file modified FastAPI/api/workout_app.db
Binary file not shown.
36 changes: 27 additions & 9 deletions src/components/Output.jsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Output.jsx
import React, { useEffect, useState } from 'react';
import axios from 'axios';

const Output = ({ formData }) => {
const [submittedData, setSubmittedData] = useState(null);
const [geminiResponse, setGeminiResponse] = useState(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);

Expand All @@ -15,13 +17,16 @@ const Output = ({ formData }) => {
const submitAndFetchData = async () => {
try {
// Submit data to backend
await axios.post('http://127.0.0.1:8000/submit_output/', {
const submitResponse = await axios.post('http://127.0.0.1:8000/submit_output/', {
name: formData.name,
description: formData.description,
extra: formData.extra,
image: formData.image,
});

// Set geminiResponse using data from response
setGeminiResponse(submitResponse.data.gemini_response);

// Fetch data from backend after submission
const response = await axios.get(`http://127.0.0.1:8000/output_data/${formData.name}`);
setSubmittedData(response.data);
Expand All @@ -43,16 +48,29 @@ const Output = ({ formData }) => {
<div className="text-white h-screen bg-black flex flex-col items-center justify-center">
<h1 className="text-5xl font-bold mb-6">Submission Summary</h1>
<div className="text-lg max-w-[600px] p-4 bg-black/50 border-2 border-[#00df9a] rounded-md">
<p><strong>ID:</strong> {submittedData.id}</p>
<p><strong>Product Name:</strong> {submittedData.name}</p>
<p><strong>Description:</strong> {submittedData.description}</p>
{submittedData.extra && <p><strong>Additional Info:</strong> {submittedData.extra}</p>}
{submittedData.image && (
<div className="mt-4">
<img src={submittedData.image} alt="Uploaded" className="w-full max-w-[400px] rounded-md shadow-lg" />
</div>
{submittedData && (
<>
<p><strong>ID:</strong> {submittedData.id}</p>
<p><strong>Product Name:</strong> {submittedData.name}</p>
<p><strong>Description:</strong> {submittedData.description}</p>
{submittedData.extra && <p><strong>Additional Info:</strong> {submittedData.extra}</p>}
{submittedData.image && (
<div className="mt-4">
<img src={submittedData.image} alt="Uploaded" className="w-full max-w-[400px] rounded-md shadow-lg" />
</div>
)}
</>
)}
</div>
<h1 className="text-5xl font-bold mb-6">Proposed Solution</h1>
<div className="text-lg max-w-[600px] p-4 bg-black/50 border-2 border-[#00df9a] rounded-md">
{geminiResponse ? (
<p>{geminiResponse}</p>
) : (
<p>No Gemini response available.</p>
)}
</div>

<button
onClick={() => window.location.reload()}
className="bg-[#00df9a] mt-6 w-[200px] rounded-md font-medium py-3 text-black"
Expand Down

0 comments on commit 7765f47

Please sign in to comment.