back to Reference (Gold) summary
Reference (Gold): pyjwt
Pytest Summary for test tests
status | count |
---|---|
passed | 259 |
skipped | 2 |
xfailed | 1 |
total | 262 |
collected | 262 |
Failed pytests:
test_utils.py::test_to_base64url_uint[-1-]
test_utils.py::test_to_base64url_uint[-1-]
inputval = -1, expected = '' @pytest.mark.parametrize( "inputval,expected", [ (0, b"AA"), (1, b"AQ"), (255, b"_w"), (65537, b"AQAB"), (123456789, b"B1vNFQ"), pytest.param(-1, "", marks=pytest.mark.xfail(raises=ValueError)), ], ) def test_to_base64url_uint(inputval, expected): > actual = to_base64url_uint(inputval) tests/test_utils.py:18: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ val = -1 def to_base64url_uint(val: int) -> bytes: if val < 0: > raise ValueError("Must be a positive integer") E ValueError: Must be a positive integer jwt/utils.py:42: ValueError
Patch diff
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 15c200a..ed18715 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,49 +1,146 @@
from __future__ import annotations
+
import hashlib
import hmac
import json
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
+
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
-from .utils import base64url_decode, base64url_encode, der_to_raw_signature, force_bytes, from_base64url_uint, is_pem_format, is_ssh_key, raw_to_der_signature, to_base64url_uint
+from .utils import (
+ base64url_decode,
+ base64url_encode,
+ der_to_raw_signature,
+ force_bytes,
+ from_base64url_uint,
+ is_pem_format,
+ is_ssh_key,
+ raw_to_der_signature,
+ to_base64url_uint,
+)
+
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
+
+
try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
- from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, SECP256K1, SECP256R1, SECP384R1, SECP521R1, EllipticCurve, EllipticCurvePrivateKey, EllipticCurvePrivateNumbers, EllipticCurvePublicKey, EllipticCurvePublicNumbers
- from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey, Ed448PublicKey
- from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
- from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPrivateNumbers, RSAPublicKey, RSAPublicNumbers, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp, rsa_recover_prime_factors
- from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, PrivateFormat, PublicFormat, load_pem_private_key, load_pem_public_key, load_ssh_public_key
+ from cryptography.hazmat.primitives.asymmetric.ec import (
+ ECDSA,
+ SECP256K1,
+ SECP256R1,
+ SECP384R1,
+ SECP521R1,
+ EllipticCurve,
+ EllipticCurvePrivateKey,
+ EllipticCurvePrivateNumbers,
+ EllipticCurvePublicKey,
+ EllipticCurvePublicNumbers,
+ )
+ from cryptography.hazmat.primitives.asymmetric.ed448 import (
+ Ed448PrivateKey,
+ Ed448PublicKey,
+ )
+ from cryptography.hazmat.primitives.asymmetric.ed25519 import (
+ Ed25519PrivateKey,
+ Ed25519PublicKey,
+ )
+ from cryptography.hazmat.primitives.asymmetric.rsa import (
+ RSAPrivateKey,
+ RSAPrivateNumbers,
+ RSAPublicKey,
+ RSAPublicNumbers,
+ rsa_crt_dmp1,
+ rsa_crt_dmq1,
+ rsa_crt_iqmp,
+ rsa_recover_prime_factors,
+ )
+ from cryptography.hazmat.primitives.serialization import (
+ Encoding,
+ NoEncryption,
+ PrivateFormat,
+ PublicFormat,
+ load_pem_private_key,
+ load_pem_public_key,
+ load_ssh_public_key,
+ )
+
has_crypto = True
except ModuleNotFoundError:
has_crypto = False
+
+
if TYPE_CHECKING:
+ # Type aliases for convenience in algorithms method signatures
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
- AllowedOKPKeys = (Ed25519PrivateKey | Ed25519PublicKey |
- Ed448PrivateKey | Ed448PublicKey)
+ AllowedOKPKeys = (
+ Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
+ )
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
- AllowedPrivateKeys = (RSAPrivateKey | EllipticCurvePrivateKey |
- Ed25519PrivateKey | Ed448PrivateKey)
- AllowedPublicKeys = (RSAPublicKey | EllipticCurvePublicKey |
- Ed25519PublicKey | Ed448PublicKey)
-requires_cryptography = {'RS256', 'RS384', 'RS512', 'ES256', 'ES256K',
- 'ES384', 'ES521', 'ES512', 'PS256', 'PS384', 'PS512', 'EdDSA'}
-
-
-def get_default_algorithms() ->dict[str, Algorithm]:
+ AllowedPrivateKeys = (
+ RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
+ )
+ AllowedPublicKeys = (
+ RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
+ )
+
+
+requires_cryptography = {
+ "RS256",
+ "RS384",
+ "RS512",
+ "ES256",
+ "ES256K",
+ "ES384",
+ "ES521",
+ "ES512",
+ "PS256",
+ "PS384",
+ "PS512",
+ "EdDSA",
+}
+
+
+def get_default_algorithms() -> dict[str, Algorithm]:
"""
Returns the algorithms that are implemented by the library.
"""
- pass
+ default_algorithms = {
+ "none": NoneAlgorithm(),
+ "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
+ "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
+ "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
+ }
+
+ if has_crypto:
+ default_algorithms.update(
+ {
+ "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
+ "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
+ "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
+ "ES256": ECAlgorithm(ECAlgorithm.SHA256),
+ "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
+ "ES384": ECAlgorithm(ECAlgorithm.SHA384),
+ "ES521": ECAlgorithm(ECAlgorithm.SHA512),
+ "ES512": ECAlgorithm(
+ ECAlgorithm.SHA512
+ ), # Backward compat for #219 fix
+ "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
+ "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
+ "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
+ "EdDSA": OKPAlgorithm(),
+ }
+ )
+
+ return default_algorithms
class Algorithm(ABC):
@@ -51,53 +148,74 @@ class Algorithm(ABC):
The interface for an algorithm used to sign and verify tokens.
"""
- def compute_hash_digest(self, bytestr: bytes) ->bytes:
+ def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
If there is no hash algorithm, raises a NotImplementedError.
"""
- pass
+ # lookup self.hash_alg if defined in a way that mypy can understand
+ hash_alg = getattr(self, "hash_alg", None)
+ if hash_alg is None:
+ raise NotImplementedError
+
+ if (
+ has_crypto
+ and isinstance(hash_alg, type)
+ and issubclass(hash_alg, hashes.HashAlgorithm)
+ ):
+ digest = hashes.Hash(hash_alg(), backend=default_backend())
+ digest.update(bytestr)
+ return bytes(digest.finalize())
+ else:
+ return bytes(hash_alg(bytestr).digest())
@abstractmethod
- def prepare_key(self, key: Any) ->Any:
+ def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
- pass
@abstractmethod
- def sign(self, msg: bytes, key: Any) ->bytes:
+ def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
- pass
@abstractmethod
- def verify(self, msg: bytes, key: Any, sig: bytes) ->bool:
+ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
- pass
+
+ @overload
+ @staticmethod
+ @abstractmethod
+ def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ @abstractmethod
+ def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
@staticmethod
@abstractmethod
- def to_jwk(key_obj, as_dict: bool=False) ->Union[JWKDict, str]:
+ def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
"""
Serializes a given key into a JWK
"""
- pass
@staticmethod
@abstractmethod
- def from_jwk(jwk: (str | JWKDict)) ->Any:
+ def from_jwk(jwk: str | JWKDict) -> Any:
"""
Deserializes a given key from JWK back into a key object
"""
- pass
class NoneAlgorithm(Algorithm):
@@ -106,54 +224,478 @@ class NoneAlgorithm(Algorithm):
operations are required.
"""
+ def prepare_key(self, key: str | None) -> None:
+ if key == "":
+ key = None
+
+ if key is not None:
+ raise InvalidKeyError('When alg = "none", key value must be None.')
+
+ return key
+
+ def sign(self, msg: bytes, key: None) -> bytes:
+ return b""
+
+ def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
+ return False
+
+ @staticmethod
+ def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
+ raise NotImplementedError()
+
+ @staticmethod
+ def from_jwk(jwk: str | JWKDict) -> NoReturn:
+ raise NotImplementedError()
+
class HMACAlgorithm(Algorithm):
"""
Performs signing and verification operations using HMAC
and the specified hash function.
"""
+
SHA256: ClassVar[HashlibHash] = hashlib.sha256
SHA384: ClassVar[HashlibHash] = hashlib.sha384
SHA512: ClassVar[HashlibHash] = hashlib.sha512
- def __init__(self, hash_alg: HashlibHash) ->None:
+ def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg
+ def prepare_key(self, key: str | bytes) -> bytes:
+ key_bytes = force_bytes(key)
-if has_crypto:
+ if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
+ raise InvalidKeyError(
+ "The specified key is an asymmetric key or x509 certificate and"
+ " should not be used as an HMAC secret."
+ )
+
+ return key_bytes
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
+ jwk = {
+ "k": base64url_encode(force_bytes(key_obj)).decode(),
+ "kty": "oct",
+ }
+
+ if as_dict:
+ return jwk
+ else:
+ return json.dumps(jwk)
+
+ @staticmethod
+ def from_jwk(jwk: str | JWKDict) -> bytes:
+ try:
+ if isinstance(jwk, str):
+ obj: JWKDict = json.loads(jwk)
+ elif isinstance(jwk, dict):
+ obj = jwk
+ else:
+ raise ValueError
+ except ValueError:
+ raise InvalidKeyError("Key is not valid JSON")
+
+ if obj.get("kty") != "oct":
+ raise InvalidKeyError("Not an HMAC key")
+
+ return base64url_decode(obj["k"])
+
+ def sign(self, msg: bytes, key: bytes) -> bytes:
+ return hmac.new(key, msg, self.hash_alg).digest()
+
+ def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
+ return hmac.compare_digest(sig, self.sign(msg, key))
+
+
+if has_crypto:
class RSAAlgorithm(Algorithm):
"""
Performs signing and verification operations using
RSASSA-PKCS-v1_5 and the specified hash function.
"""
+
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: type[hashes.HashAlgorithm]) ->None:
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
+ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
+ if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
+ return key
+
+ if not isinstance(key, (bytes, str)):
+ raise TypeError("Expecting a PEM-formatted key.")
+
+ key_bytes = force_bytes(key)
+
+ try:
+ if key_bytes.startswith(b"ssh-rsa"):
+ return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
+ else:
+ return cast(
+ RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
+ )
+ except ValueError:
+ return cast(RSAPublicKey, load_pem_public_key(key_bytes))
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(
+ key_obj: AllowedRSAKeys, as_dict: bool = False
+ ) -> Union[JWKDict, str]:
+ obj: dict[str, Any] | None = None
+
+ if hasattr(key_obj, "private_numbers"):
+ # Private key
+ numbers = key_obj.private_numbers()
+
+ obj = {
+ "kty": "RSA",
+ "key_ops": ["sign"],
+ "n": to_base64url_uint(numbers.public_numbers.n).decode(),
+ "e": to_base64url_uint(numbers.public_numbers.e).decode(),
+ "d": to_base64url_uint(numbers.d).decode(),
+ "p": to_base64url_uint(numbers.p).decode(),
+ "q": to_base64url_uint(numbers.q).decode(),
+ "dp": to_base64url_uint(numbers.dmp1).decode(),
+ "dq": to_base64url_uint(numbers.dmq1).decode(),
+ "qi": to_base64url_uint(numbers.iqmp).decode(),
+ }
+
+ elif hasattr(key_obj, "verify"):
+ # Public key
+ numbers = key_obj.public_numbers()
+
+ obj = {
+ "kty": "RSA",
+ "key_ops": ["verify"],
+ "n": to_base64url_uint(numbers.n).decode(),
+ "e": to_base64url_uint(numbers.e).decode(),
+ }
+ else:
+ raise InvalidKeyError("Not a public or private key")
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
+
+ @staticmethod
+ def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
+ try:
+ if isinstance(jwk, str):
+ obj = json.loads(jwk)
+ elif isinstance(jwk, dict):
+ obj = jwk
+ else:
+ raise ValueError
+ except ValueError:
+ raise InvalidKeyError("Key is not valid JSON")
+
+ if obj.get("kty") != "RSA":
+ raise InvalidKeyError("Not an RSA key")
+
+ if "d" in obj and "e" in obj and "n" in obj:
+ # Private key
+ if "oth" in obj:
+ raise InvalidKeyError(
+ "Unsupported RSA private key: > 2 primes not supported"
+ )
+
+ other_props = ["p", "q", "dp", "dq", "qi"]
+ props_found = [prop in obj for prop in other_props]
+ any_props_found = any(props_found)
+
+ if any_props_found and not all(props_found):
+ raise InvalidKeyError(
+ "RSA key must include all parameters if any are present besides d"
+ )
+
+ public_numbers = RSAPublicNumbers(
+ from_base64url_uint(obj["e"]),
+ from_base64url_uint(obj["n"]),
+ )
+
+ if any_props_found:
+ numbers = RSAPrivateNumbers(
+ d=from_base64url_uint(obj["d"]),
+ p=from_base64url_uint(obj["p"]),
+ q=from_base64url_uint(obj["q"]),
+ dmp1=from_base64url_uint(obj["dp"]),
+ dmq1=from_base64url_uint(obj["dq"]),
+ iqmp=from_base64url_uint(obj["qi"]),
+ public_numbers=public_numbers,
+ )
+ else:
+ d = from_base64url_uint(obj["d"])
+ p, q = rsa_recover_prime_factors(
+ public_numbers.n, d, public_numbers.e
+ )
+
+ numbers = RSAPrivateNumbers(
+ d=d,
+ p=p,
+ q=q,
+ dmp1=rsa_crt_dmp1(d, p),
+ dmq1=rsa_crt_dmq1(d, q),
+ iqmp=rsa_crt_iqmp(p, q),
+ public_numbers=public_numbers,
+ )
+
+ return numbers.private_key()
+ elif "n" in obj and "e" in obj:
+ # Public key
+ return RSAPublicNumbers(
+ from_base64url_uint(obj["e"]),
+ from_base64url_uint(obj["n"]),
+ ).public_key()
+ else:
+ raise InvalidKeyError("Not a public or private key")
+
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
+ return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
+
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
+ try:
+ key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
+ return True
+ except InvalidSignature:
+ return False
class ECAlgorithm(Algorithm):
"""
Performs signing and verification operations using
ECDSA and the specified hash function
"""
+
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
- def __init__(self, hash_alg: type[hashes.HashAlgorithm]) ->None:
+ def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
+ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
+ if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
+ return key
+
+ if not isinstance(key, (bytes, str)):
+ raise TypeError("Expecting a PEM-formatted key.")
+
+ key_bytes = force_bytes(key)
+
+ # Attempt to load key. We don't know if it's
+ # a Signing Key or a Verifying Key, so we try
+ # the Verifying Key first.
+ try:
+ if key_bytes.startswith(b"ecdsa-sha2-"):
+ crypto_key = load_ssh_public_key(key_bytes)
+ else:
+ crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
+ except ValueError:
+ crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
+
+ # Explicit check the key to prevent confusing errors from cryptography
+ if not isinstance(
+ crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
+ ):
+ raise InvalidKeyError(
+ "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
+ )
+
+ return crypto_key
+
+ def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
+ der_sig = key.sign(msg, ECDSA(self.hash_alg()))
+
+ return der_to_raw_signature(der_sig, key.curve)
+
+ def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
+ try:
+ der_sig = raw_to_der_signature(sig, key.curve)
+ except ValueError:
+ return False
+
+ try:
+ public_key = (
+ key.public_key()
+ if isinstance(key, EllipticCurvePrivateKey)
+ else key
+ )
+ public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
+ return True
+ except InvalidSignature:
+ return False
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(
+ key_obj: AllowedECKeys, as_dict: bool = False
+ ) -> Union[JWKDict, str]:
+ if isinstance(key_obj, EllipticCurvePrivateKey):
+ public_numbers = key_obj.public_key().public_numbers()
+ elif isinstance(key_obj, EllipticCurvePublicKey):
+ public_numbers = key_obj.public_numbers()
+ else:
+ raise InvalidKeyError("Not a public or private key")
+
+ if isinstance(key_obj.curve, SECP256R1):
+ crv = "P-256"
+ elif isinstance(key_obj.curve, SECP384R1):
+ crv = "P-384"
+ elif isinstance(key_obj.curve, SECP521R1):
+ crv = "P-521"
+ elif isinstance(key_obj.curve, SECP256K1):
+ crv = "secp256k1"
+ else:
+ raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
+
+ obj: dict[str, Any] = {
+ "kty": "EC",
+ "crv": crv,
+ "x": to_base64url_uint(public_numbers.x).decode(),
+ "y": to_base64url_uint(public_numbers.y).decode(),
+ }
+
+ if isinstance(key_obj, EllipticCurvePrivateKey):
+ obj["d"] = to_base64url_uint(
+ key_obj.private_numbers().private_value
+ ).decode()
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
+
+ @staticmethod
+ def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
+ try:
+ if isinstance(jwk, str):
+ obj = json.loads(jwk)
+ elif isinstance(jwk, dict):
+ obj = jwk
+ else:
+ raise ValueError
+ except ValueError:
+ raise InvalidKeyError("Key is not valid JSON")
+
+ if obj.get("kty") != "EC":
+ raise InvalidKeyError("Not an Elliptic curve key")
+
+ if "x" not in obj or "y" not in obj:
+ raise InvalidKeyError("Not an Elliptic curve key")
+
+ x = base64url_decode(obj.get("x"))
+ y = base64url_decode(obj.get("y"))
+
+ curve = obj.get("crv")
+ curve_obj: EllipticCurve
+
+ if curve == "P-256":
+ if len(x) == len(y) == 32:
+ curve_obj = SECP256R1()
+ else:
+ raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
+ elif curve == "P-384":
+ if len(x) == len(y) == 48:
+ curve_obj = SECP384R1()
+ else:
+ raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
+ elif curve == "P-521":
+ if len(x) == len(y) == 66:
+ curve_obj = SECP521R1()
+ else:
+ raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
+ elif curve == "secp256k1":
+ if len(x) == len(y) == 32:
+ curve_obj = SECP256K1()
+ else:
+ raise InvalidKeyError(
+ "Coords should be 32 bytes for curve secp256k1"
+ )
+ else:
+ raise InvalidKeyError(f"Invalid curve: {curve}")
+
+ public_numbers = EllipticCurvePublicNumbers(
+ x=int.from_bytes(x, byteorder="big"),
+ y=int.from_bytes(y, byteorder="big"),
+ curve=curve_obj,
+ )
+
+ if "d" not in obj:
+ return public_numbers.public_key()
+
+ d = base64url_decode(obj.get("d"))
+ if len(d) != len(x):
+ raise InvalidKeyError(
+ "D should be {} bytes for curve {}", len(x), curve
+ )
+
+ return EllipticCurvePrivateNumbers(
+ int.from_bytes(d, byteorder="big"), public_numbers
+ ).private_key()
class RSAPSSAlgorithm(RSAAlgorithm):
"""
Performs a signature using RSASSA-PSS with MGF1
"""
+ def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
+ return key.sign(
+ msg,
+ padding.PSS(
+ mgf=padding.MGF1(self.hash_alg()),
+ salt_length=self.hash_alg().digest_size,
+ ),
+ self.hash_alg(),
+ )
+
+ def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
+ try:
+ key.verify(
+ sig,
+ msg,
+ padding.PSS(
+ mgf=padding.MGF1(self.hash_alg()),
+ salt_length=self.hash_alg().digest_size,
+ ),
+ self.hash_alg(),
+ )
+ return True
+ except InvalidSignature:
+ return False
class OKPAlgorithm(Algorithm):
"""
@@ -162,11 +704,35 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed.
"""
- def __init__(self, **kwargs: Any) ->None:
+ def __init__(self, **kwargs: Any) -> None:
pass
- def sign(self, msg: (str | bytes), key: (Ed25519PrivateKey |
- Ed448PrivateKey)) ->bytes:
+ def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
+ if isinstance(key, (bytes, str)):
+ key_str = key.decode("utf-8") if isinstance(key, bytes) else key
+ key_bytes = key.encode("utf-8") if isinstance(key, str) else key
+
+ if "-----BEGIN PUBLIC" in key_str:
+ key = load_pem_public_key(key_bytes) # type: ignore[assignment]
+ elif "-----BEGIN PRIVATE" in key_str:
+ key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
+ elif key_str[0:4] == "ssh-":
+ key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
+
+ # Explicit check the key to prevent confusing errors from cryptography
+ if not isinstance(
+ key,
+ (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
+ ):
+ raise InvalidKeyError(
+ "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
+ )
+
+ return key
+
+ def sign(
+ self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
+ ) -> bytes:
"""
Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign
@@ -174,10 +740,12 @@ if has_crypto:
or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes
"""
- pass
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+ return key.sign(msg_bytes)
- def verify(self, msg: (str | bytes), key: AllowedOKPKeys, sig: (str |
- bytes)) ->bool:
+ def verify(
+ self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
+ ) -> bool:
"""
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
@@ -187,4 +755,108 @@ if has_crypto:
A private or public EdDSA key instance
:return bool verified: True if signature is valid, False if not.
"""
- pass
+ try:
+ msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+ sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
+
+ public_key = (
+ key.public_key()
+ if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
+ else key
+ )
+ public_key.verify(sig_bytes, msg_bytes)
+ return True # If no exception was raised, the signature is valid.
+ except InvalidSignature:
+ return False
+
+ @overload
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
+ ... # pragma: no cover
+
+ @overload
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
+ ... # pragma: no cover
+
+ @staticmethod
+ def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
+ if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
+ x = key.public_bytes(
+ encoding=Encoding.Raw,
+ format=PublicFormat.Raw,
+ )
+ crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
+
+ obj = {
+ "x": base64url_encode(force_bytes(x)).decode(),
+ "kty": "OKP",
+ "crv": crv,
+ }
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
+
+ if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
+ d = key.private_bytes(
+ encoding=Encoding.Raw,
+ format=PrivateFormat.Raw,
+ encryption_algorithm=NoEncryption(),
+ )
+
+ x = key.public_key().public_bytes(
+ encoding=Encoding.Raw,
+ format=PublicFormat.Raw,
+ )
+
+ crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
+ obj = {
+ "x": base64url_encode(force_bytes(x)).decode(),
+ "d": base64url_encode(force_bytes(d)).decode(),
+ "kty": "OKP",
+ "crv": crv,
+ }
+
+ if as_dict:
+ return obj
+ else:
+ return json.dumps(obj)
+
+ raise InvalidKeyError("Not a public or private key")
+
+ @staticmethod
+ def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
+ try:
+ if isinstance(jwk, str):
+ obj = json.loads(jwk)
+ elif isinstance(jwk, dict):
+ obj = jwk
+ else:
+ raise ValueError
+ except ValueError:
+ raise InvalidKeyError("Key is not valid JSON")
+
+ if obj.get("kty") != "OKP":
+ raise InvalidKeyError("Not an Octet Key Pair")
+
+ curve = obj.get("crv")
+ if curve != "Ed25519" and curve != "Ed448":
+ raise InvalidKeyError(f"Invalid curve: {curve}")
+
+ if "x" not in obj:
+ raise InvalidKeyError('OKP should have "x" parameter')
+ x = base64url_decode(obj.get("x"))
+
+ try:
+ if "d" not in obj:
+ if curve == "Ed25519":
+ return Ed25519PublicKey.from_public_bytes(x)
+ return Ed448PublicKey.from_public_bytes(x)
+ d = base64url_decode(obj.get("d"))
+ if curve == "Ed25519":
+ return Ed25519PrivateKey.from_private_bytes(d)
+ return Ed448PrivateKey.from_private_bytes(d)
+ except ValueError as err:
+ raise InvalidKeyError("Invalid key parameter") from err
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index 175e782..456c7f4 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -1,86 +1,132 @@
from __future__ import annotations
+
import json
import time
from typing import Any
+
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError
from .types import JWKDict
class PyJWK:
-
- def __init__(self, jwk_data: JWKDict, algorithm: (str | None)=None) ->None:
+ def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
- kty = self._jwk_data.get('kty', None)
+
+ kty = self._jwk_data.get("kty", None)
if not kty:
- raise InvalidKeyError(f'kty is not found: {self._jwk_data}')
+ raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
+
if not algorithm and isinstance(self._jwk_data, dict):
- algorithm = self._jwk_data.get('alg', None)
+ algorithm = self._jwk_data.get("alg", None)
+
if not algorithm:
- crv = self._jwk_data.get('crv', None)
- if kty == 'EC':
- if crv == 'P-256' or not crv:
- algorithm = 'ES256'
- elif crv == 'P-384':
- algorithm = 'ES384'
- elif crv == 'P-521':
- algorithm = 'ES512'
- elif crv == 'secp256k1':
- algorithm = 'ES256K'
+ # Determine alg with kty (and crv).
+ crv = self._jwk_data.get("crv", None)
+ if kty == "EC":
+ if crv == "P-256" or not crv:
+ algorithm = "ES256"
+ elif crv == "P-384":
+ algorithm = "ES384"
+ elif crv == "P-521":
+ algorithm = "ES512"
+ elif crv == "secp256k1":
+ algorithm = "ES256K"
else:
- raise InvalidKeyError(f'Unsupported crv: {crv}')
- elif kty == 'RSA':
- algorithm = 'RS256'
- elif kty == 'oct':
- algorithm = 'HS256'
- elif kty == 'OKP':
+ raise InvalidKeyError(f"Unsupported crv: {crv}")
+ elif kty == "RSA":
+ algorithm = "RS256"
+ elif kty == "oct":
+ algorithm = "HS256"
+ elif kty == "OKP":
if not crv:
- raise InvalidKeyError(f'crv is not found: {self._jwk_data}'
- )
- if crv == 'Ed25519':
- algorithm = 'EdDSA'
+ raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
+ if crv == "Ed25519":
+ algorithm = "EdDSA"
else:
- raise InvalidKeyError(f'Unsupported crv: {crv}')
+ raise InvalidKeyError(f"Unsupported crv: {crv}")
else:
- raise InvalidKeyError(f'Unsupported kty: {kty}')
+ raise InvalidKeyError(f"Unsupported kty: {kty}")
+
if not has_crypto and algorithm in requires_cryptography:
- raise PyJWKError(
- f"{algorithm} requires 'cryptography' to be installed.")
+ raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")
+
self.Algorithm = self._algorithms.get(algorithm)
+
if not self.Algorithm:
- raise PyJWKError(
- f'Unable to find an algorithm for key: {self._jwk_data}')
+ raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
+
self.key = self.Algorithm.from_jwk(self._jwk_data)
+ @staticmethod
+ def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
+ return PyJWK(obj, algorithm)
-class PyJWKSet:
+ @staticmethod
+ def from_json(data: str, algorithm: None = None) -> "PyJWK":
+ obj = json.loads(data)
+ return PyJWK.from_dict(obj, algorithm)
+
+ @property
+ def key_type(self) -> str | None:
+ return self._jwk_data.get("kty", None)
+
+ @property
+ def key_id(self) -> str | None:
+ return self._jwk_data.get("kid", None)
- def __init__(self, keys: list[JWKDict]) ->None:
+ @property
+ def public_key_use(self) -> str | None:
+ return self._jwk_data.get("use", None)
+
+
+class PyJWKSet:
+ def __init__(self, keys: list[JWKDict]) -> None:
self.keys = []
+
if not keys:
- raise PyJWKSetError('The JWK Set did not contain any keys')
+ raise PyJWKSetError("The JWK Set did not contain any keys")
+
if not isinstance(keys, list):
- raise PyJWKSetError('Invalid JWK Set value')
+ raise PyJWKSetError("Invalid JWK Set value")
+
for key in keys:
try:
self.keys.append(PyJWK(key))
except PyJWTError:
+ # skip unusable keys
continue
+
if len(self.keys) == 0:
raise PyJWKSetError(
"The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
- )
+ )
+
+ @staticmethod
+ def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
+ keys = obj.get("keys", [])
+ return PyJWKSet(keys)
+
+ @staticmethod
+ def from_json(data: str) -> "PyJWKSet":
+ obj = json.loads(data)
+ return PyJWKSet.from_dict(obj)
- def __getitem__(self, kid: str) ->'PyJWK':
+ def __getitem__(self, kid: str) -> "PyJWK":
for key in self.keys:
if key.key_id == kid:
return key
- raise KeyError(f'keyset has no key for kid: {kid}')
+ raise KeyError(f"keyset has no key for kid: {kid}")
class PyJWTSetWithTimestamp:
-
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()
+
+ def get_jwk_set(self) -> PyJWKSet:
+ return self.jwk_set
+
+ def get_timestamp(self) -> float:
+ return self.timestamp
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index eeb5924..fa6708c 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -1,51 +1,89 @@
from __future__ import annotations
+
import binascii
import json
import warnings
from typing import TYPE_CHECKING, Any
-from .algorithms import Algorithm, get_default_algorithms, has_crypto, requires_cryptography
-from .exceptions import DecodeError, InvalidAlgorithmError, InvalidSignatureError, InvalidTokenError
+
+from .algorithms import (
+ Algorithm,
+ get_default_algorithms,
+ has_crypto,
+ requires_cryptography,
+)
+from .exceptions import (
+ DecodeError,
+ InvalidAlgorithmError,
+ InvalidSignatureError,
+ InvalidTokenError,
+)
from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning
+
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWS:
- header_typ = 'JWT'
+ header_typ = "JWT"
- def __init__(self, algorithms: (list[str] | None)=None, options: (dict[
- str, Any] | None)=None) ->None:
+ def __init__(
+ self,
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ ) -> None:
self._algorithms = get_default_algorithms()
- self._valid_algs = set(algorithms) if algorithms is not None else set(
- self._algorithms)
+ self._valid_algs = (
+ set(algorithms) if algorithms is not None else set(self._algorithms)
+ )
+
+ # Remove algorithms that aren't on the whitelist
for key in list(self._algorithms.keys()):
if key not in self._valid_algs:
del self._algorithms[key]
+
if options is None:
options = {}
self.options = {**self._get_default_options(), **options}
- def register_algorithm(self, alg_id: str, alg_obj: Algorithm) ->None:
+ @staticmethod
+ def _get_default_options() -> dict[str, bool]:
+ return {"verify_signature": True}
+
+ def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
- pass
+ if alg_id in self._algorithms:
+ raise ValueError("Algorithm already has a handler.")
+
+ if not isinstance(alg_obj, Algorithm):
+ raise TypeError("Object is not of type `Algorithm`")
- def unregister_algorithm(self, alg_id: str) ->None:
+ self._algorithms[alg_id] = alg_obj
+ self._valid_algs.add(alg_id)
+
+ def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
"""
- pass
+ if alg_id not in self._algorithms:
+ raise KeyError(
+ "The specified algorithm could not be removed"
+ " because it is not registered."
+ )
+
+ del self._algorithms[alg_id]
+ self._valid_algs.remove(alg_id)
- def get_algorithms(self) ->list[str]:
+ def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
- pass
+ return list(self._valid_algs)
- def get_algorithm_by_name(self, alg_name: str) ->Algorithm:
+ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
"""
For a given string name, return the matching Algorithm object.
@@ -53,15 +91,231 @@ class PyJWS:
>>> jws_obj.get_algorithm_by_name("RS256")
"""
- pass
+ try:
+ return self._algorithms[alg_name]
+ except KeyError as e:
+ if not has_crypto and alg_name in requires_cryptography:
+ raise NotImplementedError(
+ f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
+ ) from e
+ raise NotImplementedError("Algorithm not supported") from e
+
+ def encode(
+ self,
+ payload: bytes,
+ key: AllowedPrivateKeys | str | bytes,
+ algorithm: str | None = "HS256",
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
+ is_payload_detached: bool = False,
+ sort_headers: bool = True,
+ ) -> str:
+ segments = []
+
+ # declare a new var to narrow the type for type checkers
+ algorithm_: str = algorithm if algorithm is not None else "none"
+
+ # Prefer headers values if present to function parameters.
+ if headers:
+ headers_alg = headers.get("alg")
+ if headers_alg:
+ algorithm_ = headers["alg"]
+
+ headers_b64 = headers.get("b64")
+ if headers_b64 is False:
+ is_payload_detached = True
+
+ # Header
+ header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
+
+ if headers:
+ self._validate_headers(headers)
+ header.update(headers)
+
+ if not header["typ"]:
+ del header["typ"]
+
+ if is_payload_detached:
+ header["b64"] = False
+ elif "b64" in header:
+ # True is the standard value for b64, so no need for it
+ del header["b64"]
+
+ json_header = json.dumps(
+ header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
+ ).encode()
+
+ segments.append(base64url_encode(json_header))
+
+ if is_payload_detached:
+ msg_payload = payload
+ else:
+ msg_payload = base64url_encode(payload)
+ segments.append(msg_payload)
+
+ # Segments
+ signing_input = b".".join(segments)
+
+ alg_obj = self.get_algorithm_by_name(algorithm_)
+ key = alg_obj.prepare_key(key)
+ signature = alg_obj.sign(signing_input, key)
+
+ segments.append(base64url_encode(signature))
+
+ # Don't put the payload content inside the encoded token when detached
+ if is_payload_detached:
+ segments[1] = b""
+ encoded_string = b".".join(segments)
+
+ return encoded_string.decode("utf-8")
- def get_unverified_header(self, jwt: (str | bytes)) ->dict[str, Any]:
+ def decode_complete(
+ self,
+ jwt: str | bytes,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ detached_payload: bytes | None = None,
+ **kwargs,
+ ) -> dict[str, Any]:
+ if kwargs:
+ warnings.warn(
+ "passing additional kwargs to decode_complete() is deprecated "
+ "and will be removed in pyjwt version 3. "
+ f"Unsupported kwargs: {tuple(kwargs.keys())}",
+ RemovedInPyjwt3Warning,
+ )
+ if options is None:
+ options = {}
+ merged_options = {**self.options, **options}
+ verify_signature = merged_options["verify_signature"]
+
+ if verify_signature and not algorithms:
+ raise DecodeError(
+ 'It is required that you pass in a value for the "algorithms" argument when calling decode().'
+ )
+
+ payload, signing_input, header, signature = self._load(jwt)
+
+ if header.get("b64", True) is False:
+ if detached_payload is None:
+ raise DecodeError(
+ 'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
+ )
+ payload = detached_payload
+ signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])
+
+ if verify_signature:
+ self._verify_signature(signing_input, header, signature, key, algorithms)
+
+ return {
+ "payload": payload,
+ "header": header,
+ "signature": signature,
+ }
+
+ def decode(
+ self,
+ jwt: str | bytes,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ detached_payload: bytes | None = None,
+ **kwargs,
+ ) -> Any:
+ if kwargs:
+ warnings.warn(
+ "passing additional kwargs to decode() is deprecated "
+ "and will be removed in pyjwt version 3. "
+ f"Unsupported kwargs: {tuple(kwargs.keys())}",
+ RemovedInPyjwt3Warning,
+ )
+ decoded = self.decode_complete(
+ jwt, key, algorithms, options, detached_payload=detached_payload
+ )
+ return decoded["payload"]
+
+ def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
should not be fully trusted until signature verification is complete
"""
- pass
+ headers = self._load(jwt)[2]
+ self._validate_headers(headers)
+
+ return headers
+
+ def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
+ if isinstance(jwt, str):
+ jwt = jwt.encode("utf-8")
+
+ if not isinstance(jwt, bytes):
+ raise DecodeError(f"Invalid token type. Token must be a {bytes}")
+
+ try:
+ signing_input, crypto_segment = jwt.rsplit(b".", 1)
+ header_segment, payload_segment = signing_input.split(b".", 1)
+ except ValueError as err:
+ raise DecodeError("Not enough segments") from err
+
+ try:
+ header_data = base64url_decode(header_segment)
+ except (TypeError, binascii.Error) as err:
+ raise DecodeError("Invalid header padding") from err
+
+ try:
+ header = json.loads(header_data)
+ except ValueError as e:
+ raise DecodeError(f"Invalid header string: {e}") from e
+
+ if not isinstance(header, dict):
+ raise DecodeError("Invalid header string: must be a json object")
+
+ try:
+ payload = base64url_decode(payload_segment)
+ except (TypeError, binascii.Error) as err:
+ raise DecodeError("Invalid payload padding") from err
+
+ try:
+ signature = base64url_decode(crypto_segment)
+ except (TypeError, binascii.Error) as err:
+ raise DecodeError("Invalid crypto padding") from err
+
+ return (payload, signing_input, header, signature)
+
+ def _verify_signature(
+ self,
+ signing_input: bytes,
+ header: dict[str, Any],
+ signature: bytes,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ ) -> None:
+ try:
+ alg = header["alg"]
+ except KeyError:
+ raise InvalidAlgorithmError("Algorithm not specified")
+
+ if not alg or (algorithms is not None and alg not in algorithms):
+ raise InvalidAlgorithmError("The specified alg value is not allowed")
+
+ try:
+ alg_obj = self.get_algorithm_by_name(alg)
+ except NotImplementedError as e:
+ raise InvalidAlgorithmError("Algorithm not supported") from e
+ prepared_key = alg_obj.prepare_key(key)
+
+ if not alg_obj.verify(signing_input, prepared_key, signature):
+ raise InvalidSignatureError("Signature verification failed")
+
+ def _validate_headers(self, headers: dict[str, Any]) -> None:
+ if "kid" in headers:
+ self._validate_kid(headers["kid"])
+
+ def _validate_kid(self, kid: Any) -> None:
+ if not isinstance(kid, str):
+ raise InvalidTokenError("Key ID header parameter must be a string")
_jws_global_obj = PyJWS()
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 961d144..48d739a 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -1,37 +1,172 @@
from __future__ import annotations
+
import json
import warnings
from calendar import timegm
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
+
from . import api_jws
-from .exceptions import DecodeError, ExpiredSignatureError, ImmatureSignatureError, InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
+from .exceptions import (
+ DecodeError,
+ ExpiredSignatureError,
+ ImmatureSignatureError,
+ InvalidAudienceError,
+ InvalidIssuedAtError,
+ InvalidIssuerError,
+ MissingRequiredClaimError,
+)
from .warnings import RemovedInPyjwt3Warning
+
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWT:
-
- def __init__(self, options: (dict[str, Any] | None)=None) ->None:
+ def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
- self.options: dict[str, Any] = {**self._get_default_options(), **
- options}
+ self.options: dict[str, Any] = {**self._get_default_options(), **options}
+
+ @staticmethod
+ def _get_default_options() -> dict[str, bool | list[str]]:
+ return {
+ "verify_signature": True,
+ "verify_exp": True,
+ "verify_nbf": True,
+ "verify_iat": True,
+ "verify_aud": True,
+ "verify_iss": True,
+ "require": [],
+ }
+
+ def encode(
+ self,
+ payload: dict[str, Any],
+ key: AllowedPrivateKeys | str | bytes,
+ algorithm: str | None = "HS256",
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
+ sort_headers: bool = True,
+ ) -> str:
+ # Check that we get a dict
+ if not isinstance(payload, dict):
+ raise TypeError(
+ "Expecting a dict object, as JWT only supports "
+ "JSON objects as payloads."
+ )
- def _encode_payload(self, payload: dict[str, Any], headers: (dict[str,
- Any] | None)=None, json_encoder: (type[json.JSONEncoder] | None)=None
- ) ->bytes:
+ # Payload
+ payload = payload.copy()
+ for time_claim in ["exp", "iat", "nbf"]:
+ # Convert datetime to a intDate value in known time-format claims
+ if isinstance(payload.get(time_claim), datetime):
+ payload[time_claim] = timegm(payload[time_claim].utctimetuple())
+
+ json_payload = self._encode_payload(
+ payload,
+ headers=headers,
+ json_encoder=json_encoder,
+ )
+
+ return api_jws.encode(
+ json_payload,
+ key,
+ algorithm,
+ headers,
+ json_encoder,
+ sort_headers=sort_headers,
+ )
+
+ def _encode_payload(
+ self,
+ payload: dict[str, Any],
+ headers: dict[str, Any] | None = None,
+ json_encoder: type[json.JSONEncoder] | None = None,
+ ) -> bytes:
"""
Encode a given payload to the bytes to be signed.
This method is intended to be overridden by subclasses that need to
encode the payload in a different way, e.g. compress the payload.
"""
- pass
+ return json.dumps(
+ payload,
+ separators=(",", ":"),
+ cls=json_encoder,
+ ).encode("utf-8")
+
+ def decode_complete(
+ self,
+ jwt: str | bytes,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ # deprecated arg, remove in pyjwt3
+ verify: bool | None = None,
+ # could be used as passthrough to api_jws, consider removal in pyjwt3
+ detached_payload: bytes | None = None,
+ # passthrough arguments to _validate_claims
+ # consider putting in options
+ audience: str | Iterable[str] | None = None,
+ issuer: str | None = None,
+ leeway: float | timedelta = 0,
+ # kwargs
+ **kwargs: Any,
+ ) -> dict[str, Any]:
+ if kwargs:
+ warnings.warn(
+ "passing additional kwargs to decode_complete() is deprecated "
+ "and will be removed in pyjwt version 3. "
+ f"Unsupported kwargs: {tuple(kwargs.keys())}",
+ RemovedInPyjwt3Warning,
+ )
+ options = dict(options or {}) # shallow-copy or initialize an empty dict
+ options.setdefault("verify_signature", True)
+
+ # If the user has set the legacy `verify` argument, and it doesn't match
+ # what the relevant `options` entry for the argument is, inform the user
+ # that they're likely making a mistake.
+ if verify is not None and verify != options["verify_signature"]:
+ warnings.warn(
+ "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
+ "The equivalent is setting `verify_signature` to False in the `options` dictionary. "
+ "This invocation has a mismatch between the kwarg and the option entry.",
+ category=DeprecationWarning,
+ )
+
+ if not options["verify_signature"]:
+ options.setdefault("verify_exp", False)
+ options.setdefault("verify_nbf", False)
+ options.setdefault("verify_iat", False)
+ options.setdefault("verify_aud", False)
+ options.setdefault("verify_iss", False)
- def _decode_payload(self, decoded: dict[str, Any]) ->Any:
+ if options["verify_signature"] and not algorithms:
+ raise DecodeError(
+ 'It is required that you pass in a value for the "algorithms" argument when calling decode().'
+ )
+
+ decoded = api_jws.decode_complete(
+ jwt,
+ key=key,
+ algorithms=algorithms,
+ options=options,
+ detached_payload=detached_payload,
+ )
+
+ payload = self._decode_payload(decoded)
+
+ merged_options = {**self.options, **options}
+ self._validate_claims(
+ payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
+ )
+
+ decoded["payload"] = payload
+ return decoded
+
+ def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
@@ -39,7 +174,196 @@ class PyJWT:
decode the payload in a different way, e.g. decompress compressed
payloads.
"""
- pass
+ try:
+ payload = json.loads(decoded["payload"])
+ except ValueError as e:
+ raise DecodeError(f"Invalid payload string: {e}")
+ if not isinstance(payload, dict):
+ raise DecodeError("Invalid payload string: must be a json object")
+ return payload
+
+ def decode(
+ self,
+ jwt: str | bytes,
+ key: AllowedPublicKeys | str | bytes = "",
+ algorithms: list[str] | None = None,
+ options: dict[str, Any] | None = None,
+ # deprecated arg, remove in pyjwt3
+ verify: bool | None = None,
+ # could be used as passthrough to api_jws, consider removal in pyjwt3
+ detached_payload: bytes | None = None,
+ # passthrough arguments to _validate_claims
+ # consider putting in options
+ audience: str | Iterable[str] | None = None,
+ issuer: str | None = None,
+ leeway: float | timedelta = 0,
+ # kwargs
+ **kwargs: Any,
+ ) -> Any:
+ if kwargs:
+ warnings.warn(
+ "passing additional kwargs to decode() is deprecated "
+ "and will be removed in pyjwt version 3. "
+ f"Unsupported kwargs: {tuple(kwargs.keys())}",
+ RemovedInPyjwt3Warning,
+ )
+ decoded = self.decode_complete(
+ jwt,
+ key,
+ algorithms,
+ options,
+ verify=verify,
+ detached_payload=detached_payload,
+ audience=audience,
+ issuer=issuer,
+ leeway=leeway,
+ )
+ return decoded["payload"]
+
+ def _validate_claims(
+ self,
+ payload: dict[str, Any],
+ options: dict[str, Any],
+ audience=None,
+ issuer=None,
+ leeway: float | timedelta = 0,
+ ) -> None:
+ if isinstance(leeway, timedelta):
+ leeway = leeway.total_seconds()
+
+ if audience is not None and not isinstance(audience, (str, Iterable)):
+ raise TypeError("audience must be a string, iterable or None")
+
+ self._validate_required_claims(payload, options)
+
+ now = datetime.now(tz=timezone.utc).timestamp()
+
+ if "iat" in payload and options["verify_iat"]:
+ self._validate_iat(payload, now, leeway)
+
+ if "nbf" in payload and options["verify_nbf"]:
+ self._validate_nbf(payload, now, leeway)
+
+ if "exp" in payload and options["verify_exp"]:
+ self._validate_exp(payload, now, leeway)
+
+ if options["verify_iss"]:
+ self._validate_iss(payload, issuer)
+
+ if options["verify_aud"]:
+ self._validate_aud(
+ payload, audience, strict=options.get("strict_aud", False)
+ )
+
+ def _validate_required_claims(
+ self,
+ payload: dict[str, Any],
+ options: dict[str, Any],
+ ) -> None:
+ for claim in options["require"]:
+ if payload.get(claim) is None:
+ raise MissingRequiredClaimError(claim)
+
+ def _validate_iat(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
+ try:
+ iat = int(payload["iat"])
+ except ValueError:
+ raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
+ if iat > (now + leeway):
+ raise ImmatureSignatureError("The token is not yet valid (iat)")
+
+ def _validate_nbf(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
+ try:
+ nbf = int(payload["nbf"])
+ except ValueError:
+ raise DecodeError("Not Before claim (nbf) must be an integer.")
+
+ if nbf > (now + leeway):
+ raise ImmatureSignatureError("The token is not yet valid (nbf)")
+
+ def _validate_exp(
+ self,
+ payload: dict[str, Any],
+ now: float,
+ leeway: float,
+ ) -> None:
+ try:
+ exp = int(payload["exp"])
+ except ValueError:
+ raise DecodeError("Expiration Time claim (exp) must be an" " integer.")
+
+ if exp <= (now - leeway):
+ raise ExpiredSignatureError("Signature has expired")
+
+ def _validate_aud(
+ self,
+ payload: dict[str, Any],
+ audience: str | Iterable[str] | None,
+ *,
+ strict: bool = False,
+ ) -> None:
+ if audience is None:
+ if "aud" not in payload or not payload["aud"]:
+ return
+ # Application did not specify an audience, but
+ # the token has the 'aud' claim
+ raise InvalidAudienceError("Invalid audience")
+
+ if "aud" not in payload or not payload["aud"]:
+ # Application specified an audience, but it could not be
+ # verified since the token does not contain a claim.
+ raise MissingRequiredClaimError("aud")
+
+ audience_claims = payload["aud"]
+
+ # In strict mode, we forbid list matching: the supplied audience
+ # must be a string, and it must exactly match the audience claim.
+ if strict:
+ # Only a single audience is allowed in strict mode.
+ if not isinstance(audience, str):
+ raise InvalidAudienceError("Invalid audience (strict)")
+
+ # Only a single audience claim is allowed in strict mode.
+ if not isinstance(audience_claims, str):
+ raise InvalidAudienceError("Invalid claim format in token (strict)")
+
+ if audience != audience_claims:
+ raise InvalidAudienceError("Audience doesn't match (strict)")
+
+ return
+
+ if isinstance(audience_claims, str):
+ audience_claims = [audience_claims]
+ if not isinstance(audience_claims, list):
+ raise InvalidAudienceError("Invalid claim format in token")
+ if any(not isinstance(c, str) for c in audience_claims):
+ raise InvalidAudienceError("Invalid claim format in token")
+
+ if isinstance(audience, str):
+ audience = [audience]
+
+ if all(aud not in audience_claims for aud in audience):
+ raise InvalidAudienceError("Audience doesn't match")
+
+ def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
+ if issuer is None:
+ return
+
+ if "iss" not in payload:
+ raise MissingRequiredClaimError("iss")
+
+ if payload["iss"] != issuer:
+ raise InvalidIssuerError("Invalid issuer")
_jwt_global_obj = PyJWT()
diff --git a/jwt/exceptions.py b/jwt/exceptions.py
index 43e9576..8ac6ecf 100644
--- a/jwt/exceptions.py
+++ b/jwt/exceptions.py
@@ -2,6 +2,7 @@ class PyJWTError(Exception):
"""
Base class for all exceptions
"""
+
pass
@@ -46,11 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError):
class MissingRequiredClaimError(InvalidTokenError):
-
- def __init__(self, claim: str) ->None:
+ def __init__(self, claim: str) -> None:
self.claim = claim
- def __str__(self) ->str:
+ def __str__(self) -> str:
return f'Token is missing the "{self.claim}" claim'
diff --git a/jwt/help.py b/jwt/help.py
index c6b2173..80b0ca5 100644
--- a/jwt/help.py
+++ b/jwt/help.py
@@ -2,26 +2,63 @@ import json
import platform
import sys
from typing import Dict
+
from . import __version__ as pyjwt_version
+
try:
import cryptography
+
cryptography_version = cryptography.__version__
except ModuleNotFoundError:
- cryptography_version = ''
+ cryptography_version = ""
-def info() ->Dict[str, Dict[str, str]]:
+def info() -> Dict[str, Dict[str, str]]:
"""
Generate information for a bug report.
Based on the requests package help utility module.
"""
- pass
+ try:
+ platform_info = {
+ "system": platform.system(),
+ "release": platform.release(),
+ }
+ except OSError:
+ platform_info = {"system": "Unknown", "release": "Unknown"}
+
+ implementation = platform.python_implementation()
+
+ if implementation == "CPython":
+ implementation_version = platform.python_version()
+ elif implementation == "PyPy":
+ pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined]
+ implementation_version = (
+ f"{pypy_version_info.major}."
+ f"{pypy_version_info.minor}."
+ f"{pypy_version_info.micro}"
+ )
+ if pypy_version_info.releaselevel != "final":
+ implementation_version = "".join(
+ [implementation_version, pypy_version_info.releaselevel]
+ )
+ else:
+ implementation_version = "Unknown"
+
+ return {
+ "platform": platform_info,
+ "implementation": {
+ "name": implementation,
+ "version": implementation_version,
+ },
+ "cryptography": {"version": cryptography_version},
+ "pyjwt": {"version": pyjwt_version},
+ }
-def main() ->None:
+def main() -> None:
"""Pretty-print the bug information as JSON."""
- pass
+ print(json.dumps(info(), sort_keys=True, indent=2))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py
index 6b2abcd..2432563 100644
--- a/jwt/jwk_set_cache.py
+++ b/jwt/jwk_set_cache.py
@@ -1,10 +1,31 @@
import time
from typing import Optional
+
from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
class JWKSetCache:
-
- def __init__(self, lifespan: int) ->None:
+ def __init__(self, lifespan: int) -> None:
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan
+
+ def put(self, jwk_set: PyJWKSet) -> None:
+ if jwk_set is not None:
+ self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
+ else:
+ # clear cache
+ self.jwk_set_with_timestamp = None
+
+ def get(self) -> Optional[PyJWKSet]:
+ if self.jwk_set_with_timestamp is None or self.is_expired():
+ return None
+
+ return self.jwk_set_with_timestamp.get_jwk_set()
+
+ def is_expired(self) -> bool:
+ return (
+ self.jwk_set_with_timestamp is not None
+ and self.lifespan > -1
+ and time.monotonic()
+ > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
+ )
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py
index 6cc0010..f19b10a 100644
--- a/jwt/jwks_client.py
+++ b/jwt/jwks_client.py
@@ -4,6 +4,7 @@ from functools import lru_cache
from ssl import SSLContext
from typing import Any, Dict, List, Optional
from urllib.error import URLError
+
from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientConnectionError, PyJWKClientError
@@ -11,11 +12,17 @@ from .jwk_set_cache import JWKSetCache
class PyJWKClient:
-
- def __init__(self, uri: str, cache_keys: bool=False, max_cached_keys:
- int=16, cache_jwk_set: bool=True, lifespan: int=300, headers:
- Optional[Dict[str, Any]]=None, timeout: int=30, ssl_context:
- Optional[SSLContext]=None):
+ def __init__(
+ self,
+ uri: str,
+ cache_keys: bool = False,
+ max_cached_keys: int = 16,
+ cache_jwk_set: bool = True,
+ lifespan: int = 300,
+ headers: Optional[Dict[str, Any]] = None,
+ timeout: int = 30,
+ ssl_context: Optional[SSLContext] = None,
+ ):
if headers is None:
headers = {}
self.uri = uri
@@ -23,14 +30,95 @@ class PyJWKClient:
self.headers = headers
self.timeout = timeout
self.ssl_context = ssl_context
+
if cache_jwk_set:
+ # Init jwt set cache with default or given lifespan.
+ # Default lifespan is 300 seconds (5 minutes).
if lifespan <= 0:
raise PyJWKClientError(
f'Lifespan must be greater than 0, the input is "{lifespan}"'
- )
+ )
self.jwk_set_cache = JWKSetCache(lifespan)
else:
self.jwk_set_cache = None
+
if cache_keys:
- self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.
- get_signing_key)
+ # Cache signing keys
+ # Ignore mypy (https://github.com/python/mypy/issues/2427)
+ self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
+
+ def fetch_data(self) -> Any:
+ jwk_set: Any = None
+ try:
+ r = urllib.request.Request(url=self.uri, headers=self.headers)
+ with urllib.request.urlopen(
+ r, timeout=self.timeout, context=self.ssl_context
+ ) as response:
+ jwk_set = json.load(response)
+ except (URLError, TimeoutError) as e:
+ raise PyJWKClientConnectionError(
+ f'Fail to fetch data from the url, err: "{e}"'
+ )
+ else:
+ return jwk_set
+ finally:
+ if self.jwk_set_cache is not None:
+ self.jwk_set_cache.put(jwk_set)
+
+ def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
+ data = None
+ if self.jwk_set_cache is not None and not refresh:
+ data = self.jwk_set_cache.get()
+
+ if data is None:
+ data = self.fetch_data()
+
+ if not isinstance(data, dict):
+ raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
+
+ return PyJWKSet.from_dict(data)
+
+ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
+ jwk_set = self.get_jwk_set(refresh)
+ signing_keys = [
+ jwk_set_key
+ for jwk_set_key in jwk_set.keys
+ if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
+ ]
+
+ if not signing_keys:
+ raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
+
+ return signing_keys
+
+ def get_signing_key(self, kid: str) -> PyJWK:
+ signing_keys = self.get_signing_keys()
+ signing_key = self.match_kid(signing_keys, kid)
+
+ if not signing_key:
+ # If no matching signing key from the jwk set, refresh the jwk set and try again.
+ signing_keys = self.get_signing_keys(refresh=True)
+ signing_key = self.match_kid(signing_keys, kid)
+
+ if not signing_key:
+ raise PyJWKClientError(
+ f'Unable to find a signing key that matches: "{kid}"'
+ )
+
+ return signing_key
+
+ def get_signing_key_from_jwt(self, token: str) -> PyJWK:
+ unverified = decode_token(token, options={"verify_signature": False})
+ header = unverified["header"]
+ return self.get_signing_key(header.get("kid"))
+
+ @staticmethod
+ def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
+ signing_key = None
+
+ for key in signing_keys:
+ if key.key_id == kid:
+ signing_key = key
+ break
+
+ return signing_key
diff --git a/jwt/types.py b/jwt/types.py
index 5aa6306..7d99352 100644
--- a/jwt/types.py
+++ b/jwt/types.py
@@ -1,3 +1,5 @@
from typing import Any, Callable, Dict
+
JWKDict = Dict[str, Any]
+
HashlibHash = Callable[..., Any]
diff --git a/jwt/utils.py b/jwt/utils.py
index 1e1c20d..81c5ee4 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -2,20 +2,155 @@ import base64
import binascii
import re
from typing import Union
+
try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
- from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
+ from cryptography.hazmat.primitives.asymmetric.utils import (
+ decode_dss_signature,
+ encode_dss_signature,
+ )
except ModuleNotFoundError:
pass
-_PEMS = {b'CERTIFICATE', b'TRUSTED CERTIFICATE', b'PRIVATE KEY',
- b'PUBLIC KEY', b'ENCRYPTED PRIVATE KEY', b'OPENSSH PRIVATE KEY',
- b'DSA PRIVATE KEY', b'RSA PRIVATE KEY', b'RSA PUBLIC KEY',
- b'EC PRIVATE KEY', b'DH PARAMETERS', b'NEW CERTIFICATE REQUEST',
- b'CERTIFICATE REQUEST', b'SSH2 PUBLIC KEY',
- b'SSH2 ENCRYPTED PRIVATE KEY', b'X509 CRL'}
-_PEM_RE = re.compile(b'----[- ]BEGIN (' + b'|'.join(_PEMS) +
- b')[- ]----\r?\n.+?\r?\n----[- ]END \\1[- ]----\r?\n?', re.DOTALL)
-_CERT_SUFFIX = b'-cert-v01@openssh.com'
-_SSH_PUBKEY_RC = re.compile(b'\\A(\\S+)[ \\t]+(\\S+)')
-_SSH_KEY_FORMATS = [b'ssh-ed25519', b'ssh-rsa', b'ssh-dss',
- b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521']
+
+
+def force_bytes(value: Union[bytes, str]) -> bytes:
+ if isinstance(value, str):
+ return value.encode("utf-8")
+ elif isinstance(value, bytes):
+ return value
+ else:
+ raise TypeError("Expected a string value")
+
+
+def base64url_decode(input: Union[bytes, str]) -> bytes:
+ input_bytes = force_bytes(input)
+
+ rem = len(input_bytes) % 4
+
+ if rem > 0:
+ input_bytes += b"=" * (4 - rem)
+
+ return base64.urlsafe_b64decode(input_bytes)
+
+
+def base64url_encode(input: bytes) -> bytes:
+ return base64.urlsafe_b64encode(input).replace(b"=", b"")
+
+
+def to_base64url_uint(val: int) -> bytes:
+ if val < 0:
+ raise ValueError("Must be a positive integer")
+
+ int_bytes = bytes_from_int(val)
+
+ if len(int_bytes) == 0:
+ int_bytes = b"\x00"
+
+ return base64url_encode(int_bytes)
+
+
+def from_base64url_uint(val: Union[bytes, str]) -> int:
+ data = base64url_decode(force_bytes(val))
+ return int.from_bytes(data, byteorder="big")
+
+
+def number_to_bytes(num: int, num_bytes: int) -> bytes:
+ padded_hex = "%0*x" % (2 * num_bytes, num)
+ return binascii.a2b_hex(padded_hex.encode("ascii"))
+
+
+def bytes_to_number(string: bytes) -> int:
+ return int(binascii.b2a_hex(string), 16)
+
+
+def bytes_from_int(val: int) -> bytes:
+ remaining = val
+ byte_length = 0
+
+ while remaining != 0:
+ remaining >>= 8
+ byte_length += 1
+
+ return val.to_bytes(byte_length, "big", signed=False)
+
+
+def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
+ num_bits = curve.key_size
+ num_bytes = (num_bits + 7) // 8
+
+ r, s = decode_dss_signature(der_sig)
+
+ return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
+
+
+def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
+ num_bits = curve.key_size
+ num_bytes = (num_bits + 7) // 8
+
+ if len(raw_sig) != 2 * num_bytes:
+ raise ValueError("Invalid signature")
+
+ r = bytes_to_number(raw_sig[:num_bytes])
+ s = bytes_to_number(raw_sig[num_bytes:])
+
+ return bytes(encode_dss_signature(r, s))
+
+
+# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
+_PEMS = {
+ b"CERTIFICATE",
+ b"TRUSTED CERTIFICATE",
+ b"PRIVATE KEY",
+ b"PUBLIC KEY",
+ b"ENCRYPTED PRIVATE KEY",
+ b"OPENSSH PRIVATE KEY",
+ b"DSA PRIVATE KEY",
+ b"RSA PRIVATE KEY",
+ b"RSA PUBLIC KEY",
+ b"EC PRIVATE KEY",
+ b"DH PARAMETERS",
+ b"NEW CERTIFICATE REQUEST",
+ b"CERTIFICATE REQUEST",
+ b"SSH2 PUBLIC KEY",
+ b"SSH2 ENCRYPTED PRIVATE KEY",
+ b"X509 CRL",
+}
+
+_PEM_RE = re.compile(
+ b"----[- ]BEGIN ("
+ + b"|".join(_PEMS)
+ + b""")[- ]----\r?
+.+?\r?
+----[- ]END \\1[- ]----\r?\n?""",
+ re.DOTALL,
+)
+
+
+def is_pem_format(key: bytes) -> bool:
+ return bool(_PEM_RE.search(key))
+
+
+# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
+_CERT_SUFFIX = b"-cert-v01@openssh.com"
+_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
+_SSH_KEY_FORMATS = [
+ b"ssh-ed25519",
+ b"ssh-rsa",
+ b"ssh-dss",
+ b"ecdsa-sha2-nistp256",
+ b"ecdsa-sha2-nistp384",
+ b"ecdsa-sha2-nistp521",
+]
+
+
+def is_ssh_key(key: bytes) -> bool:
+ if any(string_value in key for string_value in _SSH_KEY_FORMATS):
+ return True
+
+ ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
+ if ssh_pubkey_match:
+ key_type = ssh_pubkey_match.group(1)
+ if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
+ return True
+
+ return False