summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt16
-rw-r--r--src/common/socket_types.h16
-rw-r--r--src/core/CMakeLists.txt18
-rw-r--r--src/core/hle/service/sockets/bsd.cpp120
-rw-r--r--src/core/hle/service/sockets/bsd.h13
-rw-r--r--src/core/hle/service/sockets/nsd.cpp58
-rw-r--r--src/core/hle/service/sockets/nsd.h4
-rw-r--r--src/core/hle/service/sockets/sfdnsres.cpp388
-rw-r--r--src/core/hle/service/sockets/sfdnsres.h3
-rw-r--r--src/core/hle/service/sockets/sockets.h33
-rw-r--r--src/core/hle/service/sockets/sockets_translate.cpp114
-rw-r--r--src/core/hle/service/sockets/sockets_translate.h17
-rw-r--r--src/core/hle/service/ssl/ssl.cpp353
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h45
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp16
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp351
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp543
-rw-r--r--src/core/hle/service/ssl/ssl_backend_securetransport.cpp219
-rw-r--r--src/core/internal_network/network.cpp286
-rw-r--r--src/core/internal_network/network.h36
-rw-r--r--src/core/internal_network/socket_proxy.cpp22
-rw-r--r--src/core/internal_network/socket_proxy.h8
-rw-r--r--src/core/internal_network/sockets.h16
23 files changed, 2413 insertions, 282 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7f8febb90..647219052 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF)
63 63
64CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF) 64CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF)
65 65
66set(DEFAULT_ENABLE_OPENSSL ON)
67if (ANDROID OR WIN32 OR APPLE)
68 # - Windows defaults to the Schannel backend.
69 # - macOS defaults to the SecureTransport backend.
70 # - Android currently has no SSL backend as the NDK doesn't include any SSL
71 # library; a proper 'native' backend would have to go through Java.
72 # But you can force builds for those platforms to use OpenSSL if you have
73 # your own copy of it.
74 set(DEFAULT_ENABLE_OPENSSL OFF)
75endif()
76option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL})
77
66# On Android, fetch and compile libcxx before doing anything else 78# On Android, fetch and compile libcxx before doing anything else
67if (ANDROID) 79if (ANDROID)
68 set(CMAKE_SKIP_INSTALL_RULES ON) 80 set(CMAKE_SKIP_INSTALL_RULES ON)
@@ -322,6 +334,10 @@ if (MINGW)
322 find_library(MSWSOCK_LIBRARY mswsock REQUIRED) 334 find_library(MSWSOCK_LIBRARY mswsock REQUIRED)
323endif() 335endif()
324 336
337if(ENABLE_OPENSSL)
338 find_package(OpenSSL 1.1.1 REQUIRED)
339endif()
340
325# Please consider this as a stub 341# Please consider this as a stub
326if(ENABLE_QT6 AND Qt6_LOCATION) 342if(ENABLE_QT6 AND Qt6_LOCATION)
327 list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}") 343 list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}")
diff --git a/src/common/socket_types.h b/src/common/socket_types.h
index 0a801a443..b2191c2e8 100644
--- a/src/common/socket_types.h
+++ b/src/common/socket_types.h
@@ -5,15 +5,19 @@
5 5
6#include "common/common_types.h" 6#include "common/common_types.h"
7 7
8#include <optional>
9
8namespace Network { 10namespace Network {
9 11
10/// Address families 12/// Address families
11enum class Domain : u8 { 13enum class Domain : u8 {
12 INET, ///< Address family for IPv4 14 Unspecified, ///< Represents 0, used in getaddrinfo hints
15 INET, ///< Address family for IPv4
13}; 16};
14 17
15/// Socket types 18/// Socket types
16enum class Type { 19enum class Type {
20 Unspecified, ///< Represents 0, used in getaddrinfo hints
17 STREAM, 21 STREAM,
18 DGRAM, 22 DGRAM,
19 RAW, 23 RAW,
@@ -22,6 +26,7 @@ enum class Type {
22 26
23/// Protocol values for sockets 27/// Protocol values for sockets
24enum class Protocol : u8 { 28enum class Protocol : u8 {
29 Unspecified, ///< Represents 0, usable in various places
25 ICMP, 30 ICMP,
26 TCP, 31 TCP,
27 UDP, 32 UDP,
@@ -48,4 +53,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2;
48constexpr u32 FLAG_MSG_DONTWAIT = 0x80; 53constexpr u32 FLAG_MSG_DONTWAIT = 0x80;
49constexpr u32 FLAG_O_NONBLOCK = 0x800; 54constexpr u32 FLAG_O_NONBLOCK = 0x800;
50 55
56/// Cross-platform addrinfo structure
57struct AddrInfo {
58 Domain family;
59 Type socket_type;
60 Protocol protocol;
61 SockAddrIn addr;
62 std::optional<std::string> canon_name;
63};
64
51} // namespace Network 65} // namespace Network
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index 28cb6f86f..c3b688c5d 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -723,6 +723,7 @@ add_library(core STATIC
723 hle/service/spl/spl_types.h 723 hle/service/spl/spl_types.h
724 hle/service/ssl/ssl.cpp 724 hle/service/ssl/ssl.cpp
725 hle/service/ssl/ssl.h 725 hle/service/ssl/ssl.h
726 hle/service/ssl/ssl_backend.h
726 hle/service/time/clock_types.h 727 hle/service/time/clock_types.h
727 hle/service/time/ephemeral_network_system_clock_context_writer.h 728 hle/service/time/ephemeral_network_system_clock_context_writer.h
728 hle/service/time/ephemeral_network_system_clock_core.h 729 hle/service/time/ephemeral_network_system_clock_core.h
@@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64)
864 target_link_libraries(core PRIVATE dynarmic::dynarmic) 865 target_link_libraries(core PRIVATE dynarmic::dynarmic)
865endif() 866endif()
866 867
868if(ENABLE_OPENSSL)
869 target_sources(core PRIVATE
870 hle/service/ssl/ssl_backend_openssl.cpp)
871 target_link_libraries(core PRIVATE OpenSSL::SSL)
872elseif (APPLE)
873 target_sources(core PRIVATE
874 hle/service/ssl/ssl_backend_securetransport.cpp)
875 target_link_libraries(core PRIVATE "-framework Security")
876elseif (WIN32)
877 target_sources(core PRIVATE
878 hle/service/ssl/ssl_backend_schannel.cpp)
879 target_link_libraries(core PRIVATE secur32)
880else()
881 target_sources(core PRIVATE
882 hle/service/ssl/ssl_backend_none.cpp)
883endif()
884
867if (YUZU_USE_PRECOMPILED_HEADERS) 885if (YUZU_USE_PRECOMPILED_HEADERS)
868 target_precompile_headers(core PRIVATE precompiled_headers.h) 886 target_precompile_headers(core PRIVATE precompiled_headers.h)
869endif() 887endif()
diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp
index bce45d321..e63b0a357 100644
--- a/src/core/hle/service/sockets/bsd.cpp
+++ b/src/core/hle/service/sockets/bsd.cpp
@@ -20,6 +20,9 @@
20#include "core/internal_network/sockets.h" 20#include "core/internal_network/sockets.h"
21#include "network/network.h" 21#include "network/network.h"
22 22
23using Common::Expected;
24using Common::Unexpected;
25
23namespace Service::Sockets { 26namespace Service::Sockets {
24 27
25namespace { 28namespace {
@@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) {
265 const u32 level = rp.Pop<u32>(); 268 const u32 level = rp.Pop<u32>();
266 const auto optname = static_cast<OptName>(rp.Pop<u32>()); 269 const auto optname = static_cast<OptName>(rp.Pop<u32>());
267 270
268 LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname);
269
270 std::vector<u8> optval(ctx.GetWriteBufferSize()); 271 std::vector<u8> optval(ctx.GetWriteBufferSize());
271 272
273 LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname,
274 optval.size());
275
276 const Errno err = GetSockOptImpl(fd, level, optname, optval);
277
272 ctx.WriteBuffer(optval); 278 ctx.WriteBuffer(optval);
273 279
274 IPC::ResponseBuilder rb{ctx, 5}; 280 IPC::ResponseBuilder rb{ctx, 5};
275 rb.Push(ResultSuccess); 281 rb.Push(ResultSuccess);
276 rb.Push<s32>(-1); 282 rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1);
277 rb.PushEnum(Errno::NOTCONN); 283 rb.PushEnum(err);
278 rb.Push<u32>(static_cast<u32>(optval.size())); 284 rb.Push<u32>(static_cast<u32>(optval.size()));
279} 285}
280 286
@@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) {
436 BuildErrnoResponse(ctx, CloseImpl(fd)); 442 BuildErrnoResponse(ctx, CloseImpl(fd));
437} 443}
438 444
445void BSD::DuplicateSocket(HLERequestContext& ctx) {
446 struct InputParameters {
447 s32 fd;
448 u64 reserved;
449 };
450 static_assert(sizeof(InputParameters) == 0x10);
451
452 struct OutputParameters {
453 s32 ret;
454 Errno bsd_errno;
455 };
456 static_assert(sizeof(OutputParameters) == 0x8);
457
458 IPC::RequestParser rp{ctx};
459 auto input = rp.PopRaw<InputParameters>();
460
461 Expected<s32, Errno> res = DuplicateSocketImpl(input.fd);
462 IPC::ResponseBuilder rb{ctx, 4};
463 rb.Push(ResultSuccess);
464 rb.PushRaw(OutputParameters{
465 .ret = res.value_or(0),
466 .bsd_errno = res ? Errno::SUCCESS : res.error(),
467 });
468}
469
439void BSD::EventFd(HLERequestContext& ctx) { 470void BSD::EventFd(HLERequestContext& ctx) {
440 IPC::RequestParser rp{ctx}; 471 IPC::RequestParser rp{ctx};
441 const u64 initval = rp.Pop<u64>(); 472 const u64 initval = rp.Pop<u64>();
@@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco
477 508
478 auto room_member = room_network.GetRoomMember().lock(); 509 auto room_member = room_network.GetRoomMember().lock();
479 if (room_member && room_member->IsConnected()) { 510 if (room_member && room_member->IsConnected()) {
480 descriptor.socket = std::make_unique<Network::ProxySocket>(room_network); 511 descriptor.socket = std::make_shared<Network::ProxySocket>(room_network);
481 } else { 512 } else {
482 descriptor.socket = std::make_unique<Network::Socket>(); 513 descriptor.socket = std::make_shared<Network::Socket>();
483 } 514 }
484 515
485 descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol)); 516 descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol));
486 descriptor.is_connection_based = IsConnectionBased(type); 517 descriptor.is_connection_based = IsConnectionBased(type);
487 518
488 return {fd, Errno::SUCCESS}; 519 return {fd, Errno::SUCCESS};
@@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
538 std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) { 569 std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) {
539 Network::PollFD result; 570 Network::PollFD result;
540 result.socket = file_descriptors[pollfd.fd]->socket.get(); 571 result.socket = file_descriptors[pollfd.fd]->socket.get();
541 result.events = TranslatePollEventsToHost(pollfd.events); 572 result.events = Translate(pollfd.events);
542 result.revents = Network::PollEvents{}; 573 result.revents = Network::PollEvents{};
543 return result; 574 return result;
544 }); 575 });
@@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
547 578
548 const size_t num = host_pollfds.size(); 579 const size_t num = host_pollfds.size();
549 for (size_t i = 0; i < num; ++i) { 580 for (size_t i = 0; i < num; ++i) {
550 fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents); 581 fds[i].revents = Translate(host_pollfds[i].revents);
551 } 582 }
552 std::memcpy(write_buffer.data(), fds.data(), length); 583 std::memcpy(write_buffer.data(), fds.data(), length);
553 584
@@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) {
617 } 648 }
618 const SockAddrIn guest_addrin = Translate(addr_in); 649 const SockAddrIn guest_addrin = Translate(addr_in);
619 650
620 ASSERT(write_buffer.size() == sizeof(guest_addrin)); 651 ASSERT(write_buffer.size() >= sizeof(guest_addrin));
652 write_buffer.resize(sizeof(guest_addrin));
621 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); 653 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
622 return Translate(bsd_errno); 654 return Translate(bsd_errno);
623} 655}
@@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) {
633 } 665 }
634 const SockAddrIn guest_addrin = Translate(addr_in); 666 const SockAddrIn guest_addrin = Translate(addr_in);
635 667
636 ASSERT(write_buffer.size() == sizeof(guest_addrin)); 668 ASSERT(write_buffer.size() >= sizeof(guest_addrin));
669 write_buffer.resize(sizeof(guest_addrin));
637 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); 670 std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
638 return Translate(bsd_errno); 671 return Translate(bsd_errno);
639} 672}
@@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) {
671 } 704 }
672} 705}
673 706
674Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { 707Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) {
675 UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET 708 if (!IsFileDescriptorValid(fd)) {
709 return Errno::BADF;
710 }
711
712 if (level != static_cast<u32>(SocketLevel::SOCKET)) {
713 UNIMPLEMENTED_MSG("Unknown getsockopt level");
714 return Errno::SUCCESS;
715 }
716
717 Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
718
719 switch (optname) {
720 case OptName::ERROR_: {
721 auto [pending_err, getsockopt_err] = socket->GetPendingError();
722 if (getsockopt_err == Network::Errno::SUCCESS) {
723 Errno translated_pending_err = Translate(pending_err);
724 ASSERT_OR_EXECUTE_MSG(
725 optval.size() == sizeof(Errno), { return Errno::INVAL; },
726 "Incorrect getsockopt option size");
727 optval.resize(sizeof(Errno));
728 memcpy(optval.data(), &translated_pending_err, sizeof(Errno));
729 }
730 return Translate(getsockopt_err);
731 }
732 default:
733 UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
734 return Errno::SUCCESS;
735 }
736}
676 737
738Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
677 if (!IsFileDescriptorValid(fd)) { 739 if (!IsFileDescriptorValid(fd)) {
678 return Errno::BADF; 740 return Errno::BADF;
679 } 741 }
680 742
743 if (level != static_cast<u32>(SocketLevel::SOCKET)) {
744 UNIMPLEMENTED_MSG("Unknown setsockopt level");
745 return Errno::SUCCESS;
746 }
747
681 Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); 748 Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
682 749
683 if (optname == OptName::LINGER) { 750 if (optname == OptName::LINGER) {
@@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con
711 return Translate(socket->SetSndTimeo(value)); 778 return Translate(socket->SetSndTimeo(value));
712 case OptName::RCVTIMEO: 779 case OptName::RCVTIMEO:
713 return Translate(socket->SetRcvTimeo(value)); 780 return Translate(socket->SetRcvTimeo(value));
781 case OptName::NOSIGPIPE:
782 LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value);
783 return Errno::SUCCESS;
714 default: 784 default:
715 UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); 785 UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
716 return Errno::SUCCESS; 786 return Errno::SUCCESS;
@@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) {
841 return bsd_errno; 911 return bsd_errno;
842} 912}
843 913
914Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) {
915 if (!IsFileDescriptorValid(fd)) {
916 return Unexpected(Errno::BADF);
917 }
918
919 const s32 new_fd = FindFreeFileDescriptorHandle();
920 if (new_fd < 0) {
921 LOG_ERROR(Service, "No more file descriptors available");
922 return Unexpected(Errno::MFILE);
923 }
924
925 file_descriptors[new_fd] = file_descriptors[fd];
926 return new_fd;
927}
928
929std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) {
930 if (!IsFileDescriptorValid(fd)) {
931 return std::nullopt;
932 }
933 return file_descriptors[fd]->socket;
934}
935
844s32 BSD::FindFreeFileDescriptorHandle() noexcept { 936s32 BSD::FindFreeFileDescriptorHandle() noexcept {
845 for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) { 937 for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) {
846 if (!file_descriptors[fd]) { 938 if (!file_descriptors[fd]) {
@@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name)
911 {24, &BSD::Write, "Write"}, 1003 {24, &BSD::Write, "Write"},
912 {25, &BSD::Read, "Read"}, 1004 {25, &BSD::Read, "Read"},
913 {26, &BSD::Close, "Close"}, 1005 {26, &BSD::Close, "Close"},
914 {27, nullptr, "DuplicateSocket"}, 1006 {27, &BSD::DuplicateSocket, "DuplicateSocket"},
915 {28, nullptr, "GetResourceStatistics"}, 1007 {28, nullptr, "GetResourceStatistics"},
916 {29, nullptr, "RecvMMsg"}, 1008 {29, nullptr, "RecvMMsg"},
917 {30, nullptr, "SendMMsg"}, 1009 {30, nullptr, "SendMMsg"},
diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h
index 30ae9c140..430edb97c 100644
--- a/src/core/hle/service/sockets/bsd.h
+++ b/src/core/hle/service/sockets/bsd.h
@@ -8,6 +8,7 @@
8#include <vector> 8#include <vector>
9 9
10#include "common/common_types.h" 10#include "common/common_types.h"
11#include "common/expected.h"
11#include "common/socket_types.h" 12#include "common/socket_types.h"
12#include "core/hle/service/service.h" 13#include "core/hle/service/service.h"
13#include "core/hle/service/sockets/sockets.h" 14#include "core/hle/service/sockets/sockets.h"
@@ -29,12 +30,19 @@ public:
29 explicit BSD(Core::System& system_, const char* name); 30 explicit BSD(Core::System& system_, const char* name);
30 ~BSD() override; 31 ~BSD() override;
31 32
33 // These methods are called from SSL; the first two are also called from
34 // this class for the corresponding IPC methods.
35 // On the real device, the SSL service makes IPC calls to this service.
36 Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd);
37 Errno CloseImpl(s32 fd);
38 std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd);
39
32private: 40private:
33 /// Maximum number of file descriptors 41 /// Maximum number of file descriptors
34 static constexpr size_t MAX_FD = 128; 42 static constexpr size_t MAX_FD = 128;
35 43
36 struct FileDescriptor { 44 struct FileDescriptor {
37 std::unique_ptr<Network::SocketBase> socket; 45 std::shared_ptr<Network::SocketBase> socket;
38 s32 flags = 0; 46 s32 flags = 0;
39 bool is_connection_based = false; 47 bool is_connection_based = false;
40 }; 48 };
@@ -138,6 +146,7 @@ private:
138 void Write(HLERequestContext& ctx); 146 void Write(HLERequestContext& ctx);
139 void Read(HLERequestContext& ctx); 147 void Read(HLERequestContext& ctx);
140 void Close(HLERequestContext& ctx); 148 void Close(HLERequestContext& ctx);
149 void DuplicateSocket(HLERequestContext& ctx);
141 void EventFd(HLERequestContext& ctx); 150 void EventFd(HLERequestContext& ctx);
142 151
143 template <typename Work> 152 template <typename Work>
@@ -153,6 +162,7 @@ private:
153 Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer); 162 Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer);
154 Errno ListenImpl(s32 fd, s32 backlog); 163 Errno ListenImpl(s32 fd, s32 backlog);
155 std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); 164 std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg);
165 Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval);
156 Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); 166 Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval);
157 Errno ShutdownImpl(s32 fd, s32 how); 167 Errno ShutdownImpl(s32 fd, s32 how);
158 std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message); 168 std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message);
@@ -161,7 +171,6 @@ private:
161 std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message); 171 std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message);
162 std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message, 172 std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message,
163 std::span<const u8> addr); 173 std::span<const u8> addr);
164 Errno CloseImpl(s32 fd);
165 174
166 s32 FindFreeFileDescriptorHandle() noexcept; 175 s32 FindFreeFileDescriptorHandle() noexcept;
167 bool IsFileDescriptorValid(s32 fd) const noexcept; 176 bool IsFileDescriptorValid(s32 fd) const noexcept;
diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp
index 6491a73be..0dfb0f166 100644
--- a/src/core/hle/service/sockets/nsd.cpp
+++ b/src/core/hle/service/sockets/nsd.cpp
@@ -1,10 +1,15 @@
1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project 1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later 2// SPDX-License-Identifier: GPL-2.0-or-later
3 3
4#include "core/hle/service/ipc_helpers.h"
4#include "core/hle/service/sockets/nsd.h" 5#include "core/hle/service/sockets/nsd.h"
5 6
7#include "common/string_util.h"
8
6namespace Service::Sockets { 9namespace Service::Sockets {
7 10
11constexpr Result ResultOverflow{ErrorModule::NSD, 6};
12
8NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { 13NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} {
9 // clang-format off 14 // clang-format off
10 static const FunctionInfo functions[] = { 15 static const FunctionInfo functions[] = {
@@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
15 {13, nullptr, "DeleteSettings"}, 20 {13, nullptr, "DeleteSettings"},
16 {14, nullptr, "ImportSettings"}, 21 {14, nullptr, "ImportSettings"},
17 {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, 22 {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"},
18 {20, nullptr, "Resolve"}, 23 {20, &NSD::Resolve, "Resolve"},
19 {21, nullptr, "ResolveEx"}, 24 {21, &NSD::ResolveEx, "ResolveEx"},
20 {30, nullptr, "GetNasServiceSetting"}, 25 {30, nullptr, "GetNasServiceSetting"},
21 {31, nullptr, "GetNasServiceSettingEx"}, 26 {31, nullptr, "GetNasServiceSettingEx"},
22 {40, nullptr, "GetNasRequestFqdn"}, 27 {40, nullptr, "GetNasRequestFqdn"},
@@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
40 RegisterHandlers(functions); 45 RegisterHandlers(functions);
41} 46}
42 47
48static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
49 // The real implementation makes various substitutions.
50 // For now we just return the string as-is, which is good enough when not
51 // connecting to real Nintendo servers.
52 LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in);
53 return fqdn_in;
54}
55
56static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
57 const auto res = ResolveImpl(fqdn_in);
58 if (res.Failed()) {
59 return res.Code();
60 }
61 if (res->size() >= fqdn_out.size()) {
62 return ResultOverflow;
63 }
64 std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1);
65 return ResultSuccess;
66}
67
68void NSD::Resolve(HLERequestContext& ctx) {
69 const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
70
71 std::array<char, 0x100> fqdn_out{};
72 const Result res = ResolveCommon(fqdn_in, fqdn_out);
73
74 ctx.WriteBuffer(fqdn_out);
75 IPC::ResponseBuilder rb{ctx, 2};
76 rb.Push(res);
77}
78
79void NSD::ResolveEx(HLERequestContext& ctx) {
80 const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
81
82 std::array<char, 0x100> fqdn_out;
83 const Result res = ResolveCommon(fqdn_in, fqdn_out);
84
85 if (res.IsError()) {
86 IPC::ResponseBuilder rb{ctx, 2};
87 rb.Push(res);
88 return;
89 }
90
91 ctx.WriteBuffer(fqdn_out);
92 IPC::ResponseBuilder rb{ctx, 4};
93 rb.Push(ResultSuccess);
94 rb.Push(ResultSuccess);
95}
96
43NSD::~NSD() = default; 97NSD::~NSD() = default;
44 98
45} // namespace Service::Sockets 99} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h
index 5cc12b855..a7379a8a9 100644
--- a/src/core/hle/service/sockets/nsd.h
+++ b/src/core/hle/service/sockets/nsd.h
@@ -15,6 +15,10 @@ class NSD final : public ServiceFramework<NSD> {
15public: 15public:
16 explicit NSD(Core::System& system_, const char* name); 16 explicit NSD(Core::System& system_, const char* name);
17 ~NSD() override; 17 ~NSD() override;
18
19private:
20 void Resolve(HLERequestContext& ctx);
21 void ResolveEx(HLERequestContext& ctx);
18}; 22};
19 23
20} // namespace Service::Sockets 24} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp
index 132dd5797..84cc79de8 100644
--- a/src/core/hle/service/sockets/sfdnsres.cpp
+++ b/src/core/hle/service/sockets/sfdnsres.cpp
@@ -10,27 +10,18 @@
10#include "core/core.h" 10#include "core/core.h"
11#include "core/hle/service/ipc_helpers.h" 11#include "core/hle/service/ipc_helpers.h"
12#include "core/hle/service/sockets/sfdnsres.h" 12#include "core/hle/service/sockets/sfdnsres.h"
13#include "core/hle/service/sockets/sockets.h"
14#include "core/hle/service/sockets/sockets_translate.h"
15#include "core/internal_network/network.h"
13#include "core/memory.h" 16#include "core/memory.h"
14 17
15#ifdef _WIN32
16#include <ws2tcpip.h>
17#elif YUZU_UNIX
18#include <arpa/inet.h>
19#include <netdb.h>
20#include <netinet/in.h>
21#include <sys/socket.h>
22#ifndef EAI_NODATA
23#define EAI_NODATA EAI_NONAME
24#endif
25#endif
26
27namespace Service::Sockets { 18namespace Service::Sockets {
28 19
29SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { 20SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} {
30 static const FunctionInfo functions[] = { 21 static const FunctionInfo functions[] = {
31 {0, nullptr, "SetDnsAddressesPrivateRequest"}, 22 {0, nullptr, "SetDnsAddressesPrivateRequest"},
32 {1, nullptr, "GetDnsAddressPrivateRequest"}, 23 {1, nullptr, "GetDnsAddressPrivateRequest"},
33 {2, nullptr, "GetHostByNameRequest"}, 24 {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"},
34 {3, nullptr, "GetHostByAddrRequest"}, 25 {3, nullptr, "GetHostByAddrRequest"},
35 {4, nullptr, "GetHostStringErrorRequest"}, 26 {4, nullptr, "GetHostStringErrorRequest"},
36 {5, nullptr, "GetGaiStringErrorRequest"}, 27 {5, nullptr, "GetGaiStringErrorRequest"},
@@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"
38 {7, nullptr, "GetNameInfoRequest"}, 29 {7, nullptr, "GetNameInfoRequest"},
39 {8, nullptr, "RequestCancelHandleRequest"}, 30 {8, nullptr, "RequestCancelHandleRequest"},
40 {9, nullptr, "CancelRequest"}, 31 {9, nullptr, "CancelRequest"},
41 {10, nullptr, "GetHostByNameRequestWithOptions"}, 32 {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"},
42 {11, nullptr, "GetHostByAddrRequestWithOptions"}, 33 {11, nullptr, "GetHostByAddrRequestWithOptions"},
43 {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, 34 {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"},
44 {13, nullptr, "GetNameInfoRequestWithOptions"}, 35 {13, nullptr, "GetNameInfoRequestWithOptions"},
45 {14, nullptr, "ResolverSetOptionRequest"}, 36 {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"},
46 {15, nullptr, "ResolverGetOptionRequest"}, 37 {15, nullptr, "ResolverGetOptionRequest"},
47 }; 38 };
48 RegisterHandlers(functions); 39 RegisterHandlers(functions);
@@ -59,188 +50,285 @@ enum class NetDbError : s32 {
59 NoData = 4, 50 NoData = 4,
60}; 51};
61 52
62static NetDbError AddrInfoErrorToNetDbError(s32 result) { 53static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) {
63 // Best effort guess to map errors 54 // These combinations have been verified on console (but are not
55 // exhaustive).
64 switch (result) { 56 switch (result) {
65 case 0: 57 case GetAddrInfoError::SUCCESS:
66 return NetDbError::Success; 58 return NetDbError::Success;
67 case EAI_AGAIN: 59 case GetAddrInfoError::AGAIN:
68 return NetDbError::TryAgain; 60 return NetDbError::TryAgain;
69 case EAI_NODATA: 61 case GetAddrInfoError::NODATA:
70 return NetDbError::NoData; 62 return NetDbError::HostNotFound;
63 case GetAddrInfoError::SERVICE:
64 return NetDbError::Success;
71 default: 65 default:
72 return NetDbError::HostNotFound; 66 return NetDbError::HostNotFound;
73 } 67 }
74} 68}
75 69
76static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code, 70static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) {
71 // These combinations have been verified on console (but are not
72 // exhaustive).
73 switch (result) {
74 case GetAddrInfoError::SUCCESS:
75 // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for
76 // some reason, but that doesn't seem useful to implement.
77 return Errno::SUCCESS;
78 case GetAddrInfoError::AGAIN:
79 return Errno::SUCCESS;
80 case GetAddrInfoError::NODATA:
81 return Errno::SUCCESS;
82 case GetAddrInfoError::SERVICE:
83 return Errno::INVAL;
84 default:
85 return Errno::SUCCESS;
86 }
87}
88
89template <typename T>
90static void Append(std::vector<u8>& vec, T t) {
91 const size_t offset = vec.size();
92 vec.resize(offset + sizeof(T));
93 std::memcpy(vec.data() + offset, &t, sizeof(T));
94}
95
96static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) {
97 const size_t offset = vec.size();
98 vec.resize(offset + str.size() + 1);
99 std::memmove(vec.data() + offset, str.data(), str.size());
100}
101
102// We implement gethostbyname using the host's getaddrinfo rather than the
103// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo
104// behaves the same on Unix and Windows, unlike gethostbyname where Windows
105// doesn't implement h_errno.
106static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec,
107 std::string_view host) {
108
109 std::vector<u8> data;
110 // h_name: use the input hostname (append nul-terminated)
111 AppendNulTerminated(data, host);
112 // h_aliases: leave empty
113
114 Append<u32_be>(data, 0); // count of h_aliases
115 // (If the count were nonzero, the aliases would be appended as nul-terminated here.)
116 Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype
117 Append<u16_be>(data, sizeof(Network::IPv4Address)); // h_length
118 // h_addr_list:
119 size_t count = vec.size();
120 ASSERT(count <= UINT32_MAX);
121 Append<u32_be>(data, static_cast<uint32_t>(count));
122 for (const Network::AddrInfo& addrinfo : vec) {
123 // On the Switch, this is passed through htonl despite already being
124 // big-endian, so it ends up as little-endian.
125 Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip));
126
127 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
128 Network::IPv4AddressToString(addrinfo.addr.ip));
129 }
130 return data;
131}
132
133static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
134 struct InputParameters {
135 u8 use_nsd_resolve;
136 u32 cancel_handle;
137 u64 process_id;
138 };
139 static_assert(sizeof(InputParameters) == 0x10);
140
141 IPC::RequestParser rp{ctx};
142 const auto parameters = rp.PopRaw<InputParameters>();
143
144 LOG_WARNING(
145 Service,
146 "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
147 parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
148
149 const auto host_buffer = ctx.ReadBuffer(0);
150 const std::string host = Common::StringFromBuffer(host_buffer);
151 // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions.
152
153 auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt);
154 if (!res.has_value()) {
155 return {0, Translate(res.error())};
156 }
157
158 const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host);
159 const u32 data_size = static_cast<u32>(data.size());
160 ctx.WriteBuffer(data, 0);
161
162 return {data_size, GetAddrInfoError::SUCCESS};
163}
164
165void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
166 auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
167
168 struct OutputParameters {
169 NetDbError netdb_error;
170 Errno bsd_errno;
171 u32 data_size;
172 };
173 static_assert(sizeof(OutputParameters) == 0xc);
174
175 IPC::ResponseBuilder rb{ctx, 5};
176 rb.Push(ResultSuccess);
177 rb.PushRaw(OutputParameters{
178 .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
179 .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
180 .data_size = data_size,
181 });
182}
183
184void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
185 auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
186
187 struct OutputParameters {
188 u32 data_size;
189 NetDbError netdb_error;
190 Errno bsd_errno;
191 };
192 static_assert(sizeof(OutputParameters) == 0xc);
193
194 IPC::ResponseBuilder rb{ctx, 5};
195 rb.Push(ResultSuccess);
196 rb.PushRaw(OutputParameters{
197 .data_size = data_size,
198 .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
199 .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
200 });
201}
202
203static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
77 std::string_view host) { 204 std::string_view host) {
78 // Adapted from 205 // Adapted from
79 // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 206 // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190
80 std::vector<u8> data; 207 std::vector<u8> data;
81 208
82 auto* current = addrinfo; 209 for (const Network::AddrInfo& addrinfo : vec) {
83 while (current != nullptr) { 210 // serialized addrinfo:
84 struct SerializedResponseHeader { 211 Append<u32_be>(data, 0xBEEFCAFE); // magic
85 u32 magic; 212 Append<u32_be>(data, 0); // ai_flags
86 s32 flags; 213 Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family))); // ai_family
87 s32 family; 214 Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype
88 s32 socket_type; 215 Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol))); // ai_protocol
89 s32 protocol; 216 Append<u32_be>(data, sizeof(SockAddrIn)); // ai_addrlen
90 u32 address_length; 217 // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size
91 }; 218
92 static_assert(sizeof(SerializedResponseHeader) == 0x18, 219 // ai_addr:
93 "Response header size must be 0x18 bytes"); 220 Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family
94 221 // On the Switch, the following fields are passed through htonl despite
95 constexpr auto header_size = sizeof(SerializedResponseHeader); 222 // already being big-endian, so they end up as little-endian.
96 const auto addr_size = 223 Append<u16_le>(data, addrinfo.addr.portno); // sin_port
97 current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4; 224 Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr
98 const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1; 225 data.resize(data.size() + 8, 0); // sin_zero
99 226
100 const auto last_size = data.size(); 227 if (addrinfo.canon_name.has_value()) {
101 data.resize(last_size + header_size + addr_size + canonname_size); 228 AppendNulTerminated(data, *addrinfo.canon_name);
102
103 // Header in network byte order
104 SerializedResponseHeader header{};
105
106 constexpr auto HEADER_MAGIC = 0xBEEFCAFE;
107 header.magic = htonl(HEADER_MAGIC);
108 header.family = htonl(current->ai_family);
109 header.flags = htonl(current->ai_flags);
110 header.socket_type = htonl(current->ai_socktype);
111 header.protocol = htonl(current->ai_protocol);
112 header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0;
113
114 auto* header_ptr = data.data() + last_size;
115 std::memcpy(header_ptr, &header, header_size);
116
117 if (header.address_length == 0) {
118 std::memset(header_ptr + header_size, 0, 4);
119 } else {
120 switch (current->ai_family) {
121 case AF_INET: {
122 struct SockAddrIn {
123 s16 sin_family;
124 u16 sin_port;
125 u32 sin_addr;
126 u8 sin_zero[8];
127 };
128
129 SockAddrIn serialized_addr{};
130 const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr);
131 serialized_addr.sin_port = htons(addr.sin_port);
132 serialized_addr.sin_family = htons(addr.sin_family);
133 serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr);
134 std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn));
135
136 char addr_string_buf[64]{};
137 inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf));
138 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf);
139 break;
140 }
141 case AF_INET6: {
142 struct SockAddrIn6 {
143 s16 sin6_family;
144 u16 sin6_port;
145 u32 sin6_flowinfo;
146 u8 sin6_addr[16];
147 u32 sin6_scope_id;
148 };
149
150 SockAddrIn6 serialized_addr{};
151 const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr);
152 serialized_addr.sin6_family = htons(addr.sin6_family);
153 serialized_addr.sin6_port = htons(addr.sin6_port);
154 serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo);
155 serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id);
156 std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr,
157 sizeof(SockAddrIn6::sin6_addr));
158 std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6));
159
160 char addr_string_buf[64]{};
161 inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf));
162 LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf);
163 break;
164 }
165 default:
166 std::memcpy(header_ptr + header_size, current->ai_addr, addr_size);
167 break;
168 }
169 }
170 if (current->ai_canonname) {
171 std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size);
172 } else { 229 } else {
173 *(header_ptr + header_size + addr_size) = 0; 230 data.push_back(0);
174 } 231 }
175 232
176 current = current->ai_next; 233 LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
234 Network::IPv4AddressToString(addrinfo.addr.ip));
177 } 235 }
178 236
179 // 4-byte sentinel value 237 data.resize(data.size() + 4, 0); // 4-byte sentinel value
180 data.push_back(0);
181 data.push_back(0);
182 data.push_back(0);
183 data.push_back(0);
184 238
185 return data; 239 return data;
186} 240}
187 241
188static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) { 242static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
189 struct Parameters { 243 struct InputParameters {
190 u8 use_nsd_resolve; 244 u8 use_nsd_resolve;
191 u32 unknown; 245 u32 cancel_handle;
192 u64 process_id; 246 u64 process_id;
193 }; 247 };
248 static_assert(sizeof(InputParameters) == 0x10);
194 249
195 IPC::RequestParser rp{ctx}; 250 IPC::RequestParser rp{ctx};
196 const auto parameters = rp.PopRaw<Parameters>(); 251 const auto parameters = rp.PopRaw<InputParameters>();
252
253 LOG_WARNING(
254 Service,
255 "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
256 parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
197 257
198 LOG_WARNING(Service, 258 // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve
199 "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}", 259 // before looking up.
200 parameters.use_nsd_resolve, parameters.unknown, parameters.process_id);
201 260
202 const auto host_buffer = ctx.ReadBuffer(0); 261 const auto host_buffer = ctx.ReadBuffer(0);
203 const std::string host = Common::StringFromBuffer(host_buffer); 262 const std::string host = Common::StringFromBuffer(host_buffer);
204 263
205 const auto service_buffer = ctx.ReadBuffer(1); 264 std::optional<std::string> service = std::nullopt;
206 const std::string service = Common::StringFromBuffer(service_buffer); 265 if (ctx.CanReadBuffer(1)) {
207 266 const std::span<const u8> service_buffer = ctx.ReadBuffer(1);
208 addrinfo* addrinfo; 267 service = Common::StringFromBuffer(service_buffer);
209 // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now 268 }
210 s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo);
211 269
212 u32 data_size = 0; 270 // Serialized hints are also passed in a buffer, but are ignored for now.
213 if (result_code == 0 && addrinfo != nullptr) {
214 const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host);
215 data_size = static_cast<u32>(data.size());
216 freeaddrinfo(addrinfo);
217 271
218 ctx.WriteBuffer(data, 0); 272 auto res = Network::GetAddressInfo(host, service);
273 if (!res.has_value()) {
274 return {0, Translate(res.error())};
219 } 275 }
220 276
221 return std::make_pair(data_size, result_code); 277 const std::vector<u8> data = SerializeAddrInfo(res.value(), host);
278 const u32 data_size = static_cast<u32>(data.size());
279 ctx.WriteBuffer(data, 0);
280
281 return {data_size, GetAddrInfoError::SUCCESS};
222} 282}
223 283
224void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { 284void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
225 auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); 285 auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
286
287 struct OutputParameters {
288 Errno bsd_errno;
289 GetAddrInfoError gai_error;
290 u32 data_size;
291 };
292 static_assert(sizeof(OutputParameters) == 0xc);
226 293
227 IPC::ResponseBuilder rb{ctx, 4}; 294 IPC::ResponseBuilder rb{ctx, 5};
228 rb.Push(ResultSuccess); 295 rb.Push(ResultSuccess);
229 rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode 296 rb.PushRaw(OutputParameters{
230 rb.Push(result_code); // errno 297 .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
231 rb.Push(data_size); // serialized size 298 .gai_error = emu_gai_err,
299 .data_size = data_size,
300 });
232} 301}
233 302
234void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { 303void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
235 // Additional options are ignored 304 // Additional options are ignored
236 auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); 305 auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
306
307 struct OutputParameters {
308 u32 data_size;
309 GetAddrInfoError gai_error;
310 NetDbError netdb_error;
311 Errno bsd_errno;
312 };
313 static_assert(sizeof(OutputParameters) == 0x10);
314
315 IPC::ResponseBuilder rb{ctx, 6};
316 rb.Push(ResultSuccess);
317 rb.PushRaw(OutputParameters{
318 .data_size = data_size,
319 .gai_error = emu_gai_err,
320 .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
321 .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
322 });
323}
324
325void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
326 LOG_WARNING(Service, "(STUBBED) called");
327
328 IPC::ResponseBuilder rb{ctx, 3};
237 329
238 IPC::ResponseBuilder rb{ctx, 5};
239 rb.Push(ResultSuccess); 330 rb.Push(ResultSuccess);
240 rb.Push(data_size); // serialized size 331 rb.Push<s32>(0); // bsd errno
241 rb.Push(result_code); // errno
242 rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
243 rb.Push(0);
244} 332}
245 333
246} // namespace Service::Sockets 334} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h
index 18e3cd60c..d99a9d560 100644
--- a/src/core/hle/service/sockets/sfdnsres.h
+++ b/src/core/hle/service/sockets/sfdnsres.h
@@ -17,8 +17,11 @@ public:
17 ~SFDNSRES() override; 17 ~SFDNSRES() override;
18 18
19private: 19private:
20 void GetHostByNameRequest(HLERequestContext& ctx);
21 void GetHostByNameRequestWithOptions(HLERequestContext& ctx);
20 void GetAddrInfoRequest(HLERequestContext& ctx); 22 void GetAddrInfoRequest(HLERequestContext& ctx);
21 void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); 23 void GetAddrInfoRequestWithOptions(HLERequestContext& ctx);
24 void ResolverSetOptionRequest(HLERequestContext& ctx);
22}; 25};
23 26
24} // namespace Service::Sockets 27} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h
index acd2dae7b..77426c46e 100644
--- a/src/core/hle/service/sockets/sockets.h
+++ b/src/core/hle/service/sockets/sockets.h
@@ -22,13 +22,35 @@ enum class Errno : u32 {
22 CONNRESET = 104, 22 CONNRESET = 104,
23 NOTCONN = 107, 23 NOTCONN = 107,
24 TIMEDOUT = 110, 24 TIMEDOUT = 110,
25 INPROGRESS = 115,
26};
27
28enum class GetAddrInfoError : s32 {
29 SUCCESS = 0,
30 ADDRFAMILY = 1,
31 AGAIN = 2,
32 BADFLAGS = 3,
33 FAIL = 4,
34 FAMILY = 5,
35 MEMORY = 6,
36 NODATA = 7,
37 NONAME = 8,
38 SERVICE = 9,
39 SOCKTYPE = 10,
40 SYSTEM = 11,
41 BADHINTS = 12,
42 PROTOCOL = 13,
43 OVERFLOW_ = 14, // avoid name collision with Windows macro
44 OTHER = 15,
25}; 45};
26 46
27enum class Domain : u32 { 47enum class Domain : u32 {
48 Unspecified = 0,
28 INET = 2, 49 INET = 2,
29}; 50};
30 51
31enum class Type : u32 { 52enum class Type : u32 {
53 Unspecified = 0,
32 STREAM = 1, 54 STREAM = 1,
33 DGRAM = 2, 55 DGRAM = 2,
34 RAW = 3, 56 RAW = 3,
@@ -36,12 +58,16 @@ enum class Type : u32 {
36}; 58};
37 59
38enum class Protocol : u32 { 60enum class Protocol : u32 {
39 UNSPECIFIED = 0, 61 Unspecified = 0,
40 ICMP = 1, 62 ICMP = 1,
41 TCP = 6, 63 TCP = 6,
42 UDP = 17, 64 UDP = 17,
43}; 65};
44 66
67enum class SocketLevel : u32 {
68 SOCKET = 0xffff, // i.e. SOL_SOCKET
69};
70
45enum class OptName : u32 { 71enum class OptName : u32 {
46 REUSEADDR = 0x4, 72 REUSEADDR = 0x4,
47 KEEPALIVE = 0x8, 73 KEEPALIVE = 0x8,
@@ -51,6 +77,8 @@ enum class OptName : u32 {
51 RCVBUF = 0x1002, 77 RCVBUF = 0x1002,
52 SNDTIMEO = 0x1005, 78 SNDTIMEO = 0x1005,
53 RCVTIMEO = 0x1006, 79 RCVTIMEO = 0x1006,
80 ERROR_ = 0x1007, // avoid name collision with Windows macro
81 NOSIGPIPE = 0x800, // at least according to libnx
54}; 82};
55 83
56enum class ShutdownHow : s32 { 84enum class ShutdownHow : s32 {
@@ -80,6 +108,9 @@ enum class PollEvents : u16 {
80 Err = 1 << 3, 108 Err = 1 << 3,
81 Hup = 1 << 4, 109 Hup = 1 << 4,
82 Nval = 1 << 5, 110 Nval = 1 << 5,
111 RdNorm = 1 << 6,
112 RdBand = 1 << 7,
113 WrBand = 1 << 8,
83}; 114};
84 115
85DECLARE_ENUM_FLAG_OPERATORS(PollEvents); 116DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp
index 594e58f90..2f9a0e39c 100644
--- a/src/core/hle/service/sockets/sockets_translate.cpp
+++ b/src/core/hle/service/sockets/sockets_translate.cpp
@@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) {
29 return Errno::TIMEDOUT; 29 return Errno::TIMEDOUT;
30 case Network::Errno::CONNRESET: 30 case Network::Errno::CONNRESET:
31 return Errno::CONNRESET; 31 return Errno::CONNRESET;
32 case Network::Errno::INPROGRESS:
33 return Errno::INPROGRESS;
32 default: 34 default:
33 UNIMPLEMENTED_MSG("Unimplemented errno={}", value); 35 UNIMPLEMENTED_MSG("Unimplemented errno={}", value);
34 return Errno::SUCCESS; 36 return Errno::SUCCESS;
@@ -39,8 +41,50 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) {
39 return {value.first, Translate(value.second)}; 41 return {value.first, Translate(value.second)};
40} 42}
41 43
44GetAddrInfoError Translate(Network::GetAddrInfoError error) {
45 switch (error) {
46 case Network::GetAddrInfoError::SUCCESS:
47 return GetAddrInfoError::SUCCESS;
48 case Network::GetAddrInfoError::ADDRFAMILY:
49 return GetAddrInfoError::ADDRFAMILY;
50 case Network::GetAddrInfoError::AGAIN:
51 return GetAddrInfoError::AGAIN;
52 case Network::GetAddrInfoError::BADFLAGS:
53 return GetAddrInfoError::BADFLAGS;
54 case Network::GetAddrInfoError::FAIL:
55 return GetAddrInfoError::FAIL;
56 case Network::GetAddrInfoError::FAMILY:
57 return GetAddrInfoError::FAMILY;
58 case Network::GetAddrInfoError::MEMORY:
59 return GetAddrInfoError::MEMORY;
60 case Network::GetAddrInfoError::NODATA:
61 return GetAddrInfoError::NODATA;
62 case Network::GetAddrInfoError::NONAME:
63 return GetAddrInfoError::NONAME;
64 case Network::GetAddrInfoError::SERVICE:
65 return GetAddrInfoError::SERVICE;
66 case Network::GetAddrInfoError::SOCKTYPE:
67 return GetAddrInfoError::SOCKTYPE;
68 case Network::GetAddrInfoError::SYSTEM:
69 return GetAddrInfoError::SYSTEM;
70 case Network::GetAddrInfoError::BADHINTS:
71 return GetAddrInfoError::BADHINTS;
72 case Network::GetAddrInfoError::PROTOCOL:
73 return GetAddrInfoError::PROTOCOL;
74 case Network::GetAddrInfoError::OVERFLOW_:
75 return GetAddrInfoError::OVERFLOW_;
76 case Network::GetAddrInfoError::OTHER:
77 return GetAddrInfoError::OTHER;
78 default:
79 UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error);
80 return GetAddrInfoError::OTHER;
81 }
82}
83
42Network::Domain Translate(Domain domain) { 84Network::Domain Translate(Domain domain) {
43 switch (domain) { 85 switch (domain) {
86 case Domain::Unspecified:
87 return Network::Domain::Unspecified;
44 case Domain::INET: 88 case Domain::INET:
45 return Network::Domain::INET; 89 return Network::Domain::INET;
46 default: 90 default:
@@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) {
51 95
52Domain Translate(Network::Domain domain) { 96Domain Translate(Network::Domain domain) {
53 switch (domain) { 97 switch (domain) {
98 case Network::Domain::Unspecified:
99 return Domain::Unspecified;
54 case Network::Domain::INET: 100 case Network::Domain::INET:
55 return Domain::INET; 101 return Domain::INET;
56 default: 102 default:
@@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) {
61 107
62Network::Type Translate(Type type) { 108Network::Type Translate(Type type) {
63 switch (type) { 109 switch (type) {
110 case Type::Unspecified:
111 return Network::Type::Unspecified;
64 case Type::STREAM: 112 case Type::STREAM:
65 return Network::Type::STREAM; 113 return Network::Type::STREAM;
66 case Type::DGRAM: 114 case Type::DGRAM:
67 return Network::Type::DGRAM; 115 return Network::Type::DGRAM;
116 case Type::RAW:
117 return Network::Type::RAW;
118 case Type::SEQPACKET:
119 return Network::Type::SEQPACKET;
68 default: 120 default:
69 UNIMPLEMENTED_MSG("Unimplemented type={}", type); 121 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
70 return Network::Type{}; 122 return Network::Type{};
71 } 123 }
72} 124}
73 125
74Network::Protocol Translate(Type type, Protocol protocol) { 126Type Translate(Network::Type type) {
127 switch (type) {
128 case Network::Type::Unspecified:
129 return Type::Unspecified;
130 case Network::Type::STREAM:
131 return Type::STREAM;
132 case Network::Type::DGRAM:
133 return Type::DGRAM;
134 case Network::Type::RAW:
135 return Type::RAW;
136 case Network::Type::SEQPACKET:
137 return Type::SEQPACKET;
138 default:
139 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
140 return Type{};
141 }
142}
143
144Network::Protocol Translate(Protocol protocol) {
75 switch (protocol) { 145 switch (protocol) {
76 case Protocol::UNSPECIFIED: 146 case Protocol::Unspecified:
77 LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type"); 147 return Network::Protocol::Unspecified;
78 switch (type) {
79 case Type::DGRAM:
80 return Network::Protocol::UDP;
81 case Type::STREAM:
82 return Network::Protocol::TCP;
83 default:
84 return Network::Protocol::TCP;
85 }
86 case Protocol::TCP: 148 case Protocol::TCP:
87 return Network::Protocol::TCP; 149 return Network::Protocol::TCP;
88 case Protocol::UDP: 150 case Protocol::UDP:
89 return Network::Protocol::UDP; 151 return Network::Protocol::UDP;
90 default: 152 default:
91 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); 153 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
92 return Network::Protocol::TCP; 154 return Network::Protocol::Unspecified;
155 }
156}
157
158Protocol Translate(Network::Protocol protocol) {
159 switch (protocol) {
160 case Network::Protocol::Unspecified:
161 return Protocol::Unspecified;
162 case Network::Protocol::TCP:
163 return Protocol::TCP;
164 case Network::Protocol::UDP:
165 return Protocol::UDP;
166 default:
167 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
168 return Protocol::Unspecified;
93 } 169 }
94} 170}
95 171
96Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { 172Network::PollEvents Translate(PollEvents flags) {
97 Network::PollEvents result{}; 173 Network::PollEvents result{};
98 const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { 174 const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) {
99 if (True(flags & from)) { 175 if (True(flags & from)) {
@@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
107 translate(PollEvents::Err, Network::PollEvents::Err); 183 translate(PollEvents::Err, Network::PollEvents::Err);
108 translate(PollEvents::Hup, Network::PollEvents::Hup); 184 translate(PollEvents::Hup, Network::PollEvents::Hup);
109 translate(PollEvents::Nval, Network::PollEvents::Nval); 185 translate(PollEvents::Nval, Network::PollEvents::Nval);
186 translate(PollEvents::RdNorm, Network::PollEvents::RdNorm);
187 translate(PollEvents::RdBand, Network::PollEvents::RdBand);
188 translate(PollEvents::WrBand, Network::PollEvents::WrBand);
110 189
111 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); 190 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
112 return result; 191 return result;
113} 192}
114 193
115PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { 194PollEvents Translate(Network::PollEvents flags) {
116 PollEvents result{}; 195 PollEvents result{};
117 const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { 196 const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) {
118 if (True(flags & from)) { 197 if (True(flags & from)) {
@@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
127 translate(Network::PollEvents::Err, PollEvents::Err); 206 translate(Network::PollEvents::Err, PollEvents::Err);
128 translate(Network::PollEvents::Hup, PollEvents::Hup); 207 translate(Network::PollEvents::Hup, PollEvents::Hup);
129 translate(Network::PollEvents::Nval, PollEvents::Nval); 208 translate(Network::PollEvents::Nval, PollEvents::Nval);
209 translate(Network::PollEvents::RdNorm, PollEvents::RdNorm);
210 translate(Network::PollEvents::RdBand, PollEvents::RdBand);
211 translate(Network::PollEvents::WrBand, PollEvents::WrBand);
130 212
131 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); 213 UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
132 return result; 214 return result;
133} 215}
134 216
135Network::SockAddrIn Translate(SockAddrIn value) { 217Network::SockAddrIn Translate(SockAddrIn value) {
136 ASSERT(value.len == 0 || value.len == sizeof(value)); 218 // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets
219 // sin_len to 6 when deserializing getaddrinfo results).
220 ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6);
137 221
138 return { 222 return {
139 .family = Translate(static_cast<Domain>(value.family)), 223 .family = Translate(static_cast<Domain>(value.family)),
diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h
index c93291d3e..694868b37 100644
--- a/src/core/hle/service/sockets/sockets_translate.h
+++ b/src/core/hle/service/sockets/sockets_translate.h
@@ -17,6 +17,9 @@ Errno Translate(Network::Errno value);
17/// Translate abstract return value errno pair to guest return value errno pair 17/// Translate abstract return value errno pair to guest return value errno pair
18std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value); 18std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value);
19 19
20/// Translate abstract getaddrinfo error to guest getaddrinfo error
21GetAddrInfoError Translate(Network::GetAddrInfoError value);
22
20/// Translate guest domain to abstract domain 23/// Translate guest domain to abstract domain
21Network::Domain Translate(Domain domain); 24Network::Domain Translate(Domain domain);
22 25
@@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain);
26/// Translate guest type to abstract type 29/// Translate guest type to abstract type
27Network::Type Translate(Type type); 30Network::Type Translate(Type type);
28 31
32/// Translate abstract type to guest type
33Type Translate(Network::Type type);
34
29/// Translate guest protocol to abstract protocol 35/// Translate guest protocol to abstract protocol
30Network::Protocol Translate(Type type, Protocol protocol); 36Network::Protocol Translate(Protocol protocol);
31 37
32/// Translate abstract poll event flags to guest poll event flags 38/// Translate abstract protocol to guest protocol
33Network::PollEvents TranslatePollEventsToHost(PollEvents flags); 39Protocol Translate(Network::Protocol protocol);
34 40
35/// Translate guest poll event flags to abstract poll event flags 41/// Translate guest poll event flags to abstract poll event flags
36PollEvents TranslatePollEventsToGuest(Network::PollEvents flags); 42Network::PollEvents Translate(PollEvents flags);
43
44/// Translate abstract poll event flags to guest poll event flags
45PollEvents Translate(Network::PollEvents flags);
37 46
38/// Translate guest socket address structure to abstract socket address structure 47/// Translate guest socket address structure to abstract socket address structure
39Network::SockAddrIn Translate(SockAddrIn value); 48Network::SockAddrIn Translate(SockAddrIn value);
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..9c96f9763 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project 1// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later 2// SPDX-License-Identifier: GPL-2.0-or-later
3 3
4#include "common/string_util.h"
5
6#include "core/core.h"
4#include "core/hle/service/ipc_helpers.h" 7#include "core/hle/service/ipc_helpers.h"
5#include "core/hle/service/server_manager.h" 8#include "core/hle/service/server_manager.h"
6#include "core/hle/service/service.h" 9#include "core/hle/service/service.h"
10#include "core/hle/service/sm/sm.h"
11#include "core/hle/service/sockets/bsd.h"
7#include "core/hle/service/ssl/ssl.h" 12#include "core/hle/service/ssl/ssl.h"
13#include "core/hle/service/ssl/ssl_backend.h"
14#include "core/internal_network/network.h"
15#include "core/internal_network/sockets.h"
8 16
9namespace Service::SSL { 17namespace Service::SSL {
10 18
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
20 CrlImportDateCheckEnable = 1, 28 CrlImportDateCheckEnable = 1,
21}; 29};
22 30
31// This is nn::ssl::Connection::IoMode
32enum class IoMode : u32 {
33 Blocking = 1,
34 NonBlocking = 2,
35};
36
37// This is nn::ssl::sf::OptionType
38enum class OptionType : u32 {
39 DoNotCloseSocket = 0,
40 GetServerCertChain = 1,
41};
42
23// This is nn::ssl::sf::SslVersion 43// This is nn::ssl::sf::SslVersion
24struct SslVersion { 44struct SslVersion {
25 union { 45 union {
@@ -34,35 +54,42 @@ struct SslVersion {
34 }; 54 };
35}; 55};
36 56
57struct SslContextSharedData {
58 u32 connection_count = 0;
59};
60
37class ISslConnection final : public ServiceFramework<ISslConnection> { 61class ISslConnection final : public ServiceFramework<ISslConnection> {
38public: 62public:
39 explicit ISslConnection(Core::System& system_, SslVersion version) 63 explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in,
40 : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { 64 std::shared_ptr<SslContextSharedData>& shared_data_in,
65 std::unique_ptr<SSLConnectionBackend>&& backend_in)
66 : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in},
67 shared_data{shared_data_in}, backend{std::move(backend_in)} {
41 // clang-format off 68 // clang-format off
42 static const FunctionInfo functions[] = { 69 static const FunctionInfo functions[] = {
43 {0, nullptr, "SetSocketDescriptor"}, 70 {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
44 {1, nullptr, "SetHostName"}, 71 {1, &ISslConnection::SetHostName, "SetHostName"},
45 {2, nullptr, "SetVerifyOption"}, 72 {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
46 {3, nullptr, "SetIoMode"}, 73 {3, &ISslConnection::SetIoMode, "SetIoMode"},
47 {4, nullptr, "GetSocketDescriptor"}, 74 {4, nullptr, "GetSocketDescriptor"},
48 {5, nullptr, "GetHostName"}, 75 {5, nullptr, "GetHostName"},
49 {6, nullptr, "GetVerifyOption"}, 76 {6, nullptr, "GetVerifyOption"},
50 {7, nullptr, "GetIoMode"}, 77 {7, nullptr, "GetIoMode"},
51 {8, nullptr, "DoHandshake"}, 78 {8, &ISslConnection::DoHandshake, "DoHandshake"},
52 {9, nullptr, "DoHandshakeGetServerCert"}, 79 {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
53 {10, nullptr, "Read"}, 80 {10, &ISslConnection::Read, "Read"},
54 {11, nullptr, "Write"}, 81 {11, &ISslConnection::Write, "Write"},
55 {12, nullptr, "Pending"}, 82 {12, &ISslConnection::Pending, "Pending"},
56 {13, nullptr, "Peek"}, 83 {13, nullptr, "Peek"},
57 {14, nullptr, "Poll"}, 84 {14, nullptr, "Poll"},
58 {15, nullptr, "GetVerifyCertError"}, 85 {15, nullptr, "GetVerifyCertError"},
59 {16, nullptr, "GetNeededServerCertBufferSize"}, 86 {16, nullptr, "GetNeededServerCertBufferSize"},
60 {17, nullptr, "SetSessionCacheMode"}, 87 {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
61 {18, nullptr, "GetSessionCacheMode"}, 88 {18, nullptr, "GetSessionCacheMode"},
62 {19, nullptr, "FlushSessionCache"}, 89 {19, nullptr, "FlushSessionCache"},
63 {20, nullptr, "SetRenegotiationMode"}, 90 {20, nullptr, "SetRenegotiationMode"},
64 {21, nullptr, "GetRenegotiationMode"}, 91 {21, nullptr, "GetRenegotiationMode"},
65 {22, nullptr, "SetOption"}, 92 {22, &ISslConnection::SetOption, "SetOption"},
66 {23, nullptr, "GetOption"}, 93 {23, nullptr, "GetOption"},
67 {24, nullptr, "GetVerifyCertErrors"}, 94 {24, nullptr, "GetVerifyCertErrors"},
68 {25, nullptr, "GetCipherInfo"}, 95 {25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,299 @@ public:
80 // clang-format on 107 // clang-format on
81 108
82 RegisterHandlers(functions); 109 RegisterHandlers(functions);
110
111 shared_data->connection_count++;
112 }
113
114 ~ISslConnection() {
115 shared_data->connection_count--;
116 if (fd_to_close.has_value()) {
117 const s32 fd = *fd_to_close;
118 if (!do_not_close_socket) {
119 LOG_ERROR(Service_SSL,
120 "do_not_close_socket was changed after setting socket; is this right?");
121 } else {
122 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
123 if (bsd) {
124 auto err = bsd->CloseImpl(fd);
125 if (err != Service::Sockets::Errno::SUCCESS) {
126 LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err);
127 }
128 }
129 }
130 }
83 } 131 }
84 132
85private: 133private:
86 SslVersion ssl_version; 134 SslVersion ssl_version;
135 std::shared_ptr<SslContextSharedData> shared_data;
136 std::unique_ptr<SSLConnectionBackend> backend;
137 std::optional<int> fd_to_close;
138 bool do_not_close_socket = false;
139 bool get_server_cert_chain = false;
140 std::shared_ptr<Network::SocketBase> socket;
141 bool did_set_host_name = false;
142 bool did_handshake = false;
143
144 ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
145 LOG_DEBUG(Service_SSL, "called, fd={}", fd);
146 ASSERT(!did_handshake);
147 auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
148 ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
149 s32 ret_fd;
150 // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
151 if (do_not_close_socket) {
152 auto res = bsd->DuplicateSocketImpl(fd);
153 if (!res.has_value()) {
154 LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
155 return ResultInvalidSocket;
156 }
157 fd = *res;
158 fd_to_close = fd;
159 ret_fd = fd;
160 } else {
161 ret_fd = -1;
162 }
163 std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
164 if (!sock.has_value()) {
165 LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
166 return ResultInvalidSocket;
167 }
168 socket = std::move(*sock);
169 backend->SetSocket(socket);
170 return ret_fd;
171 }
172
173 Result SetHostNameImpl(const std::string& hostname) {
174 LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
175 ASSERT(!did_handshake);
176 Result res = backend->SetHostName(hostname);
177 if (res == ResultSuccess) {
178 did_set_host_name = true;
179 }
180 return res;
181 }
182
183 Result SetVerifyOptionImpl(u32 option) {
184 ASSERT(!did_handshake);
185 LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
186 return ResultSuccess;
187 }
188
189 Result SetIoModeImpl(u32 input_mode) {
190 auto mode = static_cast<IoMode>(input_mode);
191 ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
192 ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
193
194 const bool non_block = mode == IoMode::NonBlocking;
195 const Network::Errno error = socket->SetNonBlock(non_block);
196 if (error != Network::Errno::SUCCESS) {
197 LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
198 }
199 return ResultSuccess;
200 }
201
202 Result SetSessionCacheModeImpl(u32 mode) {
203 ASSERT(!did_handshake);
204 LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
205 return ResultSuccess;
206 }
207
208 Result DoHandshakeImpl() {
209 ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
210 ASSERT_OR_EXECUTE_MSG(
211 did_set_host_name, { return ResultInternalError; },
212 "Expected SetHostName before DoHandshake");
213 Result res = backend->DoHandshake();
214 did_handshake = res.IsSuccess();
215 return res;
216 }
217
218 std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
219 struct Header {
220 u64 magic;
221 u32 count;
222 u32 pad;
223 };
224 struct EntryHeader {
225 u32 size;
226 u32 offset;
227 };
228 if (!get_server_cert_chain) {
229 // Just return the first one, unencoded.
230 ASSERT_OR_EXECUTE_MSG(
231 !certs.empty(), { return {}; }, "Should be at least one server cert");
232 return certs[0];
233 }
234 std::vector<u8> ret;
235 Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
236 ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
237 size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
238 for (auto& cert : certs) {
239 EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
240 data_offset += cert.size();
241 ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
242 reinterpret_cast<u8*>(&entry_header + 1));
243 }
244 for (auto& cert : certs) {
245 ret.insert(ret.end(), cert.begin(), cert.end());
246 }
247 return ret;
248 }
249
250 ResultVal<std::vector<u8>> ReadImpl(size_t size) {
251 ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
252 std::vector<u8> res(size);
253 ResultVal<size_t> actual = backend->Read(res);
254 if (actual.Failed()) {
255 return actual.Code();
256 }
257 res.resize(*actual);
258 return res;
259 }
260
261 ResultVal<size_t> WriteImpl(std::span<const u8> data) {
262 ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
263 return backend->Write(data);
264 }
265
266 ResultVal<s32> PendingImpl() {
267 LOG_WARNING(Service_SSL, "(STUBBED) called.");
268 return 0;
269 }
270
271 void SetSocketDescriptor(HLERequestContext& ctx) {
272 IPC::RequestParser rp{ctx};
273 const s32 fd = rp.Pop<s32>();
274 const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
275 IPC::ResponseBuilder rb{ctx, 3};
276 rb.Push(res.Code());
277 rb.Push<s32>(res.ValueOr(-1));
278 }
279
280 void SetHostName(HLERequestContext& ctx) {
281 const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
282 const Result res = SetHostNameImpl(hostname);
283 IPC::ResponseBuilder rb{ctx, 2};
284 rb.Push(res);
285 }
286
287 void SetVerifyOption(HLERequestContext& ctx) {
288 IPC::RequestParser rp{ctx};
289 const u32 option = rp.Pop<u32>();
290 const Result res = SetVerifyOptionImpl(option);
291 IPC::ResponseBuilder rb{ctx, 2};
292 rb.Push(res);
293 }
294
295 void SetIoMode(HLERequestContext& ctx) {
296 IPC::RequestParser rp{ctx};
297 const u32 mode = rp.Pop<u32>();
298 const Result res = SetIoModeImpl(mode);
299 IPC::ResponseBuilder rb{ctx, 2};
300 rb.Push(res);
301 }
302
303 void DoHandshake(HLERequestContext& ctx) {
304 const Result res = DoHandshakeImpl();
305 IPC::ResponseBuilder rb{ctx, 2};
306 rb.Push(res);
307 }
308
309 void DoHandshakeGetServerCert(HLERequestContext& ctx) {
310 struct OutputParameters {
311 u32 certs_size;
312 u32 certs_count;
313 };
314 static_assert(sizeof(OutputParameters) == 0x8);
315
316 const Result res = DoHandshakeImpl();
317 OutputParameters out{};
318 if (res == ResultSuccess) {
319 auto certs = backend->GetServerCerts();
320 if (certs.Succeeded()) {
321 const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
322 ctx.WriteBuffer(certs_buf);
323 out.certs_count = static_cast<u32>(certs->size());
324 out.certs_size = static_cast<u32>(certs_buf.size());
325 }
326 }
327 IPC::ResponseBuilder rb{ctx, 4};
328 rb.Push(res);
329 rb.PushRaw(out);
330 }
331
332 void Read(HLERequestContext& ctx) {
333 const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
334 IPC::ResponseBuilder rb{ctx, 3};
335 rb.Push(res.Code());
336 if (res.Succeeded()) {
337 rb.Push(static_cast<u32>(res->size()));
338 ctx.WriteBuffer(*res);
339 } else {
340 rb.Push(static_cast<u32>(0));
341 }
342 }
343
344 void Write(HLERequestContext& ctx) {
345 const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
346 IPC::ResponseBuilder rb{ctx, 3};
347 rb.Push(res.Code());
348 rb.Push(static_cast<u32>(res.ValueOr(0)));
349 }
350
351 void Pending(HLERequestContext& ctx) {
352 const ResultVal<s32> res = PendingImpl();
353 IPC::ResponseBuilder rb{ctx, 3};
354 rb.Push(res.Code());
355 rb.Push<s32>(res.ValueOr(0));
356 }
357
358 void SetSessionCacheMode(HLERequestContext& ctx) {
359 IPC::RequestParser rp{ctx};
360 const u32 mode = rp.Pop<u32>();
361 const Result res = SetSessionCacheModeImpl(mode);
362 IPC::ResponseBuilder rb{ctx, 2};
363 rb.Push(res);
364 }
365
366 void SetOption(HLERequestContext& ctx) {
367 struct Parameters {
368 OptionType option;
369 s32 value;
370 };
371 static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
372
373 IPC::RequestParser rp{ctx};
374 const auto parameters = rp.PopRaw<Parameters>();
375
376 switch (parameters.option) {
377 case OptionType::DoNotCloseSocket:
378 do_not_close_socket = static_cast<bool>(parameters.value);
379 break;
380 case OptionType::GetServerCertChain:
381 get_server_cert_chain = static_cast<bool>(parameters.value);
382 break;
383 default:
384 LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
385 parameters.value);
386 }
387
388 IPC::ResponseBuilder rb{ctx, 2};
389 rb.Push(ResultSuccess);
390 }
87}; 391};
88 392
89class ISslContext final : public ServiceFramework<ISslContext> { 393class ISslContext final : public ServiceFramework<ISslContext> {
90public: 394public:
91 explicit ISslContext(Core::System& system_, SslVersion version) 395 explicit ISslContext(Core::System& system_, SslVersion version)
92 : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { 396 : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
397 shared_data{std::make_shared<SslContextSharedData>()} {
93 static const FunctionInfo functions[] = { 398 static const FunctionInfo functions[] = {
94 {0, &ISslContext::SetOption, "SetOption"}, 399 {0, &ISslContext::SetOption, "SetOption"},
95 {1, nullptr, "GetOption"}, 400 {1, nullptr, "GetOption"},
96 {2, &ISslContext::CreateConnection, "CreateConnection"}, 401 {2, &ISslContext::CreateConnection, "CreateConnection"},
97 {3, nullptr, "GetConnectionCount"}, 402 {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
98 {4, &ISslContext::ImportServerPki, "ImportServerPki"}, 403 {4, &ISslContext::ImportServerPki, "ImportServerPki"},
99 {5, &ISslContext::ImportClientPki, "ImportClientPki"}, 404 {5, &ISslContext::ImportClientPki, "ImportClientPki"},
100 {6, nullptr, "RemoveServerPki"}, 405 {6, nullptr, "RemoveServerPki"},
@@ -111,6 +416,7 @@ public:
111 416
112private: 417private:
113 SslVersion ssl_version; 418 SslVersion ssl_version;
419 std::shared_ptr<SslContextSharedData> shared_data;
114 420
115 void SetOption(HLERequestContext& ctx) { 421 void SetOption(HLERequestContext& ctx) {
116 struct Parameters { 422 struct Parameters {
@@ -130,11 +436,24 @@ private:
130 } 436 }
131 437
132 void CreateConnection(HLERequestContext& ctx) { 438 void CreateConnection(HLERequestContext& ctx) {
133 LOG_WARNING(Service_SSL, "(STUBBED) called"); 439 LOG_WARNING(Service_SSL, "called");
440
441 auto backend_res = CreateSSLConnectionBackend();
134 442
135 IPC::ResponseBuilder rb{ctx, 2, 0, 1}; 443 IPC::ResponseBuilder rb{ctx, 2, 0, 1};
444 rb.Push(backend_res.Code());
445 if (backend_res.Succeeded()) {
446 rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
447 std::move(*backend_res));
448 }
449 }
450
451 void GetConnectionCount(HLERequestContext& ctx) {
452 LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
453
454 IPC::ResponseBuilder rb{ctx, 3};
136 rb.Push(ResultSuccess); 455 rb.Push(ResultSuccess);
137 rb.PushIpcInterface<ISslConnection>(system, ssl_version); 456 rb.Push(shared_data->connection_count);
138 } 457 }
139 458
140 void ImportServerPki(HLERequestContext& ctx) { 459 void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..25c16bcc1
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,45 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#pragma once
5
6#include "core/hle/result.h"
7
8#include "common/common_types.h"
9
10#include <memory>
11#include <span>
12#include <string>
13#include <vector>
14
15namespace Network {
16class SocketBase;
17}
18
19namespace Service::SSL {
20
21constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
22constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
23constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
24constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
25
26// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
27// with no way in the latter case to distinguish whether the client should poll
28// for read or write. The one official client I've seen handles this by always
29// polling for read (with a timeout).
30constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
31
32class SSLConnectionBackend {
33public:
34 virtual ~SSLConnectionBackend() {}
35 virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
36 virtual Result SetHostName(const std::string& hostname) = 0;
37 virtual Result DoHandshake() = 0;
38 virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
39 virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
40 virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
41};
42
43ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
44
45} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..f2f0ef706
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,16 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5
6#include "common/logging/log.h"
7
8namespace Service::SSL {
9
10ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
11 LOG_ERROR(Service_SSL,
12 "Can't create SSL connection because no SSL backend is available on this platform");
13 return ResultInternalError;
14}
15
16} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..f69674f77
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,351 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/fs/file.h"
9#include "common/hex_util.h"
10#include "common/string_util.h"
11
12#include <mutex>
13
14#include <openssl/bio.h>
15#include <openssl/err.h>
16#include <openssl/ssl.h>
17#include <openssl/x509.h>
18
19using namespace Common::FS;
20
21namespace Service::SSL {
22
23// Import OpenSSL's `SSL` type into the namespace. This is needed because the
24// namespace is also named `SSL`.
25using ::SSL;
26
27namespace {
28
29std::once_flag one_time_init_flag;
30bool one_time_init_success = false;
31
32SSL_CTX* ssl_ctx;
33IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
34BIO_METHOD* bio_meth;
35
36Result CheckOpenSSLErrors();
37void OneTimeInit();
38void OneTimeInitLogFile();
39bool OneTimeInitBIO();
40
41} // namespace
42
43class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
44public:
45 Result Init() {
46 std::call_once(one_time_init_flag, OneTimeInit);
47
48 if (!one_time_init_success) {
49 LOG_ERROR(Service_SSL,
50 "Can't create SSL connection because OpenSSL one-time initialization failed");
51 return ResultInternalError;
52 }
53
54 ssl = SSL_new(ssl_ctx);
55 if (!ssl) {
56 LOG_ERROR(Service_SSL, "SSL_new failed");
57 return CheckOpenSSLErrors();
58 }
59
60 SSL_set_connect_state(ssl);
61
62 bio = BIO_new(bio_meth);
63 if (!bio) {
64 LOG_ERROR(Service_SSL, "BIO_new failed");
65 return CheckOpenSSLErrors();
66 }
67
68 BIO_set_data(bio, this);
69 BIO_set_init(bio, 1);
70 SSL_set_bio(ssl, bio, bio);
71
72 return ResultSuccess;
73 }
74
75 void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
76 socket = std::move(socket_in);
77 }
78
79 Result SetHostName(const std::string& hostname) override {
80 if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
81 LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
82 return CheckOpenSSLErrors();
83 }
84 if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
85 LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
86 return CheckOpenSSLErrors();
87 }
88 return ResultSuccess;
89 }
90
91 Result DoHandshake() override {
92 SSL_set_verify_result(ssl, X509_V_OK);
93 const int ret = SSL_do_handshake(ssl);
94 const long verify_result = SSL_get_verify_result(ssl);
95 if (verify_result != X509_V_OK) {
96 LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
97 X509_verify_cert_error_string(verify_result));
98 return CheckOpenSSLErrors();
99 }
100 if (ret <= 0) {
101 const int ssl_err = SSL_get_error(ssl, ret);
102 if (ssl_err == SSL_ERROR_ZERO_RETURN ||
103 (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
104 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
105 return ResultInternalError;
106 }
107 }
108 return HandleReturn("SSL_do_handshake", 0, ret).Code();
109 }
110
111 ResultVal<size_t> Read(std::span<u8> data) override {
112 size_t actual;
113 const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
114 return HandleReturn("SSL_read_ex", actual, ret);
115 }
116
117 ResultVal<size_t> Write(std::span<const u8> data) override {
118 size_t actual;
119 const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
120 return HandleReturn("SSL_write_ex", actual, ret);
121 }
122
123 ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
124 const int ssl_err = SSL_get_error(ssl, ret);
125 CheckOpenSSLErrors();
126 switch (ssl_err) {
127 case SSL_ERROR_NONE:
128 return actual;
129 case SSL_ERROR_ZERO_RETURN:
130 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
131 // DoHandshake special-cases this, but for Read and Write:
132 return size_t(0);
133 case SSL_ERROR_WANT_READ:
134 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
135 return ResultWouldBlock;
136 case SSL_ERROR_WANT_WRITE:
137 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
138 return ResultWouldBlock;
139 default:
140 if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
141 LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
142 return size_t(0);
143 }
144 LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
145 return ResultInternalError;
146 }
147 }
148
149 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
150 STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
151 if (!chain) {
152 LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
153 return ResultInternalError;
154 }
155 std::vector<std::vector<u8>> ret;
156 int count = sk_X509_num(chain);
157 ASSERT(count >= 0);
158 for (int i = 0; i < count; i++) {
159 X509* x509 = sk_X509_value(chain, i);
160 ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
161 unsigned char* buf = nullptr;
162 int len = i2d_X509(x509, &buf);
163 ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
164 ret.emplace_back(buf, buf + len);
165 OPENSSL_free(buf);
166 }
167 return ret;
168 }
169
170 ~SSLConnectionBackendOpenSSL() {
171 // these are null-tolerant:
172 SSL_free(ssl);
173 BIO_free(bio);
174 }
175
176 static void KeyLogCallback(const SSL* ssl, const char* line) {
177 std::string str(line);
178 str.push_back('\n');
179 // Do this in a single WriteString for atomicity if multiple instances
180 // are running on different threads (though that can't currently
181 // happen).
182 if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
183 LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
184 }
185 LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
186 }
187
188 static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
189 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
190 ASSERT_OR_EXECUTE_MSG(
191 self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
192 BIO_clear_retry_flags(bio);
193 auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
194 switch (err) {
195 case Network::Errno::SUCCESS:
196 *actual_p = actual;
197 return 1;
198 case Network::Errno::AGAIN:
199 BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
200 return 0;
201 default:
202 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
203 return -1;
204 }
205 }
206
207 static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
208 auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
209 ASSERT_OR_EXECUTE_MSG(
210 self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
211 BIO_clear_retry_flags(bio);
212 auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
213 switch (err) {
214 case Network::Errno::SUCCESS:
215 *actual_p = actual;
216 if (actual == 0) {
217 self->got_read_eof = true;
218 }
219 return actual ? 1 : 0;
220 case Network::Errno::AGAIN:
221 BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
222 return 0;
223 default:
224 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
225 return -1;
226 }
227 }
228
229 static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) {
230 switch (cmd) {
231 case BIO_CTRL_FLUSH:
232 // Nothing to flush.
233 return 1;
234 case BIO_CTRL_PUSH:
235 case BIO_CTRL_POP:
236#ifdef BIO_CTRL_GET_KTLS_SEND
237 case BIO_CTRL_GET_KTLS_SEND:
238 case BIO_CTRL_GET_KTLS_RECV:
239#endif
240 // We don't support these operations, but don't bother logging them
241 // as they're nothing unusual.
242 return 0;
243 default:
244 LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg);
245 return 0;
246 }
247 }
248
249 SSL* ssl = nullptr;
250 BIO* bio = nullptr;
251 bool got_read_eof = false;
252
253 std::shared_ptr<Network::SocketBase> socket;
254};
255
256ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
257 auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
258 const Result res = conn->Init();
259 if (res.IsFailure()) {
260 return res;
261 }
262 return conn;
263}
264
265namespace {
266
267Result CheckOpenSSLErrors() {
268 unsigned long rc;
269 const char* file;
270 int line;
271 const char* func;
272 const char* data;
273 int flags;
274#if OPENSSL_VERSION_NUMBER >= 0x30000000L
275 while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags)))
276#else
277 // Can't get function names from OpenSSL on this version, so use mine:
278 func = __func__;
279 while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags)))
280#endif
281 {
282 std::string msg;
283 msg.resize(1024, '\0');
284 ERR_error_string_n(rc, msg.data(), msg.size());
285 msg.resize(strlen(msg.data()), '\0');
286 if (flags & ERR_TXT_STRING) {
287 msg.append(" | ");
288 msg.append(data);
289 }
290 Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
291 Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
292 msg);
293 }
294 return ResultInternalError;
295}
296
297void OneTimeInit() {
298 ssl_ctx = SSL_CTX_new(TLS_client_method());
299 if (!ssl_ctx) {
300 LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
301 CheckOpenSSLErrors();
302 return;
303 }
304
305 SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
306
307 if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
308 LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
309 CheckOpenSSLErrors();
310 return;
311 }
312
313 OneTimeInitLogFile();
314
315 if (!OneTimeInitBIO()) {
316 return;
317 }
318
319 one_time_init_success = true;
320}
321
322void OneTimeInitLogFile() {
323 const char* logfile = getenv("SSLKEYLOGFILE");
324 if (logfile) {
325 key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
326 FileShareFlag::ShareWriteOnly);
327 if (key_log_file.IsOpen()) {
328 SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
329 } else {
330 LOG_CRITICAL(Service_SSL,
331 "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
332 }
333 }
334}
335
336bool OneTimeInitBIO() {
337 bio_meth =
338 BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
339 if (!bio_meth ||
340 !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
341 !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
342 !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
343 LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
344 return false;
345 }
346 return true;
347}
348
349} // namespace
350
351} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..a1d6a186e
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,543 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include "common/error.h"
9#include "common/fs/file.h"
10#include "common/hex_util.h"
11#include "common/string_util.h"
12
13#include <mutex>
14
15namespace {
16
17// These includes are inside the namespace to avoid a conflict on MinGW where
18// the headers define an enum containing Network and Service as enumerators
19// (which clash with the correspondingly named namespaces).
20#define SECURITY_WIN32
21#include <schnlsp.h>
22#include <security.h>
23
24std::once_flag one_time_init_flag;
25bool one_time_init_success = false;
26
27SCHANNEL_CRED schannel_cred{};
28CredHandle cred_handle;
29
30static void OneTimeInit() {
31 schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
32 schannel_cred.dwFlags =
33 SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
34 SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
35 SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate
36 // ^ I'm assuming that nobody would want to connect Yuzu to a
37 // service that requires some OS-provided corporate client
38 // certificate, and presenting one to some arbitrary server
39 // might be a privacy concern? Who knows, though.
40
41 const SECURITY_STATUS ret =
42 AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
43 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
44 if (ret != SEC_E_OK) {
45 // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
46 LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
47 Common::NativeErrorToString(ret));
48 return;
49 }
50
51 if (getenv("SSLKEYLOGFILE")) {
52 LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
53 "keys; not logging keys!");
54 // Not fatal.
55 }
56
57 one_time_init_success = true;
58}
59
60} // namespace
61
62namespace Service::SSL {
63
64class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
65public:
66 Result Init() {
67 std::call_once(one_time_init_flag, OneTimeInit);
68
69 if (!one_time_init_success) {
70 LOG_ERROR(
71 Service_SSL,
72 "Can't create SSL connection because Schannel one-time initialization failed");
73 return ResultInternalError;
74 }
75
76 return ResultSuccess;
77 }
78
79 void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
80 socket = std::move(socket_in);
81 }
82
83 Result SetHostName(const std::string& hostname_in) override {
84 hostname = hostname_in;
85 return ResultSuccess;
86 }
87
88 Result DoHandshake() override {
89 while (1) {
90 Result r;
91 switch (handshake_state) {
92 case HandshakeState::Initial:
93 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
94 (r = CallInitializeSecurityContext()) != ResultSuccess) {
95 return r;
96 }
97 // CallInitializeSecurityContext updated `handshake_state`.
98 continue;
99 case HandshakeState::ContinueNeeded:
100 case HandshakeState::IncompleteMessage:
101 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
102 (r = FillCiphertextReadBuf()) != ResultSuccess) {
103 return r;
104 }
105 if (ciphertext_read_buf.empty()) {
106 LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
107 return ResultInternalError;
108 }
109 if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
110 return r;
111 }
112 // CallInitializeSecurityContext updated `handshake_state`.
113 continue;
114 case HandshakeState::DoneAfterFlush:
115 if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
116 return r;
117 }
118 handshake_state = HandshakeState::Connected;
119 return ResultSuccess;
120 case HandshakeState::Connected:
121 LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
122 return ResultInternalError;
123 case HandshakeState::Error:
124 return ResultInternalError;
125 }
126 }
127 }
128
129 Result FillCiphertextReadBuf() {
130 const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
131 read_buf_fill_size = 0;
132 // This unnecessarily zeroes the buffer; oh well.
133 const size_t offset = ciphertext_read_buf.size();
134 ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
135 ciphertext_read_buf.resize(offset + fill_size, 0);
136 const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
137 const auto [actual, err] = socket->Recv(0, read_span);
138 switch (err) {
139 case Network::Errno::SUCCESS:
140 ASSERT(static_cast<size_t>(actual) <= fill_size);
141 ciphertext_read_buf.resize(offset + actual);
142 return ResultSuccess;
143 case Network::Errno::AGAIN:
144 ciphertext_read_buf.resize(offset);
145 return ResultWouldBlock;
146 default:
147 ciphertext_read_buf.resize(offset);
148 LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
149 return ResultInternalError;
150 }
151 }
152
153 // Returns success if the write buffer has been completely emptied.
154 Result FlushCiphertextWriteBuf() {
155 while (!ciphertext_write_buf.empty()) {
156 const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
157 switch (err) {
158 case Network::Errno::SUCCESS:
159 ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
160 ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
161 ciphertext_write_buf.begin() + actual);
162 break;
163 case Network::Errno::AGAIN:
164 return ResultWouldBlock;
165 default:
166 LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
167 return ResultInternalError;
168 }
169 }
170 return ResultSuccess;
171 }
172
173 Result CallInitializeSecurityContext() {
174 const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
175 ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
176 ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
177 ISC_REQ_USE_SUPPLIED_CREDS;
178 unsigned long attr;
179 // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
180 std::array<SecBuffer, 2> input_buffers{{
181 // only used if `initial_call_done`
182 {
183 // [0]
184 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
185 .BufferType = SECBUFFER_TOKEN,
186 .pvBuffer = ciphertext_read_buf.data(),
187 },
188 {
189 // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
190 // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
191 // whole buffer wasn't used)
192 .cbBuffer = 0,
193 .BufferType = SECBUFFER_EMPTY,
194 .pvBuffer = nullptr,
195 },
196 }};
197 std::array<SecBuffer, 2> output_buffers{{
198 {
199 .cbBuffer = 0,
200 .BufferType = SECBUFFER_TOKEN,
201 .pvBuffer = nullptr,
202 }, // [0]
203 {
204 .cbBuffer = 0,
205 .BufferType = SECBUFFER_ALERT,
206 .pvBuffer = nullptr,
207 }, // [1]
208 }};
209 SecBufferDesc input_desc{
210 .ulVersion = SECBUFFER_VERSION,
211 .cBuffers = static_cast<unsigned long>(input_buffers.size()),
212 .pBuffers = input_buffers.data(),
213 };
214 SecBufferDesc output_desc{
215 .ulVersion = SECBUFFER_VERSION,
216 .cBuffers = static_cast<unsigned long>(output_buffers.size()),
217 .pBuffers = output_buffers.data(),
218 };
219 ASSERT_OR_EXECUTE_MSG(
220 input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
221 { return ResultInternalError; }, "read buffer too large");
222
223 bool initial_call_done = handshake_state != HandshakeState::Initial;
224 if (initial_call_done) {
225 LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
226 ciphertext_read_buf.size());
227 }
228
229 const SECURITY_STATUS ret =
230 InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
231 // Caller ensured we have set a hostname:
232 const_cast<char*>(hostname.value().c_str()), req,
233 0, // Reserved1
234 0, // TargetDataRep not used with Schannel
235 initial_call_done ? &input_desc : nullptr,
236 0, // Reserved2
237 initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
238 nullptr); // ptsExpiry
239
240 if (output_buffers[0].pvBuffer) {
241 const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
242 output_buffers[0].cbBuffer);
243 ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
244 FreeContextBuffer(output_buffers[0].pvBuffer);
245 }
246
247 if (output_buffers[1].pvBuffer) {
248 const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
249 output_buffers[1].cbBuffer);
250 // The documentation doesn't explain what format this data is in.
251 LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
252 Common::HexToString(span));
253 }
254
255 switch (ret) {
256 case SEC_I_CONTINUE_NEEDED:
257 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
258 if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
259 LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
260 ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
261 ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
262 ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
263 } else {
264 ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
265 ciphertext_read_buf.clear();
266 }
267 handshake_state = HandshakeState::ContinueNeeded;
268 return ResultSuccess;
269 case SEC_E_INCOMPLETE_MESSAGE:
270 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
271 ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
272 read_buf_fill_size = input_buffers[1].cbBuffer;
273 handshake_state = HandshakeState::IncompleteMessage;
274 return ResultSuccess;
275 case SEC_E_OK:
276 LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
277 ciphertext_read_buf.clear();
278 handshake_state = HandshakeState::DoneAfterFlush;
279 return GrabStreamSizes();
280 default:
281 LOG_ERROR(Service_SSL,
282 "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
283 Common::NativeErrorToString(ret));
284 handshake_state = HandshakeState::Error;
285 return ResultInternalError;
286 }
287 }
288
289 Result GrabStreamSizes() {
290 const SECURITY_STATUS ret =
291 QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
292 if (ret != SEC_E_OK) {
293 LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
294 Common::NativeErrorToString(ret));
295 handshake_state = HandshakeState::Error;
296 return ResultInternalError;
297 }
298 return ResultSuccess;
299 }
300
301 ResultVal<size_t> Read(std::span<u8> data) override {
302 if (handshake_state != HandshakeState::Connected) {
303 LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
304 return ResultInternalError;
305 }
306 if (data.size() == 0 || got_read_eof) {
307 return size_t(0);
308 }
309 while (1) {
310 if (!cleartext_read_buf.empty()) {
311 const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
312 std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
313 cleartext_read_buf.erase(cleartext_read_buf.begin(),
314 cleartext_read_buf.begin() + read_size);
315 return read_size;
316 }
317 if (!ciphertext_read_buf.empty()) {
318 SecBuffer empty{
319 .cbBuffer = 0,
320 .BufferType = SECBUFFER_EMPTY,
321 .pvBuffer = nullptr,
322 };
323 std::array<SecBuffer, 5> buffers{{
324 {
325 .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
326 .BufferType = SECBUFFER_DATA,
327 .pvBuffer = ciphertext_read_buf.data(),
328 },
329 empty,
330 empty,
331 empty,
332 }};
333 ASSERT_OR_EXECUTE_MSG(
334 buffers[0].cbBuffer == ciphertext_read_buf.size(),
335 { return ResultInternalError; }, "read buffer too large");
336 SecBufferDesc desc{
337 .ulVersion = SECBUFFER_VERSION,
338 .cBuffers = static_cast<unsigned long>(buffers.size()),
339 .pBuffers = buffers.data(),
340 };
341 SECURITY_STATUS ret =
342 DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
343 switch (ret) {
344 case SEC_E_OK:
345 ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
346 { return ResultInternalError; });
347 ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
348 { return ResultInternalError; });
349 ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
350 { return ResultInternalError; });
351 cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
352 static_cast<u8*>(buffers[1].pvBuffer) +
353 buffers[1].cbBuffer);
354 if (buffers[3].BufferType == SECBUFFER_EXTRA) {
355 ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
356 ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
357 ciphertext_read_buf.end() - buffers[3].cbBuffer);
358 } else {
359 ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
360 ciphertext_read_buf.clear();
361 }
362 continue;
363 case SEC_E_INCOMPLETE_MESSAGE:
364 break;
365 case SEC_I_CONTEXT_EXPIRED:
366 // Server hung up by sending close_notify.
367 got_read_eof = true;
368 return size_t(0);
369 default:
370 LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
371 Common::NativeErrorToString(ret));
372 return ResultInternalError;
373 }
374 }
375 const Result r = FillCiphertextReadBuf();
376 if (r != ResultSuccess) {
377 return r;
378 }
379 if (ciphertext_read_buf.empty()) {
380 got_read_eof = true;
381 return size_t(0);
382 }
383 }
384 }
385
386 ResultVal<size_t> Write(std::span<const u8> data) override {
387 if (handshake_state != HandshakeState::Connected) {
388 LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
389 return ResultInternalError;
390 }
391 if (data.size() == 0) {
392 return size_t(0);
393 }
394 data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
395 if (!cleartext_write_buf.empty()) {
396 // Already in the middle of a write. It wouldn't make sense to not
397 // finish sending the entire buffer since TLS has
398 // header/MAC/padding/etc.
399 if (data.size() != cleartext_write_buf.size() ||
400 std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
401 LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
402 return ResultInternalError;
403 }
404 return WriteAlreadyEncryptedData();
405 } else {
406 cleartext_write_buf.assign(data.begin(), data.end());
407 }
408
409 std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
410 std::vector<u8> tmp_data_buf = cleartext_write_buf;
411 std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
412
413 std::array<SecBuffer, 3> buffers{{
414 {
415 .cbBuffer = stream_sizes.cbHeader,
416 .BufferType = SECBUFFER_STREAM_HEADER,
417 .pvBuffer = header_buf.data(),
418 },
419 {
420 .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
421 .BufferType = SECBUFFER_DATA,
422 .pvBuffer = tmp_data_buf.data(),
423 },
424 {
425 .cbBuffer = stream_sizes.cbTrailer,
426 .BufferType = SECBUFFER_STREAM_TRAILER,
427 .pvBuffer = trailer_buf.data(),
428 },
429 }};
430 ASSERT_OR_EXECUTE_MSG(
431 buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
432 "temp buffer too large");
433 SecBufferDesc desc{
434 .ulVersion = SECBUFFER_VERSION,
435 .cBuffers = static_cast<unsigned long>(buffers.size()),
436 .pBuffers = buffers.data(),
437 };
438
439 const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
440 if (ret != SEC_E_OK) {
441 LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
442 return ResultInternalError;
443 }
444 ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
445 header_buf.end());
446 ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
447 tmp_data_buf.end());
448 ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
449 trailer_buf.end());
450 return WriteAlreadyEncryptedData();
451 }
452
453 ResultVal<size_t> WriteAlreadyEncryptedData() {
454 const Result r = FlushCiphertextWriteBuf();
455 if (r != ResultSuccess) {
456 return r;
457 }
458 // write buf is empty
459 const size_t cleartext_bytes_written = cleartext_write_buf.size();
460 cleartext_write_buf.clear();
461 return cleartext_bytes_written;
462 }
463
464 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
465 PCCERT_CONTEXT returned_cert = nullptr;
466 const SECURITY_STATUS ret =
467 QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
468 if (ret != SEC_E_OK) {
469 LOG_ERROR(Service_SSL,
470 "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
471 Common::NativeErrorToString(ret));
472 return ResultInternalError;
473 }
474 PCCERT_CONTEXT some_cert = nullptr;
475 std::vector<std::vector<u8>> certs;
476 while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
477 certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
478 static_cast<u8*>(some_cert->pbCertEncoded) +
479 some_cert->cbCertEncoded);
480 }
481 std::reverse(certs.begin(),
482 certs.end()); // Windows returns certs in reverse order from what we want
483 CertFreeCertificateContext(returned_cert);
484 return certs;
485 }
486
487 ~SSLConnectionBackendSchannel() {
488 if (handshake_state != HandshakeState::Initial) {
489 DeleteSecurityContext(&ctxt);
490 }
491 }
492
493 enum class HandshakeState {
494 // Haven't called anything yet.
495 Initial,
496 // `SEC_I_CONTINUE_NEEDED` was returned by
497 // `InitializeSecurityContext`; must finish sending data (if any) in
498 // the write buffer, then read at least one byte before calling
499 // `InitializeSecurityContext` again.
500 ContinueNeeded,
501 // `SEC_E_INCOMPLETE_MESSAGE` was returned by
502 // `InitializeSecurityContext`; hopefully the write buffer is empty;
503 // must read at least one byte before calling
504 // `InitializeSecurityContext` again.
505 IncompleteMessage,
506 // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
507 // finish sending data in the write buffer before having `DoHandshake`
508 // report success.
509 DoneAfterFlush,
510 // We finished the above and are now connected. At this point, writing
511 // and reading are separate 'state machines' represented by the
512 // nonemptiness of the ciphertext and cleartext read and write buffers.
513 Connected,
514 // Another error was returned and we shouldn't allow initialization
515 // to continue.
516 Error,
517 } handshake_state = HandshakeState::Initial;
518
519 CtxtHandle ctxt;
520 SecPkgContext_StreamSizes stream_sizes;
521
522 std::shared_ptr<Network::SocketBase> socket;
523 std::optional<std::string> hostname;
524
525 std::vector<u8> ciphertext_read_buf;
526 std::vector<u8> ciphertext_write_buf;
527 std::vector<u8> cleartext_read_buf;
528 std::vector<u8> cleartext_write_buf;
529
530 bool got_read_eof = false;
531 size_t read_buf_fill_size = 0;
532};
533
534ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
535 auto conn = std::make_unique<SSLConnectionBackendSchannel>();
536 const Result res = conn->Init();
537 if (res.IsFailure()) {
538 return res;
539 }
540 return conn;
541}
542
543} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
new file mode 100644
index 000000000..be40a5aeb
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
@@ -0,0 +1,219 @@
1// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
2// SPDX-License-Identifier: GPL-2.0-or-later
3
4#include "core/hle/service/ssl/ssl_backend.h"
5#include "core/internal_network/network.h"
6#include "core/internal_network/sockets.h"
7
8#include <mutex>
9
10#include <Security/SecureTransport.h>
11
12// SecureTransport has been deprecated in its entirety in favor of
13// Network.framework, but that does not allow layering TLS on top of an
14// arbitrary socket.
15#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
16
17namespace {
18
19template <typename T>
20struct CFReleaser {
21 T ptr;
22
23 YUZU_NON_COPYABLE(CFReleaser);
24 constexpr CFReleaser() : ptr(nullptr) {}
25 constexpr CFReleaser(T ptr) : ptr(ptr) {}
26 constexpr operator T() {
27 return ptr;
28 }
29 ~CFReleaser() {
30 if (ptr) {
31 CFRelease(ptr);
32 }
33 }
34};
35
36std::string CFStringToString(CFStringRef cfstr) {
37 CFReleaser<CFDataRef> cfdata(
38 CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0));
39 ASSERT_OR_EXECUTE(cfdata, { return "???"; });
40 return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)),
41 CFDataGetLength(cfdata));
42}
43
44std::string OSStatusToString(OSStatus status) {
45 CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr));
46 if (!cfstr) {
47 return "[unknown error]";
48 }
49 return CFStringToString(cfstr);
50}
51
52} // namespace
53
54namespace Service::SSL {
55
56class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend {
57public:
58 Result Init() {
59 static std::once_flag once_flag;
60 std::call_once(once_flag, []() {
61 if (getenv("SSLKEYLOGFILE")) {
62 LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not "
63 "support exporting keys; not logging keys!");
64 // Not fatal.
65 }
66 });
67
68 context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType);
69 if (!context) {
70 LOG_ERROR(Service_SSL, "SSLCreateContext failed");
71 return ResultInternalError;
72 }
73
74 OSStatus status;
75 if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) ||
76 (status = SSLSetConnection(context, this))) {
77 LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}",
78 OSStatusToString(status));
79 return ResultInternalError;
80 }
81
82 return ResultSuccess;
83 }
84
85 void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override {
86 socket = std::move(in_socket);
87 }
88
89 Result SetHostName(const std::string& hostname) override {
90 OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size());
91 if (status) {
92 LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status));
93 return ResultInternalError;
94 }
95 return ResultSuccess;
96 }
97
98 Result DoHandshake() override {
99 OSStatus status = SSLHandshake(context);
100 return HandleReturn("SSLHandshake", 0, status).Code();
101 }
102
103 ResultVal<size_t> Read(std::span<u8> data) override {
104 size_t actual;
105 OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
106 ;
107 return HandleReturn("SSLRead", actual, status);
108 }
109
110 ResultVal<size_t> Write(std::span<const u8> data) override {
111 size_t actual;
112 OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
113 ;
114 return HandleReturn("SSLWrite", actual, status);
115 }
116
117 ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
118 switch (status) {
119 case 0:
120 return actual;
121 case errSSLWouldBlock:
122 return ResultWouldBlock;
123 default: {
124 std::string reason;
125 if (got_read_eof) {
126 reason = "server hung up";
127 } else {
128 reason = OSStatusToString(status);
129 }
130 LOG_ERROR(Service_SSL, "{} failed: {}", what, reason);
131 return ResultInternalError;
132 }
133 }
134 }
135
136 ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
137 CFReleaser<SecTrustRef> trust;
138 OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
139 if (status) {
140 LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
141 return ResultInternalError;
142 }
143 std::vector<std::vector<u8>> ret;
144 for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
145 SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
146 CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
147 ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
148 const u8* ptr = CFDataGetBytePtr(data);
149 ret.emplace_back(ptr, ptr + CFDataGetLength(data));
150 }
151 return ret;
152 }
153
154 static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
155 return ReadOrWriteCallback(connection, data, dataLength, true);
156 }
157
158 static OSStatus WriteCallback(SSLConnectionRef connection, const void* data,
159 size_t* dataLength) {
160 return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false);
161 }
162
163 static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength,
164 bool is_read) {
165 auto self =
166 static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection));
167 ASSERT_OR_EXECUTE_MSG(
168 self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket",
169 is_read ? "read" : "write");
170
171 // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
172 // expected to read/write the full requested dataLength or return an
173 // error, so we have to add a loop ourselves.
174 size_t requested_len = *dataLength;
175 size_t offset = 0;
176 while (offset < requested_len) {
177 std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset);
178 auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0);
179 LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset,
180 actual, cur.size(), static_cast<s32>(err));
181 switch (err) {
182 case Network::Errno::SUCCESS:
183 offset += actual;
184 if (actual == 0) {
185 ASSERT(is_read);
186 self->got_read_eof = true;
187 return errSecEndOfData;
188 }
189 break;
190 case Network::Errno::AGAIN:
191 *dataLength = offset;
192 return errSSLWouldBlock;
193 default:
194 LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}",
195 is_read ? "recv" : "send", err);
196 return errSecIO;
197 }
198 }
199 ASSERT(offset == requested_len);
200 return 0;
201 }
202
203private:
204 CFReleaser<SSLContextRef> context = nullptr;
205 bool got_read_eof = false;
206
207 std::shared_ptr<Network::SocketBase> socket;
208};
209
210ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
211 auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
212 const Result res = conn->Init();
213 if (res.IsFailure()) {
214 return res;
215 }
216 return conn;
217}
218
219} // namespace Service::SSL
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp
index 75ac10a9c..28f89c599 100644
--- a/src/core/internal_network/network.cpp
+++ b/src/core/internal_network/network.cpp
@@ -27,6 +27,7 @@
27 27
28#include "common/assert.h" 28#include "common/assert.h"
29#include "common/common_types.h" 29#include "common/common_types.h"
30#include "common/expected.h"
30#include "common/logging/log.h" 31#include "common/logging/log.h"
31#include "common/settings.h" 32#include "common/settings.h"
32#include "core/internal_network/network.h" 33#include "core/internal_network/network.h"
@@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) {
97 98
98Errno TranslateNativeError(int e) { 99Errno TranslateNativeError(int e) {
99 switch (e) { 100 switch (e) {
101 case 0:
102 return Errno::SUCCESS;
100 case WSAEBADF: 103 case WSAEBADF:
101 return Errno::BADF; 104 return Errno::BADF;
102 case WSAEINVAL: 105 case WSAEINVAL:
@@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) {
121 return Errno::MSGSIZE; 124 return Errno::MSGSIZE;
122 case WSAETIMEDOUT: 125 case WSAETIMEDOUT:
123 return Errno::TIMEDOUT; 126 return Errno::TIMEDOUT;
127 case WSAEINPROGRESS:
128 return Errno::INPROGRESS;
124 default: 129 default:
125 UNIMPLEMENTED_MSG("Unimplemented errno={}", e); 130 UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
126 return Errno::OTHER; 131 return Errno::OTHER;
@@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) {
195 200
196Errno TranslateNativeError(int e) { 201Errno TranslateNativeError(int e) {
197 switch (e) { 202 switch (e) {
203 case 0:
204 return Errno::SUCCESS;
198 case EBADF: 205 case EBADF:
199 return Errno::BADF; 206 return Errno::BADF;
200 case EINVAL: 207 case EINVAL:
@@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) {
219 return Errno::MSGSIZE; 226 return Errno::MSGSIZE;
220 case ETIMEDOUT: 227 case ETIMEDOUT:
221 return Errno::TIMEDOUT; 228 return Errno::TIMEDOUT;
229 case EINPROGRESS:
230 return Errno::INPROGRESS;
222 default: 231 default:
223 UNIMPLEMENTED_MSG("Unimplemented errno={}", e); 232 UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e));
224 return Errno::OTHER; 233 return Errno::OTHER;
225 } 234 }
226} 235}
@@ -234,15 +243,84 @@ Errno GetAndLogLastError() {
234 int e = errno; 243 int e = errno;
235#endif 244#endif
236 const Errno err = TranslateNativeError(e); 245 const Errno err = TranslateNativeError(e);
237 if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { 246 if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) {
247 // These happen during normal operation, so only log them at debug level.
248 LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
238 return err; 249 return err;
239 } 250 }
240 LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); 251 LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
241 return err; 252 return err;
242} 253}
243 254
244int TranslateDomain(Domain domain) { 255GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) {
256 switch (gai_err) {
257 case 0:
258 return GetAddrInfoError::SUCCESS;
259#ifdef EAI_ADDRFAMILY
260 case EAI_ADDRFAMILY:
261 return GetAddrInfoError::ADDRFAMILY;
262#endif
263 case EAI_AGAIN:
264 return GetAddrInfoError::AGAIN;
265 case EAI_BADFLAGS:
266 return GetAddrInfoError::BADFLAGS;
267 case EAI_FAIL:
268 return GetAddrInfoError::FAIL;
269 case EAI_FAMILY:
270 return GetAddrInfoError::FAMILY;
271 case EAI_MEMORY:
272 return GetAddrInfoError::MEMORY;
273 case EAI_NONAME:
274 return GetAddrInfoError::NONAME;
275 case EAI_SERVICE:
276 return GetAddrInfoError::SERVICE;
277 case EAI_SOCKTYPE:
278 return GetAddrInfoError::SOCKTYPE;
279 // These codes may not be defined on all systems:
280#ifdef EAI_SYSTEM
281 case EAI_SYSTEM:
282 return GetAddrInfoError::SYSTEM;
283#endif
284#ifdef EAI_BADHINTS
285 case EAI_BADHINTS:
286 return GetAddrInfoError::BADHINTS;
287#endif
288#ifdef EAI_PROTOCOL
289 case EAI_PROTOCOL:
290 return GetAddrInfoError::PROTOCOL;
291#endif
292#ifdef EAI_OVERFLOW
293 case EAI_OVERFLOW:
294 return GetAddrInfoError::OVERFLOW_;
295#endif
296 default:
297#ifdef EAI_NODATA
298 // This can't be a case statement because it would create a duplicate
299 // case on Windows where EAI_NODATA is an alias for EAI_NONAME.
300 if (gai_err == EAI_NODATA) {
301 return GetAddrInfoError::NODATA;
302 }
303#endif
304 return GetAddrInfoError::OTHER;
305 }
306}
307
308Domain TranslateDomainFromNative(int domain) {
309 switch (domain) {
310 case 0:
311 return Domain::Unspecified;
312 case AF_INET:
313 return Domain::INET;
314 default:
315 UNIMPLEMENTED_MSG("Unhandled domain={}", domain);
316 return Domain::INET;
317 }
318}
319
320int TranslateDomainToNative(Domain domain) {
245 switch (domain) { 321 switch (domain) {
322 case Domain::Unspecified:
323 return 0;
246 case Domain::INET: 324 case Domain::INET:
247 return AF_INET; 325 return AF_INET;
248 default: 326 default:
@@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) {
251 } 329 }
252} 330}
253 331
254int TranslateType(Type type) { 332Type TranslateTypeFromNative(int type) {
333 switch (type) {
334 case 0:
335 return Type::Unspecified;
336 case SOCK_STREAM:
337 return Type::STREAM;
338 case SOCK_DGRAM:
339 return Type::DGRAM;
340 case SOCK_RAW:
341 return Type::RAW;
342 case SOCK_SEQPACKET:
343 return Type::SEQPACKET;
344 default:
345 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
346 return Type::STREAM;
347 }
348}
349
350int TranslateTypeToNative(Type type) {
255 switch (type) { 351 switch (type) {
352 case Type::Unspecified:
353 return 0;
256 case Type::STREAM: 354 case Type::STREAM:
257 return SOCK_STREAM; 355 return SOCK_STREAM;
258 case Type::DGRAM: 356 case Type::DGRAM:
259 return SOCK_DGRAM; 357 return SOCK_DGRAM;
358 case Type::RAW:
359 return SOCK_RAW;
260 default: 360 default:
261 UNIMPLEMENTED_MSG("Unimplemented type={}", type); 361 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
262 return 0; 362 return 0;
263 } 363 }
264} 364}
265 365
266int TranslateProtocol(Protocol protocol) { 366Protocol TranslateProtocolFromNative(int protocol) {
367 switch (protocol) {
368 case 0:
369 return Protocol::Unspecified;
370 case IPPROTO_TCP:
371 return Protocol::TCP;
372 case IPPROTO_UDP:
373 return Protocol::UDP;
374 default:
375 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
376 return Protocol::Unspecified;
377 }
378}
379
380int TranslateProtocolToNative(Protocol protocol) {
267 switch (protocol) { 381 switch (protocol) {
382 case Protocol::Unspecified:
383 return 0;
268 case Protocol::TCP: 384 case Protocol::TCP:
269 return IPPROTO_TCP; 385 return IPPROTO_TCP;
270 case Protocol::UDP: 386 case Protocol::UDP:
@@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) {
275 } 391 }
276} 392}
277 393
278SockAddrIn TranslateToSockAddrIn(sockaddr input_) { 394SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) {
279 sockaddr_in input;
280 std::memcpy(&input, &input_, sizeof(input));
281
282 SockAddrIn result; 395 SockAddrIn result;
283 396
284 switch (input.sin_family) { 397 result.family = TranslateDomainFromNative(input.sin_family);
285 case AF_INET:
286 result.family = Domain::INET;
287 break;
288 default:
289 UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family);
290 result.family = Domain::INET;
291 break;
292 }
293 398
294 result.portno = ntohs(input.sin_port); 399 result.portno = ntohs(input.sin_port);
295 400
@@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
301short TranslatePollEvents(PollEvents events) { 406short TranslatePollEvents(PollEvents events) {
302 short result = 0; 407 short result = 0;
303 408
304 if (True(events & PollEvents::In)) { 409 const auto translate = [&result, &events](PollEvents guest, short host) {
305 events &= ~PollEvents::In; 410 if (True(events & guest)) {
306 result |= POLLIN; 411 events &= ~guest;
307 } 412 result |= host;
308 if (True(events & PollEvents::Pri)) { 413 }
309 events &= ~PollEvents::Pri; 414 };
415
416 translate(PollEvents::In, POLLIN);
417 translate(PollEvents::Pri, POLLPRI);
418 translate(PollEvents::Out, POLLOUT);
419 translate(PollEvents::Err, POLLERR);
420 translate(PollEvents::Hup, POLLHUP);
421 translate(PollEvents::Nval, POLLNVAL);
422 translate(PollEvents::RdNorm, POLLRDNORM);
423 translate(PollEvents::RdBand, POLLRDBAND);
424 translate(PollEvents::WrBand, POLLWRBAND);
425
310#ifdef _WIN32 426#ifdef _WIN32
311 LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); 427 short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM;
312#else 428 // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input.
313 result |= POLLPRI; 429 if (result & ~allowed_events) {
430 LOG_DEBUG(Network,
431 "Removing WSAPoll input events 0x{:x} because Windows doesn't support them",
432 result & ~allowed_events);
433 }
434 result &= allowed_events;
314#endif 435#endif
315 }
316 if (True(events & PollEvents::Out)) {
317 events &= ~PollEvents::Out;
318 result |= POLLOUT;
319 }
320 436
321 UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); 437 UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events);
322 438
@@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) {
337 translate(POLLOUT, PollEvents::Out); 453 translate(POLLOUT, PollEvents::Out);
338 translate(POLLERR, PollEvents::Err); 454 translate(POLLERR, PollEvents::Err);
339 translate(POLLHUP, PollEvents::Hup); 455 translate(POLLHUP, PollEvents::Hup);
456 translate(POLLNVAL, PollEvents::Nval);
457 translate(POLLRDNORM, PollEvents::RdNorm);
458 translate(POLLRDBAND, PollEvents::RdBand);
459 translate(POLLWRBAND, PollEvents::WrBand);
340 460
341 UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); 461 UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents);
342 462
@@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() {
360 return {}; 480 return {};
361 } 481 }
362 482
363 std::array<char, 16> ip_addr = {};
364 ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) !=
365 nullptr);
366 return TranslateIPv4(network_interface->ip_address); 483 return TranslateIPv4(network_interface->ip_address);
367} 484}
368 485
486std::string IPv4AddressToString(IPv4Address ip_addr) {
487 std::array<char, INET_ADDRSTRLEN> buf = {};
488 ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data());
489 return std::string(buf.data());
490}
491
492u32 IPv4AddressToInteger(IPv4Address ip_addr) {
493 return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 |
494 static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]);
495}
496
497Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
498 const std::string& host, const std::optional<std::string>& service) {
499 addrinfo hints{};
500 hints.ai_family = AF_INET; // Switch only supports IPv4.
501 addrinfo* addrinfo;
502 s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr,
503 &hints, &addrinfo);
504 if (gai_err != 0) {
505 return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err));
506 }
507 std::vector<AddrInfo> ret;
508 for (auto* current = addrinfo; current; current = current->ai_next) {
509 // We should only get AF_INET results due to the hints value.
510 ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET &&
511 addrinfo->ai_addrlen == sizeof(sockaddr_in),
512 continue;);
513
514 AddrInfo& out = ret.emplace_back();
515 out.family = TranslateDomainFromNative(current->ai_family);
516 out.socket_type = TranslateTypeFromNative(current->ai_socktype);
517 out.protocol = TranslateProtocolFromNative(current->ai_protocol);
518 out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr),
519 current->ai_addrlen);
520 if (current->ai_canonname != nullptr) {
521 out.canon_name = current->ai_canonname;
522 }
523 }
524 freeaddrinfo(addrinfo);
525 return ret;
526}
527
369std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { 528std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
370 const size_t num = pollfds.size(); 529 const size_t num = pollfds.size();
371 530
@@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept {
411} 570}
412 571
413template <typename T> 572template <typename T>
414Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { 573std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) {
574 T value{};
575 socklen_t len = sizeof(value);
576 const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
577 if (result != SOCKET_ERROR) {
578 ASSERT(len == sizeof(value));
579 return {value, Errno::SUCCESS};
580 }
581 return {value, GetAndLogLastError()};
582}
583
584template <typename T>
585Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) {
415 const int result = 586 const int result =
416 setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); 587 setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
417 if (result != SOCKET_ERROR) { 588 if (result != SOCKET_ERROR) {
418 return Errno::SUCCESS; 589 return Errno::SUCCESS;
419 } 590 }
@@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
421} 592}
422 593
423Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { 594Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
424 fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); 595 fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type),
596 TranslateProtocolToNative(protocol));
425 if (fd != INVALID_SOCKET) { 597 if (fd != INVALID_SOCKET) {
426 return Errno::SUCCESS; 598 return Errno::SUCCESS;
427 } 599 }
@@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
430} 602}
431 603
432std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { 604std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
433 sockaddr addr; 605 sockaddr_in addr;
434 socklen_t addrlen = sizeof(addr); 606 socklen_t addrlen = sizeof(addr);
435 const SOCKET new_socket = accept(fd, &addr, &addrlen); 607 const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
436 608
437 if (new_socket == INVALID_SOCKET) { 609 if (new_socket == INVALID_SOCKET) {
438 return {AcceptResult{}, GetAndLogLastError()}; 610 return {AcceptResult{}, GetAndLogLastError()};
439 } 611 }
440 612
441 ASSERT(addrlen == sizeof(sockaddr_in));
442
443 AcceptResult result{ 613 AcceptResult result{
444 .socket = std::make_unique<Socket>(new_socket), 614 .socket = std::make_unique<Socket>(new_socket),
445 .sockaddr_in = TranslateToSockAddrIn(addr), 615 .sockaddr_in = TranslateToSockAddrIn(addr, addrlen),
446 }; 616 };
447 617
448 return {std::move(result), Errno::SUCCESS}; 618 return {std::move(result), Errno::SUCCESS};
@@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) {
458} 628}
459 629
460std::pair<SockAddrIn, Errno> Socket::GetPeerName() { 630std::pair<SockAddrIn, Errno> Socket::GetPeerName() {
461 sockaddr addr; 631 sockaddr_in addr;
462 socklen_t addrlen = sizeof(addr); 632 socklen_t addrlen = sizeof(addr);
463 if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { 633 if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
464 return {SockAddrIn{}, GetAndLogLastError()}; 634 return {SockAddrIn{}, GetAndLogLastError()};
465 } 635 }
466 636
467 ASSERT(addrlen == sizeof(sockaddr_in)); 637 return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
468 return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
469} 638}
470 639
471std::pair<SockAddrIn, Errno> Socket::GetSockName() { 640std::pair<SockAddrIn, Errno> Socket::GetSockName() {
472 sockaddr addr; 641 sockaddr_in addr;
473 socklen_t addrlen = sizeof(addr); 642 socklen_t addrlen = sizeof(addr);
474 if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { 643 if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
475 return {SockAddrIn{}, GetAndLogLastError()}; 644 return {SockAddrIn{}, GetAndLogLastError()};
476 } 645 }
477 646
478 ASSERT(addrlen == sizeof(sockaddr_in)); 647 return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
479 return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
480} 648}
481 649
482Errno Socket::Bind(SockAddrIn addr) { 650Errno Socket::Bind(SockAddrIn addr) {
@@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) {
519 return GetAndLogLastError(); 687 return GetAndLogLastError();
520} 688}
521 689
522std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { 690std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) {
523 ASSERT(flags == 0); 691 ASSERT(flags == 0);
524 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 692 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
525 693
@@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
532 return {-1, GetAndLogLastError()}; 700 return {-1, GetAndLogLastError()};
533} 701}
534 702
535std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { 703std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
536 ASSERT(flags == 0); 704 ASSERT(flags == 0);
537 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 705 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
538 706
539 sockaddr addr_in{}; 707 sockaddr_in addr_in{};
540 socklen_t addrlen = sizeof(addr_in); 708 socklen_t addrlen = sizeof(addr_in);
541 socklen_t* const p_addrlen = addr ? &addrlen : nullptr; 709 socklen_t* const p_addrlen = addr ? &addrlen : nullptr;
542 sockaddr* const p_addr_in = addr ? &addr_in : nullptr; 710 sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr;
543 711
544 const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), 712 const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()),
545 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); 713 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen);
546 if (result != SOCKET_ERROR) { 714 if (result != SOCKET_ERROR) {
547 if (addr) { 715 if (addr) {
548 ASSERT(addrlen == sizeof(addr_in)); 716 *addr = TranslateToSockAddrIn(addr_in, addrlen);
549 *addr = TranslateToSockAddrIn(addr_in);
550 } 717 }
551 return {static_cast<s32>(result), Errno::SUCCESS}; 718 return {static_cast<s32>(result), Errno::SUCCESS};
552 } 719 }
@@ -597,6 +764,11 @@ Errno Socket::Close() {
597 return Errno::SUCCESS; 764 return Errno::SUCCESS;
598} 765}
599 766
767std::pair<Errno, Errno> Socket::GetPendingError() {
768 auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR);
769 return {TranslateNativeError(pending_err), getsockopt_err};
770}
771
600Errno Socket::SetLinger(bool enable, u32 linger) { 772Errno Socket::SetLinger(bool enable, u32 linger) {
601 return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); 773 return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger));
602} 774}
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h
index 1e09a007a..badcb8369 100644
--- a/src/core/internal_network/network.h
+++ b/src/core/internal_network/network.h
@@ -5,6 +5,7 @@
5 5
6#include <array> 6#include <array>
7#include <optional> 7#include <optional>
8#include <vector>
8 9
9#include "common/common_funcs.h" 10#include "common/common_funcs.h"
10#include "common/common_types.h" 11#include "common/common_types.h"
@@ -16,6 +17,11 @@
16#include <netinet/in.h> 17#include <netinet/in.h>
17#endif 18#endif
18 19
20namespace Common {
21template <typename T, typename E>
22class Expected;
23}
24
19namespace Network { 25namespace Network {
20 26
21class SocketBase; 27class SocketBase;
@@ -36,6 +42,26 @@ enum class Errno {
36 NETUNREACH, 42 NETUNREACH,
37 TIMEDOUT, 43 TIMEDOUT,
38 MSGSIZE, 44 MSGSIZE,
45 INPROGRESS,
46 OTHER,
47};
48
49enum class GetAddrInfoError {
50 SUCCESS,
51 ADDRFAMILY,
52 AGAIN,
53 BADFLAGS,
54 FAIL,
55 FAMILY,
56 MEMORY,
57 NODATA,
58 NONAME,
59 SERVICE,
60 SOCKTYPE,
61 SYSTEM,
62 BADHINTS,
63 PROTOCOL,
64 OVERFLOW_,
39 OTHER, 65 OTHER,
40}; 66};
41 67
@@ -49,6 +75,9 @@ enum class PollEvents : u16 {
49 Err = 1 << 3, 75 Err = 1 << 3,
50 Hup = 1 << 4, 76 Hup = 1 << 4,
51 Nval = 1 << 5, 77 Nval = 1 << 5,
78 RdNorm = 1 << 6,
79 RdBand = 1 << 7,
80 WrBand = 1 << 8,
52}; 81};
53 82
54DECLARE_ENUM_FLAG_OPERATORS(PollEvents); 83DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
@@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) {
82/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array 111/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array
83std::optional<IPv4Address> GetHostIPv4Address(); 112std::optional<IPv4Address> GetHostIPv4Address();
84 113
114std::string IPv4AddressToString(IPv4Address ip_addr);
115u32 IPv4AddressToInteger(IPv4Address ip_addr);
116
117// named to avoid name collision with Windows macro
118Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
119 const std::string& host, const std::optional<std::string>& service);
120
85} // namespace Network 121} // namespace Network
diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp
index 7a77171c2..44e9e3093 100644
--- a/src/core/internal_network/socket_proxy.cpp
+++ b/src/core/internal_network/socket_proxy.cpp
@@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) {
98 return Errno::SUCCESS; 98 return Errno::SUCCESS;
99} 99}
100 100
101std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { 101std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) {
102 LOG_WARNING(Network, "(STUBBED) called"); 102 LOG_WARNING(Network, "(STUBBED) called");
103 ASSERT(flags == 0); 103 ASSERT(flags == 0);
104 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 104 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -106,7 +106,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
106 return {static_cast<s32>(0), Errno::SUCCESS}; 106 return {static_cast<s32>(0), Errno::SUCCESS};
107} 107}
108 108
109std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { 109std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
110 ASSERT(flags == 0); 110 ASSERT(flags == 0);
111 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 111 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
112 112
@@ -140,8 +140,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message,
140 } 140 }
141} 141}
142 142
143std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message, 143std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
144 SockAddrIn* addr, std::size_t max_length) { 144 std::size_t max_length) {
145 ProxyPacket& packet = received_packets.front(); 145 ProxyPacket& packet = received_packets.front();
146 if (addr) { 146 if (addr) {
147 addr->family = Domain::INET; 147 addr->family = Domain::INET;
@@ -153,10 +153,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
153 std::size_t read_bytes; 153 std::size_t read_bytes;
154 if (packet.data.size() > max_length) { 154 if (packet.data.size() > max_length) {
155 read_bytes = max_length; 155 read_bytes = max_length;
156 message.clear(); 156 memcpy(message.data(), packet.data.data(), max_length);
157 std::copy(packet.data.begin(), packet.data.begin() + read_bytes,
158 std::back_inserter(message));
159 message.resize(max_length);
160 157
161 if (protocol == Protocol::UDP) { 158 if (protocol == Protocol::UDP) {
162 if (!peek) { 159 if (!peek) {
@@ -171,9 +168,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
171 } 168 }
172 } else { 169 } else {
173 read_bytes = packet.data.size(); 170 read_bytes = packet.data.size();
174 message.clear(); 171 memcpy(message.data(), packet.data.data(), read_bytes);
175 std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message));
176 message.resize(max_length);
177 if (!peek) { 172 if (!peek) {
178 received_packets.pop(); 173 received_packets.pop();
179 } 174 }
@@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) {
293 return Errno::SUCCESS; 288 return Errno::SUCCESS;
294} 289}
295 290
291std::pair<Errno, Errno> ProxySocket::GetPendingError() {
292 LOG_DEBUG(Network, "(STUBBED) called");
293 return {Errno::SUCCESS, Errno::SUCCESS};
294}
295
296bool ProxySocket::IsOpened() const { 296bool ProxySocket::IsOpened() const {
297 return fd != INVALID_SOCKET; 297 return fd != INVALID_SOCKET;
298} 298}
diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h
index 6e991fa38..e12c413d1 100644
--- a/src/core/internal_network/socket_proxy.h
+++ b/src/core/internal_network/socket_proxy.h
@@ -39,11 +39,11 @@ public:
39 39
40 Errno Shutdown(ShutdownHow how) override; 40 Errno Shutdown(ShutdownHow how) override;
41 41
42 std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; 42 std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
43 43
44 std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; 44 std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
45 45
46 std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr, 46 std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
47 std::size_t max_length); 47 std::size_t max_length);
48 48
49 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; 49 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -74,6 +74,8 @@ public:
74 template <typename T> 74 template <typename T>
75 Errno SetSockOpt(SOCKET fd, int option, T value); 75 Errno SetSockOpt(SOCKET fd, int option, T value);
76 76
77 std::pair<Errno, Errno> GetPendingError() override;
78
77 bool IsOpened() const override; 79 bool IsOpened() const override;
78 80
79private: 81private:
diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h
index 11e479e50..46a53ef79 100644
--- a/src/core/internal_network/sockets.h
+++ b/src/core/internal_network/sockets.h
@@ -59,10 +59,9 @@ public:
59 59
60 virtual Errno Shutdown(ShutdownHow how) = 0; 60 virtual Errno Shutdown(ShutdownHow how) = 0;
61 61
62 virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0; 62 virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0;
63 63
64 virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, 64 virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0;
65 SockAddrIn* addr) = 0;
66 65
67 virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; 66 virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0;
68 67
@@ -87,6 +86,8 @@ public:
87 86
88 virtual Errno SetNonBlock(bool enable) = 0; 87 virtual Errno SetNonBlock(bool enable) = 0;
89 88
89 virtual std::pair<Errno, Errno> GetPendingError() = 0;
90
90 virtual bool IsOpened() const = 0; 91 virtual bool IsOpened() const = 0;
91 92
92 virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; 93 virtual void HandleProxyPacket(const ProxyPacket& packet) = 0;
@@ -126,9 +127,9 @@ public:
126 127
127 Errno Shutdown(ShutdownHow how) override; 128 Errno Shutdown(ShutdownHow how) override;
128 129
129 std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; 130 std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
130 131
131 std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; 132 std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
132 133
133 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; 134 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
134 135
@@ -156,6 +157,11 @@ public:
156 template <typename T> 157 template <typename T>
157 Errno SetSockOpt(SOCKET fd, int option, T value); 158 Errno SetSockOpt(SOCKET fd, int option, T value);
158 159
160 std::pair<Errno, Errno> GetPendingError() override;
161
162 template <typename T>
163 std::pair<T, Errno> GetSockOpt(SOCKET fd, int option);
164
159 bool IsOpened() const override; 165 bool IsOpened() const override;
160 166
161 void HandleProxyPacket(const ProxyPacket& packet) override; 167 void HandleProxyPacket(const ProxyPacket& packet) override;