Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom models in headlines and descriptions #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions services/content_generator_service.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

En general la idea está muy bien, pero yo no escribiría todo el código en este archivo. El ContentGeneratorService no debería saber a qué endpoint le tiene que pegar, ni encargarse de temas de autenticación. Para eso están los AuthenticationHelper y GeminiHelper. Trataría de reestructurarlo de manera que cada servicio se ocupe de su propósito y nada más.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import random
import re
import requests

from alive_progress import alive_bar
from prompts.prompts import prompts
from services.keyword_suggestion_service import KeywordSuggestionService
Expand All @@ -32,6 +31,11 @@
from utils.enums import SecondTermSource
from utils.gemini_helper import GeminiHelper
from utils.sheet_helper import GoogleSheetsHelper
from utils.authentication_helper import Authenticator
from utils.utils import Utils

config = Utils.load_config('config.json')


# Logger config
logging.basicConfig()
Expand All @@ -50,6 +54,8 @@ def __init__(self, config: dict[str, str]):
if 'google_ads_developer_token' in self.config and 'login_customer_id' in self.config:
self.keyword_suggestion_service = KeywordSuggestionService(self.config)
self.sheets_helper = GoogleSheetsHelper(self.config)
self.authenticator=Authenticator()
self.credentials=self.authenticator.authenticate(self.config)

