Skip to content

back to Claude Sonnet 3.5 - Fill-in summary

Claude Sonnet 3.5 - Fill-in: pyjwt

Pytest Summary for test tests

status count
passed 11
failed 1
xfailed 1
total 13
collected 13

Failed pytests:

test_utils.py::test_to_base64url_uint[0-AA]

test_utils.py::test_to_base64url_uint[0-AA]
inputval = 0, expected = b'AA'

    @pytest.mark.parametrize(
        "inputval,expected",
        [
            (0, b"AA"),
            (1, b"AQ"),
            (255, b"_w"),
            (65537, b"AQAB"),
            (123456789, b"B1vNFQ"),
            pytest.param(-1, "", marks=pytest.mark.xfail(raises=ValueError)),
        ],
    )
    def test_to_base64url_uint(inputval, expected):
        actual = to_base64url_uint(inputval)
>       assert actual == expected
E       AssertionError: assert b'' == b'AA'
E         
E         Use -v to get more diff

tests/test_utils.py:19: AssertionError

test_utils.py::test_to_base64url_uint[-1-]

test_utils.py::test_to_base64url_uint[-1-]
inputval = -1, expected = ''

    @pytest.mark.parametrize(
        "inputval,expected",
        [
            (0, b"AA"),
            (1, b"AQ"),
            (255, b"_w"),
            (65537, b"AQAB"),
            (123456789, b"B1vNFQ"),
            pytest.param(-1, "", marks=pytest.mark.xfail(raises=ValueError)),
        ],
    )
    def test_to_base64url_uint(inputval, expected):
>       actual = to_base64url_uint(inputval)

tests/test_utils.py:18: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

val = -1

    def to_base64url_uint(val: int) -> bytes:
        if val < 0:
>           raise ValueError("Must be a positive integer")
E           ValueError: Must be a positive integer

jwt/utils.py:51: ValueError

Patch diff

diff --git a/jwt/algorithms.py b/jwt/algorithms.py
index 15c200a..be04167 100644
--- a/jwt/algorithms.py
+++ b/jwt/algorithms.py
@@ -43,7 +43,28 @@ def get_default_algorithms() ->dict[str, Algorithm]:
     """
     Returns the algorithms that are implemented by the library.
     """
-    pass
+    default_algorithms = {
+        'none': NoneAlgorithm(),
+        'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
+        'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
+        'HS512': HMACAlgorithm(HMACAlgorithm.SHA512),
+    }
+    
+    if has_crypto:
+        default_algorithms.update({
+            'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
+            'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
+            'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
+            'ES256': ECAlgorithm(ECAlgorithm.SHA256),
+            'ES384': ECAlgorithm(ECAlgorithm.SHA384),
+            'ES512': ECAlgorithm(ECAlgorithm.SHA512),
+            'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
+            'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
+            'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
+            'EdDSA': OKPAlgorithm(),
+        })
+    
+    return default_algorithms


 class Algorithm(ABC):
@@ -57,7 +78,9 @@ class Algorithm(ABC):

         If there is no hash algorithm, raises a NotImplementedError.
         """
-        pass
+        if hasattr(self, 'hash_alg'):
+            return self.hash_alg(bytestr).digest()
+        raise NotImplementedError("Hash algorithm not specified for this algorithm.")

     @abstractmethod
     def prepare_key(self, key: Any) ->Any:
@@ -105,6 +128,28 @@ class NoneAlgorithm(Algorithm):
     Placeholder for use when no signing or verification
     operations are required.
     """
+    def prepare_key(self, key: Any) ->None:
+        return None
+
+    def sign(self, msg: bytes, key: Any) ->bytes:
+        return b''
+
+    def verify(self, msg: bytes, key: Any, sig: bytes) ->bool:
+        return False
+
+    @staticmethod
+    def to_jwk(key_obj, as_dict: bool=False) ->Union[JWKDict, str]:
+        if as_dict:
+            return {'kty': 'none'}
+        return json.dumps({'kty': 'none'})
+
+    @staticmethod
+    def from_jwk(jwk: Union[str, JWKDict]) ->None:
+        if isinstance(jwk, str):
+            jwk = json.loads(jwk)
+        if not isinstance(jwk, dict) or jwk.get('kty') != 'none':
+            raise InvalidKeyError('Invalid key: not a none key')
+        return None


 class HMACAlgorithm(Algorithm):
