FreeRDP
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Modules Pages
websocket.c
1
20#include "websocket.h"
21#include <freerdp/log.h>
22#include "../tcp.h"
23
24#define TAG FREERDP_TAG("core.gateway.websocket")
25
26struct s_websocket_context
27{
28 size_t payloadLength;
29 uint32_t maskingKey;
30 BOOL masking;
31 BOOL closeSent;
32 BYTE opcode;
33 BYTE fragmentOriginalOpcode;
34 BYTE lengthAndMaskPosition;
35 WEBSOCKET_STATE state;
36 wStream* responseStreamBuffer;
37};
38
39static int websocket_write_all(BIO* bio, const BYTE* data, size_t length);
40
41BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
42 UINT32 maskingKey)
43{
44 const size_t len = Stream_Length(sDataPacket);
45 Stream_SetPosition(sDataPacket, 0);
46
47 if (!Stream_EnsureRemainingCapacity(sPacket, len))
48 return FALSE;
49
50 /* mask as much as possible with 32bit access */
51 size_t streamPos = 0;
52 for (; streamPos + 4 <= len; streamPos += 4)
53 {
54 const uint32_t data = Stream_Get_UINT32(sDataPacket);
55 Stream_Write_UINT32(sPacket, data ^ maskingKey);
56 }
57
58 /* mask the rest byte by byte */
59 for (; streamPos < len; streamPos++)
60 {
61 BYTE data = 0;
62 BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
63 Stream_Read_UINT8(sDataPacket, data);
64 Stream_Write_UINT8(sPacket, data ^ *partialMask);
65 }
66
67 Stream_SealLength(sPacket);
68
69 ERR_clear_error();
70 const size_t size = Stream_Length(sPacket);
71 const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
72 Stream_Free(sPacket, TRUE);
73
74 if ((status < 0) || ((size_t)status != size))
75 return FALSE;
76
77 return TRUE;
78}
79
80wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
81{
82 WINPR_ASSERT(pMaskingKey);
83 if (len > INT_MAX)
84 return NULL;
85
86 size_t fullLen = 0;
87 if (len < 126)
88 fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
89 else if (len < 0x10000)
90 fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
91 else
92 fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
93
94 wStream* sWS = Stream_New(NULL, fullLen);
95 if (!sWS)
96 return NULL;
97
98 UINT32 maskingKey = 0;
99 winpr_RAND(&maskingKey, sizeof(maskingKey));
100
101 Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
102 if (len < 126)
103 Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
104 else if (len < 0x10000)
105 {
106 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
107 Stream_Write_UINT16_BE(sWS, (UINT16)len);
108 }
109 else
110 {
111 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
112 Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */
113 Stream_Write_UINT32_BE(sWS, (UINT32)len);
114 }
115 Stream_Write_UINT32(sWS, maskingKey);
116 *pMaskingKey = maskingKey;
117 return sWS;
118}
119
120BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, wStream* sPacket,
121 WEBSOCKET_OPCODE opcode)
122{
123 WINPR_ASSERT(context);
124
125 if (context->closeSent)
126 return FALSE;
127
128 if (opcode == WebsocketCloseOpcode)
129 context->closeSent = TRUE;
130
131 WINPR_ASSERT(bio);
132 WINPR_ASSERT(sPacket);
133
134 const size_t len = Stream_Length(sPacket);
135 uint32_t maskingKey = 0;
136 wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
137 if (!sWS)
138 return FALSE;
139
140 return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
141}
142
143int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
144{
145 WINPR_ASSERT(bio);
146 WINPR_ASSERT(data);
147 size_t offset = 0;
148
149 if (length > INT32_MAX)
150 return -1;
151
152 while (offset < length)
153 {
154 ERR_clear_error();
155 const size_t diff = length - offset;
156 int status = BIO_write(bio, &data[offset], (int)diff);
157
158 if (status > 0)
159 offset += (size_t)status;
160 else
161 {
162 if (!BIO_should_retry(bio))
163 return -1;
164
165 if (BIO_write_blocked(bio))
166 {
167 const long rstatus = BIO_wait_write(bio, 100);
168 if (rstatus < 0)
169 return -1;
170 }
171 else if (BIO_read_blocked(bio))
172 return -2; /* Abort write, there is data that must be read */
173 else
174 USleep(100);
175 }
176 }
177
178 return (int)length;
179}
180
181int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, int isize,
182 WEBSOCKET_OPCODE opcode)
183{
184 WINPR_ASSERT(bio);
185 WINPR_ASSERT(buf);
186
187 if (isize < 0)
188 return -1;
189
190 wStream sbuffer = { 0 };
191 wStream* s = Stream_StaticConstInit(&sbuffer, buf, (size_t)isize);
192 if (!websocket_context_write_wstream(context, bio, s, opcode))
193 return -2;
194 return isize;
195}
196
197static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
198 websocket_context* encodingContext)
199{
200 int status = 0;
201
202 WINPR_ASSERT(bio);
203 WINPR_ASSERT(pBuffer);
204 WINPR_ASSERT(encodingContext);
205
206 if (encodingContext->payloadLength == 0)
207 {
208 encodingContext->state = WebsocketStateOpcodeAndFin;
209 return 0;
210 }
211
212 const size_t rlen =
213 (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
214 if (rlen > INT32_MAX)
215 return -1;
216
217 ERR_clear_error();
218 status = BIO_read(bio, pBuffer, (int)rlen);
219 if ((status <= 0) || ((size_t)status > encodingContext->payloadLength))
220 return status;
221
222 encodingContext->payloadLength -= (size_t)status;
223
224 if (encodingContext->payloadLength == 0)
225 encodingContext->state = WebsocketStateOpcodeAndFin;
226
227 return status;
228}
229
230static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
231{
232 WINPR_ASSERT(bio);
233 WINPR_ASSERT(encodingContext);
234
235 wStream* s = encodingContext->responseStreamBuffer;
236 WINPR_ASSERT(s);
237
238 if (encodingContext->payloadLength == 0)
239 {
240 encodingContext->state = WebsocketStateOpcodeAndFin;
241 return 0;
242 }
243
244 if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
245 {
246 WLog_WARN(TAG,
247 "wStream::capacity [%" PRIuz "] != encodingContext::paylaodLangth [%" PRIuz "]",
248 Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
249 return -1;
250 }
251
252 const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
253 encodingContext);
254 if (status < 0)
255 return status;
256
257 if (!Stream_SafeSeek(s, (size_t)status))
258 return -1;
259
260 return status;
261}
262
263static BOOL websocket_reply_close(BIO* bio, websocket_context* context, wStream* s)
264{
265 WINPR_ASSERT(bio);
266
267 return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
268}
269
270static BOOL websocket_reply_pong(BIO* bio, websocket_context* context, wStream* s)
271{
272 WINPR_ASSERT(bio);
273 WINPR_ASSERT(s);
274
275 if (Stream_GetPosition(s) != 0)
276 return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
277
278 return websocket_reply_close(bio, context, NULL);
279}
280
281static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
282 websocket_context* encodingContext)
283{
284 int status = 0;
285
286 WINPR_ASSERT(bio);
287 WINPR_ASSERT(pBuffer);
288 WINPR_ASSERT(encodingContext);
289
290 const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
291 ? encodingContext->fragmentOriginalOpcode & 0xf
292 : encodingContext->opcode & 0xf);
293
294 switch (effectiveOpcode)
295 {
296 case WebsocketBinaryOpcode:
297 {
298 status = websocket_read_data(bio, pBuffer, size, encodingContext);
299 if (status < 0)
300 return status;
301
302 return status;
303 }
304 case WebsocketPingOpcode:
305 {
306 status = websocket_read_wstream(bio, encodingContext);
307 if (status < 0)
308 return status;
309
310 if (encodingContext->payloadLength == 0)
311 {
312 websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
313 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
314 }
315 }
316 break;
317 case WebsocketPongOpcode:
318 {
319 status = websocket_read_wstream(bio, encodingContext);
320 if (status < 0)
321 return status;
322 /* We don´t care about pong response data, discard. */
323 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
324 }
325 break;
326 case WebsocketCloseOpcode:
327 {
328 status = websocket_read_wstream(bio, encodingContext);
329 if (status < 0)
330 return status;
331
332 if (encodingContext->payloadLength == 0)
333 {
334 websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
335 encodingContext->closeSent = TRUE;
336 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
337 }
338 }
339 break;
340 default:
341 WLog_WARN(TAG, "Unimplemented websocket opcode %" PRIx8 ". Dropping", effectiveOpcode);
342
343 status = websocket_read_wstream(bio, encodingContext);
344 if (status < 0)
345 return status;
346 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
347 break;
348 }
349 /* return how many bytes have been written to pBuffer.
350 * Only WebsocketBinaryOpcode writes into it and it returns directly */
351 return 0;
352}
353
354int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer, size_t size)
355{
356 int status = 0;
357 size_t effectiveDataLen = 0;
358
359 WINPR_ASSERT(bio);
360 WINPR_ASSERT(pBuffer);
361 WINPR_ASSERT(encodingContext);
362
363 while (TRUE)
364 {
365 switch (encodingContext->state)
366 {
367 case WebsocketStateOpcodeAndFin:
368 {
369 BYTE buffer[1] = { 0 };
370
371 ERR_clear_error();
372 status = BIO_read(bio, (char*)buffer, sizeof(buffer));
373 if (status <= 0)
374 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
375 : status);
376
377 encodingContext->opcode = buffer[0];
378 if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
379 (encodingContext->opcode & 0xf) < 0x08)
380 encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
381 encodingContext->state = WebsocketStateLengthAndMasking;
382 }
383 break;
384 case WebsocketStateLengthAndMasking:
385 {
386 BYTE buffer[1] = { 0 };
387
388 ERR_clear_error();
389 status = BIO_read(bio, (char*)buffer, sizeof(buffer));
390 if (status <= 0)
391 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
392 : status);
393
394 encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
395 encodingContext->lengthAndMaskPosition = 0;
396 encodingContext->payloadLength = 0;
397 const BYTE len = buffer[0] & 0x7f;
398 if (len < 126)
399 {
400 encodingContext->payloadLength = len;
401 encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
402 : WebSocketStatePayload);
403 }
404 else if (len == 126)
405 encodingContext->state = WebsocketStateShortLength;
406 else
407 encodingContext->state = WebsocketStateLongLength;
408 }
409 break;
410 case WebsocketStateShortLength:
411 case WebsocketStateLongLength:
412 {
413 BYTE buffer[1] = { 0 };
414 const BYTE lenLength =
415 (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
416 while (encodingContext->lengthAndMaskPosition < lenLength)
417 {
418 ERR_clear_error();
419 status = BIO_read(bio, (char*)buffer, sizeof(buffer));
420 if (status <= 0)
421 return (effectiveDataLen > 0
422 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
423 : status);
424 if (status > UINT8_MAX)
425 return -1;
426 encodingContext->payloadLength =
427 (encodingContext->payloadLength) << 8 | buffer[0];
428 encodingContext->lengthAndMaskPosition +=
429 WINPR_ASSERTING_INT_CAST(BYTE, status);
430 }
431 encodingContext->state =
432 (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
433 }
434 break;
435 case WebSocketStateMaskingKey:
436 {
437 WLog_WARN(
438 TAG, "Websocket Server sends data with masking key. This is against RFC 6455.");
439 return -1;
440 }
441 case WebSocketStatePayload:
442 {
443 status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
444 if (status < 0)
445 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(int, effectiveDataLen)
446 : status);
447
448 effectiveDataLen += WINPR_ASSERTING_INT_CAST(size_t, status);
449
450 if (WINPR_ASSERTING_INT_CAST(size_t, status) >= size)
451 return WINPR_ASSERTING_INT_CAST(int, effectiveDataLen);
452 pBuffer += status;
453 size -= WINPR_ASSERTING_INT_CAST(size_t, status);
454 }
455 break;
456 default:
457 break;
458 }
459 }
460 /* should be unreachable */
461}
462
463websocket_context* websocket_context_new(void)
464{
465 websocket_context* context = calloc(1, sizeof(websocket_context));
466 if (!context)
467 goto fail;
468
469 context->responseStreamBuffer = Stream_New(NULL, 1024);
470 if (!context->responseStreamBuffer)
471 goto fail;
472
473 if (!websocket_context_reset(context))
474 goto fail;
475
476 return context;
477fail:
478 websocket_context_free(context);
479 return NULL;
480}
481
482void websocket_context_free(websocket_context* context)
483{
484 if (!context)
485 return;
486
487 Stream_Free(context->responseStreamBuffer, TRUE);
488 free(context);
489}
490
491BOOL websocket_context_reset(websocket_context* context)
492{
493 WINPR_ASSERT(context);
494
495 context->state = WebsocketStateOpcodeAndFin;
496 return Stream_SetPosition(context->responseStreamBuffer, 0);
497}