-
Notifications
You must be signed in to change notification settings - Fork 12
/
safetensors_file.py
125 lines (103 loc) · 4 KB
/
safetensors_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os, sys, json
class SafeTensorsException(Exception):
def __init__(self, msg:str):
self.msg=msg
super().__init__(msg)
@staticmethod
def invalid_file(filename:str,whatiswrong:str):
s=f"{filename} is not a valid .safetensors file: {whatiswrong}"
return SafeTensorsException(msg=s)
def __str__(self):
return self.msg
class SafeTensorsChunk:
def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int):
self.name=name
self.dtype=dtype
self.shape=shape
self.offset0=offset0
self.offset1=offset1
class SafeTensorsFile:
def __init__(self):
self.f=None #file handle
self.hdrbuf=None #header byte buffer
self.header=None #parsed header as a dict
self.error=0
def __del__(self):
self.close_file()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close_file()
def close_file(self):
if self.f is not None:
self.f.close()
self.f=None
self.filename=""
#test file: duplicate_keys_in_header.safetensors
def _CheckDuplicateHeaderKeys(self):
def parse_object_pairs(pairs):
return [k for k,_ in pairs]
keys=json.loads(self.hdrbuf,object_pairs_hook=parse_object_pairs)
#print(keys)
d={}
for k in keys:
if k in d: d[k]=d[k]+1
else: d[k]=1
hasError=False
for k,v in d.items():
if v>1:
print(f"key {k} used {v} times in header",file=sys.stderr)
hasError=True
if hasError:
raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header")
@staticmethod
def open_file(filename:str,quiet=False,parseHeader=True):
s=SafeTensorsFile()
s.open(filename,quiet,parseHeader)
return s
def open(self,fn:str,quiet=False,parseHeader=True)->int:
st=os.stat(fn)
if st.st_size<8: #test file: zero_len_file.safetensors
raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")
f=open(fn,"rb")
b8=f.read(8) #read header size
if len(b8)!=8:
raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file")
headerlen=int.from_bytes(b8,'little',signed=False)
if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors
raise SafeTensorsException.invalid_file(fn,"header extends past end of file")
if quiet==False:
print(f"{fn}: length={st.st_size}, header length={headerlen}")
hdrbuf=f.read(headerlen)
if len(hdrbuf)!=headerlen:
raise SafeTensorsException.invalid_file(fn,f"header size is {headerlen}, but read {len(hdrbuf)} bytes")
self.filename=fn
self.f=f
self.st=st
self.hdrbuf=hdrbuf
self.error=0
self.headerlen=headerlen
if parseHeader==True:
self._CheckDuplicateHeaderKeys()
self.header=json.loads(self.hdrbuf)
return 0
def get_header(self):
return self.header
def load_one_tensor(self,tensor_name:str):
self.get_header()
if tensor_name not in self.header: return None
t=self.header[tensor_name]
self.f.seek(8+self.headerlen+t['data_offsets'][0])
bytesToRead=t['data_offsets'][1]-t['data_offsets'][0]
bytes=self.f.read(bytesToRead)
if len(bytes)!=bytesToRead:
print(f"{tensor_name}: length={bytesToRead}, only read {len(bytes)} bytes",file=sys.stderr)
return bytes
def copy_data_to_file(self,file_handle) -> int:
self.f.seek(8+self.headerlen)
bytesLeft:int=self.st.st_size - 8 - self.headerlen
while bytesLeft>0:
chunklen:int=min(bytesLeft,int(16*1024*1024)) #copy in blocks of 16 MB
file_handle.write(self.f.read(chunklen))
bytesLeft-=chunklen
return 0