From d9ca272dce7a776ab475e9b1a8e8c3d2968c8486 Mon Sep 17 00:00:00 2001
From: akallabeth <akallabeth@posteo.net>
Date: Mon, 26 Jan 2026 12:08:48 +0100
Subject: [PATCH] [channels,ainput] lock context when updating listener

---
 channels/ainput/client/ainput_main.c | 36 ++++++++++++++++++++--------
 1 file changed, 26 insertions(+), 10 deletions(-)

diff --git a/channels/ainput/client/ainput_main.c b/channels/ainput/client/ainput_main.c
index c291bd727..554575360 100644
--- a/channels/ainput/client/ainput_main.c
+++ b/channels/ainput/client/ainput_main.c
@@ -45,6 +45,7 @@ struct AINPUT_PLUGIN_
 	AInputClientContext* context;
 	UINT32 MajorVersion;
 	UINT32 MinorVersion;
+	CRITICAL_SECTION lock;
 };
 
 /**
@@ -85,18 +86,15 @@ static UINT ainput_on_data_received(IWTSVirtualChannelCallback* pChannelCallback
 
 static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags, INT32 x, INT32 y)
 {
-	AINPUT_PLUGIN* ainput = NULL;
-	GENERIC_CHANNEL_CALLBACK* callback = NULL;
 	BYTE buffer[32] = { 0 };
-	UINT64 time = 0;
 	wStream sbuffer = { 0 };
 	wStream* s = Stream_StaticInit(&sbuffer, buffer, sizeof(buffer));
 
 	WINPR_ASSERT(s);
 	WINPR_ASSERT(context);
 
-	time = GetTickCount64();
-	ainput = (AINPUT_PLUGIN*)context->handle;
+	const UINT64 time = GetTickCount64();
+	AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)context->handle;
 	WINPR_ASSERT(ainput);
 
 	if (ainput->MajorVersion != AINPUT_VERSION_MAJOR)
@@ -105,8 +103,6 @@ static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags,
 		          ainput->MajorVersion, ainput->MinorVersion);
 		return CHANNEL_RC_UNSUPPORTED_VERSION;
 	}
-	callback = ainput->base.listener_callback->channel_callback;
-	WINPR_ASSERT(callback);
 
 	{
 		char ebuffer[128] = { 0 };
@@ -125,10 +121,15 @@ static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags,
 	Stream_SealLength(s);
 
 	/* ainput back what we have received. AINPUT does not have any message IDs. */
+	EnterCriticalSection(&ainput->lock);
+	GENERIC_CHANNEL_CALLBACK* callback = ainput->base.listener_callback->channel_callback;
+	WINPR_ASSERT(callback);
 	WINPR_ASSERT(callback->channel);
 	WINPR_ASSERT(callback->channel->Write);
-	return callback->channel->Write(callback->channel, (ULONG)Stream_Length(s), Stream_Buffer(s),
-	                                NULL);
+	const UINT rc = callback->channel->Write(callback->channel, (ULONG)Stream_Length(s),
+	                                         Stream_Buffer(s), NULL);
+	LeaveCriticalSection(&ainput->lock);
+	return rc;
 }
 
 /**
@@ -140,8 +141,16 @@ static UINT ainput_on_close(IWTSVirtualChannelCallback* pChannelCallback)
 {
 	GENERIC_CHANNEL_CALLBACK* callback = (GENERIC_CHANNEL_CALLBACK*)pChannelCallback;
 
-	free(callback);
+	if (callback)
+	{
+		AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)callback->plugin;
+		WINPR_ASSERT(ainput);
 
+		/* Lock here to ensure that no ainput_send_input_event is in progress. */
+		EnterCriticalSection(&ainput->lock);
+		free(callback);
+		LeaveCriticalSection(&ainput->lock);
+	}
 	return CHANNEL_RC_OK;
 }
 
@@ -156,14 +165,21 @@ static UINT init_plugin_cb(GENERIC_DYNVC_PLUGIN* base, WINPR_ATTR_UNUSED rdpCont
 	context->handle = (void*)base;
 	context->AInputSendInputEvent = ainput_send_input_event;
 
+	InitializeCriticalSection(&ainput->lock);
+
+	EnterCriticalSection(&ainput->lock);
 	ainput->context = context;
 	ainput->base.iface.pInterface = context;
+	LeaveCriticalSection(&ainput->lock);
 	return CHANNEL_RC_OK;
 }
 
 static void terminate_plugin_cb(GENERIC_DYNVC_PLUGIN* base)
 {
 	AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)base;
+	WINPR_ASSERT(ainput);
+
+	DeleteCriticalSection(&ainput->lock);
 	free(ainput->context);
 }
 
-- 
2.53.0

