Skip to content

Commit 1da5d08

Browse files
committed
fix for spontaneous disconnects
1 parent 9865449 commit 1da5d08

2 files changed

Lines changed: 239 additions & 5 deletions

File tree

Source/WebSocket/hcwebsocket.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,13 @@ HRESULT WebSocket::ConnectAsync(
205205
HRESULT CALLBACK WebSocket::ConnectAsyncProvider(XAsyncOp op, XAsyncProviderData const* data)
206206
{
207207
ConnectContext* context{ static_cast<ConnectContext*>(data->context) };
208-
auto& ws{ context->observer->websocket };
209208

210209
switch (op)
211210
{
212211
case XAsyncOp::Begin:
213212
{
213+
assert(context->observer);
214+
auto& ws{ context->observer->websocket };
214215
std::unique_lock<std::mutex> lock{ ws->m_stateMutex };
215216

216217
RETURN_HR_IF(E_UNEXPECTED, ws->m_state != State::Initial);
@@ -255,11 +256,19 @@ void CALLBACK WebSocket::ConnectComplete(XAsyncBlock* async)
255256
ConnectContext* context{ static_cast<ConnectContext*>(async->context) };
256257
auto& ws{ context->observer->websocket };
257258

258-
assert(ws->m_state == State::Connecting);
259-
259+
// We can be put into the Disconnected state if a spontaneous error occurs between the connection process completing and this callback being invoked.
260+
// We need to be able to handle that scenario here.
260261
HRESULT hr = HCGetWebSocketConnectResult(&context->internalAsyncBlock, &context->result);
261-
262262
std::unique_lock<std::mutex> lock{ ws->m_stateMutex };
263+
const bool bIsDisconnected = (ws->m_state == State::Disconnected);
264+
if (bIsDisconnected && !FAILED(hr))
265+
{
266+
HC_TRACE_WARNING(WEBSOCKET, "WebSocket::ConnectComplete [%p] encountered a spontaneous disconnection. This implies the connection process was successful, but we otherwise had to close the connection (handle=%p)", ws.get(), context->observer.get());
267+
hr = E_FAIL;
268+
}
269+
270+
assert(ws->m_state == State::Connecting || bIsDisconnected);
271+
assert(context->observer.get() == context->result.websocket || FAILED(hr) || bIsDisconnected);
263272
if (SUCCEEDED(hr) && SUCCEEDED(context->result.errorCode))
264273
{
265274
// Connect was sucessful. Allocate ProviderContext to ensure WebSocket lifetime until it is reclaimed in WebSocket::CloseFunc
@@ -528,7 +537,10 @@ void CALLBACK WebSocket::CloseFunc(
528537
std::unique_lock<std::mutex> stateLock{ websocket->m_stateMutex };
529538
if (!websocket->m_providerContext)
530539
{
531-
HC_TRACE_ERROR(WEBSOCKET, "Unexpected call to WebSocket::CloseFunc will be ignored!");
540+
// It's possible for our websocket to get closed before we finish connecting. m_providerContext only gets populated when the connection process is 100% completed.
541+
// If we're still in the process of connecting, mark as disconnected and let ConnectComplete handle the cleanup.
542+
HC_TRACE_WARNING(WEBSOCKET, "Call to WebSocket::CloseFunc without providerContext. This means that we're aborting the connection process unexpectedly.");
543+
websocket->m_state = State::Disconnected;
532544
return;
533545
}
534546

Tests/UnitTests/Tests/WebsocketTests.cpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "DefineTestMacros.h"
88
#include "Utils.h"
99
#include "../global/global.h"
10+
#include "WebSocket/hcwebsocket.h"
1011
#include <httpClient/httpProvider.h>
1112

1213
#pragma warning(disable:4389)
@@ -135,6 +136,171 @@ HRESULT CALLBACK Test_Internal_HCWebSocketConnectAsync(
135136
});
136137
}
137138

139+
bool g_HCWebSocketConnectAndClose_Called = false;
140+
HRESULT CALLBACK Test_Internal_HCWebSocketConnectAsyncAndClose(
141+
_In_z_ PCSTR uri,
142+
_In_z_ PCSTR subProtocol,
143+
_In_ HCWebsocketHandle websocket,
144+
_Inout_ XAsyncBlock* asyncBlock,
145+
_In_opt_ void* context,
146+
_In_ HCPerformEnv env
147+
)
148+
{
149+
UNREFERENCED_PARAMETER(uri);
150+
UNREFERENCED_PARAMETER(subProtocol);
151+
UNREFERENCED_PARAMETER(context);
152+
UNREFERENCED_PARAMETER(env);
153+
154+
return XAsyncBegin(asyncBlock, websocket, nullptr, __FUNCTION__,
155+
[](XAsyncOp op, const XAsyncProviderData* data)
156+
{
157+
auto websocket = static_cast<HCWebsocketHandle>(data->context);
158+
159+
switch (op)
160+
{
161+
case XAsyncOp::Begin:
162+
{
163+
g_HCWebSocketConnectAndClose_Called = true;
164+
RETURN_IF_FAILED(XTaskQueueSubmitCallback(data->async->queue, XTaskQueuePort::Work, websocket,
165+
[](void* context, bool canceled)
166+
{
167+
if (canceled)
168+
{
169+
return;
170+
}
171+
172+
auto websocket = static_cast<HCWebsocketHandle>(context);
173+
HCWebSocketCloseEventFunction closeFunc = nullptr;
174+
void* closeContext = nullptr;
175+
HRESULT hr = HCWebSocketGetEventFunctions(websocket, nullptr, nullptr, &closeFunc, &closeContext);
176+
if (SUCCEEDED(hr) && closeFunc != nullptr)
177+
{
178+
closeFunc(websocket, HCWebSocketCloseStatus::AbnormalClose, closeContext);
179+
}
180+
}));
181+
182+
XAsyncComplete(data->async, S_OK, sizeof(WebSocketCompletionResult));
183+
return S_OK;
184+
}
185+
case XAsyncOp::GetResult:
186+
{
187+
RETURN_HR_IF(E_NOT_SUFFICIENT_BUFFER, data->bufferSize < sizeof(WebSocketCompletionResult));
188+
189+
auto result = static_cast<WebSocketCompletionResult*>(data->buffer);
190+
ZeroMemory(result, sizeof(WebSocketCompletionResult));
191+
result->errorCode = S_OK;
192+
result->websocket = websocket;
193+
return S_OK;
194+
}
195+
default:
196+
return S_OK;
197+
}
198+
});
199+
}
200+
201+
class TestWebSocketConnectAndCloseProvider : public IWebSocketProvider
202+
{
203+
public:
204+
HRESULT ConnectAsync(
205+
String const& uri,
206+
String const& subprotocol,
207+
HCWebsocketHandle websocketHandle,
208+
XAsyncBlock* async
209+
) noexcept override
210+
{
211+
UNREFERENCED_PARAMETER(uri);
212+
UNREFERENCED_PARAMETER(subprotocol);
213+
214+
m_websocket = websocketHandle;
215+
return XAsyncBegin(async, this, nullptr, __FUNCTION__,
216+
[](XAsyncOp op, const XAsyncProviderData* data)
217+
{
218+
auto provider = static_cast<TestWebSocketConnectAndCloseProvider*>(data->context);
219+
220+
switch (op)
221+
{
222+
case XAsyncOp::Begin:
223+
{
224+
provider->m_connectCalled = true;
225+
RETURN_IF_FAILED(XTaskQueueSubmitCallback(data->async->queue, XTaskQueuePort::Work, provider->m_websocket,
226+
[](void* context, bool canceled)
227+
{
228+
if (canceled)
229+
{
230+
return;
231+
}
232+
233+
auto websocket = static_cast<HCWebsocketHandle>(context);
234+
HCWebSocketCloseEventFunction closeFunc = nullptr;
235+
void* closeContext = nullptr;
236+
HRESULT hr = HCWebSocketGetEventFunctions(websocket, nullptr, nullptr, &closeFunc, &closeContext);
237+
if (SUCCEEDED(hr) && closeFunc != nullptr)
238+
{
239+
closeFunc(websocket, HCWebSocketCloseStatus::AbnormalClose, closeContext);
240+
}
241+
}));
242+
243+
XAsyncComplete(data->async, S_OK, sizeof(WebSocketCompletionResult));
244+
return S_OK;
245+
}
246+
case XAsyncOp::GetResult:
247+
{
248+
RETURN_HR_IF(E_NOT_SUFFICIENT_BUFFER, data->bufferSize < sizeof(WebSocketCompletionResult));
249+
250+
auto result = static_cast<WebSocketCompletionResult*>(data->buffer);
251+
ZeroMemory(result, sizeof(WebSocketCompletionResult));
252+
result->errorCode = S_OK;
253+
result->websocket = provider->m_websocket;
254+
return S_OK;
255+
}
256+
default:
257+
return S_OK;
258+
}
259+
});
260+
}
261+
262+
HRESULT SendAsync(
263+
HCWebsocketHandle websocketHandle,
264+
const char* message,
265+
XAsyncBlock* async
266+
) noexcept override
267+
{
268+
UNREFERENCED_PARAMETER(websocketHandle);
269+
UNREFERENCED_PARAMETER(message);
270+
UNREFERENCED_PARAMETER(async);
271+
return E_UNEXPECTED;
272+
}
273+
274+
HRESULT SendBinaryAsync(
275+
HCWebsocketHandle websocketHandle,
276+
const uint8_t* payloadBytes,
277+
uint32_t payloadSize,
278+
XAsyncBlock* async
279+
) noexcept override
280+
{
281+
UNREFERENCED_PARAMETER(websocketHandle);
282+
UNREFERENCED_PARAMETER(payloadBytes);
283+
UNREFERENCED_PARAMETER(payloadSize);
284+
UNREFERENCED_PARAMETER(async);
285+
return E_UNEXPECTED;
286+
}
287+
288+
HRESULT Disconnect(
289+
HCWebsocketHandle websocketHandle,
290+
HCWebSocketCloseStatus closeStatus
291+
) noexcept override
292+
{
293+
UNREFERENCED_PARAMETER(websocketHandle);
294+
UNREFERENCED_PARAMETER(closeStatus);
295+
return E_UNEXPECTED;
296+
}
297+
298+
bool m_connectCalled{ false };
299+
300+
private:
301+
HCWebsocketHandle m_websocket{ nullptr };
302+
};
303+
138304
bool g_HCWebSocketSendMessage_Called = false;
139305
HRESULT CALLBACK Test_Internal_HCWebSocketSendMessageAsync(
140306
_In_ HCWebsocketHandle websocket,
@@ -325,6 +491,62 @@ DEFINE_TEST_CLASS(WebsocketTests)
325491
HCCleanup();
326492
}
327493

494+
DEFINE_TEST_CASE(TestConnectFailsWhenDisconnectedDuringCompletion)
495+
{
496+
#ifdef UNITTEST_TE
497+
return;
498+
#else
499+
TestWebSocketConnectAndCloseProvider provider;
500+
auto websocket = std::make_shared<WebSocket>(1, provider);
501+
502+
XAsyncBlock asyncBlock{};
503+
VERIFY_SUCCEEDED(XTaskQueueCreate(XTaskQueueDispatchMode::Manual, XTaskQueueDispatchMode::Manual, &asyncBlock.queue));
504+
505+
VERIFY_ARE_EQUAL(S_OK, websocket->ConnectAsync(http_internal_string{ "test" }, http_internal_string{ "subProtoTest" }, &asyncBlock));
506+
WebSocketCompletionResult connectResult{};
507+
XAsyncBlock sendAsyncBlock{};
508+
HRESULT connectStatus = E_PENDING;
509+
HRESULT getConnectResultHr = E_PENDING;
510+
HRESULT sendHr = S_OK;
511+
HRESULT disconnectHr = S_OK;
512+
513+
for (uint32_t attempt = 0; attempt < 8 && connectStatus == E_PENDING; ++attempt)
514+
{
515+
auto timeout = attempt == 0 ? 100u : 0u;
516+
while (XTaskQueueDispatch(asyncBlock.queue, XTaskQueuePort::Work, timeout)) {}
517+
while (XTaskQueueDispatch(asyncBlock.queue, XTaskQueuePort::Completion, timeout)) {}
518+
connectStatus = XAsyncGetStatus(&asyncBlock, false);
519+
}
520+
521+
if (connectStatus == E_FAIL)
522+
{
523+
getConnectResultHr = HCGetWebSocketConnectResult(&asyncBlock, &connectResult);
524+
}
525+
else if (connectStatus == E_PENDING)
526+
{
527+
XAsyncCancel(&asyncBlock);
528+
while (XTaskQueueDispatch(asyncBlock.queue, XTaskQueuePort::Work, 0)) {}
529+
while (XTaskQueueDispatch(asyncBlock.queue, XTaskQueuePort::Completion, 0)) {}
530+
connectStatus = XAsyncGetStatus(&asyncBlock, false);
531+
}
532+
533+
getConnectResultHr = HCGetWebSocketConnectResult(&asyncBlock, &connectResult);
534+
sendHr = websocket->SendAsync("test", &sendAsyncBlock);
535+
disconnectHr = websocket->Disconnect();
536+
537+
while (XTaskQueueDispatch(asyncBlock.queue, XTaskQueuePort::Work, 0)) {}
538+
while (XTaskQueueDispatch(asyncBlock.queue, XTaskQueuePort::Completion, 0)) {}
539+
XTaskQueueCloseHandle(asyncBlock.queue);
540+
websocket.reset();
541+
542+
VERIFY_ARE_EQUAL(true, provider.m_connectCalled);
543+
VERIFY_IS_TRUE(FAILED(connectStatus));
544+
VERIFY_IS_TRUE(FAILED(getConnectResultHr));
545+
VERIFY_ARE_EQUAL(E_UNEXPECTED, sendHr);
546+
VERIFY_ARE_EQUAL(E_UNEXPECTED, disconnectHr);
547+
#endif
548+
}
549+
328550

329551
DEFINE_TEST_CASE(TestRequestHeaders)
330552
{

0 commit comments

Comments
 (0)