From 20419cf784dbf349dac05ea8937b3ce84ecacde6 Mon Sep 17 00:00:00 2001
From: Nicola Murino <nicola.murino@gmail.com>
Date: Sun, 15 Dec 2024 18:08:57 +0100
Subject: [PATCH] ssh: limit the size of the internal packet queue while
 waiting for KEX

In the SSH protocol, clients and servers execute the key exchange to
generate one-time session keys used for encryption and authentication.
The key exchange is performed initially after the connection is
established and then periodically after a configurable amount of data.
While a key exchange is in progress, we add the received packets to an
internal queue until we receive SSH_MSG_KEXINIT from the other side.
This can result in high memory usage if the other party is slow to
respond to the SSH_MSG_KEXINIT packet, or memory exhaustion if a
malicious client never responds to an SSH_MSG_KEXINIT packet during a
large file transfer.
We now limit the internal queue to 64 packets: this means 2MB with the
typical 32KB packet size.
When the internal queue is full we block further writes until the
pending key exchange is completed or there is a read or write error.

Thanks to Yuichi Watanabe for reporting this issue.

Change-Id: I1ce2214cc16e08b838d4bc346c74c72addafaeec
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/652135
Reviewed-by: Neal Patel <nealpatel@google.com>
Auto-Submit: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
---
 ssh/handshake.go | 41 ++++++++++++++++++++++++++++++++++-------
 1 file changed, 34 insertions(+), 7 deletions(-)

diff --git a/ssh/handshake.go b/ssh/handshake.go
index 2b10b05..30ef0c3 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -24,6 +24,11 @@ const debugHandshake = false
 // quickly.
 const chanSize = 16
 
+// maxPendingPackets sets the maximum number of packets to queue while waiting
+// for KEX to complete. This limits the total pending data to maxPendingPackets
+// * maxPacket bytes, which is ~16.8MB.
+const maxPendingPackets = 64
+
 // keyingTransport is a packet based transport that supports key
 // changes. It need not be thread-safe. It should pass through
 // msgNewKeys in both directions.
@@ -58,11 +63,19 @@ type handshakeTransport struct {
 	incoming  chan []byte
 	readError error
 
-	mu             sync.Mutex
+	mu sync.Mutex
+	// Condition for the above mutex. It is used to notify a completed key
+	// exchange or a write failure. Writes can wait for this condition while a
+	// key exchange is in progress.
+	writeCond      *sync.Cond
 	writeError     error
 	sentInitPacket []byte
 	sentInitMsg    *kexInitMsg
-	pendingPackets [][]byte // Used when a key exchange is in progress.
+	// Used to queue writes when a key exchange is in progress. The length is
+	// limited by pendingPacketsSize. Once full, writes will block until the key
+	// exchange is completed or an error occurs. If not empty, it is emptied
+	// all at once when the key exchange is completed in kexLoop.
+	pendingPackets   [][]byte
 
 	// If the read loop wants to schedule a kex, it pings this
 	// channel, and the write loop will send out a kex
@@ -112,6 +125,7 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
 
 		config: config,
 	}
+	t.writeCond = sync.NewCond(&t.mu)
 	t.resetReadThresholds()
 	t.resetWriteThresholds()
 
@@ -234,6 +248,7 @@ func (t *handshakeTransport) recordWriteError(err error) {
 	defer t.mu.Unlock()
 	if t.writeError == nil && err != nil {
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 }
 
@@ -337,6 +352,8 @@ write:
 			}
 		}
 		t.pendingPackets = t.pendingPackets[:0]
+		// Unblock writePacket if waiting for KEX.
+		t.writeCond.Broadcast()
 		t.mu.Unlock()
 	}
 
@@ -494,11 +511,20 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 	}
 
 	if t.sentInitMsg != nil {
-		// Copy the packet so the writer can reuse the buffer.
-		cp := make([]byte, len(p))
-		copy(cp, p)
-		t.pendingPackets = append(t.pendingPackets, cp)
-		return nil
+		if len(t.pendingPackets) < maxPendingPackets {
+			// Copy the packet so the writer can reuse the buffer.
+			cp := make([]byte, len(p))
+			copy(cp, p)
+			t.pendingPackets = append(t.pendingPackets, cp)
+			return nil
+		}
+		for t.sentInitMsg != nil {
+			// Block and wait for KEX to complete or an error.
+			t.writeCond.Wait()
+			if t.writeError != nil {
+				return t.writeError
+			}
+		}
 	}
 
 	if t.writeBytesLeft > 0 {
@@ -515,6 +541,7 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 
 	if err := t.pushPacket(p); err != nil {
 		t.writeError = err
+		t.writeCond.Broadcast()
 	}
 
 	return nil
-- 
2.53.0

