From 4598189a21ce60b15fb1fc506896cc27351d2473 Mon Sep 17 00:00:00 2001
From: Jacob Hoffman-Andrews <github@hoffman-andrews.com>
Date: Tue, 31 Mar 2026 16:33:50 -0700
Subject: [PATCH] Merge commit from fork

* cipher: fix panic on KeyUnwrap of too-short slice

* jwe: don't call KeyUnwrap on empty (encrypted) key

Also don't call `aead.decrypt` on an empty key.

* test: make asymmetric_test more precise

These two test cases were passing a nil recipient, and checking for "any error"
instead of a specific error, which meant that introducing a nil recipient check
in `decryptKey` caused the test to stop testing what it meant to test, but
continue passing. Now we check for a specific error.

* test: TestKeyUnwrapShort

* jwe: TestEmptyEncryptedKey

* test: add `shorten` and `empty` corruptors
---
 asymmetric.go      | 10 +++++++++-
 cipher/key_wrap.go | 10 +++++++++-
 symmetric.go       | 26 ++++++++++++++++++--------
 3 files changed, 36 insertions(+), 10 deletions(-)

diff --git a/asymmetric.go b/asymmetric.go
index f8d5774..7784cd4 100644
--- a/asymmetric.go
+++ b/asymmetric.go
@@ -414,6 +414,9 @@ func (ctx ecKeyGenerator) genKey() ([]byte, rawHeader, error) {
 
 // Decrypt the given payload and return the content encryption key.
 func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
+	if recipient == nil {
+		return nil, errors.New("go-jose/go-jose: missing recipient")
+	}
 	epk, err := headers.getEPK()
 	if err != nil {
 		return nil, errors.New("go-jose/go-jose: invalid epk header")
@@ -461,13 +464,18 @@ func (ctx ecDecrypterSigner) decryptKey(headers rawHeader, recipient *recipientI
 		return nil, ErrUnsupportedAlgorithm
 	}
 
+	encryptedKey := recipient.encryptedKey
+	if len(encryptedKey) == 0 {
+		return nil, errors.New("go-jose/go-jose: missing JWE Encrypted Key")
+	}
+
 	key := deriveKey(string(algorithm), keySize)
 	block, err := aes.NewCipher(key)
 	if err != nil {
 		return nil, err
 	}
 
-	return josecipher.KeyUnwrap(block, recipient.encryptedKey)
+	return josecipher.KeyUnwrap(block, encryptedKey)
 }
 
 func (ctx edDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error) {
diff --git a/cipher/key_wrap.go b/cipher/key_wrap.go
index b9effbc..a2f86e3 100644
--- a/cipher/key_wrap.go
+++ b/cipher/key_wrap.go
@@ -66,12 +66,20 @@ func KeyWrap(block cipher.Block, cek []byte) ([]byte, error) {
 }
 
 // KeyUnwrap implements NIST key unwrapping; it unwraps a content encryption key (cek) with the given block cipher.
+//
+// https://datatracker.ietf.org/doc/html/rfc7518#section-4.4
+// https://datatracker.ietf.org/doc/html/rfc7518#section-4.6
+// https://datatracker.ietf.org/doc/html/rfc7518#section-4.8
 func KeyUnwrap(block cipher.Block, ciphertext []byte) ([]byte, error) {
+	n := (len(ciphertext) / 8) - 1
+	if n <= 0 {
+		return nil, errors.New("go-jose/go-jose: JWE Encrypted Key too short")
+	}
+
 	if len(ciphertext)%8 != 0 {
 		return nil, errors.New("go-jose/go-jose: key wrap input must be 8 byte blocks")
 	}
 
-	n := (len(ciphertext) / 8) - 1
 	r := make([][]byte, n)
 
 	for i := range r {
diff --git a/symmetric.go b/symmetric.go
index 09efefb..f2ff29e 100644
--- a/symmetric.go
+++ b/symmetric.go
@@ -366,11 +366,21 @@ func (ctx *symmetricKeyCipher) encryptKey(cek []byte, alg KeyAlgorithm) (recipie
 
 // Decrypt the content encryption key.
 func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
-	switch headers.getAlgorithm() {
-	case DIRECT:
-		cek := make([]byte, len(ctx.key))
-		copy(cek, ctx.key)
-		return cek, nil
+	if recipient == nil {
+		return nil, fmt.Errorf("go-jose/go-jose: missing recipient")
+	}
+
+	alg := headers.getAlgorithm()
+	if alg == DIRECT {
+		return bytes.Clone(ctx.key), nil
+	}
+
+	encryptedKey := recipient.encryptedKey
+	if len(encryptedKey) == 0 {
+		return nil, fmt.Errorf("go-jose/go-jose: missing JWE Encrypted Key")
+	}
+
+	switch alg {
 	case A128GCMKW, A192GCMKW, A256GCMKW:
 		aead := newAESGCM(len(ctx.key))
 
@@ -385,7 +395,7 @@ func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipien
 
 		parts := &aeadParts{
 			iv:         iv.bytes(),
-			ciphertext: recipient.encryptedKey,
+			ciphertext: encryptedKey,
 			tag:        tag.bytes(),
 		}
 
@@ -401,7 +411,7 @@ func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipien
 			return nil, err
 		}
 
-		cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
+		cek, err := josecipher.KeyUnwrap(block, encryptedKey)
 		if err != nil {
 			return nil, err
 		}
@@ -445,7 +455,7 @@ func (ctx *symmetricKeyCipher) decryptKey(headers rawHeader, recipient *recipien
 			return nil, err
 		}
 
-		cek, err := josecipher.KeyUnwrap(block, recipient.encryptedKey)
+		cek, err := josecipher.KeyUnwrap(block, encryptedKey)
 		if err != nil {
 			return nil, err
 		}
-- 
2.53.0

