Skip to content

Commit

Permalink
feat (gui): add custom prompt (#275)
Browse files Browse the repository at this point in the history
* gui支持自定义页码。

* 修改页数显示bug

* GUI支持自定义prompt。

* format

---------

Co-authored-by: Byaidu <[email protected]>
  • Loading branch information
hellofinch and Byaidu authored Dec 18, 2024
1 parent 8a20102 commit 10d9cf1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
47 changes: 43 additions & 4 deletions pdf2zh/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"All": None,
"First": [0],
"First 5 pages": list(range(0, 5)),
"Others": None,
}

flag_demo = False
Expand Down Expand Up @@ -125,6 +126,9 @@ def translate_file(
lang_from,
lang_to,
page_range,
page_input,
prompt,
threads,
recaptcha_response,
state,
progress=gr.Progress(),
Expand Down Expand Up @@ -161,7 +165,16 @@ def translate_file(
file_dual = output / f"{filename}-dual.pdf"

translator = service_map[service]
selected_page = page_map[page_range]
if page_range != "Others":
selected_page = page_map[page_range]
else:
selected_page = []
for p in page_input.split(","):
if "-" in p:
start, end = p.split("-")
selected_page.extend(range(int(start) - 1, int(end)))
else:
selected_page.append(int(p) - 1)
lang_from = lang_map[lang_from]
lang_to = lang_map[lang_to]

Expand All @@ -181,10 +194,11 @@ def progress_bar(t: tqdm.tqdm):
"lang_out": lang_to,
"service": f"{translator.name}",
"output": output,
"thread": 4,
"thread": int(threads),

This comment has been minimized.

Copy link
@timelic

timelic Dec 19, 2024

Contributor

threads为空就报错了

"callback": progress_bar,
"cancellation_event": cancellation_event_map[session_id],
"envs": _envs,
"prompt": prompt,
}
try:
translate(**param)
Expand Down Expand Up @@ -319,15 +333,30 @@ def progress_bar(t: tqdm.tqdm):
value=list(page_map.keys())[0],
)

page_input = gr.Textbox(
label="Page range",
visible=False,
interactive=True,
)

with gr.Accordion("Open for More Experimental Options!", open=False):
gr.Markdown("#### Experimental")
threads = gr.Textbox(label="number of threads", interactive=True)
prompt = gr.Textbox(
label="Custom Prompt for llm", interactive=True, visible=False
)
envs.append(prompt)

def on_select_service(service, evt: gr.EventData):
translator = service_map[service]
_envs = []
for i in range(3):
for i in range(4):
_envs.append(gr.update(visible=False, value=""))
for i, env in enumerate(translator.envs.items()):
_envs[i] = gr.update(
visible=True, label=env[0], value=os.getenv(env[0], env[1])
)
_envs[-1] = gr.update(visible=translator.CustomPrompt)
return _envs

def on_select_filetype(file_type):
Expand All @@ -336,6 +365,12 @@ def on_select_filetype(file_type):
gr.update(visible=file_type == "Link"),
)

def on_select_page(choice):
if choice == "Others":
return gr.update(visible=True)
else:
return gr.update(visible=False)

output_title = gr.Markdown("## Translated", visible=False)
output_file_mono = gr.File(
label="Download Translation (Mono)", visible=False
Expand All @@ -358,6 +393,7 @@ def on_select_filetype(file_type):
""",
elem_classes=["secondary-text"],
)
page_range.select(on_select_page, page_range, page_input)
service.select(
on_select_service,
service,
Expand Down Expand Up @@ -422,6 +458,9 @@ def on_select_filetype(file_type):
lang_from,
lang_to,
page_range,
page_input,
prompt,
threads,
recaptcha_response,
state,
*envs,
Expand All @@ -445,7 +484,7 @@ def on_select_filetype(file_type):
def readuserandpasswd(file_path):
tuple_list = []
content = ""
if file_path is None:
if not file_path:
return tuple_list, content
if len(file_path) == 2:
try:
Expand Down
9 changes: 9 additions & 0 deletions pdf2zh/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BaseTranslator:
name = "base"
envs = {}
lang_map = {}
CustomPrompt = False

def __init__(self, lang_in, lang_out, model):
lang_in = self.lang_map.get(lang_in.lower(), lang_in)
Expand Down Expand Up @@ -200,6 +201,7 @@ class OllamaTranslator(BaseTranslator):
"OLLAMA_HOST": "http://127.0.0.1:11434",
"OLLAMA_MODEL": "gemma2",
}
CustomPrompt = True

def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
self.set_envs(envs)
Expand Down Expand Up @@ -230,6 +232,7 @@ class OpenAITranslator(BaseTranslator):
"OPENAI_API_KEY": None,
"OPENAI_MODEL": "gpt-4o-mini",
}
CustomPrompt = True

def __init__(
self,
Expand Down Expand Up @@ -265,6 +268,7 @@ class AzureOpenAITranslator(BaseTranslator):
"AZURE_OPENAI_API_KEY": None,
"AZURE_OPENAI_MODEL": "gpt-4o-mini",
}
CustomPrompt = True

def __init__(
self,
Expand Down Expand Up @@ -306,6 +310,7 @@ class ModelScopeTranslator(OpenAITranslator):
"MODELSCOPE_API_KEY": None,
"MODELSCOPE_MODEL": "Qwen/Qwen2.5-32B-Instruct",
}
CustomPrompt = True

def __init__(
self,
Expand Down Expand Up @@ -333,6 +338,7 @@ class ZhipuTranslator(OpenAITranslator):
"ZHIPU_API_KEY": None,
"ZHIPU_MODEL": "glm-4-flash",
}
CustomPrompt = True

def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
self.set_envs(envs)
Expand Down Expand Up @@ -367,6 +373,7 @@ class SiliconTranslator(OpenAITranslator):
"SILICON_API_KEY": None,
"SILICON_MODEL": "Qwen/Qwen2.5-7B-Instruct",
}
CustomPrompt = True

def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
self.set_envs(envs)
Expand All @@ -385,6 +392,7 @@ class GeminiTranslator(OpenAITranslator):
"GEMINI_API_KEY": None,
"GEMINI_MODEL": "gemini-1.5-flash",
}
CustomPrompt = True

def __init__(self, lang_in, lang_out, model, envs=None, prompt=None):
self.set_envs(envs)
Expand Down Expand Up @@ -458,6 +466,7 @@ class AnythingLLMTranslator(BaseTranslator):
"AnythingLLM_URL": None,
"AnythingLLM_APIKEY": None,
}
CustomPrompt = True

def __init__(self, lang_out, lang_in, model, envs=None, prompt=None):
self.set_envs(envs)
Expand Down

0 comments on commit 10d9cf1

Please sign in to comment.