Skip to content

back to Reference (Gold) summary

Reference (Gold): pyjwt

Pytest Summary for test tests

status count
passed 259
skipped 2
xfailed 1
total 262
collected 262

Failed pytests:

test_utils.py::test_to_base64url_uint[-1-]

test_utils.py::test_to_base64url_uint[-1-]
inputval = -1, expected = ''

    @pytest.mark.parametrize(
        "inputval,expected",
        [
            (0, b"AA"),
            (1, b"AQ"),
            (255, b"_w"),
            (65537, b"AQAB"),
            (123456789, b"B1vNFQ"),
            pytest.param(-1, "", marks=pytest.mark.xfail(raises=ValueError)),
        ],
    )
    def test_to_base64url_uint(inputval, expected):
>       actual = to_base64url_uint(inputval)

tests/test_utils.py:18: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

val = -1

    def to_base64url_uint(val: int) -> bytes:
        if val < 0:
>           raise ValueError("Must be a positive integer")
E           ValueError: Must be a positive integer

jwt/utils.py:42: ValueError

Patch diff

diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 15c200a..ed18715 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -1,49 +1,146 @@
 from __future__ import annotations
+
 import hashlib
 import hmac
 import json
 import sys
 from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
+
 from .exceptions import InvalidKeyError
 from .types import HashlibHash, JWKDict
-from .utils import base64url_decode, base64url_encode, der_to_raw_signature, force_bytes, from_base64url_uint, is_pem_format, is_ssh_key, raw_to_der_signature, to_base64url_uint
+from .utils import (
+    base64url_decode,
+    base64url_encode,
+    der_to_raw_signature,
+    force_bytes,
+    from_base64url_uint,
+    is_pem_format,
+    is_ssh_key,
+    raw_to_der_signature,
+    to_base64url_uint,
+)
+
 if sys.version_info >= (3, 8):
     from typing import Literal
 else:
     from typing_extensions import Literal
+
+
 try:
     from cryptography.exceptions import InvalidSignature
     from cryptography.hazmat.backends import default_backend
     from cryptography.hazmat.primitives import hashes
     from cryptography.hazmat.primitives.asymmetric import padding
-    from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, SECP256K1, SECP256R1, SECP384R1, SECP521R1, EllipticCurve, EllipticCurvePrivateKey, EllipticCurvePrivateNumbers, EllipticCurvePublicKey, EllipticCurvePublicNumbers
-    from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey, Ed448PublicKey
-    from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
-    from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPrivateNumbers, RSAPublicKey, RSAPublicNumbers, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp, rsa_recover_prime_factors
-    from cryptography.hazmat.primitives.serialization import Encoding, NoEncryption, PrivateFormat, PublicFormat, load_pem_private_key, load_pem_public_key, load_ssh_public_key
+    from cryptography.hazmat.primitives.asymmetric.ec import (
+        ECDSA,
+        SECP256K1,
+        SECP256R1,
+        SECP384R1,
+        SECP521R1,
+        EllipticCurve,
+        EllipticCurvePrivateKey,
+        EllipticCurvePrivateNumbers,
+        EllipticCurvePublicKey,
+        EllipticCurvePublicNumbers,
+    )
+    from cryptography.hazmat.primitives.asymmetric.ed448 import (
+        Ed448PrivateKey,
+        Ed448PublicKey,
+    )
+    from cryptography.hazmat.primitives.asymmetric.ed25519 import (
+        Ed25519PrivateKey,
+        Ed25519PublicKey,
+    )
+    from cryptography.hazmat.primitives.asymmetric.rsa import (
+        RSAPrivateKey,
+        RSAPrivateNumbers,
+        RSAPublicKey,
+        RSAPublicNumbers,
+        rsa_crt_dmp1,
+        rsa_crt_dmq1,
+        rsa_crt_iqmp,
+        rsa_recover_prime_factors,
+    )
+    from cryptography.hazmat.primitives.serialization import (
+        Encoding,
+        NoEncryption,
+        PrivateFormat,
+        PublicFormat,
+        load_pem_private_key,
+        load_pem_public_key,
+        load_ssh_public_key,
+    )
+
     has_crypto = True
 except ModuleNotFoundError:
     has_crypto = False
+
+
 if TYPE_CHECKING:
+    # Type aliases for convenience in algorithms method signatures
     AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
     AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
-    AllowedOKPKeys = (Ed25519PrivateKey | Ed25519PublicKey |
-        Ed448PrivateKey | Ed448PublicKey)
+    AllowedOKPKeys = (
+        Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
+    )
     AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
-    AllowedPrivateKeys = (RSAPrivateKey | EllipticCurvePrivateKey |
-        Ed25519PrivateKey | Ed448PrivateKey)
-    AllowedPublicKeys = (RSAPublicKey | EllipticCurvePublicKey |
-        Ed25519PublicKey | Ed448PublicKey)
-requires_cryptography = {'RS256', 'RS384', 'RS512', 'ES256', 'ES256K',
-    'ES384', 'ES521', 'ES512', 'PS256', 'PS384', 'PS512', 'EdDSA'}
-
-
-def get_default_algorithms() ->dict[str, Algorithm]:
+    AllowedPrivateKeys = (
+        RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
+    )
+    AllowedPublicKeys = (
+        RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
+    )
+
+
+requires_cryptography = {
+    "RS256",
+    "RS384",
+    "RS512",
+    "ES256",
+    "ES256K",
+    "ES384",
+    "ES521",
+    "ES512",
+    "PS256",
+    "PS384",
+    "PS512",
+    "EdDSA",
+}
+
+
+def get_default_algorithms() -> dict[str, Algorithm]:
     """
     Returns the algorithms that are implemented by the library.
     """
-    pass
+    default_algorithms = {
+        "none": NoneAlgorithm(),
+        "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
+        "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
+        "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
+    }
+
+    if has_crypto:
+        default_algorithms.update(
+            {
+                "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
+                "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
+                "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
+                "ES256": ECAlgorithm(ECAlgorithm.SHA256),
+                "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
+                "ES384": ECAlgorithm(ECAlgorithm.SHA384),
+                "ES521": ECAlgorithm(ECAlgorithm.SHA512),
+                "ES512": ECAlgorithm(
+                    ECAlgorithm.SHA512
+                ),  # Backward compat for #219 fix
+                "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
+                "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
+                "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
+                "EdDSA": OKPAlgorithm(),
+            }
+        )
+
+    return default_algorithms


 class Algorithm(ABC):
@@ -51,53 +148,74 @@ class Algorithm(ABC):
     The interface for an algorithm used to sign and verify tokens.
     """

-    def compute_hash_digest(self, bytestr: bytes) ->bytes:
+    def compute_hash_digest(self, bytestr: bytes) -> bytes:
         """
         Compute a hash digest using the specified algorithm's hash algorithm.

         If there is no hash algorithm, raises a NotImplementedError.
         """
-        pass
+        # lookup self.hash_alg if defined in a way that mypy can understand
+        hash_alg = getattr(self, "hash_alg", None)
+        if hash_alg is None:
+            raise NotImplementedError
+
+        if (
+            has_crypto
+            and isinstance(hash_alg, type)
+            and issubclass(hash_alg, hashes.HashAlgorithm)
+        ):
+            digest = hashes.Hash(hash_alg(), backend=default_backend())
+            digest.update(bytestr)
+            return bytes(digest.finalize())
+        else:
+            return bytes(hash_alg(bytestr).digest())

     @abstractmethod
-    def prepare_key(self, key: Any) ->Any:
+    def prepare_key(self, key: Any) -> Any:
         """
         Performs necessary validation and conversions on the key and returns
         the key value in the proper format for sign() and verify().
         """
-        pass

     @abstractmethod
-    def sign(self, msg: bytes, key: Any) ->bytes:
+    def sign(self, msg: bytes, key: Any) -> bytes:
         """
         Returns a digital signature for the specified message
         using the specified key value.
         """
-        pass

     @abstractmethod
-    def verify(self, msg: bytes, key: Any, sig: bytes) ->bool:
+    def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
         """
         Verifies that the specified digital signature is valid
         for the specified message and key values.
         """
-        pass
+
+    @overload
+    @staticmethod
+    @abstractmethod
+    def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
+        ...  # pragma: no cover
+
+    @overload
+    @staticmethod
+    @abstractmethod
+    def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
+        ...  # pragma: no cover

     @staticmethod
     @abstractmethod
-    def to_jwk(key_obj, as_dict: bool=False) ->Union[JWKDict, str]:
+    def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
         """
         Serializes a given key into a JWK
         """
-        pass

     @staticmethod
     @abstractmethod
-    def from_jwk(jwk: (str | JWKDict)) ->Any:
+    def from_jwk(jwk: str | JWKDict) -> Any:
         """
         Deserializes a given key from JWK back into a key object
         """
-        pass


 class NoneAlgorithm(Algorithm):
@@ -106,54 +224,478 @@ class NoneAlgorithm(Algorithm):
     operations are required.
     """

