diff --git a/ctrl_c_nn.py b/ctrl_c_nn.py index 6f275dc..67846f1 100644 --- a/ctrl_c_nn.py +++ b/ctrl_c_nn.py @@ -4,12 +4,15 @@ __original_source__ = "https://github.com/manu12121999/ctrl_c_nn/blob/main/ctrl_c_nn.py" __email__ = "manu12121999@gmail.com" +import pickle import random +import struct +import zlib import sys import math import operator from multiprocessing import Pool -import zlib +from zipfile import ZipFile sumprod = math.sumprod if sys.version_info >= (3, 12) else lambda p, q: sum([p_i*q_i for p_i, q_i in zip(p, q)]) @@ -677,6 +680,60 @@ def forward(self, x: Tensor): # missing: weight_init, ConvTranspose2d, MaxPool2d, AvgPool2d, Softmax, BatchNorm2d, InstanceNorm2d, LayerNorm2d, Losses +class PthUnpickler(pickle.Unpickler): + def __init__(self, picklefile, zipfile): + self.zipfile = zipfile + super().__init__(picklefile) + + def persistent_load(self, pid): + # print("pid", pid) + storage, class_type, storage_dir, gpu, size = pid + data_list = None + with self.zipfile.open(f'best/data/{storage_dir}') as f: + data = f.read() + if class_type == "int64": + data_list = [int.from_bytes(data[i: i+8], "little") for i in range(0, len(data), 8)] + elif class_type == "int32": + data_list = [int.from_bytes(data[i: i+4], "little") for i in range(0, len(data), 4)] + elif class_type == "float": + data_list = [struct.unpack('f', data[i: i+4])[0] for i in range(0, len(data), 4)] + return data_list + + @staticmethod + def load_replacement(data_list, storage_offset, size, stride, *args): + assert storage_offset == 0 + # print("storage:", "--", "size:", size, "stride:", stride) + tensor = Tensor(data_list) + if len(size) > 1: + tensor = tensor.reshape(size) + return tensor + + def find_class(self, module, name): + # print("m", module, "n", name) + if module == 'torch' and name == 'FloatStorage': + return "float" + elif module == 'torch' and name == 'LongStorage': + return "int64" + elif module == 'torch' and name == 'IntStorage': + return "int32" + elif module == 'torch._utils' and name == '_rebuild_tensor_v2': + return self.load_replacement # torch._utils._rebuild_tensor_v2 + else: + print("WARNING: loading module", module, "name", name) + return super().find_class(module, name) + + +def load_model(path): + with ZipFile(path) as zip_file: + with zip_file.open(f'best/data.pkl') as pickle_file: + p = PthUnpickler(pickle_file, zip_file) + print("loading ...") + model_dict = p.load() + print("loading completed") + print(model_dict["model_state_dict"]) + return model_dict + + class ImageIO: @staticmethod