20 #include "websocket.h"
21 #include <freerdp/log.h>
24 #define TAG FREERDP_TAG("core.gateway.websocket")
26 struct s_websocket_context
33 BYTE fragmentOriginalOpcode;
34 BYTE lengthAndMaskPosition;
35 WEBSOCKET_STATE state;
39 static int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length);
41 BOOL websocket_context_mask_and_send(BIO* bio,
wStream* sPacket,
wStream* sDataPacket,
44 const size_t len = Stream_Length(sDataPacket);
45 Stream_SetPosition(sDataPacket, 0);
47 if (!Stream_EnsureRemainingCapacity(sPacket, len))
52 for (; streamPos + 4 <= len; streamPos += 4)
54 const uint32_t data = Stream_Get_UINT32(sDataPacket);
55 Stream_Write_UINT32(sPacket, data ^ maskingKey);
59 for (; streamPos < len; streamPos++)
62 BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
63 Stream_Read_UINT8(sDataPacket, data);
64 Stream_Write_UINT8(sPacket, data ^ *partialMask);
67 Stream_SealLength(sPacket);
70 const size_t size = Stream_Length(sPacket);
71 const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
72 Stream_Free(sPacket, TRUE);
74 if ((status < 0) || ((
size_t)status != size))
80 wStream* websocket_context_packet_new(
size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
82 WINPR_ASSERT(pMaskingKey);
89 else if (len < 0x10000)
94 wStream* sWS = Stream_New(NULL, fullLen);
98 UINT32 maskingKey = 0;
99 winpr_RAND(&maskingKey,
sizeof(maskingKey));
101 Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
103 Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
104 else if (len < 0x10000)
106 Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
107 Stream_Write_UINT16_BE(sWS, (UINT16)len);
111 Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
112 Stream_Write_UINT32_BE(sWS, 0);
113 Stream_Write_UINT32_BE(sWS, (UINT32)len);
115 Stream_Write_UINT32(sWS, maskingKey);
116 *pMaskingKey = maskingKey;
120 BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio,
wStream* sPacket,
121 WEBSOCKET_OPCODE opcode)
123 WINPR_ASSERT(context);
125 if (context->closeSent)
128 if (opcode == WebsocketCloseOpcode)
129 context->closeSent = TRUE;
132 WINPR_ASSERT(sPacket);
134 const size_t len = Stream_Length(sPacket);
135 uint32_t maskingKey = 0;
136 wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
140 return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
143 int websocket_write_all(BIO* bio,
const BYTE* data,
size_t length)
149 if (length > INT32_MAX)
152 while (offset < length)
155 const size_t diff = length - offset;
156 int status = BIO_write(bio, &data[offset], (
int)diff);
159 offset += (size_t)status;
162 if (!BIO_should_retry(bio))
165 if (BIO_write_blocked(bio))
167 const long rstatus = BIO_wait_write(bio, 100);
171 else if (BIO_read_blocked(bio))
181 int websocket_context_write(websocket_context* context, BIO* bio,
const BYTE* buf,
int isize,
182 WEBSOCKET_OPCODE opcode)
191 wStream* s = Stream_StaticConstInit(&sbuffer, buf, (
size_t)isize);
192 if (!websocket_context_write_wstream(context, bio, s, opcode))
197 static int websocket_read_data(BIO* bio, BYTE* pBuffer,
size_t size,
198 websocket_context* encodingContext)
203 WINPR_ASSERT(pBuffer);
204 WINPR_ASSERT(encodingContext);
206 if (encodingContext->payloadLength == 0)
208 encodingContext->state = WebsocketStateOpcodeAndFin;
213 (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
214 if (rlen > INT32_MAX)
218 status = BIO_read(bio, pBuffer, (
int)rlen);
219 if ((status <= 0) || ((
size_t)status > encodingContext->payloadLength))
222 encodingContext->payloadLength -= (size_t)status;
224 if (encodingContext->payloadLength == 0)
225 encodingContext->state = WebsocketStateOpcodeAndFin;
230 static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
233 WINPR_ASSERT(encodingContext);
235 wStream* s = encodingContext->responseStreamBuffer;
238 if (encodingContext->payloadLength == 0)
240 encodingContext->state = WebsocketStateOpcodeAndFin;
244 if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
247 "wStream::capacity [%" PRIuz
"] != encodingContext::paylaodLangth [%" PRIuz
"]",
248 Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
252 const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
257 if (!Stream_SafeSeek(s, (
size_t)status))
263 static BOOL websocket_reply_close(BIO* bio, websocket_context* context,
wStream* s)
267 return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
270 static BOOL websocket_reply_pong(BIO* bio, websocket_context* context,
wStream* s)
275 if (Stream_GetPosition(s) != 0)
276 return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
278 return websocket_reply_close(bio, context, NULL);
281 static int websocket_handle_payload(BIO* bio, BYTE* pBuffer,
size_t size,
282 websocket_context* encodingContext)
287 WINPR_ASSERT(pBuffer);
288 WINPR_ASSERT(encodingContext);
290 const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
291 ? encodingContext->fragmentOriginalOpcode & 0xf
292 : encodingContext->opcode & 0xf);
294 switch (effectiveOpcode)
296 case WebsocketBinaryOpcode:
298 status = websocket_read_data(bio, pBuffer, size, encodingContext);
304 case WebsocketPingOpcode:
306 status = websocket_read_wstream(bio, encodingContext);
310 if (encodingContext->payloadLength == 0)
312 websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
313 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
317 case WebsocketPongOpcode:
319 status = websocket_read_wstream(bio, encodingContext);
323 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
326 case WebsocketCloseOpcode:
328 status = websocket_read_wstream(bio, encodingContext);
332 if (encodingContext->payloadLength == 0)
334 websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
335 encodingContext->closeSent = TRUE;
336 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
341 WLog_WARN(TAG,
"Unimplemented websocket opcode %" PRIx8
". Dropping", effectiveOpcode);
343 status = websocket_read_wstream(bio, encodingContext);
346 Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
354 int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer,
size_t size)
357 int effectiveDataLen = 0;
360 WINPR_ASSERT(pBuffer);
361 WINPR_ASSERT(encodingContext);
365 switch (encodingContext->state)
367 case WebsocketStateOpcodeAndFin:
369 BYTE buffer[1] = { 0 };
372 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
374 return (effectiveDataLen > 0 ? effectiveDataLen : status);
376 encodingContext->opcode = buffer[0];
377 if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
378 (encodingContext->opcode & 0xf) < 0x08)
379 encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
380 encodingContext->state = WebsocketStateLengthAndMasking;
383 case WebsocketStateLengthAndMasking:
385 BYTE buffer[1] = { 0 };
388 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
390 return (effectiveDataLen > 0 ? effectiveDataLen : status);
392 encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
393 encodingContext->lengthAndMaskPosition = 0;
394 encodingContext->payloadLength = 0;
395 const BYTE len = buffer[0] & 0x7f;
398 encodingContext->payloadLength = len;
399 encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
400 : WebSocketStatePayload);
403 encodingContext->state = WebsocketStateShortLength;
405 encodingContext->state = WebsocketStateLongLength;
408 case WebsocketStateShortLength:
409 case WebsocketStateLongLength:
411 BYTE buffer[1] = { 0 };
412 const BYTE lenLength =
413 (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
414 while (encodingContext->lengthAndMaskPosition < lenLength)
417 status = BIO_read(bio, (
char*)buffer,
sizeof(buffer));
419 return (effectiveDataLen > 0 ? effectiveDataLen : status);
421 encodingContext->payloadLength =
422 (encodingContext->payloadLength) << 8 | buffer[0];
423 encodingContext->lengthAndMaskPosition += status;
425 encodingContext->state =
426 (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
429 case WebSocketStateMaskingKey:
432 TAG,
"Websocket Server sends data with masking key. This is against RFC 6455.");
435 case WebSocketStatePayload:
437 status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
439 return (effectiveDataLen > 0 ? effectiveDataLen : status);
441 effectiveDataLen += status;
443 if ((
size_t)status >= size)
444 return effectiveDataLen;
446 size -= WINPR_ASSERTING_INT_CAST(
size_t, status);
456 websocket_context* websocket_context_new(
void)
458 websocket_context* context = calloc(1,
sizeof(websocket_context));
462 context->responseStreamBuffer = Stream_New(NULL, 1024);
463 if (!context->responseStreamBuffer)
466 if (!websocket_context_reset(context))
471 websocket_context_free(context);
475 void websocket_context_free(websocket_context* context)
480 Stream_Free(context->responseStreamBuffer, TRUE);
484 BOOL websocket_context_reset(websocket_context* context)
486 WINPR_ASSERT(context);
488 context->state = WebsocketStateOpcodeAndFin;
489 return Stream_SetPosition(context->responseStreamBuffer, 0);