@@ -119,6 +164,36 @@ class HMACAlgorithm(Algorithm):
     def __init__(self, hash_alg: HashlibHash) ->None:
         self.hash_alg = hash_alg

+    def prepare_key(self, key: Any) ->bytes:
+        key = force_bytes(key)
+        return key
+
+    def sign(self, msg: bytes, key: Any) ->bytes:
+        key = self.prepare_key(key)
+        return hmac.new(key, msg, self.hash_alg).digest()
+
+    def verify(self, msg: bytes, key: Any, sig: bytes) ->bool:
+        key = self.prepare_key(key)
+        return hmac.compare_digest(sig, self.sign(msg, key))
+
+    @staticmethod
+    def to_jwk(key_obj: bytes, as_dict: bool=False) ->Union[JWKDict, str]:
+        jwk = {
+            'kty': 'oct',
+            'k': base64url_encode(key_obj).decode('ascii')
+        }
+        if as_dict:
+            return jwk
+        return json.dumps(jwk)
+
+    @staticmethod
+    def from_jwk(jwk: Union[str, JWKDict]) ->bytes:
+        if isinstance(jwk, str):
+            jwk = json.loads(jwk)
+        if not isinstance(jwk, dict) or jwk.get('kty') != 'oct':
+            raise InvalidKeyError('Invalid key: not an octet key')
+        return base64url_decode(jwk['k'])
+

 if has_crypto:

diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index eeb5924..fcc015e 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -30,20 +30,24 @@ class PyJWS:
         """
         Registers a new Algorithm for use when creating and verifying tokens.
         """
-        pass
+        self._algorithms[alg_id] = alg_obj
+        self._valid_algs.add(alg_id)

     def unregister_algorithm(self, alg_id: str) ->None:
         """
         Unregisters an Algorithm for use when creating and verifying tokens
         Throws KeyError if algorithm is not registered.
         """
-        pass
+        if alg_id not in self._algorithms:
+            raise KeyError(f"The algorithm '{alg_id}' is not registered.")
+        del self._algorithms[alg_id]
+        self._valid_algs.remove(alg_id)

     def get_algorithms(self) ->list[str]:
         """
         Returns a list of supported values for the 'alg' parameter.
         """
-        pass
+        return list(self._valid_algs)

     def get_algorithm_by_name(self, alg_name: str) ->Algorithm:
         """
@@ -53,7 +57,9 @@ class PyJWS:

         >>> jws_obj.get_algorithm_by_name("RS256")
         """
-        pass
+        if alg_name not in self._algorithms:
+            raise InvalidAlgorithmError(f"Algorithm '{alg_name}' could not be found")
+        return self._algorithms[alg_name]

     def get_unverified_header(self, jwt: (str | bytes)) ->dict[str, Any]:
         """Returns back the JWT header parameters as a dict()
@@ -61,7 +67,20 @@ class PyJWS:
         Note: The signature is not verified so the header parameters
         should not be fully trusted until signature verification is complete
         """
-        pass
+        if isinstance(jwt, str):
+            jwt = jwt.encode('utf-8')
+
+        try:
+            header_segment = jwt.split(b'.')[0]
+            header_data = base64url_decode(header_segment)
+            header = json.loads(header_data.decode('utf-8'))
+        except (ValueError, TypeError, binascii.Error) as e:
+            raise DecodeError(f"Invalid header padding: {str(e)}")
+
+        if not isinstance(header, dict):
+            raise DecodeError("Invalid header string: must be a json object")
+
+        return header


 _jws_global_obj = PyJWS()
diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py
index 961d144..7a3148d 100644
--- a/jwt/api_jwt.py
+++ b/jwt/api_jwt.py
@@ -29,7 +29,12 @@ class PyJWT:
         This method is intended to be overridden by subclasses that need to
         encode the payload in a different way, e.g. compress the payload.
         """
-        pass
+        json_payload = json.dumps(
+            payload,
+            separators=(',', ':'),
+            cls=json_encoder
+        ).encode('utf-8')
+        return json_payload

     def _decode_payload(self, decoded: dict[str, Any]) ->Any:
         """
