Skip to content

Commit

Permalink
style: format namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
CandiedCode committed Apr 24, 2024
1 parent 492bf57 commit 3e0f975
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 97 deletions.
4 changes: 3 additions & 1 deletion notebooks/pytorch_sentiment_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@
}
],
"source": [
"import torch\n",
"import os\n",
"\n",
"import torch\n",
"from utils.pytorch_sentiment_model import download_model, predict_sentiment\n",
"\n",
"from tests.pickle_utils.codeinjection import PickleInject, get_inject_payload\n",
"\n",
"%env TOKENIZERS_PARALLELISM=false"
Expand Down
10 changes: 5 additions & 5 deletions notebooks/utils/pytorch_sentiment_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Final
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
from scipy.special import softmax
import csv
import urllib.request
from typing import Any, Final

import numpy as np
import torch
from scipy.special import softmax
from transformers import AutoModelForSequenceClassification, AutoTokenizer

SENTIMENT_TASK: Final[str] = "sentiment"

Expand Down
2 changes: 1 addition & 1 deletion notebooks/utils/xgboost_diabetes_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier


def get_data():
Expand Down
10 changes: 5 additions & 5 deletions notebooks/xgboost_diabetes_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"from pathlib import Path\n",
"import os\n",
"import numpy as np\n",
"from tests.pickle_utils.codeinjection import generate_unsafe_pickle_file\n",
"from utils.xgboost_diabetes_model import train_model, get_predictions"
"import pickle\n",
"\n",
"from utils.xgboost_diabetes_model import get_predictions, train_model\n",
"\n",
"from tests.pickle_utils.codeinjection import generate_unsafe_pickle_file"
]
},
{
Expand Down
167 changes: 84 additions & 83 deletions tests/pickle_utils/codeinjection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os
import pickle
import struct
from enum import Enum
from typing import TYPE_CHECKING, Any

import dill
Expand All @@ -18,6 +17,64 @@
from _typeshed import SupportsWrite


class PickleInject:
"""Pickle injection"""

def __init__(self, inj_objs: Any, first: bool = True):
self.__name__ = "pickle_inject"
self.inj_objs = inj_objs
self.first = first

class _Pickler(pickle._Pickler):
"""Re-implementation of Pickler with support for injection"""

def __init__(
self,
file: SupportsWrite[bytes],
protocol: int | None,
inj_objs: Any,
first: bool = True,
) -> None:
"""
file: File object with write attribute
protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html
inj_objs: _joblibInject object that has both the command, and the code to be injected
first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file.
"""
super().__init__(file, protocol)
self.inj_objs = inj_objs
self.first = first

def dump(self, obj: Any) -> None:
"""Pickle data, inject object before or after"""
if self.proto >= 2: # type: ignore[attr-defined]
self.write(pickle.PROTO + struct.pack("<B", self.proto)) # type: ignore[attr-defined]
if self.proto >= 4: # type: ignore[attr-defined]
self.framer.start_framing() # type: ignore[attr-defined]

# Inject the object(s) before the user-supplied data?
if self.first:
# Pickle injected objects
for inj_obj in self.inj_objs:
self.save(inj_obj) # type: ignore[attr-defined]

# Pickle user-supplied data
self.save(obj) # type: ignore[attr-defined]

# Inject the object(s) after the user-supplied data?
if not self.first:
# Pickle injected objects
for inj_obj in self.inj_objs:
self.save(inj_obj) # type: ignore[attr-defined]

self.write(pickle.STOP) # type: ignore[attr-defined]
self.framer.end_framing() # type: ignore[attr-defined]

def Pickler(self, file: Any, protocol: Any) -> _Pickler:
# Initialise the pickler interface with the injected object
return self._Pickler(file, protocol, self.inj_objs)


class _PickleInject:
"""Base class for pickling injected commands."""

Expand Down Expand Up @@ -69,99 +126,43 @@ def __reduce__(self) -> tuple[Any, ...]:
return self.command, (self.args, {})


class PicklePayload(Enum):
"""Enum for different Pickle Injection Payloads."""

SYSTEM = SystemInject
EXEC = ExecInject
EVAL = EvalInject
RUNPY = RunPyInject


class PickleInject:
"""Pickle injection. Pretends to be a "module" to work with torch."""

def __init__(self, inj_objs: Any, first: bool = True) -> None:
self.__name__ = "pickle_inject"
self.inj_objs = inj_objs
self.first = first

class _Pickler(pickle.Pickler):
"""Re-implementation of Pickler with support for injection"""

def __init__(
self,
file: SupportsWrite[bytes],
protocol: None | int,
inj_objs: Any,
first: bool = True,
) -> None:
"""Initialise the pickler with injected objects.
Args:
file: File object with write attribute.
protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html.
inj_objs: _joblibInject object that has both the command, and the code to be injected.
first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file.
"""
super().__init__(file, protocol)
self.inj_objs = inj_objs
self.first = first

def dump(self, obj: Any) -> None:
"""Pickle data, inject object before or after."""
if self.proto >= 2:
self.write(pickle.PROTO + struct.pack("<B", self.proto))
if self.proto >= 4:
self.framer.start_framing()

# Inject the object(s) before the user-supplied data?
if self.first:
# Pickle injected objects
for inj_obj in self.inj_objs:
self.save(inj_obj)

# Pickle user-supplied data
self.save(obj)

# Inject the object(s) after the user-supplied data?
if not self.first:
# Pickle injected objects
for inj_obj in self.inj_objs:
self.save(inj_obj)

self.write(pickle.STOP)
self.framer.end_framing()

def Pickler(self, file, protocol): # pylint: disable=protected-access
"""Initialise the pickler interface with the injected object."""
return self._Pickler(file, protocol, self.inj_objs)


def get_inject_payload(command: str, malicious_code: str) -> _PickleInject:
def get_inject_payload(
command: str, malicious_code: str
) -> SystemInject | ExecInject | EvalInject | RunPyInject:
"""Get the payload for the pickle injection.
Args:
command: The command to be injected.
malicious_code: The code to be injected.
Returns:
_PickleInject: The payload for the pickle injection.
PickleInject object.
Raises:
ValueError: If the command is not supported.
"""
pickle_inject = PicklePayload[command.upper()].value
return pickle_inject(malicious_code)
if command == "system":
return SystemInject(malicious_code)
if command == "exec":
return ExecInject(malicious_code)
if command == "eval":
return EvalInject(malicious_code)
if command == "runpy":
return RunPyInject(malicious_code)
else:
raise ValueError(f"Invalid command: {command}")


def generate_unsafe_pickle_file(
safe_model, command: str, malicious_code: str, unsafe_model_path: str
safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str
) -> None:
"""Create an unsafe pickled file with injected code.
Args:
safe_model: _description_
command: _description_
malicious_code: _description_
unsafe_model_path: _description_
safe_model: Safe model to be pickled.
command: The command to be injected.
malicious_code: The malicious to be injected.
unsafe_model_path: Path to save the unsafe model.
"""
payload = get_inject_payload(command, malicious_code)
pickle_protocol = 4
Expand All @@ -180,16 +181,16 @@ def __init__(self, inj_objs: Any, first: bool = True):
self.inj_objs = inj_objs
self.first = first

class _Pickler(dill.Pickler):
"""Re-implementation of Pickler with support for injection"""
class _Pickler(dill._dill.Pickler): # type: ignore[misc]
"""Reimplementation of Pickler with support for injection"""

def __init__(
self,
file: SupportsWrite[bytes],
protocol: int | None,
inj_objs: Any,
first: bool = True,
) -> None:
):
super().__init__(file, protocol)
self.inj_objs = inj_objs
self.first = first
Expand Down Expand Up @@ -219,7 +220,7 @@ def dump(self, obj: Any) -> None:
self.write(pickle.STOP)
self.framer.end_framing()

def DillPickler(self, file: Any, protocol: None | int) -> _Pickler:
def DillPickler(self, file: Any, protocol: Any) -> _Pickler:
# Initialise the pickler interface with the injected object
return self._Pickler(file, protocol, self.inj_objs)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from modelscan.settings import DEFAULT_SETTINGS
from modelscan.skip import SkipCategories
from modelscan.tools.picklescanner import scan_pickle_bytes
from tensorflow import keras

from tests.pickle_utils.codeinjection import (
generate_dill_unsafe_file,
generate_unsafe_pickle_file,
)
from tensorflow import keras

from tests.test_utils import MaliciousModule, PyTorchTestModel

settings: dict[str, Any] = DEFAULT_SETTINGS
Expand Down

0 comments on commit 3e0f975

Please sign in to comment.