From 25db861d8b29434838669a94a843af03d29ea6ed Mon Sep 17 00:00:00 2001
From: Simo Sorce <simo@redhat.com>
Date: Mon, 6 Apr 2026 10:37:20 -0400
Subject: [PATCH] Limit max plaintext size for JWE decompression

This change introduces a maximum plaintext size limit (defaulting to 100MB)
during JWE decryption and updates the decompression logic to enforce it safely
using zlib.decompressobj. The decrypt method now accepts a max_plaintext
parameter to allow overriding the default limit.

This mitigates memory exhaustion and decompression bomb attacks when
processing highly compressed malicious JWE payloads.

Fixes CVE-2026-39373

Signed-off-by: Simo Sorce <simo@redhat.com>
---
 jwcrypto/jwe.py   | 27 +++++++++++++++++++++------
 jwcrypto/tests.py | 34 ++++++++++++++++++++++++++--------
 2 files changed, 47 insertions(+), 14 deletions(-)

Index: jwcrypto-1.5.6/jwcrypto/jwe.py
===================================================================
--- jwcrypto-1.5.6.orig/jwcrypto/jwe.py
+++ jwcrypto-1.5.6/jwcrypto/jwe.py
@@ -12,7 +12,8 @@ from jwcrypto.jwk import JWKSet
 
 # Limit the amount of data we are willing to decompress by default.
 default_max_compressed_size = 256 * 1024
-
+# Limit the maximum plaintext size to 100MB by default.
+default_max_plaintext_size = 100 * 1024 * 1024
 
 # RFC 7516 - 4.1
 # name: (description, supported?)
@@ -371,7 +372,7 @@ class JWE:
         return data
 
     # FIXME: allow to specify which algorithms to accept as valid
-    def _decrypt(self, key, ppe):
+    def _decrypt(self, key, ppe, max_plaintext=default_max_plaintext_size):
 
         jh = self._get_jose_header(ppe.get('header', None))
 
@@ -429,19 +430,29 @@ class JWE:
                 raise InvalidJWEData(
                     'Compressed data exceeds maximum allowed'
                     'size' + f' ({default_max_compressed_size})')
-            self.plaintext = zlib.decompress(data, -zlib.MAX_WBITS)
+            do = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
+            self.plaintext = do.decompress(data, max_plaintext)
+            if do.unconsumed_tail or not do.eof:
+                self.plaintext = None
+                raise InvalidJWEData(
+                    'Compressed data exceeds maximum allowed'
+                    'output size' + f' ({max_plaintext})')
         elif compress is None:
             self.plaintext = data
         else:
             raise ValueError('Unknown compression')
 
-    def decrypt(self, key):
+    def decrypt(self, key, max_plaintext=0):
         """Decrypt a JWE token.
 
         :param key: The (:class:`jwcrypto.jwk.JWK`) decryption key.
         :param key: A (:class:`jwcrypto.jwk.JWK`) decryption key,
          or a (:class:`jwcrypto.jwk.JWKSet`) that contains a key indexed
          by the 'kid' header or (deprecated) a string containing a password.
+        :param max_plaintext: Maximum plaintext size allowed, 0 means
+         the library default applies. Application writers are recommended
+         to set a limit here if they know what is the max plaintext size
+         for their application.
 
         :raises InvalidJWEOperation: if the key is not a JWK object.
         :raises InvalidJWEData: if the ciphertext can't be decrypted or
@@ -449,6 +460,10 @@ class JWE:
         :raises JWKeyNotFound: if key is a JWKSet and the key is not found.
         """
 
+        self.plaintext = None
+        if max_plaintext == 0:
+            max_plaintext = default_max_plaintext_size
+
         if 'ciphertext' not in self.objects:
             raise InvalidJWEOperation("No available ciphertext")
         self.decryptlog = []
@@ -457,14 +472,14 @@ class JWE:
         if 'recipients' in self.objects:
             for rec in self.objects['recipients']:
                 try:
-                    self._decrypt(key, rec)
+                    self._decrypt(key, rec, max_plaintext=max_plaintext)
                 except Exception as e:  # pylint: disable=broad-except
                     if isinstance(e, JWKeyNotFound):
                         missingkey = True
                     self.decryptlog.append('Failed: [%s]' % repr(e))
         else:
             try:
-                self._decrypt(key, self.objects)
+                self._decrypt(key, self.objects, max_plaintext=max_plaintext)
             except Exception as e:  # pylint: disable=broad-except
                 if isinstance(e, JWKeyNotFound):
                     missingkey = True
Index: jwcrypto-1.5.6/jwcrypto/tests.py
===================================================================
--- jwcrypto-1.5.6.orig/jwcrypto/tests.py
+++ jwcrypto-1.5.6/jwcrypto/tests.py
@@ -2124,18 +2124,36 @@ class ConformanceTests(unittest.TestCase
         enc = jwe.JWE(payload.encode('utf-8'),
                       recipient=key,
                       protected=protected_header).serialize(compact=True)
+        check = jwe.JWE()
+        check.deserialize(enc)
         with self.assertRaises(jwe.InvalidJWEData):
-            check = jwe.JWE()
-            check.deserialize(enc)
             check.decrypt(key)
 
-        defmax = jwe.default_max_compressed_size
-        jwe.default_max_compressed_size = 1000000000
-        # ensure we can eraise the limit and decrypt
-        check = jwe.JWE()
-        check.deserialize(enc)
+        # raise the limit on compressed token size so we can decrypt
+        defcmax = jwe.default_max_compressed_size
+        jwe.default_max_compressed_size = 10 * 1024 * 1024
+
+        # this passes if we explicitly allow larger plaintext via API
+        check.decrypt(key, max_plaintext=1000000000)
+
+        # this will still fail because the max plaintext length clamps this
+        with self.assertRaises(jwe.InvalidJWEData):
+            check.decrypt(key)
+
+        # ensure that now this can work with changed defaults
+        defpmax = jwe.default_max_plaintext_size
+        jwe.default_max_plaintext_size = 1000000000
         check.decrypt(key)
-        jwe.default_max_compressed_size = defmax
+
+        # restore limits
+        jwe.default_max_compressed_size = defcmax
+
+        # check that this fails the max compressed header limits
+        with self.assertRaises(jwe.InvalidJWEData):
+            check.decrypt(key)
+
+        # restore plaintext limits
+        jwe.default_max_plaintext_size = defpmax
 
 
 class JWATests(unittest.TestCase):
