Skip to content

Commit

Permalink
feat: translator default model
Browse files Browse the repository at this point in the history
  • Loading branch information
Byaidu committed Dec 6, 2024
1 parent 3ee8704 commit 5612c41
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 29 deletions.
20 changes: 11 additions & 9 deletions pdf2zh/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,21 @@ def __init__(
self.noto = noto
self.translator: BaseTranslator = None
param = service.split(":", 1)
if param[0] == "google":
service_id = param[0]
service_model = param[1] if len(param) > 1 else None
if service_id == "google":
self.translator = GoogleTranslator(service, lang_out, lang_in, None)
elif param[0] == "deepl":
elif service_id == "deepl":
self.translator = DeepLTranslator(service, lang_out, lang_in, None)
elif param[0] == "deeplx":
elif service_id == "deeplx":
self.translator = DeepLXTranslator(service, lang_out, lang_in, None)
elif param[0] == "ollama":
self.translator = OllamaTranslator(service, lang_out, lang_in, param[1])
elif param[0] == "openai":
self.translator = OpenAITranslator(service, lang_out, lang_in, param[1])
elif param[0] == "azure":
elif service_id == "ollama":
self.translator = OllamaTranslator(service, lang_out, lang_in, service_model)
elif service_id == "openai":
self.translator = OpenAITranslator(service, lang_out, lang_in, service_model)
elif service_id == "azure":
self.translator = AzureTranslator(service, lang_out, lang_in, None)
elif param[0] == "tencent":
elif service_id == "tencent":
self.translator = TencentTranslator(service, lang_out, lang_in, None)
else:
raise ValueError("Unsupported translation service")
Expand Down
40 changes: 20 additions & 20 deletions pdf2zh/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def __init__(self, service, lang_out, lang_in, model):
def translate(self, text):
pass

def prompt(self, text):
return [
{
"role": "system",
"content": "You are a professional,authentic machine translation engine.",
},
{
"role": "user",
"content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:", # noqa: E501
},
]

def __str__(self):
return f"{self.service} {self.lang_out} {self.lang_in}"

Expand Down Expand Up @@ -140,11 +152,14 @@ class OllamaTranslator(BaseTranslator):
# https://github.com/ollama/ollama-python
envs = {
"OLLAMA_HOST": "http://127.0.0.1:11434",
"OLLAMA_MODEL": "gemma2",
}

def __init__(self, service, lang_out, lang_in, model):
lang_out = "zh-CN" if lang_out == "auto" else lang_out
lang_in = "en" if lang_in == "auto" else lang_in
if not model:
model = os.getenv("OLLAMA_MODEL", self.envs["OLLAMA_MODEL"])
super().__init__(service, lang_out, lang_in, model)
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
self.client = ollama.Client()
Expand All @@ -153,16 +168,7 @@ def translate(self, text):
response = self.client.chat(
model=self.model,
options=self.options,
messages=[
{
"role": "system",
"content": "You are a professional,authentic machine translation engine.",
},
{
"role": "user",
"content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:", # noqa: E501
},
],
messages=self.prompt(text),
)
return response["message"]["content"].strip()

Expand All @@ -172,11 +178,14 @@ class OpenAITranslator(BaseTranslator):
envs = {
"OPENAI_BASE_URL": "https://api.openai.com/v1",
"OPENAI_API_KEY": None,
"OPENAI_MODEL": "gpt-4o",
}

def __init__(self, service, lang_out, lang_in, model):
lang_out = "zh-CN" if lang_out == "auto" else lang_out
lang_in = "en" if lang_in == "auto" else lang_in
if not model:
model = os.getenv("OPENAI_MODEL", self.envs["OPENAI_MODEL"])
super().__init__(service, lang_out, lang_in, model)
self.options = {"temperature": 0} # 随机采样可能会打断公式标记
self.client = openai.OpenAI()
Expand All @@ -185,16 +194,7 @@ def translate(self, text) -> str:
response = self.client.chat.completions.create(
model=self.model,
**self.options,
messages=[
{
"role": "system",
"content": "You are a professional,authentic machine translation engine.",
},
{
"role": "user",
"content": f"Translate the following markdown source text to {self.lang_out}. Keep the formula notation $v*$ unchanged. Output translation directly without any additional text.\nSource Text: {text}\nTranslated Text:", # noqa: E501
},
],
messages=self.prompt(text),
)
return response.choices[0].message.content.strip()

Expand Down

0 comments on commit 5612c41

Please sign in to comment.