-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimport.go
64 lines (49 loc) · 1.09 KB
/
import.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package bpe
import (
"encoding/json"
"io"
)
func Import(r io.Reader, opts ...ImportOption) (*BPE, error) {
options := defaultImportOptions()
options.Apply(opts...)
return options.Decoder.Decode(r)
}
func defaultImportOptions() *importOptions {
return &importOptions{
Decoder: &defaultDecoder{},
}
}
type ModelDecoder interface {
Decode(r io.Reader) (*BPE, error)
}
type importOptions struct {
Decoder ModelDecoder
}
func (o *importOptions) Apply(opts ...ImportOption) {
for _, opt := range opts {
opt(o)
}
}
type ImportOption func(opts *importOptions)
func WithDecoder(decoder ModelDecoder) ImportOption {
return func(opts *importOptions) {
opts.Decoder = decoder
}
}
type defaultDecoder struct{}
func (e *defaultDecoder) Decode(r io.Reader) (*BPE, error) {
dto := &exportedModel{}
err := json.NewDecoder(r).Decode(dto)
if err != nil {
return nil, err
}
vocab := make(map[string]struct{}, len(dto.Vocab))
for _, token := range dto.Vocab {
vocab[token] = struct{}{}
}
model := &BPE{
maxTokenLength: dto.MaxTokenLength,
vocab: vocab,
}
return model, nil
}