From a298f33dc5d784084fd3c65af5819485305cb235 Mon Sep 17 00:00:00 2001 From: gptlang <121417512+gptlang@users.noreply.github.com> Date: Thu, 1 Feb 2024 13:44:00 +0000 Subject: [PATCH] add system prompt as parameter to ask --- rplugin/python3/copilot.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/rplugin/python3/copilot.py b/rplugin/python3/copilot.py index 27591a29..3b849637 100644 --- a/rplugin/python3/copilot.py +++ b/rplugin/python3/copilot.py @@ -87,19 +87,12 @@ def authenticate(self): self.token = self.session.get(url, headers=headers).json() - def ask(self, prompt: str, code: str, language: str = ""): + def ask(self, system_prompt: str, prompt: str, code: str, language: str = ""): # If expired, reauthenticate if self.token.get("expires_at") <= round(time.time()): self.authenticate() url = "https://api.githubcopilot.com/chat/completions" self.chat_history.append(typings.Message(prompt, "user")) - system_prompt = prompts.COPILOT_INSTRUCTIONS - if prompt == prompts.FIX_SHORTCUT: - system_prompt = prompts.COPILOT_FIX - elif prompt == prompts.TEST_SHORTCUT: - system_prompt = prompts.COPILOT_TESTS - elif prompt == prompts.EXPLAIN_SHORTCUT: - system_prompt = prompts.COPILOT_EXPLAIN data = utilities.generate_request( self.chat_history, code, language, system_prompt=system_prompt ) @@ -141,7 +134,7 @@ def _get_embeddings(self, inputs: list[typings.FileExtract]): if i + 18 > len(inputs): data = utilities.generate_embedding_request(inputs[i:]) else: - data = utilities.generate_embedding_request(inputs[i : i + 18]) + data = utilities.generate_embedding_request(inputs[i: i + 18]) response = self.session.post(url, headers=self._headers(), json=data).json() if "data" not in response: raise Exception(f"Error fetching embeddings: {response}")