+    def prepare_key(self, key: str | None) -> None:
+        if key == "":
+            key = None
+
+        if key is not None:
+            raise InvalidKeyError('When alg = "none", key value must be None.')
+
+        return key
+
+    def sign(self, msg: bytes, key: None) -> bytes:
+        return b""
+
+    def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
+        return False
+
+    @staticmethod
+    def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
+        raise NotImplementedError()
+
+    @staticmethod
+    def from_jwk(jwk: str | JWKDict) -> NoReturn:
+        raise NotImplementedError()
+

 class HMACAlgorithm(Algorithm):
     """
     Performs signing and verification operations using HMAC
     and the specified hash function.
     """
+
     SHA256: ClassVar[HashlibHash] = hashlib.sha256
     SHA384: ClassVar[HashlibHash] = hashlib.sha384
     SHA512: ClassVar[HashlibHash] = hashlib.sha512

-    def __init__(self, hash_alg: HashlibHash) ->None:
+    def __init__(self, hash_alg: HashlibHash) -> None:
         self.hash_alg = hash_alg

+    def prepare_key(self, key: str | bytes) -> bytes:
+        key_bytes = force_bytes(key)

-if has_crypto:
+        if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
+            raise InvalidKeyError(
+                "The specified key is an asymmetric key or x509 certificate and"
+                " should not be used as an HMAC secret."
+            )
+
+        return key_bytes
+
+    @overload
+    @staticmethod
+    def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
+        ...  # pragma: no cover

+    @overload
+    @staticmethod
+    def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
+        ...  # pragma: no cover
+
+    @staticmethod
+    def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
+        jwk = {
+            "k": base64url_encode(force_bytes(key_obj)).decode(),
+            "kty": "oct",
+        }
+
+        if as_dict:
+            return jwk
+        else:
+            return json.dumps(jwk)
+
+    @staticmethod
+    def from_jwk(jwk: str | JWKDict) -> bytes:
+        try:
+            if isinstance(jwk, str):
+                obj: JWKDict = json.loads(jwk)
+            elif isinstance(jwk, dict):
+                obj = jwk
+            else:
+                raise ValueError
+        except ValueError:
+            raise InvalidKeyError("Key is not valid JSON")
+
+        if obj.get("kty") != "oct":
+            raise InvalidKeyError("Not an HMAC key")
+
+        return base64url_decode(obj["k"])
+
+    def sign(self, msg: bytes, key: bytes) -> bytes:
+        return hmac.new(key, msg, self.hash_alg).digest()
+
+    def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
+        return hmac.compare_digest(sig, self.sign(msg, key))
+
+
+if has_crypto:

     class RSAAlgorithm(Algorithm):
         """
         Performs signing and verification operations using
         RSASSA-PKCS-v1_5 and the specified hash function.
         """
+
         SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
         SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
         SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512

-        def __init__(self, hash_alg: type[hashes.HashAlgorithm]) ->None:
+        def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
             self.hash_alg = hash_alg

+        def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
+            if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
+                return key
+
+            if not isinstance(key, (bytes, str)):
+                raise TypeError("Expecting a PEM-formatted key.")
+
+            key_bytes = force_bytes(key)
+
+            try:
+                if key_bytes.startswith(b"ssh-rsa"):
+                    return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
+                else:
+                    return cast(
+                        RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
+                    )
+            except ValueError:
+                return cast(RSAPublicKey, load_pem_public_key(key_bytes))
+
+        @overload
+        @staticmethod
+        def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
+            ...  # pragma: no cover
+
+        @overload
+        @staticmethod
+        def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
+            ...  # pragma: no cover
+
+        @staticmethod
+        def to_jwk(
+            key_obj: AllowedRSAKeys, as_dict: bool = False
+        ) -> Union[JWKDict, str]:
+            obj: dict[str, Any] | None = None
+
+            if hasattr(key_obj, "private_numbers"):
+                # Private key
+                numbers = key_obj.private_numbers()
+
+                obj = {
+                    "kty": "RSA",
+                    "key_ops": ["sign"],
+                    "n": to_base64url_uint(numbers.public_numbers.n).decode(),
+                    "e": to_base64url_uint(numbers.public_numbers.e).decode(),
+                    "d": to_base64url_uint(numbers.d).decode(),
+                    "p": to_base64url_uint(numbers.p).decode(),
+                    "q": to_base64url_uint(numbers.q).decode(),
+                    "dp": to_base64url_uint(numbers.dmp1).decode(),
+                    "dq": to_base64url_uint(numbers.dmq1).decode(),
+                    "qi": to_base64url_uint(numbers.iqmp).decode(),
+                }
+
+            elif hasattr(key_obj, "verify"):
+                # Public key
+                numbers = key_obj.public_numbers()
+
+                obj = {
+                    "kty": "RSA",
+                    "key_ops": ["verify"],
+                    "n": to_base64url_uint(numbers.n).decode(),
+                    "e": to_base64url_uint(numbers.e).decode(),
+                }
+            else:
+                raise InvalidKeyError("Not a public or private key")
+
+            if as_dict:
+                return obj
+            else:
+                return json.dumps(obj)
+
+        @staticmethod
+        def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
+            try:
+                if isinstance(jwk, str):
+                    obj = json.loads(jwk)
+                elif isinstance(jwk, dict):
+                    obj = jwk
+                else:
+                    raise ValueError
+            except ValueError:
+                raise InvalidKeyError("Key is not valid JSON")
+
+            if obj.get("kty") != "RSA":
+                raise InvalidKeyError("Not an RSA key")
+
+            if "d" in obj and "e" in obj and "n" in obj:
+                # Private key
+                if "oth" in obj:
+                    raise InvalidKeyError(
+                        "Unsupported RSA private key: > 2 primes not supported"
+                    )
+
+                other_props = ["p", "q", "dp", "dq", "qi"]
+                props_found = [prop in obj for prop in other_props]
+                any_props_found = any(props_found)
+
+                if any_props_found and not all(props_found):
+                    raise InvalidKeyError(
+                        "RSA key must include all parameters if any are present besides d"
+                    )
+
+                public_numbers = RSAPublicNumbers(
+                    from_base64url_uint(obj["e"]),
+                    from_base64url_uint(obj["n"]),
+                )
+
+                if any_props_found:
+                    numbers = RSAPrivateNumbers(
+                        d=from_base64url_uint(obj["d"]),
+                        p=from_base64url_uint(obj["p"]),
+                        q=from_base64url_uint(obj["q"]),
+                        dmp1=from_base64url_uint(obj["dp"]),
+                        dmq1=from_base64url_uint(obj["dq"]),
+                        iqmp=from_base64url_uint(obj["qi"]),
+                        public_numbers=public_numbers,
+                    )
+                else:
+                    d = from_base64url_uint(obj["d"])
+                    p, q = rsa_recover_prime_factors(
+                        public_numbers.n, d, public_numbers.e
+                    )
+
+                    numbers = RSAPrivateNumbers(
+                        d=d,
+                        p=p,
+                        q=q,
+                        dmp1=rsa_crt_dmp1(d, p),
+                        dmq1=rsa_crt_dmq1(d, q),
+                        iqmp=rsa_crt_iqmp(p, q),
+                        public_numbers=public_numbers,
+                    )
+
+                return numbers.private_key()
+            elif "n" in obj and "e" in obj:
+                # Public key
+                return RSAPublicNumbers(
+                    from_base64url_uint(obj["e"]),
+                    from_base64url_uint(obj["n"]),
+                ).public_key()
+            else:
+                raise InvalidKeyError("Not a public or private key")
+
+        def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
+            return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
+
+        def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
+            try:
+                key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
+                return True
+            except InvalidSignature:
+                return False

     class ECAlgorithm(Algorithm):
         """
         Performs signing and verification operations using
         ECDSA and the specified hash function
         """
+
         SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
         SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
         SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512

-        def __init__(self, hash_alg: type[hashes.HashAlgorithm]) ->None:
+        def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
             self.hash_alg = hash_alg

