-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport.go
61 lines (47 loc) · 1.13 KB
/
export.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
package bpe
import (
"encoding/json"
"io"
)
func Export(model *BPE, w io.Writer, opts ...ExportOption) error {
options := defaultExportOptions()
options.Apply(opts...)
m := exportedModel{
MaxTokenLength: model.maxTokenLength,
Vocab: make([]string, 0, len(model.vocab)),
}
for t := range model.vocab {
m.Vocab = append(m.Vocab, t)
}
return options.Encoder.Encode(w, m)
}
func defaultExportOptions() *exportOptions {
return &exportOptions{
Encoder: &defaultEncoder{},
}
}
type ModelEncoder interface {
Encode(w io.Writer, model interface{}) error
}
type exportOptions struct {
Encoder ModelEncoder
}
func (o *exportOptions) Apply(opts ...ExportOption) {
for _, opt := range opts {
opt(o)
}
}
type ExportOption func(opts *exportOptions)
func WithEncoder(enc ModelEncoder) ExportOption {
return func(opts *exportOptions) {
opts.Encoder = enc
}
}
type exportedModel struct {
MaxTokenLength int `json:"max_token_length"`
Vocab []string `json:"vocab"`
}
type defaultEncoder struct{}
func (e *defaultEncoder) Encode(w io.Writer, model interface{}) error {
return json.NewEncoder(w).Encode(model)
}