From 051ea341b5573fe3edcd53042f347929b92c2b92 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Padilla?= <jpadilla@webapplicate.com>
Date: Thu, 12 Mar 2026 12:46:08 -0400
Subject: [PATCH] Merge commit from fork
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: José Padilla <jpadilla@users.noreply.github.com>
---
 CHANGELOG.rst         |    1 
 jwt/api_jws.py        |   24 +++++++++++-
 tests/test_api_jws.py |  100 ++++++++++++++++++++++++++++++++++++++++++++++++++
 tests/test_api_jwt.py |   18 +++++++++
 4 files changed, 142 insertions(+), 1 deletion(-)

Index: PyJWT-2.8.0/CHANGELOG.rst
===================================================================
--- PyJWT-2.8.0.orig/CHANGELOG.rst	2023-07-18 22:00:08.000000000 +0200
+++ PyJWT-2.8.0/CHANGELOG.rst	2026-03-20 17:59:55.457796619 +0100
@@ -12,6 +12,7 @@
 
 Fixed
 ~~~~~
+- Validate the crit (Critical) Header Parameter defined in RFC 7515 §4.1.11. by @dmbs335 in `GHSA-752w-5fwx-jx9f <https://github.com/jpadilla/pyjwt/security/advisories/GHSA-752w-5fwx-jx9f>`__
 
 Added
 ~~~~~
Index: PyJWT-2.8.0/jwt/api_jws.py
===================================================================
--- PyJWT-2.8.0.orig/jwt/api_jws.py	2023-07-18 22:00:08.000000000 +0200
+++ PyJWT-2.8.0/jwt/api_jws.py	2026-03-20 18:01:37.220669725 +0100
@@ -3,7 +3,7 @@
 import binascii
 import json
 import warnings
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Dict, Set
 
 from .algorithms import (
     Algorithm,
@@ -197,6 +197,9 @@
 
         payload, signing_input, header, signature = self._load(jwt)
 
+        self._validate_headers(header)
+        self._validate_crit(header)
+
         if header.get("b64", True) is False:
             if detached_payload is None:
                 raise DecodeError(
@@ -243,6 +246,7 @@
         """
         headers = self._load(jwt)[2]
         self._validate_headers(headers)
+        self._validate_crit(headers)
 
         return headers
 
@@ -309,6 +313,9 @@
         if not alg_obj.verify(signing_input, prepared_key, signature):
             raise InvalidSignatureError("Signature verification failed")
 
+    # Extensions that PyJWT actually understands and supports
+    _supported_crit: Set[str] = {"b64"}
+
     def _validate_headers(self, headers: dict[str, Any]) -> None:
         if "kid" in headers:
             self._validate_kid(headers["kid"])
@@ -317,6 +324,21 @@
         if not isinstance(kid, str):
             raise InvalidTokenError("Key ID header parameter must be a string")
 
+    def _validate_crit(self, headers: Dict[str, Any]) -> None:
+        if "crit" not in headers:
+            return
+        crit = headers["crit"]
+        if not isinstance(crit, list) or len(crit) == 0:
+            raise InvalidTokenError("Invalid 'crit' header: must be a non-empty list")
+        for ext in crit:
+            if not isinstance(ext, str):
+                raise InvalidTokenError("Invalid 'crit' header: values must be strings")
+            if ext not in self._supported_crit:
+                raise InvalidTokenError(f"Unsupported critical extension: {ext}")
+            if ext not in headers:
+                raise InvalidTokenError(
+                    f"Critical extension '{ext}' is missing from headers"
+                )
 
 _jws_global_obj = PyJWS()
 encode = _jws_global_obj.encode
Index: PyJWT-2.8.0/tests/test_api_jws.py
===================================================================
--- PyJWT-2.8.0.orig/tests/test_api_jws.py	2023-07-18 22:00:08.000000000 +0200
+++ PyJWT-2.8.0/tests/test_api_jws.py	2026-03-20 18:04:23.193859936 +0100
@@ -782,6 +782,106 @@
             in str(exc.value)
         )
 
+    def test_decode_rejects_unknown_crit_extension(
+        self, jws: PyJWS, payload: bytes
+    ) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"},
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_empty_crit(self, jws: PyJWS, payload: bytes) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": []},
+        )
+
+        with pytest.raises(InvalidTokenError, match="must be a non-empty list"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_non_list_crit(self, jws: PyJWS, payload: bytes) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": "b64"},
+        )
+
+        with pytest.raises(InvalidTokenError, match="must be a non-empty list"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_crit_with_non_string_values(
+        self, jws: PyJWS, payload: bytes
+    ) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": [123]},
+        )
+
+        with pytest.raises(InvalidTokenError, match="values must be strings"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_rejects_crit_extension_missing_from_header(
+        self, jws: PyJWS, payload: bytes
+    ) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": ["b64"]},
+        )
+
+        with pytest.raises(InvalidTokenError, match="missing from headers"):
+            jws.decode(token, secret, algorithms=["HS256"])
+
+    def test_decode_accepts_supported_crit_extension(
+        self, jws: PyJWS, payload: bytes
+    ) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": ["b64"], "b64": False},
+            is_payload_detached=True,
+        )
+
+        decoded = jws.decode(
+            token,
+            secret,
+            algorithms=["HS256"],
+            detached_payload=payload,
+        )
+        assert decoded == payload
+
+    def test_get_unverified_header_rejects_unknown_crit(
+        self, jws: PyJWS, payload: bytes
+    ) -> None:
+        secret = "secret"
+        token = jws.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": ["x-unknown"], "x-unknown": "value"},
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jws.get_unverified_header(token)
+
     def test_decode_warns_on_unsupported_kwarg(self, jws, payload):
         secret = "secret"
         jws_message = jws.encode(
Index: PyJWT-2.8.0/tests/test_api_jwt.py
===================================================================
--- PyJWT-2.8.0.orig/tests/test_api_jwt.py	2023-07-18 22:00:08.000000000 +0200
+++ PyJWT-2.8.0/tests/test_api_jwt.py	2026-03-20 18:05:13.580061515 +0100
@@ -706,6 +706,24 @@
         jwt.decode(jwt_message, secret, options=options, algorithms=["HS256"])
         assert options == orig_options
 
+    # -------------------- Crit Header Tests --------------------
+
+    def test_decode_rejects_token_with_unknown_crit_extension(self, jwt: PyJWT) -> None:
+        """RFC 7515 §4.1.11: tokens with unsupported critical extensions MUST be rejected."""
+        from jwt.exceptions import InvalidTokenError
+
+        secret = "secret"
+        payload = {"sub": "attacker", "role": "admin"}
+        token = jwt.encode(
+            payload,
+            secret,
+            algorithm="HS256",
+            headers={"crit": ["x-custom-policy"], "x-custom-policy": "require-mfa"},
+        )
+
+        with pytest.raises(InvalidTokenError, match="Unsupported critical extension"):
+            jwt.decode(token, secret, algorithms=["HS256"])
+
     def test_decode_warns_on_unsupported_kwarg(self, jwt, payload):
         secret = "secret"
         jwt_message = jwt.encode(payload, secret)
