From 051ea341b5573fe3edcd53042f347929b92c2b92 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Padilla?= <jpadilla@webapplicate.com>
Date: Thu, 12 Mar 2026 12:46:08 -0400
Subject: [PATCH] Merge commit from fork
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: José Padilla <jpadilla@users.noreply.github.com>
---
 jwt/api_jws.py        |    6 ++
 tests/test_api_jws.py |  103 +++++++++++++++++++++++++++++++++++++++++++++++++-
 tests/test_api_jwt.py |   31 +++++++++++++++
 3 files changed, 139 insertions(+), 1 deletion(-)

Index: PyJWT-1.5.3/jwt/api_jws.py
===================================================================
--- PyJWT-1.5.3.orig/jwt/api_jws.py	2026-03-19 19:54:14.758037081 +0100
+++ PyJWT-1.5.3/jwt/api_jws.py	2026-03-20 14:43:08.817340305 +0100
@@ -131,6 +131,8 @@
 
         payload, signing_input, header, signature = self._load(jws)
 
+        self._validate_headers(header)
+
         if not verify:
             warnings.warn('The verify parameter is deprecated. '
                           'Please use verify_signature in options instead.',
@@ -212,11 +214,15 @@
     def _validate_headers(self, headers):
         if 'kid' in headers:
             self._validate_kid(headers['kid'])
+        if "crit" in headers:
+            self._validate_crit(headers)
 
     def _validate_kid(self, kid):
         if not isinstance(kid, string_types):
             raise InvalidTokenError('Key ID header parameter must be a string')
 
+    def _validate_crit(self, headers):
+        raise InvalidTokenError('Unsupported critical extension')
 
 _jws_global_obj = PyJWS()
 encode = _jws_global_obj.encode
Index: PyJWT-1.5.3/tests/test_api_jws.py
===================================================================
--- PyJWT-1.5.3.orig/tests/test_api_jws.py	2026-03-19 19:54:14.758948450 +0100
+++ PyJWT-1.5.3/tests/test_api_jws.py	2026-03-20 14:43:08.818340315 +0100
@@ -8,7 +8,7 @@
 from jwt.exceptions import (
     DecodeError, InvalidAlgorithmError, InvalidTokenError
 )
-from jwt.utils import base64url_decode, force_bytes, force_unicode
+from jwt.utils import base64url_decode, base64url_encode, force_bytes, force_unicode
 
 import pytest
 
@@ -36,6 +36,17 @@
     return force_bytes('hello world')
 
 
+def replace_header(token, header):
+    if isinstance(token, string_types):
+        token = token.encode('ascii')
+
+    signing_input, crypto_segment = token.rsplit(b'.', 1)
+    segments = signing_input.split(b'.', 1)
+    header_segment = base64url_encode(force_bytes(json.dumps(header, separators=(',', ':'))))
+
+    return b'.'.join([header_segment, segments[1], crypto_segment])
+
+
 class TestJWS:
     def test_register_algo_does_not_allow_duplicate_registration(self, jws):
         jws.register_algorithm('AAA', Algorithm())
@@ -674,3 +685,93 @@
             jws.encode(payload, 'secret', headers={'kid': None})
 
         assert 'Key ID header parameter must be a string' == str(exc.value)
+
+###################################################################
+
+    def test_decode_rejects_unknown_crit_extension(
+        self, jws, payload
+    ):
+        secret = "secret"
+        token = jws.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jws.header_typ, "alg": "HS256",
+             "crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"}
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_empty_crit(self, jws, payload):
+        secret = "secret"
+        token = jws.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jws.header_typ, "alg": "HS256", "crit": []}
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_non_list_crit(self, jws, payload):
+        secret = "secret"
+        token = jws.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jws.header_typ, "alg": "HS256", "crit": "b64"}
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_crit_with_non_string_values(
+        self, jws, payload
+    ):
+        secret = "secret"
+        token = jws.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jws.header_typ, "alg": "HS256", "crit": [123]}
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_crit_extension_missing_from_header(
+        self, jws, payload
+    ):
+        secret = "secret"
+        token = jws.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jws.header_typ, "alg": "HS256", "crit": ["b64"]}
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_encode_rejects_token_with_crit_header(
+        self, jws, payload
+    ):
+        secret = "secret"
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.encode(
+                payload,
+                secret,
+                algorithm="HS256",
+                headers={"crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"},
+            )
+
+    def test_get_unverified_header_rejects_unknown_crit(
+        self, jws, payload
+    ):
+        secret = "secret"
+        token = jws.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jws.header_typ, "alg": "HS256",
+             "crit": ["x-unknown"], "x-unknown": "value"}
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.get_unverified_header(token)
Index: PyJWT-1.5.3/tests/test_api_jwt.py
===================================================================
--- PyJWT-1.5.3.orig/tests/test_api_jwt.py	2026-03-19 19:54:14.759284925 +0100
+++ PyJWT-1.5.3/tests/test_api_jwt.py	2026-03-20 14:37:17.728924364 +0100
@@ -12,6 +12,7 @@
     InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError,
     MissingRequiredClaimError
 )
+from jwt.utils import base64url_encode, force_bytes
 
 import pytest
 
@@ -34,6 +35,17 @@
     }
 
 
+def replace_header(token, header):
+    if isinstance(token, str):
+        token = token.encode('ascii')
+
+    signing_input, crypto_segment = token.rsplit(b'.', 1)
+    segments = signing_input.split(b'.', 1)
+    header_segment = base64url_encode(force_bytes(json.dumps(header, separators=(',', ':'))))
+
+    return b'.'.join([header_segment, segments[1], crypto_segment])
+
+
 class TestJWT:
     def test_decodes_valid_jwt(self, jwt):
         example_payload = {'hello': 'world'}
@@ -496,3 +508,22 @@
             pass
         else:
             assert False, "Unexpected DeprecationWarning raised."
+
+
+    # -------------------- Crit Header Tests --------------------
+
+    def test_decode_rejects_token_with_unknown_crit_extension(self, jwt):
+        """RFC 7515 par 4.1.11: tokens with unsupported critical extensions MUST be rejected."""
+        from jwt.exceptions import InvalidTokenError
+
+        secret = "secret"
+        payload = {"sub": "attacker", "role": "admin"}
+        token = jwt.encode(payload, secret, algorithm="HS256")
+        token = replace_header(
+            token,
+            {"typ": jwt.header_type, "alg": "HS256",
+             "crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"},
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jwt.decode(token, secret, algorithms=["HS256"])
