18 #include <winpr/assert.h>
20 #include <freerdp/freerdp.h>
21 #include <freerdp/server/proxy/proxy_log.h>
23 #include "proxy_modules.h"
24 #include "pf_channel.h"
26 #define TAG PROXY_TAG("channel")
29 struct _ChannelStateTracker
31 pServerStaticChannelContext* channel;
32 ChannelTrackerMode mode;
34 size_t currentPacketReceived;
35 size_t currentPacketSize;
36 size_t currentPacketFragments;
38 ChannelTrackerPeekFn peekFn;
43 static BOOL channelTracker_resetCurrentPacket(ChannelStateTracker* tracker)
45 WINPR_ASSERT(tracker);
48 if (tracker->currentPacket)
50 const size_t cap = Stream_Capacity(tracker->currentPacket);
51 if (cap < 1ULL * 1000ULL * 1000ULL)
54 Stream_Free(tracker->currentPacket, TRUE);
58 tracker->currentPacket = Stream_New(NULL, 10ULL * 1024ULL);
59 if (!tracker->currentPacket)
61 Stream_SetPosition(tracker->currentPacket, 0);
65 ChannelStateTracker* channelTracker_new(pServerStaticChannelContext* channel,
66 ChannelTrackerPeekFn fn,
void* data)
68 ChannelStateTracker* ret = calloc(1,
sizeof(ChannelStateTracker));
74 ret->channel = channel;
77 if (!channelTracker_setCustomData(ret, data))
80 if (!channelTracker_resetCurrentPacket(ret))
86 WINPR_PRAGMA_DIAG_PUSH
87 WINPR_PRAGMA_DIAG_IGNORED_MISMATCHED_DEALLOC
88 channelTracker_free(ret);
93 PfChannelResult channelTracker_update(ChannelStateTracker* tracker,
const BYTE* xdata,
size_t xsize,
94 UINT32 flags,
size_t totalSize)
96 PfChannelResult result = PF_CHANNEL_RESULT_ERROR;
97 BOOL firstPacket = (flags & CHANNEL_FLAG_FIRST) != 0;
98 BOOL lastPacket = (flags & CHANNEL_FLAG_LAST) != 0;
100 WINPR_ASSERT(tracker);
102 WLog_VRB(TAG,
"channelTracker_update(%s): sz=%" PRIuz
" first=%d last=%d",
103 tracker->channel->channel_name, xsize, firstPacket, lastPacket);
104 if (flags & CHANNEL_FLAG_FIRST)
106 if (!channelTracker_resetCurrentPacket(tracker))
108 channelTracker_setCurrentPacketSize(tracker, totalSize);
109 tracker->currentPacketReceived = 0;
110 tracker->currentPacketFragments = 0;
114 const size_t currentPacketSize = channelTracker_getCurrentPacketSize(tracker);
115 if (tracker->currentPacketReceived + xsize > currentPacketSize)
116 WLog_INFO(TAG,
"cumulated size is bigger (%" PRIuz
") than total size (%" PRIuz
")",
117 tracker->currentPacketReceived + xsize, currentPacketSize);
120 tracker->currentPacketReceived += xsize;
121 tracker->currentPacketFragments++;
123 switch (channelTracker_getMode(tracker))
125 case CHANNEL_TRACKER_PEEK:
127 wStream* currentPacket = channelTracker_getCurrentPacket(tracker);
128 if (!Stream_EnsureRemainingCapacity(currentPacket, xsize))
129 return PF_CHANNEL_RESULT_ERROR;
131 Stream_Write(currentPacket, xdata, xsize);
133 WINPR_ASSERT(tracker->peekFn);
134 result = tracker->peekFn(tracker, firstPacket, lastPacket);
137 case CHANNEL_TRACKER_PASS:
138 result = PF_CHANNEL_RESULT_PASS;
140 case CHANNEL_TRACKER_DROP:
141 result = PF_CHANNEL_RESULT_DROP;
149 const size_t currentPacketSize = channelTracker_getCurrentPacketSize(tracker);
150 channelTracker_setMode(tracker, CHANNEL_TRACKER_PEEK);
152 if (tracker->currentPacketReceived != currentPacketSize)
153 WLog_INFO(TAG,
"cumulated size(%" PRIuz
") does not match total size (%" PRIuz
")",
154 tracker->currentPacketReceived, currentPacketSize);
160 void channelTracker_free(ChannelStateTracker* t)
165 Stream_Free(t->currentPacket, TRUE);
175 PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, BOOL last,
178 proxyData* pdata = NULL;
179 pServerContext* ps = NULL;
180 pServerStaticChannelContext* channel = NULL;
181 UINT32 flags = CHANNEL_FLAG_FIRST;
183 const char* direction = toBack ?
"F->B" :
"B->F";
184 const size_t currentPacketSize = channelTracker_getCurrentPacketSize(t);
185 wStream* currentPacket = channelTracker_getCurrentPacket(t);
189 WLog_VRB(TAG,
"channelTracker_flushCurrent(%s): %s sz=%" PRIuz
" first=%d last=%d",
190 t->channel->channel_name, direction, Stream_GetPosition(currentPacket), first, last);
193 return PF_CHANNEL_RESULT_PASS;
196 channel = t->channel;
198 flags |= CHANNEL_FLAG_LAST;
204 ev.channel_id = channel->front_channel_id;
205 ev.channel_name = channel->channel_name;
206 ev.data = Stream_Buffer(currentPacket);
207 ev.data_len = Stream_GetPosition(currentPacket);
209 ev.total_size = currentPacketSize;
211 if (!pdata->pc->sendChannelData)
212 return PF_CHANNEL_RESULT_ERROR;
214 return pdata->pc->sendChannelData(pdata->pc, &ev) ? PF_CHANNEL_RESULT_DROP
215 : PF_CHANNEL_RESULT_ERROR;
219 r = ps->context.peer->SendChannelPacket(ps->context.peer, channel->front_channel_id,
220 currentPacketSize, flags, Stream_Buffer(currentPacket),
221 Stream_GetPosition(currentPacket));
223 return r ? PF_CHANNEL_RESULT_DROP : PF_CHANNEL_RESULT_ERROR;
226 static PfChannelResult pf_channel_generic_back_data(proxyData* pdata,
227 const pServerStaticChannelContext* channel,
228 const BYTE* xdata,
size_t xsize, UINT32 flags,
234 WINPR_ASSERT(channel);
236 switch (channel->channelMode)
238 case PF_UTILS_CHANNEL_PASSTHROUGH:
239 ev.channel_id = channel->back_channel_id;
240 ev.channel_name = channel->channel_name;
244 ev.total_size = totalSize;
246 if (!pf_modules_run_filter(pdata->module, FILTER_TYPE_CLIENT_PASSTHROUGH_CHANNEL_DATA,
248 return PF_CHANNEL_RESULT_DROP;
250 return PF_CHANNEL_RESULT_PASS;
252 case PF_UTILS_CHANNEL_INTERCEPT:
254 case PF_UTILS_CHANNEL_BLOCK:
256 return PF_CHANNEL_RESULT_DROP;
260 static PfChannelResult pf_channel_generic_front_data(proxyData* pdata,
261 const pServerStaticChannelContext* channel,
262 const BYTE* xdata,
size_t xsize, UINT32 flags,
268 WINPR_ASSERT(channel);
270 switch (channel->channelMode)
272 case PF_UTILS_CHANNEL_PASSTHROUGH:
273 ev.channel_id = channel->front_channel_id;
274 ev.channel_name = channel->channel_name;
278 ev.total_size = totalSize;
280 if (!pf_modules_run_filter(pdata->module, FILTER_TYPE_SERVER_PASSTHROUGH_CHANNEL_DATA,
282 return PF_CHANNEL_RESULT_DROP;
284 return PF_CHANNEL_RESULT_PASS;
286 case PF_UTILS_CHANNEL_INTERCEPT:
288 case PF_UTILS_CHANNEL_BLOCK:
290 return PF_CHANNEL_RESULT_DROP;
294 BOOL pf_channel_setup_generic(pServerStaticChannelContext* channel)
296 WINPR_ASSERT(channel);
297 channel->onBackData = pf_channel_generic_back_data;
298 channel->onFrontData = pf_channel_generic_front_data;
302 BOOL channelTracker_setMode(ChannelStateTracker* tracker, ChannelTrackerMode mode)
304 WINPR_ASSERT(tracker);
305 tracker->mode = mode;
309 ChannelTrackerMode channelTracker_getMode(ChannelStateTracker* tracker)
311 WINPR_ASSERT(tracker);
312 return tracker->mode;
315 BOOL channelTracker_setPData(ChannelStateTracker* tracker, proxyData* pdata)
317 WINPR_ASSERT(tracker);
318 tracker->pdata = pdata;
322 proxyData* channelTracker_getPData(ChannelStateTracker* tracker)
324 WINPR_ASSERT(tracker);
325 return tracker->pdata;
328 wStream* channelTracker_getCurrentPacket(ChannelStateTracker* tracker)
330 WINPR_ASSERT(tracker);
331 return tracker->currentPacket;
334 BOOL channelTracker_setCustomData(ChannelStateTracker* tracker,
void* data)
336 WINPR_ASSERT(tracker);
337 tracker->trackerData = data;
341 void* channelTracker_getCustomData(ChannelStateTracker* tracker)
343 WINPR_ASSERT(tracker);
344 return tracker->trackerData;
347 size_t channelTracker_getCurrentPacketSize(ChannelStateTracker* tracker)
349 WINPR_ASSERT(tracker);
350 return tracker->currentPacketSize;
353 BOOL channelTracker_setCurrentPacketSize(ChannelStateTracker* tracker,
size_t size)
355 WINPR_ASSERT(tracker);
356 tracker->currentPacketSize = size;