-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow custom models in headlines and descriptions #15
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,20 @@ class Authenticator: | |
"""This class creates credentials to authenticate with google APIs.""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
self.client_id=None | ||
self.client_secret=None | ||
self.refresh_token=None | ||
Comment on lines
+38
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Estas 3 lineas no hacen falta |
||
self.is_authentication_method_client_id='unauthenticated' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mejor inicializarlo así: |
||
self.creds=None | ||
|
||
def has_been_authenticated_with_client_credentials(self): | ||
""" | ||
Can return true if authenticated with client credentials | ||
Or false if it was done with service account | ||
else it will be a tring 'unauthenticated' | ||
""" | ||
return self.is_authentication_method_client_id | ||
|
||
|
||
def authenticate(self, config: dict[str, str]) -> object: | ||
"""Authentication method. | ||
|
@@ -46,18 +59,21 @@ def authenticate(self, config: dict[str, str]) -> object: | |
Returns: | ||
object: The credentials object. | ||
""" | ||
client_id = config.get('client_id') | ||
client_secret = config.get('client_secret') | ||
refresh_token = config.get('refresh_token') | ||
self.client_id = config.get('client_id') | ||
self.client_secret = config.get('client_secret') | ||
self.refresh_token = config.get('refresh_token') | ||
Comment on lines
-49
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No hace falta que sean variables de instancia (no hace falta el |
||
|
||
if not client_id or not client_secret or not refresh_token: | ||
if not self.client_id or not self.client_secret or not self.refresh_token: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No hace falta que sean variables de instancia (no hace falta el |
||
creds, _ = default(scopes=API_SCOPES) | ||
self.is_authentication_method_client_id=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No hace falta |
||
else: | ||
creds = Credentials.from_authorized_user_info({ | ||
'client_id': client_id, | ||
'client_secret': client_secret, | ||
'refresh_token': refresh_token, | ||
'client_id': self.client_id, | ||
'client_secret': self.client_secret, | ||
'refresh_token': self.refresh_token, | ||
Comment on lines
-57
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No hace falta que sean variables de instancia (no hace falta el self.) |
||
}) | ||
self.is_authentication_method_client_id=True | ||
|
||
creds.refresh(Request()) | ||
self.creds=creds | ||
return creds |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,11 +20,12 @@ | |
import os | ||
import re | ||
import time | ||
|
||
import dirtyjson | ||
import google.generativeai as genai | ||
from google.generativeai.types import HarmCategory, HarmBlockThreshold | ||
from prompts.prompts import prompts | ||
import requests | ||
import ast | ||
|
||
# Logger config | ||
logging.basicConfig() | ||
|
@@ -106,7 +107,7 @@ def generate_dict(self, prompt: str) -> dict: | |
retries = retries - 1 | ||
return {'status': 'Error'} | ||
|
||
def generate_text_list(self, prompt: str) -> list[str]: | ||
def generate_text_list(self, prompt: str, custom_endpoint: str = "", access_token: str = None) -> list[str]: | ||
"""Makes a request to Gemini and returns a list of strings. | ||
|
||
Args: | ||
|
@@ -115,19 +116,47 @@ def generate_text_list(self, prompt: str) -> list[str]: | |
Returns: | ||
list[str]: a list of strings with the generated texts | ||
""" | ||
|
||
if custom_endpoint and not access_token: | ||
raise ValueError("Access token is required when using a custom endpoint") | ||
|
||
|
||
retries = RETRIES | ||
while retries > 0: | ||
try: | ||
response = self.model.generate_content( | ||
prompt, | ||
generation_config=GENERATION_CONFIG, | ||
safety_settings=SAFETY_SETTINGS | ||
) | ||
start_idx = response.text.index('[') | ||
end_idx = response.text.index(']') | ||
time.sleep(TIME_INTERVAL_BETWEEN_REQUESTS) | ||
|
||
return ast.literal_eval(response.text[start_idx:end_idx+1]) | ||
if custom_endpoint: | ||
# Custom endpoint request | ||
payload = { | ||
"contents": { | ||
"role": "USER", | ||
"parts": {"text": prompt} | ||
}, | ||
"generation_config": GENERATION_CONFIG | ||
} | ||
headers = { | ||
"Authorization": f"Bearer {access_token}", | ||
"Content-Type": "application/json" | ||
} | ||
response = requests.post(custom_endpoint, json=payload, headers=headers) | ||
|
||
|
||
if response.status_code != 200: | ||
raise Exception(f"Custom endpoint request failed with status code {response.status_code}") | ||
response_data = response.json() | ||
response_data = response_data['candidates'][0]['content']['parts'][0]['text'] | ||
start_idx = response_data.index('[') | ||
end_idx = response_data.index(']') | ||
return ast.literal_eval(response_data[start_idx:end_idx+1]) | ||
else: | ||
response = self.model.generate_content( | ||
prompt, | ||
generation_config=GENERATION_CONFIG, | ||
safety_settings=SAFETY_SETTINGS | ||
) | ||
start_idx = response.text.index('[') | ||
end_idx = response.text.index(']') | ||
time.sleep(TIME_INTERVAL_BETWEEN_REQUESTS) | ||
return ast.literal_eval(response.text[start_idx:end_idx+1]) | ||
Comment on lines
-121
to
+159
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Todo esto está muy bien, pero lo haría un poco diferente. En una función aparte, supongamos |
||
except Exception as e: | ||
if 'Quota exceeded' in str(e) or 'quota' in str(e).lower(): | ||
logging.error( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
En general la idea está muy bien, pero yo no escribiría todo el código en este archivo. El ContentGeneratorService no debería saber a qué endpoint le tiene que pegar, ni encargarse de temas de autenticación. Para eso están los AuthenticationHelper y GeminiHelper. Trataría de reestructurarlo de manera que cada servicio se ocupe de su propósito y nada más.