diff --git a/jwt/api_jws.py b/jwt/api_jws.py
index 9a20466..1a7235b 100644
--- a/jwt/api_jws.py
+++ b/jwt/api_jws.py
@@ -27,20 +27,26 @@ class PyJWS:
"""
Registers a new Algorithm for use when creating and verifying tokens.
"""
- pass
+ if alg_id in self._algorithms:
+ warnings.warn(f"Algorithm '{alg_id}' already exists. It will be overwritten.", UserWarning)
+ self._algorithms[alg_id] = alg_obj
+ self._valid_algs.add(alg_id)
def unregister_algorithm(self, alg_id: str) -> None:
"""
Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered.
"""
- pass
+ if alg_id not in self._algorithms:
+ raise KeyError(f"Algorithm '{alg_id}' is not registered.")
+ del self._algorithms[alg_id]
+ self._valid_algs.remove(alg_id)
def get_algorithms(self) -> list[str]:
"""
Returns a list of supported values for the 'alg' parameter.
"""
- pass
+ return list(self._valid_algs)
def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
"""
@@ -50,7 +56,10 @@ class PyJWS:
>>> jws_obj.get_algorithm_by_name("RS256")
"""
- pass
+ try:
+ return self._algorithms[alg_name]
+ except KeyError:
+ raise InvalidAlgorithmError(f"Algorithm '{alg_name}' could not be found")
def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
@@ -58,7 +67,75 @@ class PyJWS:
Note: The signature is not verified so the header parameters
should not be fully trusted until signature verification is complete
"""
- pass
+ if isinstance(jwt, str):
+ jwt = jwt.encode('utf-8')
+ try:
+ header_segment = jwt.split(b'.')[0]
+ header_data = base64url_decode(header_segment)
+ return json.loads(header_data)
+ except (ValueError, TypeError, binascii.Error) as e:
+ raise DecodeError(f"Invalid header padding: {e}")
+
+ def encode(self, payload: dict[str, Any], key: str, algorithm: str = 'HS256', headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None) -> str:
+ segments = []
+ if headers is None:
+ headers = {}
+ headers['typ'] = self.header_typ
+ headers['alg'] = algorithm
+
+ json_header = json.dumps(headers, separators=(',', ':'), cls=json_encoder).encode()
+ segments.append(base64url_encode(json_header))
+
+ json_payload = json.dumps(payload, separators=(',', ':'), cls=json_encoder).encode()
+ segments.append(base64url_encode(json_payload))
+
+ signing_input = b'.'.join(segments)
+ alg_obj = self.get_algorithm_by_name(algorithm)
+ signature = alg_obj.sign(signing_input, key)
+ segments.append(base64url_encode(signature))
+
+ return b'.'.join(segments).decode('utf-8')
+
+ def decode_complete(self, jwt: str | bytes, key: str | None = None, algorithms: list[str] | None = None, options: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, Any]:
+ if isinstance(jwt, str):
+ jwt = jwt.encode('utf-8')
+
+ try:
+ header_segment, payload_segment, crypto_segment = jwt.split(b'.')
+ except ValueError:
+ raise DecodeError("Not enough segments")
+
+ header_data = base64url_decode(header_segment)
+ header = json.loads(header_data)
+
+ payload_data = base64url_decode(payload_segment)
+ payload = json.loads(payload_data)
+
+ if algorithms is None:
+ algorithms = self.get_algorithms()
+
+ if options is None:
+ options = {}
+ merged_options = {**self.options, **options}
+
+ if key is None and merged_options.get('verify_signature'):
+ raise DecodeError("Signature verification required but no key was provided.")
+
+ if merged_options.get('verify_signature'):
+ alg_obj = self.get_algorithm_by_name(header['alg'])
+ if header['alg'] not in algorithms:
+ raise InvalidAlgorithmError(f"The specified algorithm '{header['alg']}' is not allowed.")
+ try:
+ alg_obj.verify(f"{header_segment.decode('utf-8')}.{payload_segment.decode('utf-8')}".encode(), base64url_decode(crypto_segment), key)
+ except InvalidSignatureError:
+ raise InvalidSignatureError("Signature verification failed")
+
+ return {'header': header, 'payload': payload}
+
+ def decode(self, jwt: str | bytes, key: str | None = None, algorithms: list[str] | None = None, options: dict[str, Any] | None = None, **kwargs: Any) -> dict[str, Any]:
+ decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
+ return decoded['payload']
+
_jws_global_obj = PyJWS()
encode = _jws_global_obj.encode
decode_complete = _jws_global_obj.decode_complete
@@ -66,4 +143,4 @@ decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
-get_unverified_header = _jws_global_obj.get_unverified_header
\ No newline at end of file
+get_unverified_header = _jws_global_obj.get_unverified_header