From 4937fe6cf241139ddbfc16b0bdbb5b422798909d Mon Sep 17 00:00:00 2001
From: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Date: Wed, 11 Mar 2026 19:57:24 -0700
Subject: [PATCH] Fix some shenanigans with the cache file and IPython (#5038)

---
 CHANGES.md                       |  7 +++++++
 src/black/handle_ipynb_magics.py | 21 +++++++++++++++++----
 src/black/mode.py                |  7 +++----
 tests/test_black.py              |  9 +++++++++
 tests/test_ipynb.py              | 20 ++++++++++++++++++--
 5 files changed, 54 insertions(+), 10 deletions(-)

Index: black-25.1.0/src/black/handle_ipynb_magics.py
===================================================================
--- black-25.1.0.orig/src/black/handle_ipynb_magics.py
+++ black-25.1.0/src/black/handle_ipynb_magics.py
@@ -5,6 +5,8 @@ import collections
 import dataclasses
 import re
 import secrets
+import string
+from collections.abc import Collection
 import sys
 from functools import lru_cache
 from importlib.util import find_spec
@@ -194,6 +196,13 @@ def mask_cell(src: str) -> tuple[str, li
 def create_token(n_chars: int) -> str:
     """Create a randomly generated token that is n_chars characters long."""
     assert n_chars > 0
+    if n_chars == 1:
+        return secrets.choice(string.ascii_letters)
+    if n_chars < 4:
+        return "_" + "".join(
+            secrets.choice(string.ascii_letters + string.digits + "_")
+            for _ in range(n_chars - 1)
+        )
     n_bytes = max(n_chars // 2 - 1, 1)
     token = secrets.token_hex(n_bytes)
     if len(token) + 3 > n_chars:
@@ -203,7 +212,7 @@ def create_token(n_chars: int) -> str:
     return f'b"{token}"'
 
 
-def get_token(src: str, magic: str) -> str:
+def get_token(src: str, magic: str, existing_tokens: Collection[str] = ()) -> str:
     """Return randomly generated token to mask IPython magic with.
 
     For example, if 'magic' was `%matplotlib inline`, then a possible
@@ -215,7 +224,7 @@ def get_token(src: str, magic: str) -> s
     n_chars = len(magic)
     token = create_token(n_chars)
     counter = 0
-    while token in src:
+    while token in src or token in existing_tokens:
         token = create_token(n_chars)
         counter += 1
         if counter > 100:
@@ -277,6 +286,7 @@ def replace_magics(src: str) -> tuple[st
     The replacement, along with the transformed code, are returned.
     """
     replacements = []
+    existing_tokens: set[str] = set()
     magic_finder = MagicFinder()
     magic_finder.visit(ast.parse(src))
     new_srcs = []
@@ -292,8 +302,9 @@ def replace_magics(src: str) -> tuple[st
                 offsets_and_magics[0].col_offset,
                 offsets_and_magics[0].magic,
             )
-            mask = get_token(src, magic)
+            mask = get_token(src, magic, existing_tokens)
             replacements.append(Replacement(mask=mask, src=magic))
+            existing_tokens.add(mask)
             line = line[:col_offset] + mask
         new_srcs.append(line)
     return "\n".join(new_srcs), replacements
@@ -313,7 +324,9 @@ def unmask_cell(src: str, replacements:
         foo = bar
     """
     for replacement in replacements:
-        src = src.replace(replacement.mask, replacement.src)
+        if src.count(replacement.mask) != 1:
+            raise NothingChanged
+        src = src.replace(replacement.mask, replacement.src, 1)
     return src
 
 
Index: black-25.1.0/src/black/mode.py
===================================================================
--- black-25.1.0.orig/src/black/mode.py
+++ black-25.1.0/src/black/mode.py
@@ -267,10 +267,9 @@ class Mode:
             + "@"
             + ",".join(sorted(self.python_cell_magics))
         )
-        if len(features_and_magics) > _MAX_CACHE_KEY_PART_LENGTH:
-            features_and_magics = sha256(features_and_magics.encode()).hexdigest()[
-                :_MAX_CACHE_KEY_PART_LENGTH
-            ]
+        features_and_magics = sha256(features_and_magics.encode()).hexdigest()[
+            :_MAX_CACHE_KEY_PART_LENGTH
+        ]
         parts = [
             version_str,
             str(self.line_length),
Index: black-25.1.0/tests/test_black.py
===================================================================
--- black-25.1.0.orig/tests/test_black.py
+++ black-25.1.0/tests/test_black.py
@@ -2129,6 +2129,15 @@ class TestCaching:
             # doesn't get too crazy.
             assert len(cache_file.name) <= 96
 
+    def test_cache_file_path_ignores_python_cell_magic_separators(self) -> None:
+        mode = replace(DEFAULT_MODE, python_cell_magics={"../../../tmp/pwned"})
+        with cache_dir() as workspace:
+            cache_file = get_cache_file(mode)
+            assert cache_file.parent == workspace
+            assert "/" not in cache_file.name
+            assert ".." not in cache_file.name
+            assert "../../../tmp/pwned" not in mode.get_cache_key()
+
     def test_cache_broken_file(self) -> None:
         mode = DEFAULT_MODE
         with cache_dir() as workspace:
Index: black-25.1.0/tests/test_ipynb.py
===================================================================
--- black-25.1.0.orig/tests/test_ipynb.py
+++ black-25.1.0/tests/test_ipynb.py
@@ -6,8 +6,8 @@ from contextlib import ExitStack as does
 from dataclasses import replace
 
 import pytest
-from _pytest.monkeypatch import MonkeyPatch
 from click.testing import CliRunner
+from pytest import MonkeyPatch
 
 from black import (
     Mode,
@@ -17,7 +17,12 @@ from black import (
     format_file_in_place,
     main,
 )
-from black.handle_ipynb_magics import jupyter_dependencies_are_installed
+from black.handle_ipynb_magics import (
+    Replacement,
+    create_token,
+    jupyter_dependencies_are_installed,
+    unmask_cell,
+)
 from tests.util import DATA_DIR, get_case_path, read_jupyter_notebook
 
 with contextlib.suppress(ModuleNotFoundError):
@@ -39,6 +44,17 @@ def test_noop() -> None:
         format_cell(src, fast=True, mode=JUPYTER_MODE)
 
 
+@pytest.mark.parametrize("n_chars", [1, 2, 3, 4, 5, 17])
+def test_create_token_uses_requested_length(n_chars: int) -> None:
+    assert len(create_token(n_chars)) == n_chars
+
+
+def test_unmask_cell_raises_when_token_is_not_unique() -> None:
+    replacement = Replacement(mask='b"dead"', src="%time")
+    with pytest.raises(NothingChanged):
+        unmask_cell(f"{replacement.mask}\nvalue = {replacement.mask}", [replacement])
+
+
 @pytest.mark.parametrize("fast", [True, False])
 def test_trailing_semicolon(fast: bool) -> None:
     src = 'foo = "a" ;'
