diff --git a/README.md b/README.md index 7a7734f3..917ae13f 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,10 @@ By default, this uses the `OPENAI_API_KEY` environment variable. By default, this uses the `ANTHROPIC_API_KEY` environment variable. +#### Google API (Gemini-1.5-Flash and Gem) + +By default, this uses the `GOOGLE_API_KEY` environment variable. + ##### Claude models via Bedrock For Claude models provided by [Amazon Bedrock](https://aws.amazon.com/bedrock/), please install these additional packages: diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py index a8feedfe..86334755 100644 --- a/ai_scientist/generate_ideas.py +++ b/ai_scientist/generate_ideas.py @@ -468,6 +468,8 @@ def check_idea_novelty( "gpt-4o-2024-05-13", "deepseek-coder-v2-0724", "llama3.1-405b", + "gemini-1.5-flash", + "gemini-1.5-pro" ], help="Model to use for AI Scientist.", ) @@ -529,6 +531,11 @@ def check_idea_novelty( api_key=os.environ["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1", ) + elif "gemini" in args.model: + import google.generativeai + + print(f"Using Gemini API with model {args.model}.") + client = google.generativeai.GenerativeModel(args.model) else: raise ValueError(f"Model {args.model} not supported.") diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py index 80996d9d..cee0eade 100644 --- a/ai_scientist/llm.py +++ b/ai_scientist/llm.py @@ -1,5 +1,6 @@ import backoff import openai +import google.generativeai import json @@ -88,6 +89,20 @@ def get_batch_responses_from_llm( ) content.append(c) new_msg_history.append(hist) + elif "gemini" in model: + content, new_msg_history = [], [] + for _ in range(n_responses): + c, hist = get_response_from_llm( + msg, + client, + model, + system_message, + print_debug=False, + msg_history=None, + temperature=temperature, + ) + content.append(c) + new_msg_history.append(hist) else: # TODO: This is only supported for GPT-4 in our reviewer pipeline. raise ValueError(f"Model {model} not supported.") @@ -199,6 +214,17 @@ def get_response_from_llm( ) content = response.choices[0].message.content new_msg_history = new_msg_history + [{"role": "assistant", "content": content}] + elif "gemini" in model: + new_msg_history = msg_history + [{"role": "user", "parts": msg}] + response = client.generate_content( + new_msg_history, + generation_config=google.generativeai.types.GenerationConfig( + max_output_tokens=3000, + temperature=temperature + ) + ) + content = response.text + new_msg_history = new_msg_history + [{"role": "model", "parts": content}] else: raise ValueError(f"Model {model} not supported.") diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index 78908032..7a77c5c3 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -528,6 +528,9 @@ def perform_writeup( "gpt-4o-2024-05-13", "deepseek-coder-v2-0724", "llama3.1-405b", + # Gemini models + "gemini-1.5-flash", + "gemini-1.5-pro", # Anthropic Claude models via Amazon Bedrock "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -587,6 +590,11 @@ def perform_writeup( api_key=os.environ["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1", ) + elif "gemini" in args.model: + import google.generativeai + + print(f"Using Gemini API with model {args.model}.") + client = google.generativeai.GenerativeModel(args.model) else: raise ValueError(f"Model {args.model} not recognized.") print("Make sure you cleaned the Aider logs if re-generating the writeup!") diff --git a/experimental/launch_oe_scientist.py b/experimental/launch_oe_scientist.py index 3c048a19..fdc10ed9 100644 --- a/experimental/launch_oe_scientist.py +++ b/experimental/launch_oe_scientist.py @@ -365,6 +365,11 @@ def do_idea( api_key=os.environ["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1", ) + elif "gemini" in args.model: + import google.generativeai + + print(f"Using Gemini API with model {args.model}.") + client = google.generativeai.GenerativeModel(args.model) else: raise ValueError(f"Model {args.model} not supported.") diff --git a/launch_scientist.py b/launch_scientist.py index 489c7fc8..6b3c1307 100644 --- a/launch_scientist.py +++ b/launch_scientist.py @@ -52,6 +52,9 @@ def parse_arguments(): "gpt-4o-2024-05-13", "deepseek-coder-v2-0724", "llama3.1-405b", + # Gemini Models + "gemini-1.5-flash", + "gemini-1.5-pro", # Anthropic Claude models via Amazon Bedrock "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -360,6 +363,11 @@ def do_idea( api_key=os.environ["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1", ) + elif "gemini" in args.model: + import google.generativeai + + print(f"Using Gemini API with model {args.model}.") + client = google.generativeai.GenerativeModel(args.model) else: raise ValueError(f"Model {args.model} not supported.") diff --git a/requirements.txt b/requirements.txt index 77e755fe..8971848d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ anthropic aider-chat backoff openai +google-generativeai # Viz matplotlib pypdf