+        def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
+            if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
+                return key
+
+            if not isinstance(key, (bytes, str)):
+                raise TypeError("Expecting a PEM-formatted key.")
+
+            key_bytes = force_bytes(key)
+
+            # Attempt to load key. We don't know if it's
+            # a Signing Key or a Verifying Key, so we try
+            # the Verifying Key first.
+            try:
+                if key_bytes.startswith(b"ecdsa-sha2-"):
+                    crypto_key = load_ssh_public_key(key_bytes)
+                else:
+                    crypto_key = load_pem_public_key(key_bytes)  # type: ignore[assignment]
+            except ValueError:
+                crypto_key = load_pem_private_key(key_bytes, password=None)  # type: ignore[assignment]
+
+            # Explicit check the key to prevent confusing errors from cryptography
+            if not isinstance(
+                crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
+            ):
+                raise InvalidKeyError(
+                    "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
+                )
+
+            return crypto_key
+
+        def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
+            der_sig = key.sign(msg, ECDSA(self.hash_alg()))
+
+            return der_to_raw_signature(der_sig, key.curve)
+
+        def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
+            try:
+                der_sig = raw_to_der_signature(sig, key.curve)
+            except ValueError:
+                return False
+
+            try:
+                public_key = (
+                    key.public_key()
+                    if isinstance(key, EllipticCurvePrivateKey)
+                    else key
+                )
+                public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
+                return True
+            except InvalidSignature:
+                return False
+
+        @overload
+        @staticmethod
+        def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
+            ...  # pragma: no cover
+
+        @overload
+        @staticmethod
+        def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
+            ...  # pragma: no cover
+
+        @staticmethod
+        def to_jwk(
+            key_obj: AllowedECKeys, as_dict: bool = False
+        ) -> Union[JWKDict, str]:
+            if isinstance(key_obj, EllipticCurvePrivateKey):
+                public_numbers = key_obj.public_key().public_numbers()
+            elif isinstance(key_obj, EllipticCurvePublicKey):
+                public_numbers = key_obj.public_numbers()
+            else:
+                raise InvalidKeyError("Not a public or private key")
+
+            if isinstance(key_obj.curve, SECP256R1):
+                crv = "P-256"
+            elif isinstance(key_obj.curve, SECP384R1):
+                crv = "P-384"
+            elif isinstance(key_obj.curve, SECP521R1):
+                crv = "P-521"
+            elif isinstance(key_obj.curve, SECP256K1):
+                crv = "secp256k1"
+            else:
+                raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
+
+            obj: dict[str, Any] = {
+                "kty": "EC",
+                "crv": crv,
+                "x": to_base64url_uint(public_numbers.x).decode(),
+                "y": to_base64url_uint(public_numbers.y).decode(),
+            }
+
+            if isinstance(key_obj, EllipticCurvePrivateKey):
+                obj["d"] = to_base64url_uint(
+                    key_obj.private_numbers().private_value
+                ).decode()
+
+            if as_dict:
+                return obj
+            else:
+                return json.dumps(obj)
+
+        @staticmethod
+        def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
+            try:
+                if isinstance(jwk, str):
+                    obj = json.loads(jwk)
+                elif isinstance(jwk, dict):
+                    obj = jwk
+                else:
+                    raise ValueError
+            except ValueError:
+                raise InvalidKeyError("Key is not valid JSON")
+
+            if obj.get("kty") != "EC":
+                raise InvalidKeyError("Not an Elliptic curve key")
+
+            if "x" not in obj or "y" not in obj:
+                raise InvalidKeyError("Not an Elliptic curve key")
+
+            x = base64url_decode(obj.get("x"))
+            y = base64url_decode(obj.get("y"))
+
+            curve = obj.get("crv")
+            curve_obj: EllipticCurve
+
+            if curve == "P-256":
+                if len(x) == len(y) == 32:
+                    curve_obj = SECP256R1()
+                else:
+                    raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
+            elif curve == "P-384":
+                if len(x) == len(y) == 48:
+                    curve_obj = SECP384R1()
+                else:
+                    raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
+            elif curve == "P-521":
+                if len(x) == len(y) == 66:
+                    curve_obj = SECP521R1()
+                else:
+                    raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
+            elif curve == "secp256k1":
+                if len(x) == len(y) == 32:
+                    curve_obj = SECP256K1()
+                else:
+                    raise InvalidKeyError(
+                        "Coords should be 32 bytes for curve secp256k1"
+                    )
+            else:
+                raise InvalidKeyError(f"Invalid curve: {curve}")
+
+            public_numbers = EllipticCurvePublicNumbers(
+                x=int.from_bytes(x, byteorder="big"),
+                y=int.from_bytes(y, byteorder="big"),
+                curve=curve_obj,
+            )
+
+            if "d" not in obj:
+                return public_numbers.public_key()
+
+            d = base64url_decode(obj.get("d"))
+            if len(d) != len(x):
+                raise InvalidKeyError(
+                    "D should be {} bytes for curve {}", len(x), curve
+                )
+
+            return EllipticCurvePrivateNumbers(
+                int.from_bytes(d, byteorder="big"), public_numbers
+            ).private_key()

     class RSAPSSAlgorithm(RSAAlgorithm):
         """
         Performs a signature using RSASSA-PSS with MGF1
         """

+        def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
+            return key.sign(
+                msg,
+                padding.PSS(
+                    mgf=padding.MGF1(self.hash_alg()),
+                    salt_length=self.hash_alg().digest_size,
+                ),
+                self.hash_alg(),
+            )
+
+        def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
+            try:
+                key.verify(
+                    sig,
+                    msg,
+                    padding.PSS(
+                        mgf=padding.MGF1(self.hash_alg()),
+                        salt_length=self.hash_alg().digest_size,
+                    ),
+                    self.hash_alg(),
+                )
+                return True
+            except InvalidSignature:
+                return False

     class OKPAlgorithm(Algorithm):
         """
@@ -162,11 +704,35 @@ if has_crypto:
         This class requires ``cryptography>=2.6`` to be installed.
         """

-        def __init__(self, **kwargs: Any) ->None:
+        def __init__(self, **kwargs: Any) -> None:
             pass

