summaryrefslogtreecommitdiff
path: root/src/core/hle/service/ssl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/hle/service/ssl')
-rw-r--r--src/core/hle/service/ssl/ssl.cpp349
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h44
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp15
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp342
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp529
5 files changed, 1262 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..a3b54c7f0 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project 1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later 2// SPDX-License-Identifier: GPL-2.0-or-later
3 3
4#include "common/string_util.h"
5
6#include "core/core.h"
4#include "core/hle/service/ipc_helpers.h" 7#include "core/hle/service/ipc_helpers.h"
5#include "core/hle/service/server_manager.h" 8#include "core/hle/service/server_manager.h"
6#include "core/hle/service/service.h" 9#include "core/hle/service/service.h"
10#include "core/hle/service/sm/sm.h"
11#include "core/hle/service/sockets/bsd.h"
7#include "core/hle/service/ssl/ssl.h" 12#include "core/hle/service/ssl/ssl.h"
13#include "core/hle/service/ssl/ssl_backend.h"
14#include "core/internal_network/network.h"
15#include "core/internal_network/sockets.h"
8 16
9namespace Service::SSL { 17namespace Service::SSL {
10 18
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
20 CrlImportDateCheckEnable = 1, 28 CrlImportDateCheckEnable = 1,
21}; 29};
22 30
31// This is nn::ssl::Connection::IoMode
32enum class IoMode : u32 {
33 Blocking = 1,
34 NonBlocking = 2,
35};
36
37// This is nn::ssl::sf::OptionType
38enum class OptionType : u32 {
39 DoNotCloseSocket = 0,
40 GetServerCertChain = 1,
41};
42
23// This is nn::ssl::sf::SslVersion 43// This is nn::ssl::sf::SslVersion
24struct SslVersion { 44struct SslVersion {
25 union { 45 union {
@@ -34,35 +54,42 @@ struct SslVersion {
34 }; 54 };
35}; 55};
36 56
57struct SslContextSharedData {
58 u32 connection_count = 0;
59};
60
37class ISslConnection final : public ServiceFramework<ISslConnection> { 61class ISslConnection final : public ServiceFramework<ISslConnection> {
38public: 62public:
39 explicit ISslConnection(Core::System& system_, SslVersion version) 63 explicit ISslConnection(Core::System& system_, SslVersion version,
40 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { 64 std::shared_ptr<SslContextSharedData>& shared_data,
65 std::unique_ptr<SSLConnectionBackend>&& backend)
66 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version},
67 shared_data_{shared_data}, backend_{std::move(backend)} {
41 // clang-format off 68 // clang-format off
42 static const FunctionInfo functions[] = { 69 static const FunctionInfo functions[] = {
43 {0, nullptr, "SetSocketDescriptor"}, 70 {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
44 {1, nullptr, "SetHostName"}, 71 {1, &ISslConnection::SetHostName, "SetHostName"},
45 {2, nullptr, "SetVerifyOption"}, 72 {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
46 {3, nullptr, "SetIoMode"}, 73 {3, &ISslConnection::SetIoMode, "SetIoMode"},
47 {4, nullptr, "GetSocketDescriptor"}, 74 {4, nullptr, "GetSocketDescriptor"},
48 {5, nullptr, "GetHostName"}, 75 {5, nullptr, "GetHostName"},
49 {6, nullptr, "GetVerifyOption"}, 76 {6, nullptr, "GetVerifyOption"},
50 {7, nullptr, "GetIoMode"}, 77 {7, nullptr, "GetIoMode"},
51 {8, nullptr, "DoHandshake"}, 78 {8, &ISslConnection::DoHandshake, "DoHandshake"},
52 {9, nullptr, "DoHandshakeGetServerCert"}, 79 {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
53 {10, nullptr, "Read"}, 80 {10, &ISslConnection::Read, "Read"},
54 {11, nullptr, "Write"}, 81 {11, &ISslConnection::Write, "Write"},
55 {12, nullptr, "Pending"}, 82 {12, &ISslConnection::Pending, "Pending"},
56 {13, nullptr, "Peek"}, 83 {13, nullptr, "Peek"},
57 {14, nullptr, "Poll"}, 84 {14, nullptr, "Poll"},
58 {15, nullptr, "GetVerifyCertError"}, 85 {15, nullptr, "GetVerifyCertError"},
59 {16, nullptr, "GetNeededServerCertBufferSize"}, 86 {16, nullptr, "GetNeededServerCertBufferSize"},
60 {17, nullptr, "SetSessionCacheMode"}, 87 {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
61 {18, nullptr, "GetSessionCacheMode"}, 88 {18, nullptr, "GetSessionCacheMode"},
62 {19, nullptr, "FlushSessionCache"}, 89 {19, nullptr, "FlushSessionCache"},
63 {20, nullptr, "SetRenegotiationMode"}, 90 {20, nullptr, "SetRenegotiationMode"},
64 {21, nullptr, "GetRenegotiationMode"}, 91 {21, nullptr, "GetRenegotiationMode"},
65 {22, nullptr, "SetOption"}, 92 {22, &ISslConnection::SetOption, "SetOption"},
66 {23, nullptr, "GetOption"}, 93 {23, nullptr, "GetOption"},
67 {24, nullptr, "GetVerifyCertErrors"}, 94 {24, nullptr, "GetVerifyCertErrors"},
68 {25, nullptr, "GetCipherInfo"}, 95 {25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,295 @@ public:
80 // clang-format on 107 // clang-format on
81 108
82 RegisterHandlers(functions); 109 RegisterHandlers(functions);
110
111 shared_data->connection_count++;
112 }
113
114 ~ISslConnection() {
115 shared_data_->connection_count--;
116 if (fd_to_close_.has_value()) {
117 s32 fd = *fd_to_close_;
118 if (!do_not_close_socket_) {
119 LOG_ERROR(Service_SSL,
120 "do_not_close_socket was changed after setting socket; is this right?");
121 } else {
122 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
123 if (bsd) {
124 auto err = bsd->CloseImpl(fd);
125 if (err != Service::Sockets::Errno::SUCCESS) {
126 LOG_ERROR(Service_SSL, "failed to close duplicated socket: {}", err);
127 }
128 }
129 }
130 }
83 } 131 }
84 132
85private: 133private:
86 SslVersion ssl_version; 134 SslVersion ssl_version;
135 std::shared_ptr<SslContextSharedData> shared_data_;
136 std::unique_ptr<SSLConnectionBackend> backend_;
137 std::optional<int> fd_to_close_;
138 bool do_not_close_socket_ = false;
139 bool get_server_cert_chain_ = false;
140 std::shared_ptr<Network::SocketBase> socket_;
141 bool did_set_host_name_ = false;
142 bool did_handshake_ = false;
143
144 ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
145 LOG_DEBUG(Service_SSL, "called, fd={}", fd);
146 ASSERT(!did_handshake_);
147 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
148 ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
149 s32 ret_fd;
150 // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
151 if (do_not_close_socket_) {
152 auto res = bsd->DuplicateSocketImpl(fd);
153 if (!res.has_value()) {
154 LOG_ERROR(Service_SSL, "failed to duplicate socket");
155 return ResultInvalidSocket;
156 }
157 fd = *res;
158 fd_to_close_ = fd;
159 ret_fd = fd;
160 } else {
161 ret_fd = -1;
162 }
163 std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
164 if (!sock.has_value()) {
165 LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
166 return ResultInvalidSocket;
167 }
168 socket_ = std::move(*sock);
169 backend_->SetSocket(socket_);
170 return ret_fd;
171 }
172
173 Result SetHostNameImpl(const std::string& hostname) {
174 LOG_DEBUG(Service_SSL, "SetHostNameImpl({})", hostname);
175 ASSERT(!did_handshake_);
176 Result res = backend_->SetHostName(hostname);
177 if (res == ResultSuccess) {
178 did_set_host_name_ = true;
179 }
180 return res;
181 }
182
183 Result SetVerifyOptionImpl(u32 option) {
184 ASSERT(!did_handshake_);
185 LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
186 return ResultSuccess;
187 }
188
189 Result SetIOModeImpl(u32 _mode) {
190 auto mode = static_cast<IoMode>(_mode);
191 ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
192 ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; });
193
194 bool non_block = mode == IoMode::NonBlocking;
195 Network::Errno e = socket_->SetNonBlock(non_block);
196 if (e != Network::Errno::SUCCESS) {
197 LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
198 }
199 return ResultSuccess;
200 }
201
202 Result SetSessionCacheModeImpl(u32 mode) {
203 ASSERT(!did_handshake_);
204 LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
205 return ResultSuccess;
206 }
207
208 Result DoHandshakeImpl() {
209 ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; });
210 ASSERT_OR_EXECUTE_MSG(
211 did_set_host_name_, { return ResultInternalError; },
212 "Expected SetHostName before DoHandshake");
213 Result res = backend_->DoHandshake();
214 did_handshake_ = res.IsSuccess();
215 return res;
216 }
217
218 std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
219 struct Header {
220 u64 magic;
221 u32 count;
222 u32 pad;
223 };
224 struct EntryHeader {
225 u32 size;
226 u32 offset;
227 };
228 if (!get_server_cert_chain_) {
229 // Just return the first one, unencoded.
230 ASSERT_OR_EXECUTE_MSG(
231 !certs.empty(), { return {}; }, "Should be at least one server cert");
232 return certs[0];
233 }
234 std::vector<u8> ret;
235 Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
236 ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
237 size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
238 for (auto& cert : certs) {
239 EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
240 data_offset += cert.size();
241 ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
242 reinterpret_cast<u8*>(&entry_header + 1));
243 }
244 for (auto& cert : certs) {
245 ret.insert(ret.end(), cert.begin(), cert.end());
246 }
247 return ret;
248 }
249
250 ResultVal<std::vector<u8>> ReadImpl(size_t size) {
251 ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
252 std::vector<u8> res(size);
253 ResultVal<size_t> actual = backend_->Read(res);
254 if (actual.Failed()) {
255 return actual.Code();
256 }
257 res.resize(*actual);
258 return res;
259 }
260
261 ResultVal<size_t> WriteImpl(std::span<const u8> data) {
262 ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
263 return backend_->Write(data);
264 }
265
266 ResultVal<s32> PendingImpl() {
267 LOG_WARNING(Service_SSL, "(STUBBED) called.");
268 return 0;
269 }
270
271 void SetSocketDescriptor(HLERequestContext& ctx) {
272 IPC::RequestParser rp{ctx};
273 const s32 fd = rp.Pop<s32>();
274 const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
275 IPC::ResponseBuilder rb{ctx, 3};
276 rb.Push(res.Code());
277 rb.Push<s32>(res.ValueOr(-1));
278 }
279
280 void SetHostName(HLERequestContext& ctx) {
281 const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
282 const Result res = SetHostNameImpl(hostname);
283 IPC::ResponseBuilder rb{ctx, 2};
284 rb.Push(res);
285 }
286
287 void SetVerifyOption(HLERequestContext& ctx) {
288 IPC::RequestParser rp{ctx};
289 const u32 option = rp.Pop<u32>();
290 const Result res = SetVerifyOptionImpl(option);
291 IPC::ResponseBuilder rb{ctx, 2};
292 rb.Push(res);
293 }
294
295 void SetIoMode(HLERequestContext& ctx) {
296 IPC::RequestParser rp{ctx};
297 const u32 mode = rp.Pop<u32>();
298 const Result res = SetIOModeImpl(mode);
299 IPC::ResponseBuilder rb{ctx, 2};
300 rb.Push(res);
301 }
302
303 void DoHandshake(HLERequestContext& ctx) {
304 const Result res = DoHandshakeImpl();
305 IPC::ResponseBuilder rb{ctx, 2};
306 rb.Push(res);
307 }
308
309 void DoHandshakeGetServerCert(HLERequestContext& ctx) {
310 Result res = DoHandshakeImpl();
311 u32 certs_count = 0;
312 u32 certs_size = 0;
313 if (res == ResultSuccess) {
314 auto certs = backend_->GetServerCerts();
315 if (certs.Succeeded()) {
316 std::vector<u8> certs_buf = SerializeServerCerts(*certs);
317 ctx.WriteBuffer(certs_buf);
318 certs_count = static_cast<u32>(certs->size());
319 certs_size = static_cast<u32>(certs_buf.size());
320 }
321 }
322 IPC::ResponseBuilder rb{ctx, 4};
323 rb.Push(res);
324 rb.Push(certs_size);
325 rb.Push(certs_count);
326 }
327
328 void Read(HLERequestContext& ctx) {
329 const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
330 IPC::ResponseBuilder rb{ctx, 3};
331 rb.Push(res.Code());
332 if (res.Succeeded()) {
333 rb.Push(static_cast<u32>(res->size()));
334 ctx.WriteBuffer(*res);
335 } else {
336 rb.Push(static_cast<u32>(0));
337 }
338 }
339
340 void Write(HLERequestContext& ctx) {
341 const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
342 IPC::ResponseBuilder rb{ctx, 3};
343 rb.Push(res.Code());
344 rb.Push(static_cast<u32>(res.ValueOr(0)));
345 }
346
347 void Pending(HLERequestContext& ctx) {
348 const ResultVal<s32> res = PendingImpl();
349 IPC::ResponseBuilder rb{ctx, 3};
350 rb.Push(res.Code());
351 rb.Push<s32>(res.ValueOr(0));
352 }
353
354 void SetSessionCacheMode(HLERequestContext& ctx) {
355 IPC::RequestParser rp{ctx};
356 const u32 mode = rp.Pop<u32>();
357 const Result res = SetSessionCacheModeImpl(mode);
358 IPC::ResponseBuilder rb{ctx, 2};
359 rb.Push(res);
360 }
361
362 void SetOption(HLERequestContext& ctx) {
363 struct Parameters {
364 OptionType option;
365 s32 value;
366 };
367 static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
368
369 IPC::RequestParser rp{ctx};
370 const auto parameters = rp.PopRaw<Parameters>();
371
372 switch (parameters.option) {
373 case OptionType::DoNotCloseSocket:
374 do_not_close_socket_ = static_cast<bool>(parameters.value);
375 break;
376 case OptionType::GetServerCertChain:
377 get_server_cert_chain_ = static_cast<bool>(parameters.value);
378 break;
379 default:
380 LOG_WARNING(Service_SSL, "unrecognized option={}, value={}", parameters.option,
381 parameters.value);
382 }
383
384 IPC::ResponseBuilder rb{ctx, 2};
385 rb.Push(ResultSuccess);
386 }
87}; 387};
88 388
89class ISslContext final : public ServiceFramework<ISslContext> { 389class ISslContext final : public ServiceFramework<ISslContext> {
90public: 390public:
91 explicit ISslContext(Core::System& system_, SslVersion version) 391 explicit ISslContext(Core::System& system_, SslVersion version)
92 : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { 392 : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
393 shared_data_{std::make_shared<SslContextSharedData>()} {
93 static const FunctionInfo functions[] = { 394 static const FunctionInfo functions[] = {
94 {0, &ISslContext::SetOption, "SetOption"}, 395 {0, &ISslContext::SetOption, "SetOption"},
95 {1, nullptr, "GetOption"}, 396 {1, nullptr, "GetOption"},
96 {2, &ISslContext::CreateConnection, "CreateConnection"}, 397 {2, &ISslContext::CreateConnection, "CreateConnection"},
97 {3, nullptr, "GetConnectionCount"}, 398 {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
98 {4, &ISslContext::ImportServerPki, "ImportServerPki"}, 399 {4, &ISslContext::ImportServerPki, "ImportServerPki"},
99 {5, &ISslContext::ImportClientPki, "ImportClientPki"}, 400 {5, &ISslContext::ImportClientPki, "ImportClientPki"},
100 {6, nullptr, "RemoveServerPki"}, 401 {6, nullptr, "RemoveServerPki"},
@@ -111,6 +412,7 @@ public:
111 412
112private: 413private:
113 SslVersion ssl_version; 414 SslVersion ssl_version;
415 std::shared_ptr<SslContextSharedData> shared_data_;
114 416
115 void SetOption(HLERequestContext& ctx) { 417 void SetOption(HLERequestContext& ctx) {
116 struct Parameters { 418 struct Parameters {
@@ -130,11 +432,24 @@ private:
130 } 432 }
131 433
132 void CreateConnection(HLERequestContext& ctx) { 434 void CreateConnection(HLERequestContext& ctx) {
133 LOG_WARNING(Service_SSL, "(STUBBED) called"); 435 LOG_WARNING(Service_SSL, "called");
436
437 auto backend_res = CreateSSLConnectionBackend();
134 438
135 IPC::ResponseBuilder rb{ctx, 2, 0, 1}; 439 IPC::ResponseBuilder rb{ctx, 2, 0, 1};
440 rb.Push(backend_res.Code());
441 if (backend_res.Succeeded()) {
442 rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_,
443 std::move(*backend_res));
444 }
445 }
446
447 void GetConnectionCount(HLERequestContext& ctx) {
448 LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count);
449
450 IPC::ResponseBuilder rb{ctx, 3};
136 rb.Push(ResultSuccess); 451 rb.Push(ResultSuccess);
137 rb.PushIpcInterface<ISslConnection>(system, ssl_version); 452 rb.Push(shared_data_->connection_count);
138 } 453 }
139 454
140 void ImportServerPki(HLERequestContext& ctx) { 455 void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..624e07d41
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,44 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#pragma once
5
6#include "core/hle/result.h"
7
8#include "common/common_types.h"
9
10#include <memory>
11#include <span>
12#include <string>
13#include <vector>
14
15namespace Network {
16class SocketBase;
17}
18
19namespace Service::SSL {
20
21constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
22constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
23constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
24constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
25
26constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
27// ^ ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
28// with no way in the latter case to distinguish whether the client should poll
29// for read or write. The one official client I've seen handles this by always
30// polling for read (with a timeout).
31
32class SSLConnectionBackend {
33public:
34 virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
35 virtual Result SetHostName(const std::string& hostname) = 0;
36 virtual Result DoHandshake() = 0;
37 virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
38 virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
39 virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
40};
41
42ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
43
44} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..eb01561e2
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,15 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5
6#include "common/logging/log.h"
7
8namespace Service::SSL {
9
10ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
11 LOG_ERROR(Service_SSL, "No SSL backend on this platform");
12 return ResultInternalError;
13}
14
15} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..cf9b904ac
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,342 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/fs/file.h"
9#include "common/hex_util.h"
10#include "common/string_util.h"
11
12#include <mutex>
13
14#include <openssl/bio.h>
15#include <openssl/err.h>
16#include <openssl/ssl.h>
17#include <openssl/x509.h>
18
19using namespace Common::FS;
20
21namespace Service::SSL {
22
23// Import OpenSSL's `SSL` type into the namespace. This is needed because the
24// namespace is also named `SSL`.
25using ::SSL;
26
27namespace {
28
29std::once_flag one_time_init_flag;
30bool one_time_init_success = false;
31
32SSL_CTX* ssl_ctx;
33IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
34BIO_METHOD* bio_meth;
35
36Result CheckOpenSSLErrors();
37void OneTimeInit();
38void OneTimeInitLogFile();
39bool OneTimeInitBIO();
40
41} // namespace
42
43class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
44public:
45 Result Init() {
46 std::call_once(one_time_init_flag, OneTimeInit);
47
48 if (!one_time_init_success) {
49 LOG_ERROR(Service_SSL,
50 "Can't create SSL connection because OpenSSL one-time initialization failed");
51 return ResultInternalError;
52 }
53
54 ssl_ = SSL_new(ssl_ctx);
55 if (!ssl_) {
56 LOG_ERROR(Service_SSL, "SSL_new failed");
57 return CheckOpenSSLErrors();
58 }
59
60 SSL_set_connect_state(ssl_);
61
62 bio_ = BIO_new(bio_meth);
63 if (!bio_) {
64 LOG_ERROR(Service_SSL, "BIO_new failed");
65 return CheckOpenSSLErrors();
66 }
67
68 BIO_set_data(bio_, this);
69 BIO_set_init(bio_, 1);
70 SSL_set_bio(ssl_, bio_, bio_);
71
72 return ResultSuccess;
73 }
74
75 void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
76 socket_ = socket;
77 }
78
79 Result SetHostName(const std::string& hostname) override {
80 if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification
81 LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
82 return CheckOpenSSLErrors();
83 }
84 if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI
85 LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
86 return CheckOpenSSLErrors();
87 }
88 return ResultSuccess;
89 }
90
91 Result DoHandshake() override {
92 SSL_set_verify_result(ssl_, X509_V_OK);
93 int ret = SSL_do_handshake(ssl_);
94 long verify_result = SSL_get_verify_result(ssl_);
95 if (verify_result != X509_V_OK) {
96 LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
97 X509_verify_cert_error_string(verify_result));
98 return CheckOpenSSLErrors();
99 }
100 if (ret <= 0) {
101 int ssl_err = SSL_get_error(ssl_, ret);
102 if (ssl_err == SSL_ERROR_ZERO_RETURN ||
103 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) {
104 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
105 return ResultInternalError;
106 }
107 }
108 return HandleReturn("SSL_do_handshake", 0, ret).Code();
109 }
110
111 ResultVal<size_t> Read(std::span<u8> data) override {
112 size_t actual;
113 int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual);
114 return HandleReturn("SSL_read_ex", actual, ret);
115 }
116
117 ResultVal<size_t> Write(std::span<const u8> data) override {
118 size_t actual;
119 int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual);
120 return HandleReturn("SSL_write_ex", actual, ret);
121 }
122
123 ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
124 int ssl_err = SSL_get_error(ssl_, ret);
125 CheckOpenSSLErrors();
126 switch (ssl_err) {
127 case SSL_ERROR_NONE:
128 return actual;
129 case SSL_ERROR_ZERO_RETURN:
130 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
131 // DoHandshake special-cases this, but for Read and Write:
132 return size_t(0);
133 case SSL_ERROR_WANT_READ:
134 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
135 return ResultWouldBlock;
136 case SSL_ERROR_WANT_WRITE:
137 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
138 return ResultWouldBlock;
139 default:
140 if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) {
141 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
142 return size_t(0);
143 }
144 LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
145 return ResultInternalError;
146 }
147 }
148
149 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
150 STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_);
151 if (!chain) {
152 LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
153 return ResultInternalError;
154 }
155 std::vector<std::vector<u8>> ret;
156 int count = sk_X509_num(chain);
157 ASSERT(count >= 0);
158 for (int i = 0; i < count; i++) {
159 X509* x509 = sk_X509_value(chain, i);
160 ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
161 unsigned char* buf = nullptr;
162 int len = i2d_X509(x509, &buf);
163 ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
164 ret.emplace_back(buf, buf + len);
165 OPENSSL_free(buf);
166 }
167 return ret;
168 }
169
170 ~SSLConnectionBackendOpenSSL() {
171 // these are null-tolerant:
172 SSL_free(ssl_);
173 BIO_free(bio_);
174 }
175
176 static void KeyLogCallback(const SSL* ssl, const char* line) {
177 std::string str(line);
178 str.push_back('\n');
179 // Do this in a single WriteString for atomicity if multiple instances
180 // are running on different threads (though that can't currently
181 // happen).
182 if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
183 LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
184 }
185 LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
186 }
187
188 static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
189 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
190 ASSERT_OR_EXECUTE_MSG(
191 self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket");
192 BIO_clear_retry_flags(bio);
193 auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0);
194 switch (err) {
195 case Network::Errno::SUCCESS:
196 *actual_p = actual;
197 return 1;
198 case Network::Errno::AGAIN:
199 BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
200 return 0;
201 default:
202 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
203 return -1;
204 }
205 }
206
207 static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
208 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
209 ASSERT_OR_EXECUTE_MSG(
210 self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket");
211 BIO_clear_retry_flags(bio);
212 auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len});
213 switch (err) {
214 case Network::Errno::SUCCESS:
215 *actual_p = actual;
216 if (actual == 0) {
217 self->got_read_eof_ = true;
218 }
219 return actual ? 1 : 0;
220 case Network::Errno::AGAIN:
221 BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
222 return 0;
223 default:
224 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
225 return -1;
226 }
227 }
228
229 static long CtrlCallback(BIO* bio, int cmd, long larg, void* parg) {
230 switch (cmd) {
231 case BIO_CTRL_FLUSH:
232 // Nothing to flush.
233 return 1;
234 case BIO_CTRL_PUSH:
235 case BIO_CTRL_POP:
236 case BIO_CTRL_GET_KTLS_SEND:
237 case BIO_CTRL_GET_KTLS_RECV:
238 // We don't support these operations, but don't bother logging them
239 // as they're nothing unusual.
240 return 0;
241 default:
242 LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, larg, parg);
243 return 0;
244 }
245 }
246
247 SSL* ssl_ = nullptr;
248 BIO* bio_ = nullptr;
249 bool got_read_eof_ = false;
250
251 std::shared_ptr<Network::SocketBase> socket_;
252};
253
254ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
255 auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
256 Result res = conn->Init();
257 if (res.IsFailure()) {
258 return res;
259 }
260 return conn;
261}
262
263namespace {
264
265Result CheckOpenSSLErrors() {
266 unsigned long rc;
267 const char* file;
268 int line;
269 const char* func;
270 const char* data;
271 int flags;
272 while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) {
273 std::string msg;
274 msg.resize(1024, '\0');
275 ERR_error_string_n(rc, msg.data(), msg.size());
276 msg.resize(strlen(msg.data()), '\0');
277 if (flags & ERR_TXT_STRING) {
278 msg.append(" | ");
279 msg.append(data);
280 }
281 Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
282 Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
283 msg);
284 }
285 return ResultInternalError;
286}
287
288void OneTimeInit() {
289 ssl_ctx = SSL_CTX_new(TLS_client_method());
290 if (!ssl_ctx) {
291 LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
292 CheckOpenSSLErrors();
293 return;
294 }
295
296 SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
297
298 if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
299 LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
300 CheckOpenSSLErrors();
301 return;
302 }
303
304 OneTimeInitLogFile();
305
306 if (!OneTimeInitBIO()) {
307 return;
308 }
309
310 one_time_init_success = true;
311}
312
313void OneTimeInitLogFile() {
314 const char* logfile = getenv("SSLKEYLOGFILE");
315 if (logfile) {
316 key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
317 FileShareFlag::ShareWriteOnly);
318 if (key_log_file.IsOpen()) {
319 SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
320 } else {
321 LOG_CRITICAL(Service_SSL,
322 "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
323 }
324 }
325}
326
327bool OneTimeInitBIO() {
328 bio_meth =
329 BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
330 if (!bio_meth ||
331 !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
332 !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
333 !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
334 LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
335 return false;
336 }
337 return true;
338}
339
340} // namespace
341
342} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..0a326b536
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,529 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/error.h"
9#include "common/fs/file.h"
10#include "common/hex_util.h"
11#include "common/string_util.h"
12
13#include <mutex>
14
15#define SECURITY_WIN32
16#include <Security.h>
17#include <schnlsp.h>
18
19namespace {
20
21std::once_flag one_time_init_flag;
22bool one_time_init_success = false;
23
24SCHANNEL_CRED schannel_cred{
25 .dwVersion = SCHANNEL_CRED_VERSION,
26 .dwFlags = SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
27 SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
28 SCH_CRED_NO_DEFAULT_CREDS, // don't automatically present a client certificate
29 // ^ I'm assuming that nobody would want to connect Yuzu to a
30 // service that requires some OS-provided corporate client
31 // certificate, and presenting one to some arbitrary server
32 // might be a privacy concern? Who knows, though.
33};
34
35CredHandle cred_handle;
36
37static void OneTimeInit() {
38 SECURITY_STATUS ret =
39 AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
40 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
41 if (ret != SEC_E_OK) {
42 // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
43 LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
44 Common::NativeErrorToString(ret));
45 return;
46 }
47
48 one_time_init_success = true;
49}
50
51} // namespace
52
53namespace Service::SSL {
54
55class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
56public:
57 Result Init() {
58 std::call_once(one_time_init_flag, OneTimeInit);
59
60 if (!one_time_init_success) {
61 LOG_ERROR(
62 Service_SSL,
63 "Can't create SSL connection because Schannel one-time initialization failed");
64 return ResultInternalError;
65 }
66
67 return ResultSuccess;
68 }
69
70 void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
71 socket_ = socket;
72 }
73
74 Result SetHostName(const std::string& hostname) override {
75 hostname_ = hostname;
76 return ResultSuccess;
77 }
78
79 Result DoHandshake() override {
80 while (1) {
81 Result r;
82 switch (handshake_state_) {
83 case HandshakeState::Initial:
84 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
85 (r = CallInitializeSecurityContext()) != ResultSuccess) {
86 return r;
87 }
88 // CallInitializeSecurityContext updated `handshake_state_`.
89 continue;
90 case HandshakeState::ContinueNeeded:
91 case HandshakeState::IncompleteMessage:
92 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
93 (r = FillCiphertextReadBuf()) != ResultSuccess) {
94 return r;
95 }
96 if (ciphertext_read_buf_.empty()) {
97 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
98 return ResultInternalError;
99 }
100 if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
101 return r;
102 }
103 // CallInitializeSecurityContext updated `handshake_state_`.
104 continue;
105 case HandshakeState::DoneAfterFlush:
106 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
107 return r;
108 }
109 handshake_state_ = HandshakeState::Connected;
110 return ResultSuccess;
111 case HandshakeState::Connected:
112 LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
113 return ResultInternalError;
114 case HandshakeState::Error:
115 return ResultInternalError;
116 }
117 }
118 }
119
120 Result FillCiphertextReadBuf() {
121 size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
122 read_buf_fill_size_ = 0;
123 // This unnecessarily zeroes the buffer; oh well.
124 size_t offset = ciphertext_read_buf_.size();
125 ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
126 ciphertext_read_buf_.resize(offset + fill_size, 0);
127 auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
128 auto [actual, err] = socket_->Recv(0, read_span);
129 switch (err) {
130 case Network::Errno::SUCCESS:
131 ASSERT(static_cast<size_t>(actual) <= fill_size);
132 ciphertext_read_buf_.resize(offset + actual);
133 return ResultSuccess;
134 case Network::Errno::AGAIN:
135 ciphertext_read_buf_.resize(offset);
136 return ResultWouldBlock;
137 default:
138 ciphertext_read_buf_.resize(offset);
139 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
140 return ResultInternalError;
141 }
142 }
143
144 // Returns success if the write buffer has been completely emptied.
145 Result FlushCiphertextWriteBuf() {
146 while (!ciphertext_write_buf_.empty()) {
147 auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
148 switch (err) {
149 case Network::Errno::SUCCESS:
150 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
151 ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(),
152 ciphertext_write_buf_.begin() + actual);
153 break;
154 case Network::Errno::AGAIN:
155 return ResultWouldBlock;
156 default:
157 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
158 return ResultInternalError;
159 }
160 }
161 return ResultSuccess;
162 }
163
164 Result CallInitializeSecurityContext() {
165 unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY |
166 ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
167 ISC_REQ_USE_SUPPLIED_CREDS;
168 unsigned long attr;
169 // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
170 std::array<SecBuffer, 2> input_buffers{{
171 // only used if `initial_call_done`
172 {
173 // [0]
174 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
175 .BufferType = SECBUFFER_TOKEN,
176 .pvBuffer = ciphertext_read_buf_.data(),
177 },
178 {
179 // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
180 // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
181 // whole buffer wasn't used)
182 .BufferType = SECBUFFER_EMPTY,
183 },
184 }};
185 std::array<SecBuffer, 2> output_buffers{{
186 {
187 .BufferType = SECBUFFER_TOKEN,
188 }, // [0]
189 {
190 .BufferType = SECBUFFER_ALERT,
191 }, // [1]
192 }};
193 SecBufferDesc input_desc{
194 .ulVersion = SECBUFFER_VERSION,
195 .cBuffers = static_cast<unsigned long>(input_buffers.size()),
196 .pBuffers = input_buffers.data(),
197 };
198 SecBufferDesc output_desc{
199 .ulVersion = SECBUFFER_VERSION,
200 .cBuffers = static_cast<unsigned long>(output_buffers.size()),
201 .pBuffers = output_buffers.data(),
202 };
203 ASSERT_OR_EXECUTE_MSG(
204 input_buffers[0].cbBuffer == ciphertext_read_buf_.size(),
205 { return ResultInternalError; }, "read buffer too large");
206
207 bool initial_call_done = handshake_state_ != HandshakeState::Initial;
208 if (initial_call_done) {
209 LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
210 ciphertext_read_buf_.size());
211 }
212
213 SECURITY_STATUS ret =
214 InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
215 // Caller ensured we have set a hostname:
216 const_cast<char*>(hostname_.value().c_str()), req,
217 0, // Reserved1
218 0, // TargetDataRep not used with Schannel
219 initial_call_done ? &input_desc : nullptr,
220 0, // Reserved2
221 initial_call_done ? nullptr : &ctxt_, &output_desc, &attr,
222 nullptr); // ptsExpiry
223
224 if (output_buffers[0].pvBuffer) {
225 std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
226 output_buffers[0].cbBuffer);
227 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
228 FreeContextBuffer(output_buffers[0].pvBuffer);
229 }
230
231 if (output_buffers[1].pvBuffer) {
232 std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
233 output_buffers[1].cbBuffer);
234 // The documentation doesn't explain what format this data is in.
235 LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
236 Common::HexToString(span));
237 }
238
239 switch (ret) {
240 case SEC_I_CONTINUE_NEEDED:
241 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
242 if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
243 LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
244 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size());
245 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
246 ciphertext_read_buf_.end() - input_buffers[1].cbBuffer);
247 } else {
248 ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
249 ciphertext_read_buf_.clear();
250 }
251 handshake_state_ = HandshakeState::ContinueNeeded;
252 return ResultSuccess;
253 case SEC_E_INCOMPLETE_MESSAGE:
254 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
255 ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
256 read_buf_fill_size_ = input_buffers[1].cbBuffer;
257 handshake_state_ = HandshakeState::IncompleteMessage;
258 return ResultSuccess;
259 case SEC_E_OK:
260 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
261 ciphertext_read_buf_.clear();
262 handshake_state_ = HandshakeState::DoneAfterFlush;
263 return GrabStreamSizes();
264 default:
265 LOG_ERROR(Service_SSL,
266 "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
267 Common::NativeErrorToString(ret));
268 handshake_state_ = HandshakeState::Error;
269 return ResultInternalError;
270 }
271 }
272
273 Result GrabStreamSizes() {
274 SECURITY_STATUS ret =
275 QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
276 if (ret != SEC_E_OK) {
277 LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
278 Common::NativeErrorToString(ret));
279 handshake_state_ = HandshakeState::Error;
280 return ResultInternalError;
281 }
282 return ResultSuccess;
283 }
284
285 ResultVal<size_t> Read(std::span<u8> data) override {
286 if (handshake_state_ != HandshakeState::Connected) {
287 LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
288 return ResultInternalError;
289 }
290 if (data.size() == 0 || got_read_eof_) {
291 return size_t(0);
292 }
293 while (1) {
294 if (!cleartext_read_buf_.empty()) {
295 size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
296 std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
297 cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
298 cleartext_read_buf_.begin() + read_size);
299 return read_size;
300 }
301 if (!ciphertext_read_buf_.empty()) {
302 std::array<SecBuffer, 5> buffers{{
303 {
304 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
305 .BufferType = SECBUFFER_DATA,
306 .pvBuffer = ciphertext_read_buf_.data(),
307 },
308 {
309 .BufferType = SECBUFFER_EMPTY,
310 },
311 {
312 .BufferType = SECBUFFER_EMPTY,
313 },
314 {
315 .BufferType = SECBUFFER_EMPTY,
316 },
317 }};
318 ASSERT_OR_EXECUTE_MSG(
319 buffers[0].cbBuffer == ciphertext_read_buf_.size(),
320 { return ResultInternalError; }, "read buffer too large");
321 SecBufferDesc desc{
322 .ulVersion = SECBUFFER_VERSION,
323 .cBuffers = static_cast<unsigned long>(buffers.size()),
324 .pBuffers = buffers.data(),
325 };
326 SECURITY_STATUS ret =
327 DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
328 switch (ret) {
329 case SEC_E_OK:
330 ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
331 { return ResultInternalError; });
332 ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
333 { return ResultInternalError; });
334 ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
335 { return ResultInternalError; });
336 cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer),
337 static_cast<u8*>(buffers[1].pvBuffer) +
338 buffers[1].cbBuffer);
339 if (buffers[3].BufferType == SECBUFFER_EXTRA) {
340 ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size());
341 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
342 ciphertext_read_buf_.end() -
343 buffers[3].cbBuffer);
344 } else {
345 ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
346 ciphertext_read_buf_.clear();
347 }
348 continue;
349 case SEC_E_INCOMPLETE_MESSAGE:
350 break;
351 case SEC_I_CONTEXT_EXPIRED:
352 // Server hung up by sending close_notify.
353 got_read_eof_ = true;
354 return size_t(0);
355 default:
356 LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
357 Common::NativeErrorToString(ret));
358 return ResultInternalError;
359 }
360 }
361 Result r = FillCiphertextReadBuf();
362 if (r != ResultSuccess) {
363 return r;
364 }
365 if (ciphertext_read_buf_.empty()) {
366 got_read_eof_ = true;
367 return size_t(0);
368 }
369 }
370 }
371
372 ResultVal<size_t> Write(std::span<const u8> data) override {
373 if (handshake_state_ != HandshakeState::Connected) {
374 LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
375 return ResultInternalError;
376 }
377 if (data.size() == 0) {
378 return size_t(0);
379 }
380 data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage));
381 if (!cleartext_write_buf_.empty()) {
382 // Already in the middle of a write. It wouldn't make sense to not
383 // finish sending the entire buffer since TLS has
384 // header/MAC/padding/etc.
385 if (data.size() != cleartext_write_buf_.size() ||
386 std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) {
387 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
388 return ResultInternalError;
389 }
390 return WriteAlreadyEncryptedData();
391 } else {
392 cleartext_write_buf_.assign(data.begin(), data.end());
393 }
394
395 std::vector<u8> header_buf(stream_sizes_.cbHeader, 0);
396 std::vector<u8> tmp_data_buf = cleartext_write_buf_;
397 std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0);
398
399 std::array<SecBuffer, 3> buffers{{
400 {
401 .cbBuffer = stream_sizes_.cbHeader,
402 .BufferType = SECBUFFER_STREAM_HEADER,
403 .pvBuffer = header_buf.data(),
404 },
405 {
406 .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
407 .BufferType = SECBUFFER_DATA,
408 .pvBuffer = tmp_data_buf.data(),
409 },
410 {
411 .cbBuffer = stream_sizes_.cbTrailer,
412 .BufferType = SECBUFFER_STREAM_TRAILER,
413 .pvBuffer = trailer_buf.data(),
414 },
415 }};
416 ASSERT_OR_EXECUTE_MSG(
417 buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
418 "temp buffer too large");
419 SecBufferDesc desc{
420 .ulVersion = SECBUFFER_VERSION,
421 .cBuffers = static_cast<unsigned long>(buffers.size()),
422 .pBuffers = buffers.data(),
423 };
424
425 SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
426 if (ret != SEC_E_OK) {
427 LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
428 return ResultInternalError;
429 }
430 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(),
431 header_buf.end());
432 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(),
433 tmp_data_buf.end());
434 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(),
435 trailer_buf.end());
436 return WriteAlreadyEncryptedData();
437 }
438
439 ResultVal<size_t> WriteAlreadyEncryptedData() {
440 Result r = FlushCiphertextWriteBuf();
441 if (r != ResultSuccess) {
442 return r;
443 }
444 // write buf is empty
445 size_t cleartext_bytes_written = cleartext_write_buf_.size();
446 cleartext_write_buf_.clear();
447 return cleartext_bytes_written;
448 }
449
450 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
451 PCCERT_CONTEXT returned_cert = nullptr;
452 SECURITY_STATUS ret =
453 QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
454 if (ret != SEC_E_OK) {
455 LOG_ERROR(Service_SSL,
456 "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
457 Common::NativeErrorToString(ret));
458 return ResultInternalError;
459 }
460 PCCERT_CONTEXT some_cert = nullptr;
461 std::vector<std::vector<u8>> certs;
462 while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
463 certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
464 static_cast<u8*>(some_cert->pbCertEncoded) +
465 some_cert->cbCertEncoded);
466 }
467 std::reverse(certs.begin(),
468 certs.end()); // Windows returns certs in reverse order from what we want
469 CertFreeCertificateContext(returned_cert);
470 return certs;
471 }
472
473 ~SSLConnectionBackendSchannel() {
474 if (handshake_state_ != HandshakeState::Initial) {
475 DeleteSecurityContext(&ctxt_);
476 }
477 }
478
479 enum class HandshakeState {
480 // Haven't called anything yet.
481 Initial,
482 // `SEC_I_CONTINUE_NEEDED` was returned by
483 // `InitializeSecurityContext`; must finish sending data (if any) in
484 // the write buffer, then read at least one byte before calling
485 // `InitializeSecurityContext` again.
486 ContinueNeeded,
487 // `SEC_E_INCOMPLETE_MESSAGE` was returned by
488 // `InitializeSecurityContext`; hopefully the write buffer is empty;
489 // must read at least one byte before calling
490 // `InitializeSecurityContext` again.
491 IncompleteMessage,
492 // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
493 // finish sending data in the write buffer before having `DoHandshake`
494 // report success.
495 DoneAfterFlush,
496 // We finished the above and are now connected. At this point, writing
497 // and reading are separate 'state machines' represented by the
498 // nonemptiness of the ciphertext and cleartext read and write buffers.
499 Connected,
500 // Another error was returned and we shouldn't allow initialization
501 // to continue.
502 Error,
503 } handshake_state_ = HandshakeState::Initial;
504
505 CtxtHandle ctxt_;
506 SecPkgContext_StreamSizes stream_sizes_;
507
508 std::shared_ptr<Network::SocketBase> socket_;
509 std::optional<std::string> hostname_;
510
511 std::vector<u8> ciphertext_read_buf_;
512 std::vector<u8> ciphertext_write_buf_;
513 std::vector<u8> cleartext_read_buf_;
514 std::vector<u8> cleartext_write_buf_;
515
516 bool got_read_eof_ = false;
517 size_t read_buf_fill_size_ = 0;
518};
519
520ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
521 auto conn = std::make_unique<SSLConnectionBackendSchannel>();
522 Result res = conn->Init();
523 if (res.IsFailure()) {
524 return res;
525 }
526 return conn;
527}
528
529} // namespace Service::SSL