Skip to content

Commit

Permalink
Add support for customizing OpenAI endpoint (#34)
Browse files Browse the repository at this point in the history
* Bump `orch` to `0.0.16`

* Ignore `.DS_Store` files

* Add ability to customize OpenAI base endpoint

* Fix subcommand search tests

* Add test
  • Loading branch information
guywaldman authored Jul 27, 2024
1 parent 4f34754 commit 9ab64b4
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- name: Run E2E tests
run: ./scripts/e2e-tests.sh
env:
OPENAI_API_KEY_E2E: ${{ secrets.OPENAI_API_KEY_E2E }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY_E2E }}

linting:
name: Linting and formatting
Expand Down
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ __pycache__
.pytest_cache

.env*.local
temp
temp

# macOS
.DS_Store
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ repository = "https://github.com/guywaldman/magic-cli"
edition = "2021"

[dependencies]
orch = { version = "0.0.15" }
orch_response = { version = "0.0.15" } # Will be bundled inside `orch` (#10)
orch = { version = "0.0.16" }
orch_response = { version = "0.0.16" } # Will be bundled inside `orch` (#10)
chrono = "0.4.38"
clap = { version = "4.5.7", features = ["derive"] }
clipboard = "0.5.0"
Expand Down
14 changes: 14 additions & 0 deletions src/cli/config/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ impl ConfigKeys {
}),
).secret()
);
keys.insert(
"openai.api_endpoint".to_string(),
ConfigurationKey::new(
"openai.api_endpoint".to_string(),
"Custom API endpoint for the OpenAI API.".to_string(),
Box::new(|config: &mut MagicCliConfig, value: &str| {
if config.openai_config.is_none() {
config.openai_config = Some(OpenAiConfig::default());
}
config.openai_config.as_mut().unwrap().api_endpoint = Some(value.to_string());
Ok(())
}),
).secret()
);
keys.insert(
"openai.model".to_string(),
ConfigurationKey::new(
Expand Down
9 changes: 7 additions & 2 deletions src/cli/config/mcli_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,19 @@ impl MagicCliConfigManager {
let Some(api_key) = openai_config.api_key.clone() else {
return Err(MagicCliConfigError::MissingConfigKey("api_key".to_owned()));
};
let openai = OpenAiBuilder::new()
let mut openai_builder = OpenAiBuilder::new()
.with_model(model)
.with_embeddings_model(embedding_model)
.with_api_key(api_key)
.with_api_key(api_key);
if let Some(api_endpoint) = openai_config.api_endpoint.clone() {
openai_builder = openai_builder.with_api_endpoint(api_endpoint);
}
let openai = openai_builder
.try_build()
.map_err(|e| MagicCliConfigError::Configuration(e.to_string()))?;
Ok(Box::new(openai))
}
_ => Err(MagicCliConfigError::Configuration("Invalid LLM provider".to_string())),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/cli/subcommand_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl MagicCliSubcommand for SearchSubcommand {
};
if !allow_remote_llm {
return Err(Box::new(MagicCliConfigError::Configuration(
"Using remote LLM but `allow_remote_llm` is set to false. Set it to `true` if you are willing for remote LLM providers such as OpenAI to embed your shell history which may contains sensitive information.".to_string(),
"Using remote LLM but `search.allow_remote_llm` is set to false. Set it to `true` if you are willing for remote LLM providers such as OpenAI to embed your shell history which may contains sensitive information.".to_string(),
)));
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/lm/openai_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::cli::config::{ConfigOptions, MagicCliConfigError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiConfig {
pub api_key: Option<String>,
pub api_endpoint: Option<String>,
pub model: Option<String>,
pub embedding_model: Option<String>,
}
Expand All @@ -14,6 +15,7 @@ impl Default for OpenAiConfig {
fn default() -> Self {
Self {
api_key: None,
api_endpoint: None,
model: Some(openai_model::GPT_4O_MINI.to_string()),
embedding_model: Some(openai_embedding_model::TEXT_EMBEDDING_ADA_002.to_string()),
}
Expand All @@ -29,6 +31,10 @@ impl ConfigOptions for OpenAiConfig {
populated = true;
self.api_key = defaults.api_key;
}
if self.api_endpoint.is_none() {
populated = true;
self.api_endpoint = defaults.api_endpoint;
}
if self.model.is_none() {
populated = true;
self.model = defaults.model;
Expand Down
18 changes: 18 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,21 @@ def test_custom_config_set(self):
results = run_subcommand("config", ["get", "general.llm"], mcli_args=["--config", f.name])
assert results.status == 0
assert "openai" in results.stdout

def test_config_openai_api_endpoint_misconfigured(self):
"""Tests that the config subcommand fails if the user tries to set the OpenAI API endpoint but the configuration is not set correctly."""
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
with open(tmp.name, "w") as f:
f.write(json.dumps({"general": {"llm": "openai"}}))

results = run_subcommand(
"config", ["set", "--key", "openai.api_endpoint", "--value", "http://example.org"], mcli_args=["--config", f.name]
)
assert results.status == 0
results = run_subcommand(
"config", ["set", "--key", "openai.api_key", "--value", "non-existent-api-key"], mcli_args=["--config", f.name]
)
assert results.status == 0

results = run_subcommand("suggest", ["'Print the current directory using `ls`. Use only `ls`'"], mcli_args=["--config", f.name])
assert results.status != 0
25 changes: 21 additions & 4 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@ class TestSearchSubcommand:
"""Tests for the `search` subcommand, which allows users to semantically search for commands across their shell history."""

def test_basic_search_openai(self):
"""Tests basic search with OpenAI."""
env = Env.from_env()

with tempfile.TemporaryDirectory() as index_dir, tempfile.NamedTemporaryFile(
mode="w", delete=False
) as config, tempfile.NamedTemporaryFile(mode="w", delete=False) as shell_history:
with open(config.name, "w") as f:
f.write(json.dumps({"general": {"llm": "openai"}, "search": {"shell_history": shell_history.name, "index_dir": index_dir}}))
f.write(
json.dumps(
{
"general": {"llm": "openai"},
"search": {"shell_history": shell_history.name, "index_dir": index_dir, "allow_remote_llm": True},
}
)
)

with open(shell_history.name, "w") as f:
f.write("echo foobar\n")
Expand All @@ -28,14 +36,22 @@ def test_basic_search_openai(self):
assert results.status == 0
assert "echo foobar" in results.stdout

def test_search_with_remote_llm(self):
def test_search_with_remote_llm_misconfigured(self):
"""Tests that the search subcommand fails if the user tries to use a remote LLM but the configuration is not set correctly."""
env = Env.from_env()

with tempfile.TemporaryDirectory() as index_dir, tempfile.NamedTemporaryFile(
mode="w", delete=False
) as config, tempfile.NamedTemporaryFile(mode="w", delete=False) as shell_history:
with open(config.name, "w") as f:
f.write(json.dumps({"general": {"llm": "openai"}, "search": {"shell_history": shell_history.name, "index_dir": index_dir}}))
f.write(
json.dumps(
{
"general": {"llm": "openai"},
"search": {"shell_history": shell_history.name, "index_dir": index_dir},
}
)
)

with open(shell_history.name, "w") as f:
f.write("echo foobar\n")
Expand All @@ -46,4 +62,5 @@ def test_search_with_remote_llm(self):
set_config("openai.api_key", env.openai_api_key, config_path=config.name)

results = run_subcommand("search", ["foobar", "--output-only"], mcli_args=["--config", config.name])
print(results)
assert results.status != 0
assert "search.allow_remote_llm" in results.stderr
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def from_env(cls):
print(f"Loading environment variables from '{env_path}'")
dotenv.load_dotenv(dotenv_path=env_path)

openai_api_key = os.environ.get("OPENAI_API_KEY_E2E")
openai_api_key = os.environ.get("OPENAI_API_KEY")
if openai_api_key is None:
raise Exception("OPENAI_API_KEY_E2E environment variable not set")
raise Exception("OPENAI_API_KEY environment variable not set")

return cls(openai_api_key=openai_api_key)

0 comments on commit 9ab64b4

Please sign in to comment.