38 #include <freerdp/config.h>
42 #include <sys/types.h>
43 #include <sys/socket.h>
49 #include <winpr/crt.h>
50 #include <winpr/synch.h>
51 #include <winpr/thread.h>
52 #include <winpr/stream.h>
54 #include "sshagent_main.h"
56 #include <freerdp/freerdp.h>
57 #include <freerdp/client/channels.h>
58 #include <freerdp/channels/log.h>
60 #define TAG CHANNELS_TAG("sshagent.client")
64 IWTSListenerCallback iface;
67 IWTSVirtualChannelManager* channel_mgr;
69 rdpContext* rdpcontext;
70 const char* agent_uds_path;
71 } SSHAGENT_LISTENER_CALLBACK;
77 rdpContext* rdpcontext;
81 } SSHAGENT_CHANNEL_CALLBACK;
87 SSHAGENT_LISTENER_CALLBACK* listener_callback;
89 rdpContext* rdpcontext;
97 static int connect_to_sshagent(
const char* udspath)
99 int agent_fd = socket(AF_UNIX, SOCK_STREAM, 0);
103 WLog_ERR(TAG,
"Can't open Unix domain socket!");
107 struct sockaddr_un addr = { 0 };
109 addr.sun_family = AF_UNIX;
111 strncpy(addr.sun_path, udspath,
sizeof(addr.sun_path) - 1);
113 int rc = connect(agent_fd, (
struct sockaddr*)&addr,
sizeof(addr));
117 WLog_ERR(TAG,
"Can't connect to Unix domain socket \"%s\"!", udspath);
131 static DWORD WINAPI sshagent_read_thread(LPVOID data)
133 SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)data;
134 BYTE buffer[4096] = { 0 };
136 UINT status = CHANNEL_RC_OK;
140 const ssize_t bytes_read = read(callback->agent_fd, buffer,
sizeof(buffer));
147 else if (bytes_read < 0)
151 WLog_ERR(TAG,
"Error reading from sshagent, errno=%d", errno);
152 status = ERROR_READ_FAULT;
156 else if ((
size_t)bytes_read > ULONG_MAX)
158 status = ERROR_READ_FAULT;
164 IWTSVirtualChannel* channel = callback->generic.channel;
165 status = channel->Write(channel, (ULONG)bytes_read, buffer, NULL);
167 if (status != CHANNEL_RC_OK)
174 close(callback->agent_fd);
176 if (status != CHANNEL_RC_OK)
177 setChannelError(callback->rdpcontext, status,
"sshagent_read_thread reported an error");
188 static UINT sshagent_on_data_received(IWTSVirtualChannelCallback* pChannelCallback,
wStream* data)
190 SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)pChannelCallback;
191 BYTE* pBuffer = Stream_Pointer(data);
192 size_t cbSize = Stream_GetRemainingLength(data);
195 size_t bytes_to_write = cbSize;
198 while (bytes_to_write > 0)
200 const ssize_t bytes_written = write(callback->agent_fd, pos, bytes_to_write);
202 if (bytes_written < 0)
206 WLog_ERR(TAG,
"Error writing to sshagent, errno=%d", errno);
207 return ERROR_WRITE_FAULT;
212 bytes_to_write -= bytes_written;
213 pos += bytes_written;
218 Stream_Seek(data, cbSize);
219 return CHANNEL_RC_OK;
227 static UINT sshagent_on_close(IWTSVirtualChannelCallback* pChannelCallback)
229 SSHAGENT_CHANNEL_CALLBACK* callback = (SSHAGENT_CHANNEL_CALLBACK*)pChannelCallback;
231 shutdown(callback->agent_fd, SHUT_RDWR);
232 EnterCriticalSection(&callback->lock);
234 if (WaitForSingleObject(callback->thread, INFINITE) == WAIT_FAILED)
236 UINT error = GetLastError();
237 WLog_ERR(TAG,
"WaitForSingleObject failed with error %" PRIu32
"!", error);
241 (void)CloseHandle(callback->thread);
242 LeaveCriticalSection(&callback->lock);
243 DeleteCriticalSection(&callback->lock);
245 return CHANNEL_RC_OK;
253 static UINT sshagent_on_new_channel_connection(IWTSListenerCallback* pListenerCallback,
254 IWTSVirtualChannel* pChannel, BYTE* Data,
256 IWTSVirtualChannelCallback** ppCallback)
258 SSHAGENT_LISTENER_CALLBACK* listener_callback = (SSHAGENT_LISTENER_CALLBACK*)pListenerCallback;
259 SSHAGENT_CHANNEL_CALLBACK* callback =
260 (SSHAGENT_CHANNEL_CALLBACK*)calloc(1,
sizeof(SSHAGENT_CHANNEL_CALLBACK));
264 WLog_ERR(TAG,
"calloc failed!");
265 return CHANNEL_RC_NO_MEMORY;
270 callback->agent_fd = connect_to_sshagent(listener_callback->agent_uds_path);
272 if (callback->agent_fd == -1)
275 return CHANNEL_RC_INITIALIZATION_ERROR;
278 InitializeCriticalSection(&callback->lock);
281 generic->iface.OnDataReceived = sshagent_on_data_received;
282 generic->iface.OnClose = sshagent_on_close;
283 generic->plugin = listener_callback->plugin;
284 generic->channel_mgr = listener_callback->channel_mgr;
285 generic->channel = pChannel;
286 callback->rdpcontext = listener_callback->rdpcontext;
287 callback->thread = CreateThread(NULL, 0, sshagent_read_thread, (
void*)callback, 0, NULL);
289 if (!callback->thread)
291 WLog_ERR(TAG,
"CreateThread failed!");
292 DeleteCriticalSection(&callback->lock);
294 return CHANNEL_RC_INITIALIZATION_ERROR;
297 *ppCallback = (IWTSVirtualChannelCallback*)callback;
298 return CHANNEL_RC_OK;
306 static UINT sshagent_plugin_initialize(IWTSPlugin* pPlugin, IWTSVirtualChannelManager* pChannelMgr)
308 SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*)pPlugin;
309 sshagent->listener_callback =
310 (SSHAGENT_LISTENER_CALLBACK*)calloc(1,
sizeof(SSHAGENT_LISTENER_CALLBACK));
312 if (!sshagent->listener_callback)
314 WLog_ERR(TAG,
"calloc failed!");
315 return CHANNEL_RC_NO_MEMORY;
318 sshagent->listener_callback->rdpcontext = sshagent->rdpcontext;
319 sshagent->listener_callback->iface.OnNewChannelConnection = sshagent_on_new_channel_connection;
320 sshagent->listener_callback->plugin = pPlugin;
321 sshagent->listener_callback->channel_mgr = pChannelMgr;
322 sshagent->listener_callback->agent_uds_path = getenv(
"SSH_AUTH_SOCK");
324 if (sshagent->listener_callback->agent_uds_path == NULL)
326 WLog_ERR(TAG,
"Environment variable $SSH_AUTH_SOCK undefined!");
327 free(sshagent->listener_callback);
328 sshagent->listener_callback = NULL;
329 return CHANNEL_RC_INITIALIZATION_ERROR;
332 return pChannelMgr->CreateListener(pChannelMgr,
"SSHAGENT", 0,
333 (IWTSListenerCallback*)sshagent->listener_callback, NULL);
341 static UINT sshagent_plugin_terminated(IWTSPlugin* pPlugin)
343 SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*)pPlugin;
345 return CHANNEL_RC_OK;
353 FREERDP_ENTRY_POINT(UINT VCAPITYPE sshagent_DVCPluginEntry(IDRDYNVC_ENTRY_POINTS* pEntryPoints))
355 UINT status = CHANNEL_RC_OK;
356 SSHAGENT_PLUGIN* sshagent = (SSHAGENT_PLUGIN*)pEntryPoints->GetPlugin(pEntryPoints,
"sshagent");
360 sshagent = (SSHAGENT_PLUGIN*)calloc(1,
sizeof(SSHAGENT_PLUGIN));
364 WLog_ERR(TAG,
"calloc failed!");
365 return CHANNEL_RC_NO_MEMORY;
368 sshagent->iface.Initialize = sshagent_plugin_initialize;
369 sshagent->iface.Connected = NULL;
370 sshagent->iface.Disconnected = NULL;
371 sshagent->iface.Terminated = sshagent_plugin_terminated;
372 sshagent->rdpcontext = pEntryPoints->GetRdpContext(pEntryPoints);
373 status = pEntryPoints->RegisterPlugin(pEntryPoints,
"sshagent", &sshagent->iface);