20 #include "websocket.h"
21 #include <freerdp/log.h>
24 #define TAG FREERDP_TAG("core.gateway.websocket")
26 BOOL websocket_write_wstream(BIO* bio,
wStream* sPacket, WEBSOCKET_OPCODE opcode)
32 uint32_t maskingKey = 0;
37 WINPR_ASSERT(sPacket);
39 const size_t len = Stream_Length(sPacket);
40 Stream_SetPosition(sPacket, 0);
47 else if (len < 0x10000)
52 sWS = Stream_New(NULL, fullLen);
56 winpr_RAND(&maskingKey,
sizeof(maskingKey));
58 Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
60 Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
61 else if (len < 0x10000)
63 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
64 Stream_Write_UINT16_BE(sWS, (UINT16)len);
68 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
69 Stream_Write_UINT32_BE(sWS, 0);
70 Stream_Write_UINT32_BE(sWS, (UINT32)len);
72 Stream_Write_UINT32(sWS, maskingKey);
75 for (streamPos = 0; streamPos + 4 <= len; streamPos += 4)
78 Stream_Read_UINT32(sPacket, data);
79 Stream_Write_UINT32(sWS, data ^ maskingKey);
83 for (; streamPos < len; streamPos++)
86 BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
87 Stream_Read_UINT8(sPacket, data);
88 Stream_Write_UINT8(sWS, data ^ *partialMask);
91 Stream_SealLength(sWS);
94 const size_t size = Stream_Length(sWS);
95 if (size <= INT32_MAX)
96 status = BIO_write(bio, Stream_Buffer(sWS), (
int)size);
97 Stream_Free(sWS, TRUE);
99 if (status != (SSIZE_T)fullLen)
105 static int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length)
111 if (length > INT32_MAX)
114 while (offset < length)
117 const size_t diff = length - offset;
118 int status = BIO_write(bio, &data[offset], (
int)diff);
121 offset += (size_t)status;
124 if (!BIO_should_retry(bio))
127 if (BIO_write_blocked(bio))
129 const long rstatus = BIO_wait_write(bio, 100);
133 else if (BIO_read_blocked(bio))
143 int websocket_write(BIO* bio,
const BYTE* buf,
int isize, WEBSOCKET_OPCODE opcode)
149 uint32_t maskingKey = 0;
154 winpr_RAND(&maskingKey,
sizeof(maskingKey));
159 const size_t payloadSize = (size_t)isize;
160 if (payloadSize < 126)
161 fullLen = payloadSize + 6;
162 else if (payloadSize < 0x10000)
163 fullLen = payloadSize + 8;
165 fullLen = payloadSize + 14;
167 sWS = Stream_New(NULL, fullLen);
171 Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
172 if (payloadSize < 126)
173 Stream_Write_UINT8(sWS, (UINT8)payloadSize | WEBSOCKET_MASK_BIT);
174 else if (payloadSize < 0x10000)
176 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
177 Stream_Write_UINT16_BE(sWS, (UINT16)payloadSize);
181 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
183 Stream_Write_UINT32_BE(sWS, 0);
184 Stream_Write_UINT32_BE(sWS, (UINT32)payloadSize);
186 Stream_Write_UINT32(sWS, maskingKey);
189 size_t streamPos = 0;
190 for (; streamPos + 4 <= payloadSize; streamPos += 4)
192 uint32_t masked = *((
const uint32_t*)(buf + streamPos)) ^ maskingKey;
193 Stream_Write_UINT32(sWS, masked);
197 for (; streamPos < payloadSize; streamPos++)
199 BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
200 BYTE masked = *((buf + streamPos)) ^ *partialMask;
201 Stream_Write_UINT8(sWS, masked);
204 Stream_SealLength(sWS);
206 status = websocket_write_all(bio, Stream_Buffer(sWS), Stream_Length(sWS));
207 Stream_Free(sWS, TRUE);
215 static int websocket_read_data(BIO* bio, BYTE* pBuffer,
size_t size,
221 WINPR_ASSERT(pBuffer);
222 WINPR_ASSERT(encodingContext);
224 if (encodingContext->payloadLength == 0)
226 encodingContext->state = WebsocketStateOpcodeAndFin;
231 (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
232 if (rlen > INT32_MAX)
236 status = BIO_read(bio, pBuffer, (
int)rlen);
240 encodingContext->payloadLength -= status;
242 if (encodingContext->payloadLength == 0)
243 encodingContext->state = WebsocketStateOpcodeAndFin;
250 char _dummy[256] = { 0 };
254 WINPR_ASSERT(encodingContext);
256 if (encodingContext->payloadLength == 0)
258 encodingContext->state = WebsocketStateOpcodeAndFin;
263 status = BIO_read(bio, _dummy,
sizeof(_dummy));
267 encodingContext->payloadLength -= status;
269 if (encodingContext->payloadLength == 0)
270 encodingContext->state = WebsocketStateOpcodeAndFin;
281 WINPR_ASSERT(encodingContext);
283 if (encodingContext->payloadLength == 0)
285 encodingContext->state = WebsocketStateOpcodeAndFin;
288 if (Stream_GetRemainingCapacity(s) != encodingContext->payloadLength)
291 "wStream::capacity [%" PRIuz
"] != encodingContext::paylaodLangth [%" PRIuz
"]",
292 Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
296 const size_t rlen = encodingContext->payloadLength;
297 if (rlen > INT32_MAX)
301 status = BIO_read(bio, Stream_Pointer(s), (
int)rlen);
305 Stream_Seek(s, status);
307 encodingContext->payloadLength -= status;
309 if (encodingContext->payloadLength == 0)
311 encodingContext->state = WebsocketStateOpcodeAndFin;
312 Stream_SealLength(s);
313 Stream_SetPosition(s, 0);
319 static BOOL websocket_reply_close(BIO* bio,
wStream* s)
323 uint16_t maskingKey1 = 0;
324 uint16_t maskingKey2 = 0;
325 size_t closeDataLen = 0;
330 if (s != NULL && Stream_Length(s) >= 2)
333 closeFrame = Stream_New(NULL, 6 + closeDataLen);
337 Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketCloseOpcode);
338 Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT);
339 winpr_RAND(&maskingKey1,
sizeof(maskingKey1));
340 winpr_RAND(&maskingKey2,
sizeof(maskingKey2));
341 Stream_Write_UINT16(closeFrame, maskingKey1);
342 Stream_Write_UINT16(closeFrame, maskingKey2);
344 if (closeDataLen == 2)
347 Stream_Read_UINT16(s, data);
348 Stream_Write_UINT16(closeFrame, data ^ maskingKey1);
350 Stream_SealLength(closeFrame);
352 const size_t rlen = Stream_Length(closeFrame);
355 if (rlen <= INT32_MAX)
358 status = BIO_write(bio, Stream_Buffer(closeFrame), (
int)rlen);
360 Stream_Free(closeFrame, TRUE);
369 static BOOL websocket_reply_pong(BIO* bio,
wStream* s)
372 uint32_t maskingKey = 0;
377 return websocket_write_wstream(bio, s, WebsocketPongOpcode);
379 closeFrame = Stream_New(NULL, 6);
383 Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
384 Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT);
385 winpr_RAND(&maskingKey,
sizeof(maskingKey));
386 Stream_Write_UINT32(closeFrame, maskingKey);
387 Stream_SealLength(closeFrame);
389 const size_t rlen = Stream_Length(closeFrame);
391 if (rlen <= INT32_MAX)
394 status = BIO_write(bio, Stream_Buffer(closeFrame), (
int)rlen);
396 Stream_Free(closeFrame, TRUE);
403 static int websocket_handle_payload(BIO* bio, BYTE* pBuffer,
size_t size,
409 WINPR_ASSERT(pBuffer);
410 WINPR_ASSERT(encodingContext);
412 BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
413 ? encodingContext->fragmentOriginalOpcode & 0xf
414 : encodingContext->opcode & 0xf);
416 switch (effectiveOpcode)
418 case WebsocketBinaryOpcode:
420 status = websocket_read_data(bio, pBuffer, size, encodingContext);
426 case WebsocketPingOpcode:
428 if (encodingContext->responseStreamBuffer == NULL)
429 encodingContext->responseStreamBuffer =
430 Stream_New(NULL, encodingContext->payloadLength);
433 websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext);
437 if (encodingContext->payloadLength == 0)
439 if (!encodingContext->closeSent)
440 websocket_reply_pong(bio, encodingContext->responseStreamBuffer);
442 Stream_Free(encodingContext->responseStreamBuffer, TRUE);
443 encodingContext->responseStreamBuffer = NULL;
447 case WebsocketCloseOpcode:
449 if (encodingContext->responseStreamBuffer == NULL)
450 encodingContext->responseStreamBuffer =
451 Stream_New(NULL, encodingContext->payloadLength);
454 websocket_read_wstream(bio, encodingContext->responseStreamBuffer, encodingContext);
458 if (encodingContext->payloadLength == 0)
460 websocket_reply_close(bio, encodingContext->responseStreamBuffer);
461 encodingContext->closeSent = TRUE;
463 if (encodingContext->responseStreamBuffer)
464 Stream_Free(encodingContext->responseStreamBuffer, TRUE);
465 encodingContext->responseStreamBuffer = NULL;
470 WLog_WARN(TAG,
"Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf);
472 status = websocket_read_discard(bio, encodingContext);
481 int websocket_read(BIO* bio, BYTE* pBuffer,
size_t size,
websocket_context* encodingContext)
484 int effectiveDataLen = 0;
487 WINPR_ASSERT(pBuffer);
488 WINPR_ASSERT(encodingContext);
492 switch (encodingContext->state)
494 case WebsocketStateOpcodeAndFin:
498 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
500 return (effectiveDataLen > 0 ? effectiveDataLen : status);
502 encodingContext->opcode = buffer[0];
503 if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
504 (encodingContext->opcode & 0xf) < 0x08)
505 encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
506 encodingContext->state = WebsocketStateLengthAndMasking;
509 case WebsocketStateLengthAndMasking:
514 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
516 return (effectiveDataLen > 0 ? effectiveDataLen : status);
518 encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
519 encodingContext->lengthAndMaskPosition = 0;
520 encodingContext->payloadLength = 0;
521 len = buffer[0] & 0x7f;
524 encodingContext->payloadLength = len;
525 encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
526 : WebSocketStatePayload);
529 encodingContext->state = WebsocketStateShortLength;
531 encodingContext->state = WebsocketStateLongLength;
534 case WebsocketStateShortLength:
535 case WebsocketStateLongLength:
538 BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
539 while (encodingContext->lengthAndMaskPosition < lenLength)
542 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
544 return (effectiveDataLen > 0 ? effectiveDataLen : status);
546 encodingContext->payloadLength =
547 (encodingContext->payloadLength) << 8 | buffer[0];
548 encodingContext->lengthAndMaskPosition += status;
550 encodingContext->state =
551 (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
554 case WebSocketStateMaskingKey:
557 TAG,
"Websocket Server sends data with masking key. This is against RFC 6455.");
560 case WebSocketStatePayload:
562 status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
564 return (effectiveDataLen > 0 ? effectiveDataLen : status);
566 effectiveDataLen += status;
568 if ((
size_t)status == size)
569 return effectiveDataLen;