-        def sign(self, msg: (str | bytes), key: (Ed25519PrivateKey |
-            Ed448PrivateKey)) ->bytes:
+        def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
+            if isinstance(key, (bytes, str)):
+                key_str = key.decode("utf-8") if isinstance(key, bytes) else key
+                key_bytes = key.encode("utf-8") if isinstance(key, str) else key
+
+                if "-----BEGIN PUBLIC" in key_str:
+                    key = load_pem_public_key(key_bytes)  # type: ignore[assignment]
+                elif "-----BEGIN PRIVATE" in key_str:
+                    key = load_pem_private_key(key_bytes, password=None)  # type: ignore[assignment]
+                elif key_str[0:4] == "ssh-":
+                    key = load_ssh_public_key(key_bytes)  # type: ignore[assignment]
+
+            # Explicit check the key to prevent confusing errors from cryptography
+            if not isinstance(
+                key,
+                (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
+            ):
+                raise InvalidKeyError(
+                    "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
+                )
+
+            return key
+
+        def sign(
+            self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
+        ) -> bytes:
             """
             Sign a message ``msg`` using the EdDSA private key ``key``
             :param str|bytes msg: Message to sign
@@ -174,10 +740,12 @@ if has_crypto:
                 or :class:`.Ed448PrivateKey` isinstance
             :return bytes signature: The signature, as bytes
             """
-            pass
+            msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+            return key.sign(msg_bytes)

-        def verify(self, msg: (str | bytes), key: AllowedOKPKeys, sig: (str |
-            bytes)) ->bool:
+        def verify(
+            self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
+        ) -> bool:
             """
             Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``

@@ -187,4 +755,108 @@ if has_crypto:
                 A private or public EdDSA key instance
             :return bool verified: True if signature is valid, False if not.
             """
-            pass
+            try:
+                msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
+                sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
+
+                public_key = (
+                    key.public_key()
+                    if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
+                    else key
+                )
+                public_key.verify(sig_bytes, msg_bytes)
+                return True  # If no exception was raised, the signature is valid.
+            except InvalidSignature:
+                return False
+
+        @overload
+        @staticmethod
+        def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
+            ...  # pragma: no cover
+
+        @overload
+        @staticmethod
+        def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
+            ...  # pragma: no cover
+
+        @staticmethod
+        def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
+            if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
+                x = key.public_bytes(
+                    encoding=Encoding.Raw,
+                    format=PublicFormat.Raw,
+                )
+                crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
+
+                obj = {
+                    "x": base64url_encode(force_bytes(x)).decode(),
+                    "kty": "OKP",
+                    "crv": crv,
+                }
+
+                if as_dict:
+                    return obj
+                else:
+                    return json.dumps(obj)
+
+            if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
+                d = key.private_bytes(
+                    encoding=Encoding.Raw,
+                    format=PrivateFormat.Raw,
+                    encryption_algorithm=NoEncryption(),
+                )
+
+                x = key.public_key().public_bytes(
+                    encoding=Encoding.Raw,
+                    format=PublicFormat.Raw,
+                )
+
+                crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
+                obj = {
+                    "x": base64url_encode(force_bytes(x)).decode(),
+                    "d": base64url_encode(force_bytes(d)).decode(),
+                    "kty": "OKP",
+                    "crv": crv,
+                }
+
+                if as_dict:
+                    return obj
+                else:
+                    return json.dumps(obj)
+
+            raise InvalidKeyError("Not a public or private key")
+
+        @staticmethod
+        def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
+            try:
+                if isinstance(jwk, str):
+                    obj = json.loads(jwk)
+                elif isinstance(jwk, dict):
+                    obj = jwk
+                else:
+                    raise ValueError
+            except ValueError:
+                raise InvalidKeyError("Key is not valid JSON")
+
+            if obj.get("kty") != "OKP":
+                raise InvalidKeyError("Not an Octet Key Pair")
+
+            curve = obj.get("crv")
+            if curve != "Ed25519" and curve != "Ed448":
+                raise InvalidKeyError(f"Invalid curve: {curve}")
+
+            if "x" not in obj:
+                raise InvalidKeyError('OKP should have "x" parameter')
+            x = base64url_decode(obj.get("x"))
+
+            try:
+                if "d" not in obj:
+                    if curve == "Ed25519":
+                        return Ed25519PublicKey.from_public_bytes(x)
+                    return Ed448PublicKey.from_public_bytes(x)
+                d = base64url_decode(obj.get("d"))
+                if curve == "Ed25519":
+                    return Ed25519PrivateKey.from_private_bytes(d)
+                return Ed448PrivateKey.from_private_bytes(d)
+            except ValueError as err:
+                raise InvalidKeyError("Invalid key parameter") from err
diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py
index 175e782..456c7f4 100644
--- a/jwt/api_jwk.py
+++ b/jwt/api_jwk.py
@@ -1,86 +1,132 @@
 from __future__ import annotations
+
 import json
 import time
 from typing import Any
+
 from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
 from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError
 from .types import JWKDict


 class PyJWK:
-
-    def __init__(self, jwk_data: JWKDict, algorithm: (str | None)=None) ->None:
+    def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
         self._algorithms = get_default_algorithms()
         self._jwk_data = jwk_data
-        kty = self._jwk_data.get('kty', None)
+
+        kty = self._jwk_data.get("kty", None)
         if not kty:
-            raise InvalidKeyError(f'kty is not found: {self._jwk_data}')
+            raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
+
         if not algorithm and isinstance(self._jwk_data, dict):
-            algorithm = self._jwk_data.get('alg', None)
+            algorithm = self._jwk_data.get("alg", None)
+
         if not algorithm:
-            crv = self._jwk_data.get('crv', None)
-            if kty == 'EC':
-                if crv == 'P-256' or not crv:
-                    algorithm = 'ES256'
-                elif crv == 'P-384':
-                    algorithm = 'ES384'
-                elif crv == 'P-521':
-                    algorithm = 'ES512'
-                elif crv == 'secp256k1':
-                    algorithm = 'ES256K'
+            # Determine alg with kty (and crv).
+            crv = self._jwk_data.get("crv", None)
+            if kty == "EC":
+                if crv == "P-256" or not crv:
+                    algorithm = "ES256"
+                elif crv == "P-384":
+                    algorithm = "ES384"
+                elif crv == "P-521":
+                    algorithm = "ES512"
+                elif crv == "secp256k1":
+                    algorithm = "ES256K"
                 else:
-                    raise InvalidKeyError(f'Unsupported crv: {crv}')
-            elif kty == 'RSA':
-                algorithm = 'RS256'
-            elif kty == 'oct':
-                algorithm = 'HS256'
-            elif kty == 'OKP':
+                    raise InvalidKeyError(f"Unsupported crv: {crv}")
+            elif kty == "RSA":
+                algorithm = "RS256"
+            elif kty == "oct":
+                algorithm = "HS256"
+            elif kty == "OKP":
                 if not crv:
-                    raise InvalidKeyError(f'crv is not found: {self._jwk_data}'
-                        )
-                if crv == 'Ed25519':
-                    algorithm = 'EdDSA'
+                    raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
+                if crv == "Ed25519":
+                    algorithm = "EdDSA"
                 else:
-                    raise InvalidKeyError(f'Unsupported crv: {crv}')
+                    raise InvalidKeyError(f"Unsupported crv: {crv}")
             else:
-                raise InvalidKeyError(f'Unsupported kty: {kty}')
+                raise InvalidKeyError(f"Unsupported kty: {kty}")
+
         if not has_crypto and algorithm in requires_cryptography:
-            raise PyJWKError(
-                f"{algorithm} requires 'cryptography' to be installed.")
+            raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")
+
         self.Algorithm = self._algorithms.get(algorithm)
+
         if not self.Algorithm:
-            raise PyJWKError(
-                f'Unable to find an algorithm for key: {self._jwk_data}')
+            raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
+
         self.key = self.Algorithm.from_jwk(self._jwk_data)

+    @staticmethod
+    def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
+        return PyJWK(obj, algorithm)

-class PyJWKSet:
+    @staticmethod
+    def from_json(data: str, algorithm: None = None) -> "PyJWK":
+        obj = json.loads(data)
+        return PyJWK.from_dict(obj, algorithm)
+
+    @property
+    def key_type(self) -> str | None:
+        return self._jwk_data.get("kty", None)
+
+    @property
+    def key_id(self) -> str | None:
+        return self._jwk_data.get("kid", None)

-    def __init__(self, keys: list[JWKDict]) ->None:
+    @property
+    def public_key_use(self) -> str | None:
+        return self._jwk_data.get("use", None)
+
+
+class PyJWKSet:
+    def __init__(self, keys: list[JWKDict]) -> None:
         self.keys = []
+
         if not keys:
-            raise PyJWKSetError('The JWK Set did not contain any keys')
+            raise PyJWKSetError("The JWK Set did not contain any keys")
+
         if not isinstance(keys, list):
-            raise PyJWKSetError('Invalid JWK Set value')
+            raise PyJWKSetError("Invalid JWK Set value")
+
         for key in keys:
             try:
                 self.keys.append(PyJWK(key))
             except PyJWTError:
+                # skip unusable keys
                 continue
+
         if len(self.keys) == 0:
             raise PyJWKSetError(
                 "The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
-                )
+            )
+
+    @staticmethod
+    def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
+        keys = obj.get("keys", [])
+        return PyJWKSet(keys)
+
+    @staticmethod
+    def from_json(data: str) -> "PyJWKSet":
+        obj = json.loads(data)
+        return PyJWKSet.from_dict(obj)

-    def __getitem__(self, kid: str) ->'PyJWK':
+    def __getitem__(self, kid: str) -> "PyJWK":
         for key in self.keys:
             if key.key_id == kid:
                 return key
-        raise KeyError(f'keyset has no key for kid: {kid}')
+        raise KeyError(f"keyset has no key for kid: {kid}")


 class PyJWTSetWithTimestamp:
-
     def __init__(self, jwk_set: PyJWKSet):
         self.jwk_set = jwk_set
         self.timestamp = time.monotonic()
+
+    def get_jwk_set(self) -> PyJWKSet:
+        return self.jwk_set
+
+    def get_timestamp(self) -> float:
+        return self.timestamp
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index eeb5924..fa6708c 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -1,51 +1,89 @@
 from __future__ import annotations
+
 import binascii
 import json
 import warnings
 from typing import TYPE_CHECKING, Any
-from .algorithms import Algorithm, get_default_algorithms, has_crypto, requires_cryptography
-from .exceptions import DecodeError, InvalidAlgorithmError, InvalidSignatureError, InvalidTokenError
+
+from .algorithms import (
+    Algorithm,
+    get_default_algorithms,
+    has_crypto,
+    requires_cryptography,
+)
+from .exceptions import (
+    DecodeError,
+    InvalidAlgorithmError,
+    InvalidSignatureError,
+    InvalidTokenError,
+)
 from .utils import base64url_decode, base64url_encode
 from .warnings import RemovedInPyjwt3Warning
+
 if TYPE_CHECKING:
     from .algorithms import AllowedPrivateKeys, AllowedPublicKeys


 class PyJWS:
-    header_typ = 'JWT'
+    header_typ = "JWT"

-    def __init__(self, algorithms: (list[str] | None)=None, options: (dict[
-        str, Any] | None)=None) ->None:
+    def __init__(
+        self,
+        algorithms: list[str] | None = None,
+        options: dict[str, Any] | None = None,
+    ) -> None:
         self._algorithms = get_default_algorithms()
-        self._valid_algs = set(algorithms) if algorithms is not None else set(
-            self._algorithms)
+        self._valid_algs = (
+            set(algorithms) if algorithms is not None else set(self._algorithms)
+        )
+
+        # Remove algorithms that aren't on the whitelist
         for key in list(self._algorithms.keys()):
             if key not in self._valid_algs:
                 del self._algorithms[key]
+
         if options is None:
             options = {}
         self.options = {**self._get_default_options(), **options}

