summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Liam2023-07-16 18:55:27 -0400
committerGravatar Liam2023-08-08 11:09:37 -0400
commit83eee1d2266a1de374be0a8b2c0f2827f5e25bcf (patch)
tree7dfbdfb32b6671a79f433c3cb59acc7ba0e5cc9c /src
parentcore: remove ResultVal type (diff)
downloadyuzu-83eee1d2266a1de374be0a8b2c0f2827f5e25bcf.tar.gz
yuzu-83eee1d2266a1de374be0a8b2c0f2827f5e25bcf.tar.xz
yuzu-83eee1d2266a1de374be0a8b2c0f2827f5e25bcf.zip
ssl: remove ResultVal use
Diffstat (limited to 'src')
-rw-r--r--src/core/hle/service/sockets/nsd.cpp9
-rw-r--r--src/core/hle/service/ssl/ssl.cpp86
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h8
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp2
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp45
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp62
-rw-r--r--src/core/hle/service/ssl/ssl_backend_securetransport.cpp39
7 files changed, 127 insertions, 124 deletions
diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp
index 5dfcaabb1..bac21752a 100644
--- a/src/core/hle/service/sockets/nsd.cpp
+++ b/src/core/hle/service/sockets/nsd.cpp
@@ -54,7 +54,7 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
54 RegisterHandlers(functions); 54 RegisterHandlers(functions);
55} 55}
56 56
57static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) { 57static std::string ResolveImpl(const std::string& fqdn_in) {
58 // The real implementation makes various substitutions. 58 // The real implementation makes various substitutions.
59 // For now we just return the string as-is, which is good enough when not 59 // For now we just return the string as-is, which is good enough when not
60 // connecting to real Nintendo servers. 60 // connecting to real Nintendo servers.
@@ -64,13 +64,10 @@ static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
64 64
65static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) { 65static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
66 const auto res = ResolveImpl(fqdn_in); 66 const auto res = ResolveImpl(fqdn_in);
67 if (res.Failed()) { 67 if (res.size() >= fqdn_out.size()) {
68 return res.Code();
69 }
70 if (res->size() >= fqdn_out.size()) {
71 return ResultOverflow; 68 return ResultOverflow;
72 } 69 }
73 std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1); 70 std::memcpy(fqdn_out.data(), res.c_str(), res.size() + 1);
74 return ResultSuccess; 71 return ResultSuccess;
75} 72}
76 73
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 9c96f9763..2cba9e5c9 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -4,6 +4,7 @@
4#include "common/string_util.h" 4#include "common/string_util.h"
5 5
6#include "core/core.h" 6#include "core/core.h"
7#include "core/hle/result.h"
7#include "core/hle/service/ipc_helpers.h" 8#include "core/hle/service/ipc_helpers.h"
8#include "core/hle/service/server_manager.h" 9#include "core/hle/service/server_manager.h"
9#include "core/hle/service/service.h" 10#include "core/hle/service/service.h"
@@ -141,12 +142,12 @@ private:
141 bool did_set_host_name = false; 142 bool did_set_host_name = false;
142 bool did_handshake = false; 143 bool did_handshake = false;
143 144
144 ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { 145 Result SetSocketDescriptorImpl(s32* out_fd, s32 fd) {
145 LOG_DEBUG(Service_SSL, "called, fd={}", fd); 146 LOG_DEBUG(Service_SSL, "called, fd={}", fd);
146 ASSERT(!did_handshake); 147 ASSERT(!did_handshake);
147 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); 148 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
148 ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); 149 ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
149 s32 ret_fd; 150
150 // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor 151 // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
151 if (do_not_close_socket) { 152 if (do_not_close_socket) {
152 auto res = bsd->DuplicateSocketImpl(fd); 153 auto res = bsd->DuplicateSocketImpl(fd);
@@ -156,9 +157,9 @@ private:
156 } 157 }
157 fd = *res; 158 fd = *res;
158 fd_to_close = fd; 159 fd_to_close = fd;
159 ret_fd = fd; 160 *out_fd = fd;
160 } else { 161 } else {
161 ret_fd = -1; 162 *out_fd = -1;
162 } 163 }
163 std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); 164 std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
164 if (!sock.has_value()) { 165 if (!sock.has_value()) {
@@ -167,7 +168,7 @@ private:
167 } 168 }
168 socket = std::move(*sock); 169 socket = std::move(*sock);
169 backend->SetSocket(socket); 170 backend->SetSocket(socket);
170 return ret_fd; 171 return ResultSuccess;
171 } 172 }
172 173
173 Result SetHostNameImpl(const std::string& hostname) { 174 Result SetHostNameImpl(const std::string& hostname) {
@@ -247,34 +248,36 @@ private:
247 return ret; 248 return ret;
248 } 249 }
249 250
250 ResultVal<std::vector<u8>> ReadImpl(size_t size) { 251 Result ReadImpl(std::vector<u8>* out_data, size_t size) {
251 ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); 252 ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
252 std::vector<u8> res(size); 253 size_t actual_size{};
253 ResultVal<size_t> actual = backend->Read(res); 254 Result res = backend->Read(&actual_size, *out_data);
254 if (actual.Failed()) { 255 if (res != ResultSuccess) {
255 return actual.Code(); 256 return res;
256 } 257 }
257 res.resize(*actual); 258 out_data->resize(actual_size);
258 return res; 259 return res;
259 } 260 }
260 261
261 ResultVal<size_t> WriteImpl(std::span<const u8> data) { 262 Result WriteImpl(size_t* out_size, std::span<const u8> data) {
262 ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); 263 ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
263 return backend->Write(data); 264 return backend->Write(out_size, data);
264 } 265 }
265 266
266 ResultVal<s32> PendingImpl() { 267 Result PendingImpl(s32* out_pending) {
267 LOG_WARNING(Service_SSL, "(STUBBED) called."); 268 LOG_WARNING(Service_SSL, "(STUBBED) called.");
268 return 0; 269 *out_pending = 0;
270 return ResultSuccess;
269 } 271 }
270 272
271 void SetSocketDescriptor(HLERequestContext& ctx) { 273 void SetSocketDescriptor(HLERequestContext& ctx) {
272 IPC::RequestParser rp{ctx}; 274 IPC::RequestParser rp{ctx};
273 const s32 fd = rp.Pop<s32>(); 275 const s32 in_fd = rp.Pop<s32>();
274 const ResultVal<s32> res = SetSocketDescriptorImpl(fd); 276 s32 out_fd{-1};
277 const Result res = SetSocketDescriptorImpl(&out_fd, in_fd);
275 IPC::ResponseBuilder rb{ctx, 3}; 278 IPC::ResponseBuilder rb{ctx, 3};
276 rb.Push(res.Code()); 279 rb.Push(res);
277 rb.Push<s32>(res.ValueOr(-1)); 280 rb.Push<s32>(out_fd);
278 } 281 }
279 282
280 void SetHostName(HLERequestContext& ctx) { 283 void SetHostName(HLERequestContext& ctx) {
@@ -313,14 +316,15 @@ private:
313 }; 316 };
314 static_assert(sizeof(OutputParameters) == 0x8); 317 static_assert(sizeof(OutputParameters) == 0x8);
315 318
316 const Result res = DoHandshakeImpl(); 319 Result res = DoHandshakeImpl();
317 OutputParameters out{}; 320 OutputParameters out{};
318 if (res == ResultSuccess) { 321 if (res == ResultSuccess) {
319 auto certs = backend->GetServerCerts(); 322 std::vector<std::vector<u8>> certs;
320 if (certs.Succeeded()) { 323 res = backend->GetServerCerts(&certs);
321 const std::vector<u8> certs_buf = SerializeServerCerts(*certs); 324 if (res == ResultSuccess) {
325 const std::vector<u8> certs_buf = SerializeServerCerts(certs);
322 ctx.WriteBuffer(certs_buf); 326 ctx.WriteBuffer(certs_buf);
323 out.certs_count = static_cast<u32>(certs->size()); 327 out.certs_count = static_cast<u32>(certs.size());
324 out.certs_size = static_cast<u32>(certs_buf.size()); 328 out.certs_size = static_cast<u32>(certs_buf.size());
325 } 329 }
326 } 330 }
@@ -330,29 +334,32 @@ private:
330 } 334 }
331 335
332 void Read(HLERequestContext& ctx) { 336 void Read(HLERequestContext& ctx) {
333 const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); 337 std::vector<u8> output_bytes;
338 const Result res = ReadImpl(&output_bytes, ctx.GetWriteBufferSize());
334 IPC::ResponseBuilder rb{ctx, 3}; 339 IPC::ResponseBuilder rb{ctx, 3};
335 rb.Push(res.Code()); 340 rb.Push(res);
336 if (res.Succeeded()) { 341 if (res == ResultSuccess) {
337 rb.Push(static_cast<u32>(res->size())); 342 rb.Push(static_cast<u32>(output_bytes.size()));
338 ctx.WriteBuffer(*res); 343 ctx.WriteBuffer(output_bytes);
339 } else { 344 } else {
340 rb.Push(static_cast<u32>(0)); 345 rb.Push(static_cast<u32>(0));
341 } 346 }
342 } 347 }
343 348
344 void Write(HLERequestContext& ctx) { 349 void Write(HLERequestContext& ctx) {
345 const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); 350 size_t write_size{0};
351 const Result res = WriteImpl(&write_size, ctx.ReadBuffer());
346 IPC::ResponseBuilder rb{ctx, 3}; 352 IPC::ResponseBuilder rb{ctx, 3};
347 rb.Push(res.Code()); 353 rb.Push(res);
348 rb.Push(static_cast<u32>(res.ValueOr(0))); 354 rb.Push(static_cast<u32>(write_size));
349 } 355 }
350 356
351 void Pending(HLERequestContext& ctx) { 357 void Pending(HLERequestContext& ctx) {
352 const ResultVal<s32> res = PendingImpl(); 358 s32 pending_size{0};
359 const Result res = PendingImpl(&pending_size);
353 IPC::ResponseBuilder rb{ctx, 3}; 360 IPC::ResponseBuilder rb{ctx, 3};
354 rb.Push(res.Code()); 361 rb.Push(res);
355 rb.Push<s32>(res.ValueOr(0)); 362 rb.Push<s32>(pending_size);
356 } 363 }
357 364
358 void SetSessionCacheMode(HLERequestContext& ctx) { 365 void SetSessionCacheMode(HLERequestContext& ctx) {
@@ -438,13 +445,14 @@ private:
438 void CreateConnection(HLERequestContext& ctx) { 445 void CreateConnection(HLERequestContext& ctx) {
439 LOG_WARNING(Service_SSL, "called"); 446 LOG_WARNING(Service_SSL, "called");
440 447
441 auto backend_res = CreateSSLConnectionBackend(); 448 std::unique_ptr<SSLConnectionBackend> backend;
449 const Result res = CreateSSLConnectionBackend(&backend);
442 450
443 IPC::ResponseBuilder rb{ctx, 2, 0, 1}; 451 IPC::ResponseBuilder rb{ctx, 2, 0, 1};
444 rb.Push(backend_res.Code()); 452 rb.Push(res);
445 if (backend_res.Succeeded()) { 453 if (res == ResultSuccess) {
446 rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, 454 rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
447 std::move(*backend_res)); 455 std::move(backend));
448 } 456 }
449 } 457 }
450 458
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
index 409f4367c..a2ec8e694 100644
--- a/src/core/hle/service/ssl/ssl_backend.h
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -35,11 +35,11 @@ public:
35 virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; 35 virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
36 virtual Result SetHostName(const std::string& hostname) = 0; 36 virtual Result SetHostName(const std::string& hostname) = 0;
37 virtual Result DoHandshake() = 0; 37 virtual Result DoHandshake() = 0;
38 virtual ResultVal<size_t> Read(std::span<u8> data) = 0; 38 virtual Result Read(size_t* out_size, std::span<u8> data) = 0;
39 virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; 39 virtual Result Write(size_t* out_size, std::span<const u8> data) = 0;
40 virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; 40 virtual Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) = 0;
41}; 41};
42 42
43ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); 43Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend);
44 44
45} // namespace Service::SSL 45} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
index 2f4f23c42..a7fafd0a3 100644
--- a/src/core/hle/service/ssl/ssl_backend_none.cpp
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -7,7 +7,7 @@
7 7
8namespace Service::SSL { 8namespace Service::SSL {
9 9
10ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { 10Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
11 LOG_ERROR(Service_SSL, 11 LOG_ERROR(Service_SSL,
12 "Can't create SSL connection because no SSL backend is available on this platform"); 12 "Can't create SSL connection because no SSL backend is available on this platform");
13 return ResultInternalError; 13 return ResultInternalError;
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
index 6ca869dbf..b2dd37cd4 100644
--- a/src/core/hle/service/ssl/ssl_backend_openssl.cpp
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -105,31 +105,30 @@ public:
105 return ResultInternalError; 105 return ResultInternalError;
106 } 106 }
107 } 107 }
108 return HandleReturn("SSL_do_handshake", 0, ret).Code(); 108 return HandleReturn("SSL_do_handshake", 0, ret);
109 } 109 }
110 110
111 ResultVal<size_t> Read(std::span<u8> data) override { 111 Result Read(size_t* out_size, std::span<u8> data) override {
112 size_t actual; 112 const int ret = SSL_read_ex(ssl, data.data(), data.size(), out_size);
113 const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); 113 return HandleReturn("SSL_read_ex", out_size, ret);
114 return HandleReturn("SSL_read_ex", actual, ret);
115 } 114 }
116 115
117 ResultVal<size_t> Write(std::span<const u8> data) override { 116 Result Write(size_t* out_size, std::span<const u8> data) override {
118 size_t actual; 117 const int ret = SSL_write_ex(ssl, data.data(), data.size(), out_size);
119 const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); 118 return HandleReturn("SSL_write_ex", out_size, ret);
120 return HandleReturn("SSL_write_ex", actual, ret);
121 } 119 }
122 120
123 ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { 121 Result HandleReturn(const char* what, size_t* actual, int ret) {
124 const int ssl_err = SSL_get_error(ssl, ret); 122 const int ssl_err = SSL_get_error(ssl, ret);
125 CheckOpenSSLErrors(); 123 CheckOpenSSLErrors();
126 switch (ssl_err) { 124 switch (ssl_err) {
127 case SSL_ERROR_NONE: 125 case SSL_ERROR_NONE:
128 return actual; 126 return ResultSuccess;
129 case SSL_ERROR_ZERO_RETURN: 127 case SSL_ERROR_ZERO_RETURN:
130 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); 128 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
131 // DoHandshake special-cases this, but for Read and Write: 129 // DoHandshake special-cases this, but for Read and Write:
132 return size_t(0); 130 *actual = 0;
131 return ResultSuccess;
133 case SSL_ERROR_WANT_READ: 132 case SSL_ERROR_WANT_READ:
134 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); 133 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
135 return ResultWouldBlock; 134 return ResultWouldBlock;
@@ -139,20 +138,20 @@ public:
139 default: 138 default:
140 if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { 139 if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
141 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); 140 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
142 return size_t(0); 141 *actual = 0;
142 return ResultSuccess;
143 } 143 }
144 LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); 144 LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
145 return ResultInternalError; 145 return ResultInternalError;
146 } 146 }
147 } 147 }
148 148
149 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { 149 Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) override {
150 STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); 150 STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
151 if (!chain) { 151 if (!chain) {
152 LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); 152 LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
153 return ResultInternalError; 153 return ResultInternalError;
154 } 154 }
155 std::vector<std::vector<u8>> ret;
156 int count = sk_X509_num(chain); 155 int count = sk_X509_num(chain);
157 ASSERT(count >= 0); 156 ASSERT(count >= 0);
158 for (int i = 0; i < count; i++) { 157 for (int i = 0; i < count; i++) {
@@ -161,10 +160,10 @@ public:
161 unsigned char* buf = nullptr; 160 unsigned char* buf = nullptr;
162 int len = i2d_X509(x509, &buf); 161 int len = i2d_X509(x509, &buf);
163 ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); 162 ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
164 ret.emplace_back(buf, buf + len); 163 out_certs->emplace_back(buf, buf + len);
165 OPENSSL_free(buf); 164 OPENSSL_free(buf);
166 } 165 }
167 return ret; 166 return ResultSuccess;
168 } 167 }
169 168
170 ~SSLConnectionBackendOpenSSL() { 169 ~SSLConnectionBackendOpenSSL() {
@@ -253,13 +252,13 @@ public:
253 std::shared_ptr<Network::SocketBase> socket; 252 std::shared_ptr<Network::SocketBase> socket;
254}; 253};
255 254
256ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { 255Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
257 auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); 256 auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
258 const Result res = conn->Init(); 257
259 if (res.IsFailure()) { 258 R_TRY(conn->Init());
260 return res; 259
261 } 260 *out_backend = std::move(conn);
262 return conn; 261 return ResultSuccess;
263} 262}
264 263
265namespace { 264namespace {
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
index d8074339a..bda12b761 100644
--- a/src/core/hle/service/ssl/ssl_backend_schannel.cpp
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -299,21 +299,22 @@ public:
299 return ResultSuccess; 299 return ResultSuccess;
300 } 300 }
301 301
302 ResultVal<size_t> Read(std::span<u8> data) override { 302 Result Read(size_t* out_size, std::span<u8> data) override {
303 *out_size = 0;
303 if (handshake_state != HandshakeState::Connected) { 304 if (handshake_state != HandshakeState::Connected) {
304 LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); 305 LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
305 return ResultInternalError; 306 return ResultInternalError;
306 } 307 }
307 if (data.size() == 0 || got_read_eof) { 308 if (data.size() == 0 || got_read_eof) {
308 return size_t(0); 309 return ResultSuccess;
309 } 310 }
310 while (1) { 311 while (1) {
311 if (!cleartext_read_buf.empty()) { 312 if (!cleartext_read_buf.empty()) {
312 const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); 313 *out_size = std::min(cleartext_read_buf.size(), data.size());
313 std::memcpy(data.data(), cleartext_read_buf.data(), read_size); 314 std::memcpy(data.data(), cleartext_read_buf.data(), *out_size);
314 cleartext_read_buf.erase(cleartext_read_buf.begin(), 315 cleartext_read_buf.erase(cleartext_read_buf.begin(),
315 cleartext_read_buf.begin() + read_size); 316 cleartext_read_buf.begin() + *out_size);
316 return read_size; 317 return ResultSuccess;
317 } 318 }
318 if (!ciphertext_read_buf.empty()) { 319 if (!ciphertext_read_buf.empty()) {
319 SecBuffer empty{ 320 SecBuffer empty{
@@ -366,7 +367,8 @@ public:
366 case SEC_I_CONTEXT_EXPIRED: 367 case SEC_I_CONTEXT_EXPIRED:
367 // Server hung up by sending close_notify. 368 // Server hung up by sending close_notify.
368 got_read_eof = true; 369 got_read_eof = true;
369 return size_t(0); 370 *out_size = 0;
371 return ResultSuccess;
370 default: 372 default:
371 LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", 373 LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
372 Common::NativeErrorToString(ret)); 374 Common::NativeErrorToString(ret));
@@ -379,18 +381,21 @@ public:
379 } 381 }
380 if (ciphertext_read_buf.empty()) { 382 if (ciphertext_read_buf.empty()) {
381 got_read_eof = true; 383 got_read_eof = true;
382 return size_t(0); 384 *out_size = 0;
385 return ResultSuccess;
383 } 386 }
384 } 387 }
385 } 388 }
386 389
387 ResultVal<size_t> Write(std::span<const u8> data) override { 390 Result Write(size_t* out_size, std::span<const u8> data) override {
391 *out_size = 0;
392
388 if (handshake_state != HandshakeState::Connected) { 393 if (handshake_state != HandshakeState::Connected) {
389 LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); 394 LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
390 return ResultInternalError; 395 return ResultInternalError;
391 } 396 }
392 if (data.size() == 0) { 397 if (data.size() == 0) {
393 return size_t(0); 398 return ResultSuccess;
394 } 399 }
395 data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); 400 data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
396 if (!cleartext_write_buf.empty()) { 401 if (!cleartext_write_buf.empty()) {
@@ -402,7 +407,7 @@ public:
402 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); 407 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
403 return ResultInternalError; 408 return ResultInternalError;
404 } 409 }
405 return WriteAlreadyEncryptedData(); 410 return WriteAlreadyEncryptedData(out_size);
406 } else { 411 } else {
407 cleartext_write_buf.assign(data.begin(), data.end()); 412 cleartext_write_buf.assign(data.begin(), data.end());
408 } 413 }
@@ -448,21 +453,21 @@ public:
448 tmp_data_buf.end()); 453 tmp_data_buf.end());
449 ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), 454 ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
450 trailer_buf.end()); 455 trailer_buf.end());
451 return WriteAlreadyEncryptedData(); 456 return WriteAlreadyEncryptedData(out_size);
452 } 457 }
453 458
454 ResultVal<size_t> WriteAlreadyEncryptedData() { 459 Result WriteAlreadyEncryptedData(size_t* out_size) {
455 const Result r = FlushCiphertextWriteBuf(); 460 const Result r = FlushCiphertextWriteBuf();
456 if (r != ResultSuccess) { 461 if (r != ResultSuccess) {
457 return r; 462 return r;
458 } 463 }
459 // write buf is empty 464 // write buf is empty
460 const size_t cleartext_bytes_written = cleartext_write_buf.size(); 465 *out_size = cleartext_write_buf.size();
461 cleartext_write_buf.clear(); 466 cleartext_write_buf.clear();
462 return cleartext_bytes_written; 467 return ResultSuccess;
463 } 468 }
464 469
465 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { 470 Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) override {
466 PCCERT_CONTEXT returned_cert = nullptr; 471 PCCERT_CONTEXT returned_cert = nullptr;
467 const SECURITY_STATUS ret = 472 const SECURITY_STATUS ret =
468 QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); 473 QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
@@ -473,16 +478,15 @@ public:
473 return ResultInternalError; 478 return ResultInternalError;
474 } 479 }
475 PCCERT_CONTEXT some_cert = nullptr; 480 PCCERT_CONTEXT some_cert = nullptr;
476 std::vector<std::vector<u8>> certs;
477 while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { 481 while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
478 certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), 482 out_certs->emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
479 static_cast<u8*>(some_cert->pbCertEncoded) + 483 static_cast<u8*>(some_cert->pbCertEncoded) +
480 some_cert->cbCertEncoded); 484 some_cert->cbCertEncoded);
481 } 485 }
482 std::reverse(certs.begin(), 486 std::reverse(out_certs->begin(),
483 certs.end()); // Windows returns certs in reverse order from what we want 487 out_certs->end()); // Windows returns certs in reverse order from what we want
484 CertFreeCertificateContext(returned_cert); 488 CertFreeCertificateContext(returned_cert);
485 return certs; 489 return ResultSuccess;
486 } 490 }
487 491
488 ~SSLConnectionBackendSchannel() { 492 ~SSLConnectionBackendSchannel() {
@@ -532,13 +536,13 @@ public:
532 size_t read_buf_fill_size = 0; 536 size_t read_buf_fill_size = 0;
533}; 537};
534 538
535ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { 539Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
536 auto conn = std::make_unique<SSLConnectionBackendSchannel>(); 540 auto conn = std::make_unique<SSLConnectionBackendSchannel>();
537 const Result res = conn->Init(); 541
538 if (res.IsFailure()) { 542 R_TRY(conn->Init());
539 return res; 543
540 } 544 *out_backend = std::move(conn);
541 return conn; 545 return ResultSuccess;
542} 546}
543 547
544} // namespace Service::SSL 548} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
index b3083cbad..5f9e6bef7 100644
--- a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
+++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
@@ -103,24 +103,20 @@ public:
103 return HandleReturn("SSLHandshake", 0, status).Code(); 103 return HandleReturn("SSLHandshake", 0, status).Code();
104 } 104 }
105 105
106 ResultVal<size_t> Read(std::span<u8> data) override { 106 Result Read(size_t* out_size, std::span<u8> data) override {
107 size_t actual; 107 OSStatus status = SSLRead(context, data.data(), data.size(), &out_size);
108 OSStatus status = SSLRead(context, data.data(), data.size(), &actual); 108 return HandleReturn("SSLRead", out_size, status);
109 ;
110 return HandleReturn("SSLRead", actual, status);
111 } 109 }
112 110
113 ResultVal<size_t> Write(std::span<const u8> data) override { 111 Result Write(size_t* out_size, std::span<const u8> data) override {
114 size_t actual; 112 OSStatus status = SSLWrite(context, data.data(), data.size(), &out_size);
115 OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); 113 return HandleReturn("SSLWrite", out_size, status);
116 ;
117 return HandleReturn("SSLWrite", actual, status);
118 } 114 }
119 115
120 ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { 116 Result HandleReturn(const char* what, size_t* actual, OSStatus status) {
121 switch (status) { 117 switch (status) {
122 case 0: 118 case 0:
123 return actual; 119 return ResultSuccess;
124 case errSSLWouldBlock: 120 case errSSLWouldBlock:
125 return ResultWouldBlock; 121 return ResultWouldBlock;
126 default: { 122 default: {
@@ -136,22 +132,21 @@ public:
136 } 132 }
137 } 133 }
138 134
139 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { 135 Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) override {
140 CFReleaser<SecTrustRef> trust; 136 CFReleaser<SecTrustRef> trust;
141 OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); 137 OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
142 if (status) { 138 if (status) {
143 LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); 139 LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
144 return ResultInternalError; 140 return ResultInternalError;
145 } 141 }
146 std::vector<std::vector<u8>> ret;
147 for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { 142 for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
148 SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); 143 SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
149 CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); 144 CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
150 ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); 145 ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
151 const u8* ptr = CFDataGetBytePtr(data); 146 const u8* ptr = CFDataGetBytePtr(data);
152 ret.emplace_back(ptr, ptr + CFDataGetLength(data)); 147 out_certs->emplace_back(ptr, ptr + CFDataGetLength(data));
153 } 148 }
154 return ret; 149 return ResultSuccess;
155 } 150 }
156 151
157 static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { 152 static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
@@ -210,13 +205,13 @@ private:
210 std::shared_ptr<Network::SocketBase> socket; 205 std::shared_ptr<Network::SocketBase> socket;
211}; 206};
212 207
213ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { 208Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
214 auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); 209 auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
215 const Result res = conn->Init(); 210
216 if (res.IsFailure()) { 211 R_TRY(conn->Init());
217 return res; 212
218 } 213 *out_backend = std::move(conn);
219 return conn; 214 return ResultSuccess;
220} 215}
221 216
222} // namespace Service::SSL 217} // namespace Service::SSL