-
Notifications
You must be signed in to change notification settings - Fork 5
/
ctranslate.lean
126 lines (119 loc) · 4.56 KB
/
ctranslate.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import Lean.Meta
import LeanCodePrompts
import LeanCodePrompts.BatchTranslate
import LeanAide.Config
import Cli
open Lean Cli LeanAide.Meta LeanAide Translator
set_option maxHeartbeats 10000000
set_option maxRecDepth 1000
set_option compiler.extract_closed false
unsafe def runTranslate (p : Parsed) : IO UInt32 := do
searchPathRef.set compile_time_search_path%
let type :=
p.positionalArg? "input" |>.map (fun s => s.as! String)
|>.getD "thm"
let numSim := p.flag? "prompts" |>.map (fun s => s.as! Nat)
|>.getD 20
let numConcise := p.flag? "concise_descriptions" |>.map (fun s => s.as! Nat)
|>.getD 2
let queryNum := p.flag? "responses" |>.map (fun s => s.as! Nat)
|>.getD 10
let temp10 := p.flag? "temperature" |>.map (fun s => s.as! Nat)
|>.getD 8
let temp : JsonNumber := ⟨temp10, 1⟩
let model := p.flag? "model" |>.map (fun s => s.as! String)
|>.getD "gpt-4o"
let azure := p.hasFlag "azure"
let tag := p.hasFlag "tag"
let maxTokens := p.flag? "max_tokens" |>.map (fun s => s.as! Nat)
|>.getD 1600
let sysLess := p.hasFlag "no_sysprompt"
let url? := p.flag? "url" |>.map (fun s => s.as! String)
let showPrompt := p.hasFlag "show_prompt"
let chatServer :=
if azure then ChatServer.azure else
match url? with
| some url => ChatServer.generic model url !sysLess
| none => ChatServer.openAI model
let chatParams : ChatParams :=
{temp := temp, n := queryNum, maxTokens := maxTokens}
let gitHash ← gitHash
let dir :=
if tag then System.mkFilePath <| ["results", model, gitHash]
else System.mkFilePath <| ["results", model]
if !(← dir.pathExists) then
IO.FS.createDirAll dir
let env ←
importModules #[{module := `Mathlib},
{module:= `LeanAide.TheoremElab},
{module:= `LeanCodePrompts.Translate}] {}
withUnpickle (← picklePath "docString")
<|fun (docStringData : EmbedData) => do
withUnpickle (← picklePath "description")
<|fun (descData : EmbedData) => do
withUnpickle (← picklePath "concise-description")
<|fun (concDescData : EmbedData) => do
let dataMap :
EmbedMap := Std.HashMap.ofList [("docString", docStringData), ("description", descData), ("concise-description", concDescData)]
let translator : Translator := {server := chatServer, params := chatParams}
let core :=
translator.translateViewVerboseM type |>.runWithEmbeddings dataMap
let io? :=
core.run' {fileName := "", fileMap := {source:= "", positions := #[]}, maxHeartbeats := 0, maxRecDepth := 1000000}
{env := env}
let io?' ← io?.toIO'
match io?' with
| Except.ok (translation?, output, prompt) =>
IO.eprintln "Ran successfully"
if showPrompt then
IO.eprintln "Prompt:"
IO.eprintln prompt.pretty
IO.eprintln "---"
match translation? with
| some result =>
if p.hasFlag "show_elaborated" then
IO.eprintln "Elaborated terms:"
for out in result.allElaborated do
IO.eprintln out
IO.eprintln "---"
IO.eprintln "Groups:"
for gp in result.groups do
for out in gp do
IO.eprintln out
IO.eprintln "---"
IO.eprintln "Translation:"
IO.println result.view
return 0
| none =>
IO.eprintln "No translation"
IO.eprintln "All outputs:"
for out in output do
IO.eprintln <| "* " ++ out
return 2
| Except.error e =>
do
IO.eprintln "Ran with error"
let msg ← e.toMessageData.toString
IO.eprintln msg
return 1
unsafe def translate : Cmd := `[Cli|
translate VIA runTranslate;
"Elaborate a set of inputs and report whether successful and the result if successful."
FLAGS:
include_fixed; "Include the 'Lean Chat' fixed prompts."
p, prompts : Nat; "Number of example prompts (default 20)."
concise_descriptions : Nat; "Number of example descriptions (default 2)."
r, responses : Nat; "Number of responses to ask for (default 10)."
t, temperature : Nat; "Scaled temperature `t*10` for temperature `t` (default 8)."
m, model : String ; "Model to be used (default `gpt-4o`)"
azure; "Use Azure instead of OpenAI."
url : String; "URL to query (for a local server)."
show_prompt; "Output the prompt to the LLM."
show_elaborated; "Output the elaborated terms"
max_tokens : Nat; "Maximum tokens to use in the translation."
no_sysprompt; "The model has no system prompt (not relevant for GPT models)."
ARGS:
input : String; "The input file in the `data` folder."
]
unsafe def main (args: List String) : IO UInt32 :=
translate.validate args