-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathnotebook_utils.py
32 lines (27 loc) · 1.01 KB
/
notebook_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
from pathlib import Path
import io
import torch
import pickle
def print_models(base_path, model_string):
print(model_string)
for i in range(80):
for e in range(50):
exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()
if exists:
print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))
print()
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == 'Manager':
from settings import Manager
return Manager
try:
return self.find_class_cpu(module, name)
except:
return None
def find_class_cpu(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else:
return super().find_class(module, name)