This repository was archived by the owner on Aug 24, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSSLSocket.h
More file actions
executable file
·339 lines (266 loc) · 12.9 KB
/
SSLSocket.h
File metadata and controls
executable file
·339 lines (266 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
#ifndef SSL_SOCKET_H_
#define SSL_SOCKET_H_
#include "TCPSocket.h"
#include <wincrypt.h>
#define min(a, b) ((a) < (b) ? (a) : (b))
extern "C" {
#ifdef _MSC_VER
#ifndef LIBRESSL_INTERNAL
#include <basetsd.h>
typedef SSIZE_T ssize_t;
#endif
#endif
#include <sys/types.h>
#include <stddef.h>
#include <stdint.h>
#define TLS_API 20200120
#define TLS_PROTOCOL_TLSv1_0 (1 << 1)
#define TLS_PROTOCOL_TLSv1_1 (1 << 2)
#define TLS_PROTOCOL_TLSv1_2 (1 << 3)
#define TLS_PROTOCOL_TLSv1_3 (1 << 4)
#define TLS_PROTOCOL_TLSv1 (TLS_PROTOCOL_TLSv1_0 | TLS_PROTOCOL_TLSv1_1 | TLS_PROTOCOL_TLSv1_2 | TLS_PROTOCOL_TLSv1_3)
#define TLS_PROTOCOLS_ALL TLS_PROTOCOL_TLSv1
#define TLS_PROTOCOLS_DEFAULT (TLS_PROTOCOL_TLSv1_2 | TLS_PROTOCOL_TLSv1_3)
#define TLS_WANT_POLLIN -2
#define TLS_WANT_POLLOUT -3
/* RFC 6960 Section 2.3 */
#define TLS_OCSP_RESPONSE_SUCCESSFUL 0
#define TLS_OCSP_RESPONSE_MALFORMED 1
#define TLS_OCSP_RESPONSE_INTERNALERROR 2
#define TLS_OCSP_RESPONSE_TRYLATER 3
#define TLS_OCSP_RESPONSE_SIGREQUIRED 4
#define TLS_OCSP_RESPONSE_UNAUTHORIZED 5
/* RFC 6960 Section 2.2 */
#define TLS_OCSP_CERT_GOOD 0
#define TLS_OCSP_CERT_REVOKED 1
#define TLS_OCSP_CERT_UNKNOWN 2
/* RFC 5280 Section 5.3.1 */
#define TLS_CRL_REASON_UNSPECIFIED 0
#define TLS_CRL_REASON_KEY_COMPROMISE 1
#define TLS_CRL_REASON_CA_COMPROMISE 2
#define TLS_CRL_REASON_AFFILIATION_CHANGED 3
#define TLS_CRL_REASON_SUPERSEDED 4
#define TLS_CRL_REASON_CESSATION_OF_OPERATION 5
#define TLS_CRL_REASON_CERTIFICATE_HOLD 6
#define TLS_CRL_REASON_REMOVE_FROM_CRL 8
#define TLS_CRL_REASON_PRIVILEGE_WITHDRAWN 9
#define TLS_CRL_REASON_AA_COMPROMISE 10
#define TLS_MAX_SESSION_ID_LENGTH 32
#define TLS_TICKET_KEY_SIZE 48
struct tls;
struct tls_config;
typedef ssize_t (*tls_read_cb)(struct tls* _ctx, void* _buf, size_t _buflen, void* _cb_arg);
typedef ssize_t (*tls_write_cb)(struct tls* _ctx, const void* _buf, size_t _buflen, void* _cb_arg);
int tls_init(void);
const char* tls_config_error(struct tls_config* _config);
const char* tls_error(struct tls* _ctx);
struct tls_config* tls_config_new(void);
void tls_config_free(struct tls_config* _config);
const char* tls_default_ca_cert_file(void);
int tls_config_add_keypair_file(struct tls_config* _config, const char* _cert_file, const char* _key_file);
int tls_config_add_keypair_mem(struct tls_config* _config, const uint8_t* _cert, size_t _cert_len, const uint8_t* _key,
size_t _key_len);
int tls_config_add_keypair_ocsp_file(struct tls_config* _config, const char* _cert_file, const char* _key_file,
const char* _ocsp_staple_file);
int tls_config_add_keypair_ocsp_mem(struct tls_config* _config, const uint8_t* _cert, size_t _cert_len, const uint8_t* _key,
size_t _key_len, const uint8_t* _staple, size_t _staple_len);
int tls_config_set_alpn(struct tls_config* _config, const char* _alpn);
int tls_config_set_ca_file(struct tls_config* _config, const char* _ca_file);
int tls_config_set_ca_path(struct tls_config* _config, const char* _ca_path);
int tls_config_set_ca_mem(struct tls_config* _config, const uint8_t* _ca, size_t _len);
int tls_config_set_cert_file(struct tls_config* _config, const char* _cert_file);
int tls_config_set_cert_mem(struct tls_config* _config, const uint8_t* _cert, size_t _len);
int tls_config_set_ciphers(struct tls_config* _config, const char* _ciphers);
int tls_config_set_crl_file(struct tls_config* _config, const char* _crl_file);
int tls_config_set_crl_mem(struct tls_config* _config, const uint8_t* _crl, size_t _len);
int tls_config_set_dheparams(struct tls_config* _config, const char* _params);
int tls_config_set_ecdhecurve(struct tls_config* _config, const char* _curve);
int tls_config_set_ecdhecurves(struct tls_config* _config, const char* _curves);
int tls_config_set_key_file(struct tls_config* _config, const char* _key_file);
int tls_config_set_key_mem(struct tls_config* _config, const uint8_t* _key, size_t _len);
int tls_config_set_keypair_file(struct tls_config* _config, const char* _cert_file, const char* _key_file);
int tls_config_set_keypair_mem(struct tls_config* _config, const uint8_t* _cert, size_t _cert_len, const uint8_t* _key,
size_t _key_len);
int tls_config_set_keypair_ocsp_file(struct tls_config* _config, const char* _cert_file, const char* _key_file,
const char* _staple_file);
int tls_config_set_keypair_ocsp_mem(struct tls_config* _config, const uint8_t* _cert, size_t _cert_len, const uint8_t* _key,
size_t _key_len, const uint8_t* _staple, size_t staple_len);
int tls_config_set_ocsp_staple_mem(struct tls_config* _config, const uint8_t* _staple, size_t _len);
int tls_config_set_ocsp_staple_file(struct tls_config* _config, const char* _staple_file);
int tls_config_set_protocols(struct tls_config* _config, uint32_t _protocols);
int tls_config_set_session_fd(struct tls_config* _config, int _session_fd);
int tls_config_set_verify_depth(struct tls_config* _config, int _verify_depth);
void tls_config_prefer_ciphers_client(struct tls_config* _config);
void tls_config_prefer_ciphers_server(struct tls_config* _config);
void tls_config_insecure_noverifycert(struct tls_config* _config);
void tls_config_insecure_noverifyname(struct tls_config* _config);
void tls_config_insecure_noverifytime(struct tls_config* _config);
void tls_config_verify(struct tls_config* _config);
void tls_config_ocsp_require_stapling(struct tls_config* _config);
void tls_config_verify_client(struct tls_config* _config);
void tls_config_verify_client_optional(struct tls_config* _config);
void tls_config_clear_keys(struct tls_config* _config);
int tls_config_parse_protocols(uint32_t* _protocols, const char* _protostr);
int tls_config_set_session_id(struct tls_config* _config, const unsigned char* _session_id, size_t _len);
int tls_config_set_session_lifetime(struct tls_config* _config, int _lifetime);
int tls_config_add_ticket_key(struct tls_config* _config, uint32_t _keyrev, unsigned char* _key, size_t _keylen);
struct tls* tls_client(void);
struct tls* tls_server(void);
int tls_configure(struct tls* _ctx, struct tls_config* _config);
void tls_reset(struct tls* _ctx);
void tls_free(struct tls* _ctx);
int tls_accept_fds(struct tls* _ctx, struct tls** _cctx, int _fd_read, int _fd_write);
int tls_accept_socket(struct tls* _ctx, struct tls** _cctx, int _socket);
int tls_accept_cbs(struct tls* _ctx, struct tls** _cctx, tls_read_cb _read_cb, tls_write_cb _write_cb, void* _cb_arg);
int tls_connect(struct tls* _ctx, const char* _host, const char* _port);
int tls_connect_fds(struct tls* _ctx, int _fd_read, int _fd_write, const char* _servername);
int tls_connect_servername(struct tls* _ctx, const char* _host, const char* _port, const char* _servername);
int tls_connect_socket(struct tls* _ctx, int _s, const char* _servername);
int tls_connect_cbs(struct tls* _ctx, tls_read_cb _read_cb, tls_write_cb _write_cb, void* _cb_arg, const char* _servername);
int tls_handshake(struct tls* _ctx);
ssize_t tls_read(struct tls* _ctx, void* _buf, size_t _buflen);
ssize_t tls_write(struct tls* _ctx, const void* _buf, size_t _buflen);
int tls_close(struct tls* _ctx);
int tls_peer_cert_provided(struct tls* _ctx);
int tls_peer_cert_contains_name(struct tls* _ctx, const char* _name);
const char* tls_peer_cert_hash(struct tls* _ctx);
const char* tls_peer_cert_issuer(struct tls* _ctx);
const char* tls_peer_cert_subject(struct tls* _ctx);
time_t tls_peer_cert_notbefore(struct tls* _ctx);
time_t tls_peer_cert_notafter(struct tls* _ctx);
const uint8_t* tls_peer_cert_chain_pem(struct tls* _ctx, size_t* _len);
const char* tls_conn_alpn_selected(struct tls* _ctx);
const char* tls_conn_cipher(struct tls* _ctx);
int tls_conn_cipher_strength(struct tls* _ctx);
const char* tls_conn_servername(struct tls* _ctx);
int tls_conn_session_resumed(struct tls* _ctx);
const char* tls_conn_version(struct tls* _ctx);
uint8_t* tls_load_file(const char* _file, size_t* _len, char* _password);
void tls_unload_file(uint8_t* _buf, size_t len);
int tls_ocsp_process_response(struct tls* _ctx, const unsigned char* _response, size_t _size);
int tls_peer_ocsp_cert_status(struct tls* _ctx);
int tls_peer_ocsp_crl_reason(struct tls* _ctx);
time_t tls_peer_ocsp_next_update(struct tls* _ctx);
int tls_peer_ocsp_response_status(struct tls* _ctx);
const char* tls_peer_ocsp_result(struct tls* _ctx);
time_t tls_peer_ocsp_revocation_time(struct tls* _ctx);
time_t tls_peer_ocsp_this_update(struct tls* _ctx);
const char* tls_peer_ocsp_url(struct tls* _ctx);
}
class SSLSocket : public TCPSocket {
static inline tls_config* tlsConf = 0;
static inline vector<uint8_t> certificates;
void platformInit() override {
TCPSocket::platformInit();
if(!tlsConf) tlsConf = tls_config_new();
// load all ca's
HCERTSTORE hStore = CertOpenSystemStoreA(0, "ROOT");
PCCERT_CONTEXT pContext = NULL;
while((pContext = CertEnumCertificatesInStore(hStore, pContext)) != NULL) {
DWORD length;
CryptBinaryToStringA(pContext->pbCertEncoded, pContext->cbCertEncoded, CRYPT_STRING_BASE64HEADER, nullptr, &length);
vector<uint8_t> buf(length);
CryptBinaryToStringA(pContext->pbCertEncoded, pContext->cbCertEncoded, CRYPT_STRING_BASE64HEADER, (LPSTR)buf.data(),
&length);
buf.push_back('\n');
certificates.insert(certificates.end(), buf.begin(), buf.end());
}
CertFreeCertificateContext(pContext);
CertCloseStore(hStore, 0);
tls_config_set_ca_mem(tlsConf, certificates.data(), certificates.size());
}
tls* context = 0;
uint8_t surgeBuffer[8192];
uint32_t surgeUsed = 0;
public:
bool connect(string host, uint16_t port) override {
platformInit();
disconnect();
uint8_t ip[4] = {0};
uint8_t results = sscanf(host.c_str(), "%3u.%3u.%3u.%3u", &ip[0], &ip[1], &ip[2], &ip[3]);
// isn't an ip address - parse dns
if(results != 4) {
hostent* dnsResults = gethostbyname(host.c_str());
if(dnsResults == NULL) return false;
set_ip_sockaddr_in(remote, *(uint32_t*)dnsResults->h_addr_list[0]);
} else {
set_ip_sockaddr_in(remote, *(uint32_t*)ip);
}
// Address family
remote.sin_family = AF_INET;
// Set port
remote.sin_port = htons(port);
socketFd = socket(AF_INET, SOCK_STREAM, 0);
if(!context) context = tls_client();
if(tls_configure(context, tlsConf) == -1) throw std::runtime_error(tls_error(context));
if(::connect(socketFd, (const sockaddr*)&remote, sizeof(remote)) < 0) return false;
if(tls_connect_socket(context, socketFd, host.c_str()) == -1) throw std::runtime_error(tls_error(context));
return true;
}
void disconnect() override {
remote = {0};
if(socketFd == 0) return;
closeSocket(socketFd);
tls_reset(context);
surgeUsed = 0;
}
void send(vector<uint8_t>& bytes) override {
int sentTotal = 0;
while(sentTotal < bytes.size()) {
if(!writeReady(15 * 1000)) throw TimeoutException("send timed out after 15s");
int sentBytes = tls_write(context, (char*)bytes.data() + sentTotal, bytes.size() - sentTotal);
if(sentBytes > 0) sentTotal += sentBytes;
else if(sentBytes == 0)
throw CloseException("socket was closed during send");
else
throw NetworkException("connection was aborted during send");
}
}
vector<uint8_t> receiveAvailable() override {
if(!readReady(15 * 1000)) throw TimeoutException("receive timed out after 15s");
int receivedBytes = tls_read(context, surgeBuffer + surgeUsed, 8192 - surgeUsed);
if(receivedBytes == 0) throw CloseException("socket is closed");
else if(receivedBytes < 0)
throw NetworkException("connection was aborted during receive");
receivedBytes += surgeUsed;
surgeUsed = 0;
return vector<uint8_t>(surgeBuffer, surgeBuffer + receivedBytes);
}
vector<uint8_t> receiveUntil(vector<uint8_t> byteSequence) override {
vector<uint8_t> buffer;
while(1) {
try {
auto available = receiveAvailable();
buffer.insert(buffer.end(), available.begin(), available.end());
if(buffer.size() < byteSequence.size()) continue;
auto iter = std::search(buffer.end() - available.size() - byteSequence.size() + 1, buffer.end(),
byteSequence.begin(), byteSequence.end());
if(iter != buffer.end()) {
int readSize = iter - (buffer.end() - available.size()) + byteSequence.size();
buffer.erase(iter + byteSequence.size(), buffer.end());
memcpy(surgeBuffer, available.data() + readSize,
available.size() - readSize); // remove until found from receive queue
surgeUsed = available.size() - readSize;
return buffer;
}
} catch(...) { throw; }
}
}
vector<uint8_t> receive(uint64_t amount) override {
vector<uint8_t> buffer(amount);
memcpy(buffer.data(), surgeBuffer, min(amount, surgeUsed));
int receivedTotal = min(amount, surgeUsed);
surgeUsed -= receivedTotal;
while(receivedTotal < amount) {
if(!readReady(15 * 1000)) throw TimeoutException("receive timed out after 15s");
int receivedBytes = tls_read(context, (char*)buffer.data() + receivedTotal, amount - receivedTotal);
if(receivedBytes > 0) receivedTotal += receivedBytes;
else if(receivedBytes == 0)
throw CloseException("socket was closed during receive");
else
throw NetworkException("connection was aborted during receive");
}
return buffer;
}
};
#endif