-    def register_algorithm(self, alg_id: str, alg_obj: Algorithm) ->None:
+    @staticmethod
+    def _get_default_options() -> dict[str, bool]:
+        return {"verify_signature": True}
+
+    def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
         """
         Registers a new Algorithm for use when creating and verifying tokens.
         """
-        pass
+        if alg_id in self._algorithms:
+            raise ValueError("Algorithm already has a handler.")
+
+        if not isinstance(alg_obj, Algorithm):
+            raise TypeError("Object is not of type `Algorithm`")

-    def unregister_algorithm(self, alg_id: str) ->None:
+        self._algorithms[alg_id] = alg_obj
+        self._valid_algs.add(alg_id)
+
+    def unregister_algorithm(self, alg_id: str) -> None:
         """
         Unregisters an Algorithm for use when creating and verifying tokens
         Throws KeyError if algorithm is not registered.
         """
-        pass
+        if alg_id not in self._algorithms:
+            raise KeyError(
+                "The specified algorithm could not be removed"
+                " because it is not registered."
+            )
+
+        del self._algorithms[alg_id]
+        self._valid_algs.remove(alg_id)

-    def get_algorithms(self) ->list[str]:
+    def get_algorithms(self) -> list[str]:
         """
         Returns a list of supported values for the 'alg' parameter.
         """
-        pass
+        return list(self._valid_algs)

-    def get_algorithm_by_name(self, alg_name: str) ->Algorithm:
+    def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
         """
         For a given string name, return the matching Algorithm object.

@@ -53,15 +91,231 @@ class PyJWS:

         >>> jws_obj.get_algorithm_by_name("RS256")
         """
-        pass
+        try:
+            return self._algorithms[alg_name]
+        except KeyError as e:
+            if not has_crypto and alg_name in requires_cryptography:
+                raise NotImplementedError(
+                    f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
+                ) from e
+            raise NotImplementedError("Algorithm not supported") from e
+
+    def encode(
+        self,
+        payload: bytes,
+        key: AllowedPrivateKeys | str | bytes,
+        algorithm: str | None = "HS256",
+        headers: dict[str, Any] | None = None,
+        json_encoder: type[json.JSONEncoder] | None = None,
+        is_payload_detached: bool = False,
+        sort_headers: bool = True,
+    ) -> str:
+        segments = []
+
+        # declare a new var to narrow the type for type checkers
+        algorithm_: str = algorithm if algorithm is not None else "none"
+
+        # Prefer headers values if present to function parameters.
+        if headers:
+            headers_alg = headers.get("alg")
+            if headers_alg:
+                algorithm_ = headers["alg"]
+
+            headers_b64 = headers.get("b64")
+            if headers_b64 is False:
+                is_payload_detached = True
+
+        # Header
+        header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
+
+        if headers:
+            self._validate_headers(headers)
+            header.update(headers)
+
+        if not header["typ"]:
+            del header["typ"]
+
+        if is_payload_detached:
+            header["b64"] = False
+        elif "b64" in header:
+            # True is the standard value for b64, so no need for it
+            del header["b64"]
+
+        json_header = json.dumps(
+            header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
+        ).encode()
+
+        segments.append(base64url_encode(json_header))
+
+        if is_payload_detached:
+            msg_payload = payload
+        else:
+            msg_payload = base64url_encode(payload)
+        segments.append(msg_payload)
+
+        # Segments
+        signing_input = b".".join(segments)
+
+        alg_obj = self.get_algorithm_by_name(algorithm_)
+        key = alg_obj.prepare_key(key)
+        signature = alg_obj.sign(signing_input, key)
+
+        segments.append(base64url_encode(signature))
+
+        # Don't put the payload content inside the encoded token when detached
+        if is_payload_detached:
+            segments[1] = b""
+        encoded_string = b".".join(segments)
+
+        return encoded_string.decode("utf-8")

-    def get_unverified_header(self, jwt: (str | bytes)) ->dict[str, Any]:
+    def decode_complete(
+        self,
+        jwt: str | bytes,
+        key: AllowedPublicKeys | str | bytes = "",
+        algorithms: list[str] | None = None,
+        options: dict[str, Any] | None = None,
+        detached_payload: bytes | None = None,
+        **kwargs,
+    ) -> dict[str, Any]:
+        if kwargs:
+            warnings.warn(
+                "passing additional kwargs to decode_complete() is deprecated "
+                "and will be removed in pyjwt version 3. "
+                f"Unsupported kwargs: {tuple(kwargs.keys())}",
+                RemovedInPyjwt3Warning,
+            )
+        if options is None:
+            options = {}
+        merged_options = {**self.options, **options}
+        verify_signature = merged_options["verify_signature"]
+
+        if verify_signature and not algorithms:
+            raise DecodeError(
+                'It is required that you pass in a value for the "algorithms" argument when calling decode().'
+            )
+
+        payload, signing_input, header, signature = self._load(jwt)
+
+        if header.get("b64", True) is False:
+            if detached_payload is None:
+                raise DecodeError(
+                    'It is required that you pass in a value for the "detached_payload" argument to decode a message having the b64 header set to false.'
+                )
+            payload = detached_payload
+            signing_input = b".".join([signing_input.rsplit(b".", 1)[0], payload])
+
+        if verify_signature:
+            self._verify_signature(signing_input, header, signature, key, algorithms)
+
+        return {
+            "payload": payload,
+            "header": header,
+            "signature": signature,
+        }
+
+    def decode(
+        self,
+        jwt: str | bytes,
+        key: AllowedPublicKeys | str | bytes = "",
+        algorithms: list[str] | None = None,
+        options: dict[str, Any] | None = None,
+        detached_payload: bytes | None = None,
+        **kwargs,
+    ) -> Any:
+        if kwargs:
+            warnings.warn(
+                "passing additional kwargs to decode() is deprecated "
+                "and will be removed in pyjwt version 3. "
+                f"Unsupported kwargs: {tuple(kwargs.keys())}",
+                RemovedInPyjwt3Warning,
+            )
+        decoded = self.decode_complete(
+            jwt, key, algorithms, options, detached_payload=detached_payload
+        )
+        return decoded["payload"]
+
+    def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
         """Returns back the JWT header parameters as a dict()

         Note: The signature is not verified so the header parameters
         should not be fully trusted until signature verification is complete
         """
-        pass
+        headers = self._load(jwt)[2]
+        self._validate_headers(headers)
+
+        return headers
+
+    def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
+        if isinstance(jwt, str):
+            jwt = jwt.encode("utf-8")
+
+        if not isinstance(jwt, bytes):
+            raise DecodeError(f"Invalid token type. Token must be a {bytes}")
+
+        try:
+            signing_input, crypto_segment = jwt.rsplit(b".", 1)
+            header_segment, payload_segment = signing_input.split(b".", 1)
+        except ValueError as err:
+            raise DecodeError("Not enough segments") from err
+
+        try:
+            header_data = base64url_decode(header_segment)
+        except (TypeError, binascii.Error) as err:
+            raise DecodeError("Invalid header padding") from err
+
+        try:
+            header = json.loads(header_data)
+        except ValueError as e:
+            raise DecodeError(f"Invalid header string: {e}") from e
+
+        if not isinstance(header, dict):
+            raise DecodeError("Invalid header string: must be a json object")
+
+        try:
+            payload = base64url_decode(payload_segment)
+        except (TypeError, binascii.Error) as err:
+            raise DecodeError("Invalid payload padding") from err
+
+        try:
+            signature = base64url_decode(crypto_segment)
+        except (TypeError, binascii.Error) as err:
+            raise DecodeError("Invalid crypto padding") from err
+
+        return (payload, signing_input, header, signature)
+
+    def _verify_signature(
+        self,
+        signing_input: bytes,
+        header: dict[str, Any],
+        signature: bytes,
+        key: AllowedPublicKeys | str | bytes = "",
+        algorithms: list[str] | None = None,
+    ) -> None:
+        try:
+            alg = header["alg"]
+        except KeyError:
+            raise InvalidAlgorithmError("Algorithm not specified")
+
+        if not alg or (algorithms is not None and alg not in algorithms):
+            raise InvalidAlgorithmError("The specified alg value is not allowed")
+
+        try:
+            alg_obj = self.get_algorithm_by_name(alg)
+        except NotImplementedError as e:
+            raise InvalidAlgorithmError("Algorithm not supported") from e
+        prepared_key = alg_obj.prepare_key(key)
+
+        if not alg_obj.verify(signing_input, prepared_key, signature):
+            raise InvalidSignatureError("Signature verification failed")
+
+    def _validate_headers(self, headers: dict[str, Any]) -> None:
+        if "kid" in headers:
+            self._validate_kid(headers["kid"])
+
+    def _validate_kid(self, kid: Any) -> None:
+        if not isinstance(kid, str):
+            raise InvalidTokenError("Key ID header parameter must be a string")


 _jws_global_obj = PyJWS()
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 961d144..48d739a 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -1,37 +1,172 @@
 from __future__ import annotations
