From e21035ec13c4a2e1bbc6ba92aebd0066ea52cfd4 Mon Sep 17 00:00:00 2001 From: Sean Morgan Date: Fri, 29 Mar 2024 09:37:04 -0700 Subject: [PATCH] Read model stream from start (#123) --- modelscan/model.py | 1 + modelscan/tools/picklescanner.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/modelscan/model.py b/modelscan/model.py index 43dd611..ff58035 100644 --- a/modelscan/model.py +++ b/modelscan/model.py @@ -51,4 +51,5 @@ def get_stream(self) -> IO[bytes]: if not self._stream: raise ModelDataEmpty("Model data is empty.") + self._stream.seek(0) return self._stream diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index b6e6078..80d7409 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -187,10 +187,11 @@ def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults: _ZIP_PREFIX = b"PK\x03\x04" _ZIP_SUFFIX = b"PK\x05\x06" # empty zip files start with this N = len(np.lib.format.MAGIC_PREFIX) - magic = model.get_stream().read(N) + stream = model.get_stream() + magic = stream.read(N) # If the file size is less than N, we need to make sure not # to seek past the beginning of the file - model.get_stream().seek(-min(N, len(magic)), 1) # back-up + stream.seek(-min(N, len(magic)), 1) # back-up if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): # .npz file return ScanResults( @@ -208,9 +209,9 @@ def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults: elif magic == np.lib.format.MAGIC_PREFIX: # .npy file - version = np.lib.format.read_magic(model.get_stream()) # type: ignore[no-untyped-call] + version = np.lib.format.read_magic(stream) # type: ignore[no-untyped-call] np.lib.format._check_version(version) # type: ignore[attr-defined] - _, _, dtype = np.lib.format._read_array_header(model.get_stream(), version) # type: ignore[attr-defined] + _, _, dtype = np.lib.format._read_array_header(stream, version) # type: ignore[attr-defined] if dtype.hasobject: return scan_pickle_bytes(model, settings, scan_name)