FreeRDP
websocket.c
1 
20 #include "websocket.h"
21 #include <freerdp/log.h>
22 #include "../tcp.h"
23 
24 #define TAG FREERDP_TAG("core.gateway.websocket")
25 
26 struct s_websocket_context
27 {
28  size_t payloadLength;
29  uint32_t maskingKey;
30  BOOL masking;
31  BOOL closeSent;
32  BYTE opcode;
33  BYTE fragmentOriginalOpcode;
34  BYTE lengthAndMaskPosition;
35  WEBSOCKET_STATE state;
36  wStream* responseStreamBuffer;
37 };
38 
39 static int websocket_write_all(BIO* bio, const BYTE* data, size_t length);
40 
41 BOOL websocket_context_mask_and_send(BIO* bio, wStream* sPacket, wStream* sDataPacket,
42  UINT32 maskingKey)
43 {
44  const size_t len = Stream_Length(sDataPacket);
45  Stream_SetPosition(sDataPacket, 0);
46 
47  if (!Stream_EnsureRemainingCapacity(sPacket, len))
48  return FALSE;
49 
50  /* mask as much as possible with 32bit access */
51  size_t streamPos = 0;
52  for (; streamPos + 4 <= len; streamPos += 4)
53  {
54  const uint32_t data = Stream_Get_UINT32(sDataPacket);
55  Stream_Write_UINT32(sPacket, data ^ maskingKey);
56  }
57 
58  /* mask the rest byte by byte */
59  for (; streamPos < len; streamPos++)
60  {
61  BYTE data = 0;
62  BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
63  Stream_Read_UINT8(sDataPacket, data);
64  Stream_Write_UINT8(sPacket, data ^ *partialMask);
65  }
66 
67  Stream_SealLength(sPacket);
68 
69  ERR_clear_error();
70  const size_t size = Stream_Length(sPacket);
71  const int status = websocket_write_all(bio, Stream_Buffer(sPacket), size);
72  Stream_Free(sPacket, TRUE);
73 
74  if ((status < 0) || ((size_t)status != size))
75  return FALSE;
76 
77  return TRUE;
78 }
79 
80 wStream* websocket_context_packet_new(size_t len, WEBSOCKET_OPCODE opcode, UINT32* pMaskingKey)
81 {
82  WINPR_ASSERT(pMaskingKey);
83  if (len > INT_MAX)
84  return NULL;
85 
86  size_t fullLen = 0;
87  if (len < 126)
88  fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
89  else if (len < 0x10000)
90  fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
91  else
92  fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
93 
94  wStream* sWS = Stream_New(NULL, fullLen);
95  if (!sWS)
96  return NULL;
97 
98  UINT32 maskingKey = 0;
99  winpr_RAND(&maskingKey, sizeof(maskingKey));
100 
101  Stream_Write_UINT8(sWS, (UINT8)(WEBSOCKET_FIN_BIT | opcode));
102  if (len < 126)
103  Stream_Write_UINT8(sWS, (UINT8)len | WEBSOCKET_MASK_BIT);
104  else if (len < 0x10000)
105  {
106  Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
107  Stream_Write_UINT16_BE(sWS, (UINT16)len);
108  }
109  else
110  {
111  Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
112  Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */
113  Stream_Write_UINT32_BE(sWS, (UINT32)len);
114  }
115  Stream_Write_UINT32(sWS, maskingKey);
116  *pMaskingKey = maskingKey;
117  return sWS;
118 }
119 
120 BOOL websocket_context_write_wstream(websocket_context* context, BIO* bio, wStream* sPacket,
121  WEBSOCKET_OPCODE opcode)
122 {
123  WINPR_ASSERT(context);
124 
125  if (context->closeSent)
126  return FALSE;
127 
128  if (opcode == WebsocketCloseOpcode)
129  context->closeSent = TRUE;
130 
131  WINPR_ASSERT(bio);
132  WINPR_ASSERT(sPacket);
133 
134  const size_t len = Stream_Length(sPacket);
135  uint32_t maskingKey = 0;
136  wStream* sWS = websocket_context_packet_new(len, opcode, &maskingKey);
137  if (!sWS)
138  return FALSE;
139 
140  return websocket_context_mask_and_send(bio, sWS, sPacket, maskingKey);
141 }
142 
143 int websocket_write_all(BIO* bio, const BYTE* data, size_t length)
144 {
145  WINPR_ASSERT(bio);
146  WINPR_ASSERT(data);
147  size_t offset = 0;
148 
149  if (length > INT32_MAX)
150  return -1;
151 
152  while (offset < length)
153  {
154  ERR_clear_error();
155  const size_t diff = length - offset;
156  int status = BIO_write(bio, &data[offset], (int)diff);
157 
158  if (status > 0)
159  offset += (size_t)status;
160  else
161  {
162  if (!BIO_should_retry(bio))
163  return -1;
164 
165  if (BIO_write_blocked(bio))
166  {
167  const long rstatus = BIO_wait_write(bio, 100);
168  if (rstatus < 0)
169  return -1;
170  }
171  else if (BIO_read_blocked(bio))
172  return -2; /* Abort write, there is data that must be read */
173  else
174  USleep(100);
175  }
176  }
177 
178  return (int)length;
179 }
180 
181 int websocket_context_write(websocket_context* context, BIO* bio, const BYTE* buf, int isize,
182  WEBSOCKET_OPCODE opcode)
183 {
184  WINPR_ASSERT(bio);
185  WINPR_ASSERT(buf);
186 
187  if (isize < 0)
188  return -1;
189 
190  wStream sbuffer = { 0 };
191  wStream* s = Stream_StaticConstInit(&sbuffer, buf, (size_t)isize);
192  if (!websocket_context_write_wstream(context, bio, s, opcode))
193  return -2;
194  return isize;
195 }
196 
197 static int websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
198  websocket_context* encodingContext)
199 {
200  int status = 0;
201 
202  WINPR_ASSERT(bio);
203  WINPR_ASSERT(pBuffer);
204  WINPR_ASSERT(encodingContext);
205 
206  if (encodingContext->payloadLength == 0)
207  {
208  encodingContext->state = WebsocketStateOpcodeAndFin;
209  return 0;
210  }
211 
212  const size_t rlen =
213  (encodingContext->payloadLength < size ? encodingContext->payloadLength : size);
214  if (rlen > INT32_MAX)
215  return -1;
216 
217  ERR_clear_error();
218  status = BIO_read(bio, pBuffer, (int)rlen);
219  if ((status <= 0) || ((size_t)status > encodingContext->payloadLength))
220  return status;
221 
222  encodingContext->payloadLength -= (size_t)status;
223 
224  if (encodingContext->payloadLength == 0)
225  encodingContext->state = WebsocketStateOpcodeAndFin;
226 
227  return status;
228 }
229 
230 static int websocket_read_wstream(BIO* bio, websocket_context* encodingContext)
231 {
232  WINPR_ASSERT(bio);
233  WINPR_ASSERT(encodingContext);
234 
235  wStream* s = encodingContext->responseStreamBuffer;
236  WINPR_ASSERT(s);
237 
238  if (encodingContext->payloadLength == 0)
239  {
240  encodingContext->state = WebsocketStateOpcodeAndFin;
241  return 0;
242  }
243 
244  if (!Stream_EnsureRemainingCapacity(s, encodingContext->payloadLength))
245  {
246  WLog_WARN(TAG,
247  "wStream::capacity [%" PRIuz "] != encodingContext::paylaodLangth [%" PRIuz "]",
248  Stream_GetRemainingCapacity(s), encodingContext->payloadLength);
249  return -1;
250  }
251 
252  const int status = websocket_read_data(bio, Stream_Pointer(s), Stream_GetRemainingCapacity(s),
253  encodingContext);
254  if (status < 0)
255  return status;
256 
257  if (!Stream_SafeSeek(s, (size_t)status))
258  return -1;
259 
260  return status;
261 }
262 
263 static BOOL websocket_reply_close(BIO* bio, websocket_context* context, wStream* s)
264 {
265  WINPR_ASSERT(bio);
266 
267  return websocket_context_write_wstream(context, bio, s, WebsocketCloseOpcode);
268 }
269 
270 static BOOL websocket_reply_pong(BIO* bio, websocket_context* context, wStream* s)
271 {
272  WINPR_ASSERT(bio);
273  WINPR_ASSERT(s);
274 
275  if (Stream_GetPosition(s) != 0)
276  return websocket_context_write_wstream(context, bio, s, WebsocketPongOpcode);
277 
278  return websocket_reply_close(bio, context, NULL);
279 }
280 
281 static int websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
282  websocket_context* encodingContext)
283 {
284  int status = 0;
285 
286  WINPR_ASSERT(bio);
287  WINPR_ASSERT(pBuffer);
288  WINPR_ASSERT(encodingContext);
289 
290  const BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
291  ? encodingContext->fragmentOriginalOpcode & 0xf
292  : encodingContext->opcode & 0xf);
293 
294  switch (effectiveOpcode)
295  {
296  case WebsocketBinaryOpcode:
297  {
298  status = websocket_read_data(bio, pBuffer, size, encodingContext);
299  if (status < 0)
300  return status;
301 
302  return status;
303  }
304  case WebsocketPingOpcode:
305  {
306  status = websocket_read_wstream(bio, encodingContext);
307  if (status < 0)
308  return status;
309 
310  if (encodingContext->payloadLength == 0)
311  {
312  websocket_reply_pong(bio, encodingContext, encodingContext->responseStreamBuffer);
313  Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
314  }
315  }
316  break;
317  case WebsocketPongOpcode:
318  {
319  status = websocket_read_wstream(bio, encodingContext);
320  if (status < 0)
321  return status;
322  /* We donĀ“t care about pong response data, discard. */
323  Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
324  }
325  break;
326  case WebsocketCloseOpcode:
327  {
328  status = websocket_read_wstream(bio, encodingContext);
329  if (status < 0)
330  return status;
331 
332  if (encodingContext->payloadLength == 0)
333  {
334  websocket_reply_close(bio, encodingContext, encodingContext->responseStreamBuffer);
335  encodingContext->closeSent = TRUE;
336  Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
337  }
338  }
339  break;
340  default:
341  WLog_WARN(TAG, "Unimplemented websocket opcode %" PRIx8 ". Dropping", effectiveOpcode);
342 
343  status = websocket_read_wstream(bio, encodingContext);
344  if (status < 0)
345  return status;
346  Stream_SetPosition(encodingContext->responseStreamBuffer, 0);
347  break;
348  }
349  /* return how many bytes have been written to pBuffer.
350  * Only WebsocketBinaryOpcode writes into it and it returns directly */
351  return 0;
352 }
353 
354 int websocket_context_read(websocket_context* encodingContext, BIO* bio, BYTE* pBuffer, size_t size)
355 {
356  int status = 0;
357  int effectiveDataLen = 0;
358 
359  WINPR_ASSERT(bio);
360  WINPR_ASSERT(pBuffer);
361  WINPR_ASSERT(encodingContext);
362 
363  while (TRUE)
364  {
365  switch (encodingContext->state)
366  {
367  case WebsocketStateOpcodeAndFin:
368  {
369  BYTE buffer[1] = { 0 };
370 
371  ERR_clear_error();
372  status = BIO_read(bio, (char*)buffer, sizeof(buffer));
373  if (status <= 0)
374  return (effectiveDataLen > 0 ? effectiveDataLen : status);
375 
376  encodingContext->opcode = buffer[0];
377  if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
378  (encodingContext->opcode & 0xf) < 0x08)
379  encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
380  encodingContext->state = WebsocketStateLengthAndMasking;
381  }
382  break;
383  case WebsocketStateLengthAndMasking:
384  {
385  BYTE buffer[1] = { 0 };
386 
387  ERR_clear_error();
388  status = BIO_read(bio, (char*)buffer, sizeof(buffer));
389  if (status <= 0)
390  return (effectiveDataLen > 0 ? effectiveDataLen : status);
391 
392  encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
393  encodingContext->lengthAndMaskPosition = 0;
394  encodingContext->payloadLength = 0;
395  const BYTE len = buffer[0] & 0x7f;
396  if (len < 126)
397  {
398  encodingContext->payloadLength = len;
399  encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
400  : WebSocketStatePayload);
401  }
402  else if (len == 126)
403  encodingContext->state = WebsocketStateShortLength;
404  else
405  encodingContext->state = WebsocketStateLongLength;
406  }
407  break;
408  case WebsocketStateShortLength:
409  case WebsocketStateLongLength:
410  {
411  BYTE buffer[1] = { 0 };
412  const BYTE lenLength =
413  (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
414  while (encodingContext->lengthAndMaskPosition < lenLength)
415  {
416  ERR_clear_error();
417  status = BIO_read(bio, (char*)buffer, sizeof(buffer));
418  if (status <= 0)
419  return (effectiveDataLen > 0 ? effectiveDataLen : status);
420 
421  encodingContext->payloadLength =
422  (encodingContext->payloadLength) << 8 | buffer[0];
423  encodingContext->lengthAndMaskPosition += status;
424  }
425  encodingContext->state =
426  (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
427  }
428  break;
429  case WebSocketStateMaskingKey:
430  {
431  WLog_WARN(
432  TAG, "Websocket Server sends data with masking key. This is against RFC 6455.");
433  return -1;
434  }
435  case WebSocketStatePayload:
436  {
437  status = websocket_handle_payload(bio, pBuffer, size, encodingContext);
438  if (status < 0)
439  return (effectiveDataLen > 0 ? effectiveDataLen : status);
440 
441  effectiveDataLen += status;
442 
443  if ((size_t)status >= size)
444  return effectiveDataLen;
445  pBuffer += status;
446  size -= WINPR_ASSERTING_INT_CAST(size_t, status);
447  }
448  break;
449  default:
450  break;
451  }
452  }
453  /* should be unreachable */
454 }
455 
456 websocket_context* websocket_context_new(void)
457 {
458  websocket_context* context = calloc(1, sizeof(websocket_context));
459  if (!context)
460  goto fail;
461 
462  context->responseStreamBuffer = Stream_New(NULL, 1024);
463  if (!context->responseStreamBuffer)
464  goto fail;
465 
466  if (!websocket_context_reset(context))
467  goto fail;
468 
469  return context;
470 fail:
471  websocket_context_free(context);
472  return NULL;
473 }
474 
475 void websocket_context_free(websocket_context* context)
476 {
477  if (!context)
478  return;
479 
480  Stream_Free(context->responseStreamBuffer, TRUE);
481  free(context);
482 }
483 
484 BOOL websocket_context_reset(websocket_context* context)
485 {
486  WINPR_ASSERT(context);
487 
488  context->state = WebsocketStateOpcodeAndFin;
489  return Stream_SetPosition(context->responseStreamBuffer, 0);
490 }