-
Notifications
You must be signed in to change notification settings - Fork 5
/
nearest_embeddings_full.lean
108 lines (104 loc) · 4.05 KB
/
nearest_embeddings_full.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import LeanCodePrompts.NearestEmbeddings
import LeanCodePrompts.EpsilonClusters
import Cache.IO
import LeanAide.Aides
import Lean.Data.Json
import Batteries.Util.Pickle
open Lean Cache.IO
unsafe def show_nearest_full (stdin stdout : IO.FS.Stream)
(docStringData: Array ((String × String × Bool × String) × FloatArray))
(descData: Array ((String × String × Bool × String) × FloatArray))
(concDescData: Array ((String × String × Bool × String) × FloatArray)): IO Unit := do
let inp ← stdin.getLine
logTimed "finding parameter"
let (descField, doc, num, penalty, halt) :=
match Json.parse inp with
| Except.error _ => ("docString", inp, 10, 2.0, false)
| Except.ok j =>
(j.getObjValAs? String "descField" |>.toOption.getD "docString",
j.getObjValAs? String "docString" |>.toOption.orElse
(fun _ => j.getObjValAs? String "doc_string" |>.toOption)
|>.getD inp,
j.getObjValAs? Nat "n" |>.toOption.getD 10,
j.getObjValAs? Float "penalty" |>.toOption.getD 2.0,
j.getObjValAs? Bool "halt" |>.toOption.getD false)
logTimed s!"finding nearest to `{doc}`"
let data := match descField with
| "docString" => docStringData
| "description" => descData
| "concise-description" => concDescData
| _ => docStringData
let embs ← nearestDocsToDocFull data doc num (penalty := penalty)
logTimed "found nearest"
let out :=
Lean.Json.arr <|
embs.toArray.map fun (doc, thm, isProp, name, d) =>
Json.mkObj <| [
("docString", Json.str doc),
("type", Json.str thm),
("isProp", Json.bool isProp),
("name", Json.str name),
("distance", toJson d)
]
stdout.putStrLn out.compress
stdout.flush
unless halt do
show_nearest_full stdin stdout docStringData descData concDescData
return ()
unsafe def checkAndFetch (descField: String) : IO Unit := do
let picklePath ← picklePath descField
let picklePresent ←
if ← picklePath.pathExists then
try
withUnpickle picklePath <|
fun (_ : EmbedData) => do
pure true
catch _ => pure false
else pure false
unless picklePresent do
IO.eprintln "Fetching embeddings ..."
let out ← runCurl #["--output", picklePath.toString, "https://storage.googleapis.com/leanaide_data/{picklePath.fileName.get!}"]
IO.eprintln "Fetched embeddings"
IO.eprintln out
unsafe def main (args: List String) : IO Unit := do
for descField in ["docString", "description", "concise-description"] do
checkAndFetch descField
match args.get? 0 with
| some doc =>
logTimed "starting nearest embedding process"
let descField := args.getD 1 "docString"
let picklePath ← picklePath descField
withUnpickle picklePath <|
fun (data : EmbedData) => do
IO.eprintln s!"Searching among {data.size} embeddings"
let num := (args[1]?.bind fun s => s.toNat?).getD 10
logTimed s!"finding nearest to `{doc}`"
let start ← IO.monoMsNow
let embs ← nearestDocsToDocFull data doc num (penalty := 2.0)
IO.println <|
embs.toArray.map fun (doc, thm, isProp, name, d) =>
Json.mkObj <| [
(descField, Json.str doc),
("type", Json.str thm),
("isProp", Json.bool isProp),
("name", Json.str name),
("distance", toJson d)
]
let finish ← IO.monoMsNow
logTimed "found nearest"
IO.eprintln s!"Time taken: {finish - start} ms"
| none =>
logTimed "No arguments provided, starting interactive mode"
withUnpickle (← picklePath "docString") <|fun
(docStringData : EmbedData) =>
do
withUnpickle (← picklePath "description") <|fun
(descData : EmbedData) =>
do
withUnpickle (← picklePath "concise-description") <|fun
(concDescData : EmbedData) =>
do
IO.eprintln "Enter the document string to find the nearest embeddings"
let stdin ← IO.getStdin
let stdout ← IO.getStdout
show_nearest_full stdin stdout docStringData descData concDescData