Skip to content

Commit

Permalink
Merge pull request #4 from mosquito/bugfixes
Browse files Browse the repository at this point in the history
Refactor JWT Class into JWTDecoder and JWTSigner, Improve Type Safety and Test Coverage
  • Loading branch information
mosquito authored Jan 6, 2025
2 parents 9baab51 + a213617 commit 5ecd789
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 174 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,32 @@ on:
branches: [ master ]

jobs:
mypy:
runs-on: ubuntu-latest
strategy:
fail-fast: false

steps:
- uses: actions/checkout@v2

- name: Setup python3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"

- name: Install poetry
run: python -m pip install poetry

- name: Install dependencies
run: poetry install
env:
FORCE_COLOR: yes

- name: Run mypy
run: poetry run mypy jwt_rsa
env:
FORCE_COLOR: yes

tests:
runs-on: ubuntu-latest
strategy:
Expand Down
4 changes: 3 additions & 1 deletion jwt_rsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
RSAJWKPrivateKey, RSAJWKPublicKey, generate_rsa, load_private_key,
load_public_key, rsa_to_jwk,
)
from .token import JWT
from .token import JWT, JWTDecoder, JWTSigner
from .types import RSAPrivateKey, RSAPublicKey


__all__ = (
"JWT",
"JWTDecoder",
"JWTSigner",
"RSAJWKPrivateKey",
"RSAJWKPublicKey",
"RSAPrivateKey",
Expand Down
5 changes: 2 additions & 3 deletions jwt_rsa/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from argparse import ArgumentParser
from pathlib import Path

from jwt_rsa.types import AlgorithmType

from . import convert, issue, key_tester, keygen, pubkey, verify
from .token import ALGORITHMS


parser = ArgumentParser()
Expand All @@ -20,7 +19,7 @@
"--kid", dest="kid", type=str, default="", help="Key ID, will be generated if missing",
)
keygen_parser.add_argument(
"-a", "--algorithm", choices=AlgorithmType.__args__,
"-a", "--algorithm", choices=ALGORITHMS,
help="Key ID, will be generated if missing", default="RS512",
)
keygen_parser.add_argument("-u", "--use", dest="use", type=str, default="sig", choices=["sig", "enc"])
Expand Down
2 changes: 1 addition & 1 deletion jwt_rsa/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


def main(arguments: SimpleNamespace) -> None:
jwt = JWT(private_key=load_private_key(arguments.private_key))
jwt = JWT(load_private_key(arguments.private_key))

whoami = pwd.getpwuid(os.getuid())

Expand Down
35 changes: 22 additions & 13 deletions jwt_rsa/rsa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
from pathlib import Path
from typing import NamedTuple, Optional, TypedDict, Union, overload
from typing import NamedTuple, Optional, TypedDict, overload

from cryptography.hazmat.backends import default_backend

Expand All @@ -11,6 +11,11 @@


class KeyPair(NamedTuple):
private: RSAPrivateKey
public: RSAPublicKey


class JWKKeyPair(NamedTuple):
private: Optional[RSAPrivateKey]
public: RSAPublicKey

Expand Down Expand Up @@ -80,8 +85,8 @@ def load_jwk_private_key(jwk: RSAJWKPrivateKey) -> RSAPrivateKey:
return private_numbers.private_key(default_backend())


def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
jwk_dict: Union[RSAJWKPublicKey, RSAJWKPrivateKey]
def load_jwk(jwk: RSAJWKPublicKey | RSAJWKPrivateKey | str) -> JWKKeyPair:
jwk_dict: RSAJWKPublicKey | RSAJWKPrivateKey

if isinstance(jwk, str):
jwk_dict = json.loads(jwk)
Expand All @@ -92,10 +97,10 @@ def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
private_key = load_jwk_private_key(jwk_dict) # type: ignore
public_key = private_key.public_key()
else: # Public key
public_key = load_jwk_public_key(jwk_dict) # type: ignore
public_key = load_jwk_public_key(jwk_dict)
private_key = None

return KeyPair(private=private_key, public=public_key)
return JWKKeyPair(private=private_key, public=public_key)


def int_to_base64url(value: int) -> str:
Expand All @@ -106,24 +111,24 @@ def int_to_base64url(value: int) -> str:

@overload
def rsa_to_jwk(
key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig"
key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
) -> RSAJWKPublicKey: ...


@overload
def rsa_to_jwk( # type: ignore[overload-cannot-match]
def rsa_to_jwk(
key: RSAPrivateKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
) -> RSAJWKPrivateKey: ...


def rsa_to_jwk(
key: Union[RSAPrivateKey, RSAPublicKey],
key: RSAPrivateKey | RSAPublicKey,
*,
kid: str = "",
alg: AlgorithmType = "RS256",
use: str = "sig",
kty: str = "RSA",
) -> Union[RSAJWKPublicKey, RSAJWKPrivateKey]:
) -> RSAJWKPublicKey | RSAJWKPrivateKey:
if isinstance(key, RSAPublicKey):
public_numbers = key.public_numbers()
private_numbers = None
Expand Down Expand Up @@ -161,12 +166,14 @@ def rsa_to_jwk(
)


def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
def load_private_key(data: str | RSAJWKPrivateKey | Path) -> RSAPrivateKey:
if isinstance(data, Path):
data = data.read_text()
if isinstance(data, str):
if data.startswith("-----BEGIN "):
return serialization.load_pem_private_key(data.encode(), None, default_backend())
result = serialization.load_pem_private_key(data.encode(), None, default_backend())
assert isinstance(result, RSAPrivateKey)
return result
if data.strip().startswith("{"):
return load_jwk_private_key(json.loads(data))
if isinstance(data, dict):
Expand All @@ -177,12 +184,14 @@ def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
return key


def load_public_key(data: Union[str, RSAJWKPublicKey, Path]) -> RSAPublicKey:
def load_public_key(data: str | RSAJWKPublicKey | Path) -> RSAPublicKey:
if isinstance(data, Path):
data = data.read_text()
if isinstance(data, str):
if data.startswith("-----BEGIN "):
return serialization.load_pem_public_key(data.encode(), default_backend())
result = serialization.load_pem_public_key(data.encode(), default_backend())
assert isinstance(result, RSAPublicKey)
return result
if data.strip().startswith("{"):
return load_jwk_public_key(json.loads(data))
if isinstance(data, dict):
Expand Down
Loading

0 comments on commit 5ecd789

Please sign in to comment.