21#include <freerdp/log.h>
24#define TAG FREERDP_TAG("core.gateway.websocket")
26struct s_websocket_context
33 BYTE fragmentOriginalOpcode;
34 BYTE lengthAndMaskPosition;
35 WEBSOCKET_STATE state;
39static int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length);
41BOOL websocket_context_mask_and_send(BIO* bio,
wStream* sPacket,
wStream* sDataPacket,
44 const size_t len = Stream_Length(sDataPacket);
45 Stream_SetPosition(sDataPacket, 0);
47 if (!Stream_EnsureRemainingCapacity(sPacket, len))
52 for (; streamPos + 4 <= len; streamPos += 4)
54 const uint32_t data = Stream_Get_UINT32(sDataPacket);
55 Stream_Write_UINT32(sPacket, data ^ maskingKey);
59 for (; streamPos < len; streamPos++)
62 BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
63 Stream_Read_UINT8(sDataPacket, data);
64 Stream_Write_UINT8(sPacket, data ^ *partialMask);
67 Stream_SealLength(sPacket);
70 const size_t size = Stream_Length(sPacket);
71 const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
72 Stream_Free(sPacket, TRUE);
74 if ((status < 0) || ((
size_t)status != size))
80wStream* websocket_context_packet_new(
size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
82 WINPR_ASSERT(pMaskingKey);
89 else if (len < 0x10000)
94 wStream* sWS = Stream_New(NULL, fullLen);
98 UINT32 maskingKey = 0;
99 winpr_RAND(&maskingKey,
sizeof(maskingKey));
101 Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
103 Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
104 else if (len < 0x10000)
106 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
107 Stream_Write_UINT16_BE(sWS, (UINT16)len);
111 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
112 Stream_Write_UINT32_BE(sWS, 0);
113 Stream_Write_UINT32_BE(sWS, (UINT32)len);
115 Stream_Write_UINT32(sWS, maskingKey);
116 *pMaskingKey = maskingKey;
120BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio,
wStream* sPacket,
121 WEBSOCKET_OPCODE opcode)
123 WINPR_ASSERT(context);
125 if (context->closeSent)
128 if (opcode == WebsocketCloseOpcode)
129 context->closeSent = TRUE;
132 WINPR_ASSERT(sPacket);
134 const size_t len = Stream_Length(sPacket);
135 uint32_t maskingKey = 0;
136 wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
140 return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
143int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length)
149 if (length > INT32_MAX)
152 while (offset < length)
155 const size_t diff = length - offset;
156 int status = BIO_write(bio, &data[offset], (
int)diff);
159 offset += (size_t)status;
162 if (!BIO_should_retry(bio))
165 if (BIO_write_blocked(bio))
167 const long rstatus = BIO_wait_write(bio, 100);
171 else if (BIO_read_blocked(bio))
181int websocket_context_write(websocket_context* context, BIO* bio,
const BYTE* buf,
int isize,
182 WEBSOCKET_OPCODE opcode)
191 wStream* s = Stream_StaticConstInit(&sbuffer, buf, (
size_t)isize);
192 if (!websocket_context_write_wstream(context, bio, s, opcode))
197static int websocket_read_data(BIO* bio, BYTE* pBuffer,
size_t size,
198 websocket_context* encodingContext)
203 WINPR_ASSERT(pBuffer);
204 WINPR_ASSERT(encodingContext);
206 if (encodingContext->payloadLength == 0)
208 encodingContext->state = WebsocketStateOpcodeAndFin;
213 (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
214 if (rlen > INT32_MAX)
218 status = BIO_read(bio, pBuffer, (
int)rlen);
219 if ((status <= 0) || ((
size_t)status > encodingContext->payloadLength))
222 encodingContext->payloadLength -= (size_t)status;
224 if (encodingContext->payloadLength == 0)
225 encodingContext->state = WebsocketStateOpcodeAndFin;
230static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
233 WINPR_ASSERT(encodingContext);
235 wStream* s = encodingContext->responseStreamBuffer;
238 if (encodingContext->payloadLength == 0)
240 encodingContext->state = WebsocketStateOpcodeAndFin;
244 if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
247 "wStream::capacity [%" PRIuz
"] != encodingContext::paylaodLangth [%" PRIuz
"]",
248 Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
252 const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
257 if (!Stream_SafeSeek(s, (
size_t)status))
263static BOOL websocket_reply_close(BIO* bio, websocket_context* context,
wStream* s)
267 return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
270static BOOL websocket_reply_pong(BIO* bio, websocket_context* context,
wStream* s)
275 if (Stream_GetPosition(s) != 0)
276 return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
278 return websocket_reply_close(bio, context, NULL);
281static int websocket_handle_payload(BIO* bio, BYTE* pBuffer,
size_t size,
282 websocket_context* encodingContext)
287 WINPR_ASSERT(pBuffer);
288 WINPR_ASSERT(encodingContext);
290 const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
291 ? encodingContext->fragmentOriginalOpcode & 0xf
292 : encodingContext->opcode & 0xf);
294 switch (effectiveOpcode)
296 case WebsocketBinaryOpcode:
298 status = websocket_read_data(bio, pBuffer, size, encodingContext);
304 case WebsocketPingOpcode:
306 status = websocket_read_wstream(bio, encodingContext);
310 if (encodingContext->payloadLength == 0)
312 websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
313 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
317 case WebsocketPongOpcode:
319 status = websocket_read_wstream(bio, encodingContext);
323 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
326 case WebsocketCloseOpcode:
328 status = websocket_read_wstream(bio, encodingContext);
332 if (encodingContext->payloadLength == 0)
334 websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
335 encodingContext->closeSent = TRUE;
336 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
341 WLog_WARN(TAG,
"Unimplemented websocket opcode %" PRIx8
". Dropping", effectiveOpcode);
343 status = websocket_read_wstream(bio, encodingContext);
346 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
354int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer,
size_t size)
357 size_t effectiveDataLen = 0;
360 WINPR_ASSERT(pBuffer);
361 WINPR_ASSERT(encodingContext);
365 switch (encodingContext->state)
367 case WebsocketStateOpcodeAndFin:
369 BYTE buffer[1] = { 0 };
372 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
374 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
377 encodingContext->opcode = buffer[0];
378 if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
379 (encodingContext->opcode & 0xf) < 0x08)
380 encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
381 encodingContext->state = WebsocketStateLengthAndMasking;
384 case WebsocketStateLengthAndMasking:
386 BYTE buffer[1] = { 0 };
389 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
391 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
394 encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
395 encodingContext->lengthAndMaskPosition = 0;
396 encodingContext->payloadLength = 0;
397 const BYTE len = buffer[0] & 0x7f;
400 encodingContext->payloadLength = len;
401 encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
402 : WebSocketStatePayload);
405 encodingContext->state = WebsocketStateShortLength;
407 encodingContext->state = WebsocketStateLongLength;
410 case WebsocketStateShortLength:
411 case WebsocketStateLongLength:
413 BYTE buffer[1] = { 0 };
414 const BYTE lenLength =
415 (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
416 while (encodingContext->lengthAndMaskPosition < lenLength)
419 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
421 return (effectiveDataLen > 0
422 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
424 if (status > UINT8_MAX)
426 encodingContext->payloadLength =
427 (encodingContext->payloadLength) << 8 | buffer[0];
428 encodingContext->lengthAndMaskPosition +=
429 WINPR_ASSERTING_INT_CAST(BYTE, status);
431 encodingContext->state =
432 (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
435 case WebSocketStateMaskingKey:
438 TAG,
"Websocket Server sends data with masking key. This is against RFC 6455.");
441 case WebSocketStatePayload:
443 status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
445 return (effectiveDataLen > 0 ? WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen)
448 effectiveDataLen += WINPR_ASSERTING_INT_CAST(
size_t, status);
450 if (WINPR_ASSERTING_INT_CAST(
size_t, status) >= size)
451 return WINPR_ASSERTING_INT_CAST(
int, effectiveDataLen);
453 size -= WINPR_ASSERTING_INT_CAST(
size_t, status);
463websocket_context* websocket_context_new(
void)
465 websocket_context* context = calloc(1,
sizeof(websocket_context));
469 context->responseStreamBuffer = Stream_New(NULL, 1024);
470 if (!context->responseStreamBuffer)
473 if (!websocket_context_reset(context))
478 websocket_context_free(context);
482void websocket_context_free(websocket_context* context)
487 Stream_Free(context->responseStreamBuffer, TRUE);
491BOOL websocket_context_reset(websocket_context* context)
493 WINPR_ASSERT(context);
495 context->state = WebsocketStateOpcodeAndFin;
496 return Stream_SetPosition(context->responseStreamBuffer, 0);