+
 import json
 import warnings
 from calendar import timegm
 from collections.abc import Iterable
 from datetime import datetime, timedelta, timezone
 from typing import TYPE_CHECKING, Any
+
 from . import api_jws
-from .exceptions import DecodeError, ExpiredSignatureError, ImmatureSignatureError, InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
+from .exceptions import (
+    DecodeError,
+    ExpiredSignatureError,
+    ImmatureSignatureError,
+    InvalidAudienceError,
+    InvalidIssuedAtError,
+    InvalidIssuerError,
+    MissingRequiredClaimError,
+)
 from .warnings import RemovedInPyjwt3Warning
+
 if TYPE_CHECKING:
     from .algorithms import AllowedPrivateKeys, AllowedPublicKeys


 class PyJWT:
-
-    def __init__(self, options: (dict[str, Any] | None)=None) ->None:
+    def __init__(self, options: dict[str, Any] | None = None) -> None:
         if options is None:
             options = {}
-        self.options: dict[str, Any] = {**self._get_default_options(), **
-            options}
+        self.options: dict[str, Any] = {**self._get_default_options(), **options}
+
+    @staticmethod
+    def _get_default_options() -> dict[str, bool | list[str]]:
+        return {
+            "verify_signature": True,
+            "verify_exp": True,
+            "verify_nbf": True,
+            "verify_iat": True,
+            "verify_aud": True,
+            "verify_iss": True,
+            "require": [],
+        }
+
+    def encode(
+        self,
+        payload: dict[str, Any],
+        key: AllowedPrivateKeys | str | bytes,
+        algorithm: str | None = "HS256",
+        headers: dict[str, Any] | None = None,
+        json_encoder: type[json.JSONEncoder] | None = None,
+        sort_headers: bool = True,
+    ) -> str:
+        # Check that we get a dict
+        if not isinstance(payload, dict):
+            raise TypeError(
+                "Expecting a dict object, as JWT only supports "
+                "JSON objects as payloads."
+            )

-    def _encode_payload(self, payload: dict[str, Any], headers: (dict[str,
-        Any] | None)=None, json_encoder: (type[json.JSONEncoder] | None)=None
-        ) ->bytes:
+        # Payload
+        payload = payload.copy()
+        for time_claim in ["exp", "iat", "nbf"]:
+            # Convert datetime to a intDate value in known time-format claims
+            if isinstance(payload.get(time_claim), datetime):
+                payload[time_claim] = timegm(payload[time_claim].utctimetuple())
+
+        json_payload = self._encode_payload(
+            payload,
+            headers=headers,
+            json_encoder=json_encoder,
+        )
+
+        return api_jws.encode(
+            json_payload,
+            key,
+            algorithm,
+            headers,
+            json_encoder,
+            sort_headers=sort_headers,
+        )
+
+    def _encode_payload(
+        self,
+        payload: dict[str, Any],
+        headers: dict[str, Any] | None = None,
+        json_encoder: type[json.JSONEncoder] | None = None,
+    ) -> bytes:
         """
         Encode a given payload to the bytes to be signed.

         This method is intended to be overridden by subclasses that need to
         encode the payload in a different way, e.g. compress the payload.
         """
-        pass
+        return json.dumps(
+            payload,
+            separators=(",", ":"),
+            cls=json_encoder,
+        ).encode("utf-8")
+
+    def decode_complete(
+        self,
+        jwt: str | bytes,
+        key: AllowedPublicKeys | str | bytes = "",
+        algorithms: list[str] | None = None,
+        options: dict[str, Any] | None = None,
+        # deprecated arg, remove in pyjwt3
+        verify: bool | None = None,
+        # could be used as passthrough to api_jws, consider removal in pyjwt3
+        detached_payload: bytes | None = None,
+        # passthrough arguments to _validate_claims
+        # consider putting in options
+        audience: str | Iterable[str] | None = None,
+        issuer: str | None = None,
+        leeway: float | timedelta = 0,
+        # kwargs
+        **kwargs: Any,
+    ) -> dict[str, Any]:
+        if kwargs:
+            warnings.warn(
+                "passing additional kwargs to decode_complete() is deprecated "
+                "and will be removed in pyjwt version 3. "
+                f"Unsupported kwargs: {tuple(kwargs.keys())}",
+                RemovedInPyjwt3Warning,
+            )
+        options = dict(options or {})  # shallow-copy or initialize an empty dict
+        options.setdefault("verify_signature", True)
+
+        # If the user has set the legacy `verify` argument, and it doesn't match
+        # what the relevant `options` entry for the argument is, inform the user
+        # that they're likely making a mistake.
+        if verify is not None and verify != options["verify_signature"]:
+            warnings.warn(
+                "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
+                "The equivalent is setting `verify_signature` to False in the `options` dictionary. "
+                "This invocation has a mismatch between the kwarg and the option entry.",
+                category=DeprecationWarning,
+            )
+
+        if not options["verify_signature"]:
+            options.setdefault("verify_exp", False)
+            options.setdefault("verify_nbf", False)
+            options.setdefault("verify_iat", False)
+            options.setdefault("verify_aud", False)
+            options.setdefault("verify_iss", False)

-    def _decode_payload(self, decoded: dict[str, Any]) ->Any:
+        if options["verify_signature"] and not algorithms:
+            raise DecodeError(
+                'It is required that you pass in a value for the "algorithms" argument when calling decode().'
+            )
+
+        decoded = api_jws.decode_complete(
+            jwt,
+            key=key,
+            algorithms=algorithms,
+            options=options,
+            detached_payload=detached_payload,
+        )
+
+        payload = self._decode_payload(decoded)
+
+        merged_options = {**self.options, **options}
+        self._validate_claims(
+            payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
+        )
+
+        decoded["payload"] = payload
+        return decoded
+
+    def _decode_payload(self, decoded: dict[str, Any]) -> Any:
         """
         Decode the payload from a JWS dictionary (payload, signature, header).

@@ -39,7 +174,196 @@ class PyJWT:
         decode the payload in a different way, e.g. decompress compressed
         payloads.
         """
