summaryrefslogtreecommitdiff
path: root/src/core/hle/service/ssl
diff options
context:
space:
mode:
authorGravatar comex2023-06-19 18:17:43 -0700
committerGravatar comex2023-06-25 12:53:31 -0700
commit8e703e08dfcf735a08df2ceff6a05221b7cc981f (patch)
tree771ebe71883ff9e179156f2b38b21b05070d7667 /src/core/hle/service/ssl
parentMerge pull request #10825 from 8bitDream/vcpkg-zlib (diff)
downloadyuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.tar.gz
yuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.tar.xz
yuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.zip
Implement SSL service
This implements some missing network APIs including a large chunk of the SSL service, enough for Mario Maker (with an appropriate mod applied) to connect to the fan server [Open Course World](https://opencourse.world/). Connecting to first-party servers is out of scope of this PR and is a minefield I'd rather not step into. ## TLS TLS is implemented with multiple backends depending on the system's 'native' TLS library. Currently there are two backends: Schannel for Windows, and OpenSSL for Linux. (In reality Linux is a bit of a free-for-all where there's no one 'native' library, but OpenSSL is the closest it gets.) On macOS the 'native' library is SecureTransport but that isn't implemented in this PR. (Instead, all non-Windows OSes will use OpenSSL unless disabled with `-DENABLE_OPENSSL=OFF`.) Why have multiple backends instead of just using a single library, especially given that Yuzu already embeds mbedtls for cryptographic algorithms? Well, I tried implementing this on mbedtls first, but the problem is TLS policies - mainly trusted certificate policies, and to a lesser extent trusted algorithms, SSL versions, etc. ...In practice, the chance that someone is going to conduct a man-in-the-middle attack on a third-party game server is pretty low, but I'm a security nerd so I like to do the right security things. My base assumption is that we want to use the host system's TLS policies. An alternative would be to more closely emulate the Switch's TLS implementation (which is based on NSS). But for one thing, I don't feel like reverse engineering it. And I'd argue that for third-party servers such as Open Course World, it's theoretically preferable to use the system's policies rather than the Switch's, for two reasons 1. Someday the Switch will stop being updated, and the trusted cert list, algorithms, etc. will start to go stale, but users will still want to connect to third-party servers, and there's no reason they shouldn't have up-to-date security when doing so. At that point, homebrew users on actual hardware may patch the TLS implementation, but for emulators it's simpler to just use the host's stack. 2. Also, it's good to respect any custom certificate policies the user may have added systemwide. For example, they may have added custom trusted CAs in order to use TLS debugging tools or pass through corporate MitM middleboxes. Or they may have removed some CAs that are normally trusted out of paranoia. Note that this policy wouldn't work as-is for connecting to first-party servers, because some of them serve certificates based on Nintendo's own CA rather than a publicly trusted one. However, this could probably be solved easily by using appropriate APIs to adding Nintendo's CA as an alternate trusted cert for Yuzu's connections. That is not implemented in this PR because, again, first-party servers are out of scope. (If anything I'd rather have an option to _block_ connections to Nintendo servers, but that's not implemented here.) To use the host's TLS policies, there are three theoretical options: a) Import the host's trusted certificate list into a cross-platform TLS library (presumably mbedtls). b) Use the native TLS library to verify certificates but use a cross-platform TLS library for everything else. c) Use the native TLS library for everything. Two problems with option a). First, importing the trusted certificate list at minimum requires a bunch of platform-specific code, which mbedtls does not have built in. Interestingly, OpenSSL recently gained the ability to import the Windows certificate trust store... but that leads to the second problem, which is that a list of trusted certificates is [not expressive enough](https://bugs.archlinux.org/task/41909) to express a modern certificate trust policy. For example, Windows has the concept of [explicitly distrusted certificates](https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-server-2012-r2-and-2012/dn265983(v=ws.11)), and macOS requires Certificate Transparency validation for some certificates with complex rules for when it's required. Option b) (using native library just to verify certs) is probably feasible, but it would miss aspects of TLS policy other than trusted certs (like allowed algorithms), and in any case it might well require writing more code, not less, compared to using the native library for everything. So I ended up at option c), using the native library for everything. What I'd *really* prefer would be to use a third-party library that does option c) for me. Rust has a good library for this, [native-tls](https://docs.rs/native-tls/latest/native_tls/). I did search, but I couldn't find a good option in the C or C++ ecosystem, at least not any that wasn't part of some much larger framework. I was surprised - isn't this a pretty common use case? Well, many applications only need TLS for HTTPS, and they can use libcurl, which has a TLS abstraction layer internally but doesn't expose it. Other applications only support a single TLS library, or use one of the aforementioned larger frameworks, or are platform-specific to begin with, or of course are written in a non-C/C++ language, most of which have some canonical choice for TLS. But there are also many applications that have a set of TLS backends just like this; it's just that nobody has gone ahead and abstracted the pattern into a library, at least not a widespread one. Amusingly, there is one TLS abstraction layer that Yuzu already bundles: the one in ffmpeg. But it is missing some features that would be needed to use it here (like reusing an existing socket rather than managing the socket itself). Though, that does mean that the wiki's build instructions for Linux (and macOS for some reason?) already recommend installing OpenSSL, so no need to update those. ## Other APIs implemented - Sockets: - GetSockOpt(`SO_ERROR`) - SetSockOpt(`SO_NOSIGPIPE`) (stub, I have no idea what this does on Switch) - `DuplicateSocket` (because the SSL sysmodule calls it internally) - More `PollEvents` values - NSD: - `Resolve` and `ResolveEx` (stub, good enough for Open Course World and probably most third-party servers, but not first-party) - SFDNSRES: - `GetHostByNameRequest` and `GetHostByNameRequestWithOptions` - `ResolverSetOptionRequest` (stub) ## Fixes - Parts of the socket code were previously allocating a `sockaddr` object on the stack when calling functions that take a `sockaddr*` (e.g. `accept`). This might seem like the right thing to do to avoid illegal aliasing, but in fact `sockaddr` is not guaranteed to be large enough to hold any particular type of address, only the header. This worked in practice because in practice `sockaddr` is the same size as `sockaddr_in`, but it's not how the API is meant to be used. I changed this to allocate an `sockaddr_in` on the stack and `reinterpret_cast` it. I could try to do something cleverer with `aligned_storage`, but casting is the idiomatic way to use these particular APIs, so it's really the system's responsibility to avoid any aliasing issues. - I rewrote most of the `GetAddrInfoRequest[WithOptions]` implementation. The old implementation invoked the host's getaddrinfo directly from sfdnsres.cpp, and directly passed through the host's socket type, protocol, etc. values rather than looking up the corresponding constants on the Switch. To be fair, these constants don't tend to actually vary across systems, but still... I added a wrapper for `getaddrinfo` in `internal_network/network.cpp` similar to the ones for other socket APIs, and changed the `GetAddrInfoRequest` implementation to use it. While I was at it, I rewrote the serialization to use the same approach I used to implement `GetHostByNameRequest`, because it reduces the number of size calculations. While doing so I removed `AF_INET6` support because the Switch doesn't support IPv6; it might be nice to support IPv6 anyway, but that would have to apply to all of the socket APIs. I also corrected the IPC wrappers for `GetAddrInfoRequest` and `GetAddrInfoRequestWithOptions` based on reverse engineering and hardware testing. Every call to `GetAddrInfoRequestWithOptions` returns *four* different error codes (IPC status, getaddrinfo error code, netdb error code, and errno), and `GetAddrInfoRequest` returns three of those but in a different order, and it doesn't really matter but the existing implementation was a bit off, as I discovered while testing `GetHostByNameRequest`. - The new serialization code is based on two simple helper functions: ```cpp template <typename T> static void Append(std::vector<u8>& vec, T t); void AppendNulTerminated(std::vector<u8>& vec, std::string_view str); ``` I was thinking there must be existing functions somewhere that assist with serialization/deserialization of binary data, but all I could find was the helper methods in `IOFile` and `HLERequestContext`, not anything that could be used with a generic byte buffer. If I'm not missing something, then maybe I should move the above functions to a new header in `common`... right now they're just sitting in `sfdnsres.cpp` where they're used. - Not a fix, but `SocketBase::Recv`/`Send` is changed to use `std::span<u8>` rather than `std::vector<u8>&` to avoid needing to copy the data to/from a vector when those methods are called from the TLS implementation.
Diffstat (limited to 'src/core/hle/service/ssl')
-rw-r--r--src/core/hle/service/ssl/ssl.cpp349
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h44
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp15
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp342
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp529
5 files changed, 1262 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..a3b54c7f0 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project 1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later 2// SPDX-License-Identifier: GPL-2.0-or-later
3 3
4#include "common/string_util.h"
5
6#include "core/core.h"
4#include "core/hle/service/ipc_helpers.h" 7#include "core/hle/service/ipc_helpers.h"
5#include "core/hle/service/server_manager.h" 8#include "core/hle/service/server_manager.h"
6#include "core/hle/service/service.h" 9#include "core/hle/service/service.h"
10#include "core/hle/service/sm/sm.h"
11#include "core/hle/service/sockets/bsd.h"
7#include "core/hle/service/ssl/ssl.h" 12#include "core/hle/service/ssl/ssl.h"
13#include "core/hle/service/ssl/ssl_backend.h"
14#include "core/internal_network/network.h"
15#include "core/internal_network/sockets.h"
8 16
9namespace Service::SSL { 17namespace Service::SSL {
10 18
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
20 CrlImportDateCheckEnable = 1, 28 CrlImportDateCheckEnable = 1,
21}; 29};
22 30
31// This is nn::ssl::Connection::IoMode
32enum class IoMode : u32 {
33 Blocking = 1,
34 NonBlocking = 2,
35};
36
37// This is nn::ssl::sf::OptionType
38enum class OptionType : u32 {
39 DoNotCloseSocket = 0,
40 GetServerCertChain = 1,
41};
42
23// This is nn::ssl::sf::SslVersion 43// This is nn::ssl::sf::SslVersion
24struct SslVersion { 44struct SslVersion {
25 union { 45 union {
@@ -34,35 +54,42 @@ struct SslVersion {
34 }; 54 };
35}; 55};
36 56
57struct SslContextSharedData {
58 u32 connection_count = 0;
59};
60
37class ISslConnection final : public ServiceFramework<ISslConnection> { 61class ISslConnection final : public ServiceFramework<ISslConnection> {
38public: 62public:
39 explicit ISslConnection(Core::System& system_, SslVersion version) 63 explicit ISslConnection(Core::System& system_, SslVersion version,
40 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { 64 std::shared_ptr<SslContextSharedData>& shared_data,
65 std::unique_ptr<SSLConnectionBackend>&& backend)
66 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version},
67 shared_data_{shared_data}, backend_{std::move(backend)} {
41 // clang-format off 68 // clang-format off
42 static const FunctionInfo functions[] = { 69 static const FunctionInfo functions[] = {
43 {0, nullptr, "SetSocketDescriptor"}, 70 {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
44 {1, nullptr, "SetHostName"}, 71 {1, &ISslConnection::SetHostName, "SetHostName"},
45 {2, nullptr, "SetVerifyOption"}, 72 {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
46 {3, nullptr, "SetIoMode"}, 73 {3, &ISslConnection::SetIoMode, "SetIoMode"},
47 {4, nullptr, "GetSocketDescriptor"}, 74 {4, nullptr, "GetSocketDescriptor"},
48 {5, nullptr, "GetHostName"}, 75 {5, nullptr, "GetHostName"},
49 {6, nullptr, "GetVerifyOption"}, 76 {6, nullptr, "GetVerifyOption"},
50 {7, nullptr, "GetIoMode"}, 77 {7, nullptr, "GetIoMode"},
51 {8, nullptr, "DoHandshake"}, 78 {8, &ISslConnection::DoHandshake, "DoHandshake"},
52 {9, nullptr, "DoHandshakeGetServerCert"}, 79 {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
53 {10, nullptr, "Read"}, 80 {10, &ISslConnection::Read, "Read"},
54 {11, nullptr, "Write"}, 81 {11, &ISslConnection::Write, "Write"},
55 {12, nullptr, "Pending"}, 82 {12, &ISslConnection::Pending, "Pending"},
56 {13, nullptr, "Peek"}, 83 {13, nullptr, "Peek"},
57 {14, nullptr, "Poll"}, 84 {14, nullptr, "Poll"},
58 {15, nullptr, "GetVerifyCertError"}, 85 {15, nullptr, "GetVerifyCertError"},
59 {16, nullptr, "GetNeededServerCertBufferSize"}, 86 {16, nullptr, "GetNeededServerCertBufferSize"},
60 {17, nullptr, "SetSessionCacheMode"}, 87 {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
61 {18, nullptr, "GetSessionCacheMode"}, 88 {18, nullptr, "GetSessionCacheMode"},
62 {19, nullptr, "FlushSessionCache"}, 89 {19, nullptr, "FlushSessionCache"},
63 {20, nullptr, "SetRenegotiationMode"}, 90 {20, nullptr, "SetRenegotiationMode"},
64 {21, nullptr, "GetRenegotiationMode"}, 91 {21, nullptr, "GetRenegotiationMode"},
65 {22, nullptr, "SetOption"}, 92 {22, &ISslConnection::SetOption, "SetOption"},
66 {23, nullptr, "GetOption"}, 93 {23, nullptr, "GetOption"},
67 {24, nullptr, "GetVerifyCertErrors"}, 94 {24, nullptr, "GetVerifyCertErrors"},
68 {25, nullptr, "GetCipherInfo"}, 95 {25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,295 @@ public:
80 // clang-format on 107 // clang-format on
81 108
82 RegisterHandlers(functions); 109 RegisterHandlers(functions);
110
111 shared_data->connection_count++;
112 }
113
114 ~ISslConnection() {
115 shared_data_->connection_count--;
116 if (fd_to_close_.has_value()) {
117 s32 fd = *fd_to_close_;
118 if (!do_not_close_socket_) {
119 LOG_ERROR(Service_SSL,
120 "do_not_close_socket was changed after setting socket; is this right?");
121 } else {
122 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
123 if (bsd) {
124 auto err = bsd->CloseImpl(fd);
125 if (err != Service::Sockets::Errno::SUCCESS) {
126 LOG_ERROR(Service_SSL, "failed to close duplicated socket: {}", err);
127 }
128 }
129 }
130 }
83 } 131 }
84 132
85private: 133private:
86 SslVersion ssl_version; 134 SslVersion ssl_version;
135 std::shared_ptr<SslContextSharedData> shared_data_;
136 std::unique_ptr<SSLConnectionBackend> backend_;
137 std::optional<int> fd_to_close_;
138 bool do_not_close_socket_ = false;
139 bool get_server_cert_chain_ = false;
140 std::shared_ptr<Network::SocketBase> socket_;
141 bool did_set_host_name_ = false;
142 bool did_handshake_ = false;
143
144 ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
145 LOG_DEBUG(Service_SSL, "called, fd={}", fd);
146 ASSERT(!did_handshake_);
147 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
148 ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
149 s32 ret_fd;
150 // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
151 if (do_not_close_socket_) {
152 auto res = bsd->DuplicateSocketImpl(fd);
153 if (!res.has_value()) {
154 LOG_ERROR(Service_SSL, "failed to duplicate socket");
155 return ResultInvalidSocket;
156 }
157 fd = *res;
158 fd_to_close_ = fd;
159 ret_fd = fd;
160 } else {
161 ret_fd = -1;
162 }
163 std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
164 if (!sock.has_value()) {
165 LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
166 return ResultInvalidSocket;
167 }
168 socket_ = std::move(*sock);
169 backend_->SetSocket(socket_);
170 return ret_fd;
171 }
172
173 Result SetHostNameImpl(const std::string& hostname) {
174 LOG_DEBUG(Service_SSL, "SetHostNameImpl({})", hostname);
175 ASSERT(!did_handshake_);
176 Result res = backend_->SetHostName(hostname);
177 if (res == ResultSuccess) {
178 did_set_host_name_ = true;
179 }
180 return res;
181 }
182
183 Result SetVerifyOptionImpl(u32 option) {
184 ASSERT(!did_handshake_);
185 LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
186 return ResultSuccess;
187 }
188
189 Result SetIOModeImpl(u32 _mode) {
190 auto mode = static_cast<IoMode>(_mode);
191 ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
192 ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; });
193
194 bool non_block = mode == IoMode::NonBlocking;
195 Network::Errno e = socket_->SetNonBlock(non_block);
196 if (e != Network::Errno::SUCCESS) {
197 LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
198 }
199 return ResultSuccess;
200 }
201
202 Result SetSessionCacheModeImpl(u32 mode) {
203 ASSERT(!did_handshake_);
204 LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
205 return ResultSuccess;
206 }
207
208 Result DoHandshakeImpl() {
209 ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; });
210 ASSERT_OR_EXECUTE_MSG(
211 did_set_host_name_, { return ResultInternalError; },
212 "Expected SetHostName before DoHandshake");
213 Result res = backend_->DoHandshake();
214 did_handshake_ = res.IsSuccess();
215 return res;
216 }
217
218 std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
219 struct Header {
220 u64 magic;
221 u32 count;
222 u32 pad;
223 };
224 struct EntryHeader {
225 u32 size;
226 u32 offset;
227 };
228 if (!get_server_cert_chain_) {
229 // Just return the first one, unencoded.
230 ASSERT_OR_EXECUTE_MSG(
231 !certs.empty(), { return {}; }, "Should be at least one server cert");
232 return certs[0];
233 }
234 std::vector<u8> ret;
235 Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
236 ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
237 size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
238 for (auto& cert : certs) {
239 EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
240 data_offset += cert.size();
241 ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
242 reinterpret_cast<u8*>(&entry_header + 1));
243 }
244 for (auto& cert : certs) {
245 ret.insert(ret.end(), cert.begin(), cert.end());
246 }
247 return ret;
248 }
249
250 ResultVal<std::vector<u8>> ReadImpl(size_t size) {
251 ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
252 std::vector<u8> res(size);
253 ResultVal<size_t> actual = backend_->Read(res);
254 if (actual.Failed()) {
255 return actual.Code();
256 }
257 res.resize(*actual);
258 return res;
259 }
260
261 ResultVal<size_t> WriteImpl(std::span<const u8> data) {
262 ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; });
263 return backend_->Write(data);
264 }
265
266 ResultVal<s32> PendingImpl() {
267 LOG_WARNING(Service_SSL, "(STUBBED) called.");
268 return 0;
269 }
270
271 void SetSocketDescriptor(HLERequestContext& ctx) {
272 IPC::RequestParser rp{ctx};
273 const s32 fd = rp.Pop<s32>();
274 const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
275 IPC::ResponseBuilder rb{ctx, 3};
276 rb.Push(res.Code());
277 rb.Push<s32>(res.ValueOr(-1));
278 }
279
280 void SetHostName(HLERequestContext& ctx) {
281 const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
282 const Result res = SetHostNameImpl(hostname);
283 IPC::ResponseBuilder rb{ctx, 2};
284 rb.Push(res);
285 }
286
287 void SetVerifyOption(HLERequestContext& ctx) {
288 IPC::RequestParser rp{ctx};
289 const u32 option = rp.Pop<u32>();
290 const Result res = SetVerifyOptionImpl(option);
291 IPC::ResponseBuilder rb{ctx, 2};
292 rb.Push(res);
293 }
294
295 void SetIoMode(HLERequestContext& ctx) {
296 IPC::RequestParser rp{ctx};
297 const u32 mode = rp.Pop<u32>();
298 const Result res = SetIOModeImpl(mode);
299 IPC::ResponseBuilder rb{ctx, 2};
300 rb.Push(res);
301 }
302
303 void DoHandshake(HLERequestContext& ctx) {
304 const Result res = DoHandshakeImpl();
305 IPC::ResponseBuilder rb{ctx, 2};
306 rb.Push(res);
307 }
308
309 void DoHandshakeGetServerCert(HLERequestContext& ctx) {
310 Result res = DoHandshakeImpl();
311 u32 certs_count = 0;
312 u32 certs_size = 0;
313 if (res == ResultSuccess) {
314 auto certs = backend_->GetServerCerts();
315 if (certs.Succeeded()) {
316 std::vector<u8> certs_buf = SerializeServerCerts(*certs);
317 ctx.WriteBuffer(certs_buf);
318 certs_count = static_cast<u32>(certs->size());
319 certs_size = static_cast<u32>(certs_buf.size());
320 }
321 }
322 IPC::ResponseBuilder rb{ctx, 4};
323 rb.Push(res);
324 rb.Push(certs_size);
325 rb.Push(certs_count);
326 }
327
328 void Read(HLERequestContext& ctx) {
329 const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
330 IPC::ResponseBuilder rb{ctx, 3};
331 rb.Push(res.Code());
332 if (res.Succeeded()) {
333 rb.Push(static_cast<u32>(res->size()));
334 ctx.WriteBuffer(*res);
335 } else {
336 rb.Push(static_cast<u32>(0));
337 }
338 }
339
340 void Write(HLERequestContext& ctx) {
341 const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
342 IPC::ResponseBuilder rb{ctx, 3};
343 rb.Push(res.Code());
344 rb.Push(static_cast<u32>(res.ValueOr(0)));
345 }
346
347 void Pending(HLERequestContext& ctx) {
348 const ResultVal<s32> res = PendingImpl();
349 IPC::ResponseBuilder rb{ctx, 3};
350 rb.Push(res.Code());
351 rb.Push<s32>(res.ValueOr(0));
352 }
353
354 void SetSessionCacheMode(HLERequestContext& ctx) {
355 IPC::RequestParser rp{ctx};
356 const u32 mode = rp.Pop<u32>();
357 const Result res = SetSessionCacheModeImpl(mode);
358 IPC::ResponseBuilder rb{ctx, 2};
359 rb.Push(res);
360 }
361
362 void SetOption(HLERequestContext& ctx) {
363 struct Parameters {
364 OptionType option;
365 s32 value;
366 };
367 static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
368
369 IPC::RequestParser rp{ctx};
370 const auto parameters = rp.PopRaw<Parameters>();
371
372 switch (parameters.option) {
373 case OptionType::DoNotCloseSocket:
374 do_not_close_socket_ = static_cast<bool>(parameters.value);
375 break;
376 case OptionType::GetServerCertChain:
377 get_server_cert_chain_ = static_cast<bool>(parameters.value);
378 break;
379 default:
380 LOG_WARNING(Service_SSL, "unrecognized option={}, value={}", parameters.option,
381 parameters.value);
382 }
383
384 IPC::ResponseBuilder rb{ctx, 2};
385 rb.Push(ResultSuccess);
386 }
87}; 387};
88 388
89class ISslContext final : public ServiceFramework<ISslContext> { 389class ISslContext final : public ServiceFramework<ISslContext> {
90public: 390public:
91 explicit ISslContext(Core::System& system_, SslVersion version) 391 explicit ISslContext(Core::System& system_, SslVersion version)
92 : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { 392 : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
393 shared_data_{std::make_shared<SslContextSharedData>()} {
93 static const FunctionInfo functions[] = { 394 static const FunctionInfo functions[] = {
94 {0, &ISslContext::SetOption, "SetOption"}, 395 {0, &ISslContext::SetOption, "SetOption"},
95 {1, nullptr, "GetOption"}, 396 {1, nullptr, "GetOption"},
96 {2, &ISslContext::CreateConnection, "CreateConnection"}, 397 {2, &ISslContext::CreateConnection, "CreateConnection"},
97 {3, nullptr, "GetConnectionCount"}, 398 {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
98 {4, &ISslContext::ImportServerPki, "ImportServerPki"}, 399 {4, &ISslContext::ImportServerPki, "ImportServerPki"},
99 {5, &ISslContext::ImportClientPki, "ImportClientPki"}, 400 {5, &ISslContext::ImportClientPki, "ImportClientPki"},
100 {6, nullptr, "RemoveServerPki"}, 401 {6, nullptr, "RemoveServerPki"},
@@ -111,6 +412,7 @@ public:
111 412
112private: 413private:
113 SslVersion ssl_version; 414 SslVersion ssl_version;
415 std::shared_ptr<SslContextSharedData> shared_data_;
114 416
115 void SetOption(HLERequestContext& ctx) { 417 void SetOption(HLERequestContext& ctx) {
116 struct Parameters { 418 struct Parameters {
@@ -130,11 +432,24 @@ private:
130 } 432 }
131 433
132 void CreateConnection(HLERequestContext& ctx) { 434 void CreateConnection(HLERequestContext& ctx) {
133 LOG_WARNING(Service_SSL, "(STUBBED) called"); 435 LOG_WARNING(Service_SSL, "called");
436
437 auto backend_res = CreateSSLConnectionBackend();
134 438
135 IPC::ResponseBuilder rb{ctx, 2, 0, 1}; 439 IPC::ResponseBuilder rb{ctx, 2, 0, 1};
440 rb.Push(backend_res.Code());
441 if (backend_res.Succeeded()) {
442 rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_,
443 std::move(*backend_res));
444 }
445 }
446
447 void GetConnectionCount(HLERequestContext& ctx) {
448 LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count);
449
450 IPC::ResponseBuilder rb{ctx, 3};
136 rb.Push(ResultSuccess); 451 rb.Push(ResultSuccess);
137 rb.PushIpcInterface<ISslConnection>(system, ssl_version); 452 rb.Push(shared_data_->connection_count);
138 } 453 }
139 454
140 void ImportServerPki(HLERequestContext& ctx) { 455 void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..624e07d41
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,44 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#pragma once
5
6#include "core/hle/result.h"
7
8#include "common/common_types.h"
9
10#include <memory>
11#include <span>
12#include <string>
13#include <vector>
14
15namespace Network {
16class SocketBase;
17}
18
19namespace Service::SSL {
20
21constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
22constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
23constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
24constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
25
26constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
27// ^ ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
28// with no way in the latter case to distinguish whether the client should poll
29// for read or write. The one official client I've seen handles this by always
30// polling for read (with a timeout).
31
32class SSLConnectionBackend {
33public:
34 virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
35 virtual Result SetHostName(const std::string& hostname) = 0;
36 virtual Result DoHandshake() = 0;
37 virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
38 virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
39 virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
40};
41
42ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
43
44} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..eb01561e2
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,15 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5
6#include "common/logging/log.h"
7
8namespace Service::SSL {
9
10ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
11 LOG_ERROR(Service_SSL, "No SSL backend on this platform");
12 return ResultInternalError;
13}
14
15} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..cf9b904ac
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,342 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/fs/file.h"
9#include "common/hex_util.h"
10#include "common/string_util.h"
11
12#include <mutex>
13
14#include <openssl/bio.h>
15#include <openssl/err.h>
16#include <openssl/ssl.h>
17#include <openssl/x509.h>
18
19using namespace Common::FS;
20
21namespace Service::SSL {
22
23// Import OpenSSL's `SSL` type into the namespace. This is needed because the
24// namespace is also named `SSL`.
25using ::SSL;
26
27namespace {
28
29std::once_flag one_time_init_flag;
30bool one_time_init_success = false;
31
32SSL_CTX* ssl_ctx;
33IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
34BIO_METHOD* bio_meth;
35
36Result CheckOpenSSLErrors();
37void OneTimeInit();
38void OneTimeInitLogFile();
39bool OneTimeInitBIO();
40
41} // namespace
42
43class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
44public:
45 Result Init() {
46 std::call_once(one_time_init_flag, OneTimeInit);
47
48 if (!one_time_init_success) {
49 LOG_ERROR(Service_SSL,
50 "Can't create SSL connection because OpenSSL one-time initialization failed");
51 return ResultInternalError;
52 }
53
54 ssl_ = SSL_new(ssl_ctx);
55 if (!ssl_) {
56 LOG_ERROR(Service_SSL, "SSL_new failed");
57 return CheckOpenSSLErrors();
58 }
59
60 SSL_set_connect_state(ssl_);
61
62 bio_ = BIO_new(bio_meth);
63 if (!bio_) {
64 LOG_ERROR(Service_SSL, "BIO_new failed");
65 return CheckOpenSSLErrors();
66 }
67
68 BIO_set_data(bio_, this);
69 BIO_set_init(bio_, 1);
70 SSL_set_bio(ssl_, bio_, bio_);
71
72 return ResultSuccess;
73 }
74
75 void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
76 socket_ = socket;
77 }
78
79 Result SetHostName(const std::string& hostname) override {
80 if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification
81 LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
82 return CheckOpenSSLErrors();
83 }
84 if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI
85 LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
86 return CheckOpenSSLErrors();
87 }
88 return ResultSuccess;
89 }
90
91 Result DoHandshake() override {
92 SSL_set_verify_result(ssl_, X509_V_OK);
93 int ret = SSL_do_handshake(ssl_);
94 long verify_result = SSL_get_verify_result(ssl_);
95 if (verify_result != X509_V_OK) {
96 LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
97 X509_verify_cert_error_string(verify_result));
98 return CheckOpenSSLErrors();
99 }
100 if (ret <= 0) {
101 int ssl_err = SSL_get_error(ssl_, ret);
102 if (ssl_err == SSL_ERROR_ZERO_RETURN ||
103 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) {
104 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
105 return ResultInternalError;
106 }
107 }
108 return HandleReturn("SSL_do_handshake", 0, ret).Code();
109 }
110
111 ResultVal<size_t> Read(std::span<u8> data) override {
112 size_t actual;
113 int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual);
114 return HandleReturn("SSL_read_ex", actual, ret);
115 }
116
117 ResultVal<size_t> Write(std::span<const u8> data) override {
118 size_t actual;
119 int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual);
120 return HandleReturn("SSL_write_ex", actual, ret);
121 }
122
123 ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
124 int ssl_err = SSL_get_error(ssl_, ret);
125 CheckOpenSSLErrors();
126 switch (ssl_err) {
127 case SSL_ERROR_NONE:
128 return actual;
129 case SSL_ERROR_ZERO_RETURN:
130 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
131 // DoHandshake special-cases this, but for Read and Write:
132 return size_t(0);
133 case SSL_ERROR_WANT_READ:
134 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
135 return ResultWouldBlock;
136 case SSL_ERROR_WANT_WRITE:
137 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
138 return ResultWouldBlock;
139 default:
140 if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) {
141 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
142 return size_t(0);
143 }
144 LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
145 return ResultInternalError;
146 }
147 }
148
149 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
150 STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_);
151 if (!chain) {
152 LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
153 return ResultInternalError;
154 }
155 std::vector<std::vector<u8>> ret;
156 int count = sk_X509_num(chain);
157 ASSERT(count >= 0);
158 for (int i = 0; i < count; i++) {
159 X509* x509 = sk_X509_value(chain, i);
160 ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
161 unsigned char* buf = nullptr;
162 int len = i2d_X509(x509, &buf);
163 ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
164 ret.emplace_back(buf, buf + len);
165 OPENSSL_free(buf);
166 }
167 return ret;
168 }
169
170 ~SSLConnectionBackendOpenSSL() {
171 // these are null-tolerant:
172 SSL_free(ssl_);
173 BIO_free(bio_);
174 }
175
176 static void KeyLogCallback(const SSL* ssl, const char* line) {
177 std::string str(line);
178 str.push_back('\n');
179 // Do this in a single WriteString for atomicity if multiple instances
180 // are running on different threads (though that can't currently
181 // happen).
182 if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
183 LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
184 }
185 LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
186 }
187
188 static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
189 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
190 ASSERT_OR_EXECUTE_MSG(
191 self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket");
192 BIO_clear_retry_flags(bio);
193 auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0);
194 switch (err) {
195 case Network::Errno::SUCCESS:
196 *actual_p = actual;
197 return 1;
198 case Network::Errno::AGAIN:
199 BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
200 return 0;
201 default:
202 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
203 return -1;
204 }
205 }
206
207 static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
208 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
209 ASSERT_OR_EXECUTE_MSG(
210 self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket");
211 BIO_clear_retry_flags(bio);
212 auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len});
213 switch (err) {
214 case Network::Errno::SUCCESS:
215 *actual_p = actual;
216 if (actual == 0) {
217 self->got_read_eof_ = true;
218 }
219 return actual ? 1 : 0;
220 case Network::Errno::AGAIN:
221 BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
222 return 0;
223 default:
224 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
225 return -1;
226 }
227 }
228
229 static long CtrlCallback(BIO* bio, int cmd, long larg, void* parg) {
230 switch (cmd) {
231 case BIO_CTRL_FLUSH:
232 // Nothing to flush.
233 return 1;
234 case BIO_CTRL_PUSH:
235 case BIO_CTRL_POP:
236 case BIO_CTRL_GET_KTLS_SEND:
237 case BIO_CTRL_GET_KTLS_RECV:
238 // We don't support these operations, but don't bother logging them
239 // as they're nothing unusual.
240 return 0;
241 default:
242 LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, larg, parg);
243 return 0;
244 }
245 }
246
247 SSL* ssl_ = nullptr;
248 BIO* bio_ = nullptr;
249 bool got_read_eof_ = false;
250
251 std::shared_ptr<Network::SocketBase> socket_;
252};
253
254ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
255 auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
256 Result res = conn->Init();
257 if (res.IsFailure()) {
258 return res;
259 }
260 return conn;
261}
262
263namespace {
264
265Result CheckOpenSSLErrors() {
266 unsigned long rc;
267 const char* file;
268 int line;
269 const char* func;
270 const char* data;
271 int flags;
272 while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) {
273 std::string msg;
274 msg.resize(1024, '\0');
275 ERR_error_string_n(rc, msg.data(), msg.size());
276 msg.resize(strlen(msg.data()), '\0');
277 if (flags & ERR_TXT_STRING) {
278 msg.append(" | ");
279 msg.append(data);
280 }
281 Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
282 Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
283 msg);
284 }
285 return ResultInternalError;
286}
287
288void OneTimeInit() {
289 ssl_ctx = SSL_CTX_new(TLS_client_method());
290 if (!ssl_ctx) {
291 LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
292 CheckOpenSSLErrors();
293 return;
294 }
295
296 SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
297
298 if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
299 LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
300 CheckOpenSSLErrors();
301 return;
302 }
303
304 OneTimeInitLogFile();
305
306 if (!OneTimeInitBIO()) {
307 return;
308 }
309
310 one_time_init_success = true;
311}
312
313void OneTimeInitLogFile() {
314 const char* logfile = getenv("SSLKEYLOGFILE");
315 if (logfile) {
316 key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
317 FileShareFlag::ShareWriteOnly);
318 if (key_log_file.IsOpen()) {
319 SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
320 } else {
321 LOG_CRITICAL(Service_SSL,
322 "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
323 }
324 }
325}
326
327bool OneTimeInitBIO() {
328 bio_meth =
329 BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
330 if (!bio_meth ||
331 !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
332 !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
333 !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
334 LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
335 return false;
336 }
337 return true;
338}
339
340} // namespace
341
342} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..0a326b536
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,529 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/error.h"
9#include "common/fs/file.h"
10#include "common/hex_util.h"
11#include "common/string_util.h"
12
13#include <mutex>
14
15#define SECURITY_WIN32
16#include <Security.h>
17#include <schnlsp.h>
18
19namespace {
20
21std::once_flag one_time_init_flag;
22bool one_time_init_success = false;
23
24SCHANNEL_CRED schannel_cred{
25 .dwVersion = SCHANNEL_CRED_VERSION,
26 .dwFlags = SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
27 SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
28 SCH_CRED_NO_DEFAULT_CREDS, // don't automatically present a client certificate
29 // ^ I'm assuming that nobody would want to connect Yuzu to a
30 // service that requires some OS-provided corporate client
31 // certificate, and presenting one to some arbitrary server
32 // might be a privacy concern? Who knows, though.
33};
34
35CredHandle cred_handle;
36
37static void OneTimeInit() {
38 SECURITY_STATUS ret =
39 AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
40 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
41 if (ret != SEC_E_OK) {
42 // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
43 LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
44 Common::NativeErrorToString(ret));
45 return;
46 }
47
48 one_time_init_success = true;
49}
50
51} // namespace
52
53namespace Service::SSL {
54
55class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
56public:
57 Result Init() {
58 std::call_once(one_time_init_flag, OneTimeInit);
59
60 if (!one_time_init_success) {
61 LOG_ERROR(
62 Service_SSL,
63 "Can't create SSL connection because Schannel one-time initialization failed");
64 return ResultInternalError;
65 }
66
67 return ResultSuccess;
68 }
69
70 void SetSocket(std::shared_ptr<Network::SocketBase> socket) override {
71 socket_ = socket;
72 }
73
74 Result SetHostName(const std::string& hostname) override {
75 hostname_ = hostname;
76 return ResultSuccess;
77 }
78
79 Result DoHandshake() override {
80 while (1) {
81 Result r;
82 switch (handshake_state_) {
83 case HandshakeState::Initial:
84 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
85 (r = CallInitializeSecurityContext()) != ResultSuccess) {
86 return r;
87 }
88 // CallInitializeSecurityContext updated `handshake_state_`.
89 continue;
90 case HandshakeState::ContinueNeeded:
91 case HandshakeState::IncompleteMessage:
92 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
93 (r = FillCiphertextReadBuf()) != ResultSuccess) {
94 return r;
95 }
96 if (ciphertext_read_buf_.empty()) {
97 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
98 return ResultInternalError;
99 }
100 if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
101 return r;
102 }
103 // CallInitializeSecurityContext updated `handshake_state_`.
104 continue;
105 case HandshakeState::DoneAfterFlush:
106 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
107 return r;
108 }
109 handshake_state_ = HandshakeState::Connected;
110 return ResultSuccess;
111 case HandshakeState::Connected:
112 LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
113 return ResultInternalError;
114 case HandshakeState::Error:
115 return ResultInternalError;
116 }
117 }
118 }
119
120 Result FillCiphertextReadBuf() {
121 size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
122 read_buf_fill_size_ = 0;
123 // This unnecessarily zeroes the buffer; oh well.
124 size_t offset = ciphertext_read_buf_.size();
125 ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
126 ciphertext_read_buf_.resize(offset + fill_size, 0);
127 auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
128 auto [actual, err] = socket_->Recv(0, read_span);
129 switch (err) {
130 case Network::Errno::SUCCESS:
131 ASSERT(static_cast<size_t>(actual) <= fill_size);
132 ciphertext_read_buf_.resize(offset + actual);
133 return ResultSuccess;
134 case Network::Errno::AGAIN:
135 ciphertext_read_buf_.resize(offset);
136 return ResultWouldBlock;
137 default:
138 ciphertext_read_buf_.resize(offset);
139 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
140 return ResultInternalError;
141 }
142 }
143
144 // Returns success if the write buffer has been completely emptied.
145 Result FlushCiphertextWriteBuf() {
146 while (!ciphertext_write_buf_.empty()) {
147 auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
148 switch (err) {
149 case Network::Errno::SUCCESS:
150 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
151 ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(),
152 ciphertext_write_buf_.begin() + actual);
153 break;
154 case Network::Errno::AGAIN:
155 return ResultWouldBlock;
156 default:
157 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
158 return ResultInternalError;
159 }
160 }
161 return ResultSuccess;
162 }
163
164 Result CallInitializeSecurityContext() {
165 unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY |
166 ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
167 ISC_REQ_USE_SUPPLIED_CREDS;
168 unsigned long attr;
169 // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
170 std::array<SecBuffer, 2> input_buffers{{
171 // only used if `initial_call_done`
172 {
173 // [0]
174 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
175 .BufferType = SECBUFFER_TOKEN,
176 .pvBuffer = ciphertext_read_buf_.data(),
177 },
178 {
179 // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
180 // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
181 // whole buffer wasn't used)
182 .BufferType = SECBUFFER_EMPTY,
183 },
184 }};
185 std::array<SecBuffer, 2> output_buffers{{
186 {
187 .BufferType = SECBUFFER_TOKEN,
188 }, // [0]
189 {
190 .BufferType = SECBUFFER_ALERT,
191 }, // [1]
192 }};
193 SecBufferDesc input_desc{
194 .ulVersion = SECBUFFER_VERSION,
195 .cBuffers = static_cast<unsigned long>(input_buffers.size()),
196 .pBuffers = input_buffers.data(),
197 };
198 SecBufferDesc output_desc{
199 .ulVersion = SECBUFFER_VERSION,
200 .cBuffers = static_cast<unsigned long>(output_buffers.size()),
201 .pBuffers = output_buffers.data(),
202 };
203 ASSERT_OR_EXECUTE_MSG(
204 input_buffers[0].cbBuffer == ciphertext_read_buf_.size(),
205 { return ResultInternalError; }, "read buffer too large");
206
207 bool initial_call_done = handshake_state_ != HandshakeState::Initial;
208 if (initial_call_done) {
209 LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
210 ciphertext_read_buf_.size());
211 }
212
213 SECURITY_STATUS ret =
214 InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
215 // Caller ensured we have set a hostname:
216 const_cast<char*>(hostname_.value().c_str()), req,
217 0, // Reserved1
218 0, // TargetDataRep not used with Schannel
219 initial_call_done ? &input_desc : nullptr,
220 0, // Reserved2
221 initial_call_done ? nullptr : &ctxt_, &output_desc, &attr,
222 nullptr); // ptsExpiry
223
224 if (output_buffers[0].pvBuffer) {
225 std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
226 output_buffers[0].cbBuffer);
227 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
228 FreeContextBuffer(output_buffers[0].pvBuffer);
229 }
230
231 if (output_buffers[1].pvBuffer) {
232 std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
233 output_buffers[1].cbBuffer);
234 // The documentation doesn't explain what format this data is in.
235 LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
236 Common::HexToString(span));
237 }
238
239 switch (ret) {
240 case SEC_I_CONTINUE_NEEDED:
241 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
242 if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
243 LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
244 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size());
245 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
246 ciphertext_read_buf_.end() - input_buffers[1].cbBuffer);
247 } else {
248 ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
249 ciphertext_read_buf_.clear();
250 }
251 handshake_state_ = HandshakeState::ContinueNeeded;
252 return ResultSuccess;
253 case SEC_E_INCOMPLETE_MESSAGE:
254 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
255 ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
256 read_buf_fill_size_ = input_buffers[1].cbBuffer;
257 handshake_state_ = HandshakeState::IncompleteMessage;
258 return ResultSuccess;
259 case SEC_E_OK:
260 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
261 ciphertext_read_buf_.clear();
262 handshake_state_ = HandshakeState::DoneAfterFlush;
263 return GrabStreamSizes();
264 default:
265 LOG_ERROR(Service_SSL,
266 "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
267 Common::NativeErrorToString(ret));
268 handshake_state_ = HandshakeState::Error;
269 return ResultInternalError;
270 }
271 }
272
273 Result GrabStreamSizes() {
274 SECURITY_STATUS ret =
275 QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
276 if (ret != SEC_E_OK) {
277 LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
278 Common::NativeErrorToString(ret));
279 handshake_state_ = HandshakeState::Error;
280 return ResultInternalError;
281 }
282 return ResultSuccess;
283 }
284
285 ResultVal<size_t> Read(std::span<u8> data) override {
286 if (handshake_state_ != HandshakeState::Connected) {
287 LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
288 return ResultInternalError;
289 }
290 if (data.size() == 0 || got_read_eof_) {
291 return size_t(0);
292 }
293 while (1) {
294 if (!cleartext_read_buf_.empty()) {
295 size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
296 std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
297 cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
298 cleartext_read_buf_.begin() + read_size);
299 return read_size;
300 }
301 if (!ciphertext_read_buf_.empty()) {
302 std::array<SecBuffer, 5> buffers{{
303 {
304 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()),
305 .BufferType = SECBUFFER_DATA,
306 .pvBuffer = ciphertext_read_buf_.data(),
307 },
308 {
309 .BufferType = SECBUFFER_EMPTY,
310 },
311 {
312 .BufferType = SECBUFFER_EMPTY,
313 },
314 {
315 .BufferType = SECBUFFER_EMPTY,
316 },
317 }};
318 ASSERT_OR_EXECUTE_MSG(
319 buffers[0].cbBuffer == ciphertext_read_buf_.size(),
320 { return ResultInternalError; }, "read buffer too large");
321 SecBufferDesc desc{
322 .ulVersion = SECBUFFER_VERSION,
323 .cBuffers = static_cast<unsigned long>(buffers.size()),
324 .pBuffers = buffers.data(),
325 };
326 SECURITY_STATUS ret =
327 DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
328 switch (ret) {
329 case SEC_E_OK:
330 ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
331 { return ResultInternalError; });
332 ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
333 { return ResultInternalError; });
334 ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
335 { return ResultInternalError; });
336 cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer),
337 static_cast<u8*>(buffers[1].pvBuffer) +
338 buffers[1].cbBuffer);
339 if (buffers[3].BufferType == SECBUFFER_EXTRA) {
340 ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size());
341 ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(),
342 ciphertext_read_buf_.end() -
343 buffers[3].cbBuffer);
344 } else {
345 ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
346 ciphertext_read_buf_.clear();
347 }
348 continue;
349 case SEC_E_INCOMPLETE_MESSAGE:
350 break;
351 case SEC_I_CONTEXT_EXPIRED:
352 // Server hung up by sending close_notify.
353 got_read_eof_ = true;
354 return size_t(0);
355 default:
356 LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
357 Common::NativeErrorToString(ret));
358 return ResultInternalError;
359 }
360 }
361 Result r = FillCiphertextReadBuf();
362 if (r != ResultSuccess) {
363 return r;
364 }
365 if (ciphertext_read_buf_.empty()) {
366 got_read_eof_ = true;
367 return size_t(0);
368 }
369 }
370 }
371
372 ResultVal<size_t> Write(std::span<const u8> data) override {
373 if (handshake_state_ != HandshakeState::Connected) {
374 LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
375 return ResultInternalError;
376 }
377 if (data.size() == 0) {
378 return size_t(0);
379 }
380 data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage));
381 if (!cleartext_write_buf_.empty()) {
382 // Already in the middle of a write. It wouldn't make sense to not
383 // finish sending the entire buffer since TLS has
384 // header/MAC/padding/etc.
385 if (data.size() != cleartext_write_buf_.size() ||
386 std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) {
387 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
388 return ResultInternalError;
389 }
390 return WriteAlreadyEncryptedData();
391 } else {
392 cleartext_write_buf_.assign(data.begin(), data.end());
393 }
394
395 std::vector<u8> header_buf(stream_sizes_.cbHeader, 0);
396 std::vector<u8> tmp_data_buf = cleartext_write_buf_;
397 std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0);
398
399 std::array<SecBuffer, 3> buffers{{
400 {
401 .cbBuffer = stream_sizes_.cbHeader,
402 .BufferType = SECBUFFER_STREAM_HEADER,
403 .pvBuffer = header_buf.data(),
404 },
405 {
406 .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
407 .BufferType = SECBUFFER_DATA,
408 .pvBuffer = tmp_data_buf.data(),
409 },
410 {
411 .cbBuffer = stream_sizes_.cbTrailer,
412 .BufferType = SECBUFFER_STREAM_TRAILER,
413 .pvBuffer = trailer_buf.data(),
414 },
415 }};
416 ASSERT_OR_EXECUTE_MSG(
417 buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
418 "temp buffer too large");
419 SecBufferDesc desc{
420 .ulVersion = SECBUFFER_VERSION,
421 .cBuffers = static_cast<unsigned long>(buffers.size()),
422 .pBuffers = buffers.data(),
423 };
424
425 SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
426 if (ret != SEC_E_OK) {
427 LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
428 return ResultInternalError;
429 }
430 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(),
431 header_buf.end());
432 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(),
433 tmp_data_buf.end());
434 ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(),
435 trailer_buf.end());
436 return WriteAlreadyEncryptedData();
437 }
438
439 ResultVal<size_t> WriteAlreadyEncryptedData() {
440 Result r = FlushCiphertextWriteBuf();
441 if (r != ResultSuccess) {
442 return r;
443 }
444 // write buf is empty
445 size_t cleartext_bytes_written = cleartext_write_buf_.size();
446 cleartext_write_buf_.clear();
447 return cleartext_bytes_written;
448 }
449
450 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
451 PCCERT_CONTEXT returned_cert = nullptr;
452 SECURITY_STATUS ret =
453 QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
454 if (ret != SEC_E_OK) {
455 LOG_ERROR(Service_SSL,
456 "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
457 Common::NativeErrorToString(ret));
458 return ResultInternalError;
459 }
460 PCCERT_CONTEXT some_cert = nullptr;
461 std::vector<std::vector<u8>> certs;
462 while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
463 certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
464 static_cast<u8*>(some_cert->pbCertEncoded) +
465 some_cert->cbCertEncoded);
466 }
467 std::reverse(certs.begin(),
468 certs.end()); // Windows returns certs in reverse order from what we want
469 CertFreeCertificateContext(returned_cert);
470 return certs;
471 }
472
473 ~SSLConnectionBackendSchannel() {
474 if (handshake_state_ != HandshakeState::Initial) {
475 DeleteSecurityContext(&ctxt_);
476 }
477 }
478
479 enum class HandshakeState {
480 // Haven't called anything yet.
481 Initial,
482 // `SEC_I_CONTINUE_NEEDED` was returned by
483 // `InitializeSecurityContext`; must finish sending data (if any) in
484 // the write buffer, then read at least one byte before calling
485 // `InitializeSecurityContext` again.
486 ContinueNeeded,
487 // `SEC_E_INCOMPLETE_MESSAGE` was returned by
488 // `InitializeSecurityContext`; hopefully the write buffer is empty;
489 // must read at least one byte before calling
490 // `InitializeSecurityContext` again.
491 IncompleteMessage,
492 // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
493 // finish sending data in the write buffer before having `DoHandshake`
494 // report success.
495 DoneAfterFlush,
496 // We finished the above and are now connected. At this point, writing
497 // and reading are separate 'state machines' represented by the
498 // nonemptiness of the ciphertext and cleartext read and write buffers.
499 Connected,
500 // Another error was returned and we shouldn't allow initialization
501 // to continue.
502 Error,
503 } handshake_state_ = HandshakeState::Initial;
504
505 CtxtHandle ctxt_;
506 SecPkgContext_StreamSizes stream_sizes_;
507
508 std::shared_ptr<Network::SocketBase> socket_;
509 std::optional<std::string> hostname_;
510
511 std::vector<u8> ciphertext_read_buf_;
512 std::vector<u8> ciphertext_write_buf_;
513 std::vector<u8> cleartext_read_buf_;
514 std::vector<u8> cleartext_write_buf_;
515
516 bool got_read_eof_ = false;
517 size_t read_buf_fill_size_ = 0;
518};
519
520ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
521 auto conn = std::make_unique<SSLConnectionBackendSchannel>();
522 Result res = conn->Init();
523 if (res.IsFailure()) {
524 return res;
525 }
526 return conn;
527}
528
529} // namespace Service::SSL