Skip to content

Commit

Permalink
Start working on loading pth files
Browse files Browse the repository at this point in the history
  • Loading branch information
manu12121999 authored Dec 11, 2024
1 parent d616c20 commit 338657b
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion ctrl_c_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
__original_source__ = "https://github.com/manu12121999/ctrl_c_nn/blob/main/ctrl_c_nn.py"
__email__ = "[email protected]"

import pickle
import random
import struct
import zlib
import sys
import math
import operator
from multiprocessing import Pool
import zlib
from zipfile import ZipFile

sumprod = math.sumprod if sys.version_info >= (3, 12) else lambda p, q: sum([p_i*q_i for p_i, q_i in zip(p, q)])

Expand Down Expand Up @@ -677,6 +680,60 @@ def forward(self, x: Tensor):
# missing: weight_init, ConvTranspose2d, MaxPool2d, AvgPool2d, Softmax, BatchNorm2d, InstanceNorm2d, LayerNorm2d, Losses


class PthUnpickler(pickle.Unpickler):
def __init__(self, picklefile, zipfile):
self.zipfile = zipfile
super().__init__(picklefile)

def persistent_load(self, pid):
# print("pid", pid)
storage, class_type, storage_dir, gpu, size = pid
data_list = None
with self.zipfile.open(f'best/data/{storage_dir}') as f:
data = f.read()
if class_type == "int64":
data_list = [int.from_bytes(data[i: i+8], "little") for i in range(0, len(data), 8)]
elif class_type == "int32":
data_list = [int.from_bytes(data[i: i+4], "little") for i in range(0, len(data), 4)]
elif class_type == "float":
data_list = [struct.unpack('f', data[i: i+4])[0] for i in range(0, len(data), 4)]
return data_list

@staticmethod
def load_replacement(data_list, storage_offset, size, stride, *args):
assert storage_offset == 0
# print("storage:", "--", "size:", size, "stride:", stride)
tensor = Tensor(data_list)
if len(size) > 1:
tensor = tensor.reshape(size)
return tensor

def find_class(self, module, name):
# print("m", module, "n", name)
if module == 'torch' and name == 'FloatStorage':
return "float"
elif module == 'torch' and name == 'LongStorage':
return "int64"
elif module == 'torch' and name == 'IntStorage':
return "int32"
elif module == 'torch._utils' and name == '_rebuild_tensor_v2':
return self.load_replacement # torch._utils._rebuild_tensor_v2
else:
print("WARNING: loading module", module, "name", name)
return super().find_class(module, name)


def load_model(path):
with ZipFile(path) as zip_file:
with zip_file.open(f'best/data.pkl') as pickle_file:
p = PthUnpickler(pickle_file, zip_file)
print("loading ...")
model_dict = p.load()
print("loading completed")
print(model_dict["model_state_dict"])
return model_dict


class ImageIO:

@staticmethod
Expand Down

0 comments on commit 338657b

Please sign in to comment.