-        pass
+        try:
+            payload = json.loads(decoded["payload"])
+        except ValueError as e:
+            raise DecodeError(f"Invalid payload string: {e}")
+        if not isinstance(payload, dict):
+            raise DecodeError("Invalid payload string: must be a json object")
+        return payload
+
+    def decode(
+        self,
+        jwt: str | bytes,
+        key: AllowedPublicKeys | str | bytes = "",
+        algorithms: list[str] | None = None,
+        options: dict[str, Any] | None = None,
+        # deprecated arg, remove in pyjwt3
+        verify: bool | None = None,
+        # could be used as passthrough to api_jws, consider removal in pyjwt3
+        detached_payload: bytes | None = None,
+        # passthrough arguments to _validate_claims
+        # consider putting in options
+        audience: str | Iterable[str] | None = None,
+        issuer: str | None = None,
+        leeway: float | timedelta = 0,
+        # kwargs
+        **kwargs: Any,
+    ) -> Any:
+        if kwargs:
+            warnings.warn(
+                "passing additional kwargs to decode() is deprecated "
+                "and will be removed in pyjwt version 3. "
+                f"Unsupported kwargs: {tuple(kwargs.keys())}",
+                RemovedInPyjwt3Warning,
+            )
+        decoded = self.decode_complete(
+            jwt,
+            key,
+            algorithms,
+            options,
+            verify=verify,
+            detached_payload=detached_payload,
+            audience=audience,
+            issuer=issuer,
+            leeway=leeway,
+        )
+        return decoded["payload"]
+
+    def _validate_claims(
+        self,
+        payload: dict[str, Any],
+        options: dict[str, Any],
+        audience=None,
+        issuer=None,
+        leeway: float | timedelta = 0,
+    ) -> None:
+        if isinstance(leeway, timedelta):
+            leeway = leeway.total_seconds()
+
+        if audience is not None and not isinstance(audience, (str, Iterable)):
+            raise TypeError("audience must be a string, iterable or None")
+
+        self._validate_required_claims(payload, options)
+
+        now = datetime.now(tz=timezone.utc).timestamp()
+
+        if "iat" in payload and options["verify_iat"]:
+            self._validate_iat(payload, now, leeway)
+
+        if "nbf" in payload and options["verify_nbf"]:
+            self._validate_nbf(payload, now, leeway)
+
+        if "exp" in payload and options["verify_exp"]:
+            self._validate_exp(payload, now, leeway)
+
+        if options["verify_iss"]:
+            self._validate_iss(payload, issuer)
+
+        if options["verify_aud"]:
+            self._validate_aud(
+                payload, audience, strict=options.get("strict_aud", False)
+            )
+
+    def _validate_required_claims(
+        self,
+        payload: dict[str, Any],
+        options: dict[str, Any],
+    ) -> None:
+        for claim in options["require"]:
+            if payload.get(claim) is None:
+                raise MissingRequiredClaimError(claim)
+
+    def _validate_iat(
+        self,
+        payload: dict[str, Any],
+        now: float,
+        leeway: float,
+    ) -> None:
+        try:
+            iat = int(payload["iat"])
+        except ValueError:
+            raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
+        if iat > (now + leeway):
+            raise ImmatureSignatureError("The token is not yet valid (iat)")
+
+    def _validate_nbf(
+        self,
+        payload: dict[str, Any],
+        now: float,
+        leeway: float,
+    ) -> None:
+        try:
+            nbf = int(payload["nbf"])
+        except ValueError:
+            raise DecodeError("Not Before claim (nbf) must be an integer.")
+
+        if nbf > (now + leeway):
+            raise ImmatureSignatureError("The token is not yet valid (nbf)")
+
+    def _validate_exp(
+        self,
+        payload: dict[str, Any],
+        now: float,
+        leeway: float,
+    ) -> None:
+        try:
+            exp = int(payload["exp"])
+        except ValueError:
+            raise DecodeError("Expiration Time claim (exp) must be an" " integer.")
+
+        if exp <= (now - leeway):
+            raise ExpiredSignatureError("Signature has expired")
+
+    def _validate_aud(
+        self,
+        payload: dict[str, Any],
+        audience: str | Iterable[str] | None,
+        *,
+        strict: bool = False,
+    ) -> None:
+        if audience is None:
+            if "aud" not in payload or not payload["aud"]:
+                return
+            # Application did not specify an audience, but
+            # the token has the 'aud' claim
+            raise InvalidAudienceError("Invalid audience")
+
+        if "aud" not in payload or not payload["aud"]:
+            # Application specified an audience, but it could not be
+            # verified since the token does not contain a claim.
+            raise MissingRequiredClaimError("aud")
+
+        audience_claims = payload["aud"]
+
+        # In strict mode, we forbid list matching: the supplied audience
+        # must be a string, and it must exactly match the audience claim.
+        if strict:
+            # Only a single audience is allowed in strict mode.
+            if not isinstance(audience, str):
+                raise InvalidAudienceError("Invalid audience (strict)")
+
+            # Only a single audience claim is allowed in strict mode.
+            if not isinstance(audience_claims, str):
+                raise InvalidAudienceError("Invalid claim format in token (strict)")
+
+            if audience != audience_claims:
+                raise InvalidAudienceError("Audience doesn't match (strict)")
+
+            return
+
+        if isinstance(audience_claims, str):
+            audience_claims = [audience_claims]
+        if not isinstance(audience_claims, list):
+            raise InvalidAudienceError("Invalid claim format in token")
+        if any(not isinstance(c, str) for c in audience_claims):
+            raise InvalidAudienceError("Invalid claim format in token")
+
+        if isinstance(audience, str):
+            audience = [audience]
+
+        if all(aud not in audience_claims for aud in audience):
+            raise InvalidAudienceError("Audience doesn't match")
+
+    def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
+        if issuer is None:
+            return
+
+        if "iss" not in payload:
+            raise MissingRequiredClaimError("iss")
+
+        if payload["iss"] != issuer:
+            raise InvalidIssuerError("Invalid issuer")


 _jwt_global_obj = PyJWT()
diff --git a/jwt/exceptions.py b/jwt/exceptions.py
index 43e9576..8ac6ecf 100644
--- a/jwt/exceptions.py
+++ b/jwt/exceptions.py
@@ -2,6 +2,7 @@ class PyJWTError(Exception):
     """
     Base class for all exceptions
     """
+
     pass


@@ -46,11 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError):


 class MissingRequiredClaimError(InvalidTokenError):
-
-    def __init__(self, claim: str) ->None:
+    def __init__(self, claim: str) -> None:
         self.claim = claim

-    def __str__(self) ->str:
+    def __str__(self) -> str:
         return f'Token is missing the "{self.claim}" claim'


diff --git a/jwt/help.py b/jwt/help.py
index c6b2173..80b0ca5 100644
--- a/jwt/help.py
+++ b/jwt/help.py
@@ -2,26 +2,63 @@ import json
 import platform
 import sys
 from typing import Dict
+
 from . import __version__ as pyjwt_version
+
 try:
     import cryptography
+
     cryptography_version = cryptography.__version__
 except ModuleNotFoundError:
-    cryptography_version = ''
+    cryptography_version = ""


-def info() ->Dict[str, Dict[str, str]]:
+def info() -> Dict[str, Dict[str, str]]:
     """
     Generate information for a bug report.
     Based on the requests package help utility module.
     """
-    pass
+    try:
+        platform_info = {
+            "system": platform.system(),
+            "release": platform.release(),
+        }
+    except OSError:
+        platform_info = {"system": "Unknown", "release": "Unknown"}
+
+    implementation = platform.python_implementation()
+
+    if implementation == "CPython":
+        implementation_version = platform.python_version()
+    elif implementation == "PyPy":
+        pypy_version_info = sys.pypy_version_info  # type: ignore[attr-defined]
+        implementation_version = (
+            f"{pypy_version_info.major}."
+            f"{pypy_version_info.minor}."
+            f"{pypy_version_info.micro}"
+        )
+        if pypy_version_info.releaselevel != "final":
+            implementation_version = "".join(
+                [implementation_version, pypy_version_info.releaselevel]
+            )
+    else:
+        implementation_version = "Unknown"
+
+    return {
+        "platform": platform_info,
+        "implementation": {
+            "name": implementation,
+            "version": implementation_version,
+        },
+        "cryptography": {"version": cryptography_version},
+        "pyjwt": {"version": pyjwt_version},
+    }


-def main() ->None:
+def main() -> None:
     """Pretty-print the bug information as JSON."""
-    pass
+    print(json.dumps(info(), sort_keys=True, indent=2))


-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/jwt/jwk_set_cache.py b/jwt/jwk_set_cache.py
index 6b2abcd..2432563 100644
--- a/jwt/jwk_set_cache.py
+++ b/jwt/jwk_set_cache.py
@@ -1,10 +1,31 @@
 import time
 from typing import Optional
+
 from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp


 class JWKSetCache:
-
-    def __init__(self, lifespan: int) ->None:
+    def __init__(self, lifespan: int) -> None:
         self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
         self.lifespan = lifespan
+
+    def put(self, jwk_set: PyJWKSet) -> None:
+        if jwk_set is not None:
+            self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
+        else:
+            # clear cache
+            self.jwk_set_with_timestamp = None
+
+    def get(self) -> Optional[PyJWKSet]:
+        if self.jwk_set_with_timestamp is None or self.is_expired():
+            return None
+
+        return self.jwk_set_with_timestamp.get_jwk_set()
+
+    def is_expired(self) -> bool:
+        return (
+            self.jwk_set_with_timestamp is not None
+            and self.lifespan > -1
+            and time.monotonic()
+            > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
+        )
diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py
index 6cc0010..f19b10a 100644
--- a/jwt/jwks_client.py
+++ b/jwt/jwks_client.py
@@ -4,6 +4,7 @@ from functools import lru_cache
 from ssl import SSLContext
 from typing import Any, Dict, List, Optional
 from urllib.error import URLError
+
 from .api_jwk import PyJWK, PyJWKSet
 from .api_jwt import decode_complete as decode_token
 from .exceptions import PyJWKClientConnectionError, PyJWKClientError
@@ -11,11 +12,17 @@ from .jwk_set_cache import JWKSetCache


 class PyJWKClient:
-
-    def __init__(self, uri: str, cache_keys: bool=False, max_cached_keys:
-        int=16, cache_jwk_set: bool=True, lifespan: int=300, headers:
-        Optional[Dict[str, Any]]=None, timeout: int=30, ssl_context:
-        Optional[SSLContext]=None):
+    def __init__(
+        self,
+        uri: str,
+        cache_keys: bool = False,
+        max_cached_keys: int = 16,
+        cache_jwk_set: bool = True,
+        lifespan: int = 300,
+        headers: Optional[Dict[str, Any]] = None,
+        timeout: int = 30,
+        ssl_context: Optional[SSLContext] = None,
+    ):
         if headers is None:
             headers = {}
         self.uri = uri
@@ -23,14 +30,95 @@ class PyJWKClient:
         self.headers = headers
         self.timeout = timeout
         self.ssl_context = ssl_context
+
         if cache_jwk_set:
+            # Init jwt set cache with default or given lifespan.
+            # Default lifespan is 300 seconds (5 minutes).
             if lifespan <= 0:
                 raise PyJWKClientError(
                     f'Lifespan must be greater than 0, the input is "{lifespan}"'
-                    )
+                )
             self.jwk_set_cache = JWKSetCache(lifespan)
         else:
             self.jwk_set_cache = None
