diff --git a/rplugin/python3/copilot.py b/rplugin/python3/copilot.py index 27591a29..3b849637 100644 --- a/rplugin/python3/copilot.py +++ b/rplugin/python3/copilot.py @@ -87,19 +87,12 @@ def authenticate(self): self.token = self.session.get(url, headers=headers).json() - def ask(self, prompt: str, code: str, language: str = ""): + def ask(self, system_prompt: str, prompt: str, code: str, language: str = ""): # If expired, reauthenticate if self.token.get("expires_at") <= round(time.time()): self.authenticate() url = "https://api.githubcopilot.com/chat/completions" self.chat_history.append(typings.Message(prompt, "user")) - system_prompt = prompts.COPILOT_INSTRUCTIONS - if prompt == prompts.FIX_SHORTCUT: - system_prompt = prompts.COPILOT_FIX - elif prompt == prompts.TEST_SHORTCUT: - system_prompt = prompts.COPILOT_TESTS - elif prompt == prompts.EXPLAIN_SHORTCUT: - system_prompt = prompts.COPILOT_EXPLAIN data = utilities.generate_request( self.chat_history, code, language, system_prompt=system_prompt ) @@ -141,7 +134,7 @@ def _get_embeddings(self, inputs: list[typings.FileExtract]): if i + 18 > len(inputs): data = utilities.generate_embedding_request(inputs[i:]) else: - data = utilities.generate_embedding_request(inputs[i : i + 18]) + data = utilities.generate_embedding_request(inputs[i: i + 18]) response = self.session.post(url, headers=self._headers(), json=data).json() if "data" not in response: raise Exception(f"Error fetching embeddings: {response}")