Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Gemini models #75

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ By default, this uses the `OPENAI_API_KEY` environment variable.

By default, this uses the `ANTHROPIC_API_KEY` environment variable.

#### Google API (Gemini-1.5-Flash and Gem)

By default, this uses the `GOOGLE_API_KEY` environment variable.

##### Claude models via Bedrock

For Claude models provided by [Amazon Bedrock](https://aws.amazon.com/bedrock/), please install these additional packages:
Expand Down
7 changes: 7 additions & 0 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ def check_idea_novelty(
"gpt-4o-2024-05-13",
"deepseek-coder-v2-0724",
"llama3.1-405b",
"gemini-1.5-flash",
"gemini-1.5-pro"
],
help="Model to use for AI Scientist.",
)
Expand Down Expand Up @@ -529,6 +531,11 @@ def check_idea_novelty(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)
elif "gemini" in args.model:
import google.generativeai

print(f"Using Gemini API with model {args.model}.")
client = google.generativeai.GenerativeModel(args.model)
else:
raise ValueError(f"Model {args.model} not supported.")

Expand Down
26 changes: 26 additions & 0 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import backoff
import openai
import google.generativeai
import json


Expand Down Expand Up @@ -88,6 +89,20 @@ def get_batch_responses_from_llm(
)
content.append(c)
new_msg_history.append(hist)
elif "gemini" in model:
content, new_msg_history = [], []
for _ in range(n_responses):
c, hist = get_response_from_llm(
msg,
client,
model,
system_message,
print_debug=False,
msg_history=None,
temperature=temperature,
)
content.append(c)
new_msg_history.append(hist)
else:
# TODO: This is only supported for GPT-4 in our reviewer pipeline.
raise ValueError(f"Model {model} not supported.")
Expand Down Expand Up @@ -199,6 +214,17 @@ def get_response_from_llm(
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif "gemini" in model:
new_msg_history = msg_history + [{"role": "user", "parts": msg}]
response = client.generate_content(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No system message!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks like Gemini doesn't have support for system messages as of now. This is what I found online:
Discussion1, Discussion2

new_msg_history,
generation_config=google.generativeai.types.GenerationConfig(
max_output_tokens=3000,
temperature=temperature
)
)
content = response.text
new_msg_history = new_msg_history + [{"role": "model", "parts": content}]
else:
raise ValueError(f"Model {model} not supported.")

Expand Down
8 changes: 8 additions & 0 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ def perform_writeup(
"gpt-4o-2024-05-13",
"deepseek-coder-v2-0724",
"llama3.1-405b",
# Gemini models
"gemini-1.5-flash",
"gemini-1.5-pro",
# Anthropic Claude models via Amazon Bedrock
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
Expand Down Expand Up @@ -587,6 +590,11 @@ def perform_writeup(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)
elif "gemini" in args.model:
import google.generativeai

print(f"Using Gemini API with model {args.model}.")
client = google.generativeai.GenerativeModel(args.model)
else:
raise ValueError(f"Model {args.model} not recognized.")
print("Make sure you cleaned the Aider logs if re-generating the writeup!")
Expand Down
5 changes: 5 additions & 0 deletions experimental/launch_oe_scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ def do_idea(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)
elif "gemini" in args.model:
import google.generativeai

print(f"Using Gemini API with model {args.model}.")
client = google.generativeai.GenerativeModel(args.model)
else:
raise ValueError(f"Model {args.model} not supported.")

Expand Down
8 changes: 8 additions & 0 deletions launch_scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def parse_arguments():
"gpt-4o-2024-05-13",
"deepseek-coder-v2-0724",
"llama3.1-405b",
# Gemini Models
"gemini-1.5-flash",
"gemini-1.5-pro",
# Anthropic Claude models via Amazon Bedrock
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
Expand Down Expand Up @@ -360,6 +363,11 @@ def do_idea(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)
elif "gemini" in args.model:
import google.generativeai

print(f"Using Gemini API with model {args.model}.")
client = google.generativeai.GenerativeModel(args.model)
else:
raise ValueError(f"Model {args.model} not supported.")

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ anthropic
aider-chat
backoff
openai
google-generativeai
# Viz
matplotlib
pypdf
Expand Down