def generate_content(
self,
Expand Down Expand Up @@ -789,10 +795,45 @@ def __generate_copies(
# Get copy generation prompt
prompt = self.__get_copy_generation_prompt(t, entry, num_copies)

generated_copies=[]
#Logic to determine the use of custom models or endpoints
try:
has_client_creds=self.authenticator.has_been_authenticated_with_client_credentials()
is_authenticated=self.authenticator.has_been_authenticated_with_client_credentials() != "unauthenticated"
headlines_endpoint_is_configured=self.body_params["custom_models"]["headlines"]["endpoint"]
descriptions_endpoint_is_configured=self.body_params["custom_models"]["descriptions"]["endpoint"]
self.authenticator.authenticate(self.config)
creds=self.authenticator.creds
access_token=creds.token
if(has_client_creds and is_authenticated and headlines_endpoint_is_configured):
if(t=='headlines'):
logging.info("Using custom endpoint for headlines.")
generated_copies = self.gemini_helper.generate_text_list(
prompt,
self.body_params["custom_models"]["headlines"]["endpoint"],
access_token
)
if(has_client_creds and is_authenticated and descriptions_endpoint_is_configured):
logging.info("Using custom model for descriptions.")
if(t=='descriptions'):
generated_copies = self.gemini_helper.generate_text_list(
prompt,
self.body_params["custom_models"]["descriptions"]["endpoint"],
access_token
)

except Exception as e:
logging.info("Incorrect Authentication for Custom Endpoint")
logging.error(e)
generated_copies=[]



# Generate copies
generated_copies = self.gemini_helper.generate_text_list(
prompt
)
if len(generated_copies)==0 or generated_copies==['Generation failed']:
generated_copies = self.gemini_helper.generate_text_list(
prompt
)

generated_copies_with_size_enforced = []

Expand Down
32 changes: 24 additions & 8 deletions utils/authentication_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,20 @@ class Authenticator:
"""This class creates credentials to authenticate with google APIs."""

def __init__(self) -> None:
pass
self.client_id=None
self.client_secret=None
self.refresh_token=None
Comment on lines +38 to +40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Estas 3 lineas no hacen falta

self.is_authentication_method_client_id='unauthenticated'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mejor inicializarlo así: self.is_authentication_method_client_id=False o self.is_authentication_method_client_id=None. De esta forma, la variable es de un solo tipo (bool)

self.creds=None

def has_been_authenticated_with_client_credentials(self):
"""
Can return true if authenticated with client credentials
Or false if it was done with service account
else it will be a tring 'unauthenticated'
"""
return self.is_authentication_method_client_id


def authenticate(self, config: dict[str, str]) -> object:
"""Authentication method.
Expand All @@ -46,18 +59,21 @@ def authenticate(self, config: dict[str, str]) -> object:
Returns:
object: The credentials object.
"""
client_id = config.get('client_id')
client_secret = config.get('client_secret')
refresh_token = config.get('refresh_token')
self.client_id = config.get('client_id')
self.client_secret = config.get('client_secret')
self.refresh_token = config.get('refresh_token')
Comment on lines -49 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No hace falta que sean variables de instancia (no hace falta el self.)


if not client_id or not client_secret or not refresh_token:
if not self.client_id or not self.client_secret or not self.refresh_token:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No hace falta que sean variables de instancia (no hace falta el self.)

creds, _ = default(scopes=API_SCOPES)
self.is_authentication_method_client_id=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No hace falta

else:
creds = Credentials.from_authorized_user_info({
'client_id': client_id,
'client_secret': client_secret,
'refresh_token': refresh_token,
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self.refresh_token,
Comment on lines -57 to +73
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No hace falta que sean variables de instancia (no hace falta el self.)

})
self.is_authentication_method_client_id=True

creds.refresh(Request())
self.creds=creds
return creds
53 changes: 41 additions & 12 deletions utils/gemini_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import os
import re
import time

import dirtyjson
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from prompts.prompts import prompts
import requests
import ast

# Logger config
logging.basicConfig()
Expand Down Expand Up @@ -106,7 +107,7 @@ def generate_dict(self, prompt: str) -> dict:
retries = retries - 1
return {'status': 'Error'}

def generate_text_list(self, prompt: str) -> list[str]:
def generate_text_list(self, prompt: str, custom_endpoint: str = "", access_token: str = None) -> list[str]:
"""Makes a request to Gemini and returns a list of strings.

Args:
Expand All @@ -115,19 +116,47 @@ def generate_text_list(self, prompt: str) -> list[str]:
Returns:
list[str]: a list of strings with the generated texts
"""

if custom_endpoint and not access_token:
raise ValueError("Access token is required when using a custom endpoint")


retries = RETRIES
while retries > 0:
try:
response = self.model.generate_content(
prompt,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS
)
start_idx = response.text.index('[')
end_idx = response.text.index(']')
time.sleep(TIME_INTERVAL_BETWEEN_REQUESTS)

return ast.literal_eval(response.text[start_idx:end_idx+1])
if custom_endpoint:
# Custom endpoint request
payload = {
"contents": {
"role": "USER",
"parts": {"text": prompt}
},
"generation_config": GENERATION_CONFIG
}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json"
}
response = requests.post(custom_endpoint, json=payload, headers=headers)


if response.status_code != 200:
raise Exception(f"Custom endpoint request failed with status code {response.status_code}")
response_data = response.json()
response_data = response_data['candidates'][0]['content']['parts'][0]['text']
start_idx = response_data.index('[')
end_idx = response_data.index(']')
return ast.literal_eval(response_data[start_idx:end_idx+1])
else:
response = self.model.generate_content(
prompt,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS
)
start_idx = response.text.index('[')
end_idx = response.text.index(']')
time.sleep(TIME_INTERVAL_BETWEEN_REQUESTS)
return ast.literal_eval(response.text[start_idx:end_idx+1])
Comment on lines -121 to +159
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo esto está muy bien, pero lo haría un poco diferente. En una función aparte, supongamos __call_gemini, haría que se llame al endpoint personalizado o al estándar según la configuración y retorne siempre lo mismo. De esta manera, el código de esta parte queda casi igual que antes, no se repite el procesado de la respuesta de gemini y es más prolijo

except Exception as e:
if 'Quota exceeded' in str(e) or 'quota' in str(e).lower():
logging.error(
Expand Down
7 changes: 6 additions & 1 deletion utils/sheet_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import logging
import backoff

from decimal import Decimal
import gspread
from utils.authentication_helper import Authenticator

Expand Down Expand Up @@ -105,7 +105,12 @@ def write_data_to_sheet(
"""
logging.info(' Writing data to sheet: %s', sheet_name)

#Catch the case where there are decimals to write
try:
for row in data:
for i, value in enumerate(row):
if isinstance(value, Decimal):
row[i] = str(value)
spreadsheet = self.client.open_by_key(sheet_id)
worksheet = spreadsheet.worksheet(sheet_name)
worksheet.update(sheet_range, data)
Expand Down