18 #include <winpr/assert.h>
19 #include <winpr/cast.h>
21 #include <freerdp/freerdp.h>
22 #include <freerdp/server/proxy/proxy_log.h>
24 #include "proxy_modules.h"
25 #include "pf_channel.h"
27 #define TAG PROXY_TAG("channel")
30 struct _ChannelStateTracker
32 pServerStaticChannelContext* channel;
33 ChannelTrackerMode mode;
35 size_t currentPacketReceived;
36 size_t currentPacketSize;
37 size_t currentPacketFragments;
39 ChannelTrackerPeekFn peekFn;
44 static BOOL channelTracker_resetCurrentPacket(ChannelStateTracker* tracker)
46 WINPR_ASSERT(tracker);
49 if (tracker->currentPacket)
51 const size_t cap = Stream_Capacity(tracker->currentPacket);
52 if (cap < 1ULL * 1000ULL * 1000ULL)
55 Stream_Free(tracker->currentPacket, TRUE);
59 tracker->currentPacket = Stream_New(NULL, 10ULL * 1024ULL);
60 if (!tracker->currentPacket)
62 Stream_SetPosition(tracker->currentPacket, 0);
66 ChannelStateTracker* channelTracker_new(pServerStaticChannelContext* channel,
67 ChannelTrackerPeekFn fn,
void* data)
69 ChannelStateTracker* ret = calloc(1,
sizeof(ChannelStateTracker));
75 ret->channel = channel;
78 if (!channelTracker_setCustomData(ret, data))
81 if (!channelTracker_resetCurrentPacket(ret))
87 WINPR_PRAGMA_DIAG_PUSH
88 WINPR_PRAGMA_DIAG_IGNORED_MISMATCHED_DEALLOC
89 channelTracker_free(ret);
94 PfChannelResult channelTracker_update(ChannelStateTracker* tracker,
const BYTE* xdata,
size_t xsize,
95 UINT32 flags,
size_t totalSize)
97 PfChannelResult result = PF_CHANNEL_RESULT_ERROR;
98 BOOL firstPacket = (flags & CHANNEL_FLAG_FIRST) != 0;
99 BOOL lastPacket = (flags & CHANNEL_FLAG_LAST) != 0;
101 WINPR_ASSERT(tracker);
103 WLog_VRB(TAG,
"channelTracker_update(%s): sz=%" PRIuz
" first=%d last=%d",
104 tracker->channel->channel_name, xsize, firstPacket, lastPacket);
105 if (flags & CHANNEL_FLAG_FIRST)
107 if (!channelTracker_resetCurrentPacket(tracker))
109 channelTracker_setCurrentPacketSize(tracker, totalSize);
110 tracker->currentPacketReceived = 0;
111 tracker->currentPacketFragments = 0;
115 const size_t currentPacketSize = channelTracker_getCurrentPacketSize(tracker);
116 if (tracker->currentPacketReceived + xsize > currentPacketSize)
117 WLog_INFO(TAG,
"cumulated size is bigger (%" PRIuz
") than total size (%" PRIuz
")",
118 tracker->currentPacketReceived + xsize, currentPacketSize);
121 tracker->currentPacketReceived += xsize;
122 tracker->currentPacketFragments++;
124 switch (channelTracker_getMode(tracker))
126 case CHANNEL_TRACKER_PEEK:
128 wStream* currentPacket = channelTracker_getCurrentPacket(tracker);
129 if (!Stream_EnsureRemainingCapacity(currentPacket, xsize))
130 return PF_CHANNEL_RESULT_ERROR;
132 Stream_Write(currentPacket, xdata, xsize);
134 WINPR_ASSERT(tracker->peekFn);
135 result = tracker->peekFn(tracker, firstPacket, lastPacket);
138 case CHANNEL_TRACKER_PASS:
139 result = PF_CHANNEL_RESULT_PASS;
141 case CHANNEL_TRACKER_DROP:
142 result = PF_CHANNEL_RESULT_DROP;
150 const size_t currentPacketSize = channelTracker_getCurrentPacketSize(tracker);
151 channelTracker_setMode(tracker, CHANNEL_TRACKER_PEEK);
153 if (tracker->currentPacketReceived != currentPacketSize)
154 WLog_INFO(TAG,
"cumulated size(%" PRIuz
") does not match total size (%" PRIuz
")",
155 tracker->currentPacketReceived, currentPacketSize);
161 void channelTracker_free(ChannelStateTracker* t)
166 Stream_Free(t->currentPacket, TRUE);
176 PfChannelResult channelTracker_flushCurrent(ChannelStateTracker* t, BOOL first, BOOL last,
179 proxyData* pdata = NULL;
180 pServerContext* ps = NULL;
181 pServerStaticChannelContext* channel = NULL;
182 UINT32 flags = CHANNEL_FLAG_FIRST;
184 const char* direction = toBack ?
"F->B" :
"B->F";
185 const size_t currentPacketSize = channelTracker_getCurrentPacketSize(t);
186 wStream* currentPacket = channelTracker_getCurrentPacket(t);
190 WLog_VRB(TAG,
"channelTracker_flushCurrent(%s): %s sz=%" PRIuz
" first=%d last=%d",
191 t->channel->channel_name, direction, Stream_GetPosition(currentPacket), first, last);
194 return PF_CHANNEL_RESULT_PASS;
197 channel = t->channel;
199 flags |= CHANNEL_FLAG_LAST;
205 ev.channel_id = WINPR_ASSERTING_INT_CAST(UINT16, channel->front_channel_id);
206 ev.channel_name = channel->channel_name;
207 ev.data = Stream_Buffer(currentPacket);
208 ev.data_len = Stream_GetPosition(currentPacket);
210 ev.total_size = currentPacketSize;
212 if (!pdata->pc->sendChannelData)
213 return PF_CHANNEL_RESULT_ERROR;
215 return pdata->pc->sendChannelData(pdata->pc, &ev) ? PF_CHANNEL_RESULT_DROP
216 : PF_CHANNEL_RESULT_ERROR;
220 r = ps->context.peer->SendChannelPacket(
221 ps->context.peer, WINPR_ASSERTING_INT_CAST(UINT16, channel->front_channel_id),
222 currentPacketSize, flags, Stream_Buffer(currentPacket), Stream_GetPosition(currentPacket));
224 return r ? PF_CHANNEL_RESULT_DROP : PF_CHANNEL_RESULT_ERROR;
227 static PfChannelResult pf_channel_generic_back_data(proxyData* pdata,
228 const pServerStaticChannelContext* channel,
229 const BYTE* xdata,
size_t xsize, UINT32 flags,
235 WINPR_ASSERT(channel);
237 switch (channel->channelMode)
239 case PF_UTILS_CHANNEL_PASSTHROUGH:
240 ev.channel_id = WINPR_ASSERTING_INT_CAST(UINT16, channel->back_channel_id);
241 ev.channel_name = channel->channel_name;
245 ev.total_size = totalSize;
247 if (!pf_modules_run_filter(pdata->module, FILTER_TYPE_CLIENT_PASSTHROUGH_CHANNEL_DATA,
249 return PF_CHANNEL_RESULT_DROP;
251 return PF_CHANNEL_RESULT_PASS;
253 case PF_UTILS_CHANNEL_INTERCEPT:
255 case PF_UTILS_CHANNEL_BLOCK:
257 return PF_CHANNEL_RESULT_DROP;
261 static PfChannelResult pf_channel_generic_front_data(proxyData* pdata,
262 const pServerStaticChannelContext* channel,
263 const BYTE* xdata,
size_t xsize, UINT32 flags,
269 WINPR_ASSERT(channel);
271 switch (channel->channelMode)
273 case PF_UTILS_CHANNEL_PASSTHROUGH:
274 ev.channel_id = WINPR_ASSERTING_INT_CAST(UINT16, channel->front_channel_id);
275 ev.channel_name = channel->channel_name;
279 ev.total_size = totalSize;
281 if (!pf_modules_run_filter(pdata->module, FILTER_TYPE_SERVER_PASSTHROUGH_CHANNEL_DATA,
283 return PF_CHANNEL_RESULT_DROP;
285 return PF_CHANNEL_RESULT_PASS;
287 case PF_UTILS_CHANNEL_INTERCEPT:
289 case PF_UTILS_CHANNEL_BLOCK:
291 return PF_CHANNEL_RESULT_DROP;
295 BOOL pf_channel_setup_generic(pServerStaticChannelContext* channel)
297 WINPR_ASSERT(channel);
298 channel->onBackData = pf_channel_generic_back_data;
299 channel->onFrontData = pf_channel_generic_front_data;
303 BOOL channelTracker_setMode(ChannelStateTracker* tracker, ChannelTrackerMode mode)
305 WINPR_ASSERT(tracker);
306 tracker->mode = mode;
310 ChannelTrackerMode channelTracker_getMode(ChannelStateTracker* tracker)
312 WINPR_ASSERT(tracker);
313 return tracker->mode;
316 BOOL channelTracker_setPData(ChannelStateTracker* tracker, proxyData* pdata)
318 WINPR_ASSERT(tracker);
319 tracker->pdata = pdata;
323 proxyData* channelTracker_getPData(ChannelStateTracker* tracker)
325 WINPR_ASSERT(tracker);
326 return tracker->pdata;
329 wStream* channelTracker_getCurrentPacket(ChannelStateTracker* tracker)
331 WINPR_ASSERT(tracker);
332 return tracker->currentPacket;
335 BOOL channelTracker_setCustomData(ChannelStateTracker* tracker,
void* data)
337 WINPR_ASSERT(tracker);
338 tracker->trackerData = data;
342 void* channelTracker_getCustomData(ChannelStateTracker* tracker)
344 WINPR_ASSERT(tracker);
345 return tracker->trackerData;
348 size_t channelTracker_getCurrentPacketSize(ChannelStateTracker* tracker)
350 WINPR_ASSERT(tracker);
351 return tracker->currentPacketSize;
354 BOOL channelTracker_setCurrentPacketSize(ChannelStateTracker* tracker,
size_t size)
356 WINPR_ASSERT(tracker);
357 tracker->currentPacketSize = size;