From aa66d5b12322f6c326aa7a826db9451170956fa0 Mon Sep 17 00:00:00 2001
From: Eugene Kliuchnikov <eustas@google.com>
Date: Wed, 4 Feb 2026 10:05:33 +0100
Subject: [PATCH] Fix #4549 (#4579)

In case of grayscale on either side in CmsStage (if LCMS was used) wrong assumption was made on the length / structure of buffers.
---
 lib/jxl/render_pipeline/stage_cms.cc | 42 +++++++++++++++-------------
 1 file changed, 23 insertions(+), 19 deletions(-)

diff -urp libjxl-0.10.3.orig/lib/jxl/render_pipeline/stage_cms.cc libjxl-0.10.3/lib/jxl/render_pipeline/stage_cms.cc
--- libjxl-0.10.3.orig/lib/jxl/render_pipeline/stage_cms.cc	2024-06-27 07:10:08.000000000 -0500
+++ libjxl-0.10.3/lib/jxl/render_pipeline/stage_cms.cc	2026-02-24 13:49:14.774005966 -0600
@@ -45,27 +45,45 @@ class CmsStage : public RenderPipelineSt
                     size_t xextra, size_t xsize, size_t xpos, size_t ypos,
                     size_t thread_id) const final {
     JXL_ASSERT(xsize <= xsize_);
-    // TODO(firsching): handle grey case seperately
-    //  interleave
-    float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0);
-    float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0);
-    float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0);
+    bool gray_src = (c_src_.Channels() == 1);
+    bool gray_dst = (output_encoding_info_.color_encoding.Channels() == 1);
     float* mutable_buf_src = color_space_transform->BufSrc(thread_id);
+    float* JXL_RESTRICT buf_dst = color_space_transform->BufDst(thread_id);
+    //  interleave
+    if (gray_src) {
+      float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0);
+      memcpy(mutable_buf_src, row0, xsize * sizeof(float));
+    } else {
+      float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0);
+      float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0);
+      float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0);
 
-    for (size_t x = 0; x < xsize; x++) {
-      mutable_buf_src[3 * x + 0] = row0[x];
-      mutable_buf_src[3 * x + 1] = row1[x];
-      mutable_buf_src[3 * x + 2] = row2[x];
+      for (size_t x = 0; x < xsize; x++) {
+        mutable_buf_src[3 * x + 0] = row0[x];
+        mutable_buf_src[3 * x + 1] = row1[x];
+        mutable_buf_src[3 * x + 2] = row2[x];
+      }
     }
     const float* buf_src = mutable_buf_src;
-    float* JXL_RESTRICT buf_dst = color_space_transform->BufDst(thread_id);
     JXL_RETURN_IF_ERROR(
         color_space_transform->Run(thread_id, buf_src, buf_dst, xsize));
     // de-interleave
-    for (size_t x = 0; x < xsize; x++) {
-      row0[x] = buf_dst[3 * x + 0];
-      row1[x] = buf_dst[3 * x + 1];
-      row2[x] = buf_dst[3 * x + 2];
+    if (gray_dst) {
+      float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0);
+      float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0);
+      float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0);
+      memcpy(row0, buf_dst, xsize * sizeof(float));
+      memcpy(row1, buf_dst, xsize * sizeof(float));
+      memcpy(row2, buf_dst, xsize * sizeof(float));
+    } else {
+      float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0);
+      float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0);
+      float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0);
+      for (size_t x = 0; x < xsize; x++) {
+        row0[x] = buf_dst[3 * x + 0];
+        row1[x] = buf_dst[3 * x + 1];
+        row2[x] = buf_dst[3 * x + 2];
+      }
     }
     return true;
   }
