back to Claude Sonnet 3.5 - Fill-in summary
Claude Sonnet 3.5 - Fill-in: pyjwt
Pytest Summary for test tests
status | count |
---|---|
passed | 11 |
failed | 1 |
xfailed | 1 |
total | 13 |
collected | 13 |
Failed pytests:
test_utils.py::test_to_base64url_uint[0-AA]
test_utils.py::test_to_base64url_uint[0-AA]
inputval = 0, expected = b'AA' @pytest.mark.parametrize( "inputval,expected", [ (0, b"AA"), (1, b"AQ"), (255, b"_w"), (65537, b"AQAB"), (123456789, b"B1vNFQ"), pytest.param(-1, "", marks=pytest.mark.xfail(raises=ValueError)), ], ) def test_to_base64url_uint(inputval, expected): actual = to_base64url_uint(inputval) > assert actual == expected E AssertionError: assert b'' == b'AA' E E Use -v to get more diff tests/test_utils.py:19: AssertionError
test_utils.py::test_to_base64url_uint[-1-]
test_utils.py::test_to_base64url_uint[-1-]
inputval = -1, expected = '' @pytest.mark.parametrize( "inputval,expected", [ (0, b"AA"), (1, b"AQ"), (255, b"_w"), (65537, b"AQAB"), (123456789, b"B1vNFQ"), pytest.param(-1, "", marks=pytest.mark.xfail(raises=ValueError)), ], ) def test_to_base64url_uint(inputval, expected): > actual = to_base64url_uint(inputval) tests/test_utils.py:18: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ val = -1 def to_base64url_uint(val: int) -> bytes: if val < 0: > raise ValueError("Must be a positive integer") E ValueError: Must be a positive integer jwt/utils.py:51: ValueError
Patch diff
diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 15c200a..be04167 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -43,7 +43,28 @@ def get_default_algorithms() ->dict[str, Algorithm]:
"""
Returns the algorithms that are implemented by the library.
"""
- pass
+ default_algorithms = {
+ 'none': NoneAlgorithm(),
+ 'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
+ 'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
+ 'HS512': HMACAlgorithm(HMACAlgorithm.SHA512),
+ }
+
+ if has_crypto:
+ default_algorithms.update({
+ 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
+ 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
+ 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
+ 'ES256': ECAlgorithm(ECAlgorithm.SHA256),
+ 'ES384': ECAlgorithm(ECAlgorithm.SHA384),
+ 'ES512': ECAlgorithm(ECAlgorithm.SHA512),
+ 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
+ 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
+ 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
+ 'EdDSA': OKPAlgorithm(),
+ })
+
+ return default_algorithms
class Algorithm(ABC):
@@ -57,7 +78,9 @@ class Algorithm(ABC):
If there is no hash algorithm, raises a NotImplementedError.
"""
- pass
+ if hasattr(self, 'hash_alg'):
+ return self.hash_alg(bytestr).digest()
+ raise NotImplementedError("Hash algorithm not specified for this algorithm.")
@abstractmethod
def prepare_key(self, key: Any) ->Any:
@@ -105,6 +128,28 @@ class NoneAlgorithm(Algorithm):
Placeholder for use when no signing or verification
operations are required.
"""
+ def prepare_key(self, key: Any) ->None:
+ return None
+
+ def sign(self, msg: bytes, key: Any) ->bytes:
+ return b''
+
+ def verify(self, msg: bytes, key: Any, sig: bytes) ->bool:
+ return False
+
+ @staticmethod
+ def to_jwk(key_obj, as_dict: bool=False) ->Union[JWKDict, str]:
+ if as_dict:
+ return {'kty': 'none'}
+ return json.dumps({'kty': 'none'})
+
+ @staticmethod
+ def from_jwk(jwk: Union[str, JWKDict]) ->None:
+ if isinstance(jwk, str):
+ jwk = json.loads(jwk)
+ if not isinstance(jwk, dict) or jwk.get('kty') != 'none':
+ raise InvalidKeyError('Invalid key: not a none key')
+ return None
class HMACAlgorithm(Algorithm):
@@ -119,6 +164,36 @@ class HMACAlgorithm(Algorithm):
def __init__(self, hash_alg: HashlibHash) ->None:
self.hash_alg = hash_alg
+ def prepare_key(self, key: Any) ->bytes:
+ key = force_bytes(key)
+ return key
+
+ def sign(self, msg: bytes, key: Any) ->bytes:
+ key = self.prepare_key(key)
+ return hmac.new(key, msg, self.hash_alg).digest()
+
+ def verify(self, msg: bytes, key: Any, sig: bytes) ->bool:
+ key = self.prepare_key(key)
+ return hmac.compare_digest(sig, self.sign(msg, key))
+
+ @staticmethod
+ def to_jwk(key_obj: bytes, as_dict: bool=False) ->Union[JWKDict, str]:
+ jwk = {
+ 'kty': 'oct',
+ 'k': base64url_encode(key_obj).decode('ascii')
+ }
+ if as_dict:
+ return jwk
+ return json.dumps(jwk)
+
+ @staticmethod
+ def from_jwk(jwk: Union[str, JWKDict]) ->bytes:
+ if isinstance(jwk, str):
+ jwk = json.loads(jwk)
+ if not isinstance(jwk, dict) or jwk.get('kty') != 'oct':
+ raise InvalidKeyError('Invalid key: not an octet key')
+ return base64url_decode(jwk['k'])
+
if has_crypto:
diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index eeb5924..fcc015e 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -30,20 +30,24 @@ class PyJWS:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
- pass
+ self._algorithms[alg_id] = alg_obj
+ self._valid_algs.add(alg_id)
def unregister_algorithm(self, alg_id: str) ->None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
"""
- pass
+ if alg_id not in self._algorithms:
+ raise KeyError(f"The algorithm '{alg_id}' is not registered.")
+ del self._algorithms[alg_id]
+ self._valid_algs.remove(alg_id)
def get_algorithms(self) ->list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
- pass
+ return list(self._valid_algs)
def get_algorithm_by_name(self, alg_name: str) ->Algorithm:
"""
@@ -53,7 +57,9 @@ class PyJWS:
>>> jws_obj.get_algorithm_by_name("RS256")
"""
- pass
+ if alg_name not in self._algorithms:
+ raise InvalidAlgorithmError(f"Algorithm '{alg_name}' could not be found")
+ return self._algorithms[alg_name]
def get_unverified_header(self, jwt: (str | bytes)) ->dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
@@ -61,7 +67,20 @@ class PyJWS:
Note: The signature is not verified so the header parameters
should not be fully trusted until signature verification is complete
"""
- pass
+ if isinstance(jwt, str):
+ jwt = jwt.encode('utf-8')
+
+ try:
+ header_segment = jwt.split(b'.')[0]
+ header_data = base64url_decode(header_segment)
+ header = json.loads(header_data.decode('utf-8'))
+ except (ValueError, TypeError, binascii.Error) as e:
+ raise DecodeError(f"Invalid header padding: {str(e)}")
+
+ if not isinstance(header, dict):
+ raise DecodeError("Invalid header string: must be a json object")
+
+ return header
_jws_global_obj = PyJWS()
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 961d144..7a3148d 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -29,7 +29,12 @@ class PyJWT:
This method is intended to be overridden by subclasses that need to
encode the payload in a different way, e.g. compress the payload.
"""
- pass
+ json_payload = json.dumps(
+ payload,
+ separators=(',', ':'),
+ cls=json_encoder
+ ).encode('utf-8')
+ return json_payload
def _decode_payload(self, decoded: dict[str, Any]) ->Any:
"""
@@ -39,7 +44,11 @@ class PyJWT:
decode the payload in a different way, e.g. decompress compressed
payloads.
"""
- pass
+ try:
+ payload = json.loads(decoded['payload'])
+ except ValueError as e:
+ raise DecodeError('Invalid payload string: %s' % e)
+ return payload
_jwt_global_obj = PyJWT()
diff --git a/jwt/help.py b/jwt/help.py
index c6b2173..b5660a7 100644
--- a/jwt/help.py
+++ b/jwt/help.py
@@ -15,12 +15,26 @@ def info() ->Dict[str, Dict[str, str]]:
Generate information for a bug report.
Based on the requests package help utility module.
"""
- pass
+ return {
+ "platform": {
+ "system": platform.system(),
+ "release": platform.release(),
+ "version": platform.version(),
+ "machine": platform.machine(),
+ "processor": platform.processor(),
+ "python_version": platform.python_version(),
+ "python_implementation": platform.python_implementation(),
+ },
+ "dependencies": {
+ "pyjwt": pyjwt_version,
+ "cryptography": cryptography_version,
+ },
+ }
def main() ->None:
"""Pretty-print the bug information as JSON."""
- pass
+ print(json.dumps(info(), sort_keys=True, indent=2))
if __name__ == '__main__':
diff --git a/jwt/utils.py b/jwt/utils.py
index 1e1c20d..20dfae4 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -19,3 +19,62 @@ _CERT_SUFFIX = b'-cert-v01@openssh.com'
_SSH_PUBKEY_RC = re.compile(b'\\A(\\S+)[ \\t]+(\\S+)')
_SSH_KEY_FORMATS = [b'ssh-ed25519', b'ssh-rsa', b'ssh-dss',
b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521']
+
+def force_bytes(value: Union[str, bytes]) -> bytes:
+ if isinstance(value, str):
+ return value.encode('utf-8')
+ elif isinstance(value, bytes):
+ return value
+ else:
+ raise TypeError("Expected str or bytes, got %s" % type(value))
+
+def force_unicode(value: Union[str, bytes]) -> str:
+ if isinstance(value, bytes):
+ return value.decode('utf-8')
+ elif isinstance(value, str):
+ return value
+ else:
+ raise TypeError("Expected str or bytes, got %s" % type(value))
+
+def base64url_decode(input: Union[str, bytes]) -> bytes:
+ input = force_bytes(input)
+ rem = len(input) % 4
+ if rem > 0:
+ input += b'=' * (4 - rem)
+ return base64.urlsafe_b64decode(input)
+
+def base64url_encode(input: Union[str, bytes]) -> bytes:
+ return base64.urlsafe_b64encode(force_bytes(input)).rstrip(b'=')
+
+def to_base64url_uint(val: int) -> bytes:
+ if val < 0:
+ raise ValueError("Must be a positive integer")
+ int_bytes = val.to_bytes((val.bit_length() + 7) // 8, byteorder='big')
+ return base64url_encode(int_bytes)
+
+def from_base64url_uint(val: Union[str, bytes]) -> int:
+ int_bytes = base64url_decode(val)
+ return int.from_bytes(int_bytes, byteorder='big')
+
+def merge_dict(original: dict, updates: dict) -> dict:
+ if not updates:
+ return original
+
+ merged = original.copy()
+ for key, value in updates.items():
+ if isinstance(value, dict) and key in merged and isinstance(merged[key], dict):
+ merged[key] = merge_dict(merged[key], value)
+ else:
+ merged[key] = value
+ return merged
+
+def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
+ try:
+ r, s = decode_dss_signature(der_sig)
+ except ValueError:
+ raise ValueError("Invalid DER signature")
+
+ order = curve.order
+ r_bytes = r.to_bytes((order.bit_length() + 7) // 8, byteorder='big')
+ s_bytes = s.to_bytes((order.bit_length() + 7) // 8, byteorder='big')
+ return r_bytes + s_bytes
diff --git a/jwt/warnings.py b/jwt/warnings.py
index 8762a8c..3977d9c 100644
--- a/jwt/warnings.py
+++ b/jwt/warnings.py
@@ -1,2 +1,4 @@
class RemovedInPyjwt3Warning(DeprecationWarning):
- pass
+ """
+ Warning class to indicate functionality that will be removed in PyJWT version 3.
+ """