-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathllms.py
90 lines (75 loc) · 2.77 KB
/
llms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from typing import Literal
from dotenv import load_dotenv
load_dotenv()
def get_chatbot(
llm: Literal[
"gemini-2.0-flash-exp",
"gemini-1.5-pro",
"gemini-1.5-flash",
"gpt-4o",
"gpt-4",
"gpt-4o-mini",
"gpt-3.5-turbo-16k",
"mistral",
"claude",
"llama-3.1",
] = "claude",
**kwargs,
):
"""Get a chatbot instance.
Args:
llm: The language model to use. Options are "gemini-1.5-pro", "gemini-1.5-flash", "gpt-4o", "gpt-4", "gpt-4o-mini", "gpt-3.5-turbo-16k", "mistral", "claude" or "llama-3.1".
**kwargs: optional keyword arguments to pass to the chat model
Returns:
A chatbot instance.
"""
if llm in ["gpt-4o", "gpt-4", "gpt-4o-mini", "gpt-3.5-turbo-16k"]:
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=llm, **kwargs)
elif llm in ["llama-3.1"]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
if "NVIDIA_API_KEY" not in os.environ:
raise ValueError(
"`NVIDIA_API_KEY` is needed to call llama-3.1 but it is not set in .env."
)
return ChatNVIDIA(
model="meta/llama-3.1-405b-instruct",
api_key=os.getenv("NVIDIA_API_KEY"),
temperature=0.2,
top_p=0.7,
max_tokens=4096,
)
elif llm in ["gemini-1.5-pro", "gemini-1.5-flash", "gemini-2.0-flash-exp"]:
from langchain_google_vertexai import ChatVertexAI
# The GOOGLE_PROJECT_ID environment variable must be set
if "GOOGLE_PROJECT_ID" not in os.environ or not os.getenv(
"GOOGLE_APPLICATION_CREDENTIALS"
):
raise ValueError(
"`GOOGLE_PROJECT_ID` and `GOOGLE_APPLICATION_CREDENTIALS` are both needed to call Gemini but are not set in .env. See https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to for details"
)
return ChatVertexAI(
model=llm,
project=os.getenv("GOOGLE_PROJECT_ID"),
**kwargs,
)
elif llm == "mistral":
from langchain_community.chat_models import ChatOllama
return ChatOllama(model="mistral", temperature=0, **kwargs)
elif llm == "claude":
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(
model="claude-3-5-sonnet-20240620",
temperature=0,
timeout=None,
max_retries=2,
)
def get_max_token_length(llm: str) -> int:
if llm == "gpt-4o":
return 24_000
elif llm in ("gemini-2.0-flash-exp", "gemini-1.5-flash", "gemini-1.5-pro"):
# Can handle 1 million tokens but takes too long to process so not worth it!
return 500_000
else:
return 24_000