-
Notifications
You must be signed in to change notification settings - Fork 12
/
safetensors_worker.py
243 lines (203 loc) · 7.94 KB
/
safetensors_worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import os, sys, json
from safetensors_file import SafeTensorsFile
def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool:
if cmdLine["force_overwrite"]==False:
if os.path.exists(output_file):
print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr)
return True
return False
def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int:
if _need_force_overwrite(output_file,cmdLine): return -1
with open(in_json_file,"rt") as f:
inmeta=json.load(f)
if not "__metadata__" in inmeta:
print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr)
#json.dump(inmeta,fp=sys.stdout,indent=2)
return -2
inmeta=inmeta["__metadata__"] #keep only metadata
#json.dump(inmeta,fp=sys.stdout,indent=2)
s=SafeTensorsFile.open_file(in_st_file)
js=s.get_header()
if inmeta==[]:
js.pop("__metadata__",0)
print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header")
else:
print("adding __metadata__ to header:")
json.dump(inmeta,fp=sys.stdout,indent=2)
if isinstance(inmeta,dict):
for k in inmeta:
inmeta[k]=str(inmeta[k])
else:
inmeta=str(inmeta)
#js["__metadata__"]=json.dumps(inmeta,ensure_ascii=False)
js["__metadata__"]=inmeta
print()
newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8')
newhdrlen:int=int(len(newhdrbuf))
pad:int=((newhdrlen+7)&(~7))-newhdrlen #pad to multiple of 8
with open(output_file,"wb") as f:
f.write(int(newhdrlen+pad).to_bytes(8,'little'))
f.write(newhdrbuf)
if pad>0: f.write(bytearray([32]*pad))
i:int=s.copy_data_to_file(f)
if i==0:
print(f"file {output_file} saved successfully")
else:
print(f"error {i} occurred when writing to file {output_file}")
return i
def PrintHeader(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
# All the .safetensors files I've seen have long key names, and as a result,
# neither json nor pprint package prints text in very readable format,
# so we print it ourselves, putting key name & value on one long line.
# Note the print out is in Python format, not valid JSON format.
firstKey=True
print("{")
for key in js:
if firstKey:
firstKey=False
else:
print(",")
json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
print(": ",end='')
json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
print("\n}")
return 0
def _ParseMore(d:dict):
'''Basically try to turn this:
"ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}",
into this:
"ss_dataset_dirs":{
"abc":{
"n_repeats":2,
"img_count":60
}
},
'''
for key in d:
value=d[key]
#print("+++",key,value,type(value),"+++",sep='|')
if isinstance(value,str):
try:
v2=json.loads(value)
d[key]=v2
value=v2
except json.JSONDecodeError as e:
pass
if isinstance(value,dict):
_ParseMore(value)
def PrintMetadata(cmdLine:dict,input_file:str) -> int:
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
js=s.get_header()
if not "__metadata__" in js:
print("file header does not contain a __metadata__ item",file=sys.stderr)
return -2
md=js["__metadata__"]
if cmdLine['parse_more']:
_ParseMore(md)
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
return 0
def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
_lora_keys:list[tuple(str,bool)]=[] # use list to sort by name
for key in js:
if key=='__metadata__': continue
v=js[key]
isScalar=False
if isinstance(v,dict):
if 'shape' in v:
if 0==len(v['shape']):
isScalar=True
_lora_keys.append((key,isScalar))
_lora_keys.sort(key=lambda x:x[0])
def printkeylist(kl):
firstKey=True
for key in kl:
if firstKey: firstKey=False
else: print(",")
print(key,end='')
print()
print("# use list to keep insertion order")
print("_lora_keys:list[tuple[str,bool]]=[")
printkeylist(_lora_keys)
print("]")
return 0
def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:
if _need_force_overwrite(output_file,cmdLine): return -1
s=SafeTensorsFile.open_file(input_file,parseHeader=False)
if s.error!=0: return s.error
hdrbuf=s.hdrbuf
s.close_file() #close it in case user wants to write back to input_file itself
with open(output_file,"wb") as fo:
wn=fo.write(hdrbuf)
if wn!=len(hdrbuf):
print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr)
return -1
print(f"raw header saved to file {output_file}")
return 0
def _CheckLoRA_internal(s:SafeTensorsFile)->int:
import lora_keys_sd15 as lora_keys
js=s.get_header()
set_scalar=set()
set_nonscalar=set()
for x in lora_keys._lora_keys:
if x[1]==True: set_scalar.add(x[0])
else: set_nonscalar.add(x[0])
bad_unknowns:list[str]=[] # unrecognized keys
bad_scalars:list[str]=[] #bad scalar
bad_nonscalars:list[str]=[] #bad nonscalar
for key in js:
if key in set_nonscalar:
if js[key]['shape']==[]: bad_nonscalars.append(key)
set_nonscalar.remove(key)
elif key in set_scalar:
if js[key]['shape']!=[]: bad_scalars.append(key)
set_scalar.remove(key)
else:
if "__metadata__"!=key:
bad_unknowns.append(key)
hasError=False
if len(bad_unknowns)!=0:
print("INFO: unrecognized items:")
for x in bad_unknowns: print(" ",x)
#hasError=True
if len(set_scalar)>0:
print("missing scalar keys:")
for x in set_scalar: print(" ",x)
hasError=True
if len(set_nonscalar)>0:
print("missing nonscalar keys:")
for x in set_nonscalar: print(" ",x)
hasError=True
if len(bad_scalars)!=0:
print("keys expected to be scalar but are nonscalar:")
for x in bad_scalars: print(" ",x)
hasError=True
if len(bad_nonscalars)!=0:
print("keys expected to be nonscalar but are scalar:")
for x in bad_nonscalars: print(" ",x)
hasError=True
return (1 if hasError else 0)
def CheckLoRA(cmdLine:dict,input_file:str)->int:
s=SafeTensorsFile.open_file(input_file)
i:int=_CheckLoRA_internal(s)
if i==0: print("looks like an OK SD 1.x LoRA file")
return 0
def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
if _need_force_overwrite(output_file,cmdLine): return -1
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
if s.error!=0: return s.error
bindata=s.load_one_tensor(key_name)
s.close_file() #close it just in case user wants to write back to input_file itself
if bindata is None:
print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr)
return -1
with open(output_file,"wb") as fo:
wn=fo.write(bindata)
if wn!=len(bindata):
print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr)
return -1
if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}")
return 0