-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d616c20
commit 338657b
Showing
1 changed file
with
58 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,15 @@ | |
__original_source__ = "https://github.com/manu12121999/ctrl_c_nn/blob/main/ctrl_c_nn.py" | ||
__email__ = "[email protected]" | ||
|
||
import pickle | ||
import random | ||
import struct | ||
import zlib | ||
import sys | ||
import math | ||
import operator | ||
from multiprocessing import Pool | ||
import zlib | ||
from zipfile import ZipFile | ||
|
||
sumprod = math.sumprod if sys.version_info >= (3, 12) else lambda p, q: sum([p_i*q_i for p_i, q_i in zip(p, q)]) | ||
|
||
|
@@ -677,6 +680,60 @@ def forward(self, x: Tensor): | |
# missing: weight_init, ConvTranspose2d, MaxPool2d, AvgPool2d, Softmax, BatchNorm2d, InstanceNorm2d, LayerNorm2d, Losses | ||
|
||
|
||
class PthUnpickler(pickle.Unpickler): | ||
def __init__(self, picklefile, zipfile): | ||
self.zipfile = zipfile | ||
super().__init__(picklefile) | ||
|
||
def persistent_load(self, pid): | ||
# print("pid", pid) | ||
storage, class_type, storage_dir, gpu, size = pid | ||
data_list = None | ||
with self.zipfile.open(f'best/data/{storage_dir}') as f: | ||
data = f.read() | ||
if class_type == "int64": | ||
data_list = [int.from_bytes(data[i: i+8], "little") for i in range(0, len(data), 8)] | ||
elif class_type == "int32": | ||
data_list = [int.from_bytes(data[i: i+4], "little") for i in range(0, len(data), 4)] | ||
elif class_type == "float": | ||
data_list = [struct.unpack('f', data[i: i+4])[0] for i in range(0, len(data), 4)] | ||
return data_list | ||
|
||
@staticmethod | ||
def load_replacement(data_list, storage_offset, size, stride, *args): | ||
assert storage_offset == 0 | ||
# print("storage:", "--", "size:", size, "stride:", stride) | ||
tensor = Tensor(data_list) | ||
if len(size) > 1: | ||
tensor = tensor.reshape(size) | ||
return tensor | ||
|
||
def find_class(self, module, name): | ||
# print("m", module, "n", name) | ||
if module == 'torch' and name == 'FloatStorage': | ||
return "float" | ||
elif module == 'torch' and name == 'LongStorage': | ||
return "int64" | ||
elif module == 'torch' and name == 'IntStorage': | ||
return "int32" | ||
elif module == 'torch._utils' and name == '_rebuild_tensor_v2': | ||
return self.load_replacement # torch._utils._rebuild_tensor_v2 | ||
else: | ||
print("WARNING: loading module", module, "name", name) | ||
return super().find_class(module, name) | ||
|
||
|
||
def load_model(path): | ||
with ZipFile(path) as zip_file: | ||
with zip_file.open(f'best/data.pkl') as pickle_file: | ||
p = PthUnpickler(pickle_file, zip_file) | ||
print("loading ...") | ||
model_dict = p.load() | ||
print("loading completed") | ||
print(model_dict["model_state_dict"]) | ||
return model_dict | ||
|
||
|
||
class ImageIO: | ||
|
||
@staticmethod | ||
|