38 #include <freerdp/config.h>
42 #include <sys/types.h>
43 #include <sys/socket.h>
49 #include <winpr/crt.h>
50 #include <winpr/assert.h>
51 #include <winpr/synch.h>
52 #include <winpr/thread.h>
53 #include <winpr/stream.h>
55 #include "sshagent_main.h"
57 #include <freerdp/freerdp.h>
58 #include <freerdp/client/channels.h>
59 #include <freerdp/channels/log.h>
61 #define TAG CHANNELS_TAG("sshagent.client")
65 IWTSListenerCallback iface;
68 IWTSVirtualChannelManager* channel_mgr;
70 rdpContext* rdpcontext;
71 const char* agent_uds_path;
72 } SSHAGENT_LISTENER_CALLBACK;
78 rdpContext* rdpcontext;
82 } SSHAGENT_CHANNEL_CALLBACK;
88 SSHAGENT_LISTENER_CALLBACK* listener_callback;
90 rdpContext* rdpcontext;
98 static int connect_to_sshagent(
const char* udspath)
100 WINPR_ASSERT(udspath);
102 int agent_fd = socket(AF_UNIX, SOCK_STREAM, 0);
106 WLog_ERR(TAG,
"Can't open Unix domain socket!");
110 struct sockaddr_un addr = { 0 };
112 addr.sun_family = AF_UNIX;
114 strncpy(addr.sun_path, udspath,
sizeof(addr.sun_path) - 1);
116 int rc = connect(agent_fd, (
struct sockaddr*)&addr,
sizeof(addr));
120 WLog_ERR(TAG,
"Can't connect to Unix domain socket \"%s\"!", udspath);
134 static DWORD WINAPI sshagent_read_thread(LPVOID data)
136 SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)data;
137 WINPR_ASSERT(callback);
139 BYTE buffer[4096] = { 0 };
141 UINT status = CHANNEL_RC_OK;
145 const ssize_t bytes_read = read(callback->agent_fd, buffer,
sizeof(buffer));
152 else if (bytes_read < 0)
156 WLog_ERR(TAG,
"Error reading from sshagent, errno=%d", errno);
157 status = ERROR_READ_FAULT;
161 else if ((
size_t)bytes_read > ULONG_MAX)
163 status = ERROR_READ_FAULT;
169 IWTSVirtualChannel* channel = callback->generic.channel;
170 status = channel->Write(channel, (ULONG)bytes_read, buffer, NULL);
172 if (status != CHANNEL_RC_OK)
179 close(callback->agent_fd);
181 if (status != CHANNEL_RC_OK)
182 setChannelError(callback->rdpcontext, status,
"sshagent_read_thread reported an error");
193 static UINT sshagent_on_data_received(IWTSVirtualChannelCallback* pChannelCallback,
wStream* data)
195 SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)pChannelCallback;
196 WINPR_ASSERT(callback);
198 BYTE* pBuffer = Stream_Pointer(data);
199 size_t cbSize = Stream_GetRemainingLength(data);
202 size_t bytes_to_write = cbSize;
205 while (bytes_to_write > 0)
207 const ssize_t bytes_written = write(callback->agent_fd, pos, bytes_to_write);
209 if (bytes_written < 0)
213 WLog_ERR(TAG,
"Error writing to sshagent, errno=%d", errno);
214 return ERROR_WRITE_FAULT;
219 bytes_to_write -= WINPR_ASSERTING_INT_CAST(
size_t, bytes_written);
220 pos += bytes_written;
225 Stream_Seek(data, cbSize);
226 return CHANNEL_RC_OK;
234 static UINT sshagent_on_close(IWTSVirtualChannelCallback* pChannelCallback)
236 SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)pChannelCallback;
237 WINPR_ASSERT(callback);
240 shutdown(callback->agent_fd, SHUT_RDWR);
241 EnterCriticalSection(&callback->lock);
243 if (WaitForSingleObject(callback->thread, INFINITE) == WAIT_FAILED)
245 UINT error = GetLastError();
246 WLog_ERR(TAG,
"WaitForSingleObject failed with error %" PRIu32
"!", error);
250 (void)CloseHandle(callback->thread);
251 LeaveCriticalSection(&callback->lock);
252 DeleteCriticalSection(&callback->lock);
254 return CHANNEL_RC_OK;
262 static UINT sshagent_on_new_channel_connection(IWTSListenerCallback* pListenerCallback,
263 IWTSVirtualChannel* pChannel, BYTE* Data,
265 IWTSVirtualChannelCallback** ppCallback)
267 SSHAGENT_LISTENER_CALLBACK* listener_callback = (SSHAGENT_LISTENER_CALLBACK*)pListenerCallback;
269 WINPR_UNUSED(pbAccept);
271 SSHAGENT_CHANNEL_CALLBACK* callback =
272 (SSHAGENT_CHANNEL_CALLBACK*)calloc(1,
sizeof(SSHAGENT_CHANNEL_CALLBACK));
276 WLog_ERR(TAG,
"calloc failed!");
277 return CHANNEL_RC_NO_MEMORY;
282 callback->agent_fd = connect_to_sshagent(listener_callback->agent_uds_path);
284 if (callback->agent_fd == -1)
287 return CHANNEL_RC_INITIALIZATION_ERROR;
290 InitializeCriticalSection(&callback->lock);
293 generic->iface.OnDataReceived = sshagent_on_data_received;
294 generic->iface.OnClose = sshagent_on_close;
295 generic->plugin = listener_callback->plugin;
296 generic->channel_mgr = listener_callback->channel_mgr;
297 generic->channel = pChannel;
298 callback->rdpcontext = listener_callback->rdpcontext;
299 callback->thread = CreateThread(NULL, 0, sshagent_read_thread, (
void*)callback, 0, NULL);
301 if (!callback->thread)
303 WLog_ERR(TAG,
"CreateThread failed!");
304 DeleteCriticalSection(&callback->lock);
306 return CHANNEL_RC_INITIALIZATION_ERROR;
309 *ppCallback = (IWTSVirtualChannelCallback*)callback;
310 return CHANNEL_RC_OK;
318 static UINT sshagent_plugin_initialize(IWTSPlugin* pPlugin, IWTSVirtualChannelManager* pChannelMgr)
320 SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*)pPlugin;
321 WINPR_ASSERT(sshagent);
322 WINPR_ASSERT(pChannelMgr);
324 sshagent->listener_callback =
325 (SSHAGENT_LISTENER_CALLBACK*)calloc(1,
sizeof(SSHAGENT_LISTENER_CALLBACK));
327 if (!sshagent->listener_callback)
329 WLog_ERR(TAG,
"calloc failed!");
330 return CHANNEL_RC_NO_MEMORY;
333 sshagent->listener_callback->rdpcontext = sshagent->rdpcontext;
334 sshagent->listener_callback->iface.OnNewChannelConnection = sshagent_on_new_channel_connection;
335 sshagent->listener_callback->plugin = pPlugin;
336 sshagent->listener_callback->channel_mgr = pChannelMgr;
337 sshagent->listener_callback->agent_uds_path = getenv(
"SSH_AUTH_SOCK");
339 if (sshagent->listener_callback->agent_uds_path == NULL)
341 WLog_ERR(TAG,
"Environment variable $SSH_AUTH_SOCK undefined!");
342 free(sshagent->listener_callback);
343 sshagent->listener_callback = NULL;
344 return CHANNEL_RC_INITIALIZATION_ERROR;
347 return pChannelMgr->CreateListener(pChannelMgr,
"SSHAGENT", 0,
348 (IWTSListenerCallback*)sshagent->listener_callback, NULL);
356 static UINT sshagent_plugin_terminated(IWTSPlugin* pPlugin)
358 SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*)pPlugin;
360 return CHANNEL_RC_OK;
368 FREERDP_ENTRY_POINT(UINT VCAPITYPE sshagent_DVCPluginEntry(IDRDYNVC_ENTRY_POINTS* pEntryPoints))
370 UINT status = CHANNEL_RC_OK;
372 WINPR_ASSERT(pEntryPoints);
374 SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*)pEntryPoints->GetPlugin(pEntryPoints,
"sshagent");
378 sshagent = (SSHAGENT_PLUGIN*)calloc(1,
sizeof(SSHAGENT_PLUGIN));
382 WLog_ERR(TAG,
"calloc failed!");
383 return CHANNEL_RC_NO_MEMORY;
386 sshagent->iface.Initialize = sshagent_plugin_initialize;
387 sshagent->iface.Connected = NULL;
388 sshagent->iface.Disconnected = NULL;
389 sshagent->iface.Terminated = sshagent_plugin_terminated;
390 sshagent->rdpcontext = pEntryPoints->GetRdpContext(pEntryPoints);
391 status = pEntryPoints->RegisterPlugin(pEntryPoints,
"sshagent", &sshagent->iface);