@@ -39,7 +44,11 @@ class PyJWT:
         decode the payload in a different way, e.g. decompress compressed
         payloads.
         """
-        pass
+        try:
+            payload = json.loads(decoded['payload'])
+        except ValueError as e:
+            raise DecodeError('Invalid payload string: %s' % e)
+        return payload


 _jwt_global_obj = PyJWT()
diff --git a/jwt/help.py b/jwt/help.py
index c6b2173..b5660a7 100644
--- a/jwt/help.py
+++ b/jwt/help.py
@@ -15,12 +15,26 @@ def info() ->Dict[str, Dict[str, str]]:
     Generate information for a bug report.
     Based on the requests package help utility module.
     """
-    pass
+    return {
+        "platform": {
+            "system": platform.system(),
+            "release": platform.release(),
+            "version": platform.version(),
+            "machine": platform.machine(),
+            "processor": platform.processor(),
+            "python_version": platform.python_version(),
+            "python_implementation": platform.python_implementation(),
+        },
+        "dependencies": {
+            "pyjwt": pyjwt_version,
+            "cryptography": cryptography_version,
+        },
+    }


 def main() ->None:
     """Pretty-print the bug information as JSON."""
-    pass
+    print(json.dumps(info(), sort_keys=True, indent=2))


 if __name__ == '__main__':
diff --git a/jwt/utils.py b/jwt/utils.py
index 1e1c20d..20dfae4 100644
--- a/jwt/utils.py
+++ b/jwt/utils.py
@@ -19,3 +19,62 @@ _CERT_SUFFIX = b'-cert-v01@openssh.com'
 _SSH_PUBKEY_RC = re.compile(b'\\A(\\S+)[ \\t]+(\\S+)')
 _SSH_KEY_FORMATS = [b'ssh-ed25519', b'ssh-rsa', b'ssh-dss',
     b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521']
+
+def force_bytes(value: Union[str, bytes]) -> bytes:
+    if isinstance(value, str):
+        return value.encode('utf-8')
+    elif isinstance(value, bytes):
+        return value
+    else:
+        raise TypeError("Expected str or bytes, got %s" % type(value))
+
+def force_unicode(value: Union[str, bytes]) -> str:
+    if isinstance(value, bytes):
+        return value.decode('utf-8')
+    elif isinstance(value, str):
+        return value
+    else:
+        raise TypeError("Expected str or bytes, got %s" % type(value))
+
+def base64url_decode(input: Union[str, bytes]) -> bytes:
+    input = force_bytes(input)
+    rem = len(input) % 4
+    if rem > 0:
+        input += b'=' * (4 - rem)
+    return base64.urlsafe_b64decode(input)
+
+def base64url_encode(input: Union[str, bytes]) -> bytes:
+    return base64.urlsafe_b64encode(force_bytes(input)).rstrip(b'=')
+
+def to_base64url_uint(val: int) -> bytes:
+    if val < 0:
+        raise ValueError("Must be a positive integer")
+    int_bytes = val.to_bytes((val.bit_length() + 7) // 8, byteorder='big')
+    return base64url_encode(int_bytes)
+
+def from_base64url_uint(val: Union[str, bytes]) -> int:
+    int_bytes = base64url_decode(val)
+    return int.from_bytes(int_bytes, byteorder='big')
+
+def merge_dict(original: dict, updates: dict) -> dict:
+    if not updates:
+        return original
+
+    merged = original.copy()
+    for key, value in updates.items():
+        if isinstance(value, dict) and key in merged and isinstance(merged[key], dict):
+            merged[key] = merge_dict(merged[key], value)
+        else:
+            merged[key] = value
+    return merged
+
+def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
+    try:
+        r, s = decode_dss_signature(der_sig)
+    except ValueError:
+        raise ValueError("Invalid DER signature")
+
+    order = curve.order
+    r_bytes = r.to_bytes((order.bit_length() + 7) // 8, byteorder='big')
+    s_bytes = s.to_bytes((order.bit_length() + 7) // 8, byteorder='big')
+    return r_bytes + s_bytes
diff --git a/jwt/warnings.py b/jwt/warnings.py
index 8762a8c..3977d9c 100644
--- a/jwt/warnings.py
+++ b/jwt/warnings.py
@@ -1,2 +1,4 @@
 class RemovedInPyjwt3Warning(DeprecationWarning):
-    pass
+    """
+    Warning class to indicate functionality that will be removed in PyJWT version 3.
+    """