+
         if cache_keys:
-            self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.
-                get_signing_key)
+            # Cache signing keys
+            # Ignore mypy (https://github.com/python/mypy/issues/2427)
+            self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key)  # type: ignore
+
+    def fetch_data(self) -> Any:
+        jwk_set: Any = None
+        try:
+            r = urllib.request.Request(url=self.uri, headers=self.headers)
+            with urllib.request.urlopen(
+                r, timeout=self.timeout, context=self.ssl_context
+            ) as response:
+                jwk_set = json.load(response)
+        except (URLError, TimeoutError) as e:
+            raise PyJWKClientConnectionError(
+                f'Fail to fetch data from the url, err: "{e}"'
+            )
+        else:
+            return jwk_set
+        finally:
+            if self.jwk_set_cache is not None:
+                self.jwk_set_cache.put(jwk_set)
+
+    def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
+        data = None
+        if self.jwk_set_cache is not None and not refresh:
+            data = self.jwk_set_cache.get()
+
+        if data is None:
+            data = self.fetch_data()
+
+        if not isinstance(data, dict):
+            raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
+
+        return PyJWKSet.from_dict(data)
+
+    def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
+        jwk_set = self.get_jwk_set(refresh)
+        signing_keys = [
+            jwk_set_key
+            for jwk_set_key in jwk_set.keys
+            if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
+        ]
+
+        if not signing_keys:
+            raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
+
+        return signing_keys
+
+    def get_signing_key(self, kid: str) -> PyJWK:
+        signing_keys = self.get_signing_keys()
+        signing_key = self.match_kid(signing_keys, kid)
+
+        if not signing_key:
+            # If no matching signing key from the jwk set, refresh the jwk set and try again.
+            signing_keys = self.get_signing_keys(refresh=True)
+            signing_key = self.match_kid(signing_keys, kid)
+
+            if not signing_key:
+                raise PyJWKClientError(
+                    f'Unable to find a signing key that matches: "{kid}"'
+                )
+
+        return signing_key
+
+    def get_signing_key_from_jwt(self, token: str) -> PyJWK:
+        unverified = decode_token(token, options={"verify_signature": False})
+        header = unverified["header"]
+        return self.get_signing_key(header.get("kid"))
+
+    @staticmethod
+    def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
+        signing_key = None
+
+        for key in signing_keys:
+            if key.key_id == kid:
+                signing_key = key
+                break
+
+        return signing_key
diff --git a/jwt/types.py b/jwt/types.py
index 5aa6306..7d99352 100644
--- a/jwt/types.py
+++ b/jwt/types.py
@@ -1,3 +1,5 @@
 from typing import Any, Callable, Dict
+
 JWKDict = Dict[str, Any]
+
 HashlibHash = Callable[..., Any]
diff --git a/jwt/utils.py b/jwt/utils.py
index 1e1c20d..81c5ee4 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -2,20 +2,155 @@ import base64
 import binascii
 import re
 from typing import Union
+
 try:
     from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
-    from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
+    from cryptography.hazmat.primitives.asymmetric.utils import (
+        decode_dss_signature,
+        encode_dss_signature,
+    )
 except ModuleNotFoundError:
     pass
-_PEMS = {b'CERTIFICATE', b'TRUSTED CERTIFICATE', b'PRIVATE KEY',
-    b'PUBLIC KEY', b'ENCRYPTED PRIVATE KEY', b'OPENSSH PRIVATE KEY',
-    b'DSA PRIVATE KEY', b'RSA PRIVATE KEY', b'RSA PUBLIC KEY',
-    b'EC PRIVATE KEY', b'DH PARAMETERS', b'NEW CERTIFICATE REQUEST',
-    b'CERTIFICATE REQUEST', b'SSH2 PUBLIC KEY',
-    b'SSH2 ENCRYPTED PRIVATE KEY', b'X509 CRL'}
-_PEM_RE = re.compile(b'----[- ]BEGIN (' + b'|'.join(_PEMS) +
-    b')[- ]----\r?\n.+?\r?\n----[- ]END \\1[- ]----\r?\n?', re.DOTALL)
-_CERT_SUFFIX = b'-cert-v01@openssh.com'
-_SSH_PUBKEY_RC = re.compile(b'\\A(\\S+)[ \\t]+(\\S+)')
-_SSH_KEY_FORMATS = [b'ssh-ed25519', b'ssh-rsa', b'ssh-dss',
-    b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521']
+
+
+def force_bytes(value: Union[bytes, str]) -> bytes:
+    if isinstance(value, str):
+        return value.encode("utf-8")
+    elif isinstance(value, bytes):
+        return value
+    else:
+        raise TypeError("Expected a string value")
+
+
+def base64url_decode(input: Union[bytes, str]) -> bytes:
+    input_bytes = force_bytes(input)
+
+    rem = len(input_bytes) % 4
+
+    if rem > 0:
+        input_bytes += b"=" * (4 - rem)
+
+    return base64.urlsafe_b64decode(input_bytes)
+
+
+def base64url_encode(input: bytes) -> bytes:
+    return base64.urlsafe_b64encode(input).replace(b"=", b"")
+
+
+def to_base64url_uint(val: int) -> bytes:
+    if val < 0:
+        raise ValueError("Must be a positive integer")
+
+    int_bytes = bytes_from_int(val)
+
+    if len(int_bytes) == 0:
+        int_bytes = b"\x00"
+
+    return base64url_encode(int_bytes)
+
+
+def from_base64url_uint(val: Union[bytes, str]) -> int:
+    data = base64url_decode(force_bytes(val))
+    return int.from_bytes(data, byteorder="big")
+
+
+def number_to_bytes(num: int, num_bytes: int) -> bytes:
+    padded_hex = "%0*x" % (2 * num_bytes, num)
+    return binascii.a2b_hex(padded_hex.encode("ascii"))
+
+
+def bytes_to_number(string: bytes) -> int:
+    return int(binascii.b2a_hex(string), 16)
+
+
+def bytes_from_int(val: int) -> bytes:
+    remaining = val
+    byte_length = 0
+
+    while remaining != 0:
+        remaining >>= 8
+        byte_length += 1
+
+    return val.to_bytes(byte_length, "big", signed=False)
+
+
+def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
+    num_bits = curve.key_size
+    num_bytes = (num_bits + 7) // 8
+
+    r, s = decode_dss_signature(der_sig)
+
+    return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
+
+
+def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
+    num_bits = curve.key_size
+    num_bytes = (num_bits + 7) // 8
+
+    if len(raw_sig) != 2 * num_bytes:
+        raise ValueError("Invalid signature")
+
+    r = bytes_to_number(raw_sig[:num_bytes])
+    s = bytes_to_number(raw_sig[num_bytes:])
+
+    return bytes(encode_dss_signature(r, s))
+
+
+# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
+_PEMS = {
+    b"CERTIFICATE",
+    b"TRUSTED CERTIFICATE",
+    b"PRIVATE KEY",
+    b"PUBLIC KEY",
+    b"ENCRYPTED PRIVATE KEY",
+    b"OPENSSH PRIVATE KEY",
+    b"DSA PRIVATE KEY",
+    b"RSA PRIVATE KEY",
+    b"RSA PUBLIC KEY",
+    b"EC PRIVATE KEY",
+    b"DH PARAMETERS",
+    b"NEW CERTIFICATE REQUEST",
+    b"CERTIFICATE REQUEST",
+    b"SSH2 PUBLIC KEY",
+    b"SSH2 ENCRYPTED PRIVATE KEY",
+    b"X509 CRL",
+}
+
+_PEM_RE = re.compile(
+    b"----[- ]BEGIN ("
+    + b"|".join(_PEMS)
+    + b""")[- ]----\r?
+.+?\r?
+----[- ]END \\1[- ]----\r?\n?""",
+    re.DOTALL,
+)
+
+
+def is_pem_format(key: bytes) -> bool:
+    return bool(_PEM_RE.search(key))
+
+
+# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
+_CERT_SUFFIX = b"-cert-v01@openssh.com"
+_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
+_SSH_KEY_FORMATS = [
+    b"ssh-ed25519",
+    b"ssh-rsa",
+    b"ssh-dss",
+    b"ecdsa-sha2-nistp256",
+    b"ecdsa-sha2-nistp384",
+    b"ecdsa-sha2-nistp521",
+]
+
+
+def is_ssh_key(key: bytes) -> bool:
+    if any(string_value in key for string_value in _SSH_KEY_FORMATS):
+        return True
+
+    ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
+    if ssh_pubkey_match:
+        key_type = ssh_pubkey_match.group(1)
+        if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
+            return True
+
+    return False